use crate::core::math::{
GramMatrix, LinearSolveSpd, MatTransposeVec, NegInPlace, NormInfinity, NormSquared, Scalar,
ScaledAdd,
};
use crate::core::problem::{Jacobian, Problem, Residual};
use crate::core::solver::Solver;
use crate::core::state::BasicState;
use crate::core::termination::TerminationReason;
pub struct GaussNewton<V, M, F = f64> {
tol_grad: F,
r_cache: Option<V>,
j_cache: Option<M>,
}
impl<V, M> Default for GaussNewton<V, M> {
fn default() -> Self {
Self::new()
}
}
impl<V, M> GaussNewton<V, M> {
pub fn new() -> Self {
Self {
tol_grad: 1e-8,
r_cache: None,
j_cache: None,
}
}
}
impl<V, M, F: Scalar> GaussNewton<V, M, F> {
pub fn tol_grad(mut self, tol: F) -> Self {
assert!(tol >= F::zero(), "tol_grad must be ≥ 0");
self.tol_grad = tol;
self
}
}
impl<P, V, M, F> Solver<P, BasicState<V, F>> for GaussNewton<V, M, F>
where
F: Scalar,
P: Residual<Param = V, Output = V> + Jacobian<Jacobian = M>,
V: ScaledAdd<F> + NormSquared<F> + NormInfinity<F> + NegInPlace + Clone,
M: GramMatrix + MatTransposeVec<V> + LinearSolveSpd<V>,
{
type Error = <P as Residual>::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V, F>,
) -> Result<BasicState<V, F>, Self::Error> {
let (r, j) = problem.residual_and_jacobian(&state.param)?;
state.cost = Some(F::from_f64(0.5).unwrap() * r.norm_squared());
self.r_cache = Some(r);
self.j_cache = Some(j);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: BasicState<V, F>,
) -> Result<(BasicState<V, F>, Option<TerminationReason>), Self::Error> {
let r = match self.r_cache.take() {
Some(r) => r,
None => problem.residual(&state.param)?,
};
let j = match self.j_cache.take() {
Some(j) => j,
None => problem.jacobian(&state.param)?,
};
let g = j.mat_transpose_vec(&r);
if self.tol_grad > F::zero() && g.norm_infinity() <= self.tol_grad {
self.r_cache = Some(r);
self.j_cache = Some(j);
return Ok((state, Some(TerminationReason::SolverConverged)));
}
let gram = j.gram();
let mut neg_g = g;
neg_g.neg_in_place();
let delta = match gram.solve_spd(&neg_g) {
Ok(d) => d,
Err(_) => {
self.r_cache = Some(r);
self.j_cache = Some(j);
return Ok((state, Some(TerminationReason::SolverFailed)));
}
};
state.param.scaled_add(F::one(), &delta);
let r_new = problem.residual(&state.param)?;
state.cost = Some(F::from_f64(0.5).unwrap() * r_new.norm_squared());
self.r_cache = Some(r_new);
self.j_cache = None;
Ok((state, None))
}
}