use alloc::format;
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;
use crate::error::{RcfError, RcfResult};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct EmaStats {
mean: f64,
variance: f64,
decay: f64,
observations: u64,
}
impl EmaStats {
pub fn new(decay: f64) -> RcfResult<Self> {
if !decay.is_finite() || decay <= 0.0 || decay > 1.0 {
return Err(RcfError::InvalidConfig(
format!("EmaStats decay must be in (0.0, 1.0], got {decay}").into(),
));
}
Ok(Self {
mean: 0.0,
variance: 0.0,
decay,
observations: 0,
})
}
pub fn update(&mut self, value: f64) {
if !value.is_finite() {
return;
}
let delta = value - self.mean;
if self.observations == 0 {
self.mean = value;
self.variance = 0.0;
} else {
self.mean += self.decay * delta;
self.variance = (1.0 - self.decay) * (self.variance + self.decay * delta * delta);
}
self.observations = self.observations.saturating_add(1);
}
#[must_use]
pub fn mean(&self) -> f64 {
self.mean
}
#[must_use]
pub fn variance(&self) -> f64 {
self.variance.max(0.0)
}
#[must_use]
pub fn stddev(&self) -> f64 {
self.variance().sqrt()
}
#[must_use]
pub fn observations(&self) -> u64 {
self.observations
}
#[must_use]
pub fn decay(&self) -> f64 {
self.decay
}
pub fn reset(&mut self) {
self.mean = 0.0;
self.variance = 0.0;
self.observations = 0;
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
#[test]
fn new_rejects_non_finite_decay() {
assert!(EmaStats::new(f64::NAN).is_err());
assert!(EmaStats::new(f64::INFINITY).is_err());
}
#[test]
fn new_rejects_non_positive_decay() {
assert!(EmaStats::new(0.0).is_err());
assert!(EmaStats::new(-0.1).is_err());
}
#[test]
fn new_rejects_decay_above_one() {
assert!(EmaStats::new(1.001).is_err());
}
#[test]
fn new_accepts_decay_at_one() {
EmaStats::new(1.0).unwrap();
}
#[test]
fn first_update_sets_mean_exactly() {
let mut s = EmaStats::new(0.1).unwrap();
s.update(7.0);
assert_eq!(s.mean(), 7.0);
assert_eq!(s.variance(), 0.0);
assert_eq!(s.observations(), 1);
}
#[test]
fn non_finite_update_is_ignored() {
let mut s = EmaStats::new(0.1).unwrap();
s.update(f64::NAN);
s.update(f64::INFINITY);
assert_eq!(s.observations(), 0);
assert_eq!(s.mean(), 0.0);
}
#[test]
fn mean_tracks_constant_stream_with_zero_variance() {
let mut s = EmaStats::new(0.1).unwrap();
for _ in 0..1000 {
s.update(5.0);
}
assert!((s.mean() - 5.0).abs() < 1e-9);
assert!(s.variance() < 1e-12);
}
#[test]
fn variance_tracks_spread() {
let mut s = EmaStats::new(0.05).unwrap();
for i in 0..5_000 {
let v = if i % 2 == 0 { 1.0 } else { -1.0 };
s.update(v);
}
assert!(s.mean().abs() < 0.1);
assert!(s.stddev() > 0.5);
assert!(s.stddev() < 1.5);
}
#[test]
fn reset_clears_state() {
let mut s = EmaStats::new(0.1).unwrap();
for i in 0..10 {
s.update(f64::from(i));
}
assert!(s.observations() > 0);
s.reset();
assert_eq!(s.mean(), 0.0);
assert_eq!(s.variance(), 0.0);
assert_eq!(s.observations(), 0);
}
#[test]
fn observations_saturates_at_u64_max() {
let mut s = EmaStats::new(1.0).unwrap();
s.observations = u64::MAX;
s.update(1.0);
assert_eq!(s.observations(), u64::MAX);
}
#[test]
fn variance_is_never_negative() {
let mut s = EmaStats::new(0.5).unwrap();
s.update(1e-300);
s.update(-1e-300);
assert!(s.variance() >= 0.0);
}
}