#[derive(Debug, Clone)]
pub struct EwmaRegressionMetrics {
alpha: f64,
count: u64,
ewma_abs_error: f64,
ewma_sq_error: f64,
ewma_target: f64,
ewma_target_sq_dev: f64,
initialized: bool,
}
impl EwmaRegressionMetrics {
pub fn new(span: usize) -> Self {
assert!(span > 0, "span must be > 0");
Self::with_alpha(2.0 / (span as f64 + 1.0))
}
pub fn with_alpha(alpha: f64) -> Self {
assert!(
alpha > 0.0 && alpha <= 1.0,
"alpha must be in (0, 1], got {alpha}"
);
Self {
alpha,
count: 0,
ewma_abs_error: 0.0,
ewma_sq_error: 0.0,
ewma_target: 0.0,
ewma_target_sq_dev: 0.0,
initialized: false,
}
}
pub fn update(&mut self, target: f64, prediction: f64) {
let abs_err = (target - prediction).abs();
let sq_err = (target - prediction) * (target - prediction);
if !self.initialized {
self.ewma_abs_error = abs_err;
self.ewma_sq_error = sq_err;
self.ewma_target = target;
self.ewma_target_sq_dev = 0.0;
self.initialized = true;
} else {
self.ewma_abs_error = self.alpha * abs_err + (1.0 - self.alpha) * self.ewma_abs_error;
self.ewma_sq_error = self.alpha * sq_err + (1.0 - self.alpha) * self.ewma_sq_error;
let old_target_mean = self.ewma_target;
self.ewma_target = self.alpha * target + (1.0 - self.alpha) * self.ewma_target;
let dev = (target - old_target_mean) * (target - old_target_mean);
self.ewma_target_sq_dev =
self.alpha * dev + (1.0 - self.alpha) * self.ewma_target_sq_dev;
}
self.count += 1;
}
pub fn mae(&self) -> f64 {
if !self.initialized {
return 0.0;
}
self.ewma_abs_error
}
pub fn mse(&self) -> f64 {
if !self.initialized {
return 0.0;
}
self.ewma_sq_error
}
pub fn rmse(&self) -> f64 {
self.mse().sqrt()
}
pub fn r_squared(&self) -> f64 {
if self.count < 2 || self.ewma_target_sq_dev < 1e-15 {
return 0.0;
}
1.0 - self.ewma_sq_error / self.ewma_target_sq_dev
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn n_samples(&self) -> u64 {
self.count
}
pub fn reset(&mut self) {
self.count = 0;
self.ewma_abs_error = 0.0;
self.ewma_sq_error = 0.0;
self.ewma_target = 0.0;
self.ewma_target_sq_dev = 0.0;
self.initialized = false;
}
}
#[derive(Debug, Clone)]
pub struct EwmaClassificationMetrics {
alpha: f64,
count: u64,
ewma_correct: f64,
ewma_log_loss: f64,
initialized: bool,
}
const CLIP_MIN: f64 = 1e-15;
const CLIP_MAX: f64 = 1.0 - 1e-15;
impl EwmaClassificationMetrics {
pub fn new(span: usize) -> Self {
assert!(span > 0, "span must be > 0");
Self::with_alpha(2.0 / (span as f64 + 1.0))
}
pub fn with_alpha(alpha: f64) -> Self {
assert!(
alpha > 0.0 && alpha <= 1.0,
"alpha must be in (0, 1], got {alpha}"
);
Self {
alpha,
count: 0,
ewma_correct: 0.0,
ewma_log_loss: 0.0,
initialized: false,
}
}
pub fn update(&mut self, target: usize, predicted: usize, predicted_proba: f64) {
let correct = if target == predicted { 1.0 } else { 0.0 };
let p = predicted_proba.clamp(CLIP_MIN, CLIP_MAX);
let y = if target == 1 { 1.0 } else { 0.0 };
let sample_loss = -(y * p.ln() + (1.0 - y) * (1.0 - p).ln());
if !self.initialized {
self.ewma_correct = correct;
self.ewma_log_loss = sample_loss;
self.initialized = true;
} else {
self.ewma_correct = self.alpha * correct + (1.0 - self.alpha) * self.ewma_correct;
self.ewma_log_loss = self.alpha * sample_loss + (1.0 - self.alpha) * self.ewma_log_loss;
}
self.count += 1;
}
pub fn accuracy(&self) -> f64 {
if !self.initialized {
return 0.0;
}
self.ewma_correct
}
pub fn log_loss(&self) -> f64 {
if !self.initialized {
return 0.0;
}
self.ewma_log_loss
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn n_samples(&self) -> u64 {
self.count
}
pub fn reset(&mut self) {
self.count = 0;
self.ewma_correct = 0.0;
self.ewma_log_loss = 0.0;
self.initialized = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-6;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
#[test]
fn ewma_reg_initial() {
let mut m = EwmaRegressionMetrics::new(10);
m.update(5.0, 3.0); assert!(approx_eq(m.mae(), 2.0));
assert!(approx_eq(m.mse(), 4.0));
assert_eq!(m.n_samples(), 1);
}
#[test]
fn ewma_reg_decay() {
let mut m = EwmaRegressionMetrics::new(5); for _ in 0..100 {
m.update(5.0, 4.0);
}
assert!(approx_eq(m.mae(), 1.0));
for _ in 0..100 {
m.update(5.0, 5.0);
}
assert!(m.mae() < 0.01);
}
#[test]
fn ewma_reg_span_controls_rate() {
let mut fast = EwmaRegressionMetrics::new(2); let mut slow = EwmaRegressionMetrics::new(50);
fast.update(5.0, 4.0);
slow.update(5.0, 4.0);
for _ in 0..10 {
fast.update(5.0, 5.0);
slow.update(5.0, 5.0);
}
assert!(fast.mae() < slow.mae());
}
#[test]
fn ewma_reg_reset() {
let mut m = EwmaRegressionMetrics::new(10);
m.update(1.0, 2.0);
m.reset();
assert_eq!(m.n_samples(), 0);
assert_eq!(m.mae(), 0.0);
}
#[test]
fn ewma_reg_r_squared_perfect() {
let mut m = EwmaRegressionMetrics::new(20);
for i in 0..100 {
m.update(i as f64, i as f64); }
assert!(m.r_squared() > 0.99 || m.mse() < 1e-10);
}
#[test]
fn ewma_cls_all_correct() {
let mut m = EwmaClassificationMetrics::new(10);
for _ in 0..50 {
m.update(1, 1, 0.95);
}
assert!(m.accuracy() > 0.99);
}
#[test]
fn ewma_cls_accuracy_decays() {
let mut m = EwmaClassificationMetrics::new(5);
for _ in 0..50 {
m.update(1, 1, 0.9);
}
assert!(m.accuracy() > 0.99);
for _ in 0..50 {
m.update(1, 0, 0.1);
}
assert!(m.accuracy() < 0.01);
}
#[test]
fn ewma_cls_log_loss_near_perfect() {
let mut m = EwmaClassificationMetrics::new(10);
m.update(1, 1, 0.999);
assert!(m.log_loss() < 0.01);
assert!(m.log_loss() > 0.0);
}
#[test]
fn ewma_cls_reset() {
let mut m = EwmaClassificationMetrics::new(10);
m.update(1, 1, 0.9);
m.reset();
assert_eq!(m.n_samples(), 0);
assert_eq!(m.accuracy(), 0.0);
}
#[test]
#[should_panic(expected = "span must be > 0")]
fn ewma_reg_zero_span() {
EwmaRegressionMetrics::new(0);
}
#[test]
#[should_panic(expected = "span must be > 0")]
fn ewma_cls_zero_span() {
EwmaClassificationMetrics::new(0);
}
#[test]
fn ewma_alpha_getter() {
let m = EwmaRegressionMetrics::new(9); assert!(approx_eq(m.alpha(), 0.2));
}
}