fdars_core/explain_generic/
shap.rs1use crate::error::FdarError;
2use crate::explain::{
3 accumulate_kernel_shap_sample, compute_column_means, compute_mean_scalar, get_obs_scalar,
4 sample_random_coalition, shapley_kernel_weight, solve_kernel_shap_obs, FpcShapValues,
5};
6use crate::iter_maybe_parallel;
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9
10use super::FpcPredictor;
11
12#[must_use = "expensive computation whose result should not be discarded"]
24pub fn generic_shap_values(
25 model: &dyn FpcPredictor,
26 data: &FdMatrix,
27 scalar_covariates: Option<&FdMatrix>,
28 n_samples: usize,
29 seed: u64,
30) -> Result<FpcShapValues, FdarError> {
31 #[cfg(feature = "parallel")]
32 use rayon::iter::ParallelIterator;
33
34 let (n, m) = data.shape();
35 if n == 0 {
36 return Err(FdarError::InvalidDimension {
37 parameter: "data",
38 expected: "n > 0".into(),
39 actual: "0 rows".into(),
40 });
41 }
42 if m != model.fpca_mean().len() {
43 return Err(FdarError::InvalidDimension {
44 parameter: "data columns",
45 expected: model.fpca_mean().len().to_string(),
46 actual: m.to_string(),
47 });
48 }
49 if n_samples == 0 {
50 return Err(FdarError::InvalidParameter {
51 parameter: "n_samples",
52 message: "n_samples must be > 0".into(),
53 });
54 }
55 let ncomp = model.ncomp();
56 if ncomp == 0 {
57 return Err(FdarError::InvalidParameter {
58 parameter: "ncomp",
59 message: "model has 0 components".into(),
60 });
61 }
62 let p_scalar = scalar_covariates.map_or(0, crate::matrix::FdMatrix::ncols);
63 let scores = model.project(data);
64 let mean_scores = compute_column_means(&scores, ncomp);
65 let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
66
67 let base_value = model.predict_from_scores(
68 &mean_scores,
69 if mean_z.is_empty() {
70 None
71 } else {
72 Some(&mean_z)
73 },
74 );
75
76 let rows: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
77 .map(|i| {
78 let mut rng_i = StdRng::seed_from_u64(seed.wrapping_add(i as u64));
79 let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
80 let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
81
82 let mut ata = vec![0.0; ncomp * ncomp];
83 let mut atb = vec![0.0; ncomp];
84 let mut coal_scores = vec![0.0; ncomp];
86
87 let f_base = model.predict_from_scores(
89 &mean_scores,
90 if obs_z.is_empty() { None } else { Some(&obs_z) },
91 );
92
93 for _ in 0..n_samples {
94 let (coalition, s_size) = sample_random_coalition(&mut rng_i, ncomp);
95 let weight = shapley_kernel_weight(ncomp, s_size);
96 for (k, &in_coal) in coalition.iter().enumerate() {
98 coal_scores[k] = if in_coal {
99 obs_scores[k]
100 } else {
101 mean_scores[k]
102 };
103 }
104
105 let f_coal = model.predict_from_scores(
106 &coal_scores,
107 if obs_z.is_empty() { None } else { Some(&obs_z) },
108 );
109 let y_val = f_coal - f_base;
110
111 accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
112 }
113
114 let mut local_values = FdMatrix::zeros(1, ncomp);
116 solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut local_values, 0);
117 (0..ncomp).map(|k| local_values[(0, k)]).collect()
118 })
119 .collect();
120
121 let mut values = FdMatrix::zeros(n, ncomp);
122 for (i, row) in rows.iter().enumerate() {
123 for (k, &v) in row.iter().enumerate() {
124 values[(i, k)] = v;
125 }
126 }
127
128 Ok(FpcShapValues {
129 values,
130 base_value,
131 mean_scores,
132 })
133}