Skip to main content

math_audio_optimisation/
levenberg_marquardt.rs

1//! Levenberg-Marquardt bounded nonlinear least-squares solver.
2//!
3//! Solves `min_x  ||r(x)||^2_W` subject to `lb <= x <= ub`, where `r(x)` is a
4//! vector-valued residual function and `W` is a diagonal weight matrix.
5//!
6//! The solver interpolates between Gauss-Newton (fast near the minimum) and
7//! gradient descent (robust far from it) via an adaptive damping parameter `lambda`.
8//! Bounds are enforced by projecting the trial step.
9
10use ndarray::{Array1, Array2};
11use thiserror::Error;
12
13// ---------------------------------------------------------------------------
14// Error types
15// ---------------------------------------------------------------------------
16
17/// Errors from the Levenberg-Marquardt solver.
18#[derive(Debug, Error)]
19pub enum LMError {
20    /// Lower and upper bounds have different lengths, or don't match x0.
21    #[error("bounds/x0 dimension mismatch: x0 has {x0_len}, bounds has {bounds_len}")]
22    DimensionMismatch {
23        /// Length of x0
24        x0_len: usize,
25        /// Number of bound pairs
26        bounds_len: usize,
27    },
28
29    /// A lower bound exceeds its upper bound.
30    #[error("invalid bounds at index {index}: lower ({lower}) > upper ({upper})")]
31    InvalidBounds {
32        /// Index of the bad pair
33        index: usize,
34        /// Lower bound
35        lower: f64,
36        /// Upper bound
37        upper: f64,
38    },
39
40    /// Residual function returned different sizes on successive calls.
41    #[error("residual dimension changed: was {expected}, now {got}")]
42    ResidualDimensionChanged {
43        /// First observed size
44        expected: usize,
45        /// New size
46        got: usize,
47    },
48
49    /// The Jacobian-derived system is singular (QR rank-deficient).
50    #[error("singular Jacobian — cannot compute step")]
51    SingularJacobian,
52}
53
54/// Result alias for LM operations.
55pub type LMResult<T> = std::result::Result<T, LMError>;
56
57// ---------------------------------------------------------------------------
58// Callback
59// ---------------------------------------------------------------------------
60
61/// Action returned by a progress callback.
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum LMCallbackAction {
64    /// Keep iterating.
65    Continue,
66    /// Stop early.
67    Stop,
68}
69
70/// Snapshot of solver state passed to the callback.
71pub struct LMIntermediate {
72    /// Current parameter vector.
73    pub x: Array1<f64>,
74    /// Current objective value (weighted sum of squared residuals).
75    pub fun: f64,
76    /// Current damping parameter.
77    pub lambda: f64,
78    /// Iteration number.
79    pub iter: usize,
80}
81
82// ---------------------------------------------------------------------------
83// Config
84// ---------------------------------------------------------------------------
85
86/// Progress callback type for the LM solver.
87pub type LMCallback = Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>;
88
89/// Configuration for the LM solver.
90pub struct LMConfig {
91    /// Maximum iterations (default 100).
92    pub maxiter: usize,
93    /// Relative convergence tolerance on the objective.
94    pub tol: f64,
95    /// Absolute convergence tolerance on the objective.
96    pub atol: f64,
97    /// Initial damping parameter (default 1.0).
98    pub lambda_init: f64,
99    /// Finite-difference step for Jacobian approximation (default 1e-8).
100    pub jacobian_epsilon: f64,
101    /// Initial guess (required).
102    pub x0: Array1<f64>,
103    /// Per-residual weights (optional; default = uniform).
104    pub weights: Option<Array1<f64>>,
105    /// Print progress to stderr.
106    pub disp: bool,
107    /// Optional progress callback.
108    pub callback: Option<LMCallback>,
109}
110
111/// Fluent builder for [`LMConfig`].
112pub struct LMConfigBuilder {
113    maxiter: usize,
114    tol: f64,
115    atol: f64,
116    lambda_init: f64,
117    jacobian_epsilon: f64,
118    x0: Option<Array1<f64>>,
119    weights: Option<Array1<f64>>,
120    disp: bool,
121    callback: Option<LMCallback>,
122}
123
124impl LMConfigBuilder {
125    /// Create a new builder with defaults.
126    pub fn new() -> Self {
127        Self {
128            maxiter: 100,
129            tol: 1e-10,
130            atol: 1e-14,
131            lambda_init: 1.0,
132            jacobian_epsilon: 1e-8,
133            x0: None,
134            weights: None,
135            disp: false,
136            callback: None,
137        }
138    }
139
140    /// Set the initial guess (required).
141    pub fn x0(mut self, x0: Array1<f64>) -> Self {
142        self.x0 = Some(x0);
143        self
144    }
145
146    /// Maximum number of iterations.
147    pub fn maxiter(mut self, n: usize) -> Self {
148        self.maxiter = n;
149        self
150    }
151
152    /// Relative convergence tolerance.
153    pub fn tol(mut self, t: f64) -> Self {
154        self.tol = t;
155        self
156    }
157
158    /// Absolute convergence tolerance.
159    pub fn atol(mut self, t: f64) -> Self {
160        self.atol = t;
161        self
162    }
163
164    /// Initial damping factor.
165    pub fn lambda_init(mut self, l: f64) -> Self {
166        self.lambda_init = l;
167        self
168    }
169
170    /// Finite-difference epsilon for Jacobian.
171    pub fn jacobian_epsilon(mut self, eps: f64) -> Self {
172        self.jacobian_epsilon = eps;
173        self
174    }
175
176    /// Per-residual weights.
177    pub fn weights(mut self, w: Array1<f64>) -> Self {
178        self.weights = Some(w);
179        self
180    }
181
182    /// Print progress.
183    pub fn disp(mut self, d: bool) -> Self {
184        self.disp = d;
185        self
186    }
187
188    /// Progress callback.
189    pub fn callback(mut self, cb: Box<dyn FnMut(&LMIntermediate) -> LMCallbackAction>) -> Self {
190        self.callback = Some(cb);
191        self
192    }
193
194    /// Build the config. Panics if `x0` was not set.
195    pub fn build(self) -> LMConfig {
196        LMConfig {
197            maxiter: self.maxiter,
198            tol: self.tol,
199            atol: self.atol,
200            lambda_init: self.lambda_init,
201            jacobian_epsilon: self.jacobian_epsilon,
202            x0: self.x0.expect("LMConfigBuilder: x0 is required"),
203            weights: self.weights,
204            disp: self.disp,
205            callback: self.callback,
206        }
207    }
208}
209
210impl Default for LMConfigBuilder {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216// ---------------------------------------------------------------------------
217// Report
218// ---------------------------------------------------------------------------
219
220/// Result of a Levenberg-Marquardt optimization.
221#[derive(Debug)]
222pub struct LMReport {
223    /// Optimal parameter vector.
224    pub x: Array1<f64>,
225    /// Final objective value (weighted sum of squared residuals).
226    pub fun: f64,
227    /// Final residual vector.
228    pub residuals: Array1<f64>,
229    /// Whether the solver converged.
230    pub success: bool,
231    /// Human-readable termination message.
232    pub message: String,
233    /// Number of iterations performed.
234    pub nit: usize,
235    /// Number of residual function evaluations.
236    pub nfev: usize,
237}
238
239// ---------------------------------------------------------------------------
240// Solver
241// ---------------------------------------------------------------------------
242
243/// Bounded Levenberg-Marquardt nonlinear least-squares solver.
244///
245/// Minimizes `sum_i  w_i * r_i(x)^2` subject to `bounds[j].0 <= x[j] <= bounds[j].1`.
246///
247/// # Arguments
248/// * `residual_fn` — returns the residual vector `r(x)`.
249/// * `bounds` — per-parameter `(lower, upper)` pairs.
250/// * `config` — solver configuration (must include `x0`).
251pub fn levenberg_marquardt<F>(
252    residual_fn: &F,
253    bounds: &[(f64, f64)],
254    config: LMConfig,
255) -> LMResult<LMReport>
256where
257    F: Fn(&Array1<f64>) -> Array1<f64>,
258{
259    let n_params = config.x0.len();
260
261    // --- Validate inputs ---
262    if bounds.len() != n_params {
263        return Err(LMError::DimensionMismatch {
264            x0_len: n_params,
265            bounds_len: bounds.len(),
266        });
267    }
268    for (i, &(lo, hi)) in bounds.iter().enumerate() {
269        if lo > hi {
270            return Err(LMError::InvalidBounds {
271                index: i,
272                lower: lo,
273                upper: hi,
274            });
275        }
276    }
277
278    // --- Initialize ---
279    let mut x = project(&config.x0, bounds);
280    let mut r = residual_fn(&x);
281    let n_residuals = r.len();
282    let mut nfev: usize = 1;
283
284    let w = config.weights.unwrap_or_else(|| Array1::ones(n_residuals));
285    let mut f_val = weighted_sos(&r, &w);
286    let mut lambda = config.lambda_init;
287    let eps = config.jacobian_epsilon;
288
289    let mut success = false;
290    let mut message = format!("max iterations ({}) reached", config.maxiter);
291    let mut callback = config.callback;
292    let mut nit: usize = 0;
293
294    for iter in 0..config.maxiter {
295        nit = iter + 1;
296        // --- Callback ---
297        if let Some(ref mut cb) = callback {
298            let intermediate = LMIntermediate {
299                x: x.clone(),
300                fun: f_val,
301                lambda,
302                iter,
303            };
304            if cb(&intermediate) == LMCallbackAction::Stop {
305                message = "stopped by callback".to_string();
306                break;
307            }
308        }
309
310        // --- Check for already-converged ---
311        if f_val <= config.atol {
312            success = true;
313            message = format!(
314                "converged: objective {:.2e} <= atol {:.2e}",
315                f_val, config.atol
316            );
317            break;
318        }
319
320        // --- Compute Jacobian via central finite differences ---
321        let jac = compute_jacobian(residual_fn, &x, &r, n_residuals, eps, &mut nfev)?;
322
323        // --- Form damped normal equations ---
324        // LM step solves: (J^T W J + λ·D) δ = −J^T W r
325        // where D = diag(J^T W J) (Marquardt scaling).
326        let jtwj = jtw_j(&jac, &w);
327        let jtwr = jtw_r(&jac, &w, &r); // gradient direction
328
329        // Marquardt diagonal scaling
330        let diag: Array1<f64> = (0..n_params).map(|j| jtwj[[j, j]].max(1e-20)).collect();
331
332        // H = J^T W J + λ·diag(D)
333        let mut h = jtwj.clone();
334        for j in 0..n_params {
335            h[[j, j]] += lambda * diag[j];
336        }
337
338        // RHS = −J^T W r  (descent direction)
339        let neg_jtwr = jtwr.mapv(|v| -v);
340
341        // Solve H · δ = −J^T W r
342        let delta = match solve_linear_system(&h, &neg_jtwr) {
343            Some(d) => d,
344            None => {
345                // Increase damping and retry
346                lambda *= 4.0;
347                continue;
348            }
349        };
350
351        // --- Trial step with bounds projection ---
352        let x_trial = project(&(&x + &delta), bounds);
353        let r_trial = residual_fn(&x_trial);
354        nfev += 1;
355
356        if r_trial.len() != n_residuals {
357            return Err(LMError::ResidualDimensionChanged {
358                expected: n_residuals,
359                got: r_trial.len(),
360            });
361        }
362
363        let f_trial = weighted_sos(&r_trial, &w);
364
365        // --- Gain ratio ---
366        // Model reduction for f(x) = r^T W r with step δ solving (JTWJ + λD)δ = −jtwr:
367        //   pred = δ^T JTWJ δ + 2λ δ^T D δ
368        let pred: f64 = {
369            let mut p = 0.0;
370            for j in 0..n_params {
371                let mut jtwj_delta_j = 0.0;
372                for k in 0..n_params {
373                    jtwj_delta_j += jtwj[[j, k]] * delta[k];
374                }
375                p += delta[j] * jtwj_delta_j + 2.0 * lambda * diag[j] * delta[j] * delta[j];
376            }
377            p
378        };
379
380        let actual = f_val - f_trial;
381        let rho = if pred.abs() > 1e-30 {
382            actual / pred
383        } else {
384            0.0
385        };
386
387        if rho > 0.0 && f_trial < f_val {
388            // Accept step
389            let f_old = f_val;
390            x = x_trial;
391            r = r_trial;
392            f_val = f_trial;
393            lambda = (lambda / 2.0).max(1e-15);
394
395            // Convergence check
396            if (f_old - f_val).abs() < config.tol * f_old + config.atol {
397                success = true;
398                message = format!(
399                    "converged: |df|={:.2e} < tol*f+atol={:.2e}",
400                    (f_old - f_val).abs(),
401                    config.tol * f_old + config.atol
402                );
403                break;
404            }
405        } else {
406            // Reject step — increase damping
407            lambda = (lambda * 2.0).min(1e15);
408        }
409
410        if config.disp && iter % 10 == 0 {
411            eprintln!(
412                "LM iter {}: f={:.6e}, lambda={:.2e}, rho={:.3}",
413                iter, f_val, lambda, rho
414            );
415        }
416    }
417
418    Ok(LMReport {
419        x,
420        fun: f_val,
421        residuals: r,
422        success,
423        message,
424        nit,
425        nfev,
426    })
427}
428
429// ---------------------------------------------------------------------------
430// Internal helpers
431// ---------------------------------------------------------------------------
432
433/// Project `x` onto the box `[lb, ub]`.
434fn project(x: &Array1<f64>, bounds: &[(f64, f64)]) -> Array1<f64> {
435    Array1::from(
436        x.iter()
437            .zip(bounds.iter())
438            .map(|(&xi, &(lo, hi))| xi.clamp(lo, hi))
439            .collect::<Vec<_>>(),
440    )
441}
442
443/// Weighted sum of squares: sum(w_i * r_i^2).
444fn weighted_sos(r: &Array1<f64>, w: &Array1<f64>) -> f64 {
445    r.iter().zip(w.iter()).map(|(&ri, &wi)| wi * ri * ri).sum()
446}
447
448/// Compute Jacobian via central finite differences.
449fn compute_jacobian<F>(
450    residual_fn: &F,
451    x: &Array1<f64>,
452    _r0: &Array1<f64>,
453    n_residuals: usize,
454    eps: f64,
455    nfev: &mut usize,
456) -> LMResult<Array2<f64>>
457where
458    F: Fn(&Array1<f64>) -> Array1<f64>,
459{
460    let n_params = x.len();
461    let mut jac = Array2::zeros((n_residuals, n_params));
462
463    for j in 0..n_params {
464        let mut x_plus = x.clone();
465        let mut x_minus = x.clone();
466        let h = eps.max(eps * x[j].abs());
467        x_plus[j] += h;
468        x_minus[j] -= h;
469
470        let r_plus = residual_fn(&x_plus);
471        let r_minus = residual_fn(&x_minus);
472        *nfev += 2;
473
474        if r_plus.len() != n_residuals {
475            return Err(LMError::ResidualDimensionChanged {
476                expected: n_residuals,
477                got: r_plus.len(),
478            });
479        }
480        if r_minus.len() != n_residuals {
481            return Err(LMError::ResidualDimensionChanged {
482                expected: n_residuals,
483                got: r_minus.len(),
484            });
485        }
486
487        let inv_2h = 1.0 / (2.0 * h);
488        for i in 0..n_residuals {
489            jac[[i, j]] = (r_plus[i] - r_minus[i]) * inv_2h;
490        }
491    }
492
493    Ok(jac)
494}
495
496/// Compute J^T W J (n_params x n_params).
497fn jtw_j(jac: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
498    let n_params = jac.ncols();
499    let n_res = jac.nrows();
500    let mut result = Array2::zeros((n_params, n_params));
501
502    for j in 0..n_params {
503        for k in j..n_params {
504            let mut val = 0.0;
505            for i in 0..n_res {
506                val += w[i] * jac[[i, j]] * jac[[i, k]];
507            }
508            result[[j, k]] = val;
509            result[[k, j]] = val;
510        }
511    }
512
513    result
514}
515
516/// Compute J^T W r (n_params vector).
517fn jtw_r(jac: &Array2<f64>, w: &Array1<f64>, r: &Array1<f64>) -> Array1<f64> {
518    let n_params = jac.ncols();
519    let n_res = jac.nrows();
520    let mut result = Array1::zeros(n_params);
521
522    for j in 0..n_params {
523        let mut val = 0.0;
524        for i in 0..n_res {
525            val += w[i] * jac[[i, j]] * r[i];
526        }
527        result[j] = val;
528    }
529
530    result
531}
532
533/// Solve a dense symmetric positive-definite system Ax = b via Gaussian
534/// elimination with partial pivoting. Returns `None` if the matrix is singular.
535fn solve_linear_system(a: &Array2<f64>, b: &Array1<f64>) -> Option<Array1<f64>> {
536    let n = b.len();
537    debug_assert_eq!(a.nrows(), n);
538    debug_assert_eq!(a.ncols(), n);
539
540    // Augmented matrix [A | b]
541    let mut aug = Array2::zeros((n, n + 1));
542    for i in 0..n {
543        for j in 0..n {
544            aug[[i, j]] = a[[i, j]];
545        }
546        aug[[i, n]] = b[i];
547    }
548
549    // Forward elimination with partial pivoting
550    for col in 0..n {
551        // Find pivot
552        let mut max_val = aug[[col, col]].abs();
553        let mut max_row = col;
554        for row in (col + 1)..n {
555            let val = aug[[row, col]].abs();
556            if val > max_val {
557                max_val = val;
558                max_row = row;
559            }
560        }
561
562        if max_val < 1e-30 {
563            return None; // Singular
564        }
565
566        // Swap rows
567        if max_row != col {
568            for j in 0..=n {
569                let tmp = aug[[col, j]];
570                aug[[col, j]] = aug[[max_row, j]];
571                aug[[max_row, j]] = tmp;
572            }
573        }
574
575        // Eliminate below
576        let pivot = aug[[col, col]];
577        for row in (col + 1)..n {
578            let factor = aug[[row, col]] / pivot;
579            for j in col..=n {
580                aug[[row, j]] -= factor * aug[[col, j]];
581            }
582        }
583    }
584
585    // Back substitution
586    let mut x = Array1::zeros(n);
587    for col in (0..n).rev() {
588        let mut sum = aug[[col, n]];
589        for j in (col + 1)..n {
590            sum -= aug[[col, j]] * x[j];
591        }
592        x[col] = sum / aug[[col, col]];
593    }
594
595    // NaN check
596    if x.iter().any(|v| !v.is_finite()) {
597        return None;
598    }
599
600    Some(x)
601}
602
603// ---------------------------------------------------------------------------
604// Tests
605// ---------------------------------------------------------------------------
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610    use ndarray::array;
611
612    #[test]
613    fn test_sphere() {
614        // r_i = x_i, minimum at origin
615        let residual = |x: &Array1<f64>| x.clone();
616        let bounds = vec![(-10.0, 10.0); 3];
617        let config = LMConfigBuilder::new()
618            .x0(array![3.0, -2.0, 1.0])
619            .maxiter(50)
620            .build();
621
622        let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
623        assert!(report.success, "should converge: {}", report.message);
624        assert!(
625            report.fun < 1e-12,
626            "objective should be ~0, got {}",
627            report.fun
628        );
629        for &xi in report.x.iter() {
630            assert!(xi.abs() < 1e-6, "x should be ~0, got {}", xi);
631        }
632    }
633
634    #[test]
635    fn test_rosenbrock_residual() {
636        // r = [10*(x1 - x0^2), 1 - x0]  =>  min at (1, 1)
637        let residual = |x: &Array1<f64>| array![10.0 * (x[1] - x[0] * x[0]), 1.0 - x[0]];
638        let bounds = vec![(-5.0, 5.0); 2];
639        let config = LMConfigBuilder::new()
640            .x0(array![-1.0, 1.0])
641            .maxiter(200)
642            .tol(1e-12)
643            .build();
644
645        let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
646        assert!(report.success, "should converge: {}", report.message);
647        assert!(
648            (report.x[0] - 1.0).abs() < 1e-4,
649            "x0 should be ~1, got {}",
650            report.x[0]
651        );
652        assert!(
653            (report.x[1] - 1.0).abs() < 1e-4,
654            "x1 should be ~1, got {}",
655            report.x[1]
656        );
657    }
658
659    #[test]
660    fn test_bounded_solution() {
661        // r = [x - 5]  =>  unconstrained min at x=5, but bound x <= 3
662        let residual = |x: &Array1<f64>| array![x[0] - 5.0];
663        let bounds = vec![(-10.0, 3.0)];
664        let config = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
665
666        let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
667        assert!(
668            (report.x[0] - 3.0).abs() < 1e-6,
669            "x should be at bound 3.0, got {}",
670            report.x[0]
671        );
672    }
673
674    #[test]
675    fn test_nan_residual_handled() {
676        // Residual that returns NaN for large x
677        let residual = |x: &Array1<f64>| {
678            if x[0].abs() > 100.0 {
679                array![f64::NAN]
680            } else {
681                array![x[0] - 1.0]
682            }
683        };
684        let bounds = vec![(-200.0, 200.0)];
685        let config = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
686
687        // Should not panic
688        let result = levenberg_marquardt(&residual, &bounds, config);
689        assert!(result.is_ok());
690    }
691
692    #[test]
693    fn test_zero_residual() {
694        // Already at the optimum
695        let residual = |x: &Array1<f64>| array![x[0], x[1]];
696        let bounds = vec![(-10.0, 10.0); 2];
697        let config = LMConfigBuilder::new()
698            .x0(array![0.0, 0.0])
699            .maxiter(10)
700            .build();
701
702        let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
703        assert!(report.success, "already at optimum: {}", report.message);
704        assert!(report.fun < 1e-14);
705    }
706
707    #[test]
708    fn test_callback_stop() {
709        let residual = |x: &Array1<f64>| x.clone();
710        let bounds = vec![(-10.0, 10.0); 2];
711        let config = LMConfigBuilder::new()
712            .x0(array![5.0, 5.0])
713            .maxiter(1000)
714            .callback(Box::new(|inter| {
715                if inter.iter >= 3 {
716                    LMCallbackAction::Stop
717                } else {
718                    LMCallbackAction::Continue
719                }
720            }))
721            .build();
722
723        let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
724        assert_eq!(report.message, "stopped by callback");
725    }
726
727    #[test]
728    fn test_weighted_residuals() {
729        // r = [x - 1, x - 3], unweighted min at x=2
730        // With w = [10, 1], should pull x closer to 1
731        let residual = |x: &Array1<f64>| array![x[0] - 1.0, x[0] - 3.0];
732        let bounds = vec![(-10.0, 10.0)];
733
734        // Unweighted
735        let config_unw = LMConfigBuilder::new().x0(array![0.0]).maxiter(50).build();
736        let report_unw = levenberg_marquardt(&residual, &bounds, config_unw).unwrap();
737
738        // Weighted (10x weight on first residual)
739        let config_w = LMConfigBuilder::new()
740            .x0(array![0.0])
741            .maxiter(50)
742            .weights(array![10.0, 1.0])
743            .build();
744        let report_w = levenberg_marquardt(&residual, &bounds, config_w).unwrap();
745
746        // Unweighted should land at ~2.0
747        assert!(
748            (report_unw.x[0] - 2.0).abs() < 0.01,
749            "unweighted x should be ~2, got {}",
750            report_unw.x[0]
751        );
752        // Weighted should be closer to 1.0
753        assert!(
754            report_w.x[0] < report_unw.x[0],
755            "weighted x ({}) should be less than unweighted ({})",
756            report_w.x[0],
757            report_unw.x[0]
758        );
759        assert!(
760            (report_w.x[0] - 1.0).abs() < 0.5,
761            "weighted x should be near 1.0, got {}",
762            report_w.x[0]
763        );
764    }
765
766    #[test]
767    fn test_dimension_mismatch() {
768        let residual = |x: &Array1<f64>| x.clone();
769        let bounds = vec![(-1.0, 1.0); 3]; // 3 bounds
770        let config = LMConfigBuilder::new()
771            .x0(array![0.0, 0.0]) // 2 params
772            .build();
773
774        let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
775        assert!(matches!(err, LMError::DimensionMismatch { .. }));
776    }
777
778    #[test]
779    fn test_nit_tracks_iterations_not_nfev() {
780        // nit should count outer iterations, nfev counts residual evaluations
781        // For a 2-param problem, each iteration does 1 trial + 4 FD evals = 5+ nfev
782        let residual = |x: &Array1<f64>| x.clone();
783        let bounds = vec![(-10.0, 10.0); 2];
784        let config = LMConfigBuilder::new()
785            .x0(array![5.0, 5.0])
786            .maxiter(5)
787            .tol(1e-20) // don't converge early
788            .atol(0.0)
789            .build();
790
791        let report = levenberg_marquardt(&residual, &bounds, config).unwrap();
792        assert_eq!(
793            report.nit, 5,
794            "nit should be maxiter (5), got {}",
795            report.nit
796        );
797        assert!(
798            report.nfev > report.nit,
799            "nfev ({}) should be much larger than nit ({})",
800            report.nfev,
801            report.nit
802        );
803    }
804
805    #[test]
806    fn test_invalid_bounds() {
807        let residual = |x: &Array1<f64>| x.clone();
808        let bounds = vec![(5.0, 1.0)]; // lower > upper
809        let config = LMConfigBuilder::new().x0(array![0.0]).build();
810
811        let err = levenberg_marquardt(&residual, &bounds, config).unwrap_err();
812        assert!(matches!(err, LMError::InvalidBounds { .. }));
813    }
814}