use super::DriftSignal;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Ddm {
warning_level: f64,
drift_level: f64,
min_instances: u64,
mean: f64,
m2: f64,
count: u64,
min_p_plus_s: f64,
min_s: f64,
}
impl Ddm {
pub fn new() -> Self {
Self::with_params(2.0, 3.0, 30)
}
pub fn with_params(warning_level: f64, drift_level: f64, min_instances: u64) -> Self {
Self {
warning_level,
drift_level,
min_instances,
mean: 0.0,
m2: 0.0,
count: 0,
min_p_plus_s: f64::MAX,
min_s: f64::MAX,
}
}
#[inline]
pub fn warning_level(&self) -> f64 {
self.warning_level
}
#[inline]
pub fn drift_level(&self) -> f64 {
self.drift_level
}
#[inline]
pub fn min_instances(&self) -> u64 {
self.min_instances
}
#[inline]
pub fn std_dev(&self) -> f64 {
if self.count == 0 {
0.0
} else {
crate::math::sqrt(self.m2 / self.count as f64)
}
}
#[inline]
pub fn min_p_plus_s(&self) -> f64 {
self.min_p_plus_s
}
fn reset_running_stats(&mut self) {
self.mean = 0.0;
self.m2 = 0.0;
self.count = 0;
}
}
impl Default for Ddm {
fn default() -> Self {
Self::new()
}
}
impl core::fmt::Display for Ddm {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Ddm(warn={}, drift={}, min_inst={}, count={})",
self.warning_level, self.drift_level, self.min_instances, self.count
)
}
}
impl Ddm {
pub fn update(&mut self, value: f64) -> DriftSignal {
self.count += 1;
let n = self.count as f64;
let delta = value - self.mean;
self.mean += delta / n;
let delta2 = value - self.mean;
self.m2 += delta * delta2;
let std = crate::math::sqrt(self.m2 / n);
let p_plus_s = self.mean + std;
if self.count <= self.min_instances {
return DriftSignal::Stable;
}
if p_plus_s < self.min_p_plus_s {
self.min_p_plus_s = p_plus_s;
self.min_s = std;
}
if p_plus_s >= self.min_p_plus_s + self.drift_level * self.min_s {
self.reset_running_stats();
return DriftSignal::Drift;
}
if p_plus_s >= self.min_p_plus_s + self.warning_level * self.min_s {
return DriftSignal::Warning;
}
DriftSignal::Stable
}
pub fn reset(&mut self) {
self.mean = 0.0;
self.m2 = 0.0;
self.count = 0;
self.min_p_plus_s = f64::MAX;
self.min_s = f64::MAX;
}
pub fn estimated_mean(&self) -> f64 {
self.mean
}
}
#[cfg(feature = "alloc")]
impl super::DriftDetector for Ddm {
fn update(&mut self, value: f64) -> DriftSignal {
Ddm::update(self, value)
}
fn reset(&mut self) {
Ddm::reset(self);
}
fn clone_fresh(&self) -> alloc::boxed::Box<dyn super::DriftDetector> {
alloc::boxed::Box::new(Self::with_params(
self.warning_level,
self.drift_level,
self.min_instances,
))
}
fn clone_boxed(&self) -> alloc::boxed::Box<dyn super::DriftDetector> {
alloc::boxed::Box::new(self.clone())
}
fn estimated_mean(&self) -> f64 {
Ddm::estimated_mean(self)
}
fn serialize_state(&self) -> Option<super::DriftDetectorState> {
Some(super::DriftDetectorState::Ddm {
mean: self.mean,
m2: self.m2,
count: self.count,
min_p_plus_s: self.min_p_plus_s,
min_s: self.min_s,
})
}
fn restore_state(&mut self, state: &super::DriftDetectorState) -> bool {
if let super::DriftDetectorState::Ddm {
mean,
m2,
count,
min_p_plus_s,
min_s,
} = state
{
self.mean = *mean;
self.m2 = *m2;
self.count = *count;
self.min_p_plus_s = *min_p_plus_s;
self.min_s = *min_s;
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
extern crate alloc;
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
use super::super::DriftDetector;
use super::super::DriftSignal;
use super::*;
fn feed(ddm: &mut Ddm, values: &[f64]) -> Vec<DriftSignal> {
values.iter().map(|&v| ddm.update(v)).collect()
}
fn generate_values(centre: f64, jitter: f64, n: usize) -> Vec<f64> {
(0..n)
.map(|i| {
let t = i as f64;
centre + jitter * crate::math::sin(t * 0.7)
})
.collect()
}
#[test]
fn stationary_low_error_no_drift() {
let mut ddm = Ddm::new();
let values = generate_values(0.1, 0.02, 5000);
let signals = feed(&mut ddm, &values);
let drift_count = signals.iter().filter(|&&s| s == DriftSignal::Drift).count();
assert_eq!(
drift_count, 0,
"stationary low error should produce no drift"
);
}
#[test]
fn error_rate_increase_detects_drift() {
let mut ddm = Ddm::new();
let low = generate_values(0.1, 0.01, 2000);
feed(&mut ddm, &low);
let high = generate_values(0.8, 0.01, 2000);
let signals = feed(&mut ddm, &high);
let drift_count = signals.iter().filter(|&&s| s == DriftSignal::Drift).count();
assert!(
drift_count >= 1,
"sudden error increase should trigger at least one drift, got {}",
drift_count
);
}
#[test]
fn warning_before_drift() {
let mut ddm = Ddm::new();
let baseline = generate_values(0.05, 0.005, 200);
feed(&mut ddm, &baseline);
let ramp: Vec<f64> = (0..3000)
.map(|i| {
let t = i as f64 / 3000.0;
0.05 + 0.85 * t
})
.collect();
let signals = feed(&mut ddm, &ramp);
let first_warning = signals.iter().position(|&s| s == DriftSignal::Warning);
let first_drift = signals.iter().position(|&s| s == DriftSignal::Drift);
assert!(
first_warning.is_some(),
"gradual increase should trigger at least one warning"
);
assert!(
first_drift.is_some(),
"gradual increase should eventually trigger drift"
);
assert!(
first_warning.unwrap() < first_drift.unwrap(),
"warning (idx {}) should fire before drift (idx {})",
first_warning.unwrap(),
first_drift.unwrap()
);
}
#[test]
fn minimum_tracking() {
let mut ddm = Ddm::new();
for _ in 0..5 {
ddm.update(0.5);
}
let early_min = ddm.min_p_plus_s();
let low = generate_values(0.05, 0.005, 200);
feed(&mut ddm, &low);
let later_min = ddm.min_p_plus_s();
assert!(
later_min < early_min,
"min_p_plus_s should decrease with low errors: early={}, later={}",
early_min,
later_min
);
}
#[test]
fn estimated_mean_tracks_correctly() {
let mut ddm = Ddm::with_params(2.0, 3.0, 100_000);
for _ in 0..1000 {
ddm.update(0.3);
}
let mean = ddm.estimated_mean();
assert!(
(mean - 0.3).abs() < 1e-10,
"mean should be ~0.3, got {}",
mean
);
for _ in 0..1000 {
ddm.update(0.7);
}
let mean2 = ddm.estimated_mean();
assert!(
(mean2 - 0.5).abs() < 1e-10,
"mean should be ~0.5, got {}",
mean2
);
for _ in 0..1000 {
ddm.update(0.2);
}
let mean3 = ddm.estimated_mean();
assert!(
(mean3 - 0.4).abs() < 1e-10,
"mean should be ~0.4, got {}",
mean3
);
}
#[test]
fn reset_clears_state() {
let mut ddm = Ddm::new();
let vals = generate_values(0.4, 0.05, 500);
feed(&mut ddm, &vals);
assert!(ddm.count > 0);
ddm.reset();
assert_eq!(ddm.count, 0);
assert_eq!(ddm.mean, 0.0);
assert_eq!(ddm.m2, 0.0);
assert_eq!(ddm.min_p_plus_s, f64::MAX);
assert_eq!(ddm.min_s, f64::MAX);
}
#[cfg(feature = "alloc")]
#[test]
fn clone_fresh_same_params() {
let ddm = Ddm::with_params(1.5, 2.5, 50);
let mut dirty = ddm.clone();
let vals = generate_values(0.3, 0.02, 200);
feed(&mut dirty, &vals);
let mut fresh = dirty.clone_fresh();
assert_eq!(fresh.estimated_mean(), 0.0);
let mut manual_fresh = Ddm::with_params(1.5, 2.5, 50);
let test_vals = generate_values(0.2, 0.01, 100);
let signals_a: Vec<DriftSignal> = test_vals.iter().map(|&v| fresh.update(v)).collect();
let signals_b: Vec<DriftSignal> =
test_vals.iter().map(|&v| manual_fresh.update(v)).collect();
assert_eq!(
signals_a, signals_b,
"clone_fresh should behave identically to a new instance with same params"
);
assert!(
(fresh.estimated_mean() - manual_fresh.estimated_mean()).abs() < 1e-12,
"means should match: {} vs {}",
fresh.estimated_mean(),
manual_fresh.estimated_mean()
);
}
#[test]
fn warmup_no_drift() {
let min_inst = 50u64;
let mut ddm = Ddm::with_params(2.0, 3.0, min_inst);
for i in 0..min_inst {
let value = if i % 2 == 0 { 0.0 } else { 1.0 };
let signal = ddm.update(value);
assert_eq!(
signal,
DriftSignal::Stable,
"during warmup (i={}), signal should be Stable, got {:?}",
i,
signal
);
}
}
#[test]
fn custom_params() {
let ddm = Ddm::with_params(1.0, 4.0, 100);
assert_eq!(ddm.warning_level(), 1.0);
assert_eq!(ddm.drift_level(), 4.0);
assert_eq!(ddm.min_instances(), 100);
}
#[test]
fn default_matches_new() {
let a = Ddm::new();
let b = Ddm::default();
assert_eq!(a.warning_level, b.warning_level);
assert_eq!(a.drift_level, b.drift_level);
assert_eq!(a.min_instances, b.min_instances);
}
#[test]
fn std_dev_zero_initially() {
let ddm = Ddm::new();
assert_eq!(ddm.std_dev(), 0.0);
}
#[test]
fn std_dev_zero_for_constant() {
let mut ddm = Ddm::new();
for _ in 0..100 {
ddm.update(0.5);
}
assert!(
ddm.std_dev().abs() < 1e-12,
"std of constant stream should be ~0, got {}",
ddm.std_dev()
);
}
#[test]
fn drift_resets_running_keeps_mins() {
let mut ddm = Ddm::with_params(2.0, 3.0, 10);
for _ in 0..100 {
ddm.update(0.05);
}
let min_before = ddm.min_p_plus_s();
let mut got_drift = false;
for _ in 0..500 {
if ddm.update(0.95) == DriftSignal::Drift {
got_drift = true;
break;
}
}
assert!(got_drift, "should have triggered drift");
assert_eq!(ddm.count, 0, "count should be 0 after drift reset");
assert_eq!(ddm.mean, 0.0, "mean should be 0 after drift reset");
assert_eq!(
ddm.min_p_plus_s(),
min_before,
"min_p_plus_s should be preserved after drift reset"
);
}
}