Skip to main content

fdars_core/
streaming_depth.rs

1//! Streaming / online depth computation for functional data.
2//!
3//! This module decouples reference-set construction from query evaluation,
4//! enabling efficient depth computation in streaming scenarios:
5//!
6//! - [`SortedReferenceState`] pre-sorts reference values per time point for O(log N) rank queries.
7//! - [`StreamingMbd`] uses a rank-based combinatorial identity to compute Modified Band Depth
8//!   in O(T log N) per query instead of O(N² T).
9//! - [`StreamingFraimanMuniz`] computes Fraiman-Muniz depth via binary search on sorted columns.
10//! - [`StreamingBd`] computes Band Depth with decoupled reference and early-exit optimisation.
11//! - [`RollingReference`] maintains a sliding window of reference curves with incremental
12//!   sorted-column updates.
13
14use std::collections::VecDeque;
15
16use crate::iter_maybe_parallel;
17use crate::matrix::FdMatrix;
18#[cfg(feature = "parallel")]
19use rayon::iter::ParallelIterator;
20
21// ---------------------------------------------------------------------------
22// Helper: choose-2 combinator
23// ---------------------------------------------------------------------------
24
25#[inline]
26fn c2(k: usize) -> usize {
27    k * k.wrapping_sub(1) / 2
28}
29
30// ===========================================================================
31// SortedReferenceState
32// ===========================================================================
33
34/// Pre-sorted reference values at each time point for O(log N) rank queries.
35///
36/// Constructed once from a column-major reference matrix and then shared
37/// (immutably) by any number of streaming depth estimators.
38pub struct SortedReferenceState {
39    /// `sorted_columns[t]` contains the reference values at time point `t`, sorted ascending.
40    sorted_columns: Vec<Vec<f64>>,
41    nori: usize,
42    n_points: usize,
43}
44
45impl SortedReferenceState {
46    /// Build from a column-major reference matrix.
47    ///
48    /// * `data_ori` – reference matrix of shape `nori × n_points`
49    ///
50    /// Complexity: O(T × N log N)  (parallelised over time points).
51    pub fn from_reference(data_ori: &FdMatrix) -> Self {
52        let nori = data_ori.nrows();
53        let n_points = data_ori.ncols();
54        if nori == 0 || n_points == 0 {
55            return Self {
56                sorted_columns: Vec::new(),
57                nori,
58                n_points,
59            };
60        }
61        let sorted_columns: Vec<Vec<f64>> = iter_maybe_parallel!(0..n_points)
62            .map(|t| {
63                let mut col: Vec<f64> = (0..nori).map(|j| data_ori[(j, t)]).collect();
64                col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
65                col
66            })
67            .collect();
68        Self {
69            sorted_columns,
70            nori,
71            n_points,
72        }
73    }
74
75    /// Returns `(below, above)` — the count of reference values strictly below
76    /// and strictly above `x` at time point `t`.
77    ///
78    /// Complexity: O(log N) via two binary searches.
79    #[inline]
80    pub fn rank_at(&self, t: usize, x: f64) -> (usize, usize) {
81        let col = &self.sorted_columns[t];
82        let below = col.partition_point(|&v| v < x);
83        let at_or_below = col.partition_point(|&v| v <= x);
84        let above = self.nori - at_or_below;
85        (below, above)
86    }
87
88    /// Number of reference observations.
89    #[inline]
90    pub fn nori(&self) -> usize {
91        self.nori
92    }
93
94    /// Number of evaluation points.
95    #[inline]
96    pub fn n_points(&self) -> usize {
97        self.n_points
98    }
99}
100
101// ===========================================================================
102// StreamingDepth trait
103// ===========================================================================
104
105/// Trait for streaming depth estimators backed by a pre-built reference state.
106pub trait StreamingDepth {
107    /// Depth of a single curve given as a contiguous `&[f64]` of length `n_points`.
108    fn depth_one(&self, curve: &[f64]) -> f64;
109
110    /// Batch depth for a matrix of query curves (`nobj × n_points`).
111    fn depth_batch(&self, data_obj: &FdMatrix) -> Vec<f64>;
112
113    /// Number of evaluation points.
114    fn n_points(&self) -> usize;
115
116    /// Number of reference observations backing this estimator.
117    fn n_reference(&self) -> usize;
118}
119
120// ===========================================================================
121// StreamingMbd — rank-based Modified Band Depth, O(T log N) per query
122// ===========================================================================
123
124/// Rank-based Modified Band Depth estimator.
125///
126/// Uses the combinatorial identity: at time t with `b` values strictly below
127/// x(t) and `a` strictly above,
128///
129/// > pairs containing x(t) = C(N,2) − C(b,2) − C(a,2)
130///
131/// MBD(x) = (1 / (C(N,2) × T)) × Σ_t [C(N,2) − C(b_t,2) − C(a_t,2)]
132///
133/// Per-query complexity: **O(T × log N)** instead of O(N² × T).
134pub struct StreamingMbd {
135    state: SortedReferenceState,
136}
137
138impl StreamingMbd {
139    pub fn new(state: SortedReferenceState) -> Self {
140        Self { state }
141    }
142
143    /// Compute MBD for a single row-layout curve using rank formula.
144    #[inline]
145    fn mbd_one_inner(&self, curve: &[f64]) -> f64 {
146        let n = self.state.nori;
147        if n < 2 {
148            return 0.0;
149        }
150        let cn2 = c2(n);
151        let t_len = self.state.n_points;
152        let mut total = 0usize;
153        for t in 0..t_len {
154            let (below, above) = self.state.rank_at(t, curve[t]);
155            total += cn2 - c2(below) - c2(above);
156        }
157        total as f64 / (cn2 as f64 * t_len as f64)
158    }
159}
160
161impl StreamingDepth for StreamingMbd {
162    fn depth_one(&self, curve: &[f64]) -> f64 {
163        self.mbd_one_inner(curve)
164    }
165
166    fn depth_batch(&self, data_obj: &FdMatrix) -> Vec<f64> {
167        let nobj = data_obj.nrows();
168        if nobj == 0 || self.state.n_points == 0 || self.state.nori < 2 {
169            return vec![0.0; nobj];
170        }
171        let n_points = self.state.n_points;
172        iter_maybe_parallel!(0..nobj)
173            .map(|i| {
174                let curve: Vec<f64> = (0..n_points).map(|t| data_obj[(i, t)]).collect();
175                self.mbd_one_inner(&curve)
176            })
177            .collect()
178    }
179
180    fn n_points(&self) -> usize {
181        self.state.n_points
182    }
183
184    fn n_reference(&self) -> usize {
185        self.state.nori
186    }
187}
188
189// ===========================================================================
190// StreamingFraimanMuniz — O(T log N) FM Depth
191// ===========================================================================
192
193/// Streaming Fraiman-Muniz depth estimator.
194///
195/// Uses binary search on sorted columns to compute the empirical CDF at each
196/// time point: Fn(x) = #{ref ≤ x} / N.
197///
198/// Per-query complexity: **O(T × log N)** instead of O(T × N).
199pub struct StreamingFraimanMuniz {
200    state: SortedReferenceState,
201    scale: bool,
202}
203
204impl StreamingFraimanMuniz {
205    pub fn new(state: SortedReferenceState, scale: bool) -> Self {
206        Self { state, scale }
207    }
208
209    #[inline]
210    fn fm_one_inner(&self, curve: &[f64]) -> f64 {
211        let n = self.state.nori;
212        if n == 0 {
213            return 0.0;
214        }
215        let t_len = self.state.n_points;
216        if t_len == 0 {
217            return 0.0;
218        }
219        let scale_factor = if self.scale { 2.0 } else { 1.0 };
220        let mut depth_sum = 0.0;
221        for t in 0..t_len {
222            let col = &self.state.sorted_columns[t];
223            let at_or_below = col.partition_point(|&v| v <= curve[t]);
224            let fn_x = at_or_below as f64 / n as f64;
225            depth_sum += fn_x.min(1.0 - fn_x) * scale_factor;
226        }
227        depth_sum / t_len as f64
228    }
229}
230
231impl StreamingDepth for StreamingFraimanMuniz {
232    fn depth_one(&self, curve: &[f64]) -> f64 {
233        self.fm_one_inner(curve)
234    }
235
236    fn depth_batch(&self, data_obj: &FdMatrix) -> Vec<f64> {
237        let nobj = data_obj.nrows();
238        if nobj == 0 || self.state.n_points == 0 || self.state.nori == 0 {
239            return vec![0.0; nobj];
240        }
241        let n_points = self.state.n_points;
242        iter_maybe_parallel!(0..nobj)
243            .map(|i| {
244                let curve: Vec<f64> = (0..n_points).map(|t| data_obj[(i, t)]).collect();
245                self.fm_one_inner(&curve)
246            })
247            .collect()
248    }
249
250    fn n_points(&self) -> usize {
251        self.state.n_points
252    }
253
254    fn n_reference(&self) -> usize {
255        self.state.nori
256    }
257}
258
259// ===========================================================================
260// FullReferenceState + StreamingBd — Band Depth with decoupled reference
261// ===========================================================================
262
263/// Full reference state that keeps per-curve values alongside sorted columns.
264///
265/// Required by Band Depth (BD), which checks all-or-nothing containment across
266/// ALL time points and therefore cannot decompose into per-point rank queries.
267pub struct FullReferenceState {
268    /// Sorted columns for rank queries (shared with MBD/FM estimators if desired).
269    pub sorted: SortedReferenceState,
270    /// `values_by_curve[j][t]` = reference curve j at time point t (row layout).
271    values_by_curve: Vec<Vec<f64>>,
272}
273
274impl FullReferenceState {
275    /// Build from a column-major reference matrix.
276    pub fn from_reference(data_ori: &FdMatrix) -> Self {
277        let nori = data_ori.nrows();
278        let n_points = data_ori.ncols();
279        let sorted = SortedReferenceState::from_reference(data_ori);
280        let values_by_curve: Vec<Vec<f64>> = (0..nori)
281            .map(|j| (0..n_points).map(|t| data_ori[(j, t)]).collect())
282            .collect();
283        Self {
284            sorted,
285            values_by_curve,
286        }
287    }
288}
289
290/// Streaming Band Depth estimator.
291///
292/// BD requires all-or-nothing containment across ALL time points — it does not
293/// decompose per-point like MBD. The streaming advantage here is **reference
294/// decoupling** (no re-parsing the matrix) and **early-exit per pair** (break
295/// on first time point where x is outside the band), not an asymptotic
296/// improvement.
297pub struct StreamingBd {
298    state: FullReferenceState,
299}
300
301impl StreamingBd {
302    pub fn new(state: FullReferenceState) -> Self {
303        Self { state }
304    }
305
306    #[inline]
307    fn bd_one_inner(&self, curve: &[f64]) -> f64 {
308        let n = self.state.sorted.nori;
309        if n < 2 {
310            return 0.0;
311        }
312        let n_pairs = c2(n);
313        let n_points = self.state.sorted.n_points;
314
315        let mut count_in_band = 0usize;
316        for j in 0..n {
317            for k in (j + 1)..n {
318                let mut inside = true;
319                for t in 0..n_points {
320                    let x_t = curve[t];
321                    let y_j_t = self.state.values_by_curve[j][t];
322                    let y_k_t = self.state.values_by_curve[k][t];
323                    let band_min = y_j_t.min(y_k_t);
324                    let band_max = y_j_t.max(y_k_t);
325                    if x_t < band_min || x_t > band_max {
326                        inside = false;
327                        break;
328                    }
329                }
330                if inside {
331                    count_in_band += 1;
332                }
333            }
334        }
335        count_in_band as f64 / n_pairs as f64
336    }
337}
338
339impl StreamingDepth for StreamingBd {
340    fn depth_one(&self, curve: &[f64]) -> f64 {
341        self.bd_one_inner(curve)
342    }
343
344    fn depth_batch(&self, data_obj: &FdMatrix) -> Vec<f64> {
345        let nobj = data_obj.nrows();
346        let n = self.state.sorted.nori;
347        if nobj == 0 || self.state.sorted.n_points == 0 || n < 2 {
348            return vec![0.0; nobj];
349        }
350        let n_points = self.state.sorted.n_points;
351        iter_maybe_parallel!(0..nobj)
352            .map(|i| {
353                let curve: Vec<f64> = (0..n_points).map(|t| data_obj[(i, t)]).collect();
354                self.bd_one_inner(&curve)
355            })
356            .collect()
357    }
358
359    fn n_points(&self) -> usize {
360        self.state.sorted.n_points
361    }
362
363    fn n_reference(&self) -> usize {
364        self.state.sorted.nori
365    }
366}
367
368// ===========================================================================
369// RollingReference — sliding window with incremental sorted-column updates
370// ===========================================================================
371
372/// Sliding window of reference curves with incrementally maintained sorted columns.
373///
374/// When a new curve is pushed and the window is at capacity, the oldest curve
375/// is evicted. For each time point the old value is removed (binary-search +
376/// `Vec::remove`) and the new value is inserted (binary-search + `Vec::insert`).
377///
378/// Complexity per push: O(T × N) due to element shifting in the sorted vectors.
379pub struct RollingReference {
380    curves: VecDeque<Vec<f64>>,
381    capacity: usize,
382    n_points: usize,
383    sorted_columns: Vec<Vec<f64>>,
384}
385
386impl RollingReference {
387    /// Create an empty rolling window.
388    ///
389    /// * `capacity` – maximum number of curves in the window (must be ≥ 1).
390    /// * `n_points` – number of evaluation points per curve.
391    pub fn new(capacity: usize, n_points: usize) -> Self {
392        assert!(capacity >= 1, "capacity must be at least 1");
393        Self {
394            curves: VecDeque::with_capacity(capacity),
395            capacity,
396            n_points,
397            sorted_columns: (0..n_points)
398                .map(|_| Vec::with_capacity(capacity))
399                .collect(),
400        }
401    }
402
403    /// Push a new curve into the window.
404    ///
405    /// If the window is at capacity, the oldest curve is evicted and returned.
406    /// For each time point, the sorted column is updated incrementally.
407    pub fn push(&mut self, curve: &[f64]) -> Option<Vec<f64>> {
408        assert_eq!(
409            curve.len(),
410            self.n_points,
411            "curve length {} does not match n_points {}",
412            curve.len(),
413            self.n_points
414        );
415
416        let evicted = if self.curves.len() == self.capacity {
417            let old = self.curves.pop_front().unwrap();
418            // Remove old values from sorted columns
419            for t in 0..self.n_points {
420                let col = &mut self.sorted_columns[t];
421                let old_val = old[t];
422                let pos = col.partition_point(|&v| v < old_val);
423                // Find exact match (handles duplicates by scanning nearby)
424                let mut found = false;
425                for idx in pos..col.len() {
426                    if col[idx] == old_val {
427                        col.remove(idx);
428                        found = true;
429                        break;
430                    }
431                    if col[idx] > old_val {
432                        break;
433                    }
434                }
435                if !found {
436                    // Fallback: scan from pos backwards for floating-point edge cases
437                    for idx in (0..pos).rev() {
438                        if col[idx] == old_val {
439                            col.remove(idx);
440                            break;
441                        }
442                        if col[idx] < old_val {
443                            break;
444                        }
445                    }
446                }
447            }
448            Some(old)
449        } else {
450            None
451        };
452
453        // Insert new values into sorted columns
454        let new_curve: Vec<f64> = curve.to_vec();
455        for t in 0..self.n_points {
456            let col = &mut self.sorted_columns[t];
457            let val = new_curve[t];
458            let pos = col.partition_point(|&v| v < val);
459            col.insert(pos, val);
460        }
461        self.curves.push_back(new_curve);
462
463        evicted
464    }
465
466    /// Take a snapshot of the current sorted reference state.
467    ///
468    /// This clones the sorted columns. For repeated queries, prefer
469    /// [`mbd_one`](Self::mbd_one) which queries the window directly.
470    pub fn snapshot(&self) -> SortedReferenceState {
471        SortedReferenceState {
472            sorted_columns: self.sorted_columns.clone(),
473            nori: self.curves.len(),
474            n_points: self.n_points,
475        }
476    }
477
478    /// Compute rank-based MBD for a single curve directly against the current window.
479    ///
480    /// Avoids the overhead of cloning sorted columns into a snapshot.
481    pub fn mbd_one(&self, curve: &[f64]) -> f64 {
482        let n = self.curves.len();
483        if n < 2 || self.n_points == 0 {
484            return 0.0;
485        }
486        assert_eq!(
487            curve.len(),
488            self.n_points,
489            "curve length {} does not match n_points {}",
490            curve.len(),
491            self.n_points
492        );
493        let cn2 = c2(n);
494        let mut total = 0usize;
495        for t in 0..self.n_points {
496            let col = &self.sorted_columns[t];
497            let below = col.partition_point(|&v| v < curve[t]);
498            let at_or_below = col.partition_point(|&v| v <= curve[t]);
499            let above = n - at_or_below;
500            total += cn2 - c2(below) - c2(above);
501        }
502        total as f64 / (cn2 as f64 * self.n_points as f64)
503    }
504
505    /// Number of curves currently in the window.
506    #[inline]
507    pub fn len(&self) -> usize {
508        self.curves.len()
509    }
510
511    /// Whether the window is empty.
512    #[inline]
513    pub fn is_empty(&self) -> bool {
514        self.curves.is_empty()
515    }
516
517    /// Maximum capacity of the window.
518    #[inline]
519    pub fn capacity(&self) -> usize {
520        self.capacity
521    }
522}
523
524// ===========================================================================
525// Tests
526// ===========================================================================
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::depth::{band_1d, fraiman_muniz_1d, modified_band_1d};
532    use crate::matrix::FdMatrix;
533    use std::f64::consts::PI;
534
535    fn uniform_grid(n: usize) -> Vec<f64> {
536        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
537    }
538
539    fn generate_centered_data(n: usize, m: usize) -> Vec<f64> {
540        let argvals = uniform_grid(m);
541        let mut data = vec![0.0; n * m];
542        for i in 0..n {
543            let offset = (i as f64 - n as f64 / 2.0) / (n as f64);
544            for j in 0..m {
545                data[i + j * n] = (2.0 * PI * argvals[j]).sin() + offset;
546            }
547        }
548        data
549    }
550
551    /// Extract a single curve (row i) from column-major data into row layout.
552    fn extract_curve(data: &[f64], i: usize, n: usize, m: usize) -> Vec<f64> {
553        (0..m).map(|t| data[i + t * n]).collect()
554    }
555
556    // ============== Rank correctness ==============
557
558    #[test]
559    fn test_rank_basic() {
560        // 5 reference curves, 3 time points
561        // Column 0: [1, 2, 3, 4, 5]
562        let data = vec![
563            1.0, 2.0, 3.0, 4.0, 5.0, // t=0
564            10.0, 20.0, 30.0, 40.0, 50.0, // t=1
565            100.0, 200.0, 300.0, 400.0, 500.0, // t=2
566        ];
567        let mat = FdMatrix::from_column_major(data, 5, 3).unwrap();
568        let state = SortedReferenceState::from_reference(&mat);
569
570        // At t=0, x=3.0: below=2 (1,2), above=2 (4,5)
571        let (below, above) = state.rank_at(0, 3.0);
572        assert_eq!(below, 2);
573        assert_eq!(above, 2);
574
575        // At t=1, x=25.0: below=2 (10,20), above=3 (30,40,50)
576        let (below, above) = state.rank_at(1, 25.0);
577        assert_eq!(below, 2);
578        assert_eq!(above, 3);
579    }
580
581    #[test]
582    fn test_rank_boundary_values() {
583        // All values identical
584        let data = vec![5.0, 5.0, 5.0, 5.0];
585        let mat = FdMatrix::from_column_major(data, 4, 1).unwrap();
586        let state = SortedReferenceState::from_reference(&mat);
587
588        // x=5.0 exactly: none strictly below, none strictly above
589        let (below, above) = state.rank_at(0, 5.0);
590        assert_eq!(below, 0);
591        assert_eq!(above, 0);
592
593        // x < all: below=0, above=4
594        let (below, above) = state.rank_at(0, 3.0);
595        assert_eq!(below, 0);
596        assert_eq!(above, 4);
597
598        // x > all: below=4, above=0
599        let (below, above) = state.rank_at(0, 7.0);
600        assert_eq!(below, 4);
601        assert_eq!(above, 0);
602    }
603
604    #[test]
605    fn test_rank_duplicates() {
606        // Values with duplicates: [1, 2, 2, 3, 3, 3]
607        let data = vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0];
608        let mat = FdMatrix::from_column_major(data, 6, 1).unwrap();
609        let state = SortedReferenceState::from_reference(&mat);
610
611        // x=2.0: below=1 (just 1), above=3 (three 3s)
612        let (below, above) = state.rank_at(0, 2.0);
613        assert_eq!(below, 1);
614        assert_eq!(above, 3);
615
616        // x=3.0: below=3 (1,2,2), above=0
617        let (below, above) = state.rank_at(0, 3.0);
618        assert_eq!(below, 3);
619        assert_eq!(above, 0);
620    }
621
622    // ============== Batch equivalence ==============
623
624    #[test]
625    fn test_streaming_mbd_matches_batch() {
626        let n = 15;
627        let m = 20;
628        let data = generate_centered_data(n, m);
629
630        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
631        let batch = modified_band_1d(&mat, &mat);
632        let state = SortedReferenceState::from_reference(&mat);
633        let streaming = StreamingMbd::new(state);
634        let streaming_result = streaming.depth_batch(&mat);
635
636        assert_eq!(batch.len(), streaming_result.len());
637        for (b, s) in batch.iter().zip(streaming_result.iter()) {
638            assert!(
639                (b - s).abs() < 1e-10,
640                "MBD mismatch: batch={}, streaming={}",
641                b,
642                s
643            );
644        }
645    }
646
647    #[test]
648    fn test_streaming_fm_matches_batch() {
649        let n = 15;
650        let m = 20;
651        let data = generate_centered_data(n, m);
652
653        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
654        for scale in [true, false] {
655            let batch = fraiman_muniz_1d(&mat, &mat, scale);
656            let state = SortedReferenceState::from_reference(&mat);
657            let streaming = StreamingFraimanMuniz::new(state, scale);
658            let streaming_result = streaming.depth_batch(&mat);
659
660            assert_eq!(batch.len(), streaming_result.len());
661            for (b, s) in batch.iter().zip(streaming_result.iter()) {
662                assert!(
663                    (b - s).abs() < 1e-10,
664                    "FM mismatch (scale={}): batch={}, streaming={}",
665                    scale,
666                    b,
667                    s
668                );
669            }
670        }
671    }
672
673    #[test]
674    fn test_streaming_bd_matches_batch() {
675        let n = 10;
676        let m = 20;
677        let data = generate_centered_data(n, m);
678
679        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
680        let batch = band_1d(&mat, &mat);
681        let full_state = FullReferenceState::from_reference(&mat);
682        let streaming = StreamingBd::new(full_state);
683        let streaming_result = streaming.depth_batch(&mat);
684
685        assert_eq!(batch.len(), streaming_result.len());
686        for (b, s) in batch.iter().zip(streaming_result.iter()) {
687            assert!(
688                (b - s).abs() < 1e-10,
689                "BD mismatch: batch={}, streaming={}",
690                b,
691                s
692            );
693        }
694    }
695
696    // ============== Rolling reference ==============
697
698    #[test]
699    fn test_rolling_sorted_columns_maintained() {
700        let mut rolling = RollingReference::new(3, 2);
701
702        rolling.push(&[1.0, 10.0]);
703        assert_eq!(rolling.sorted_columns[0], vec![1.0]);
704        assert_eq!(rolling.sorted_columns[1], vec![10.0]);
705
706        rolling.push(&[3.0, 5.0]);
707        assert_eq!(rolling.sorted_columns[0], vec![1.0, 3.0]);
708        assert_eq!(rolling.sorted_columns[1], vec![5.0, 10.0]);
709
710        rolling.push(&[2.0, 7.0]);
711        assert_eq!(rolling.sorted_columns[0], vec![1.0, 2.0, 3.0]);
712        assert_eq!(rolling.sorted_columns[1], vec![5.0, 7.0, 10.0]);
713
714        // Push a 4th — evicts [1.0, 10.0]
715        let evicted = rolling.push(&[0.5, 8.0]);
716        assert_eq!(evicted, Some(vec![1.0, 10.0]));
717        assert_eq!(rolling.sorted_columns[0], vec![0.5, 2.0, 3.0]);
718        assert_eq!(rolling.sorted_columns[1], vec![5.0, 7.0, 8.0]);
719    }
720
721    #[test]
722    fn test_rolling_mbd_matches_batch() {
723        let n = 10;
724        let m = 15;
725        let data = generate_centered_data(n, m);
726
727        // Fill a rolling window with the same curves
728        let mut rolling = RollingReference::new(n, m);
729        for i in 0..n {
730            let curve = extract_curve(&data, i, n, m);
731            rolling.push(&curve);
732        }
733
734        // mbd_one should match batch for each curve
735        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
736        let batch = modified_band_1d(&mat, &mat);
737        for i in 0..n {
738            let curve = extract_curve(&data, i, n, m);
739            let rolling_depth = rolling.mbd_one(&curve);
740            assert!(
741                (batch[i] - rolling_depth).abs() < 1e-10,
742                "Rolling MBD mismatch at i={}: batch={}, rolling={}",
743                i,
744                batch[i],
745                rolling_depth
746            );
747        }
748    }
749
750    #[test]
751    fn test_rolling_eviction_correctness() {
752        let m = 5;
753        let mut rolling = RollingReference::new(3, m);
754
755        // Push 5 curves — window should only contain the last 3
756        let curves: Vec<Vec<f64>> = (0..5)
757            .map(|i| (0..m).map(|t| (i * m + t) as f64).collect())
758            .collect();
759
760        for c in &curves {
761            rolling.push(c);
762        }
763
764        assert_eq!(rolling.len(), 3);
765
766        // Snapshot should match manually-built state from curves 2,3,4
767        let snapshot = rolling.snapshot();
768        assert_eq!(snapshot.nori(), 3);
769
770        // Build reference data manually from curves 2..5
771        let mut ref_data = vec![0.0; 3 * m];
772        for (idx, ci) in (2..5).enumerate() {
773            for t in 0..m {
774                ref_data[idx + t * 3] = curves[ci][t];
775            }
776        }
777        let ref_mat = FdMatrix::from_column_major(ref_data, 3, m).unwrap();
778        let expected = SortedReferenceState::from_reference(&ref_mat);
779
780        for t in 0..m {
781            assert_eq!(
782                snapshot.sorted_columns[t], expected.sorted_columns[t],
783                "sorted columns differ at t={}",
784                t
785            );
786        }
787    }
788
789    // ============== Properties ==============
790
791    #[test]
792    fn test_depth_in_unit_interval() {
793        let n = 20;
794        let m = 30;
795        let data = generate_centered_data(n, m);
796        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
797
798        let state_mbd = SortedReferenceState::from_reference(&mat);
799        let mbd = StreamingMbd::new(state_mbd);
800        for d in mbd.depth_batch(&mat) {
801            assert!((0.0..=1.0).contains(&d), "MBD out of range: {}", d);
802        }
803
804        let state_fm = SortedReferenceState::from_reference(&mat);
805        let fm = StreamingFraimanMuniz::new(state_fm, true);
806        for d in fm.depth_batch(&mat) {
807            assert!((0.0..=1.0).contains(&d), "FM out of range: {}", d);
808        }
809
810        let full = FullReferenceState::from_reference(&mat);
811        let bd = StreamingBd::new(full);
812        for d in bd.depth_batch(&mat) {
813            assert!((0.0..=1.0).contains(&d), "BD out of range: {}", d);
814        }
815    }
816
817    #[test]
818    fn test_central_curves_deeper() {
819        let n = 20;
820        let m = 30;
821        let data = generate_centered_data(n, m);
822        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
823
824        let state = SortedReferenceState::from_reference(&mat);
825        let mbd = StreamingMbd::new(state);
826        let depths = mbd.depth_batch(&mat);
827
828        let central_depth = depths[n / 2];
829        let edge_depth = depths[0];
830        assert!(
831            central_depth > edge_depth,
832            "Central curve should be deeper: {} > {}",
833            central_depth,
834            edge_depth
835        );
836    }
837
838    #[test]
839    fn test_empty_inputs() {
840        let empty = FdMatrix::zeros(0, 0);
841        let state = SortedReferenceState::from_reference(&empty);
842        let mbd = StreamingMbd::new(state);
843        assert_eq!(mbd.depth_one(&[]), 0.0);
844
845        let state = SortedReferenceState::from_reference(&empty);
846        let fm = StreamingFraimanMuniz::new(state, true);
847        assert_eq!(fm.depth_one(&[]), 0.0);
848    }
849
850    #[test]
851    fn test_depth_one_matches_depth_batch_single() {
852        let n = 10;
853        let m = 15;
854        let data = generate_centered_data(n, m);
855        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
856
857        // Build a 1-curve column-major "matrix" from curve 3
858        let curve = extract_curve(&data, 3, n, m);
859        let single_mat = FdMatrix::from_column_major(curve.clone(), 1, m).unwrap();
860
861        let state = SortedReferenceState::from_reference(&mat);
862        let mbd = StreamingMbd::new(state);
863
864        let one = mbd.depth_one(&curve);
865        let batch = mbd.depth_batch(&single_mat);
866        assert!(
867            (one - batch[0]).abs() < 1e-14,
868            "depth_one ({}) != depth_batch ({}) for single curve",
869            one,
870            batch[0]
871        );
872    }
873
874    // ============== Thread safety ==============
875
876    #[test]
877    fn test_send_sync() {
878        fn assert_send_sync<T: Send + Sync>() {}
879        assert_send_sync::<SortedReferenceState>();
880        assert_send_sync::<StreamingMbd>();
881        assert_send_sync::<StreamingFraimanMuniz>();
882        assert_send_sync::<FullReferenceState>();
883        assert_send_sync::<StreamingBd>();
884        assert_send_sync::<RollingReference>();
885    }
886
887    // ============== Edge cases ==============
888
889    #[test]
890    fn test_single_reference_curve() {
891        // nori=1: C(1,2) = 0, MBD is undefined → returns 0
892        let data = vec![1.0, 2.0, 3.0]; // 1 curve, 3 time points
893        let mat = FdMatrix::from_column_major(data, 1, 3).unwrap();
894        let state = SortedReferenceState::from_reference(&mat);
895        let mbd = StreamingMbd::new(state);
896        assert_eq!(mbd.depth_one(&[1.0, 2.0, 3.0]), 0.0);
897
898        // BD also needs at least 2
899        let full = FullReferenceState::from_reference(&mat);
900        let bd = StreamingBd::new(full);
901        assert_eq!(bd.depth_one(&[1.0, 2.0, 3.0]), 0.0);
902    }
903
904    #[test]
905    fn test_capacity_one_window() {
906        let mut rolling = RollingReference::new(1, 3);
907
908        rolling.push(&[1.0, 2.0, 3.0]);
909        assert_eq!(rolling.len(), 1);
910        // MBD with 1 curve → 0
911        assert_eq!(rolling.mbd_one(&[1.0, 2.0, 3.0]), 0.0);
912
913        let evicted = rolling.push(&[4.0, 5.0, 6.0]);
914        assert_eq!(evicted, Some(vec![1.0, 2.0, 3.0]));
915        assert_eq!(rolling.len(), 1);
916    }
917
918    #[test]
919    #[should_panic(expected = "curve length")]
920    fn test_curve_length_mismatch() {
921        let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
922        let state = SortedReferenceState::from_reference(&mat);
923        let mbd = StreamingMbd::new(state);
924        // Curve has 3 elements but n_points is 2 — should ideally be caught.
925        // depth_one doesn't assert length (it just indexes), but rolling does.
926        let mut rolling = RollingReference::new(5, 2);
927        rolling.push(&[1.0, 2.0, 3.0]); // panics: length mismatch
928        let _ = mbd; // suppress unused warning
929    }
930
931    // ============== Additional: snapshot-based streaming ==============
932
933    #[test]
934    fn test_rolling_snapshot_produces_valid_mbd() {
935        let n = 8;
936        let m = 10;
937        let data = generate_centered_data(n, m);
938
939        let mut rolling = RollingReference::new(n, m);
940        for i in 0..n {
941            let curve = extract_curve(&data, i, n, m);
942            rolling.push(&curve);
943        }
944
945        let snapshot = rolling.snapshot();
946        let mbd = StreamingMbd::new(snapshot);
947
948        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
949        let batch_depths = modified_band_1d(&mat, &mat);
950        let streaming_depths = mbd.depth_batch(&mat);
951
952        for (b, s) in batch_depths.iter().zip(streaming_depths.iter()) {
953            assert!(
954                (b - s).abs() < 1e-10,
955                "Snapshot MBD mismatch: batch={}, streaming={}",
956                b,
957                s
958            );
959        }
960    }
961}