1use axonml_core;
24use axonml_nn::Parameter;
25use axonml_tensor::Tensor;
26
27use crate::optimizer::Optimizer;
28
29#[derive(Debug, Clone)]
38struct LambState {
39 exp_avg: Tensor<f32>,
41 exp_avg_sq: Tensor<f32>,
43 step: usize,
45}
46
47impl LambState {
48 fn new(shape: &[usize], device: axonml_core::Device) -> Self {
49 let size: usize = shape.iter().product();
50 let mut exp_avg =
51 Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
52 let mut exp_avg_sq =
53 Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
54 if device.is_gpu() {
55 exp_avg = exp_avg.to_device(device).expect("device transfer failed");
56 exp_avg_sq = exp_avg_sq
57 .to_device(device)
58 .expect("device transfer failed");
59 }
60 Self {
61 exp_avg,
62 exp_avg_sq,
63 step: 0,
64 }
65 }
66}
67
68pub struct LAMB {
89 params: Vec<Parameter>,
91 lr: f32,
93 beta1: f32,
95 beta2: f32,
97 eps: f32,
99 weight_decay: f32,
101 bias_correction: bool,
103 state: Vec<LambState>,
105}
106
107impl LAMB {
108 #[must_use]
115 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
116 Self {
117 params,
118 lr,
119 beta1: 0.9,
120 beta2: 0.999,
121 eps: 1e-6,
122 weight_decay: 0.0,
123 bias_correction: true,
124 state: Vec::new(),
125 }
126 }
127
128 #[must_use]
130 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
131 Self {
132 params,
133 lr,
134 beta1: betas.0,
135 beta2: betas.1,
136 eps: 1e-6,
137 weight_decay: 0.0,
138 bias_correction: true,
139 state: Vec::new(),
140 }
141 }
142
143 #[must_use]
145 pub fn with_options(
146 params: Vec<Parameter>,
147 lr: f32,
148 betas: (f32, f32),
149 eps: f32,
150 weight_decay: f32,
151 ) -> Self {
152 Self {
153 params,
154 lr,
155 beta1: betas.0,
156 beta2: betas.1,
157 eps,
158 weight_decay,
159 bias_correction: true,
160 state: Vec::new(),
161 }
162 }
163
164 #[must_use]
166 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
167 self.beta1 = beta1;
168 self.beta2 = beta2;
169 self
170 }
171
172 #[must_use]
174 pub fn eps(mut self, eps: f32) -> Self {
175 self.eps = eps;
176 self
177 }
178
179 #[must_use]
181 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
182 self.weight_decay = weight_decay;
183 self
184 }
185
186 #[must_use]
188 pub fn bias_correction(mut self, enabled: bool) -> Self {
189 self.bias_correction = enabled;
190 self
191 }
192
193 fn ensure_state_initialized(&mut self) {
194 if self.state.is_empty() {
195 self.state = self
196 .params
197 .iter()
198 .map(|p| {
199 let data = p.data();
200 LambState::new(data.shape(), data.device())
201 })
202 .collect();
203 }
204 }
205}
206
207impl Optimizer for LAMB {
208 fn step(&mut self) {
209 self.ensure_state_initialized();
210
211 for (i, param) in self.params.iter().enumerate() {
218 if !param.requires_grad() {
219 continue;
220 }
221
222 let grad = match param.grad() {
223 Some(g) => g,
224 None => continue,
225 };
226
227 let state = &mut self.state[i];
228 state.step += 1;
229
230 let param_data = param.data();
231
232 state.exp_avg = state
234 .exp_avg
235 .mul_scalar(self.beta1)
236 .add(&grad.mul_scalar(1.0 - self.beta1))
237 .unwrap();
238
239 let grad_sq = grad.mul(&grad).unwrap();
241 state.exp_avg_sq = state
242 .exp_avg_sq
243 .mul_scalar(self.beta2)
244 .add(&grad_sq.mul_scalar(1.0 - self.beta2))
245 .unwrap();
246
247 let (bias_correction1, bias_correction2) = if self.bias_correction {
249 (
250 1.0 - self.beta1.powi(state.step as i32),
251 1.0 - self.beta2.powi(state.step as i32),
252 )
253 } else {
254 (1.0, 1.0)
255 };
256
257 let m_hat = state.exp_avg.mul_scalar(1.0 / bias_correction1);
259 let v_hat = state.exp_avg_sq.mul_scalar(1.0 / bias_correction2);
260
261 let adam_update = m_hat.div(&v_hat.sqrt().add_scalar(self.eps)).unwrap();
263
264 let update = if self.weight_decay > 0.0 {
266 adam_update
267 .add(¶m_data.mul_scalar(self.weight_decay))
268 .unwrap()
269 } else {
270 adam_update
271 };
272
273 let weight_norm_sq = param_data.mul(¶m_data).unwrap().sum();
276 let update_norm_sq = update.mul(&update).unwrap().sum();
277
278 let weight_norm = weight_norm_sq.to_vec()[0].sqrt();
280 let update_norm = update_norm_sq.to_vec()[0].sqrt();
281
282 let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
283 weight_norm / update_norm
284 } else {
285 1.0
286 };
287
288 let effective_lr = self.lr * trust_ratio;
290 let new_param = param_data.sub(&update.mul_scalar(effective_lr)).unwrap();
291 param.update_data(new_param);
292 }
293 }
294
295 fn zero_grad(&mut self) {
296 for param in &self.params {
297 param.zero_grad();
298 }
299 }
300
301 fn get_lr(&self) -> f32 {
302 self.lr
303 }
304
305 fn set_lr(&mut self, lr: f32) {
306 self.lr = lr;
307 }
308
309 fn parameters(&self) -> &[Parameter] {
310 &self.params
311 }
312}
313
314#[cfg(test)]
319mod tests {
320 use super::*;
321 use axonml_autograd::Variable;
322
323 #[test]
324 fn test_lamb_creation() {
325 let var = Variable::new(
326 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
327 true,
328 );
329 let param = Parameter::from_variable(var);
330 let optimizer = LAMB::new(vec![param], 0.001);
331
332 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
333 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
334 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
335 }
336
337 #[test]
338 fn test_lamb_step() {
339 let var = Variable::new(
340 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
341 true,
342 );
343 let param = Parameter::from_variable(var);
344
345 param
347 .variable()
348 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
349
350 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
351 optimizer.step();
352
353 let new_data = param.data().to_vec();
354 assert!((new_data[0] - 1.0).abs() > 1e-6);
356 }
357
358 #[test]
359 fn test_lamb_with_weight_decay() {
360 let var = Variable::new(
361 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
362 true,
363 );
364 let param = Parameter::from_variable(var);
365
366 param
367 .variable()
368 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
369
370 let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
371 optimizer.step();
372
373 let new_data = param.data().to_vec();
374 assert!((new_data[0] - 1.0).abs() > 1e-6);
375 }
376
377 #[test]
378 fn test_lamb_builder_pattern() {
379 let var = Variable::new(
380 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
381 true,
382 );
383 let param = Parameter::from_variable(var);
384
385 let optimizer = LAMB::new(vec![param], 0.001)
386 .betas(0.95, 0.9999)
387 .eps(1e-7)
388 .weight_decay(0.01);
389
390 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
391 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
392 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
393 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
394 }
395
396 #[test]
397 fn test_lamb_trust_ratio() {
398 let var = Variable::new(
400 Tensor::from_vec(vec![3.0, 4.0], &[2]).expect("tensor creation failed"),
401 true,
402 );
403 let param = Parameter::from_variable(var);
404
405 param
407 .variable()
408 .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).expect("tensor creation failed"));
409
410 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
411
412 let old_data = param.data().to_vec();
414 optimizer.step();
415 let new_data = param.data().to_vec();
416
417 assert!((new_data[0] - old_data[0]).abs() > 1e-6);
419 assert!((new_data[1] - old_data[1]).abs() > 1e-6);
420 }
421
422 #[test]
423 fn test_lamb_zero_grad() {
424 let var = Variable::new(
425 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
426 true,
427 );
428 let param = Parameter::from_variable(var);
429
430 param
431 .variable()
432 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
433
434 let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
435 assert!(param.grad().is_some());
436
437 optimizer.zero_grad();
438 }
440
441 #[test]
442 fn test_l2_norm_via_tensor() {
443 let t = Tensor::from_vec(vec![3.0f32, 4.0], &[2]).expect("tensor creation failed");
444 let norm_sq = t.mul(&t).unwrap().sum();
445 let norm = norm_sq.to_vec()[0].sqrt();
446 assert!((norm - 5.0).abs() < 1e-6);
447 }
448}