use crate::stats::DistortionMetrics;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-support", derive(serde::Serialize, serde::Deserialize))]
pub struct DistortionTracker {
alpha: f64,
mse_threshold: f64,
ema_mse: f64,
ema_bias: f64,
samples: u64,
}
impl DistortionTracker {
pub fn new(alpha: f64, mse_threshold: f64) -> Self {
let alpha = alpha.clamp(1e-6, 1.0);
Self { alpha, mse_threshold, ema_mse: 0.0, ema_bias: 0.0, samples: 0 }
}
pub fn observe(&mut self, estimated: f64, ground_truth: f64) {
let error = estimated - ground_truth;
let sq_error = error * error;
if self.samples == 0 {
self.ema_mse = sq_error;
self.ema_bias = error;
} else {
self.ema_mse = self.alpha * sq_error + (1.0 - self.alpha) * self.ema_mse;
self.ema_bias = self.alpha * error + (1.0 - self.alpha) * self.ema_bias;
}
self.samples += 1;
}
pub fn is_healthy(&self) -> bool {
if self.samples == 0 {
return true;
}
self.ema_mse <= self.mse_threshold
}
pub fn metrics(&self) -> DistortionMetrics {
DistortionMetrics {
mse: self.ema_mse,
bias: self.ema_bias,
samples: self.samples,
healthy: self.is_healthy(),
}
}
pub fn reset(&mut self) {
self.ema_mse = 0.0;
self.ema_bias = 0.0;
self.samples = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_healthy_before_observations() {
let tracker = DistortionTracker::new(0.1, 1.0);
assert!(tracker.is_healthy());
assert_eq!(tracker.metrics().samples, 0);
}
#[test]
fn test_healthy_after_zero_error() {
let mut tracker = DistortionTracker::new(0.1, 1.0);
for _ in 0..100 {
tracker.observe(0.5, 0.5);
}
assert!(tracker.is_healthy());
assert!((tracker.metrics().mse).abs() < 1e-9);
assert!((tracker.metrics().bias).abs() < 1e-9);
}
#[test]
fn test_unhealthy_after_large_errors() {
let mut tracker = DistortionTracker::new(0.5, 0.1);
for _ in 0..50 {
tracker.observe(10.0, 0.0); }
assert!(!tracker.is_healthy());
assert!(tracker.metrics().mse > 0.1);
}
#[test]
fn test_bias_tracking() {
let mut tracker = DistortionTracker::new(0.5, 1000.0);
for _ in 0..200 {
tracker.observe(1.0, 0.0); }
let m = tracker.metrics();
assert!(m.bias > 0.5, "bias={}", m.bias);
}
#[test]
fn test_reset() {
let mut tracker = DistortionTracker::new(0.1, 1.0);
tracker.observe(5.0, 0.0);
tracker.reset();
assert_eq!(tracker.metrics().samples, 0);
assert!(tracker.is_healthy());
}
#[test]
fn test_samples_count() {
let mut tracker = DistortionTracker::new(0.1, 1.0);
for i in 0..42 {
tracker.observe(i as f64, 0.0);
}
assert_eq!(tracker.metrics().samples, 42);
}
}