use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use crate::errors::{DecisionError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DriftStatus {
Stable,
Warning,
Drift,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DriftAlgorithm {
ADWIN,
PageHinkley,
CUSUM,
Statistical,
}
pub struct ADWIN {
delta: f64,
window: VecDeque<f64>,
sum: f64,
sum_squares: f64,
max_window_size: usize,
drift_detected: bool,
}
impl ADWIN {
pub fn new(delta: f64, max_window_size: usize) -> Result<Self> {
if delta <= 0.0 || delta >= 1.0 {
return Err(DecisionError::InvalidParameter(
"Delta must be in (0, 1)".to_string(),
));
}
Ok(Self {
delta,
window: VecDeque::with_capacity(max_window_size),
sum: 0.0,
sum_squares: 0.0,
max_window_size,
drift_detected: false,
})
}
pub fn add(&mut self, value: f64) -> DriftStatus {
self.drift_detected = false;
if self.window.len() >= self.max_window_size {
if let Some(old) = self.window.pop_front() {
self.sum -= old;
self.sum_squares -= old * old;
}
}
self.window.push_back(value);
self.sum += value;
self.sum_squares += value * value;
if self.detect_change() {
self.drift_detected = true;
DriftStatus::Drift
} else if self.window.len() > 10 && self.is_warning() {
DriftStatus::Warning
} else {
DriftStatus::Stable
}
}
fn detect_change(&self) -> bool {
let n = self.window.len();
if n < 10 {
return false;
}
for cut in n / 4..=3 * n / 4 {
if self.test_split(cut) {
return true;
}
}
false
}
fn test_split(&self, cut: usize) -> bool {
let n = self.window.len();
let mut sum1 = 0.0;
let mut sum_sq1 = 0.0;
let mut sum2 = 0.0;
let mut sum_sq2 = 0.0;
for (i, &val) in self.window.iter().enumerate() {
if i < cut {
sum1 += val;
sum_sq1 += val * val;
} else {
sum2 += val;
sum_sq2 += val * val;
}
}
let n1 = cut as f64;
let n2 = (n - cut) as f64;
if n1 == 0.0 || n2 == 0.0 {
return false;
}
let mean1 = sum1 / n1;
let mean2 = sum2 / n2;
let var1 = (sum_sq1 / n1) - (mean1 * mean1);
let var2 = (sum_sq2 / n2) - (mean2 * mean2);
let m = 1.0 / (1.0 / n1 + 1.0 / n2);
let epsilon = ((1.0 / (2.0 * m)) * (4.0 + (n as f64).ln() / self.delta).ln()).sqrt();
(mean1 - mean2).abs() > epsilon || (var1 - var2).abs() > epsilon
}
fn is_warning(&self) -> bool {
if self.window.len() < 5 {
return false;
}
let n = self.window.len();
let mean = self.sum / n as f64;
let variance = (self.sum_squares / n as f64) - (mean * mean);
let recent_count = (n / 4).max(5);
let recent_sum: f64 = self.window.iter().rev().take(recent_count).sum();
let recent_mean = recent_sum / recent_count as f64;
let std_dev = variance.sqrt();
if std_dev > 0.0 {
(recent_mean - mean).abs() / std_dev > 1.5
} else {
false
}
}
pub fn reset(&mut self) {
self.window.clear();
self.sum = 0.0;
self.sum_squares = 0.0;
self.drift_detected = false;
}
pub fn window_size(&self) -> usize {
self.window.len()
}
pub fn mean(&self) -> Option<f64> {
if self.window.is_empty() {
None
} else {
Some(self.sum / self.window.len() as f64)
}
}
pub fn variance(&self) -> Option<f64> {
if self.window.len() < 2 {
None
} else {
let n = self.window.len() as f64;
let mean = self.sum / n;
Some((self.sum_squares / n) - (mean * mean))
}
}
}
pub struct PageHinkley {
threshold: f64,
alpha: f64,
cumsum: f64,
min_cumsum: f64,
reference_mean: f64,
sample_count: usize,
drift_detected: bool,
}
impl PageHinkley {
pub fn new(threshold: f64, alpha: f64) -> Result<Self> {
if threshold <= 0.0 {
return Err(DecisionError::InvalidParameter(
"Threshold must be positive".to_string(),
));
}
if alpha <= 0.0 || alpha > 1.0 {
return Err(DecisionError::InvalidParameter(
"Alpha must be in (0, 1]".to_string(),
));
}
Ok(Self {
threshold,
alpha,
cumsum: 0.0,
min_cumsum: 0.0,
reference_mean: 0.0,
sample_count: 0,
drift_detected: false,
})
}
pub fn add(&mut self, value: f64) -> DriftStatus {
self.drift_detected = false;
if self.sample_count == 0 {
self.reference_mean = value;
self.sample_count = 1;
return DriftStatus::Stable;
}
self.cumsum += value - self.reference_mean - self.alpha;
if self.cumsum < self.min_cumsum {
self.min_cumsum = self.cumsum;
}
let ph_value = self.cumsum - self.min_cumsum;
self.sample_count += 1;
if ph_value > self.threshold {
self.drift_detected = true;
DriftStatus::Drift
} else if ph_value > self.threshold * 0.7 {
DriftStatus::Warning
} else {
DriftStatus::Stable
}
}
pub fn reset(&mut self) {
self.cumsum = 0.0;
self.min_cumsum = 0.0;
self.reference_mean = 0.0;
self.sample_count = 0;
self.drift_detected = false;
}
pub fn statistic(&self) -> f64 {
self.cumsum - self.min_cumsum
}
pub fn count(&self) -> usize {
self.sample_count
}
}
pub struct CUSUM {
threshold: f64,
target_mean: f64,
delta: f64,
cumsum_pos: f64,
cumsum_neg: f64,
sample_count: usize,
drift_direction: Option<bool>, }
impl CUSUM {
pub fn new(threshold: f64, target_mean: f64, delta: f64) -> Result<Self> {
if threshold <= 0.0 {
return Err(DecisionError::InvalidParameter(
"Threshold must be positive".to_string(),
));
}
Ok(Self {
threshold,
target_mean,
delta,
cumsum_pos: 0.0,
cumsum_neg: 0.0,
sample_count: 0,
drift_direction: None,
})
}
pub fn add(&mut self, value: f64) -> DriftStatus {
self.drift_direction = None;
let deviation = value - self.target_mean;
self.cumsum_pos = (self.cumsum_pos + deviation - self.delta / 2.0).max(0.0);
self.cumsum_neg = (self.cumsum_neg - deviation - self.delta / 2.0).max(0.0);
self.sample_count += 1;
if self.cumsum_pos > self.threshold {
self.drift_direction = Some(true);
DriftStatus::Drift
} else if self.cumsum_neg > self.threshold {
self.drift_direction = Some(false);
DriftStatus::Drift
} else if self.cumsum_pos > self.threshold * 0.7 || self.cumsum_neg > self.threshold * 0.7 {
DriftStatus::Warning
} else {
DriftStatus::Stable
}
}
pub fn reset(&mut self) {
self.cumsum_pos = 0.0;
self.cumsum_neg = 0.0;
self.sample_count = 0;
self.drift_direction = None;
}
pub fn drift_direction(&self) -> Option<bool> {
self.drift_direction
}
pub fn positive_cusum(&self) -> f64 {
self.cumsum_pos
}
pub fn negative_cusum(&self) -> f64 {
self.cumsum_neg
}
}
pub struct StatisticalDriftDetector {
reference_window: VecDeque<f64>,
current_window: VecDeque<f64>,
window_size: usize,
alpha: f64,
current_count: usize,
}
impl StatisticalDriftDetector {
pub fn new(window_size: usize, alpha: f64) -> Result<Self> {
if window_size < 2 {
return Err(DecisionError::InvalidParameter(
"Window size must be at least 2".to_string(),
));
}
if alpha <= 0.0 || alpha >= 1.0 {
return Err(DecisionError::InvalidParameter(
"Alpha must be in (0, 1)".to_string(),
));
}
Ok(Self {
reference_window: VecDeque::with_capacity(window_size),
current_window: VecDeque::with_capacity(window_size),
window_size,
alpha,
current_count: 0,
})
}
pub fn add(&mut self, value: f64) -> DriftStatus {
if self.reference_window.len() < self.window_size {
self.reference_window.push_back(value);
return DriftStatus::Stable;
}
if self.current_window.len() >= self.window_size {
self.current_window.pop_front();
}
self.current_window.push_back(value);
self.current_count += 1;
if self.current_window.len() < self.window_size {
return DriftStatus::Stable;
}
match self.welch_t_test() {
Ok(p_value) => {
if p_value < self.alpha {
DriftStatus::Drift
} else if p_value < self.alpha * 2.0 {
DriftStatus::Warning
} else {
DriftStatus::Stable
}
}
Err(_) => DriftStatus::Stable,
}
}
fn welch_t_test(&self) -> Result<f64> {
let (mean1, var1) = self.mean_variance(&self.reference_window)?;
let (mean2, var2) = self.mean_variance(&self.current_window)?;
let n1 = self.reference_window.len() as f64;
let n2 = self.current_window.len() as f64;
let se = ((var1 / n1) + (var2 / n2)).sqrt();
if se == 0.0 {
return Ok(1.0); }
let t = ((mean1 - mean2).abs()) / se;
let p_value = 2.0 * (1.0 - normal_cdf(t.abs()));
Ok(p_value.clamp(0.0, 1.0))
}
fn mean_variance(&self, window: &VecDeque<f64>) -> Result<(f64, f64)> {
if window.is_empty() {
return Err(DecisionError::InvalidState("Empty window".to_string()));
}
let n = window.len() as f64;
let sum: f64 = window.iter().sum();
let mean = sum / n;
let variance = if window.len() > 1 {
let sum_sq: f64 = window.iter().map(|x| (x - mean).powi(2)).sum();
sum_sq / (n - 1.0)
} else {
0.0
};
Ok((mean, variance))
}
pub fn update_reference(&mut self) {
self.reference_window = self.current_window.clone();
self.current_window.clear();
self.current_count = 0;
}
pub fn reset(&mut self) {
self.reference_window.clear();
self.current_window.clear();
self.current_count = 0;
}
}
fn normal_cdf(x: f64) -> f64 {
0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
}
fn erf(x: f64) -> f64 {
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let p = 0.3275911;
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
sign * y
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adwin_creation() {
let adwin = ADWIN::new(0.002, 100).unwrap();
assert_eq!(adwin.window_size(), 0);
}
#[test]
fn test_adwin_invalid_delta() {
assert!(ADWIN::new(0.0, 100).is_err());
assert!(ADWIN::new(1.0, 100).is_err());
assert!(ADWIN::new(1.5, 100).is_err());
}
#[test]
fn test_adwin_stable_data() {
let mut adwin = ADWIN::new(0.002, 100).unwrap();
for _ in 0..50 {
let status = adwin.add(1.0);
assert_eq!(status, DriftStatus::Stable);
}
}
#[test]
fn test_adwin_drift_detection() {
let mut adwin = ADWIN::new(0.002, 100).unwrap();
for _ in 0..30 {
adwin.add(1.0);
}
let mut drift_detected = false;
for _ in 0..30 {
let status = adwin.add(2.0);
if status == DriftStatus::Drift {
drift_detected = true;
break;
}
}
assert!(drift_detected);
}
#[test]
fn test_adwin_statistics() {
let mut adwin = ADWIN::new(0.002, 100).unwrap();
for i in 1..=10 {
adwin.add(i as f64);
}
assert!(adwin.mean().is_some());
assert!(adwin.variance().is_some());
assert_eq!(adwin.window_size(), 10);
}
#[test]
fn test_page_hinkley_creation() {
let ph = PageHinkley::new(50.0, 0.005).unwrap();
assert_eq!(ph.count(), 0);
}
#[test]
fn test_page_hinkley_invalid_params() {
assert!(PageHinkley::new(0.0, 0.005).is_err());
assert!(PageHinkley::new(50.0, 0.0).is_err());
assert!(PageHinkley::new(50.0, 1.5).is_err());
}
#[test]
fn test_page_hinkley_stable() {
let mut ph = PageHinkley::new(50.0, 0.005).unwrap();
for _ in 0..20 {
let status = ph.add(1.0);
assert_ne!(status, DriftStatus::Drift);
}
}
#[test]
fn test_page_hinkley_drift() {
let mut ph = PageHinkley::new(10.0, 0.005).unwrap();
for _ in 0..20 {
ph.add(1.0);
}
let mut drift_detected = false;
for _ in 0..30 {
let status = ph.add(3.0);
if status == DriftStatus::Drift {
drift_detected = true;
break;
}
}
assert!(drift_detected);
}
#[test]
fn test_cusum_creation() {
let cusum = CUSUM::new(5.0, 1.0, 0.5).unwrap();
assert_eq!(cusum.positive_cusum(), 0.0);
assert_eq!(cusum.negative_cusum(), 0.0);
}
#[test]
fn test_cusum_stable() {
let mut cusum = CUSUM::new(5.0, 1.0, 0.5).unwrap();
for _ in 0..20 {
let status = cusum.add(1.0);
assert_eq!(status, DriftStatus::Stable);
}
}
#[test]
fn test_cusum_positive_drift() {
let mut cusum = CUSUM::new(3.0, 1.0, 0.5).unwrap();
let mut drift_detected = false;
for _ in 0..20 {
let status = cusum.add(2.5);
if status == DriftStatus::Drift {
drift_detected = true;
assert_eq!(cusum.drift_direction(), Some(true));
break;
}
}
assert!(drift_detected);
}
#[test]
fn test_cusum_negative_drift() {
let mut cusum = CUSUM::new(3.0, 1.0, 0.5).unwrap();
let mut drift_detected = false;
for _ in 0..20 {
let status = cusum.add(-0.5);
if status == DriftStatus::Drift {
drift_detected = true;
assert_eq!(cusum.drift_direction(), Some(false));
break;
}
}
assert!(drift_detected);
}
#[test]
fn test_statistical_detector_creation() {
let detector = StatisticalDriftDetector::new(30, 0.05).unwrap();
assert!(detector.reference_window.is_empty());
}
#[test]
fn test_statistical_detector_stable() {
let mut detector = StatisticalDriftDetector::new(20, 0.05).unwrap();
for _ in 0..60 {
let status = detector.add(1.0);
if detector.current_window.len() >= 20 {
assert_eq!(status, DriftStatus::Stable);
}
}
}
#[test]
fn test_statistical_detector_basic() {
let mut detector = StatisticalDriftDetector::new(20, 0.1).unwrap();
for _ in 0..20 {
let status = detector.add(1.0);
assert!(status == DriftStatus::Stable);
}
for _ in 0..20 {
detector.add(5.0);
}
detector.update_reference();
detector.reset();
}
#[test]
fn test_normal_cdf() {
assert!((normal_cdf(0.0) - 0.5).abs() < 0.01);
assert!(normal_cdf(1.96) > 0.97);
assert!(normal_cdf(-1.96) < 0.03);
}
#[test]
fn test_adwin_reset() {
let mut adwin = ADWIN::new(0.002, 100).unwrap();
for i in 1..=10 {
adwin.add(i as f64);
}
assert_eq!(adwin.window_size(), 10);
adwin.reset();
assert_eq!(adwin.window_size(), 0);
assert!(adwin.mean().is_none());
}
#[test]
fn test_page_hinkley_reset() {
let mut ph = PageHinkley::new(50.0, 0.005).unwrap();
for _ in 0..10 {
ph.add(1.0);
}
assert!(ph.count() > 0);
ph.reset();
assert_eq!(ph.count(), 0);
}
#[test]
fn test_cusum_reset() {
let mut cusum = CUSUM::new(5.0, 1.0, 0.5).unwrap();
for _ in 0..10 {
cusum.add(2.0);
}
assert!(cusum.positive_cusum() > 0.0);
cusum.reset();
assert_eq!(cusum.positive_cusum(), 0.0);
assert_eq!(cusum.negative_cusum(), 0.0);
}
}