use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use super::{DriftDetector, DriftSignal};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct Bucket {
total: f64,
variance: f64,
count: u64,
}
impl Bucket {
#[inline]
fn singleton(value: f64) -> Self {
Self {
total: value,
variance: 0.0,
count: 1,
}
}
fn merge(a: &Bucket, b: &Bucket) -> Self {
let count = a.count + b.count;
let total = a.total + b.total;
let mean_a = a.total / a.count as f64;
let mean_b = b.total / b.count as f64;
let diff = mean_a - mean_b;
let variance = a.variance
+ b.variance
+ diff * diff * (a.count as f64) * (b.count as f64) / (count as f64);
Self {
total,
variance,
count,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Adwin {
delta: f64,
max_buckets: usize,
rows: Vec<Vec<Bucket>>,
total: f64,
variance: f64,
count: u64,
width: u64,
min_window: u64,
}
impl Adwin {
pub fn new() -> Self {
Self::with_delta(0.002)
}
pub fn with_delta(delta: f64) -> Self {
assert!(
delta > 0.0 && delta < 1.0,
"delta must be in (0, 1), got {delta}"
);
Self {
delta,
max_buckets: 5,
rows: Vec::new(),
total: 0.0,
variance: 0.0,
count: 0,
width: 0,
min_window: 32,
}
}
pub fn with_max_buckets(mut self, m: usize) -> Self {
assert!(m >= 2, "max_buckets must be >= 2, got {m}");
self.max_buckets = m;
self
}
pub fn with_min_window(mut self, min: u64) -> Self {
self.min_window = min;
self
}
#[inline]
pub fn width(&self) -> u64 {
self.width
}
fn insert_bucket(&mut self, value: f64) {
if self.rows.is_empty() {
self.rows.push(Vec::new());
}
self.rows[0].push(Bucket::singleton(value));
self.count += 1;
self.width += 1;
let old_mean = if self.count > 1 {
(self.total) / (self.count - 1) as f64
} else {
0.0
};
self.total += value;
let new_mean = self.total / self.count as f64;
self.variance += (value - old_mean) * (value - new_mean);
}
fn compress(&mut self) {
let max = self.max_buckets;
let mut row_idx = 0;
while row_idx < self.rows.len() {
if self.rows[row_idx].len() <= max {
break; }
let b1 = self.rows[row_idx].remove(0);
let b2 = self.rows[row_idx].remove(0);
let merged = Bucket::merge(&b1, &b2);
let next_row = row_idx + 1;
if next_row >= self.rows.len() {
self.rows.push(Vec::new());
}
self.rows[next_row].push(merged);
row_idx += 1;
}
}
fn check_drift(&self) -> (bool, bool) {
if self.width <= self.min_window {
return (false, false);
}
if self.width < 4 {
return (false, false);
}
let ln_width = crate::math::ln(self.width as f64);
if ln_width <= 0.0 {
return (false, false);
}
let delta_prime = self.delta / ln_width;
let delta_warn = (2.0 * self.delta) / ln_width;
let mut right_count: u64 = 0;
let mut right_total: f64 = 0.0;
let mut warning_found = false;
for row in &self.rows {
for bucket in row.iter().rev() {
right_count += bucket.count;
right_total += bucket.total;
let left_count = self.count - right_count;
if left_count < 1 || right_count < 1 {
continue;
}
let left_total = self.total - right_total;
let mean_left = left_total / left_count as f64;
let mean_right = right_total / right_count as f64;
let abs_diff = (mean_left - mean_right).abs();
let n0 = left_count as f64;
let n1 = right_count as f64;
let m = 1.0 / (1.0 / n0 + 1.0 / n1);
let epsilon_drift =
crate::math::sqrt((1.0 / (2.0 * m)) * crate::math::ln(4.0 / delta_prime));
if abs_diff >= epsilon_drift {
return (true, true);
}
let epsilon_warn =
crate::math::sqrt((1.0 / (2.0 * m)) * crate::math::ln(4.0 / delta_warn));
if abs_diff >= epsilon_warn {
warning_found = true;
}
}
}
(false, warning_found)
}
fn shrink_window(&mut self) {
let ln_width = crate::math::ln(self.width as f64);
if ln_width <= 0.0 {
return;
}
let delta_prime = self.delta / ln_width;
let mut right_count: u64 = 0;
let mut right_total: f64 = 0.0;
let mut right_variance: f64 = 0.0;
let mut all_buckets: Vec<(usize, usize)> = Vec::new(); for (row_idx, row) in self.rows.iter().enumerate() {
for (bucket_idx, _) in row.iter().enumerate().rev() {
all_buckets.push((row_idx, bucket_idx));
}
}
let mut split_pos = all_buckets.len(); for (pos, &(row_idx, bucket_idx)) in all_buckets.iter().enumerate() {
let bucket = &self.rows[row_idx][bucket_idx];
if right_count > 0 {
let mean_right_old = right_total / right_count as f64;
let mean_bucket = bucket.total / bucket.count as f64;
let diff = mean_right_old - mean_bucket;
right_variance = right_variance
+ bucket.variance
+ diff * diff * (right_count as f64) * (bucket.count as f64)
/ (right_count + bucket.count) as f64;
} else {
right_variance = bucket.variance;
}
right_count += bucket.count;
right_total += bucket.total;
let left_count = self.count - right_count;
if left_count < 1 || right_count < 1 {
continue;
}
let left_total = self.total - right_total;
let mean_left = left_total / left_count as f64;
let mean_right = right_total / right_count as f64;
let abs_diff = (mean_left - mean_right).abs();
let n0 = left_count as f64;
let n1 = right_count as f64;
let m = 1.0 / (1.0 / n0 + 1.0 / n1);
let epsilon = crate::math::sqrt((1.0 / (2.0 * m)) * crate::math::ln(4.0 / delta_prime));
if abs_diff >= epsilon {
split_pos = pos + 1;
break;
}
}
if split_pos >= all_buckets.len() {
return;
}
let keep_set: Vec<(usize, usize)> = all_buckets[..split_pos].to_vec();
let mut new_rows: Vec<Vec<Bucket>> = Vec::new();
let max_row = self.rows.len();
new_rows.resize_with(max_row, Vec::new);
let mut keep_flags: Vec<Vec<bool>> =
self.rows.iter().map(|row| vec![false; row.len()]).collect();
for &(r, b) in &keep_set {
keep_flags[r][b] = true;
}
let mut new_total: f64 = 0.0;
let mut new_count: u64 = 0;
for (row_idx, row) in self.rows.iter().enumerate() {
for (bucket_idx, bucket) in row.iter().enumerate() {
if keep_flags[row_idx][bucket_idx] {
new_total += bucket.total;
new_count += bucket.count;
new_rows[row_idx].push(bucket.clone());
}
}
}
while new_rows.last().is_some_and(|r| r.is_empty()) {
new_rows.pop();
}
self.rows = new_rows;
self.total = new_total;
self.count = new_count;
self.width = new_count;
self.recompute_variance();
}
fn recompute_variance(&mut self) {
if self.count == 0 {
self.variance = 0.0;
return;
}
let mut running_total: f64 = 0.0;
let mut running_count: u64 = 0;
let mut running_var: f64 = 0.0;
for row in self.rows.iter().rev() {
for bucket in row.iter() {
if running_count == 0 {
running_total = bucket.total;
running_count = bucket.count;
running_var = bucket.variance;
} else {
let combined_count = running_count + bucket.count;
let mean_running = running_total / running_count as f64;
let mean_bucket = bucket.total / bucket.count as f64;
let diff = mean_running - mean_bucket;
running_var = running_var
+ bucket.variance
+ diff * diff * (running_count as f64) * (bucket.count as f64)
/ combined_count as f64;
running_total += bucket.total;
running_count = combined_count;
}
}
}
self.variance = running_var;
}
}
impl Default for Adwin {
fn default() -> Self {
Self::new()
}
}
impl core::fmt::Display for Adwin {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Adwin(delta={}, width={}, mean={:.6})",
self.delta,
self.width,
self.estimated_mean()
)
}
}
impl DriftDetector for Adwin {
fn update(&mut self, value: f64) -> DriftSignal {
self.insert_bucket(value);
self.compress();
let (drift, warning) = self.check_drift();
if drift {
self.shrink_window();
DriftSignal::Drift
} else if warning {
DriftSignal::Warning
} else {
DriftSignal::Stable
}
}
fn reset(&mut self) {
self.rows.clear();
self.total = 0.0;
self.variance = 0.0;
self.count = 0;
self.width = 0;
}
fn clone_fresh(&self) -> Box<dyn DriftDetector> {
Box::new(Self::with_delta(self.delta).with_max_buckets(self.max_buckets))
}
fn clone_boxed(&self) -> Box<dyn DriftDetector> {
Box::new(self.clone())
}
fn estimated_mean(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total / self.count as f64
}
}
fn serialize_state(&self) -> Option<super::DriftDetectorState> {
use super::{AdwinBucketState, DriftDetectorState};
let rows = self
.rows
.iter()
.map(|row| {
row.iter()
.map(|b| AdwinBucketState {
total: b.total,
variance: b.variance,
count: b.count,
})
.collect()
})
.collect();
Some(DriftDetectorState::Adwin {
rows,
total: self.total,
variance: self.variance,
count: self.count,
width: self.width,
})
}
fn restore_state(&mut self, state: &super::DriftDetectorState) -> bool {
if let super::DriftDetectorState::Adwin {
rows,
total,
variance,
count,
width,
} = state
{
self.rows = rows
.iter()
.map(|row| {
row.iter()
.map(|b| Bucket {
total: b.total,
variance: b.variance,
count: b.count,
})
.collect()
})
.collect();
self.total = *total;
self.variance = *variance;
self.count = *count;
self.width = *width;
true
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::super::{DriftDetector, DriftSignal};
use super::*;
use alloc::vec::Vec;
struct Xorshift64(u64);
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self(if seed == 0 { 1 } else { seed })
}
fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
fn next_normal(&mut self, mean: f64, std: f64) -> f64 {
let u1 = self.next_f64().max(1e-15); let u2 = self.next_f64();
let z = crate::math::sqrt(-2.0 * crate::math::ln(u1))
* crate::math::cos(2.0 * core::f64::consts::PI * u2);
mean + std * z
}
}
#[test]
fn no_false_alarm_constant() {
let mut det = Adwin::with_delta(0.002);
let mut rng = Xorshift64::new(42);
let mut drift_count = 0;
for _ in 0..10_000 {
let value = 0.5 + rng.next_normal(0.0, 0.01);
if det.update(value) == DriftSignal::Drift {
drift_count += 1;
}
}
assert!(
drift_count <= 2,
"Too many false alarms on constant distribution: {drift_count}"
);
}
#[test]
fn detect_abrupt_shift() {
let mut det = Adwin::with_delta(0.002);
let mut rng = Xorshift64::new(123);
for _ in 0..2000 {
det.update(rng.next_normal(0.0, 0.1));
}
let mut detected = false;
for i in 0..2000 {
let sig = det.update(rng.next_normal(5.0, 0.1));
if sig == DriftSignal::Drift {
detected = true;
assert!(
i < 500,
"Drift detected too late after abrupt shift: sample {i}"
);
break;
}
}
assert!(detected, "Failed to detect abrupt mean shift from 0 to 5");
}
#[test]
fn detect_gradual_shift() {
let mut det = Adwin::with_delta(0.01); let mut rng = Xorshift64::new(456);
let mut detected = false;
for i in 0..5000 {
let mean = 5.0 * (i as f64) / 5000.0;
let value = rng.next_normal(mean, 0.1);
if det.update(value) == DriftSignal::Drift {
detected = true;
break;
}
}
assert!(detected, "Failed to detect gradual mean shift from 0 to 5");
}
#[test]
fn estimated_mean_tracks() {
let mut det = Adwin::new();
for _ in 0..100 {
det.update(3.0);
}
let mean = det.estimated_mean();
assert!((mean - 3.0).abs() < 1e-9, "Expected mean ~3.0, got {mean}");
for _ in 0..100 {
det.update(7.0);
}
let mean = det.estimated_mean();
assert!(
mean > 2.5 && mean < 7.5,
"Mean out of expected range: {mean}"
);
}
#[test]
fn reset_clears_state() {
let mut det = Adwin::new();
for _ in 0..500 {
det.update(1.0);
}
assert!(det.width() > 0);
assert!(det.estimated_mean() > 0.0);
det.reset();
assert_eq!(det.width(), 0);
assert_eq!(det.estimated_mean(), 0.0);
assert!(det.rows.is_empty());
assert_eq!(det.count, 0);
assert_eq!(det.total, 0.0);
assert_eq!(det.variance, 0.0);
}
#[test]
fn clone_fresh_preserves_config() {
let mut det = Adwin::with_delta(0.05).with_max_buckets(7);
for _ in 0..200 {
det.update(42.0);
}
let fresh = det.clone_fresh();
assert_eq!(fresh.estimated_mean(), 0.0);
let mut fresh = fresh;
let mut drifts = 0;
for _ in 0..1000 {
if fresh.update(1.0) == DriftSignal::Drift {
drifts += 1;
}
}
assert!(
drifts <= 1,
"clone_fresh produced detector with too many false alarms: {drifts}"
);
}
#[test]
fn warmup_suppresses_early_detection() {
let mut det = Adwin::with_delta(0.002).with_min_window(100);
let mut any_drift = false;
for _ in 0..50 {
if det.update(0.0) == DriftSignal::Drift {
any_drift = true;
}
}
for _ in 0..50 {
if det.update(100.0) == DriftSignal::Drift {
any_drift = true;
}
}
assert!(
!any_drift,
"Drift should not fire before min_window=100 samples"
);
}
#[test]
fn compression_bounds_memory() {
let mut det = Adwin::with_delta(0.002).with_max_buckets(5);
for i in 0..10_000 {
det.update(i as f64);
}
for (row_idx, row) in det.rows.iter().enumerate() {
assert!(
row.len() <= det.max_buckets + 1, "Row {row_idx} has {} buckets, exceeding max {}",
row.len(),
det.max_buckets
);
}
let total_buckets: usize = det.rows.iter().map(|r| r.len()).sum();
let expected_max = det.rows.len() * (det.max_buckets + 1);
assert!(
total_buckets <= expected_max,
"Total buckets {total_buckets} exceeds expected max {expected_max}"
);
}
#[test]
fn window_shrinks_on_drift() {
let mut det = Adwin::with_delta(0.002);
for _ in 0..2000 {
det.update(0.0);
}
let width_before = det.width();
assert!(
width_before >= 1900,
"Expected large window, got {width_before}"
);
let mut drifted = false;
for _ in 0..500 {
if det.update(100.0) == DriftSignal::Drift {
drifted = true;
break;
}
}
assert!(drifted, "Expected drift on extreme shift");
let width_after = det.width();
assert!(
width_after < width_before,
"Window should shrink after drift: before={width_before}, after={width_after}"
);
}
#[test]
fn warning_precedes_drift() {
let mut det = Adwin::with_delta(0.002);
let mut rng = Xorshift64::new(789);
for _ in 0..1000 {
det.update(rng.next_normal(0.0, 0.1));
}
let mut _saw_warning = false;
let mut saw_drift = false;
for _ in 0..2000 {
let sig = det.update(rng.next_normal(2.0, 0.1));
match sig {
DriftSignal::Warning => {
if !saw_drift {
_saw_warning = true;
}
}
DriftSignal::Drift => {
saw_drift = true;
break;
}
_ => {}
}
}
assert!(saw_drift, "Should have detected drift on shift from 0 to 2");
}
#[test]
fn deterministic_for_same_input() {
let values: Vec<f64> = (0..500)
.map(|i| crate::math::sin(i as f64 * 0.01))
.collect();
let mut det1 = Adwin::with_delta(0.01);
let mut det2 = Adwin::with_delta(0.01);
let signals1: Vec<DriftSignal> = values.iter().map(|&v| det1.update(v)).collect();
let signals2: Vec<DriftSignal> = values.iter().map(|&v| det2.update(v)).collect();
assert_eq!(
signals1, signals2,
"Same input must produce identical signals"
);
}
}