use crate::error::{MathError, MathResult};
use crate::solvers::{SolverConfig, SolverResult};
#[allow(clippy::many_single_char_names)]
pub fn brent<F>(f: F, a: f64, b: f64, config: &SolverConfig) -> MathResult<SolverResult>
where
F: Fn(f64) -> f64,
{
let mut a = a;
let mut b = b;
let mut fa = f(a);
let mut fb = f(b);
if fa * fb > 0.0 {
return Err(MathError::InvalidBracket { a, b, fa, fb });
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
let mut c = a;
let mut fc = fa;
let mut d = b - a;
let mut e = d;
for iteration in 0..config.max_iterations {
if fb.abs() < config.tolerance {
return Ok(SolverResult {
root: b,
iterations: iteration,
residual: fb,
});
}
if (b - a).abs() < config.tolerance {
return Ok(SolverResult {
root: b,
iterations: iteration,
residual: fb,
});
}
let mut use_bisection = true;
let mut s = 0.0;
if (fa - fc).abs() > 1e-15 && (fb - fc).abs() > 1e-15 {
let r = fb / fc;
let p_val = fa / fc;
let q = fa / fb;
s = b
- (q * (q - r) * (b - a) + (1.0 - r) * (b - c) * p_val)
/ ((q - 1.0) * (r - 1.0) * (p_val - 1.0));
let m = (a + b) / 2.0;
if s > m.min(b) && s < m.max(b) && (s - b).abs() < e.abs() / 2.0 {
use_bisection = false;
}
} else if (fb - fa).abs() > 1e-15 {
s = b - fb * (b - a) / (fb - fa);
let m = (a + b) / 2.0;
if s > m.min(b) && s < m.max(b) && (s - b).abs() < e.abs() / 2.0 {
use_bisection = false;
}
}
if use_bisection {
s = (a + b) / 2.0;
e = b - a;
d = e;
} else {
e = d;
d = s - b;
}
c = b;
fc = fb;
let fs = f(s);
if fa * fs < 0.0 {
b = s;
fb = fs;
} else {
a = s;
fa = fs;
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
}
Err(MathError::convergence_failed(
config.max_iterations,
fb.abs(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_sqrt_2() {
let f = |x: f64| x * x - 2.0;
let result = brent(f, 1.0, 2.0, &SolverConfig::default()).unwrap();
assert_relative_eq!(result.root, std::f64::consts::SQRT_2, epsilon = 1e-10);
}
#[test]
fn test_cubic() {
let f = |x: f64| x * x * x - x - 2.0;
let result = brent(f, 1.0, 2.0, &SolverConfig::default()).unwrap();
assert!(f(result.root).abs() < 1e-10);
assert_relative_eq!(result.root, 1.521_379_706_804_568, epsilon = 1e-10);
}
#[test]
fn test_sin() {
let f = |x: f64| x.sin();
let result = brent(f, 3.0, 4.0, &SolverConfig::default()).unwrap();
assert_relative_eq!(result.root, std::f64::consts::PI, epsilon = 1e-10);
}
#[test]
fn test_invalid_bracket() {
let f = |x: f64| x * x - 2.0;
let result = brent(f, 2.0, 3.0, &SolverConfig::default());
assert!(result.is_err());
}
#[test]
fn test_faster_than_bisection() {
let f = |x: f64| x * x - 2.0;
let config = SolverConfig::default();
let brent_result = brent(f, 1.0, 2.0, &config).unwrap();
assert!(brent_result.iterations < 20);
}
}