optirs_core/optimizers/
lamb.rs1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12
13#[derive(Debug, Clone)]
47pub struct LAMB<A: Float + ScalarOperand + Debug> {
48 learning_rate: A,
50 beta1: A,
52 beta2: A,
54 epsilon: A,
56 weight_decay: A,
58 bias_correction: bool,
60 m: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
62 v: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
64 t: usize,
66}
67
68impl<A: Float + ScalarOperand + Debug + Send + Sync> LAMB<A> {
69 pub fn new(learning_rate: A) -> Self {
75 Self {
76 learning_rate,
77 beta1: A::from(0.9).unwrap(),
78 beta2: A::from(0.999).unwrap(),
79 epsilon: A::from(1e-6).unwrap(),
80 weight_decay: A::zero(),
81 bias_correction: true,
82 m: None,
83 v: None,
84 t: 0,
85 }
86 }
87
88 pub fn new_with_config(
99 learning_rate: A,
100 beta1: A,
101 beta2: A,
102 epsilon: A,
103 weight_decay: A,
104 bias_correction: bool,
105 ) -> Self {
106 Self {
107 learning_rate,
108 beta1,
109 beta2,
110 epsilon,
111 weight_decay,
112 bias_correction,
113 m: None,
114 v: None,
115 t: 0,
116 }
117 }
118
119 pub fn set_beta1(&mut self, beta1: A) -> &mut Self {
121 self.beta1 = beta1;
122 self
123 }
124
125 pub fn get_beta1(&self) -> A {
127 self.beta1
128 }
129
130 pub fn set_beta2(&mut self, beta2: A) -> &mut Self {
132 self.beta2 = beta2;
133 self
134 }
135
136 pub fn get_beta2(&self) -> A {
138 self.beta2
139 }
140
141 pub fn set_epsilon(&mut self, epsilon: A) -> &mut Self {
143 self.epsilon = epsilon;
144 self
145 }
146
147 pub fn get_epsilon(&self) -> A {
149 self.epsilon
150 }
151
152 pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
154 self.weight_decay = weight_decay;
155 self
156 }
157
158 pub fn get_weight_decay(&self) -> A {
160 self.weight_decay
161 }
162
163 pub fn learning_rate(&self) -> A {
165 self.learning_rate
166 }
167
168 pub fn set_lr(&mut self, lr: A) {
170 self.learning_rate = lr;
171 }
172
173 pub fn reset(&mut self) {
175 self.m = None;
176 self.v = None;
177 self.t = 0;
178 }
179}
180
181impl<A, D> Optimizer<A, D> for LAMB<A>
182where
183 A: Float + ScalarOperand + Debug + Send + Sync,
184 D: Dimension,
185{
186 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
187 let params_dyn = params.to_owned().into_dyn();
189 let gradients_dyn = gradients.to_owned().into_dyn();
190
191 if self.m.is_none() {
193 self.m = Some(vec![Array::zeros(params_dyn.raw_dim())]);
194 self.v = Some(vec![Array::zeros(params_dyn.raw_dim())]);
195 self.t = 0;
196 }
197
198 let m = self.m.as_mut().unwrap();
199 let v = self.v.as_mut().unwrap();
200
201 if m.is_empty() {
203 m.push(Array::zeros(params_dyn.raw_dim()));
204 v.push(Array::zeros(params_dyn.raw_dim()));
205 } else if m[0].raw_dim() != params_dyn.raw_dim() {
206 m[0] = Array::zeros(params_dyn.raw_dim());
208 v[0] = Array::zeros(params_dyn.raw_dim());
209 }
210
211 self.t += 1;
213
214 m[0] = &m[0] * self.beta1 + &gradients_dyn * (A::one() - self.beta1);
216
217 v[0] = &v[0] * self.beta2 + &(&gradients_dyn * &gradients_dyn * (A::one() - self.beta2));
219
220 let (m_hat, v_hat) = if self.bias_correction {
222 let bias1 = A::one() - self.beta1.powi(self.t as i32);
223 let bias2 = A::one() - self.beta2.powi(self.t as i32);
224 (&m[0] / bias1, &v[0] / bias2)
225 } else {
226 (m[0].clone(), v[0].clone())
227 };
228
229 let v_hat_sqrt = v_hat.mapv(|x| x.sqrt());
231 let adaptive_term = &m_hat / &(&v_hat_sqrt + self.epsilon);
232
233 let normalized_gradient = if self.weight_decay > A::zero() {
235 &adaptive_term + &(¶ms_dyn * self.weight_decay)
236 } else {
237 adaptive_term
238 };
239
240 let weight_norm = {
242 let norm_sq = params_dyn
243 .iter()
244 .map(|x| *x * *x)
245 .fold(A::zero(), |acc, x| acc + x);
246 norm_sq.sqrt()
247 };
248 let gradient_norm = {
249 let norm_sq = normalized_gradient
250 .iter()
251 .map(|x| *x * *x)
252 .fold(A::zero(), |acc, x| acc + x);
253 norm_sq.sqrt()
254 };
255
256 let trust_ratio = if weight_norm > A::zero() && gradient_norm > A::zero() {
257 weight_norm / gradient_norm
258 } else {
259 A::one()
260 };
261
262 let step = &normalized_gradient * (self.learning_rate * trust_ratio);
264 let updated_params = ¶ms_dyn - step;
265
266 Ok(updated_params.into_dimensionality::<D>().unwrap())
268 }
269
270 fn get_learning_rate(&self) -> A {
271 self.learning_rate
272 }
273
274 fn set_learning_rate(&mut self, learning_rate: A) {
275 self.learning_rate = learning_rate;
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use approx::assert_abs_diff_eq;
283 use scirs2_core::ndarray::Array1;
284
285 #[test]
286 fn test_lamb_basic_creation() {
287 let optimizer: LAMB<f64> = LAMB::new(0.001);
288 assert_abs_diff_eq!(optimizer.learning_rate(), 0.001);
289 assert_abs_diff_eq!(optimizer.get_beta1(), 0.9);
290 assert_abs_diff_eq!(optimizer.get_beta2(), 0.999);
291 assert_abs_diff_eq!(optimizer.get_epsilon(), 1e-6);
292 assert_abs_diff_eq!(optimizer.get_weight_decay(), 0.0);
293 assert!(optimizer.bias_correction);
294 }
295
296 #[test]
297 fn test_lamb_convergence() {
298 let mut optimizer: LAMB<f64> = LAMB::new(0.1);
299
300 let mut params = Array1::from_vec(vec![5.0, 3.0]);
302
303 for _ in 0..50 {
304 let gradients = Array1::from_vec(vec![2.0 * params[0], 2.0 * params[1]]);
306 params = optimizer.step(¶ms, &gradients).unwrap();
307 }
308
309 assert!(params[0].abs() < 1.0);
311 assert!(params[1].abs() < 1.0);
312 }
313
314 #[test]
315 fn test_lamb_with_weight_decay() {
316 let mut optimizer: LAMB<f64> = LAMB::new_with_config(
317 0.1, 0.9, 0.999, 1e-6, 0.1, true, );
324
325 let mut params = Array1::from_vec(vec![1.0, 1.0]);
327
328 for _ in 0..20 {
330 let gradients = Array1::from_vec(vec![0.1, 0.1]);
331 params = optimizer.step(¶ms, &gradients).unwrap();
332 }
333
334 assert!(params[0] < 1.0);
336 assert!(params[1] < 1.0);
337 }
338
339 #[test]
340 fn test_lamb_reset() {
341 let mut optimizer: LAMB<f64> = LAMB::new(0.1);
342
343 let params = Array1::from_vec(vec![1.0]);
345 let gradients = Array1::from_vec(vec![0.5]);
346 let _ = optimizer.step(¶ms, &gradients).unwrap();
347
348 assert!(optimizer.m.is_some());
350 assert!(optimizer.v.is_some());
351 assert_eq!(optimizer.t, 1);
352
353 optimizer.reset();
355
356 assert!(optimizer.m.is_none());
358 assert!(optimizer.v.is_none());
359 assert_eq!(optimizer.t, 0);
360 }
361
362 #[test]
363 fn test_lamb_trust_ratio() {
364 let mut optimizer: LAMB<f64> = LAMB::new(0.1);
366 let params = Array1::from_vec(vec![2.0, 3.0]);
367 let gradients = Array1::from_vec(vec![0.4, 0.6]);
368
369 let new_params = optimizer.step(¶ms, &gradients).unwrap();
370
371 assert_ne!(new_params[0], params[0]);
373 assert_ne!(new_params[1], params[1]);
374
375 assert!(new_params[0] < params[0]); assert!(new_params[1] < params[1]); }
379}