Skip to main content

irithyll_core/drift/
adwin.rs

1//! ADWIN (ADaptive WINdowing) drift detector.
2//!
3//! ADWIN maintains a variable-length window of recent values using an exponential
4//! histogram for compression. It detects concept drift by comparing the means of
5//! sub-windows at every bucket boundary. When a statistically significant difference
6//! is found, the window is shrunk by removing the older portion.
7//!
8//! # Algorithm
9//!
10//! Each new value is inserted as a bucket of order 0. When a row accumulates more
11//! than `max_buckets` entries, the two oldest are merged into a single bucket at the
12//! next order (capacity doubles). This gives O(log W) memory for a window of width W.
13//!
14//! Drift is tested by iterating bucket boundaries from newest to oldest, splitting
15//! the window into a right (newer) and left (older) sub-window. The Hoeffding-based
16//! cut uses `epsilon = sqrt((1 / 2m) * ln(4 / delta'))` where `m` is the harmonic
17//! count and `delta' = delta / ln(width)`.
18//!
19//! # Reference
20//!
21//! Bifet, A. & Gavalda, R. (2007). "Learning from Time-Changing Data with Adaptive
22//! Windowing." In *Proceedings of the 2007 SIAM International Conference on Data Mining*.
23
24use alloc::boxed::Box;
25use alloc::vec;
26use alloc::vec::Vec;
27
28use super::{DriftDetector, DriftSignal};
29
30// ---------------------------------------------------------------------------
31// Bucket
32// ---------------------------------------------------------------------------
33
34/// A single bucket in the exponential histogram.
35///
36/// Buckets at row `i` summarise `2^i` original observations.
37#[derive(Debug, Clone)]
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39struct Bucket {
40    /// Sum of values stored in this bucket.
41    total: f64,
42    /// Variance accumulator: sum of squared deviations (variance * count).
43    variance: f64,
44    /// Number of original values represented by this bucket.
45    count: u64,
46}
47
48impl Bucket {
49    /// Create a leaf bucket for a single observed value.
50    #[inline]
51    fn singleton(value: f64) -> Self {
52        Self {
53            total: value,
54            variance: 0.0,
55            count: 1,
56        }
57    }
58
59    /// Merge two buckets into one, combining their sufficient statistics.
60    ///
61    /// The variance update follows the parallel variance formula:
62    /// `Var(A+B) = Var(A) + Var(B) + (mean_A - mean_B)^2 * n_A * n_B / (n_A + n_B)`
63    fn merge(a: &Bucket, b: &Bucket) -> Self {
64        let count = a.count + b.count;
65        let total = a.total + b.total;
66
67        let mean_a = a.total / a.count as f64;
68        let mean_b = b.total / b.count as f64;
69        let diff = mean_a - mean_b;
70        let variance = a.variance
71            + b.variance
72            + diff * diff * (a.count as f64) * (b.count as f64) / (count as f64);
73
74        Self {
75            total,
76            variance,
77            count,
78        }
79    }
80}
81
82// ---------------------------------------------------------------------------
83// Adwin
84// ---------------------------------------------------------------------------
85
86/// ADWIN (ADaptive WINdowing) drift detector.
87///
88/// Maintains an adaptive-size window over a stream of real-valued observations
89/// using an exponential histogram. Detects distribution shift by comparing
90/// sub-window means with a Hoeffding-style bound.
91///
92/// # Examples
93///
94/// ```
95/// use irithyll_core::drift::{DriftDetector, DriftSignal};
96/// use irithyll_core::drift::adwin::Adwin;
97///
98/// let mut det = Adwin::new();
99/// // Feed stable values
100/// for _ in 0..200 {
101///     let sig = det.update(0.0);
102///     assert_ne!(sig, DriftSignal::Drift);
103/// }
104/// ```
105#[derive(Debug, Clone)]
106#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
107pub struct Adwin {
108    /// Confidence parameter. Smaller values require stronger evidence to declare
109    /// drift. Default: 0.002.
110    delta: f64,
111    /// Maximum number of buckets allowed per row before compression triggers.
112    /// Default: 5.
113    max_buckets: usize,
114    /// Bucket rows. `rows[i]` holds buckets whose capacity is `2^i`.
115    /// Within each row the *oldest* bucket is at the front (index 0) and the
116    /// *newest* is at the back.
117    rows: Vec<Vec<Bucket>>,
118    /// Running sum of all values currently in the window.
119    total: f64,
120    /// Running variance accumulator for the full window.
121    variance: f64,
122    /// Total number of original observations in the window.
123    count: u64,
124    /// Width of the window (mirrors `count`; kept for semantic clarity).
125    width: u64,
126    /// Minimum number of observations before drift checking begins.
127    min_window: u64,
128}
129
130impl Adwin {
131    /// Create a new ADWIN detector with the default confidence `delta = 0.002`.
132    pub fn new() -> Self {
133        Self::with_delta(0.002)
134    }
135
136    /// Create a new ADWIN detector with the given confidence parameter.
137    ///
138    /// `delta` controls sensitivity: smaller values mean fewer false positives but
139    /// slower reaction to real drift. Typical range: 0.0001 to 0.1.
140    ///
141    /// # Panics
142    ///
143    /// Panics if `delta` is not in `(0, 1)`.
144    pub fn with_delta(delta: f64) -> Self {
145        assert!(
146            delta > 0.0 && delta < 1.0,
147            "delta must be in (0, 1), got {delta}"
148        );
149        Self {
150            delta,
151            max_buckets: 5,
152            rows: Vec::new(),
153            total: 0.0,
154            variance: 0.0,
155            count: 0,
156            width: 0,
157            min_window: 32,
158        }
159    }
160
161    /// Set the maximum number of buckets per row (default 5).
162    ///
163    /// Higher values use more memory but give finer granularity for split-point
164    /// testing. Must be at least 2.
165    pub fn with_max_buckets(mut self, m: usize) -> Self {
166        assert!(m >= 2, "max_buckets must be >= 2, got {m}");
167        self.max_buckets = m;
168        self
169    }
170
171    /// Set the minimum window size before drift checking begins (default 32).
172    pub fn with_min_window(mut self, min: u64) -> Self {
173        self.min_window = min;
174        self
175    }
176
177    /// Current window width (number of original observations).
178    #[inline]
179    pub fn width(&self) -> u64 {
180        self.width
181    }
182
183    // -----------------------------------------------------------------------
184    // Internal helpers
185    // -----------------------------------------------------------------------
186
187    /// Insert a new singleton bucket at row 0.
188    fn insert_bucket(&mut self, value: f64) {
189        if self.rows.is_empty() {
190            self.rows.push(Vec::new());
191        }
192        self.rows[0].push(Bucket::singleton(value));
193
194        // Update running stats using Welford's online method for variance.
195        self.count += 1;
196        self.width += 1;
197        let old_mean = if self.count > 1 {
198            (self.total) / (self.count - 1) as f64
199        } else {
200            0.0
201        };
202        self.total += value;
203        let new_mean = self.total / self.count as f64;
204        self.variance += (value - old_mean) * (value - new_mean);
205    }
206
207    /// Compress bucket rows: whenever a row has more than `max_buckets` entries,
208    /// merge the two oldest (front of the vec) into one bucket at the next row.
209    fn compress(&mut self) {
210        let max = self.max_buckets;
211        let mut row_idx = 0;
212        while row_idx < self.rows.len() {
213            if self.rows[row_idx].len() <= max {
214                break; // rows beyond this one can't have overflowed from this insertion
215            }
216            // Merge the two oldest buckets (indices 0 and 1).
217            let b1 = self.rows[row_idx].remove(0);
218            let b2 = self.rows[row_idx].remove(0);
219            let merged = Bucket::merge(&b1, &b2);
220
221            let next_row = row_idx + 1;
222            if next_row >= self.rows.len() {
223                self.rows.push(Vec::new());
224            }
225            self.rows[next_row].push(merged);
226
227            row_idx += 1;
228        }
229    }
230
231    /// Check for drift by scanning split points from newest to oldest.
232    ///
233    /// Returns `(drift_detected, warning_detected)`.
234    fn check_drift(&self) -> (bool, bool) {
235        if self.width <= self.min_window {
236            return (false, false);
237        }
238        // We need at least 2 observations on each side of a split.
239        if self.width < 4 {
240            return (false, false);
241        }
242
243        let ln_width = crate::math::ln(self.width as f64);
244        // Guard: if width is so small that ln(width) <= 0 we skip.
245        if ln_width <= 0.0 {
246            return (false, false);
247        }
248
249        let delta_prime = self.delta / ln_width;
250        let delta_warn = (2.0 * self.delta) / ln_width;
251
252        // Traverse buckets from newest to oldest, accumulating the "right"
253        // (newer) sub-window. The "left" (older) sub-window is the remainder.
254        let mut right_count: u64 = 0;
255        let mut right_total: f64 = 0.0;
256
257        let mut warning_found = false;
258
259        // Iterate rows from 0 (smallest buckets, most recent) upward, and
260        // within each row from the back (newest) to the front (oldest).
261        for row in &self.rows {
262            for bucket in row.iter().rev() {
263                right_count += bucket.count;
264                right_total += bucket.total;
265
266                let left_count = self.count - right_count;
267                if left_count < 1 || right_count < 1 {
268                    continue;
269                }
270
271                let left_total = self.total - right_total;
272
273                let mean_left = left_total / left_count as f64;
274                let mean_right = right_total / right_count as f64;
275                let abs_diff = (mean_left - mean_right).abs();
276
277                // Harmonic mean of the two sub-window sizes.
278                let n0 = left_count as f64;
279                let n1 = right_count as f64;
280                let m = 1.0 / (1.0 / n0 + 1.0 / n1);
281
282                // Drift threshold.
283                let epsilon_drift =
284                    crate::math::sqrt((1.0 / (2.0 * m)) * crate::math::ln(4.0 / delta_prime));
285
286                if abs_diff >= epsilon_drift {
287                    // No need to keep scanning -- we found the outermost drift point.
288                    return (true, true);
289                }
290
291                // Warning threshold (looser).
292                let epsilon_warn =
293                    crate::math::sqrt((1.0 / (2.0 * m)) * crate::math::ln(4.0 / delta_warn));
294                if abs_diff >= epsilon_warn {
295                    warning_found = true;
296                }
297            }
298        }
299
300        (false, warning_found)
301    }
302
303    /// Remove the oldest portion of the window up to (and including) the split
304    /// point where drift was detected. We find the split point again and remove
305    /// everything on the left side.
306    fn shrink_window(&mut self) {
307        let ln_width = crate::math::ln(self.width as f64);
308        if ln_width <= 0.0 {
309            return;
310        }
311        let delta_prime = self.delta / ln_width;
312
313        // Traverse newest-to-oldest just like check_drift, find the first
314        // split where drift is confirmed, then remove the left (older) side.
315        let mut right_count: u64 = 0;
316        let mut right_total: f64 = 0.0;
317        let mut right_variance: f64 = 0.0;
318
319        // We need to record which buckets comprise the right sub-window.
320        // Strategy: find how many buckets (from the newest end) form the right
321        // sub-window, then rebuild rows keeping only those.
322
323        // Collect all buckets in order from newest to oldest.
324        let mut all_buckets: Vec<(usize, usize)> = Vec::new(); // (row_idx, bucket_idx_in_row)
325        for (row_idx, row) in self.rows.iter().enumerate() {
326            for (bucket_idx, _) in row.iter().enumerate().rev() {
327                all_buckets.push((row_idx, bucket_idx));
328            }
329        }
330
331        let mut split_pos = all_buckets.len(); // keep everything by default
332        for (pos, &(row_idx, bucket_idx)) in all_buckets.iter().enumerate() {
333            let bucket = &self.rows[row_idx][bucket_idx];
334            // Update right sub-window variance using parallel formula.
335            if right_count > 0 {
336                let mean_right_old = right_total / right_count as f64;
337                let mean_bucket = bucket.total / bucket.count as f64;
338                let diff = mean_right_old - mean_bucket;
339                right_variance = right_variance
340                    + bucket.variance
341                    + diff * diff * (right_count as f64) * (bucket.count as f64)
342                        / (right_count + bucket.count) as f64;
343            } else {
344                right_variance = bucket.variance;
345            }
346            right_count += bucket.count;
347            right_total += bucket.total;
348
349            let left_count = self.count - right_count;
350            if left_count < 1 || right_count < 1 {
351                continue;
352            }
353            let left_total = self.total - right_total;
354            let mean_left = left_total / left_count as f64;
355            let mean_right = right_total / right_count as f64;
356            let abs_diff = (mean_left - mean_right).abs();
357
358            let n0 = left_count as f64;
359            let n1 = right_count as f64;
360            let m = 1.0 / (1.0 / n0 + 1.0 / n1);
361            let epsilon = crate::math::sqrt((1.0 / (2.0 * m)) * crate::math::ln(4.0 / delta_prime));
362
363            if abs_diff >= epsilon {
364                // The right sub-window (positions 0..=pos) is the part we keep.
365                split_pos = pos + 1;
366                break;
367            }
368        }
369
370        if split_pos >= all_buckets.len() {
371            // No split found (shouldn't happen if called after check_drift),
372            // but just in case, don't remove anything.
373            return;
374        }
375
376        // Rebuild: keep only the buckets in all_buckets[0..split_pos].
377        // These are the *right* (newer) sub-window.
378        let keep_set: Vec<(usize, usize)> = all_buckets[..split_pos].to_vec();
379
380        // Build a set for fast lookup: which (row, idx) to keep.
381        let mut new_rows: Vec<Vec<Bucket>> = Vec::new();
382        // We need to reconstruct in the proper order: within each row,
383        // oldest first (low index) to newest (high index).
384
385        // First, determine max row.
386        let max_row = self.rows.len();
387        new_rows.resize_with(max_row, Vec::new);
388
389        // Mark which buckets to keep. Since keep_set lists them newest-first,
390        // we process them in reverse to restore oldest-first order within rows.
391        let mut keep_flags: Vec<Vec<bool>> =
392            self.rows.iter().map(|row| vec![false; row.len()]).collect();
393        for &(r, b) in &keep_set {
394            keep_flags[r][b] = true;
395        }
396
397        let mut new_total: f64 = 0.0;
398        let mut new_count: u64 = 0;
399        for (row_idx, row) in self.rows.iter().enumerate() {
400            for (bucket_idx, bucket) in row.iter().enumerate() {
401                if keep_flags[row_idx][bucket_idx] {
402                    new_total += bucket.total;
403                    new_count += bucket.count;
404                    new_rows[row_idx].push(bucket.clone());
405                }
406            }
407        }
408
409        // Trim empty trailing rows.
410        while new_rows.last().is_some_and(|r| r.is_empty()) {
411            new_rows.pop();
412        }
413
414        self.rows = new_rows;
415        self.total = new_total;
416        self.count = new_count;
417        self.width = new_count;
418        // Recompute variance from the remaining buckets.
419        self.recompute_variance();
420    }
421
422    /// Recompute the aggregate variance from the bucket tree.
423    ///
424    /// This is called after shrinking the window because we cannot cheaply
425    /// subtract out the removed portion's contribution to variance.
426    fn recompute_variance(&mut self) {
427        if self.count == 0 {
428            self.variance = 0.0;
429            return;
430        }
431        // Merge all buckets' variances using the parallel formula, processing
432        // from oldest to newest.
433        let mut running_total: f64 = 0.0;
434        let mut running_count: u64 = 0;
435        let mut running_var: f64 = 0.0;
436
437        // Process in order: highest row first (oldest, largest buckets),
438        // then within each row from index 0 (oldest) upward.
439        for row in self.rows.iter().rev() {
440            for bucket in row.iter() {
441                if running_count == 0 {
442                    running_total = bucket.total;
443                    running_count = bucket.count;
444                    running_var = bucket.variance;
445                } else {
446                    let combined_count = running_count + bucket.count;
447                    let mean_running = running_total / running_count as f64;
448                    let mean_bucket = bucket.total / bucket.count as f64;
449                    let diff = mean_running - mean_bucket;
450                    running_var = running_var
451                        + bucket.variance
452                        + diff * diff * (running_count as f64) * (bucket.count as f64)
453                            / combined_count as f64;
454                    running_total += bucket.total;
455                    running_count = combined_count;
456                }
457            }
458        }
459        self.variance = running_var;
460    }
461}
462
463impl Default for Adwin {
464    fn default() -> Self {
465        Self::new()
466    }
467}
468
469// ---------------------------------------------------------------------------
470// Display
471// ---------------------------------------------------------------------------
472
473impl core::fmt::Display for Adwin {
474    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
475        write!(
476            f,
477            "Adwin(delta={}, width={}, mean={:.6})",
478            self.delta,
479            self.width,
480            self.estimated_mean()
481        )
482    }
483}
484
485// ---------------------------------------------------------------------------
486// DriftDetector trait
487// ---------------------------------------------------------------------------
488
489impl DriftDetector for Adwin {
490    fn update(&mut self, value: f64) -> DriftSignal {
491        // 1. Insert new observation.
492        self.insert_bucket(value);
493
494        // 2. Compress overflowing rows.
495        self.compress();
496
497        // 3. Check for drift.
498        let (drift, warning) = self.check_drift();
499
500        if drift {
501            // Shrink window: remove the older sub-window.
502            self.shrink_window();
503            DriftSignal::Drift
504        } else if warning {
505            DriftSignal::Warning
506        } else {
507            DriftSignal::Stable
508        }
509    }
510
511    fn reset(&mut self) {
512        self.rows.clear();
513        self.total = 0.0;
514        self.variance = 0.0;
515        self.count = 0;
516        self.width = 0;
517    }
518
519    fn clone_fresh(&self) -> Box<dyn DriftDetector> {
520        Box::new(Self::with_delta(self.delta).with_max_buckets(self.max_buckets))
521    }
522
523    fn clone_boxed(&self) -> Box<dyn DriftDetector> {
524        Box::new(self.clone())
525    }
526
527    fn estimated_mean(&self) -> f64 {
528        if self.count == 0 {
529            0.0
530        } else {
531            self.total / self.count as f64
532        }
533    }
534
535    fn serialize_state(&self) -> Option<super::DriftDetectorState> {
536        use super::{AdwinBucketState, DriftDetectorState};
537        let rows = self
538            .rows
539            .iter()
540            .map(|row| {
541                row.iter()
542                    .map(|b| AdwinBucketState {
543                        total: b.total,
544                        variance: b.variance,
545                        count: b.count,
546                    })
547                    .collect()
548            })
549            .collect();
550        Some(DriftDetectorState::Adwin {
551            rows,
552            total: self.total,
553            variance: self.variance,
554            count: self.count,
555            width: self.width,
556        })
557    }
558
559    fn restore_state(&mut self, state: &super::DriftDetectorState) -> bool {
560        if let super::DriftDetectorState::Adwin {
561            rows,
562            total,
563            variance,
564            count,
565            width,
566        } = state
567        {
568            self.rows = rows
569                .iter()
570                .map(|row| {
571                    row.iter()
572                        .map(|b| Bucket {
573                            total: b.total,
574                            variance: b.variance,
575                            count: b.count,
576                        })
577                        .collect()
578                })
579                .collect();
580            self.total = *total;
581            self.variance = *variance;
582            self.count = *count;
583            self.width = *width;
584            true
585        } else {
586            false
587        }
588    }
589}
590
591// ---------------------------------------------------------------------------
592// Tests
593// ---------------------------------------------------------------------------
594
595#[cfg(test)]
596mod tests {
597    use super::super::{DriftDetector, DriftSignal};
598    use super::*;
599    use alloc::vec::Vec;
600
601    /// Helper: simple deterministic PRNG for reproducible tests without pulling
602    /// in `rand` in the test module. Uses xorshift64.
603    struct Xorshift64(u64);
604
605    impl Xorshift64 {
606        fn new(seed: u64) -> Self {
607            Self(if seed == 0 { 1 } else { seed })
608        }
609
610        fn next_u64(&mut self) -> u64 {
611            let mut x = self.0;
612            x ^= x << 13;
613            x ^= x >> 7;
614            x ^= x << 17;
615            self.0 = x;
616            x
617        }
618
619        /// Uniform f64 in [0, 1).
620        fn next_f64(&mut self) -> f64 {
621            (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
622        }
623
624        /// Approximate normal via Box-Muller transform.
625        fn next_normal(&mut self, mean: f64, std: f64) -> f64 {
626            let u1 = self.next_f64().max(1e-15); // avoid ln(0)
627            let u2 = self.next_f64();
628            let z = crate::math::sqrt(-2.0 * crate::math::ln(u1))
629                * crate::math::cos(2.0 * core::f64::consts::PI * u2);
630            mean + std * z
631        }
632    }
633
634    // -----------------------------------------------------------------------
635    // 1. No false alarm on constant distribution
636    // -----------------------------------------------------------------------
637    #[test]
638    fn no_false_alarm_constant() {
639        let mut det = Adwin::with_delta(0.002);
640        let mut rng = Xorshift64::new(42);
641
642        let mut drift_count = 0;
643        for _ in 0..10_000 {
644            let value = 0.5 + rng.next_normal(0.0, 0.01);
645            if det.update(value) == DriftSignal::Drift {
646                drift_count += 1;
647            }
648        }
649        // With delta=0.002 over 10k samples we might get a rare false positive,
650        // but certainly not more than a handful.
651        assert!(
652            drift_count <= 2,
653            "Too many false alarms on constant distribution: {drift_count}"
654        );
655    }
656
657    // -----------------------------------------------------------------------
658    // 2. Detect abrupt shift
659    // -----------------------------------------------------------------------
660    #[test]
661    fn detect_abrupt_shift() {
662        let mut det = Adwin::with_delta(0.002);
663        let mut rng = Xorshift64::new(123);
664
665        // Phase 1: N(0, 0.1)
666        for _ in 0..2000 {
667            det.update(rng.next_normal(0.0, 0.1));
668        }
669
670        // Phase 2: N(5, 0.1) -- huge shift, should detect quickly.
671        let mut detected = false;
672        for i in 0..2000 {
673            let sig = det.update(rng.next_normal(5.0, 0.1));
674            if sig == DriftSignal::Drift {
675                detected = true;
676                // Should detect within the first ~100 samples of the new regime.
677                assert!(
678                    i < 500,
679                    "Drift detected too late after abrupt shift: sample {i}"
680                );
681                break;
682            }
683        }
684        assert!(detected, "Failed to detect abrupt mean shift from 0 to 5");
685    }
686
687    // -----------------------------------------------------------------------
688    // 3. Detect gradual shift
689    // -----------------------------------------------------------------------
690    #[test]
691    fn detect_gradual_shift() {
692        let mut det = Adwin::with_delta(0.01); // slightly more sensitive for gradual
693        let mut rng = Xorshift64::new(456);
694
695        let mut detected = false;
696        for i in 0..5000 {
697            // Mean ramps linearly from 0 to 5.
698            let mean = 5.0 * (i as f64) / 5000.0;
699            let value = rng.next_normal(mean, 0.1);
700            if det.update(value) == DriftSignal::Drift {
701                detected = true;
702                break;
703            }
704        }
705        assert!(detected, "Failed to detect gradual mean shift from 0 to 5");
706    }
707
708    // -----------------------------------------------------------------------
709    // 4. estimated_mean tracks recent values
710    // -----------------------------------------------------------------------
711    #[test]
712    fn estimated_mean_tracks() {
713        let mut det = Adwin::new();
714        // Feed exactly 100 values of 3.0.
715        for _ in 0..100 {
716            det.update(3.0);
717        }
718        let mean = det.estimated_mean();
719        assert!((mean - 3.0).abs() < 1e-9, "Expected mean ~3.0, got {mean}");
720
721        // Feed 100 values of 7.0.
722        for _ in 0..100 {
723            det.update(7.0);
724        }
725        // Mean should be somewhere between 3 and 7 depending on window.
726        let mean = det.estimated_mean();
727        assert!(
728            mean > 2.5 && mean < 7.5,
729            "Mean out of expected range: {mean}"
730        );
731    }
732
733    // -----------------------------------------------------------------------
734    // 5. reset clears all state
735    // -----------------------------------------------------------------------
736    #[test]
737    fn reset_clears_state() {
738        let mut det = Adwin::new();
739        for _ in 0..500 {
740            det.update(1.0);
741        }
742        assert!(det.width() > 0);
743        assert!(det.estimated_mean() > 0.0);
744
745        det.reset();
746
747        assert_eq!(det.width(), 0);
748        assert_eq!(det.estimated_mean(), 0.0);
749        assert!(det.rows.is_empty());
750        assert_eq!(det.count, 0);
751        assert_eq!(det.total, 0.0);
752        assert_eq!(det.variance, 0.0);
753    }
754
755    // -----------------------------------------------------------------------
756    // 6. clone_fresh returns clean instance with same delta
757    // -----------------------------------------------------------------------
758    #[test]
759    fn clone_fresh_preserves_config() {
760        let mut det = Adwin::with_delta(0.05).with_max_buckets(7);
761        for _ in 0..200 {
762            det.update(42.0);
763        }
764
765        let fresh = det.clone_fresh();
766        // The fresh detector should have zero state.
767        assert_eq!(fresh.estimated_mean(), 0.0);
768
769        // We can't directly inspect the boxed detector's fields, so we verify
770        // behaviour: feeding the same constant shouldn't drift.
771        // (This implicitly tests that delta was preserved -- if it were near 1.0,
772        // we'd get spurious drifts.)
773        let mut fresh = fresh;
774        let mut drifts = 0;
775        for _ in 0..1000 {
776            if fresh.update(1.0) == DriftSignal::Drift {
777                drifts += 1;
778            }
779        }
780        assert!(
781            drifts <= 1,
782            "clone_fresh produced detector with too many false alarms: {drifts}"
783        );
784    }
785
786    // -----------------------------------------------------------------------
787    // 7. Small window warmup: no drift check before min_window
788    // -----------------------------------------------------------------------
789    #[test]
790    fn warmup_suppresses_early_detection() {
791        let mut det = Adwin::with_delta(0.002).with_min_window(100);
792
793        // Feed an extreme shift within the first 100 samples: 50 zeros then 50
794        // very large values. Without warmup this would certainly trigger.
795        let mut any_drift = false;
796        for _ in 0..50 {
797            if det.update(0.0) == DriftSignal::Drift {
798                any_drift = true;
799            }
800        }
801        for _ in 0..50 {
802            if det.update(100.0) == DriftSignal::Drift {
803                any_drift = true;
804            }
805        }
806        assert!(
807            !any_drift,
808            "Drift should not fire before min_window=100 samples"
809        );
810    }
811
812    // -----------------------------------------------------------------------
813    // 8. Bucket compression keeps memory bounded
814    // -----------------------------------------------------------------------
815    #[test]
816    fn compression_bounds_memory() {
817        let mut det = Adwin::with_delta(0.002).with_max_buckets(5);
818
819        for i in 0..10_000 {
820            det.update(i as f64);
821        }
822
823        // After compression, no row should have more than max_buckets entries.
824        for (row_idx, row) in det.rows.iter().enumerate() {
825            assert!(
826                row.len() <= det.max_buckets + 1, // +1 tolerance for the just-inserted bucket before next compress
827                "Row {row_idx} has {} buckets, exceeding max {}",
828                row.len(),
829                det.max_buckets
830            );
831        }
832
833        // Total bucket count should be O(log W * max_buckets).
834        let total_buckets: usize = det.rows.iter().map(|r| r.len()).sum();
835        let expected_max = det.rows.len() * (det.max_buckets + 1);
836        assert!(
837            total_buckets <= expected_max,
838            "Total buckets {total_buckets} exceeds expected max {expected_max}"
839        );
840    }
841
842    // -----------------------------------------------------------------------
843    // 9. Window shrinks after drift
844    // -----------------------------------------------------------------------
845    #[test]
846    fn window_shrinks_on_drift() {
847        let mut det = Adwin::with_delta(0.002);
848
849        // Build up a big window.
850        for _ in 0..2000 {
851            det.update(0.0);
852        }
853        let width_before = det.width();
854        assert!(
855            width_before >= 1900,
856            "Expected large window, got {width_before}"
857        );
858
859        // Inject abrupt shift to force drift.
860        let mut drifted = false;
861        for _ in 0..500 {
862            if det.update(100.0) == DriftSignal::Drift {
863                drifted = true;
864                break;
865            }
866        }
867        assert!(drifted, "Expected drift on extreme shift");
868
869        // After drift, window should be significantly smaller.
870        let width_after = det.width();
871        assert!(
872            width_after < width_before,
873            "Window should shrink after drift: before={width_before}, after={width_after}"
874        );
875    }
876
877    // -----------------------------------------------------------------------
878    // 10. Warning signal fires before drift
879    // -----------------------------------------------------------------------
880    #[test]
881    fn warning_precedes_drift() {
882        let mut det = Adwin::with_delta(0.002);
883        let mut rng = Xorshift64::new(789);
884
885        // Stable phase.
886        for _ in 0..1000 {
887            det.update(rng.next_normal(0.0, 0.1));
888        }
889
890        // Moderate shift that might trigger warning before drift.
891        let mut _saw_warning = false;
892        let mut saw_drift = false;
893        for _ in 0..2000 {
894            let sig = det.update(rng.next_normal(2.0, 0.1));
895            match sig {
896                DriftSignal::Warning => {
897                    if !saw_drift {
898                        _saw_warning = true;
899                    }
900                }
901                DriftSignal::Drift => {
902                    saw_drift = true;
903                    break;
904                }
905                _ => {}
906            }
907        }
908
909        // We should have detected drift on a shift this large.
910        assert!(saw_drift, "Should have detected drift on shift from 0 to 2");
911        // Warning may or may not precede drift depending on how fast it triggers,
912        // but we at least verify the mechanism didn't panic.
913    }
914
915    // -----------------------------------------------------------------------
916    // 11. Deterministic reproducibility
917    // -----------------------------------------------------------------------
918    #[test]
919    fn deterministic_for_same_input() {
920        let values: Vec<f64> = (0..500)
921            .map(|i| crate::math::sin(i as f64 * 0.01))
922            .collect();
923
924        let mut det1 = Adwin::with_delta(0.01);
925        let mut det2 = Adwin::with_delta(0.01);
926
927        let signals1: Vec<DriftSignal> = values.iter().map(|&v| det1.update(v)).collect();
928        let signals2: Vec<DriftSignal> = values.iter().map(|&v| det2.update(v)).collect();
929
930        assert_eq!(
931            signals1, signals2,
932            "Same input must produce identical signals"
933        );
934    }
935}