1use axonml_nn::Parameter;
29use axonml_tensor::Tensor;
30
31use crate::optimizer::Optimizer;
32
33#[derive(Debug, Clone)]
39struct LambState {
40 exp_avg: Vec<f32>,
42 exp_avg_sq: Vec<f32>,
44 step: usize,
46}
47
48impl LambState {
49 fn new(size: usize) -> Self {
50 Self {
51 exp_avg: vec![0.0; size],
52 exp_avg_sq: vec![0.0; size],
53 step: 0,
54 }
55 }
56}
57
58pub struct LAMB {
79 params: Vec<Parameter>,
81 lr: f32,
83 beta1: f32,
85 beta2: f32,
87 eps: f32,
89 weight_decay: f32,
91 bias_correction: bool,
93 state: Vec<LambState>,
95}
96
97impl LAMB {
98 #[must_use]
105 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
106 Self {
107 params,
108 lr,
109 beta1: 0.9,
110 beta2: 0.999,
111 eps: 1e-6,
112 weight_decay: 0.0,
113 bias_correction: true,
114 state: Vec::new(),
115 }
116 }
117
118 #[must_use]
120 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
121 Self {
122 params,
123 lr,
124 beta1: betas.0,
125 beta2: betas.1,
126 eps: 1e-6,
127 weight_decay: 0.0,
128 bias_correction: true,
129 state: Vec::new(),
130 }
131 }
132
133 #[must_use]
135 pub fn with_options(
136 params: Vec<Parameter>,
137 lr: f32,
138 betas: (f32, f32),
139 eps: f32,
140 weight_decay: f32,
141 ) -> Self {
142 Self {
143 params,
144 lr,
145 beta1: betas.0,
146 beta2: betas.1,
147 eps,
148 weight_decay,
149 bias_correction: true,
150 state: Vec::new(),
151 }
152 }
153
154 #[must_use]
156 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
157 self.beta1 = beta1;
158 self.beta2 = beta2;
159 self
160 }
161
162 #[must_use]
164 pub fn eps(mut self, eps: f32) -> Self {
165 self.eps = eps;
166 self
167 }
168
169 #[must_use]
171 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
172 self.weight_decay = weight_decay;
173 self
174 }
175
176 #[must_use]
178 pub fn bias_correction(mut self, enabled: bool) -> Self {
179 self.bias_correction = enabled;
180 self
181 }
182
183 fn ensure_state_initialized(&mut self) {
184 if self.state.is_empty() {
185 self.state = self
186 .params
187 .iter()
188 .map(|p| LambState::new(p.numel()))
189 .collect();
190 }
191 }
192
193 fn l2_norm(vec: &[f32]) -> f32 {
195 vec.iter().map(|x| x * x).sum::<f32>().sqrt()
196 }
197}
198
199impl Optimizer for LAMB {
200 fn step(&mut self) {
201 self.ensure_state_initialized();
202
203 for (i, param) in self.params.iter().enumerate() {
204 if !param.requires_grad() {
205 continue;
206 }
207
208 let grad = match param.grad() {
209 Some(g) => g,
210 None => continue,
211 };
212
213 let grad_vec = grad.to_vec();
214 let state = &mut self.state[i];
215 state.step += 1;
216
217 let param_data = param.data();
218 let param_vec = param_data.to_vec();
219
220 for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
222 *m = self.beta1 * *m + (1.0 - self.beta1) * g;
223 }
224
225 for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
227 *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
228 }
229
230 let (bias_correction1, bias_correction2) = if self.bias_correction {
232 (
233 1.0 - self.beta1.powi(state.step as i32),
234 1.0 - self.beta2.powi(state.step as i32),
235 )
236 } else {
237 (1.0, 1.0)
238 };
239
240 let mut update: Vec<f32> = state
242 .exp_avg
243 .iter()
244 .zip(state.exp_avg_sq.iter())
245 .map(|(m, v)| {
246 let m_hat = m / bias_correction1;
247 let v_hat = v / bias_correction2;
248 m_hat / (v_hat.sqrt() + self.eps)
249 })
250 .collect();
251
252 if self.weight_decay > 0.0 {
254 for (u, p) in update.iter_mut().zip(param_vec.iter()) {
255 *u += self.weight_decay * p;
256 }
257 }
258
259 let weight_norm = Self::l2_norm(¶m_vec);
261 let update_norm = Self::l2_norm(&update);
262
263 let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
264 weight_norm / update_norm
265 } else {
266 1.0
267 };
268
269 let effective_lr = self.lr * trust_ratio;
271 let new_data: Vec<f32> = param_vec
272 .iter()
273 .zip(update.iter())
274 .map(|(p, u)| p - effective_lr * u)
275 .collect();
276
277 let new_tensor = Tensor::from_vec(new_data, param_data.shape()).unwrap();
278 param.update_data(new_tensor);
279 }
280 }
281
282 fn zero_grad(&mut self) {
283 for param in &self.params {
284 param.zero_grad();
285 }
286 }
287
288 fn get_lr(&self) -> f32 {
289 self.lr
290 }
291
292 fn set_lr(&mut self, lr: f32) {
293 self.lr = lr;
294 }
295
296 fn parameters(&self) -> &[Parameter] {
297 &self.params
298 }
299}
300
301#[cfg(test)]
306mod tests {
307 use super::*;
308 use axonml_autograd::Variable;
309
310 #[test]
311 fn test_lamb_creation() {
312 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
313 let param = Parameter::from_variable(var);
314 let optimizer = LAMB::new(vec![param], 0.001);
315
316 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
317 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
318 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
319 }
320
321 #[test]
322 fn test_lamb_step() {
323 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
324 let param = Parameter::from_variable(var);
325
326 param
328 .variable()
329 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
330
331 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
332 optimizer.step();
333
334 let new_data = param.data().to_vec();
335 assert!((new_data[0] - 1.0).abs() > 1e-6);
337 }
338
339 #[test]
340 fn test_lamb_with_weight_decay() {
341 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
342 let param = Parameter::from_variable(var);
343
344 param
345 .variable()
346 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
347
348 let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
349 optimizer.step();
350
351 let new_data = param.data().to_vec();
352 assert!((new_data[0] - 1.0).abs() > 1e-6);
353 }
354
355 #[test]
356 fn test_lamb_builder_pattern() {
357 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
358 let param = Parameter::from_variable(var);
359
360 let optimizer = LAMB::new(vec![param], 0.001)
361 .betas(0.95, 0.9999)
362 .eps(1e-7)
363 .weight_decay(0.01);
364
365 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
366 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
367 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
368 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
369 }
370
371 #[test]
372 fn test_lamb_trust_ratio() {
373 let var = Variable::new(Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(), true);
375 let param = Parameter::from_variable(var);
376
377 param
379 .variable()
380 .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
381
382 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
383
384 let old_data = param.data().to_vec();
386 optimizer.step();
387 let new_data = param.data().to_vec();
388
389 assert!((new_data[0] - old_data[0]).abs() > 1e-6);
391 assert!((new_data[1] - old_data[1]).abs() > 1e-6);
392 }
393
394 #[test]
395 fn test_lamb_zero_grad() {
396 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
397 let param = Parameter::from_variable(var);
398
399 param
400 .variable()
401 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
402
403 let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
404 assert!(param.grad().is_some());
405
406 optimizer.zero_grad();
407 }
409
410 #[test]
411 fn test_l2_norm() {
412 let vec = vec![3.0, 4.0];
413 let norm = LAMB::l2_norm(&vec);
414 assert!((norm - 5.0).abs() < 1e-6);
415 }
416}