apex_solver/core/loss_functions.rs
1//! Robust loss functions for outlier rejection in nonlinear least squares optimization.
2//!
3//! Loss functions (also called robust cost functions or M-estimators) reduce the influence of
4//! outlier measurements on the optimization result. In standard least squares, the cost is
5//! the squared norm of residuals: `cost = Σ ||r_i||²`. With a robust loss function ρ(s), the
6//! cost becomes: `cost = Σ ρ(||r_i||²)`.
7//!
8//! # Mathematical Formulation
9//!
10//! Each loss function implements the `Loss` trait, which evaluates:
11//! - **ρ(s)**: The robust cost value
12//! - **ρ'(s)**: First derivative (weight function)
13//! - **ρ''(s)**: Second derivative (for corrector algorithm)
14//!
15//! The input `s = ||r||²` is the squared norm of the residual vector.
16//!
17//! # Usage in Optimization
18//!
19//! Loss functions are applied via the `Corrector` algorithm (see `corrector.rs`), which
20//! modifies the residuals and Jacobians to account for the robust weighting. The optimization
21//! then proceeds as if solving a reweighted least squares problem.
22//!
23//! # Available Loss Functions
24//!
25//! ## Basic Loss Functions
26//! - [`L2Loss`]: Standard least squares (no robustness)
27//! - [`L1Loss`]: Absolute error (simple robust baseline)
28//!
29//! ## Moderate Robustness
30//! - [`HuberLoss`]: Quadratic for inliers, linear for outliers (recommended for general use)
31//! - [`FairLoss`]: Smooth transition with continuous derivatives
32//! - [`CauchyLoss`]: Heavier suppression of large residuals
33//!
34//! ## Strong Robustness
35//! - [`GemanMcClureLoss`]: Very strong outlier rejection
36//! - [`WelschLoss`]: Exponential downweighting
37//! - [`TukeyBiweightLoss`]: Complete outlier suppression (redescending)
38//!
39//! ## Specialized Functions
40//! - [`AndrewsWaveLoss`]: Sine-based redescending M-estimator
41//! - [`RamsayEaLoss`]: Exponential decay weighting
42//! - [`TrimmedMeanLoss`]: Hard threshold cutoff
43//! - [`LpNormLoss`]: Generalized Lp norm (flexible p parameter)
44//!
45//! ## Modern Adaptive
46//! - [`BarronGeneralLoss`]: Unified framework encompassing many loss functions (CVPR 2019)
47//!
48//! # Loss Function Selection Guide
49//!
50//! | Use Case | Recommended Loss | Tuning Constant |
51//! |----------|-----------------|-----------------|
52//! | Clean data, no outliers | `L2Loss` | N/A |
53//! | Few outliers (<5%) | `HuberLoss` | c = 1.345 |
54//! | Moderate outliers (5-10%) | `FairLoss` or `CauchyLoss` | c = 1.3998 / 2.3849 |
55//! | Many outliers (>10%) | `WelschLoss` or `TukeyBiweightLoss` | c = 2.9846 / 4.6851 |
56//! | Severe outliers | `GemanMcClureLoss` | c = 1.0-2.0 |
57//! | Adaptive/unknown | `BarronGeneralLoss` | α adaptive |
58//!
59//! # Example
60//!
61//! ```
62//! use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
63//! # use apex_solver::error::ApexSolverResult;
64//! # fn example() -> ApexSolverResult<()> {
65//!
66//! let huber = HuberLoss::new(1.345)?;
67//!
68//! // Evaluate for an inlier (small residual)
69//! let s_inlier = 0.5;
70//! let [rho, rho_prime, rho_double_prime] = huber.evaluate(s_inlier);
71//! assert_eq!(rho, s_inlier); // Quadratic cost in inlier region
72//! assert_eq!(rho_prime, 1.0); // Full weight
73//!
74//! // Evaluate for an outlier (large residual)
75//! let s_outlier = 10.0;
76//! let [rho, rho_prime, rho_double_prime] = huber.evaluate(s_outlier);
77//! // rho grows linearly instead of quadratically
78//! // rho_prime < 1.0, downweighting the outlier
79//! # Ok(())
80//! # }
81//! # example().unwrap();
82//! ```
83
84use crate::core::CoreError;
85use crate::error::ApexSolverResult;
86
87/// Trait for robust loss functions used in nonlinear least squares optimization.
88///
89/// A loss function transforms the squared residual `s = ||r||²` into a robust cost `ρ(s)`
90/// that reduces the influence of outliers. The trait provides the cost value and its first
91/// two derivatives, which are used by the `Corrector` to modify the optimization problem.
92///
93/// # Returns
94///
95/// The `evaluate` method returns a 3-element array: `[ρ(s), ρ'(s), ρ''(s)]`
96/// - `ρ(s)`: Robust cost value
97/// - `ρ'(s)`: First derivative (weight function)
98/// - `ρ''(s)`: Second derivative
99///
100/// # Implementation Notes
101///
102/// - Loss functions should be smooth (at least C²) for optimization stability
103/// - Typically ρ(0) = 0, ρ'(0) = 1, ρ''(0) = 0 (behaves like standard least squares near zero)
104/// - For outliers, ρ'(s) should decrease to downweight large residuals
105pub trait LossFunction: Send + Sync {
106 /// Evaluate the loss function and its first two derivatives at squared residual `s`.
107 ///
108 /// # Arguments
109 ///
110 /// * `s` - The squared norm of the residual: `s = ||r||²` (always non-negative)
111 ///
112 /// # Returns
113 ///
114 /// Array `[ρ(s), ρ'(s), ρ''(s)]` containing the cost, first derivative, and second derivative
115 fn evaluate(&self, s: f64) -> [f64; 3];
116}
117
118/// L2 loss function (standard least squares, no robustness).
119///
120/// The L2 loss is the standard squared error used in ordinary least squares optimization.
121/// It provides no outlier robustness and is optimal only when residuals follow a Gaussian
122/// distribution.
123///
124/// # Mathematical Definition
125///
126/// ```text
127/// ρ(s) = s
128/// ρ'(s) = 1
129/// ρ''(s) = 0
130/// ```
131///
132/// where `s = ||r||²` is the squared residual norm.
133///
134/// # Properties
135///
136/// - **Convex**: Globally optimal solution
137/// - **Not robust**: Outliers have full influence (squared!)
138/// - **Optimal**: For Gaussian noise without outliers
139/// - **Fast**: Simplest to compute
140///
141/// # Use Cases
142///
143/// - Clean data with known Gaussian noise
144/// - Baseline comparison for robust methods
145/// - When outliers are already filtered
146///
147/// # Example
148///
149/// ```
150/// use apex_solver::core::loss_functions::{LossFunction, L2Loss};
151///
152/// let l2 = L2Loss::new();
153///
154/// let [rho, rho_prime, rho_double_prime] = l2.evaluate(4.0);
155/// assert_eq!(rho, 4.0); // ρ(s) = s
156/// assert_eq!(rho_prime, 1.0); // Full weight
157/// assert_eq!(rho_double_prime, 0.0);
158/// ```
159#[derive(Debug, Clone, Copy)]
160pub struct L2Loss;
161
162impl L2Loss {
163 /// Create a new L2 loss function (no parameters needed).
164 pub fn new() -> Self {
165 L2Loss
166 }
167}
168
169impl Default for L2Loss {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175impl LossFunction for L2Loss {
176 fn evaluate(&self, s: f64) -> [f64; 3] {
177 [s, 1.0, 0.0]
178 }
179}
180
181/// L1 loss function (absolute error, simple robust baseline).
182///
183/// The L1 loss uses absolute error instead of squared error, providing basic robustness
184/// to outliers. It is optimal for Laplacian noise distributions.
185///
186/// # Mathematical Definition
187///
188/// ```text
189/// ρ(s) = 2√s
190/// ρ'(s) = 1/√s
191/// ρ''(s) = -1/(2s^(3/2))
192/// ```
193///
194/// where `s = ||r||²` is the squared residual norm.
195///
196/// # Properties
197///
198/// - **Convex**: Globally optimal solution
199/// - **Moderately robust**: Linear growth vs quadratic
200/// - **Unstable at zero**: Derivative undefined at s=0
201/// - **Median estimator**: Minimizes to median instead of mean
202///
203/// # Use Cases
204///
205/// - Simple outlier rejection
206/// - When median is preferred over mean
207/// - Sparse optimization problems
208///
209/// # Example
210///
211/// ```
212/// use apex_solver::core::loss_functions::{LossFunction, L1Loss};
213///
214/// let l1 = L1Loss::new();
215///
216/// let [rho, rho_prime, _] = l1.evaluate(4.0);
217/// assert!((rho - 4.0).abs() < 1e-10); // ρ(4) = 2√4 = 4
218/// assert!((rho_prime - 0.5).abs() < 1e-10); // ρ'(4) = 1/√4 = 0.5
219/// ```
220#[derive(Debug, Clone, Copy)]
221pub struct L1Loss;
222
223impl L1Loss {
224 /// Create a new L1 loss function (no parameters needed).
225 pub fn new() -> Self {
226 L1Loss
227 }
228}
229
230impl Default for L1Loss {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236impl LossFunction for L1Loss {
237 fn evaluate(&self, s: f64) -> [f64; 3] {
238 if s < f64::EPSILON {
239 // Near zero: use L2 to avoid singularity
240 return [s, 1.0, 0.0];
241 }
242 let sqrt_s = s.sqrt();
243 [
244 2.0 * sqrt_s, // ρ(s) = 2√s
245 1.0 / sqrt_s, // ρ'(s) = 1/√s
246 -1.0 / (2.0 * s * sqrt_s), // ρ''(s) = -1/(2s√s)
247 ]
248 }
249}
250
251/// Huber loss function for moderate outlier rejection.
252///
253/// The Huber loss is quadratic for small residuals (inliers) and linear for large residuals
254/// (outliers), providing a good balance between robustness and efficiency.
255///
256/// # Mathematical Definition
257///
258/// ```text
259/// ρ(s) = { s if s ≤ δ²
260/// { 2δ√s - δ² if s > δ²
261///
262/// ρ'(s) = { 1 if s ≤ δ²
263/// { δ / √s if s > δ²
264///
265/// ρ''(s) = { 0 if s ≤ δ²
266/// { -δ / (2s^(3/2)) if s > δ²
267/// ```
268///
269/// where `δ` is the scale parameter (threshold), and `s = ||r||²` is the squared residual norm.
270///
271/// # Properties
272///
273/// - **Inlier region** (s ≤ δ²): Behaves like standard least squares (quadratic cost)
274/// - **Outlier region** (s > δ²): Cost grows linearly, limiting outlier influence
275/// - **Transition point**: At s = δ², the function switches from quadratic to linear
276///
277/// # Scale Parameter Selection
278///
279/// Common choices for the scale parameter `δ`:
280/// - **1.345**: Approximately 95% efficiency on Gaussian data (most common)
281/// - **0.5-1.0**: More aggressive outlier rejection
282/// - **2.0-3.0**: More lenient, closer to standard least squares
283///
284/// # Example
285///
286/// ```
287/// use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
288/// # use apex_solver::error::ApexSolverResult;
289/// # fn example() -> ApexSolverResult<()> {
290///
291/// // Create Huber loss with scale = 1.345 (standard choice)
292/// let huber = HuberLoss::new(1.345)?;
293///
294/// // Small residual (inlier): ||r||² = 0.5
295/// let [rho, rho_prime, rho_double_prime] = huber.evaluate(0.5);
296/// assert_eq!(rho, 0.5); // Quadratic: ρ(s) = s
297/// assert_eq!(rho_prime, 1.0); // Full weight
298/// assert_eq!(rho_double_prime, 0.0);
299///
300/// // Large residual (outlier): ||r||² = 10.0
301/// let [rho, rho_prime, rho_double_prime] = huber.evaluate(10.0);
302/// // ρ(10) ≈ 6.69, grows linearly not quadratically
303/// // ρ'(10) ≈ 0.425, downweighted to ~42.5% of original
304/// # Ok(())
305/// # }
306/// # example().unwrap();
307/// ```
308#[derive(Debug, Clone)]
309pub struct HuberLoss {
310 /// Scale parameter δ
311 scale: f64,
312 /// Cached value δ² for efficient computation
313 scale2: f64,
314}
315
316impl HuberLoss {
317 /// Create a new Huber loss function with the given scale parameter.
318 ///
319 /// # Arguments
320 ///
321 /// * `scale` - The threshold δ that separates inliers from outliers (must be positive)
322 ///
323 /// # Returns
324 ///
325 /// `Ok(HuberLoss)` if scale > 0, otherwise an error
326 ///
327 /// # Example
328 ///
329 /// ```
330 /// use apex_solver::core::loss_functions::HuberLoss;
331 /// # use apex_solver::error::ApexSolverResult;
332 /// # fn example() -> ApexSolverResult<()> {
333 ///
334 /// let huber = HuberLoss::new(1.345)?;
335 /// # Ok(())
336 /// # }
337 /// # example().unwrap();
338 /// ```
339 pub fn new(scale: f64) -> ApexSolverResult<Self> {
340 if scale <= 0.0 {
341 return Err(
342 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
343 );
344 }
345 Ok(HuberLoss {
346 scale,
347 scale2: scale * scale,
348 })
349 }
350}
351
352impl LossFunction for HuberLoss {
353 /// Evaluate Huber loss function: ρ(s), ρ'(s), ρ''(s).
354 ///
355 /// # Arguments
356 ///
357 /// * `s` - Squared residual norm: s = ||r||²
358 ///
359 /// # Returns
360 ///
361 /// `[ρ(s), ρ'(s), ρ''(s)]` - Cost, first derivative, second derivative
362 fn evaluate(&self, s: f64) -> [f64; 3] {
363 if s > self.scale2 {
364 // Outlier region: s > δ²
365 // Linear cost: ρ(s) = 2δ√s - δ²
366 let r = s.sqrt(); // r = √s = ||r||
367 let rho1 = (self.scale / r).max(f64::MIN); // ρ'(s) = δ / √s
368 [
369 2.0 * self.scale * r - self.scale2, // ρ(s)
370 rho1, // ρ'(s)
371 -rho1 / (2.0 * s), // ρ''(s) = -δ / (2s√s)
372 ]
373 } else {
374 // Inlier region: s ≤ δ²
375 // Quadratic cost: ρ(s) = s, ρ'(s) = 1, ρ''(s) = 0
376 [s, 1.0, 0.0]
377 }
378 }
379}
380
381/// Cauchy loss function for aggressive outlier rejection.
382///
383/// The Cauchy loss (also called Lorentzian loss) provides stronger suppression of outliers
384/// than Huber loss. It never fully rejects outliers but reduces their weight significantly.
385///
386/// # Mathematical Definition
387///
388/// ```text
389/// ρ(s) = (δ²/2) * log(1 + s/δ²)
390///
391/// ρ'(s) = 1 / (1 + s/δ²)
392///
393/// ρ''(s) = -1 / (δ² * (1 + s/δ²)²)
394/// ```
395///
396/// where `δ` is the scale parameter, and `s = ||r||²` is the squared residual norm.
397///
398/// # Properties
399///
400/// - **Smooth transition**: No sharp boundary between inliers and outliers
401/// - **Logarithmic growth**: Cost grows very slowly for large residuals
402/// - **Strong downweighting**: Large outliers receive very small weights
403/// - **Non-convex**: Can have multiple local minima (harder to optimize than Huber)
404///
405/// # Scale Parameter Selection
406///
407/// Typical values:
408/// - **2.3849**: Approximately 95% efficiency on Gaussian data
409/// - **1.0-2.0**: More aggressive outlier rejection
410/// - **3.0-5.0**: More lenient
411///
412/// # Comparison to Huber Loss
413///
414/// - **Cauchy**: Stronger outlier rejection, smoother, but non-convex (may converge to local minimum)
415/// - **Huber**: Weaker outlier rejection, convex, more predictable convergence
416///
417/// # Example
418///
419/// ```
420/// use apex_solver::core::loss_functions::{LossFunction, CauchyLoss};
421/// # use apex_solver::error::ApexSolverResult;
422/// # fn example() -> ApexSolverResult<()> {
423///
424/// // Create Cauchy loss with scale = 2.3849 (standard choice)
425/// let cauchy = CauchyLoss::new(2.3849)?;
426///
427/// // Small residual: ||r||² = 0.5
428/// let [rho, rho_prime, _] = cauchy.evaluate(0.5);
429/// // ρ ≈ 0.47, slightly less than 0.5 (mild downweighting)
430/// // ρ' ≈ 0.92, close to 1.0 (near full weight)
431///
432/// // Large residual: ||r||² = 100.0
433/// let [rho, rho_prime, _] = cauchy.evaluate(100.0);
434/// // ρ ≈ 8.0, logarithmic growth (much less than 100)
435/// // ρ' ≈ 0.05, heavily downweighted (5% of original)
436/// # Ok(())
437/// # }
438/// # example().unwrap();
439/// ```
440pub struct CauchyLoss {
441 /// Cached value δ² (scale squared)
442 scale2: f64,
443 /// Cached value 1/δ² for efficient computation
444 c: f64,
445}
446
447impl CauchyLoss {
448 /// Create a new Cauchy loss function with the given scale parameter.
449 ///
450 /// # Arguments
451 ///
452 /// * `scale` - The scale parameter δ (must be positive)
453 ///
454 /// # Returns
455 ///
456 /// `Ok(CauchyLoss)` if scale > 0, otherwise an error
457 ///
458 /// # Example
459 ///
460 /// ```
461 /// use apex_solver::core::loss_functions::CauchyLoss;
462 /// # use apex_solver::error::ApexSolverResult;
463 /// # fn example() -> ApexSolverResult<()> {
464 ///
465 /// let cauchy = CauchyLoss::new(2.3849)?;
466 /// # Ok(())
467 /// # }
468 /// # example().unwrap();
469 /// ```
470 pub fn new(scale: f64) -> ApexSolverResult<Self> {
471 if scale <= 0.0 {
472 return Err(
473 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
474 );
475 }
476 let scale2 = scale * scale;
477 Ok(CauchyLoss {
478 scale2,
479 c: 1.0 / scale2,
480 })
481 }
482}
483
484impl LossFunction for CauchyLoss {
485 /// Evaluate Cauchy loss function: ρ(s), ρ'(s), ρ''(s).
486 ///
487 /// # Arguments
488 ///
489 /// * `s` - Squared residual norm: s = ||r||²
490 ///
491 /// # Returns
492 ///
493 /// `[ρ(s), ρ'(s), ρ''(s)]` - Cost, first derivative, second derivative
494 fn evaluate(&self, s: f64) -> [f64; 3] {
495 let sum = 1.0 + s * self.c; // 1 + s/δ²
496 let inv = 1.0 / sum; // 1 / (1 + s/δ²)
497
498 // Note: sum and inv are always positive, assuming s ≥ 0
499 [
500 self.scale2 * sum.ln() / 2.0, // ρ(s) = (δ²/2) * ln(1 + s/δ²)
501 inv.max(f64::MIN), // ρ'(s) = 1 / (1 + s/δ²)
502 -self.c * (inv * inv), // ρ''(s) = -1 / (δ² * (1 + s/δ²)²)
503 ]
504 }
505}
506
507/// Fair loss function with continuous smooth derivatives.
508///
509/// The Fair loss provides a good balance between robustness and stability with everywhere-defined
510/// continuous derivatives up to third order. It yields a unique solution and is recommended
511/// for general use when you need guaranteed smoothness.
512///
513/// # Mathematical Definition
514///
515/// ```text
516/// ρ(s) = c² * (|x|/c - ln(1 + |x|/c))
517/// ρ'(s) = |x| / (c + |x|)
518/// ρ''(s) = sign(x) * c / ((c + |x|)²√s)
519/// ```
520///
521/// where `c` is the scale parameter, `s = ||r||²`, and `x = √s = ||r||`.
522///
523/// # Properties
524///
525/// - **Smooth**: Continuous derivatives of first three orders
526/// - **Unique solution**: Strictly convex near origin
527/// - **Moderate robustness**: Between Huber and Cauchy
528/// - **Stable**: No discontinuities in optimization
529///
530/// # Scale Parameter Selection
531///
532/// - **1.3998**: Approximately 95% efficiency on Gaussian data (recommended)
533/// - **0.8-1.2**: More aggressive outlier rejection
534/// - **2.0-3.0**: More lenient
535///
536/// # Comparison
537///
538/// - Smoother than Huber (no kink at threshold)
539/// - Less aggressive than Cauchy
540/// - Better behaved numerically than many redescending M-estimators
541///
542/// # Example
543///
544/// ```
545/// use apex_solver::core::loss_functions::{LossFunction, FairLoss};
546/// # use apex_solver::error::ApexSolverResult;
547/// # fn example() -> ApexSolverResult<()> {
548///
549/// let fair = FairLoss::new(1.3998)?;
550///
551/// let [rho, rho_prime, _] = fair.evaluate(4.0);
552/// // Smooth transition, no sharp corners
553/// # Ok(())
554/// # }
555/// # example().unwrap();
556/// ```
557#[derive(Debug, Clone)]
558pub struct FairLoss {
559 scale: f64,
560}
561
562impl FairLoss {
563 /// Create a new Fair loss function with the given scale parameter.
564 ///
565 /// # Arguments
566 ///
567 /// * `scale` - The scale parameter c (must be positive)
568 ///
569 /// # Returns
570 ///
571 /// `Ok(FairLoss)` if scale > 0, otherwise an error
572 pub fn new(scale: f64) -> ApexSolverResult<Self> {
573 if scale <= 0.0 {
574 return Err(
575 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
576 );
577 }
578 Ok(FairLoss { scale })
579 }
580}
581
582impl LossFunction for FairLoss {
583 fn evaluate(&self, s: f64) -> [f64; 3] {
584 if s < f64::EPSILON {
585 return [s, 1.0, 0.0];
586 }
587
588 let x = s.sqrt(); // ||r||
589 let abs_x = x.abs();
590 let c_plus_x = self.scale + abs_x;
591
592 // ρ(s) = c² * (|x|/c - ln(1 + |x|/c))
593 let rho = self.scale * self.scale * (abs_x / self.scale - (1.0 + abs_x / self.scale).ln());
594
595 // ρ'(s) = |x| / (c + |x|) * (1 / 2|x|) = 1 / (2(c + |x|))
596 let rho_prime = 0.5 / c_plus_x;
597
598 // ρ''(s) = -1 / (4s(c + |x|)²)
599 let rho_double_prime = -1.0 / (4.0 * s * c_plus_x * c_plus_x);
600
601 [rho, rho_prime, rho_double_prime]
602 }
603}
604
605/// Geman-McClure loss function for very strong outlier rejection.
606///
607/// The Geman-McClure loss provides one of the strongest forms of outlier suppression,
608/// with weights that decay rapidly for large residuals.
609///
610/// # Mathematical Definition
611///
612/// ```text
613/// ρ(s) = s / (1 + s/c²)
614/// ρ'(s) = 1 / (1 + s/c²)²
615/// ρ''(s) = -2 / (c² * (1 + s/c²)³)
616/// ```
617///
618/// where `c` is the scale parameter and `s = ||r||²`.
619///
620/// # Properties
621///
622/// - **Very strong rejection**: Weights decay as O(1/s²) for large s
623/// - **Non-convex**: Multiple local minima possible
624/// - **No unique solution**: Requires good initialization
625/// - **Aggressive**: Use when outliers are severe
626///
627/// # Scale Parameter Selection
628///
629/// - **1.0-2.0**: Typical range
630/// - **0.5-1.0**: Very aggressive (use with care)
631/// - **2.0-4.0**: More lenient
632///
633/// # Example
634///
635/// ```
636/// use apex_solver::core::loss_functions::{LossFunction, GemanMcClureLoss};
637/// # use apex_solver::error::ApexSolverResult;
638/// # fn example() -> ApexSolverResult<()> {
639///
640/// let geman = GemanMcClureLoss::new(1.0)?;
641///
642/// let [rho, rho_prime, _] = geman.evaluate(100.0);
643/// // Very small weight for large outliers
644/// # Ok(())
645/// # }
646/// # example().unwrap();
647/// ```
648#[derive(Debug, Clone)]
649pub struct GemanMcClureLoss {
650 c: f64, // 1/scale²
651}
652
653impl GemanMcClureLoss {
654 /// Create a new Geman-McClure loss function.
655 ///
656 /// # Arguments
657 ///
658 /// * `scale` - The scale parameter c (must be positive)
659 pub fn new(scale: f64) -> ApexSolverResult<Self> {
660 if scale <= 0.0 {
661 return Err(
662 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
663 );
664 }
665 let scale2 = scale * scale;
666 Ok(GemanMcClureLoss { c: 1.0 / scale2 })
667 }
668}
669
670impl LossFunction for GemanMcClureLoss {
671 fn evaluate(&self, s: f64) -> [f64; 3] {
672 let denom = 1.0 + s * self.c; // 1 + s/c²
673 let inv = 1.0 / denom;
674 let inv2 = inv * inv;
675
676 [
677 s * inv, // ρ(s) = s / (1 + s/c²)
678 inv2, // ρ'(s) = 1 / (1 + s/c²)²
679 -2.0 * self.c * inv2 * inv, // ρ''(s) = -2 / (c²(1 + s/c²)³)
680 ]
681 }
682}
683
684/// Welsch loss function with exponential downweighting.
685///
686/// The Welsch loss (also called Leclerc loss) uses exponential decay to strongly
687/// suppress outliers while maintaining smoothness. It completely suppresses very
688/// large outliers (weight → 0 as s → ∞).
689///
690/// # Mathematical Definition
691///
692/// ```text
693/// ρ(s) = c²/2 * (1 - exp(-s/c²))
694/// ρ'(s) = (1/2) * exp(-s/c²)
695/// ρ''(s) = -(1/2c²) * exp(-s/c²)
696/// ```
697///
698/// where `c` is the scale parameter and `s = ||r||²`.
699///
700/// # Properties
701///
702/// - **Redescending**: Weights decrease to zero for large residuals
703/// - **Smooth**: Infinitely differentiable
704/// - **Strong suppression**: Exponential decay
705/// - **Non-convex**: Requires good initialization
706///
707/// # Scale Parameter Selection
708///
709/// - **2.9846**: Approximately 95% efficiency on Gaussian data
710/// - **2.0-2.5**: More aggressive
711/// - **3.5-4.5**: More lenient
712///
713/// # Example
714///
715/// ```
716/// use apex_solver::core::loss_functions::{LossFunction, WelschLoss};
717/// # use apex_solver::error::ApexSolverResult;
718/// # fn example() -> ApexSolverResult<()> {
719///
720/// let welsch = WelschLoss::new(2.9846)?;
721///
722/// let [rho, rho_prime, _] = welsch.evaluate(50.0);
723/// // Weight approaches zero for large residuals
724/// # Ok(())
725/// # }
726/// # example().unwrap();
727/// ```
728#[derive(Debug, Clone)]
729pub struct WelschLoss {
730 scale2: f64,
731 inv_scale2: f64,
732}
733
734impl WelschLoss {
735 /// Create a new Welsch loss function.
736 ///
737 /// # Arguments
738 ///
739 /// * `scale` - The scale parameter c (must be positive)
740 pub fn new(scale: f64) -> ApexSolverResult<Self> {
741 if scale <= 0.0 {
742 return Err(
743 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
744 );
745 }
746 let scale2 = scale * scale;
747 Ok(WelschLoss {
748 scale2,
749 inv_scale2: 1.0 / scale2,
750 })
751 }
752}
753
754impl LossFunction for WelschLoss {
755 fn evaluate(&self, s: f64) -> [f64; 3] {
756 let exp_term = (-s * self.inv_scale2).exp();
757
758 [
759 (self.scale2 / 2.0) * (1.0 - exp_term), // ρ(s) = c²/2 * (1 - exp(-s/c²))
760 0.5 * exp_term, // ρ'(s) = (1/2) * exp(-s/c²)
761 -0.5 * self.inv_scale2 * exp_term, // ρ''(s) = -(1/2c²) * exp(-s/c²)
762 ]
763 }
764}
765
766/// Tukey biweight loss function with complete outlier suppression.
767///
768/// The Tukey biweight (bisquare) loss completely suppresses outliers beyond a threshold,
769/// setting their weight to exactly zero. This is a "redescending" M-estimator.
770///
771/// # Mathematical Definition
772///
773/// For |x| ≤ c:
774/// ```text
775/// ρ(s) = c²/6 * (1 - (1 - (x/c)²)³)
776/// ρ'(s) = (1/2) * (1 - (x/c)²)²
777/// ρ''(s) = -(x/c²) * (1 - (x/c)²)
778/// ```
779///
780/// For |x| > c:
781/// ```text
782/// ρ(s) = c²/6
783/// ρ'(s) = 0
784/// ρ''(s) = 0
785/// ```
786///
787/// where `c` is the scale parameter, `x = √s`, and `s = ||r||²`.
788///
789/// # Properties
790///
791/// - **Complete suppression**: Outliers have exactly zero weight
792/// - **Redescending**: Weight goes to zero beyond threshold
793/// - **Non-convex**: Multiple local minima
794/// - **Aggressive**: Best for severe outlier contamination
795///
796/// # Scale Parameter Selection
797///
798/// - **4.6851**: Approximately 95% efficiency on Gaussian data
799/// - **3.5-4.0**: More aggressive
800/// - **5.5-6.5**: More lenient
801///
802/// # Example
803///
804/// ```
805/// use apex_solver::core::loss_functions::{LossFunction, TukeyBiweightLoss};
806/// # use apex_solver::error::ApexSolverResult;
807/// # fn example() -> ApexSolverResult<()> {
808///
809/// let tukey = TukeyBiweightLoss::new(4.6851)?;
810///
811/// let [rho, rho_prime, _] = tukey.evaluate(25.0); // |x| = 5 > 4.6851
812/// assert_eq!(rho_prime, 0.0); // Complete suppression
813/// # Ok(())
814/// # }
815/// # example().unwrap();
816/// ```
817#[derive(Debug, Clone)]
818pub struct TukeyBiweightLoss {
819 scale: f64,
820 scale2: f64,
821}
822
823impl TukeyBiweightLoss {
824 /// Create a new Tukey biweight loss function.
825 ///
826 /// # Arguments
827 ///
828 /// * `scale` - The scale parameter c (must be positive)
829 pub fn new(scale: f64) -> ApexSolverResult<Self> {
830 if scale <= 0.0 {
831 return Err(
832 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
833 );
834 }
835 Ok(TukeyBiweightLoss {
836 scale,
837 scale2: scale * scale,
838 })
839 }
840}
841
842impl LossFunction for TukeyBiweightLoss {
843 fn evaluate(&self, s: f64) -> [f64; 3] {
844 let x = s.sqrt();
845
846 if x > self.scale {
847 // Complete outlier suppression
848 [self.scale2 / 6.0, 0.0, 0.0]
849 } else {
850 let ratio = x / self.scale;
851 let ratio2 = ratio * ratio;
852 let one_minus_ratio2 = 1.0 - ratio2;
853 let one_minus_ratio2_sq = one_minus_ratio2 * one_minus_ratio2;
854
855 [
856 (self.scale2 / 6.0) * (1.0 - one_minus_ratio2 * one_minus_ratio2_sq), // ρ(s)
857 0.5 * one_minus_ratio2_sq, // ρ'(s)
858 -(ratio / self.scale2) * one_minus_ratio2, // ρ''(s)
859 ]
860 }
861 }
862}
863
864/// Andrews sine wave loss function (redescending M-estimator).
865///
866/// The Andrews sine wave loss uses a periodic sine function to create a redescending
867/// M-estimator that completely suppresses outliers beyond π*c.
868///
869/// # Mathematical Definition
870///
871/// For |x| ≤ πc:
872/// ```text
873/// ρ(s) = c² * (1 - cos(x/c))
874/// ρ'(s) = (1/2) * sin(x/c)
875/// ρ''(s) = (1/4c) * cos(x/c) / √s
876/// ```
877///
878/// For |x| > πc:
879/// ```text
880/// ρ(s) = 2c²
881/// ρ'(s) = 0
882/// ρ''(s) = 0
883/// ```
884///
885/// where `c` is the scale parameter, `x = √s`, and `s = ||r||²`.
886///
887/// # Properties
888///
889/// - **Periodic structure**: Sine-based weighting
890/// - **Complete suppression**: Zero weight beyond πc
891/// - **Redescending**: Smooth transition to zero
892/// - **Non-convex**: Requires careful initialization
893///
894/// # Scale Parameter Selection
895///
896/// - **1.339**: Standard tuning constant
897/// - **1.0-1.2**: More aggressive
898/// - **1.5-2.0**: More lenient
899///
900/// # Example
901///
902/// ```
903/// use apex_solver::core::loss_functions::{LossFunction, AndrewsWaveLoss};
904/// # use apex_solver::error::ApexSolverResult;
905/// # fn example() -> ApexSolverResult<()> {
906///
907/// let andrews = AndrewsWaveLoss::new(1.339)?;
908///
909/// let [rho, rho_prime, _] = andrews.evaluate(20.0);
910/// // Weight is zero for large outliers
911/// # Ok(())
912/// # }
913/// # example().unwrap();
914/// ```
915#[derive(Debug, Clone)]
916pub struct AndrewsWaveLoss {
917 scale: f64,
918 scale2: f64,
919 threshold: f64, // π * scale
920}
921
922impl AndrewsWaveLoss {
923 /// Create a new Andrews sine wave loss function.
924 ///
925 /// # Arguments
926 ///
927 /// * `scale` - The scale parameter c (must be positive)
928 pub fn new(scale: f64) -> ApexSolverResult<Self> {
929 if scale <= 0.0 {
930 return Err(
931 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
932 );
933 }
934 Ok(AndrewsWaveLoss {
935 scale,
936 scale2: scale * scale,
937 threshold: std::f64::consts::PI * scale,
938 })
939 }
940}
941
942impl LossFunction for AndrewsWaveLoss {
943 fn evaluate(&self, s: f64) -> [f64; 3] {
944 let x = s.sqrt();
945
946 if x > self.threshold {
947 // Complete suppression beyond π*c
948 [2.0 * self.scale2, 0.0, 0.0]
949 } else {
950 let arg = x / self.scale;
951 let sin_val = arg.sin();
952 let cos_val = arg.cos();
953
954 [
955 self.scale2 * (1.0 - cos_val), // ρ(s)
956 0.5 * sin_val, // ρ'(s)
957 (0.25 / self.scale) * cos_val / x.max(f64::EPSILON), // ρ''(s)
958 ]
959 }
960 }
961}
962
963/// Ramsay Ea loss function with exponential decay.
964///
965/// The Ramsay Ea loss uses exponential weighting to provide smooth, strong
966/// downweighting of outliers.
967///
968/// # Mathematical Definition
969///
970/// ```text
971/// ρ(s) = a⁻² * (1 - exp(-a|x|) * (1 + a|x|))
972/// ```
973///
974/// where `a` is the scale parameter, `x = √s`, and `s = ||r||²`.
975///
976/// # Properties
977///
978/// - **Exponential decay**: Smooth weight reduction
979/// - **Strongly robust**: Good for heavy outliers
980/// - **Smooth**: Continuous derivatives
981/// - **Non-convex**: Needs good initialization
982///
983/// # Scale Parameter Selection
984///
985/// - **0.3**: Standard tuning constant
986/// - **0.2-0.25**: More aggressive
987/// - **0.35-0.5**: More lenient
988///
989/// # Example
990///
991/// ```
992/// use apex_solver::core::loss_functions::{LossFunction, RamsayEaLoss};
993/// # use apex_solver::error::ApexSolverResult;
994/// # fn example() -> ApexSolverResult<()> {
995///
996/// let ramsay = RamsayEaLoss::new(0.3)?;
997///
998/// let [rho, rho_prime, _] = ramsay.evaluate(10.0);
999/// // Exponential downweighting
1000/// # Ok(())
1001/// # }
1002/// # example().unwrap();
1003/// ```
1004#[derive(Debug, Clone)]
1005pub struct RamsayEaLoss {
1006 scale: f64,
1007 inv_scale2: f64,
1008}
1009
1010impl RamsayEaLoss {
1011 /// Create a new Ramsay Ea loss function.
1012 ///
1013 /// # Arguments
1014 ///
1015 /// * `scale` - The scale parameter a (must be positive)
1016 pub fn new(scale: f64) -> ApexSolverResult<Self> {
1017 if scale <= 0.0 {
1018 return Err(
1019 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
1020 );
1021 }
1022 Ok(RamsayEaLoss {
1023 scale,
1024 inv_scale2: 1.0 / (scale * scale),
1025 })
1026 }
1027}
1028
1029impl LossFunction for RamsayEaLoss {
1030 fn evaluate(&self, s: f64) -> [f64; 3] {
1031 let x = s.sqrt();
1032 let ax = self.scale * x;
1033 let exp_term = (-ax).exp();
1034
1035 // ρ(s) = a⁻² * (1 - exp(-a|x|) * (1 + a|x|))
1036 let rho = self.inv_scale2 * (1.0 - exp_term * (1.0 + ax));
1037
1038 // ρ'(s) = (1/2) * exp(-a|x|)
1039 let rho_prime = 0.5 * exp_term;
1040
1041 // ρ''(s) = -(a/4|x|) * exp(-a|x|)
1042 let rho_double_prime = -(self.scale / (4.0 * x.max(f64::EPSILON))) * exp_term;
1043
1044 [rho, rho_prime, rho_double_prime]
1045 }
1046}
1047
1048/// Trimmed mean loss function with hard threshold.
1049///
1050/// The trimmed mean loss is the simplest redescending estimator, applying a hard
1051/// cutoff at a threshold. Residuals below the threshold use L2 loss, those above
1052/// contribute a constant.
1053///
1054/// # Mathematical Definition
1055///
1056/// For s ≤ c²:
1057/// ```text
1058/// ρ(s) = s/2
1059/// ρ'(s) = 1/2
1060/// ρ''(s) = 0
1061/// ```
1062///
1063/// For s > c²:
1064/// ```text
1065/// ρ(s) = c²/2
1066/// ρ'(s) = 0
1067/// ρ''(s) = 0
1068/// ```
1069///
1070/// where `c` is the scale parameter and `s = ||r||²`.
1071///
1072/// # Properties
1073///
1074/// - **Simple**: Easiest to understand and implement
1075/// - **Hard cutoff**: Discontinuous weight function
1076/// - **Robust**: Completely ignores large outliers
1077/// - **Unstable**: Discontinuity can cause optimization issues
1078///
1079/// # Scale Parameter Selection
1080///
1081/// - **2.0**: Standard tuning constant
1082/// - **1.5**: More aggressive
1083/// - **3.0**: More lenient
1084///
1085/// # Example
1086///
1087/// ```
1088/// use apex_solver::core::loss_functions::{LossFunction, TrimmedMeanLoss};
1089/// # use apex_solver::error::ApexSolverResult;
1090/// # fn example() -> ApexSolverResult<()> {
1091///
1092/// let trimmed = TrimmedMeanLoss::new(2.0)?;
1093///
1094/// let [rho, rho_prime, _] = trimmed.evaluate(5.0);
1095/// assert_eq!(rho_prime, 0.0); // Beyond threshold
1096/// # Ok(())
1097/// # }
1098/// # example().unwrap();
1099/// ```
1100#[derive(Debug, Clone)]
1101pub struct TrimmedMeanLoss {
1102 scale2: f64,
1103}
1104
1105impl TrimmedMeanLoss {
1106 /// Create a new trimmed mean loss function.
1107 ///
1108 /// # Arguments
1109 ///
1110 /// * `scale` - The scale parameter c (must be positive)
1111 pub fn new(scale: f64) -> ApexSolverResult<Self> {
1112 if scale <= 0.0 {
1113 return Err(
1114 CoreError::InvalidInput("scale needs to be larger than zero".to_string()).into(),
1115 );
1116 }
1117 Ok(TrimmedMeanLoss {
1118 scale2: scale * scale,
1119 })
1120 }
1121}
1122
1123impl LossFunction for TrimmedMeanLoss {
1124 fn evaluate(&self, s: f64) -> [f64; 3] {
1125 if s <= self.scale2 {
1126 [s / 2.0, 0.5, 0.0]
1127 } else {
1128 [self.scale2 / 2.0, 0.0, 0.0]
1129 }
1130 }
1131}
1132
1133/// Generalized Lp norm loss function.
1134///
1135/// The Lp norm loss allows flexible control over robustness through the p parameter.
1136/// It interpolates between L1 (p=1) and L2 (p=2) losses.
1137///
1138/// # Mathematical Definition
1139///
1140/// ```text
1141/// ρ(s) = |x|^p = s^(p/2)
1142/// ρ'(s) = (p/2) * s^(p/2-1)
1143/// ρ''(s) = (p/2) * (p/2-1) * s^(p/2-2)
1144/// ```
1145///
1146/// where `p` is the norm parameter, `x = √s`, and `s = ||r||²`.
1147///
1148/// # Properties
1149///
1150/// - **Flexible**: Tune robustness with p parameter
1151/// - **p=2**: L2 norm (standard least squares)
1152/// - **p=1**: L1 norm (robust median estimator)
1153/// - **p<1**: Very robust (non-convex)
1154/// - **1<p<2**: Compromise between L1 and L2
1155///
1156/// # Parameter Selection
1157///
1158/// - **p = 2.0**: No robustness (L2 loss)
1159/// - **p = 1.5**: Moderate robustness
1160/// - **p = 1.2**: Strong robustness
1161/// - **p = 1.0**: L1 loss (median)
1162///
1163/// # Example
1164///
1165/// ```
1166/// use apex_solver::core::loss_functions::{LossFunction, LpNormLoss};
1167/// # use apex_solver::error::ApexSolverResult;
1168/// # fn example() -> ApexSolverResult<()> {
1169///
1170/// let lp = LpNormLoss::new(1.5)?;
1171///
1172/// let [rho, rho_prime, _] = lp.evaluate(4.0);
1173/// // Between L1 and L2 behavior
1174/// # Ok(())
1175/// # }
1176/// # example().unwrap();
1177/// ```
1178#[derive(Debug, Clone)]
1179pub struct LpNormLoss {
1180 p: f64,
1181}
1182
1183impl LpNormLoss {
1184 /// Create a new Lp norm loss function.
1185 ///
1186 /// # Arguments
1187 ///
1188 /// * `p` - The norm parameter (0 < p ≤ 2 for practical use)
1189 pub fn new(p: f64) -> ApexSolverResult<Self> {
1190 if p <= 0.0 {
1191 return Err(CoreError::InvalidInput("p must be positive".to_string()).into());
1192 }
1193 Ok(LpNormLoss { p })
1194 }
1195}
1196
1197impl LossFunction for LpNormLoss {
1198 fn evaluate(&self, s: f64) -> [f64; 3] {
1199 if s < f64::EPSILON {
1200 return [s, 1.0, 0.0];
1201 }
1202
1203 let exp_rho = self.p / 2.0;
1204 let exp_rho_prime = exp_rho - 1.0;
1205 let exp_rho_double_prime = exp_rho_prime - 1.0;
1206
1207 [
1208 s.powf(exp_rho), // ρ(s) = s^(p/2)
1209 exp_rho * s.powf(exp_rho_prime), // ρ'(s) = (p/2) * s^(p/2-1)
1210 exp_rho * exp_rho_prime * s.powf(exp_rho_double_prime), // ρ''(s)
1211 ]
1212 }
1213}
1214
1215/// Barron's general and adaptive robust loss function (CVPR 2019).
1216///
1217/// The Barron loss is a unified framework that encompasses many classic loss functions
1218/// (L2, Charbonnier, Cauchy, Geman-McClure, Welsch) through a single shape parameter α.
1219/// It can also adapt α automatically during optimization.
1220///
1221/// # Mathematical Definition
1222///
1223/// ```text
1224/// ρ(s, α, c) = (|α|/c²) * (|(x/c)² * |α|/2 + 1|^(α/2) - 1)
1225/// ```
1226///
1227/// where `α` controls robustness, `c` is scale, `x = √s`, and `s = ||r||²`.
1228///
1229/// # Special Cases (by α value)
1230///
1231/// - **α = 2**: L2 loss (no robustness)
1232/// - **α = 1**: Charbonnier/Pseudo-Huber loss
1233/// - **α = 0**: Cauchy loss
1234/// - **α = -1**: Welsch loss
1235/// - **α = -2**: Geman-McClure loss
1236/// - **α → -∞**: L0 "norm" (binary)
1237///
1238/// # Properties
1239///
1240/// - **Unified**: Single framework for many loss functions
1241/// - **Adaptive**: Can learn optimal α during training
1242/// - **Smooth**: Continuously differentiable in α
1243/// - **Modern**: State-of-the-art from computer vision research
1244///
1245/// # Parameter Selection
1246///
1247/// **Fixed α (manual tuning):**
1248/// - α = 2.0: Clean data, no outliers
1249/// - α = 0.0 to 1.0: Moderate outliers
1250/// - α = -2.0 to 0.0: Heavy outliers
1251///
1252/// **Scale c:**
1253/// - c = 1.0: Standard choice
1254/// - Adjust based on expected residual magnitude
1255///
1256/// # References
1257///
1258/// Barron, J. T. (2019). A general and adaptive robust loss function.
1259/// IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).
1260///
1261/// # Example
1262///
1263/// ```
1264/// use apex_solver::core::loss_functions::{LossFunction, BarronGeneralLoss};
1265/// # use apex_solver::error::ApexSolverResult;
1266/// # fn example() -> ApexSolverResult<()> {
1267///
1268/// // Cauchy-like behavior
1269/// let barron = BarronGeneralLoss::new(0.0, 1.0)?;
1270///
1271/// let [rho, rho_prime, _] = barron.evaluate(4.0);
1272/// // Behaves like Cauchy loss
1273/// # Ok(())
1274/// # }
1275/// # example().unwrap();
1276/// ```
1277#[derive(Debug, Clone)]
1278pub struct BarronGeneralLoss {
1279 alpha: f64,
1280 scale: f64,
1281 scale2: f64,
1282}
1283
1284impl BarronGeneralLoss {
1285 /// Create a new Barron general robust loss function.
1286 ///
1287 /// # Arguments
1288 ///
1289 /// * `alpha` - The shape parameter (controls robustness)
1290 /// * `scale` - The scale parameter c (must be positive)
1291 pub fn new(alpha: f64, scale: f64) -> ApexSolverResult<Self> {
1292 if scale <= 0.0 {
1293 return Err(CoreError::InvalidInput("scale must be positive".to_string()).into());
1294 }
1295 Ok(BarronGeneralLoss {
1296 alpha,
1297 scale,
1298 scale2: scale * scale,
1299 })
1300 }
1301}
1302
1303impl LossFunction for BarronGeneralLoss {
1304 fn evaluate(&self, s: f64) -> [f64; 3] {
1305 // Handle special case α ≈ 0 (Cauchy loss)
1306 if self.alpha.abs() < 1e-6 {
1307 let denom = 1.0 + s / self.scale2;
1308 let inv = 1.0 / denom;
1309 return [
1310 (self.scale2 / 2.0) * denom.ln(),
1311 inv.max(f64::MIN),
1312 -inv * inv / self.scale2,
1313 ];
1314 }
1315
1316 // Handle special case α ≈ 2 (L2 loss)
1317 if (self.alpha - 2.0).abs() < 1e-6 {
1318 return [s, 1.0, 0.0];
1319 }
1320
1321 // General case
1322 let x = s.sqrt();
1323 let normalized = x / self.scale;
1324 let normalized2 = normalized * normalized;
1325
1326 let inner = self.alpha.abs() / 2.0 * normalized2 + 1.0;
1327 let power = inner.powf(self.alpha / 2.0);
1328
1329 // ρ(s) = (|α|/c²) * (power - 1)
1330 let rho = (self.alpha.abs() / self.scale2) * (power - 1.0);
1331
1332 // ρ'(s) = (1/2) * inner^(α/2 - 1)
1333 let rho_prime = 0.5 * inner.powf(self.alpha / 2.0 - 1.0);
1334
1335 // ρ''(s) = (α - 2)/(4c²) * inner^(α/2 - 2)
1336 let rho_double_prime =
1337 (self.alpha - 2.0) / (4.0 * self.scale2) * inner.powf(self.alpha / 2.0 - 2.0);
1338
1339 [rho, rho_prime, rho_double_prime]
1340 }
1341}
1342
1343/// Student's t-distribution loss function (robust M-estimator).
1344///
1345/// The t-distribution loss is derived from the negative log-likelihood of Student's
1346/// t-distribution. It provides heavy tails for robustness against outliers, with the
1347/// degrees of freedom parameter ν controlling the tail heaviness.
1348///
1349/// # Mathematical Definition
1350///
1351/// ```text
1352/// ρ(s) = (ν + 1)/2 · log(1 + s/ν)
1353/// ρ'(s) = (ν + 1)/(2(ν + s))
1354/// ρ''(s) = -(ν + 1)/(2(ν + s)²)
1355/// ```
1356///
1357/// where `ν` is the degrees of freedom and `s = ||r||²` is the squared residual norm.
1358///
1359/// # Properties
1360///
1361/// - **Heavy tails**: Provides robustness through heavier tails than Gaussian
1362/// - **Parameter control**: Small ν → heavy tails (more robust), large ν → Gaussian (less robust)
1363/// - **Well-founded**: Based on maximum likelihood estimation with t-distribution
1364/// - **Smooth**: Continuous derivatives for all s > 0
1365///
1366/// # Degrees of Freedom Selection
1367///
1368/// - **ν = 3-4**: Very robust, heavy outlier suppression
1369/// - **ν = 5**: Recommended default, good balance
1370/// - **ν = 10**: Moderate robustness
1371/// - **ν → ∞**: Converges to Gaussian (L2 loss)
1372///
1373/// # Use Cases
1374///
1375/// - Robust regression with unknown outlier distribution
1376/// - SLAM and pose graph optimization with loop closure outliers
1377/// - Bundle adjustment with incorrect feature matches
1378/// - Any optimization problem with heavy-tailed noise
1379///
1380/// # References
1381///
1382/// - Student's t-distribution is widely used in robust statistics
1383/// - Applied in robust SLAM (e.g., Chebrolu et al. 2021, Agarwal et al.)
1384///
1385/// # Example
1386///
1387/// ```
1388/// use apex_solver::core::loss_functions::{LossFunction, TDistributionLoss};
1389/// # use apex_solver::error::ApexSolverResult;
1390/// # fn example() -> ApexSolverResult<()> {
1391///
1392/// let t_loss = TDistributionLoss::new(5.0)?;
1393///
1394/// let [rho, rho_prime, _] = t_loss.evaluate(4.0);
1395/// // Robust to outliers with heavy tails
1396/// # Ok(())
1397/// # }
1398/// # example().unwrap();
1399/// ```
1400#[derive(Debug, Clone)]
1401pub struct TDistributionLoss {
1402 nu: f64, // Degrees of freedom
1403 half_nu_plus_1: f64, // (ν + 1)/2 (cached)
1404}
1405
1406impl TDistributionLoss {
1407 /// Create a new Student's t-distribution loss function.
1408 ///
1409 /// # Arguments
1410 ///
1411 /// * `nu` - Degrees of freedom (must be positive)
1412 ///
1413 /// # Recommended Values
1414 ///
1415 /// - ν = 5.0: Default, good balance between robustness and efficiency
1416 /// - ν = 3.0-4.0: More robust to outliers
1417 /// - ν = 10.0: Less aggressive, closer to Gaussian
1418 pub fn new(nu: f64) -> ApexSolverResult<Self> {
1419 if nu <= 0.0 {
1420 return Err(
1421 CoreError::InvalidInput("degrees of freedom must be positive".to_string()).into(),
1422 );
1423 }
1424 Ok(TDistributionLoss {
1425 nu,
1426 half_nu_plus_1: (nu + 1.0) / 2.0,
1427 })
1428 }
1429}
1430
1431impl LossFunction for TDistributionLoss {
1432 fn evaluate(&self, s: f64) -> [f64; 3] {
1433 // ρ(s) = (ν + 1)/2 · log(1 + s/ν)
1434 let inner = 1.0 + s / self.nu;
1435 let rho = self.half_nu_plus_1 * inner.ln();
1436
1437 // ρ'(s) = (ν + 1)/(2(ν + s))
1438 let denom = self.nu + s;
1439 let rho_prime = self.half_nu_plus_1 / denom;
1440
1441 // ρ''(s) = -(ν + 1)/(2(ν + s)²)
1442 let rho_double_prime = -self.half_nu_plus_1 / (denom * denom);
1443
1444 [rho, rho_prime, rho_double_prime]
1445 }
1446}
1447
1448/// Adaptive Barron loss function (simplified version).
1449///
1450/// This is a convenience wrapper around `BarronGeneralLoss` with recommended default
1451/// parameters for adaptive robust optimization. Based on Chebrolu et al. (2021) RAL paper
1452/// "Adaptive Robust Kernels for Non-Linear Least Squares Problems".
1453///
1454/// # Mathematical Definition
1455///
1456/// For α ≠ 0:
1457/// ```text
1458/// ρ(s) = |α - 2|/α · ((s/c² + 1)^(α/2) - 1)
1459/// ```
1460///
1461/// For α = 0 (Cauchy-like):
1462/// ```text
1463/// ρ(s) = log(s/(2c²) + 1)
1464/// ```
1465///
1466/// where `α` is the shape parameter and `c` is the scale parameter.
1467///
1468/// # Properties
1469///
1470/// - **Adaptive**: Can approximate many M-estimators (Huber, Cauchy, Geman-McClure, etc.)
1471/// - **Shape parameter α**: Controls the robustness level
1472/// - **Scale parameter c**: Controls the transition point
1473/// - **Unified framework**: Single loss function family
1474///
1475/// # Parameter Selection
1476///
1477/// **Default (α = 0.0, c = 1.0):**
1478/// - Cauchy-like behavior
1479/// - Good general-purpose robust loss
1480/// - Suitable for moderate to heavy outliers
1481///
1482/// **Other values:**
1483/// - α = 2.0: L2 loss (no robustness)
1484/// - α = 1.0: Pseudo-Huber/Charbonnier-like
1485/// - α = -2.0: Geman-McClure-like (very robust)
1486///
1487/// # Note on Adaptivity
1488///
1489/// This simplified version uses fixed parameters. The full adaptive version from
1490/// Chebrolu et al. requires iterative estimation of α based on residual distribution,
1491/// which would require integration into the optimizer's main loop.
1492///
1493/// # References
1494///
1495/// Chebrolu, N., Läbe, T., Vysotska, O., Behley, J., & Stachniss, C. (2021).
1496/// Adaptive robust kernels for non-linear least squares problems.
1497/// IEEE Robotics and Automation Letters, 6(2), 2240-2247.
1498///
1499/// Barron, J. T. (2019). A general and adaptive robust loss function.
1500/// IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).
1501///
1502/// # Example
1503///
1504/// ```
1505/// use apex_solver::core::loss_functions::{LossFunction, AdaptiveBarronLoss};
1506/// # use apex_solver::error::ApexSolverResult;
1507/// # fn example() -> ApexSolverResult<()> {
1508///
1509/// // Default Cauchy-like behavior
1510/// let adaptive = AdaptiveBarronLoss::new(0.0, 1.0)?;
1511///
1512/// let [rho, rho_prime, _] = adaptive.evaluate(4.0);
1513/// // Adaptive robust behavior
1514/// # Ok(())
1515/// # }
1516/// # example().unwrap();
1517/// ```
1518#[derive(Debug, Clone)]
1519pub struct AdaptiveBarronLoss {
1520 inner: BarronGeneralLoss,
1521}
1522
1523impl AdaptiveBarronLoss {
1524 /// Create a new adaptive Barron loss function.
1525 ///
1526 /// # Arguments
1527 ///
1528 /// * `alpha` - Shape parameter (default: 0.0 for Cauchy-like)
1529 /// * `scale` - Scale parameter c (must be positive)
1530 ///
1531 /// # Recommended Defaults
1532 ///
1533 /// - α = 0.0, c = 1.0: General-purpose robust loss
1534 pub fn new(alpha: f64, scale: f64) -> ApexSolverResult<Self> {
1535 Ok(AdaptiveBarronLoss {
1536 inner: BarronGeneralLoss::new(alpha, scale)?,
1537 })
1538 }
1539
1540 /// Create default instance without validation (alpha=0.0, scale=1.0).
1541 ///
1542 /// This is safe because the default parameters are mathematically valid.
1543 const fn new_default() -> Self {
1544 AdaptiveBarronLoss {
1545 inner: BarronGeneralLoss {
1546 alpha: 0.0,
1547 scale: 1.0,
1548 scale2: 1.0,
1549 },
1550 }
1551 }
1552}
1553
1554impl LossFunction for AdaptiveBarronLoss {
1555 fn evaluate(&self, s: f64) -> [f64; 3] {
1556 self.inner.evaluate(s)
1557 }
1558}
1559
1560impl Default for AdaptiveBarronLoss {
1561 /// Creates default AdaptiveBarronLoss with validated parameters (alpha=0.0, scale=1.0).
1562 fn default() -> Self {
1563 Self::new_default()
1564 }
1565}
1566
1567#[cfg(test)]
1568mod tests {
1569 use super::*;
1570
1571 type TestResult = Result<(), Box<dyn std::error::Error>>;
1572
1573 const EPSILON: f64 = 1e-6;
1574
1575 /// Helper function to test derivatives numerically
1576 fn numerical_derivative(loss: &dyn LossFunction, s: f64, h: f64) -> (f64, f64) {
1577 let [rho_plus, _, _] = loss.evaluate(s + h);
1578 let [rho_minus, _, _] = loss.evaluate(s - h);
1579 let [rho, _, _] = loss.evaluate(s);
1580
1581 // First derivative: f'(x) ≈ (f(x+h) - f(x-h)) / (2h)
1582 let rho_prime_numerical = (rho_plus - rho_minus) / (2.0 * h);
1583
1584 // Second derivative: f''(x) ≈ (f(x+h) - 2f(x) + f(x-h)) / h²
1585 let rho_double_prime_numerical = (rho_plus - 2.0 * rho + rho_minus) / (h * h);
1586
1587 (rho_prime_numerical, rho_double_prime_numerical)
1588 }
1589
1590 #[test]
1591 fn test_l2_loss() -> TestResult {
1592 let loss = L2Loss;
1593
1594 // Test at s = 0
1595 let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
1596 assert_eq!(rho, 0.0);
1597 assert_eq!(rho_prime, 1.0);
1598 assert_eq!(rho_double_prime, 0.0);
1599
1600 // Test at s = 4.0
1601 let [rho, rho_prime, rho_double_prime] = loss.evaluate(4.0);
1602 assert_eq!(rho, 4.0);
1603 assert_eq!(rho_prime, 1.0);
1604 assert_eq!(rho_double_prime, 0.0);
1605
1606 Ok(())
1607 }
1608
1609 #[test]
1610 fn test_l1_loss() -> TestResult {
1611 let loss = L1Loss;
1612
1613 // Test at s = 0 (should handle gracefully)
1614 let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
1615 assert_eq!(rho, 0.0);
1616 assert!(rho_prime.is_finite());
1617 assert!(rho_double_prime.is_finite());
1618
1619 // Test at s = 4.0 (√s = 2.0, ρ(s) = 2√s = 4.0)
1620 let [rho, rho_prime, _] = loss.evaluate(4.0);
1621 assert!((rho - 4.0).abs() < EPSILON); // ρ(s) = 2√s = 2*2 = 4
1622 assert!((rho_prime - 0.5).abs() < EPSILON); // ρ'(s) = 1/√s = 1/2 = 0.5
1623
1624 Ok(())
1625 }
1626
1627 #[test]
1628 fn test_fair_loss() -> TestResult {
1629 let loss = FairLoss::new(1.3999)?;
1630
1631 // Test at s = 0 (special case handling)
1632 let [rho, rho_prime, rho_double_prime] = loss.evaluate(0.0);
1633 assert_eq!(rho, 0.0);
1634 assert_eq!(rho_prime, 1.0);
1635 assert_eq!(rho_double_prime, 0.0);
1636
1637 // Test inlier region (s = 1.0)
1638 let [_, rho_prime, _] = loss.evaluate(1.0);
1639 assert!(rho_prime > 0.2 && rho_prime < 0.25); // ρ'(s) = 1/(2(c+|x|)) where x=1, c≈1.4 → ~0.208
1640
1641 // Test outlier region (s = 100.0)
1642 let [_, rho_prime_outlier, _] = loss.evaluate(100.0);
1643 assert!(rho_prime_outlier < rho_prime); // Weight should decrease
1644
1645 // Test that derivatives are finite and reasonable
1646 let [_, rho_prime_4, rho_double_prime_4] = loss.evaluate(4.0);
1647 assert!(rho_prime_4.is_finite() && rho_prime_4 > 0.0);
1648 assert!(rho_double_prime_4.is_finite() && rho_double_prime_4 < 0.0); // Convex near origin
1649
1650 Ok(())
1651 }
1652
1653 #[test]
1654 fn test_geman_mcclure_loss() -> TestResult {
1655 let loss = GemanMcClureLoss::new(1.0)?;
1656
1657 // Test at s = 0
1658 let [rho, rho_prime, _] = loss.evaluate(0.0);
1659 assert_eq!(rho, 0.0);
1660 assert!((rho_prime - 1.0).abs() < EPSILON);
1661
1662 // Test outlier suppression
1663 let [_, rho_prime_small, _] = loss.evaluate(1.0);
1664 let [_, rho_prime_large, _] = loss.evaluate(100.0);
1665 assert!(rho_prime_large < rho_prime_small);
1666 assert!(rho_prime_large < 0.1); // Strong suppression
1667
1668 // Verify derivatives
1669 let s = 2.0;
1670 let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1671 let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1672 assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1673 assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
1674
1675 Ok(())
1676 }
1677
1678 #[test]
1679 fn test_welsch_loss() -> TestResult {
1680 let loss = WelschLoss::new(2.9846)?;
1681
1682 // Test at s = 0: ρ'(0) = 0.5 * exp(0) = 0.5
1683 let [rho, rho_prime, _] = loss.evaluate(0.0);
1684 assert_eq!(rho, 0.0);
1685 assert!((rho_prime - 0.5).abs() < EPSILON); // ρ'(s) = 0.5 * exp(-s/c²), at s=0: 0.5
1686
1687 // Test redescending behavior
1688 let [_, rho_prime_10, _] = loss.evaluate(10.0);
1689 let [_, rho_prime_100, _] = loss.evaluate(100.0);
1690 assert!(rho_prime_100 < rho_prime_10);
1691 assert!(rho_prime_100 < 0.01); // Nearly zero for large outliers
1692
1693 // Verify derivatives
1694 let s = 5.0;
1695 let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1696 let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1697 assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1698 assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
1699
1700 Ok(())
1701 }
1702
1703 #[test]
1704 fn test_tukey_biweight_loss() -> TestResult {
1705 let loss = TukeyBiweightLoss::new(4.6851)?;
1706
1707 // Test at s = 0: ρ'(0) = 0.5 * (1-0)^2 = 0.5
1708 let [rho, rho_prime, _] = loss.evaluate(0.0);
1709 assert_eq!(rho, 0.0);
1710 assert!((rho_prime - 0.5).abs() < EPSILON); // ρ'(s) = 0.5 * (1 - ratio²)², at s=0: 0.5
1711
1712 // Test within threshold
1713 let scale2 = 4.6851 * 4.6851;
1714 let [_, rho_prime_in, _] = loss.evaluate(scale2 * 0.5);
1715 assert!(rho_prime_in > 0.05);
1716
1717 // Test beyond threshold (complete suppression)
1718 let [_, rho_prime_out, _] = loss.evaluate(scale2 * 1.5);
1719 assert_eq!(rho_prime_out, 0.0);
1720
1721 // Test that derivatives are finite and reasonable
1722 let [_, rho_prime_5, rho_double_prime_5] = loss.evaluate(5.0);
1723 assert!(rho_prime_5.is_finite() && rho_prime_5 > 0.0);
1724 assert!(rho_double_prime_5.is_finite() && rho_double_prime_5 < 0.0);
1725
1726 Ok(())
1727 }
1728
1729 #[test]
1730 fn test_andrews_wave_loss() -> TestResult {
1731 let loss = AndrewsWaveLoss::new(1.339)?;
1732
1733 // Test at s = 0: ρ'(0) = 0.5 * sin(0) = 0
1734 let [rho, rho_prime, _] = loss.evaluate(0.0);
1735 assert_eq!(rho, 0.0);
1736 assert!(rho_prime.abs() < EPSILON); // ρ'(s) = 0.5 * sin(x/c), at s=0: 0
1737
1738 // Test within threshold (small s where sin(x/c) gives moderate weight)
1739 let [_, rho_prime_in, _] = loss.evaluate(1.0);
1740 assert!(rho_prime_in > 0.33 && rho_prime_in < 0.35); // ~0.3397
1741
1742 // Test beyond threshold
1743 let scale = 1.339;
1744 let [_, rho_prime_out, _] = loss.evaluate((scale * std::f64::consts::PI + 0.1).powi(2));
1745 assert!(rho_prime_out.abs() < 0.01);
1746
1747 // Test that derivatives are finite
1748 let [_, rho_prime_1, rho_double_prime_1] = loss.evaluate(1.0);
1749 assert!(rho_prime_1.is_finite() && rho_prime_1 > 0.0);
1750 assert!(rho_double_prime_1.is_finite());
1751
1752 Ok(())
1753 }
1754
1755 #[test]
1756 fn test_ramsay_ea_loss() -> TestResult {
1757 let loss = RamsayEaLoss::new(0.3)?;
1758
1759 // Test at s = 0 (should handle gracefully)
1760 let [rho, _, _] = loss.evaluate(0.0);
1761 assert_eq!(rho, 0.0);
1762
1763 // Test exponential decay behavior
1764 let [_, rho_prime_small, _] = loss.evaluate(1.0);
1765 let [_, rho_prime_large, _] = loss.evaluate(100.0);
1766 assert!(rho_prime_large < rho_prime_small);
1767
1768 // Verify derivatives
1769 let s = 4.0;
1770 let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1771 let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1772 assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1773 assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
1774
1775 Ok(())
1776 }
1777
1778 #[test]
1779 fn test_trimmed_mean_loss() -> TestResult {
1780 let loss = TrimmedMeanLoss::new(2.0)?;
1781 let scale2 = 4.0;
1782
1783 // Test below threshold (L2 behavior)
1784 let [rho, rho_prime, rho_double_prime] = loss.evaluate(2.0);
1785 assert!((rho - 1.0).abs() < EPSILON);
1786 assert!((rho_prime - 0.5).abs() < EPSILON);
1787 assert_eq!(rho_double_prime, 0.0);
1788
1789 // Test above threshold (constant)
1790 let [rho_out, rho_prime_out, rho_double_prime_out] = loss.evaluate(10.0);
1791 assert!((rho_out - scale2 / 2.0).abs() < EPSILON);
1792 assert_eq!(rho_prime_out, 0.0);
1793 assert_eq!(rho_double_prime_out, 0.0);
1794
1795 Ok(())
1796 }
1797
1798 #[test]
1799 fn test_lp_norm_loss() -> TestResult {
1800 // Test L1 (p = 1)
1801 let l1 = LpNormLoss::new(1.0)?;
1802 let [rho_l1, _, _] = l1.evaluate(4.0);
1803 assert!((rho_l1 - 2.0).abs() < EPSILON); // ||r||₁ = 2
1804
1805 // Test L2 (p = 2)
1806 let l2 = LpNormLoss::new(2.0)?;
1807 let [rho_l2, rho_prime_l2, rho_double_prime_l2] = l2.evaluate(4.0);
1808 assert!((rho_l2 - 4.0).abs() < EPSILON);
1809 assert!((rho_prime_l2 - 1.0).abs() < EPSILON);
1810 assert_eq!(rho_double_prime_l2, 0.0);
1811
1812 // Test fractional p (p = 0.5)
1813 let l05 = LpNormLoss::new(0.5)?;
1814 let [_, rho_prime_05, _] = l05.evaluate(4.0);
1815 assert!(rho_prime_05 < 1.0); // Robust behavior
1816
1817 // Verify derivatives for p = 1.5
1818 let loss = LpNormLoss::new(1.5)?;
1819 let s = 4.0;
1820 let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1821 let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1822 assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1823 assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-3);
1824
1825 Ok(())
1826 }
1827
1828 #[test]
1829 fn test_barron_general_loss_special_cases() -> TestResult {
1830 // α = 0 (Cauchy-like)
1831 let cauchy = BarronGeneralLoss::new(0.0, 1.0)?;
1832 let [_, rho_prime_small, _] = cauchy.evaluate(1.0);
1833 let [_, rho_prime_large, _] = cauchy.evaluate(100.0);
1834 assert!(rho_prime_large < rho_prime_small);
1835
1836 // α = 2 (L2)
1837 let l2 = BarronGeneralLoss::new(2.0, 1.0)?;
1838 let [rho, rho_prime, rho_double_prime] = l2.evaluate(4.0);
1839 assert!((rho - 4.0).abs() < EPSILON);
1840 assert!((rho_prime - 1.0).abs() < EPSILON);
1841 assert!(rho_double_prime.abs() < EPSILON);
1842
1843 // α = 1 (Charbonnier-like)
1844 let charbonnier = BarronGeneralLoss::new(1.0, 1.0)?;
1845 let [_, rho_prime_char, _] = charbonnier.evaluate(4.0);
1846 assert!(rho_prime_char > 0.0 && rho_prime_char < 1.0);
1847
1848 // Test α = -2 (Geman-McClure-like) - strong outlier suppression
1849 let gm = BarronGeneralLoss::new(-2.0, 1.0)?;
1850 let [_, rho_prime_small, _] = gm.evaluate(1.0);
1851 let [_, rho_prime_large, _] = gm.evaluate(100.0);
1852 assert!(rho_prime_large < rho_prime_small); // Redescending behavior
1853 assert!(rho_prime_large < 0.1); // Strong suppression
1854
1855 Ok(())
1856 }
1857
1858 #[test]
1859 fn test_constructor_validation() -> TestResult {
1860 // Test that negative or zero scale parameters are rejected
1861 assert!(FairLoss::new(0.0).is_err());
1862 assert!(FairLoss::new(-1.0).is_err());
1863 assert!(GemanMcClureLoss::new(0.0).is_err());
1864 assert!(WelschLoss::new(-1.0).is_err());
1865 assert!(TukeyBiweightLoss::new(0.0).is_err());
1866 assert!(AndrewsWaveLoss::new(-1.0).is_err());
1867 assert!(RamsayEaLoss::new(0.0).is_err());
1868 assert!(TrimmedMeanLoss::new(-1.0).is_err());
1869 assert!(BarronGeneralLoss::new(0.0, 0.0).is_err());
1870 assert!(BarronGeneralLoss::new(1.0, -1.0).is_err());
1871
1872 // Test that p ≤ 0 is rejected for LpNormLoss
1873 assert!(LpNormLoss::new(0.0).is_err());
1874 assert!(LpNormLoss::new(-1.0).is_err());
1875
1876 // Test valid constructors
1877 assert!(FairLoss::new(1.0).is_ok());
1878 assert!(LpNormLoss::new(1.5).is_ok());
1879 assert!(BarronGeneralLoss::new(1.0, 1.0).is_ok());
1880
1881 Ok(())
1882 }
1883
1884 #[test]
1885 fn test_loss_comparison() -> TestResult {
1886 // Compare robustness: L2 vs Huber vs Cauchy at outlier
1887 let s_outlier = 100.0;
1888
1889 let l2 = L2Loss;
1890 let huber = HuberLoss::new(1.345)?;
1891 let cauchy = CauchyLoss::new(2.3849)?;
1892
1893 let [_, w_l2, _] = l2.evaluate(s_outlier);
1894 let [_, w_huber, _] = huber.evaluate(s_outlier);
1895 let [_, w_cauchy, _] = cauchy.evaluate(s_outlier);
1896
1897 // L2 should give highest weight (no robustness)
1898 assert!(w_l2 > w_huber);
1899 assert!(w_huber > w_cauchy);
1900
1901 // Cauchy should strongly suppress outliers
1902 assert!(w_cauchy < 0.1);
1903
1904 Ok(())
1905 }
1906
1907 #[test]
1908 fn test_t_distribution_loss() -> TestResult {
1909 let loss = TDistributionLoss::new(5.0)?;
1910
1911 // Test at s = 0 (should be well-defined)
1912 let [rho, rho_prime, _] = loss.evaluate(0.0);
1913 assert_eq!(rho, 0.0);
1914 assert!((rho_prime - 0.6).abs() < 0.01); // (ν+1)/(2ν) = 6/10 = 0.6
1915
1916 // Test heavy tail behavior (downweighting)
1917 let [_, rho_prime_small, _] = loss.evaluate(1.0);
1918 let [_, rho_prime_large, _] = loss.evaluate(100.0);
1919 assert!(rho_prime_large < rho_prime_small);
1920
1921 // Test that large outliers are strongly downweighted
1922 assert!(rho_prime_large < 0.1);
1923
1924 // Verify derivatives numerically
1925 let s = 4.0;
1926 let [_, rho_prime, rho_double_prime] = loss.evaluate(s);
1927 let (rho_prime_num, rho_double_prime_num) = numerical_derivative(&loss, s, 1e-5);
1928 assert!((rho_prime - rho_prime_num).abs() < 1e-4);
1929 assert!((rho_double_prime - rho_double_prime_num).abs() < 1e-4);
1930
1931 Ok(())
1932 }
1933
1934 #[test]
1935 fn test_t_distribution_loss_different_nu() -> TestResult {
1936 // Test that smaller ν is more robust
1937 let t3 = TDistributionLoss::new(3.0)?;
1938 let t10 = TDistributionLoss::new(10.0)?;
1939
1940 let s_outlier = 100.0;
1941 let [_, w_t3, _] = t3.evaluate(s_outlier);
1942 let [_, w_t10, _] = t10.evaluate(s_outlier);
1943
1944 // Smaller ν should downweight more aggressively
1945 assert!(w_t3 < w_t10);
1946
1947 Ok(())
1948 }
1949
1950 #[test]
1951 fn test_adaptive_barron_loss() -> TestResult {
1952 // Test default (Cauchy-like with α = 0)
1953 let adaptive = AdaptiveBarronLoss::new(0.0, 1.0)?;
1954
1955 // Test at s = 0
1956 let [rho, _, _] = adaptive.evaluate(0.0);
1957 assert!(rho.abs() < EPSILON);
1958
1959 // Test robustness
1960 let [_, rho_prime_small, _] = adaptive.evaluate(1.0);
1961 let [_, rho_prime_large, _] = adaptive.evaluate(100.0);
1962 assert!(rho_prime_large < rho_prime_small);
1963
1964 // AdaptiveBarron wraps BarronGeneral which is already tested,
1965 // so we just verify the wrapper works correctly
1966 let barron = BarronGeneralLoss::new(0.0, 1.0)?;
1967 let [rho_a, rho_prime_a, rho_double_prime_a] = adaptive.evaluate(4.0);
1968 let [rho_b, rho_prime_b, rho_double_prime_b] = barron.evaluate(4.0);
1969
1970 // Should match the underlying BarronGeneral exactly
1971 assert!((rho_a - rho_b).abs() < EPSILON);
1972 assert!((rho_prime_a - rho_prime_b).abs() < EPSILON);
1973 assert!((rho_double_prime_a - rho_double_prime_b).abs() < EPSILON);
1974
1975 Ok(())
1976 }
1977
1978 #[test]
1979 fn test_adaptive_barron_default() -> TestResult {
1980 // Test default constructor
1981 let adaptive = AdaptiveBarronLoss::default();
1982
1983 // Should behave like Cauchy
1984 let [_, rho_prime, _] = adaptive.evaluate(4.0);
1985 assert!(rho_prime > 0.0 && rho_prime < 1.0);
1986
1987 Ok(())
1988 }
1989
1990 #[test]
1991 fn test_new_loss_constructor_validation() -> TestResult {
1992 // T-distribution: reject non-positive degrees of freedom
1993 assert!(TDistributionLoss::new(0.0).is_err());
1994 assert!(TDistributionLoss::new(-1.0).is_err());
1995 assert!(TDistributionLoss::new(5.0).is_ok());
1996
1997 // Adaptive Barron: reject non-positive scale
1998 assert!(AdaptiveBarronLoss::new(0.0, 0.0).is_err());
1999 assert!(AdaptiveBarronLoss::new(1.0, -1.0).is_err());
2000 assert!(AdaptiveBarronLoss::new(0.0, 1.0).is_ok());
2001
2002 Ok(())
2003 }
2004}