Skip to main content

axonml_optim/
lamb.rs

1//! LAMB Optimizer - Layer-wise Adaptive Moments
2//!
3//! # File
4//! `crates/axonml-optim/src/lamb.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_core;
18use axonml_nn::Parameter;
19use axonml_tensor::Tensor;
20
21use crate::optimizer::Optimizer;
22
23// =============================================================================
24// LAMB State
25// =============================================================================
26
27/// Per-parameter state for LAMB optimizer.
28///
29/// Stores momentum tensors on the same device as parameters (CPU or GPU).
30/// When parameters are on GPU, all state stays GPU-resident — zero CPU round-trips.
31#[derive(Debug, Clone)]
32struct LambState {
33    /// First moment (exponential moving average of gradient) — on same device as param.
34    exp_avg: Tensor<f32>,
35    /// Second moment (exponential moving average of squared gradient) — on same device as param.
36    exp_avg_sq: Tensor<f32>,
37    /// Step count for bias correction
38    step: usize,
39}
40
41impl LambState {
42    fn new(shape: &[usize], device: axonml_core::Device) -> Self {
43        let size: usize = shape.iter().product();
44        let mut exp_avg =
45            Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
46        let mut exp_avg_sq =
47            Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
48        if device.is_gpu() {
49            exp_avg = exp_avg.to_device(device).expect("device transfer failed");
50            exp_avg_sq = exp_avg_sq
51                .to_device(device)
52                .expect("device transfer failed");
53        }
54        Self {
55            exp_avg,
56            exp_avg_sq,
57            step: 0,
58        }
59    }
60}
61
62// =============================================================================
63// LAMB Optimizer
64// =============================================================================
65
66/// LAMB optimizer for large batch training.
67///
68/// LAMB extends Adam by adding a layer-wise trust ratio that scales
69/// the update based on the ratio of parameter norm to update norm.
70/// This enables stable training with very large batch sizes.
71///
72/// The update rule is:
73/// ```text
74/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
75/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
76/// m_hat = m_t / (1 - beta1^t)
77/// v_hat = v_t / (1 - beta2^t)
78/// r = m_hat / (sqrt(v_hat) + eps) + weight_decay * param
79/// trust_ratio = ||param|| / ||r||  (layer-wise)
80/// param = param - lr * trust_ratio * r
81/// ```
82pub struct LAMB {
83    /// Parameters to optimize
84    params: Vec<Parameter>,
85    /// Learning rate
86    lr: f32,
87    /// First moment decay rate
88    beta1: f32,
89    /// Second moment decay rate
90    beta2: f32,
91    /// Small constant for numerical stability
92    eps: f32,
93    /// Weight decay coefficient (decoupled)
94    weight_decay: f32,
95    /// Whether to use bias correction
96    bias_correction: bool,
97    /// Per-parameter state
98    state: Vec<LambState>,
99}
100
101impl LAMB {
102    /// Creates a new LAMB optimizer with default hyperparameters.
103    ///
104    /// Defaults:
105    /// - betas: (0.9, 0.999)
106    /// - eps: 1e-6
107    /// - weight_decay: 0.0
108    #[must_use]
109    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
110        Self {
111            params,
112            lr,
113            beta1: 0.9,
114            beta2: 0.999,
115            eps: 1e-6,
116            weight_decay: 0.0,
117            bias_correction: true,
118            state: Vec::new(),
119        }
120    }
121
122    /// Creates LAMB with specified betas.
123    #[must_use]
124    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
125        Self {
126            params,
127            lr,
128            beta1: betas.0,
129            beta2: betas.1,
130            eps: 1e-6,
131            weight_decay: 0.0,
132            bias_correction: true,
133            state: Vec::new(),
134        }
135    }
136
137    /// Creates LAMB with all options.
138    #[must_use]
139    pub fn with_options(
140        params: Vec<Parameter>,
141        lr: f32,
142        betas: (f32, f32),
143        eps: f32,
144        weight_decay: f32,
145    ) -> Self {
146        Self {
147            params,
148            lr,
149            beta1: betas.0,
150            beta2: betas.1,
151            eps,
152            weight_decay,
153            bias_correction: true,
154            state: Vec::new(),
155        }
156    }
157
158    /// Builder: set betas (momentum decay rates)
159    #[must_use]
160    pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
161        self.beta1 = beta1;
162        self.beta2 = beta2;
163        self
164    }
165
166    /// Builder: set epsilon
167    #[must_use]
168    pub fn eps(mut self, eps: f32) -> Self {
169        self.eps = eps;
170        self
171    }
172
173    /// Builder: set weight decay
174    #[must_use]
175    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
176        self.weight_decay = weight_decay;
177        self
178    }
179
180    /// Builder: set bias correction
181    #[must_use]
182    pub fn bias_correction(mut self, enabled: bool) -> Self {
183        self.bias_correction = enabled;
184        self
185    }
186
187    fn ensure_state_initialized(&mut self) {
188        if self.state.is_empty() {
189            self.state = self
190                .params
191                .iter()
192                .map(|p| {
193                    let data = p.data();
194                    LambState::new(data.shape(), data.device())
195                })
196                .collect();
197        }
198    }
199}
200
201impl Optimizer for LAMB {
202    fn step(&mut self) {
203        self.ensure_state_initialized();
204
205        // ============================================================
206        // Tensor-op path: works on both CPU and GPU without to_vec()
207        // All ops (add, mul, mul_scalar, div, sqrt, add_scalar, sub)
208        // dispatch to CUDA when the tensors are GPU-resident.
209        // ============================================================
210
211        for (i, param) in self.params.iter().enumerate() {
212            if !param.requires_grad() {
213                continue;
214            }
215
216            let grad = match param.grad() {
217                Some(g) => g,
218                None => continue,
219            };
220
221            let state = &mut self.state[i];
222            state.step += 1;
223
224            let param_data = param.data();
225
226            // Update biased first moment: m = beta1 * m + (1 - beta1) * grad
227            state.exp_avg = state
228                .exp_avg
229                .mul_scalar(self.beta1)
230                .add(&grad.mul_scalar(1.0 - self.beta1))
231                .unwrap();
232
233            // Update biased second moment: v = beta2 * v + (1 - beta2) * grad^2
234            let grad_sq = grad.mul(&grad).unwrap();
235            state.exp_avg_sq = state
236                .exp_avg_sq
237                .mul_scalar(self.beta2)
238                .add(&grad_sq.mul_scalar(1.0 - self.beta2))
239                .unwrap();
240
241            // Compute bias-corrected moments
242            let (bias_correction1, bias_correction2) = if self.bias_correction {
243                (
244                    1.0 - self.beta1.powi(state.step as i32),
245                    1.0 - self.beta2.powi(state.step as i32),
246                )
247            } else {
248                (1.0, 1.0)
249            };
250
251            // m_hat = m / bc1, v_hat = v / bc2
252            let m_hat = state.exp_avg.mul_scalar(1.0 / bias_correction1);
253            let v_hat = state.exp_avg_sq.mul_scalar(1.0 / bias_correction2);
254
255            // adam_update = m_hat / (sqrt(v_hat) + eps)
256            let adam_update = m_hat.div(&v_hat.sqrt().add_scalar(self.eps)).unwrap();
257
258            // update = adam_update + weight_decay * param (decoupled weight decay)
259            let update = if self.weight_decay > 0.0 {
260                adam_update
261                    .add(&param_data.mul_scalar(self.weight_decay))
262                    .unwrap()
263            } else {
264                adam_update
265            };
266
267            // Compute trust ratio: ||param|| / ||update||
268            // norm = sqrt(sum(x^2))  using Tensor ops
269            let weight_norm_sq = param_data.mul(&param_data).unwrap().sum();
270            let update_norm_sq = update.mul(&update).unwrap().sum();
271
272            // Extract scalar norms (single element tensors)
273            let weight_norm = weight_norm_sq.to_vec()[0].sqrt();
274            let update_norm = update_norm_sq.to_vec()[0].sqrt();
275
276            let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
277                weight_norm / update_norm
278            } else {
279                1.0
280            };
281
282            // param = param - lr * trust_ratio * update
283            let effective_lr = self.lr * trust_ratio;
284            let new_param = param_data.sub(&update.mul_scalar(effective_lr)).unwrap();
285            param.update_data(new_param);
286        }
287    }
288
289    fn zero_grad(&mut self) {
290        for param in &self.params {
291            param.zero_grad();
292        }
293    }
294
295    fn get_lr(&self) -> f32 {
296        self.lr
297    }
298
299    fn set_lr(&mut self, lr: f32) {
300        self.lr = lr;
301    }
302
303    fn parameters(&self) -> &[Parameter] {
304        &self.params
305    }
306}
307
308// =============================================================================
309// Tests
310// =============================================================================
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use axonml_autograd::Variable;
316
317    #[test]
318    fn test_lamb_creation() {
319        let var = Variable::new(
320            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
321            true,
322        );
323        let param = Parameter::from_variable(var);
324        let optimizer = LAMB::new(vec![param], 0.001);
325
326        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
327        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
328        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
329    }
330
331    #[test]
332    fn test_lamb_step() {
333        let var = Variable::new(
334            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
335            true,
336        );
337        let param = Parameter::from_variable(var);
338
339        // Set gradient
340        param
341            .variable()
342            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
343
344        let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
345        optimizer.step();
346
347        let new_data = param.data().to_vec();
348        // Parameters should have changed
349        assert!((new_data[0] - 1.0).abs() > 1e-6);
350    }
351
352    #[test]
353    fn test_lamb_with_weight_decay() {
354        let var = Variable::new(
355            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
356            true,
357        );
358        let param = Parameter::from_variable(var);
359
360        param
361            .variable()
362            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
363
364        let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
365        optimizer.step();
366
367        let new_data = param.data().to_vec();
368        assert!((new_data[0] - 1.0).abs() > 1e-6);
369    }
370
371    #[test]
372    fn test_lamb_builder_pattern() {
373        let var = Variable::new(
374            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
375            true,
376        );
377        let param = Parameter::from_variable(var);
378
379        let optimizer = LAMB::new(vec![param], 0.001)
380            .betas(0.95, 0.9999)
381            .eps(1e-7)
382            .weight_decay(0.01);
383
384        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
385        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
386        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
387        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
388    }
389
390    #[test]
391    fn test_lamb_trust_ratio() {
392        // Test that trust ratio is computed correctly
393        let var = Variable::new(
394            Tensor::from_vec(vec![3.0, 4.0], &[2]).expect("tensor creation failed"),
395            true,
396        );
397        let param = Parameter::from_variable(var);
398
399        // Weight norm = sqrt(9 + 16) = 5
400        param
401            .variable()
402            .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).expect("tensor creation failed"));
403
404        let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
405
406        // After one step, parameters should change based on trust ratio
407        let old_data = param.data().to_vec();
408        optimizer.step();
409        let new_data = param.data().to_vec();
410
411        // Verify parameters changed
412        assert!((new_data[0] - old_data[0]).abs() > 1e-6);
413        assert!((new_data[1] - old_data[1]).abs() > 1e-6);
414    }
415
416    #[test]
417    fn test_lamb_zero_grad() {
418        let var = Variable::new(
419            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
420            true,
421        );
422        let param = Parameter::from_variable(var);
423
424        param
425            .variable()
426            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
427
428        let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
429        assert!(param.grad().is_some());
430
431        optimizer.zero_grad();
432        // Grad might be zeroed or None depending on implementation
433    }
434
435    #[test]
436    fn test_l2_norm_via_tensor() {
437        let t = Tensor::from_vec(vec![3.0f32, 4.0], &[2]).expect("tensor creation failed");
438        let norm_sq = t.mul(&t).unwrap().sum();
439        let norm = norm_sq.to_vec()[0].sqrt();
440        assert!((norm - 5.0).abs() < 1e-6);
441    }
442}