1use num_traits::Float;
2
3use crate::convergence::{dot, norm, ConvergenceParams};
4use crate::line_search::{backtracking_armijo, ArmijoParams};
5use crate::objective::Objective;
6use crate::result::{OptimResult, TerminationReason};
7
8#[derive(Debug, Clone)]
10pub struct LbfgsConfig<F> {
11 pub memory: usize,
13 pub convergence: ConvergenceParams<F>,
15 pub line_search: ArmijoParams<F>,
17}
18
19impl Default for LbfgsConfig<f64> {
20 fn default() -> Self {
21 LbfgsConfig {
22 memory: 10,
23 convergence: ConvergenceParams::default(),
24 line_search: ArmijoParams::default(),
25 }
26 }
27}
28
29impl Default for LbfgsConfig<f32> {
30 fn default() -> Self {
31 LbfgsConfig {
32 memory: 10,
33 convergence: ConvergenceParams::default(),
34 line_search: ArmijoParams::default(),
35 }
36 }
37}
38
39pub fn lbfgs<F: Float, O: Objective<F>>(
44 obj: &mut O,
45 x0: &[F],
46 config: &LbfgsConfig<F>,
47) -> OptimResult<F> {
48 let n = x0.len();
49
50 if config.memory == 0 || config.convergence.max_iter == 0 {
52 return OptimResult {
53 x: x0.to_vec(),
54 value: F::nan(),
55 gradient: vec![F::nan(); n],
56 gradient_norm: F::nan(),
57 iterations: 0,
58 func_evals: 0,
59 termination: TerminationReason::NumericalError,
60 };
61 }
62
63 let mut x = x0.to_vec();
64 let (mut f_val, mut grad) = obj.eval_grad(&x);
65 let mut func_evals = 1usize;
66 let mut grad_norm = norm(&grad);
67
68 if grad_norm < config.convergence.grad_tol {
70 return OptimResult {
71 x,
72 value: f_val,
73 gradient: grad,
74 gradient_norm: grad_norm,
75 iterations: 0,
76 func_evals,
77 termination: TerminationReason::GradientNorm,
78 };
79 }
80
81 let m = config.memory;
83 let mut s_hist: Vec<Vec<F>> = Vec::with_capacity(m);
84 let mut y_hist: Vec<Vec<F>> = Vec::with_capacity(m);
85 let mut rho_hist: Vec<F> = Vec::with_capacity(m);
86
87 for iter in 0..config.convergence.max_iter {
88 let d = two_loop_recursion(&grad, &s_hist, &y_hist, &rho_hist);
90
91 let ls = match backtracking_armijo(obj, &x, &d, f_val, &grad, &config.line_search) {
93 Some(ls) => ls,
94 None => {
95 return OptimResult {
96 x,
97 value: f_val,
98 gradient: grad,
99 gradient_norm: grad_norm,
100 iterations: iter,
101 func_evals,
102 termination: TerminationReason::LineSearchFailed,
103 };
104 }
105 };
106 func_evals += ls.evals;
107
108 let mut s = vec![F::zero(); n];
110 let mut y = vec![F::zero(); n];
111 for i in 0..n {
112 let x_new_i = x[i] + ls.alpha * d[i];
113 s[i] = x_new_i - x[i];
114 y[i] = ls.gradient[i] - grad[i];
115 x[i] = x_new_i;
116 }
117
118 let f_prev = f_val;
119 f_val = ls.value;
120 grad = ls.gradient;
121 grad_norm = norm(&grad);
122
123 let sy = dot(&s, &y);
125 if sy > F::zero() {
126 if s_hist.len() == m {
127 s_hist.remove(0);
128 y_hist.remove(0);
129 rho_hist.remove(0);
130 }
131 rho_hist.push(F::one() / sy);
132 s_hist.push(s);
133 y_hist.push(y);
134 }
135
136 if grad_norm < config.convergence.grad_tol {
138 return OptimResult {
139 x,
140 value: f_val,
141 gradient: grad,
142 gradient_norm: grad_norm,
143 iterations: iter + 1,
144 func_evals,
145 termination: TerminationReason::GradientNorm,
146 };
147 }
148
149 let step_norm = norm_step(ls.alpha, &d);
150 if step_norm < config.convergence.step_tol {
151 return OptimResult {
152 x,
153 value: f_val,
154 gradient: grad,
155 gradient_norm: grad_norm,
156 iterations: iter + 1,
157 func_evals,
158 termination: TerminationReason::StepSize,
159 };
160 }
161
162 if config.convergence.func_tol > F::zero()
163 && (f_prev - f_val).abs() < config.convergence.func_tol
164 {
165 return OptimResult {
166 x,
167 value: f_val,
168 gradient: grad,
169 gradient_norm: grad_norm,
170 iterations: iter + 1,
171 func_evals,
172 termination: TerminationReason::FunctionChange,
173 };
174 }
175 }
176
177 OptimResult {
178 x,
179 value: f_val,
180 gradient: grad,
181 gradient_norm: grad_norm,
182 iterations: config.convergence.max_iter,
183 func_evals,
184 termination: TerminationReason::MaxIterations,
185 }
186}
187
188fn two_loop_recursion<F: Float>(
190 grad: &[F],
191 s_hist: &[Vec<F>],
192 y_hist: &[Vec<F>],
193 rho_hist: &[F],
194) -> Vec<F> {
195 let k = s_hist.len();
196 let n = grad.len();
197
198 let mut q: Vec<F> = grad.to_vec();
200
201 let mut alpha = vec![F::zero(); k];
203 for i in (0..k).rev() {
204 alpha[i] = rho_hist[i] * dot(&s_hist[i], &q);
205 for j in 0..n {
206 q[j] = q[j] - alpha[i] * y_hist[i][j];
207 }
208 }
209
210 let mut r = q;
213 if k > 0 {
214 let sy = dot(&s_hist[k - 1], &y_hist[k - 1]);
215 let yy = dot(&y_hist[k - 1], &y_hist[k - 1]);
216 if yy > F::zero() {
217 let gamma = sy / yy;
218 for v in r.iter_mut() {
219 *v = *v * gamma;
220 }
221 }
222 }
223
224 for i in 0..k {
226 let beta = rho_hist[i] * dot(&y_hist[i], &r);
227 for j in 0..n {
228 r[j] = r[j] + (alpha[i] - beta) * s_hist[i][j];
229 }
230 }
231
232 for v in r.iter_mut() {
234 *v = F::zero() - *v;
235 }
236
237 r
238}
239
240fn norm_step<F: Float>(alpha: F, d: &[F]) -> F {
241 let mut s = F::zero();
242 for &di in d {
243 let step = alpha * di;
244 s = s + step * step;
245 }
246 s.sqrt()
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 struct Rosenbrock;
254
255 impl Objective<f64> for Rosenbrock {
256 fn dim(&self) -> usize {
257 2
258 }
259
260 fn eval_grad(&mut self, x: &[f64]) -> (f64, Vec<f64>) {
261 let a = 1.0 - x[0];
262 let b = x[1] - x[0] * x[0];
263 let f = a * a + 100.0 * b * b;
264 let g0 = -2.0 * a - 400.0 * x[0] * b;
265 let g1 = 200.0 * b;
266 (f, vec![g0, g1])
267 }
268 }
269
270 #[test]
271 fn lbfgs_rosenbrock() {
272 let mut obj = Rosenbrock;
273 let config = LbfgsConfig::default();
274 let result = lbfgs(&mut obj, &[0.0, 0.0], &config);
275
276 assert_eq!(result.termination, TerminationReason::GradientNorm);
277 assert!(
278 (result.x[0] - 1.0).abs() < 1e-6,
279 "x[0] = {}, expected 1.0",
280 result.x[0]
281 );
282 assert!(
283 (result.x[1] - 1.0).abs() < 1e-6,
284 "x[1] = {}, expected 1.0",
285 result.x[1]
286 );
287 assert!(result.gradient_norm < 1e-8);
288 }
289
290 #[test]
291 fn lbfgs_already_converged() {
292 let mut obj = Rosenbrock;
293 let config = LbfgsConfig::default();
294 let result = lbfgs(&mut obj, &[1.0, 1.0], &config);
295
296 assert_eq!(result.termination, TerminationReason::GradientNorm);
297 assert_eq!(result.iterations, 0);
298 }
299}