#![allow(clippy::must_use_candidate)]
#![allow(clippy::return_self_not_must_use)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::unused_self)] #![allow(clippy::unnecessary_wraps)] #![allow(clippy::option_if_let_else)]
use serde::{Deserialize, Serialize};
use std::fmt;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ExplainError {
#[error("Model does not support explainability: {reason}")]
UnsupportedModel {
reason: String,
},
#[error("Invalid input: expected {expected} features, got {actual}")]
InvalidInput {
expected: usize,
actual: usize,
},
#[error("Background dataset required for KernelSHAP")]
NoBackground,
#[error("Computation error: {0}")]
ComputationError(String),
}
pub trait Explainable {
fn predict(&self, instance: &[f32]) -> Result<f32, ExplainError>;
fn predict_batch(&self, instances: &[Vec<f32>]) -> Result<Vec<f32>, ExplainError> {
instances.iter().map(|x| self.predict(x)).collect()
}
fn n_features(&self) -> usize;
fn is_tree_model(&self) -> bool {
false
}
fn get_tree_structure(&self) -> Option<&TreeStructure> {
None
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeStructure {
pub n_trees: usize,
pub n_features: usize,
pub trees: Vec<DecisionTree>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecisionTree {
pub feature: Vec<i32>,
pub threshold: Vec<f32>,
pub left: Vec<usize>,
pub right: Vec<usize>,
pub value: Vec<f32>,
}
impl DecisionTree {
pub fn new(
feature: Vec<i32>,
threshold: Vec<f32>,
left: Vec<usize>,
right: Vec<usize>,
value: Vec<f32>,
) -> Self {
Self {
feature,
threshold,
left,
right,
value,
}
}
pub fn n_nodes(&self) -> usize {
self.feature.len()
}
pub fn is_leaf(&self, node: usize) -> bool {
self.feature.get(node).map_or(true, |&f| f < 0)
}
pub fn predict(&self, instance: &[f32]) -> f32 {
let mut node = 0;
while !self.is_leaf(node) {
let feature_idx = self.feature[node] as usize;
if instance
.get(feature_idx)
.is_some_and(|&v| v <= self.threshold[node])
{
node = self.left[node];
} else {
node = self.right[node];
}
}
self.value.get(node).copied().unwrap_or(0.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapExplanation {
pub base_value: f32,
pub shap_values: Vec<f32>,
pub feature_names: Vec<String>,
pub prediction: f32,
}
impl ShapExplanation {
pub fn new(base_value: f32, shap_values: Vec<f32>, prediction: f32) -> Self {
let n = shap_values.len();
Self {
base_value,
shap_values,
feature_names: (0..n).map(|i| format!("feature_{i}")).collect(),
prediction,
}
}
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
self.feature_names = names;
self
}
pub fn top_features(&self, n: usize) -> Vec<(String, f32)> {
let mut indexed: Vec<_> = self
.shap_values
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed.sort_by(|a, b| {
b.1.abs()
.partial_cmp(&a.1.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
indexed
.into_iter()
.take(n)
.map(|(i, v)| {
let name = self
.feature_names
.get(i)
.cloned()
.unwrap_or_else(|| format!("feature_{i}"));
(name, v)
})
.collect()
}
pub fn verify_consistency(&self, tolerance: f32) -> bool {
let sum: f32 = self.shap_values.iter().sum();
(self.base_value + sum - self.prediction).abs() < tolerance
}
}
impl fmt::Display for ShapExplanation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "SHAP Explanation:")?;
writeln!(f, " Base value: {:.4}", self.base_value)?;
writeln!(f, " Prediction: {:.4}", self.prediction)?;
writeln!(f, " Top features:")?;
for (name, value) in self.top_features(5) {
let sign = if value >= 0.0 { "+" } else { "" };
writeln!(f, " {name}: {sign}{value:.4}")?;
}
Ok(())
}
}
pub struct ShapExplainer {
background: Vec<Vec<f32>>,
nsamples: usize,
feature_names: Vec<String>,
}
impl ShapExplainer {
pub fn new(background: Vec<Vec<f32>>) -> Self {
let n_features = background.first().map_or(0, Vec::len);
Self {
background,
nsamples: 100,
feature_names: (0..n_features).map(|i| format!("feature_{i}")).collect(),
}
}
pub fn with_nsamples(mut self, nsamples: usize) -> Self {
self.nsamples = nsamples;
self
}
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
self.feature_names = names;
self
}
pub fn explain(
&self,
model: &dyn Explainable,
instance: &[f32],
) -> Result<ShapExplanation, ExplainError> {
if instance.len() != model.n_features() {
return Err(ExplainError::InvalidInput {
expected: model.n_features(),
actual: instance.len(),
});
}
if model.is_tree_model() {
if let Some(tree_structure) = model.get_tree_structure() {
return self.tree_shap(tree_structure, instance, model);
}
}
self.kernel_shap(model, instance)
}
fn tree_shap(
&self,
tree_structure: &TreeStructure,
instance: &[f32],
model: &dyn Explainable,
) -> Result<ShapExplanation, ExplainError> {
let n_features = tree_structure.n_features;
let mut shap_values = vec![0.0; n_features];
for tree in &tree_structure.trees {
let tree_shap = self.tree_shap_single(tree, instance)?;
for (i, v) in tree_shap.iter().enumerate() {
shap_values[i] += v / tree_structure.n_trees as f32;
}
}
let base_value = self.compute_expected_value(model)?;
let prediction = model.predict(instance)?;
Ok(ShapExplanation::new(base_value, shap_values, prediction)
.with_feature_names(self.feature_names.clone()))
}
fn tree_shap_single(
&self,
tree: &DecisionTree,
instance: &[f32],
) -> Result<Vec<f32>, ExplainError> {
let n_features = instance.len();
let mut shap_values = vec![0.0; n_features];
for feature_idx in 0..n_features {
let pred_with = tree.predict(instance);
let mut instance_without = instance.to_vec();
let background_mean = self
.background
.iter()
.filter_map(|bg| bg.get(feature_idx).copied())
.sum::<f32>()
/ self.background.len().max(1) as f32;
instance_without[feature_idx] = background_mean;
let pred_without = tree.predict(&instance_without);
shap_values[feature_idx] = pred_with - pred_without;
}
Ok(shap_values)
}
fn kernel_shap(
&self,
model: &dyn Explainable,
instance: &[f32],
) -> Result<ShapExplanation, ExplainError> {
if self.background.is_empty() {
return Err(ExplainError::NoBackground);
}
let n_features = instance.len();
let mut shap_values = vec![0.0; n_features];
for _ in 0..self.nsamples {
let coalition = self.sample_coalition(n_features);
let coalition_size = coalition.iter().filter(|&&b| b).count();
if coalition_size == 0 || coalition_size == n_features {
continue;
}
let marginal = self.compute_marginal(model, instance, &coalition)?;
let weight = self.shap_kernel_weight(n_features, coalition_size);
for (i, &in_coalition) in coalition.iter().enumerate() {
if in_coalition {
shap_values[i] += marginal * weight;
}
}
}
let total_weight: f32 = (1..n_features)
.map(|k| self.shap_kernel_weight(n_features, k))
.sum();
if total_weight > 0.0 {
for v in &mut shap_values {
*v /= total_weight;
}
}
let base_value = self.compute_expected_value(model)?;
let prediction = model.predict(instance)?;
Ok(ShapExplanation::new(base_value, shap_values, prediction)
.with_feature_names(self.feature_names.clone()))
}
fn sample_coalition(&self, n_features: usize) -> Vec<bool> {
(0..n_features).map(|i| i % 2 == 0).collect()
}
fn compute_marginal(
&self,
model: &dyn Explainable,
instance: &[f32],
coalition: &[bool],
) -> Result<f32, ExplainError> {
let mut total_pred = 0.0;
let n_background = self.background.len().max(1);
for bg in &self.background {
let mut masked: Vec<f32> = Vec::with_capacity(instance.len());
for (i, (&inst_val, &in_coalition)) in instance.iter().zip(coalition.iter()).enumerate()
{
if in_coalition {
masked.push(inst_val);
} else {
masked.push(bg.get(i).copied().unwrap_or(0.0));
}
}
total_pred += model.predict(&masked)?;
}
Ok(total_pred / n_background as f32)
}
fn compute_expected_value(&self, model: &dyn Explainable) -> Result<f32, ExplainError> {
if self.background.is_empty() {
return Ok(0.0);
}
let predictions: Result<Vec<f32>, _> =
self.background.iter().map(|x| model.predict(x)).collect();
let predictions = predictions?;
Ok(predictions.iter().sum::<f32>() / predictions.len() as f32)
}
fn shap_kernel_weight(&self, n_features: usize, coalition_size: usize) -> f32 {
let m = n_features as f32;
let s = coalition_size as f32;
let binom = binomial(n_features, coalition_size) as f32;
if binom * s * (m - s) == 0.0 {
0.0
} else {
m / (binom * s * (m - s))
}
}
}
fn binomial(n: usize, k: usize) -> usize {
if k > n {
return 0;
}
if k == 0 || k == n {
return 1;
}
let k = k.min(n - k); let mut result = 1usize;
for i in 0..k {
result = result.saturating_mul(n - i) / (i + 1);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
struct LinearModel {
weights: Vec<f32>,
bias: f32,
}
impl LinearModel {
fn new(weights: Vec<f32>, bias: f32) -> Self {
Self { weights, bias }
}
}
impl Explainable for LinearModel {
fn predict(&self, instance: &[f32]) -> Result<f32, ExplainError> {
if instance.len() != self.weights.len() {
return Err(ExplainError::InvalidInput {
expected: self.weights.len(),
actual: instance.len(),
});
}
let dot: f32 = instance.iter().zip(&self.weights).map(|(x, w)| x * w).sum();
Ok(dot + self.bias)
}
fn n_features(&self) -> usize {
self.weights.len()
}
}
struct SimpleTreeModel {
structure: TreeStructure,
}
impl SimpleTreeModel {
fn new(tree: DecisionTree) -> Self {
let n_features = tree
.feature
.iter()
.filter(|&&f| f >= 0)
.map(|&f| f as usize + 1)
.max()
.unwrap_or(1);
Self {
structure: TreeStructure {
n_trees: 1,
n_features,
trees: vec![tree],
},
}
}
}
impl Explainable for SimpleTreeModel {
fn predict(&self, instance: &[f32]) -> Result<f32, ExplainError> {
let sum: f32 = self
.structure
.trees
.iter()
.map(|t| t.predict(instance))
.sum();
Ok(sum / self.structure.n_trees as f32)
}
fn n_features(&self) -> usize {
self.structure.n_features
}
fn is_tree_model(&self) -> bool {
true
}
fn get_tree_structure(&self) -> Option<&TreeStructure> {
Some(&self.structure)
}
}
#[test]
fn test_shap_explanation_new() {
let exp = ShapExplanation::new(0.5, vec![0.1, -0.2, 0.3], 0.7);
assert_eq!(exp.base_value, 0.5);
assert_eq!(exp.shap_values.len(), 3);
assert_eq!(exp.prediction, 0.7);
assert_eq!(exp.feature_names.len(), 3);
}
#[test]
fn test_shap_explanation_with_feature_names() {
let exp = ShapExplanation::new(0.5, vec![0.1, -0.2], 0.4)
.with_feature_names(vec!["age".to_string(), "income".to_string()]);
assert_eq!(exp.feature_names, vec!["age", "income"]);
}
#[test]
fn test_shap_explanation_top_features() {
let exp = ShapExplanation::new(0.0, vec![0.1, -0.3, 0.2], 0.0).with_feature_names(vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
]);
let top = exp.top_features(2);
assert_eq!(top.len(), 2);
assert_eq!(top[0].0, "b"); assert_eq!(top[1].0, "c"); }
#[test]
fn test_shap_explanation_verify_consistency() {
let exp = ShapExplanation::new(0.5, vec![0.2, 0.3], 1.0);
assert!(exp.verify_consistency(0.01));
let exp_bad = ShapExplanation::new(0.5, vec![0.2, 0.3], 2.0);
assert!(!exp_bad.verify_consistency(0.01));
}
#[test]
fn test_shap_explanation_display() {
let exp = ShapExplanation::new(0.5, vec![0.1, -0.2], 0.4);
let display = format!("{exp}");
assert!(display.contains("SHAP Explanation"));
assert!(display.contains("Base value"));
assert!(display.contains("Prediction"));
}
#[test]
fn test_decision_tree_predict_simple() {
let tree = DecisionTree::new(
vec![0, -1, -1], vec![0.5, 0.0, 0.0], vec![1, 0, 0], vec![2, 0, 0], vec![0.0, 1.0, 2.0], );
assert_eq!(tree.predict(&[0.3]), 1.0); assert_eq!(tree.predict(&[0.7]), 2.0); }
#[test]
fn test_decision_tree_n_nodes() {
let tree = DecisionTree::new(
vec![0, -1, -1],
vec![0.5, 0.0, 0.0],
vec![1, 0, 0],
vec![2, 0, 0],
vec![0.0, 1.0, 2.0],
);
assert_eq!(tree.n_nodes(), 3);
}
#[test]
fn test_decision_tree_is_leaf() {
let tree = DecisionTree::new(
vec![0, -1, -1],
vec![0.5, 0.0, 0.0],
vec![1, 0, 0],
vec![2, 0, 0],
vec![0.0, 1.0, 2.0],
);
assert!(!tree.is_leaf(0)); assert!(tree.is_leaf(1)); assert!(tree.is_leaf(2)); }
#[test]
fn test_shap_explainer_new() {
let background = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let explainer = ShapExplainer::new(background);
assert_eq!(explainer.nsamples, 100);
assert_eq!(explainer.feature_names.len(), 2);
}
#[test]
fn test_shap_explainer_with_nsamples() {
let explainer = ShapExplainer::new(vec![vec![1.0]]).with_nsamples(50);
assert_eq!(explainer.nsamples, 50);
}
#[test]
fn test_shap_explainer_with_feature_names() {
let explainer = ShapExplainer::new(vec![vec![1.0, 2.0]])
.with_feature_names(vec!["age".to_string(), "income".to_string()]);
assert_eq!(explainer.feature_names, vec!["age", "income"]);
}
#[test]
fn test_shap_explainer_linear_model() {
let model = LinearModel::new(vec![1.0, 2.0], 0.0);
let background = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
let explainer = ShapExplainer::new(background);
let instance = vec![2.0, 3.0]; let explanation = explainer.explain(&model, &instance).unwrap();
assert_eq!(explanation.prediction, 8.0);
assert_eq!(explanation.shap_values.len(), 2);
}
#[test]
fn test_shap_explainer_tree_model() {
let tree = DecisionTree::new(
vec![0, -1, -1],
vec![0.5, 0.0, 0.0],
vec![1, 0, 0],
vec![2, 0, 0],
vec![0.0, 1.0, 2.0],
);
let model = SimpleTreeModel::new(tree);
let background = vec![vec![0.3], vec![0.7]];
let explainer = ShapExplainer::new(background);
let explanation = explainer.explain(&model, &[0.3]).unwrap();
assert_eq!(explanation.prediction, 1.0);
}
#[test]
fn test_shap_explainer_invalid_input() {
let model = LinearModel::new(vec![1.0, 2.0], 0.0);
let background = vec![vec![0.0, 0.0]];
let explainer = ShapExplainer::new(background);
let result = explainer.explain(&model, &[1.0, 2.0, 3.0]); assert!(matches!(result, Err(ExplainError::InvalidInput { .. })));
}
#[test]
fn test_shap_explainer_empty_background() {
let model = LinearModel::new(vec![1.0], 0.0);
let explainer = ShapExplainer::new(vec![]);
let result = explainer.explain(&model, &[1.0]);
assert!(matches!(result, Err(ExplainError::NoBackground)));
}
#[test]
fn test_explain_error_display() {
let err = ExplainError::UnsupportedModel {
reason: "not a tree".to_string(),
};
assert!(err.to_string().contains("not a tree"));
let err = ExplainError::InvalidInput {
expected: 3,
actual: 2,
};
assert!(err.to_string().contains("expected 3"));
assert!(err.to_string().contains("got 2"));
let err = ExplainError::NoBackground;
assert!(err.to_string().contains("Background"));
let err = ExplainError::ComputationError("overflow".to_string());
assert!(err.to_string().contains("overflow"));
}
#[test]
fn test_binomial_coefficient() {
assert_eq!(binomial(5, 0), 1);
assert_eq!(binomial(5, 1), 5);
assert_eq!(binomial(5, 2), 10);
assert_eq!(binomial(5, 3), 10);
assert_eq!(binomial(5, 4), 5);
assert_eq!(binomial(5, 5), 1);
assert_eq!(binomial(5, 6), 0); }
#[test]
fn test_binomial_edge_cases() {
assert_eq!(binomial(0, 0), 1);
assert_eq!(binomial(1, 0), 1);
assert_eq!(binomial(1, 1), 1);
assert_eq!(binomial(10, 5), 252);
}
#[test]
fn test_tree_structure_serialization() {
let tree = DecisionTree::new(
vec![0, -1, -1],
vec![0.5, 0.0, 0.0],
vec![1, 0, 0],
vec![2, 0, 0],
vec![0.0, 1.0, 2.0],
);
let structure = TreeStructure {
n_trees: 1,
n_features: 1,
trees: vec![tree],
};
let json = serde_json::to_string(&structure).unwrap();
assert!(json.contains("n_trees"));
assert!(json.contains("n_features"));
}
#[test]
fn test_shap_explanation_serialization() {
let exp = ShapExplanation::new(0.5, vec![0.1, -0.2], 0.4);
let json = serde_json::to_string(&exp).unwrap();
let parsed: ShapExplanation = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.base_value, exp.base_value);
assert_eq!(parsed.shap_values, exp.shap_values);
assert_eq!(parsed.prediction, exp.prediction);
}
#[test]
fn test_shap_explanation_empty_values() {
let exp = ShapExplanation::new(0.5, vec![], 0.5);
assert!(exp.verify_consistency(0.01));
assert!(exp.top_features(3).is_empty());
}
#[test]
fn test_decision_tree_empty_instance() {
let tree = DecisionTree::new(
vec![-1], vec![0.0],
vec![0],
vec![0],
vec![5.0],
);
assert_eq!(tree.predict(&[]), 5.0);
}
#[test]
fn test_linear_model_batch_predict() {
let model = LinearModel::new(vec![1.0, 2.0], 0.0);
let instances = vec![vec![1.0, 1.0], vec![2.0, 2.0]];
let predictions = model.predict_batch(&instances).unwrap();
assert_eq!(predictions, vec![3.0, 6.0]);
}
}