use thiserror::Error;
#[derive(Debug, Error, PartialEq)]
pub enum RootError {
#[error("interval [{a}, {b}] does not bracket a root: f(a)={fa}, f(b)={fb}")]
NoBracket { a: f64, b: f64, fa: f64, fb: f64 },
#[error("did not converge in {max_iter} iterations (last residual {last_residual})")]
NoConvergence { max_iter: usize, last_residual: f64 },
#[error("f returned NaN at x = {x}")]
NanEvaluation { x: f64 },
}
pub fn brent<F>(mut f: F, a: f64, b: f64, xtol: f64, max_iter: usize) -> Result<f64, RootError>
where
F: FnMut(f64) -> f64,
{
let mut fa = f(a);
let mut fb = f(b);
if fa.is_nan() {
return Err(RootError::NanEvaluation { x: a });
}
if fb.is_nan() {
return Err(RootError::NanEvaluation { x: b });
}
if fa * fb > 0.0 {
return Err(RootError::NoBracket { a, b, fa, fb });
}
let (mut a, mut b) = (a, b);
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
let mut c = a;
let mut fc = fa;
let mut d = c;
let mut mflag = true;
for iter in 0..max_iter {
let tol = 2.0 * f64::EPSILON * b.abs() + 0.5 * xtol;
let mid = 0.5 * (a - b);
if fb == 0.0 || mid.abs() < tol {
return Ok(b);
}
let s: f64 = if fa != fc && fb != fc {
let denom_a = (fa - fb) * (fa - fc);
let denom_b = (fb - fa) * (fb - fc);
let denom_c = (fc - fa) * (fc - fb);
a * fb * fc / denom_a + b * fa * fc / denom_b + c * fa * fb / denom_c
} else {
b - fb * (b - a) / (fb - fa)
};
let cond1 = {
let lo = (3.0 * a + b) / 4.0;
let (low, high) = if lo < b { (lo, b) } else { (b, lo) };
s < low || s > high
};
let cond2 = mflag && (s - b).abs() >= (b - c).abs() / 2.0;
let cond3 = !mflag && (s - b).abs() >= (c - d).abs() / 2.0;
let cond4 = mflag && (b - c).abs() < tol;
let cond5 = !mflag && (c - d).abs() < tol;
let s = if cond1 || cond2 || cond3 || cond4 || cond5 {
mflag = true;
0.5 * (a + b)
} else {
mflag = false;
s
};
let fs = f(s);
if fs.is_nan() {
return Err(RootError::NanEvaluation { x: s });
}
d = c;
c = b;
fc = fb;
if fa * fs < 0.0 {
b = s;
fb = fs;
} else {
a = s;
fa = fs;
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
if iter + 1 == max_iter {
return Err(RootError::NoConvergence {
max_iter,
last_residual: fb,
});
}
}
unreachable!("loop exits via convergence return or NoConvergence")
}
pub fn illinois<F>(mut f: F, a: f64, b: f64, xtol: f64, max_iter: usize) -> Result<f64, RootError>
where
F: FnMut(f64) -> f64,
{
let mut fa = f(a);
let mut fb = f(b);
if fa.is_nan() {
return Err(RootError::NanEvaluation { x: a });
}
if fb.is_nan() {
return Err(RootError::NanEvaluation { x: b });
}
if fa * fb > 0.0 {
return Err(RootError::NoBracket { a, b, fa, fb });
}
let (mut a, mut b) = (a, b);
let mut side: i32 = 0;
for iter in 0..max_iter {
let denom = fb - fa;
let c = if denom == 0.0 {
0.5 * (a + b)
} else {
(a * fb - b * fa) / denom
};
let fc = f(c);
if fc.is_nan() {
return Err(RootError::NanEvaluation { x: c });
}
if fc == 0.0 || (b - a).abs() < xtol {
return Ok(c);
}
if fc * fb < 0.0 {
a = c;
fa = fc;
if side == -1 {
fb *= 0.5; }
side = -1;
} else {
b = c;
fb = fc;
if side == 1 {
fa *= 0.5;
}
side = 1;
}
if iter + 1 == max_iter {
return Err(RootError::NoConvergence {
max_iter,
last_residual: fc,
});
}
}
unreachable!("loop exits via convergence return or NoConvergence")
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_close(actual: f64, expected: f64, tol: f64) {
assert!(
(actual - expected).abs() < tol,
"expected {expected}, got {actual}"
);
}
fn cos_minus_x(x: f64) -> f64 {
x.cos() - x
}
fn wallis(x: f64) -> f64 {
x.powi(3) - 2.0 * x - 5.0
}
#[test]
fn brent_finds_dottie_number() {
let r = brent(cos_minus_x, 0.0, 1.0, 1e-12, 100).unwrap();
assert_close(r, 0.739_085_133_215_160_7, 1e-12);
}
#[test]
fn brent_finds_wallis_root() {
let r = brent(wallis, 2.0, 3.0, 1e-12, 100).unwrap();
assert_close(r, 2.094_551_481_542_326_5, 1e-12);
}
#[test]
fn brent_rejects_unbracketed_interval() {
let err = brent(|x: f64| x * x, 1.0, 2.0, 1e-9, 50).unwrap_err();
assert!(matches!(err, RootError::NoBracket { .. }));
}
#[test]
fn brent_handles_root_on_endpoint() {
let r = brent(|x: f64| x - 3.0, 3.0, 4.0, 1e-12, 100).unwrap();
assert_close(r, 3.0, 1e-12);
}
#[test]
fn brent_reports_nan_evaluation() {
let err = brent(f64::ln, -1.0, 1.0, 1e-9, 50).unwrap_err();
assert!(matches!(err, RootError::NanEvaluation { .. }));
}
#[test]
fn illinois_finds_dottie_number() {
let r = illinois(cos_minus_x, 0.0, 1.0, 1e-12, 100).unwrap();
assert_close(r, 0.739_085_133_215_160_7, 1e-9);
}
#[test]
fn illinois_finds_wallis_root() {
let r = illinois(wallis, 2.0, 3.0, 1e-12, 100).unwrap();
assert_close(r, 2.094_551_481_542_326_5, 1e-9);
}
#[test]
fn illinois_rejects_unbracketed_interval() {
let err = illinois(|x: f64| x * x + 1.0, -1.0, 1.0, 1e-9, 50).unwrap_err();
assert!(matches!(err, RootError::NoBracket { .. }));
}
#[test]
fn illinois_handles_steep_function() {
let r = illinois(|x: f64| x.powi(15) + 1.0, -2.0, 0.5, 1e-9, 100).unwrap();
assert_close(r, -1.0, 1e-6);
}
}