use crate::core::constraint::BoxConstraints;
use crate::core::math::Scalar;
use crate::core::problem::{CostFunction, Gradient, Problem};
use crate::core::solver::Solver;
use crate::core::state::ScalarGradientState;
use crate::core::termination::TerminationReason;
pub struct BrentDerivative<F = f64> {
tol_rel: F,
tol_abs: F,
inner: Option<Inner<F>>,
}
#[derive(Clone, Copy)]
struct Inner<F> {
a: F,
b: F,
x: F,
fx: F,
dx: F,
w: F,
fw: F,
dw: F,
v: F,
fv: F,
dv: F,
d: F,
e: F,
}
fn golden_c<F: Scalar>() -> F {
F::from_f64(0.381_966_011_250_105_2).unwrap()
}
impl<F: Scalar> Default for BrentDerivative<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Scalar> BrentDerivative<F> {
pub fn new() -> Self {
Self {
tol_rel: F::epsilon().sqrt(),
tol_abs: F::from_f64(1e-12).unwrap(),
inner: None,
}
}
pub fn with_tol(tol_rel: F, tol_abs: F) -> Self {
assert!(tol_rel > F::zero(), "tol_rel must be > 0");
assert!(tol_abs > F::zero(), "tol_abs must be > 0");
Self {
tol_rel,
tol_abs,
inner: None,
}
}
}
impl<P, F> Solver<P, ScalarGradientState<F>> for BrentDerivative<F>
where
F: Scalar,
P: CostFunction<Param = F, Output = F> + Gradient<Gradient = F> + BoxConstraints,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: ScalarGradientState<F>,
) -> Result<ScalarGradientState<F>, Self::Error> {
let a = *problem.inner().lower();
let b = *problem.inner().upper();
assert!(
a.is_finite() && b.is_finite() && a < b,
"BrentDerivative requires a finite, ordered bracket: lower < upper"
);
let mut x = state.param.max(a).min(b);
if x == a || x == b {
x = a + golden_c::<F>() * (b - a);
}
let (fx, dx) = problem.cost_and_gradient(&x)?;
self.inner = Some(Inner {
a,
b,
x,
fx,
dx,
w: x,
fw: fx,
dw: dx,
v: x,
fv: fx,
dv: dx,
d: F::zero(),
e: F::zero(),
});
state.param = x;
state.cost = Some(fx);
state.gradient = Some(dx);
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
mut state: ScalarGradientState<F>,
) -> Result<(ScalarGradientState<F>, Option<TerminationReason>), Self::Error> {
let s = self
.inner
.as_mut()
.expect("BrentDerivative::init must run first");
let half = F::from_f64(0.5).unwrap();
let two = F::from_f64(2.0).unwrap();
let m = half * (s.a + s.b);
let tol1 = self.tol_rel * s.x.abs() + self.tol_abs;
let tol2 = two * tol1;
let bisect = |s: &Inner<F>| -> (F, F) {
let e = if s.dx >= F::zero() {
s.a - s.x
} else {
s.b - s.x
};
(e, half * e)
};
if s.e.abs() > tol1 {
let big = two * (s.b - s.a);
let mut d1 = big;
let mut d2 = big;
if s.dw != s.dx {
d1 = (s.w - s.x) * s.dx / (s.dx - s.dw);
}
if s.dv != s.dx {
d2 = (s.v - s.x) * s.dx / (s.dx - s.dv);
}
let u1 = s.x + d1;
let u2 = s.x + d2;
let ok1 = (s.a - u1) * (u1 - s.b) > F::zero() && s.dx * d1 <= F::zero();
let ok2 = (s.a - u2) * (u2 - s.b) > F::zero() && s.dx * d2 <= F::zero();
let olde = s.e;
s.e = s.d;
let chosen = if ok1 && ok2 {
if d1.abs() < d2.abs() {
Some(d1)
} else {
Some(d2)
}
} else if ok1 {
Some(d1)
} else if ok2 {
Some(d2)
} else {
None
};
match chosen {
Some(d) if d.abs() <= (half * olde).abs() => {
s.d = d;
let u = s.x + d;
if u - s.a < tol2 || s.b - u < tol2 {
s.d = if m - s.x >= F::zero() { tol1 } else { -tol1 };
}
}
_ => {
let (e, d) = bisect(s);
s.e = e;
s.d = d;
}
}
} else {
let (e, d) = bisect(s);
s.e = e;
s.d = d;
}
let step = if s.d.abs() >= tol1 {
s.d
} else if s.d >= F::zero() {
tol1
} else {
-tol1
};
let u = s.x + step;
let (fu, du) = problem.cost_and_gradient(&u)?;
if fu <= s.fx {
if u >= s.x {
s.a = s.x;
} else {
s.b = s.x;
}
s.v = s.w;
s.fv = s.fw;
s.dv = s.dw;
s.w = s.x;
s.fw = s.fx;
s.dw = s.dx;
s.x = u;
s.fx = fu;
s.dx = du;
} else {
if u < s.x {
s.a = u;
} else {
s.b = u;
}
if fu <= s.fw || s.w == s.x {
s.v = s.w;
s.fv = s.fw;
s.dv = s.dw;
s.w = u;
s.fw = fu;
s.dw = du;
} else if fu <= s.fv || s.v == s.x || s.v == s.w {
s.v = u;
s.fv = fu;
s.dv = du;
}
}
state.param = u;
state.cost = Some(fu);
state.gradient = Some(du);
Ok((state, None))
}
fn terminate(&self, _state: &ScalarGradientState<F>) -> Option<TerminationReason> {
let s = self.inner.as_ref()?;
let half = F::from_f64(0.5).unwrap();
let two = F::from_f64(2.0).unwrap();
let m = half * (s.a + s.b);
let tol = self.tol_rel * s.x.abs() + self.tol_abs;
if (s.x - m).abs() + half * (s.b - s.a) <= two * tol {
Some(TerminationReason::SolverConverged)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::executor::Executor;
use crate::core::state::State;
use crate::core::termination::{GradientTolerance, TerminationReason};
struct Quadratic {
lo: f64,
hi: f64,
}
impl CostFunction for Quadratic {
type Param = f64;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &f64) -> Result<f64, Self::Error> {
Ok((x - 2.0).powi(2))
}
}
impl Gradient for Quadratic {
type Gradient = f64;
fn gradient(&self, x: &f64) -> Result<f64, Self::Error> {
Ok(2.0 * (x - 2.0))
}
}
impl BoxConstraints for Quadratic {
fn lower(&self) -> &f64 {
&self.lo
}
fn upper(&self) -> &f64 {
&self.hi
}
}
#[test]
fn quadratic_finds_interior_min() {
let r = Executor::new(
Quadratic { lo: 0.0, hi: 5.0 },
BrentDerivative::new(),
ScalarGradientState::new(2.5),
)
.max_iter(100)
.run()
.unwrap();
assert_eq!(r.reason, TerminationReason::SolverConverged);
assert!((r.param() - 2.0).abs() < 1e-7, "x = {}", r.param());
assert!(*r.param() >= 0.0 && *r.param() <= 5.0);
}
#[test]
fn monotonic_function_converges_to_boundary() {
let r = Executor::new(
Quadratic { lo: 3.0, hi: 5.0 },
BrentDerivative::new(),
ScalarGradientState::new(4.0),
)
.max_iter(200)
.run()
.unwrap();
assert!((r.param() - 3.0).abs() < 1e-5, "x = {}", r.param());
}
struct Cubic {
lo: f64,
hi: f64,
}
impl CostFunction for Cubic {
type Param = f64;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &f64) -> Result<f64, Self::Error> {
Ok(x.powi(3) - 3.0 * x)
}
}
impl Gradient for Cubic {
type Gradient = f64;
fn gradient(&self, x: &f64) -> Result<f64, Self::Error> {
Ok(3.0 * x * x - 3.0)
}
}
impl BoxConstraints for Cubic {
fn lower(&self) -> &f64 {
&self.lo
}
fn upper(&self) -> &f64 {
&self.hi
}
}
#[test]
fn cubic_unimodal_on_interval() {
let r = Executor::new(
Cubic { lo: 0.0, hi: 2.0 },
BrentDerivative::new(),
ScalarGradientState::new(0.5),
)
.max_iter(100)
.run()
.unwrap();
assert_eq!(r.reason, TerminationReason::SolverConverged);
assert!(
(r.best_param() - 1.0).abs() < 1e-6,
"x = {}",
r.best_param()
);
assert!((r.best_cost() + 2.0).abs() < 1e-10, "f = {}", r.best_cost());
assert!(
r.state.cost_evals() < 25,
"evals = {}",
r.state.cost_evals()
);
}
#[test]
fn gradient_tolerance_stops() {
let r = Executor::new(
Cubic { lo: 0.0, hi: 2.0 },
BrentDerivative::new(),
ScalarGradientState::new(0.5),
)
.max_iter(200)
.terminate_on(GradientTolerance(1e-4))
.run()
.unwrap();
assert_eq!(r.reason, TerminationReason::GradientTolerance);
assert!(
(r.best_param() - 1.0).abs() < 1e-3,
"best_x = {}",
r.best_param()
);
}
#[test]
fn cost_tolerance_does_not_fire_on_non_improving_probe() {
use crate::core::termination::CostTolerance;
let r = Executor::new(
Cubic { lo: 0.0, hi: 2.0 },
BrentDerivative::new(),
ScalarGradientState::new(0.5),
)
.max_iter(200)
.terminate_on(CostTolerance::new(1e-12))
.run()
.unwrap();
assert!(
(r.best_param() - 1.0).abs() < 1e-5,
"best_x = {}, reason = {:?}",
r.best_param(),
r.reason
);
assert!(
(r.best_cost() + 2.0).abs() < 1e-9,
"best_cost = {}, reason = {:?}",
r.best_cost(),
r.reason
);
assert!(r.best_iter() > 0, "best_iter = {}", r.best_iter());
assert!(
r.best_cost_evals() > 0,
"best_cost_evals = {}",
r.best_cost_evals()
);
}
}