use crate::core::{
ArgminFloat, CostFunction, Error, IterState, Problem, Solver, State, TerminationReason, KV,
};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct BrentOpt<F> {
eps: F,
t: F,
a: F,
b: F,
u: F,
v: F,
w: F,
x: F,
fv: F,
fw: F,
fx: F,
e: F,
d: F,
c: F,
}
impl<F: ArgminFloat> BrentOpt<F> {
pub fn new(min: F, max: F) -> Self {
BrentOpt {
eps: F::epsilon().sqrt(),
t: float!(1e-5),
a: min,
b: max,
u: F::nan(),
v: F::nan(),
w: F::nan(),
x: F::nan(),
fv: F::nan(),
fw: F::nan(),
fx: F::nan(),
e: F::zero(),
d: F::zero(),
c: float!((3f64 - 5f64.sqrt()) / 2f64),
}
}
pub fn set_tolerance(mut self, eps: F, t: F) -> Self {
self.eps = eps;
self.t = t;
self
}
}
impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for BrentOpt<F>
where
O: CostFunction<Param = F, Output = F>,
F: ArgminFloat,
{
fn name(&self) -> &str {
"BrentOpt"
}
fn init(
&mut self,
problem: &mut Problem<O>,
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
let u = self.a + self.c * (self.b - self.a);
self.v = u;
self.w = u;
self.x = u;
let f = problem.cost(&u)?;
self.fv = f;
self.fw = f;
self.fx = f;
Ok((state.param(self.x).cost(self.fx), None))
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
let two = float!(2f64);
let tol = self.eps * self.x.abs() + self.t;
let m = (self.a + self.b) / two;
if (self.x - m).abs() <= two * tol - (self.b - self.a) / two {
return Ok((
state
.terminate_with(TerminationReason::SolverConverged)
.param(self.x)
.cost(self.fx),
None,
));
}
let p = (self.x - self.v) * (self.x - self.v) * (self.fx - self.fw)
- (self.x - self.w) * (self.x - self.w) * (self.fx - self.fv);
let q = two
* ((self.x - self.w) * (self.fx - self.fv) - (self.x - self.v) * (self.fx - self.fw));
let (p, q) = if q >= F::zero() { (p, q) } else { (-p, -q) };
self.d = if self.e.abs() <= tol
|| p < q * (self.a - self.x)
|| p > q * (self.b - self.x)
|| two * p.abs() >= q * self.e.abs()
{
self.e = if self.x < m { self.b } else { self.a } - self.x;
self.c * self.e
} else {
self.e = self.d;
let d = p / q;
if self.x + d - self.a < two * tol || self.b - self.x - d < two * tol {
(m - self.x).signum() * tol
} else {
d
}
};
self.u = self.x
+ if self.d.abs() >= tol {
self.d
} else {
self.d.signum() * tol
};
let fu = problem.cost(&self.u)?;
if fu <= self.fx {
if self.u < self.x {
self.b = self.x;
} else {
self.a = self.x;
}
self.v = self.w;
self.fv = self.fw;
self.w = self.x;
self.fw = self.fx;
self.x = self.u;
self.fx = fu;
} else {
if self.u < self.x {
self.a = self.u;
} else {
self.b = self.u;
}
if fu <= self.fw || self.w == self.x {
self.v = self.w;
self.fv = self.fw;
self.w = self.u;
self.fw = fu;
} else if fu <= self.fv || self.v == self.x || self.v == self.w {
self.v = self.u;
self.fv = fu;
}
}
Ok((state.param(self.x).cost(self.fx), None))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Executor, TerminationStatus};
use approx::assert_relative_eq;
test_trait_impl!(brent, BrentOpt<f64>);
struct TestFunc {}
impl CostFunction for TestFunc {
type Param = f64;
type Output = f64;
fn cost(&self, x: &Self::Param) -> Result<Self::Output, Error> {
Ok((-x).exp() - (5. - x / 2.).exp())
}
}
#[test]
fn test_brent() {
let cost = TestFunc {};
let solver = BrentOpt::new(-10., 10.);
let res = Executor::new(cost, solver)
.configure(|state| state.counting(true).max_iters(13))
.run()
.unwrap();
assert_eq!(
res.state().termination_status,
TerminationStatus::Terminated(TerminationReason::SolverConverged)
);
assert_relative_eq!(
res.state().param.unwrap(),
-8.613701289624956,
epsilon = f64::EPSILON.sqrt()
);
assert_relative_eq!(
res.state().prev_param.unwrap(),
-8.613701289624956,
epsilon = f64::EPSILON.sqrt()
);
assert_relative_eq!(
res.state().best_param.unwrap(),
-8.613701289624956,
epsilon = f64::EPSILON.sqrt()
);
assert_relative_eq!(
res.state().prev_best_param.unwrap(),
-8.613570813317839,
epsilon = f64::EPSILON.sqrt()
);
assert_relative_eq!(
res.state().cost,
-5506.616448675639,
epsilon = f64::EPSILON.sqrt()
);
assert_relative_eq!(
res.state().best_cost,
-5506.616448675639,
epsilon = f64::EPSILON.sqrt()
);
assert_relative_eq!(
res.state().prev_cost,
-5506.616448675639,
epsilon = f64::EPSILON.sqrt()
);
assert_relative_eq!(
res.state().prev_best_cost,
-5506.616423678641,
epsilon = f64::EPSILON.sqrt()
);
assert_eq!(res.state().iter, 13);
assert_eq!(res.state().get_func_counts()["cost_count"], 13);
}
}