1use ndarray::{Array1, Array2};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
19pub enum LMError {
20 #[error("bounds/x0 dimension mismatch: x0 has {x0_len}, bounds has {bounds_len}")]
22 DimensionMismatch {
23 x0_len: usize,
25 bounds_len: usize,
27 },
28
29 #[error("invalid bounds at index {index}: lower ({lower}) > upper ({upper})")]
31 InvalidBounds {
32 index: usize,
34 lower: f64,
36 upper: f64,
38 },
39
40 #[error("residual dimension changed: was {expected}, now {got}")]
42 ResidualDimensionChanged {
43 expected: usize,
45 got: usize,
47 },
48
49 #[error("singular Jacobian — cannot compute step")]
51 SingularJacobian,
52}
53
54pub type LMResult<T> = std::result::Result<T, LMError>;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum LMCallbackAction {
64 Continue,
66 Stop,
68}
69
70pub struct LMIntermediate {
72 pub x: Array1<f64>,
74 pub fun: f64,
76 pub lambda: f64,
78 pub iter: usize,
80}
81
82pub type LMCallback = Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>;
88
89pub struct LMConfig {
91 pub maxiter: usize,
93 pub tol: f64,
95 pub atol: f64,
97 pub lambda_init: f64,
99 pub jacobian_epsilon: f64,
101 pub x0: Array1<f64>,
103 pub weights: Option<Array1<f64>>,
105 pub disp: bool,
107 pub callback: Option<LMCallback>,
109}
110
111pub struct LMConfigBuilder {
113 maxiter: usize,
114 tol: f64,
115 atol: f64,
116 lambda_init: f64,
117 jacobian_epsilon: f64,
118 x0: Option<Array1<f64>>,
119 weights: Option<Array1<f64>>,
120 disp: bool,
121 callback: Option<LMCallback>,
122}
123
124impl LMConfigBuilder {
125 pub fn new() -> Self {
127 Self {
128 maxiter: 100,
129 tol: 1e-10,
130 atol: 1e-14,
131 lambda_init: 1.0,
132 jacobian_epsilon: 1e-8,
133 x0: None,
134 weights: None,
135 disp: false,
136 callback: None,
137 }
138 }
139
140 pub fn x0(mut self, x0: Array1<f64>) -> Self {
142 self.x0 = Some(x0);
143 self
144 }
145
146 pub fn maxiter(mut self, n: usize) -> Self {
148 self.maxiter = n;
149 self
150 }
151
152 pub fn tol(mut self, t: f64) -> Self {
154 self.tol = t;
155 self
156 }
157
158 pub fn atol(mut self, t: f64) -> Self {
160 self.atol = t;
161 self
162 }
163
164 pub fn lambda_init(mut self, l: f64) -> Self {
166 self.lambda_init = l;
167 self
168 }
169
170 pub fn jacobian_epsilon(mut self, eps: f64) -> Self {
172 self.jacobian_epsilon = eps;
173 self
174 }
175
176 pub fn weights(mut self, w: Array1<f64>) -> Self {
178 self.weights = Some(w);
179 self
180 }
181
182 pub fn disp(mut self, d: bool) -> Self {
184 self.disp = d;
185 self
186 }
187
188 pub fn callback(mut self, cb: Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>) -> Self {
190 self.callback = Some(cb);
191 self
192 }
193
194 pub fn build(self) -> LMConfig {
196 LMConfig {
197 maxiter: self.maxiter,
198 tol: self.tol,
199 atol: self.atol,
200 lambda_init: self.lambda_init,
201 jacobian_epsilon: self.jacobian_epsilon,
202 x0: self.x0.expect("LMConfigBuilder: x0 is required"),
203 weights: self.weights,
204 disp: self.disp,
205 callback: self.callback,
206 }
207 }
208}
209
210impl Default for LMConfigBuilder {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[derive(Debug)]
222pub struct LMReport {
223 pub x: Array1<f64>,
225 pub fun: f64,
227 pub residuals: Array1<f64>,
229 pub success: bool,
231 pub message: String,
233 pub nit: usize,
235 pub nfev: usize,
237}
238
239pub fn levenberg_marquardt<F>(
252 residual_fn: &F,
253 bounds: &[(f64, f64)],
254 config: LMConfig,
255) -> LMResult<LMReport>
256where
257 F: Fn(&Array1<f64>) -> Array1<f64>,
258{
259 let n_params = config.x0.len();
260
261 if bounds.len() != n_params {
263 return Err(LMError::DimensionMismatch {
264 x0_len: n_params,
265 bounds_len: bounds.len(),
266 });
267 }
268 for (i, &(lo, hi)) in bounds.iter().enumerate() {
269 if lo > hi {
270 return Err(LMError::InvalidBounds {
271 index: i,
272 lower: lo,
273 upper: hi,
274 });
275 }
276 }
277
278 let mut x = project(&config.x0, bounds);
280 let mut r = residual_fn(&x);
281 let n_residuals = r.len();
282 let mut nfev: usize = 1;
283
284 let w = config.weights.unwrap_or_else(|| Array1::ones(n_residuals));
285 let mut f_val = weighted_sos(&r, &w);
286 let mut lambda = config.lambda_init;
287 let eps = config.jacobian_epsilon;
288
289 let mut success = false;
290 let mut message = format!("max iterations ({}) reached", config.maxiter);
291 let mut callback = config.callback;
292 let mut nit: usize = 0;
293
294 for iter in 0..config.maxiter {
295 nit = iter + 1;
296 if let Some(ref mut cb) = callback {
298 let intermediate = LMIntermediate {
299 x: x.clone(),
300 fun: f_val,
301 lambda,
302 iter,
303 };
304 if cb(&intermediate) == LMCallbackAction::Stop {
305 message = "stopped by callback".to_string();
306 break;
307 }
308 }
309
310 if f_val <= config.atol {
312 success = true;
313 message = format!(
314 "converged: objective {:.2e} <= atol {:.2e}",
315 f_val, config.atol
316 );
317 break;
318 }
319
320 let jac = compute_jacobian(residual_fn, &x, &r, n_residuals, eps, &mut nfev)?;
322
323 let jtwj = jtw_j(&jac, &w);
327 let jtwr = jtw_r(&jac, &w, &r); let diag: Array1<f64> = (0..n_params).map(|j| jtwj[[j, j]].max(1e-20)).collect();
331
332 let mut h = jtwj.clone();
334 for j in 0..n_params {
335 h[[j, j]] += lambda * diag[j];
336 }
337
338 let neg_jtwr = jtwr.mapv(|v| -v);
340
341 let delta = match solve_linear_system(&h, &neg_jtwr) {
343 Some(d) => d,
344 None => {
345 lambda *= 4.0;
347 continue;
348 }
349 };
350
351 let x_trial = project(&(&x + &delta), bounds);
353 let r_trial = residual_fn(&x_trial);
354 nfev += 1;
355
356 if r_trial.len() != n_residuals {
357 return Err(LMError::ResidualDimensionChanged {
358 expected: n_residuals,
359 got: r_trial.len(),
360 });
361 }
362
363 let f_trial = weighted_sos(&r_trial, &w);
364
365 let pred: f64 = {
369 let mut p = 0.0;
370 for j in 0..n_params {
371 let mut jtwj_delta_j = 0.0;
372 for k in 0..n_params {
373 jtwj_delta_j += jtwj[[j, k]] * delta[k];
374 }
375 p += delta[j] * jtwj_delta_j + 2.0 * lambda * diag[j] * delta[j] * delta[j];
376 }
377 p
378 };
379
380 let actual = f_val - f_trial;
381 let rho = if pred.abs() > 1e-30 {
382 actual / pred
383 } else {
384 0.0
385 };
386
387 if rho > 0.0 && f_trial < f_val {
388 let f_old = f_val;
390 x = x_trial;
391 r = r_trial;
392 f_val = f_trial;
393 lambda = (lambda / 2.0).max(1e-15);
394
395 if (f_old - f_val).abs() < config.tol * f_old + config.atol {
397 success = true;
398 message = format!(
399 "converged: |df|={:.2e} < tol*f+atol={:.2e}",
400 (f_old - f_val).abs(),
401 config.tol * f_old + config.atol
402 );
403 break;
404 }
405 } else {
406 lambda = (lambda * 2.0).min(1e15);
408 }
409
410 if config.disp && iter % 10 == 0 {
411 eprintln!(
412 "LM iter {}: f={:.6e}, lambda={:.2e}, rho={:.3}",
413 iter, f_val, lambda, rho
414 );
415 }
416 }
417
418 Ok(LMReport {
419 x,
420 fun: f_val,
421 residuals: r,
422 success,
423 message,
424 nit,
425 nfev,
426 })
427}
428
429fn project(x: &Array1<f64>, bounds: &[(f64, f64)]) -> Array1<f64> {
435 Array1::from(
436 x.iter()
437 .zip(bounds.iter())
438 .map(|(&xi, &(lo, hi))| xi.clamp(lo, hi))
439 .collect::<Vec<_>>(),
440 )
441}
442
443fn weighted_sos(r: &Array1<f64>, w: &Array1<f64>) -> f64 {
445 r.iter().zip(w.iter()).map(|(&ri, &wi)| wi * ri * ri).sum()
446}
447
448fn compute_jacobian<F>(
450 residual_fn: &F,
451 x: &Array1<f64>,
452 _r0: &Array1<f64>,
453 n_residuals: usize,
454 eps: f64,
455 nfev: &mut usize,
456) -> LMResult<Array2<f64>>
457where
458 F: Fn(&Array1<f64>) -> Array1<f64>,
459{
460 let n_params = x.len();
461 let mut jac = Array2::zeros((n_residuals, n_params));
462
463 for j in 0..n_params {
464 let mut x_plus = x.clone();
465 let mut x_minus = x.clone();
466 let h = jacobian_step(eps, x[j]);
467 x_plus[j] += h;
468 x_minus[j] -= h;
469
470 let r_plus = residual_fn(&x_plus);
471 let r_minus = residual_fn(&x_minus);
472 *nfev += 2;
473
474 if r_plus.len() != n_residuals {
475 return Err(LMError::ResidualDimensionChanged {
476 expected: n_residuals,
477 got: r_plus.len(),
478 });
479 }
480 if r_minus.len() != n_residuals {
481 return Err(LMError::ResidualDimensionChanged {
482 expected: n_residuals,
483 got: r_minus.len(),
484 });
485 }
486
487 let inv_2h = 1.0 / (2.0 * h);
488 for i in 0..n_residuals {
489 jac[[i, j]] = (r_plus[i] - r_minus[i]) * inv_2h;
490 }
491 }
492
493 Ok(jac)
494}
495
496fn jacobian_step(eps: f64, x: f64) -> f64 {
502 (eps * (1.0 + x.abs())).max(eps)
503}
504
505fn jtw_j(jac: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
507 let n_params = jac.ncols();
508 let n_res = jac.nrows();
509 let mut result = Array2::zeros((n_params, n_params));
510
511 for j in 0..n_params {
512 for k in j..n_params {
513 let mut val = 0.0;
514 for i in 0..n_res {
515 val += w[i] * jac[[i, j]] * jac[[i, k]];
516 }
517 result[[j, k]] = val;
518 result[[k, j]] = val;
519 }
520 }
521
522 result
523}
524
525fn jtw_r(jac: &Array2<f64>, w: &Array1<f64>, r: &Array1<f64>) -> Array1<f64> {
527 let n_params = jac.ncols();
528 let n_res = jac.nrows();
529 let mut result = Array1::zeros(n_params);
530
531 for j in 0..n_params {
532 let mut val = 0.0;
533 for i in 0..n_res {
534 val += w[i] * jac[[i, j]] * r[i];
535 }
536 result[j] = val;
537 }
538
539 result
540}
541
542fn solve_linear_system(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
545 let n = b.len();
546 debug_assert_eq!(a.nrows(), n);
547 debug_assert_eq!(a.ncols(), n);
548
549 let mut aug = Array2::zeros((n, n + 1));
551 for i in 0..n {
552 for j in 0..n {
553 aug[[i, j]] = a[[i, j]];
554 }
555 aug[[i, n]] = b[i];
556 }
557
558 for col in 0..n {
560 let mut max_val = aug[[col, col]].abs();
562 let mut max_row = col;
563 for row in (col + 1)..n {
564 let val = aug[[row, col]].abs();
565 if val > max_val {
566 max_val = val;
567 max_row = row;
568 }
569 }
570
571 if max_val < 1e-12 {
572 return None; }
574
575 if max_row != col {
577 for j in 0..=n {
578 let tmp = aug[[col, j]];
579 aug[[col, j]] = aug[[max_row, j]];
580 aug[[max_row, j]] = tmp;
581 }
582 }
583
584 let pivot = aug[[col, col]];
586 for row in (col + 1)..n {
587 let factor = aug[[row, col]] / pivot;
588 for j in col..=n {
589 aug[[row, j]] -= factor * aug[[col, j]];
590 }
591 }
592 }
593
594 let mut x = Array1::zeros(n);
596 for col in (0..n).rev() {
597 let mut sum = aug[[col, n]];
598 for j in (col + 1)..n {
599 sum -= aug[[col, j]] * x[j];
600 }
601 x[col] = sum / aug[[col, col]];
602 }
603
604 if x.iter().any(|v| !v.is_finite()) {
606 return None;
607 }
608
609 Some(x)
610}
611
612#[cfg(test)]
617mod tests {
618 use super::*;
619 use ndarray::array;
620
621 #[test]
622 fn test_sphere() {
623 let residual = |x: &Array1<f64>| x.clone();
625 let bounds = vec![(-10.0, 10.0); 3];
626 let config = LMConfigBuilder::new()
627 .x0(array![3.0, -2.0, 1.0])
628 .maxiter(50)
629 .build();
630
631 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
632 assert!(report.success, "should converge: {}", report.message);
633 assert!(
634 report.fun < 1e-12,
635 "objective should be ~0, got {}",
636 report.fun
637 );
638 for &xi in report.x.iter() {
639 assert!(xi.abs() < 1e-6, "x should be ~0, got {}", xi);
640 }
641 }
642
643 #[test]
644 fn test_rosenbrock_residual() {
645 let residual = |x: &Array1<f64>| array![10.0 * (x[1] - x[0] * x[0]), 1.0 - x[0]];
647 let bounds = vec![(-5.0, 5.0); 2];
648 let config = LMConfigBuilder::new()
649 .x0(array![-1.0, 1.0])
650 .maxiter(200)
651 .tol(1e-12)
652 .build();
653
654 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
655 assert!(report.success, "should converge: {}", report.message);
656 assert!(
657 (report.x[0] - 1.0).abs() < 1e-4,
658 "x0 should be ~1, got {}",
659 report.x[0]
660 );
661 assert!(
662 (report.x[1] - 1.0).abs() < 1e-4,
663 "x1 should be ~1, got {}",
664 report.x[1]
665 );
666 }
667
668 #[test]
669 fn test_bounded_solution() {
670 let residual = |x: &Array1<f64>| array![x[0] - 5.0];
672 let bounds = vec![(-10.0, 3.0)];
673 let config = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
674
675 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
676 assert!(
677 (report.x[0] - 3.0).abs() < 1e-6,
678 "x should be at bound 3.0, got {}",
679 report.x[0]
680 );
681 }
682
683 #[test]
684 fn test_nan_residual_handled() {
685 let residual = |x: &Array1<f64>| {
687 if x[0].abs() > 100.0 {
688 array![f64::NAN]
689 } else {
690 array![x[0] - 1.0]
691 }
692 };
693 let bounds = vec![(-200.0, 200.0)];
694 let config = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
695
696 let result = levenberg_marquardt(&residual, &bounds, config);
698 assert!(result.is_ok());
699 }
700
701 #[test]
702 fn test_zero_residual() {
703 let residual = |x: &Array1<f64>| array![x[0], x[1]];
705 let bounds = vec![(-10.0, 10.0); 2];
706 let config = LMConfigBuilder::new()
707 .x0(array![0.0, 0.0])
708 .maxiter(10)
709 .build();
710
711 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
712 assert!(report.success, "already at optimum: {}", report.message);
713 assert!(report.fun < 1e-14);
714 }
715
716 #[test]
717 fn test_callback_stop() {
718 let residual = |x: &Array1<f64>| x.clone();
719 let bounds = vec![(-10.0, 10.0); 2];
720 let config = LMConfigBuilder::new()
721 .x0(array![5.0, 5.0])
722 .maxiter(1000)
723 .callback(Box::new(|inter| {
724 if inter.iter >= 3 {
725 LMCallbackAction::Stop
726 } else {
727 LMCallbackAction::Continue
728 }
729 }))
730 .build();
731
732 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
733 assert_eq!(report.message, "stopped by callback");
734 }
735
736 #[test]
737 fn test_weighted_residuals() {
738 let residual = |x: &Array1<f64>| array![x[0] - 1.0, x[0] - 3.0];
741 let bounds = vec![(-10.0, 10.0)];
742
743 let config_unw = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
745 let report_unw = levenberg_marquardt(&residual, &bounds, config_unw).unwrap();
746
747 let config_w = LMConfigBuilder::new()
749 .x0(array![0.0])
750 .maxiter(50)
751 .weights(array![10.0, 1.0])
752 .build();
753 let report_w = levenberg_marquardt(&residual, &bounds, config_w).unwrap();
754
755 assert!(
757 (report_unw.x[0] - 2.0).abs() < 0.01,
758 "unweighted x should be ~2, got {}",
759 report_unw.x[0]
760 );
761 assert!(
763 report_w.x[0] < report_unw.x[0],
764 "weighted x ({}) should be less than unweighted ({})",
765 report_w.x[0],
766 report_unw.x[0]
767 );
768 assert!(
769 (report_w.x[0] - 1.0).abs() < 0.5,
770 "weighted x should be near 1.0, got {}",
771 report_w.x[0]
772 );
773 }
774
775 #[test]
776 fn test_dimension_mismatch() {
777 let residual = |x: &Array1<f64>| x.clone();
778 let bounds = vec![(-1.0, 1.0); 3]; let config = LMConfigBuilder::new()
780 .x0(array![0.0, 0.0]) .build();
782
783 let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
784 assert!(matches!(err, LMError::DimensionMismatch { .. }));
785 }
786
787 #[test]
788 fn test_nit_tracks_iterations_not_nfev() {
789 let residual = |x: &Array1<f64>| x.clone();
792 let bounds = vec![(-10.0, 10.0); 2];
793 let config = LMConfigBuilder::new()
794 .x0(array![5.0, 5.0])
795 .maxiter(5)
796 .tol(1e-20) .atol(0.0)
798 .build();
799
800 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
801 assert_eq!(
802 report.nit, 5,
803 "nit should be maxiter (5), got {}",
804 report.nit
805 );
806 assert!(
807 report.nfev > report.nit,
808 "nfev ({}) should be much larger than nit ({})",
809 report.nfev,
810 report.nit
811 );
812 }
813
814 #[test]
815 fn test_invalid_bounds() {
816 let residual = |x: &Array1<f64>| x.clone();
817 let bounds = vec![(5.0, 1.0)]; let config = LMConfigBuilder::new().x0(array![0.0]).build();
819
820 let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
821 assert!(matches!(err, LMError::InvalidBounds { .. }));
822 }
823
824 #[test]
825 fn test_jacobian_step_size_uses_relative_scale() {
826 let eps = 1e-8;
827 assert!((jacobian_step(eps, 0.0) - eps).abs() < 1e-15);
828 assert!((jacobian_step(eps, 1.0) - 2.0e-8).abs() < 1e-15);
829 let h_large = jacobian_step(eps, 1.0e12);
830 assert!(h_large >= eps);
831 assert!((h_large - eps * (1.0 + 1.0e12)).abs() < 1.0);
832 }
833
834 #[test]
835 fn test_large_x0_lm_converges() {
836 let residual = |x: &Array1<f64>| array![x[0] - 1.0e6];
838 let bounds = vec![(0.0, 2.0e6)];
839 let config = LMConfigBuilder::new()
840 .x0(array![1.0e8])
841 .maxiter(100)
842 .build();
843 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
844 assert!(report.success, "should converge: {}", report.message);
845 assert!(
846 (report.x[0] - 1.0e6).abs() < 1.0,
847 "x should be ~1e6, got {}",
848 report.x[0]
849 );
850 }
851
852 #[test]
853 fn test_linear_solver_rejects_near_singular() {
854 let a = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 1.0, 1.0 + 1e-15]).unwrap();
857 let b = Array1::from(vec![1.0, 1.0]);
858 assert!(
859 solve_linear_system(&a, &b).is_none(),
860 "near-singular matrix should be rejected"
861 );
862 }
863}