irithyll/explain/
streaming.rs1use crate::explain::treeshap::ShapValues;
4
5#[derive(Debug, Clone)]
21pub struct StreamingShap {
22 running_mean: Vec<f64>,
23 count: u64,
24}
25
26impl StreamingShap {
27 pub fn new(n_features: usize) -> Self {
29 Self {
30 running_mean: vec![0.0; n_features],
31 count: 0,
32 }
33 }
34
35 pub fn update(&mut self, shap: &ShapValues) {
37 self.count += 1;
38 let n = self.count as f64;
39 for (i, &v) in shap.values.iter().enumerate() {
40 if i < self.running_mean.len() {
41 self.running_mean[i] += (v.abs() - self.running_mean[i]) / n;
42 }
43 }
44 }
45
46 pub fn importances(&self) -> &[f64] {
48 &self.running_mean
49 }
50
51 pub fn count(&self) -> u64 {
53 self.count
54 }
55
56 pub fn reset(&mut self) {
58 self.running_mean.iter_mut().for_each(|v| *v = 0.0);
59 self.count = 0;
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn streaming_shap_basic() {
69 let mut tracker = StreamingShap::new(2);
70
71 tracker.update(&ShapValues {
72 values: vec![1.0, -2.0],
73 base_value: 0.0,
74 });
75 assert_eq!(tracker.count(), 1);
76 assert!((tracker.importances()[0] - 1.0).abs() < 1e-10);
77 assert!((tracker.importances()[1] - 2.0).abs() < 1e-10);
78
79 tracker.update(&ShapValues {
80 values: vec![3.0, -4.0],
81 base_value: 0.0,
82 });
83 assert_eq!(tracker.count(), 2);
84 assert!((tracker.importances()[0] - 2.0).abs() < 1e-10);
86 assert!((tracker.importances()[1] - 3.0).abs() < 1e-10);
87 }
88
89 #[test]
90 fn streaming_shap_reset() {
91 let mut tracker = StreamingShap::new(2);
92 tracker.update(&ShapValues {
93 values: vec![5.0, -3.0],
94 base_value: 0.0,
95 });
96 tracker.reset();
97 assert_eq!(tracker.count(), 0);
98 assert_eq!(tracker.importances()[0], 0.0);
99 assert_eq!(tracker.importances()[1], 0.0);
100 }
101}