use crate::{
algorithms::gradient::GradientStatus,
core::Bounds,
error::{GaneshError, GaneshResult},
traits::{linesearch::LineSearchOutput, Gradient, LineSearch},
DVector, Float,
};
#[derive(Clone)]
pub struct BacktrackingLineSearch {
rho: Float,
c: Float,
}
impl Default for BacktrackingLineSearch {
fn default() -> Self {
Self { rho: 0.5, c: 1e-4 }
}
}
impl BacktrackingLineSearch {
pub fn with_rho(mut self, rho: Float) -> GaneshResult<Self> {
if !(0.0 < rho && rho < 1.0) {
return Err(GaneshError::ConfigError(
"BacktrackingLineSearch requires 0 < rho < 1".to_string(),
));
}
self.rho = rho;
Ok(self)
}
pub fn with_c(mut self, c: Float) -> GaneshResult<Self> {
if !(0.0 < c && c < 1.0) {
return Err(GaneshError::ConfigError(
"BacktrackingLineSearch requires 0 < c < 1".to_string(),
));
}
self.c = c;
Ok(self)
}
}
impl<U, E> LineSearch<GradientStatus, U, E> for BacktrackingLineSearch {
fn search(
&mut self,
x: &DVector<Float>,
p: &DVector<Float>,
max_step: Option<Float>,
problem: &dyn Gradient<U, E>,
_bounds: Option<&Bounds>,
args: &U,
status: &mut GradientStatus,
) -> Result<Result<LineSearchOutput, LineSearchOutput>, E> {
let mut alpha_i = max_step.map_or(1.0, |max_alpha| max_alpha);
let phi = |alpha: Float, args: &U, st: &mut GradientStatus| -> Result<Float, E> {
st.inc_n_f_evals();
problem.evaluate(&(x + p.scale(alpha)), args)
};
status.inc_n_f_evals();
status.inc_n_g_evals();
let (phi_0, g_0) = problem.evaluate_with_gradient(x, args)?;
let mut phi_alpha_i = phi(alpha_i, args, status)?;
let dphi_0 = g_0.dot(p);
loop {
let armijo = phi_alpha_i <= (self.c * alpha_i).mul_add(dphi_0, phi_0);
if armijo {
status.inc_n_g_evals();
let g_alpha_i = problem.gradient(&(x + p.scale(alpha_i)), args)?;
return Ok(Ok(LineSearchOutput {
alpha: alpha_i,
fx: phi_alpha_i,
g: g_alpha_i,
}));
}
alpha_i *= self.rho;
phi_alpha_i = phi(alpha_i, args, status)?;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn with_rho_sets_value() {
let ls = BacktrackingLineSearch::default().with_rho(0.7).unwrap();
assert_eq!(ls.rho, 0.7);
}
#[test]
fn with_c_sets_value() {
let ls = BacktrackingLineSearch::default().with_c(1e-3).unwrap();
assert_eq!(ls.c, 1e-3);
}
#[test]
fn with_rho_errors_when_out_of_range_low() {
assert!(BacktrackingLineSearch::default().with_rho(0.0).is_err());
}
#[test]
fn with_rho_errors_when_out_of_range_high() {
assert!(BacktrackingLineSearch::default().with_rho(1.0).is_err());
}
#[test]
fn with_c_errors_when_out_of_range_low() {
assert!(BacktrackingLineSearch::default().with_c(0.0).is_err());
}
#[test]
fn with_c_errors_when_out_of_range_high() {
assert!(BacktrackingLineSearch::default().with_c(1.0).is_err());
}
}