Skip to main content

numra_nonlinear/
line_search.rs

1//! Wolfe line search for globalized optimization.
2//!
3//! Implements Algorithm 3.6 from Nocedal & Wright, *Numerical Optimization*,
4//! consisting of a bracket phase that finds an interval containing a step
5//! satisfying the strong Wolfe conditions, followed by a zoom (bisection)
6//! phase that narrows the interval.
7//!
8//! Author: Moussa Leblouba
9//! Date: 5 March 2026
10//! Modified: 2 May 2026
11
12use numra_core::Scalar;
13
14/// Errors that can occur during line search.
15#[derive(Debug, Clone)]
16pub enum LineSearchError {
17    /// The search direction is not a descent direction.
18    NotDescentDirection,
19    /// Zoom bracket collapsed to zero width.
20    BracketCollapsed,
21    /// Maximum iterations reached without satisfying Wolfe conditions.
22    MaxIterations,
23}
24
25impl std::fmt::Display for LineSearchError {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            Self::NotDescentDirection => write!(f, "search direction is not a descent direction"),
29            Self::BracketCollapsed => write!(f, "zoom bracket collapsed"),
30            Self::MaxIterations => write!(f, "max line search iterations reached"),
31        }
32    }
33}
34
35impl std::error::Error for LineSearchError {}
36
37/// Options controlling the Wolfe line search.
38#[derive(Debug, Clone)]
39pub struct WolfeOptions<S: Scalar> {
40    /// Sufficient-decrease (Armijo) parameter, typically 1e-4.
41    pub c1: S,
42    /// Curvature parameter, typically 0.9 for quasi-Newton, 0.1 for CG.
43    pub c2: S,
44    /// Maximum allowed step length.
45    pub max_step: S,
46    /// Maximum number of bracket-phase iterations.
47    pub max_iter: usize,
48}
49
50impl<S: Scalar> Default for WolfeOptions<S> {
51    fn default() -> Self {
52        Self {
53            c1: S::from_f64(1e-4),
54            c2: S::from_f64(0.9),
55            max_step: S::from_f64(1e20),
56            max_iter: 40,
57        }
58    }
59}
60
61/// Result returned by [`wolfe_line_search`].
62#[derive(Debug, Clone)]
63pub struct LineSearchResult<S: Scalar> {
64    /// Accepted step length.
65    pub step: S,
66    /// Function value at the accepted point.
67    pub f_new: S,
68    /// Total number of function evaluations performed.
69    pub n_eval: usize,
70}
71
72/// Inner product of two slices.
73fn dot<S: Scalar>(a: &[S], b: &[S]) -> S {
74    a.iter()
75        .zip(b.iter())
76        .map(|(&ai, &bi)| ai * bi)
77        .fold(S::ZERO, |acc, x| acc + x)
78}
79
80/// Bracket collapse tolerance for detecting degenerate intervals.
81const BRACKET_COLLAPSE_TOL: f64 = 1e-16;
82
83/// Zoom phase (Algorithm 3.6, lines 12-21 in Nocedal & Wright).
84///
85/// Bisects `[alpha_lo, alpha_hi]` until the strong Wolfe conditions are met.
86#[allow(clippy::too_many_arguments)]
87fn zoom<S, F, G>(
88    f: &F,
89    grad: &G,
90    x: &[S],
91    d: &[S],
92    f0: S,
93    dg0: S,
94    mut alpha_lo: S,
95    mut f_lo: S,
96    mut alpha_hi: S,
97    opts: &WolfeOptions<S>,
98    n_eval: &mut usize,
99) -> Result<LineSearchResult<S>, LineSearchError>
100where
101    S: Scalar,
102    F: Fn(&[S]) -> S,
103    G: Fn(&[S], &mut [S]),
104{
105    let n = x.len();
106    let mut x_trial = vec![S::ZERO; n];
107    let mut g_trial = vec![S::ZERO; n];
108
109    for _ in 0..opts.max_iter {
110        if (alpha_hi - alpha_lo).abs() < S::from_f64(BRACKET_COLLAPSE_TOL) {
111            return Err(LineSearchError::BracketCollapsed);
112        }
113
114        let alpha_j = (alpha_lo + alpha_hi) / S::TWO;
115
116        // Evaluate f at trial point
117        for i in 0..n {
118            x_trial[i] = x[i] + alpha_j * d[i];
119        }
120        let f_j = f(&x_trial);
121        *n_eval += 1;
122
123        if f_j > f0 + opts.c1 * alpha_j * dg0 || f_j >= f_lo {
124            // Armijo fails or no improvement over lo — shrink from the right
125            alpha_hi = alpha_j;
126        } else {
127            // Armijo holds — check curvature
128            grad(&x_trial, &mut g_trial);
129            let dg_j = dot(&g_trial, d);
130
131            if dg_j.abs() <= -opts.c2 * dg0 {
132                // Strong Wolfe satisfied
133                return Ok(LineSearchResult {
134                    step: alpha_j,
135                    f_new: f_j,
136                    n_eval: *n_eval,
137                });
138            }
139
140            if dg_j * (alpha_hi - alpha_lo) >= S::ZERO {
141                alpha_hi = alpha_lo;
142            }
143
144            alpha_lo = alpha_j;
145            f_lo = f_j;
146        }
147    }
148
149    // Exhausted iterations — return the best we found
150    Err(LineSearchError::MaxIterations)
151}
152
153/// Wolfe line search (Nocedal & Wright Algorithm 3.6).
154///
155/// Given a current point `x`, a descent direction `d`, the current function
156/// value `f0`, and the current gradient `g0`, find a step length `alpha`
157/// such that the **strong Wolfe conditions** hold:
158///
159/// 1. Sufficient decrease (Armijo): `f(x + alpha*d) <= f0 + c1 * alpha * g0^T d`
160/// 2. Curvature: `|grad f(x + alpha*d)^T d| <= c2 * |g0^T d|`
161///
162/// # Errors
163///
164/// Returns `Err` if `d` is not a descent direction (`g0^T d >= 0`) or if
165/// the algorithm fails to find a satisfactory step within `max_iter`
166/// iterations.
167pub fn wolfe_line_search<S, F, G>(
168    f: F,
169    grad: G,
170    x: &[S],
171    d: &[S],
172    f0: S,
173    g0: &[S],
174    opts: &WolfeOptions<S>,
175) -> Result<LineSearchResult<S>, LineSearchError>
176where
177    S: Scalar,
178    F: Fn(&[S]) -> S,
179    G: Fn(&[S], &mut [S]),
180{
181    let dg0 = dot(g0, d);
182    if dg0 >= S::ZERO {
183        return Err(LineSearchError::NotDescentDirection);
184    }
185
186    let n = x.len();
187    let mut x_trial = vec![S::ZERO; n];
188    let mut g_trial = vec![S::ZERO; n];
189
190    let mut alpha_prev = S::ZERO;
191    let mut f_prev = f0;
192    let mut alpha = S::ONE;
193    let mut n_eval: usize = 0;
194
195    for i in 1..=opts.max_iter {
196        // Clamp to max_step
197        if alpha > opts.max_step {
198            alpha = opts.max_step;
199        }
200
201        // Evaluate f at trial point x + alpha * d
202        for j in 0..n {
203            x_trial[j] = x[j] + alpha * d[j];
204        }
205        let f_alpha = f(&x_trial);
206        n_eval += 1;
207
208        // Check Armijo condition or if function increased relative to previous step
209        if f_alpha > f0 + opts.c1 * alpha * dg0 || (i > 1 && f_alpha >= f_prev) {
210            return zoom(
211                &f,
212                &grad,
213                x,
214                d,
215                f0,
216                dg0,
217                alpha_prev,
218                f_prev,
219                alpha,
220                opts,
221                &mut n_eval,
222            );
223        }
224
225        // Armijo holds — check curvature (strong Wolfe)
226        grad(&x_trial, &mut g_trial);
227        let dg_alpha = dot(&g_trial, d);
228
229        if dg_alpha.abs() <= -opts.c2 * dg0 {
230            return Ok(LineSearchResult {
231                step: alpha,
232                f_new: f_alpha,
233                n_eval,
234            });
235        }
236
237        // If slope is positive, we overshot the minimum — zoom backwards
238        if dg_alpha >= S::ZERO {
239            return zoom(
240                &f,
241                &grad,
242                x,
243                d,
244                f0,
245                dg0,
246                alpha,
247                f_alpha,
248                alpha_prev,
249                opts,
250                &mut n_eval,
251            );
252        }
253
254        // Otherwise expand the step
255        alpha_prev = alpha;
256        f_prev = f_alpha;
257        alpha *= S::TWO;
258    }
259
260    Err(LineSearchError::MaxIterations)
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_wolfe_quadratic() {
269        // f(x) = x^2, grad = 2x
270        let f = |x: &[f64]| x[0] * x[0];
271        let grad = |x: &[f64], g: &mut [f64]| {
272            g[0] = 2.0 * x[0];
273        };
274
275        let x = [2.0];
276        let d = [-1.0];
277        let f0 = f(&x);
278        let g0 = [4.0]; // grad at x=2
279
280        let opts = WolfeOptions::default();
281        let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
282
283        assert!(res.step > 0.0, "step must be positive");
284        assert!(
285            res.f_new < f0,
286            "function must decrease: f_new={} vs f0={}",
287            res.f_new,
288            f0
289        );
290    }
291
292    #[test]
293    fn test_wolfe_rosenbrock() {
294        // Rosenbrock: f(x) = (1-x0)^2 + 100*(x1 - x0^2)^2
295        let f = |x: &[f64]| {
296            let a = 1.0 - x[0];
297            let b = x[1] - x[0] * x[0];
298            a * a + 100.0 * b * b
299        };
300        let grad = |x: &[f64], g: &mut [f64]| {
301            g[0] = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
302            g[1] = 200.0 * (x[1] - x[0] * x[0]);
303        };
304
305        let x = [-1.0, 1.0];
306        let f0 = f(&x);
307        let mut g0 = [0.0; 2];
308        grad(&x, &mut g0);
309
310        // Steepest descent direction
311        let d = [-g0[0], -g0[1]];
312
313        let opts = WolfeOptions::default();
314        let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
315
316        assert!(res.step > 0.0, "step must be positive");
317        assert!(
318            res.f_new < f0,
319            "function must decrease: f_new={} vs f0={}",
320            res.f_new,
321            f0
322        );
323    }
324
325    #[test]
326    fn test_wolfe_not_descent() {
327        let f = |x: &[f64]| x[0] * x[0];
328        let grad = |x: &[f64], g: &mut [f64]| {
329            g[0] = 2.0 * x[0];
330        };
331
332        let x = [2.0];
333        let d = [1.0]; // ascending direction
334        let f0 = f(&x);
335        let g0 = [4.0];
336
337        let opts = WolfeOptions::default();
338        let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts);
339
340        assert!(res.is_err(), "must reject non-descent direction");
341        assert!(
342            matches!(res.unwrap_err(), LineSearchError::NotDescentDirection),
343            "error should be NotDescentDirection"
344        );
345    }
346
347    #[test]
348    fn test_wolfe_f32() {
349        // Test with f32 to verify generic Scalar works
350        let f = |x: &[f32]| x[0] * x[0];
351        let grad = |x: &[f32], g: &mut [f32]| {
352            g[0] = 2.0 * x[0];
353        };
354
355        let x = [2.0f32];
356        let d = [-1.0f32];
357        let f0 = f(&x);
358        let g0 = [4.0f32];
359
360        let opts = WolfeOptions::default();
361        let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
362
363        assert!(res.step > 0.0, "step must be positive");
364        assert!(res.f_new < f0, "function must decrease");
365    }
366}