Skip to main content

numra_nonlinear/
line_search.rs

1//! Wolfe line search for globalized optimization.
2//!
3//! Implements Algorithm 3.6 from Nocedal & Wright, *Numerical Optimization*,
4//! consisting of a bracket phase that finds an interval containing a step
5//! satisfying the strong Wolfe conditions, followed by a zoom (bisection)
6//! phase that narrows the interval.
7//!
8//! Author: Moussa Leblouba
9//! Date: 5 March 2026
10//! Modified: 2 May 2026
11
12use numra_core::{NumraError, Scalar};
13
14/// Errors that can occur during line search.
15#[derive(Debug, Clone)]
16pub enum LineSearchError {
17    /// The search direction is not a descent direction.
18    NotDescentDirection,
19    /// Zoom bracket collapsed to zero width.
20    BracketCollapsed,
21    /// Maximum iterations reached without satisfying Wolfe conditions.
22    MaxIterations,
23}
24
25impl std::fmt::Display for LineSearchError {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            Self::NotDescentDirection => write!(f, "search direction is not a descent direction"),
29            Self::BracketCollapsed => write!(f, "zoom bracket collapsed"),
30            Self::MaxIterations => write!(f, "max line search iterations reached"),
31        }
32    }
33}
34
35impl std::error::Error for LineSearchError {}
36
37impl From<LineSearchError> for NumraError {
38    fn from(e: LineSearchError) -> Self {
39        NumraError::LineSearch(e.to_string())
40    }
41}
42
43/// Options controlling the Wolfe line search.
44#[derive(Debug, Clone)]
45pub struct WolfeOptions<S: Scalar> {
46    /// Sufficient-decrease (Armijo) parameter, typically 1e-4.
47    pub c1: S,
48    /// Curvature parameter, typically 0.9 for quasi-Newton, 0.1 for CG.
49    pub c2: S,
50    /// Maximum allowed step length.
51    pub max_step: S,
52    /// Maximum number of bracket-phase iterations.
53    pub max_iter: usize,
54}
55
56impl<S: Scalar> Default for WolfeOptions<S> {
57    fn default() -> Self {
58        Self {
59            c1: S::from_f64(1e-4),
60            c2: S::from_f64(0.9),
61            max_step: S::from_f64(1e20),
62            max_iter: 40,
63        }
64    }
65}
66
67/// Result returned by [`wolfe_line_search`].
68#[derive(Debug, Clone)]
69pub struct LineSearchResult<S: Scalar> {
70    /// Accepted step length.
71    pub step: S,
72    /// Function value at the accepted point.
73    pub f_new: S,
74    /// Total number of function evaluations performed.
75    pub n_eval: usize,
76}
77
78/// Inner product of two slices.
79fn dot<S: Scalar>(a: &[S], b: &[S]) -> S {
80    a.iter()
81        .zip(b.iter())
82        .map(|(&ai, &bi)| ai * bi)
83        .fold(S::ZERO, |acc, x| acc + x)
84}
85
86/// Bracket collapse tolerance for detecting degenerate intervals.
87const BRACKET_COLLAPSE_TOL: f64 = 1e-16;
88
89/// Zoom phase (Algorithm 3.6, lines 12-21 in Nocedal & Wright).
90///
91/// Bisects `[alpha_lo, alpha_hi]` until the strong Wolfe conditions are met.
92#[allow(clippy::too_many_arguments)]
93fn zoom<S, F, G>(
94    f: &F,
95    grad: &G,
96    x: &[S],
97    d: &[S],
98    f0: S,
99    dg0: S,
100    mut alpha_lo: S,
101    mut f_lo: S,
102    mut alpha_hi: S,
103    opts: &WolfeOptions<S>,
104    n_eval: &mut usize,
105) -> Result<LineSearchResult<S>, LineSearchError>
106where
107    S: Scalar,
108    F: Fn(&[S]) -> S,
109    G: Fn(&[S], &mut [S]),
110{
111    let n = x.len();
112    let mut x_trial = vec![S::ZERO; n];
113    let mut g_trial = vec![S::ZERO; n];
114
115    for _ in 0..opts.max_iter {
116        if (alpha_hi - alpha_lo).abs() < S::from_f64(BRACKET_COLLAPSE_TOL) {
117            return Err(LineSearchError::BracketCollapsed);
118        }
119
120        let alpha_j = (alpha_lo + alpha_hi) / S::TWO;
121
122        // Evaluate f at trial point
123        for i in 0..n {
124            x_trial[i] = x[i] + alpha_j * d[i];
125        }
126        let f_j = f(&x_trial);
127        *n_eval += 1;
128
129        if f_j > f0 + opts.c1 * alpha_j * dg0 || f_j >= f_lo {
130            // Armijo fails or no improvement over lo — shrink from the right
131            alpha_hi = alpha_j;
132        } else {
133            // Armijo holds — check curvature
134            grad(&x_trial, &mut g_trial);
135            let dg_j = dot(&g_trial, d);
136
137            if dg_j.abs() <= -opts.c2 * dg0 {
138                // Strong Wolfe satisfied
139                return Ok(LineSearchResult {
140                    step: alpha_j,
141                    f_new: f_j,
142                    n_eval: *n_eval,
143                });
144            }
145
146            if dg_j * (alpha_hi - alpha_lo) >= S::ZERO {
147                alpha_hi = alpha_lo;
148            }
149
150            alpha_lo = alpha_j;
151            f_lo = f_j;
152        }
153    }
154
155    // Exhausted iterations — return the best we found
156    Err(LineSearchError::MaxIterations)
157}
158
159/// Wolfe line search (Nocedal & Wright Algorithm 3.6).
160///
161/// Given a current point `x`, a descent direction `d`, the current function
162/// value `f0`, and the current gradient `g0`, find a step length `alpha`
163/// such that the **strong Wolfe conditions** hold:
164///
165/// 1. Sufficient decrease (Armijo): `f(x + alpha*d) <= f0 + c1 * alpha * g0^T d`
166/// 2. Curvature: `|grad f(x + alpha*d)^T d| <= c2 * |g0^T d|`
167///
168/// # Errors
169///
170/// Returns `Err` if `d` is not a descent direction (`g0^T d >= 0`) or if
171/// the algorithm fails to find a satisfactory step within `max_iter`
172/// iterations.
173pub fn wolfe_line_search<S, F, G>(
174    f: F,
175    grad: G,
176    x: &[S],
177    d: &[S],
178    f0: S,
179    g0: &[S],
180    opts: &WolfeOptions<S>,
181) -> Result<LineSearchResult<S>, LineSearchError>
182where
183    S: Scalar,
184    F: Fn(&[S]) -> S,
185    G: Fn(&[S], &mut [S]),
186{
187    let dg0 = dot(g0, d);
188    if dg0 >= S::ZERO {
189        return Err(LineSearchError::NotDescentDirection);
190    }
191
192    let n = x.len();
193    let mut x_trial = vec![S::ZERO; n];
194    let mut g_trial = vec![S::ZERO; n];
195
196    let mut alpha_prev = S::ZERO;
197    let mut f_prev = f0;
198    let mut alpha = S::ONE;
199    let mut n_eval: usize = 0;
200
201    for i in 1..=opts.max_iter {
202        // Clamp to max_step
203        if alpha > opts.max_step {
204            alpha = opts.max_step;
205        }
206
207        // Evaluate f at trial point x + alpha * d
208        for j in 0..n {
209            x_trial[j] = x[j] + alpha * d[j];
210        }
211        let f_alpha = f(&x_trial);
212        n_eval += 1;
213
214        // Check Armijo condition or if function increased relative to previous step
215        if f_alpha > f0 + opts.c1 * alpha * dg0 || (i > 1 && f_alpha >= f_prev) {
216            return zoom(
217                &f,
218                &grad,
219                x,
220                d,
221                f0,
222                dg0,
223                alpha_prev,
224                f_prev,
225                alpha,
226                opts,
227                &mut n_eval,
228            );
229        }
230
231        // Armijo holds — check curvature (strong Wolfe)
232        grad(&x_trial, &mut g_trial);
233        let dg_alpha = dot(&g_trial, d);
234
235        if dg_alpha.abs() <= -opts.c2 * dg0 {
236            return Ok(LineSearchResult {
237                step: alpha,
238                f_new: f_alpha,
239                n_eval,
240            });
241        }
242
243        // If slope is positive, we overshot the minimum — zoom backwards
244        if dg_alpha >= S::ZERO {
245            return zoom(
246                &f,
247                &grad,
248                x,
249                d,
250                f0,
251                dg0,
252                alpha,
253                f_alpha,
254                alpha_prev,
255                opts,
256                &mut n_eval,
257            );
258        }
259
260        // Otherwise expand the step
261        alpha_prev = alpha;
262        f_prev = f_alpha;
263        alpha *= S::TWO;
264    }
265
266    Err(LineSearchError::MaxIterations)
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_wolfe_quadratic() {
275        // f(x) = x^2, grad = 2x
276        let f = |x: &[f64]| x[0] * x[0];
277        let grad = |x: &[f64], g: &mut [f64]| {
278            g[0] = 2.0 * x[0];
279        };
280
281        let x = [2.0];
282        let d = [-1.0];
283        let f0 = f(&x);
284        let g0 = [4.0]; // grad at x=2
285
286        let opts = WolfeOptions::default();
287        let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
288
289        assert!(res.step > 0.0, "step must be positive");
290        assert!(
291            res.f_new < f0,
292            "function must decrease: f_new={} vs f0={}",
293            res.f_new,
294            f0
295        );
296    }
297
298    #[test]
299    fn test_wolfe_rosenbrock() {
300        // Rosenbrock: f(x) = (1-x0)^2 + 100*(x1 - x0^2)^2
301        let f = |x: &[f64]| {
302            let a = 1.0 - x[0];
303            let b = x[1] - x[0] * x[0];
304            a * a + 100.0 * b * b
305        };
306        let grad = |x: &[f64], g: &mut [f64]| {
307            g[0] = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
308            g[1] = 200.0 * (x[1] - x[0] * x[0]);
309        };
310
311        let x = [-1.0, 1.0];
312        let f0 = f(&x);
313        let mut g0 = [0.0; 2];
314        grad(&x, &mut g0);
315
316        // Steepest descent direction
317        let d = [-g0[0], -g0[1]];
318
319        let opts = WolfeOptions::default();
320        let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
321
322        assert!(res.step > 0.0, "step must be positive");
323        assert!(
324            res.f_new < f0,
325            "function must decrease: f_new={} vs f0={}",
326            res.f_new,
327            f0
328        );
329    }
330
331    #[test]
332    fn test_wolfe_not_descent() {
333        let f = |x: &[f64]| x[0] * x[0];
334        let grad = |x: &[f64], g: &mut [f64]| {
335            g[0] = 2.0 * x[0];
336        };
337
338        let x = [2.0];
339        let d = [1.0]; // ascending direction
340        let f0 = f(&x);
341        let g0 = [4.0];
342
343        let opts = WolfeOptions::default();
344        let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts);
345
346        assert!(res.is_err(), "must reject non-descent direction");
347        assert!(
348            matches!(res.unwrap_err(), LineSearchError::NotDescentDirection),
349            "error should be NotDescentDirection"
350        );
351    }
352
353    #[test]
354    fn test_wolfe_f32() {
355        // Test with f32 to verify generic Scalar works
356        let f = |x: &[f32]| x[0] * x[0];
357        let grad = |x: &[f32], g: &mut [f32]| {
358            g[0] = 2.0 * x[0];
359        };
360
361        let x = [2.0f32];
362        let d = [-1.0f32];
363        let f0 = f(&x);
364        let g0 = [4.0f32];
365
366        let opts = WolfeOptions::default();
367        let res = wolfe_line_search(f, grad, &x, &d, f0, &g0, &opts).unwrap();
368
369        assert!(res.step > 0.0, "step must be positive");
370        assert!(res.f_new < f0, "function must decrease");
371    }
372}