use crate::core::math::{Dot, ScaledAdd};
use crate::core::problem::{CostFunction, Problem};
use crate::line_search::LineSearch;
pub struct Backtracking {
pub alpha_init: f64,
pub rho: f64,
pub c: f64,
pub max_iter: u32,
}
impl Default for Backtracking {
fn default() -> Self {
Self {
alpha_init: 1.0,
rho: 0.5,
c: 1e-4,
max_iter: 50,
}
}
}
impl Backtracking {
pub fn new() -> Self {
Self::default()
}
pub fn alpha_init(mut self, alpha_init: f64) -> Self {
self.alpha_init = alpha_init;
self
}
pub fn rho(mut self, rho: f64) -> Self {
self.rho = rho;
self
}
pub fn c(mut self, c: f64) -> Self {
self.c = c;
self
}
pub fn max_iter(mut self, max_iter: u32) -> Self {
self.max_iter = max_iter;
self
}
}
impl<P, V> LineSearch<P, V> for Backtracking
where
P: CostFunction<Param = V, Output = f64>,
V: ScaledAdd<f64> + Dot + Clone,
{
type Error = P::Error;
fn next(
&mut self,
problem: &mut Problem<P>,
param: &V,
cost: f64,
gradient: &V,
direction: &V,
) -> Result<f64, Self::Error> {
let g_dot_d = gradient.dot(direction);
let mut alpha = self.alpha_init;
for _ in 0..self.max_iter {
let mut trial = param.clone();
trial.scaled_add(alpha, direction);
let trial_cost = problem.cost(&trial)?;
if trial_cost <= cost + self.c * alpha * g_dot_d {
return Ok(alpha);
}
alpha *= self.rho;
}
Ok(alpha)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Quadratic;
impl CostFunction for Quadratic {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
Ok((x[0] - 3.0).powi(2))
}
}
fn run(ls: &mut Backtracking, x: &[f64], grad: &[f64], dir: &[f64]) -> (f64, u64) {
let mut p = Problem::new(Quadratic);
let x = x.to_vec();
let f0 = p.cost(&x).unwrap();
let baseline = p.counts().cost_evals;
let g = grad.to_vec();
let d = dir.to_vec();
let alpha = ls.next(&mut p, &x, f0, &g, &d).unwrap();
(alpha, p.counts().cost_evals - baseline)
}
#[test]
fn accepts_alpha_init_when_armijo_holds() {
let mut ls = Backtracking::new().alpha_init(0.5);
let (alpha, cost_evals) = run(&mut ls, &[2.0], &[-2.0], &[1.0]);
assert_eq!(alpha, 0.5, "expected α_init accepted on first try");
assert_eq!(cost_evals, 1);
}
#[test]
fn backtracks_when_initial_alpha_overshoots() {
let mut ls = Backtracking::new(); let (alpha, cost_evals) = run(&mut ls, &[0.0], &[-6.0], &[6.0]);
let f0 = 9.0; let f_new = (alpha * 6.0 - 3.0).powi(2);
let g_dot_d = (-6.0_f64) * 6.0;
assert!(
f_new <= f0 + 1e-4 * alpha * g_dot_d,
"Armijo violated: f_new={f_new}, threshold={}",
f0 + 1e-4 * alpha * g_dot_d,
);
assert!(alpha < 1.0, "expected backtrack, got α={alpha}");
assert!(cost_evals > 1);
}
#[test]
fn reports_cost_eval_count() {
let mut ls = Backtracking::new().rho(0.5);
let (_, cost_evals) = run(&mut ls, &[0.0], &[-6.0], &[6.0]);
assert!(cost_evals >= 1);
assert!(
cost_evals <= ls.max_iter as u64,
"cost_evals={cost_evals} exceeds max_iter={}",
ls.max_iter
);
}
#[test]
fn caps_at_max_iter_when_armijo_never_holds() {
let mut ls = Backtracking::new().max_iter(5);
let (alpha, cost_evals) = run(&mut ls, &[0.0], &[-6.0], &[-6.0]);
assert_eq!(cost_evals, 5);
assert!(
(alpha - 1.0 / 32.0).abs() < 1e-12,
"expected α=1/32, got {alpha}",
);
}
}