eqsolver/solvers/multivariable/
gaussnewton.rs1use crate::{MatrixType, SolverError, SolverResult, VectorType, DEFAULT_ITERMAX, DEFAULT_TOL};
2use nalgebra::{allocator::Allocator, ComplexField, DefaultAllocator, Dim, UniformNorm};
3use num_traits::{Float, Signed};
4use std::marker::PhantomData;
5
6pub struct GaussNewton<T, R, C, F, J> {
14 f: F,
15 j: J,
16 tolerance: T,
17 iter_max: usize,
18 r_phantom: PhantomData<R>,
19 c_phantom: PhantomData<C>,
20}
21
22impl<T, R, C, F, J> GaussNewton<T, R, C, F, J>
23where
24 T: Float + ComplexField<RealField = T> + Signed,
25 R: Dim,
26 C: Dim,
27 F: Fn(VectorType<T, C>) -> VectorType<T, R>,
28 J: Fn(VectorType<T, C>) -> MatrixType<T, R, C>,
29 DefaultAllocator:
30 Allocator<C> + Allocator<R> + Allocator<R, C> + Allocator<C, R> + Allocator<C, C>,
31{
32 pub fn new(f: F, j: J) -> Self {
36 Self {
37 f,
38 j,
39 tolerance: T::from(DEFAULT_TOL).unwrap(),
40 iter_max: DEFAULT_ITERMAX,
41 r_phantom: PhantomData,
42 c_phantom: PhantomData,
43 }
44 }
45
46 pub fn with_tol(&mut self, tol: T) -> &mut Self {
50 self.tolerance = tol;
51 self
52 }
53
54 pub fn with_itermax(&mut self, max: usize) -> &mut Self {
58 self.iter_max = max;
59 self
60 }
61
62 pub fn solve(&self, mut x0: VectorType<T, C>) -> SolverResult<VectorType<T, C>> {
66 let mut dv = x0.clone().add_scalar(T::max_value()); let mut iter = 1;
68
69 while dv.apply_norm(&UniformNorm) > self.tolerance && iter <= self.iter_max {
71 let j = (self.j)(x0.clone());
72 let jt = j.transpose();
73 if let Some(jtj_inv) = (&jt * j).try_inverse() {
74 dv = jtj_inv * jt * (self.f)(x0.clone());
75 x0 -= &dv;
76 iter += 1;
77 } else {
78 return Err(SolverError::BadJacobian);
79 }
80 }
81
82 if iter >= self.iter_max {
83 return Err(SolverError::MaxIterReached(iter));
84 }
85
86 Ok(x0)
87 }
88}