Skip to main content

numra_optim/
lbfgs.rs

1//! L-BFGS (Limited-memory BFGS) optimizer.
2//!
3//! Uses the two-loop recursion to compute search directions using only
4//! the last `m` (s, y) correction pairs, requiring O(mn) storage instead of O(n²).
5//!
6//! Author: Moussa Leblouba
7//! Date: 8 February 2026
8//! Modified: 2 May 2026
9
10use numra_core::Scalar;
11
12use crate::error::OptimError;
13use crate::types::{IterationRecord, OptimOptions, OptimResult, OptimStatus};
14use numra_nonlinear::line_search::{wolfe_line_search, WolfeOptions};
15use std::collections::VecDeque;
16
17/// Options specific to L-BFGS.
18#[derive(Clone, Debug)]
19pub struct LbfgsOptions<S: Scalar> {
20    /// Base optimization options.
21    pub base: OptimOptions<S>,
22    /// Number of correction pairs to store (default: 10).
23    pub memory: usize,
24}
25
26impl<S: Scalar> Default for LbfgsOptions<S> {
27    fn default() -> Self {
28        Self {
29            base: OptimOptions::default(),
30            memory: 10,
31        }
32    }
33}
34
35impl<S: Scalar> LbfgsOptions<S> {
36    pub fn memory(mut self, m: usize) -> Self {
37        self.memory = m;
38        self
39    }
40    pub fn max_iter(mut self, n: usize) -> Self {
41        self.base.max_iter = n;
42        self
43    }
44    pub fn gtol(mut self, tol: S) -> Self {
45        self.base.gtol = tol;
46        self
47    }
48}
49
50/// L-BFGS optimizer.
51pub struct Lbfgs<S: Scalar> {
52    options: LbfgsOptions<S>,
53}
54
55impl<S: Scalar> Lbfgs<S> {
56    pub fn new(options: LbfgsOptions<S>) -> Self {
57        Self { options }
58    }
59
60    pub fn minimize<F, G>(&self, f: F, grad: G, x0: &[S]) -> Result<OptimResult<S>, OptimError>
61    where
62        F: Fn(&[S]) -> S,
63        G: Fn(&[S], &mut [S]),
64    {
65        lbfgs_minimize(f, grad, x0, &self.options)
66    }
67}
68
69/// Minimize f(x) using L-BFGS.
70pub fn lbfgs_minimize<S: Scalar, F, G>(
71    f: F,
72    grad: G,
73    x0: &[S],
74    opts: &LbfgsOptions<S>,
75) -> Result<OptimResult<S>, OptimError>
76where
77    F: Fn(&[S]) -> S,
78    G: Fn(&[S], &mut [S]),
79{
80    let start = std::time::Instant::now();
81    let n = x0.len();
82    let m = opts.memory;
83    let mut x = x0.to_vec();
84    let mut g = vec![S::ZERO; n];
85    let mut g_new = vec![S::ZERO; n];
86
87    // History buffers
88    let mut s_hist: VecDeque<Vec<S>> = VecDeque::with_capacity(m);
89    let mut y_hist: VecDeque<Vec<S>> = VecDeque::with_capacity(m);
90    let mut rho_hist: VecDeque<S> = VecDeque::with_capacity(m);
91
92    let mut fval = f(&x);
93    grad(&x, &mut g);
94    let mut n_feval = 1_usize;
95    let mut n_geval = 1_usize;
96
97    let wolfe_opts = WolfeOptions::default();
98    let mut history = Vec::new();
99
100    for iter in 0..opts.base.max_iter {
101        let g_norm = g.iter().copied().map(|gi| gi * gi).sum::<S>().sqrt();
102        if g_norm < opts.base.gtol {
103            return Ok((OptimResult {
104                history,
105                ..OptimResult::unconstrained(
106                    x,
107                    fval,
108                    g,
109                    iter,
110                    n_feval,
111                    n_geval,
112                    true,
113                    format!("Converged: gradient norm {:.2e}", g_norm.to_f64()),
114                    OptimStatus::GradientConverged,
115                )
116            })
117            .with_wall_time(start));
118        }
119
120        // Two-loop recursion to compute d = -H_k * g
121        let d = two_loop_recursion(&g, &s_hist, &y_hist, &rho_hist);
122
123        // Line search
124        let ls_result = wolfe_line_search(&f, &grad, &x, &d, fval, &g, &wolfe_opts)?;
125        n_feval += ls_result.n_eval;
126
127        let alpha = ls_result.step;
128
129        // s = alpha * d
130        let s: Vec<S> = d.iter().map(|di| alpha * *di).collect();
131
132        // Update x
133        for i in 0..n {
134            x[i] += s[i];
135        }
136
137        let f_new = ls_result.f_new;
138
139        if (fval - f_new).abs() < opts.base.ftol * (S::ONE + fval.abs()) {
140            grad(&x, &mut g);
141            n_geval += 1;
142            let g_norm_new = g.iter().copied().map(|gi| gi * gi).sum::<S>().sqrt();
143            history.push(IterationRecord {
144                iteration: iter,
145                objective: f_new,
146                gradient_norm: g_norm_new,
147                step_size: alpha,
148                constraint_violation: S::ZERO,
149            });
150            return Ok((OptimResult {
151                history,
152                ..OptimResult::unconstrained(
153                    x,
154                    f_new,
155                    g,
156                    iter + 1,
157                    n_feval,
158                    n_geval,
159                    true,
160                    format!(
161                        "Converged: function change {:.2e}",
162                        (fval - f_new).abs().to_f64()
163                    ),
164                    OptimStatus::FunctionConverged,
165                )
166            })
167            .with_wall_time(start));
168        }
169
170        fval = f_new;
171
172        // New gradient
173        grad(&x, &mut g_new);
174        n_geval += 1;
175
176        let g_new_norm = g_new.iter().copied().map(|gi| gi * gi).sum::<S>().sqrt();
177        history.push(IterationRecord {
178            iteration: iter,
179            objective: fval,
180            gradient_norm: g_new_norm,
181            step_size: alpha,
182            constraint_violation: S::ZERO,
183        });
184
185        // y = g_new - g
186        let y: Vec<S> = g_new
187            .iter()
188            .zip(g.iter())
189            .map(|(gn, go)| *gn - *go)
190            .collect();
191        let sy: S = s.iter().zip(y.iter()).map(|(si, yi)| *si * *yi).sum();
192
193        if sy > S::from_f64(1e-16) {
194            if s_hist.len() == m {
195                s_hist.pop_front();
196                y_hist.pop_front();
197                rho_hist.pop_front();
198            }
199            rho_hist.push_back(S::ONE / sy);
200            s_hist.push_back(s);
201            y_hist.push_back(y);
202        }
203
204        g.copy_from_slice(&g_new);
205    }
206
207    Ok((OptimResult {
208        history,
209        ..OptimResult::unconstrained(
210            x,
211            fval,
212            g,
213            opts.base.max_iter,
214            n_feval,
215            n_geval,
216            false,
217            format!("Maximum iterations ({}) reached", opts.base.max_iter),
218            OptimStatus::MaxIterations,
219        )
220    })
221    .with_wall_time(start))
222}
223
224/// Two-loop recursion for L-BFGS direction computation.
225pub(crate) fn two_loop_recursion<S: Scalar>(
226    g: &[S],
227    s_hist: &VecDeque<Vec<S>>,
228    y_hist: &VecDeque<Vec<S>>,
229    rho_hist: &VecDeque<S>,
230) -> Vec<S> {
231    let k = s_hist.len();
232    let n = g.len();
233    let mut q = g.to_vec();
234    let mut alpha_vec = vec![S::ZERO; k];
235
236    // Forward pass
237    for i in (0..k).rev() {
238        let a: S = rho_hist[i]
239            * s_hist[i]
240                .iter()
241                .zip(q.iter())
242                .map(|(si, qi)| *si * *qi)
243                .sum::<S>();
244        alpha_vec[i] = a;
245        for j in 0..n {
246            q[j] -= a * y_hist[i][j];
247        }
248    }
249
250    // Initial Hessian scaling: H_0 = gamma * I
251    let gamma = if k > 0 {
252        let sy: S = s_hist[k - 1]
253            .iter()
254            .zip(y_hist[k - 1].iter())
255            .map(|(s, y)| *s * *y)
256            .sum();
257        let yy: S = y_hist[k - 1].iter().copied().map(|y| y * y).sum();
258        sy / yy
259    } else {
260        S::ONE
261    };
262
263    let mut r: Vec<S> = q.iter().map(|qi| gamma * *qi).collect();
264
265    // Backward pass
266    for i in 0..k {
267        let b: S = rho_hist[i]
268            * y_hist[i]
269                .iter()
270                .zip(r.iter())
271                .map(|(yi, ri)| *yi * *ri)
272                .sum::<S>();
273        for j in 0..n {
274            r[j] += s_hist[i][j] * (alpha_vec[i] - b);
275        }
276    }
277
278    // Negate for descent direction
279    for ri in r.iter_mut() {
280        *ri = -*ri;
281    }
282
283    r
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_lbfgs_quadratic() {
292        let f = |x: &[f64]| x[0] * x[0] + 4.0 * x[1] * x[1];
293        let g = |x: &[f64], grad: &mut [f64]| {
294            grad[0] = 2.0 * x[0];
295            grad[1] = 8.0 * x[1];
296        };
297
298        let result = lbfgs_minimize(f, g, &[5.0, 3.0], &LbfgsOptions::default()).unwrap();
299        assert!(result.converged);
300        assert!(result.f < 1e-12);
301    }
302
303    #[test]
304    fn test_lbfgs_rosenbrock() {
305        let f = |x: &[f64]| (1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0] * x[0]).powi(2);
306        let g = |x: &[f64], grad: &mut [f64]| {
307            grad[0] = -2.0 * (1.0 - x[0]) - 400.0 * x[0] * (x[1] - x[0] * x[0]);
308            grad[1] = 200.0 * (x[1] - x[0] * x[0]);
309        };
310
311        let result = lbfgs_minimize(f, g, &[-1.0, 1.0], &LbfgsOptions::default()).unwrap();
312        assert!(
313            result.converged,
314            "L-BFGS did not converge: {}",
315            result.message
316        );
317        assert!((result.x[0] - 1.0).abs() < 1e-4);
318        assert!((result.x[1] - 1.0).abs() < 1e-4);
319    }
320
321    #[test]
322    fn test_lbfgs_large_scale() {
323        // n=100, f(x) = sum(x_i^2)
324        let n = 100;
325        let f = |x: &[f64]| x.iter().copied().map(|xi| xi * xi).sum::<f64>();
326        let g = |x: &[f64], grad: &mut [f64]| {
327            for i in 0..x.len() {
328                grad[i] = 2.0 * x[i];
329            }
330        };
331
332        let x0: Vec<f64> = (1..=n).map(|i| i as f64 * 0.1).collect();
333        let result = lbfgs_minimize(f, g, &x0, &LbfgsOptions::default()).unwrap();
334        assert!(result.converged);
335        assert!(result.f < 1e-10);
336    }
337}