pub trait StreamingMetric: Send + Sync {
fn update(&mut self, pred: f64, actual: f64);
fn get(&self) -> f64;
fn name(&self) -> &'static str;
fn reset(&mut self);
fn higher_is_better(&self) -> bool {
false
}
}
fn _assert_object_safe(_: Box<dyn StreamingMetric>) {}
pub struct MetricUnion<A, B> {
pub a: A,
pub b: B,
}
impl<A: StreamingMetric, B: StreamingMetric> StreamingMetric for MetricUnion<A, B> {
fn update(&mut self, pred: f64, actual: f64) {
self.a.update(pred, actual);
self.b.update(pred, actual);
}
fn get(&self) -> f64 {
self.a.get()
}
fn name(&self) -> &'static str {
self.a.name()
}
fn reset(&mut self) {
self.a.reset();
self.b.reset();
}
fn higher_is_better(&self) -> bool {
self.a.higher_is_better()
}
}
impl<A: StreamingMetric, C: StreamingMetric, B: StreamingMetric> std::ops::Add<C>
for MetricUnion<A, B>
{
type Output = MetricUnion<MetricUnion<A, B>, C>;
fn add(self, rhs: C) -> Self::Output {
MetricUnion { a: self, b: rhs }
}
}
#[derive(Debug, Clone, Default)]
pub struct MAE {
count: u64,
sum_abs_error: f64,
}
impl MAE {
pub fn new() -> Self {
Self::default()
}
}
impl StreamingMetric for MAE {
fn update(&mut self, pred: f64, actual: f64) {
self.count += 1;
self.sum_abs_error += (actual - pred).abs();
}
fn get(&self) -> f64 {
if self.count == 0 {
return 0.0;
}
self.sum_abs_error / self.count as f64
}
fn name(&self) -> &'static str {
"MAE"
}
fn reset(&mut self) {
self.count = 0;
self.sum_abs_error = 0.0;
}
}
impl<B: StreamingMetric> std::ops::Add<B> for MAE {
type Output = MetricUnion<MAE, B>;
fn add(self, rhs: B) -> Self::Output {
MetricUnion { a: self, b: rhs }
}
}
#[derive(Debug, Clone, Default)]
pub struct MSE {
count: u64,
sum_sq_error: f64,
}
impl MSE {
pub fn new() -> Self {
Self::default()
}
}
impl StreamingMetric for MSE {
fn update(&mut self, pred: f64, actual: f64) {
self.count += 1;
let e = actual - pred;
self.sum_sq_error += e * e;
}
fn get(&self) -> f64 {
if self.count == 0 {
return 0.0;
}
self.sum_sq_error / self.count as f64
}
fn name(&self) -> &'static str {
"MSE"
}
fn reset(&mut self) {
self.count = 0;
self.sum_sq_error = 0.0;
}
}
impl<B: StreamingMetric> std::ops::Add<B> for MSE {
type Output = MetricUnion<MSE, B>;
fn add(self, rhs: B) -> Self::Output {
MetricUnion { a: self, b: rhs }
}
}
#[derive(Debug, Clone, Default)]
pub struct RMSE {
inner: MSE,
}
impl RMSE {
pub fn new() -> Self {
Self::default()
}
}
impl StreamingMetric for RMSE {
fn update(&mut self, pred: f64, actual: f64) {
self.inner.update(pred, actual);
}
fn get(&self) -> f64 {
self.inner.get().sqrt()
}
fn name(&self) -> &'static str {
"RMSE"
}
fn reset(&mut self) {
self.inner.reset();
}
}
impl<B: StreamingMetric> std::ops::Add<B> for RMSE {
type Output = MetricUnion<RMSE, B>;
fn add(self, rhs: B) -> Self::Output {
MetricUnion { a: self, b: rhs }
}
}
#[derive(Debug, Clone, Default)]
pub struct R2 {
count: u64,
sum_sq_error: f64,
target_mean: f64,
target_m2: f64,
}
impl R2 {
pub fn new() -> Self {
Self::default()
}
}
impl StreamingMetric for R2 {
fn update(&mut self, pred: f64, actual: f64) {
self.count += 1;
let e = actual - pred;
self.sum_sq_error += e * e;
let delta = actual - self.target_mean;
self.target_mean += delta / self.count as f64;
let delta2 = actual - self.target_mean;
self.target_m2 += delta * delta2;
}
fn get(&self) -> f64 {
if self.count < 2 {
return 0.0;
}
if self.target_m2 == 0.0 {
return 0.0;
}
1.0 - self.sum_sq_error / self.target_m2
}
fn name(&self) -> &'static str {
"R2"
}
fn reset(&mut self) {
self.count = 0;
self.sum_sq_error = 0.0;
self.target_mean = 0.0;
self.target_m2 = 0.0;
}
fn higher_is_better(&self) -> bool {
true
}
}
impl<B: StreamingMetric> std::ops::Add<B> for R2 {
type Output = MetricUnion<R2, B>;
fn add(self, rhs: B) -> Self::Output {
MetricUnion { a: self, b: rhs }
}
}
#[derive(Debug, Clone)]
pub struct Pinball {
tau: f64,
count: u64,
sum_loss: f64,
}
impl Pinball {
pub fn new(tau: f64) -> Self {
assert!(
tau > 0.0 && tau < 1.0,
"Pinball tau must be in (0, 1), got {tau}"
);
Self {
tau,
count: 0,
sum_loss: 0.0,
}
}
pub fn tau(&self) -> f64 {
self.tau
}
}
impl StreamingMetric for Pinball {
fn update(&mut self, pred: f64, actual: f64) {
self.count += 1;
let residual = actual - pred;
let loss = if residual >= 0.0 {
self.tau * residual
} else {
(self.tau - 1.0) * residual
};
self.sum_loss += loss;
}
fn get(&self) -> f64 {
if self.count == 0 {
return 0.0;
}
self.sum_loss / self.count as f64
}
fn name(&self) -> &'static str {
"Pinball"
}
fn reset(&mut self) {
self.count = 0;
self.sum_loss = 0.0;
}
}
impl<B: StreamingMetric> std::ops::Add<B> for Pinball {
type Output = MetricUnion<Pinball, B>;
fn add(self, rhs: B) -> Self::Output {
MetricUnion { a: self, b: rhs }
}
}
#[derive(Debug, Clone, Default)]
pub struct LogLoss {
count: u64,
sum_loss: f64,
}
impl LogLoss {
pub fn new() -> Self {
Self::default()
}
}
const LOGLOSS_CLIP_MIN: f64 = 1e-15;
const LOGLOSS_CLIP_MAX: f64 = 1.0 - 1e-15;
impl StreamingMetric for LogLoss {
fn update(&mut self, pred: f64, actual: f64) {
self.count += 1;
let p = pred.clamp(LOGLOSS_CLIP_MIN, LOGLOSS_CLIP_MAX);
let y = if actual > 0.5 { 1.0_f64 } else { 0.0_f64 };
let loss = -(y * p.ln() + (1.0 - y) * (1.0 - p).ln());
self.sum_loss += loss;
}
fn get(&self) -> f64 {
if self.count == 0 {
return 0.0;
}
self.sum_loss / self.count as f64
}
fn name(&self) -> &'static str {
"LogLoss"
}
fn reset(&mut self) {
self.count = 0;
self.sum_loss = 0.0;
}
}
impl<B: StreamingMetric> std::ops::Add<B> for LogLoss {
type Output = MetricUnion<LogLoss, B>;
fn add(self, rhs: B) -> Self::Output {
MetricUnion { a: self, b: rhs }
}
}
#[derive(Debug, Clone, Default)]
pub struct Accuracy {
n_total: u64,
n_correct: u64,
}
impl Accuracy {
pub fn new() -> Self {
Self::default()
}
}
impl StreamingMetric for Accuracy {
fn update(&mut self, pred: f64, actual: f64) {
self.n_total += 1;
if pred.round() as i64 == actual.round() as i64 {
self.n_correct += 1;
}
}
fn get(&self) -> f64 {
if self.n_total == 0 {
return 0.0;
}
self.n_correct as f64 / self.n_total as f64
}
fn name(&self) -> &'static str {
"Accuracy"
}
fn reset(&mut self) {
self.n_total = 0;
self.n_correct = 0;
}
fn higher_is_better(&self) -> bool {
true
}
}
impl<B: StreamingMetric> std::ops::Add<B> for Accuracy {
type Output = MetricUnion<Accuracy, B>;
fn add(self, rhs: B) -> Self::Output {
MetricUnion { a: self, b: rhs }
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
#[test]
fn mae_streaming_matches_offline_computation() {
let data = [(3.0_f64, 2.5_f64), (1.0, 1.5), (5.0, 4.0), (2.0, 2.0)];
let mut m = MAE::new();
for &(actual, pred) in &data {
m.update(pred, actual);
}
assert!(approx_eq(m.get(), 0.5), "MAE expected 0.5, got {}", m.get());
}
#[test]
fn mae_empty_returns_zero() {
let m = MAE::new();
assert_eq!(m.get(), 0.0, "empty MAE must be 0.0");
}
#[test]
fn mae_reset_clears_state() {
let mut m = MAE::new();
m.update(1.0, 2.0);
m.reset();
assert_eq!(m.get(), 0.0, "MAE after reset must be 0.0");
}
#[test]
fn mse_streaming_matches_offline_computation() {
let data = [(3.0_f64, 2.5_f64), (1.0, 1.5), (5.0, 4.0), (2.0, 2.0)];
let mut m = MSE::new();
for &(actual, pred) in &data {
m.update(pred, actual);
}
assert!(
approx_eq(m.get(), 0.375),
"MSE expected 0.375, got {}",
m.get()
);
}
#[test]
fn mse_empty_returns_zero() {
assert_eq!(MSE::new().get(), 0.0);
}
#[test]
fn rmse_is_sqrt_of_mse() {
let data = [(3.0_f64, 2.5_f64), (1.0, 1.5), (5.0, 4.0)];
let mut rmse = RMSE::new();
let mut mse = MSE::new();
for &(actual, pred) in &data {
rmse.update(pred, actual);
mse.update(pred, actual);
}
let expected = mse.get().sqrt();
assert!(
approx_eq(rmse.get(), expected),
"RMSE = sqrt(MSE): expected {expected}, got {}",
rmse.get()
);
}
#[test]
fn r_squared_streaming_uses_welford_sstot() {
let mut m = R2::new();
m.update(1.0, 1.0);
m.update(2.0, 2.0);
m.update(3.0, 3.0);
m.update(4.0, 4.0);
assert!(
approx_eq(m.get(), 1.0),
"R2 = 1.0 for perfect predictions, got {}",
m.get()
);
}
#[test]
fn r2_negative_for_bad_predictions() {
let mut m = R2::new();
m.update(10.0, 1.0);
m.update(10.0, 2.0);
m.update(10.0, 3.0);
assert!(
m.get() < 0.0,
"R2 must be negative for terrible predictions"
);
}
#[test]
fn r2_cold_start_returns_zero() {
let mut m = R2::new();
m.update(1.0, 1.0);
assert_eq!(m.get(), 0.0, "R2 with <2 samples must be 0.0");
}
#[test]
fn pinball_loss_correct_at_tau_05() {
let mut m = Pinball::new(0.5);
m.update(2.0, 3.0);
assert!(
approx_eq(m.get(), 0.5),
"Pinball(0.5) on residual=1 expected 0.5, got {}",
m.get()
);
}
#[test]
fn pinball_asymmetry_at_tau_09() {
let mut m_over = Pinball::new(0.9);
m_over.update(5.0, 3.0);
let mut m_under = Pinball::new(0.9);
m_under.update(3.0, 5.0);
assert!(
m_over.get() < m_under.get(),
"overestimate loss ({}) < underestimate loss ({}) at tau=0.9",
m_over.get(),
m_under.get()
);
assert!(approx_eq(m_over.get(), 0.2));
assert!(approx_eq(m_under.get(), 1.8));
}
#[test]
fn pinball_empty_returns_zero() {
assert_eq!(Pinball::new(0.5).get(), 0.0);
}
#[test]
#[should_panic(expected = "Pinball tau must be in (0, 1)")]
fn pinball_rejects_tau_out_of_range() {
let _ = Pinball::new(1.0);
}
#[test]
fn logloss_at_half_equals_ln2() {
let mut m = LogLoss::new();
m.update(0.5, 1.0);
m.update(0.5, 0.0);
let expected = 2.0_f64.ln();
assert!(
approx_eq(m.get(), expected),
"LogLoss at p=0.5 expected ln(2)={expected}, got {}",
m.get()
);
}
#[test]
fn logloss_clamps_extremes() {
let mut m = LogLoss::new();
m.update(0.0, 1.0); m.update(1.0, 0.0); assert!(
m.get().is_finite(),
"LogLoss must not be inf/NaN at extremes"
);
}
#[test]
fn accuracy_classification_metric_correct() {
let mut m = Accuracy::new();
m.update(1.0, 1.0); m.update(0.0, 0.0); m.update(1.0, 0.0); m.update(0.0, 1.0); assert!(
approx_eq(m.get(), 0.5),
"Accuracy expected 0.5, got {}",
m.get()
);
}
#[test]
fn accuracy_handles_multiclass_labels() {
let mut m = Accuracy::new();
m.update(2.0, 2.0); m.update(1.0, 2.0); m.update(0.0, 0.0); assert!(
approx_eq(m.get(), 2.0 / 3.0),
"Accuracy with 3-class expected {}, got {}",
2.0 / 3.0,
m.get()
);
}
#[test]
fn composed_metrics_emit_both_values() {
let mut m = MAE::new() + MSE::new();
m.update(2.0, 3.0); m.update(1.0, 3.0); assert!(
approx_eq(m.a.get(), 1.5),
"Union.a (MAE) expected 1.5, got {}",
m.a.get()
);
assert!(
approx_eq(m.b.get(), 2.5),
"Union.b (MSE) expected 2.5, got {}",
m.b.get()
);
}
#[test]
fn triple_union_receives_all_updates() {
let mut m = MAE::new() + MSE::new() + Accuracy::new();
m.update(1.0, 1.0); m.update(2.0, 3.0); assert!(
approx_eq(m.a.a.get(), 0.5),
"MAE in triple union: expected 0.5, got {}",
m.a.a.get()
);
assert!(
approx_eq(m.b.get(), 0.5),
"Accuracy in triple union: expected 0.5, got {}",
m.b.get()
);
}
#[test]
fn union_reset_resets_both() {
let mut m = MAE::new() + MSE::new();
m.update(1.0, 2.0);
m.reset();
assert_eq!(m.a.get(), 0.0, "MAE must be 0 after union reset");
assert_eq!(m.b.get(), 0.0, "MSE must be 0 after union reset");
}
#[test]
fn higher_is_better_flags() {
assert!(!MAE::new().higher_is_better(), "MAE: lower is better");
assert!(!MSE::new().higher_is_better(), "MSE: lower is better");
assert!(!RMSE::new().higher_is_better(), "RMSE: lower is better");
assert!(R2::new().higher_is_better(), "R2: higher is better");
assert!(
Accuracy::new().higher_is_better(),
"Accuracy: higher is better"
);
}
}