1use axonml_core;
18use axonml_nn::Parameter;
19use axonml_tensor::Tensor;
20
21use crate::optimizer::Optimizer;
22
23#[derive(Debug, Clone)]
32struct LambState {
33 exp_avg: Tensor<f32>,
35 exp_avg_sq: Tensor<f32>,
37 step: usize,
39}
40
41impl LambState {
42 fn new(shape: &[usize], device: axonml_core::Device) -> Self {
43 let size: usize = shape.iter().product();
44 let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
45 let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
46 if device.is_gpu() {
47 exp_avg = exp_avg.to_device(device).unwrap();
48 exp_avg_sq = exp_avg_sq.to_device(device).unwrap();
49 }
50 Self {
51 exp_avg,
52 exp_avg_sq,
53 step: 0,
54 }
55 }
56}
57
58pub struct LAMB {
79 params: Vec<Parameter>,
81 lr: f32,
83 beta1: f32,
85 beta2: f32,
87 eps: f32,
89 weight_decay: f32,
91 bias_correction: bool,
93 state: Vec<LambState>,
95}
96
97impl LAMB {
98 #[must_use]
105 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
106 Self {
107 params,
108 lr,
109 beta1: 0.9,
110 beta2: 0.999,
111 eps: 1e-6,
112 weight_decay: 0.0,
113 bias_correction: true,
114 state: Vec::new(),
115 }
116 }
117
118 #[must_use]
120 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
121 Self {
122 params,
123 lr,
124 beta1: betas.0,
125 beta2: betas.1,
126 eps: 1e-6,
127 weight_decay: 0.0,
128 bias_correction: true,
129 state: Vec::new(),
130 }
131 }
132
133 #[must_use]
135 pub fn with_options(
136 params: Vec<Parameter>,
137 lr: f32,
138 betas: (f32, f32),
139 eps: f32,
140 weight_decay: f32,
141 ) -> Self {
142 Self {
143 params,
144 lr,
145 beta1: betas.0,
146 beta2: betas.1,
147 eps,
148 weight_decay,
149 bias_correction: true,
150 state: Vec::new(),
151 }
152 }
153
154 #[must_use]
156 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
157 self.beta1 = beta1;
158 self.beta2 = beta2;
159 self
160 }
161
162 #[must_use]
164 pub fn eps(mut self, eps: f32) -> Self {
165 self.eps = eps;
166 self
167 }
168
169 #[must_use]
171 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
172 self.weight_decay = weight_decay;
173 self
174 }
175
176 #[must_use]
178 pub fn bias_correction(mut self, enabled: bool) -> Self {
179 self.bias_correction = enabled;
180 self
181 }
182
183 fn ensure_state_initialized(&mut self) {
184 if self.state.is_empty() {
185 self.state = self
186 .params
187 .iter()
188 .map(|p| {
189 let data = p.data();
190 LambState::new(data.shape(), data.device())
191 })
192 .collect();
193 }
194 }
195}
196
197impl Optimizer for LAMB {
198 fn step(&mut self) {
199 self.ensure_state_initialized();
200
201 for (i, param) in self.params.iter().enumerate() {
208 if !param.requires_grad() {
209 continue;
210 }
211
212 let grad = match param.grad() {
213 Some(g) => g,
214 None => continue,
215 };
216
217 let state = &mut self.state[i];
218 state.step += 1;
219
220 let param_data = param.data();
221
222 state.exp_avg = state
224 .exp_avg
225 .mul_scalar(self.beta1)
226 .add(&grad.mul_scalar(1.0 - self.beta1))
227 .unwrap();
228
229 let grad_sq = grad.mul(&grad).unwrap();
231 state.exp_avg_sq = state
232 .exp_avg_sq
233 .mul_scalar(self.beta2)
234 .add(&grad_sq.mul_scalar(1.0 - self.beta2))
235 .unwrap();
236
237 let (bias_correction1, bias_correction2) = if self.bias_correction {
239 (
240 1.0 - self.beta1.powi(state.step as i32),
241 1.0 - self.beta2.powi(state.step as i32),
242 )
243 } else {
244 (1.0, 1.0)
245 };
246
247 let m_hat = state.exp_avg.mul_scalar(1.0 / bias_correction1);
249 let v_hat = state.exp_avg_sq.mul_scalar(1.0 / bias_correction2);
250
251 let adam_update = m_hat.div(&v_hat.sqrt().add_scalar(self.eps)).unwrap();
253
254 let update = if self.weight_decay > 0.0 {
256 adam_update
257 .add(¶m_data.mul_scalar(self.weight_decay))
258 .unwrap()
259 } else {
260 adam_update
261 };
262
263 let weight_norm_sq = param_data.mul(¶m_data).unwrap().sum();
266 let update_norm_sq = update.mul(&update).unwrap().sum();
267
268 let weight_norm = weight_norm_sq.to_vec()[0].sqrt();
270 let update_norm = update_norm_sq.to_vec()[0].sqrt();
271
272 let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
273 weight_norm / update_norm
274 } else {
275 1.0
276 };
277
278 let effective_lr = self.lr * trust_ratio;
280 let new_param = param_data.sub(&update.mul_scalar(effective_lr)).unwrap();
281 param.update_data(new_param);
282 }
283 }
284
285 fn zero_grad(&mut self) {
286 for param in &self.params {
287 param.zero_grad();
288 }
289 }
290
291 fn get_lr(&self) -> f32 {
292 self.lr
293 }
294
295 fn set_lr(&mut self, lr: f32) {
296 self.lr = lr;
297 }
298
299 fn parameters(&self) -> &[Parameter] {
300 &self.params
301 }
302}
303
304#[cfg(test)]
309mod tests {
310 use super::*;
311 use axonml_autograd::Variable;
312
313 #[test]
314 fn test_lamb_creation() {
315 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
316 let param = Parameter::from_variable(var);
317 let optimizer = LAMB::new(vec![param], 0.001);
318
319 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
320 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
321 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
322 }
323
324 #[test]
325 fn test_lamb_step() {
326 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
327 let param = Parameter::from_variable(var);
328
329 param
331 .variable()
332 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
333
334 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
335 optimizer.step();
336
337 let new_data = param.data().to_vec();
338 assert!((new_data[0] - 1.0).abs() > 1e-6);
340 }
341
342 #[test]
343 fn test_lamb_with_weight_decay() {
344 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
345 let param = Parameter::from_variable(var);
346
347 param
348 .variable()
349 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
350
351 let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
352 optimizer.step();
353
354 let new_data = param.data().to_vec();
355 assert!((new_data[0] - 1.0).abs() > 1e-6);
356 }
357
358 #[test]
359 fn test_lamb_builder_pattern() {
360 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
361 let param = Parameter::from_variable(var);
362
363 let optimizer = LAMB::new(vec![param], 0.001)
364 .betas(0.95, 0.9999)
365 .eps(1e-7)
366 .weight_decay(0.01);
367
368 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
369 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
370 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
371 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
372 }
373
374 #[test]
375 fn test_lamb_trust_ratio() {
376 let var = Variable::new(Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(), true);
378 let param = Parameter::from_variable(var);
379
380 param
382 .variable()
383 .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
384
385 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
386
387 let old_data = param.data().to_vec();
389 optimizer.step();
390 let new_data = param.data().to_vec();
391
392 assert!((new_data[0] - old_data[0]).abs() > 1e-6);
394 assert!((new_data[1] - old_data[1]).abs() > 1e-6);
395 }
396
397 #[test]
398 fn test_lamb_zero_grad() {
399 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
400 let param = Parameter::from_variable(var);
401
402 param
403 .variable()
404 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
405
406 let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
407 assert!(param.grad().is_some());
408
409 optimizer.zero_grad();
410 }
412
413 #[test]
414 fn test_l2_norm_via_tensor() {
415 let t = Tensor::from_vec(vec![3.0f32, 4.0], &[2]).unwrap();
416 let norm_sq = t.mul(&t).unwrap().sum();
417 let norm = norm_sq.to_vec()[0].sqrt();
418 assert!((norm - 5.0).abs() < 1e-6);
419 }
420}