use crate::tree::FlatTree;
#[derive(Clone, Copy, Debug)]
struct PathEntry {
feature_index: usize,
zero_fraction: f64,
one_fraction: f64,
pweight: f64,
}
pub fn tree_shap(tree: &FlatTree, sample: &[f64], node_counts: &[usize]) -> Vec<f64> {
let n_features = sample.len();
let mut phi = vec![0.0; n_features];
if tree.nodes.is_empty() {
return phi;
}
let mut path: Vec<PathEntry> = vec![PathEntry {
feature_index: usize::MAX,
zero_fraction: 1.0,
one_fraction: 1.0,
pweight: 1.0,
}];
recurse(
tree,
sample,
node_counts,
0,
&mut path,
0, 1.0, 1.0, -1_i64, &mut phi,
);
phi
}
pub fn ensemble_tree_shap(
trees: &[(&FlatTree, &[usize])],
sample: &[f64],
n_features: usize,
) -> Vec<f64> {
if trees.is_empty() {
return vec![0.0; n_features];
}
let mut phi = vec![0.0; n_features];
for &(tree, counts) in trees {
let tree_phi = tree_shap(tree, sample, counts);
for (i, &v) in tree_phi.iter().enumerate() {
if i < n_features {
phi[i] += v;
}
}
}
let n = trees.len() as f64;
for v in &mut phi {
*v /= n;
}
phi
}
#[allow(clippy::too_many_arguments)]
fn recurse(
tree: &FlatTree,
sample: &[f64],
node_counts: &[usize],
node_idx: usize,
path: &mut Vec<PathEntry>,
unique_depth: usize,
pz: f64,
po: f64,
incoming_feature: i64,
phi: &mut [f64],
) {
if incoming_feature >= 0 {
extend_path(path, unique_depth, pz, po, incoming_feature as usize);
}
let node = &tree.nodes[node_idx];
if node.right == u32::MAX {
let leaf_idx = node.feature_idx as usize;
let leaf_val = tree.predictions[leaf_idx];
for i in 1..=unique_depth {
let w = unwound_path_sum(path, unique_depth, i);
let entry = &path[i];
if entry.feature_index < phi.len() {
let contrib = w * (entry.one_fraction - entry.zero_fraction) * leaf_val;
phi[entry.feature_index] += contrib;
}
}
} else {
let split_feature = node.feature_idx as usize;
let threshold = node.threshold;
let left_idx = node_idx + 1;
let right_idx = node.right as usize;
let parent_count = node_counts[node_idx] as f64;
let left_count = if left_idx < node_counts.len() {
node_counts[left_idx] as f64
} else {
0.0
};
let right_count = if right_idx < node_counts.len() {
node_counts[right_idx] as f64
} else {
0.0
};
let goes_left = split_feature < sample.len() && sample[split_feature] <= threshold;
let (hot_idx, cold_idx, hot_count, cold_count) = if goes_left {
(left_idx, right_idx, left_count, right_count)
} else {
(right_idx, left_idx, right_count, left_count)
};
let hot_zero_fraction = if parent_count > 0.0 {
hot_count / parent_count
} else {
0.5
};
let cold_zero_fraction = if parent_count > 0.0 {
cold_count / parent_count
} else {
0.5
};
let mut incoming_zero = 1.0;
let mut incoming_one = 1.0;
let mut found_idx = None;
for i in 1..=unique_depth {
if path[i].feature_index == split_feature {
incoming_zero = path[i].zero_fraction;
incoming_one = path[i].one_fraction;
found_idx = Some(i);
break;
}
}
let mut next_depth = unique_depth;
if let Some(fi) = found_idx {
unwind_path(path, next_depth, fi);
next_depth -= 1;
}
let saved_path = path.clone();
recurse(
tree,
sample,
node_counts,
hot_idx,
path,
next_depth + 1,
incoming_zero * hot_zero_fraction,
incoming_one,
split_feature as i64,
phi,
);
*path = saved_path.clone();
recurse(
tree,
sample,
node_counts,
cold_idx,
path,
next_depth + 1,
incoming_zero * cold_zero_fraction,
0.0,
split_feature as i64,
phi,
);
*path = saved_path;
}
}
fn extend_path(
path: &mut Vec<PathEntry>,
unique_depth: usize,
zero_fraction: f64,
one_fraction: f64,
feature: usize,
) {
while path.len() <= unique_depth {
path.push(PathEntry {
feature_index: usize::MAX,
zero_fraction: 0.0,
one_fraction: 0.0,
pweight: 0.0,
});
}
path[unique_depth] = PathEntry {
feature_index: feature,
zero_fraction,
one_fraction,
pweight: if unique_depth == 0 { 1.0 } else { 0.0 },
};
if unique_depth > 0 {
let d = unique_depth; for i in (0..d).rev() {
let old_pw = path[i].pweight;
path[i + 1].pweight += one_fraction * old_pw * ((i + 1) as f64) / ((d + 1) as f64);
path[i].pweight = zero_fraction * old_pw * ((d - i) as f64) / ((d + 1) as f64);
}
}
}
fn unwind_path(path: &mut [PathEntry], unique_depth: usize, path_index: usize) {
let one_fraction = path[path_index].one_fraction;
let zero_fraction = path[path_index].zero_fraction;
let n = unique_depth;
let mut next_one = path[n].pweight;
for i in (0..n).rev() {
if one_fraction.abs() > 1e-30 {
let tmp = next_one * ((n + 1) as f64) / (((i + 1) as f64) * one_fraction);
next_one = path[i].pweight - tmp * zero_fraction * ((n - i) as f64) / ((n + 1) as f64);
path[i].pweight = tmp;
} else {
path[i].pweight =
path[i].pweight * ((n + 1) as f64) / (zero_fraction * ((n - i) as f64));
}
}
for i in path_index..n {
path[i] = path[i + 1];
}
}
fn unwound_path_sum(path: &[PathEntry], unique_depth: usize, path_index: usize) -> f64 {
let one_fraction = path[path_index].one_fraction;
let zero_fraction = path[path_index].zero_fraction;
let n = unique_depth;
let mut next_one = path[n].pweight;
let mut total = 0.0;
if one_fraction.abs() < 1e-30 && zero_fraction.abs() < 1e-30 {
return 0.0;
}
for i in (0..n).rev() {
if one_fraction.abs() > 1e-30 {
let tmp = next_one * ((n + 1) as f64) / (((i + 1) as f64) * one_fraction);
total += tmp;
next_one = path[i].pweight - tmp * zero_fraction * ((n - i) as f64) / ((n + 1) as f64);
} else {
total += (path[i].pweight / zero_fraction) / (((n - i) as f64) / ((n + 1) as f64));
}
}
total
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tree::{FlatTree, TreeNode};
fn simple_tree() -> (FlatTree, Vec<usize>) {
let root = TreeNode::Split {
feature_idx: 0,
threshold: 0.5,
left: Box::new(TreeNode::Leaf {
prediction: 1.0,
n_samples: 50,
class_counts: vec![],
impurity: 0.0,
}),
right: Box::new(TreeNode::Leaf {
prediction: 3.0,
n_samples: 50,
class_counts: vec![],
impurity: 0.0,
}),
n_samples: 100,
impurity: 0.5,
class_counts: vec![],
prediction: 2.0,
};
let flat = FlatTree::from_tree_node(&root, 0);
let counts = flat.node_counts.clone();
(flat, counts)
}
#[test]
fn test_tree_shap_single_split() {
let (tree, counts) = simple_tree();
let sample = vec![0.2];
let phi = tree_shap(&tree, &sample, &counts);
assert_eq!(phi.len(), 1);
assert!(
(phi[0] - (-1.0)).abs() < 1e-10,
"SHAP for left sample should be -1.0, got {}",
phi[0]
);
}
#[test]
fn test_tree_shap_right_path() {
let (tree, counts) = simple_tree();
let sample = vec![0.8];
let phi = tree_shap(&tree, &sample, &counts);
assert_eq!(phi.len(), 1);
assert!(
(phi[0] - 1.0).abs() < 1e-10,
"SHAP for right sample should be 1.0, got {}",
phi[0]
);
}
#[test]
fn test_tree_shap_empty_tree() {
let tree = FlatTree {
nodes: vec![],
predictions: vec![],
leaf_probas: vec![],
n_classes_stored: 0,
node_counts: vec![],
};
let phi = tree_shap(&tree, &[1.0, 2.0], &[]);
assert_eq!(phi, vec![0.0, 0.0]);
}
#[test]
fn test_ensemble_tree_shap() {
let (tree1, counts1) = simple_tree();
let (tree2, counts2) = simple_tree();
let trees: Vec<(&FlatTree, &[usize])> = vec![(&tree1, &counts1), (&tree2, &counts2)];
let sample = vec![0.2];
let phi = ensemble_tree_shap(&trees, &sample, 1);
assert_eq!(phi.len(), 1);
let single_phi = tree_shap(&tree1, &sample, &counts1);
assert!(
(phi[0] - single_phi[0]).abs() < 1e-10,
"Ensemble of identical trees should match single: {} vs {}",
phi[0],
single_phi[0]
);
}
#[test]
fn test_tree_shap_additivity() {
let (tree, counts) = simple_tree();
let expected = 2.0;
let sample_left = vec![0.2];
let phi_left = tree_shap(&tree, &sample_left, &counts);
let pred_left = tree.predict_sample(&sample_left);
let phi_sum_left: f64 = phi_left.iter().sum();
assert!(
(phi_sum_left - (pred_left - expected)).abs() < 1e-10,
"SHAP additivity: sum={}, pred-E[f]={}",
phi_sum_left,
pred_left - expected,
);
let sample_right = vec![0.8];
let phi_right = tree_shap(&tree, &sample_right, &counts);
let pred_right = tree.predict_sample(&sample_right);
let phi_sum_right: f64 = phi_right.iter().sum();
assert!(
(phi_sum_right - (pred_right - expected)).abs() < 1e-10,
"SHAP additivity: sum={}, pred-E[f]={}",
phi_sum_right,
pred_right - expected,
);
}
}