use tokio::time::{Duration, Instant};
#[derive(Debug)]
pub enum ExpWeightedMovingAvgMode {
TimeDecay { half_life_secs: f64, last_update: Instant },
CountDecay { half_life_count: f64 },
}
#[derive(Debug)]
pub struct ExpWeightedMovingAvg {
mode: ExpWeightedMovingAvgMode,
weight: f64,
value: f64,
}
impl ExpWeightedMovingAvg {
pub fn new_time_decay(half_life: Duration) -> Self {
let hl_secs = half_life.as_secs_f64();
assert!(hl_secs.is_finite() && hl_secs > 0.0, "half-life must be positive");
Self {
mode: ExpWeightedMovingAvgMode::TimeDecay {
half_life_secs: hl_secs,
last_update: Instant::now(),
},
weight: 0.0,
value: 0.0,
}
}
pub fn new_count_decay(half_life_count: f64) -> Self {
assert!(half_life_count.is_finite() && half_life_count > 0.0, "half-life must be positive");
Self {
mode: ExpWeightedMovingAvgMode::CountDecay { half_life_count },
weight: 0.0,
value: 0.0,
}
}
pub fn update_with_weight(&mut self, sample: f64, weight: f64) {
let decay = match &mut self.mode {
ExpWeightedMovingAvgMode::TimeDecay {
half_life_secs,
last_update,
} => {
let now = Instant::now();
let dt_secs = (now - *last_update).as_secs_f64();
let decay = (-dt_secs / *half_life_secs).exp2();
*last_update = now;
decay
},
ExpWeightedMovingAvgMode::CountDecay { half_life_count } => {
(-weight / *half_life_count).exp2()
},
};
self.weight *= decay;
self.value *= decay;
self.weight += weight;
self.value += sample;
}
pub fn update(&mut self, sample: f64) {
self.update_with_weight(sample, 1.0);
}
pub fn value(&self) -> f64 {
if self.weight == 0.0 {
0.0
} else {
self.value / self.weight
}
}
}
#[cfg(test)]
mod tests {
use tokio::time::{Duration, advance, pause};
use super::*;
#[tokio::test]
async fn ewma_decays_with_simulated_time() {
pause();
let half_life = Duration::from_secs(2);
let mut avg = ExpWeightedMovingAvg::new_time_decay(half_life);
avg.update(10.0);
assert_eq!(avg.value(), 10.0);
advance(half_life).await;
avg.update(0.0);
let epsilon = 1e-6;
assert!((avg.value() - 3.333_333_333).abs() < epsilon);
}
#[tokio::test]
async fn ewma_multiple_advances() {
pause();
let mut avg = ExpWeightedMovingAvg::new_time_decay(Duration::from_secs(4));
avg.update(8.0); advance(Duration::from_secs(2)).await; avg.update(8.0); advance(Duration::from_secs(4)).await; avg.update(0.0);
let m = avg.value();
assert!(m > 0.0 && m < 8.0);
}
#[test]
fn ewma_count_decay_basic() {
let half_life_count = 2.0;
let mut avg = ExpWeightedMovingAvg::new_count_decay(half_life_count);
avg.update(10.0);
assert_eq!(avg.value(), 10.0);
avg.update(0.0);
let epsilon = 1e-6;
let decay_factor: f64 = (-1.0_f64 / 2.0_f64).exp2();
let expected_mean = (10.0 * decay_factor) / (1.0 * decay_factor + 1.0);
assert!((avg.value() - expected_mean).abs() < epsilon);
}
#[test]
fn ewma_count_decay_multiple_samples() {
let mut avg = ExpWeightedMovingAvg::new_count_decay(4.0);
avg.update(8.0); avg.update(8.0); avg.update(8.0); avg.update(0.0);
let m = avg.value();
assert!(m > 0.0 && m < 8.0);
}
#[tokio::test]
async fn ewma_time_decay_weighted_rate() {
pause();
let half_life = Duration::from_secs(10);
let mut avg = ExpWeightedMovingAvg::new_time_decay(half_life);
advance(Duration::from_millis(200)).await;
avg.update_with_weight(2000.0, 0.2);
assert!((avg.value() - 10_000.0).abs() < 1.0);
advance(Duration::from_millis(200)).await;
avg.update_with_weight(2000.0, 0.2);
assert!((avg.value() - 10_000.0).abs() < 1.0);
advance(Duration::from_millis(200)).await;
avg.update_with_weight(0.0, 0.2);
assert!(avg.value() < 10_000.0);
}
#[tokio::test]
async fn ewma_time_decay_half_life_independent_of_weight() {
pause();
let half_life = Duration::from_secs(10);
let mut avg = ExpWeightedMovingAvg::new_time_decay(half_life);
avg.update_with_weight(100.0, 0.5);
advance(half_life).await;
avg.update_with_weight(0.0, 0.5);
let epsilon = 1e-6;
let expected = 50.0 / 0.75;
assert!((avg.value() - expected).abs() < epsilon);
}
#[test]
fn ewma_count_decay_half_life() {
let half_life_count = 10.0;
let mut avg = ExpWeightedMovingAvg::new_count_decay(half_life_count);
avg.update(100.0);
let initial_value = avg.value();
assert_eq!(initial_value, 100.0);
for _ in 0..9 {
avg.update(0.0);
}
let final_value = avg.value();
assert!(final_value > 0.0 && final_value < 50.0);
}
}