use crate::core::math::{Dot, Scalar, ScaledAdd};
use crate::core::problem::{CostFunction, Gradient, Problem};
use crate::line_search::LineSearch;
pub struct MoreThuente<F = f64> {
pub ftol: F,
pub gtol: F,
pub xtol: F,
pub alpha_init: F,
pub stpmin: F,
pub stpmax: F,
pub maxfev: u32,
}
impl<F: Scalar> Default for MoreThuente<F> {
fn default() -> Self {
Self {
ftol: F::from_f64(1.0e-3).unwrap(),
gtol: F::from_f64(0.9).unwrap(),
xtol: F::from_f64(0.1).unwrap(),
alpha_init: F::one(),
stpmin: F::zero(),
stpmax: F::from_f64(1.0e10).unwrap(),
maxfev: 20,
}
}
}
impl<F: Scalar> MoreThuente<F> {
pub fn new() -> Self {
Self::default()
}
pub fn ftol(mut self, ftol: F) -> Self {
assert!(
F::zero() < ftol && ftol < F::one(),
"ftol must be in (0, 1)"
);
self.ftol = ftol;
self
}
pub fn gtol(mut self, gtol: F) -> Self {
assert!(
F::zero() < gtol && gtol < F::one(),
"gtol must be in (0, 1)"
);
self.gtol = gtol;
self
}
pub fn xtol(mut self, xtol: F) -> Self {
assert!(xtol >= F::zero(), "xtol must be ≥ 0");
self.xtol = xtol;
self
}
pub fn alpha_init(mut self, alpha_init: F) -> Self {
assert!(alpha_init > F::zero(), "alpha_init must be > 0");
self.alpha_init = alpha_init;
self
}
pub fn stpmin(mut self, stpmin: F) -> Self {
assert!(stpmin >= F::zero(), "stpmin must be ≥ 0");
self.stpmin = stpmin;
self
}
pub fn stpmax(mut self, stpmax: F) -> Self {
assert!(stpmax > F::zero(), "stpmax must be > 0");
self.stpmax = stpmax;
self
}
pub fn maxfev(mut self, maxfev: u32) -> Self {
self.maxfev = maxfev;
self
}
}
impl<P, V, F> LineSearch<P, V, F> for MoreThuente<F>
where
F: Scalar,
P: CostFunction<Param = V, Output = F> + Gradient<Gradient = V>,
V: ScaledAdd<F> + Dot<F> + Clone,
{
type Error = P::Error;
fn next(
&mut self,
problem: &mut Problem<P>,
param: &V,
cost: F,
gradient: &V,
direction: &V,
) -> Result<F, Self::Error> {
let zero = F::zero();
let p5 = F::from_f64(0.5).unwrap();
let p66 = F::from_f64(0.66).unwrap();
let xtrapl = F::from_f64(1.1).unwrap();
let xtrapu = F::from_f64(4.0).unwrap();
let finit = cost;
let ginit = gradient.dot(direction);
if !ginit.is_finite() || ginit >= zero {
return Ok(zero);
}
if !(self.alpha_init >= self.stpmin && self.alpha_init <= self.stpmax) {
return Ok(zero);
}
let mut brackt = false;
let mut stage: u8 = 1;
let gtest = self.ftol * ginit;
let mut width = self.stpmax - self.stpmin;
let mut width1 = width / p5;
let mut stx = zero;
let mut fx = finit;
let mut gx = ginit;
let mut sty = zero;
let mut fy = finit;
let mut gy = ginit;
let mut stmin = zero;
let mut stmax = self.alpha_init + xtrapu * self.alpha_init;
let mut stp = self.alpha_init;
for _ in 0..self.maxfev {
let mut trial = param.clone();
trial.scaled_add(stp, direction);
let f = problem.cost(&trial)?;
let g_full = problem.gradient(&trial)?;
let g = g_full.dot(direction);
let ftest = finit + stp * gtest;
if stage == 1 && f <= ftest && g >= zero {
stage = 2;
}
let warn_rounding = brackt && (stp <= stmin || stp >= stmax);
let warn_xtol = brackt && stmax - stmin <= self.xtol * stmax;
let warn_stpmax = stp == self.stpmax && f <= ftest && g <= gtest;
let warn_stpmin = stp == self.stpmin && (f > ftest || g >= gtest);
let converged = f <= ftest && g.abs() <= self.gtol * (-ginit);
if warn_rounding || warn_xtol || warn_stpmax || warn_stpmin || converged {
return Ok(stp);
}
if stage == 1 && f <= fx && f > ftest {
let fm = f - stp * gtest;
let mut fxm = fx - stx * gtest;
let mut fym = fy - sty * gtest;
let gm = g - gtest;
let mut gxm = gx - gtest;
let mut gym = gy - gtest;
dcstep::<F>(
&mut stx,
&mut fxm,
&mut gxm,
&mut sty,
&mut fym,
&mut gym,
&mut stp,
fm,
gm,
&mut brackt,
stmin,
stmax,
);
fx = fxm + stx * gtest;
fy = fym + sty * gtest;
gx = gxm + gtest;
gy = gym + gtest;
} else {
dcstep::<F>(
&mut stx,
&mut fx,
&mut gx,
&mut sty,
&mut fy,
&mut gy,
&mut stp,
f,
g,
&mut brackt,
stmin,
stmax,
);
}
if brackt {
if (sty - stx).abs() >= p66 * width1 {
stp = stx + p5 * (sty - stx);
}
width1 = width;
width = (sty - stx).abs();
}
if brackt {
stmin = stx.min(sty);
stmax = stx.max(sty);
} else {
stmin = stp + xtrapl * (stp - stx);
stmax = stp + xtrapu * (stp - stx);
}
stp = stp.max(self.stpmin).min(self.stpmax);
if brackt && (stp <= stmin || stp >= stmax || stmax - stmin <= self.xtol * stmax) {
stp = stx;
}
}
Ok(stx)
}
}
#[allow(clippy::too_many_arguments)]
fn dcstep<F: Scalar>(
stx: &mut F,
fx: &mut F,
dx: &mut F,
sty: &mut F,
fy: &mut F,
dy: &mut F,
stp: &mut F,
fp: F,
dp: F,
brackt: &mut bool,
stpmin: F,
stpmax: F,
) {
let zero = F::zero();
let two = F::from_f64(2.0).unwrap();
let three = F::from_f64(3.0).unwrap();
let p66 = F::from_f64(0.66).unwrap();
let sgnd = dp * (*dx / dx.abs());
let stpf;
if fp > *fx {
let theta = three * (*fx - fp) / (*stp - *stx) + *dx + dp;
let s = theta.abs().max(dx.abs()).max(dp.abs());
let mut gamma = s * ((theta / s).powi(2) - (*dx / s) * (dp / s)).sqrt();
if *stp < *stx {
gamma = -gamma;
}
let p = (gamma - *dx) + theta;
let q = ((gamma - *dx) + gamma) + dp;
let r = p / q;
let stpc = *stx + r * (*stp - *stx);
let stpq = *stx + ((*dx / ((*fx - fp) / (*stp - *stx) + *dx)) / two) * (*stp - *stx);
stpf = if (stpc - *stx).abs() < (stpq - *stx).abs() {
stpc
} else {
stpc + (stpq - stpc) / two
};
*brackt = true;
} else if sgnd < zero {
let theta = three * (*fx - fp) / (*stp - *stx) + *dx + dp;
let s = theta.abs().max(dx.abs()).max(dp.abs());
let mut gamma = s * ((theta / s).powi(2) - (*dx / s) * (dp / s)).sqrt();
if *stp > *stx {
gamma = -gamma;
}
let p = (gamma - dp) + theta;
let q = ((gamma - dp) + gamma) + *dx;
let r = p / q;
let stpc = *stp + r * (*stx - *stp);
let stpq = *stp + (dp / (dp - *dx)) * (*stx - *stp);
stpf = if (stpc - *stp).abs() > (stpq - *stp).abs() {
stpc
} else {
stpq
};
*brackt = true;
} else if dp.abs() < dx.abs() {
let theta = three * (*fx - fp) / (*stp - *stx) + *dx + dp;
let s = theta.abs().max(dx.abs()).max(dp.abs());
let mut gamma = s * (zero.max((theta / s).powi(2) - (*dx / s) * (dp / s))).sqrt();
if *stp > *stx {
gamma = -gamma;
}
let p = (gamma - dp) + theta;
let q = (gamma + (*dx - dp)) + gamma;
let r = p / q;
let stpc = if r < zero && gamma != zero {
*stp + r * (*stx - *stp)
} else if *stp > *stx {
stpmax
} else {
stpmin
};
let stpq = *stp + (dp / (dp - *dx)) * (*stx - *stp);
stpf = if *brackt {
let cand = if (stpc - *stp).abs() < (stpq - *stp).abs() {
stpc
} else {
stpq
};
if *stp > *stx {
(*stp + p66 * (*sty - *stp)).min(cand)
} else {
(*stp + p66 * (*sty - *stp)).max(cand)
}
} else {
let cand = if (stpc - *stp).abs() > (stpq - *stp).abs() {
stpc
} else {
stpq
};
cand.min(stpmax).max(stpmin)
};
} else {
stpf = if *brackt {
let theta = three * (fp - *fy) / (*sty - *stp) + *dy + dp;
let s = theta.abs().max(dy.abs()).max(dp.abs());
let mut gamma = s * ((theta / s).powi(2) - (*dy / s) * (dp / s)).sqrt();
if *stp > *sty {
gamma = -gamma;
}
let p = (gamma - dp) + theta;
let q = ((gamma - dp) + gamma) + *dy;
let r = p / q;
*stp + r * (*sty - *stp)
} else if *stp > *stx {
stpmax
} else {
stpmin
};
}
if fp > *fx {
*sty = *stp;
*fy = fp;
*dy = dp;
} else {
if sgnd < zero {
*sty = *stx;
*fy = *fx;
*dy = *dx;
}
*stx = *stp;
*fx = fp;
*dx = dp;
}
*stp = stpf;
}
#[cfg(test)]
mod tests {
use super::*;
struct Quadratic;
impl CostFunction for Quadratic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
Ok((x[0] - 3.0).powi(2))
}
}
impl Gradient for Quadratic {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, std::convert::Infallible> {
Ok(vec![2.0 * (x[0] - 3.0)])
}
}
struct Cubic;
impl CostFunction for Cubic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
let t = x[0] - 2.0;
Ok(t.powi(3) - 3.0 * t)
}
}
impl Gradient for Cubic {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, std::convert::Infallible> {
let t = x[0] - 2.0;
Ok(vec![3.0 * t.powi(2) - 3.0])
}
}
#[test]
fn satisfies_strong_wolfe_on_quadratic() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![-g[0]]; let mut ls = MoreThuente::new();
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!(alpha > 0.0);
let mut x_new = x.clone();
x_new[0] += alpha * d[0];
let f_new = p.cost(&x_new).unwrap();
let g_new = p.gradient(&x_new).unwrap();
let g0_dot_d = g[0] * d[0];
let gnew_dot_d = g_new[0] * d[0];
assert!(
f_new <= f0 + ls.ftol * alpha * g0_dot_d + 1e-12,
"Armijo failed",
);
assert!(
gnew_dot_d.abs() <= -ls.gtol * g0_dot_d + 1e-12,
"Strong curvature failed",
);
}
#[test]
fn unit_step_accepted_when_quadratic_minimum_within_initial_step() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![6.0];
let mut ls = MoreThuente::new();
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!(
(alpha - 0.5).abs() < 0.5,
"expected α near 0.5, got {alpha}",
);
let mut x_new = x.clone();
x_new[0] += alpha * d[0];
let f_new = p.cost(&x_new).unwrap();
assert!(f_new <= f0 + ls.ftol * alpha * (g[0] * d[0]) + 1e-12);
}
#[test]
fn ascent_direction_returns_zero_step() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap(); let baseline = *p.counts();
let d = vec![g[0]]; let mut ls = MoreThuente::new();
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert_eq!(alpha, 0.0);
assert_eq!(p.counts().cost_evals, baseline.cost_evals);
assert_eq!(p.counts().gradient_evals, baseline.gradient_evals);
}
#[test]
fn cubic_satisfies_wolfe_on_nontrivial_function() {
let mut p = Problem::new(Cubic);
let x = vec![5.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![-1.0];
let mut ls = MoreThuente::new().alpha_init(3.0);
let alpha = LineSearch::<Cubic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!(alpha > 0.0);
let mut x_new = x.clone();
x_new[0] += alpha * d[0];
let f_new = p.cost(&x_new).unwrap();
let g_new = p.gradient(&x_new).unwrap();
let g0_dot_d = g[0] * d[0];
let gnew_dot_d = g_new[0] * d[0];
assert!(
f_new <= f0 + ls.ftol * alpha * g0_dot_d + 1e-12,
"Armijo failed at α={alpha}: f_new={f_new}, threshold={}",
f0 + ls.ftol * alpha * g0_dot_d,
);
assert!(
gnew_dot_d.abs() <= -ls.gtol * g0_dot_d + 1e-12,
"Strong curvature failed at α={alpha}: |g·d|={}, threshold={}",
gnew_dot_d.abs(),
-ls.gtol * g0_dot_d,
);
}
#[test]
fn respects_stpmax_when_minimum_is_beyond() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![6.0];
let mut ls = MoreThuente::new().stpmax(0.1).alpha_init(0.1);
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!((alpha - 0.1).abs() < 1e-12, "expected α=0.1, got {alpha}",);
}
}