use crate::primitives::Vector;
pub trait LineSearch {
fn search<F, G>(&self, f: &F, grad: &G, x: &Vector<f32>, d: &Vector<f32>) -> f32
where
F: Fn(&Vector<f32>) -> f32,
G: Fn(&Vector<f32>) -> Vector<f32>;
}
#[derive(Debug, Clone)]
pub struct BacktrackingLineSearch {
pub(crate) c1: f32,
pub(crate) rho: f32,
pub(crate) max_iter: usize,
}
impl BacktrackingLineSearch {
#[must_use]
pub fn new(c1: f32, rho: f32, max_iter: usize) -> Self {
Self { c1, rho, max_iter }
}
}
impl Default for BacktrackingLineSearch {
fn default() -> Self {
Self::new(1e-4, 0.5, 50)
}
}
impl LineSearch for BacktrackingLineSearch {
fn search<F, G>(&self, f: &F, grad: &G, x: &Vector<f32>, d: &Vector<f32>) -> f32
where
F: Fn(&Vector<f32>) -> f32,
G: Fn(&Vector<f32>) -> Vector<f32>,
{
let mut alpha = 1.0;
let fx = f(x);
let grad_x = grad(x);
let mut dir_deriv = 0.0;
for i in 0..x.len() {
dir_deriv += grad_x[i] * d[i];
}
for _ in 0..self.max_iter {
let mut x_new = Vector::zeros(x.len());
for i in 0..x.len() {
x_new[i] = x[i] + alpha * d[i];
}
let fx_new = f(&x_new);
if fx_new <= fx + self.c1 * alpha * dir_deriv {
return alpha;
}
alpha *= self.rho;
}
alpha
}
}
#[derive(Debug, Clone)]
pub struct WolfeLineSearch {
pub(crate) c1: f32,
pub(crate) c2: f32,
pub(crate) max_iter: usize,
}
impl WolfeLineSearch {
#[must_use]
pub fn new(c1: f32, c2: f32, max_iter: usize) -> Self {
assert!(
c1 < c2 && c1 > 0.0 && c2 < 1.0,
"Wolfe conditions require 0 < c1 < c2 < 1"
);
Self { c1, c2, max_iter }
}
}
impl Default for WolfeLineSearch {
fn default() -> Self {
Self::new(1e-4, 0.9, 50)
}
}
impl LineSearch for WolfeLineSearch {
fn search<F, G>(&self, f: &F, grad: &G, x: &Vector<f32>, d: &Vector<f32>) -> f32
where
F: Fn(&Vector<f32>) -> f32,
G: Fn(&Vector<f32>) -> Vector<f32>,
{
let fx = f(x);
let grad_x = grad(x);
let mut dir_deriv = 0.0;
for i in 0..x.len() {
dir_deriv += grad_x[i] * d[i];
}
let mut alpha = 1.0;
let mut alpha_lo = 0.0;
let mut alpha_hi = f32::INFINITY;
for _ in 0..self.max_iter {
let mut x_new = Vector::zeros(x.len());
for i in 0..x.len() {
x_new[i] = x[i] + alpha * d[i];
}
let fx_new = f(&x_new);
let grad_new = grad(&x_new);
let mut dir_deriv_new = 0.0;
for i in 0..x.len() {
dir_deriv_new += grad_new[i] * d[i];
}
if fx_new > fx + self.c1 * alpha * dir_deriv {
alpha_hi = alpha;
alpha = f32::midpoint(alpha_lo, alpha_hi);
continue;
}
if dir_deriv_new.abs() <= self.c2 * dir_deriv.abs() {
return alpha;
}
if dir_deriv_new > 0.0 {
alpha_hi = alpha;
} else {
alpha_lo = alpha;
}
if alpha_hi.is_finite() {
alpha = f32::midpoint(alpha_lo, alpha_hi);
} else {
alpha *= 2.0;
}
}
alpha
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backtracking_quadratic() {
let ls = BacktrackingLineSearch::default();
let f = |x: &Vector<f32>| x[0] * x[0] + x[1] * x[1];
let grad = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0], 2.0 * x[1]]);
let x = Vector::from_slice(&[1.0, 1.0]);
let d = Vector::from_slice(&[-2.0, -2.0]);
let alpha = ls.search(&f, &grad, &x, &d);
assert!(alpha > 0.0);
assert!(alpha <= 1.0);
}
#[test]
fn test_wolfe_quadratic() {
let ls = WolfeLineSearch::default();
let f = |x: &Vector<f32>| x[0] * x[0];
let grad = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0]]);
let x = Vector::from_slice(&[1.0]);
let d = Vector::from_slice(&[-2.0]);
let alpha = ls.search(&f, &grad, &x, &d);
assert!(alpha > 0.0);
}
#[test]
fn test_backtracking_ensures_decrease() {
let ls = BacktrackingLineSearch::new(1e-4, 0.5, 100);
let f = |x: &Vector<f32>| x[0] * x[0];
let grad = |x: &Vector<f32>| Vector::from_slice(&[2.0 * x[0]]);
let x = Vector::from_slice(&[5.0]);
let g = grad(&x);
let d = Vector::from_slice(&[-g[0]]);
let alpha = ls.search(&f, &grad, &x, &d);
let mut x_new = Vector::zeros(1);
x_new[0] = x[0] + alpha * d[0];
assert!(f(&x_new) < f(&x));
}
}