1use axonml_nn::Parameter;
29use axonml_tensor::Tensor;
30
31use crate::optimizer::Optimizer;
32
33#[derive(Debug, Clone)]
39struct LambState {
40 exp_avg: Vec<f32>,
42 exp_avg_sq: Vec<f32>,
44 step: usize,
46}
47
48impl LambState {
49 fn new(size: usize) -> Self {
50 Self {
51 exp_avg: vec![0.0; size],
52 exp_avg_sq: vec![0.0; size],
53 step: 0,
54 }
55 }
56}
57
58pub struct LAMB {
79 params: Vec<Parameter>,
81 lr: f32,
83 beta1: f32,
85 beta2: f32,
87 eps: f32,
89 weight_decay: f32,
91 bias_correction: bool,
93 state: Vec<LambState>,
95}
96
97impl LAMB {
98 #[must_use]
105 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
106 Self {
107 params,
108 lr,
109 beta1: 0.9,
110 beta2: 0.999,
111 eps: 1e-6,
112 weight_decay: 0.0,
113 bias_correction: true,
114 state: Vec::new(),
115 }
116 }
117
118 #[must_use]
120 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
121 Self {
122 params,
123 lr,
124 beta1: betas.0,
125 beta2: betas.1,
126 eps: 1e-6,
127 weight_decay: 0.0,
128 bias_correction: true,
129 state: Vec::new(),
130 }
131 }
132
133 #[must_use]
135 pub fn with_options(
136 params: Vec<Parameter>,
137 lr: f32,
138 betas: (f32, f32),
139 eps: f32,
140 weight_decay: f32,
141 ) -> Self {
142 Self {
143 params,
144 lr,
145 beta1: betas.0,
146 beta2: betas.1,
147 eps,
148 weight_decay,
149 bias_correction: true,
150 state: Vec::new(),
151 }
152 }
153
154 #[must_use]
156 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
157 self.beta1 = beta1;
158 self.beta2 = beta2;
159 self
160 }
161
162 #[must_use]
164 pub fn eps(mut self, eps: f32) -> Self {
165 self.eps = eps;
166 self
167 }
168
169 #[must_use]
171 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
172 self.weight_decay = weight_decay;
173 self
174 }
175
176 #[must_use]
178 pub fn bias_correction(mut self, enabled: bool) -> Self {
179 self.bias_correction = enabled;
180 self
181 }
182
183 fn ensure_state_initialized(&mut self) {
184 if self.state.is_empty() {
185 self.state = self
186 .params
187 .iter()
188 .map(|p| LambState::new(p.numel()))
189 .collect();
190 }
191 }
192
193 fn l2_norm(vec: &[f32]) -> f32 {
195 vec.iter().map(|x| x * x).sum::<f32>().sqrt()
196 }
197}
198
199impl Optimizer for LAMB {
200 fn step(&mut self) {
201 self.ensure_state_initialized();
202
203 for (i, param) in self.params.iter().enumerate() {
204 if !param.requires_grad() {
205 continue;
206 }
207
208 let grad = match param.grad() {
209 Some(g) => g,
210 None => continue,
211 };
212
213 let grad_vec = grad.to_vec();
214 let state = &mut self.state[i];
215 state.step += 1;
216
217 let param_data = param.data();
218 let mut param_vec = param_data.to_vec();
219
220 for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
222 *m = self.beta1 * *m + (1.0 - self.beta1) * g;
223 }
224
225 for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
227 *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
228 }
229
230 let (bias_correction1, bias_correction2) = if self.bias_correction {
232 (
233 1.0 - self.beta1.powi(state.step as i32),
234 1.0 - self.beta2.powi(state.step as i32),
235 )
236 } else {
237 (1.0, 1.0)
238 };
239
240 let eps = self.eps;
242 let wd = self.weight_decay;
243 let has_wd = wd > 0.0;
244 let n = param_vec.len();
245
246 let mut update = grad_vec; let mut weight_norm_sq: f32 = 0.0;
249 let mut update_norm_sq: f32 = 0.0;
250
251 for i in 0..n {
252 let m_hat = state.exp_avg[i] / bias_correction1;
253 let v_hat = state.exp_avg_sq[i] / bias_correction2;
254 let mut u = m_hat / (v_hat.sqrt() + eps);
255 if has_wd {
256 u += wd * param_vec[i];
257 }
258 update[i] = u;
259 weight_norm_sq += param_vec[i] * param_vec[i];
260 update_norm_sq += u * u;
261 }
262
263 let weight_norm = weight_norm_sq.sqrt();
264 let update_norm = update_norm_sq.sqrt();
265
266 let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
267 weight_norm / update_norm
268 } else {
269 1.0
270 };
271
272 let effective_lr = self.lr * trust_ratio;
274 for i in 0..n {
275 param_vec[i] -= effective_lr * update[i];
276 }
277
278 let mut new_tensor = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
279 let device = param_data.device();
281 if device.is_gpu() {
282 new_tensor = new_tensor.to_device(device).unwrap();
283 }
284 param.update_data(new_tensor);
285 }
286 }
287
288 fn zero_grad(&mut self) {
289 for param in &self.params {
290 param.zero_grad();
291 }
292 }
293
294 fn get_lr(&self) -> f32 {
295 self.lr
296 }
297
298 fn set_lr(&mut self, lr: f32) {
299 self.lr = lr;
300 }
301
302 fn parameters(&self) -> &[Parameter] {
303 &self.params
304 }
305}
306
307#[cfg(test)]
312mod tests {
313 use super::*;
314 use axonml_autograd::Variable;
315
316 #[test]
317 fn test_lamb_creation() {
318 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
319 let param = Parameter::from_variable(var);
320 let optimizer = LAMB::new(vec![param], 0.001);
321
322 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
323 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
324 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
325 }
326
327 #[test]
328 fn test_lamb_step() {
329 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
330 let param = Parameter::from_variable(var);
331
332 param
334 .variable()
335 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
336
337 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
338 optimizer.step();
339
340 let new_data = param.data().to_vec();
341 assert!((new_data[0] - 1.0).abs() > 1e-6);
343 }
344
345 #[test]
346 fn test_lamb_with_weight_decay() {
347 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
348 let param = Parameter::from_variable(var);
349
350 param
351 .variable()
352 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
353
354 let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
355 optimizer.step();
356
357 let new_data = param.data().to_vec();
358 assert!((new_data[0] - 1.0).abs() > 1e-6);
359 }
360
361 #[test]
362 fn test_lamb_builder_pattern() {
363 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
364 let param = Parameter::from_variable(var);
365
366 let optimizer = LAMB::new(vec![param], 0.001)
367 .betas(0.95, 0.9999)
368 .eps(1e-7)
369 .weight_decay(0.01);
370
371 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
372 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
373 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
374 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
375 }
376
377 #[test]
378 fn test_lamb_trust_ratio() {
379 let var = Variable::new(Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(), true);
381 let param = Parameter::from_variable(var);
382
383 param
385 .variable()
386 .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
387
388 let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
389
390 let old_data = param.data().to_vec();
392 optimizer.step();
393 let new_data = param.data().to_vec();
394
395 assert!((new_data[0] - old_data[0]).abs() > 1e-6);
397 assert!((new_data[1] - old_data[1]).abs() > 1e-6);
398 }
399
400 #[test]
401 fn test_lamb_zero_grad() {
402 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
403 let param = Parameter::from_variable(var);
404
405 param
406 .variable()
407 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
408
409 let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
410 assert!(param.grad().is_some());
411
412 optimizer.zero_grad();
413 }
415
416 #[test]
417 fn test_l2_norm() {
418 let vec = vec![3.0, 4.0];
419 let norm = LAMB::l2_norm(&vec);
420 assert!((norm - 5.0).abs() < 1e-6);
421 }
422}