use crate::{
algorithms::gradient::GradientStatus,
core::Bounds,
error::{GaneshError, GaneshResult},
traits::{linesearch::LineSearchOutput, Gradient, LineSearch},
DVector, Float,
};
#[derive(Clone)]
pub struct MoreThuenteLineSearch {
max_iters: usize,
max_zoom: usize,
c1: Float,
c2: Float,
}
impl Default for MoreThuenteLineSearch {
fn default() -> Self {
Self {
max_iters: 100,
max_zoom: 100,
c1: 1e-4,
c2: 0.9,
}
}
}
impl MoreThuenteLineSearch {
pub const fn with_max_iterations(mut self, max_iters: usize) -> Self {
self.max_iters = max_iters;
self
}
pub const fn with_max_zoom(mut self, max_zoom: usize) -> Self {
self.max_zoom = max_zoom;
self
}
pub fn with_c1(mut self, c1: Float) -> GaneshResult<Self> {
if !(0.0 < c1 && c1 < self.c2) {
return Err(GaneshError::ConfigError(
"MoreThuenteLineSearch requires 0 < c1 < c2".to_string(),
));
}
self.c1 = c1;
Ok(self)
}
pub fn with_c2(mut self, c2: Float) -> GaneshResult<Self> {
if !(self.c1 < c2 && c2 < 1.0) {
return Err(GaneshError::ConfigError(
"MoreThuenteLineSearch requires c1 < c2 < 1".to_string(),
));
}
self.c2 = c2;
Ok(self)
}
pub fn with_c1_c2(mut self, c1: Float, c2: Float) -> GaneshResult<Self> {
if !(0.0 < c1 && c1 < c2 && c2 < 1.0) {
return Err(GaneshError::ConfigError(
"MoreThuenteLineSearch requires 0 < c1 < c2 < 1".to_string(),
));
}
self.c1 = c1;
self.c2 = c2;
Ok(self)
}
}
impl MoreThuenteLineSearch {
fn f_eval<U, E>(
&self,
func: &dyn Gradient<U, E>,
x: &DVector<Float>,
args: &U,
status: &mut GradientStatus,
) -> Result<Float, E> {
status.inc_n_f_evals();
func.evaluate(x, args)
}
fn g_eval<U, E>(
&self,
func: &dyn Gradient<U, E>,
x: &DVector<Float>,
args: &U,
status: &mut GradientStatus,
) -> Result<DVector<Float>, E> {
status.inc_n_g_evals();
func.gradient(x, args)
}
fn f_g_eval<U, E>(
&self,
func: &dyn Gradient<U, E>,
x: &DVector<Float>,
args: &U,
status: &mut GradientStatus,
) -> Result<(Float, DVector<Float>), E> {
status.inc_n_f_evals();
status.inc_n_g_evals();
func.evaluate_with_gradient(x, args)
}
#[allow(clippy::too_many_arguments)]
fn zoom<U, E>(
&self,
func: &dyn Gradient<U, E>,
x0: &DVector<Float>,
args: &U,
f0: Float,
g0: &DVector<Float>,
p: &DVector<Float>,
alpha_lo: Float,
alpha_hi: Float,
status: &mut GradientStatus,
) -> Result<Result<LineSearchOutput, LineSearchOutput>, E> {
let mut alpha_lo = alpha_lo;
let mut alpha_hi = alpha_hi;
let dphi0 = g0.dot(p);
let mut i = 0;
loop {
let alpha_i = (alpha_lo + alpha_hi) / 2.0;
let x = x0 + p.scale(alpha_i);
let f_i = self.f_eval(func, &x, args, status)?;
let x_lo = x0 + p.scale(alpha_lo);
let f_lo = self.f_eval(func, &x_lo, args, status)?;
let valid = if (f_i > (self.c1 * alpha_i).mul_add(dphi0, f0)) || (f_i >= f_lo) {
alpha_hi = alpha_i;
false
} else {
let g_i = self.g_eval(func, &x, args, status)?;
let dphi = g_i.dot(p);
if Float::abs(dphi) <= -self.c2 * dphi0 {
return Ok(Ok(LineSearchOutput {
alpha: alpha_i,
fx: f_i,
g: g_i,
}));
}
if dphi * (alpha_hi - alpha_lo) >= 0.0 {
alpha_hi = alpha_lo;
}
alpha_lo = alpha_i;
true
};
i += 1;
if i > self.max_zoom {
let g_i = self.g_eval(func, &x, args, status)?;
if valid {
return Ok(Ok(LineSearchOutput {
alpha: alpha_i,
fx: f_i,
g: g_i,
}));
} else {
return Ok(Err(LineSearchOutput {
alpha: alpha_i,
fx: f_i,
g: g_i,
}));
}
}
}
}
}
impl<U, E> LineSearch<GradientStatus, U, E> for MoreThuenteLineSearch {
fn search(
&mut self,
x0: &DVector<Float>,
p: &DVector<Float>,
max_step: Option<Float>,
problem: &dyn Gradient<U, E>,
_bounds: Option<&Bounds>,
args: &U,
status: &mut GradientStatus,
) -> Result<Result<LineSearchOutput, LineSearchOutput>, E> {
let (f0, g0) = self.f_g_eval(problem, x0, args, status)?;
let alpha_max = max_step.unwrap_or(1.0); let mut alpha_im1 = 0.0;
let mut alpha_i = 1.0;
let mut f_im1 = f0;
let dphi0 = g0.dot(p);
let mut i = 0;
loop {
let x = x0 + p.scale(alpha_i);
let f_i = self.f_eval(problem, &x, args, status)?;
if (f_i > self.c1.mul_add(dphi0, f0)) || (i > 1 && f_i >= f_im1) {
return self.zoom(problem, x0, args, f0, &g0, p, alpha_im1, alpha_i, status);
}
let g_i = self.g_eval(problem, &x, args, status)?;
let dphi = g_i.dot(p);
if Float::abs(dphi) <= self.c2 * Float::abs(dphi0) {
return Ok(Ok(LineSearchOutput {
alpha: alpha_i,
fx: f_i,
g: g_i,
}));
}
if dphi >= 0.0 {
return self.zoom(problem, x0, args, f0, &g0, p, alpha_i, alpha_im1, status);
}
alpha_im1 = alpha_i;
f_im1 = f_i;
alpha_i += 0.8 * (alpha_max - alpha_i);
i += 1;
if i > self.max_iters {
return Ok(Err(LineSearchOutput {
alpha: alpha_i,
fx: f_i,
g: g_i,
}));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn with_c1_sets_value() {
let ls = MoreThuenteLineSearch::default().with_c1(1e-3).unwrap();
assert_eq!(ls.c1, 1e-3);
assert!(ls.c1 > 0.0 && ls.c1 < ls.c2);
}
#[test]
fn with_c2_sets_value() {
let ls = MoreThuenteLineSearch::default().with_c2(0.8).unwrap();
assert_eq!(ls.c2, 0.8);
assert!(ls.c2 < 1.0 && ls.c2 > ls.c1);
}
#[test]
fn with_c1_c2_sets_both() {
let ls = MoreThuenteLineSearch::default()
.with_c1_c2(1e-5, 0.7)
.unwrap();
assert_eq!(ls.c1, 1e-5);
assert_eq!(ls.c2, 0.7);
assert!(ls.c1 > 0.0 && ls.c2 < 1.0 && ls.c1 < ls.c2);
}
#[test]
fn with_c1_errors_when_nonpositive() {
assert!(MoreThuenteLineSearch::default().with_c1(0.0).is_err());
}
#[test]
fn with_c1_errors_when_not_less_than_c2() {
let ls = MoreThuenteLineSearch::default().with_c2(0.2).unwrap();
assert!(ls.with_c1(0.3).is_err());
}
#[test]
fn with_c2_errors_when_not_less_than_one() {
assert!(MoreThuenteLineSearch::default().with_c2(1.0).is_err());
}
#[test]
fn with_c2_errors_when_not_greater_than_c1() {
let ls = MoreThuenteLineSearch::default().with_c1(1e-4).unwrap();
assert!(ls.with_c2(1e-5).is_err());
}
#[test]
fn with_c1_c2_errors_when_bad_ordering() {
assert!(MoreThuenteLineSearch::default()
.with_c1_c2(0.9, 0.1)
.is_err());
}
#[test]
fn with_c1_c2_errors_when_c2_not_less_than_one() {
assert!(MoreThuenteLineSearch::default()
.with_c1_c2(1e-4, 1.0)
.is_err());
}
#[test]
fn with_c1_c2_errors_when_c1_not_positive() {
assert!(MoreThuenteLineSearch::default()
.with_c1_c2(0.0, 0.5)
.is_err());
}
}