1use serde::{Deserialize, Serialize};
7use std::collections::VecDeque;
8
9use crate::errors::{DecisionError, Result};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum DriftStatus {
14 Stable,
16 Warning,
18 Drift,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24pub enum DriftAlgorithm {
25 ADWIN,
27 PageHinkley,
29 CUSUM,
31 Statistical,
33}
34
35pub struct ADWIN {
39 delta: f64,
41 window: VecDeque<f64>,
43 sum: f64,
45 sum_squares: f64,
47 max_window_size: usize,
49 drift_detected: bool,
51}
52
53impl ADWIN {
54 pub fn new(delta: f64, max_window_size: usize) -> Result<Self> {
56 if delta <= 0.0 || delta >= 1.0 {
57 return Err(DecisionError::InvalidParameter(
58 "Delta must be in (0, 1)".to_string(),
59 ));
60 }
61
62 Ok(Self {
63 delta,
64 window: VecDeque::with_capacity(max_window_size),
65 sum: 0.0,
66 sum_squares: 0.0,
67 max_window_size,
68 drift_detected: false,
69 })
70 }
71
72 pub fn add(&mut self, value: f64) -> DriftStatus {
74 self.drift_detected = false;
75
76 if self.window.len() >= self.max_window_size {
78 if let Some(old) = self.window.pop_front() {
79 self.sum -= old;
80 self.sum_squares -= old * old;
81 }
82 }
83
84 self.window.push_back(value);
85 self.sum += value;
86 self.sum_squares += value * value;
87
88 if self.detect_change() {
90 self.drift_detected = true;
91 DriftStatus::Drift
92 } else if self.window.len() > 10 && self.is_warning() {
93 DriftStatus::Warning
94 } else {
95 DriftStatus::Stable
96 }
97 }
98
99 fn detect_change(&self) -> bool {
101 let n = self.window.len();
102 if n < 10 {
103 return false;
104 }
105
106 for cut in n / 4..=3 * n / 4 {
108 if self.test_split(cut) {
109 return true;
110 }
111 }
112
113 false
114 }
115
116 fn test_split(&self, cut: usize) -> bool {
118 let n = self.window.len();
119
120 let mut sum1 = 0.0;
122 let mut sum_sq1 = 0.0;
123 let mut sum2 = 0.0;
124 let mut sum_sq2 = 0.0;
125
126 for (i, &val) in self.window.iter().enumerate() {
127 if i < cut {
128 sum1 += val;
129 sum_sq1 += val * val;
130 } else {
131 sum2 += val;
132 sum_sq2 += val * val;
133 }
134 }
135
136 let n1 = cut as f64;
137 let n2 = (n - cut) as f64;
138
139 if n1 == 0.0 || n2 == 0.0 {
140 return false;
141 }
142
143 let mean1 = sum1 / n1;
144 let mean2 = sum2 / n2;
145
146 let var1 = (sum_sq1 / n1) - (mean1 * mean1);
147 let var2 = (sum_sq2 / n2) - (mean2 * mean2);
148
149 let m = 1.0 / (1.0 / n1 + 1.0 / n2);
151 let epsilon = ((1.0 / (2.0 * m)) * (4.0 + (n as f64).ln() / self.delta).ln()).sqrt();
152
153 (mean1 - mean2).abs() > epsilon || (var1 - var2).abs() > epsilon
154 }
155
156 fn is_warning(&self) -> bool {
158 if self.window.len() < 5 {
159 return false;
160 }
161
162 let n = self.window.len();
163 let mean = self.sum / n as f64;
164 let variance = (self.sum_squares / n as f64) - (mean * mean);
165
166 let recent_count = (n / 4).max(5);
168 let recent_sum: f64 = self.window.iter().rev().take(recent_count).sum();
169 let recent_mean = recent_sum / recent_count as f64;
170
171 let std_dev = variance.sqrt();
172 if std_dev > 0.0 {
173 (recent_mean - mean).abs() / std_dev > 1.5
174 } else {
175 false
176 }
177 }
178
179 pub fn reset(&mut self) {
181 self.window.clear();
182 self.sum = 0.0;
183 self.sum_squares = 0.0;
184 self.drift_detected = false;
185 }
186
187 pub fn window_size(&self) -> usize {
189 self.window.len()
190 }
191
192 pub fn mean(&self) -> Option<f64> {
194 if self.window.is_empty() {
195 None
196 } else {
197 Some(self.sum / self.window.len() as f64)
198 }
199 }
200
201 pub fn variance(&self) -> Option<f64> {
203 if self.window.len() < 2 {
204 None
205 } else {
206 let n = self.window.len() as f64;
207 let mean = self.sum / n;
208 Some((self.sum_squares / n) - (mean * mean))
209 }
210 }
211}
212
213pub struct PageHinkley {
217 threshold: f64,
219 alpha: f64,
221 cumsum: f64,
223 min_cumsum: f64,
225 reference_mean: f64,
227 sample_count: usize,
229 drift_detected: bool,
231}
232
233impl PageHinkley {
234 pub fn new(threshold: f64, alpha: f64) -> Result<Self> {
236 if threshold <= 0.0 {
237 return Err(DecisionError::InvalidParameter(
238 "Threshold must be positive".to_string(),
239 ));
240 }
241
242 if alpha <= 0.0 || alpha > 1.0 {
243 return Err(DecisionError::InvalidParameter(
244 "Alpha must be in (0, 1]".to_string(),
245 ));
246 }
247
248 Ok(Self {
249 threshold,
250 alpha,
251 cumsum: 0.0,
252 min_cumsum: 0.0,
253 reference_mean: 0.0,
254 sample_count: 0,
255 drift_detected: false,
256 })
257 }
258
259 pub fn add(&mut self, value: f64) -> DriftStatus {
261 self.drift_detected = false;
262
263 if self.sample_count == 0 {
264 self.reference_mean = value;
265 self.sample_count = 1;
266 return DriftStatus::Stable;
267 }
268
269 self.cumsum += value - self.reference_mean - self.alpha;
271
272 if self.cumsum < self.min_cumsum {
274 self.min_cumsum = self.cumsum;
275 }
276
277 let ph_value = self.cumsum - self.min_cumsum;
279
280 self.sample_count += 1;
281
282 if ph_value > self.threshold {
283 self.drift_detected = true;
284 DriftStatus::Drift
285 } else if ph_value > self.threshold * 0.7 {
286 DriftStatus::Warning
287 } else {
288 DriftStatus::Stable
289 }
290 }
291
292 pub fn reset(&mut self) {
294 self.cumsum = 0.0;
295 self.min_cumsum = 0.0;
296 self.reference_mean = 0.0;
297 self.sample_count = 0;
298 self.drift_detected = false;
299 }
300
301 pub fn statistic(&self) -> f64 {
303 self.cumsum - self.min_cumsum
304 }
305
306 pub fn count(&self) -> usize {
308 self.sample_count
309 }
310}
311
312pub struct CUSUM {
314 threshold: f64,
316 target_mean: f64,
318 delta: f64,
320 cumsum_pos: f64,
322 cumsum_neg: f64,
324 sample_count: usize,
326 drift_direction: Option<bool>, }
329
330impl CUSUM {
331 pub fn new(threshold: f64, target_mean: f64, delta: f64) -> Result<Self> {
333 if threshold <= 0.0 {
334 return Err(DecisionError::InvalidParameter(
335 "Threshold must be positive".to_string(),
336 ));
337 }
338
339 Ok(Self {
340 threshold,
341 target_mean,
342 delta,
343 cumsum_pos: 0.0,
344 cumsum_neg: 0.0,
345 sample_count: 0,
346 drift_direction: None,
347 })
348 }
349
350 pub fn add(&mut self, value: f64) -> DriftStatus {
352 self.drift_direction = None;
353
354 let deviation = value - self.target_mean;
355
356 self.cumsum_pos = (self.cumsum_pos + deviation - self.delta / 2.0).max(0.0);
358
359 self.cumsum_neg = (self.cumsum_neg - deviation - self.delta / 2.0).max(0.0);
361
362 self.sample_count += 1;
363
364 if self.cumsum_pos > self.threshold {
366 self.drift_direction = Some(true);
367 DriftStatus::Drift
368 } else if self.cumsum_neg > self.threshold {
369 self.drift_direction = Some(false);
370 DriftStatus::Drift
371 } else if self.cumsum_pos > self.threshold * 0.7 || self.cumsum_neg > self.threshold * 0.7 {
372 DriftStatus::Warning
373 } else {
374 DriftStatus::Stable
375 }
376 }
377
378 pub fn reset(&mut self) {
380 self.cumsum_pos = 0.0;
381 self.cumsum_neg = 0.0;
382 self.sample_count = 0;
383 self.drift_direction = None;
384 }
385
386 pub fn drift_direction(&self) -> Option<bool> {
388 self.drift_direction
389 }
390
391 pub fn positive_cusum(&self) -> f64 {
393 self.cumsum_pos
394 }
395
396 pub fn negative_cusum(&self) -> f64 {
398 self.cumsum_neg
399 }
400}
401
402pub struct StatisticalDriftDetector {
404 reference_window: VecDeque<f64>,
406 current_window: VecDeque<f64>,
408 window_size: usize,
410 alpha: f64,
412 current_count: usize,
414}
415
416impl StatisticalDriftDetector {
417 pub fn new(window_size: usize, alpha: f64) -> Result<Self> {
419 if window_size < 2 {
420 return Err(DecisionError::InvalidParameter(
421 "Window size must be at least 2".to_string(),
422 ));
423 }
424
425 if alpha <= 0.0 || alpha >= 1.0 {
426 return Err(DecisionError::InvalidParameter(
427 "Alpha must be in (0, 1)".to_string(),
428 ));
429 }
430
431 Ok(Self {
432 reference_window: VecDeque::with_capacity(window_size),
433 current_window: VecDeque::with_capacity(window_size),
434 window_size,
435 alpha,
436 current_count: 0,
437 })
438 }
439
440 pub fn add(&mut self, value: f64) -> DriftStatus {
442 if self.reference_window.len() < self.window_size {
444 self.reference_window.push_back(value);
445 return DriftStatus::Stable;
446 }
447
448 if self.current_window.len() >= self.window_size {
450 self.current_window.pop_front();
451 }
452 self.current_window.push_back(value);
453 self.current_count += 1;
454
455 if self.current_window.len() < self.window_size {
456 return DriftStatus::Stable;
457 }
458
459 match self.welch_t_test() {
461 Ok(p_value) => {
462 if p_value < self.alpha {
463 DriftStatus::Drift
464 } else if p_value < self.alpha * 2.0 {
465 DriftStatus::Warning
466 } else {
467 DriftStatus::Stable
468 }
469 }
470 Err(_) => DriftStatus::Stable,
471 }
472 }
473
474 fn welch_t_test(&self) -> Result<f64> {
476 let (mean1, var1) = self.mean_variance(&self.reference_window)?;
477 let (mean2, var2) = self.mean_variance(&self.current_window)?;
478
479 let n1 = self.reference_window.len() as f64;
480 let n2 = self.current_window.len() as f64;
481
482 let se = ((var1 / n1) + (var2 / n2)).sqrt();
484 if se == 0.0 {
485 return Ok(1.0); }
487
488 let t = ((mean1 - mean2).abs()) / se;
489
490 let p_value = 2.0 * (1.0 - normal_cdf(t.abs()));
493
494 Ok(p_value.clamp(0.0, 1.0))
495 }
496
497 fn mean_variance(&self, window: &VecDeque<f64>) -> Result<(f64, f64)> {
499 if window.is_empty() {
500 return Err(DecisionError::InvalidState("Empty window".to_string()));
501 }
502
503 let n = window.len() as f64;
504 let sum: f64 = window.iter().sum();
505 let mean = sum / n;
506
507 let variance = if window.len() > 1 {
508 let sum_sq: f64 = window.iter().map(|x| (x - mean).powi(2)).sum();
509 sum_sq / (n - 1.0)
510 } else {
511 0.0
512 };
513
514 Ok((mean, variance))
515 }
516
517 pub fn update_reference(&mut self) {
519 self.reference_window = self.current_window.clone();
520 self.current_window.clear();
521 self.current_count = 0;
522 }
523
524 pub fn reset(&mut self) {
526 self.reference_window.clear();
527 self.current_window.clear();
528 self.current_count = 0;
529 }
530}
531
532fn normal_cdf(x: f64) -> f64 {
534 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
535}
536
537fn erf(x: f64) -> f64 {
539 let a1 = 0.254829592;
540 let a2 = -0.284496736;
541 let a3 = 1.421413741;
542 let a4 = -1.453152027;
543 let a5 = 1.061405429;
544 let p = 0.3275911;
545
546 let sign = if x < 0.0 { -1.0 } else { 1.0 };
547 let x = x.abs();
548
549 let t = 1.0 / (1.0 + p * x);
550 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
551
552 sign * y
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[test]
560 fn test_adwin_creation() {
561 let adwin = ADWIN::new(0.002, 100).unwrap();
562 assert_eq!(adwin.window_size(), 0);
563 }
564
565 #[test]
566 fn test_adwin_invalid_delta() {
567 assert!(ADWIN::new(0.0, 100).is_err());
568 assert!(ADWIN::new(1.0, 100).is_err());
569 assert!(ADWIN::new(1.5, 100).is_err());
570 }
571
572 #[test]
573 fn test_adwin_stable_data() {
574 let mut adwin = ADWIN::new(0.002, 100).unwrap();
575
576 for _ in 0..50 {
577 let status = adwin.add(1.0);
578 assert_eq!(status, DriftStatus::Stable);
579 }
580 }
581
582 #[test]
583 fn test_adwin_drift_detection() {
584 let mut adwin = ADWIN::new(0.002, 100).unwrap();
585
586 for _ in 0..30 {
588 adwin.add(1.0);
589 }
590
591 let mut drift_detected = false;
593 for _ in 0..30 {
594 let status = adwin.add(2.0);
595 if status == DriftStatus::Drift {
596 drift_detected = true;
597 break;
598 }
599 }
600
601 assert!(drift_detected);
602 }
603
604 #[test]
605 fn test_adwin_statistics() {
606 let mut adwin = ADWIN::new(0.002, 100).unwrap();
607
608 for i in 1..=10 {
609 adwin.add(i as f64);
610 }
611
612 assert!(adwin.mean().is_some());
613 assert!(adwin.variance().is_some());
614 assert_eq!(adwin.window_size(), 10);
615 }
616
617 #[test]
618 fn test_page_hinkley_creation() {
619 let ph = PageHinkley::new(50.0, 0.005).unwrap();
620 assert_eq!(ph.count(), 0);
621 }
622
623 #[test]
624 fn test_page_hinkley_invalid_params() {
625 assert!(PageHinkley::new(0.0, 0.005).is_err());
626 assert!(PageHinkley::new(50.0, 0.0).is_err());
627 assert!(PageHinkley::new(50.0, 1.5).is_err());
628 }
629
630 #[test]
631 fn test_page_hinkley_stable() {
632 let mut ph = PageHinkley::new(50.0, 0.005).unwrap();
633
634 for _ in 0..20 {
635 let status = ph.add(1.0);
636 assert_ne!(status, DriftStatus::Drift);
637 }
638 }
639
640 #[test]
641 fn test_page_hinkley_drift() {
642 let mut ph = PageHinkley::new(10.0, 0.005).unwrap();
643
644 for _ in 0..20 {
646 ph.add(1.0);
647 }
648
649 let mut drift_detected = false;
651 for _ in 0..30 {
652 let status = ph.add(3.0);
653 if status == DriftStatus::Drift {
654 drift_detected = true;
655 break;
656 }
657 }
658
659 assert!(drift_detected);
660 }
661
662 #[test]
663 fn test_cusum_creation() {
664 let cusum = CUSUM::new(5.0, 1.0, 0.5).unwrap();
665 assert_eq!(cusum.positive_cusum(), 0.0);
666 assert_eq!(cusum.negative_cusum(), 0.0);
667 }
668
669 #[test]
670 fn test_cusum_stable() {
671 let mut cusum = CUSUM::new(5.0, 1.0, 0.5).unwrap();
672
673 for _ in 0..20 {
674 let status = cusum.add(1.0);
675 assert_eq!(status, DriftStatus::Stable);
676 }
677 }
678
679 #[test]
680 fn test_cusum_positive_drift() {
681 let mut cusum = CUSUM::new(3.0, 1.0, 0.5).unwrap();
682
683 let mut drift_detected = false;
685 for _ in 0..20 {
686 let status = cusum.add(2.5);
687 if status == DriftStatus::Drift {
688 drift_detected = true;
689 assert_eq!(cusum.drift_direction(), Some(true));
690 break;
691 }
692 }
693
694 assert!(drift_detected);
695 }
696
697 #[test]
698 fn test_cusum_negative_drift() {
699 let mut cusum = CUSUM::new(3.0, 1.0, 0.5).unwrap();
700
701 let mut drift_detected = false;
703 for _ in 0..20 {
704 let status = cusum.add(-0.5);
705 if status == DriftStatus::Drift {
706 drift_detected = true;
707 assert_eq!(cusum.drift_direction(), Some(false));
708 break;
709 }
710 }
711
712 assert!(drift_detected);
713 }
714
715 #[test]
716 fn test_statistical_detector_creation() {
717 let detector = StatisticalDriftDetector::new(30, 0.05).unwrap();
718 assert!(detector.reference_window.is_empty());
719 }
720
721 #[test]
722 fn test_statistical_detector_stable() {
723 let mut detector = StatisticalDriftDetector::new(20, 0.05).unwrap();
724
725 for _ in 0..60 {
726 let status = detector.add(1.0);
727 if detector.current_window.len() >= 20 {
728 assert_eq!(status, DriftStatus::Stable);
729 }
730 }
731 }
732
733 #[test]
734 fn test_statistical_detector_basic() {
735 let mut detector = StatisticalDriftDetector::new(20, 0.1).unwrap();
736
737 for _ in 0..20 {
739 let status = detector.add(1.0);
740 assert!(status == DriftStatus::Stable);
742 }
743
744 for _ in 0..20 {
748 detector.add(5.0);
749 }
751
752 detector.update_reference();
754 detector.reset();
755 }
756
757 #[test]
758 fn test_normal_cdf() {
759 assert!((normal_cdf(0.0) - 0.5).abs() < 0.01);
760 assert!(normal_cdf(1.96) > 0.97);
761 assert!(normal_cdf(-1.96) < 0.03);
762 }
763
764 #[test]
765 fn test_adwin_reset() {
766 let mut adwin = ADWIN::new(0.002, 100).unwrap();
767
768 for i in 1..=10 {
769 adwin.add(i as f64);
770 }
771
772 assert_eq!(adwin.window_size(), 10);
773
774 adwin.reset();
775 assert_eq!(adwin.window_size(), 0);
776 assert!(adwin.mean().is_none());
777 }
778
779 #[test]
780 fn test_page_hinkley_reset() {
781 let mut ph = PageHinkley::new(50.0, 0.005).unwrap();
782
783 for _ in 0..10 {
784 ph.add(1.0);
785 }
786
787 assert!(ph.count() > 0);
788
789 ph.reset();
790 assert_eq!(ph.count(), 0);
791 }
792
793 #[test]
794 fn test_cusum_reset() {
795 let mut cusum = CUSUM::new(5.0, 1.0, 0.5).unwrap();
796
797 for _ in 0..10 {
798 cusum.add(2.0);
799 }
800
801 assert!(cusum.positive_cusum() > 0.0);
802
803 cusum.reset();
804 assert_eq!(cusum.positive_cusum(), 0.0);
805 assert_eq!(cusum.negative_cusum(), 0.0);
806 }
807}