Skip to main content

echidna_optim/solvers/
lbfgs.rs

1use num_traits::Float;
2
3use crate::convergence::{dot, norm, ConvergenceParams};
4use crate::line_search::{backtracking_armijo, ArmijoParams};
5use crate::objective::Objective;
6use crate::result::{OptimResult, TerminationReason};
7
8/// Configuration for the L-BFGS solver.
9#[derive(Debug, Clone)]
10pub struct LbfgsConfig<F> {
11    /// Number of recent (s, y) pairs to store (default: 10).
12    pub memory: usize,
13    /// Convergence parameters.
14    pub convergence: ConvergenceParams<F>,
15    /// Line search parameters.
16    pub line_search: ArmijoParams<F>,
17}
18
19impl Default for LbfgsConfig<f64> {
20    fn default() -> Self {
21        LbfgsConfig {
22            memory: 10,
23            convergence: ConvergenceParams::default(),
24            line_search: ArmijoParams::default(),
25        }
26    }
27}
28
29impl Default for LbfgsConfig<f32> {
30    fn default() -> Self {
31        LbfgsConfig {
32            memory: 10,
33            convergence: ConvergenceParams::default(),
34            line_search: ArmijoParams::default(),
35        }
36    }
37}
38
39/// L-BFGS optimization.
40///
41/// Minimizes `obj` starting from `x0` using the limited-memory BFGS method
42/// with two-loop recursion and backtracking Armijo line search.
43pub fn lbfgs<F: Float, O: Objective<F>>(
44    obj: &mut O,
45    x0: &[F],
46    config: &LbfgsConfig<F>,
47) -> OptimResult<F> {
48    let n = x0.len();
49
50    // Config validation
51    if config.memory == 0 || config.convergence.max_iter == 0 {
52        return OptimResult {
53            x: x0.to_vec(),
54            value: F::nan(),
55            gradient: vec![F::nan(); n],
56            gradient_norm: F::nan(),
57            iterations: 0,
58            func_evals: 0,
59            termination: TerminationReason::NumericalError,
60        };
61    }
62
63    let mut x = x0.to_vec();
64    let (mut f_val, mut grad) = obj.eval_grad(&x);
65    let mut func_evals = 1usize;
66    let mut grad_norm = norm(&grad);
67
68    // Check initial convergence
69    if grad_norm < config.convergence.grad_tol {
70        return OptimResult {
71            x,
72            value: f_val,
73            gradient: grad,
74            gradient_norm: grad_norm,
75            iterations: 0,
76            func_evals,
77            termination: TerminationReason::GradientNorm,
78        };
79    }
80
81    // L-BFGS history buffers: store most recent `m` pairs
82    let m = config.memory;
83    let mut s_hist: Vec<Vec<F>> = Vec::with_capacity(m);
84    let mut y_hist: Vec<Vec<F>> = Vec::with_capacity(m);
85    let mut rho_hist: Vec<F> = Vec::with_capacity(m);
86
87    for iter in 0..config.convergence.max_iter {
88        // Two-loop recursion to compute H_k * g_k
89        let d = two_loop_recursion(&grad, &s_hist, &y_hist, &rho_hist);
90
91        // Line search
92        let ls = match backtracking_armijo(obj, &x, &d, f_val, &grad, &config.line_search) {
93            Some(ls) => ls,
94            None => {
95                return OptimResult {
96                    x,
97                    value: f_val,
98                    gradient: grad,
99                    gradient_norm: grad_norm,
100                    iterations: iter,
101                    func_evals,
102                    termination: TerminationReason::LineSearchFailed,
103                };
104            }
105        };
106        func_evals += ls.evals;
107
108        // Compute s = x_new - x, y = g_new - g
109        let mut s = vec![F::zero(); n];
110        let mut y = vec![F::zero(); n];
111        for i in 0..n {
112            let x_new_i = x[i] + ls.alpha * d[i];
113            s[i] = x_new_i - x[i];
114            y[i] = ls.gradient[i] - grad[i];
115            x[i] = x_new_i;
116        }
117
118        let f_prev = f_val;
119        f_val = ls.value;
120        grad = ls.gradient;
121        grad_norm = norm(&grad);
122
123        // Update history
124        let sy = dot(&s, &y);
125        if sy > F::zero() {
126            if s_hist.len() == m {
127                s_hist.remove(0);
128                y_hist.remove(0);
129                rho_hist.remove(0);
130            }
131            rho_hist.push(F::one() / sy);
132            s_hist.push(s);
133            y_hist.push(y);
134        }
135
136        // Convergence checks
137        if grad_norm < config.convergence.grad_tol {
138            return OptimResult {
139                x,
140                value: f_val,
141                gradient: grad,
142                gradient_norm: grad_norm,
143                iterations: iter + 1,
144                func_evals,
145                termination: TerminationReason::GradientNorm,
146            };
147        }
148
149        let step_norm = norm_step(ls.alpha, &d);
150        if step_norm < config.convergence.step_tol {
151            return OptimResult {
152                x,
153                value: f_val,
154                gradient: grad,
155                gradient_norm: grad_norm,
156                iterations: iter + 1,
157                func_evals,
158                termination: TerminationReason::StepSize,
159            };
160        }
161
162        if config.convergence.func_tol > F::zero()
163            && (f_prev - f_val).abs() < config.convergence.func_tol
164        {
165            return OptimResult {
166                x,
167                value: f_val,
168                gradient: grad,
169                gradient_norm: grad_norm,
170                iterations: iter + 1,
171                func_evals,
172                termination: TerminationReason::FunctionChange,
173            };
174        }
175    }
176
177    OptimResult {
178        x,
179        value: f_val,
180        gradient: grad,
181        gradient_norm: grad_norm,
182        iterations: config.convergence.max_iter,
183        func_evals,
184        termination: TerminationReason::MaxIterations,
185    }
186}
187
188/// L-BFGS two-loop recursion: compute d = -H_k * g_k.
189fn two_loop_recursion<F: Float>(
190    grad: &[F],
191    s_hist: &[Vec<F>],
192    y_hist: &[Vec<F>],
193    rho_hist: &[F],
194) -> Vec<F> {
195    let k = s_hist.len();
196    let n = grad.len();
197
198    // q = g
199    let mut q: Vec<F> = grad.to_vec();
200
201    // First loop: newest to oldest
202    let mut alpha = vec![F::zero(); k];
203    for i in (0..k).rev() {
204        alpha[i] = rho_hist[i] * dot(&s_hist[i], &q);
205        for j in 0..n {
206            q[j] = q[j] - alpha[i] * y_hist[i][j];
207        }
208    }
209
210    // Initial Hessian approximation: H_0 = gamma * I
211    // gamma = s^T y / y^T y (from the most recent pair)
212    let mut r = q;
213    if k > 0 {
214        let sy = dot(&s_hist[k - 1], &y_hist[k - 1]);
215        let yy = dot(&y_hist[k - 1], &y_hist[k - 1]);
216        if yy > F::zero() {
217            let gamma = sy / yy;
218            for v in r.iter_mut() {
219                *v = *v * gamma;
220            }
221        }
222    }
223
224    // Second loop: oldest to newest
225    for i in 0..k {
226        let beta = rho_hist[i] * dot(&y_hist[i], &r);
227        for j in 0..n {
228            r[j] = r[j] + (alpha[i] - beta) * s_hist[i][j];
229        }
230    }
231
232    // Negate: d = -H * g
233    for v in r.iter_mut() {
234        *v = F::zero() - *v;
235    }
236
237    r
238}
239
240fn norm_step<F: Float>(alpha: F, d: &[F]) -> F {
241    let mut s = F::zero();
242    for &di in d {
243        let step = alpha * di;
244        s = s + step * step;
245    }
246    s.sqrt()
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    struct Rosenbrock;
254
255    impl Objective<f64> for Rosenbrock {
256        fn dim(&self) -> usize {
257            2
258        }
259
260        fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
261            let a = 1.0 - x[0];
262            let b = x[1] - x[0] * x[0];
263            let f = a * a + 100.0 * b * b;
264            let g0 = -2.0 * a - 400.0 * x[0] * b;
265            let g1 = 200.0 * b;
266            (f, vec![g0, g1])
267        }
268    }
269
270    #[test]
271    fn lbfgs_rosenbrock() {
272        let mut obj = Rosenbrock;
273        let config = LbfgsConfig::default();
274        let result = lbfgs(&mut obj, &[0.0, 0.0], &config);
275
276        assert_eq!(result.termination, TerminationReason::GradientNorm);
277        assert!(
278            (result.x[0] - 1.0).abs() < 1e-6,
279            "x[0] = {}, expected 1.0",
280            result.x[0]
281        );
282        assert!(
283            (result.x[1] - 1.0).abs() < 1e-6,
284            "x[1] = {}, expected 1.0",
285            result.x[1]
286        );
287        assert!(result.gradient_norm < 1e-8);
288    }
289
290    #[test]
291    fn lbfgs_already_converged() {
292        let mut obj = Rosenbrock;
293        let config = LbfgsConfig::default();
294        let result = lbfgs(&mut obj, &[1.0, 1.0], &config);
295
296        assert_eq!(result.termination, TerminationReason::GradientNorm);
297        assert_eq!(result.iterations, 0);
298    }
299}