use crate::traits::FloatScalar;
use super::{OptimError, RootResult};
#[derive(Debug, Clone, Copy)]
pub struct RootSettings<T> {
pub x_tol: T,
pub f_tol: T,
pub max_iter: usize,
}
impl Default for RootSettings<f64> {
fn default() -> Self {
Self {
x_tol: 1e-12,
f_tol: 1e-12,
max_iter: 100,
}
}
}
impl Default for RootSettings<f32> {
fn default() -> Self {
Self {
x_tol: 1e-6,
f_tol: 1e-6,
max_iter: 100,
}
}
}
pub fn brent<T: FloatScalar>(
mut f: impl FnMut(T) -> T,
a: T,
b: T,
settings: &RootSettings<T>,
) -> Result<RootResult<T>, OptimError> {
let mut a = a;
let mut b = b;
let mut fa = f(a);
let mut fb = f(b);
let mut evals = 2usize;
if (fa > T::zero()) == (fb > T::zero()) {
return Err(OptimError::BracketInvalid);
}
if fa.abs() < fb.abs() {
core::mem::swap(&mut a, &mut b);
core::mem::swap(&mut fa, &mut fb);
}
let mut c = a;
let mut fc = fa;
let mut d = b - a;
let mut mflag = true;
for iter in 0..settings.max_iter {
if fb.abs() < settings.f_tol || (b - a).abs() < settings.x_tol {
return Ok(RootResult {
x: b,
fx: fb,
iterations: iter,
evals,
});
}
let mut s;
if fa != fc && fb != fc {
s = a * fb * fc / ((fa - fb) * (fa - fc))
+ b * fa * fc / ((fb - fa) * (fb - fc))
+ c * fa * fb / ((fc - fa) * (fc - fb));
} else {
s = b - fb * (b - a) / (fb - fa);
}
let three_quarter = (T::from(3).unwrap() * a + b) / T::from(4).unwrap();
let cond1 = if three_quarter < b {
s < three_quarter || s > b
} else {
s > three_quarter || s < b
};
let cond2 = mflag && (s - b).abs() >= (b - c).abs() / T::from(2).unwrap();
let cond3 = !mflag && (s - b).abs() >= (c - d).abs() / T::from(2).unwrap();
let cond4 = mflag && (b - c).abs() < settings.x_tol;
let cond5 = !mflag && (c - d).abs() < settings.x_tol;
if cond1 || cond2 || cond3 || cond4 || cond5 {
s = (a + b) / T::from(2).unwrap();
mflag = true;
} else {
mflag = false;
}
let fs = f(s);
evals += 1;
d = c;
c = b;
fc = fb;
if (fa > T::zero()) != (fs > T::zero()) {
b = s;
fb = fs;
} else {
a = s;
fa = fs;
}
if fa.abs() < fb.abs() {
core::mem::swap(&mut a, &mut b);
core::mem::swap(&mut fa, &mut fb);
}
}
Err(OptimError::MaxIterations)
}
pub fn newton_1d<T: FloatScalar>(
mut f: impl FnMut(T) -> T,
mut df: impl FnMut(T) -> T,
x0: T,
settings: &RootSettings<T>,
) -> Result<RootResult<T>, OptimError> {
let mut x = x0;
let mut fx = f(x);
let mut evals = 1usize;
for iter in 0..settings.max_iter {
if fx.abs() < settings.f_tol {
return Ok(RootResult {
x,
fx,
iterations: iter,
evals,
});
}
let dfx = df(x);
evals += 1;
if dfx.abs() < T::epsilon() {
return Err(OptimError::Singular);
}
let x_new = x - fx / dfx;
if (x_new - x).abs() < settings.x_tol {
fx = f(x_new);
evals += 1;
return Ok(RootResult {
x: x_new,
fx,
iterations: iter + 1,
evals,
});
}
x = x_new;
fx = f(x);
evals += 1;
}
Err(OptimError::MaxIterations)
}