Skip to main content

irithyll/explain/
treeshap.rs

1//! TreeSHAP algorithm for single trees and ensembles.
2//!
3//! Implements the path-dependent TreeSHAP (Lundberg et al., 2020) which
4//! computes exact Shapley values in O(L * D^2) per tree, where L is the
5//! number of leaves and D is the tree depth.
6
7use crate::tree::node::{NodeId, TreeArena};
8
9/// Precompute the true cover (sum of leaf sample counts) for each node.
10///
11/// The arena's `sample_count` only tracks how many samples reached a node while
12/// it was a leaf. After splitting, the internal node's count goes stale. This
13/// function computes the correct cover by summing leaf counts bottom-up.
14fn compute_covers(arena: &TreeArena, root: NodeId) -> Vec<f64> {
15    let n = arena.n_nodes();
16    let mut covers = vec![0.0; n];
17
18    fn fill(arena: &TreeArena, node: NodeId, covers: &mut [f64]) -> f64 {
19        let idx = node.idx();
20        if arena.is_leaf[idx] {
21            covers[idx] = arena.sample_count[idx] as f64;
22            return covers[idx];
23        }
24        let left = fill(arena, arena.left[idx], covers);
25        let right = fill(arena, arena.right[idx], covers);
26        covers[idx] = left + right;
27        covers[idx]
28    }
29
30    if !root.is_none() && root.idx() < n {
31        fill(arena, root, &mut covers);
32    }
33    covers
34}
35
36/// Per-feature SHAP contributions for a single prediction.
37///
38/// # Invariant
39///
40/// `base_value + values.iter().sum::<f64>() ≈ model.predict(features)`
41#[derive(Debug, Clone)]
42pub struct ShapValues {
43    /// Per-feature SHAP contribution (positive = pushes prediction up).
44    pub values: Vec<f64>,
45    /// Expected value of the model (mean prediction over training data).
46    pub base_value: f64,
47}
48
49/// Named SHAP values with feature names attached.
50#[derive(Debug, Clone)]
51pub struct NamedShapValues {
52    /// `(feature_name, shap_value)` pairs.
53    pub values: Vec<(String, f64)>,
54    /// Expected value of the model.
55    pub base_value: f64,
56}
57
58// ---------------------------------------------------------------------------
59// Path-dependent TreeSHAP internals
60// ---------------------------------------------------------------------------
61
62/// A single entry in the path tracking structure.
63#[derive(Clone)]
64struct PathEntry {
65    /// Feature index at this node (-1 if root / no feature).
66    feature_idx: i64,
67    /// Fraction of zero-valued paths passing through this node.
68    zero_fraction: f64,
69    /// Fraction of one-valued paths passing through this node.
70    one_fraction: f64,
71    /// Proportion of training samples passing through this node.
72    pweight: f64,
73}
74
75/// Extend the path by one node.
76fn extend_path(path: &mut Vec<PathEntry>, zero_fraction: f64, one_fraction: f64, feature_idx: i64) {
77    let depth = path.len();
78    path.push(PathEntry {
79        feature_idx,
80        zero_fraction,
81        one_fraction,
82        pweight: if depth == 0 { 1.0 } else { 0.0 },
83    });
84
85    // Update path weights.
86    for i in (1..depth + 1).rev() {
87        path[i].pweight += one_fraction * path[i - 1].pweight * (i as f64) / ((depth + 1) as f64);
88        path[i - 1].pweight =
89            zero_fraction * path[i - 1].pweight * ((depth + 1 - i) as f64) / ((depth + 1) as f64);
90    }
91}
92
93/// Unwind the path after returning from a subtree.
94fn unwind_path(path: &mut Vec<PathEntry>, path_idx: usize) {
95    let depth = path.len() - 1;
96    let one_fraction = path[path_idx].one_fraction;
97    let zero_fraction = path[path_idx].zero_fraction;
98
99    let mut next_one_portion = path[depth].pweight;
100
101    for i in (0..depth).rev() {
102        if one_fraction != 0.0 {
103            let tmp = path[i].pweight;
104            path[i].pweight =
105                next_one_portion * ((depth + 1 - i) as f64) / ((i + 1) as f64 * one_fraction);
106            next_one_portion =
107                tmp - path[i].pweight * zero_fraction * ((i + 1) as f64) / ((depth + 1 - i) as f64);
108        } else {
109            path[i].pweight =
110                path[i].pweight * ((depth + 1 - i) as f64) / (zero_fraction * (i + 1) as f64);
111        }
112    }
113
114    // Remove the entry at path_idx.
115    for i in path_idx..depth {
116        path[i] = path[i + 1].clone();
117    }
118    path.pop();
119}
120
121/// Compute the sum of unwound path weights for a given feature index.
122fn unwound_path_sum(path: &[PathEntry], path_idx: usize) -> f64 {
123    let depth = path.len() - 1;
124    let one_fraction = path[path_idx].one_fraction;
125    let zero_fraction = path[path_idx].zero_fraction;
126
127    let mut total = 0.0;
128    let mut next_one_portion = path[depth].pweight;
129
130    for i in (0..depth).rev() {
131        if one_fraction != 0.0 {
132            let tmp = next_one_portion * ((depth + 1 - i) as f64) / ((i + 1) as f64 * one_fraction);
133            total += tmp;
134            next_one_portion =
135                path[i].pweight - tmp * zero_fraction * ((i + 1) as f64) / ((depth + 1 - i) as f64);
136        } else {
137            total += path[i].pweight / (zero_fraction * (i + 1) as f64) * ((depth + 1 - i) as f64);
138        }
139    }
140    total
141}
142
143/// Recursive TreeSHAP traversal.
144fn tree_shap_recursive(
145    arena: &TreeArena,
146    covers: &[f64],
147    node: NodeId,
148    features: &[f64],
149    shap_values: &mut [f64],
150    path: &mut Vec<PathEntry>,
151) {
152    let idx = node.idx();
153
154    if arena.is_leaf[idx] {
155        // At a leaf — accumulate SHAP contributions.
156        let leaf_value = arena.leaf_value[idx];
157        for i in 1..path.len() {
158            let w = unwound_path_sum(path, i);
159            let feat = path[i].feature_idx;
160            if feat >= 0 && (feat as usize) < shap_values.len() {
161                shap_values[feat as usize] +=
162                    w * (path[i].one_fraction - path[i].zero_fraction) * leaf_value;
163            }
164        }
165        return;
166    }
167
168    // Internal node.
169    let split_feat = arena.feature_idx[idx] as i64;
170    let threshold = arena.threshold[idx];
171    let left = arena.left[idx];
172    let right = arena.right[idx];
173
174    let left_cover = covers[left.idx()];
175    let right_cover = covers[right.idx()];
176    let node_cover = left_cover + right_cover;
177
178    // No training data at this node — nothing to explain.
179    if node_cover == 0.0 {
180        return;
181    }
182
183    // Determine which child the sample goes to.
184    let feat_val = if (split_feat as usize) < features.len() {
185        features[split_feat as usize]
186    } else {
187        0.0
188    };
189
190    let (hot_child, cold_child, hot_cover, cold_cover) = if feat_val <= threshold {
191        (left, right, left_cover, right_cover)
192    } else {
193        (right, left, right_cover, left_cover)
194    };
195
196    let hot_zero_fraction = hot_cover / node_cover;
197    let cold_zero_fraction = cold_cover / node_cover;
198
199    // Check if this feature appeared earlier in the path.
200    let mut incoming_zero_fraction = 1.0;
201    let mut incoming_one_fraction = 1.0;
202    let mut duplicate_idx = None;
203
204    for (i, entry) in path.iter().enumerate().skip(1) {
205        if entry.feature_idx == split_feat {
206            incoming_zero_fraction = entry.zero_fraction;
207            incoming_one_fraction = entry.one_fraction;
208            duplicate_idx = Some(i);
209            break;
210        }
211    }
212
213    if let Some(dup) = duplicate_idx {
214        unwind_path(path, dup);
215    }
216
217    // If one child has zero cover, only recurse into the non-zero side.
218    // The zero-cover subtree has no training data and contributes nothing;
219    // extending the path with zero_fraction=0 would cause division-by-zero
220    // in unwind_path.
221    if hot_cover > 0.0 && cold_cover > 0.0 {
222        // Standard case: both children have data.
223        extend_path(
224            path,
225            hot_zero_fraction * incoming_zero_fraction,
226            incoming_one_fraction,
227            split_feat,
228        );
229        tree_shap_recursive(arena, covers, hot_child, features, shap_values, path);
230
231        // Unwind hot extension, then re-extend with cold parameters.
232        unwind_path(path, path.len() - 1);
233        extend_path(
234            path,
235            cold_zero_fraction * incoming_zero_fraction,
236            0.0, // sample doesn't go this way
237            split_feat,
238        );
239        tree_shap_recursive(arena, covers, cold_child, features, shap_values, path);
240
241        // Unwind cold extension.
242        unwind_path(path, path.len() - 1);
243    } else if hot_cover > 0.0 {
244        // Only hot child has data — recurse without adding a path entry
245        // for this feature (it has no decision power at this split).
246        tree_shap_recursive(arena, covers, hot_child, features, shap_values, path);
247    } else {
248        // Only cold child has data.
249        tree_shap_recursive(arena, covers, cold_child, features, shap_values, path);
250    }
251
252    // Restore duplicate if we unwound one earlier.
253    if duplicate_idx.is_some() {
254        extend_path(
255            path,
256            incoming_zero_fraction,
257            incoming_one_fraction,
258            split_feat,
259        );
260    }
261}
262
263/// Compute SHAP values for a single tree.
264///
265/// Returns per-feature SHAP contributions. The base value for a single tree
266/// is the expected leaf value (weighted by sample count).
267pub fn tree_shap_values(
268    arena: &TreeArena,
269    root: NodeId,
270    features: &[f64],
271    n_features: usize,
272) -> ShapValues {
273    let mut shap_values = vec![0.0; n_features];
274
275    if arena.n_nodes() == 0 || root.is_none() {
276        return ShapValues {
277            values: shap_values,
278            base_value: 0.0,
279        };
280    }
281
282    // Precompute correct covers (sum of leaf sample counts) for each node.
283    let covers = compute_covers(arena, root);
284
285    // Compute base value: expected leaf value weighted by covers.
286    let total_cover: f64 = arena
287        .is_leaf
288        .iter()
289        .enumerate()
290        .filter(|(_, &is_leaf)| is_leaf)
291        .map(|(i, _)| covers[i])
292        .sum();
293
294    let base_value = if total_cover > 0.0 {
295        arena
296            .leaf_value
297            .iter()
298            .zip(arena.is_leaf.iter())
299            .enumerate()
300            .filter(|(_, (_, &is_leaf))| is_leaf)
301            .map(|(i, (&val, _))| val * covers[i])
302            .sum::<f64>()
303            / total_cover
304    } else {
305        0.0
306    };
307
308    // Initialize path with a sentinel entry.
309    let mut path = Vec::with_capacity(32);
310    path.push(PathEntry {
311        feature_idx: -1,
312        zero_fraction: 1.0,
313        one_fraction: 1.0,
314        pweight: 1.0,
315    });
316
317    tree_shap_recursive(arena, &covers, root, features, &mut shap_values, &mut path);
318
319    ShapValues {
320        values: shap_values,
321        base_value,
322    }
323}
324
325/// Compute SHAP values for an SGBT ensemble.
326///
327/// Sums weighted SHAP contributions across all boosting steps.
328/// The base_value is the ensemble's base_prediction.
329pub fn ensemble_shap<L: crate::loss::Loss>(
330    model: &crate::ensemble::SGBT<L>,
331    features: &[f64],
332) -> ShapValues {
333    let n_features = model
334        .config()
335        .feature_names
336        .as_ref()
337        .map(|n| n.len())
338        .unwrap_or_else(|| features.len());
339
340    let lr = model.config().learning_rate;
341    let mut total_shap = vec![0.0; n_features];
342
343    for step in model.steps() {
344        let slot = step.slot();
345        let tree = slot.active_tree();
346        let arena = tree.arena();
347        let root = tree.root();
348
349        if arena.n_nodes() == 0 {
350            continue;
351        }
352
353        let tree_shap = tree_shap_values(arena, root, features, n_features);
354        for (i, v) in tree_shap.values.iter().enumerate() {
355            if i < total_shap.len() {
356                total_shap[i] += lr * v;
357            }
358        }
359    }
360
361    ShapValues {
362        values: total_shap,
363        base_value: model.base_prediction(),
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::tree::node::TreeArena;
371
372    #[test]
373    fn single_leaf_tree_all_shap_zero() {
374        let mut arena = TreeArena::new();
375        let root = arena.add_leaf(0);
376        arena.sample_count[root.idx()] = 100;
377        arena.leaf_value[root.idx()] = 5.0;
378
379        let shap = tree_shap_values(&arena, root, &[1.0, 2.0, 3.0], 3);
380        assert!((shap.base_value - 5.0).abs() < 1e-10);
381        for v in &shap.values {
382            assert!(v.abs() < 1e-10, "single-leaf SHAP should be 0, got {v}");
383        }
384    }
385
386    #[test]
387    fn two_level_tree_shap_invariant() {
388        // Build: root splits on feature 0 at threshold 0.5.
389        //   Left leaf: value = -1.0, 60 samples
390        //   Right leaf: value = 1.0, 40 samples
391        let mut arena = TreeArena::new();
392        let root = arena.add_leaf(0);
393        let (left, right) = arena.split_leaf(root, 0, 0.5, -1.0, 1.0);
394        arena.sample_count[root.idx()] = 100;
395        arena.sample_count[left.idx()] = 60;
396        arena.sample_count[right.idx()] = 40;
397
398        // Test: sample goes left (feature 0 = 0.3 <= 0.5).
399        let features = [0.3, 5.0];
400        let shap = tree_shap_values(&arena, root, &features, 2);
401
402        // Base value = (-1.0 * 60 + 1.0 * 40) / 100 = -0.2
403        let expected_base = -0.2;
404        assert!(
405            (shap.base_value - expected_base).abs() < 1e-10,
406            "base_value: got {}, expected {}",
407            shap.base_value,
408            expected_base
409        );
410
411        // Prediction = -1.0 (left leaf).
412        let prediction = -1.0;
413        let shap_sum: f64 = shap.values.iter().sum();
414        let reconstructed = shap.base_value + shap_sum;
415        assert!(
416            (reconstructed - prediction).abs() < 1e-8,
417            "SHAP invariant violated: base({}) + sum({}) = {} != prediction({})",
418            shap.base_value,
419            shap_sum,
420            reconstructed,
421            prediction
422        );
423
424        // Feature 1 has no splits — its SHAP should be 0.
425        assert!(
426            shap.values[1].abs() < 1e-10,
427            "non-split feature SHAP should be 0, got {}",
428            shap.values[1]
429        );
430    }
431
432    #[test]
433    fn shap_invariant_right_path() {
434        // Same tree, sample goes right (feature 0 = 0.7 > 0.5).
435        let mut arena = TreeArena::new();
436        let root = arena.add_leaf(0);
437        let (left, right) = arena.split_leaf(root, 0, 0.5, -1.0, 1.0);
438        arena.sample_count[root.idx()] = 100;
439        arena.sample_count[left.idx()] = 60;
440        arena.sample_count[right.idx()] = 40;
441
442        let features = [0.7, 5.0];
443        let shap = tree_shap_values(&arena, root, &features, 2);
444
445        let prediction = 1.0; // right leaf
446        let reconstructed = shap.base_value + shap.values.iter().sum::<f64>();
447        assert!(
448            (reconstructed - prediction).abs() < 1e-8,
449            "SHAP invariant violated for right path: {} != {}",
450            reconstructed,
451            prediction
452        );
453    }
454
455    #[test]
456    fn empty_tree_returns_zeros() {
457        let arena = TreeArena::new();
458        let shap = tree_shap_values(&arena, NodeId::NONE, &[1.0], 1);
459        assert_eq!(shap.base_value, 0.0);
460        assert_eq!(shap.values.len(), 1);
461        assert_eq!(shap.values[0], 0.0);
462    }
463
464    #[test]
465    fn ensemble_shap_integration() {
466        use crate::ensemble::config::SGBTConfig;
467        use crate::ensemble::SGBT;
468
469        let config = SGBTConfig::builder()
470            .n_steps(5)
471            .learning_rate(0.1)
472            .grace_period(10)
473            .max_depth(3)
474            .n_bins(8)
475            .build()
476            .unwrap();
477
478        let mut model = SGBT::new(config);
479
480        // Train on a simple linear signal: y = 2*x0 + 0.5*x1.
481        let mut rng: u64 = 42;
482        for _ in 0..200 {
483            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
484            let x0 = (rng >> 33) as f64 / (u32::MAX as f64);
485            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
486            let x1 = (rng >> 33) as f64 / (u32::MAX as f64);
487            let y = 2.0 * x0 + 0.5 * x1;
488            model.train_one(&(&[x0, x1][..], y));
489        }
490
491        let features = [0.5, 0.5];
492        let shap = ensemble_shap(&model, &features);
493
494        // Verify SHAP invariant.
495        let prediction = model.predict(&features);
496        let reconstructed = shap.base_value + shap.values.iter().sum::<f64>();
497        assert!(
498            (reconstructed - prediction).abs() < 0.1,
499            "ensemble SHAP invariant violated: {} != {} (diff={})",
500            reconstructed,
501            prediction,
502            (reconstructed - prediction).abs()
503        );
504    }
505}