Skip to main content

echidna_optim/solvers/
lbfgs.rs

1use num_traits::Float;
2
3use crate::convergence::{dot, norm, ConvergenceParams};
4use crate::line_search::{backtracking_armijo, ArmijoParams};
5use crate::objective::Objective;
6use crate::result::{OptimResult, TerminationReason};
7
8/// Configuration for the L-BFGS solver.
9#[derive(Debug, Clone)]
10pub struct LbfgsConfig<F> {
11    /// Number of recent (s, y) pairs to store (default: 10).
12    pub memory: usize,
13    /// Convergence parameters.
14    pub convergence: ConvergenceParams<F>,
15    /// Line search parameters.
16    pub line_search: ArmijoParams<F>,
17}
18
19impl Default for LbfgsConfig<f64> {
20    fn default() -> Self {
21        LbfgsConfig {
22            memory: 10,
23            convergence: ConvergenceParams::default(),
24            line_search: ArmijoParams::default(),
25        }
26    }
27}
28
29impl Default for LbfgsConfig<f32> {
30    fn default() -> Self {
31        LbfgsConfig {
32            memory: 10,
33            convergence: ConvergenceParams::default(),
34            line_search: ArmijoParams::default(),
35        }
36    }
37}
38
39/// L-BFGS optimization.
40///
41/// Minimizes `obj` starting from `x0` using the limited-memory BFGS method
42/// with two-loop recursion and backtracking Armijo line search.
43pub fn lbfgs<F: Float, O: Objective<F>>(
44    obj: &mut O,
45    x0: &[F],
46    config: &LbfgsConfig<F>,
47) -> OptimResult<F> {
48    let n = x0.len();
49
50    // Config validation
51    if config.memory == 0 || config.convergence.max_iter == 0 {
52        return OptimResult {
53            x: x0.to_vec(),
54            value: F::nan(),
55            gradient: vec![F::nan(); n],
56            gradient_norm: F::nan(),
57            iterations: 0,
58            func_evals: 0,
59            termination: TerminationReason::NumericalError,
60        };
61    }
62
63    let mut x = x0.to_vec();
64    let (mut f_val, mut grad) = obj.eval_grad(&x);
65    let mut func_evals = 1usize;
66    let mut grad_norm = norm(&grad);
67
68    // NaN/Inf detection
69    if !grad_norm.is_finite() || !f_val.is_finite() {
70        return OptimResult {
71            x,
72            value: f_val,
73            gradient: grad,
74            gradient_norm: grad_norm,
75            iterations: 0,
76            func_evals,
77            termination: TerminationReason::NumericalError,
78        };
79    }
80
81    // Check initial convergence
82    if grad_norm < config.convergence.grad_tol {
83        return OptimResult {
84            x,
85            value: f_val,
86            gradient: grad,
87            gradient_norm: grad_norm,
88            iterations: 0,
89            func_evals,
90            termination: TerminationReason::GradientNorm,
91        };
92    }
93
94    // L-BFGS history buffers: store most recent `m` pairs
95    let m = config.memory;
96    let mut s_hist: Vec<Vec<F>> = Vec::with_capacity(m);
97    let mut y_hist: Vec<Vec<F>> = Vec::with_capacity(m);
98    let mut rho_hist: Vec<F> = Vec::with_capacity(m);
99
100    for iter in 0..config.convergence.max_iter {
101        // Two-loop recursion to compute H_k * g_k
102        let d = two_loop_recursion(&grad, &s_hist, &y_hist, &rho_hist);
103
104        // Line search
105        let ls = match backtracking_armijo(obj, &x, &d, f_val, &grad, &config.line_search) {
106            Some(ls) => ls,
107            None => {
108                return OptimResult {
109                    x,
110                    value: f_val,
111                    gradient: grad,
112                    gradient_norm: grad_norm,
113                    iterations: iter,
114                    func_evals,
115                    termination: TerminationReason::LineSearchFailed,
116                };
117            }
118        };
119        func_evals += ls.evals;
120
121        // Compute s = x_new - x, y = g_new - g
122        let mut s = vec![F::zero(); n];
123        let mut y = vec![F::zero(); n];
124        for i in 0..n {
125            // Compute s = alpha * d directly instead of (x + alpha*d) - x
126            // to avoid cancellation when ||x|| >> alpha*||d||
127            s[i] = ls.alpha * d[i];
128            y[i] = ls.gradient[i] - grad[i];
129            x[i] = x[i] + s[i];
130        }
131
132        let f_prev = f_val;
133        f_val = ls.value;
134        grad = ls.gradient;
135        grad_norm = norm(&grad);
136
137        // Update history (skip pairs with near-zero curvature to prevent rho overflow)
138        let sy = dot(&s, &y);
139        let yy = dot(&y, &y);
140        if sy > F::epsilon() * yy {
141            if s_hist.len() == m {
142                s_hist.remove(0);
143                y_hist.remove(0);
144                rho_hist.remove(0);
145            }
146            rho_hist.push(F::one() / sy);
147            s_hist.push(s);
148            y_hist.push(y);
149        }
150
151        // NaN/Inf detection
152        if !grad_norm.is_finite() || !f_val.is_finite() {
153            return OptimResult {
154                x,
155                value: f_val,
156                gradient: grad,
157                gradient_norm: grad_norm,
158                iterations: iter + 1,
159                func_evals,
160                termination: TerminationReason::NumericalError,
161            };
162        }
163
164        // Convergence checks
165        if grad_norm < config.convergence.grad_tol {
166            return OptimResult {
167                x,
168                value: f_val,
169                gradient: grad,
170                gradient_norm: grad_norm,
171                iterations: iter + 1,
172                func_evals,
173                termination: TerminationReason::GradientNorm,
174            };
175        }
176
177        let step_norm = norm_step(ls.alpha, &d);
178        if step_norm < config.convergence.step_tol {
179            return OptimResult {
180                x,
181                value: f_val,
182                gradient: grad,
183                gradient_norm: grad_norm,
184                iterations: iter + 1,
185                func_evals,
186                termination: TerminationReason::StepSize,
187            };
188        }
189
190        if config.convergence.func_tol > F::zero()
191            && (f_prev - f_val).abs() < config.convergence.func_tol
192        {
193            return OptimResult {
194                x,
195                value: f_val,
196                gradient: grad,
197                gradient_norm: grad_norm,
198                iterations: iter + 1,
199                func_evals,
200                termination: TerminationReason::FunctionChange,
201            };
202        }
203    }
204
205    OptimResult {
206        x,
207        value: f_val,
208        gradient: grad,
209        gradient_norm: grad_norm,
210        iterations: config.convergence.max_iter,
211        func_evals,
212        termination: TerminationReason::MaxIterations,
213    }
214}
215
216/// L-BFGS two-loop recursion: compute d = -H_k * g_k.
217fn two_loop_recursion<F: Float>(
218    grad: &[F],
219    s_hist: &[Vec<F>],
220    y_hist: &[Vec<F>],
221    rho_hist: &[F],
222) -> Vec<F> {
223    let k = s_hist.len();
224    let n = grad.len();
225
226    // q = g
227    let mut q: Vec<F> = grad.to_vec();
228
229    // First loop: newest to oldest
230    let mut alpha = vec![F::zero(); k];
231    for i in (0..k).rev() {
232        alpha[i] = rho_hist[i] * dot(&s_hist[i], &q);
233        for j in 0..n {
234            q[j] = q[j] - alpha[i] * y_hist[i][j];
235        }
236    }
237
238    // Initial Hessian approximation: H_0 = gamma * I
239    // gamma = s^T y / y^T y (from the most recent pair)
240    let mut r = q;
241    if k > 0 {
242        let sy = dot(&s_hist[k - 1], &y_hist[k - 1]);
243        let yy = dot(&y_hist[k - 1], &y_hist[k - 1]);
244        if yy > F::epsilon() {
245            let gamma = sy / yy;
246            if gamma.is_finite() {
247                for v in r.iter_mut() {
248                    *v = *v * gamma;
249                }
250            }
251        }
252    }
253
254    // Second loop: oldest to newest
255    for i in 0..k {
256        let beta = rho_hist[i] * dot(&y_hist[i], &r);
257        for j in 0..n {
258            r[j] = r[j] + (alpha[i] - beta) * s_hist[i][j];
259        }
260    }
261
262    // Negate: d = -H * g
263    for v in r.iter_mut() {
264        *v = F::zero() - *v;
265    }
266
267    r
268}
269
270fn norm_step<F: Float>(alpha: F, d: &[F]) -> F {
271    let mut s = F::zero();
272    for &di in d {
273        let step = alpha * di;
274        s = s + step * step;
275    }
276    s.sqrt()
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    struct Rosenbrock;
284
285    impl Objective<f64> for Rosenbrock {
286        fn dim(&self) -> usize {
287            2
288        }
289
290        fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
291            let a = 1.0 - x[0];
292            let b = x[1] - x[0] * x[0];
293            let f = a * a + 100.0 * b * b;
294            let g0 = -2.0 * a - 400.0 * x[0] * b;
295            let g1 = 200.0 * b;
296            (f, vec![g0, g1])
297        }
298    }
299
300    #[test]
301    fn lbfgs_rosenbrock() {
302        let mut obj = Rosenbrock;
303        let config = LbfgsConfig::default();
304        let result = lbfgs(&mut obj, &[0.0, 0.0], &config);
305
306        assert_eq!(result.termination, TerminationReason::GradientNorm);
307        assert!(
308            (result.x[0] - 1.0).abs() < 1e-6,
309            "x[0] = {}, expected 1.0",
310            result.x[0]
311        );
312        assert!(
313            (result.x[1] - 1.0).abs() < 1e-6,
314            "x[1] = {}, expected 1.0",
315            result.x[1]
316        );
317        assert!(result.gradient_norm < 1e-8);
318    }
319
320    #[test]
321    fn lbfgs_already_converged() {
322        let mut obj = Rosenbrock;
323        let config = LbfgsConfig::default();
324        let result = lbfgs(&mut obj, &[1.0, 1.0], &config);
325
326        assert_eq!(result.termination, TerminationReason::GradientNorm);
327        assert_eq!(result.iterations, 0);
328    }
329}