1use axonml_core;
18use axonml_nn::Parameter;
19use axonml_tensor::Tensor;
20
21use crate::optimizer::Optimizer;
22
23#[derive(Debug, Clone)]
32struct LambState {
33 exp_avg: Tensor<f32>,
35 exp_avg_sq: Tensor<f32>,
37 step: usize,
39}
40
41impl LambState {
42 fn new(shape: &[usize], device: axonml_core::Device) -> Self {
43 let size: usize = shape.iter().product();
44 let mut exp_avg =
45 Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
46 let mut exp_avg_sq =
47 Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
48 if device.is_gpu() {
49 exp_avg = exp_avg.to_device(device).expect("device transfer failed");
50 exp_avg_sq = exp_avg_sq
51 .to_device(device)
52 .expect("device transfer failed");
53 }
54 Self {
55 exp_avg,
56 exp_avg_sq,
57 step: 0,
58 }
59 }
60}
61
62pub struct LAMB {
83 params: Vec<Parameter>,
85 lr: f32,
87 beta1: f32,
89 beta2: f32,
91 eps: f32,
93 weight_decay: f32,
95 bias_correction: bool,
97 state: Vec<LambState>,
99}
100
101impl LAMB {
102 #[must_use]
109 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
110 Self {
111 params,
112 lr,
113 beta1: 0.9,
114 beta2: 0.999,
115 eps: 1e-6,
116 weight_decay: 0.0,
117 bias_correction: true,
118 state: Vec::new(),
119 }
120 }
121
122 #[must_use]
124 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
125 Self {
126 params,
127 lr,
128 beta1: betas.0,
129 beta2: betas.1,
130 eps: 1e-6,
131 weight_decay: 0.0,
132 bias_correction: true,
133 state: Vec::new(),
134 }
135 }
136
137 #[must_use]
139 pub fn with_options(
140 params: Vec<Parameter>,
141 lr: f32,
142 betas: (f32, f32),
143 eps: f32,
144 weight_decay: f32,
145 ) -> Self {
146 Self {
147 params,
148 lr,
149 beta1: betas.0,
150 beta2: betas.1,
151 eps,
152 weight_decay,
153 bias_correction: true,
154 state: Vec::new(),
155 }
156 }
157
158 #[must_use]
160 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
161 self.beta1 = beta1;
162 self.beta2 = beta2;
163 self
164 }
165
166 #[must_use]
168 pub fn eps(mut self, eps: f32) -> Self {
169 self.eps = eps;
170 self
171 }
172
173 #[must_use]
175 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
176 self.weight_decay = weight_decay;
177 self
178 }
179
180 #[must_use]
182 pub fn bias_correction(mut self, enabled: bool) -> Self {
183 self.bias_correction = enabled;
184 self
185 }
186
187 fn ensure_state_initialized(&mut self) {
188 if self.state.is_empty() {
189 self.state = self
190 .params
191 .iter()
192 .map(|p| {
193 let data = p.data();
194 LambState::new(data.shape(), data.device())
195 })
196 .collect();
197 }
198 }
199}
200
201impl Optimizer for LAMB {
202 fn step(&mut self) {
203 self.ensure_state_initialized();
204
205 for (i, param) in self.params.iter().enumerate() {
212 if !param.requires_grad() {
213 continue;
214 }
215
216 let grad = match param.grad() {
217 Some(g) => g,
218 None => continue,
219 };
220
221 let state = &mut self.state[i];
222 state.step += 1;
223
224 let param_data = param.data();
225
226 state.exp_avg = state
228 .exp_avg
229 .mul_scalar(self.beta1)
230 .add(&grad.mul_scalar(1.0 - self.beta1))
231 .unwrap();
232
233 let grad_sq = grad.mul(&grad).unwrap();
235 state.exp_avg_sq = state
236 .exp_avg_sq
237 .mul_scalar(self.beta2)
238 .add(&grad_sq.mul_scalar(1.0 - self.beta2))
239 .unwrap();
240
241 let (bias_correction1, bias_correction2) = if self.bias_correction {
243 (
244 1.0 - self.beta1.powi(state.step as i32),
245 1.0 - self.beta2.powi(state.step as i32),
246 )
247 } else {
248 (1.0, 1.0)
249 };
250
251 let m_hat = state.exp_avg.mul_scalar(1.0 / bias_correction1);
253 let v_hat = state.exp_avg_sq.mul_scalar(1.0 / bias_correction2);
254
255 let adam_update = m_hat.div(&v_hat.sqrt().add_scalar(self.eps)).unwrap();
257
258 let update = if self.weight_decay > 0.0 {
260 adam_update
261 .add(¶m_data.mul_scalar(self.weight_decay))
262 .unwrap()
263 } else {
264 adam_update
265 };
266
267 let weight_norm_sq = param_data.mul(¶m_data).unwrap().sum();
270 let update_norm_sq = update.mul(&update).unwrap().sum();
271
272 let weight_norm = weight_norm_sq.to_vec()[0].sqrt();
274 let update_norm = update_norm_sq.to_vec()[0].sqrt();
275
276 let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
277 weight_norm / update_norm
278 } else {
279 1.0
280 };
281
282 let effective_lr = self.lr * trust_ratio;
284 let new_param = param_data.sub(&update.mul_scalar(effective_lr)).unwrap();
285 param.update_data(new_param);
286 }
287 }
288
289 fn zero_grad(&mut self) {
290 for param in &self.params {
291 param.zero_grad();
292 }
293 }
294
295 fn get_lr(&self) -> f32 {
296 self.lr
297 }
298
299 fn set_lr(&mut self, lr: f32) {
300 self.lr = lr;
301 }
302
303 fn parameters(&self) -> &[Parameter] {
304 &self.params
305 }
306}
307
308#[cfg(test)]
313mod tests {
314 use super::*;
315 use axonml_autograd::Variable;
316
317 #[test]
318 fn test_lamb_creation() {
319 let var = Variable::new(
320 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
321 true,
322 );
323 let param = Parameter::from_variable(var);
324 let optimizer = LAMB::new(vec![param], 0.001);
325
326 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
327 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
328 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
329 }
330
331 #[test]
332 fn test_lamb_step() {
333 let var = Variable::new(
334 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
335 true,
336 );
337 let param = Parameter::from_variable(var);
338
339 param
341 .variable()
342 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
343
344 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
345 optimizer.step();
346
347 let new_data = param.data().to_vec();
348 assert!((new_data[0] - 1.0).abs() > 1e-6);
350 }
351
352 #[test]
353 fn test_lamb_with_weight_decay() {
354 let var = Variable::new(
355 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
356 true,
357 );
358 let param = Parameter::from_variable(var);
359
360 param
361 .variable()
362 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
363
364 let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
365 optimizer.step();
366
367 let new_data = param.data().to_vec();
368 assert!((new_data[0] - 1.0).abs() > 1e-6);
369 }
370
371 #[test]
372 fn test_lamb_builder_pattern() {
373 let var = Variable::new(
374 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
375 true,
376 );
377 let param = Parameter::from_variable(var);
378
379 let optimizer = LAMB::new(vec![param], 0.001)
380 .betas(0.95, 0.9999)
381 .eps(1e-7)
382 .weight_decay(0.01);
383
384 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
385 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
386 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
387 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
388 }
389
390 #[test]
391 fn test_lamb_trust_ratio() {
392 let var = Variable::new(
394 Tensor::from_vec(vec![3.0, 4.0], &[2]).expect("tensor creation failed"),
395 true,
396 );
397 let param = Parameter::from_variable(var);
398
399 param
401 .variable()
402 .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).expect("tensor creation failed"));
403
404 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
405
406 let old_data = param.data().to_vec();
408 optimizer.step();
409 let new_data = param.data().to_vec();
410
411 assert!((new_data[0] - old_data[0]).abs() > 1e-6);
413 assert!((new_data[1] - old_data[1]).abs() > 1e-6);
414 }
415
416 #[test]
417 fn test_lamb_zero_grad() {
418 let var = Variable::new(
419 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
420 true,
421 );
422 let param = Parameter::from_variable(var);
423
424 param
425 .variable()
426 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
427
428 let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
429 assert!(param.grad().is_some());
430
431 optimizer.zero_grad();
432 }
434
435 #[test]
436 fn test_l2_norm_via_tensor() {
437 let t = Tensor::from_vec(vec![3.0f32, 4.0], &[2]).expect("tensor creation failed");
438 let norm_sq = t.mul(&t).unwrap().sum();
439 let norm = norm_sq.to_vec()[0].sqrt();
440 assert!((norm - 5.0).abs() < 1e-6);
441 }
442}