1use ndarray::{Array1, Array2};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
19pub enum LMError {
20 #[error("bounds/x0 dimension mismatch: x0 has {x0_len}, bounds has {bounds_len}")]
22 DimensionMismatch {
23 x0_len: usize,
25 bounds_len: usize,
27 },
28
29 #[error("invalid bounds at index {index}: lower ({lower}) > upper ({upper})")]
31 InvalidBounds {
32 index: usize,
34 lower: f64,
36 upper: f64,
38 },
39
40 #[error("residual dimension changed: was {expected}, now {got}")]
42 ResidualDimensionChanged {
43 expected: usize,
45 got: usize,
47 },
48
49 #[error("singular Jacobian — cannot compute step")]
51 SingularJacobian,
52}
53
54pub type LMResult<T> = std::result::Result<T, LMError>;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum LMCallbackAction {
64 Continue,
66 Stop,
68}
69
70pub struct LMIntermediate {
72 pub x: Array1<f64>,
74 pub fun: f64,
76 pub lambda: f64,
78 pub iter: usize,
80}
81
82pub type LMCallback = Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>;
88
89pub struct LMConfig {
91 pub maxiter: usize,
93 pub tol: f64,
95 pub atol: f64,
97 pub lambda_init: f64,
99 pub jacobian_epsilon: f64,
101 pub x0: Array1<f64>,
103 pub weights: Option<Array1<f64>>,
105 pub disp: bool,
107 pub callback: Option<LMCallback>,
109}
110
111pub struct LMConfigBuilder {
113 maxiter: usize,
114 tol: f64,
115 atol: f64,
116 lambda_init: f64,
117 jacobian_epsilon: f64,
118 x0: Option<Array1<f64>>,
119 weights: Option<Array1<f64>>,
120 disp: bool,
121 callback: Option<LMCallback>,
122}
123
124impl LMConfigBuilder {
125 pub fn new() -> Self {
127 Self {
128 maxiter: 100,
129 tol: 1e-10,
130 atol: 1e-14,
131 lambda_init: 1.0,
132 jacobian_epsilon: 1e-8,
133 x0: None,
134 weights: None,
135 disp: false,
136 callback: None,
137 }
138 }
139
140 pub fn x0(mut self, x0: Array1<f64>) -> Self {
142 self.x0 = Some(x0);
143 self
144 }
145
146 pub fn maxiter(mut self, n: usize) -> Self {
148 self.maxiter = n;
149 self
150 }
151
152 pub fn tol(mut self, t: f64) -> Self {
154 self.tol = t;
155 self
156 }
157
158 pub fn atol(mut self, t: f64) -> Self {
160 self.atol = t;
161 self
162 }
163
164 pub fn lambda_init(mut self, l: f64) -> Self {
166 self.lambda_init = l;
167 self
168 }
169
170 pub fn jacobian_epsilon(mut self, eps: f64) -> Self {
172 self.jacobian_epsilon = eps;
173 self
174 }
175
176 pub fn weights(mut self, w: Array1<f64>) -> Self {
178 self.weights = Some(w);
179 self
180 }
181
182 pub fn disp(mut self, d: bool) -> Self {
184 self.disp = d;
185 self
186 }
187
188 pub fn callback(mut self, cb: Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>) -> Self {
190 self.callback = Some(cb);
191 self
192 }
193
194 pub fn build(self) -> LMConfig {
196 LMConfig {
197 maxiter: self.maxiter,
198 tol: self.tol,
199 atol: self.atol,
200 lambda_init: self.lambda_init,
201 jacobian_epsilon: self.jacobian_epsilon,
202 x0: self.x0.expect("LMConfigBuilder: x0 is required"),
203 weights: self.weights,
204 disp: self.disp,
205 callback: self.callback,
206 }
207 }
208}
209
210impl Default for LMConfigBuilder {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[derive(Debug)]
222pub struct LMReport {
223 pub x: Array1<f64>,
225 pub fun: f64,
227 pub residuals: Array1<f64>,
229 pub success: bool,
231 pub message: String,
233 pub nit: usize,
235 pub nfev: usize,
237}
238
239pub fn levenberg_marquardt<F>(
252 residual_fn: &F,
253 bounds: &[(f64, f64)],
254 config: LMConfig,
255) -> LMResult<LMReport>
256where
257 F: Fn(&Array1<f64>) -> Array1<f64>,
258{
259 let n_params = config.x0.len();
260
261 if bounds.len() != n_params {
263 return Err(LMError::DimensionMismatch {
264 x0_len: n_params,
265 bounds_len: bounds.len(),
266 });
267 }
268 for (i, &(lo, hi)) in bounds.iter().enumerate() {
269 if lo > hi {
270 return Err(LMError::InvalidBounds {
271 index: i,
272 lower: lo,
273 upper: hi,
274 });
275 }
276 }
277
278 let mut x = project(&config.x0, bounds);
280 let mut r = residual_fn(&x);
281 let n_residuals = r.len();
282 let mut nfev: usize = 1;
283
284 let w = config.weights.unwrap_or_else(|| Array1::ones(n_residuals));
285 let mut f_val = weighted_sos(&r, &w);
286 let mut lambda = config.lambda_init;
287 let eps = config.jacobian_epsilon;
288
289 let mut success = false;
290 let mut message = format!("max iterations ({}) reached", config.maxiter);
291 let mut callback = config.callback;
292 let mut nit: usize = 0;
293
294 for iter in 0..config.maxiter {
295 nit = iter + 1;
296 if let Some(ref mut cb) = callback {
298 let intermediate = LMIntermediate {
299 x: x.clone(),
300 fun: f_val,
301 lambda,
302 iter,
303 };
304 if cb(&intermediate) == LMCallbackAction::Stop {
305 message = "stopped by callback".to_string();
306 break;
307 }
308 }
309
310 if f_val <= config.atol {
312 success = true;
313 message = format!(
314 "converged: objective {:.2e} <= atol {:.2e}",
315 f_val, config.atol
316 );
317 break;
318 }
319
320 let jac = compute_jacobian(residual_fn, &x, &r, n_residuals, eps, &mut nfev)?;
322
323 let jtwj = jtw_j(&jac, &w);
327 let jtwr = jtw_r(&jac, &w, &r); let diag: Array1<f64> = (0..n_params).map(|j| jtwj[[j, j]].max(1e-20)).collect();
331
332 let mut h = jtwj.clone();
334 for j in 0..n_params {
335 h[[j, j]] += lambda * diag[j];
336 }
337
338 let neg_jtwr = jtwr.mapv(|v| -v);
340
341 let delta = match solve_linear_system(&h, &neg_jtwr) {
343 Some(d) => d,
344 None => {
345 lambda *= 4.0;
347 continue;
348 }
349 };
350
351 let x_trial = project(&(&x + &delta), bounds);
353 let r_trial = residual_fn(&x_trial);
354 nfev += 1;
355
356 if r_trial.len() != n_residuals {
357 return Err(LMError::ResidualDimensionChanged {
358 expected: n_residuals,
359 got: r_trial.len(),
360 });
361 }
362
363 let f_trial = weighted_sos(&r_trial, &w);
364
365 let pred: f64 = {
369 let mut p = 0.0;
370 for j in 0..n_params {
371 let mut jtwj_delta_j = 0.0;
372 for k in 0..n_params {
373 jtwj_delta_j += jtwj[[j, k]] * delta[k];
374 }
375 p += delta[j] * jtwj_delta_j + 2.0 * lambda * diag[j] * delta[j] * delta[j];
376 }
377 p
378 };
379
380 let actual = f_val - f_trial;
381 let rho = if pred.abs() > 1e-30 {
382 actual / pred
383 } else {
384 0.0
385 };
386
387 if rho > 0.0 && f_trial < f_val {
388 let f_old = f_val;
390 x = x_trial;
391 r = r_trial;
392 f_val = f_trial;
393 lambda = (lambda / 2.0).max(1e-15);
394
395 if (f_old - f_val).abs() < config.tol * f_old + config.atol {
397 success = true;
398 message = format!(
399 "converged: |df|={:.2e} < tol*f+atol={:.2e}",
400 (f_old - f_val).abs(),
401 config.tol * f_old + config.atol
402 );
403 break;
404 }
405 } else {
406 lambda = (lambda * 2.0).min(1e15);
408 }
409
410 if config.disp && iter % 10 == 0 {
411 eprintln!(
412 "LM iter {}: f={:.6e}, lambda={:.2e}, rho={:.3}",
413 iter, f_val, lambda, rho
414 );
415 }
416 }
417
418 Ok(LMReport {
419 x,
420 fun: f_val,
421 residuals: r,
422 success,
423 message,
424 nit,
425 nfev,
426 })
427}
428
429fn project(x: &Array1<f64>, bounds: &[(f64, f64)]) -> Array1<f64> {
435 Array1::from(
436 x.iter()
437 .zip(bounds.iter())
438 .map(|(&xi, &(lo, hi))| xi.clamp(lo, hi))
439 .collect::<Vec<_>>(),
440 )
441}
442
443fn weighted_sos(r: &Array1<f64>, w: &Array1<f64>) -> f64 {
445 r.iter().zip(w.iter()).map(|(&ri, &wi)| wi * ri * ri).sum()
446}
447
448fn compute_jacobian<F>(
450 residual_fn: &F,
451 x: &Array1<f64>,
452 _r0: &Array1<f64>,
453 n_residuals: usize,
454 eps: f64,
455 nfev: &mut usize,
456) -> LMResult<Array2<f64>>
457where
458 F: Fn(&Array1<f64>) -> Array1<f64>,
459{
460 let n_params = x.len();
461 let mut jac = Array2::zeros((n_residuals, n_params));
462
463 for j in 0..n_params {
464 let mut x_plus = x.clone();
465 let mut x_minus = x.clone();
466 let h = jacobian_step(eps, x[j]);
467 x_plus[j] += h;
468 x_minus[j] -= h;
469
470 let r_plus = residual_fn(&x_plus);
471 let r_minus = residual_fn(&x_minus);
472 *nfev += 2;
473
474 if r_plus.len() != n_residuals {
475 return Err(LMError::ResidualDimensionChanged {
476 expected: n_residuals,
477 got: r_plus.len(),
478 });
479 }
480 if r_minus.len() != n_residuals {
481 return Err(LMError::ResidualDimensionChanged {
482 expected: n_residuals,
483 got: r_minus.len(),
484 });
485 }
486
487 let inv_2h = 1.0 / (2.0 * h);
488 for i in 0..n_residuals {
489 jac[[i, j]] = (r_plus[i] - r_minus[i]) * inv_2h;
490 }
491 }
492
493 Ok(jac)
494}
495
496fn jacobian_step(eps: f64, x: f64) -> f64 {
499 (eps * (1.0 + x.abs())).max(eps)
500}
501
502fn jtw_j(jac: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
504 let n_params = jac.ncols();
505 let n_res = jac.nrows();
506 let mut result = Array2::zeros((n_params, n_params));
507
508 for j in 0..n_params {
509 for k in j..n_params {
510 let mut val = 0.0;
511 for i in 0..n_res {
512 val += w[i] * jac[[i, j]] * jac[[i, k]];
513 }
514 result[[j, k]] = val;
515 result[[k, j]] = val;
516 }
517 }
518
519 result
520}
521
522fn jtw_r(jac: &Array2<f64>, w: &Array1<f64>, r: &Array1<f64>) -> Array1<f64> {
524 let n_params = jac.ncols();
525 let n_res = jac.nrows();
526 let mut result = Array1::zeros(n_params);
527
528 for j in 0..n_params {
529 let mut val = 0.0;
530 for i in 0..n_res {
531 val += w[i] * jac[[i, j]] * r[i];
532 }
533 result[j] = val;
534 }
535
536 result
537}
538
539fn solve_linear_system(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
542 let n = b.len();
543 debug_assert_eq!(a.nrows(), n);
544 debug_assert_eq!(a.ncols(), n);
545
546 let mut aug = Array2::zeros((n, n + 1));
548 for i in 0..n {
549 for j in 0..n {
550 aug[[i, j]] = a[[i, j]];
551 }
552 aug[[i, n]] = b[i];
553 }
554
555 for col in 0..n {
557 let mut max_val = aug[[col, col]].abs();
559 let mut max_row = col;
560 for row in (col + 1)..n {
561 let val = aug[[row, col]].abs();
562 if val > max_val {
563 max_val = val;
564 max_row = row;
565 }
566 }
567
568 if max_val < 1e-12 {
569 return None; }
571
572 if max_row != col {
574 for j in 0..=n {
575 let tmp = aug[[col, j]];
576 aug[[col, j]] = aug[[max_row, j]];
577 aug[[max_row, j]] = tmp;
578 }
579 }
580
581 let pivot = aug[[col, col]];
583 for row in (col + 1)..n {
584 let factor = aug[[row, col]] / pivot;
585 for j in col..=n {
586 aug[[row, j]] -= factor * aug[[col, j]];
587 }
588 }
589 }
590
591 let mut x = Array1::zeros(n);
593 for col in (0..n).rev() {
594 let mut sum = aug[[col, n]];
595 for j in (col + 1)..n {
596 sum -= aug[[col, j]] * x[j];
597 }
598 x[col] = sum / aug[[col, col]];
599 }
600
601 if x.iter().any(|v| !v.is_finite()) {
603 return None;
604 }
605
606 Some(x)
607}
608
609#[cfg(test)]
614mod tests {
615 use super::*;
616 use ndarray::array;
617
618 #[test]
619 fn test_sphere() {
620 let residual = |x: &Array1<f64>| x.clone();
622 let bounds = vec![(-10.0, 10.0); 3];
623 let config = LMConfigBuilder::new()
624 .x0(array![3.0, -2.0, 1.0])
625 .maxiter(50)
626 .build();
627
628 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
629 assert!(report.success, "should converge: {}", report.message);
630 assert!(
631 report.fun < 1e-12,
632 "objective should be ~0, got {}",
633 report.fun
634 );
635 for &xi in report.x.iter() {
636 assert!(xi.abs() < 1e-6, "x should be ~0, got {}", xi);
637 }
638 }
639
640 #[test]
641 fn test_rosenbrock_residual() {
642 let residual = |x: &Array1<f64>| array![10.0 * (x[1] - x[0] * x[0]), 1.0 - x[0]];
644 let bounds = vec![(-5.0, 5.0); 2];
645 let config = LMConfigBuilder::new()
646 .x0(array![-1.0, 1.0])
647 .maxiter(200)
648 .tol(1e-12)
649 .build();
650
651 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
652 assert!(report.success, "should converge: {}", report.message);
653 assert!(
654 (report.x[0] - 1.0).abs() < 1e-4,
655 "x0 should be ~1, got {}",
656 report.x[0]
657 );
658 assert!(
659 (report.x[1] - 1.0).abs() < 1e-4,
660 "x1 should be ~1, got {}",
661 report.x[1]
662 );
663 }
664
665 #[test]
666 fn test_bounded_solution() {
667 let residual = |x: &Array1<f64>| array![x[0] - 5.0];
669 let bounds = vec![(-10.0, 3.0)];
670 let config = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
671
672 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
673 assert!(
674 (report.x[0] - 3.0).abs() < 1e-6,
675 "x should be at bound 3.0, got {}",
676 report.x[0]
677 );
678 }
679
680 #[test]
681 fn test_nan_residual_handled() {
682 let residual = |x: &Array1<f64>| {
684 if x[0].abs() > 100.0 {
685 array![f64::NAN]
686 } else {
687 array![x[0] - 1.0]
688 }
689 };
690 let bounds = vec![(-200.0, 200.0)];
691 let config = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
692
693 let result = levenberg_marquardt(&residual, &bounds, config);
695 assert!(result.is_ok());
696 }
697
698 #[test]
699 fn test_zero_residual() {
700 let residual = |x: &Array1<f64>| array![x[0], x[1]];
702 let bounds = vec![(-10.0, 10.0); 2];
703 let config = LMConfigBuilder::new()
704 .x0(array![0.0, 0.0])
705 .maxiter(10)
706 .build();
707
708 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
709 assert!(report.success, "already at optimum: {}", report.message);
710 assert!(report.fun < 1e-14);
711 }
712
713 #[test]
714 fn test_callback_stop() {
715 let residual = |x: &Array1<f64>| x.clone();
716 let bounds = vec![(-10.0, 10.0); 2];
717 let config = LMConfigBuilder::new()
718 .x0(array![5.0, 5.0])
719 .maxiter(1000)
720 .callback(Box::new(|inter| {
721 if inter.iter >= 3 {
722 LMCallbackAction::Stop
723 } else {
724 LMCallbackAction::Continue
725 }
726 }))
727 .build();
728
729 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
730 assert_eq!(report.message, "stopped by callback");
731 }
732
733 #[test]
734 fn test_weighted_residuals() {
735 let residual = |x: &Array1<f64>| array![x[0] - 1.0, x[0] - 3.0];
738 let bounds = vec![(-10.0, 10.0)];
739
740 let config_unw = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
742 let report_unw = levenberg_marquardt(&residual, &bounds, config_unw).unwrap();
743
744 let config_w = LMConfigBuilder::new()
746 .x0(array![0.0])
747 .maxiter(50)
748 .weights(array![10.0, 1.0])
749 .build();
750 let report_w = levenberg_marquardt(&residual, &bounds, config_w).unwrap();
751
752 assert!(
754 (report_unw.x[0] - 2.0).abs() < 0.01,
755 "unweighted x should be ~2, got {}",
756 report_unw.x[0]
757 );
758 assert!(
760 report_w.x[0] < report_unw.x[0],
761 "weighted x ({}) should be less than unweighted ({})",
762 report_w.x[0],
763 report_unw.x[0]
764 );
765 assert!(
766 (report_w.x[0] - 1.0).abs() < 0.5,
767 "weighted x should be near 1.0, got {}",
768 report_w.x[0]
769 );
770 }
771
772 #[test]
773 fn test_dimension_mismatch() {
774 let residual = |x: &Array1<f64>| x.clone();
775 let bounds = vec![(-1.0, 1.0); 3]; let config = LMConfigBuilder::new()
777 .x0(array![0.0, 0.0]) .build();
779
780 let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
781 assert!(matches!(err, LMError::DimensionMismatch { .. }));
782 }
783
784 #[test]
785 fn test_nit_tracks_iterations_not_nfev() {
786 let residual = |x: &Array1<f64>| x.clone();
789 let bounds = vec![(-10.0, 10.0); 2];
790 let config = LMConfigBuilder::new()
791 .x0(array![5.0, 5.0])
792 .maxiter(5)
793 .tol(1e-20) .atol(0.0)
795 .build();
796
797 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
798 assert_eq!(
799 report.nit, 5,
800 "nit should be maxiter (5), got {}",
801 report.nit
802 );
803 assert!(
804 report.nfev > report.nit,
805 "nfev ({}) should be much larger than nit ({})",
806 report.nfev,
807 report.nit
808 );
809 }
810
811 #[test]
812 fn test_invalid_bounds() {
813 let residual = |x: &Array1<f64>| x.clone();
814 let bounds = vec![(5.0, 1.0)]; let config = LMConfigBuilder::new().x0(array![0.0]).build();
816
817 let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
818 assert!(matches!(err, LMError::InvalidBounds { .. }));
819 }
820
821 #[test]
822 fn test_jacobian_step_size_bounded() {
823 let eps = 1e-8;
826 assert!((jacobian_step(eps, 0.0) - eps).abs() < 1e-15);
827 assert!((jacobian_step(eps, 1.0) - 2.0e-8).abs() < 1e-15);
828 let h_large = jacobian_step(eps, 1.0e12);
833 assert!(h_large >= eps);
834 assert!((h_large - eps * (1.0 + 1.0e12)).abs() < 1.0);
835 }
836
837 #[test]
838 fn test_large_x0_lm_converges() {
839 let residual = |x: &Array1<f64>| array![x[0] - 1.0e6];
841 let bounds = vec![(0.0, 2.0e6)];
842 let config = LMConfigBuilder::new()
843 .x0(array![1.0e8])
844 .maxiter(100)
845 .build();
846 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
847 assert!(report.success, "should converge: {}", report.message);
848 assert!(
849 (report.x[0] - 1.0e6).abs() < 1.0,
850 "x should be ~1e6, got {}",
851 report.x[0]
852 );
853 }
854
855 #[test]
856 fn test_linear_solver_rejects_near_singular() {
857 let a = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 1.0, 1.0 + 1e-15]).unwrap();
860 let b = Array1::from(vec![1.0, 1.0]);
861 assert!(
862 solve_linear_system(&a, &b).is_none(),
863 "near-singular matrix should be rejected"
864 );
865 }
866}