use crate::core::math::{Dot, ScaledAdd};
use crate::core::problem::{CostFunction, Gradient, Problem};
use crate::line_search::LineSearch;
pub struct Wolfe {
pub c1: f64,
pub c2: f64,
pub alpha_init: f64,
pub alpha_max: f64,
pub max_iter: u32,
}
impl Default for Wolfe {
fn default() -> Self {
Self {
c1: 1e-4,
c2: 0.9,
alpha_init: 1.0,
alpha_max: 10.0,
max_iter: 25,
}
}
}
impl Wolfe {
pub fn new() -> Self {
Self::default()
}
pub fn c1(mut self, c1: f64) -> Self {
assert!(0.0 < c1 && c1 < 1.0, "c1 must be in (0, 1)");
self.c1 = c1;
self
}
pub fn c2(mut self, c2: f64) -> Self {
assert!(0.0 < c2 && c2 < 1.0, "c2 must be in (0, 1)");
self.c2 = c2;
self
}
pub fn alpha_init(mut self, alpha_init: f64) -> Self {
assert!(alpha_init > 0.0, "alpha_init must be > 0");
self.alpha_init = alpha_init;
self
}
pub fn alpha_max(mut self, alpha_max: f64) -> Self {
assert!(alpha_max > 0.0, "alpha_max must be > 0");
self.alpha_max = alpha_max;
self
}
pub fn max_iter(mut self, max_iter: u32) -> Self {
self.max_iter = max_iter;
self
}
}
impl<P, V> LineSearch<P, V> for Wolfe
where
P: CostFunction<Param = V, Output = f64> + Gradient<Gradient = V>,
V: ScaledAdd<f64> + Dot + Clone,
{
type Error = P::Error;
fn next(
&mut self,
problem: &mut Problem<P>,
param: &V,
cost: f64,
gradient: &V,
direction: &V,
) -> Result<f64, Self::Error> {
let phi0 = cost;
let phi0_prime = gradient.dot(direction);
if phi0_prime >= 0.0 || phi0_prime.is_nan() {
return Ok(0.0);
}
let mut alpha_prev = 0.0;
let mut phi_prev = phi0;
let mut alpha = self.alpha_init.min(self.alpha_max);
for i in 0..self.max_iter {
let mut trial = param.clone();
trial.scaled_add(alpha, direction);
let phi = problem.cost(&trial)?;
if phi > phi0 + self.c1 * alpha * phi0_prime || (i > 0 && phi >= phi_prev) {
return self.zoom(
problem, param, direction, phi0, phi0_prime, alpha_prev, phi_prev, alpha,
);
}
let g_trial = problem.gradient(&trial)?;
let phi_prime = g_trial.dot(direction);
if phi_prime.abs() <= -self.c2 * phi0_prime {
return Ok(alpha);
}
if phi_prime >= 0.0 {
return self.zoom(
problem, param, direction, phi0, phi0_prime, alpha, phi, alpha_prev,
);
}
alpha_prev = alpha;
phi_prev = phi;
let next_alpha = (alpha * 2.0).min(self.alpha_max);
if next_alpha == alpha {
return Ok(alpha);
}
alpha = next_alpha;
}
Ok(alpha)
}
}
impl Wolfe {
#[allow(clippy::too_many_arguments)]
fn zoom<P, V>(
&self,
problem: &mut Problem<P>,
param: &V,
direction: &V,
phi0: f64,
phi0_prime: f64,
mut alpha_lo: f64,
mut phi_lo: f64,
mut alpha_hi: f64,
) -> Result<f64, P::Error>
where
P: CostFunction<Param = V, Output = f64> + Gradient<Gradient = V>,
V: ScaledAdd<f64> + Dot + Clone,
{
for _ in 0..self.max_iter {
let alpha_j = 0.5 * (alpha_lo + alpha_hi);
let mut trial = param.clone();
trial.scaled_add(alpha_j, direction);
let phi_j = problem.cost(&trial)?;
if phi_j > phi0 + self.c1 * alpha_j * phi0_prime || phi_j >= phi_lo {
alpha_hi = alpha_j;
} else {
let g_j = problem.gradient(&trial)?;
let phi_j_prime = g_j.dot(direction);
if phi_j_prime.abs() <= -self.c2 * phi0_prime {
return Ok(alpha_j);
}
if phi_j_prime * (alpha_hi - alpha_lo) >= 0.0 {
alpha_hi = alpha_lo;
}
alpha_lo = alpha_j;
phi_lo = phi_j;
}
if (alpha_hi - alpha_lo).abs() <= f64::EPSILON * alpha_hi.abs().max(1.0) {
break;
}
}
Ok(alpha_lo)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Quadratic;
impl CostFunction for Quadratic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
Ok((x[0] - 3.0).powi(2))
}
}
impl Gradient for Quadratic {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, std::convert::Infallible> {
Ok(vec![2.0 * (x[0] - 3.0)])
}
}
#[test]
fn satisfies_strong_wolfe_on_quadratic() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![-g[0]]; let mut ls = Wolfe::new();
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!(alpha > 0.0);
let c1 = 1e-4;
let c2 = 0.9;
let mut x_new = x.clone();
x_new[0] += alpha * d[0];
let f_new = p.cost(&x_new).unwrap();
let g_new = p.gradient(&x_new).unwrap();
let g0_dot_d = g[0] * d[0];
let gnew_dot_d = g_new[0] * d[0];
assert!(
f_new <= f0 + c1 * alpha * g0_dot_d + 1e-12,
"Armijo failed: f_new={f_new}, threshold={}",
f0 + c1 * alpha * g0_dot_d,
);
assert!(
gnew_dot_d.abs() <= -c2 * g0_dot_d + 1e-12,
"Strong curvature failed: |g_new·d|={}, threshold={}",
gnew_dot_d.abs(),
-c2 * g0_dot_d,
);
}
#[test]
fn unit_step_accepted_when_quadratic_minimum_inside_bracket() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![6.0];
let mut ls = Wolfe::new();
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!(
(alpha - 0.5).abs() < 0.5,
"expected α near 0.5, got {alpha}",
);
}
}