use crate::{MatrixType, SolverError, SolverResult, VectorType, DEFAULT_ITERMAX, DEFAULT_TOL};
use nalgebra::{allocator::Allocator, ComplexField, DefaultAllocator, Dim, Scalar, UniformNorm};
use num_traits::{Float, Signed};
use std::marker::PhantomData;
pub struct MultiVarNewton<T, D, F, J> {
f: F,
j: J,
tolerance: T,
iter_max: usize,
d_phantom: PhantomData<D>,
}
impl<T, D, F, J> MultiVarNewton<T, D, F, J>
where
T: Float + Scalar + ComplexField<RealField = T> + Signed,
D: Dim,
J: Fn(VectorType<T, D>) -> MatrixType<T, D, D>,
F: Fn(VectorType<T, D>) -> VectorType<T, D>,
DefaultAllocator: Allocator<D, D> + Allocator<D>,
{
pub fn new(f: F, j: J) -> Self {
Self {
f,
j,
tolerance: T::from(DEFAULT_TOL).unwrap(),
iter_max: DEFAULT_ITERMAX,
d_phantom: PhantomData,
}
}
pub fn with_tol(&mut self, tol: T) -> &mut Self {
self.tolerance = tol;
self
}
pub fn with_itermax(&mut self, max: usize) -> &mut Self {
self.iter_max = max;
self
}
pub fn solve(&self, mut x0: VectorType<T, D>) -> SolverResult<VectorType<T, D>> {
let mut dv = x0.clone().add_scalar(T::max_value()); let mut iter = 1;
while dv.apply_norm(&UniformNorm) > self.tolerance && iter <= self.iter_max {
if let Some(j_inv) = (self.j)(x0.clone()).try_inverse() {
dv = j_inv * (self.f)(x0.clone());
x0 -= &dv;
iter += 1;
} else {
return Err(SolverError::BadJacobian);
}
}
if iter >= self.iter_max {
return Err(SolverError::MaxIterReached(iter));
}
Ok(x0)
}
}