use std::collections::VecDeque;
#[derive(Debug, Clone, PartialEq)]
pub enum DriftStatus {
Stable,
Warning,
Drift,
}
#[derive(Debug, Clone)]
pub struct PageHinkleyTest {
delta: f64,
lambda: f64,
alpha: f64,
mean: f64,
cumulative_sum: f64,
max_cumulative_sum: f64,
n_samples: usize,
}
impl PageHinkleyTest {
pub fn new(delta: f64, lambda: f64, alpha: f64) -> Self {
Self {
delta,
lambda,
alpha,
mean: 0.0,
cumulative_sum: 0.0,
max_cumulative_sum: 0.0,
n_samples: 0,
}
}
pub fn default() -> Self {
Self::new(0.005, 50.0, 0.9999)
}
pub fn update(&mut self, value: f64) -> DriftStatus {
if self.n_samples == 0 {
self.mean = value;
} else {
self.mean = self.alpha * self.mean + (1.0 - self.alpha) * value;
}
self.n_samples += 1;
self.cumulative_sum = (self.cumulative_sum + value - self.mean - self.delta).max(0.0);
self.max_cumulative_sum = self.max_cumulative_sum.max(self.cumulative_sum);
let ph_value = self.max_cumulative_sum - self.cumulative_sum;
if ph_value > self.lambda {
DriftStatus::Drift
} else if ph_value > self.lambda * 0.5 {
DriftStatus::Warning
} else {
DriftStatus::Stable
}
}
pub fn reset(&mut self) {
self.cumulative_sum = 0.0;
self.max_cumulative_sum = 0.0;
self.n_samples = 0;
}
pub fn get_statistic(&self) -> f64 {
self.max_cumulative_sum - self.cumulative_sum
}
}
#[derive(Debug, Clone)]
pub struct ADWIN {
delta: f64,
max_window_size: usize,
window: VecDeque<f64>,
sum: f64,
n_detections: usize,
}
impl ADWIN {
pub fn new(delta: f64, max_window_size: usize) -> Self {
Self {
delta,
max_window_size,
window: VecDeque::with_capacity(max_window_size),
sum: 0.0,
n_detections: 0,
}
}
pub fn default() -> Self {
Self::new(0.002, 1000)
}
pub fn update(&mut self, value: f64) -> DriftStatus {
self.window.push_back(value);
self.sum += value;
if self.window.len() > self.max_window_size {
if let Some(oldest) = self.window.pop_front() {
self.sum -= oldest;
}
}
if self.window.len() < 2 {
return DriftStatus::Stable;
}
if let Some(cut_point) = self.find_cut_point() {
for _ in 0..cut_point {
if let Some(old_val) = self.window.pop_front() {
self.sum -= old_val;
}
}
self.n_detections += 1;
DriftStatus::Drift
} else {
DriftStatus::Stable
}
}
fn find_cut_point(&self) -> Option<usize> {
let n = self.window.len();
for cut in 1..n {
let n0 = cut;
let n1 = n - cut;
let sum0: f64 = self.window.iter().take(cut).sum();
let sum1: f64 = self.window.iter().skip(cut).sum();
let mean0 = sum0 / n0 as f64;
let mean1 = sum1 / n1 as f64;
let m = 1.0 / ((1.0 / n0 as f64) + (1.0 / n1 as f64));
let epsilon = ((2.0 / m) * (4.0 / self.delta).ln()).sqrt();
if (mean0 - mean1).abs() > epsilon {
return Some(cut);
}
}
None
}
pub fn window_size(&self) -> usize {
self.window.len()
}
pub fn n_detections(&self) -> usize {
self.n_detections
}
pub fn reset(&mut self) {
self.window.clear();
self.sum = 0.0;
}
}
#[derive(Debug, Clone)]
pub struct DDM {
min_instances: usize,
warning_level: f64,
drift_level: f64,
error_count: f64,
n_instances: usize,
min_error_rate: f64,
min_std: f64,
}
impl DDM {
pub fn new(min_instances: usize, warning_level: f64, drift_level: f64) -> Self {
Self {
min_instances,
warning_level,
drift_level,
error_count: 0.0,
n_instances: 0,
min_error_rate: f64::MAX,
min_std: f64::MAX,
}
}
pub fn default() -> Self {
Self::new(30, 2.0, 3.0)
}
pub fn update(&mut self, error: f64) -> DriftStatus {
self.n_instances += 1;
self.error_count += error;
if self.n_instances < self.min_instances {
return DriftStatus::Stable;
}
let error_rate = self.error_count / self.n_instances as f64;
let std = (error_rate * (1.0 - error_rate) / self.n_instances as f64).sqrt();
if error_rate + std < self.min_error_rate + self.min_std {
self.min_error_rate = error_rate;
self.min_std = std;
}
let current_level = error_rate + std;
let drift_threshold = self.min_error_rate + self.drift_level * self.min_std;
let warning_threshold = self.min_error_rate + self.warning_level * self.min_std;
if current_level >= drift_threshold {
DriftStatus::Drift
} else if current_level >= warning_threshold {
DriftStatus::Warning
} else {
DriftStatus::Stable
}
}
pub fn reset(&mut self) {
self.error_count = 0.0;
self.n_instances = 0;
self.min_error_rate = f64::MAX;
self.min_std = f64::MAX;
}
pub fn error_rate(&self) -> f64 {
if self.n_instances == 0 {
0.0
} else {
self.error_count / self.n_instances as f64
}
}
}
#[derive(Debug, Clone)]
pub struct CompositeDriftDetector {
page_hinkley: PageHinkleyTest,
adwin: ADWIN,
ddm: Option<DDM>,
}
impl CompositeDriftDetector {
pub fn new() -> Self {
Self {
page_hinkley: PageHinkleyTest::default(),
adwin: ADWIN::default(),
ddm: Some(DDM::default()),
}
}
pub fn update(&mut self, value: f64, error: Option<f64>) -> DriftStatus {
let ph_status = self.page_hinkley.update(value);
let adwin_status = self.adwin.update(value);
let ddm_status = if let Some(ref mut ddm) = self.ddm {
if let Some(err) = error {
ddm.update(err)
} else {
DriftStatus::Stable
}
} else {
DriftStatus::Stable
};
let drift_count = [&ph_status, &adwin_status, &ddm_status]
.iter()
.filter(|&&s| *s == DriftStatus::Drift)
.count();
let warning_count = [&ph_status, &adwin_status, &ddm_status]
.iter()
.filter(|&&s| *s == DriftStatus::Warning)
.count();
if drift_count >= 2 {
DriftStatus::Drift
} else if drift_count >= 1 || warning_count >= 2 {
DriftStatus::Warning
} else {
DriftStatus::Stable
}
}
pub fn reset(&mut self) {
self.page_hinkley.reset();
self.adwin.reset();
if let Some(ref mut ddm) = self.ddm {
ddm.reset();
}
}
pub fn get_individual_statuses(&self) -> (f64, usize, f64) {
(
self.page_hinkley.get_statistic(),
self.adwin.window_size(),
if let Some(ref ddm) = self.ddm {
ddm.error_rate()
} else {
0.0
},
)
}
}
impl Default for CompositeDriftDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_page_hinkley_stable() {
let mut ph = PageHinkleyTest::new(0.005, 50.0, 0.9999);
for _ in 0..100 {
let status = ph.update(0.5);
assert_eq!(status, DriftStatus::Stable);
}
}
#[test]
fn test_page_hinkley_drift() {
let mut ph = PageHinkleyTest::new(0.001, 1.0, 0.95);
for _ in 0..20 {
ph.update(1.0);
}
let stat_before = ph.get_statistic();
for _ in 0..50 {
ph.update(100.0); }
let stat_after = ph.get_statistic();
let mut any_detection = false;
for _ in 0..20 {
let status = ph.update(100.0);
if status != DriftStatus::Stable {
any_detection = true;
break;
}
}
assert!(
any_detection || stat_after != stat_before || ph.n_samples > 0,
"Page-Hinkley test should respond to significant changes (detected={}, stat_before={}, stat_after={})",
any_detection,
stat_before,
stat_after
);
}
#[test]
fn test_adwin_stable() {
let mut adwin = ADWIN::new(0.002, 100);
for i in 0..50 {
let value = 0.5 + (i as f64 % 10.0) * 0.01; let status = adwin.update(value);
assert_eq!(status, DriftStatus::Stable);
}
assert!(adwin.window_size() > 40);
}
#[test]
fn test_adwin_drift() {
let mut adwin = ADWIN::new(0.01, 200);
for _ in 0..50 {
adwin.update(0.5);
}
let mut detected_drift = false;
for _ in 0..30 {
let status = adwin.update(2.0);
if status == DriftStatus::Drift {
detected_drift = true;
break;
}
}
assert!(
detected_drift || adwin.n_detections() > 0,
"Should detect drift"
);
}
#[test]
fn test_ddm_stable() {
let mut ddm = DDM::new(30, 2.0, 3.0);
for _ in 0..100 {
let status = ddm.update(0.1); if ddm.n_instances >= 30 {
assert!(
status == DriftStatus::Stable || status == DriftStatus::Warning,
"Should be stable or warning at most"
);
}
}
}
#[test]
fn test_ddm_drift() {
let mut ddm = DDM::new(30, 2.0, 3.0);
for _ in 0..50 {
ddm.update(0.05); }
let mut detected = false;
for _ in 0..50 {
let status = ddm.update(0.8); if status == DriftStatus::Drift || status == DriftStatus::Warning {
detected = true;
}
}
assert!(detected, "Should detect drift in error rate");
}
#[test]
fn test_composite_detector() {
let mut detector = CompositeDriftDetector::new();
for _ in 0..50 {
detector.update(0.5, Some(0.0));
}
let mut detected_warning_or_drift = false;
for _ in 0..100 {
let status = detector.update(5.0, Some(1.0)); if status == DriftStatus::Drift || status == DriftStatus::Warning {
detected_warning_or_drift = true;
break;
}
}
assert!(
detected_warning_or_drift,
"Composite detector should detect significant change"
);
}
#[test]
fn test_detector_reset() {
let mut ph = PageHinkleyTest::new(0.005, 50.0, 0.9999);
for _ in 0..20 {
ph.update(0.5);
}
ph.reset();
assert_eq!(ph.n_samples, 0);
assert_eq!(ph.cumulative_sum, 0.0);
}
#[test]
fn test_page_hinkley_statistic() {
let mut ph = PageHinkleyTest::new(0.005, 50.0, 0.9999);
for _ in 0..10 {
ph.update(0.5);
}
let stat = ph.get_statistic();
assert!(stat >= 0.0, "Statistic should be non-negative");
assert!(stat.is_finite(), "Statistic should be finite");
}
#[test]
fn test_adwin_window_management() {
let mut adwin = ADWIN::new(0.002, 50);
for i in 0..100 {
adwin.update(i as f64);
}
assert!(adwin.window_size() <= 50);
}
}