anomstream_core/
attribution_stability.rs1use alloc::vec;
29use alloc::vec::Vec;
30
31#[cfg(not(feature = "std"))]
32#[allow(unused_imports)]
33use num_traits::Float;
34
35use crate::domain::DiVector;
36use crate::domain::point::ensure_finite;
37use crate::error::{RcfError, RcfResult};
38use crate::forest::RandomCutForest;
39use crate::thresholded::ThresholdedForest;
40use crate::visitor::AttributionVisitor;
41
42#[derive(Debug, Clone)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct AttributionStability {
47 mean: DiVector,
51 variance: Vec<f64>,
54 stddev: Vec<f64>,
58 tree_count: usize,
62}
63
64impl AttributionStability {
65 #[must_use]
67 pub fn mean(&self) -> &DiVector {
68 &self.mean
69 }
70
71 #[must_use]
73 pub fn variance(&self) -> &[f64] {
74 &self.variance
75 }
76
77 #[must_use]
79 pub fn stddev(&self) -> &[f64] {
80 &self.stddev
81 }
82
83 #[must_use]
85 pub fn tree_count(&self) -> usize {
86 self.tree_count
87 }
88
89 #[must_use]
91 pub fn dim(&self) -> usize {
92 self.mean.dim()
93 }
94
95 #[must_use]
103 pub fn coefficient_of_variation(&self, d: usize) -> f64 {
104 let mean_abs = self.mean.per_dim_total(d).abs();
105 if mean_abs < f64::EPSILON {
106 return 0.0;
107 }
108 self.stddev[d] / mean_abs
109 }
110
111 #[must_use]
120 pub fn confidence(&self, d: usize) -> f64 {
121 1.0 / (1.0 + self.coefficient_of_variation(d))
122 }
123
124 #[must_use]
128 pub fn argmax_mean(&self) -> Option<usize> {
129 self.mean.argmax()
130 }
131
132 #[must_use]
135 pub fn argmax_weighted(&self) -> Option<usize> {
136 if self.dim() == 0 {
137 return None;
138 }
139 let mut best: usize = 0;
140 let mut best_val = self.mean.per_dim_total(0) * self.confidence(0);
141 for d in 1..self.dim() {
142 let v = self.mean.per_dim_total(d) * self.confidence(d);
143 if v > best_val {
144 best = d;
145 best_val = v;
146 }
147 }
148 Some(best)
149 }
150}
151
152fn collect_per_tree<const D: usize>(
156 forest: &RandomCutForest<D>,
157 point: &[f64; D],
158) -> RcfResult<Vec<DiVector>> {
159 let mut out = Vec::with_capacity(forest.num_trees());
160 for (tree, _, _) in forest.trees() {
161 let Some(root) = tree.root() else {
162 continue;
163 };
164 let mass = tree.store().view(root)?.mass();
165 let visitor = AttributionVisitor::new(point, mass)?;
166 let di = tree.traverse(point, visitor)?;
167 out.push(di);
168 }
169 Ok(out)
170}
171
172#[allow(clippy::cast_precision_loss)] fn stability_from_collection<const D: usize>(
176 per_tree: &[DiVector],
177) -> RcfResult<AttributionStability> {
178 if per_tree.is_empty() {
179 return Err(RcfError::EmptyForest);
180 }
181 let tree_count = per_tree.len();
182 let divisor = tree_count as f64;
183
184 let mut mean = DiVector::zeros(D);
185 for di in per_tree {
186 mean.accumulate(di)?;
187 }
188 mean.scale(divisor)?;
189
190 let mut variance = vec![0.0_f64; D];
191 for di in per_tree {
192 for (d, var_d) in variance.iter_mut().enumerate().take(D) {
193 let delta = di.per_dim_total(d) - mean.per_dim_total(d);
194 *var_d += delta * delta;
195 }
196 }
197 for v in &mut variance {
198 *v /= divisor;
199 }
200 let stddev: Vec<f64> = variance.iter().map(|v| v.sqrt()).collect();
201
202 Ok(AttributionStability {
203 mean,
204 variance,
205 stddev,
206 tree_count,
207 })
208}
209
210impl<const D: usize> RandomCutForest<D> {
211 pub fn attribution_stability(&self, point: &[f64; D]) -> RcfResult<AttributionStability> {
226 ensure_finite(point)?;
227 let scaled = self.scale_point_copy(point);
231 let per_tree = collect_per_tree(self, &scaled)?;
232 stability_from_collection::<D>(&per_tree)
233 }
234}
235
236impl<const D: usize> ThresholdedForest<D> {
237 pub fn attribution_stability(&self, point: &[f64; D]) -> RcfResult<AttributionStability> {
245 self.forest().attribution_stability(point)
246 }
247}
248
249#[cfg(feature = "std")]
250impl<K, const D: usize> crate::pool::TenantForestPool<K, D>
251where
252 K: core::hash::Hash + Eq + Clone,
253{
254 pub fn attribution_stability(
269 &mut self,
270 key: &K,
271 point: &[f64; D],
272 ) -> RcfResult<AttributionStability> {
273 if !self.contains(key) {
274 self.score_only(key, point)?;
275 }
276 let detector = self
277 .get_mut(key)
278 .expect("tenant was just forced into the pool");
279 detector.attribution_stability(point)
280 }
281}
282
283#[cfg(test)]
284#[allow(clippy::float_cmp)] mod tests {
286 use super::*;
287 use crate::ForestBuilder;
288
289 fn trained() -> RandomCutForest<2> {
290 let mut f = ForestBuilder::<2>::new()
291 .num_trees(50)
292 .sample_size(32)
293 .seed(2026)
294 .build()
295 .unwrap();
296 for i in 0_u32..256 {
297 let v = f64::from(i) * 0.01;
298 f.update([v, v + 0.5]).unwrap();
299 }
300 f
301 }
302
303 #[test]
304 fn empty_forest_errors() {
305 let f = ForestBuilder::<2>::new().seed(1).build().unwrap();
306 let err = f.attribution_stability(&[0.0, 0.0]).unwrap_err();
307 assert!(matches!(err, RcfError::EmptyForest));
308 }
309
310 #[test]
311 fn non_finite_point_rejected() {
312 let f = trained();
313 let err = f.attribution_stability(&[f64::NAN, 0.0]).unwrap_err();
314 assert!(matches!(err, RcfError::NaNValue));
315 }
316
317 #[test]
318 fn tree_count_matches_forest_size_on_trained_forest() {
319 let f = trained();
320 let s = f.attribution_stability(&[5.0, 5.0]).unwrap();
321 assert_eq!(s.tree_count(), 50);
322 assert_eq!(s.dim(), 2);
323 }
324
325 #[test]
326 fn mean_matches_plain_attribution() {
327 let f = trained();
328 let probe = [5.0_f64, 5.0];
329 let plain = f.attribution(&probe).unwrap();
330 let s = f.attribution_stability(&probe).unwrap();
331 for d in 0..2 {
336 let delta = (plain.per_dim_total(d) - s.mean().per_dim_total(d)).abs();
337 assert!(delta < 1e-10, "dim {d} drift {delta}");
338 }
339 }
340
341 #[test]
342 fn variance_is_non_negative_per_dim() {
343 let f = trained();
344 let s = f.attribution_stability(&[5.0_f64, 5.0]).unwrap();
345 for v in s.variance() {
346 assert!(*v >= 0.0);
347 }
348 for sd in s.stddev() {
349 assert!(*sd >= 0.0);
350 }
351 }
352
353 #[test]
354 fn stddev_is_sqrt_of_variance() {
355 let f = trained();
356 let s = f.attribution_stability(&[5.0_f64, 5.0]).unwrap();
357 for d in 0..s.dim() {
358 assert!((s.stddev()[d] - s.variance()[d].sqrt()).abs() < 1e-12);
359 }
360 }
361
362 #[test]
363 fn confidence_is_one_when_variance_zero() {
364 let mut mean = DiVector::zeros(3);
366 mean.add_high(0, 1.0).unwrap();
367 mean.add_low(1, 2.0).unwrap();
368 let s = AttributionStability {
369 mean,
370 variance: vec![0.0, 0.0, 0.0],
371 stddev: vec![0.0, 0.0, 0.0],
372 tree_count: 10,
373 };
374 assert!((s.confidence(0) - 1.0).abs() < f64::EPSILON);
375 assert!((s.confidence(1) - 1.0).abs() < f64::EPSILON);
376 }
377
378 #[test]
379 fn confidence_drops_monotonically_with_cv() {
380 let mut mean = DiVector::zeros(2);
381 mean.add_high(0, 1.0).unwrap();
382 mean.add_high(1, 1.0).unwrap();
383 let stable = AttributionStability {
384 mean: mean.clone(),
385 variance: vec![0.01_f64, 0.25],
386 stddev: vec![0.1_f64, 0.5],
387 tree_count: 10,
388 };
389 assert!(stable.confidence(0) > stable.confidence(1));
390 }
391
392 #[test]
393 fn coefficient_of_variation_is_zero_when_mean_zero() {
394 let mean = DiVector::zeros(1);
395 let s = AttributionStability {
396 mean,
397 variance: vec![1.0],
398 stddev: vec![1.0],
399 tree_count: 4,
400 };
401 assert_eq!(s.coefficient_of_variation(0), 0.0);
403 assert!((s.confidence(0) - 1.0).abs() < f64::EPSILON);
404 }
405
406 #[test]
407 fn argmax_weighted_prefers_stable_dim_over_unstable() {
408 let mut mean = DiVector::zeros(2);
412 mean.add_high(0, 10.0).unwrap();
413 mean.add_high(1, 5.0).unwrap();
414 let s = AttributionStability {
415 mean,
416 variance: vec![900.0, 0.01],
417 stddev: vec![30.0, 0.1],
418 tree_count: 10,
419 };
420 assert_eq!(s.argmax_mean(), Some(0));
421 assert_eq!(s.argmax_weighted(), Some(1));
422 }
423
424 #[test]
425 fn argmax_weighted_empty_returns_none() {
426 let s = AttributionStability {
427 mean: DiVector::zeros(0),
428 variance: vec![],
429 stddev: vec![],
430 tree_count: 0,
431 };
432 assert!(s.argmax_weighted().is_none());
433 assert!(s.argmax_mean().is_none());
434 }
435
436 #[test]
437 fn stability_from_collection_rejects_empty() {
438 let err = stability_from_collection::<2>(&[]).unwrap_err();
439 assert!(matches!(err, RcfError::EmptyForest));
440 }
441}