use crate::core::math::{Dot, ScaledAdd};
use crate::core::problem::{CostFunction, Gradient, Problem};
use crate::line_search::LineSearch;
pub struct MoreThuente {
pub ftol: f64,
pub gtol: f64,
pub xtol: f64,
pub alpha_init: f64,
pub stpmin: f64,
pub stpmax: f64,
pub maxfev: u32,
}
impl Default for MoreThuente {
fn default() -> Self {
Self {
ftol: 1.0e-3,
gtol: 0.9,
xtol: 0.1,
alpha_init: 1.0,
stpmin: 0.0,
stpmax: 1.0e10,
maxfev: 20,
}
}
}
impl MoreThuente {
pub fn new() -> Self {
Self::default()
}
pub fn ftol(mut self, ftol: f64) -> Self {
assert!(0.0 < ftol && ftol < 1.0, "ftol must be in (0, 1)");
self.ftol = ftol;
self
}
pub fn gtol(mut self, gtol: f64) -> Self {
assert!(0.0 < gtol && gtol < 1.0, "gtol must be in (0, 1)");
self.gtol = gtol;
self
}
pub fn xtol(mut self, xtol: f64) -> Self {
assert!(xtol >= 0.0, "xtol must be ≥ 0");
self.xtol = xtol;
self
}
pub fn alpha_init(mut self, alpha_init: f64) -> Self {
assert!(alpha_init > 0.0, "alpha_init must be > 0");
self.alpha_init = alpha_init;
self
}
pub fn stpmin(mut self, stpmin: f64) -> Self {
assert!(stpmin >= 0.0, "stpmin must be ≥ 0");
self.stpmin = stpmin;
self
}
pub fn stpmax(mut self, stpmax: f64) -> Self {
assert!(stpmax > 0.0, "stpmax must be > 0");
self.stpmax = stpmax;
self
}
pub fn maxfev(mut self, maxfev: u32) -> Self {
self.maxfev = maxfev;
self
}
}
impl<P, V> LineSearch<P, V> for MoreThuente
where
P: CostFunction<Param = V, Output = f64> + Gradient<Gradient = V>,
V: ScaledAdd<f64> + Dot + Clone,
{
type Error = P::Error;
fn next(
&mut self,
problem: &mut Problem<P>,
param: &V,
cost: f64,
gradient: &V,
direction: &V,
) -> Result<f64, Self::Error> {
let finit = cost;
let ginit = gradient.dot(direction);
if !ginit.is_finite() || ginit >= 0.0 {
return Ok(0.0);
}
if !(self.alpha_init >= self.stpmin && self.alpha_init <= self.stpmax) {
return Ok(0.0);
}
let mut brackt = false;
let mut stage: u8 = 1;
let gtest = self.ftol * ginit;
let mut width = self.stpmax - self.stpmin;
let mut width1 = width / P5;
let mut stx = 0.0;
let mut fx = finit;
let mut gx = ginit;
let mut sty = 0.0;
let mut fy = finit;
let mut gy = ginit;
let mut stmin = 0.0;
let mut stmax = self.alpha_init + XTRAPU * self.alpha_init;
let mut stp = self.alpha_init;
for _ in 0..self.maxfev {
let mut trial = param.clone();
trial.scaled_add(stp, direction);
let f = problem.cost(&trial)?;
let g_full = problem.gradient(&trial)?;
let g = g_full.dot(direction);
let ftest = finit + stp * gtest;
if stage == 1 && f <= ftest && g >= 0.0 {
stage = 2;
}
let warn_rounding = brackt && (stp <= stmin || stp >= stmax);
let warn_xtol = brackt && stmax - stmin <= self.xtol * stmax;
let warn_stpmax = stp == self.stpmax && f <= ftest && g <= gtest;
let warn_stpmin = stp == self.stpmin && (f > ftest || g >= gtest);
let converged = f <= ftest && g.abs() <= self.gtol * (-ginit);
if warn_rounding || warn_xtol || warn_stpmax || warn_stpmin || converged {
return Ok(stp);
}
if stage == 1 && f <= fx && f > ftest {
let fm = f - stp * gtest;
let mut fxm = fx - stx * gtest;
let mut fym = fy - sty * gtest;
let gm = g - gtest;
let mut gxm = gx - gtest;
let mut gym = gy - gtest;
dcstep(
&mut stx,
&mut fxm,
&mut gxm,
&mut sty,
&mut fym,
&mut gym,
&mut stp,
fm,
gm,
&mut brackt,
stmin,
stmax,
);
fx = fxm + stx * gtest;
fy = fym + sty * gtest;
gx = gxm + gtest;
gy = gym + gtest;
} else {
dcstep(
&mut stx,
&mut fx,
&mut gx,
&mut sty,
&mut fy,
&mut gy,
&mut stp,
f,
g,
&mut brackt,
stmin,
stmax,
);
}
if brackt {
if (sty - stx).abs() >= P66 * width1 {
stp = stx + P5 * (sty - stx);
}
width1 = width;
width = (sty - stx).abs();
}
if brackt {
stmin = stx.min(sty);
stmax = stx.max(sty);
} else {
stmin = stp + XTRAPL * (stp - stx);
stmax = stp + XTRAPU * (stp - stx);
}
stp = stp.max(self.stpmin).min(self.stpmax);
if brackt && (stp <= stmin || stp >= stmax || stmax - stmin <= self.xtol * stmax) {
stp = stx;
}
}
Ok(stx)
}
}
const P5: f64 = 0.5;
const P66: f64 = 0.66;
const XTRAPL: f64 = 1.1;
const XTRAPU: f64 = 4.0;
#[allow(clippy::too_many_arguments)]
fn dcstep(
stx: &mut f64,
fx: &mut f64,
dx: &mut f64,
sty: &mut f64,
fy: &mut f64,
dy: &mut f64,
stp: &mut f64,
fp: f64,
dp: f64,
brackt: &mut bool,
stpmin: f64,
stpmax: f64,
) {
let sgnd = dp * (*dx / dx.abs());
let stpf;
if fp > *fx {
let theta = 3.0 * (*fx - fp) / (*stp - *stx) + *dx + dp;
let s = theta.abs().max(dx.abs()).max(dp.abs());
let mut gamma = s * ((theta / s).powi(2) - (*dx / s) * (dp / s)).sqrt();
if *stp < *stx {
gamma = -gamma;
}
let p = (gamma - *dx) + theta;
let q = ((gamma - *dx) + gamma) + dp;
let r = p / q;
let stpc = *stx + r * (*stp - *stx);
let stpq = *stx + ((*dx / ((*fx - fp) / (*stp - *stx) + *dx)) / 2.0) * (*stp - *stx);
stpf = if (stpc - *stx).abs() < (stpq - *stx).abs() {
stpc
} else {
stpc + (stpq - stpc) / 2.0
};
*brackt = true;
} else if sgnd < 0.0 {
let theta = 3.0 * (*fx - fp) / (*stp - *stx) + *dx + dp;
let s = theta.abs().max(dx.abs()).max(dp.abs());
let mut gamma = s * ((theta / s).powi(2) - (*dx / s) * (dp / s)).sqrt();
if *stp > *stx {
gamma = -gamma;
}
let p = (gamma - dp) + theta;
let q = ((gamma - dp) + gamma) + *dx;
let r = p / q;
let stpc = *stp + r * (*stx - *stp);
let stpq = *stp + (dp / (dp - *dx)) * (*stx - *stp);
stpf = if (stpc - *stp).abs() > (stpq - *stp).abs() {
stpc
} else {
stpq
};
*brackt = true;
} else if dp.abs() < dx.abs() {
let theta = 3.0 * (*fx - fp) / (*stp - *stx) + *dx + dp;
let s = theta.abs().max(dx.abs()).max(dp.abs());
let mut gamma = s
* (0.0_f64)
.max((theta / s).powi(2) - (*dx / s) * (dp / s))
.sqrt();
if *stp > *stx {
gamma = -gamma;
}
let p = (gamma - dp) + theta;
let q = (gamma + (*dx - dp)) + gamma;
let r = p / q;
let stpc = if r < 0.0 && gamma != 0.0 {
*stp + r * (*stx - *stp)
} else if *stp > *stx {
stpmax
} else {
stpmin
};
let stpq = *stp + (dp / (dp - *dx)) * (*stx - *stp);
stpf = if *brackt {
let cand = if (stpc - *stp).abs() < (stpq - *stp).abs() {
stpc
} else {
stpq
};
if *stp > *stx {
(*stp + P66 * (*sty - *stp)).min(cand)
} else {
(*stp + P66 * (*sty - *stp)).max(cand)
}
} else {
let cand = if (stpc - *stp).abs() > (stpq - *stp).abs() {
stpc
} else {
stpq
};
cand.min(stpmax).max(stpmin)
};
} else {
stpf = if *brackt {
let theta = 3.0 * (fp - *fy) / (*sty - *stp) + *dy + dp;
let s = theta.abs().max(dy.abs()).max(dp.abs());
let mut gamma = s * ((theta / s).powi(2) - (*dy / s) * (dp / s)).sqrt();
if *stp > *sty {
gamma = -gamma;
}
let p = (gamma - dp) + theta;
let q = ((gamma - dp) + gamma) + *dy;
let r = p / q;
*stp + r * (*sty - *stp)
} else if *stp > *stx {
stpmax
} else {
stpmin
};
}
if fp > *fx {
*sty = *stp;
*fy = fp;
*dy = dp;
} else {
if sgnd < 0.0 {
*sty = *stx;
*fy = *fx;
*dy = *dx;
}
*stx = *stp;
*fx = fp;
*dx = dp;
}
*stp = stpf;
}
#[cfg(test)]
mod tests {
use super::*;
struct Quadratic;
impl CostFunction for Quadratic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
Ok((x[0] - 3.0).powi(2))
}
}
impl Gradient for Quadratic {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, std::convert::Infallible> {
Ok(vec![2.0 * (x[0] - 3.0)])
}
}
struct Cubic;
impl CostFunction for Cubic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
let t = x[0] - 2.0;
Ok(t.powi(3) - 3.0 * t)
}
}
impl Gradient for Cubic {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, std::convert::Infallible> {
let t = x[0] - 2.0;
Ok(vec![3.0 * t.powi(2) - 3.0])
}
}
#[test]
fn satisfies_strong_wolfe_on_quadratic() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![-g[0]]; let mut ls = MoreThuente::new();
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!(alpha > 0.0);
let mut x_new = x.clone();
x_new[0] += alpha * d[0];
let f_new = p.cost(&x_new).unwrap();
let g_new = p.gradient(&x_new).unwrap();
let g0_dot_d = g[0] * d[0];
let gnew_dot_d = g_new[0] * d[0];
assert!(
f_new <= f0 + ls.ftol * alpha * g0_dot_d + 1e-12,
"Armijo failed",
);
assert!(
gnew_dot_d.abs() <= -ls.gtol * g0_dot_d + 1e-12,
"Strong curvature failed",
);
}
#[test]
fn unit_step_accepted_when_quadratic_minimum_within_initial_step() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![6.0];
let mut ls = MoreThuente::new();
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!(
(alpha - 0.5).abs() < 0.5,
"expected α near 0.5, got {alpha}",
);
let mut x_new = x.clone();
x_new[0] += alpha * d[0];
let f_new = p.cost(&x_new).unwrap();
assert!(f_new <= f0 + ls.ftol * alpha * (g[0] * d[0]) + 1e-12);
}
#[test]
fn ascent_direction_returns_zero_step() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap(); let baseline = *p.counts();
let d = vec![g[0]]; let mut ls = MoreThuente::new();
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert_eq!(alpha, 0.0);
assert_eq!(p.counts().cost_evals, baseline.cost_evals);
assert_eq!(p.counts().gradient_evals, baseline.gradient_evals);
}
#[test]
fn cubic_satisfies_wolfe_on_nontrivial_function() {
let mut p = Problem::new(Cubic);
let x = vec![5.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![-1.0];
let mut ls = MoreThuente::new().alpha_init(3.0);
let alpha = LineSearch::<Cubic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!(alpha > 0.0);
let mut x_new = x.clone();
x_new[0] += alpha * d[0];
let f_new = p.cost(&x_new).unwrap();
let g_new = p.gradient(&x_new).unwrap();
let g0_dot_d = g[0] * d[0];
let gnew_dot_d = g_new[0] * d[0];
assert!(
f_new <= f0 + ls.ftol * alpha * g0_dot_d + 1e-12,
"Armijo failed at α={alpha}: f_new={f_new}, threshold={}",
f0 + ls.ftol * alpha * g0_dot_d,
);
assert!(
gnew_dot_d.abs() <= -ls.gtol * g0_dot_d + 1e-12,
"Strong curvature failed at α={alpha}: |g·d|={}, threshold={}",
gnew_dot_d.abs(),
-ls.gtol * g0_dot_d,
);
}
#[test]
fn respects_stpmax_when_minimum_is_beyond() {
let mut p = Problem::new(Quadratic);
let x = vec![0.0];
let f0 = p.cost(&x).unwrap();
let g = p.gradient(&x).unwrap();
let d = vec![6.0];
let mut ls = MoreThuente::new().stpmax(0.1).alpha_init(0.1);
let alpha =
LineSearch::<Quadratic, Vec<f64>>::next(&mut ls, &mut p, &x, f0, &g, &d).unwrap();
assert!((alpha - 0.1).abs() < 1e-12, "expected α=0.1, got {alpha}",);
}
}