optirs_core/optimizers/
lbfgs.rs1use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
6use scirs2_core::numeric::Float;
7use std::collections::VecDeque;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12
13#[derive(Debug, Clone)]
36pub struct LBFGS<A: Float + ScalarOperand + Debug> {
37 learning_rate: A,
39 history_size: usize,
41 tolerance_grad: A,
43 #[allow(dead_code)]
45 c1: A,
46 #[allow(dead_code)]
48 c2: A,
49 #[allow(dead_code)]
51 max_ls: usize,
52 old_dirs: VecDeque<Array1<A>>,
54 old_stps: VecDeque<Array1<A>>,
56 ro: VecDeque<A>,
58 prev_grad: Option<Array1<A>>,
60 h_diag: A,
62 n_iter: usize,
64 alpha: Vec<A>,
66}
67
68impl<A: Float + ScalarOperand + Debug + Send + Sync> LBFGS<A> {
69 pub fn new(learning_rate: A) -> Self {
75 Self::new_with_config(
76 learning_rate,
77 100, A::from(1e-7).unwrap(), A::from(1e-4).unwrap(), A::from(0.9).unwrap(), 25, )
83 }
84
85 pub fn new_with_config(
96 learning_rate: A,
97 history_size: usize,
98 tolerance_grad: A,
99 c1: A,
100 c2: A,
101 max_ls: usize,
102 ) -> Self {
103 Self {
104 learning_rate,
105 history_size,
106 tolerance_grad,
107 c1,
108 c2,
109 max_ls,
110 old_dirs: VecDeque::with_capacity(history_size),
111 old_stps: VecDeque::with_capacity(history_size),
112 ro: VecDeque::with_capacity(history_size),
113 prev_grad: None,
114 h_diag: A::one(),
115 n_iter: 0,
116 alpha: vec![A::zero(); history_size],
117 }
118 }
119
120 pub fn learning_rate(&self) -> A {
122 self.learning_rate
123 }
124
125 pub fn set_lr(&mut self, lr: A) {
127 self.learning_rate = lr;
128 }
129
130 pub fn reset(&mut self) {
132 self.old_dirs.clear();
133 self.old_stps.clear();
134 self.ro.clear();
135 self.prev_grad = None;
136 self.h_diag = A::one();
137 self.n_iter = 0;
138 self.alpha.fill(A::zero());
139 }
140
141 fn compute_direction(&mut self, gradient: &Array1<A>) -> Array1<A> {
143 if self.n_iter == 0 {
145 return gradient.mapv(|x| -x);
146 }
147
148 let num_old = self.old_dirs.len();
149
150 let mut q = gradient.mapv(|x| -x);
152
153 for i in (0..num_old).rev() {
154 self.alpha[i] = self.old_stps[i].dot(&q) * self.ro[i];
155 q = &q - &self.old_dirs[i] * self.alpha[i];
156 }
157
158 let mut r = q * self.h_diag;
160
161 for i in 0..num_old {
163 let beta = self.old_dirs[i].dot(&r) * self.ro[i];
164 r = &r + &self.old_stps[i] * (self.alpha[i] - beta);
165 }
166
167 r
168 }
169
170 fn update_history(&mut self, y: Array1<A>, s: Array1<A>) {
172 let ys = y.dot(&s);
173
174 if ys > A::from(1e-10).unwrap() {
176 if self.old_dirs.len() >= self.history_size {
178 self.old_dirs.pop_front();
179 self.old_stps.pop_front();
180 self.ro.pop_front();
181 }
182
183 self.old_dirs.push_back(y.clone());
185 self.old_stps.push_back(s);
186 self.ro.push_back(A::one() / ys);
187
188 let yy = y.dot(&y);
190 if yy > A::zero() {
191 self.h_diag = ys / yy;
192 }
193 }
194 }
195}
196
197impl<A, D> Optimizer<A, D> for LBFGS<A>
198where
199 A: Float + ScalarOperand + Debug + Send + Sync,
200 D: Dimension,
201{
202 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
203 let params_flat = params
205 .to_owned()
206 .into_shape_with_order(params.len())
207 .unwrap();
208 let gradients_flat = gradients
209 .to_owned()
210 .into_shape_with_order(gradients.len())
211 .unwrap();
212
213 let grad_norm = gradients_flat.dot(&gradients_flat).sqrt();
215 if grad_norm <= self.tolerance_grad {
216 return Ok(params.clone());
217 }
218
219 if let Some(prev_grad) = self.prev_grad.clone() {
221 let y = &gradients_flat - &prev_grad;
222
223 if self.n_iter > 0 {
225 let direction = self.compute_direction(&prev_grad);
227 let step_size = if self.n_iter == 1 {
228 self.learning_rate / (A::one() + grad_norm)
229 } else {
230 self.learning_rate
231 };
232 let s = direction * step_size;
233 self.update_history(y, s);
234 }
235 }
236
237 let direction = self.compute_direction(&gradients_flat);
239
240 let step_size = if self.n_iter == 0 {
242 self.learning_rate / (A::one() + grad_norm)
244 } else {
245 self.learning_rate
246 };
247
248 let new_params_flat = ¶ms_flat + &(&direction * step_size);
250
251 self.prev_grad = Some(gradients_flat.clone());
253 self.n_iter += 1;
254
255 let new_params = new_params_flat
257 .into_shape_with_order(params.raw_dim())
258 .unwrap();
259
260 Ok(new_params)
261 }
262
263 fn get_learning_rate(&self) -> A {
264 self.learning_rate
265 }
266
267 fn set_learning_rate(&mut self, learning_rate: A) {
268 self.learning_rate = learning_rate;
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use approx::assert_abs_diff_eq;
276 use scirs2_core::ndarray::Array1;
277
278 #[test]
279 fn test_lbfgs_basic_creation() {
280 let optimizer: LBFGS<f64> = LBFGS::new(1.0);
281 assert_abs_diff_eq!(optimizer.learning_rate(), 1.0);
282 assert_eq!(optimizer.history_size, 100);
283 assert_abs_diff_eq!(optimizer.tolerance_grad, 1e-7);
284 }
285
286 #[test]
287 fn test_lbfgs_convergence() {
288 let mut optimizer: LBFGS<f64> = LBFGS::new(0.1);
289
290 let mut params = Array1::from_vec(vec![10.0]);
292
293 for _ in 0..50 {
294 let gradients = Array1::from_vec(vec![2.0 * params[0]]);
295 params = optimizer.step(¶ms, &gradients).unwrap();
296 }
297
298 assert!(params[0].abs() < 0.1);
300 }
301
302 #[test]
303 fn test_lbfgs_2d() {
304 let mut optimizer: LBFGS<f64> = LBFGS::new(0.1);
305
306 let mut params = Array1::from_vec(vec![5.0, 3.0]);
308
309 for _ in 0..50 {
310 let gradients = Array1::from_vec(vec![2.0 * params[0], 2.0 * params[1]]);
311 params = optimizer.step(¶ms, &gradients).unwrap();
312 }
313
314 assert!(params[0].abs() < 0.1);
316 assert!(params[1].abs() < 0.1);
317 }
318
319 #[test]
320 fn test_lbfgs_reset() {
321 let mut optimizer: LBFGS<f64> = LBFGS::new(0.1);
322
323 let mut params = Array1::from_vec(vec![1.0]);
325 let gradients = Array1::from_vec(vec![2.0]);
326 params = optimizer.step(¶ms, &gradients).unwrap();
327
328 let gradients2 = Array1::from_vec(vec![1.5]);
330 params = optimizer.step(¶ms, &gradients2).unwrap();
331
332 let gradients3 = Array1::from_vec(vec![1.0]);
334 let _ = optimizer.step(¶ms, &gradients3).unwrap();
335
336 assert!(!optimizer.old_dirs.is_empty());
338 assert!(optimizer.n_iter > 0);
339
340 optimizer.reset();
342
343 assert!(optimizer.old_dirs.is_empty());
345 assert!(optimizer.old_stps.is_empty());
346 assert!(optimizer.ro.is_empty());
347 assert!(optimizer.prev_grad.is_none());
348 assert_eq!(optimizer.n_iter, 0);
349 }
350}