Skip to main content

irithyll/explain/
streaming.rs

1//! Streaming SHAP — running mean of absolute SHAP values for online feature importance.
2
3use crate::explain::treeshap::ShapValues;
4
5/// Running mean of absolute SHAP values for online feature importance tracking.
6///
7/// Feed SHAP explanations as they are computed during training/prediction,
8/// and query the running importance scores at any time.
9///
10/// # Example
11///
12/// ```text
13/// let mut tracker = StreamingShap::new(n_features);
14/// for sample in stream {
15///     let shap = model.explain(&sample.features);
16///     tracker.update(&shap);
17/// }
18/// let importances = tracker.importances(); // running mean |SHAP|
19/// ```
20#[derive(Debug, Clone)]
21pub struct StreamingShap {
22    running_mean: Vec<f64>,
23    count: u64,
24}
25
26impl StreamingShap {
27    /// Create a new tracker for the given number of features.
28    pub fn new(n_features: usize) -> Self {
29        Self {
30            running_mean: vec![0.0; n_features],
31            count: 0,
32        }
33    }
34
35    /// Update with a new SHAP explanation, incrementing the running mean.
36    pub fn update(&mut self, shap: &ShapValues) {
37        self.count += 1;
38        let n = self.count as f64;
39        for (i, &v) in shap.values.iter().enumerate() {
40            if i < self.running_mean.len() {
41                self.running_mean[i] += (v.abs() - self.running_mean[i]) / n;
42            }
43        }
44    }
45
46    /// Current running mean of |SHAP| per feature.
47    pub fn importances(&self) -> &[f64] {
48        &self.running_mean
49    }
50
51    /// Number of explanations processed.
52    pub fn count(&self) -> u64 {
53        self.count
54    }
55
56    /// Reset to zero state.
57    pub fn reset(&mut self) {
58        self.running_mean.iter_mut().for_each(|v| *v = 0.0);
59        self.count = 0;
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn streaming_shap_basic() {
69        let mut tracker = StreamingShap::new(2);
70
71        tracker.update(&ShapValues {
72            values: vec![1.0, -2.0],
73            base_value: 0.0,
74        });
75        assert_eq!(tracker.count(), 1);
76        assert!((tracker.importances()[0] - 1.0).abs() < 1e-10);
77        assert!((tracker.importances()[1] - 2.0).abs() < 1e-10);
78
79        tracker.update(&ShapValues {
80            values: vec![3.0, -4.0],
81            base_value: 0.0,
82        });
83        assert_eq!(tracker.count(), 2);
84        // mean |SHAP| = (1+3)/2=2 for feat 0, (2+4)/2=3 for feat 1.
85        assert!((tracker.importances()[0] - 2.0).abs() < 1e-10);
86        assert!((tracker.importances()[1] - 3.0).abs() < 1e-10);
87    }
88
89    #[test]
90    fn streaming_shap_reset() {
91        let mut tracker = StreamingShap::new(2);
92        tracker.update(&ShapValues {
93            values: vec![5.0, -3.0],
94            base_value: 0.0,
95        });
96        tracker.reset();
97        assert_eq!(tracker.count(), 0);
98        assert_eq!(tracker.importances()[0], 0.0);
99        assert_eq!(tracker.importances()[1], 0.0);
100    }
101}