use crate::time::{Interval, TimeInstant};
use qtty::{Days, Quantity, Unit};
const DEFAULT_TOL: f64 = 1e-9;
const BRENT_MAX_ITER: usize = 100;
const BISECT_MAX_ITER: usize = 80;
const F_EPS: f64 = 1e-12;
const COINCIDENCE_EPS: f64 = 1e-14;
fn same_sign<V: Unit>(a: Quantity<V>, b: Quantity<V>) -> bool {
let sa = a.signum();
let sb = b.signum();
sa == sb && sa != 0.0
}
fn dimensionless_ratio<V: Unit>(num: Quantity<V>, den: Quantity<V>) -> f64 {
num.value() / den.value()
}
pub fn brent<T, V, F>(lo: Quantity<T>, hi: Quantity<T>, f: F) -> Option<Quantity<T>>
where
T: Unit,
V: Unit,
F: Fn(Quantity<T>) -> Quantity<V>,
{
let f_lo = f(lo);
let f_hi = f(hi);
brent_engine(lo, hi, f_lo, f_hi, &f, Quantity::new(DEFAULT_TOL))
}
pub fn brent_with_values<T, V, F>(
period: Interval<T>,
f_lo: Quantity<V>,
f_hi: Quantity<V>,
f: F,
) -> Option<T>
where
T: TimeInstant<Duration = Days>,
V: Unit,
F: Fn(T) -> Quantity<V>,
{
brent_core(period, f_lo, f_hi, &f, Days::new(DEFAULT_TOL))
}
pub fn brent_tol<T, V, F>(
period: Interval<T>,
f_lo: Quantity<V>,
f_hi: Quantity<V>,
f: F,
tolerance: Days,
) -> Option<T>
where
T: TimeInstant<Duration = Days>,
V: Unit,
F: Fn(T) -> Quantity<V>,
{
brent_core(period, f_lo, f_hi, &f, tolerance)
}
fn brent_core<T, V, F>(
period: Interval<T>,
f_lo: Quantity<V>,
f_hi: Quantity<V>,
f: &F,
tol: Days,
) -> Option<T>
where
T: TimeInstant<Duration = Days>,
V: Unit,
F: Fn(T) -> Quantity<V>,
{
let start = period.start;
let span = period.end.difference(&start);
brent_engine(
Days::zero(),
span,
f_lo,
f_hi,
&|offset: Days| {
let t = start.add_duration(offset);
f(t)
},
tol,
)
.map(|offset| start.add_duration(offset))
}
fn brent_engine<T, V, F>(
lo: Quantity<T>,
hi: Quantity<T>,
f_lo: Quantity<V>,
f_hi: Quantity<V>,
f: &F,
tol: Quantity<T>,
) -> Option<Quantity<T>>
where
T: Unit,
V: Unit,
F: Fn(Quantity<T>) -> Quantity<V>,
{
let f_eps: Quantity<V> = Quantity::new(F_EPS);
let coincidence: Quantity<T> = Quantity::new(COINCIDENCE_EPS);
let mut a = lo;
let mut b = hi;
let mut fa = f_lo;
let mut fb = f_hi;
if fa.abs() < f_eps {
return Some(a);
}
if fb.abs() < f_eps {
return Some(b);
}
if same_sign(fa, fb) {
return None;
}
if fa.abs() < fb.abs() {
std::mem::swap(&mut a, &mut b);
std::mem::swap(&mut fa, &mut fb);
}
let mut c = a;
let mut fc = fa;
let mut d = b - a;
let mut e = d;
for _ in 0..BRENT_MAX_ITER {
let m = (c - b) * 0.5;
if fb.abs() < f_eps || m.abs() <= tol {
return Some(b);
}
let use_bisection = e.abs() < tol || fa.abs() <= fb.abs();
let (new_e, new_d) = if use_bisection {
(m, m)
} else {
let s = dimensionless_ratio(fb, fa);
let (p, q): (Quantity<T>, f64) = if (a - c).abs() < coincidence {
(m * (2.0 * s), 1.0 - s)
} else {
let q_val = dimensionless_ratio(fa, fc);
let r = dimensionless_ratio(fb, fc);
let p = m * (2.0 * s * q_val * (q_val - r)) - (b - a) * (s * (r - 1.0));
let q = (q_val - 1.0) * (r - 1.0) * (s - 1.0);
(p, q)
};
let (p, q) = if p > Quantity::<T>::zero() {
(p, -q)
} else {
(-p, q)
};
let s_val = e;
if p * 2.0 < m * (3.0 * q) - (tol * q).abs() && p < (s_val * (0.5 * q)).abs() {
(d, p / q) } else {
(m, m) }
};
e = new_e;
d = new_d;
a = b;
fa = fb;
b += if d.abs() > tol {
d
} else if m > Quantity::<T>::zero() {
tol
} else {
-tol
};
fb = f(b);
if same_sign(fb, fc) {
c = a;
fc = fa;
e = b - a;
d = e;
}
if fc.abs() < fb.abs() {
a = b;
b = c;
c = a;
fa = fb;
fb = fc;
fc = fa;
}
}
Some(b)
}
pub fn bisection<T, V, F>(lo: Quantity<T>, hi: Quantity<T>, f: F) -> Option<Quantity<T>>
where
T: Unit,
V: Unit,
F: Fn(Quantity<T>) -> Quantity<V>,
{
let tol: Quantity<T> = Quantity::new(DEFAULT_TOL);
let f_eps: Quantity<V> = Quantity::new(F_EPS);
let mut a = lo;
let mut b = hi;
let mut fa = f(a);
let fb = f(b);
if fa.abs() < f_eps {
return Some(a);
}
if fb.abs() < f_eps {
return Some(b);
}
if same_sign(fa, fb) {
return None;
}
for _ in 0..BISECT_MAX_ITER {
let mid = a.mean(b);
let fm = f(mid);
let width = (b - a).abs();
if fm.abs() < f_eps || width < tol {
return Some(mid);
}
if same_sign(fa, fm) {
a = mid;
fa = fm;
} else {
b = mid;
}
}
Some(a.mean(b))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::time::{Interval, ModifiedJulianDate};
use qtty::{Day, Radian};
type Days = Quantity<Day>;
type Mjd = ModifiedJulianDate;
type Radians = Quantity<Radian>;
fn day_f64(t: Days) -> f64 {
t.value()
}
fn mjd_f64(t: Mjd) -> f64 {
t.quantity().value()
}
#[test]
fn brent_finds_sine_root_near_pi() {
let root = brent(Days::new(3.0), Days::new(4.0), |t: Days| {
Radians::new(day_f64(t).sin())
})
.expect("should find π");
assert!((root - Days::new(std::f64::consts::PI)).abs() < Days::new(1e-10));
}
#[test]
fn brent_finds_linear_root() {
let root = brent(Days::new(0.0), Days::new(10.0), |t: Days| {
Radians::new(day_f64(t) - 5.0)
})
.expect("should find 5");
assert!((root - Days::new(5.0)).abs() < Days::new(1e-10));
}
#[test]
fn brent_returns_none_for_invalid_bracket() {
assert!(brent(Days::new(0.0), Days::new(1.0), |_: Days| Radians::new(42.0)).is_none());
}
#[test]
fn brent_returns_endpoint_when_exact() {
let root = brent(Days::new(0.0), Days::new(5.0), |t: Days| {
Radians::new(day_f64(t) - 5.0)
})
.expect("endpoint");
assert!((root - Days::new(5.0)).abs() < Days::new(F_EPS));
}
#[test]
fn brent_with_values_saves_evaluations() {
use std::cell::Cell;
let count = Cell::new(0usize);
let f = |t: Mjd| -> Radians {
count.set(count.get() + 1);
Radians::new(mjd_f64(t).sin())
};
let f_lo = Radians::new((3.0_f64).sin());
let f_hi = Radians::new((4.0_f64).sin());
let _ = brent_with_values(Interval::new(Mjd::new(3.0), Mjd::new(4.0)), f_lo, f_hi, f);
let with_vals = count.get();
count.set(0);
let _ = brent(Days::new(3.0), Days::new(4.0), |t: Days| {
count.set(count.get() + 1);
Radians::new(day_f64(t).sin())
});
let without = count.get();
assert!(with_vals + 2 <= without || with_vals <= without);
}
#[test]
fn brent_tol_respects_relaxed_tolerance() {
let root = brent_tol(
Interval::new(Mjd::new(3.0), Mjd::new(4.0)),
Radians::new((3.0_f64).sin()),
Radians::new((4.0_f64).sin()),
|t: Mjd| Radians::new(mjd_f64(t).sin()),
Days::new(1e-3),
)
.expect("relaxed");
assert!((root - Mjd::new(std::f64::consts::PI)).abs() < Days::new(2e-3));
}
#[test]
fn brent_handles_step_function() {
let root = brent(Days::new(-1.0), Days::new(1.0), |t: Days| {
Radians::new(if t < 0.0 { -1.0 } else { 1.0 })
})
.expect("step");
assert!(root.abs() < 1e-6);
}
#[test]
fn brent_cubic() {
let root = brent(Days::new(1.0), Days::new(2.0), |t: Days| {
Radians::new(day_f64(t).powi(3) - 2.0)
})
.expect("cbrt 2");
assert!((root - Days::new(2.0_f64.powf(1.0 / 3.0))).abs() < Days::new(1e-9));
}
#[test]
fn bisection_finds_sine_root() {
let root = bisection(Days::new(3.0), Days::new(4.0), |t: Days| {
Radians::new(day_f64(t).sin())
})
.expect("π");
assert!((root - Days::new(std::f64::consts::PI)).abs() < Days::new(1e-8));
}
#[test]
fn bisection_returns_none_for_invalid_bracket() {
assert!(bisection(Days::new(0.0), Days::new(1.0), |_: Days| Radians::new(42.0)).is_none());
}
#[test]
fn bisection_handles_step_function() {
let root = bisection(Days::new(-1.0), Days::new(1.0), |t: Days| {
Radians::new(if t < 0.0 { -5.0 } else { 5.0 })
})
.expect("step");
assert!(root.abs() < 1e-6);
}
#[test]
fn bisection_endpoint_root() {
let root = bisection(Days::new(0.0), Days::new(5.0), |t: Days| {
Radians::new(day_f64(t))
})
.expect("root at 0");
assert!(root.abs() < Days::new(F_EPS));
}
}