Skip to main content

anomstream_core/domain/
divector.rs

1//! Per-dimension attribution vector.
2//!
3//! [`DiVector`] tracks two `f64` accumulators per dimension: `high[d]`
4//! holds the contribution to the score of cuts where the queried point
5//! lies *above* the cut, `low[d]` of cuts where the point lies *below*
6//! the cut. Summing `high[d] + low[d]` per dimension yields the total
7//! per-feature contribution; the largest component identifies the most
8//! anomalous dimension. The shape mirrors the AWS RCF reference and
9//! Guha et al. (2016) attribution algorithm.
10
11use alloc::format;
12use alloc::vec;
13use alloc::vec::Vec;
14
15use crate::error::{RcfError, RcfResult};
16
17/// Two-sided per-dimension attribution accumulator.
18#[derive(Debug, Clone, PartialEq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct DiVector {
21    /// Contribution from cuts where the point lies above the cut.
22    high: Vec<f64>,
23    /// Contribution from cuts where the point lies below the cut.
24    low: Vec<f64>,
25}
26
27impl DiVector {
28    /// Build a zeroed `DiVector` with `dim` dimensions.
29    ///
30    /// # Examples
31    ///
32    /// ```
33    /// use anomstream_core::domain::DiVector;
34    /// let v = DiVector::zeros(4);
35    /// assert_eq!(v.dim(), 4);
36    /// assert_eq!(v.total(), 0.0);
37    /// ```
38    #[must_use]
39    pub fn zeros(dim: usize) -> Self {
40        Self {
41            high: vec![0.0; dim],
42            low: vec![0.0; dim],
43        }
44    }
45
46    /// Build a `DiVector` from explicit `high` / `low` vectors.
47    /// Useful for tests and downstream code that wants to pipe a
48    /// ready-made attribution back into a triage helper (e.g.
49    /// `anomstream_triage::AlertClusterer`).
50    ///
51    /// # Errors
52    ///
53    /// Returns [`RcfError::DimensionMismatch`] when
54    /// `high.len() != low.len()`.
55    pub fn from_arrays(high: Vec<f64>, low: Vec<f64>) -> RcfResult<Self> {
56        if high.len() != low.len() {
57            return Err(RcfError::DimensionMismatch {
58                expected: high.len(),
59                got: low.len(),
60            });
61        }
62        Ok(Self { high, low })
63    }
64
65    /// Dimensionality.
66    #[must_use]
67    pub fn dim(&self) -> usize {
68        self.high.len()
69    }
70
71    /// Read-only view of the upper-side contributions.
72    #[must_use]
73    pub fn high(&self) -> &[f64] {
74        &self.high
75    }
76
77    /// Read-only view of the lower-side contributions.
78    #[must_use]
79    pub fn low(&self) -> &[f64] {
80        &self.low
81    }
82
83    /// Sum of all `high[d]` and `low[d]` entries.
84    #[must_use]
85    pub fn total(&self) -> f64 {
86        self.high.iter().sum::<f64>() + self.low.iter().sum::<f64>()
87    }
88
89    /// Per-dimension total: `high[d] + low[d]`.
90    ///
91    /// # Panics
92    ///
93    /// Panics when `d >= self.dim()` — call sites size-check first.
94    #[must_use]
95    pub fn per_dim_total(&self, d: usize) -> f64 {
96        self.high[d] + self.low[d]
97    }
98
99    /// Index of the dimension with the largest `high[d] + low[d]`.
100    /// Returns `None` for an empty vector.
101    #[must_use]
102    pub fn argmax(&self) -> Option<usize> {
103        if self.dim() == 0 {
104            return None;
105        }
106        let mut best = 0_usize;
107        let mut best_val = self.per_dim_total(0);
108        for d in 1..self.dim() {
109            let v = self.per_dim_total(d);
110            if v > best_val {
111                best = d;
112                best_val = v;
113            }
114        }
115        Some(best)
116    }
117
118    /// Add `value` to the upper-side contribution for dimension `d`.
119    ///
120    /// # Errors
121    ///
122    /// Returns [`RcfError::OutOfBounds`] when `d >= self.dim()`.
123    pub fn add_high(&mut self, d: usize, value: f64) -> RcfResult<()> {
124        if d >= self.high.len() {
125            return Err(RcfError::OutOfBounds {
126                index: d,
127                len: self.high.len(),
128            });
129        }
130        self.high[d] += value;
131        Ok(())
132    }
133
134    /// Add `value` to the lower-side contribution for dimension `d`.
135    ///
136    /// # Errors
137    ///
138    /// Returns [`RcfError::OutOfBounds`] when `d >= self.dim()`.
139    pub fn add_low(&mut self, d: usize, value: f64) -> RcfResult<()> {
140        if d >= self.low.len() {
141            return Err(RcfError::OutOfBounds {
142                index: d,
143                len: self.low.len(),
144            });
145        }
146        self.low[d] += value;
147        Ok(())
148    }
149
150    /// Element-wise add `other` into `self`.
151    ///
152    /// # Errors
153    ///
154    /// Returns [`RcfError::DimensionMismatch`] when dimensions differ.
155    pub fn accumulate(&mut self, other: &Self) -> RcfResult<()> {
156        if other.dim() != self.dim() {
157            return Err(RcfError::DimensionMismatch {
158                expected: self.dim(),
159                got: other.dim(),
160            });
161        }
162        for d in 0..self.dim() {
163            self.high[d] += other.high[d];
164            self.low[d] += other.low[d];
165        }
166        Ok(())
167    }
168
169    /// Divide every component by `divisor` in place. Used by the
170    /// forest layer to convert a sum of per-tree attributions into a
171    /// mean.
172    ///
173    /// # Errors
174    ///
175    /// Returns [`RcfError::InvalidConfig`] when `divisor` is zero or
176    /// non-finite.
177    pub fn scale(&mut self, divisor: f64) -> RcfResult<()> {
178        if divisor == 0.0 || !divisor.is_finite() {
179            return Err(RcfError::InvalidConfig(
180                format!("DiVector::scale divisor must be non-zero and finite, got {divisor}")
181                    .into(),
182            ));
183        }
184        for d in 0..self.dim() {
185            self.high[d] /= divisor;
186            self.low[d] /= divisor;
187        }
188        Ok(())
189    }
190}
191
192#[cfg(test)]
193#[allow(clippy::float_cmp)] // Tests assert exact equality on integer-valued accumulations.
194mod tests {
195    use super::*;
196
197    #[test]
198    fn zeros_creates_dim_sized_vector() {
199        let v = DiVector::zeros(5);
200        assert_eq!(v.dim(), 5);
201        assert_eq!(v.high(), &[0.0; 5]);
202        assert_eq!(v.low(), &[0.0; 5]);
203        assert_eq!(v.total(), 0.0);
204    }
205
206    #[test]
207    fn add_high_and_low_accumulate() {
208        let mut v = DiVector::zeros(3);
209        v.add_high(0, 1.0).unwrap();
210        v.add_high(0, 2.0).unwrap();
211        v.add_low(2, 4.0).unwrap();
212        assert_eq!(v.high(), &[3.0, 0.0, 0.0]);
213        assert_eq!(v.low(), &[0.0, 0.0, 4.0]);
214        assert_eq!(v.total(), 7.0);
215        assert_eq!(v.per_dim_total(0), 3.0);
216        assert_eq!(v.per_dim_total(2), 4.0);
217    }
218
219    #[test]
220    fn add_high_oob() {
221        let mut v = DiVector::zeros(2);
222        let err = v.add_high(3, 1.0).unwrap_err();
223        assert!(matches!(err, RcfError::OutOfBounds { index: 3, len: 2 }));
224    }
225
226    #[test]
227    fn add_low_oob() {
228        let mut v = DiVector::zeros(2);
229        assert!(matches!(
230            v.add_low(99, 1.0).unwrap_err(),
231            RcfError::OutOfBounds { .. }
232        ));
233    }
234
235    #[test]
236    fn accumulate_sums_componentwise() {
237        let mut a = DiVector::zeros(2);
238        a.add_high(0, 1.0).unwrap();
239        a.add_low(1, 2.0).unwrap();
240        let mut b = DiVector::zeros(2);
241        b.add_high(0, 4.0).unwrap();
242        b.add_low(1, 8.0).unwrap();
243        a.accumulate(&b).unwrap();
244        assert_eq!(a.high(), &[5.0, 0.0]);
245        assert_eq!(a.low(), &[0.0, 10.0]);
246    }
247
248    #[test]
249    fn accumulate_rejects_dim_mismatch() {
250        let mut a = DiVector::zeros(2);
251        let b = DiVector::zeros(3);
252        assert!(matches!(
253            a.accumulate(&b).unwrap_err(),
254            RcfError::DimensionMismatch { .. }
255        ));
256    }
257
258    #[test]
259    fn scale_divides_componentwise() {
260        let mut v = DiVector::zeros(2);
261        v.add_high(0, 10.0).unwrap();
262        v.add_low(1, 6.0).unwrap();
263        v.scale(2.0).unwrap();
264        assert_eq!(v.high(), &[5.0, 0.0]);
265        assert_eq!(v.low(), &[0.0, 3.0]);
266    }
267
268    #[test]
269    fn scale_rejects_zero() {
270        let mut v = DiVector::zeros(1);
271        assert!(matches!(
272            v.scale(0.0).unwrap_err(),
273            RcfError::InvalidConfig(_)
274        ));
275    }
276
277    #[test]
278    fn scale_rejects_nan_infinity() {
279        let mut v = DiVector::zeros(1);
280        assert!(v.scale(f64::NAN).is_err());
281        assert!(v.scale(f64::INFINITY).is_err());
282    }
283
284    #[test]
285    fn argmax_picks_largest() {
286        let mut v = DiVector::zeros(4);
287        v.add_high(2, 5.0).unwrap();
288        v.add_low(1, 1.0).unwrap();
289        assert_eq!(v.argmax(), Some(2));
290    }
291
292    #[test]
293    fn argmax_zero_dim_returns_none() {
294        let v = DiVector::zeros(0);
295        assert!(v.argmax().is_none());
296    }
297
298    #[test]
299    fn argmax_ties_returns_first() {
300        let mut v = DiVector::zeros(3);
301        v.add_high(0, 5.0).unwrap();
302        v.add_high(2, 5.0).unwrap();
303        assert_eq!(v.argmax(), Some(0));
304    }
305}