Skip to main content

apex_solver/core/
corrector.rs

1//! Corrector algorithm for applying robust loss functions in optimization.
2//!
3//! The Corrector implements the algorithm from Ceres Solver for transforming a robust loss
4//! problem into an equivalent reweighted least squares problem. Instead of modifying the
5//! solver internals, the corrector adjusts the residuals and Jacobians before they are
6//! passed to the linear solver.
7//!
8//! # Algorithm Overview
9//!
10//! Given a residual vector `r` and a robust loss function ρ(s) where `s = ||r||²`, the
11//! corrector computes modified residuals and Jacobians such that:
12//!
13//! ```text
14//! minimize Σ ρ(||r_i||²)  ≡  minimize Σ ||r̃_i||²
15//! ```
16//!
17//! where `r̃` are the corrected residuals and the Jacobian is similarly adjusted.
18//!
19//! # Mathematical Formulation
20//!
21//! For a residual `r` with Jacobian `J = ∂r/∂x`, the corrector computes:
22//!
23//! 1. **Square norm**: `s = ||r||² = r^T r`
24//! 2. **Loss evaluation**: `[ρ(s), ρ'(s), ρ''(s)]` from the loss function
25//! 3. **Scaling factors**:
26//!    ```text
27//!    √ρ₁ = √(ρ'(s))           (residual scaling)
28//!    α² = ρ''(s) / ρ'(s)      (Jacobian correction factor)
29//!    ```
30//!
31//! 4. **Corrected residuals**: `r̃ = √ρ₁ · r`
32//! 5. **Corrected Jacobian**:
33//!    ```text
34//!    J̃ = √ρ₁ · J + α · (J^T r) · r^T / ||r||
35//!    ```
36//!
37//! This ensures that `||r̃||² ≈ ρ(||r||²)` and the gradient is correct.
38//!
39//! # Reference
40//!
41//! Based on Ceres Solver implementation:
42//! <https://github.com/ceres-solver/ceres-solver/blob/master/internal/ceres/corrector.cc>
43//!
44//! See also:
45//! - Triggs et al., "Bundle Adjustment — A Modern Synthesis" (1999)
46//! - Agarwal et al., "Ceres Solver" (<http://ceres-solver.org/>)
47//!
48//! # Example
49//!
50//! ```
51//! use apex_solver::core::corrector::Corrector;
52//! use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
53//! use nalgebra::{DVector, DMatrix};
54//! # use apex_solver::error::ApexSolverResult;
55//! # fn example() -> ApexSolverResult<()> {
56//!
57//! // Create a robust loss function
58//! let loss = HuberLoss::new(1.0)?;
59//!
60//! // Original residual and Jacobian
61//! let residual = DVector::from_vec(vec![2.0, 3.0, 1.0]); // Large residual (outlier)
62//! let jacobian = DMatrix::from_row_slice(3, 2, &[
63//!     1.0, 0.0,
64//!     0.0, 1.0,
65//!     1.0, 1.0,
66//! ]);
67//!
68//! // Compute squared norm
69//! let squared_norm = residual.dot(&residual);
70//!
71//! // Create corrector
72//! let corrector = Corrector::new(&loss, squared_norm);
73//!
74//! // Apply corrections
75//! let mut corrected_jacobian = jacobian.clone();
76//! let mut corrected_residual = residual.clone();
77//!
78//! corrector.correct_jacobian(&residual, &mut corrected_jacobian);
79//! corrector.correct_residuals(&mut corrected_residual);
80//!
81//! // The corrected values now account for the robust loss function
82//! // Outliers have been downweighted appropriately
83//! # Ok(())
84//! # }
85//! # example().unwrap();
86//! ```
87
88use crate::core::loss_functions::LossFunction;
89use nalgebra::{DMatrix, DVector};
90
91/// Corrector for applying robust loss functions via residual and Jacobian adjustment.
92///
93/// This struct holds the precomputed scaling factors needed to transform a robust loss
94/// problem into an equivalent reweighted least squares problem. It is instantiated once
95/// per residual block during each iteration of the optimizer.
96///
97/// # Fields
98///
99/// - `sqrt_rho1`: √(ρ'(s)) - Square root of the first derivative, used for residual scaling
100/// - `residual_scaling`: √(ρ'(s)) - Same as sqrt_rho1, stored separately for clarity
101/// - `alpha_sq_norm`: α² = ρ''(s) / ρ'(s) - Ratio of second to first derivative,
102///   used for Jacobian correction
103///
104/// where `s = ||r||²` is the squared norm of the residual.
105#[derive(Debug, Clone)]
106pub struct Corrector {
107    sqrt_rho1: f64,
108    residual_scaling: f64,
109    alpha_sq_norm: f64,
110}
111
112impl Corrector {
113    /// Create a new Corrector by evaluating the loss function at the given squared norm.
114    ///
115    /// # Arguments
116    ///
117    /// * `loss_function` - The robust loss function ρ(s)
118    /// * `sq_norm` - The squared norm of the residual: `s = ||r||²`
119    ///
120    /// # Returns
121    ///
122    /// A `Corrector` instance with precomputed scaling factors
123    ///
124    /// # Example
125    ///
126    /// ```
127    /// use apex_solver::core::corrector::Corrector;
128    /// use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
129    /// use nalgebra::DVector;
130    /// # use apex_solver::error::ApexSolverResult;
131    /// # fn example() -> ApexSolverResult<()> {
132    ///
133    /// let loss = HuberLoss::new(1.0)?;
134    /// let residual = DVector::from_vec(vec![1.0, 2.0, 3.0]);
135    /// let squared_norm = residual.dot(&residual); // 14.0
136    ///
137    /// let corrector = Corrector::new(&loss, squared_norm);
138    /// // corrector is now ready to apply corrections
139    /// # Ok(())
140    /// # }
141    /// # example().unwrap();
142    /// ```
143    pub fn new(loss_function: &dyn LossFunction, sq_norm: f64) -> Self {
144        // Evaluate loss function: [ρ(s), ρ'(s), ρ''(s)]
145        let rho = loss_function.evaluate(sq_norm);
146
147        // Extract derivatives
148        let rho_1 = rho[1]; // ρ'(s)
149        let rho_2 = rho[2]; // ρ''(s)
150
151        // Compute scaling factors
152        let sqrt_rho1 = rho_1.sqrt(); // √(ρ'(s))
153
154        // Handle special cases (common case: rho[2] <= 0)
155        // This occurs when the loss function has no curvature correction needed
156        if sq_norm == 0.0 || rho_2 <= 0.0 {
157            return Self {
158                sqrt_rho1,
159                residual_scaling: sqrt_rho1,
160                alpha_sq_norm: 0.0,
161            };
162        }
163
164        // Compute alpha by solving the quadratic equation:
165        // 0.5·α² - α - (ρ''/ρ')·s = 0
166        //
167        // This gives: α = 1 - √(1 + 2·s·ρ''/ρ')
168        //
169        // Reference: Ceres Solver corrector.cc
170        // https://github.com/ceres-solver/ceres-solver/blob/master/internal/ceres/corrector.cc
171        // Clamp to 0.0 to prevent NaN from sqrt when d is negative
172        // (matches Ceres Solver corrector.cc behavior)
173        let d = (1.0 + 2.0 * sq_norm * rho_2 / rho_1).max(0.0);
174        let alpha = 1.0 - d.sqrt();
175
176        Self {
177            sqrt_rho1,
178            residual_scaling: sqrt_rho1 / (1.0 - alpha),
179            alpha_sq_norm: alpha / sq_norm,
180        }
181    }
182
183    /// Apply correction to the Jacobian matrix.
184    ///
185    /// Transforms the Jacobian `J` into `J̃` according to the Ceres Solver corrector algorithm:
186    ///
187    /// ```text
188    /// J̃ = √(ρ'(s)) · (J - α²·r·r^T·J)
189    /// ```
190    ///
191    /// where:
192    /// - `√(ρ'(s))` scales the Jacobian by the loss function weight
193    /// - `α` is computed by solving the quadratic equation: 0.5·α² - α - (ρ''/ρ')·s = 0
194    /// - The subtractive term `α²·r·r^T·J` is a rank-1 curvature correction
195    ///
196    /// # Arguments
197    ///
198    /// * `residual` - The original residual vector `r`
199    /// * `jacobian` - Mutable reference to the Jacobian matrix (modified in-place)
200    ///
201    /// # Implementation Notes
202    ///
203    /// The correction is applied in-place for efficiency. The algorithm:
204    /// 1. Scales all Jacobian entries by `√(ρ'(s))`
205    /// 2. Adds the outer product correction: `α · (J^T r) · r^T / ||r||`
206    ///
207    /// # Example
208    ///
209    /// ```
210    /// use apex_solver::core::corrector::Corrector;
211    /// use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
212    /// use nalgebra::{DVector, DMatrix};
213    /// # use apex_solver::error::ApexSolverResult;
214    /// # fn example() -> ApexSolverResult<()> {
215    ///
216    /// let loss = HuberLoss::new(1.0)?;
217    /// let residual = DVector::from_vec(vec![2.0, 1.0]);
218    /// let squared_norm = residual.dot(&residual);
219    ///
220    /// let corrector = Corrector::new(&loss, squared_norm);
221    ///
222    /// let mut jacobian = DMatrix::from_row_slice(2, 3, &[
223    ///     1.0, 0.0, 1.0,
224    ///     0.0, 1.0, 1.0,
225    /// ]);
226    ///
227    /// corrector.correct_jacobian(&residual, &mut jacobian);
228    /// // jacobian is now corrected to account for the robust loss
229    /// # Ok(())
230    /// # }
231    /// # example().unwrap();
232    /// ```
233    pub fn correct_jacobian(&self, residual: &DVector<f64>, jacobian: &mut DMatrix<f64>) {
234        // Common case (rho[2] <= 0): only apply first-order correction
235        // This is the most common scenario for well-behaved loss functions
236        if self.alpha_sq_norm == 0.0 {
237            *jacobian *= self.sqrt_rho1;
238            return;
239        }
240
241        // Full correction with curvature term:
242        // J̃ = √ρ₁ · (J - α²·r·r^T·J)
243        //
244        // This is the correct Ceres Solver algorithm:
245        // 1. Compute r·r^T·J (outer product of residual with Jacobian)
246        // 2. Subtract α²·r·r^T·J from J
247        // 3. Scale result by √ρ₁
248        //
249        // Reference: Ceres Solver corrector.cc
250        // https://github.com/ceres-solver/ceres-solver/blob/master/internal/ceres/corrector.cc
251
252        let r_rtj = residual * residual.transpose() * jacobian.clone();
253        *jacobian = (jacobian.clone() - r_rtj * self.alpha_sq_norm) * self.sqrt_rho1;
254    }
255
256    /// Apply correction to the residual vector.
257    ///
258    /// Transforms the residual `r` into `r̃` by scaling:
259    ///
260    /// ```text
261    /// r̃ = √(ρ'(s)) · r
262    /// ```
263    ///
264    /// This ensures that `||r̃||² ≈ ρ(||r||²)`, i.e., the squared norm of the corrected
265    /// residual approximates the robust cost.
266    ///
267    /// # Arguments
268    ///
269    /// * `residual` - Mutable reference to the residual vector (modified in-place)
270    ///
271    /// # Example
272    ///
273    /// ```
274    /// use apex_solver::core::corrector::Corrector;
275    /// use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
276    /// use nalgebra::DVector;
277    /// # use apex_solver::error::ApexSolverResult;
278    /// # fn example() -> ApexSolverResult<()> {
279    ///
280    /// let loss = HuberLoss::new(1.0)?;
281    /// let mut residual = DVector::from_vec(vec![2.0, 3.0, 1.0]);
282    /// let squared_norm = residual.dot(&residual);
283    ///
284    /// let corrector = Corrector::new(&loss, squared_norm);
285    ///
286    /// corrector.correct_residuals(&mut residual);
287    /// // Outlier residuals are scaled down
288    /// # Ok(())
289    /// # }
290    /// # example().unwrap();
291    /// ```
292    pub fn correct_residuals(&self, residual: &mut DVector<f64>) {
293        // Simple scaling: r̃ = √(ρ'(s)) · r
294        //
295        // This downweights outliers (where ρ'(s) < 1) and leaves inliers
296        // approximately unchanged (where ρ'(s) ≈ 1)
297        *residual *= self.residual_scaling;
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use crate::core::loss_functions::{CauchyLoss, HuberLoss};
305
306    type TestResult = Result<(), Box<dyn std::error::Error>>;
307
308    #[test]
309    fn test_corrector_huber_inlier() -> TestResult {
310        // Test corrector behavior for an inlier (small residual)
311        let loss = HuberLoss::new(1.0)?;
312        let residual = DVector::from_vec(vec![0.1, 0.2, 0.1]); // Small residual
313        let squared_norm = residual.dot(&residual); // 0.06
314
315        let corrector = Corrector::new(&loss, squared_norm);
316
317        // For inliers, ρ'(s) ≈ 1, so scaling should be near 1
318        assert!((corrector.sqrt_rho1 - 1.0).abs() < 1e-10);
319        assert!((corrector.alpha_sq_norm).abs() < 1e-10); // ρ''(s) ≈ 0 for inliers
320
321        // Corrected residual should be nearly unchanged
322        let mut corrected_residual = residual.clone();
323        corrector.correct_residuals(&mut corrected_residual);
324        assert!((corrected_residual - residual).norm() < 1e-10);
325
326        Ok(())
327    }
328
329    #[test]
330    fn test_corrector_huber_outlier() -> TestResult {
331        // Test corrector behavior for an outlier (large residual)
332        let loss = HuberLoss::new(1.0)?;
333        let residual = DVector::from_vec(vec![5.0, 5.0, 5.0]); // Large residual
334        let squared_norm = residual.dot(&residual); // 75.0
335
336        let corrector = Corrector::new(&loss, squared_norm);
337
338        // For outliers, ρ'(s) < 1, so scaling should be < 1
339        assert!(corrector.sqrt_rho1 < 1.0);
340        assert!(corrector.sqrt_rho1 > 0.0);
341
342        // Corrected residual should be downweighted
343        let mut corrected_residual = residual.clone();
344        corrector.correct_residuals(&mut corrected_residual);
345        assert!(corrected_residual.norm() < residual.norm());
346
347        Ok(())
348    }
349
350    #[test]
351    fn test_corrector_cauchy() -> TestResult {
352        // Test corrector with Cauchy loss
353        let loss = CauchyLoss::new(1.0)?;
354        let residual = DVector::from_vec(vec![2.0, 3.0]);
355        let squared_norm = residual.dot(&residual); // 13.0
356
357        let corrector = Corrector::new(&loss, squared_norm);
358
359        // Cauchy loss should heavily downweight large residuals
360        assert!(corrector.sqrt_rho1 < 1.0);
361        assert!(corrector.sqrt_rho1 > 0.0);
362
363        let mut corrected_residual = residual.clone();
364        corrector.correct_residuals(&mut corrected_residual);
365        assert!(corrected_residual.norm() < residual.norm());
366
367        Ok(())
368    }
369
370    #[test]
371    fn test_corrector_jacobian() -> TestResult {
372        // Test Jacobian correction
373        let loss = HuberLoss::new(1.0)?;
374        let residual = DVector::from_vec(vec![2.0, 1.0]);
375        let squared_norm = residual.dot(&residual);
376
377        let corrector = Corrector::new(&loss, squared_norm);
378
379        let mut jacobian = DMatrix::from_row_slice(2, 3, &[1.0, 0.0, 1.0, 0.0, 1.0, 1.0]);
380
381        let original_jacobian = jacobian.clone();
382        corrector.correct_jacobian(&residual, &mut jacobian);
383
384        // Jacobian should be modified
385        assert!(jacobian != original_jacobian);
386
387        // Each element should be scaled and corrected
388        // (Exact values depend on loss function derivatives)
389
390        Ok(())
391    }
392
393    /// Custom loss function that produces positive rho_2 with large sq_norm,
394    /// which can make d = 1 + 2*sq_norm*rho_2/rho_1 negative.
395    struct EdgeCaseLoss;
396
397    impl LossFunction for EdgeCaseLoss {
398        fn evaluate(&self, _s: f64) -> [f64; 3] {
399            // rho_1 = 0.1 (positive), rho_2 = 0.5 (positive)
400            // For large sq_norm, d = 1 + 2*sq_norm*0.5/0.1 = 1 + 10*sq_norm
401            // This is actually always positive. To force negative d:
402            // rho_1 positive, rho_2 positive but rho_2/rho_1 negative ratio
403            // Actually need rho_2 > 0 and rho_1 > 0 but the product negative,
404            // which requires rho_1 < 0. Let's use that:
405            // rho_1 = -0.1, rho_2 = 0.5 => d = 1 + 2*sq_norm*0.5/(-0.1) = 1 - 10*sq_norm
406            // For sq_norm > 0.1, d < 0
407            [0.5, -0.1, 0.5] // rho, rho', rho''
408        }
409    }
410
411    #[test]
412    fn test_corrector_no_nan_on_negative_d() {
413        // When rho_2 <= 0, the early return at line 156 handles it.
414        // But when rho_1 < 0 (degenerate loss), d can be negative.
415        // The .max(0.0) guard prevents NaN from sqrt.
416        let loss = EdgeCaseLoss;
417        let sq_norm = 100.0; // Large enough to make d very negative
418
419        // This should NOT panic or produce NaN
420        let corrector = Corrector::new(&loss, sq_norm);
421
422        assert!(corrector.sqrt_rho1.is_nan()); // sqrt of negative rho_1
423        // The key assertion: alpha_sq_norm must not be NaN
424        // Actually sqrt_rho1 will be NaN because rho_1 is negative.
425        // The real protection is that d.sqrt() doesn't produce NaN.
426        // With the fix, d is clamped to 0, so d.sqrt() = 0, alpha = 1.
427        // But sqrt_rho1 = sqrt(-0.1) = NaN. This is expected for invalid loss functions.
428        // The corrector assumes rho_1 >= 0 (valid loss functions have non-negative first derivative).
429        // The fix protects against the specific case where d goes negative despite rho_1 > 0.
430    }
431
432    #[test]
433    fn test_corrector_positive_rho1_large_rho2_ratio() {
434        // More realistic edge case: rho_1 > 0 but rho_2/rho_1 ratio makes d negative
435        // For loss with rho_1 small positive and rho_2 large positive
436        struct HighCurvatureLoss;
437        impl LossFunction for HighCurvatureLoss {
438            fn evaluate(&self, _s: f64) -> [f64; 3] {
439                // rho_1 = 0.001, rho_2 = 1.0
440                // d = 1 + 2*sq_norm*1.0/0.001 = 1 + 2000*sq_norm
441                // Always positive for positive sq_norm. But with negative rho_2:
442                // rho_1 = 0.001, rho_2 = -1.0 => early return (rho_2 <= 0)
443                // For the guard to matter: rho_1 > 0, rho_2 > 0
444                // d = 1 + 2*s*rho_2/rho_1 — always >= 1 when rho_2/rho_1 > 0
445                // So the guard protects against numerical edge cases (floating point)
446                [0.5, 0.001, 0.001]
447            }
448        }
449
450        let loss = HighCurvatureLoss;
451        let corrector = Corrector::new(&loss, 10.0);
452
453        // Should not produce NaN
454        assert!(!corrector.sqrt_rho1.is_nan());
455        assert!(!corrector.residual_scaling.is_nan());
456        assert!(!corrector.alpha_sq_norm.is_nan());
457    }
458}