use crate::error::{Result, TransformError};
#[derive(Debug, Clone)]
struct Bucket {
count: usize,
total: f64,
variance: f64,
}
#[derive(Debug, Clone)]
pub struct Adwin {
delta: f64,
buckets: Vec<Vec<Bucket>>,
max_buckets: usize,
total_count: usize,
total_sum: f64,
total_variance: f64,
last_change_detected: bool,
min_window_length: usize,
}
impl Adwin {
pub fn new(delta: f64) -> Result<Self> {
if delta <= 0.0 || delta >= 1.0 {
return Err(TransformError::InvalidInput(
"delta must be in (0, 1)".to_string(),
));
}
Ok(Self {
delta,
buckets: Vec::new(),
max_buckets: 5, total_count: 0,
total_sum: 0.0,
total_variance: 0.0,
last_change_detected: false,
min_window_length: 10,
})
}
pub fn set_min_window_length(&mut self, min_len: usize) {
self.min_window_length = min_len;
}
pub fn add_element(&mut self, value: f64) -> Result<bool> {
if !value.is_finite() {
return Err(TransformError::InvalidInput(
"Value must be finite".to_string(),
));
}
self.last_change_detected = false;
let new_bucket = Bucket {
count: 1,
total: value,
variance: 0.0,
};
if self.buckets.is_empty() {
self.buckets.push(Vec::new());
}
self.buckets[0].push(new_bucket);
self.total_count += 1;
self.total_sum += value;
self.total_variance += value * value;
self.compress();
if self.total_count >= self.min_window_length {
self.last_change_detected = self.check_and_cut();
}
Ok(self.last_change_detected)
}
fn compress(&mut self) {
let mut level = 0;
while level < self.buckets.len() {
if self.buckets[level].len() > self.max_buckets + 1 {
if self.buckets[level].len() >= 2 {
let b1 = self.buckets[level].remove(0);
let b2 = self.buckets[level].remove(0);
let merged_count = b1.count + b2.count;
let merged_total = b1.total + b2.total;
let delta_mean =
b2.total / b2.count.max(1) as f64 - b1.total / b1.count.max(1) as f64;
let merged_variance = b1.variance
+ b2.variance
+ delta_mean * delta_mean * (b1.count * b2.count) as f64
/ merged_count.max(1) as f64;
let merged = Bucket {
count: merged_count,
total: merged_total,
variance: merged_variance,
};
if level + 1 >= self.buckets.len() {
self.buckets.push(Vec::new());
}
self.buckets[level + 1].push(merged);
}
}
level += 1;
}
}
fn check_and_cut(&mut self) -> bool {
let mut w1_count: usize = 0;
let mut w1_sum: f64 = 0.0;
let mut _w1_var: f64 = 0.0;
let n_levels = self.buckets.len();
let mut ordered_buckets: Vec<(usize, usize)> = Vec::new(); for level in 0..n_levels {
for idx in (0..self.buckets[level].len()).rev() {
ordered_buckets.push((level, idx));
}
}
for &(level, idx) in ordered_buckets.iter() {
let bucket = &self.buckets[level][idx];
w1_count += bucket.count;
w1_sum += bucket.total;
_w1_var += bucket.variance;
let w0_count = self.total_count - w1_count;
if w0_count < 1 || w1_count < 1 {
continue;
}
let w0_sum = self.total_sum - w1_sum;
let mean0 = w0_sum / w0_count as f64;
let mean1 = w1_sum / w1_count as f64;
let diff = (mean0 - mean1).abs();
let n = self.total_count as f64;
let m = (1.0 / w0_count as f64 + 1.0 / w1_count as f64).min(1.0);
let delta_prime = self.delta / n.ln().max(1.0);
let epsilon = ((m / (2.0 * delta_prime)).ln().max(0.0) * m / 2.0).sqrt();
if diff >= epsilon && w0_count >= 2 && w1_count >= 2 {
self.drop_oldest(w0_count);
return true;
}
}
false
}
fn drop_oldest(&mut self, count: usize) {
let mut remaining = count;
let mut level = self.buckets.len();
while level > 0 && remaining > 0 {
level -= 1;
while !self.buckets[level].is_empty() && remaining > 0 {
let bucket = &self.buckets[level][0];
if bucket.count <= remaining {
let removed = self.buckets[level].remove(0);
remaining -= removed.count;
self.total_count -= removed.count;
self.total_sum -= removed.total;
self.total_variance -=
removed.total * removed.total / removed.count.max(1) as f64;
} else {
break;
}
}
}
while let Some(last) = self.buckets.last() {
if last.is_empty() {
self.buckets.pop();
} else {
break;
}
}
}
pub fn detected_change(&self) -> bool {
self.last_change_detected
}
pub fn current_mean(&self) -> f64 {
if self.total_count == 0 {
0.0
} else {
self.total_sum / self.total_count as f64
}
}
pub fn current_length(&self) -> usize {
self.total_count
}
pub fn current_sum(&self) -> f64 {
self.total_sum
}
pub fn delta(&self) -> f64 {
self.delta
}
pub fn reset(&mut self) {
self.buckets.clear();
self.total_count = 0;
self.total_sum = 0.0;
self.total_variance = 0.0;
self.last_change_detected = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adwin_no_change_stable_data() {
let mut adwin = Adwin::new(0.01).expect("valid delta");
let mut any_change = false;
for i in 0..500 {
let val = 5.0 + (i as f64) * 0.0001; let changed = adwin.add_element(val).expect("add");
if changed {
any_change = true;
}
}
let mean = adwin.current_mean();
assert!(
mean > 4.0 && mean < 6.0,
"Mean should be around 5.0: {}",
mean
);
assert!(adwin.current_length() > 0);
let _ = any_change;
}
#[test]
fn test_adwin_detect_abrupt_change() {
let mut adwin = Adwin::new(0.002).expect("valid delta");
adwin.set_min_window_length(5);
for _ in 0..200 {
adwin.add_element(0.0).expect("add");
}
let mut detected = false;
for _ in 0..200 {
let changed = adwin.add_element(100.0).expect("add");
if changed {
detected = true;
break;
}
}
assert!(
detected,
"ADWIN should detect abrupt mean shift from 0 to 100"
);
}
#[test]
fn test_adwin_window_shrinks_on_change() {
let mut adwin = Adwin::new(0.01).expect("valid delta");
adwin.set_min_window_length(5);
for _ in 0..200 {
adwin.add_element(0.0).expect("add");
}
let len_before = adwin.current_length();
assert!(len_before > 100, "Window should have grown: {}", len_before);
for _ in 0..200 {
let changed = adwin.add_element(50.0).expect("add");
if changed {
break;
}
}
let len_after = adwin.current_length();
assert!(
len_after < len_before,
"Window should shrink after drift: {} -> {}",
len_before,
len_after
);
}
#[test]
fn test_adwin_mean_tracking() {
let mut adwin = Adwin::new(0.05).expect("valid delta");
for _ in 0..100 {
adwin.add_element(10.0).expect("add");
}
let mean = adwin.current_mean();
assert!(
(mean - 10.0).abs() < 1.0,
"Mean should be close to 10.0: {}",
mean
);
}
#[test]
fn test_adwin_reset() {
let mut adwin = Adwin::new(0.01).expect("valid delta");
for _ in 0..50 {
adwin.add_element(1.0).expect("add");
}
assert!(adwin.current_length() > 0);
adwin.reset();
assert_eq!(adwin.current_length(), 0);
assert!((adwin.current_mean()).abs() < 1e-15);
}
#[test]
fn test_adwin_invalid_delta() {
assert!(Adwin::new(0.0).is_err());
assert!(Adwin::new(1.0).is_err());
assert!(Adwin::new(-0.5).is_err());
}
#[test]
fn test_adwin_nan_input() {
let mut adwin = Adwin::new(0.01).expect("valid delta");
assert!(adwin.add_element(f64::NAN).is_err());
assert!(adwin.add_element(f64::INFINITY).is_err());
}
#[test]
fn test_adwin_accessors() {
let adwin = Adwin::new(0.05).expect("valid delta");
assert!((adwin.delta() - 0.05).abs() < 1e-15);
assert_eq!(adwin.current_length(), 0);
assert!(!adwin.detected_change());
}
}