Skip to main content

fdars_core/explain_generic/
shap.rs

1use crate::error::FdarError;
2use crate::explain::{
3    accumulate_kernel_shap_sample, compute_column_means, compute_mean_scalar, get_obs_scalar,
4    sample_random_coalition, shapley_kernel_weight, solve_kernel_shap_obs, FpcShapValues,
5};
6use crate::iter_maybe_parallel;
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9
10use super::FpcPredictor;
11
12/// Generic Kernel SHAP values for any FPC-based model.
13///
14/// For nonlinear models uses sampling-based Kernel SHAP; linear models get
15/// the same approximation (which converges to exact with enough samples).
16///
17/// # Errors
18///
19/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows or its
20/// column count does not match the model.
21/// Returns [`FdarError::InvalidParameter`] if `n_samples` is zero or the
22/// model has zero components.
23#[must_use = "expensive computation whose result should not be discarded"]
24pub fn generic_shap_values(
25    model: &dyn FpcPredictor,
26    data: &FdMatrix,
27    scalar_covariates: Option<&FdMatrix>,
28    n_samples: usize,
29    seed: u64,
30) -> Result<FpcShapValues, FdarError> {
31    #[cfg(feature = "parallel")]
32    use rayon::iter::ParallelIterator;
33
34    let (n, m) = data.shape();
35    if n == 0 {
36        return Err(FdarError::InvalidDimension {
37            parameter: "data",
38            expected: "n > 0".into(),
39            actual: "0 rows".into(),
40        });
41    }
42    if m != model.fpca_mean().len() {
43        return Err(FdarError::InvalidDimension {
44            parameter: "data columns",
45            expected: model.fpca_mean().len().to_string(),
46            actual: m.to_string(),
47        });
48    }
49    if n_samples == 0 {
50        return Err(FdarError::InvalidParameter {
51            parameter: "n_samples",
52            message: "n_samples must be > 0".into(),
53        });
54    }
55    let ncomp = model.ncomp();
56    if ncomp == 0 {
57        return Err(FdarError::InvalidParameter {
58            parameter: "ncomp",
59            message: "model has 0 components".into(),
60        });
61    }
62    let p_scalar = scalar_covariates.map_or(0, crate::matrix::FdMatrix::ncols);
63    let scores = model.project(data);
64    let mean_scores = compute_column_means(&scores, ncomp);
65    let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
66
67    let base_value = model.predict_from_scores(
68        &mean_scores,
69        if mean_z.is_empty() {
70            None
71        } else {
72            Some(&mean_z)
73        },
74    );
75
76    let rows: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
77        .map(|i| {
78            let mut rng_i = StdRng::seed_from_u64(seed.wrapping_add(i as u64));
79            let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
80            let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
81
82            let mut ata = vec![0.0; ncomp * ncomp];
83            let mut atb = vec![0.0; ncomp];
84            // Pre-allocate coalition scores buffer outside the inner loop
85            let mut coal_scores = vec![0.0; ncomp];
86
87            // Pre-compute f_base once (it is constant across all coalitions)
88            let f_base = model.predict_from_scores(
89                &mean_scores,
90                if obs_z.is_empty() { None } else { Some(&obs_z) },
91            );
92
93            for _ in 0..n_samples {
94                let (coalition, s_size) = sample_random_coalition(&mut rng_i, ncomp);
95                let weight = shapley_kernel_weight(ncomp, s_size);
96                // Reuse pre-allocated buffer instead of allocating a new Vec each iteration
97                for (k, &in_coal) in coalition.iter().enumerate() {
98                    coal_scores[k] = if in_coal {
99                        obs_scores[k]
100                    } else {
101                        mean_scores[k]
102                    };
103                }
104
105                let f_coal = model.predict_from_scores(
106                    &coal_scores,
107                    if obs_z.is_empty() { None } else { Some(&obs_z) },
108                );
109                let y_val = f_coal - f_base;
110
111                accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
112            }
113
114            // Solve locally and return row
115            let mut local_values = FdMatrix::zeros(1, ncomp);
116            solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut local_values, 0);
117            (0..ncomp).map(|k| local_values[(0, k)]).collect()
118        })
119        .collect();
120
121    let mut values = FdMatrix::zeros(n, ncomp);
122    for (i, row) in rows.iter().enumerate() {
123        for (k, &v) in row.iter().enumerate() {
124            values[(i, k)] = v;
125        }
126    }
127
128    Ok(FpcShapValues {
129        values,
130        base_value,
131        mean_scores,
132    })
133}