pub const HUBER_K: f64 = 1.345;
pub const MAD_NORMAL_CONST: f64 = 1.4826;
#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
pub enum RobustError {
#[error("invalid robust statistic {field}: {reason}")]
InvalidInput {
field: &'static str,
reason: &'static str,
},
}
impl RobustError {
pub const fn field(&self) -> &'static str {
match self {
Self::InvalidInput { field, .. } => field,
}
}
pub const fn reason(&self) -> &'static str {
match self {
Self::InvalidInput { reason, .. } => reason,
}
}
}
pub fn median(values: &[f64]) -> Result<f64, RobustError> {
validate_finite_slice(values, "values")?;
if values.is_empty() {
return Ok(0.0);
}
let mut v: Vec<f64> = values.to_vec();
v.sort_by(|a, b| a.total_cmp(b));
let n = v.len();
if n % 2 == 1 {
Ok(v[n / 2])
} else {
Ok((v[n / 2 - 1] + v[n / 2]) / 2.0)
}
}
pub fn mad_scale(residuals: &[f64], scale_floor: f64) -> Result<f64, RobustError> {
validate_finite_positive(scale_floor, "scale_floor")?;
let med = median(residuals)?;
let abs_dev: Vec<f64> = residuals.iter().map(|r| (r - med).abs()).collect();
let mad = median(&abs_dev)?;
let scaled = MAD_NORMAL_CONST * mad;
if scaled > scale_floor {
Ok(scaled)
} else {
Ok(scale_floor)
}
}
pub fn huber_weight(u: f64, k: f64) -> f64 {
let a = u.abs();
if a <= k {
1.0
} else {
k / a
}
}
fn validate_finite_slice(values: &[f64], field: &'static str) -> Result<(), RobustError> {
if values.iter().all(|value| value.is_finite()) {
Ok(())
} else {
Err(invalid_input(field, "not finite"))
}
}
fn validate_finite_positive(value: f64, field: &'static str) -> Result<(), RobustError> {
if !value.is_finite() {
Err(invalid_input(field, "not finite"))
} else if value <= 0.0 {
Err(invalid_input(field, "not positive"))
} else {
Ok(())
}
}
fn invalid_input(field: &'static str, reason: &'static str) -> RobustError {
RobustError::InvalidInput { field, reason }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn median_odd_even() {
assert_eq!(median(&[3.0, 1.0, 2.0]).unwrap(), 2.0);
assert_eq!(median(&[1.0, 2.0, 3.0, 4.0]).unwrap(), 2.5);
assert_eq!(median(&[]).unwrap(), 0.0);
}
#[test]
fn median_rejects_nonfinite_sample() {
assert_eq!(
median(&[1.0, f64::NAN]),
Err(RobustError::InvalidInput {
field: "values",
reason: "not finite"
})
);
}
#[test]
fn huber_weight_breaks_at_k() {
assert_eq!(huber_weight(0.0, HUBER_K), 1.0);
assert_eq!(huber_weight(HUBER_K, HUBER_K), 1.0);
let w = huber_weight(2.0 * HUBER_K, HUBER_K);
assert!((w - 0.5).abs() < 1e-15);
}
#[test]
fn mad_scale_floored() {
assert_eq!(mad_scale(&[5.0, 5.0, 5.0], 0.25).unwrap(), 0.25);
}
#[test]
fn mad_scale_rejects_nonfinite_sample() {
assert_eq!(
mad_scale(&[5.0, f64::INFINITY], 0.25),
Err(RobustError::InvalidInput {
field: "values",
reason: "not finite"
})
);
}
}