use std;
use errors::*;
use ndarray::{Array1, arr1};
use ArgminSolver;
use result::ArgminResult;
use termination::TerminationReason;
pub struct BacktrackingLineSearch<'a> {
cost_function: &'a Fn(&Array1<f64>) -> f64,
gradient: &'a Fn(&Array1<f64>) -> Array1<f64>,
alpha: f64,
max_iters: u64,
tau: f64,
c: f64,
state: Option<BacktrackingLineSearchState>,
}
pub struct BacktrackingLineSearchState {
p: Array1<f64>,
x: Array1<f64>,
cost: f64,
t: f64,
fx: f64,
iter: u64,
alpha: f64,
}
impl<'a> BacktrackingLineSearch<'a> {
pub fn new(
cost_function: &'a Fn(&Array1<f64>) -> f64,
gradient: &'a Fn(&Array1<f64>) -> Array1<f64>,
) -> Self {
BacktrackingLineSearch {
cost_function: cost_function,
gradient: gradient,
alpha: 1.0,
max_iters: 100,
tau: 0.5,
c: 0.5,
state: None,
}
}
pub fn alpha(&mut self, alpha: f64) -> &mut Self {
self.alpha = alpha;
self
}
pub fn max_iters(&mut self, max_iters: u64) -> &mut Self {
self.max_iters = max_iters;
self
}
pub fn c(&mut self, c: f64) -> Result<&mut Self> {
if c >= 1.0 || c <= 0.0 {
return Err(ErrorKind::InvalidParameter(
"BacktrackingLineSearch: Parameter `c` must satisfy 0 < c < 1.".into(),
).into());
}
self.c = c;
Ok(self)
}
pub fn tau(&mut self, tau: f64) -> Result<&mut Self> {
if tau >= 1.0 || tau <= 0.0 {
return Err(ErrorKind::InvalidParameter(
"BacktrackingLineSearch: Parameter `tau` must satisfy 0 < tau < 1.".into(),
).into());
}
self.tau = tau;
Ok(self)
}
}
impl<'a> ArgminSolver<'a> for BacktrackingLineSearch<'a> {
type Parameter = Array1<f64>;
type CostValue = f64;
type Hessian = Array1<f64>;
type StartingPoints = Array1<f64>;
type ProblemDefinition = Array1<f64>;
fn init(&mut self, p: Self::ProblemDefinition, x: &Self::StartingPoints) -> Result<()> {
let m: f64 = p.t().dot(&((self.gradient)(x)));
self.state = Some(BacktrackingLineSearchState {
cost: std::f64::NAN,
p: p,
x: x.to_owned(),
t: -self.c * m,
fx: (self.cost_function)(x),
iter: 0,
alpha: self.alpha,
});
Ok(())
}
fn next_iter(&mut self) -> Result<ArgminResult<Self::Parameter, Self::CostValue>> {
let mut state = self.state.take().unwrap();
let param = &state.x + &(state.alpha * &state.p);
state.cost = (self.cost_function)(¶m);
state.iter += 1;
state.alpha *= self.tau;
let mut out = ArgminResult::new(arr1(&[state.alpha]), std::f64::NAN, state.iter);
self.state = Some(state);
out.set_termination_reason(self.terminate());
Ok(out)
}
make_terminate!(self,
self.state.as_ref().unwrap().iter >= self.max_iters, TerminationReason::MaxItersReached;
self.state.as_ref().unwrap().fx - self.state.as_ref().unwrap().cost >= self.state.as_ref().unwrap().alpha * self.state.as_ref().unwrap().t, TerminationReason::TargetCostReached;
);
make_run!(
Self::ProblemDefinition,
Self::StartingPoints,
Self::Parameter,
Self::CostValue
);
}