use core::qm;
use std::f64::NAN;
use std::f64::EPSILON;
pub fn zbrent<F>(x1: f64, x2: f64, tol: f64, max_iter: u32, func: &mut F)
-> Result<f64, qm::Error>
where F: FnMut(f64) -> Result<f64, qm::Error> {
let mut a = x1;
let mut b = x2;
let mut c = x2;
let mut fa = func(a)?;
let mut fb = func(b)?;
if (fa > 0.0 && fb > 0.0) || (fa < 0.0 && fb < 0.0) {
return Err(qm::Error::new("Root must be bracketed in zbrent"))
}
let mut d = NAN;
let mut e = NAN;
let mut q;
let mut r;
let mut p;
let mut fc = fb;
for _ in 0..max_iter {
if (fb > 0.0 && fc > 0.0) || (fb < 0.0 && fc < 0.0) {
c = a;
fc = fa;
d = b - a;
e = d;
}
if fc.abs() < fb.abs() {
a = b;
b = c;
c = a;
fa = fb;
fb = fc;
fc = fa;
}
let tol1 = 2.0 * EPSILON * b.abs() + 0.5 * tol;
let xm = 0.5 * (c - b);
if xm.abs() <= tol1 || fb == 0.0 {
return Ok(b)
}
if e.abs() >= tol1 && fa.abs() > fb.abs() {
let s = fb / fa;
if a == c {
p = 2.0 * xm * s;
q = 1.0 - s;
} else {
q = fa / fc;
r = fb / fc;
p = s * (2.0 * xm * q * (q - r) - (b - a) * (r - 1.0));
q = (q - 1.0) * (r - 1.0) * (s - 1.0);
}
if p > 0.0 {
q = -q;
}
p = p.abs();
let min1 = 3.0 * xm * q * (tol1 - q).abs();
let min2 = (e * q).abs();
if 2.0 * p < min1.min(min2) {
e = d;
d = p / q;
} else {
d = xm;
e = d;
}
} else {
d = xm;
e = d;
}
a = b;
fa = fb;
if d.abs() > tol1 {
b += d;
} else {
b += tol1.abs() * xm.signum();
}
fb = func(b)?;
}
Err(qm::Error::new("Maximum number of iterations exceeded in zbrent"))
}
#[cfg(test)]
mod tests {
use super::*;
use math::numerics::approx_eq;
use std::f64::consts::PI;
#[test]
fn brent_arcsin() {
let samples = vec![0.0, 0.1, 0.2, 0.3];
let min = 0.0;
let max = PI * 0.5;
let tol = 1e-10;
let max_iter = 100;
for v in samples.iter() {
let y = zbrent(min, max, tol, max_iter, &mut |x : f64| Ok(x.sin() - *v)).unwrap();
let expected = v.asin();
assert!(approx_eq(y, expected, tol), "result={} expected={}", v, expected);
}
}
#[test]
fn problem_of_the_day() {
let r = 6.371e6_f64;
let min = 0.0;
let max = 1e8;
let max_iter = 100;
let tol = 1e-10;
let y = zbrent(min, max, tol, max_iter, &mut |h|
Ok(1.0 + r * (r / (r + h)).acos() - (h * (2.0 * r + h)).sqrt())).unwrap();
let expected = 192.80752497643797_f64;
assert!(approx_eq(y, expected, tol), "result={} expected={}", y, expected);
let test_expected = 1.0 + r * (r / (r + expected)).acos()
- (expected * (2.0 * r + expected)).sqrt();
assert!(approx_eq(test_expected, 0.0, 1e-8), "result={} expected={}", test_expected, 0.0);
}
}