apex_solver/core/
loss_functions.rs

1//! Robust loss functions for outlier rejection in nonlinear least squares optimization.
2//!
3//! Loss functions (also called robust cost functions or M-estimators) reduce the influence of
4//! outlier measurements on the optimization result. In standard least squares, the cost is
5//! the squared norm of residuals: `cost = Σ ||r_i||²`. With a robust loss function ρ(s), the
6//! cost becomes: `cost = Σ ρ(||r_i||²)`.
7//!
8//! # Mathematical Formulation
9//!
10//! Each loss function implements the `Loss` trait, which evaluates:
11//! - **ρ(s)**: The robust cost value
12//! - **ρ'(s)**: First derivative (weight function)
13//! - **ρ''(s)**: Second derivative (for corrector algorithm)
14//!
15//! The input `s = ||r||²` is the squared norm of the residual vector.
16//!
17//! # Usage in Optimization
18//!
19//! Loss functions are applied via the `Corrector` algorithm (see `corrector.rs`), which
20//! modifies the residuals and Jacobians to account for the robust weighting. The optimization
21//! then proceeds as if solving a reweighted least squares problem.
22//!
23//! # Available Loss Functions
24//!
25//! ## Basic Loss Functions
26//! - [`L2Loss`]: Standard least squares (no robustness)
27//! - [`L1Loss`]: Absolute error (simple robust baseline)
28//!
29//! ## Moderate Robustness
30//! - [`HuberLoss`]: Quadratic for inliers, linear for outliers (recommended for general use)
31//! - [`FairLoss`]: Smooth transition with continuous derivatives
32//! - [`CauchyLoss`]: Heavier suppression of large residuals
33//!
34//! ## Strong Robustness
35//! - [`GemanMcClureLoss`]: Very strong outlier rejection
36//! - [`WelschLoss`]: Exponential downweighting
37//! - [`TukeyBiweightLoss`]: Complete outlier suppression (redescending)
38//!
39//! ## Specialized Functions
40//! - [`AndrewsWaveLoss`]: Sine-based redescending M-estimator
41//! - [`RamsayEaLoss`]: Exponential decay weighting
42//! - [`TrimmedMeanLoss`]: Hard threshold cutoff
43//! - [`LpNormLoss`]: Generalized Lp norm (flexible p parameter)
44//!
45//! ## Modern Adaptive
46//! - [`BarronGeneralLoss`]: Unified framework encompassing many loss functions (CVPR 2019)
47//!
48//! # Loss Function Selection Guide
49//!
50//! | Use Case | Recommended Loss | Tuning Constant |
51//! |----------|-----------------|-----------------|
52//! | Clean data, no outliers | `L2Loss` | N/A |
53//! | Few outliers (<5%) | `HuberLoss` | c = 1.345 |
54//! | Moderate outliers (5-10%) | `FairLoss` or `CauchyLoss` | c = 1.3998 / 2.3849 |
55//! | Many outliers (>10%) | `WelschLoss` or `TukeyBiweightLoss` | c = 2.9846 / 4.6851 |
56//! | Severe outliers | `GemanMcClureLoss` | c = 1.0-2.0 |
57//! | Adaptive/unknown | `BarronGeneralLoss` | α adaptive |
58//!
59//! # Example
60//!
61//! ```
62//! use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
63//! # use apex_solver::error::ApexSolverResult;
64//! # fn example() -> ApexSolverResult<()> {
65//!
66//! let huber = HuberLoss::new(1.345)?;
67//!
68//! // Evaluate for an inlier (small residual)
69//! let s_inlier = 0.5;
70//! let [rho, rho_prime, rho_double_prime] = huber.evaluate(s_inlier);
71//! assert_eq!(rho, s_inlier); // Quadratic cost in inlier region
72//! assert_eq!(rho_prime, 1.0); // Full weight
73//!
74//! // Evaluate for an outlier (large residual)
75//! let s_outlier = 10.0;
76//! let [rho, rho_prime, rho_double_prime] = huber.evaluate(s_outlier);
77//! // rho grows linearly instead of quadratically
78//! // rho_prime < 1.0, downweighting the outlier
79//! # Ok(())
80//! # }
81//! # example().unwrap();
82//! ```
83
84use crate::core::CoreError;
85use crate::error::ApexSolverResult;
86
87/// Trait for robust loss functions used in nonlinear least squares optimization.
88///
89/// A loss function transforms the squared residual `s = ||r||²` into a robust cost `ρ(s)`
90/// that reduces the influence of outliers. The trait provides the cost value and its first
91/// two derivatives, which are used by the `Corrector` to modify the optimization problem.
92///
93/// # Returns
94///
95/// The `evaluate` method returns a 3-element array: `[ρ(s), ρ'(s), ρ''(s)]`
96/// - `ρ(s)`: Robust cost value
97/// - `ρ'(s)`: First derivative (weight function)
98/// - `ρ''(s)`: Second derivative
99///
100/// # Implementation Notes
101///
102/// - Loss functions should be smooth (at least C²) for optimization stability
103/// - Typically ρ(0) = 0, ρ'(0) = 1, ρ''(0) = 0 (behaves like standard least squares near zero)
104/// - For outliers, ρ'(s) should decrease to downweight large residuals
105pub trait LossFunction: Send + Sync {
106    /// Evaluate the loss function and its first two derivatives at squared residual `s`.
107    ///
108    /// # Arguments
109    ///
110    /// * `s` - The squared norm of the residual: `s = ||r||²` (always non-negative)
111    ///
112    /// # Returns
113    ///
114    /// Array `[ρ(s), ρ'(s), ρ''(s)]` containing the cost, first derivative, and second derivative
115    fn evaluate(&self, s: f64) -> [f64; 3];
116}
117
118/// L2 loss function (standard least squares, no robustness).
119///
120/// The L2 loss is the standard squared error used in ordinary least squares optimization.
121/// It provides no outlier robustness and is optimal only when residuals follow a Gaussian
122/// distribution.
123///
124/// # Mathematical Definition
125///
126/// ```text
127/// ρ(s) = s
128/// ρ'(s) = 1
129/// ρ''(s) = 0
130/// ```
131///
132/// where `s = ||r||²` is the squared residual norm.
133///
134/// # Properties
135///
136/// - **Convex**: Globally optimal solution
137/// - **Not robust**: Outliers have full influence (squared!)
138/// - **Optimal**: For Gaussian noise without outliers
139/// - **Fast**: Simplest to compute
140///
141/// # Use Cases
142///
143/// - Clean data with known Gaussian noise
144/// - Baseline comparison for robust methods
145/// - When outliers are already filtered
146///
147/// # Example
148///
149/// ```
150/// use apex_solver::core::loss_functions::{LossFunction, L2Loss};
151///
152/// let l2 = L2Loss::new();
153///
154/// let [rho, rho_prime, rho_double_prime] = l2.evaluate(4.0);
155/// assert_eq!(rho, 4.0);  // ρ(s) = s
156/// assert_eq!(rho_prime, 1.0);  // Full weight
157/// assert_eq!(rho_double_prime, 0.0);
158/// ```
159#[derive(Debug, Clone, Copy)]
160pub struct L2Loss;
161
162impl L2Loss {
163    /// Create a new L2 loss function (no parameters needed).
164    pub fn new() -> Self {
165        L2Loss
166    }
167}
168
169impl Default for L2Loss {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175impl LossFunction for L2Loss {
176    fn evaluate(&self, s: f64) -> [f64; 3] {
177        [s, 1.0, 0.0]
178    }
179}
180
181/// L1 loss function (absolute error, simple robust baseline).
182///
183/// The L1 loss uses absolute error instead of squared error, providing basic robustness
184/// to outliers. It is optimal for Laplacian noise distributions.
185///
186/// # Mathematical Definition
187///
188/// ```text
189/// ρ(s) = 2√s
190/// ρ'(s) = 1/√s
191/// ρ''(s) = -1/(2s^(3/2))
192/// ```
193///
194/// where `s = ||r||²` is the squared residual norm.
195///
196/// # Properties
197///
198/// - **Convex**: Globally optimal solution
199/// - **Moderately robust**: Linear growth vs quadratic
200/// - **Unstable at zero**: Derivative undefined at s=0
201/// - **Median estimator**: Minimizes to median instead of mean
202///
203/// # Use Cases
204///
205/// - Simple outlier rejection
206/// - When median is preferred over mean
207/// - Sparse optimization problems
208///
209/// # Example
210///
211/// ```
212/// use apex_solver::core::loss_functions::{LossFunction, L1Loss};
213///
214/// let l1 = L1Loss::new();
215///
216/// let [rho, rho_prime, _] = l1.evaluate(4.0);
217/// assert!((rho - 4.0).abs() < 1e-10);  // ρ(4) = 2√4 = 4
218/// assert!((rho_prime - 0.5).abs() < 1e-10);  // ρ'(4) = 1/√4 = 0.5
219/// ```
220#[derive(Debug, Clone, Copy)]
221pub struct L1Loss;
222
223impl L1Loss {
224    /// Create a new L1 loss function (no parameters needed).
225    pub fn new() -> Self {
226        L1Loss
227    }
228}
229
230impl Default for L1Loss {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236impl LossFunction for L1Loss {
237    fn evaluate(&self, s: f64) -> [f64; 3] {
238        if s < f64::EPSILON {
239            // Near zero: use L2 to avoid singularity
240            return [s, 1.0, 0.0];
241        }
242        let sqrt_s = s.sqrt();
243        [
244            2.0 * sqrt_s,              // ρ(s) = 2√s
245            1.0 / sqrt_s,              // ρ'(s) = 1/√s
246            -1.0 / (2.0 * s * sqrt_s), // ρ''(s) = -1/(2s√s)
247        ]
248    }
249}
250
251/// Huber loss function for moderate outlier rejection.
252///
253/// The Huber loss is quadratic for small residuals (inliers) and linear for large residuals
254/// (outliers), providing a good balance between robustness and efficiency.
255///
256/// # Mathematical Definition
257///
258/// ```text
259/// ρ(s) = {  s                           if s ≤ δ²
260///        {  2δ√s - δ²                  if s > δ²
261///
262/// ρ'(s) = {  1                          if s ≤ δ²
263///         {  δ / √s                    if s > δ²
264///
265/// ρ''(s) = {  0                         if s ≤ δ²
266///          {  -δ / (2s^(3/2))          if s > δ²
267/// ```
268///
269/// where `δ` is the scale parameter (threshold), and `s = ||r||²` is the squared residual norm.
270///
271/// # Properties
272///
273/// - **Inlier region** (s ≤ δ²): Behaves like standard least squares (quadratic cost)
274/// - **Outlier region** (s > δ²): Cost grows linearly, limiting outlier influence
275/// - **Transition point**: At s = δ², the function switches from quadratic to linear
276///
277/// # Scale Parameter Selection
278///
279/// Common choices for the scale parameter `δ`:
280/// - **1.345**: Approximately 95% efficiency on Gaussian data (most common)
281/// - **0.5-1.0**: More aggressive outlier rejection
282/// - **2.0-3.0**: More lenient, closer to standard least squares
283///
284/// # Example
285///
286/// ```
287/// use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
288/// # use apex_solver::error::ApexSolverResult;
289/// # fn example() -> ApexSolverResult<()> {
290///
291/// // Create Huber loss with scale = 1.345 (standard choice)
292/// let huber = HuberLoss::new(1.345)?;
293///
294/// // Small residual (inlier): ||r||² = 0.5
295/// let [rho, rho_prime, rho_double_prime] = huber.evaluate(0.5);
296/// assert_eq!(rho, 0.5);           // Quadratic: ρ(s) = s
297/// assert_eq!(rho_prime, 1.0);     // Full weight
298/// assert_eq!(rho_double_prime, 0.0);
299///
300/// // Large residual (outlier): ||r||² = 10.0
301/// let [rho, rho_prime, rho_double_prime] = huber.evaluate(10.0);
302/// // ρ(10) ≈ 6.69, grows linearly not quadratically
303/// // ρ'(10) ≈ 0.425, downweighted to ~42.5% of original
304/// # Ok(())
305/// # }
306/// # example().unwrap();
307/// ```
308#[derive(Debug, Clone)]
309pub struct HuberLoss {
310    /// Scale parameter δ
311    scale: f64,
312    /// Cached value δ² for efficient computation
313    scale2: f64,
314}
315
316impl HuberLoss {
317    /// Create a new Huber loss function with the given scale parameter.
318    ///
319    /// # Arguments
320    ///
321    /// * `scale` - The threshold δ that separates inliers from outliers (must be positive)
322    ///
323    /// # Returns
324    ///
325    /// `Ok(HuberLoss)` if scale > 0, otherwise an error
326    ///
327    /// # Example
328    ///
329    /// ```
330    /// use apex_solver::core::loss_functions::HuberLoss;
331    /// # use apex_solver::error::ApexSolverResult;
332    /// # fn example() -> ApexSolverResult<()> {
333    ///
334    /// let huber = HuberLoss::new(1.345)?;
335    /// # Ok(())
336    /// # }
337    /// # example().unwrap();
338    /// ```
339    pub fn new(scale: f64) -> ApexSolverResult<Self> {
340        if scale <= 0.0 {
341            return Err(
342                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
343            );
344        }
345        Ok(HuberLoss {
346            scale,
347            scale2: scale * scale,
348        })
349    }
350}
351
352impl LossFunction for HuberLoss {
353    /// Evaluate Huber loss function: ρ(s), ρ'(s), ρ''(s).
354    ///
355    /// # Arguments
356    ///
357    /// * `s` - Squared residual norm: s = ||r||²
358    ///
359    /// # Returns
360    ///
361    /// `[ρ(s), ρ'(s), ρ''(s)]` - Cost, first derivative, second derivative
362    fn evaluate(&self, s: f64) -> [f64; 3] {
363        if s > self.scale2 {
364            // Outlier region: s > δ²
365            // Linear cost: ρ(s) = 2δ√s - δ²
366            let r = s.sqrt(); // r = √s = ||r||
367            let rho1 = (self.scale / r).max(f64::MIN); // ρ'(s) = δ / √s
368            [
369                2.0 * self.scale * r - self.scale2, // ρ(s)
370                rho1,                               // ρ'(s)
371                -rho1 / (2.0 * s),                  // ρ''(s) = -δ / (2s√s)
372            ]
373        } else {
374            // Inlier region: s ≤ δ²
375            // Quadratic cost: ρ(s) = s, ρ'(s) = 1, ρ''(s) = 0
376            [s, 1.0, 0.0]
377        }
378    }
379}
380
381/// Cauchy loss function for aggressive outlier rejection.
382///
383/// The Cauchy loss (also called Lorentzian loss) provides stronger suppression of outliers
384/// than Huber loss. It never fully rejects outliers but reduces their weight significantly.
385///
386/// # Mathematical Definition
387///
388/// ```text
389/// ρ(s) = (δ²/2) * log(1 + s/δ²)
390///
391/// ρ'(s) = 1 / (1 + s/δ²)
392///
393/// ρ''(s) = -1 / (δ² * (1 + s/δ²)²)
394/// ```
395///
396/// where `δ` is the scale parameter, and `s = ||r||²` is the squared residual norm.
397///
398/// # Properties
399///
400/// - **Smooth transition**: No sharp boundary between inliers and outliers
401/// - **Logarithmic growth**: Cost grows very slowly for large residuals
402/// - **Strong downweighting**: Large outliers receive very small weights
403/// - **Non-convex**: Can have multiple local minima (harder to optimize than Huber)
404///
405/// # Scale Parameter Selection
406///
407/// Typical values:
408/// - **2.3849**: Approximately 95% efficiency on Gaussian data
409/// - **1.0-2.0**: More aggressive outlier rejection
410/// - **3.0-5.0**: More lenient
411///
412/// # Comparison to Huber Loss
413///
414/// - **Cauchy**: Stronger outlier rejection, smoother, but non-convex (may converge to local minimum)
415/// - **Huber**: Weaker outlier rejection, convex, more predictable convergence
416///
417/// # Example
418///
419/// ```
420/// use apex_solver::core::loss_functions::{LossFunction, CauchyLoss};
421/// # use apex_solver::error::ApexSolverResult;
422/// # fn example() -> ApexSolverResult<()> {
423///
424/// // Create Cauchy loss with scale = 2.3849 (standard choice)
425/// let cauchy = CauchyLoss::new(2.3849)?;
426///
427/// // Small residual: ||r||² = 0.5
428/// let [rho, rho_prime, _] = cauchy.evaluate(0.5);
429/// // ρ ≈ 0.47, slightly less than 0.5 (mild downweighting)
430/// // ρ' ≈ 0.92, close to 1.0 (near full weight)
431///
432/// // Large residual: ||r||² = 100.0
433/// let [rho, rho_prime, _] = cauchy.evaluate(100.0);
434/// // ρ ≈ 8.0, logarithmic growth (much less than 100)
435/// // ρ' ≈ 0.05, heavily downweighted (5% of original)
436/// # Ok(())
437/// # }
438/// # example().unwrap();
439/// ```
440pub struct CauchyLoss {
441    /// Cached value δ² (scale squared)
442    scale2: f64,
443    /// Cached value 1/δ² for efficient computation
444    c: f64,
445}
446
447impl CauchyLoss {
448    /// Create a new Cauchy loss function with the given scale parameter.
449    ///
450    /// # Arguments
451    ///
452    /// * `scale` - The scale parameter δ (must be positive)
453    ///
454    /// # Returns
455    ///
456    /// `Ok(CauchyLoss)` if scale > 0, otherwise an error
457    ///
458    /// # Example
459    ///
460    /// ```
461    /// use apex_solver::core::loss_functions::CauchyLoss;
462    /// # use apex_solver::error::ApexSolverResult;
463    /// # fn example() -> ApexSolverResult<()> {
464    ///
465    /// let cauchy = CauchyLoss::new(2.3849)?;
466    /// # Ok(())
467    /// # }
468    /// # example().unwrap();
469    /// ```
470    pub fn new(scale: f64) -> ApexSolverResult<Self> {
471        if scale <= 0.0 {
472            return Err(
473                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
474            );
475        }
476        let scale2 = scale * scale;
477        Ok(CauchyLoss {
478            scale2,
479            c: 1.0 / scale2,
480        })
481    }
482}
483
484impl LossFunction for CauchyLoss {
485    /// Evaluate Cauchy loss function: ρ(s), ρ'(s), ρ''(s).
486    ///
487    /// # Arguments
488    ///
489    /// * `s` - Squared residual norm: s = ||r||²
490    ///
491    /// # Returns
492    ///
493    /// `[ρ(s), ρ'(s), ρ''(s)]` - Cost, first derivative, second derivative
494    fn evaluate(&self, s: f64) -> [f64; 3] {
495        let sum = 1.0 + s * self.c; // 1 + s/δ²
496        let inv = 1.0 / sum; // 1 / (1 + s/δ²)
497
498        // Note: sum and inv are always positive, assuming s ≥ 0
499        [
500            self.scale2 * sum.ln() / 2.0, // ρ(s) = (δ²/2) * ln(1 + s/δ²)
501            inv.max(f64::MIN),            // ρ'(s) = 1 / (1 + s/δ²)
502            -self.c * (inv * inv),        // ρ''(s) = -1 / (δ² * (1 + s/δ²)²)
503        ]
504    }
505}
506
507/// Fair loss function with continuous smooth derivatives.
508///
509/// The Fair loss provides a good balance between robustness and stability with everywhere-defined
510/// continuous derivatives up to third order. It yields a unique solution and is recommended
511/// for general use when you need guaranteed smoothness.
512///
513/// # Mathematical Definition
514///
515/// ```text
516/// ρ(s) = c² * (|x|/c - ln(1 + |x|/c))
517/// ρ'(s) = |x| / (c + |x|)
518/// ρ''(s) = sign(x) * c / ((c + |x|)²√s)
519/// ```
520///
521/// where `c` is the scale parameter, `s = ||r||²`, and `x = √s = ||r||`.
522///
523/// # Properties
524///
525/// - **Smooth**: Continuous derivatives of first three orders
526/// - **Unique solution**: Strictly convex near origin
527/// - **Moderate robustness**: Between Huber and Cauchy
528/// - **Stable**: No discontinuities in optimization
529///
530/// # Scale Parameter Selection
531///
532/// - **1.3998**: Approximately 95% efficiency on Gaussian data (recommended)
533/// - **0.8-1.2**: More aggressive outlier rejection
534/// - **2.0-3.0**: More lenient
535///
536/// # Comparison
537///
538/// - Smoother than Huber (no kink at threshold)
539/// - Less aggressive than Cauchy
540/// - Better behaved numerically than many redescending M-estimators
541///
542/// # Example
543///
544/// ```
545/// use apex_solver::core::loss_functions::{LossFunction, FairLoss};
546/// # use apex_solver::error::ApexSolverResult;
547/// # fn example() -> ApexSolverResult<()> {
548///
549/// let fair = FairLoss::new(1.3998)?;
550///
551/// let [rho, rho_prime, _] = fair.evaluate(4.0);
552/// // Smooth transition, no sharp corners
553/// # Ok(())
554/// # }
555/// # example().unwrap();
556/// ```
557#[derive(Debug, Clone)]
558pub struct FairLoss {
559    scale: f64,
560}
561
562impl FairLoss {
563    /// Create a new Fair loss function with the given scale parameter.
564    ///
565    /// # Arguments
566    ///
567    /// * `scale` - The scale parameter c (must be positive)
568    ///
569    /// # Returns
570    ///
571    /// `Ok(FairLoss)` if scale > 0, otherwise an error
572    pub fn new(scale: f64) -> ApexSolverResult<Self> {
573        if scale <= 0.0 {
574            return Err(
575                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
576            );
577        }
578        Ok(FairLoss { scale })
579    }
580}
581
582impl LossFunction for FairLoss {
583    fn evaluate(&self, s: f64) -> [f64; 3] {
584        if s < f64::EPSILON {
585            return [s, 1.0, 0.0];
586        }
587
588        let x = s.sqrt(); // ||r||
589        let abs_x = x.abs();
590        let c_plus_x = self.scale + abs_x;
591
592        // ρ(s) = c² * (|x|/c - ln(1 + |x|/c))
593        let rho = self.scale * self.scale * (abs_x / self.scale - (1.0 + abs_x / self.scale).ln());
594
595        // ρ'(s) = |x| / (c + |x|) * (1 / 2|x|) = 1 / (2(c + |x|))
596        let rho_prime = 0.5 / c_plus_x;
597
598        // ρ''(s) = -1 / (4s(c + |x|)²)
599        let rho_double_prime = -1.0 / (4.0 * s * c_plus_x * c_plus_x);
600
601        [rho, rho_prime, rho_double_prime]
602    }
603}
604
605/// Geman-McClure loss function for very strong outlier rejection.
606///
607/// The Geman-McClure loss provides one of the strongest forms of outlier suppression,
608/// with weights that decay rapidly for large residuals.
609///
610/// # Mathematical Definition
611///
612/// ```text
613/// ρ(s) = s / (1 + s/c²)
614/// ρ'(s) = 1 / (1 + s/c²)²
615/// ρ''(s) = -2 / (c² * (1 + s/c²)³)
616/// ```
617///
618/// where `c` is the scale parameter and `s = ||r||²`.
619///
620/// # Properties
621///
622/// - **Very strong rejection**: Weights decay as O(1/s²) for large s
623/// - **Non-convex**: Multiple local minima possible
624/// - **No unique solution**: Requires good initialization
625/// - **Aggressive**: Use when outliers are severe
626///
627/// # Scale Parameter Selection
628///
629/// - **1.0-2.0**: Typical range
630/// - **0.5-1.0**: Very aggressive (use with care)
631/// - **2.0-4.0**: More lenient
632///
633/// # Example
634///
635/// ```
636/// use apex_solver::core::loss_functions::{LossFunction, GemanMcClureLoss};
637/// # use apex_solver::error::ApexSolverResult;
638/// # fn example() -> ApexSolverResult<()> {
639///
640/// let geman = GemanMcClureLoss::new(1.0)?;
641///
642/// let [rho, rho_prime, _] = geman.evaluate(100.0);
643/// // Very small weight for large outliers
644/// # Ok(())
645/// # }
646/// # example().unwrap();
647/// ```
648#[derive(Debug, Clone)]
649pub struct GemanMcClureLoss {
650    c: f64, // 1/scale²
651}
652
653impl GemanMcClureLoss {
654    /// Create a new Geman-McClure loss function.
655    ///
656    /// # Arguments
657    ///
658    /// * `scale` - The scale parameter c (must be positive)
659    pub fn new(scale: f64) -> ApexSolverResult<Self> {
660        if scale <= 0.0 {
661            return Err(
662                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
663            );
664        }
665        let scale2 = scale * scale;
666        Ok(GemanMcClureLoss { c: 1.0 / scale2 })
667    }
668}
669
670impl LossFunction for GemanMcClureLoss {
671    fn evaluate(&self, s: f64) -> [f64; 3] {
672        let denom = 1.0 + s * self.c; // 1 + s/c²
673        let inv = 1.0 / denom;
674        let inv2 = inv * inv;
675
676        [
677            s * inv,                    // ρ(s) = s / (1 + s/c²)
678            inv2,                       // ρ'(s) = 1 / (1 + s/c²)²
679            -2.0 * self.c * inv2 * inv, // ρ''(s) = -2 / (c²(1 + s/c²)³)
680        ]
681    }
682}
683
684/// Welsch loss function with exponential downweighting.
685///
686/// The Welsch loss (also called Leclerc loss) uses exponential decay to strongly
687/// suppress outliers while maintaining smoothness. It completely suppresses very
688/// large outliers (weight → 0 as s → ∞).
689///
690/// # Mathematical Definition
691///
692/// ```text
693/// ρ(s) = c²/2 * (1 - exp(-s/c²))
694/// ρ'(s) = (1/2) * exp(-s/c²)
695/// ρ''(s) = -(1/2c²) * exp(-s/c²)
696/// ```
697///
698/// where `c` is the scale parameter and `s = ||r||²`.
699///
700/// # Properties
701///
702/// - **Redescending**: Weights decrease to zero for large residuals
703/// - **Smooth**: Infinitely differentiable
704/// - **Strong suppression**: Exponential decay
705/// - **Non-convex**: Requires good initialization
706///
707/// # Scale Parameter Selection
708///
709/// - **2.9846**: Approximately 95% efficiency on Gaussian data
710/// - **2.0-2.5**: More aggressive
711/// - **3.5-4.5**: More lenient
712///
713/// # Example
714///
715/// ```
716/// use apex_solver::core::loss_functions::{LossFunction, WelschLoss};
717/// # use apex_solver::error::ApexSolverResult;
718/// # fn example() -> ApexSolverResult<()> {
719///
720/// let welsch = WelschLoss::new(2.9846)?;
721///
722/// let [rho, rho_prime, _] = welsch.evaluate(50.0);
723/// // Weight approaches zero for large residuals
724/// # Ok(())
725/// # }
726/// # example().unwrap();
727/// ```
728#[derive(Debug, Clone)]
729pub struct WelschLoss {
730    scale2: f64,
731    inv_scale2: f64,
732}
733
734impl WelschLoss {
735    /// Create a new Welsch loss function.
736    ///
737    /// # Arguments
738    ///
739    /// * `scale` - The scale parameter c (must be positive)
740    pub fn new(scale: f64) -> ApexSolverResult<Self> {
741        if scale <= 0.0 {
742            return Err(
743                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
744            );
745        }
746        let scale2 = scale * scale;
747        Ok(WelschLoss {
748            scale2,
749            inv_scale2: 1.0 / scale2,
750        })
751    }
752}
753
754impl LossFunction for WelschLoss {
755    fn evaluate(&self, s: f64) -> [f64; 3] {
756        let exp_term = (-s * self.inv_scale2).exp();
757
758        [
759            (self.scale2 / 2.0) * (1.0 - exp_term), // ρ(s) = c²/2 * (1 - exp(-s/c²))
760            0.5 * exp_term,                         // ρ'(s) = (1/2) * exp(-s/c²)
761            -0.5 * self.inv_scale2 * exp_term,      // ρ''(s) = -(1/2c²) * exp(-s/c²)
762        ]
763    }
764}
765
766/// Tukey biweight loss function with complete outlier suppression.
767///
768/// The Tukey biweight (bisquare) loss completely suppresses outliers beyond a threshold,
769/// setting their weight to exactly zero. This is a "redescending" M-estimator.
770///
771/// # Mathematical Definition
772///
773/// For |x| ≤ c:
774/// ```text
775/// ρ(s) = c²/6 * (1 - (1 - (x/c)²)³)
776/// ρ'(s) = (1/2) * (1 - (x/c)²)²
777/// ρ''(s) = -(x/c²) * (1 - (x/c)²)
778/// ```
779///
780/// For |x| > c:
781/// ```text
782/// ρ(s) = c²/6
783/// ρ'(s) = 0
784/// ρ''(s) = 0
785/// ```
786///
787/// where `c` is the scale parameter, `x = √s`, and `s = ||r||²`.
788///
789/// # Properties
790///
791/// - **Complete suppression**: Outliers have exactly zero weight
792/// - **Redescending**: Weight goes to zero beyond threshold
793/// - **Non-convex**: Multiple local minima
794/// - **Aggressive**: Best for severe outlier contamination
795///
796/// # Scale Parameter Selection
797///
798/// - **4.6851**: Approximately 95% efficiency on Gaussian data
799/// - **3.5-4.0**: More aggressive
800/// - **5.5-6.5**: More lenient
801///
802/// # Example
803///
804/// ```
805/// use apex_solver::core::loss_functions::{LossFunction, TukeyBiweightLoss};
806/// # use apex_solver::error::ApexSolverResult;
807/// # fn example() -> ApexSolverResult<()> {
808///
809/// let tukey = TukeyBiweightLoss::new(4.6851)?;
810///
811/// let [rho, rho_prime, _] = tukey.evaluate(25.0); // |x| = 5 > 4.6851
812/// assert_eq!(rho_prime, 0.0); // Complete suppression
813/// # Ok(())
814/// # }
815/// # example().unwrap();
816/// ```
817#[derive(Debug, Clone)]
818pub struct TukeyBiweightLoss {
819    scale: f64,
820    scale2: f64,
821}
822
823impl TukeyBiweightLoss {
824    /// Create a new Tukey biweight loss function.
825    ///
826    /// # Arguments
827    ///
828    /// * `scale` - The scale parameter c (must be positive)
829    pub fn new(scale: f64) -> ApexSolverResult<Self> {
830        if scale <= 0.0 {
831            return Err(
832                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
833            );
834        }
835        Ok(TukeyBiweightLoss {
836            scale,
837            scale2: scale * scale,
838        })
839    }
840}
841
842impl LossFunction for TukeyBiweightLoss {
843    fn evaluate(&self, s: f64) -> [f64; 3] {
844        let x = s.sqrt();
845
846        if x > self.scale {
847            // Complete outlier suppression
848            [self.scale2 / 6.0, 0.0, 0.0]
849        } else {
850            let ratio = x / self.scale;
851            let ratio2 = ratio * ratio;
852            let one_minus_ratio2 = 1.0 - ratio2;
853            let one_minus_ratio2_sq = one_minus_ratio2 * one_minus_ratio2;
854
855            [
856                (self.scale2 / 6.0) * (1.0 - one_minus_ratio2 * one_minus_ratio2_sq), // ρ(s)
857                0.5 * one_minus_ratio2_sq,                                            // ρ'(s)
858                -(ratio / self.scale2) * one_minus_ratio2,                            // ρ''(s)
859            ]
860        }
861    }
862}
863
864/// Andrews sine wave loss function (redescending M-estimator).
865///
866/// The Andrews sine wave loss uses a periodic sine function to create a redescending
867/// M-estimator that completely suppresses outliers beyond π*c.
868///
869/// # Mathematical Definition
870///
871/// For |x| ≤ πc:
872/// ```text
873/// ρ(s) = c² * (1 - cos(x/c))
874/// ρ'(s) = (1/2) * sin(x/c)
875/// ρ''(s) = (1/4c) * cos(x/c) / √s
876/// ```
877///
878/// For |x| > πc:
879/// ```text
880/// ρ(s) = 2c²
881/// ρ'(s) = 0
882/// ρ''(s) = 0
883/// ```
884///
885/// where `c` is the scale parameter, `x = √s`, and `s = ||r||²`.
886///
887/// # Properties
888///
889/// - **Periodic structure**: Sine-based weighting
890/// - **Complete suppression**: Zero weight beyond πc
891/// - **Redescending**: Smooth transition to zero
892/// - **Non-convex**: Requires careful initialization
893///
894/// # Scale Parameter Selection
895///
896/// - **1.339**: Standard tuning constant
897/// - **1.0-1.2**: More aggressive
898/// - **1.5-2.0**: More lenient
899///
900/// # Example
901///
902/// ```
903/// use apex_solver::core::loss_functions::{LossFunction, AndrewsWaveLoss};
904/// # use apex_solver::error::ApexSolverResult;
905/// # fn example() -> ApexSolverResult<()> {
906///
907/// let andrews = AndrewsWaveLoss::new(1.339)?;
908///
909/// let [rho, rho_prime, _] = andrews.evaluate(20.0);
910/// // Weight is zero for large outliers
911/// # Ok(())
912/// # }
913/// # example().unwrap();
914/// ```
915#[derive(Debug, Clone)]
916pub struct AndrewsWaveLoss {
917    scale: f64,
918    scale2: f64,
919    threshold: f64, // π * scale
920}
921
922impl AndrewsWaveLoss {
923    /// Create a new Andrews sine wave loss function.
924    ///
925    /// # Arguments
926    ///
927    /// * `scale` - The scale parameter c (must be positive)
928    pub fn new(scale: f64) -> ApexSolverResult<Self> {
929        if scale <= 0.0 {
930            return Err(
931                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
932            );
933        }
934        Ok(AndrewsWaveLoss {
935            scale,
936            scale2: scale * scale,
937            threshold: std::f64::consts::PI * scale,
938        })
939    }
940}
941
942impl LossFunction for AndrewsWaveLoss {
943    fn evaluate(&self, s: f64) -> [f64; 3] {
944        let x = s.sqrt();
945
946        if x > self.threshold {
947            // Complete suppression beyond π*c
948            [2.0 * self.scale2, 0.0, 0.0]
949        } else {
950            let arg = x / self.scale;
951            let sin_val = arg.sin();
952            let cos_val = arg.cos();
953
954            [
955                self.scale2 * (1.0 - cos_val),                       // ρ(s)
956                0.5 * sin_val,                                       // ρ'(s)
957                (0.25 / self.scale) * cos_val / x.max(f64::EPSILON), // ρ''(s)
958            ]
959        }
960    }
961}
962
963/// Ramsay Ea loss function with exponential decay.
964///
965/// The Ramsay Ea loss uses exponential weighting to provide smooth, strong
966/// downweighting of outliers.
967///
968/// # Mathematical Definition
969///
970/// ```text
971/// ρ(s) = a⁻² * (1 - exp(-a|x|) * (1 + a|x|))
972/// ```
973///
974/// where `a` is the scale parameter, `x = √s`, and `s = ||r||²`.
975///
976/// # Properties
977///
978/// - **Exponential decay**: Smooth weight reduction
979/// - **Strongly robust**: Good for heavy outliers
980/// - **Smooth**: Continuous derivatives
981/// - **Non-convex**: Needs good initialization
982///
983/// # Scale Parameter Selection
984///
985/// - **0.3**: Standard tuning constant
986/// - **0.2-0.25**: More aggressive
987/// - **0.35-0.5**: More lenient
988///
989/// # Example
990///
991/// ```
992/// use apex_solver::core::loss_functions::{LossFunction, RamsayEaLoss};
993/// # use apex_solver::error::ApexSolverResult;
994/// # fn example() -> ApexSolverResult<()> {
995///
996/// let ramsay = RamsayEaLoss::new(0.3)?;
997///
998/// let [rho, rho_prime, _] = ramsay.evaluate(10.0);
999/// // Exponential downweighting
1000/// # Ok(())
1001/// # }
1002/// # example().unwrap();
1003/// ```
1004#[derive(Debug, Clone)]
1005pub struct RamsayEaLoss {
1006    scale: f64,
1007    inv_scale2: f64,
1008}
1009
1010impl RamsayEaLoss {
1011    /// Create a new Ramsay Ea loss function.
1012    ///
1013    /// # Arguments
1014    ///
1015    /// * `scale` - The scale parameter a (must be positive)
1016    pub fn new(scale: f64) -> ApexSolverResult<Self> {
1017        if scale <= 0.0 {
1018            return Err(
1019                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
1020            );
1021        }
1022        Ok(RamsayEaLoss {
1023            scale,
1024            inv_scale2: 1.0 / (scale * scale),
1025        })
1026    }
1027}
1028
1029impl LossFunction for RamsayEaLoss {
1030    fn evaluate(&self, s: f64) -> [f64; 3] {
1031        let x = s.sqrt();
1032        let ax = self.scale * x;
1033        let exp_term = (-ax).exp();
1034
1035        // ρ(s) = a⁻² * (1 - exp(-a|x|) * (1 + a|x|))
1036        let rho = self.inv_scale2 * (1.0 - exp_term * (1.0 + ax));
1037
1038        // ρ'(s) = (1/2) * exp(-a|x|)
1039        let rho_prime = 0.5 * exp_term;
1040
1041        // ρ''(s) = -(a/4|x|) * exp(-a|x|)
1042        let rho_double_prime = -(self.scale / (4.0 * x.max(f64::EPSILON))) * exp_term;
1043
1044        [rho, rho_prime, rho_double_prime]
1045    }
1046}
1047
1048/// Trimmed mean loss function with hard threshold.
1049///
1050/// The trimmed mean loss is the simplest redescending estimator, applying a hard
1051/// cutoff at a threshold. Residuals below the threshold use L2 loss, those above
1052/// contribute a constant.
1053///
1054/// # Mathematical Definition
1055///
1056/// For s ≤ c²:
1057/// ```text
1058/// ρ(s) = s/2
1059/// ρ'(s) = 1/2
1060/// ρ''(s) = 0
1061/// ```
1062///
1063/// For s > c²:
1064/// ```text
1065/// ρ(s) = c²/2
1066/// ρ'(s) = 0
1067/// ρ''(s) = 0
1068/// ```
1069///
1070/// where `c` is the scale parameter and `s = ||r||²`.
1071///
1072/// # Properties
1073///
1074/// - **Simple**: Easiest to understand and implement
1075/// - **Hard cutoff**: Discontinuous weight function
1076/// - **Robust**: Completely ignores large outliers
1077/// - **Unstable**: Discontinuity can cause optimization issues
1078///
1079/// # Scale Parameter Selection
1080///
1081/// - **2.0**: Standard tuning constant
1082/// - **1.5**: More aggressive
1083/// - **3.0**: More lenient
1084///
1085/// # Example
1086///
1087/// ```
1088/// use apex_solver::core::loss_functions::{LossFunction, TrimmedMeanLoss};
1089/// # use apex_solver::error::ApexSolverResult;
1090/// # fn example() -> ApexSolverResult<()> {
1091///
1092/// let trimmed = TrimmedMeanLoss::new(2.0)?;
1093///
1094/// let [rho, rho_prime, _] = trimmed.evaluate(5.0);
1095/// assert_eq!(rho_prime, 0.0); // Beyond threshold
1096/// # Ok(())
1097/// # }
1098/// # example().unwrap();
1099/// ```
1100#[derive(Debug, Clone)]
1101pub struct TrimmedMeanLoss {
1102    scale2: f64,
1103}
1104
1105impl TrimmedMeanLoss {
1106    /// Create a new trimmed mean loss function.
1107    ///
1108    /// # Arguments
1109    ///
1110    /// * `scale` - The scale parameter c (must be positive)
1111    pub fn new(scale: f64) -> ApexSolverResult<Self> {
1112        if scale <= 0.0 {
1113            return Err(
1114                CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
1115            );
1116        }
1117        Ok(TrimmedMeanLoss {
1118            scale2: scale * scale,
1119        })
1120    }
1121}
1122
1123impl LossFunction for TrimmedMeanLoss {
1124    fn evaluate(&self, s: f64) -> [f64; 3] {
1125        if s <= self.scale2 {
1126            [s / 2.0, 0.5, 0.0]
1127        } else {
1128            [self.scale2 / 2.0, 0.0, 0.0]
1129        }
1130    }
1131}
1132
1133/// Generalized Lp norm loss function.
1134///
1135/// The Lp norm loss allows flexible control over robustness through the p parameter.
1136/// It interpolates between L1 (p=1) and L2 (p=2) losses.
1137///
1138/// # Mathematical Definition
1139///
1140/// ```text
1141/// ρ(s) = |x|^p = s^(p/2)
1142/// ρ'(s) = (p/2) * s^(p/2-1)
1143/// ρ''(s) = (p/2) * (p/2-1) * s^(p/2-2)
1144/// ```
1145///
1146/// where `p` is the norm parameter, `x = √s`, and `s = ||r||²`.
1147///
1148/// # Properties
1149///
1150/// - **Flexible**: Tune robustness with p parameter
1151/// - **p=2**: L2 norm (standard least squares)
1152/// - **p=1**: L1 norm (robust median estimator)
1153/// - **p<1**: Very robust (non-convex)
1154/// - **1<p<2**: Compromise between L1 and L2
1155///
1156/// # Parameter Selection
1157///
1158/// - **p = 2.0**: No robustness (L2 loss)
1159/// - **p = 1.5**: Moderate robustness
1160/// - **p = 1.2**: Strong robustness
1161/// - **p = 1.0**: L1 loss (median)
1162///
1163/// # Example
1164///
1165/// ```
1166/// use apex_solver::core::loss_functions::{LossFunction, LpNormLoss};
1167/// # use apex_solver::error::ApexSolverResult;
1168/// # fn example() -> ApexSolverResult<()> {
1169///
1170/// let lp = LpNormLoss::new(1.5)?;
1171///
1172/// let [rho, rho_prime, _] = lp.evaluate(4.0);
1173/// // Between L1 and L2 behavior
1174/// # Ok(())
1175/// # }
1176/// # example().unwrap();
1177/// ```
1178#[derive(Debug, Clone)]
1179pub struct LpNormLoss {
1180    p: f64,
1181}
1182
1183impl LpNormLoss {
1184    /// Create a new Lp norm loss function.
1185    ///
1186    /// # Arguments
1187    ///
1188    /// * `p` - The norm parameter (0 < p ≤ 2 for practical use)
1189    pub fn new(p: f64) -> ApexSolverResult<Self> {
1190        if p <= 0.0 {
1191            return Err(CoreError::InvalidInput("p must be positive".to_string()).into());
1192        }
1193        Ok(LpNormLoss { p })
1194    }
1195}
1196
1197impl LossFunction for LpNormLoss {
1198    fn evaluate(&self, s: f64) -> [f64; 3] {
1199        if s < f64::EPSILON {
1200            return [s, 1.0, 0.0];
1201        }
1202
1203        let exp_rho = self.p / 2.0;
1204        let exp_rho_prime = exp_rho - 1.0;
1205        let exp_rho_double_prime = exp_rho_prime - 1.0;
1206
1207        [
1208            s.powf(exp_rho),                                        // ρ(s) = s^(p/2)
1209            exp_rho * s.powf(exp_rho_prime),                        // ρ'(s) = (p/2) * s^(p/2-1)
1210            exp_rho * exp_rho_prime * s.powf(exp_rho_double_prime), // ρ''(s)
1211        ]
1212    }
1213}
1214
1215/// Barron's general and adaptive robust loss function (CVPR 2019).
1216///
1217/// The Barron loss is a unified framework that encompasses many classic loss functions
1218/// (L2, Charbonnier, Cauchy, Geman-McClure, Welsch) through a single shape parameter α.
1219/// It can also adapt α automatically during optimization.
1220///
1221/// # Mathematical Definition
1222///
1223/// ```text
1224/// ρ(s, α, c) = (|α|/c²) * (|(x/c)² * |α|/2 + 1|^(α/2) - 1)
1225/// ```
1226///
1227/// where `α` controls robustness, `c` is scale, `x = √s`, and `s = ||r||²`.
1228///
1229/// # Special Cases (by α value)
1230///
1231/// - **α = 2**: L2 loss (no robustness)
1232/// - **α = 1**: Charbonnier/Pseudo-Huber loss
1233/// - **α = 0**: Cauchy loss
1234/// - **α = -1**: Welsch loss
1235/// - **α = -2**: Geman-McClure loss
1236/// - **α → -∞**: L0 "norm" (binary)
1237///
1238/// # Properties
1239///
1240/// - **Unified**: Single framework for many loss functions
1241/// - **Adaptive**: Can learn optimal α during training
1242/// - **Smooth**: Continuously differentiable in α
1243/// - **Modern**: State-of-the-art from computer vision research
1244///
1245/// # Parameter Selection
1246///
1247/// **Fixed α (manual tuning):**
1248/// - α = 2.0: Clean data, no outliers
1249/// - α = 0.0 to 1.0: Moderate outliers
1250/// - α = -2.0 to 0.0: Heavy outliers
1251///
1252/// **Scale c:**
1253/// - c = 1.0: Standard choice
1254/// - Adjust based on expected residual magnitude
1255///
1256/// # References
1257///
1258/// Barron, J. T. (2019). A general and adaptive robust loss function.
1259/// IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).
1260///
1261/// # Example
1262///
1263/// ```
1264/// use apex_solver::core::loss_functions::{LossFunction, BarronGeneralLoss};
1265/// # use apex_solver::error::ApexSolverResult;
1266/// # fn example() -> ApexSolverResult<()> {
1267///
1268/// // Cauchy-like behavior
1269/// let barron = BarronGeneralLoss::new(0.0, 1.0)?;
1270///
1271/// let [rho, rho_prime, _] = barron.evaluate(4.0);
1272/// // Behaves like Cauchy loss
1273/// # Ok(())
1274/// # }
1275/// # example().unwrap();
1276/// ```
1277#[derive(Debug, Clone)]
1278pub struct BarronGeneralLoss {
1279    alpha: f64,
1280    scale: f64,
1281    scale2: f64,
1282}
1283
1284impl BarronGeneralLoss {
1285    /// Create a new Barron general robust loss function.
1286    ///
1287    /// # Arguments
1288    ///
1289    /// * `alpha` - The shape parameter (controls robustness)
1290    /// * `scale` - The scale parameter c (must be positive)
1291    pub fn new(alpha: f64, scale: f64) -> ApexSolverResult<Self> {
1292        if scale <= 0.0 {
1293            return Err(CoreError::InvalidInput("scale must be positive".to_string()).into());
1294        }
1295        Ok(BarronGeneralLoss {
1296            alpha,
1297            scale,
1298            scale2: scale * scale,
1299        })
1300    }
1301}
1302
1303impl LossFunction for BarronGeneralLoss {
1304    fn evaluate(&self, s: f64) -> [f64; 3] {
1305        // Handle special case α ≈ 0 (Cauchy loss)
1306        if self.alpha.abs() < 1e-6 {
1307            let denom = 1.0 + s / self.scale2;
1308            let inv = 1.0 / denom;
1309            return [
1310                (self.scale2 / 2.0) * denom.ln(),
1311                inv.max(f64::MIN),
1312                -inv * inv / self.scale2,
1313            ];
1314        }
1315
1316        // Handle special case α ≈ 2 (L2 loss)
1317        if (self.alpha - 2.0).abs() < 1e-6 {
1318            return [s, 1.0, 0.0];
1319        }
1320
1321        // General case
1322        let x = s.sqrt();
1323        let normalized = x / self.scale;
1324        let normalized2 = normalized * normalized;
1325
1326        let inner = self.alpha.abs() / 2.0 * normalized2 + 1.0;
1327        let power = inner.powf(self.alpha / 2.0);
1328
1329        // ρ(s) = (|α|/c²) * (power - 1)
1330        let rho = (self.alpha.abs() / self.scale2) * (power - 1.0);
1331
1332        // ρ'(s) = (1/2) * inner^(α/2 - 1)
1333        let rho_prime = 0.5 * inner.powf(self.alpha / 2.0 - 1.0);
1334
1335        // ρ''(s) = (α - 2)/(4c²) * inner^(α/2 - 2)
1336        let rho_double_prime =
1337            (self.alpha - 2.0) / (4.0 * self.scale2) * inner.powf(self.alpha / 2.0 - 2.0);
1338
1339        [rho, rho_prime, rho_double_prime]
1340    }
1341}
1342
1343/// Student's t-distribution loss function (robust M-estimator).
1344///
1345/// The t-distribution loss is derived from the negative log-likelihood of Student's
1346/// t-distribution. It provides heavy tails for robustness against outliers, with the
1347/// degrees of freedom parameter ν controlling the tail heaviness.
1348///
1349/// # Mathematical Definition
1350///
1351/// ```text
1352/// ρ(s) = (ν + 1)/2 · log(1 + s/ν)
1353/// ρ'(s) = (ν + 1)/(2(ν + s))
1354/// ρ''(s) = -(ν + 1)/(2(ν + s)²)
1355/// ```
1356///
1357/// where `ν` is the degrees of freedom and `s = ||r||²` is the squared residual norm.
1358///
1359/// # Properties
1360///
1361/// - **Heavy tails**: Provides robustness through heavier tails than Gaussian
1362/// - **Parameter control**: Small ν → heavy tails (more robust), large ν → Gaussian (less robust)
1363/// - **Well-founded**: Based on maximum likelihood estimation with t-distribution
1364/// - **Smooth**: Continuous derivatives for all s > 0
1365///
1366/// # Degrees of Freedom Selection
1367///
1368/// - **ν = 3-4**: Very robust, heavy outlier suppression
1369/// - **ν = 5**: Recommended default, good balance
1370/// - **ν = 10**: Moderate robustness
1371/// - **ν → ∞**: Converges to Gaussian (L2 loss)
1372///
1373/// # Use Cases
1374///
1375/// - Robust regression with unknown outlier distribution
1376/// - SLAM and pose graph optimization with loop closure outliers
1377/// - Bundle adjustment with incorrect feature matches
1378/// - Any optimization problem with heavy-tailed noise
1379///
1380/// # References
1381///
1382/// - Student's t-distribution is widely used in robust statistics
1383/// - Applied in robust SLAM (e.g., Chebrolu et al. 2021, Agarwal et al.)
1384///
1385/// # Example
1386///
1387/// ```
1388/// use apex_solver::core::loss_functions::{LossFunction, TDistributionLoss};
1389/// # use apex_solver::error::ApexSolverResult;
1390/// # fn example() -> ApexSolverResult<()> {
1391///
1392/// let t_loss = TDistributionLoss::new(5.0)?;
1393///
1394/// let [rho, rho_prime, _] = t_loss.evaluate(4.0);
1395/// // Robust to outliers with heavy tails
1396/// # Ok(())
1397/// # }
1398/// # example().unwrap();
1399/// ```
1400#[derive(Debug, Clone)]
1401pub struct TDistributionLoss {
1402    nu: f64,             // Degrees of freedom
1403    half_nu_plus_1: f64, // (ν + 1)/2 (cached)
1404}
1405
1406impl TDistributionLoss {
1407    /// Create a new Student's t-distribution loss function.
1408    ///
1409    /// # Arguments
1410    ///
1411    /// * `nu` - Degrees of freedom (must be positive)
1412    ///
1413    /// # Recommended Values
1414    ///
1415    /// - ν = 5.0: Default, good balance between robustness and efficiency
1416    /// - ν = 3.0-4.0: More robust to outliers
1417    /// - ν = 10.0: Less aggressive, closer to Gaussian
1418    pub fn new(nu: f64) -> ApexSolverResult<Self> {
1419        if nu <= 0.0 {
1420            return Err(
1421                CoreError::InvalidInput("degrees of freedom must be positive".to_string()).into(),
1422            );
1423        }
1424        Ok(TDistributionLoss {
1425            nu,
1426            half_nu_plus_1: (nu + 1.0) / 2.0,
1427        })
1428    }
1429}
1430
1431impl LossFunction for TDistributionLoss {
1432    fn evaluate(&self, s: f64) -> [f64; 3] {
1433        // ρ(s) = (ν + 1)/2 · log(1 + s/ν)
1434        let inner = 1.0 + s / self.nu;
1435        let rho = self.half_nu_plus_1 * inner.ln();
1436
1437        // ρ'(s) = (ν + 1)/(2(ν + s))
1438        let denom = self.nu + s;
1439        let rho_prime = self.half_nu_plus_1 / denom;
1440
1441        // ρ''(s) = -(ν + 1)/(2(ν + s)²)
1442        let rho_double_prime = -self.half_nu_plus_1 / (denom * denom);
1443
1444        [rho, rho_prime, rho_double_prime]
1445    }
1446}
1447
1448/// Adaptive Barron loss function (simplified version).
1449///
1450/// This is a convenience wrapper around `BarronGeneralLoss` with recommended default
1451/// parameters for adaptive robust optimization. Based on Chebrolu et al. (2021) RAL paper
1452/// "Adaptive Robust Kernels for Non-Linear Least Squares Problems".
1453///
1454/// # Mathematical Definition
1455///
1456/// For α ≠ 0:
1457/// ```text
1458/// ρ(s) = |α - 2|/α · ((s/c² + 1)^(α/2) - 1)
1459/// ```
1460///
1461/// For α = 0 (Cauchy-like):
1462/// ```text
1463/// ρ(s) = log(s/(2c²) + 1)
1464/// ```
1465///
1466/// where `α` is the shape parameter and `c` is the scale parameter.
1467///
1468/// # Properties
1469///
1470/// - **Adaptive**: Can approximate many M-estimators (Huber, Cauchy, Geman-McClure, etc.)
1471/// - **Shape parameter α**: Controls the robustness level
1472/// - **Scale parameter c**: Controls the transition point
1473/// - **Unified framework**: Single loss function family
1474///
1475/// # Parameter Selection
1476///
1477/// **Default (α = 0.0, c = 1.0):**
1478/// - Cauchy-like behavior
1479/// - Good general-purpose robust loss
1480/// - Suitable for moderate to heavy outliers
1481///
1482/// **Other values:**
1483/// - α = 2.0: L2 loss (no robustness)
1484/// - α = 1.0: Pseudo-Huber/Charbonnier-like
1485/// - α = -2.0: Geman-McClure-like (very robust)
1486///
1487/// # Note on Adaptivity
1488///
1489/// This simplified version uses fixed parameters. The full adaptive version from
1490/// Chebrolu et al. requires iterative estimation of α based on residual distribution,
1491/// which would require integration into the optimizer's main loop.
1492///
1493/// # References
1494///
1495/// Chebrolu, N., Läbe, T., Vysotska, O., Behley, J., & Stachniss, C. (2021).
1496/// Adaptive robust kernels for non-linear least squares problems.
1497/// IEEE Robotics and Automation Letters, 6(2), 2240-2247.
1498///
1499/// Barron, J. T. (2019). A general and adaptive robust loss function.
1500/// IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).
1501///
1502/// # Example
1503///
1504/// ```
1505/// use apex_solver::core::loss_functions::{LossFunction, AdaptiveBarronLoss};
1506/// # use apex_solver::error::ApexSolverResult;
1507/// # fn example() -> ApexSolverResult<()> {
1508///
1509/// // Default Cauchy-like behavior
1510/// let adaptive = AdaptiveBarronLoss::new(0.0, 1.0)?;
1511///
1512/// let [rho, rho_prime, _] = adaptive.evaluate(4.0);
1513/// // Adaptive robust behavior
1514/// # Ok(())
1515/// # }
1516/// # example().unwrap();
1517/// ```
1518#[derive(Debug, Clone)]
1519pub struct AdaptiveBarronLoss {
1520    inner: BarronGeneralLoss,
1521}
1522
1523impl AdaptiveBarronLoss {
1524    /// Create a new adaptive Barron loss function.
1525    ///
1526    /// # Arguments
1527    ///
1528    /// * `alpha` - Shape parameter (default: 0.0 for Cauchy-like)
1529    /// * `scale` - Scale parameter c (must be positive)
1530    ///
1531    /// # Recommended Defaults
1532    ///
1533    /// - α = 0.0, c = 1.0: General-purpose robust loss
1534    pub fn new(alpha: f64, scale: f64) -> ApexSolverResult<Self> {
1535        Ok(AdaptiveBarronLoss {
1536            inner: BarronGeneralLoss::new(alpha, scale)?,
1537        })
1538    }
1539
1540    /// Create default instance without validation (alpha=0.0, scale=1.0).
1541    ///
1542    /// This is safe because the default parameters are mathematically valid.
1543    const fn new_default() -> Self {
1544        AdaptiveBarronLoss {
1545            inner: BarronGeneralLoss {
1546                alpha: 0.0,
1547                scale: 1.0,
1548                scale2: 1.0,
1549            },
1550        }
1551    }
1552}
1553
1554impl LossFunction for AdaptiveBarronLoss {
1555    fn evaluate(&self, s: f64) -> [f64; 3] {
1556        self.inner.evaluate(s)
1557    }
1558}
1559
1560impl Default for AdaptiveBarronLoss {
1561    /// Creates default AdaptiveBarronLoss with validated parameters (alpha=0.0, scale=1.0).
1562    fn default() -> Self {
1563        Self::new_default()
1564    }
1565}
1566
1567#[cfg(test)]
1568mod tests {
1569    use super::*;
1570
1571    type TestResult = Result<(), Box<dyn std::error::Error>>;
1572
1573    const EPSILON: f64 = 1e-6;
1574
1575    /// Helper function to test derivatives numerically
1576    fn numerical_derivative(loss: &dyn LossFunction, s: f64, h: f64) -> (f64, f64) {
1577        let [rho_plus, _, _] = loss.evaluate(s + h);
1578        let [rho_minus, _, _] = loss.evaluate(s - h);
1579        let [rho, _, _] = loss.evaluate(s);
1580
1581        // First derivative: f'(x) ≈ (f(x+h) - f(x-h)) / (2h)
1582        let rho_prime_numerical = (rho_plus - rho_minus) / (2.0 * h);
1583
1584        // Second derivative: f''(x) ≈ (f(x+h) - 2f(x) + f(x-h)) / h²
1585        let rho_double_prime_numerical = (rho_plus - 2.0 * rho + rho_minus) / (h * h);
1586
1587        (rho_prime_numerical, rho_double_prime_numerical)
1588    }
1589
1590    #[test]
1591    fn test_l2_loss() -> TestResult {
1592        let loss = L2Loss;
1593
1594        // Test at s = 0
1595        let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
1596        assert_eq!(rho, 0.0);
1597        assert_eq!(rho_prime, 1.0);
1598        assert_eq!(rho_double_prime, 0.0);
1599
1600        // Test at s = 4.0
1601        let [rho, rho_prime, rho_double_prime] = loss.evaluate(4.0);
1602        assert_eq!(rho, 4.0);
1603        assert_eq!(rho_prime, 1.0);
1604        assert_eq!(rho_double_prime, 0.0);
1605
1606        Ok(())
1607    }
1608
1609    #[test]
1610    fn test_l1_loss() -> TestResult {
1611        let loss = L1Loss;
1612
1613        // Test at s = 0 (should handle gracefully)
1614        let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
1615        assert_eq!(rho, 0.0);
1616        assert!(rho_prime.is_finite());
1617        assert!(rho_double_prime.is_finite());
1618
1619        // Test at s = 4.0 (√s = 2.0, ρ(s) = 2√s = 4.0)
1620        let [rho, rho_prime, _] = loss.evaluate(4.0);
1621        assert!((rho - 4.0).abs() < EPSILON); // ρ(s) = 2√s = 2*2 = 4
1622        assert!((rho_prime - 0.5).abs() < EPSILON); // ρ'(s) = 1/√s = 1/2 = 0.5
1623
1624        Ok(())
1625    }
1626
1627    #[test]
1628    fn test_fair_loss() -> TestResult {
1629        let loss = FairLoss::new(1.3999)?;
1630
1631        // Test at s = 0 (special case handling)
1632        let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
1633        assert_eq!(rho, 0.0);
1634        assert_eq!(rho_prime, 1.0);
1635        assert_eq!(rho_double_prime, 0.0);
1636
1637        // Test inlier region (s = 1.0)
1638        let [_, rho_prime, _] = loss.evaluate(1.0);
1639        assert!(rho_prime > 0.2 && rho_prime < 0.25); // ρ'(s) = 1/(2(c+|x|)) where x=1, c≈1.4 → ~0.208
1640
1641        // Test outlier region (s = 100.0)
1642        let [_, rho_prime_outlier, _] = loss.evaluate(100.0);
1643        assert!(rho_prime_outlier < rho_prime); // Weight should decrease
1644
1645        // Test that derivatives are finite and reasonable
1646        let [_, rho_prime_4, rho_double_prime_4] = loss.evaluate(4.0);
1647        assert!(rho_prime_4.is_finite() && rho_prime_4 > 0.0);
1648        assert!(rho_double_prime_4.is_finite() && rho_double_prime_4 < 0.0); // Convex near origin
1649
1650        Ok(())
1651    }
1652
1653    #[test]
1654    fn test_geman_mcclure_loss() -> TestResult {
1655        let loss = GemanMcClureLoss::new(1.0)?;
1656
1657        // Test at s = 0
1658        let [rho, rho_prime, _] = loss.evaluate(0.0);
1659        assert_eq!(rho, 0.0);
1660        assert!((rho_prime - 1.0).abs() < EPSILON);
1661
1662        // Test outlier suppression
1663        let [_, rho_prime_small, _] = loss.evaluate(1.0);
1664        let [_, rho_prime_large, _] = loss.evaluate(100.0);
1665        assert!(rho_prime_large < rho_prime_small);
1666        assert!(rho_prime_large < 0.1); // Strong suppression
1667
1668        // Verify derivatives
1669        let s = 2.0;
1670        let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1671        let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1672        assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1673        assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
1674
1675        Ok(())
1676    }
1677
1678    #[test]
1679    fn test_welsch_loss() -> TestResult {
1680        let loss = WelschLoss::new(2.9846)?;
1681
1682        // Test at s = 0: ρ'(0) = 0.5 * exp(0) = 0.5
1683        let [rho, rho_prime, _] = loss.evaluate(0.0);
1684        assert_eq!(rho, 0.0);
1685        assert!((rho_prime - 0.5).abs() < EPSILON); // ρ'(s) = 0.5 * exp(-s/c²), at s=0: 0.5
1686
1687        // Test redescending behavior
1688        let [_, rho_prime_10, _] = loss.evaluate(10.0);
1689        let [_, rho_prime_100, _] = loss.evaluate(100.0);
1690        assert!(rho_prime_100 < rho_prime_10);
1691        assert!(rho_prime_100 < 0.01); // Nearly zero for large outliers
1692
1693        // Verify derivatives
1694        let s = 5.0;
1695        let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1696        let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1697        assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1698        assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
1699
1700        Ok(())
1701    }
1702
1703    #[test]
1704    fn test_tukey_biweight_loss() -> TestResult {
1705        let loss = TukeyBiweightLoss::new(4.6851)?;
1706
1707        // Test at s = 0: ρ'(0) = 0.5 * (1-0)^2 = 0.5
1708        let [rho, rho_prime, _] = loss.evaluate(0.0);
1709        assert_eq!(rho, 0.0);
1710        assert!((rho_prime - 0.5).abs() < EPSILON); // ρ'(s) = 0.5 * (1 - ratio²)², at s=0: 0.5
1711
1712        // Test within threshold
1713        let scale2 = 4.6851 * 4.6851;
1714        let [_, rho_prime_in, _] = loss.evaluate(scale2 * 0.5);
1715        assert!(rho_prime_in > 0.05);
1716
1717        // Test beyond threshold (complete suppression)
1718        let [_, rho_prime_out, _] = loss.evaluate(scale2 * 1.5);
1719        assert_eq!(rho_prime_out, 0.0);
1720
1721        // Test that derivatives are finite and reasonable
1722        let [_, rho_prime_5, rho_double_prime_5] = loss.evaluate(5.0);
1723        assert!(rho_prime_5.is_finite() && rho_prime_5 > 0.0);
1724        assert!(rho_double_prime_5.is_finite() && rho_double_prime_5 < 0.0);
1725
1726        Ok(())
1727    }
1728
1729    #[test]
1730    fn test_andrews_wave_loss() -> TestResult {
1731        let loss = AndrewsWaveLoss::new(1.339)?;
1732
1733        // Test at s = 0: ρ'(0) = 0.5 * sin(0) = 0
1734        let [rho, rho_prime, _] = loss.evaluate(0.0);
1735        assert_eq!(rho, 0.0);
1736        assert!(rho_prime.abs() < EPSILON); // ρ'(s) = 0.5 * sin(x/c), at s=0: 0
1737
1738        // Test within threshold (small s where sin(x/c) gives moderate weight)
1739        let [_, rho_prime_in, _] = loss.evaluate(1.0);
1740        assert!(rho_prime_in > 0.33 && rho_prime_in < 0.35); // ~0.3397
1741
1742        // Test beyond threshold
1743        let scale = 1.339;
1744        let [_, rho_prime_out, _] = loss.evaluate((scale * std::f64::consts::PI + 0.1).powi(2));
1745        assert!(rho_prime_out.abs() < 0.01);
1746
1747        // Test that derivatives are finite
1748        let [_, rho_prime_1, rho_double_prime_1] = loss.evaluate(1.0);
1749        assert!(rho_prime_1.is_finite() && rho_prime_1 > 0.0);
1750        assert!(rho_double_prime_1.is_finite());
1751
1752        Ok(())
1753    }
1754
1755    #[test]
1756    fn test_ramsay_ea_loss() -> TestResult {
1757        let loss = RamsayEaLoss::new(0.3)?;
1758
1759        // Test at s = 0 (should handle gracefully)
1760        let [rho, _, _] = loss.evaluate(0.0);
1761        assert_eq!(rho, 0.0);
1762
1763        // Test exponential decay behavior
1764        let [_, rho_prime_small, _] = loss.evaluate(1.0);
1765        let [_, rho_prime_large, _] = loss.evaluate(100.0);
1766        assert!(rho_prime_large < rho_prime_small);
1767
1768        // Verify derivatives
1769        let s = 4.0;
1770        let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1771        let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1772        assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1773        assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
1774
1775        Ok(())
1776    }
1777
1778    #[test]
1779    fn test_trimmed_mean_loss() -> TestResult {
1780        let loss = TrimmedMeanLoss::new(2.0)?;
1781        let scale2 = 4.0;
1782
1783        // Test below threshold (L2 behavior)
1784        let [rho, rho_prime, rho_double_prime] = loss.evaluate(2.0);
1785        assert!((rho - 1.0).abs() < EPSILON);
1786        assert!((rho_prime - 0.5).abs() < EPSILON);
1787        assert_eq!(rho_double_prime, 0.0);
1788
1789        // Test above threshold (constant)
1790        let [rho_out, rho_prime_out, rho_double_prime_out] = loss.evaluate(10.0);
1791        assert!((rho_out - scale2 / 2.0).abs() < EPSILON);
1792        assert_eq!(rho_prime_out, 0.0);
1793        assert_eq!(rho_double_prime_out, 0.0);
1794
1795        Ok(())
1796    }
1797
1798    #[test]
1799    fn test_lp_norm_loss() -> TestResult {
1800        // Test L1 (p = 1)
1801        let l1 = LpNormLoss::new(1.0)?;
1802        let [rho_l1, _, _] = l1.evaluate(4.0);
1803        assert!((rho_l1 - 2.0).abs() < EPSILON); // ||r||₁ = 2
1804
1805        // Test L2 (p = 2)
1806        let l2 = LpNormLoss::new(2.0)?;
1807        let [rho_l2, rho_prime_l2, rho_double_prime_l2] = l2.evaluate(4.0);
1808        assert!((rho_l2 - 4.0).abs() < EPSILON);
1809        assert!((rho_prime_l2 - 1.0).abs() < EPSILON);
1810        assert_eq!(rho_double_prime_l2, 0.0);
1811
1812        // Test fractional p (p = 0.5)
1813        let l05 = LpNormLoss::new(0.5)?;
1814        let [_, rho_prime_05, _] = l05.evaluate(4.0);
1815        assert!(rho_prime_05 < 1.0); // Robust behavior
1816
1817        // Verify derivatives for p = 1.5
1818        let loss = LpNormLoss::new(1.5)?;
1819        let s = 4.0;
1820        let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1821        let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1822        assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1823        assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
1824
1825        Ok(())
1826    }
1827
1828    #[test]
1829    fn test_barron_general_loss_special_cases() -> TestResult {
1830        // α = 0 (Cauchy-like)
1831        let cauchy = BarronGeneralLoss::new(0.0, 1.0)?;
1832        let [_, rho_prime_small, _] = cauchy.evaluate(1.0);
1833        let [_, rho_prime_large, _] = cauchy.evaluate(100.0);
1834        assert!(rho_prime_large < rho_prime_small);
1835
1836        // α = 2 (L2)
1837        let l2 = BarronGeneralLoss::new(2.0, 1.0)?;
1838        let [rho, rho_prime, rho_double_prime] = l2.evaluate(4.0);
1839        assert!((rho - 4.0).abs() < EPSILON);
1840        assert!((rho_prime - 1.0).abs() < EPSILON);
1841        assert!(rho_double_prime.abs() < EPSILON);
1842
1843        // α = 1 (Charbonnier-like)
1844        let charbonnier = BarronGeneralLoss::new(1.0, 1.0)?;
1845        let [_, rho_prime_char, _] = charbonnier.evaluate(4.0);
1846        assert!(rho_prime_char > 0.0 && rho_prime_char < 1.0);
1847
1848        // Test α = -2 (Geman-McClure-like) - strong outlier suppression
1849        let gm = BarronGeneralLoss::new(-2.0, 1.0)?;
1850        let [_, rho_prime_small, _] = gm.evaluate(1.0);
1851        let [_, rho_prime_large, _] = gm.evaluate(100.0);
1852        assert!(rho_prime_large < rho_prime_small); // Redescending behavior
1853        assert!(rho_prime_large < 0.1); // Strong suppression
1854
1855        Ok(())
1856    }
1857
1858    #[test]
1859    fn test_constructor_validation() -> TestResult {
1860        // Test that negative or zero scale parameters are rejected
1861        assert!(FairLoss::new(0.0).is_err());
1862        assert!(FairLoss::new(-1.0).is_err());
1863        assert!(GemanMcClureLoss::new(0.0).is_err());
1864        assert!(WelschLoss::new(-1.0).is_err());
1865        assert!(TukeyBiweightLoss::new(0.0).is_err());
1866        assert!(AndrewsWaveLoss::new(-1.0).is_err());
1867        assert!(RamsayEaLoss::new(0.0).is_err());
1868        assert!(TrimmedMeanLoss::new(-1.0).is_err());
1869        assert!(BarronGeneralLoss::new(0.0, 0.0).is_err());
1870        assert!(BarronGeneralLoss::new(1.0, -1.0).is_err());
1871
1872        // Test that p ≤ 0 is rejected for LpNormLoss
1873        assert!(LpNormLoss::new(0.0).is_err());
1874        assert!(LpNormLoss::new(-1.0).is_err());
1875
1876        // Test valid constructors
1877        assert!(FairLoss::new(1.0).is_ok());
1878        assert!(LpNormLoss::new(1.5).is_ok());
1879        assert!(BarronGeneralLoss::new(1.0, 1.0).is_ok());
1880
1881        Ok(())
1882    }
1883
1884    #[test]
1885    fn test_loss_comparison() -> TestResult {
1886        // Compare robustness: L2 vs Huber vs Cauchy at outlier
1887        let s_outlier = 100.0;
1888
1889        let l2 = L2Loss;
1890        let huber = HuberLoss::new(1.345)?;
1891        let cauchy = CauchyLoss::new(2.3849)?;
1892
1893        let [_, w_l2, _] = l2.evaluate(s_outlier);
1894        let [_, w_huber, _] = huber.evaluate(s_outlier);
1895        let [_, w_cauchy, _] = cauchy.evaluate(s_outlier);
1896
1897        // L2 should give highest weight (no robustness)
1898        assert!(w_l2 > w_huber);
1899        assert!(w_huber > w_cauchy);
1900
1901        // Cauchy should strongly suppress outliers
1902        assert!(w_cauchy < 0.1);
1903
1904        Ok(())
1905    }
1906
1907    #[test]
1908    fn test_t_distribution_loss() -> TestResult {
1909        let loss = TDistributionLoss::new(5.0)?;
1910
1911        // Test at s = 0 (should be well-defined)
1912        let [rho, rho_prime, _] = loss.evaluate(0.0);
1913        assert_eq!(rho, 0.0);
1914        assert!((rho_prime - 0.6).abs() < 0.01); // (ν+1)/(2ν) = 6/10 = 0.6
1915
1916        // Test heavy tail behavior (downweighting)
1917        let [_, rho_prime_small, _] = loss.evaluate(1.0);
1918        let [_, rho_prime_large, _] = loss.evaluate(100.0);
1919        assert!(rho_prime_large < rho_prime_small);
1920
1921        // Test that large outliers are strongly downweighted
1922        assert!(rho_prime_large < 0.1);
1923
1924        // Verify derivatives numerically
1925        let s = 4.0;
1926        let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1927        let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1928        assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1929        assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-4);
1930
1931        Ok(())
1932    }
1933
1934    #[test]
1935    fn test_t_distribution_loss_different_nu() -> TestResult {
1936        // Test that smaller ν is more robust
1937        let t3 = TDistributionLoss::new(3.0)?;
1938        let t10 = TDistributionLoss::new(10.0)?;
1939
1940        let s_outlier = 100.0;
1941        let [_, w_t3, _] = t3.evaluate(s_outlier);
1942        let [_, w_t10, _] = t10.evaluate(s_outlier);
1943
1944        // Smaller ν should downweight more aggressively
1945        assert!(w_t3 < w_t10);
1946
1947        Ok(())
1948    }
1949
1950    #[test]
1951    fn test_adaptive_barron_loss() -> TestResult {
1952        // Test default (Cauchy-like with α = 0)
1953        let adaptive = AdaptiveBarronLoss::new(0.0, 1.0)?;
1954
1955        // Test at s = 0
1956        let [rho, _, _] = adaptive.evaluate(0.0);
1957        assert!(rho.abs() < EPSILON);
1958
1959        // Test robustness
1960        let [_, rho_prime_small, _] = adaptive.evaluate(1.0);
1961        let [_, rho_prime_large, _] = adaptive.evaluate(100.0);
1962        assert!(rho_prime_large < rho_prime_small);
1963
1964        // AdaptiveBarron wraps BarronGeneral which is already tested,
1965        // so we just verify the wrapper works correctly
1966        let barron = BarronGeneralLoss::new(0.0, 1.0)?;
1967        let [rho_a, rho_prime_a, rho_double_prime_a] = adaptive.evaluate(4.0);
1968        let [rho_b, rho_prime_b, rho_double_prime_b] = barron.evaluate(4.0);
1969
1970        // Should match the underlying BarronGeneral exactly
1971        assert!((rho_a - rho_b).abs() < EPSILON);
1972        assert!((rho_prime_a - rho_prime_b).abs() < EPSILON);
1973        assert!((rho_double_prime_a - rho_double_prime_b).abs() < EPSILON);
1974
1975        Ok(())
1976    }
1977
1978    #[test]
1979    fn test_adaptive_barron_default() -> TestResult {
1980        // Test default constructor
1981        let adaptive = AdaptiveBarronLoss::default();
1982
1983        // Should behave like Cauchy
1984        let [_, rho_prime, _] = adaptive.evaluate(4.0);
1985        assert!(rho_prime > 0.0 && rho_prime < 1.0);
1986
1987        Ok(())
1988    }
1989
1990    #[test]
1991    fn test_new_loss_constructor_validation() -> TestResult {
1992        // T-distribution: reject non-positive degrees of freedom
1993        assert!(TDistributionLoss::new(0.0).is_err());
1994        assert!(TDistributionLoss::new(-1.0).is_err());
1995        assert!(TDistributionLoss::new(5.0).is_ok());
1996
1997        // Adaptive Barron: reject non-positive scale
1998        assert!(AdaptiveBarronLoss::new(0.0, 0.0).is_err());
1999        assert!(AdaptiveBarronLoss::new(1.0, -1.0).is_err());
2000        assert!(AdaptiveBarronLoss::new(0.0, 1.0).is_ok());
2001
2002        Ok(())
2003    }
2004}