Skip to main content

echidna_optim/
line_search.rs

1use num_traits::Float;
2
3use crate::convergence::dot;
4use crate::objective::Objective;
5
6/// Parameters for the backtracking Armijo line search.
7#[derive(Debug, Clone)]
8pub struct ArmijoParams<F> {
9    /// Sufficient decrease parameter (default: 1e-4).
10    pub c: F,
11    /// Backtracking factor (default: 0.5).
12    pub rho: F,
13    /// Initial step size (default: 1.0).
14    pub alpha_init: F,
15    /// Minimum step size before declaring failure (default: 1e-16).
16    pub alpha_min: F,
17}
18
19impl Default for ArmijoParams<f64> {
20    fn default() -> Self {
21        ArmijoParams {
22            c: 1e-4,
23            rho: 0.5,
24            alpha_init: 1.0,
25            alpha_min: 1e-16,
26        }
27    }
28}
29
30impl Default for ArmijoParams<f32> {
31    fn default() -> Self {
32        ArmijoParams {
33            c: 1e-4,
34            rho: 0.5,
35            alpha_init: 1.0,
36            alpha_min: 1e-8,
37        }
38    }
39}
40
41/// Result of a successful line search.
42#[derive(Debug)]
43pub struct LineSearchResult<F> {
44    /// The accepted step size.
45    pub alpha: F,
46    /// Objective value at `x + alpha * d`.
47    pub value: F,
48    /// Gradient at `x + alpha * d`.
49    pub gradient: Vec<F>,
50    /// Number of function evaluations used.
51    pub evals: usize,
52}
53
54/// Backtracking line search satisfying the Armijo (sufficient decrease) condition.
55///
56/// Searches for `alpha` such that `f(x + alpha*d) <= f(x) + c * alpha * g^T d`.
57///
58/// Returns `None` if `alpha` falls below `alpha_min` (line search failure),
59/// which includes the case where every trial point returned a non-finite
60/// objective value or gradient — treated as infeasible and backtracked past.
61pub fn backtracking_armijo<F: Float, O: Objective<F>>(
62    obj: &mut O,
63    x: &[F],
64    d: &[F],
65    f_x: F,
66    grad_x: &[F],
67    params: &ArmijoParams<F>,
68) -> Option<LineSearchResult<F>> {
69    let n = x.len();
70    let dg = dot(grad_x, d);
71
72    // Not a descent direction — caller should handle this
73    if dg >= F::zero() {
74        return None;
75    }
76
77    let mut alpha = params.alpha_init;
78    let mut x_new = vec![F::zero(); n];
79    let mut evals = 0;
80
81    loop {
82        if alpha < params.alpha_min {
83            return None;
84        }
85
86        for i in 0..n {
87            x_new[i] = x[i] + alpha * d[i];
88        }
89
90        let (f_new, g_new) = obj.eval_grad(&x_new);
91        evals += 1;
92
93        // Reject infeasible trial points: `-Inf <= anything` is trivially
94        // true, so the Armijo check would accept a step off the domain and
95        // the solver would walk toward -Inf indefinitely. A NaN `f_new`
96        // falls through (`NaN <= x` is false) but is rejected here for
97        // symmetry. Either case is treated as "backtrack past this α".
98        if !f_new.is_finite() || !g_new.iter().all(|g| g.is_finite()) {
99            alpha = alpha * params.rho;
100            continue;
101        }
102
103        // Armijo condition: f(x + alpha*d) <= f(x) + c * alpha * g^T d
104        if f_new <= f_x + params.c * alpha * dg {
105            return Some(LineSearchResult {
106                alpha,
107                value: f_new,
108                gradient: g_new,
109                evals,
110            });
111        }
112
113        alpha = alpha * params.rho;
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    /// Simple quadratic objective for testing: f(x) = 0.5 * (x0^2 + x1^2)
122    struct Quadratic;
123
124    impl Objective<f64> for Quadratic {
125        fn dim(&self) -> usize {
126            2
127        }
128
129        fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
130            let f = 0.5 * (x[0] * x[0] + x[1] * x[1]);
131            let g = vec![x[0], x[1]];
132            (f, g)
133        }
134    }
135
136    #[test]
137    fn armijo_quadratic_descent() {
138        let mut obj = Quadratic;
139        let x = vec![2.0, 3.0];
140        let (f_x, grad) = obj.eval_grad(&x);
141        // Steepest descent direction
142        let d: Vec<f64> = grad.iter().map(|&g| -g).collect();
143
144        let result =
145            backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default()).unwrap();
146
147        assert!(result.alpha > 0.0);
148        assert!(result.value < f_x, "line search should decrease objective");
149    }
150
151    #[test]
152    fn armijo_full_step_on_quadratic() {
153        let mut obj = Quadratic;
154        let x = vec![2.0, 3.0];
155        let (f_x, grad) = obj.eval_grad(&x);
156        let d: Vec<f64> = grad.iter().map(|&g| -g).collect();
157
158        let result =
159            backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default()).unwrap();
160
161        // For a quadratic, steepest descent with alpha=1 satisfies Armijo with c=1e-4
162        assert!(
163            (result.alpha - 1.0).abs() < 1e-12,
164            "full step should be accepted on quadratic, got alpha={}",
165            result.alpha
166        );
167    }
168
169    #[test]
170    fn armijo_non_descent_returns_none() {
171        let mut obj = Quadratic;
172        let x = vec![2.0, 3.0];
173        let (f_x, grad) = obj.eval_grad(&x);
174        // Ascent direction (same as gradient)
175        let d = grad.clone();
176
177        let result = backtracking_armijo(&mut obj, &x, &d, f_x, &grad, &ArmijoParams::default());
178        assert!(result.is_none());
179    }
180}