Skip to main content

fdars_core/
streaming_depth.rs

1//! Streaming / online depth computation for functional data.
2//!
3//! This module decouples reference-set construction from query evaluation,
4//! enabling efficient depth computation in streaming scenarios:
5//!
6//! - [`SortedReferenceState`] pre-sorts reference values per time point for O(log N) rank queries.
7//! - [`StreamingMbd`] uses a rank-based combinatorial identity to compute Modified Band Depth
8//!   in O(T log N) per query instead of O(N² T).
9//! - [`StreamingFraimanMuniz`] computes Fraiman-Muniz depth via binary search on sorted columns.
10//! - [`StreamingBd`] computes Band Depth with decoupled reference and early-exit optimisation.
11//! - [`RollingReference`] maintains a sliding window of reference curves with incremental
12//!   sorted-column updates.
13
14use std::collections::VecDeque;
15
16use crate::iter_maybe_parallel;
17#[cfg(feature = "parallel")]
18use rayon::iter::ParallelIterator;
19
20// ---------------------------------------------------------------------------
21// Helper: choose-2 combinator
22// ---------------------------------------------------------------------------
23
24#[inline]
25fn c2(k: usize) -> usize {
26    k * k.wrapping_sub(1) / 2
27}
28
29// ===========================================================================
30// SortedReferenceState
31// ===========================================================================
32
33/// Pre-sorted reference values at each time point for O(log N) rank queries.
34///
35/// Constructed once from a column-major reference matrix and then shared
36/// (immutably) by any number of streaming depth estimators.
37pub struct SortedReferenceState {
38    /// `sorted_columns[t]` contains the reference values at time point `t`, sorted ascending.
39    sorted_columns: Vec<Vec<f64>>,
40    nori: usize,
41    n_points: usize,
42}
43
44impl SortedReferenceState {
45    /// Build from a column-major reference matrix.
46    ///
47    /// * `data_ori` – flat column-major data of shape `nori × n_points`
48    /// * `nori` – number of reference observations
49    /// * `n_points` – number of evaluation points
50    ///
51    /// Complexity: O(T × N log N)  (parallelised over time points).
52    pub fn from_reference(data_ori: &[f64], nori: usize, n_points: usize) -> Self {
53        if nori == 0 || n_points == 0 {
54            return Self {
55                sorted_columns: Vec::new(),
56                nori,
57                n_points,
58            };
59        }
60        let sorted_columns: Vec<Vec<f64>> = iter_maybe_parallel!(0..n_points)
61            .map(|t| {
62                let mut col: Vec<f64> = (0..nori).map(|j| data_ori[j + t * nori]).collect();
63                col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
64                col
65            })
66            .collect();
67        Self {
68            sorted_columns,
69            nori,
70            n_points,
71        }
72    }
73
74    /// Returns `(below, above)` — the count of reference values strictly below
75    /// and strictly above `x` at time point `t`.
76    ///
77    /// Complexity: O(log N) via two binary searches.
78    #[inline]
79    pub fn rank_at(&self, t: usize, x: f64) -> (usize, usize) {
80        let col = &self.sorted_columns[t];
81        let below = col.partition_point(|&v| v < x);
82        let at_or_below = col.partition_point(|&v| v <= x);
83        let above = self.nori - at_or_below;
84        (below, above)
85    }
86
87    /// Number of reference observations.
88    #[inline]
89    pub fn nori(&self) -> usize {
90        self.nori
91    }
92
93    /// Number of evaluation points.
94    #[inline]
95    pub fn n_points(&self) -> usize {
96        self.n_points
97    }
98}
99
100// ===========================================================================
101// StreamingDepth trait
102// ===========================================================================
103
104/// Trait for streaming depth estimators backed by a pre-built reference state.
105pub trait StreamingDepth {
106    /// Depth of a single curve given as a contiguous `&[f64]` of length `n_points`.
107    fn depth_one(&self, curve: &[f64]) -> f64;
108
109    /// Batch depth for column-major data (`nobj × n_points`).
110    fn depth_batch(&self, data_obj: &[f64], nobj: usize) -> Vec<f64>;
111
112    /// Number of evaluation points.
113    fn n_points(&self) -> usize;
114
115    /// Number of reference observations backing this estimator.
116    fn n_reference(&self) -> usize;
117}
118
119// ===========================================================================
120// StreamingMbd — rank-based Modified Band Depth, O(T log N) per query
121// ===========================================================================
122
123/// Rank-based Modified Band Depth estimator.
124///
125/// Uses the combinatorial identity: at time t with `b` values strictly below
126/// x(t) and `a` strictly above,
127///
128/// > pairs containing x(t) = C(N,2) − C(b,2) − C(a,2)
129///
130/// MBD(x) = (1 / (C(N,2) × T)) × Σ_t [C(N,2) − C(b_t,2) − C(a_t,2)]
131///
132/// Per-query complexity: **O(T × log N)** instead of O(N² × T).
133pub struct StreamingMbd {
134    state: SortedReferenceState,
135}
136
137impl StreamingMbd {
138    pub fn new(state: SortedReferenceState) -> Self {
139        Self { state }
140    }
141
142    /// Compute MBD for a single row-layout curve using rank formula.
143    #[inline]
144    fn mbd_one_inner(&self, curve: &[f64]) -> f64 {
145        let n = self.state.nori;
146        if n < 2 {
147            return 0.0;
148        }
149        let cn2 = c2(n);
150        let t_len = self.state.n_points;
151        let mut total = 0usize;
152        for t in 0..t_len {
153            let (below, above) = self.state.rank_at(t, curve[t]);
154            total += cn2 - c2(below) - c2(above);
155        }
156        total as f64 / (cn2 as f64 * t_len as f64)
157    }
158}
159
160impl StreamingDepth for StreamingMbd {
161    fn depth_one(&self, curve: &[f64]) -> f64 {
162        self.mbd_one_inner(curve)
163    }
164
165    fn depth_batch(&self, data_obj: &[f64], nobj: usize) -> Vec<f64> {
166        if nobj == 0 || self.state.n_points == 0 || self.state.nori < 2 {
167            return vec![0.0; nobj];
168        }
169        let n_points = self.state.n_points;
170        iter_maybe_parallel!(0..nobj)
171            .map(|i| {
172                let curve: Vec<f64> = (0..n_points).map(|t| data_obj[i + t * nobj]).collect();
173                self.mbd_one_inner(&curve)
174            })
175            .collect()
176    }
177
178    fn n_points(&self) -> usize {
179        self.state.n_points
180    }
181
182    fn n_reference(&self) -> usize {
183        self.state.nori
184    }
185}
186
187// ===========================================================================
188// StreamingFraimanMuniz — O(T log N) FM Depth
189// ===========================================================================
190
191/// Streaming Fraiman-Muniz depth estimator.
192///
193/// Uses binary search on sorted columns to compute the empirical CDF at each
194/// time point: Fn(x) = #{ref ≤ x} / N.
195///
196/// Per-query complexity: **O(T × log N)** instead of O(T × N).
197pub struct StreamingFraimanMuniz {
198    state: SortedReferenceState,
199    scale: bool,
200}
201
202impl StreamingFraimanMuniz {
203    pub fn new(state: SortedReferenceState, scale: bool) -> Self {
204        Self { state, scale }
205    }
206
207    #[inline]
208    fn fm_one_inner(&self, curve: &[f64]) -> f64 {
209        let n = self.state.nori;
210        if n == 0 {
211            return 0.0;
212        }
213        let t_len = self.state.n_points;
214        if t_len == 0 {
215            return 0.0;
216        }
217        let scale_factor = if self.scale { 2.0 } else { 1.0 };
218        let mut depth_sum = 0.0;
219        for t in 0..t_len {
220            let col = &self.state.sorted_columns[t];
221            let at_or_below = col.partition_point(|&v| v <= curve[t]);
222            let fn_x = at_or_below as f64 / n as f64;
223            depth_sum += fn_x.min(1.0 - fn_x) * scale_factor;
224        }
225        depth_sum / t_len as f64
226    }
227}
228
229impl StreamingDepth for StreamingFraimanMuniz {
230    fn depth_one(&self, curve: &[f64]) -> f64 {
231        self.fm_one_inner(curve)
232    }
233
234    fn depth_batch(&self, data_obj: &[f64], nobj: usize) -> Vec<f64> {
235        if nobj == 0 || self.state.n_points == 0 || self.state.nori == 0 {
236            return vec![0.0; nobj];
237        }
238        let n_points = self.state.n_points;
239        iter_maybe_parallel!(0..nobj)
240            .map(|i| {
241                let curve: Vec<f64> = (0..n_points).map(|t| data_obj[i + t * nobj]).collect();
242                self.fm_one_inner(&curve)
243            })
244            .collect()
245    }
246
247    fn n_points(&self) -> usize {
248        self.state.n_points
249    }
250
251    fn n_reference(&self) -> usize {
252        self.state.nori
253    }
254}
255
256// ===========================================================================
257// FullReferenceState + StreamingBd — Band Depth with decoupled reference
258// ===========================================================================
259
260/// Full reference state that keeps per-curve values alongside sorted columns.
261///
262/// Required by Band Depth (BD), which checks all-or-nothing containment across
263/// ALL time points and therefore cannot decompose into per-point rank queries.
264pub struct FullReferenceState {
265    /// Sorted columns for rank queries (shared with MBD/FM estimators if desired).
266    pub sorted: SortedReferenceState,
267    /// `values_by_curve[j][t]` = reference curve j at time point t (row layout).
268    values_by_curve: Vec<Vec<f64>>,
269}
270
271impl FullReferenceState {
272    /// Build from a column-major reference matrix.
273    pub fn from_reference(data_ori: &[f64], nori: usize, n_points: usize) -> Self {
274        let sorted = SortedReferenceState::from_reference(data_ori, nori, n_points);
275        let values_by_curve: Vec<Vec<f64>> = (0..nori)
276            .map(|j| (0..n_points).map(|t| data_ori[j + t * nori]).collect())
277            .collect();
278        Self {
279            sorted,
280            values_by_curve,
281        }
282    }
283}
284
285/// Streaming Band Depth estimator.
286///
287/// BD requires all-or-nothing containment across ALL time points — it does not
288/// decompose per-point like MBD. The streaming advantage here is **reference
289/// decoupling** (no re-parsing the matrix) and **early-exit per pair** (break
290/// on first time point where x is outside the band), not an asymptotic
291/// improvement.
292pub struct StreamingBd {
293    state: FullReferenceState,
294}
295
296impl StreamingBd {
297    pub fn new(state: FullReferenceState) -> Self {
298        Self { state }
299    }
300
301    #[inline]
302    fn bd_one_inner(&self, curve: &[f64]) -> f64 {
303        let n = self.state.sorted.nori;
304        if n < 2 {
305            return 0.0;
306        }
307        let n_pairs = c2(n);
308        let n_points = self.state.sorted.n_points;
309
310        let mut count_in_band = 0usize;
311        for j in 0..n {
312            for k in (j + 1)..n {
313                let mut inside = true;
314                for t in 0..n_points {
315                    let x_t = curve[t];
316                    let y_j_t = self.state.values_by_curve[j][t];
317                    let y_k_t = self.state.values_by_curve[k][t];
318                    let band_min = y_j_t.min(y_k_t);
319                    let band_max = y_j_t.max(y_k_t);
320                    if x_t < band_min || x_t > band_max {
321                        inside = false;
322                        break;
323                    }
324                }
325                if inside {
326                    count_in_band += 1;
327                }
328            }
329        }
330        count_in_band as f64 / n_pairs as f64
331    }
332}
333
334impl StreamingDepth for StreamingBd {
335    fn depth_one(&self, curve: &[f64]) -> f64 {
336        self.bd_one_inner(curve)
337    }
338
339    fn depth_batch(&self, data_obj: &[f64], nobj: usize) -> Vec<f64> {
340        let n = self.state.sorted.nori;
341        if nobj == 0 || self.state.sorted.n_points == 0 || n < 2 {
342            return vec![0.0; nobj];
343        }
344        let n_points = self.state.sorted.n_points;
345        iter_maybe_parallel!(0..nobj)
346            .map(|i| {
347                let curve: Vec<f64> = (0..n_points).map(|t| data_obj[i + t * nobj]).collect();
348                self.bd_one_inner(&curve)
349            })
350            .collect()
351    }
352
353    fn n_points(&self) -> usize {
354        self.state.sorted.n_points
355    }
356
357    fn n_reference(&self) -> usize {
358        self.state.sorted.nori
359    }
360}
361
362// ===========================================================================
363// RollingReference — sliding window with incremental sorted-column updates
364// ===========================================================================
365
366/// Sliding window of reference curves with incrementally maintained sorted columns.
367///
368/// When a new curve is pushed and the window is at capacity, the oldest curve
369/// is evicted. For each time point the old value is removed (binary-search +
370/// `Vec::remove`) and the new value is inserted (binary-search + `Vec::insert`).
371///
372/// Complexity per push: O(T × N) due to element shifting in the sorted vectors.
373pub struct RollingReference {
374    curves: VecDeque<Vec<f64>>,
375    capacity: usize,
376    n_points: usize,
377    sorted_columns: Vec<Vec<f64>>,
378}
379
380impl RollingReference {
381    /// Create an empty rolling window.
382    ///
383    /// * `capacity` – maximum number of curves in the window (must be ≥ 1).
384    /// * `n_points` – number of evaluation points per curve.
385    pub fn new(capacity: usize, n_points: usize) -> Self {
386        assert!(capacity >= 1, "capacity must be at least 1");
387        Self {
388            curves: VecDeque::with_capacity(capacity),
389            capacity,
390            n_points,
391            sorted_columns: (0..n_points)
392                .map(|_| Vec::with_capacity(capacity))
393                .collect(),
394        }
395    }
396
397    /// Push a new curve into the window.
398    ///
399    /// If the window is at capacity, the oldest curve is evicted and returned.
400    /// For each time point, the sorted column is updated incrementally.
401    pub fn push(&mut self, curve: &[f64]) -> Option<Vec<f64>> {
402        assert_eq!(
403            curve.len(),
404            self.n_points,
405            "curve length {} does not match n_points {}",
406            curve.len(),
407            self.n_points
408        );
409
410        let evicted = if self.curves.len() == self.capacity {
411            let old = self.curves.pop_front().unwrap();
412            // Remove old values from sorted columns
413            for t in 0..self.n_points {
414                let col = &mut self.sorted_columns[t];
415                let old_val = old[t];
416                let pos = col.partition_point(|&v| v < old_val);
417                // Find exact match (handles duplicates by scanning nearby)
418                let mut found = false;
419                for idx in pos..col.len() {
420                    if col[idx] == old_val {
421                        col.remove(idx);
422                        found = true;
423                        break;
424                    }
425                    if col[idx] > old_val {
426                        break;
427                    }
428                }
429                if !found {
430                    // Fallback: scan from pos backwards for floating-point edge cases
431                    for idx in (0..pos).rev() {
432                        if col[idx] == old_val {
433                            col.remove(idx);
434                            break;
435                        }
436                        if col[idx] < old_val {
437                            break;
438                        }
439                    }
440                }
441            }
442            Some(old)
443        } else {
444            None
445        };
446
447        // Insert new values into sorted columns
448        let new_curve: Vec<f64> = curve.to_vec();
449        for t in 0..self.n_points {
450            let col = &mut self.sorted_columns[t];
451            let val = new_curve[t];
452            let pos = col.partition_point(|&v| v < val);
453            col.insert(pos, val);
454        }
455        self.curves.push_back(new_curve);
456
457        evicted
458    }
459
460    /// Take a snapshot of the current sorted reference state.
461    ///
462    /// This clones the sorted columns. For repeated queries, prefer
463    /// [`mbd_one`](Self::mbd_one) which queries the window directly.
464    pub fn snapshot(&self) -> SortedReferenceState {
465        SortedReferenceState {
466            sorted_columns: self.sorted_columns.clone(),
467            nori: self.curves.len(),
468            n_points: self.n_points,
469        }
470    }
471
472    /// Compute rank-based MBD for a single curve directly against the current window.
473    ///
474    /// Avoids the overhead of cloning sorted columns into a snapshot.
475    pub fn mbd_one(&self, curve: &[f64]) -> f64 {
476        let n = self.curves.len();
477        if n < 2 || self.n_points == 0 {
478            return 0.0;
479        }
480        assert_eq!(
481            curve.len(),
482            self.n_points,
483            "curve length {} does not match n_points {}",
484            curve.len(),
485            self.n_points
486        );
487        let cn2 = c2(n);
488        let mut total = 0usize;
489        for t in 0..self.n_points {
490            let col = &self.sorted_columns[t];
491            let below = col.partition_point(|&v| v < curve[t]);
492            let at_or_below = col.partition_point(|&v| v <= curve[t]);
493            let above = n - at_or_below;
494            total += cn2 - c2(below) - c2(above);
495        }
496        total as f64 / (cn2 as f64 * self.n_points as f64)
497    }
498
499    /// Number of curves currently in the window.
500    #[inline]
501    pub fn len(&self) -> usize {
502        self.curves.len()
503    }
504
505    /// Whether the window is empty.
506    #[inline]
507    pub fn is_empty(&self) -> bool {
508        self.curves.is_empty()
509    }
510
511    /// Maximum capacity of the window.
512    #[inline]
513    pub fn capacity(&self) -> usize {
514        self.capacity
515    }
516}
517
518// ===========================================================================
519// Tests
520// ===========================================================================
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use crate::depth::{band_1d, fraiman_muniz_1d, modified_band_1d};
526    use std::f64::consts::PI;
527
528    fn uniform_grid(n: usize) -> Vec<f64> {
529        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
530    }
531
532    fn generate_centered_data(n: usize, m: usize) -> Vec<f64> {
533        let argvals = uniform_grid(m);
534        let mut data = vec![0.0; n * m];
535        for i in 0..n {
536            let offset = (i as f64 - n as f64 / 2.0) / (n as f64);
537            for j in 0..m {
538                data[i + j * n] = (2.0 * PI * argvals[j]).sin() + offset;
539            }
540        }
541        data
542    }
543
544    /// Extract a single curve (row i) from column-major data into row layout.
545    fn extract_curve(data: &[f64], i: usize, n: usize, m: usize) -> Vec<f64> {
546        (0..m).map(|t| data[i + t * n]).collect()
547    }
548
549    // ============== Rank correctness ==============
550
551    #[test]
552    fn test_rank_basic() {
553        // 5 reference curves, 3 time points
554        // Column 0: [1, 2, 3, 4, 5]
555        let data = vec![
556            1.0, 2.0, 3.0, 4.0, 5.0, // t=0
557            10.0, 20.0, 30.0, 40.0, 50.0, // t=1
558            100.0, 200.0, 300.0, 400.0, 500.0, // t=2
559        ];
560        let state = SortedReferenceState::from_reference(&data, 5, 3);
561
562        // At t=0, x=3.0: below=2 (1,2), above=2 (4,5)
563        let (below, above) = state.rank_at(0, 3.0);
564        assert_eq!(below, 2);
565        assert_eq!(above, 2);
566
567        // At t=1, x=25.0: below=2 (10,20), above=3 (30,40,50)
568        let (below, above) = state.rank_at(1, 25.0);
569        assert_eq!(below, 2);
570        assert_eq!(above, 3);
571    }
572
573    #[test]
574    fn test_rank_boundary_values() {
575        // All values identical
576        let data = vec![5.0, 5.0, 5.0, 5.0];
577        let state = SortedReferenceState::from_reference(&data, 4, 1);
578
579        // x=5.0 exactly: none strictly below, none strictly above
580        let (below, above) = state.rank_at(0, 5.0);
581        assert_eq!(below, 0);
582        assert_eq!(above, 0);
583
584        // x < all: below=0, above=4
585        let (below, above) = state.rank_at(0, 3.0);
586        assert_eq!(below, 0);
587        assert_eq!(above, 4);
588
589        // x > all: below=4, above=0
590        let (below, above) = state.rank_at(0, 7.0);
591        assert_eq!(below, 4);
592        assert_eq!(above, 0);
593    }
594
595    #[test]
596    fn test_rank_duplicates() {
597        // Values with duplicates: [1, 2, 2, 3, 3, 3]
598        let data = vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0];
599        let state = SortedReferenceState::from_reference(&data, 6, 1);
600
601        // x=2.0: below=1 (just 1), above=3 (three 3s)
602        let (below, above) = state.rank_at(0, 2.0);
603        assert_eq!(below, 1);
604        assert_eq!(above, 3);
605
606        // x=3.0: below=3 (1,2,2), above=0
607        let (below, above) = state.rank_at(0, 3.0);
608        assert_eq!(below, 3);
609        assert_eq!(above, 0);
610    }
611
612    // ============== Batch equivalence ==============
613
614    #[test]
615    fn test_streaming_mbd_matches_batch() {
616        let n = 15;
617        let m = 20;
618        let data = generate_centered_data(n, m);
619
620        let batch = modified_band_1d(&data, &data, n, n, m);
621        let state = SortedReferenceState::from_reference(&data, n, m);
622        let streaming = StreamingMbd::new(state);
623        let streaming_result = streaming.depth_batch(&data, n);
624
625        assert_eq!(batch.len(), streaming_result.len());
626        for (b, s) in batch.iter().zip(streaming_result.iter()) {
627            assert!(
628                (b - s).abs() < 1e-10,
629                "MBD mismatch: batch={}, streaming={}",
630                b,
631                s
632            );
633        }
634    }
635
636    #[test]
637    fn test_streaming_fm_matches_batch() {
638        let n = 15;
639        let m = 20;
640        let data = generate_centered_data(n, m);
641
642        for scale in [true, false] {
643            let batch = fraiman_muniz_1d(&data, &data, n, n, m, scale);
644            let state = SortedReferenceState::from_reference(&data, n, m);
645            let streaming = StreamingFraimanMuniz::new(state, scale);
646            let streaming_result = streaming.depth_batch(&data, n);
647
648            assert_eq!(batch.len(), streaming_result.len());
649            for (b, s) in batch.iter().zip(streaming_result.iter()) {
650                assert!(
651                    (b - s).abs() < 1e-10,
652                    "FM mismatch (scale={}): batch={}, streaming={}",
653                    scale,
654                    b,
655                    s
656                );
657            }
658        }
659    }
660
661    #[test]
662    fn test_streaming_bd_matches_batch() {
663        let n = 10;
664        let m = 20;
665        let data = generate_centered_data(n, m);
666
667        let batch = band_1d(&data, &data, n, n, m);
668        let full_state = FullReferenceState::from_reference(&data, n, m);
669        let streaming = StreamingBd::new(full_state);
670        let streaming_result = streaming.depth_batch(&data, n);
671
672        assert_eq!(batch.len(), streaming_result.len());
673        for (b, s) in batch.iter().zip(streaming_result.iter()) {
674            assert!(
675                (b - s).abs() < 1e-10,
676                "BD mismatch: batch={}, streaming={}",
677                b,
678                s
679            );
680        }
681    }
682
683    // ============== Rolling reference ==============
684
685    #[test]
686    fn test_rolling_sorted_columns_maintained() {
687        let mut rolling = RollingReference::new(3, 2);
688
689        rolling.push(&[1.0, 10.0]);
690        assert_eq!(rolling.sorted_columns[0], vec![1.0]);
691        assert_eq!(rolling.sorted_columns[1], vec![10.0]);
692
693        rolling.push(&[3.0, 5.0]);
694        assert_eq!(rolling.sorted_columns[0], vec![1.0, 3.0]);
695        assert_eq!(rolling.sorted_columns[1], vec![5.0, 10.0]);
696
697        rolling.push(&[2.0, 7.0]);
698        assert_eq!(rolling.sorted_columns[0], vec![1.0, 2.0, 3.0]);
699        assert_eq!(rolling.sorted_columns[1], vec![5.0, 7.0, 10.0]);
700
701        // Push a 4th — evicts [1.0, 10.0]
702        let evicted = rolling.push(&[0.5, 8.0]);
703        assert_eq!(evicted, Some(vec![1.0, 10.0]));
704        assert_eq!(rolling.sorted_columns[0], vec![0.5, 2.0, 3.0]);
705        assert_eq!(rolling.sorted_columns[1], vec![5.0, 7.0, 8.0]);
706    }
707
708    #[test]
709    fn test_rolling_mbd_matches_batch() {
710        let n = 10;
711        let m = 15;
712        let data = generate_centered_data(n, m);
713
714        // Fill a rolling window with the same curves
715        let mut rolling = RollingReference::new(n, m);
716        for i in 0..n {
717            let curve = extract_curve(&data, i, n, m);
718            rolling.push(&curve);
719        }
720
721        // mbd_one should match batch for each curve
722        let batch = modified_band_1d(&data, &data, n, n, m);
723        for i in 0..n {
724            let curve = extract_curve(&data, i, n, m);
725            let rolling_depth = rolling.mbd_one(&curve);
726            assert!(
727                (batch[i] - rolling_depth).abs() < 1e-10,
728                "Rolling MBD mismatch at i={}: batch={}, rolling={}",
729                i,
730                batch[i],
731                rolling_depth
732            );
733        }
734    }
735
736    #[test]
737    fn test_rolling_eviction_correctness() {
738        let m = 5;
739        let mut rolling = RollingReference::new(3, m);
740
741        // Push 5 curves — window should only contain the last 3
742        let curves: Vec<Vec<f64>> = (0..5)
743            .map(|i| (0..m).map(|t| (i * m + t) as f64).collect())
744            .collect();
745
746        for c in &curves {
747            rolling.push(c);
748        }
749
750        assert_eq!(rolling.len(), 3);
751
752        // Snapshot should match manually-built state from curves 2,3,4
753        let snapshot = rolling.snapshot();
754        assert_eq!(snapshot.nori(), 3);
755
756        // Build reference data manually from curves 2..5
757        let mut ref_data = vec![0.0; 3 * m];
758        for (idx, ci) in (2..5).enumerate() {
759            for t in 0..m {
760                ref_data[idx + t * 3] = curves[ci][t];
761            }
762        }
763        let expected = SortedReferenceState::from_reference(&ref_data, 3, m);
764
765        for t in 0..m {
766            assert_eq!(
767                snapshot.sorted_columns[t], expected.sorted_columns[t],
768                "sorted columns differ at t={}",
769                t
770            );
771        }
772    }
773
774    // ============== Properties ==============
775
776    #[test]
777    fn test_depth_in_unit_interval() {
778        let n = 20;
779        let m = 30;
780        let data = generate_centered_data(n, m);
781
782        let state_mbd = SortedReferenceState::from_reference(&data, n, m);
783        let mbd = StreamingMbd::new(state_mbd);
784        for d in mbd.depth_batch(&data, n) {
785            assert!((0.0..=1.0).contains(&d), "MBD out of range: {}", d);
786        }
787
788        let state_fm = SortedReferenceState::from_reference(&data, n, m);
789        let fm = StreamingFraimanMuniz::new(state_fm, true);
790        for d in fm.depth_batch(&data, n) {
791            assert!((0.0..=1.0).contains(&d), "FM out of range: {}", d);
792        }
793
794        let full = FullReferenceState::from_reference(&data, n, m);
795        let bd = StreamingBd::new(full);
796        for d in bd.depth_batch(&data, n) {
797            assert!((0.0..=1.0).contains(&d), "BD out of range: {}", d);
798        }
799    }
800
801    #[test]
802    fn test_central_curves_deeper() {
803        let n = 20;
804        let m = 30;
805        let data = generate_centered_data(n, m);
806
807        let state = SortedReferenceState::from_reference(&data, n, m);
808        let mbd = StreamingMbd::new(state);
809        let depths = mbd.depth_batch(&data, n);
810
811        let central_depth = depths[n / 2];
812        let edge_depth = depths[0];
813        assert!(
814            central_depth > edge_depth,
815            "Central curve should be deeper: {} > {}",
816            central_depth,
817            edge_depth
818        );
819    }
820
821    #[test]
822    fn test_empty_inputs() {
823        let state = SortedReferenceState::from_reference(&[], 0, 0);
824        let mbd = StreamingMbd::new(state);
825        assert_eq!(mbd.depth_one(&[]), 0.0);
826
827        let state = SortedReferenceState::from_reference(&[], 0, 0);
828        let fm = StreamingFraimanMuniz::new(state, true);
829        assert_eq!(fm.depth_one(&[]), 0.0);
830    }
831
832    #[test]
833    fn test_depth_one_matches_depth_batch_single() {
834        let n = 10;
835        let m = 15;
836        let data = generate_centered_data(n, m);
837
838        // Build a 1-curve column-major "matrix" from curve 3
839        let curve = extract_curve(&data, 3, n, m);
840        let mut single_data = vec![0.0; m];
841        single_data[..m].copy_from_slice(&curve[..m]);
842
843        let state = SortedReferenceState::from_reference(&data, n, m);
844        let mbd = StreamingMbd::new(state);
845
846        let one = mbd.depth_one(&curve);
847        let batch = mbd.depth_batch(&single_data, 1);
848        assert!(
849            (one - batch[0]).abs() < 1e-14,
850            "depth_one ({}) != depth_batch ({}) for single curve",
851            one,
852            batch[0]
853        );
854    }
855
856    // ============== Thread safety ==============
857
858    #[test]
859    fn test_send_sync() {
860        fn assert_send_sync<T: Send + Sync>() {}
861        assert_send_sync::<SortedReferenceState>();
862        assert_send_sync::<StreamingMbd>();
863        assert_send_sync::<StreamingFraimanMuniz>();
864        assert_send_sync::<FullReferenceState>();
865        assert_send_sync::<StreamingBd>();
866        assert_send_sync::<RollingReference>();
867    }
868
869    // ============== Edge cases ==============
870
871    #[test]
872    fn test_single_reference_curve() {
873        // nori=1: C(1,2) = 0, MBD is undefined → returns 0
874        let data = vec![1.0, 2.0, 3.0]; // 1 curve, 3 time points
875        let state = SortedReferenceState::from_reference(&data, 1, 3);
876        let mbd = StreamingMbd::new(state);
877        assert_eq!(mbd.depth_one(&[1.0, 2.0, 3.0]), 0.0);
878
879        // BD also needs at least 2
880        let full = FullReferenceState::from_reference(&data, 1, 3);
881        let bd = StreamingBd::new(full);
882        assert_eq!(bd.depth_one(&[1.0, 2.0, 3.0]), 0.0);
883    }
884
885    #[test]
886    fn test_capacity_one_window() {
887        let mut rolling = RollingReference::new(1, 3);
888
889        rolling.push(&[1.0, 2.0, 3.0]);
890        assert_eq!(rolling.len(), 1);
891        // MBD with 1 curve → 0
892        assert_eq!(rolling.mbd_one(&[1.0, 2.0, 3.0]), 0.0);
893
894        let evicted = rolling.push(&[4.0, 5.0, 6.0]);
895        assert_eq!(evicted, Some(vec![1.0, 2.0, 3.0]));
896        assert_eq!(rolling.len(), 1);
897    }
898
899    #[test]
900    #[should_panic(expected = "curve length")]
901    fn test_curve_length_mismatch() {
902        let state = SortedReferenceState::from_reference(&[1.0, 2.0, 3.0, 4.0], 2, 2);
903        let mbd = StreamingMbd::new(state);
904        // Curve has 3 elements but n_points is 2 — should ideally be caught.
905        // depth_one doesn't assert length (it just indexes), but rolling does.
906        let mut rolling = RollingReference::new(5, 2);
907        rolling.push(&[1.0, 2.0, 3.0]); // panics: length mismatch
908        let _ = mbd; // suppress unused warning
909    }
910
911    // ============== Additional: snapshot-based streaming ==============
912
913    #[test]
914    fn test_rolling_snapshot_produces_valid_mbd() {
915        let n = 8;
916        let m = 10;
917        let data = generate_centered_data(n, m);
918
919        let mut rolling = RollingReference::new(n, m);
920        for i in 0..n {
921            let curve = extract_curve(&data, i, n, m);
922            rolling.push(&curve);
923        }
924
925        let snapshot = rolling.snapshot();
926        let mbd = StreamingMbd::new(snapshot);
927
928        let batch_depths = modified_band_1d(&data, &data, n, n, m);
929        let streaming_depths = mbd.depth_batch(&data, n);
930
931        for (b, s) in batch_depths.iter().zip(streaming_depths.iter()) {
932            assert!(
933                (b - s).abs() < 1e-10,
934                "Snapshot MBD mismatch: batch={}, streaming={}",
935                b,
936                s
937            );
938        }
939    }
940}