1use ndarray::{Array1, Array2};
11use thiserror::Error;
12
13#[derive(Debug, Error)]
19pub enum LMError {
20 #[error("bounds/x0 dimension mismatch: x0 has {x0_len}, bounds has {bounds_len}")]
22 DimensionMismatch {
23 x0_len: usize,
25 bounds_len: usize,
27 },
28
29 #[error("invalid bounds at index {index}: lower ({lower}) > upper ({upper})")]
31 InvalidBounds {
32 index: usize,
34 lower: f64,
36 upper: f64,
38 },
39
40 #[error("residual dimension changed: was {expected}, now {got}")]
42 ResidualDimensionChanged {
43 expected: usize,
45 got: usize,
47 },
48
49 #[error("singular Jacobian — cannot compute step")]
51 SingularJacobian,
52}
53
54pub type LMResult<T> = std::result::Result<T, LMError>;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum LMCallbackAction {
64 Continue,
66 Stop,
68}
69
70pub struct LMIntermediate {
72 pub x: Array1<f64>,
74 pub fun: f64,
76 pub lambda: f64,
78 pub iter: usize,
80}
81
82pub struct LMConfig {
88 pub maxiter: usize,
90 pub tol: f64,
92 pub atol: f64,
94 pub lambda_init: f64,
96 pub jacobian_epsilon: f64,
98 pub x0: Array1<f64>,
100 pub weights: Option<Array1<f64>>,
102 pub disp: bool,
104 pub callback: Option<Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>>,
106}
107
108pub struct LMConfigBuilder {
110 maxiter: usize,
111 tol: f64,
112 atol: f64,
113 lambda_init: f64,
114 jacobian_epsilon: f64,
115 x0: Option<Array1<f64>>,
116 weights: Option<Array1<f64>>,
117 disp: bool,
118 callback: Option<Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>>,
119}
120
121impl LMConfigBuilder {
122 pub fn new() -> Self {
124 Self {
125 maxiter: 100,
126 tol: 1e-10,
127 atol: 1e-14,
128 lambda_init: 1.0,
129 jacobian_epsilon: 1e-8,
130 x0: None,
131 weights: None,
132 disp: false,
133 callback: None,
134 }
135 }
136
137 pub fn x0(mut self, x0: Array1<f64>) -> Self {
139 self.x0 = Some(x0);
140 self
141 }
142
143 pub fn maxiter(mut self, n: usize) -> Self {
145 self.maxiter = n;
146 self
147 }
148
149 pub fn tol(mut self, t: f64) -> Self {
151 self.tol = t;
152 self
153 }
154
155 pub fn atol(mut self, t: f64) -> Self {
157 self.atol = t;
158 self
159 }
160
161 pub fn lambda_init(mut self, l: f64) -> Self {
163 self.lambda_init = l;
164 self
165 }
166
167 pub fn jacobian_epsilon(mut self, eps: f64) -> Self {
169 self.jacobian_epsilon = eps;
170 self
171 }
172
173 pub fn weights(mut self, w: Array1<f64>) -> Self {
175 self.weights = Some(w);
176 self
177 }
178
179 pub fn disp(mut self, d: bool) -> Self {
181 self.disp = d;
182 self
183 }
184
185 pub fn callback(mut self, cb: Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>) -> Self {
187 self.callback = Some(cb);
188 self
189 }
190
191 pub fn build(self) -> LMConfig {
193 LMConfig {
194 maxiter: self.maxiter,
195 tol: self.tol,
196 atol: self.atol,
197 lambda_init: self.lambda_init,
198 jacobian_epsilon: self.jacobian_epsilon,
199 x0: self.x0.expect("LMConfigBuilder: x0 is required"),
200 weights: self.weights,
201 disp: self.disp,
202 callback: self.callback,
203 }
204 }
205}
206
207impl Default for LMConfigBuilder {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213#[derive(Debug)]
219pub struct LMReport {
220 pub x: Array1<f64>,
222 pub fun: f64,
224 pub residuals: Array1<f64>,
226 pub success: bool,
228 pub message: String,
230 pub nit: usize,
232 pub nfev: usize,
234}
235
236pub fn levenberg_marquardt<F>(
249 residual_fn: &F,
250 bounds: &[(f64, f64)],
251 config: LMConfig,
252) -> LMResult<LMReport>
253where
254 F: Fn(&Array1<f64>) -> Array1<f64>,
255{
256 let n_params = config.x0.len();
257
258 if bounds.len() != n_params {
260 return Err(LMError::DimensionMismatch {
261 x0_len: n_params,
262 bounds_len: bounds.len(),
263 });
264 }
265 for (i, &(lo, hi)) in bounds.iter().enumerate() {
266 if lo > hi {
267 return Err(LMError::InvalidBounds {
268 index: i,
269 lower: lo,
270 upper: hi,
271 });
272 }
273 }
274
275 let mut x = project(&config.x0, bounds);
277 let mut r = residual_fn(&x);
278 let n_residuals = r.len();
279 let mut nfev: usize = 1;
280
281 let w = config.weights.unwrap_or_else(|| Array1::ones(n_residuals));
282 let mut f_val = weighted_sos(&r, &w);
283 let mut lambda = config.lambda_init;
284 let eps = config.jacobian_epsilon;
285
286 let mut success = false;
287 let mut message = format!("max iterations ({}) reached", config.maxiter);
288 let mut callback = config.callback;
289 let mut nit: usize = 0;
290
291 for iter in 0..config.maxiter {
292 nit = iter + 1;
293 if let Some(ref mut cb) = callback {
295 let intermediate = LMIntermediate {
296 x: x.clone(),
297 fun: f_val,
298 lambda,
299 iter,
300 };
301 if cb(&intermediate) == LMCallbackAction::Stop {
302 message = "stopped by callback".to_string();
303 break;
304 }
305 }
306
307 if f_val <= config.atol {
309 success = true;
310 message = format!("converged: objective {:.2e} <= atol {:.2e}", f_val, config.atol);
311 break;
312 }
313
314 let jac = compute_jacobian(residual_fn, &x, &r, n_residuals, eps, &mut nfev)?;
316
317 let jtwj = jtw_j(&jac, &w);
321 let jtwr = jtw_r(&jac, &w, &r); let diag: Array1<f64> = (0..n_params)
325 .map(|j| jtwj[[j, j]].max(1e-20))
326 .collect();
327
328 let mut h = jtwj.clone();
330 for j in 0..n_params {
331 h[[j, j]] += lambda * diag[j];
332 }
333
334 let neg_jtwr = jtwr.mapv(|v| -v);
336
337 let delta = match solve_linear_system(&h, &neg_jtwr) {
339 Some(d) => d,
340 None => {
341 lambda *= 4.0;
343 continue;
344 }
345 };
346
347 let x_trial = project(&(&x + &delta), bounds);
349 let r_trial = residual_fn(&x_trial);
350 nfev += 1;
351
352 if r_trial.len() != n_residuals {
353 return Err(LMError::ResidualDimensionChanged {
354 expected: n_residuals,
355 got: r_trial.len(),
356 });
357 }
358
359 let f_trial = weighted_sos(&r_trial, &w);
360
361 let pred: f64 = {
365 let mut p = 0.0;
366 for j in 0..n_params {
367 let mut jtwj_delta_j = 0.0;
368 for k in 0..n_params {
369 jtwj_delta_j += jtwj[[j, k]] * delta[k];
370 }
371 p += delta[j] * jtwj_delta_j + 2.0 * lambda * diag[j] * delta[j] * delta[j];
372 }
373 p
374 };
375
376 let actual = f_val - f_trial;
377 let rho = if pred.abs() > 1e-30 { actual / pred } else { 0.0 };
378
379 if rho > 0.0 && f_trial < f_val {
380 let f_old = f_val;
382 x = x_trial;
383 r = r_trial;
384 f_val = f_trial;
385 lambda = (lambda / 2.0).max(1e-15);
386
387 if (f_old - f_val).abs() < config.tol * f_old + config.atol {
389 success = true;
390 message = format!(
391 "converged: |df|={:.2e} < tol*f+atol={:.2e}",
392 (f_old - f_val).abs(),
393 config.tol * f_old + config.atol
394 );
395 break;
396 }
397 } else {
398 lambda = (lambda * 2.0).min(1e15);
400 }
401
402 if config.disp && iter % 10 == 0 {
403 eprintln!(
404 "LM iter {}: f={:.6e}, lambda={:.2e}, rho={:.3}",
405 iter, f_val, lambda, rho
406 );
407 }
408 }
409
410 Ok(LMReport {
411 x,
412 fun: f_val,
413 residuals: r,
414 success,
415 message,
416 nit,
417 nfev,
418 })
419}
420
421fn project(x: &Array1<f64>, bounds: &[(f64, f64)]) -> Array1<f64> {
427 Array1::from(
428 x.iter()
429 .zip(bounds.iter())
430 .map(|(&xi, &(lo, hi))| xi.clamp(lo, hi))
431 .collect::<Vec<_>>(),
432 )
433}
434
435fn weighted_sos(r: &Array1<f64>, w: &Array1<f64>) -> f64 {
437 r.iter().zip(w.iter()).map(|(&ri, &wi)| wi * ri * ri).sum()
438}
439
440fn compute_jacobian<F>(
442 residual_fn: &F,
443 x: &Array1<f64>,
444 _r0: &Array1<f64>,
445 n_residuals: usize,
446 eps: f64,
447 nfev: &mut usize,
448) -> LMResult<Array2<f64>>
449where
450 F: Fn(&Array1<f64>) -> Array1<f64>,
451{
452 let n_params = x.len();
453 let mut jac = Array2::zeros((n_residuals, n_params));
454
455 for j in 0..n_params {
456 let mut x_plus = x.clone();
457 let mut x_minus = x.clone();
458 let h = eps.max(eps * x[j].abs());
459 x_plus[j] += h;
460 x_minus[j] -= h;
461
462 let r_plus = residual_fn(&x_plus);
463 let r_minus = residual_fn(&x_minus);
464 *nfev += 2;
465
466 if r_plus.len() != n_residuals {
467 return Err(LMError::ResidualDimensionChanged {
468 expected: n_residuals,
469 got: r_plus.len(),
470 });
471 }
472 if r_minus.len() != n_residuals {
473 return Err(LMError::ResidualDimensionChanged {
474 expected: n_residuals,
475 got: r_minus.len(),
476 });
477 }
478
479 let inv_2h = 1.0 / (2.0 * h);
480 for i in 0..n_residuals {
481 jac[[i, j]] = (r_plus[i] - r_minus[i]) * inv_2h;
482 }
483 }
484
485 Ok(jac)
486}
487
488fn jtw_j(jac: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
490 let n_params = jac.ncols();
491 let n_res = jac.nrows();
492 let mut result = Array2::zeros((n_params, n_params));
493
494 for j in 0..n_params {
495 for k in j..n_params {
496 let mut val = 0.0;
497 for i in 0..n_res {
498 val += w[i] * jac[[i, j]] * jac[[i, k]];
499 }
500 result[[j, k]] = val;
501 result[[k, j]] = val;
502 }
503 }
504
505 result
506}
507
508fn jtw_r(jac: &Array2<f64>, w: &Array1<f64>, r: &Array1<f64>) -> Array1<f64> {
510 let n_params = jac.ncols();
511 let n_res = jac.nrows();
512 let mut result = Array1::zeros(n_params);
513
514 for j in 0..n_params {
515 let mut val = 0.0;
516 for i in 0..n_res {
517 val += w[i] * jac[[i, j]] * r[i];
518 }
519 result[j] = val;
520 }
521
522 result
523}
524
525fn solve_linear_system(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
528 let n = b.len();
529 debug_assert_eq!(a.nrows(), n);
530 debug_assert_eq!(a.ncols(), n);
531
532 let mut aug = Array2::zeros((n, n + 1));
534 for i in 0..n {
535 for j in 0..n {
536 aug[[i, j]] = a[[i, j]];
537 }
538 aug[[i, n]] = b[i];
539 }
540
541 for col in 0..n {
543 let mut max_val = aug[[col, col]].abs();
545 let mut max_row = col;
546 for row in (col + 1)..n {
547 let val = aug[[row, col]].abs();
548 if val > max_val {
549 max_val = val;
550 max_row = row;
551 }
552 }
553
554 if max_val < 1e-30 {
555 return None; }
557
558 if max_row != col {
560 for j in 0..=n {
561 let tmp = aug[[col, j]];
562 aug[[col, j]] = aug[[max_row, j]];
563 aug[[max_row, j]] = tmp;
564 }
565 }
566
567 let pivot = aug[[col, col]];
569 for row in (col + 1)..n {
570 let factor = aug[[row, col]] / pivot;
571 for j in col..=n {
572 aug[[row, j]] -= factor * aug[[col, j]];
573 }
574 }
575 }
576
577 let mut x = Array1::zeros(n);
579 for col in (0..n).rev() {
580 let mut sum = aug[[col, n]];
581 for j in (col + 1)..n {
582 sum -= aug[[col, j]] * x[j];
583 }
584 x[col] = sum / aug[[col, col]];
585 }
586
587 if x.iter().any(|v| !v.is_finite()) {
589 return None;
590 }
591
592 Some(x)
593}
594
595#[cfg(test)]
600mod tests {
601 use super::*;
602 use ndarray::array;
603
604 #[test]
605 fn test_sphere() {
606 let residual = |x: &Array1<f64>| x.clone();
608 let bounds = vec![(-10.0, 10.0); 3];
609 let config = LMConfigBuilder::new()
610 .x0(array![3.0, -2.0, 1.0])
611 .maxiter(50)
612 .build();
613
614 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
615 assert!(report.success, "should converge: {}", report.message);
616 assert!(report.fun < 1e-12, "objective should be ~0, got {}", report.fun);
617 for &xi in report.x.iter() {
618 assert!(xi.abs() < 1e-6, "x should be ~0, got {}", xi);
619 }
620 }
621
622 #[test]
623 fn test_rosenbrock_residual() {
624 let residual = |x: &Array1<f64>| array![10.0 * (x[1] - x[0] * x[0]), 1.0 - x[0]];
626 let bounds = vec![(-5.0, 5.0); 2];
627 let config = LMConfigBuilder::new()
628 .x0(array![-1.0, 1.0])
629 .maxiter(200)
630 .tol(1e-12)
631 .build();
632
633 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
634 assert!(report.success, "should converge: {}", report.message);
635 assert!((report.x[0] - 1.0).abs() < 1e-4, "x0 should be ~1, got {}", report.x[0]);
636 assert!((report.x[1] - 1.0).abs() < 1e-4, "x1 should be ~1, got {}", report.x[1]);
637 }
638
639 #[test]
640 fn test_bounded_solution() {
641 let residual = |x: &Array1<f64>| array![x[0] - 5.0];
643 let bounds = vec![(-10.0, 3.0)];
644 let config = LMConfigBuilder::new()
645 .x0(array![0.0])
646 .maxiter(50)
647 .build();
648
649 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
650 assert!(
651 (report.x[0] - 3.0).abs() < 1e-6,
652 "x should be at bound 3.0, got {}",
653 report.x[0]
654 );
655 }
656
657 #[test]
658 fn test_nan_residual_handled() {
659 let residual = |x: &Array1<f64>| {
661 if x[0].abs() > 100.0 {
662 array![f64::NAN]
663 } else {
664 array![x[0] - 1.0]
665 }
666 };
667 let bounds = vec![(-200.0, 200.0)];
668 let config = LMConfigBuilder::new()
669 .x0(array![0.0])
670 .maxiter(50)
671 .build();
672
673 let result = levenberg_marquardt(&residual, &bounds, config);
675 assert!(result.is_ok());
676 }
677
678 #[test]
679 fn test_zero_residual() {
680 let residual = |x: &Array1<f64>| array![x[0], x[1]];
682 let bounds = vec![(-10.0, 10.0); 2];
683 let config = LMConfigBuilder::new()
684 .x0(array![0.0, 0.0])
685 .maxiter(10)
686 .build();
687
688 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
689 assert!(report.success, "already at optimum: {}", report.message);
690 assert!(report.fun < 1e-14);
691 }
692
693 #[test]
694 fn test_callback_stop() {
695 let residual = |x: &Array1<f64>| x.clone();
696 let bounds = vec![(-10.0, 10.0); 2];
697 let config = LMConfigBuilder::new()
698 .x0(array![5.0, 5.0])
699 .maxiter(1000)
700 .callback(Box::new(|inter| {
701 if inter.iter >= 3 {
702 LMCallbackAction::Stop
703 } else {
704 LMCallbackAction::Continue
705 }
706 }))
707 .build();
708
709 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
710 assert_eq!(report.message, "stopped by callback");
711 }
712
713 #[test]
714 fn test_weighted_residuals() {
715 let residual = |x: &Array1<f64>| array![x[0] - 1.0, x[0] - 3.0];
718 let bounds = vec![(-10.0, 10.0)];
719
720 let config_unw = LMConfigBuilder::new()
722 .x0(array![0.0])
723 .maxiter(50)
724 .build();
725 let report_unw = levenberg_marquardt(&residual, &bounds, config_unw).unwrap();
726
727 let config_w = LMConfigBuilder::new()
729 .x0(array![0.0])
730 .maxiter(50)
731 .weights(array![10.0, 1.0])
732 .build();
733 let report_w = levenberg_marquardt(&residual, &bounds, config_w).unwrap();
734
735 assert!(
737 (report_unw.x[0] - 2.0).abs() < 0.01,
738 "unweighted x should be ~2, got {}",
739 report_unw.x[0]
740 );
741 assert!(
743 report_w.x[0] < report_unw.x[0],
744 "weighted x ({}) should be less than unweighted ({})",
745 report_w.x[0],
746 report_unw.x[0]
747 );
748 assert!(
749 (report_w.x[0] - 1.0).abs() < 0.5,
750 "weighted x should be near 1.0, got {}",
751 report_w.x[0]
752 );
753 }
754
755 #[test]
756 fn test_dimension_mismatch() {
757 let residual = |x: &Array1<f64>| x.clone();
758 let bounds = vec![(-1.0, 1.0); 3]; let config = LMConfigBuilder::new()
760 .x0(array![0.0, 0.0]) .build();
762
763 let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
764 assert!(matches!(err, LMError::DimensionMismatch { .. }));
765 }
766
767 #[test]
768 fn test_nit_tracks_iterations_not_nfev() {
769 let residual = |x: &Array1<f64>| x.clone();
772 let bounds = vec![(-10.0, 10.0); 2];
773 let config = LMConfigBuilder::new()
774 .x0(array![5.0, 5.0])
775 .maxiter(5)
776 .tol(1e-20) .atol(0.0)
778 .build();
779
780 let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
781 assert_eq!(report.nit, 5, "nit should be maxiter (5), got {}", report.nit);
782 assert!(
783 report.nfev > report.nit,
784 "nfev ({}) should be much larger than nit ({})",
785 report.nfev,
786 report.nit
787 );
788 }
789
790 #[test]
791 fn test_invalid_bounds() {
792 let residual = |x: &Array1<f64>| x.clone();
793 let bounds = vec![(5.0, 1.0)]; let config = LMConfigBuilder::new().x0(array![0.0]).build();
795
796 let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
797 assert!(matches!(err, LMError::InvalidBounds { .. }));
798 }
799}