apex_solver/core/corrector.rs
1//! Corrector algorithm for applying robust loss functions in optimization.
2//!
3//! The Corrector implements the algorithm from Ceres Solver for transforming a robust loss
4//! problem into an equivalent reweighted least squares problem. Instead of modifying the
5//! solver internals, the corrector adjusts the residuals and Jacobians before they are
6//! passed to the linear solver.
7//!
8//! # Algorithm Overview
9//!
10//! Given a residual vector `r` and a robust loss function ρ(s) where `s = ||r||²`, the
11//! corrector computes modified residuals and Jacobians such that:
12//!
13//! ```text
14//! minimize Σ ρ(||r_i||²) ≡ minimize Σ ||r̃_i||²
15//! ```
16//!
17//! where `r̃` are the corrected residuals and the Jacobian is similarly adjusted.
18//!
19//! # Mathematical Formulation
20//!
21//! For a residual `r` with Jacobian `J = ∂r/∂x`, the corrector computes:
22//!
23//! 1. **Square norm**: `s = ||r||² = r^T r`
24//! 2. **Loss evaluation**: `[ρ(s), ρ'(s), ρ''(s)]` from the loss function
25//! 3. **Scaling factors**:
26//! ```text
27//! √ρ₁ = √(ρ'(s)) (residual scaling)
28//! α² = ρ''(s) / ρ'(s) (Jacobian correction factor)
29//! ```
30//!
31//! 4. **Corrected residuals**: `r̃ = √ρ₁ · r`
32//! 5. **Corrected Jacobian**:
33//! ```text
34//! J̃ = √ρ₁ · J + α · (J^T r) · r^T / ||r||
35//! ```
36//!
37//! This ensures that `||r̃||² ≈ ρ(||r||²)` and the gradient is correct.
38//!
39//! # Reference
40//!
41//! Based on Ceres Solver implementation:
42//! <https://github.com/ceres-solver/ceres-solver/blob/master/internal/ceres/corrector.cc>
43//!
44//! See also:
45//! - Triggs et al., "Bundle Adjustment — A Modern Synthesis" (1999)
46//! - Agarwal et al., "Ceres Solver" (<http://ceres-solver.org/>)
47//!
48//! # Example
49//!
50//! ```
51//! use apex_solver::core::corrector::Corrector;
52//! use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
53//! use nalgebra::{DVector, DMatrix};
54//! # use apex_solver::error::ApexSolverResult;
55//! # fn example() -> ApexSolverResult<()> {
56//!
57//! // Create a robust loss function
58//! let loss = HuberLoss::new(1.0)?;
59//!
60//! // Original residual and Jacobian
61//! let residual = DVector::from_vec(vec![2.0, 3.0, 1.0]); // Large residual (outlier)
62//! let jacobian = DMatrix::from_row_slice(3, 2, &[
63//! 1.0, 0.0,
64//! 0.0, 1.0,
65//! 1.0, 1.0,
66//! ]);
67//!
68//! // Compute squared norm
69//! let squared_norm = residual.dot(&residual);
70//!
71//! // Create corrector
72//! let corrector = Corrector::new(&loss, squared_norm);
73//!
74//! // Apply corrections
75//! let mut corrected_jacobian = jacobian.clone();
76//! let mut corrected_residual = residual.clone();
77//!
78//! corrector.correct_jacobian(&residual, &mut corrected_jacobian);
79//! corrector.correct_residuals(&mut corrected_residual);
80//!
81//! // The corrected values now account for the robust loss function
82//! // Outliers have been downweighted appropriately
83//! # Ok(())
84//! # }
85//! # example().unwrap();
86//! ```
87
88use crate::core::loss_functions::LossFunction;
89use nalgebra::{DMatrix, DVector};
90
91/// Corrector for applying robust loss functions via residual and Jacobian adjustment.
92///
93/// This struct holds the precomputed scaling factors needed to transform a robust loss
94/// problem into an equivalent reweighted least squares problem. It is instantiated once
95/// per residual block during each iteration of the optimizer.
96///
97/// # Fields
98///
99/// - `sqrt_rho1`: √(ρ'(s)) - Square root of the first derivative, used for residual scaling
100/// - `residual_scaling`: √(ρ'(s)) - Same as sqrt_rho1, stored separately for clarity
101/// - `alpha_sq_norm`: α² = ρ''(s) / ρ'(s) - Ratio of second to first derivative,
102/// used for Jacobian correction
103///
104/// where `s = ||r||²` is the squared norm of the residual.
105#[derive(Debug, Clone)]
106pub struct Corrector {
107 sqrt_rho1: f64,
108 residual_scaling: f64,
109 alpha_sq_norm: f64,
110}
111
112impl Corrector {
113 /// Create a new Corrector by evaluating the loss function at the given squared norm.
114 ///
115 /// # Arguments
116 ///
117 /// * `loss_function` - The robust loss function ρ(s)
118 /// * `sq_norm` - The squared norm of the residual: `s = ||r||²`
119 ///
120 /// # Returns
121 ///
122 /// A `Corrector` instance with precomputed scaling factors
123 ///
124 /// # Example
125 ///
126 /// ```
127 /// use apex_solver::core::corrector::Corrector;
128 /// use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
129 /// use nalgebra::DVector;
130 /// # use apex_solver::error::ApexSolverResult;
131 /// # fn example() -> ApexSolverResult<()> {
132 ///
133 /// let loss = HuberLoss::new(1.0)?;
134 /// let residual = DVector::from_vec(vec![1.0, 2.0, 3.0]);
135 /// let squared_norm = residual.dot(&residual); // 14.0
136 ///
137 /// let corrector = Corrector::new(&loss, squared_norm);
138 /// // corrector is now ready to apply corrections
139 /// # Ok(())
140 /// # }
141 /// # example().unwrap();
142 /// ```
143 pub fn new(loss_function: &dyn LossFunction, sq_norm: f64) -> Self {
144 // Evaluate loss function: [ρ(s), ρ'(s), ρ''(s)]
145 let rho = loss_function.evaluate(sq_norm);
146
147 // Extract derivatives
148 let rho_1 = rho[1]; // ρ'(s)
149 let rho_2 = rho[2]; // ρ''(s)
150
151 // Compute scaling factors
152 let sqrt_rho1 = rho_1.sqrt(); // √(ρ'(s))
153
154 // Handle special cases (common case: rho[2] <= 0)
155 // This occurs when the loss function has no curvature correction needed
156 if sq_norm == 0.0 || rho_2 <= 0.0 {
157 return Self {
158 sqrt_rho1,
159 residual_scaling: sqrt_rho1,
160 alpha_sq_norm: 0.0,
161 };
162 }
163
164 // Compute alpha by solving the quadratic equation:
165 // 0.5·α² - α - (ρ''/ρ')·s = 0
166 //
167 // This gives: α = 1 - √(1 + 2·s·ρ''/ρ')
168 //
169 // Reference: Ceres Solver corrector.cc
170 // https://github.com/ceres-solver/ceres-solver/blob/master/internal/ceres/corrector.cc
171 // Clamp to 0.0 to prevent NaN from sqrt when d is negative
172 // (matches Ceres Solver corrector.cc behavior)
173 let d = (1.0 + 2.0 * sq_norm * rho_2 / rho_1).max(0.0);
174 let alpha = 1.0 - d.sqrt();
175
176 Self {
177 sqrt_rho1,
178 residual_scaling: sqrt_rho1 / (1.0 - alpha),
179 alpha_sq_norm: alpha / sq_norm,
180 }
181 }
182
183 /// Apply correction to the Jacobian matrix.
184 ///
185 /// Transforms the Jacobian `J` into `J̃` according to the Ceres Solver corrector algorithm:
186 ///
187 /// ```text
188 /// J̃ = √(ρ'(s)) · (J - α²·r·r^T·J)
189 /// ```
190 ///
191 /// where:
192 /// - `√(ρ'(s))` scales the Jacobian by the loss function weight
193 /// - `α` is computed by solving the quadratic equation: 0.5·α² - α - (ρ''/ρ')·s = 0
194 /// - The subtractive term `α²·r·r^T·J` is a rank-1 curvature correction
195 ///
196 /// # Arguments
197 ///
198 /// * `residual` - The original residual vector `r`
199 /// * `jacobian` - Mutable reference to the Jacobian matrix (modified in-place)
200 ///
201 /// # Implementation Notes
202 ///
203 /// The correction is applied in-place for efficiency. The algorithm:
204 /// 1. Scales all Jacobian entries by `√(ρ'(s))`
205 /// 2. Adds the outer product correction: `α · (J^T r) · r^T / ||r||`
206 ///
207 /// # Example
208 ///
209 /// ```
210 /// use apex_solver::core::corrector::Corrector;
211 /// use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
212 /// use nalgebra::{DVector, DMatrix};
213 /// # use apex_solver::error::ApexSolverResult;
214 /// # fn example() -> ApexSolverResult<()> {
215 ///
216 /// let loss = HuberLoss::new(1.0)?;
217 /// let residual = DVector::from_vec(vec![2.0, 1.0]);
218 /// let squared_norm = residual.dot(&residual);
219 ///
220 /// let corrector = Corrector::new(&loss, squared_norm);
221 ///
222 /// let mut jacobian = DMatrix::from_row_slice(2, 3, &[
223 /// 1.0, 0.0, 1.0,
224 /// 0.0, 1.0, 1.0,
225 /// ]);
226 ///
227 /// corrector.correct_jacobian(&residual, &mut jacobian);
228 /// // jacobian is now corrected to account for the robust loss
229 /// # Ok(())
230 /// # }
231 /// # example().unwrap();
232 /// ```
233 pub fn correct_jacobian(&self, residual: &DVector<f64>, jacobian: &mut DMatrix<f64>) {
234 // Common case (rho[2] <= 0): only apply first-order correction
235 // This is the most common scenario for well-behaved loss functions
236 if self.alpha_sq_norm == 0.0 {
237 *jacobian *= self.sqrt_rho1;
238 return;
239 }
240
241 // Full correction with curvature term:
242 // J̃ = √ρ₁ · (J - α²·r·r^T·J)
243 //
244 // This is the correct Ceres Solver algorithm:
245 // 1. Compute r·r^T·J (outer product of residual with Jacobian)
246 // 2. Subtract α²·r·r^T·J from J
247 // 3. Scale result by √ρ₁
248 //
249 // Reference: Ceres Solver corrector.cc
250 // https://github.com/ceres-solver/ceres-solver/blob/master/internal/ceres/corrector.cc
251
252 let r_rtj = residual * residual.transpose() * jacobian.clone();
253 *jacobian = (jacobian.clone() - r_rtj * self.alpha_sq_norm) * self.sqrt_rho1;
254 }
255
256 /// Apply correction to the residual vector.
257 ///
258 /// Transforms the residual `r` into `r̃` by scaling:
259 ///
260 /// ```text
261 /// r̃ = √(ρ'(s)) · r
262 /// ```
263 ///
264 /// This ensures that `||r̃||² ≈ ρ(||r||²)`, i.e., the squared norm of the corrected
265 /// residual approximates the robust cost.
266 ///
267 /// # Arguments
268 ///
269 /// * `residual` - Mutable reference to the residual vector (modified in-place)
270 ///
271 /// # Example
272 ///
273 /// ```
274 /// use apex_solver::core::corrector::Corrector;
275 /// use apex_solver::core::loss_functions::{LossFunction, HuberLoss};
276 /// use nalgebra::DVector;
277 /// # use apex_solver::error::ApexSolverResult;
278 /// # fn example() -> ApexSolverResult<()> {
279 ///
280 /// let loss = HuberLoss::new(1.0)?;
281 /// let mut residual = DVector::from_vec(vec![2.0, 3.0, 1.0]);
282 /// let squared_norm = residual.dot(&residual);
283 ///
284 /// let corrector = Corrector::new(&loss, squared_norm);
285 ///
286 /// corrector.correct_residuals(&mut residual);
287 /// // Outlier residuals are scaled down
288 /// # Ok(())
289 /// # }
290 /// # example().unwrap();
291 /// ```
292 pub fn correct_residuals(&self, residual: &mut DVector<f64>) {
293 // Simple scaling: r̃ = √(ρ'(s)) · r
294 //
295 // This downweights outliers (where ρ'(s) < 1) and leaves inliers
296 // approximately unchanged (where ρ'(s) ≈ 1)
297 *residual *= self.residual_scaling;
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::core::loss_functions::{CauchyLoss, HuberLoss};
305
306 type TestResult = Result<(), Box<dyn std::error::Error>>;
307
308 #[test]
309 fn test_corrector_huber_inlier() -> TestResult {
310 // Test corrector behavior for an inlier (small residual)
311 let loss = HuberLoss::new(1.0)?;
312 let residual = DVector::from_vec(vec![0.1, 0.2, 0.1]); // Small residual
313 let squared_norm = residual.dot(&residual); // 0.06
314
315 let corrector = Corrector::new(&loss, squared_norm);
316
317 // For inliers, ρ'(s) ≈ 1, so scaling should be near 1
318 assert!((corrector.sqrt_rho1 - 1.0).abs() < 1e-10);
319 assert!((corrector.alpha_sq_norm).abs() < 1e-10); // ρ''(s) ≈ 0 for inliers
320
321 // Corrected residual should be nearly unchanged
322 let mut corrected_residual = residual.clone();
323 corrector.correct_residuals(&mut corrected_residual);
324 assert!((corrected_residual - residual).norm() < 1e-10);
325
326 Ok(())
327 }
328
329 #[test]
330 fn test_corrector_huber_outlier() -> TestResult {
331 // Test corrector behavior for an outlier (large residual)
332 let loss = HuberLoss::new(1.0)?;
333 let residual = DVector::from_vec(vec![5.0, 5.0, 5.0]); // Large residual
334 let squared_norm = residual.dot(&residual); // 75.0
335
336 let corrector = Corrector::new(&loss, squared_norm);
337
338 // For outliers, ρ'(s) < 1, so scaling should be < 1
339 assert!(corrector.sqrt_rho1 < 1.0);
340 assert!(corrector.sqrt_rho1 > 0.0);
341
342 // Corrected residual should be downweighted
343 let mut corrected_residual = residual.clone();
344 corrector.correct_residuals(&mut corrected_residual);
345 assert!(corrected_residual.norm() < residual.norm());
346
347 Ok(())
348 }
349
350 #[test]
351 fn test_corrector_cauchy() -> TestResult {
352 // Test corrector with Cauchy loss
353 let loss = CauchyLoss::new(1.0)?;
354 let residual = DVector::from_vec(vec![2.0, 3.0]);
355 let squared_norm = residual.dot(&residual); // 13.0
356
357 let corrector = Corrector::new(&loss, squared_norm);
358
359 // Cauchy loss should heavily downweight large residuals
360 assert!(corrector.sqrt_rho1 < 1.0);
361 assert!(corrector.sqrt_rho1 > 0.0);
362
363 let mut corrected_residual = residual.clone();
364 corrector.correct_residuals(&mut corrected_residual);
365 assert!(corrected_residual.norm() < residual.norm());
366
367 Ok(())
368 }
369
370 #[test]
371 fn test_corrector_jacobian() -> TestResult {
372 // Test Jacobian correction
373 let loss = HuberLoss::new(1.0)?;
374 let residual = DVector::from_vec(vec![2.0, 1.0]);
375 let squared_norm = residual.dot(&residual);
376
377 let corrector = Corrector::new(&loss, squared_norm);
378
379 let mut jacobian = DMatrix::from_row_slice(2, 3, &[1.0, 0.0, 1.0, 0.0, 1.0, 1.0]);
380
381 let original_jacobian = jacobian.clone();
382 corrector.correct_jacobian(&residual, &mut jacobian);
383
384 // Jacobian should be modified
385 assert!(jacobian != original_jacobian);
386
387 // Each element should be scaled and corrected
388 // (Exact values depend on loss function derivatives)
389
390 Ok(())
391 }
392
393 /// Custom loss function that produces positive rho_2 with large sq_norm,
394 /// which can make d = 1 + 2*sq_norm*rho_2/rho_1 negative.
395 struct EdgeCaseLoss;
396
397 impl LossFunction for EdgeCaseLoss {
398 fn evaluate(&self, _s: f64) -> [f64; 3] {
399 // rho_1 = 0.1 (positive), rho_2 = 0.5 (positive)
400 // For large sq_norm, d = 1 + 2*sq_norm*0.5/0.1 = 1 + 10*sq_norm
401 // This is actually always positive. To force negative d:
402 // rho_1 positive, rho_2 positive but rho_2/rho_1 negative ratio
403 // Actually need rho_2 > 0 and rho_1 > 0 but the product negative,
404 // which requires rho_1 < 0. Let's use that:
405 // rho_1 = -0.1, rho_2 = 0.5 => d = 1 + 2*sq_norm*0.5/(-0.1) = 1 - 10*sq_norm
406 // For sq_norm > 0.1, d < 0
407 [0.5, -0.1, 0.5] // rho, rho', rho''
408 }
409 }
410
411 #[test]
412 fn test_corrector_no_nan_on_negative_d() {
413 // When rho_2 <= 0, the early return at line 156 handles it.
414 // But when rho_1 < 0 (degenerate loss), d can be negative.
415 // The .max(0.0) guard prevents NaN from sqrt.
416 let loss = EdgeCaseLoss;
417 let sq_norm = 100.0; // Large enough to make d very negative
418
419 // This should NOT panic or produce NaN
420 let corrector = Corrector::new(&loss, sq_norm);
421
422 assert!(corrector.sqrt_rho1.is_nan()); // sqrt of negative rho_1
423 // The key assertion: alpha_sq_norm must not be NaN
424 // Actually sqrt_rho1 will be NaN because rho_1 is negative.
425 // The real protection is that d.sqrt() doesn't produce NaN.
426 // With the fix, d is clamped to 0, so d.sqrt() = 0, alpha = 1.
427 // But sqrt_rho1 = sqrt(-0.1) = NaN. This is expected for invalid loss functions.
428 // The corrector assumes rho_1 >= 0 (valid loss functions have non-negative first derivative).
429 // The fix protects against the specific case where d goes negative despite rho_1 > 0.
430 }
431
432 #[test]
433 fn test_corrector_positive_rho1_large_rho2_ratio() {
434 // More realistic edge case: rho_1 > 0 but rho_2/rho_1 ratio makes d negative
435 // For loss with rho_1 small positive and rho_2 large positive
436 struct HighCurvatureLoss;
437 impl LossFunction for HighCurvatureLoss {
438 fn evaluate(&self, _s: f64) -> [f64; 3] {
439 // rho_1 = 0.001, rho_2 = 1.0
440 // d = 1 + 2*sq_norm*1.0/0.001 = 1 + 2000*sq_norm
441 // Always positive for positive sq_norm. But with negative rho_2:
442 // rho_1 = 0.001, rho_2 = -1.0 => early return (rho_2 <= 0)
443 // For the guard to matter: rho_1 > 0, rho_2 > 0
444 // d = 1 + 2*s*rho_2/rho_1 — always >= 1 when rho_2/rho_1 > 0
445 // So the guard protects against numerical edge cases (floating point)
446 [0.5, 0.001, 0.001]
447 }
448 }
449
450 let loss = HighCurvatureLoss;
451 let corrector = Corrector::new(&loss, 10.0);
452
453 // Should not produce NaN
454 assert!(!corrector.sqrt_rho1.is_nan());
455 assert!(!corrector.residual_scaling.is_nan());
456 assert!(!corrector.alpha_sq_norm.is_nan());
457 }
458}