Skip to main content

anomstream_core/
attribution_stability.rs

1//! Inter-tree dispersion of the per-dim attribution vector.
2//!
3//! [`crate::RandomCutForest::attribution`] already returns the mean
4//! [`DiVector`] across every tree in the forest — but the mean hides
5//! how *unanimously* the trees agreed on that answer. An attribution
6//! of `high[4] + low[4] = 10` where every tree saw dim 4 as the top
7//! contributor is very different from the same mean synthesised from
8//! a handful of trees strongly flagging dim 4 and the rest flagging
9//! different dims. The second case is a lucky coincidence, not a
10//! stable signal.
11//!
12//! [`AttributionStability`] exposes the mean *and* the per-dim
13//! variance / stddev across trees, derives a coefficient of
14//! variation and a bounded `confidence ∈ [0, 1]` per dim, and offers
15//! two ways to pick the driver dimension:
16//!
17//! - [`AttributionStability::argmax_mean`] — classic
18//!   [`DiVector::argmax`] behaviour, ignores disagreement.
19//! - [`AttributionStability::argmax_weighted`] — picks the dim that
20//!   maximises `mean × confidence`; downranks dims where the trees
21//!   disagree. Safer for SOC-facing alerts.
22//!
23//! The helper runs one attribution visitor pass per tree, collects
24//! the per-tree [`DiVector`]s into a `Vec`, then computes mean and
25//! variance in two sweeps. For an AWS-default forest (100 trees, D=16)
26//! the extra allocation is ~26 KB.
27
28use alloc::vec;
29use alloc::vec::Vec;
30
31#[cfg(not(feature = "std"))]
32#[allow(unused_imports)]
33use num_traits::Float;
34
35use crate::domain::DiVector;
36use crate::domain::point::ensure_finite;
37use crate::error::{RcfError, RcfResult};
38use crate::forest::RandomCutForest;
39use crate::thresholded::ThresholdedForest;
40use crate::visitor::AttributionVisitor;
41
42/// Inter-tree dispersion of the attribution vector, paired with the
43/// mean.
44#[derive(Debug, Clone)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct AttributionStability {
47    /// Mean per-dim contribution across every tree that produced a
48    /// non-trivial attribution. Identical to what
49    /// [`crate::RandomCutForest::attribution`] returns.
50    mean: DiVector,
51    /// Per-dim population variance (not the unbiased estimator —
52    /// divides by `tree_count`, not `tree_count − 1`).
53    variance: Vec<f64>,
54    /// Per-dim standard deviation (`sqrt(variance)`), cached so
55    /// callers that query [`Self::confidence`] in a hot loop do not
56    /// re-square-root on every call.
57    stddev: Vec<f64>,
58    /// Number of trees that actually contributed — trees with an
59    /// empty reservoir are skipped so `tree_count` may be less than
60    /// the forest's configured `num_trees`.
61    tree_count: usize,
62}
63
64impl AttributionStability {
65    /// Mean attribution across trees.
66    #[must_use]
67    pub fn mean(&self) -> &DiVector {
68        &self.mean
69    }
70
71    /// Per-dim population variance of the contributions.
72    #[must_use]
73    pub fn variance(&self) -> &[f64] {
74        &self.variance
75    }
76
77    /// Per-dim standard deviation of the contributions.
78    #[must_use]
79    pub fn stddev(&self) -> &[f64] {
80        &self.stddev
81    }
82
83    /// Number of trees that contributed an attribution.
84    #[must_use]
85    pub fn tree_count(&self) -> usize {
86        self.tree_count
87    }
88
89    /// Per-point dimensionality.
90    #[must_use]
91    pub fn dim(&self) -> usize {
92        self.mean.dim()
93    }
94
95    /// Coefficient of variation for dim `d` — `stddev[d] / |mean[d]|`.
96    /// Returns `0.0` when `|mean[d]| < f64::EPSILON` (no dispersion is
97    /// observable when nothing was attributed).
98    ///
99    /// # Panics
100    ///
101    /// Panics when `d >= self.dim()` — callers size-check first.
102    #[must_use]
103    pub fn coefficient_of_variation(&self, d: usize) -> f64 {
104        let mean_abs = self.mean.per_dim_total(d).abs();
105        if mean_abs < f64::EPSILON {
106            return 0.0;
107        }
108        self.stddev[d] / mean_abs
109    }
110
111    /// Bounded `[0, 1]` confidence that dim `d`'s mean contribution
112    /// is a stable signal rather than a handful of trees agreeing by
113    /// chance. Derived as `1 / (1 + CV)` — `1.0` for perfect
114    /// agreement, falling monotonically as CV rises.
115    ///
116    /// # Panics
117    ///
118    /// Panics when `d >= self.dim()` — callers size-check first.
119    #[must_use]
120    pub fn confidence(&self, d: usize) -> f64 {
121        1.0 / (1.0 + self.coefficient_of_variation(d))
122    }
123
124    /// Classic [`DiVector::argmax`] — dim with the largest mean
125    /// contribution, independent of stability. Returns `None` on an
126    /// empty attribution vector.
127    #[must_use]
128    pub fn argmax_mean(&self) -> Option<usize> {
129        self.mean.argmax()
130    }
131
132    /// Dim maximising `mean × confidence`. Downranks dims where the
133    /// trees disagreed. Returns `None` on an empty attribution vector.
134    #[must_use]
135    pub fn argmax_weighted(&self) -> Option<usize> {
136        if self.dim() == 0 {
137            return None;
138        }
139        let mut best: usize = 0;
140        let mut best_val = self.mean.per_dim_total(0) * self.confidence(0);
141        for d in 1..self.dim() {
142            let v = self.mean.per_dim_total(d) * self.confidence(d);
143            if v > best_val {
144                best = d;
145                best_val = v;
146            }
147        }
148        Some(best)
149    }
150}
151
152/// Collect the per-tree attribution for `point`. Skips trees whose
153/// reservoir is still empty (identical to
154/// [`crate::RandomCutForest::attribution`]).
155fn collect_per_tree<const D: usize>(
156    forest: &RandomCutForest<D>,
157    point: &[f64; D],
158) -> RcfResult<Vec<DiVector>> {
159    let mut out = Vec::with_capacity(forest.num_trees());
160    for (tree, _, _) in forest.trees() {
161        let Some(root) = tree.root() else {
162            continue;
163        };
164        let mass = tree.store().view(root)?.mass();
165        let visitor = AttributionVisitor::new(point, mass)?;
166        let di = tree.traverse(point, visitor)?;
167        out.push(di);
168    }
169    Ok(out)
170}
171
172/// Compute [`AttributionStability`] from a collected per-tree set.
173/// Shared by the forest- and pool-level entry points.
174#[allow(clippy::cast_precision_loss)] // Tree counts are bounded by num_trees <= 1000.
175fn stability_from_collection<const D: usize>(
176    per_tree: &[DiVector],
177) -> RcfResult<AttributionStability> {
178    if per_tree.is_empty() {
179        return Err(RcfError::EmptyForest);
180    }
181    let tree_count = per_tree.len();
182    let divisor = tree_count as f64;
183
184    let mut mean = DiVector::zeros(D);
185    for di in per_tree {
186        mean.accumulate(di)?;
187    }
188    mean.scale(divisor)?;
189
190    let mut variance = vec![0.0_f64; D];
191    for di in per_tree {
192        for (d, var_d) in variance.iter_mut().enumerate().take(D) {
193            let delta = di.per_dim_total(d) - mean.per_dim_total(d);
194            *var_d += delta * delta;
195        }
196    }
197    for v in &mut variance {
198        *v /= divisor;
199    }
200    let stddev: Vec<f64> = variance.iter().map(|v| v.sqrt()).collect();
201
202    Ok(AttributionStability {
203        mean,
204        variance,
205        stddev,
206        tree_count,
207    })
208}
209
210impl<const D: usize> RandomCutForest<D> {
211    /// Inter-tree dispersion of the attribution vector on `point`.
212    ///
213    /// Returns the mean contribution per dim plus the per-dim
214    /// variance and stddev across trees — use
215    /// [`AttributionStability::confidence`] or
216    /// [`AttributionStability::argmax_weighted`] to pick a driver
217    /// dim that downranks tree-level disagreement.
218    ///
219    /// # Errors
220    ///
221    /// - [`RcfError::NaNValue`] when the point contains a non-finite
222    ///   component.
223    /// - [`RcfError::EmptyForest`] when no tree holds any leaf.
224    /// - Any error bubbled up from the per-tree attribution path.
225    pub fn attribution_stability(&self, point: &[f64; D]) -> RcfResult<AttributionStability> {
226        ensure_finite(point)?;
227        // Keep parity with `attribution()` — stored points are in the
228        // forest's scaled space, so the caller query must be scaled
229        // before walking the tree cuts.
230        let scaled = self.scale_point_copy(point);
231        let per_tree = collect_per_tree(self, &scaled)?;
232        stability_from_collection::<D>(&per_tree)
233    }
234}
235
236impl<const D: usize> ThresholdedForest<D> {
237    /// Inter-tree dispersion of the attribution on `point`. Delegates
238    /// to the underlying forest — the threshold layer does not
239    /// influence attribution.
240    ///
241    /// # Errors
242    ///
243    /// Same as [`RandomCutForest::attribution_stability`].
244    pub fn attribution_stability(&self, point: &[f64; D]) -> RcfResult<AttributionStability> {
245        self.forest().attribution_stability(point)
246    }
247}
248
249#[cfg(feature = "std")]
250impl<K, const D: usize> crate::pool::TenantForestPool<K, D>
251where
252    K: core::hash::Hash + Eq + Clone,
253{
254    /// Per-tenant attribution stability. Lazily instantiates the
255    /// tenant (like [`Self::process`]).
256    ///
257    /// # Errors
258    ///
259    /// Same as [`ThresholdedForest::attribution_stability`] plus
260    /// factory errors.
261    ///
262    /// # Panics
263    ///
264    /// Never under normal use — the fall-through branch forces a
265    /// slot via [`Self::score_only`] before re-borrowing through
266    /// [`Self::get_mut`]; the assertion only fires on an impossible
267    /// concurrent eviction through `&mut self`.
268    pub fn attribution_stability(
269        &mut self,
270        key: &K,
271        point: &[f64; D],
272    ) -> RcfResult<AttributionStability> {
273        if !self.contains(key) {
274            self.score_only(key, point)?;
275        }
276        let detector = self
277            .get_mut(key)
278            .expect("tenant was just forced into the pool");
279        detector.attribution_stability(point)
280    }
281}
282
283#[cfg(test)]
284#[allow(clippy::float_cmp)] // Tests assert bounds on closed-form quantities.
285mod tests {
286    use super::*;
287    use crate::ForestBuilder;
288
289    fn trained() -> RandomCutForest<2> {
290        let mut f = ForestBuilder::<2>::new()
291            .num_trees(50)
292            .sample_size(32)
293            .seed(2026)
294            .build()
295            .unwrap();
296        for i in 0_u32..256 {
297            let v = f64::from(i) * 0.01;
298            f.update([v, v + 0.5]).unwrap();
299        }
300        f
301    }
302
303    #[test]
304    fn empty_forest_errors() {
305        let f = ForestBuilder::<2>::new().seed(1).build().unwrap();
306        let err = f.attribution_stability(&[0.0, 0.0]).unwrap_err();
307        assert!(matches!(err, RcfError::EmptyForest));
308    }
309
310    #[test]
311    fn non_finite_point_rejected() {
312        let f = trained();
313        let err = f.attribution_stability(&[f64::NAN, 0.0]).unwrap_err();
314        assert!(matches!(err, RcfError::NaNValue));
315    }
316
317    #[test]
318    fn tree_count_matches_forest_size_on_trained_forest() {
319        let f = trained();
320        let s = f.attribution_stability(&[5.0, 5.0]).unwrap();
321        assert_eq!(s.tree_count(), 50);
322        assert_eq!(s.dim(), 2);
323    }
324
325    #[test]
326    fn mean_matches_plain_attribution() {
327        let f = trained();
328        let probe = [5.0_f64, 5.0];
329        let plain = f.attribution(&probe).unwrap();
330        let s = f.attribution_stability(&probe).unwrap();
331        // Under the `parallel` feature, `attribution()` uses rayon's
332        // reorder-safe fold/reduce, which can differ from this
333        // helper's serial sum in the last ULP. 1e-10 is orders of
334        // magnitude below any observable signal.
335        for d in 0..2 {
336            let delta = (plain.per_dim_total(d) - s.mean().per_dim_total(d)).abs();
337            assert!(delta < 1e-10, "dim {d} drift {delta}");
338        }
339    }
340
341    #[test]
342    fn variance_is_non_negative_per_dim() {
343        let f = trained();
344        let s = f.attribution_stability(&[5.0_f64, 5.0]).unwrap();
345        for v in s.variance() {
346            assert!(*v >= 0.0);
347        }
348        for sd in s.stddev() {
349            assert!(*sd >= 0.0);
350        }
351    }
352
353    #[test]
354    fn stddev_is_sqrt_of_variance() {
355        let f = trained();
356        let s = f.attribution_stability(&[5.0_f64, 5.0]).unwrap();
357        for d in 0..s.dim() {
358            assert!((s.stddev()[d] - s.variance()[d].sqrt()).abs() < 1e-12);
359        }
360    }
361
362    #[test]
363    fn confidence_is_one_when_variance_zero() {
364        // Every tree attributes exactly the same contribution → CV=0 → conf=1.
365        let mut mean = DiVector::zeros(3);
366        mean.add_high(0, 1.0).unwrap();
367        mean.add_low(1, 2.0).unwrap();
368        let s = AttributionStability {
369            mean,
370            variance: vec![0.0, 0.0, 0.0],
371            stddev: vec![0.0, 0.0, 0.0],
372            tree_count: 10,
373        };
374        assert!((s.confidence(0) - 1.0).abs() < f64::EPSILON);
375        assert!((s.confidence(1) - 1.0).abs() < f64::EPSILON);
376    }
377
378    #[test]
379    fn confidence_drops_monotonically_with_cv() {
380        let mut mean = DiVector::zeros(2);
381        mean.add_high(0, 1.0).unwrap();
382        mean.add_high(1, 1.0).unwrap();
383        let stable = AttributionStability {
384            mean: mean.clone(),
385            variance: vec![0.01_f64, 0.25],
386            stddev: vec![0.1_f64, 0.5],
387            tree_count: 10,
388        };
389        assert!(stable.confidence(0) > stable.confidence(1));
390    }
391
392    #[test]
393    fn coefficient_of_variation_is_zero_when_mean_zero() {
394        let mean = DiVector::zeros(1);
395        let s = AttributionStability {
396            mean,
397            variance: vec![1.0],
398            stddev: vec![1.0],
399            tree_count: 4,
400        };
401        // mean[0] == 0 → CV undefined → clamp to 0.
402        assert_eq!(s.coefficient_of_variation(0), 0.0);
403        assert!((s.confidence(0) - 1.0).abs() < f64::EPSILON);
404    }
405
406    #[test]
407    fn argmax_weighted_prefers_stable_dim_over_unstable() {
408        // mean[0] = 10 but very unstable (stddev=30 → CV=3 → conf ~ 0.25)
409        // mean[1] = 5 but stable (stddev=0.1 → CV=0.02 → conf ~ 0.98)
410        // weighted: 10 * 0.25 = 2.5  vs  5 * 0.98 = 4.9 → pick 1.
411        let mut mean = DiVector::zeros(2);
412        mean.add_high(0, 10.0).unwrap();
413        mean.add_high(1, 5.0).unwrap();
414        let s = AttributionStability {
415            mean,
416            variance: vec![900.0, 0.01],
417            stddev: vec![30.0, 0.1],
418            tree_count: 10,
419        };
420        assert_eq!(s.argmax_mean(), Some(0));
421        assert_eq!(s.argmax_weighted(), Some(1));
422    }
423
424    #[test]
425    fn argmax_weighted_empty_returns_none() {
426        let s = AttributionStability {
427            mean: DiVector::zeros(0),
428            variance: vec![],
429            stddev: vec![],
430            tree_count: 0,
431        };
432        assert!(s.argmax_weighted().is_none());
433        assert!(s.argmax_mean().is_none());
434    }
435
436    #[test]
437    fn stability_from_collection_rejects_empty() {
438        let err = stability_from_collection::<2>(&[]).unwrap_err();
439        assert!(matches!(err, RcfError::EmptyForest));
440    }
441}