Skip to main content

anomstream_core/domain/
bounding_box.rs

1//! Axis-aligned bounding box for `D`-dimensional points.
2//!
3//! [`BoundingBox<D>`] is a value object: every mutating operation
4//! ([`extend`](BoundingBox::extend), [`merge_with`](BoundingBox::merge_with))
5//! takes `&mut self` but the box is otherwise treated as a plain data
6//! container with structural equality.
7//!
8//! Storage is a stack-allocated `[f64; D]` so the compiler can unroll
9//! per-dim loops, vectorise via SIMD, and avoid all heap traffic. The
10//! AWS-default `feature_dim = 16` is the canonical instantiation.
11//!
12//! The cut probability machinery follows Guha et al. (2016), §3:
13//! the probability that a uniform random cut of the box augmented by
14//! `point` would isolate `point` from the rest equals
15//! `Σ_d Δ_d / total_range_after`, where `Δ_d` is the per-dimension
16//! extension caused by including `point`. Both [`probability_of_cut`]
17//! and [`per_dim_cut_probabilities`] return a tuple `(total, per_dim)`
18//! so callers (e.g. the future `AttributionVisitor`) can reuse the
19//! per-dim breakdown for attribution without recomputing it.
20//!
21//! [`probability_of_cut`]: BoundingBox::probability_of_cut
22//! [`per_dim_cut_probabilities`]: BoundingBox::per_dim_cut_probabilities
23
24use wide::f64x4;
25
26use crate::domain::cut::Cut;
27use crate::error::{RcfError, RcfResult};
28
29/// Axis-aligned bounding box for `D`-dimensional points. Storage is
30/// stack-allocated `[f64; D]` so the compiler can unroll the
31/// per-dim loops, vectorise via SIMD, and avoid any heap traffic.
32///
33/// # Examples
34///
35/// ```
36/// use anomstream_core::BoundingBox;
37///
38/// let mut bbox = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
39/// bbox.extend(&[3.0, 4.0]).unwrap();
40/// assert_eq!(bbox.range_sum(), 7.0);
41/// ```
42#[derive(Debug, Clone, PartialEq)]
43#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
44pub struct BoundingBox<const D: usize> {
45    /// Per-dimension lower corner. Serialised through
46    /// [`crate::serde_util::fixed_array_f64`] because `serde` does
47    /// not yet ship `Deserialize` for `[T; N]` at arbitrary `N`.
48    #[cfg_attr(feature = "serde", serde(with = "crate::serde_util::fixed_array_f64"))]
49    min: [f64; D],
50    /// Per-dimension upper corner.
51    #[cfg_attr(feature = "serde", serde(with = "crate::serde_util::fixed_array_f64"))]
52    max: [f64; D],
53}
54
55impl<const D: usize> BoundingBox<D> {
56    /// Build a degenerate bounding box from a single point.
57    ///
58    /// # Errors
59    ///
60    /// - [`RcfError::EmptyBoundingBox`] when `D == 0`.
61    /// - [`RcfError::DimensionMismatch`] when `point.len() != D`.
62    pub fn from_point(point: &[f64]) -> RcfResult<Self> {
63        if D == 0 {
64            return Err(RcfError::EmptyBoundingBox);
65        }
66        if point.len() != D {
67            return Err(RcfError::DimensionMismatch {
68                expected: D,
69                got: point.len(),
70            });
71        }
72        let mut min = [0.0_f64; D];
73        let mut max = [0.0_f64; D];
74        min.copy_from_slice(point);
75        max.copy_from_slice(point);
76        Ok(Self { min, max })
77    }
78
79    /// Dimensionality of the box (compile-time constant `D`).
80    #[must_use]
81    #[inline]
82    pub const fn dim(&self) -> usize {
83        D
84    }
85
86    /// Per-dimension lower corner.
87    #[must_use]
88    #[inline]
89    pub fn min(&self) -> &[f64; D] {
90        &self.min
91    }
92
93    /// Per-dimension upper corner.
94    #[must_use]
95    #[inline]
96    pub fn max(&self) -> &[f64; D] {
97        &self.max
98    }
99
100    /// Range (`max_d − min_d`) for dimension `d`.
101    ///
102    /// # Panics
103    ///
104    /// Panics when `d >= D` — call sites are internal and always
105    /// size-checked.
106    #[must_use]
107    #[inline]
108    pub fn range_at(&self, d: usize) -> f64 {
109        self.max[d] - self.min[d]
110    }
111
112    /// Sum of per-dimension ranges (`Σ_d (max_d − min_d)`).
113    ///
114    /// This is the denominator used by [`Cut::random_cut`] to pick a
115    /// dimension weighted by its range. Vectorised in 4-lane f64
116    /// chunks via [`wide::f64x4`] for the AWS-default `D = 16` hot
117    /// path; a scalar tail handles dims that are not a multiple of 4.
118    ///
119    /// [`Cut::random_cut`]: crate::domain::Cut::random_cut
120    #[must_use]
121    #[inline]
122    pub fn range_sum(&self) -> f64 {
123        let chunks = D / 4;
124        let mut acc_simd = f64x4::splat(0.0);
125        for i in 0..chunks {
126            let off = i * 4;
127            let mn = f64x4::from([
128                self.min[off],
129                self.min[off + 1],
130                self.min[off + 2],
131                self.min[off + 3],
132            ]);
133            let mx = f64x4::from([
134                self.max[off],
135                self.max[off + 1],
136                self.max[off + 2],
137                self.max[off + 3],
138            ]);
139            acc_simd += mx - mn;
140        }
141        let mut s = acc_simd.reduce_add();
142        for d in (chunks * 4)..D {
143            s += self.max[d] - self.min[d];
144        }
145        s
146    }
147
148    /// Extend the box in place to include `point`.
149    ///
150    /// # Errors
151    ///
152    /// Returns [`RcfError::DimensionMismatch`] when `point.len() != D`.
153    pub fn extend(&mut self, point: &[f64]) -> RcfResult<()> {
154        if point.len() != D {
155            return Err(RcfError::DimensionMismatch {
156                expected: D,
157                got: point.len(),
158            });
159        }
160        for (d, &v) in point.iter().enumerate() {
161            if v < self.min[d] {
162                self.min[d] = v;
163            }
164            if v > self.max[d] {
165                self.max[d] = v;
166            }
167        }
168        Ok(())
169    }
170
171    /// Merge `other` into `self` in place. Both boxes have the same
172    /// type-level dimensionality so this is infallible.
173    pub fn merge_with(&mut self, other: &Self) {
174        for d in 0..D {
175            if other.min[d] < self.min[d] {
176                self.min[d] = other.min[d];
177            }
178            if other.max[d] > self.max[d] {
179                self.max[d] = other.max[d];
180            }
181        }
182    }
183
184    /// Return a new box equal to the union of `self` and `other`.
185    #[must_use]
186    pub fn merged(&self, other: &Self) -> Self {
187        let mut out = self.clone();
188        out.merge_with(other);
189        out
190    }
191
192    /// Per-dimension extension required to accommodate `point` —
193    /// `Δ_d = max(0, point_d − max_d) + max(0, min_d − point_d)`.
194    ///
195    /// When `point` already lies inside the box every `Δ_d` is `0` and
196    /// the cut probability is `0`.
197    ///
198    /// # Errors
199    ///
200    /// Returns [`RcfError::DimensionMismatch`] when `point.len() != D`.
201    pub fn extension_per_dim(&self, point: &[f64]) -> RcfResult<[f64; D]> {
202        if point.len() != D {
203            return Err(RcfError::DimensionMismatch {
204                expected: D,
205                got: point.len(),
206            });
207        }
208        let mut out = [0.0_f64; D];
209        for d in 0..D {
210            let above = point[d] - self.max[d];
211            let below = self.min[d] - point[d];
212            let mut delta = 0.0;
213            if above > 0.0 {
214                delta += above;
215            }
216            if below > 0.0 {
217                delta += below;
218            }
219            out[d] = delta;
220        }
221        Ok(out)
222    }
223
224    /// Probability that a uniform random cut over the augmented box
225    /// would isolate `point` from the original box.
226    ///
227    /// Returns `(total_probability, per_dim_contributions)` where the
228    /// per-dim array sums to `total_probability`.
229    ///
230    /// # Errors
231    ///
232    /// Returns [`RcfError::DimensionMismatch`] when `point.len() != D`.
233    pub fn probability_of_cut(&self, point: &[f64]) -> RcfResult<(f64, [f64; D])> {
234        let extension = self.extension_per_dim(point)?;
235        let extension_sum: f64 = extension.iter().sum();
236        let denom = self.range_sum() + extension_sum;
237        if denom == 0.0 {
238            return Ok((0.0, [0.0; D]));
239        }
240        let mut per_dim = [0.0_f64; D];
241        for d in 0..D {
242            per_dim[d] = extension[d] / denom;
243        }
244        let total: f64 = per_dim.iter().sum();
245        Ok((total, per_dim))
246    }
247
248    /// Convenience accessor returning only the per-dim contributions
249    /// (ignores the total).
250    ///
251    /// # Errors
252    ///
253    /// Returns [`RcfError::DimensionMismatch`] when `point.len() != D`.
254    pub fn per_dim_cut_probabilities(&self, point: &[f64]) -> RcfResult<[f64; D]> {
255        Ok(self.probability_of_cut(point)?.1)
256    }
257
258    /// Per-dimension range of the bounding box augmented by `point`
259    /// without materialising a fresh [`BoundingBox`].
260    ///
261    /// # Panics
262    ///
263    /// Panics in debug builds when `d >= D` or `point.len() != D`.
264    #[inline]
265    #[must_use]
266    pub fn augmented_range_at(&self, d: usize, point: &[f64]) -> f64 {
267        let lo = self.min[d].min(point[d]);
268        let hi = self.max[d].max(point[d]);
269        hi - lo
270    }
271
272    /// Sum of [`augmented_range_at`](Self::augmented_range_at) over
273    /// every dimension.
274    ///
275    /// # Panics
276    ///
277    /// Panics in debug builds when `point.len() != D`.
278    #[inline]
279    #[must_use]
280    pub fn augmented_range_sum(&self, point: &[f64]) -> f64 {
281        let chunks = D / 4;
282        let mut acc_simd = f64x4::splat(0.0);
283        for i in 0..chunks {
284            let off = i * 4;
285            let p = f64x4::from([point[off], point[off + 1], point[off + 2], point[off + 3]]);
286            let mn = f64x4::from([
287                self.min[off],
288                self.min[off + 1],
289                self.min[off + 2],
290                self.min[off + 3],
291            ]);
292            let mx = f64x4::from([
293                self.max[off],
294                self.max[off + 1],
295                self.max[off + 2],
296                self.max[off + 3],
297            ]);
298            let lo = mn.fast_min(p);
299            let hi = mx.fast_max(p);
300            acc_simd += hi - lo;
301        }
302        let mut s = acc_simd.reduce_add();
303        let tail_start = chunks * 4;
304        for ((&p, &mn), &mx) in point[tail_start..D]
305            .iter()
306            .zip(self.min[tail_start..D].iter())
307            .zip(self.max[tail_start..D].iter())
308        {
309            let lo = mn.min(p);
310            let hi = mx.max(p);
311            s += hi - lo;
312        }
313        s
314    }
315
316    /// Sample a random cut over the bounding box augmented by
317    /// `point` without materialising the augmented box.
318    ///
319    /// # Errors
320    ///
321    /// Returns [`RcfError::EmptyBoundingBox`] when every per-dim
322    /// range of the augmented box is zero.
323    #[inline]
324    pub fn augmented_random_cut<R: rand::Rng + ?Sized>(
325        &self,
326        point: &[f64],
327        rng: &mut R,
328    ) -> RcfResult<Cut> {
329        let total = self.augmented_range_sum(point);
330        if total <= 0.0 {
331            return Err(RcfError::EmptyBoundingBox);
332        }
333        let mut target = rand::RngExt::random::<f64>(rng) * total;
334        let mut chosen = 0_usize;
335        for d in 0..D {
336            let r = self.augmented_range_at(d, point);
337            if target < r {
338                chosen = d;
339                break;
340            }
341            target -= r;
342            chosen = d;
343        }
344        let lo = self.min[chosen].min(point[chosen]);
345        let hi = self.max[chosen].max(point[chosen]);
346        let value = if (hi - lo).abs() < f64::EPSILON {
347            lo
348        } else {
349            lo + rand::RngExt::random::<f64>(rng) * (hi - lo)
350        };
351        Ok(Cut::new(chosen, value))
352    }
353
354    /// Total cut probability without allocating the per-dim
355    /// breakdown — fast path for [`crate::ScalarScoreVisitor`].
356    ///
357    /// Fuses the `range_sum` and extension passes into a single SIMD
358    /// loop so `self.min` / `self.max` are loaded once per chunk. The
359    /// previous split implementation did two passes and ate 2× the L1
360    /// bandwidth on deep tree descents where bbox reload dominates.
361    ///
362    /// # Errors
363    ///
364    /// Returns [`RcfError::DimensionMismatch`] when `point.len() != D`.
365    pub fn total_probability_of_cut(&self, point: &[f64]) -> RcfResult<f64> {
366        if point.len() != D {
367            return Err(RcfError::DimensionMismatch {
368                expected: D,
369                got: point.len(),
370            });
371        }
372        let chunks = D / 4;
373        let zero = f64x4::splat(0.0);
374        let mut range_acc = f64x4::splat(0.0);
375        let mut ext_acc = f64x4::splat(0.0);
376        for i in 0..chunks {
377            let off = i * 4;
378            let p = f64x4::from([point[off], point[off + 1], point[off + 2], point[off + 3]]);
379            let mn = f64x4::from([
380                self.min[off],
381                self.min[off + 1],
382                self.min[off + 2],
383                self.min[off + 3],
384            ]);
385            let mx = f64x4::from([
386                self.max[off],
387                self.max[off + 1],
388                self.max[off + 2],
389                self.max[off + 3],
390            ]);
391            range_acc += mx - mn;
392            let above = (p - mx).fast_max(zero);
393            let below = (mn - p).fast_max(zero);
394            ext_acc += above + below;
395        }
396        let mut range_sum = range_acc.reduce_add();
397        let mut extension_sum = ext_acc.reduce_add();
398        let tail_start = chunks * 4;
399        for ((&p, &mn), &mx) in point[tail_start..D]
400            .iter()
401            .zip(self.min[tail_start..D].iter())
402            .zip(self.max[tail_start..D].iter())
403        {
404            range_sum += mx - mn;
405            let above = p - mx;
406            let below = mn - p;
407            if above > 0.0 {
408                extension_sum += above;
409            }
410            if below > 0.0 {
411                extension_sum += below;
412            }
413        }
414        let denom = range_sum + extension_sum;
415        if denom == 0.0 {
416            return Ok(0.0);
417        }
418        Ok(extension_sum / denom)
419    }
420}
421
422#[cfg(test)]
423#[allow(clippy::float_cmp)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn from_point_creates_degenerate_box() {
429        let b = BoundingBox::<3>::from_point(&[1.0, 2.0, 3.0]).unwrap();
430        assert_eq!(b.dim(), 3);
431        assert_eq!(b.min(), &[1.0, 2.0, 3.0]);
432        assert_eq!(b.max(), &[1.0, 2.0, 3.0]);
433        assert_eq!(b.range_sum(), 0.0);
434    }
435
436    #[test]
437    fn from_point_rejects_zero_dim() {
438        assert!(matches!(
439            BoundingBox::<0>::from_point(&[]).unwrap_err(),
440            RcfError::EmptyBoundingBox
441        ));
442    }
443
444    #[test]
445    fn from_point_rejects_dim_mismatch() {
446        assert!(matches!(
447            BoundingBox::<3>::from_point(&[1.0, 2.0]).unwrap_err(),
448            RcfError::DimensionMismatch { .. }
449        ));
450    }
451
452    #[test]
453    fn extend_grows_box() {
454        let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
455        b.extend(&[3.0, -2.0]).unwrap();
456        assert_eq!(b.min(), &[0.0, -2.0]);
457        assert_eq!(b.max(), &[3.0, 0.0]);
458        assert!((b.range_sum() - 5.0).abs() < 1e-12);
459    }
460
461    #[test]
462    fn extend_rejects_dim_mismatch() {
463        let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
464        assert!(matches!(
465            b.extend(&[1.0, 2.0, 3.0]).unwrap_err(),
466            RcfError::DimensionMismatch { .. }
467        ));
468    }
469
470    #[test]
471    fn range_at_per_dim() {
472        let mut b = BoundingBox::<3>::from_point(&[0.0, 0.0, 0.0]).unwrap();
473        b.extend(&[2.0, 4.0, 8.0]).unwrap();
474        assert_eq!(b.range_at(0), 2.0);
475        assert_eq!(b.range_at(1), 4.0);
476        assert_eq!(b.range_at(2), 8.0);
477        assert_eq!(b.range_sum(), 14.0);
478    }
479
480    #[test]
481    fn merge_with_unions_corners() {
482        let mut a = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
483        a.extend(&[2.0, 2.0]).unwrap();
484        let mut b = BoundingBox::<2>::from_point(&[-1.0, 1.0]).unwrap();
485        b.extend(&[1.0, 5.0]).unwrap();
486        a.merge_with(&b);
487        assert_eq!(a.min(), &[-1.0, 0.0]);
488        assert_eq!(a.max(), &[2.0, 5.0]);
489    }
490
491    #[test]
492    fn merged_returns_new_box() {
493        let a = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
494        let b = BoundingBox::<2>::from_point(&[5.0, 5.0]).unwrap();
495        let union = a.merged(&b);
496        assert_eq!(union.min(), &[0.0, 0.0]);
497        assert_eq!(union.max(), &[5.0, 5.0]);
498        assert_eq!(a.min(), &[0.0, 0.0]);
499        assert_eq!(b.max(), &[5.0, 5.0]);
500    }
501
502    #[test]
503    fn extension_zero_when_point_inside() {
504        let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
505        b.extend(&[10.0, 10.0]).unwrap();
506        let ext = b.extension_per_dim(&[5.0, 5.0]).unwrap();
507        assert_eq!(ext, [0.0, 0.0]);
508    }
509
510    #[test]
511    fn extension_picks_above_and_below() {
512        let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
513        b.extend(&[10.0, 10.0]).unwrap();
514        let ext = b.extension_per_dim(&[-3.0, 15.0]).unwrap();
515        assert_eq!(ext, [3.0, 5.0]);
516    }
517
518    #[test]
519    fn probability_of_cut_zero_when_inside() {
520        let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
521        b.extend(&[10.0, 10.0]).unwrap();
522        let (p, per_dim) = b.probability_of_cut(&[5.0, 5.0]).unwrap();
523        assert_eq!(p, 0.0);
524        assert_eq!(per_dim, [0.0, 0.0]);
525    }
526
527    #[test]
528    fn probability_of_cut_concentrated_on_extending_dim() {
529        let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
530        b.extend(&[10.0, 10.0]).unwrap();
531        let (total, per_dim) = b.probability_of_cut(&[1000.0, 5.0]).unwrap();
532        assert!(per_dim[0] > per_dim[1]);
533        assert!((per_dim[0] + per_dim[1] - total).abs() < 1e-12);
534    }
535
536    #[test]
537    fn probability_of_cut_handles_degenerate_box() {
538        let b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
539        let (p, per_dim) = b.probability_of_cut(&[0.0, 0.0]).unwrap();
540        assert_eq!(p, 0.0);
541        assert_eq!(per_dim, [0.0, 0.0]);
542    }
543
544    #[test]
545    fn probability_of_cut_per_dim_sums_to_total() {
546        let mut b = BoundingBox::<3>::from_point(&[0.0, 0.0, 0.0]).unwrap();
547        b.extend(&[1.0, 1.0, 1.0]).unwrap();
548        let (total, per_dim) = b.probability_of_cut(&[5.0, -3.0, 0.5]).unwrap();
549        let sum: f64 = per_dim.iter().sum();
550        assert!((sum - total).abs() < 1e-12);
551        assert!(per_dim[0] > 0.0);
552        assert!(per_dim[1] > 0.0);
553        assert_eq!(per_dim[2], 0.0);
554    }
555
556    #[test]
557    fn probability_of_cut_rejects_dim_mismatch() {
558        let b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
559        assert!(matches!(
560            b.probability_of_cut(&[1.0]).unwrap_err(),
561            RcfError::DimensionMismatch { .. }
562        ));
563    }
564
565    #[test]
566    fn per_dim_cut_probabilities_matches_full_call() {
567        let mut b = BoundingBox::<2>::from_point(&[0.0, 0.0]).unwrap();
568        b.extend(&[1.0, 1.0]).unwrap();
569        let (_, full) = b.probability_of_cut(&[5.0, -3.0]).unwrap();
570        let only_per_dim = b.per_dim_cut_probabilities(&[5.0, -3.0]).unwrap();
571        assert_eq!(full, only_per_dim);
572    }
573}