use crate::core::math::{
AddDiagonalVectorInPlace, ComponentDivAssign, ComponentMaxAssign, ComponentMulAssign, Dot,
FloorZerosInPlace, GramMatrix, LinearSolveSpd, MatDiagonal, MatTransposeVec, NegInPlace,
NormInfinity, NormSquared, Scalar, ScaleInPlace, ScaledAdd,
};
use crate::core::problem::{Jacobian, Problem, Residual};
use crate::core::solver::Solver;
use crate::core::state::NllsState;
use crate::core::termination::TerminationReason;
pub struct LevenbergMarquardt<V, M, F = f64> {
tol_grad: F,
tol_grad_rel: F,
tol_cost_rel: F,
tol_step_rel: F,
tau: F,
max_inner_attempts: u32,
mu: Option<F>,
nu: F,
diag: Option<V>,
r_cache: Option<V>,
gram_cache: Option<M>,
jtr_cache: Option<V>,
}
impl<V, M> Default for LevenbergMarquardt<V, M> {
fn default() -> Self {
Self::new()
}
}
impl<V, M> LevenbergMarquardt<V, M> {
pub fn new() -> Self {
Self {
tol_grad: 1e-8,
tol_grad_rel: 0.0,
tol_cost_rel: 0.0,
tol_step_rel: 0.0,
tau: 1e-3,
max_inner_attempts: 50,
mu: None,
nu: 2.0,
diag: None,
r_cache: None,
gram_cache: None,
jtr_cache: None,
}
}
}
impl<V, M, F: Scalar> LevenbergMarquardt<V, M, F> {
pub fn with_tol_grad(mut self, tol: F) -> Self {
assert!(tol >= F::zero(), "tol_grad must be ≥ 0");
self.tol_grad = tol;
self
}
pub fn with_tol_grad_rel(mut self, tol: F) -> Self {
assert!(tol >= F::zero(), "tol_grad_rel must be ≥ 0");
self.tol_grad_rel = tol;
self
}
pub fn with_tol_cost_rel(mut self, tol: F) -> Self {
assert!(tol >= F::zero(), "tol_cost_rel must be ≥ 0");
self.tol_cost_rel = tol;
self
}
pub fn with_tol_step_rel(mut self, tol: F) -> Self {
assert!(tol >= F::zero(), "tol_step_rel must be ≥ 0");
self.tol_step_rel = tol;
self
}
pub fn with_tau(mut self, tau: F) -> Self {
assert!(tau > F::zero(), "tau must be > 0");
self.tau = tau;
self
}
pub fn with_max_inner_attempts(mut self, n: u32) -> Self {
assert!(n > 0, "max_inner_attempts must be > 0");
self.max_inner_attempts = n;
self
}
}
impl<P, V, M, F> Solver<P, NllsState<V, F>> for LevenbergMarquardt<V, M, F>
where
F: Scalar,
P: Residual<Param = V, Output = V> + Jacobian<Jacobian = M>,
V: ScaledAdd<F>
+ NormSquared<F>
+ NormInfinity<F>
+ NegInPlace
+ Dot<F>
+ ScaleInPlace<F>
+ ComponentMulAssign
+ ComponentDivAssign
+ ComponentMaxAssign
+ FloorZerosInPlace<F>
+ Clone,
M: GramMatrix
+ MatTransposeVec<V>
+ LinearSolveSpd<V>
+ AddDiagonalVectorInPlace<V>
+ MatDiagonal<V>
+ Clone,
{
type Error = <P as Residual>::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: NllsState<V, F>,
) -> Result<NllsState<V, F>, Self::Error> {
let (r, j) = problem.residual_and_jacobian(&state.param)?;
state.cost = Some(F::from_f64(0.5).unwrap() * r.norm_squared());
let a = j.gram();
let mut d = a.diagonal();
d.floor_zeros_in_place(F::one());
self.diag = Some(d);
self.mu = Some(self.tau);
self.nu = F::from_f64(2.0).unwrap();
self.jtr_cache = Some(j.mat_transpose_vec(&r));
self.gram_cache = Some(a);
self.r_cache = Some(r);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: NllsState<V, F>,
) -> Result<(NllsState<V, F>, Option<TerminationReason>), Self::Error> {
let r = match self.r_cache.take() {
Some(r) => r,
None => problem.residual(&state.param)?,
};
let (a, g) = match (self.gram_cache.take(), self.jtr_cache.take()) {
(Some(a), Some(g)) => (a, g),
_ => {
let j = problem.jacobian(&state.param)?;
(j.gram(), j.mat_transpose_vec(&r))
}
};
let diag_cur = a.diagonal();
let abs_converged = self.tol_grad > F::zero() && g.norm_infinity() <= self.tol_grad;
let rel_converged = self.tol_grad_rel > F::zero() && {
let mut cos_sq = g.clone();
cos_sq.component_mul_assign(&g);
let mut denom = diag_cur.clone();
denom.floor_zeros_in_place(F::one());
cos_sq.component_div_assign(&denom);
cos_sq.norm_infinity() <= self.tol_grad_rel * self.tol_grad_rel * r.norm_squared()
};
if abs_converged || rel_converged {
self.r_cache = Some(r);
self.gram_cache = Some(a);
self.jtr_cache = Some(g);
return Ok((state, Some(TerminationReason::SolverConverged)));
}
let mut neg_g = g.clone();
neg_g.neg_in_place();
let mut d = self
.diag
.take()
.expect("diag not set: Solver::init must run before next_iter");
d.component_max_assign(&diag_cur);
let mut mu = self
.mu
.expect("mu not set: Solver::init must run before next_iter");
let mut nu = self.nu;
let two = F::from_f64(2.0).unwrap();
let half = F::from_f64(0.5).unwrap();
let one_third = F::from_f64(1.0 / 3.0).unwrap();
let h;
let mut attempts: u32 = 0;
loop {
let mut a_damped = a.clone();
let mut damping = d.clone();
damping.scale_in_place(mu);
a_damped.add_diagonal_vector_in_place(&damping);
match a_damped.solve_spd(&neg_g) {
Ok(step) => {
h = step;
break;
}
Err(_) => {
attempts += 1;
if attempts >= self.max_inner_attempts || !mu.is_finite() {
self.mu = Some(mu);
self.nu = nu;
self.diag = Some(d);
self.r_cache = Some(r);
self.gram_cache = Some(a);
self.jtr_cache = Some(g);
return Ok((state, Some(TerminationReason::SolverFailed)));
}
mu = mu * nu;
nu = nu * two;
}
}
}
let mut dh = d.clone();
dh.component_mul_assign(&h);
let l_diff = half * (mu * h.dot(&dh) - h.dot(&g));
let mut x_trial = state.param.clone();
x_trial.scaled_add(F::one(), &h);
let r_trial = problem.residual(&x_trial)?;
state.cost_evals += 1;
let f_trial = half * r_trial.norm_squared();
let prev_cost = state
.cost
.expect("cost not set: Solver::init must run before next_iter");
let actual_diff = prev_cost - f_trial;
let rho = if l_diff > F::zero() {
actual_diff / l_diff
} else {
F::zero()
};
if rho > F::zero() {
state.param = x_trial;
state.cost = Some(f_trial);
let factor = F::one() - (two * rho - F::one()).powi(3);
mu = mu * factor.max(one_third);
nu = two;
self.r_cache = Some(r_trial);
self.gram_cache = None;
self.jtr_cache = None;
} else {
mu = mu * nu;
nu = nu * two;
self.r_cache = Some(r);
self.gram_cache = Some(a);
self.jtr_cache = Some(g);
}
self.mu = Some(mu);
self.nu = nu;
self.diag = Some(d);
let cost_rel_converged = self.tol_cost_rel > F::zero()
&& actual_diff.abs() <= self.tol_cost_rel * prev_cost
&& l_diff <= self.tol_cost_rel * prev_cost
&& rho <= two;
let step_rel_converged = self.tol_step_rel > F::zero()
&& h.norm_squared()
<= self.tol_step_rel * self.tol_step_rel * state.param.norm_squared();
if cost_rel_converged || step_rel_converged {
return Ok((state, Some(TerminationReason::SolverConverged)));
}
Ok((state, None))
}
}