Skip to main content

phop_core/
loss.rs

1//! Robust regression losses for constant fitting under outliers.
2//!
3//! Plain mean-squared error gives every residual quadratic weight, so a handful of gross
4//! outliers dominate the fit and the recovered constants drift away from the clean law. A
5//! **robust loss** caps (Huber) or discards (trimmed) the influence of large residuals.
6//!
7//! `oxieml`'s symbolic-regression engine exposes the same two robust variants through its
8//! [`oxieml::SymRegLoss`] enum, but the loss math itself (`huber_loss`, `trimmed_mse`, and their
9//! gradient factors) is private to that crate. phop re-implements the (short, standard) formulas
10//! here so they can drive phop's *own* Levenberg–Marquardt polish via **IRLS** (iteratively
11//! reweighted least squares): at each LM iteration the per-row residual and Jacobian are scaled by
12//! a weight `w_i = √(grad_factor(r_i)/r_i)`, turning a robust objective into a sequence of weighted
13//! least-squares solves. With [`RobustLoss::Mse`] every weight is `1`, so the polish reduces
14//! exactly to ordinary least squares.
15
16use serde::{Deserialize, Serialize};
17
18/// A loss function for fitting an expression's constants to data.
19#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
20pub enum RobustLoss {
21    /// Mean-squared error (no robustness). Every residual gets full quadratic weight.
22    #[default]
23    Mse,
24    /// Huber loss: quadratic for `|r| ≤ delta`, linear beyond. `delta` is the quadratic-to-linear
25    /// transition; smaller `delta` is more aggressive at down-weighting outliers.
26    Huber {
27        /// Transition point between the quadratic and linear regimes.
28        delta: f64,
29    },
30    /// Trimmed MSE: discard the largest `alpha` fraction of squared residuals before averaging.
31    /// `alpha ∈ [0, 1)`; e.g. `alpha = 0.1` ignores the worst 10% of points.
32    Trimmed {
33        /// Fraction of largest-residual points to trim.
34        alpha: f64,
35    },
36}
37
38/// Clamp ceiling for IRLS weights, mirroring `oxieml`'s relaxation: keeps a near-zero residual
39/// from producing an unbounded `grad_factor(r)/r` ratio.
40const WEIGHT_CLAMP: f64 = 1e6;
41/// Sharpness of the trimmed-loss smooth (sigmoid) soft-trim used for the gradient weight.
42const TRIM_SHARPNESS: f64 = 3.0;
43
44impl RobustLoss {
45    /// The scalar loss (mean over points) for a residual vector `r = prediction − target`.
46    #[must_use]
47    pub fn cost(self, residuals: &[f64]) -> f64 {
48        let n = residuals.len().max(1) as f64;
49        match self {
50            RobustLoss::Mse => residuals.iter().map(|r| r * r).sum::<f64>() / n,
51            RobustLoss::Huber { delta } => {
52                residuals
53                    .iter()
54                    .map(|&r| huber_point(r, delta))
55                    .sum::<f64>()
56                    / n
57            }
58            RobustLoss::Trimmed { alpha } => trimmed_mse(residuals, alpha),
59        }
60    }
61
62    /// The IRLS multiplicative weight `w_i` applied to residual row `i` (and the matching Jacobian
63    /// row) so that a weighted least-squares step descends the robust objective. `residuals` is the
64    /// full current residual vector (needed by the trimmed loss to find its quantile threshold).
65    #[must_use]
66    pub fn irls_weight(self, r: f64, residuals: &[f64]) -> f64 {
67        let ratio = match self {
68            RobustLoss::Mse => 1.0,
69            RobustLoss::Huber { delta } => {
70                if r.abs() <= delta || r == 0.0 {
71                    1.0
72                } else {
73                    delta / r.abs()
74                }
75            }
76            RobustLoss::Trimmed { alpha } => soft_trim_weight(r, residuals, alpha),
77        };
78        ratio.clamp(0.0, WEIGHT_CLAMP).sqrt()
79    }
80}
81
82/// Per-point Huber loss `L(r) = ½r²` for `|r| ≤ δ`, else `δ(|r| − ½δ)`.
83fn huber_point(r: f64, delta: f64) -> f64 {
84    let a = r.abs();
85    if a <= delta {
86        0.5 * r * r
87    } else {
88        delta * (a - 0.5 * delta)
89    }
90}
91
92/// Trimmed MSE: square residuals, drop the largest `alpha` fraction, average the rest.
93fn trimmed_mse(residuals: &[f64], alpha: f64) -> f64 {
94    if residuals.is_empty() {
95        return 0.0;
96    }
97    let alpha = alpha.clamp(0.0, 1.0);
98    let mut sq: Vec<f64> = residuals.iter().map(|r| r * r).collect();
99    sq.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
100    // Keep the ceil((1-alpha)·n) smallest squared residuals (at least one).
101    let keep = (((1.0 - alpha) * residuals.len() as f64).ceil() as usize).max(1);
102    let keep = keep.min(sq.len());
103    sq.iter().take(keep).sum::<f64>() / keep as f64
104}
105
106/// Smooth (sigmoid) soft-trim weight in `[0, 1]`: ≈1 for residuals below the `(1-alpha)` quantile
107/// of `|r|`, decaying to 0 beyond it. Smoothness keeps the IRLS iteration well-behaved (a hard cut
108/// would make the weight discontinuous in the parameters).
109fn soft_trim_weight(r: f64, residuals: &[f64], alpha: f64) -> f64 {
110    if residuals.is_empty() {
111        return 1.0;
112    }
113    let alpha = alpha.clamp(0.0, 1.0);
114    let mut abs: Vec<f64> = residuals.iter().map(|v| v.abs()).collect();
115    abs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
116    // (1-alpha) quantile of |r| as the soft threshold.
117    let q_idx = (((1.0 - alpha) * (abs.len() as f64 - 1.0)).round() as usize).min(abs.len() - 1);
118    let q = abs[q_idx];
119    1.0 / (1.0 + (TRIM_SHARPNESS * (r.abs() - q)).exp())
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn mse_weights_are_unit() {
128        let r = [1.0, -5.0, 0.0, 100.0];
129        for &ri in &r {
130            assert!((RobustLoss::Mse.irls_weight(ri, &r) - 1.0).abs() < 1e-15);
131        }
132        let expected = (1.0 + 25.0 + 0.0 + 10_000.0) / 4.0;
133        assert!((RobustLoss::Mse.cost(&r) - expected).abs() < 1e-9);
134    }
135
136    #[test]
137    fn huber_caps_large_residuals() {
138        let delta = 1.0;
139        let loss = RobustLoss::Huber { delta };
140        // Inside the quadratic regime: weight 1, cost ½r².
141        assert!((loss.irls_weight(0.5, &[0.5]) - 1.0).abs() < 1e-12);
142        assert!((huber_point(0.5, delta) - 0.125).abs() < 1e-12);
143        // Outside: weight √(δ/|r|) < 1, cost linear.
144        let w = loss.irls_weight(4.0, &[4.0]);
145        assert!(w < 1.0 && (w - (1.0_f64 / 4.0).sqrt()).abs() < 1e-12);
146        assert!((huber_point(4.0, delta) - (4.0 - 0.5)).abs() < 1e-12);
147    }
148
149    #[test]
150    fn trimmed_drops_worst_points() {
151        // Nine small residuals and one huge outlier; trimming 10% removes the outlier.
152        let mut r = vec![0.1_f64; 9];
153        r.push(1000.0);
154        let full = RobustLoss::Mse.cost(&r);
155        let trimmed = RobustLoss::Trimmed { alpha: 0.1 }.cost(&r);
156        assert!(
157            trimmed < full * 1e-3,
158            "trim did not drop the outlier: {trimmed} vs {full}"
159        );
160        // The outlier gets a near-zero IRLS weight; a clean point keeps weight ≈ 1.
161        let w_out = RobustLoss::Trimmed { alpha: 0.1 }.irls_weight(1000.0, &r);
162        let w_in = RobustLoss::Trimmed { alpha: 0.1 }.irls_weight(0.1, &r);
163        assert!(w_out < 0.1, "outlier weight too high: {w_out}");
164        assert!(w_in > 0.5, "inlier weight too low: {w_in}");
165    }
166}