impl ShapExplainer {
pub fn new(background: Vec<Vec<f32>>) -> Self {
let n_features = background.first().map_or(0, Vec::len);
Self {
background,
nsamples: 100,
feature_names: (0..n_features).map(|i| format!("feature_{i}")).collect(),
}
}
pub fn with_nsamples(mut self, nsamples: usize) -> Self {
self.nsamples = nsamples;
self
}
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
self.feature_names = names;
self
}
pub fn explain(
&self,
model: &dyn Explainable,
instance: &[f32],
) -> Result<ShapExplanation, ExplainError> {
if instance.len() != model.n_features() {
return Err(ExplainError::InvalidInput {
expected: model.n_features(),
actual: instance.len(),
});
}
if model.is_tree_model() {
if let Some(tree_structure) = model.get_tree_structure() {
return self.tree_shap(tree_structure, instance, model);
}
}
self.kernel_shap(model, instance)
}
fn tree_shap(
&self,
tree_structure: &TreeStructure,
instance: &[f32],
model: &dyn Explainable,
) -> Result<ShapExplanation, ExplainError> {
let n_features = tree_structure.n_features;
let mut shap_values = vec![0.0; n_features];
for tree in &tree_structure.trees {
let tree_shap = self.tree_shap_single(tree, instance)?;
for (i, v) in tree_shap.iter().enumerate() {
shap_values[i] += v / tree_structure.n_trees as f32;
}
}
let base_value = self.compute_expected_value(model)?;
let prediction = model.predict(instance)?;
Ok(ShapExplanation::new(base_value, shap_values, prediction)
.with_feature_names(self.feature_names.clone()))
}
fn tree_shap_single(
&self,
tree: &DecisionTree,
instance: &[f32],
) -> Result<Vec<f32>, ExplainError> {
let n_features = instance.len();
let mut shap_values = vec![0.0; n_features];
for feature_idx in 0..n_features {
let pred_with = tree.predict(instance);
let mut instance_without = instance.to_vec();
let background_mean = self
.background
.iter()
.filter_map(|bg| bg.get(feature_idx).copied())
.sum::<f32>()
/ self.background.len().max(1) as f32;
instance_without[feature_idx] = background_mean;
let pred_without = tree.predict(&instance_without);
shap_values[feature_idx] = pred_with - pred_without;
}
Ok(shap_values)
}
fn kernel_shap(
&self,
model: &dyn Explainable,
instance: &[f32],
) -> Result<ShapExplanation, ExplainError> {
if self.background.is_empty() {
return Err(ExplainError::NoBackground);
}
let n_features = instance.len();
let mut shap_values = vec![0.0; n_features];
for _ in 0..self.nsamples {
let coalition = self.sample_coalition(n_features);
let coalition_size = coalition.iter().filter(|&&b| b).count();
if coalition_size == 0 || coalition_size == n_features {
continue;
}
let marginal = self.compute_marginal(model, instance, &coalition)?;
let weight = self.shap_kernel_weight(n_features, coalition_size);
for (i, &in_coalition) in coalition.iter().enumerate() {
if in_coalition {
shap_values[i] += marginal * weight;
}
}
}
let total_weight: f32 = (1..n_features)
.map(|k| self.shap_kernel_weight(n_features, k))
.sum();
if total_weight > 0.0 {
for v in &mut shap_values {
*v /= total_weight;
}
}
let base_value = self.compute_expected_value(model)?;
let prediction = model.predict(instance)?;
Ok(ShapExplanation::new(base_value, shap_values, prediction)
.with_feature_names(self.feature_names.clone()))
}
fn sample_coalition(&self, n_features: usize) -> Vec<bool> {
(0..n_features).map(|i| i % 2 == 0).collect()
}
fn compute_marginal(
&self,
model: &dyn Explainable,
instance: &[f32],
coalition: &[bool],
) -> Result<f32, ExplainError> {
let mut total_pred = 0.0;
let n_background = self.background.len().max(1);
for bg in &self.background {
let mut masked: Vec<f32> = Vec::with_capacity(instance.len());
for (i, (&inst_val, &in_coalition)) in instance.iter().zip(coalition.iter()).enumerate()
{
if in_coalition {
masked.push(inst_val);
} else {
masked.push(bg.get(i).copied().unwrap_or(0.0));
}
}
total_pred += model.predict(&masked)?;
}
Ok(total_pred / n_background as f32)
}
fn compute_expected_value(&self, model: &dyn Explainable) -> Result<f32, ExplainError> {
if self.background.is_empty() {
return Ok(0.0);
}
let predictions: Result<Vec<f32>, _> =
self.background.iter().map(|x| model.predict(x)).collect();
let predictions = predictions?;
Ok(predictions.iter().sum::<f32>() / predictions.len() as f32)
}
fn shap_kernel_weight(&self, n_features: usize, coalition_size: usize) -> f32 {
let m = n_features as f32;
let s = coalition_size as f32;
let binom = binomial(n_features, coalition_size) as f32;
if binom * s * (m - s) == 0.0 {
0.0
} else {
m / (binom * s * (m - s))
}
}
}
fn binomial(n: usize, k: usize) -> usize {
if k > n {
return 0;
}
if k == 0 || k == n {
return 1;
}
let k = k.min(n - k); let mut result = 1usize;
for i in 0..k {
result = result.saturating_mul(n - i) / (i + 1);
}
result
}