use crate::{MatrixType, SolverError, SolverResult, VectorType, DEFAULT_ITERMAX, DEFAULT_TOL};
use nalgebra::{allocator::Allocator, ComplexField, DefaultAllocator, Dim, UniformNorm, U1};
use num_traits::{Float, Signed};
use std::marker::PhantomData;
pub(super) const DEFAULT_DAMPING_INITIAL_VALUE: f64 = 0.01;
pub(super) const DEFAULT_DAMPING_DECAY_FACTOR: f64 = 10.;
pub struct LevenbergMarquardt<T, R, C, F, J> {
f: F,
j: J,
tolerance: T,
iter_max: usize,
mu_0: T,
beta: T,
r_phantom: PhantomData<R>,
c_phantom: PhantomData<C>,
}
impl<T, R, C, F, J> LevenbergMarquardt<T, R, C, F, J>
where
T: Float + ComplexField<RealField = T> + Signed,
R: Dim,
C: Dim,
F: Fn(VectorType<T, C>) -> VectorType<T, R>,
J: Fn(VectorType<T, C>) -> MatrixType<T, R, C>,
DefaultAllocator: Allocator<C>
+ Allocator<R>
+ Allocator<R, C>
+ Allocator<C, R>
+ Allocator<C, C>
+ Allocator<U1, C>,
{
pub fn new(f: F, j: J) -> Self {
Self {
f,
j,
tolerance: T::from(DEFAULT_TOL).unwrap(),
mu_0: T::from(DEFAULT_DAMPING_INITIAL_VALUE).unwrap(),
beta: T::from(DEFAULT_DAMPING_DECAY_FACTOR).unwrap(),
iter_max: DEFAULT_ITERMAX,
r_phantom: PhantomData,
c_phantom: PhantomData,
}
}
pub fn with_tol(&mut self, tol: T) -> &mut Self {
self.tolerance = tol;
self
}
pub fn with_itermax(&mut self, max: usize) -> &mut Self {
self.iter_max = max;
self
}
pub fn with_intial_damping_factor(&mut self, mu_0: T) -> &mut Self {
self.mu_0 = mu_0;
self
}
pub fn with_damping_decay_factor(&mut self, beta: T) -> &mut Self {
self.beta = beta;
self
}
pub fn solve(&self, mut x0: VectorType<T, C>) -> SolverResult<VectorType<T, C>> {
let mut dv = x0.clone().add_scalar(T::max_value()); let mut identity = &x0 * x0.transpose();
identity.fill_with_identity();
let mut iter = 1;
let mut damping = self.mu_0;
let mut fx = (self.f)(x0.clone());
while dv.apply_norm(&UniformNorm) > self.tolerance && iter <= self.iter_max {
let j = (self.j)(x0.clone());
let jt = j.transpose();
let d = &jt * j + &identity * damping;
let Some(j_inv) = d.try_inverse() else {
return Err(SolverError::BadJacobian);
};
dv = j_inv * -jt * fx.clone();
x0 += &dv;
let fx_next = (self.f)(x0.clone());
if fx_next.norm() < fx.norm() {
damping /= self.beta;
} else {
damping *= self.beta;
}
fx = fx_next;
iter += 1;
}
if iter >= self.iter_max {
return Err(SolverError::MaxIterReached(iter));
}
Ok(x0)
}
}