Skip to main content

fdars_core/
streaming_depth.rs

1//! Streaming / online depth computation for functional data.
2//!
3//! This module decouples reference-set construction from query evaluation,
4//! enabling efficient depth computation in streaming scenarios:
5//!
6//! - [`SortedReferenceState`] pre-sorts reference values per time point for O(log N) rank queries.
7//! - [`StreamingMbd`] uses a rank-based combinatorial identity to compute Modified Band Depth
8//!   in O(T log N) per query instead of O(N² T).
9//! - [`StreamingFraimanMuniz`] computes Fraiman-Muniz depth via binary search on sorted columns.
10//! - [`StreamingBd`] computes Band Depth with decoupled reference and early-exit optimisation.
11//! - [`RollingReference`] maintains a sliding window of reference curves with incremental
12//!   sorted-column updates.
13
14use std::collections::VecDeque;
15
16use crate::iter_maybe_parallel;
17use crate::matrix::FdMatrix;
18#[cfg(feature = "parallel")]
19use rayon::iter::ParallelIterator;
20
21// ---------------------------------------------------------------------------
22// Helper: choose-2 combinator
23// ---------------------------------------------------------------------------
24
25#[inline]
26fn c2(k: usize) -> usize {
27    k * k.wrapping_sub(1) / 2
28}
29
30// ===========================================================================
31// SortedReferenceState
32// ===========================================================================
33
34/// Pre-sorted reference values at each time point for O(log N) rank queries.
35///
36/// Constructed once from a column-major reference matrix and then shared
37/// (immutably) by any number of streaming depth estimators.
38pub struct SortedReferenceState {
39    /// `sorted_columns[t]` contains the reference values at time point `t`, sorted ascending.
40    sorted_columns: Vec<Vec<f64>>,
41    nori: usize,
42    n_points: usize,
43}
44
45impl SortedReferenceState {
46    /// Build from a column-major reference matrix.
47    ///
48    /// * `data_ori` – reference matrix of shape `nori × n_points`
49    ///
50    /// Complexity: O(T × N log N)  (parallelised over time points).
51    pub fn from_reference(data_ori: &FdMatrix) -> Self {
52        let nori = data_ori.nrows();
53        let n_points = data_ori.ncols();
54        if nori == 0 || n_points == 0 {
55            return Self {
56                sorted_columns: Vec::new(),
57                nori,
58                n_points,
59            };
60        }
61        let sorted_columns: Vec<Vec<f64>> = iter_maybe_parallel!(0..n_points)
62            .map(|t| {
63                let mut col = data_ori.column(t).to_vec();
64                col.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
65                col
66            })
67            .collect();
68        Self {
69            sorted_columns,
70            nori,
71            n_points,
72        }
73    }
74
75    /// Returns `(below, above)` — the count of reference values strictly below
76    /// and strictly above `x` at time point `t`.
77    ///
78    /// Complexity: O(log N) via two binary searches.
79    #[inline]
80    pub fn rank_at(&self, t: usize, x: f64) -> (usize, usize) {
81        let col = &self.sorted_columns[t];
82        let below = col.partition_point(|&v| v < x);
83        let at_or_below = col.partition_point(|&v| v <= x);
84        let above = self.nori - at_or_below;
85        (below, above)
86    }
87
88    /// Number of reference observations.
89    #[inline]
90    pub fn nori(&self) -> usize {
91        self.nori
92    }
93
94    /// Number of evaluation points.
95    #[inline]
96    pub fn n_points(&self) -> usize {
97        self.n_points
98    }
99}
100
101// ===========================================================================
102// StreamingDepth trait
103// ===========================================================================
104
105/// Trait for streaming depth estimators backed by a pre-built reference state.
106pub trait StreamingDepth {
107    /// Depth of a single curve given as a contiguous `&[f64]` of length `n_points`.
108    fn depth_one(&self, curve: &[f64]) -> f64;
109
110    /// Batch depth for a matrix of query curves (`nobj × n_points`).
111    fn depth_batch(&self, data_obj: &FdMatrix) -> Vec<f64>;
112
113    /// Number of evaluation points.
114    fn n_points(&self) -> usize;
115
116    /// Number of reference observations backing this estimator.
117    fn n_reference(&self) -> usize;
118}
119
120// ===========================================================================
121// StreamingMbd — rank-based Modified Band Depth, O(T log N) per query
122// ===========================================================================
123
124/// Rank-based Modified Band Depth estimator.
125///
126/// Uses the combinatorial identity: at time t with `b` values strictly below
127/// x(t) and `a` strictly above,
128///
129/// > pairs containing x(t) = C(N,2) − C(b,2) − C(a,2)
130///
131/// MBD(x) = (1 / (C(N,2) × T)) × Σ_t [C(N,2) − C(b_t,2) − C(a_t,2)]
132///
133/// Per-query complexity: **O(T × log N)** instead of O(N² × T).
134pub struct StreamingMbd {
135    state: SortedReferenceState,
136}
137
138impl StreamingMbd {
139    pub fn new(state: SortedReferenceState) -> Self {
140        Self { state }
141    }
142
143    /// Compute MBD for a single row-layout curve using rank formula.
144    #[inline]
145    fn mbd_one_inner(&self, curve: &[f64]) -> f64 {
146        let n = self.state.nori;
147        if n < 2 {
148            return 0.0;
149        }
150        let cn2 = c2(n);
151        let t_len = self.state.n_points;
152        let mut total = 0usize;
153        for t in 0..t_len {
154            let (below, above) = self.state.rank_at(t, curve[t]);
155            total += cn2 - c2(below) - c2(above);
156        }
157        total as f64 / (cn2 as f64 * t_len as f64)
158    }
159
160    /// Compute MBD for row `row` of `data` without allocating a temporary Vec.
161    #[inline]
162    fn mbd_one_from_row(&self, data: &FdMatrix, row: usize) -> f64 {
163        let n = self.state.nori;
164        if n < 2 {
165            return 0.0;
166        }
167        let cn2 = c2(n);
168        let t_len = self.state.n_points;
169        let mut total = 0usize;
170        for t in 0..t_len {
171            let (below, above) = self.state.rank_at(t, data[(row, t)]);
172            total += cn2 - c2(below) - c2(above);
173        }
174        total as f64 / (cn2 as f64 * t_len as f64)
175    }
176}
177
178impl StreamingDepth for StreamingMbd {
179    fn depth_one(&self, curve: &[f64]) -> f64 {
180        self.mbd_one_inner(curve)
181    }
182
183    fn depth_batch(&self, data_obj: &FdMatrix) -> Vec<f64> {
184        let nobj = data_obj.nrows();
185        if nobj == 0 || self.state.n_points == 0 || self.state.nori < 2 {
186            return vec![0.0; nobj];
187        }
188        iter_maybe_parallel!(0..nobj)
189            .map(|i| self.mbd_one_from_row(data_obj, i))
190            .collect()
191    }
192
193    fn n_points(&self) -> usize {
194        self.state.n_points
195    }
196
197    fn n_reference(&self) -> usize {
198        self.state.nori
199    }
200}
201
202// ===========================================================================
203// StreamingFraimanMuniz — O(T log N) FM Depth
204// ===========================================================================
205
206/// Streaming Fraiman-Muniz depth estimator.
207///
208/// Uses binary search on sorted columns to compute the empirical CDF at each
209/// time point: Fn(x) = #{ref ≤ x} / N.
210///
211/// Per-query complexity: **O(T × log N)** instead of O(T × N).
212pub struct StreamingFraimanMuniz {
213    state: SortedReferenceState,
214    scale: bool,
215}
216
217impl StreamingFraimanMuniz {
218    pub fn new(state: SortedReferenceState, scale: bool) -> Self {
219        Self { state, scale }
220    }
221
222    #[inline]
223    fn fm_one_inner(&self, curve: &[f64]) -> f64 {
224        let n = self.state.nori;
225        if n == 0 {
226            return 0.0;
227        }
228        let t_len = self.state.n_points;
229        if t_len == 0 {
230            return 0.0;
231        }
232        let scale_factor = if self.scale { 2.0 } else { 1.0 };
233        let mut depth_sum = 0.0;
234        for t in 0..t_len {
235            let col = &self.state.sorted_columns[t];
236            let at_or_below = col.partition_point(|&v| v <= curve[t]);
237            let fn_x = at_or_below as f64 / n as f64;
238            depth_sum += fn_x.min(1.0 - fn_x) * scale_factor;
239        }
240        depth_sum / t_len as f64
241    }
242
243    /// Compute FM depth for row `row` of `data` without allocating a temporary Vec.
244    #[inline]
245    fn fm_one_from_row(&self, data: &FdMatrix, row: usize) -> f64 {
246        let n = self.state.nori;
247        if n == 0 {
248            return 0.0;
249        }
250        let t_len = self.state.n_points;
251        if t_len == 0 {
252            return 0.0;
253        }
254        let scale_factor = if self.scale { 2.0 } else { 1.0 };
255        let mut depth_sum = 0.0;
256        for t in 0..t_len {
257            let col = &self.state.sorted_columns[t];
258            let at_or_below = col.partition_point(|&v| v <= data[(row, t)]);
259            let fn_x = at_or_below as f64 / n as f64;
260            depth_sum += fn_x.min(1.0 - fn_x) * scale_factor;
261        }
262        depth_sum / t_len as f64
263    }
264}
265
266impl StreamingDepth for StreamingFraimanMuniz {
267    fn depth_one(&self, curve: &[f64]) -> f64 {
268        self.fm_one_inner(curve)
269    }
270
271    fn depth_batch(&self, data_obj: &FdMatrix) -> Vec<f64> {
272        let nobj = data_obj.nrows();
273        if nobj == 0 || self.state.n_points == 0 || self.state.nori == 0 {
274            return vec![0.0; nobj];
275        }
276        iter_maybe_parallel!(0..nobj)
277            .map(|i| self.fm_one_from_row(data_obj, i))
278            .collect()
279    }
280
281    fn n_points(&self) -> usize {
282        self.state.n_points
283    }
284
285    fn n_reference(&self) -> usize {
286        self.state.nori
287    }
288}
289
290// ===========================================================================
291// FullReferenceState + StreamingBd — Band Depth with decoupled reference
292// ===========================================================================
293
294/// Full reference state that keeps per-curve values alongside sorted columns.
295///
296/// Required by Band Depth (BD), which checks all-or-nothing containment across
297/// ALL time points and therefore cannot decompose into per-point rank queries.
298pub struct FullReferenceState {
299    /// Sorted columns for rank queries (shared with MBD/FM estimators if desired).
300    pub sorted: SortedReferenceState,
301    /// `values_by_curve[j][t]` = reference curve j at time point t (row layout).
302    values_by_curve: Vec<Vec<f64>>,
303}
304
305impl FullReferenceState {
306    /// Build from a column-major reference matrix.
307    pub fn from_reference(data_ori: &FdMatrix) -> Self {
308        let nori = data_ori.nrows();
309        let n_points = data_ori.ncols();
310        let sorted = SortedReferenceState::from_reference(data_ori);
311        let values_by_curve: Vec<Vec<f64>> = (0..nori)
312            .map(|j| (0..n_points).map(|t| data_ori[(j, t)]).collect())
313            .collect();
314        Self {
315            sorted,
316            values_by_curve,
317        }
318    }
319}
320
321/// Streaming Band Depth estimator.
322///
323/// BD requires all-or-nothing containment across ALL time points — it does not
324/// decompose per-point like MBD. The streaming advantage here is **reference
325/// decoupling** (no re-parsing the matrix) and **early-exit per pair** (break
326/// on first time point where x is outside the band), not an asymptotic
327/// improvement.
328pub struct StreamingBd {
329    state: FullReferenceState,
330}
331
332impl StreamingBd {
333    pub fn new(state: FullReferenceState) -> Self {
334        Self { state }
335    }
336
337    #[inline]
338    fn bd_one_inner(&self, curve: &[f64]) -> f64 {
339        let n = self.state.sorted.nori;
340        if n < 2 {
341            return 0.0;
342        }
343        let n_pairs = c2(n);
344        let n_points = self.state.sorted.n_points;
345
346        let mut count_in_band = 0usize;
347        for j in 0..n {
348            for k in (j + 1)..n {
349                let mut inside = true;
350                for t in 0..n_points {
351                    let x_t = curve[t];
352                    let y_j_t = self.state.values_by_curve[j][t];
353                    let y_k_t = self.state.values_by_curve[k][t];
354                    let band_min = y_j_t.min(y_k_t);
355                    let band_max = y_j_t.max(y_k_t);
356                    if x_t < band_min || x_t > band_max {
357                        inside = false;
358                        break;
359                    }
360                }
361                if inside {
362                    count_in_band += 1;
363                }
364            }
365        }
366        count_in_band as f64 / n_pairs as f64
367    }
368
369    /// Compute BD for row `row` of `data` without allocating a temporary Vec.
370    #[inline]
371    fn bd_one_from_row(&self, data: &FdMatrix, row: usize) -> f64 {
372        let n = self.state.sorted.nori;
373        if n < 2 {
374            return 0.0;
375        }
376        let n_pairs = c2(n);
377        let n_points = self.state.sorted.n_points;
378
379        let mut count_in_band = 0usize;
380        for j in 0..n {
381            for k in (j + 1)..n {
382                let mut inside = true;
383                for t in 0..n_points {
384                    let x_t = data[(row, t)];
385                    let y_j_t = self.state.values_by_curve[j][t];
386                    let y_k_t = self.state.values_by_curve[k][t];
387                    let band_min = y_j_t.min(y_k_t);
388                    let band_max = y_j_t.max(y_k_t);
389                    if x_t < band_min || x_t > band_max {
390                        inside = false;
391                        break;
392                    }
393                }
394                if inside {
395                    count_in_band += 1;
396                }
397            }
398        }
399        count_in_band as f64 / n_pairs as f64
400    }
401}
402
403impl StreamingDepth for StreamingBd {
404    fn depth_one(&self, curve: &[f64]) -> f64 {
405        self.bd_one_inner(curve)
406    }
407
408    fn depth_batch(&self, data_obj: &FdMatrix) -> Vec<f64> {
409        let nobj = data_obj.nrows();
410        let n = self.state.sorted.nori;
411        if nobj == 0 || self.state.sorted.n_points == 0 || n < 2 {
412            return vec![0.0; nobj];
413        }
414        iter_maybe_parallel!(0..nobj)
415            .map(|i| self.bd_one_from_row(data_obj, i))
416            .collect()
417    }
418
419    fn n_points(&self) -> usize {
420        self.state.sorted.n_points
421    }
422
423    fn n_reference(&self) -> usize {
424        self.state.sorted.nori
425    }
426}
427
428// ===========================================================================
429// RollingReference — sliding window with incremental sorted-column updates
430// ===========================================================================
431
432/// Sliding window of reference curves with incrementally maintained sorted columns.
433///
434/// When a new curve is pushed and the window is at capacity, the oldest curve
435/// is evicted. For each time point the old value is removed (binary-search +
436/// `Vec::remove`) and the new value is inserted (binary-search + `Vec::insert`).
437///
438/// Complexity per push: O(T × N) due to element shifting in the sorted vectors.
439pub struct RollingReference {
440    curves: VecDeque<Vec<f64>>,
441    capacity: usize,
442    n_points: usize,
443    sorted_columns: Vec<Vec<f64>>,
444}
445
446impl RollingReference {
447    /// Create an empty rolling window.
448    ///
449    /// * `capacity` – maximum number of curves in the window (must be ≥ 1).
450    /// * `n_points` – number of evaluation points per curve.
451    pub fn new(capacity: usize, n_points: usize) -> Self {
452        assert!(capacity >= 1, "capacity must be at least 1");
453        Self {
454            curves: VecDeque::with_capacity(capacity),
455            capacity,
456            n_points,
457            sorted_columns: (0..n_points)
458                .map(|_| Vec::with_capacity(capacity))
459                .collect(),
460        }
461    }
462
463    /// Push a new curve into the window.
464    ///
465    /// If the window is at capacity, the oldest curve is evicted and returned.
466    /// For each time point, the sorted column is updated incrementally.
467    pub fn push(&mut self, curve: &[f64]) -> Option<Vec<f64>> {
468        assert_eq!(
469            curve.len(),
470            self.n_points,
471            "curve length {} does not match n_points {}",
472            curve.len(),
473            self.n_points
474        );
475
476        let evicted = if self.curves.len() == self.capacity {
477            let old = self.curves.pop_front().unwrap();
478            // Remove old values from sorted columns
479            for t in 0..self.n_points {
480                let col = &mut self.sorted_columns[t];
481                let old_val = old[t];
482                let pos = col.partition_point(|&v| v < old_val);
483                // Find exact match (handles duplicates by scanning nearby)
484                let mut found = false;
485                for idx in pos..col.len() {
486                    if col[idx] == old_val {
487                        col.remove(idx);
488                        found = true;
489                        break;
490                    }
491                    if col[idx] > old_val {
492                        break;
493                    }
494                }
495                if !found {
496                    // Fallback: scan from pos backwards for floating-point edge cases
497                    for idx in (0..pos).rev() {
498                        if col[idx] == old_val {
499                            col.remove(idx);
500                            break;
501                        }
502                        if col[idx] < old_val {
503                            break;
504                        }
505                    }
506                }
507            }
508            Some(old)
509        } else {
510            None
511        };
512
513        // Insert new values into sorted columns
514        let new_curve: Vec<f64> = curve.to_vec();
515        for t in 0..self.n_points {
516            let col = &mut self.sorted_columns[t];
517            let val = new_curve[t];
518            let pos = col.partition_point(|&v| v < val);
519            col.insert(pos, val);
520        }
521        self.curves.push_back(new_curve);
522
523        evicted
524    }
525
526    /// Take a snapshot of the current sorted reference state.
527    ///
528    /// This clones the sorted columns. For repeated queries, prefer
529    /// [`mbd_one`](Self::mbd_one) which queries the window directly.
530    pub fn snapshot(&self) -> SortedReferenceState {
531        SortedReferenceState {
532            sorted_columns: self.sorted_columns.clone(),
533            nori: self.curves.len(),
534            n_points: self.n_points,
535        }
536    }
537
538    /// Compute rank-based MBD for a single curve directly against the current window.
539    ///
540    /// Avoids the overhead of cloning sorted columns into a snapshot.
541    pub fn mbd_one(&self, curve: &[f64]) -> f64 {
542        let n = self.curves.len();
543        if n < 2 || self.n_points == 0 {
544            return 0.0;
545        }
546        assert_eq!(
547            curve.len(),
548            self.n_points,
549            "curve length {} does not match n_points {}",
550            curve.len(),
551            self.n_points
552        );
553        let cn2 = c2(n);
554        let mut total = 0usize;
555        for t in 0..self.n_points {
556            let col = &self.sorted_columns[t];
557            let below = col.partition_point(|&v| v < curve[t]);
558            let at_or_below = col.partition_point(|&v| v <= curve[t]);
559            let above = n - at_or_below;
560            total += cn2 - c2(below) - c2(above);
561        }
562        total as f64 / (cn2 as f64 * self.n_points as f64)
563    }
564
565    /// Number of curves currently in the window.
566    #[inline]
567    pub fn len(&self) -> usize {
568        self.curves.len()
569    }
570
571    /// Whether the window is empty.
572    #[inline]
573    pub fn is_empty(&self) -> bool {
574        self.curves.is_empty()
575    }
576
577    /// Maximum capacity of the window.
578    #[inline]
579    pub fn capacity(&self) -> usize {
580        self.capacity
581    }
582}
583
584// ===========================================================================
585// Tests
586// ===========================================================================
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591    use crate::depth::{band_1d, fraiman_muniz_1d, modified_band_1d};
592    use crate::matrix::FdMatrix;
593    use std::f64::consts::PI;
594
595    fn uniform_grid(n: usize) -> Vec<f64> {
596        (0..n).map(|i| i as f64 / (n - 1) as f64).collect()
597    }
598
599    fn generate_centered_data(n: usize, m: usize) -> Vec<f64> {
600        let argvals = uniform_grid(m);
601        let mut data = vec![0.0; n * m];
602        for i in 0..n {
603            let offset = (i as f64 - n as f64 / 2.0) / (n as f64);
604            for j in 0..m {
605                data[i + j * n] = (2.0 * PI * argvals[j]).sin() + offset;
606            }
607        }
608        data
609    }
610
611    /// Extract a single curve (row i) from column-major data into row layout.
612    fn extract_curve(data: &[f64], i: usize, n: usize, m: usize) -> Vec<f64> {
613        (0..m).map(|t| data[i + t * n]).collect()
614    }
615
616    // ============== Rank correctness ==============
617
618    #[test]
619    fn test_rank_basic() {
620        // 5 reference curves, 3 time points
621        // Column 0: [1, 2, 3, 4, 5]
622        let data = vec![
623            1.0, 2.0, 3.0, 4.0, 5.0, // t=0
624            10.0, 20.0, 30.0, 40.0, 50.0, // t=1
625            100.0, 200.0, 300.0, 400.0, 500.0, // t=2
626        ];
627        let mat = FdMatrix::from_column_major(data, 5, 3).unwrap();
628        let state = SortedReferenceState::from_reference(&mat);
629
630        // At t=0, x=3.0: below=2 (1,2), above=2 (4,5)
631        let (below, above) = state.rank_at(0, 3.0);
632        assert_eq!(below, 2);
633        assert_eq!(above, 2);
634
635        // At t=1, x=25.0: below=2 (10,20), above=3 (30,40,50)
636        let (below, above) = state.rank_at(1, 25.0);
637        assert_eq!(below, 2);
638        assert_eq!(above, 3);
639    }
640
641    #[test]
642    fn test_rank_boundary_values() {
643        // All values identical
644        let data = vec![5.0, 5.0, 5.0, 5.0];
645        let mat = FdMatrix::from_column_major(data, 4, 1).unwrap();
646        let state = SortedReferenceState::from_reference(&mat);
647
648        // x=5.0 exactly: none strictly below, none strictly above
649        let (below, above) = state.rank_at(0, 5.0);
650        assert_eq!(below, 0);
651        assert_eq!(above, 0);
652
653        // x < all: below=0, above=4
654        let (below, above) = state.rank_at(0, 3.0);
655        assert_eq!(below, 0);
656        assert_eq!(above, 4);
657
658        // x > all: below=4, above=0
659        let (below, above) = state.rank_at(0, 7.0);
660        assert_eq!(below, 4);
661        assert_eq!(above, 0);
662    }
663
664    #[test]
665    fn test_rank_duplicates() {
666        // Values with duplicates: [1, 2, 2, 3, 3, 3]
667        let data = vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0];
668        let mat = FdMatrix::from_column_major(data, 6, 1).unwrap();
669        let state = SortedReferenceState::from_reference(&mat);
670
671        // x=2.0: below=1 (just 1), above=3 (three 3s)
672        let (below, above) = state.rank_at(0, 2.0);
673        assert_eq!(below, 1);
674        assert_eq!(above, 3);
675
676        // x=3.0: below=3 (1,2,2), above=0
677        let (below, above) = state.rank_at(0, 3.0);
678        assert_eq!(below, 3);
679        assert_eq!(above, 0);
680    }
681
682    // ============== Batch equivalence ==============
683
684    #[test]
685    fn test_streaming_mbd_matches_batch() {
686        let n = 15;
687        let m = 20;
688        let data = generate_centered_data(n, m);
689
690        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
691        let batch = modified_band_1d(&mat, &mat);
692        let state = SortedReferenceState::from_reference(&mat);
693        let streaming = StreamingMbd::new(state);
694        let streaming_result = streaming.depth_batch(&mat);
695
696        assert_eq!(batch.len(), streaming_result.len());
697        for (b, s) in batch.iter().zip(streaming_result.iter()) {
698            assert!(
699                (b - s).abs() < 1e-10,
700                "MBD mismatch: batch={}, streaming={}",
701                b,
702                s
703            );
704        }
705    }
706
707    #[test]
708    fn test_streaming_fm_matches_batch() {
709        let n = 15;
710        let m = 20;
711        let data = generate_centered_data(n, m);
712
713        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
714        for scale in [true, false] {
715            let batch = fraiman_muniz_1d(&mat, &mat, scale);
716            let state = SortedReferenceState::from_reference(&mat);
717            let streaming = StreamingFraimanMuniz::new(state, scale);
718            let streaming_result = streaming.depth_batch(&mat);
719
720            assert_eq!(batch.len(), streaming_result.len());
721            for (b, s) in batch.iter().zip(streaming_result.iter()) {
722                assert!(
723                    (b - s).abs() < 1e-10,
724                    "FM mismatch (scale={}): batch={}, streaming={}",
725                    scale,
726                    b,
727                    s
728                );
729            }
730        }
731    }
732
733    #[test]
734    fn test_streaming_bd_matches_batch() {
735        let n = 10;
736        let m = 20;
737        let data = generate_centered_data(n, m);
738
739        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
740        let batch = band_1d(&mat, &mat);
741        let full_state = FullReferenceState::from_reference(&mat);
742        let streaming = StreamingBd::new(full_state);
743        let streaming_result = streaming.depth_batch(&mat);
744
745        assert_eq!(batch.len(), streaming_result.len());
746        for (b, s) in batch.iter().zip(streaming_result.iter()) {
747            assert!(
748                (b - s).abs() < 1e-10,
749                "BD mismatch: batch={}, streaming={}",
750                b,
751                s
752            );
753        }
754    }
755
756    // ============== Rolling reference ==============
757
758    #[test]
759    fn test_rolling_sorted_columns_maintained() {
760        let mut rolling = RollingReference::new(3, 2);
761
762        rolling.push(&[1.0, 10.0]);
763        assert_eq!(rolling.sorted_columns[0], vec![1.0]);
764        assert_eq!(rolling.sorted_columns[1], vec![10.0]);
765
766        rolling.push(&[3.0, 5.0]);
767        assert_eq!(rolling.sorted_columns[0], vec![1.0, 3.0]);
768        assert_eq!(rolling.sorted_columns[1], vec![5.0, 10.0]);
769
770        rolling.push(&[2.0, 7.0]);
771        assert_eq!(rolling.sorted_columns[0], vec![1.0, 2.0, 3.0]);
772        assert_eq!(rolling.sorted_columns[1], vec![5.0, 7.0, 10.0]);
773
774        // Push a 4th — evicts [1.0, 10.0]
775        let evicted = rolling.push(&[0.5, 8.0]);
776        assert_eq!(evicted, Some(vec![1.0, 10.0]));
777        assert_eq!(rolling.sorted_columns[0], vec![0.5, 2.0, 3.0]);
778        assert_eq!(rolling.sorted_columns[1], vec![5.0, 7.0, 8.0]);
779    }
780
781    #[test]
782    fn test_rolling_mbd_matches_batch() {
783        let n = 10;
784        let m = 15;
785        let data = generate_centered_data(n, m);
786
787        // Fill a rolling window with the same curves
788        let mut rolling = RollingReference::new(n, m);
789        for i in 0..n {
790            let curve = extract_curve(&data, i, n, m);
791            rolling.push(&curve);
792        }
793
794        // mbd_one should match batch for each curve
795        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
796        let batch = modified_band_1d(&mat, &mat);
797        for i in 0..n {
798            let curve = extract_curve(&data, i, n, m);
799            let rolling_depth = rolling.mbd_one(&curve);
800            assert!(
801                (batch[i] - rolling_depth).abs() < 1e-10,
802                "Rolling MBD mismatch at i={}: batch={}, rolling={}",
803                i,
804                batch[i],
805                rolling_depth
806            );
807        }
808    }
809
810    #[test]
811    fn test_rolling_eviction_correctness() {
812        let m = 5;
813        let mut rolling = RollingReference::new(3, m);
814
815        // Push 5 curves — window should only contain the last 3
816        let curves: Vec<Vec<f64>> = (0..5)
817            .map(|i| (0..m).map(|t| (i * m + t) as f64).collect())
818            .collect();
819
820        for c in &curves {
821            rolling.push(c);
822        }
823
824        assert_eq!(rolling.len(), 3);
825
826        // Snapshot should match manually-built state from curves 2,3,4
827        let snapshot = rolling.snapshot();
828        assert_eq!(snapshot.nori(), 3);
829
830        // Build reference data manually from curves 2..5
831        let mut ref_data = vec![0.0; 3 * m];
832        for (idx, ci) in (2..5).enumerate() {
833            for t in 0..m {
834                ref_data[idx + t * 3] = curves[ci][t];
835            }
836        }
837        let ref_mat = FdMatrix::from_column_major(ref_data, 3, m).unwrap();
838        let expected = SortedReferenceState::from_reference(&ref_mat);
839
840        for t in 0..m {
841            assert_eq!(
842                snapshot.sorted_columns[t], expected.sorted_columns[t],
843                "sorted columns differ at t={}",
844                t
845            );
846        }
847    }
848
849    // ============== Properties ==============
850
851    #[test]
852    fn test_depth_in_unit_interval() {
853        let n = 20;
854        let m = 30;
855        let data = generate_centered_data(n, m);
856        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
857
858        let state_mbd = SortedReferenceState::from_reference(&mat);
859        let mbd = StreamingMbd::new(state_mbd);
860        for d in mbd.depth_batch(&mat) {
861            assert!((0.0..=1.0).contains(&d), "MBD out of range: {}", d);
862        }
863
864        let state_fm = SortedReferenceState::from_reference(&mat);
865        let fm = StreamingFraimanMuniz::new(state_fm, true);
866        for d in fm.depth_batch(&mat) {
867            assert!((0.0..=1.0).contains(&d), "FM out of range: {}", d);
868        }
869
870        let full = FullReferenceState::from_reference(&mat);
871        let bd = StreamingBd::new(full);
872        for d in bd.depth_batch(&mat) {
873            assert!((0.0..=1.0).contains(&d), "BD out of range: {}", d);
874        }
875    }
876
877    #[test]
878    fn test_central_curves_deeper() {
879        let n = 20;
880        let m = 30;
881        let data = generate_centered_data(n, m);
882        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
883
884        let state = SortedReferenceState::from_reference(&mat);
885        let mbd = StreamingMbd::new(state);
886        let depths = mbd.depth_batch(&mat);
887
888        let central_depth = depths[n / 2];
889        let edge_depth = depths[0];
890        assert!(
891            central_depth > edge_depth,
892            "Central curve should be deeper: {} > {}",
893            central_depth,
894            edge_depth
895        );
896    }
897
898    #[test]
899    fn test_empty_inputs() {
900        let empty = FdMatrix::zeros(0, 0);
901        let state = SortedReferenceState::from_reference(&empty);
902        let mbd = StreamingMbd::new(state);
903        assert_eq!(mbd.depth_one(&[]), 0.0);
904
905        let state = SortedReferenceState::from_reference(&empty);
906        let fm = StreamingFraimanMuniz::new(state, true);
907        assert_eq!(fm.depth_one(&[]), 0.0);
908    }
909
910    #[test]
911    fn test_depth_one_matches_depth_batch_single() {
912        let n = 10;
913        let m = 15;
914        let data = generate_centered_data(n, m);
915        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
916
917        // Build a 1-curve column-major "matrix" from curve 3
918        let curve = extract_curve(&data, 3, n, m);
919        let single_mat = FdMatrix::from_column_major(curve.clone(), 1, m).unwrap();
920
921        let state = SortedReferenceState::from_reference(&mat);
922        let mbd = StreamingMbd::new(state);
923
924        let one = mbd.depth_one(&curve);
925        let batch = mbd.depth_batch(&single_mat);
926        assert!(
927            (one - batch[0]).abs() < 1e-14,
928            "depth_one ({}) != depth_batch ({}) for single curve",
929            one,
930            batch[0]
931        );
932    }
933
934    // ============== Thread safety ==============
935
936    #[test]
937    fn test_send_sync() {
938        fn assert_send_sync<T: Send + Sync>() {}
939        assert_send_sync::<SortedReferenceState>();
940        assert_send_sync::<StreamingMbd>();
941        assert_send_sync::<StreamingFraimanMuniz>();
942        assert_send_sync::<FullReferenceState>();
943        assert_send_sync::<StreamingBd>();
944        assert_send_sync::<RollingReference>();
945    }
946
947    // ============== Edge cases ==============
948
949    #[test]
950    fn test_single_reference_curve() {
951        // nori=1: C(1,2) = 0, MBD is undefined → returns 0
952        let data = vec![1.0, 2.0, 3.0]; // 1 curve, 3 time points
953        let mat = FdMatrix::from_column_major(data, 1, 3).unwrap();
954        let state = SortedReferenceState::from_reference(&mat);
955        let mbd = StreamingMbd::new(state);
956        assert_eq!(mbd.depth_one(&[1.0, 2.0, 3.0]), 0.0);
957
958        // BD also needs at least 2
959        let full = FullReferenceState::from_reference(&mat);
960        let bd = StreamingBd::new(full);
961        assert_eq!(bd.depth_one(&[1.0, 2.0, 3.0]), 0.0);
962    }
963
964    #[test]
965    fn test_capacity_one_window() {
966        let mut rolling = RollingReference::new(1, 3);
967
968        rolling.push(&[1.0, 2.0, 3.0]);
969        assert_eq!(rolling.len(), 1);
970        // MBD with 1 curve → 0
971        assert_eq!(rolling.mbd_one(&[1.0, 2.0, 3.0]), 0.0);
972
973        let evicted = rolling.push(&[4.0, 5.0, 6.0]);
974        assert_eq!(evicted, Some(vec![1.0, 2.0, 3.0]));
975        assert_eq!(rolling.len(), 1);
976    }
977
978    #[test]
979    #[should_panic(expected = "curve length")]
980    fn test_curve_length_mismatch() {
981        let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
982        let state = SortedReferenceState::from_reference(&mat);
983        let mbd = StreamingMbd::new(state);
984        // Curve has 3 elements but n_points is 2 — should ideally be caught.
985        // depth_one doesn't assert length (it just indexes), but rolling does.
986        let mut rolling = RollingReference::new(5, 2);
987        rolling.push(&[1.0, 2.0, 3.0]); // panics: length mismatch
988        let _ = mbd; // suppress unused warning
989    }
990
991    // ============== Additional: snapshot-based streaming ==============
992
993    #[test]
994    fn test_rolling_snapshot_produces_valid_mbd() {
995        let n = 8;
996        let m = 10;
997        let data = generate_centered_data(n, m);
998
999        let mut rolling = RollingReference::new(n, m);
1000        for i in 0..n {
1001            let curve = extract_curve(&data, i, n, m);
1002            rolling.push(&curve);
1003        }
1004
1005        let snapshot = rolling.snapshot();
1006        let mbd = StreamingMbd::new(snapshot);
1007
1008        let mat = FdMatrix::from_slice(&data, n, m).unwrap();
1009        let batch_depths = modified_band_1d(&mat, &mat);
1010        let streaming_depths = mbd.depth_batch(&mat);
1011
1012        for (b, s) in batch_depths.iter().zip(streaming_depths.iter()) {
1013            assert!(
1014                (b - s).abs() < 1e-10,
1015                "Snapshot MBD mismatch: batch={}, streaming={}",
1016                b,
1017                s
1018            );
1019        }
1020    }
1021}