1use ndarray::{Array1, Array2};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
19pub enum LMError {
20 #[error("bounds/x0 dimension mismatch: x0 has {x0_len}, bounds has {bounds_len}")]
22 DimensionMismatch {
23 x0_len: usize,
25 bounds_len: usize,
27 },
28
29 #[error("invalid bounds at index {index}: lower ({lower}) > upper ({upper})")]
31 InvalidBounds {
32 index: usize,
34 lower: f64,
36 upper: f64,
38 },
39
40 #[error("residual dimension changed: was {expected}, now {got}")]
42 ResidualDimensionChanged {
43 expected: usize,
45 got: usize,
47 },
48
49 #[error("singular Jacobian — cannot compute step")]
51 SingularJacobian,
52}
53
54pub type LMResult<T> = std::result::Result<T, LMError>;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum LMCallbackAction {
64 Continue,
66 Stop,
68}
69
70pub struct LMIntermediate {
72 pub x: Array1<f64>,
74 pub fun: f64,
76 pub lambda: f64,
78 pub iter: usize,
80}
81
82pub type LMCallback = Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>;
88
89pub struct LMConfig {
91 pub maxiter: usize,
93 pub tol: f64,
95 pub atol: f64,
97 pub lambda_init: f64,
99 pub jacobian_epsilon: f64,
101 pub x0: Array1<f64>,
103 pub weights: Option<Array1<f64>>,
105 pub disp: bool,
107 pub callback: Option<LMCallback>,
109}
110
111pub struct LMConfigBuilder {
113 maxiter: usize,
114 tol: f64,
115 atol: f64,
116 lambda_init: f64,
117 jacobian_epsilon: f64,
118 x0: Option<Array1<f64>>,
119 weights: Option<Array1<f64>>,
120 disp: bool,
121 callback: Option<LMCallback>,
122}
123
124impl LMConfigBuilder {
125 pub fn new() -> Self {
127 Self {
128 maxiter: 100,
129 tol: 1e-10,
130 atol: 1e-14,
131 lambda_init: 1.0,
132 jacobian_epsilon: 1e-8,
133 x0: None,
134 weights: None,
135 disp: false,
136 callback: None,
137 }
138 }
139
140 pub fn x0(mut self, x0: Array1<f64>) -> Self {
142 self.x0 = Some(x0);
143 self
144 }
145
146 pub fn maxiter(mut self, n: usize) -> Self {
148 self.maxiter = n;
149 self
150 }
151
152 pub fn tol(mut self, t: f64) -> Self {
154 self.tol = t;
155 self
156 }
157
158 pub fn atol(mut self, t: f64) -> Self {
160 self.atol = t;
161 self
162 }
163
164 pub fn lambda_init(mut self, l: f64) -> Self {
166 self.lambda_init = l;
167 self
168 }
169
170 pub fn jacobian_epsilon(mut self, eps: f64) -> Self {
172 self.jacobian_epsilon = eps;
173 self
174 }
175
176 pub fn weights(mut self, w: Array1<f64>) -> Self {
178 self.weights = Some(w);
179 self
180 }
181
182 pub fn disp(mut self, d: bool) -> Self {
184 self.disp = d;
185 self
186 }
187
188 pub fn callback(mut self, cb: Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>) -> Self {
190 self.callback = Some(cb);
191 self
192 }
193
194 pub fn build(self) -> LMConfig {
196 LMConfig {
197 maxiter: self.maxiter,
198 tol: self.tol,
199 atol: self.atol,
200 lambda_init: self.lambda_init,
201 jacobian_epsilon: self.jacobian_epsilon,
202 x0: self.x0.expect("LMConfigBuilder: x0 is required"),
203 weights: self.weights,
204 disp: self.disp,
205 callback: self.callback,
206 }
207 }
208}
209
210impl Default for LMConfigBuilder {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[derive(Debug)]
222pub struct LMReport {
223 pub x: Array1<f64>,
225 pub fun: f64,
227 pub residuals: Array1<f64>,
229 pub success: bool,
231 pub message: String,
233 pub nit: usize,
235 pub nfev: usize,
237}
238
239pub fn levenberg_marquardt<F>(
252 residual_fn: &F,
253 bounds: &[(f64, f64)],
254 config: LMConfig,
255) -> LMResult<LMReport>
256where
257 F: Fn(&Array1<f64>) -> Array1<f64>,
258{
259 let n_params = config.x0.len();
260
261 if bounds.len() != n_params {
263 return Err(LMError::DimensionMismatch {
264 x0_len: n_params,
265 bounds_len: bounds.len(),
266 });
267 }
268 for (i, &(lo, hi)) in bounds.iter().enumerate() {
269 if lo > hi {
270 return Err(LMError::InvalidBounds {
271 index: i,
272 lower: lo,
273 upper: hi,
274 });
275 }
276 }
277
278 let mut x = project(&config.x0, bounds);
280 let mut r = residual_fn(&x);
281 let n_residuals = r.len();
282 let mut nfev: usize = 1;
283
284 let w = config.weights.unwrap_or_else(|| Array1::ones(n_residuals));
285 let mut f_val = weighted_sos(&r, &w);
286 let mut lambda = config.lambda_init;
287 let eps = config.jacobian_epsilon;
288
289 let mut success = false;
290 let mut message = format!("max iterations ({}) reached", config.maxiter);
291 let mut callback = config.callback;
292 let mut nit: usize = 0;
293
294 for iter in 0..config.maxiter {
295 nit = iter + 1;
296 if let Some(ref mut cb) = callback {
298 let intermediate = LMIntermediate {
299 x: x.clone(),
300 fun: f_val,
301 lambda,
302 iter,
303 };
304 if cb(&intermediate) == LMCallbackAction::Stop {
305 message = "stopped by callback".to_string();
306 break;
307 }
308 }
309
310 if f_val <= config.atol {
312 success = true;
313 message = format!(
314 "converged: objective {:.2e} <= atol {:.2e}",
315 f_val, config.atol
316 );
317 break;
318 }
319
320 let jac = compute_jacobian(residual_fn, &x, &r, n_residuals, eps, &mut nfev)?;
322
323 let jtwj = jtw_j(&jac, &w);
327 let jtwr = jtw_r(&jac, &w, &r); let diag: Array1<f64> = (0..n_params).map(|j| jtwj[[j, j]].max(1e-20)).collect();
331
332 let mut h = jtwj.clone();
334 for j in 0..n_params {
335 h[[j, j]] += lambda * diag[j];
336 }
337
338 let neg_jtwr = jtwr.mapv(|v| -v);
340
341 let delta = match solve_linear_system(&h, &neg_jtwr) {
343 Some(d) => d,
344 None => {
345 lambda *= 4.0;
347 continue;
348 }
349 };
350
351 let x_trial = project(&(&x + &delta), bounds);
353 let r_trial = residual_fn(&x_trial);
354 nfev += 1;
355
356 if r_trial.len() != n_residuals {
357 return Err(LMError::ResidualDimensionChanged {
358 expected: n_residuals,
359 got: r_trial.len(),
360 });
361 }
362
363 let f_trial = weighted_sos(&r_trial, &w);
364
365 let pred: f64 = {
369 let mut p = 0.0;
370 for j in 0..n_params {
371 let mut jtwj_delta_j = 0.0;
372 for k in 0..n_params {
373 jtwj_delta_j += jtwj[[j, k]] * delta[k];
374 }
375 p += delta[j] * jtwj_delta_j + 2.0 * lambda * diag[j] * delta[j] * delta[j];
376 }
377 p
378 };
379
380 let actual = f_val - f_trial;
381 let rho = if pred.abs() > 1e-30 {
382 actual / pred
383 } else {
384 0.0
385 };
386
387 if rho > 0.0 && f_trial < f_val {
388 let f_old = f_val;
390 x = x_trial;
391 r = r_trial;
392 f_val = f_trial;
393 lambda = (lambda / 2.0).max(1e-15);
394
395 if (f_old - f_val).abs() < config.tol * f_old + config.atol {
397 success = true;
398 message = format!(
399 "converged: |df|={:.2e} < tol*f+atol={:.2e}",
400 (f_old - f_val).abs(),
401 config.tol * f_old + config.atol
402 );
403 break;
404 }
405 } else {
406 lambda = (lambda * 2.0).min(1e15);
408 }
409
410 if config.disp && iter % 10 == 0 {
411 eprintln!(
412 "LM iter {}: f={:.6e}, lambda={:.2e}, rho={:.3}",
413 iter, f_val, lambda, rho
414 );
415 }
416 }
417
418 Ok(LMReport {
419 x,
420 fun: f_val,
421 residuals: r,
422 success,
423 message,
424 nit,
425 nfev,
426 })
427}
428
429fn project(x: &Array1<f64>, bounds: &[(f64, f64)]) -> Array1<f64> {
435 Array1::from(
436 x.iter()
437 .zip(bounds.iter())
438 .map(|(&xi, &(lo, hi))| xi.clamp(lo, hi))
439 .collect::<Vec<_>>(),
440 )
441}
442
443fn weighted_sos(r: &Array1<f64>, w: &Array1<f64>) -> f64 {
445 r.iter().zip(w.iter()).map(|(&ri, &wi)| wi * ri * ri).sum()
446}
447
448fn compute_jacobian<F>(
450 residual_fn: &F,
451 x: &Array1<f64>,
452 _r0: &Array1<f64>,
453 n_residuals: usize,
454 eps: f64,
455 nfev: &mut usize,
456) -> LMResult<Array2<f64>>
457where
458 F: Fn(&Array1<f64>) -> Array1<f64>,
459{
460 let n_params = x.len();
461 let mut jac = Array2::zeros((n_residuals, n_params));
462
463 for j in 0..n_params {
464 let mut x_plus = x.clone();
465 let mut x_minus = x.clone();
466 let h = eps.max(eps * x[j].abs());
467 x_plus[j] += h;
468 x_minus[j] -= h;
469
470 let r_plus = residual_fn(&x_plus);
471 let r_minus = residual_fn(&x_minus);
472 *nfev += 2;
473
474 if r_plus.len() != n_residuals {
475 return Err(LMError::ResidualDimensionChanged {
476 expected: n_residuals,
477 got: r_plus.len(),
478 });
479 }
480 if r_minus.len() != n_residuals {
481 return Err(LMError::ResidualDimensionChanged {
482 expected: n_residuals,
483 got: r_minus.len(),
484 });
485 }
486
487 let inv_2h = 1.0 / (2.0 * h);
488 for i in 0..n_residuals {
489 jac[[i, j]] = (r_plus[i] - r_minus[i]) * inv_2h;
490 }
491 }
492
493 Ok(jac)
494}
495
496fn jtw_j(jac: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
498 let n_params = jac.ncols();
499 let n_res = jac.nrows();
500 let mut result = Array2::zeros((n_params, n_params));
501
502 for j in 0..n_params {
503 for k in j..n_params {
504 let mut val = 0.0;
505 for i in 0..n_res {
506 val += w[i] * jac[[i, j]] * jac[[i, k]];
507 }
508 result[[j, k]] = val;
509 result[[k, j]] = val;
510 }
511 }
512
513 result
514}
515
516fn jtw_r(jac: &Array2<f64>, w: &Array1<f64>, r: &Array1<f64>) -> Array1<f64> {
518 let n_params = jac.ncols();
519 let n_res = jac.nrows();
520 let mut result = Array1::zeros(n_params);
521
522 for j in 0..n_params {
523 let mut val = 0.0;
524 for i in 0..n_res {
525 val += w[i] * jac[[i, j]] * r[i];
526 }
527 result[j] = val;
528 }
529
530 result
531}
532
533fn solve_linear_system(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
536 let n = b.len();
537 debug_assert_eq!(a.nrows(), n);
538 debug_assert_eq!(a.ncols(), n);
539
540 let mut aug = Array2::zeros((n, n + 1));
542 for i in 0..n {
543 for j in 0..n {
544 aug[[i, j]] = a[[i, j]];
545 }
546 aug[[i, n]] = b[i];
547 }
548
549 for col in 0..n {
551 let mut max_val = aug[[col, col]].abs();
553 let mut max_row = col;
554 for row in (col + 1)..n {
555 let val = aug[[row, col]].abs();
556 if val > max_val {
557 max_val = val;
558 max_row = row;
559 }
560 }
561
562 if max_val < 1e-30 {
563 return None; }
565
566 if max_row != col {
568 for j in 0..=n {
569 let tmp = aug[[col, j]];
570 aug[[col, j]] = aug[[max_row, j]];
571 aug[[max_row, j]] = tmp;
572 }
573 }
574
575 let pivot = aug[[col, col]];
577 for row in (col + 1)..n {
578 let factor = aug[[row, col]] / pivot;
579 for j in col..=n {
580 aug[[row, j]] -= factor * aug[[col, j]];
581 }
582 }
583 }
584
585 let mut x = Array1::zeros(n);
587 for col in (0..n).rev() {
588 let mut sum = aug[[col, n]];
589 for j in (col + 1)..n {
590 sum -= aug[[col, j]] * x[j];
591 }
592 x[col] = sum / aug[[col, col]];
593 }
594
595 if x.iter().any(|v| !v.is_finite()) {
597 return None;
598 }
599
600 Some(x)
601}
602
603#[cfg(test)]
608mod tests {
609 use super::*;
610 use ndarray::array;
611
612 #[test]
613 fn test_sphere() {
614 let residual = |x: &Array1<f64>| x.clone();
616 let bounds = vec![(-10.0, 10.0); 3];
617 let config = LMConfigBuilder::new()
618 .x0(array![3.0, -2.0, 1.0])
619 .maxiter(50)
620 .build();
621
622 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
623 assert!(report.success, "should converge: {}", report.message);
624 assert!(
625 report.fun < 1e-12,
626 "objective should be ~0, got {}",
627 report.fun
628 );
629 for &xi in report.x.iter() {
630 assert!(xi.abs() < 1e-6, "x should be ~0, got {}", xi);
631 }
632 }
633
634 #[test]
635 fn test_rosenbrock_residual() {
636 let residual = |x: &Array1<f64>| array![10.0 * (x[1] - x[0] * x[0]), 1.0 - x[0]];
638 let bounds = vec![(-5.0, 5.0); 2];
639 let config = LMConfigBuilder::new()
640 .x0(array![-1.0, 1.0])
641 .maxiter(200)
642 .tol(1e-12)
643 .build();
644
645 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
646 assert!(report.success, "should converge: {}", report.message);
647 assert!(
648 (report.x[0] - 1.0).abs() < 1e-4,
649 "x0 should be ~1, got {}",
650 report.x[0]
651 );
652 assert!(
653 (report.x[1] - 1.0).abs() < 1e-4,
654 "x1 should be ~1, got {}",
655 report.x[1]
656 );
657 }
658
659 #[test]
660 fn test_bounded_solution() {
661 let residual = |x: &Array1<f64>| array![x[0] - 5.0];
663 let bounds = vec![(-10.0, 3.0)];
664 let config = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
665
666 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
667 assert!(
668 (report.x[0] - 3.0).abs() < 1e-6,
669 "x should be at bound 3.0, got {}",
670 report.x[0]
671 );
672 }
673
674 #[test]
675 fn test_nan_residual_handled() {
676 let residual = |x: &Array1<f64>| {
678 if x[0].abs() > 100.0 {
679 array![f64::NAN]
680 } else {
681 array![x[0] - 1.0]
682 }
683 };
684 let bounds = vec![(-200.0, 200.0)];
685 let config = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
686
687 let result = levenberg_marquardt(&residual, &bounds, config);
689 assert!(result.is_ok());
690 }
691
692 #[test]
693 fn test_zero_residual() {
694 let residual = |x: &Array1<f64>| array![x[0], x[1]];
696 let bounds = vec![(-10.0, 10.0); 2];
697 let config = LMConfigBuilder::new()
698 .x0(array![0.0, 0.0])
699 .maxiter(10)
700 .build();
701
702 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
703 assert!(report.success, "already at optimum: {}", report.message);
704 assert!(report.fun < 1e-14);
705 }
706
707 #[test]
708 fn test_callback_stop() {
709 let residual = |x: &Array1<f64>| x.clone();
710 let bounds = vec![(-10.0, 10.0); 2];
711 let config = LMConfigBuilder::new()
712 .x0(array![5.0, 5.0])
713 .maxiter(1000)
714 .callback(Box::new(|inter| {
715 if inter.iter >= 3 {
716 LMCallbackAction::Stop
717 } else {
718 LMCallbackAction::Continue
719 }
720 }))
721 .build();
722
723 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
724 assert_eq!(report.message, "stopped by callback");
725 }
726
727 #[test]
728 fn test_weighted_residuals() {
729 let residual = |x: &Array1<f64>| array![x[0] - 1.0, x[0] - 3.0];
732 let bounds = vec![(-10.0, 10.0)];
733
734 let config_unw = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
736 let report_unw = levenberg_marquardt(&residual, &bounds, config_unw).unwrap();
737
738 let config_w = LMConfigBuilder::new()
740 .x0(array![0.0])
741 .maxiter(50)
742 .weights(array![10.0, 1.0])
743 .build();
744 let report_w = levenberg_marquardt(&residual, &bounds, config_w).unwrap();
745
746 assert!(
748 (report_unw.x[0] - 2.0).abs() < 0.01,
749 "unweighted x should be ~2, got {}",
750 report_unw.x[0]
751 );
752 assert!(
754 report_w.x[0] < report_unw.x[0],
755 "weighted x ({}) should be less than unweighted ({})",
756 report_w.x[0],
757 report_unw.x[0]
758 );
759 assert!(
760 (report_w.x[0] - 1.0).abs() < 0.5,
761 "weighted x should be near 1.0, got {}",
762 report_w.x[0]
763 );
764 }
765
766 #[test]
767 fn test_dimension_mismatch() {
768 let residual = |x: &Array1<f64>| x.clone();
769 let bounds = vec![(-1.0, 1.0); 3]; let config = LMConfigBuilder::new()
771 .x0(array![0.0, 0.0]) .build();
773
774 let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
775 assert!(matches!(err, LMError::DimensionMismatch { .. }));
776 }
777
778 #[test]
779 fn test_nit_tracks_iterations_not_nfev() {
780 let residual = |x: &Array1<f64>| x.clone();
783 let bounds = vec![(-10.0, 10.0); 2];
784 let config = LMConfigBuilder::new()
785 .x0(array![5.0, 5.0])
786 .maxiter(5)
787 .tol(1e-20) .atol(0.0)
789 .build();
790
791 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
792 assert_eq!(
793 report.nit, 5,
794 "nit should be maxiter (5), got {}",
795 report.nit
796 );
797 assert!(
798 report.nfev > report.nit,
799 "nfev ({}) should be much larger than nit ({})",
800 report.nfev,
801 report.nit
802 );
803 }
804
805 #[test]
806 fn test_invalid_bounds() {
807 let residual = |x: &Array1<f64>| x.clone();
808 let bounds = vec![(5.0, 1.0)]; let config = LMConfigBuilder::new().x0(array![0.0]).build();
810
811 let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
812 assert!(matches!(err, LMError::InvalidBounds { .. }));
813 }
814}