use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
pub enum RobustLoss {
#[default]
Mse,
Huber {
delta: f64,
},
Trimmed {
alpha: f64,
},
}
const WEIGHT_CLAMP: f64 = 1e6;
const TRIM_SHARPNESS: f64 = 3.0;
impl RobustLoss {
#[must_use]
pub fn cost(self, residuals: &[f64]) -> f64 {
let n = residuals.len().max(1) as f64;
match self {
RobustLoss::Mse => residuals.iter().map(|r| r * r).sum::<f64>() / n,
RobustLoss::Huber { delta } => {
residuals
.iter()
.map(|&r| huber_point(r, delta))
.sum::<f64>()
/ n
}
RobustLoss::Trimmed { alpha } => trimmed_mse(residuals, alpha),
}
}
#[must_use]
pub fn irls_weight(self, r: f64, residuals: &[f64]) -> f64 {
let ratio = match self {
RobustLoss::Mse => 1.0,
RobustLoss::Huber { delta } => {
if r.abs() <= delta || r == 0.0 {
1.0
} else {
delta / r.abs()
}
}
RobustLoss::Trimmed { alpha } => soft_trim_weight(r, residuals, alpha),
};
ratio.clamp(0.0, WEIGHT_CLAMP).sqrt()
}
}
fn huber_point(r: f64, delta: f64) -> f64 {
let a = r.abs();
if a <= delta {
0.5 * r * r
} else {
delta * (a - 0.5 * delta)
}
}
fn trimmed_mse(residuals: &[f64], alpha: f64) -> f64 {
if residuals.is_empty() {
return 0.0;
}
let alpha = alpha.clamp(0.0, 1.0);
let mut sq: Vec<f64> = residuals.iter().map(|r| r * r).collect();
sq.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let keep = (((1.0 - alpha) * residuals.len() as f64).ceil() as usize).max(1);
let keep = keep.min(sq.len());
sq.iter().take(keep).sum::<f64>() / keep as f64
}
fn soft_trim_weight(r: f64, residuals: &[f64], alpha: f64) -> f64 {
if residuals.is_empty() {
return 1.0;
}
let alpha = alpha.clamp(0.0, 1.0);
let mut abs: Vec<f64> = residuals.iter().map(|v| v.abs()).collect();
abs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let q_idx = (((1.0 - alpha) * (abs.len() as f64 - 1.0)).round() as usize).min(abs.len() - 1);
let q = abs[q_idx];
1.0 / (1.0 + (TRIM_SHARPNESS * (r.abs() - q)).exp())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mse_weights_are_unit() {
let r = [1.0, -5.0, 0.0, 100.0];
for &ri in &r {
assert!((RobustLoss::Mse.irls_weight(ri, &r) - 1.0).abs() < 1e-15);
}
let expected = (1.0 + 25.0 + 0.0 + 10_000.0) / 4.0;
assert!((RobustLoss::Mse.cost(&r) - expected).abs() < 1e-9);
}
#[test]
fn huber_caps_large_residuals() {
let delta = 1.0;
let loss = RobustLoss::Huber { delta };
assert!((loss.irls_weight(0.5, &[0.5]) - 1.0).abs() < 1e-12);
assert!((huber_point(0.5, delta) - 0.125).abs() < 1e-12);
let w = loss.irls_weight(4.0, &[4.0]);
assert!(w < 1.0 && (w - (1.0_f64 / 4.0).sqrt()).abs() < 1e-12);
assert!((huber_point(4.0, delta) - (4.0 - 0.5)).abs() < 1e-12);
}
#[test]
fn trimmed_drops_worst_points() {
let mut r = vec![0.1_f64; 9];
r.push(1000.0);
let full = RobustLoss::Mse.cost(&r);
let trimmed = RobustLoss::Trimmed { alpha: 0.1 }.cost(&r);
assert!(
trimmed < full * 1e-3,
"trim did not drop the outlier: {trimmed} vs {full}"
);
let w_out = RobustLoss::Trimmed { alpha: 0.1 }.irls_weight(1000.0, &r);
let w_in = RobustLoss::Trimmed { alpha: 0.1 }.irls_weight(0.1, &r);
assert!(w_out < 0.1, "outlier weight too high: {w_out}");
assert!(w_in > 0.5, "inlier weight too low: {w_in}");
}
}