#[derive(Debug, Clone)]
pub struct RegressionMetrics {
count: u64,
sum_abs_error: f64,
sum_sq_error: f64,
target_mean: f64,
target_m2: f64,
}
impl RegressionMetrics {
pub fn new() -> Self {
Self {
count: 0,
sum_abs_error: 0.0,
sum_sq_error: 0.0,
target_mean: 0.0,
target_m2: 0.0,
}
}
pub fn update(&mut self, target: f64, prediction: f64) {
self.count += 1;
let error = target - prediction;
self.sum_abs_error += error.abs();
self.sum_sq_error += error * error;
let delta = target - self.target_mean;
self.target_mean += delta / self.count as f64;
let delta2 = target - self.target_mean;
self.target_m2 += delta * delta2;
}
pub fn mae(&self) -> f64 {
if self.count == 0 {
return 0.0;
}
self.sum_abs_error / self.count as f64
}
pub fn mse(&self) -> f64 {
if self.count == 0 {
return 0.0;
}
self.sum_sq_error / self.count as f64
}
pub fn rmse(&self) -> f64 {
self.mse().sqrt()
}
pub fn r_squared(&self) -> f64 {
if self.count < 2 {
return 0.0;
}
if self.target_m2 == 0.0 {
return 0.0;
}
1.0 - self.sum_sq_error / self.target_m2
}
pub fn n_samples(&self) -> u64 {
self.count
}
pub fn reset(&mut self) {
self.count = 0;
self.sum_abs_error = 0.0;
self.sum_sq_error = 0.0;
self.target_mean = 0.0;
self.target_m2 = 0.0;
}
}
impl Default for RegressionMetrics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
#[test]
fn empty_state_returns_zeros() {
let m = RegressionMetrics::new();
assert_eq!(m.mae(), 0.0);
assert_eq!(m.mse(), 0.0);
assert_eq!(m.rmse(), 0.0);
assert_eq!(m.r_squared(), 0.0);
assert_eq!(m.n_samples(), 0);
}
#[test]
fn single_sample() {
let mut m = RegressionMetrics::new();
m.update(5.0, 3.0);
assert!(approx_eq(m.mae(), 2.0));
assert!(approx_eq(m.mse(), 4.0));
assert!(approx_eq(m.rmse(), 2.0));
assert_eq!(m.r_squared(), 0.0);
assert_eq!(m.n_samples(), 1);
}
#[test]
fn mae_multiple_samples() {
let mut m = RegressionMetrics::new();
m.update(1.0, 2.0);
m.update(3.0, 1.0);
m.update(5.0, 4.0);
assert!(approx_eq(m.mae(), 4.0 / 3.0));
}
#[test]
fn mse_multiple_samples() {
let mut m = RegressionMetrics::new();
m.update(1.0, 2.0);
m.update(3.0, 1.0);
m.update(5.0, 4.0);
assert!(approx_eq(m.mse(), 2.0));
}
#[test]
fn rmse_multiple_samples() {
let mut m = RegressionMetrics::new();
m.update(1.0, 2.0);
m.update(3.0, 1.0);
m.update(5.0, 4.0);
assert!(approx_eq(m.rmse(), 2.0_f64.sqrt()));
}
#[test]
fn r_squared_perfect_prediction() {
let mut m = RegressionMetrics::new();
m.update(1.0, 1.0);
m.update(2.0, 2.0);
m.update(3.0, 3.0);
m.update(4.0, 4.0);
assert!(approx_eq(m.r_squared(), 1.0));
}
#[test]
fn r_squared_mean_prediction() {
let mut m = RegressionMetrics::new();
let targets = [1.0, 2.0, 3.0, 4.0, 5.0];
let mean = 3.0;
for &t in &targets {
m.update(t, mean);
}
assert!(approx_eq(m.r_squared(), 0.0));
}
#[test]
fn r_squared_partial_fit() {
let mut m = RegressionMetrics::new();
m.update(2.0, 2.5);
m.update(4.0, 3.5);
m.update(6.0, 5.5);
m.update(8.0, 8.5);
assert!(approx_eq(m.r_squared(), 0.95));
}
#[test]
fn r_squared_negative() {
let mut m = RegressionMetrics::new();
m.update(1.0, 10.0);
m.update(2.0, 10.0);
m.update(3.0, 10.0);
assert!(m.r_squared() < 0.0);
}
#[test]
fn r_squared_constant_target() {
let mut m = RegressionMetrics::new();
m.update(5.0, 5.0);
m.update(5.0, 5.0);
m.update(5.0, 5.0);
assert_eq!(m.r_squared(), 0.0);
}
#[test]
fn reset_clears_state() {
let mut m = RegressionMetrics::new();
m.update(1.0, 2.0);
m.update(3.0, 4.0);
m.reset();
assert_eq!(m.n_samples(), 0);
assert_eq!(m.mae(), 0.0);
assert_eq!(m.mse(), 0.0);
assert_eq!(m.rmse(), 0.0);
assert_eq!(m.r_squared(), 0.0);
}
#[test]
fn nan_inputs_propagate() {
let mut m = RegressionMetrics::new();
m.update(f64::NAN, 1.0);
assert!(m.mae().is_nan());
assert!(m.mse().is_nan());
}
#[test]
fn large_sample_count() {
let mut m = RegressionMetrics::new();
for i in 0..10_000 {
let t = i as f64;
let p = t + 1.0; m.update(t, p);
}
assert!(approx_eq(m.mae(), 1.0));
assert!(approx_eq(m.mse(), 1.0));
assert!(approx_eq(m.rmse(), 1.0));
assert_eq!(m.n_samples(), 10_000);
}
#[test]
fn default_is_empty() {
let m = RegressionMetrics::default();
assert_eq!(m.n_samples(), 0);
assert_eq!(m.mae(), 0.0);
}
}