use numra_core::{NumraError, Scalar};
#[derive(Debug, Clone)]
pub enum LineSearchError {
NotDescentDirection,
BracketCollapsed,
MaxIterations,
}
impl std::fmt::Display for LineSearchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotDescentDirection => write!(f, "search direction is not a descent direction"),
Self::BracketCollapsed => write!(f, "zoom bracket collapsed"),
Self::MaxIterations => write!(f, "max line search iterations reached"),
}
}
}
impl std::error::Error for LineSearchError {}
impl From<LineSearchError> for NumraError {
fn from(e: LineSearchError) -> Self {
NumraError::LineSearch(e.to_string())
}
}
#[derive(Debug, Clone)]
pub struct WolfeOptions<S: Scalar> {
pub c1: S,
pub c2: S,
pub max_step: S,
pub max_iter: usize,
}
impl<S: Scalar> Default for WolfeOptions<S> {
fn default() -> Self {
Self {
c1: S::from_f64(1e-4),
c2: S::from_f64(0.9),
max_step: S::from_f64(1e20),
max_iter: 40,
}
}
}
#[derive(Debug, Clone)]
pub struct LineSearchResult<S: Scalar> {
pub step: S,
pub f_new: S,
pub n_eval: usize,
}
fn dot<S: Scalar>(a: &[S], b: &[S]) -> S {
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| ai * bi)
.fold(S::ZERO, |acc, x| acc + x)
}
const BRACKET_COLLAPSE_TOL: f64 = 1e-16;
#[allow(clippy::too_many_arguments)]
fn zoom<S, F, G>(
f: &F,
grad: &G,
x: &[S],
d: &[S],
f0: S,
dg0: S,
mut alpha_lo: S,
mut f_lo: S,
mut alpha_hi: S,
opts: &WolfeOptions<S>,
n_eval: &mut usize,
) -> Result<LineSearchResult<S>, LineSearchError>
where
S: Scalar,
F: Fn(&[S]) -> S,
G: Fn(&[S], &mut [S]),
{
let n = x.len();
let mut x_trial = vec![S::ZERO; n];
let mut g_trial = vec![S::ZERO; n];
for _ in 0..opts.max_iter {
if (alpha_hi - alpha_lo).abs() < S::from_f64(BRACKET_COLLAPSE_TOL) {
return Err(LineSearchError::BracketCollapsed);
}
let alpha_j = (alpha_lo + alpha_hi) / S::TWO;
for i in 0..n {
x_trial[i] = x[i] + alpha_j * d[i];
}
let f_j = f(&x_trial);
*n_eval += 1;
if f_j > f0 + opts.c1 * alpha_j * dg0 || f_j >= f_lo {
alpha_hi = alpha_j;
} else {
grad(&x_trial, &mut g_trial);
let dg_j = dot(&g_trial, d);
if dg_j.abs() <= -opts.c2 * dg0 {
return Ok(LineSearchResult {
step: alpha_j,
f_new: f_j,
n_eval: *n_eval,
});
}
if dg_j * (alpha_hi - alpha_lo) >= S::ZERO {
alpha_hi = alpha_lo;
}
alpha_lo = alpha_j;
f_lo = f_j;
}
}
Err(LineSearchError::MaxIterations)
}
pub fn wolfe_line_search<S, F, G>(
f: F,
grad: G,
x: &[S],
d: &[S],
f0: S,
g0: &[S],
opts: &WolfeOptions<S>,
) -> Result<LineSearchResult<S>, LineSearchError>
where
S: Scalar,
F: Fn(&[S]) -> S,
G: Fn(&[S], &mut [S]),
{
let dg0 = dot(g0, d);
if dg0 >= S::ZERO {
return Err(LineSearchError::NotDescentDirection);
}
let n = x.len();
let mut x_trial = vec![S::ZERO; n];
let mut g_trial = vec![S::ZERO; n];
let mut alpha_prev = S::ZERO;
let mut f_prev = f0;
let mut alpha = S::ONE;
let mut n_eval: usize = 0;
for i in 1..=opts.max_iter {
if alpha > opts.max_step {
alpha = opts.max_step;
}
for j in 0..n {
x_trial[j] = x[j] + alpha * d[j];
}
let f_alpha = f(&x_trial);
n_eval += 1;
if f_alpha > f0 + opts.c1 * alpha * dg0 || (i > 1 && f_alpha >= f_prev) {
return zoom(
&f,
&grad,
x,
d,
f0,
dg0,
alpha_prev,
f_prev,
alpha,
opts,
&mut n_eval,
);
}
grad(&x_trial, &mut g_trial);
let dg_alpha = dot(&g_trial, d);
if dg_alpha.abs() <= -opts.c2 * dg0 {
return Ok(LineSearchResult {
step: alpha,
f_new: f_alpha,
n_eval,
});
}
if dg_alpha >= S::ZERO {
return zoom(
&f,
&grad,
x,
d,
f0,
dg0,
alpha,
f_alpha,
alpha_prev,
opts,
&mut n_eval,
);
}
alpha_prev = alpha;
f_prev = f_alpha;
alpha *= S::TWO;
}
Err(LineSearchError::MaxIterations)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wolfe_quadratic() {
let f = |x: &[f64]| x[0] * x[0];
let grad = |x: &[f64], g: &mut [f64]| {
g[0] = 2.0 * x[0];
};
let x = [2.0];
let d = [-1.0];
let f0 = f(&x);
let g0 = [4.0];
let opts = WolfeOptions::default();
let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
assert!(res.step > 0.0, "step must be positive");
assert!(
res.f_new < f0,
"function must decrease: f_new={} vs f0={}",
res.f_new,
f0
);
}
#[test]
fn test_wolfe_rosenbrock() {
let f = |x: &[f64]| {
let a = 1.0 - x[0];
let b = x[1] - x[0] * x[0];
a * a + 100.0 * b * b
};
let grad = |x: &[f64], g: &mut [f64]| {
g[0] = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
g[1] = 200.0 * (x[1] - x[0] * x[0]);
};
let x = [-1.0, 1.0];
let f0 = f(&x);
let mut g0 = [0.0; 2];
grad(&x, &mut g0);
let d = [-g0[0], -g0[1]];
let opts = WolfeOptions::default();
let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
assert!(res.step > 0.0, "step must be positive");
assert!(
res.f_new < f0,
"function must decrease: f_new={} vs f0={}",
res.f_new,
f0
);
}
#[test]
fn test_wolfe_not_descent() {
let f = |x: &[f64]| x[0] * x[0];
let grad = |x: &[f64], g: &mut [f64]| {
g[0] = 2.0 * x[0];
};
let x = [2.0];
let d = [1.0]; let f0 = f(&x);
let g0 = [4.0];
let opts = WolfeOptions::default();
let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts);
assert!(res.is_err(), "must reject non-descent direction");
assert!(
matches!(res.unwrap_err(), LineSearchError::NotDescentDirection),
"error should be NotDescentDirection"
);
}
#[test]
fn test_wolfe_f32() {
let f = |x: &[f32]| x[0] * x[0];
let grad = |x: &[f32], g: &mut [f32]| {
g[0] = 2.0 * x[0];
};
let x = [2.0f32];
let d = [-1.0f32];
let f0 = f(&x);
let g0 = [4.0f32];
let opts = WolfeOptions::default();
let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
assert!(res.step > 0.0, "step must be positive");
assert!(res.f_new < f0, "function must decrease");
}
}