use crate::prelude::*;
use std;
type Triplet = (f64, f64, f64);
#[derive(ArgminSolver)]
#[stop("self.best_f - self.finit < self.delta * self.best_x * self.dginit" => LineSearchConditionMet)]
#[stop("self.best_g > self.sigma * self.dginit" => LineSearchConditionMet)]
#[stop("(2.0*self.delta - 1.0)*self.dginit >= self.best_g && self.best_g >= self.sigma * self.dginit && self.best_f <= self.finit + self.epsilon_k" => LineSearchConditionMet)]
pub struct HagerZhangLineSearch<O>
where
O: ArgminOp<Output = f64>,
<O as ArgminOp>::Param: ArgminSub<<O as ArgminOp>::Param, <O as ArgminOp>::Param>
+ ArgminDot<<O as ArgminOp>::Param, f64>
+ ArgminScaledAdd<<O as ArgminOp>::Param, f64, <O as ArgminOp>::Param>,
{
delta: f64,
sigma: f64,
epsilon: f64,
epsilon_k: f64,
theta: f64,
gamma: f64,
eta: f64,
a_x_init: f64,
a_x: f64,
a_f: f64,
a_g: f64,
b_x_init: f64,
b_x: f64,
b_f: f64,
b_g: f64,
c_x_init: f64,
c_x: f64,
c_f: f64,
c_g: f64,
best_x: f64,
best_f: f64,
best_g: f64,
init_param_b: Option<<O as ArgminOp>::Param>,
finit_b: Option<f64>,
init_grad_b: Option<<O as ArgminOp>::Param>,
search_direction_b: Option<<O as ArgminOp>::Param>,
init_param: <O as ArgminOp>::Param,
finit: f64,
init_grad: <O as ArgminOp>::Param,
search_direction: <O as ArgminOp>::Param,
dginit: f64,
base: ArgminBase<O>,
}
impl<O> HagerZhangLineSearch<O>
where
O: ArgminOp<Output = f64>,
<O as ArgminOp>::Param: ArgminSub<<O as ArgminOp>::Param, <O as ArgminOp>::Param>
+ ArgminDot<<O as ArgminOp>::Param, f64>
+ ArgminScaledAdd<<O as ArgminOp>::Param, f64, <O as ArgminOp>::Param>,
{
pub fn new(operator: O) -> Self {
HagerZhangLineSearch {
delta: 0.1,
sigma: 0.9,
epsilon: 1e-6,
epsilon_k: std::f64::NAN,
theta: 0.5,
gamma: 0.66,
eta: 0.01,
a_x_init: std::f64::EPSILON,
a_x: std::f64::NAN,
a_f: std::f64::NAN,
a_g: std::f64::NAN,
b_x_init: 100.0,
b_x: std::f64::NAN,
b_f: std::f64::NAN,
b_g: std::f64::NAN,
c_x_init: 1.0,
c_x: std::f64::NAN,
c_f: std::f64::NAN,
c_g: std::f64::NAN,
best_x: 0.0,
best_f: std::f64::INFINITY,
best_g: std::f64::NAN,
init_param_b: None,
finit_b: None,
init_grad_b: None,
search_direction_b: None,
init_param: <O as ArgminOp>::Param::default(),
init_grad: <O as ArgminOp>::Param::default(),
search_direction: <O as ArgminOp>::Param::default(),
dginit: std::f64::NAN,
finit: std::f64::INFINITY,
base: ArgminBase::new(operator, <O as ArgminOp>::Param::default()),
}
}
pub fn set_cur_grad(&mut self, grad: <O as ArgminOp>::Param) -> &mut Self {
self.base.set_cur_grad(grad);
self
}
pub fn set_delta(&mut self, delta: f64) -> Result<&mut Self, Error> {
if delta <= 0.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: delta must be > 0.0.".to_string(),
}
.into());
}
if delta >= 1.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: delta must be < 1.0.".to_string(),
}
.into());
}
self.delta = delta;
Ok(self)
}
pub fn set_sigma(&mut self, sigma: f64) -> Result<&mut Self, Error> {
if sigma < self.delta {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: sigma must be >= delta.".to_string(),
}
.into());
}
if sigma >= 1.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: sigma must be < 1.0.".to_string(),
}
.into());
}
self.sigma = sigma;
Ok(self)
}
pub fn set_epsilon(&mut self, epsilon: f64) -> Result<&mut Self, Error> {
if epsilon < 0.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: epsilon must be >= 0.0.".to_string(),
}
.into());
}
self.epsilon = epsilon;
Ok(self)
}
pub fn set_theta(&mut self, theta: f64) -> Result<&mut Self, Error> {
if theta <= 0.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: theta must be > 0.0.".to_string(),
}
.into());
}
if theta >= 1.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: theta must be < 1.0.".to_string(),
}
.into());
}
self.theta = theta;
Ok(self)
}
pub fn set_gamma(&mut self, gamma: f64) -> Result<&mut Self, Error> {
if gamma <= 0.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: gamma must be > 0.0.".to_string(),
}
.into());
}
if gamma >= 1.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: gamma must be < 1.0.".to_string(),
}
.into());
}
self.gamma = gamma;
Ok(self)
}
pub fn set_eta(&mut self, eta: f64) -> Result<&mut Self, Error> {
if eta <= 0.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: eta must be > 0.0.".to_string(),
}
.into());
}
self.eta = eta;
Ok(self)
}
pub fn set_alpha_min_max(
&mut self,
alpha_min: f64,
alpha_max: f64,
) -> Result<&mut Self, Error> {
if alpha_min < 0.0 {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: alpha_min must be >= 0.0.".to_string(),
}
.into());
}
if alpha_max <= alpha_min {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: alpha_min must be smaller than alpha_max.".to_string(),
}
.into());
}
self.a_x_init = alpha_min;
self.b_x_init = alpha_max;
Ok(self)
}
fn update(
&mut self,
(a_x, a_f, a_g): Triplet,
(b_x, b_f, b_g): Triplet,
(c_x, c_f, c_g): Triplet,
) -> Result<(Triplet, Triplet), Error> {
if c_x <= a_x || c_x >= b_x {
return Ok(((a_x, a_f, a_g), (b_x, b_f, b_g)));
}
if c_g >= 0.0 {
return Ok(((a_x, a_f, a_g), (c_x, c_f, c_g)));
}
if c_g < 0.0 && c_f <= self.finit + self.epsilon_k {
return Ok(((c_x, c_f, c_g), (b_x, b_f, b_g)));
}
if c_g < 0.0 && c_f > self.finit + self.epsilon_k {
let mut ah_x = a_x;
let mut ah_f = a_f;
let mut ah_g = a_g;
let mut bh_x = c_x;
loop {
let d_x = (1.0 - self.theta) * ah_x + self.theta * bh_x;
let d_f = self.calc(d_x)?;
let d_g = self.calc_grad(d_x)?;
if d_g >= 0.0 {
return Ok(((ah_x, ah_f, ah_g), (d_x, d_f, d_g)));
}
if d_g < 0.0 && d_f <= self.finit + self.epsilon_k {
ah_x = d_x;
ah_f = d_f;
ah_g = d_g;
}
if d_g < 0.0 && d_f > self.finit + self.epsilon_k {
bh_x = d_x;
}
}
}
Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: Reached unreachable point in `update` method.".to_string(),
}
.into())
}
fn secant(&self, a_x: f64, a_g: f64, b_x: f64, b_g: f64) -> f64 {
(a_x * b_g - b_x * a_g) / (b_g - a_g)
}
fn secant2(
&mut self,
(a_x, a_f, a_g): Triplet,
(b_x, b_f, b_g): Triplet,
) -> Result<(Triplet, Triplet), Error> {
let c_x = self.secant(a_x, a_g, b_x, b_g);
let c_f = self.calc(c_x)?;
let c_g = self.calc_grad(c_x)?;
let mut c_bar_x: f64 = 0.0;
let ((aa_x, aa_f, aa_g), (bb_x, bb_f, bb_g)) =
self.update((a_x, a_f, a_g), (b_x, b_f, b_g), (c_x, c_f, c_g))?;
if (c_x - bb_x).abs() < std::f64::EPSILON {
c_bar_x = self.secant(b_x, b_g, bb_x, bb_g);
}
if (c_x - aa_x).abs() < std::f64::EPSILON {
c_bar_x = self.secant(a_x, a_g, aa_x, aa_g);
}
if (c_x - aa_x).abs() < std::f64::EPSILON || (c_x - bb_x).abs() < std::f64::EPSILON {
let c_bar_f = self.calc(c_bar_x)?;
let c_bar_g = self.calc_grad(c_bar_x)?;
let (a_bar, b_bar) = self.update(
(aa_x, aa_f, aa_g),
(bb_x, bb_f, bb_g),
(c_bar_x, c_bar_f, c_bar_g),
)?;
Ok((a_bar, b_bar))
} else {
Ok(((aa_x, aa_f, aa_g), (bb_x, bb_f, bb_g)))
}
}
fn calc(&mut self, alpha: f64) -> Result<f64, Error> {
let tmp = self.init_param.scaled_add(&alpha, &self.search_direction);
self.apply(&tmp)
}
fn calc_grad(&mut self, alpha: f64) -> Result<f64, Error> {
let tmp = self.init_param.scaled_add(&alpha, &self.search_direction);
let grad = self.gradient(&tmp)?;
Ok(self.search_direction.dot(&grad))
}
fn set_best(&mut self) {
if self.a_f < self.b_f && self.a_f < self.c_f {
self.best_x = self.a_x;
self.best_f = self.a_f;
self.best_g = self.a_g;
}
if self.b_f < self.a_f && self.b_f < self.c_f {
self.best_x = self.b_x;
self.best_f = self.b_f;
self.best_g = self.b_g;
}
if self.c_f < self.a_f && self.c_f < self.b_f {
self.best_x = self.c_x;
self.best_f = self.c_f;
self.best_g = self.c_g;
}
}
}
impl<O> ArgminLineSearch for HagerZhangLineSearch<O>
where
O: ArgminOp<Output = f64>,
<O as ArgminOp>::Param: ArgminSub<<O as ArgminOp>::Param, <O as ArgminOp>::Param>
+ ArgminDot<<O as ArgminOp>::Param, f64>
+ ArgminScaledAdd<<O as ArgminOp>::Param, f64, <O as ArgminOp>::Param>,
{
fn set_search_direction(&mut self, search_direction: <O as ArgminOp>::Param) {
self.search_direction_b = Some(search_direction);
}
fn set_initial_parameter(&mut self, param: <O as ArgminOp>::Param) {
self.init_param_b = Some(param.clone());
self.set_cur_param(param);
}
fn set_initial_cost(&mut self, init_cost: f64) {
self.finit_b = Some(init_cost);
}
fn set_initial_gradient(&mut self, init_grad: <O as ArgminOp>::Param) {
self.init_grad_b = Some(init_grad);
}
fn calc_initial_cost(&mut self) -> Result<(), Error> {
let tmp = self.cur_param();
self.finit_b = Some(self.apply(&tmp)?);
Ok(())
}
fn calc_initial_gradient(&mut self) -> Result<(), Error> {
let tmp = self.cur_param();
self.init_grad_b = Some(self.gradient(&tmp)?);
Ok(())
}
fn set_initial_alpha(&mut self, alpha: f64) -> Result<(), Error> {
self.c_x_init = alpha;
Ok(())
}
}
impl<O> ArgminIter for HagerZhangLineSearch<O>
where
O: ArgminOp<Output = f64>,
<O as ArgminOp>::Param: ArgminSub<<O as ArgminOp>::Param, <O as ArgminOp>::Param>
+ ArgminDot<<O as ArgminOp>::Param, f64>
+ ArgminScaledAdd<<O as ArgminOp>::Param, f64, <O as ArgminOp>::Param>,
{
type Param = <O as ArgminOp>::Param;
type Output = f64;
type Hessian = <O as ArgminOp>::Hessian;
fn init(&mut self) -> Result<(), Error> {
if self.sigma < self.delta {
return Err(ArgminError::InvalidParameter {
text: "HagerZhangLineSearch: sigma must be >= delta.".to_string(),
}
.into());
}
self.init_param = check_param!(
self.init_param_b,
"HagerZhangLineSearch: Initial parameter not initialized. Call `set_initial_parameter`."
);
self.finit = check_param!(
self.finit_b,
"HagerZhangLineSearch: Initial cost not computed. Call `set_initial_cost` or `calc_inital_cost`."
);
self.init_grad = check_param!(
self.init_grad_b,
"HagerZhangLineSearch: Initial gradient not computed. Call `set_initial_grad` or `calc_inital_grad`."
);
self.search_direction = check_param!(
self.search_direction_b,
"HagerZhangLineSearch: Search direction not initialized. Call `set_search_direction`."
);
self.a_x = self.a_x_init;
self.b_x = self.b_x_init;
self.c_x = self.c_x_init;
let at = self.a_x;
self.a_f = self.calc(at)?;
self.a_g = self.calc_grad(at)?;
let bt = self.b_x;
self.b_f = self.calc(bt)?;
self.b_g = self.calc_grad(bt)?;
let ct = self.c_x;
self.c_f = self.calc(ct)?;
self.c_g = self.calc_grad(ct)?;
self.epsilon_k = self.epsilon * self.finit.abs();
self.dginit = self.init_grad.dot(&self.search_direction);
self.set_best();
let new_param = self
.init_param
.scaled_add(&self.best_x, &self.search_direction);
self.set_best_param(new_param);
let best_f = self.best_f;
self.set_best_cost(best_f);
Ok(())
}
fn next_iter(&mut self) -> Result<ArgminIterData<Self::Param>, Error> {
let aa = (self.a_x, self.a_f, self.a_g);
let bb = (self.b_x, self.b_f, self.b_g);
let ((mut at_x, mut at_f, mut at_g), (mut bt_x, mut bt_f, mut bt_g)) =
self.secant2(aa, bb)?;
if bt_x - at_x > self.gamma * (self.b_x - self.a_x) {
let c_x = (at_x + bt_x) / 2.0;
let tmp = self.init_param.scaled_add(&c_x, &self.search_direction);
let c_f = self.apply(&tmp)?;
let grad = self.gradient(&tmp)?;
let c_g = self.search_direction.dot(&grad);
let ((an_x, an_f, an_g), (bn_x, bn_f, bn_g)) =
self.update((at_x, at_f, at_g), (bt_x, bt_f, bt_g), (c_x, c_f, c_g))?;
at_x = an_x;
at_f = an_f;
at_g = an_g;
bt_x = bn_x;
bt_f = bn_f;
bt_g = bn_g;
}
self.a_x = at_x;
self.a_f = at_f;
self.a_g = at_g;
self.b_x = bt_x;
self.b_f = bt_f;
self.b_g = bt_g;
self.set_best();
let new_param = self
.init_param
.scaled_add(&self.best_x, &self.search_direction);
let out = ArgminIterData::new(new_param, self.best_f);
Ok(out)
}
}