use crate::{
finite_differences::{backward, central, forward, FiniteDifferenceType},
{SolverError, SolverResult, DEFAULT_ITERMAX, DEFAULT_TOL},
};
use num_traits::Float;
use std::ops::Fn;
pub struct FDNewton<T, F> {
f: F,
finite_diff: fn(F, T, T) -> T,
fd_step_length: T,
tolerance: T,
iter_max: usize,
}
impl<T, F> FDNewton<T, F>
where
T: Float,
F: Fn(T) -> T + Copy,
{
pub fn new(f: F) -> Self {
Self {
f,
finite_diff: central,
fd_step_length: T::epsilon().sqrt(),
tolerance: T::from(DEFAULT_TOL).unwrap(),
iter_max: DEFAULT_ITERMAX,
}
}
pub fn with_tol(&mut self, tol: T) -> &mut Self {
self.tolerance = tol;
self
}
pub fn with_itermax(&mut self, max: usize) -> &mut Self {
self.iter_max = max;
self
}
pub fn with_fd_step_length(&mut self, h: T) -> &mut Self {
self.fd_step_length = h;
self
}
pub fn with_finite_difference(&mut self, fd_type: FiniteDifferenceType) -> &mut Self {
match fd_type {
FiniteDifferenceType::Central => self.finite_diff = central,
FiniteDifferenceType::Forward => self.finite_diff = forward,
FiniteDifferenceType::Backward => self.finite_diff = backward,
}
self
}
pub fn solve(&self, mut x0: T) -> SolverResult<T> {
let mut dx = T::max_value(); let mut iter = 1;
while dx.abs() > self.tolerance && iter <= self.iter_max {
dx = (self.f)(x0) / (self.finite_diff)(self.f, x0, self.fd_step_length);
x0 = x0 - dx;
iter += 1;
}
if iter >= self.iter_max {
return Err(SolverError::MaxIterReached(iter));
}
if x0.is_nan() {
return Err(SolverError::NotANumber);
}
Ok(x0)
}
}