use crate::error::FdarError;
use crate::explain::{
accumulate_kernel_shap_sample, compute_column_means, compute_mean_scalar, get_obs_scalar,
sample_random_coalition, shapley_kernel_weight, solve_kernel_shap_obs, FpcShapValues,
};
use crate::iter_maybe_parallel;
use crate::matrix::FdMatrix;
use rand::prelude::*;
use super::FpcPredictor;
#[must_use = "expensive computation whose result should not be discarded"]
pub fn generic_shap_values(
model: &dyn FpcPredictor,
data: &FdMatrix,
scalar_covariates: Option<&FdMatrix>,
n_samples: usize,
seed: u64,
) -> Result<FpcShapValues, FdarError> {
#[cfg(feature = "parallel")]
use rayon::iter::ParallelIterator;
let (n, m) = data.shape();
if n == 0 {
return Err(FdarError::InvalidDimension {
parameter: "data",
expected: "n > 0".into(),
actual: "0 rows".into(),
});
}
if m != model.fpca_mean().len() {
return Err(FdarError::InvalidDimension {
parameter: "data columns",
expected: model.fpca_mean().len().to_string(),
actual: m.to_string(),
});
}
if n_samples == 0 {
return Err(FdarError::InvalidParameter {
parameter: "n_samples",
message: "n_samples must be > 0".into(),
});
}
let ncomp = model.ncomp();
if ncomp == 0 {
return Err(FdarError::InvalidParameter {
parameter: "ncomp",
message: "model has 0 components".into(),
});
}
let p_scalar = scalar_covariates.map_or(0, crate::matrix::FdMatrix::ncols);
let scores = model.project(data);
let mean_scores = compute_column_means(&scores, ncomp);
let mean_z = compute_mean_scalar(scalar_covariates, p_scalar, n);
let base_value = model.predict_from_scores(
&mean_scores,
if mean_z.is_empty() {
None
} else {
Some(&mean_z)
},
);
let rows: Vec<Vec<f64>> = iter_maybe_parallel!(0..n)
.map(|i| {
let mut rng_i = StdRng::seed_from_u64(seed.wrapping_add(i as u64));
let obs_scores: Vec<f64> = (0..ncomp).map(|k| scores[(i, k)]).collect();
let obs_z = get_obs_scalar(scalar_covariates, i, p_scalar, &mean_z);
let mut ata = vec![0.0; ncomp * ncomp];
let mut atb = vec![0.0; ncomp];
let mut coal_scores = vec![0.0; ncomp];
let f_base = model.predict_from_scores(
&mean_scores,
if obs_z.is_empty() { None } else { Some(&obs_z) },
);
for _ in 0..n_samples {
let (coalition, s_size) = sample_random_coalition(&mut rng_i, ncomp);
let weight = shapley_kernel_weight(ncomp, s_size);
for (k, &in_coal) in coalition.iter().enumerate() {
coal_scores[k] = if in_coal {
obs_scores[k]
} else {
mean_scores[k]
};
}
let f_coal = model.predict_from_scores(
&coal_scores,
if obs_z.is_empty() { None } else { Some(&obs_z) },
);
let y_val = f_coal - f_base;
accumulate_kernel_shap_sample(&mut ata, &mut atb, &coalition, weight, y_val, ncomp);
}
let mut local_values = FdMatrix::zeros(1, ncomp);
solve_kernel_shap_obs(&mut ata, &atb, ncomp, &mut local_values, 0);
(0..ncomp).map(|k| local_values[(0, k)]).collect()
})
.collect();
let mut values = FdMatrix::zeros(n, ncomp);
for (i, row) in rows.iter().enumerate() {
for (k, &v) in row.iter().enumerate() {
values[(i, k)] = v;
}
}
Ok(FpcShapValues {
values,
base_value,
mean_scores,
})
}