1use crate::tree::node::{NodeId, TreeArena};
8
9fn compute_covers(arena: &TreeArena, root: NodeId) -> Vec<f64> {
15 let n = arena.n_nodes();
16 let mut covers = vec![0.0; n];
17
18 fn fill(arena: &TreeArena, node: NodeId, covers: &mut [f64]) -> f64 {
19 let idx = node.idx();
20 if arena.is_leaf[idx] {
21 covers[idx] = arena.sample_count[idx] as f64;
22 return covers[idx];
23 }
24 let left = fill(arena, arena.left[idx], covers);
25 let right = fill(arena, arena.right[idx], covers);
26 covers[idx] = left + right;
27 covers[idx]
28 }
29
30 if !root.is_none() && root.idx() < n {
31 fill(arena, root, &mut covers);
32 }
33 covers
34}
35
36#[derive(Debug, Clone)]
42pub struct ShapValues {
43 pub values: Vec<f64>,
45 pub base_value: f64,
47}
48
49#[derive(Debug, Clone)]
51pub struct NamedShapValues {
52 pub values: Vec<(String, f64)>,
54 pub base_value: f64,
56}
57
58#[derive(Clone)]
64struct PathEntry {
65 feature_idx: i64,
67 zero_fraction: f64,
69 one_fraction: f64,
71 pweight: f64,
73}
74
75fn extend_path(path: &mut Vec<PathEntry>, zero_fraction: f64, one_fraction: f64, feature_idx: i64) {
77 let depth = path.len();
78 path.push(PathEntry {
79 feature_idx,
80 zero_fraction,
81 one_fraction,
82 pweight: if depth == 0 { 1.0 } else { 0.0 },
83 });
84
85 for i in (1..depth + 1).rev() {
87 path[i].pweight += one_fraction * path[i - 1].pweight * (i as f64) / ((depth + 1) as f64);
88 path[i - 1].pweight =
89 zero_fraction * path[i - 1].pweight * ((depth + 1 - i) as f64) / ((depth + 1) as f64);
90 }
91}
92
93fn unwind_path(path: &mut Vec<PathEntry>, path_idx: usize) {
95 let depth = path.len() - 1;
96 let one_fraction = path[path_idx].one_fraction;
97 let zero_fraction = path[path_idx].zero_fraction;
98
99 let mut next_one_portion = path[depth].pweight;
100
101 for i in (0..depth).rev() {
102 if one_fraction != 0.0 {
103 let tmp = path[i].pweight;
104 path[i].pweight =
105 next_one_portion * ((depth + 1 - i) as f64) / ((i + 1) as f64 * one_fraction);
106 next_one_portion =
107 tmp - path[i].pweight * zero_fraction * ((i + 1) as f64) / ((depth + 1 - i) as f64);
108 } else {
109 path[i].pweight =
110 path[i].pweight * ((depth + 1 - i) as f64) / (zero_fraction * (i + 1) as f64);
111 }
112 }
113
114 for i in path_idx..depth {
116 path[i] = path[i + 1].clone();
117 }
118 path.pop();
119}
120
121fn unwound_path_sum(path: &[PathEntry], path_idx: usize) -> f64 {
123 let depth = path.len() - 1;
124 let one_fraction = path[path_idx].one_fraction;
125 let zero_fraction = path[path_idx].zero_fraction;
126
127 let mut total = 0.0;
128 let mut next_one_portion = path[depth].pweight;
129
130 for i in (0..depth).rev() {
131 if one_fraction != 0.0 {
132 let tmp = next_one_portion * ((depth + 1 - i) as f64) / ((i + 1) as f64 * one_fraction);
133 total += tmp;
134 next_one_portion =
135 path[i].pweight - tmp * zero_fraction * ((i + 1) as f64) / ((depth + 1 - i) as f64);
136 } else {
137 total += path[i].pweight / (zero_fraction * (i + 1) as f64) * ((depth + 1 - i) as f64);
138 }
139 }
140 total
141}
142
143fn tree_shap_recursive(
145 arena: &TreeArena,
146 covers: &[f64],
147 node: NodeId,
148 features: &[f64],
149 shap_values: &mut [f64],
150 path: &mut Vec<PathEntry>,
151) {
152 let idx = node.idx();
153
154 if arena.is_leaf[idx] {
155 let leaf_value = arena.leaf_value[idx];
157 for i in 1..path.len() {
158 let w = unwound_path_sum(path, i);
159 let feat = path[i].feature_idx;
160 if feat >= 0 && (feat as usize) < shap_values.len() {
161 shap_values[feat as usize] +=
162 w * (path[i].one_fraction - path[i].zero_fraction) * leaf_value;
163 }
164 }
165 return;
166 }
167
168 let split_feat = arena.feature_idx[idx] as i64;
170 let threshold = arena.threshold[idx];
171 let left = arena.left[idx];
172 let right = arena.right[idx];
173
174 let left_cover = covers[left.idx()];
175 let right_cover = covers[right.idx()];
176 let node_cover = left_cover + right_cover;
177
178 if node_cover == 0.0 {
180 return;
181 }
182
183 let feat_val = if (split_feat as usize) < features.len() {
185 features[split_feat as usize]
186 } else {
187 0.0
188 };
189
190 let (hot_child, cold_child, hot_cover, cold_cover) = if feat_val <= threshold {
191 (left, right, left_cover, right_cover)
192 } else {
193 (right, left, right_cover, left_cover)
194 };
195
196 let hot_zero_fraction = hot_cover / node_cover;
197 let cold_zero_fraction = cold_cover / node_cover;
198
199 let mut incoming_zero_fraction = 1.0;
201 let mut incoming_one_fraction = 1.0;
202 let mut duplicate_idx = None;
203
204 for (i, entry) in path.iter().enumerate().skip(1) {
205 if entry.feature_idx == split_feat {
206 incoming_zero_fraction = entry.zero_fraction;
207 incoming_one_fraction = entry.one_fraction;
208 duplicate_idx = Some(i);
209 break;
210 }
211 }
212
213 if let Some(dup) = duplicate_idx {
214 unwind_path(path, dup);
215 }
216
217 if hot_cover > 0.0 && cold_cover > 0.0 {
222 extend_path(
224 path,
225 hot_zero_fraction * incoming_zero_fraction,
226 incoming_one_fraction,
227 split_feat,
228 );
229 tree_shap_recursive(arena, covers, hot_child, features, shap_values, path);
230
231 unwind_path(path, path.len() - 1);
233 extend_path(
234 path,
235 cold_zero_fraction * incoming_zero_fraction,
236 0.0, split_feat,
238 );
239 tree_shap_recursive(arena, covers, cold_child, features, shap_values, path);
240
241 unwind_path(path, path.len() - 1);
243 } else if hot_cover > 0.0 {
244 tree_shap_recursive(arena, covers, hot_child, features, shap_values, path);
247 } else {
248 tree_shap_recursive(arena, covers, cold_child, features, shap_values, path);
250 }
251
252 if duplicate_idx.is_some() {
254 extend_path(
255 path,
256 incoming_zero_fraction,
257 incoming_one_fraction,
258 split_feat,
259 );
260 }
261}
262
263pub fn tree_shap_values(
268 arena: &TreeArena,
269 root: NodeId,
270 features: &[f64],
271 n_features: usize,
272) -> ShapValues {
273 let mut shap_values = vec![0.0; n_features];
274
275 if arena.n_nodes() == 0 || root.is_none() {
276 return ShapValues {
277 values: shap_values,
278 base_value: 0.0,
279 };
280 }
281
282 let covers = compute_covers(arena, root);
284
285 let total_cover: f64 = arena
287 .is_leaf
288 .iter()
289 .enumerate()
290 .filter(|(_, &is_leaf)| is_leaf)
291 .map(|(i, _)| covers[i])
292 .sum();
293
294 let base_value = if total_cover > 0.0 {
295 arena
296 .leaf_value
297 .iter()
298 .zip(arena.is_leaf.iter())
299 .enumerate()
300 .filter(|(_, (_, &is_leaf))| is_leaf)
301 .map(|(i, (&val, _))| val * covers[i])
302 .sum::<f64>()
303 / total_cover
304 } else {
305 0.0
306 };
307
308 let mut path = Vec::with_capacity(32);
310 path.push(PathEntry {
311 feature_idx: -1,
312 zero_fraction: 1.0,
313 one_fraction: 1.0,
314 pweight: 1.0,
315 });
316
317 tree_shap_recursive(arena, &covers, root, features, &mut shap_values, &mut path);
318
319 ShapValues {
320 values: shap_values,
321 base_value,
322 }
323}
324
325pub fn ensemble_shap<L: crate::loss::Loss>(
330 model: &crate::ensemble::SGBT<L>,
331 features: &[f64],
332) -> ShapValues {
333 let n_features = model
334 .config()
335 .feature_names
336 .as_ref()
337 .map(|n| n.len())
338 .unwrap_or_else(|| features.len());
339
340 let lr = model.config().learning_rate;
341 let mut total_shap = vec![0.0; n_features];
342
343 for step in model.steps() {
344 let slot = step.slot();
345 let tree = slot.active_tree();
346 let arena = tree.arena();
347 let root = tree.root();
348
349 if arena.n_nodes() == 0 {
350 continue;
351 }
352
353 let tree_shap = tree_shap_values(arena, root, features, n_features);
354 for (i, v) in tree_shap.values.iter().enumerate() {
355 if i < total_shap.len() {
356 total_shap[i] += lr * v;
357 }
358 }
359 }
360
361 ShapValues {
362 values: total_shap,
363 base_value: model.base_prediction(),
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::tree::node::TreeArena;
371
372 #[test]
373 fn single_leaf_tree_all_shap_zero() {
374 let mut arena = TreeArena::new();
375 let root = arena.add_leaf(0);
376 arena.sample_count[root.idx()] = 100;
377 arena.leaf_value[root.idx()] = 5.0;
378
379 let shap = tree_shap_values(&arena, root, &[1.0, 2.0, 3.0], 3);
380 assert!((shap.base_value - 5.0).abs() < 1e-10);
381 for v in &shap.values {
382 assert!(v.abs() < 1e-10, "single-leaf SHAP should be 0, got {v}");
383 }
384 }
385
386 #[test]
387 fn two_level_tree_shap_invariant() {
388 let mut arena = TreeArena::new();
392 let root = arena.add_leaf(0);
393 let (left, right) = arena.split_leaf(root, 0, 0.5, -1.0, 1.0);
394 arena.sample_count[root.idx()] = 100;
395 arena.sample_count[left.idx()] = 60;
396 arena.sample_count[right.idx()] = 40;
397
398 let features = [0.3, 5.0];
400 let shap = tree_shap_values(&arena, root, &features, 2);
401
402 let expected_base = -0.2;
404 assert!(
405 (shap.base_value - expected_base).abs() < 1e-10,
406 "base_value: got {}, expected {}",
407 shap.base_value,
408 expected_base
409 );
410
411 let prediction = -1.0;
413 let shap_sum: f64 = shap.values.iter().sum();
414 let reconstructed = shap.base_value + shap_sum;
415 assert!(
416 (reconstructed - prediction).abs() < 1e-8,
417 "SHAP invariant violated: base({}) + sum({}) = {} != prediction({})",
418 shap.base_value,
419 shap_sum,
420 reconstructed,
421 prediction
422 );
423
424 assert!(
426 shap.values[1].abs() < 1e-10,
427 "non-split feature SHAP should be 0, got {}",
428 shap.values[1]
429 );
430 }
431
432 #[test]
433 fn shap_invariant_right_path() {
434 let mut arena = TreeArena::new();
436 let root = arena.add_leaf(0);
437 let (left, right) = arena.split_leaf(root, 0, 0.5, -1.0, 1.0);
438 arena.sample_count[root.idx()] = 100;
439 arena.sample_count[left.idx()] = 60;
440 arena.sample_count[right.idx()] = 40;
441
442 let features = [0.7, 5.0];
443 let shap = tree_shap_values(&arena, root, &features, 2);
444
445 let prediction = 1.0; let reconstructed = shap.base_value + shap.values.iter().sum::<f64>();
447 assert!(
448 (reconstructed - prediction).abs() < 1e-8,
449 "SHAP invariant violated for right path: {} != {}",
450 reconstructed,
451 prediction
452 );
453 }
454
455 #[test]
456 fn empty_tree_returns_zeros() {
457 let arena = TreeArena::new();
458 let shap = tree_shap_values(&arena, NodeId::NONE, &[1.0], 1);
459 assert_eq!(shap.base_value, 0.0);
460 assert_eq!(shap.values.len(), 1);
461 assert_eq!(shap.values[0], 0.0);
462 }
463
464 #[test]
465 fn ensemble_shap_integration() {
466 use crate::ensemble::config::SGBTConfig;
467 use crate::ensemble::SGBT;
468
469 let config = SGBTConfig::builder()
470 .n_steps(5)
471 .learning_rate(0.1)
472 .grace_period(10)
473 .max_depth(3)
474 .n_bins(8)
475 .build()
476 .unwrap();
477
478 let mut model = SGBT::new(config);
479
480 let mut rng: u64 = 42;
482 for _ in 0..200 {
483 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
484 let x0 = (rng >> 33) as f64 / (u32::MAX as f64);
485 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
486 let x1 = (rng >> 33) as f64 / (u32::MAX as f64);
487 let y = 2.0 * x0 + 0.5 * x1;
488 model.train_one(&(&[x0, x1][..], y));
489 }
490
491 let features = [0.5, 0.5];
492 let shap = ensemble_shap(&model, &features);
493
494 let prediction = model.predict(&features);
496 let reconstructed = shap.base_value + shap.values.iter().sum::<f64>();
497 assert!(
498 (reconstructed - prediction).abs() < 0.1,
499 "ensemble SHAP invariant violated: {} != {} (diff={})",
500 reconstructed,
501 prediction,
502 (reconstructed - prediction).abs()
503 );
504 }
505}