Skip to main content

axonml_optim/
lamb.rs

1//! LAMB Optimizer - Layer-wise Adaptive Moments
2//!
3//! # File
4//! `crates/axonml-optim/src/lamb.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_core;
18use axonml_nn::Parameter;
19use axonml_tensor::Tensor;
20
21use crate::optimizer::Optimizer;
22
23// =============================================================================
24// LAMB State
25// =============================================================================
26
27/// Per-parameter state for LAMB optimizer.
28///
29/// Stores momentum tensors on the same device as parameters (CPU or GPU).
30/// When parameters are on GPU, all state stays GPU-resident — zero CPU round-trips.
31#[derive(Debug, Clone)]
32struct LambState {
33    /// First moment (exponential moving average of gradient) — on same device as param.
34    exp_avg: Tensor<f32>,
35    /// Second moment (exponential moving average of squared gradient) — on same device as param.
36    exp_avg_sq: Tensor<f32>,
37    /// Step count for bias correction
38    step: usize,
39}
40
41impl LambState {
42    fn new(shape: &[usize], device: axonml_core::Device) -> Self {
43        let size: usize = shape.iter().product();
44        let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
45        let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
46        if device.is_gpu() {
47            exp_avg = exp_avg.to_device(device).unwrap();
48            exp_avg_sq = exp_avg_sq.to_device(device).unwrap();
49        }
50        Self {
51            exp_avg,
52            exp_avg_sq,
53            step: 0,
54        }
55    }
56}
57
58// =============================================================================
59// LAMB Optimizer
60// =============================================================================
61
62/// LAMB optimizer for large batch training.
63///
64/// LAMB extends Adam by adding a layer-wise trust ratio that scales
65/// the update based on the ratio of parameter norm to update norm.
66/// This enables stable training with very large batch sizes.
67///
68/// The update rule is:
69/// ```text
70/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
71/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
72/// m_hat = m_t / (1 - beta1^t)
73/// v_hat = v_t / (1 - beta2^t)
74/// r = m_hat / (sqrt(v_hat) + eps) + weight_decay * param
75/// trust_ratio = ||param|| / ||r||  (layer-wise)
76/// param = param - lr * trust_ratio * r
77/// ```
78pub struct LAMB {
79    /// Parameters to optimize
80    params: Vec<Parameter>,
81    /// Learning rate
82    lr: f32,
83    /// First moment decay rate
84    beta1: f32,
85    /// Second moment decay rate
86    beta2: f32,
87    /// Small constant for numerical stability
88    eps: f32,
89    /// Weight decay coefficient (decoupled)
90    weight_decay: f32,
91    /// Whether to use bias correction
92    bias_correction: bool,
93    /// Per-parameter state
94    state: Vec<LambState>,
95}
96
97impl LAMB {
98    /// Creates a new LAMB optimizer with default hyperparameters.
99    ///
100    /// Defaults:
101    /// - betas: (0.9, 0.999)
102    /// - eps: 1e-6
103    /// - weight_decay: 0.0
104    #[must_use]
105    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
106        Self {
107            params,
108            lr,
109            beta1: 0.9,
110            beta2: 0.999,
111            eps: 1e-6,
112            weight_decay: 0.0,
113            bias_correction: true,
114            state: Vec::new(),
115        }
116    }
117
118    /// Creates LAMB with specified betas.
119    #[must_use]
120    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
121        Self {
122            params,
123            lr,
124            beta1: betas.0,
125            beta2: betas.1,
126            eps: 1e-6,
127            weight_decay: 0.0,
128            bias_correction: true,
129            state: Vec::new(),
130        }
131    }
132
133    /// Creates LAMB with all options.
134    #[must_use]
135    pub fn with_options(
136        params: Vec<Parameter>,
137        lr: f32,
138        betas: (f32, f32),
139        eps: f32,
140        weight_decay: f32,
141    ) -> Self {
142        Self {
143            params,
144            lr,
145            beta1: betas.0,
146            beta2: betas.1,
147            eps,
148            weight_decay,
149            bias_correction: true,
150            state: Vec::new(),
151        }
152    }
153
154    /// Builder: set betas (momentum decay rates)
155    #[must_use]
156    pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
157        self.beta1 = beta1;
158        self.beta2 = beta2;
159        self
160    }
161
162    /// Builder: set epsilon
163    #[must_use]
164    pub fn eps(mut self, eps: f32) -> Self {
165        self.eps = eps;
166        self
167    }
168
169    /// Builder: set weight decay
170    #[must_use]
171    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
172        self.weight_decay = weight_decay;
173        self
174    }
175
176    /// Builder: set bias correction
177    #[must_use]
178    pub fn bias_correction(mut self, enabled: bool) -> Self {
179        self.bias_correction = enabled;
180        self
181    }
182
183    fn ensure_state_initialized(&mut self) {
184        if self.state.is_empty() {
185            self.state = self
186                .params
187                .iter()
188                .map(|p| {
189                    let data = p.data();
190                    LambState::new(data.shape(), data.device())
191                })
192                .collect();
193        }
194    }
195}
196
197impl Optimizer for LAMB {
198    fn step(&mut self) {
199        self.ensure_state_initialized();
200
201        // ============================================================
202        // Tensor-op path: works on both CPU and GPU without to_vec()
203        // All ops (add, mul, mul_scalar, div, sqrt, add_scalar, sub)
204        // dispatch to CUDA when the tensors are GPU-resident.
205        // ============================================================
206
207        for (i, param) in self.params.iter().enumerate() {
208            if !param.requires_grad() {
209                continue;
210            }
211
212            let grad = match param.grad() {
213                Some(g) => g,
214                None => continue,
215            };
216
217            let state = &mut self.state[i];
218            state.step += 1;
219
220            let param_data = param.data();
221
222            // Update biased first moment: m = beta1 * m + (1 - beta1) * grad
223            state.exp_avg = state
224                .exp_avg
225                .mul_scalar(self.beta1)
226                .add(&grad.mul_scalar(1.0 - self.beta1))
227                .unwrap();
228
229            // Update biased second moment: v = beta2 * v + (1 - beta2) * grad^2
230            let grad_sq = grad.mul(&grad).unwrap();
231            state.exp_avg_sq = state
232                .exp_avg_sq
233                .mul_scalar(self.beta2)
234                .add(&grad_sq.mul_scalar(1.0 - self.beta2))
235                .unwrap();
236
237            // Compute bias-corrected moments
238            let (bias_correction1, bias_correction2) = if self.bias_correction {
239                (
240                    1.0 - self.beta1.powi(state.step as i32),
241                    1.0 - self.beta2.powi(state.step as i32),
242                )
243            } else {
244                (1.0, 1.0)
245            };
246
247            // m_hat = m / bc1, v_hat = v / bc2
248            let m_hat = state.exp_avg.mul_scalar(1.0 / bias_correction1);
249            let v_hat = state.exp_avg_sq.mul_scalar(1.0 / bias_correction2);
250
251            // adam_update = m_hat / (sqrt(v_hat) + eps)
252            let adam_update = m_hat.div(&v_hat.sqrt().add_scalar(self.eps)).unwrap();
253
254            // update = adam_update + weight_decay * param (decoupled weight decay)
255            let update = if self.weight_decay > 0.0 {
256                adam_update
257                    .add(&param_data.mul_scalar(self.weight_decay))
258                    .unwrap()
259            } else {
260                adam_update
261            };
262
263            // Compute trust ratio: ||param|| / ||update||
264            // norm = sqrt(sum(x^2))  using Tensor ops
265            let weight_norm_sq = param_data.mul(&param_data).unwrap().sum();
266            let update_norm_sq = update.mul(&update).unwrap().sum();
267
268            // Extract scalar norms (single element tensors)
269            let weight_norm = weight_norm_sq.to_vec()[0].sqrt();
270            let update_norm = update_norm_sq.to_vec()[0].sqrt();
271
272            let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
273                weight_norm / update_norm
274            } else {
275                1.0
276            };
277
278            // param = param - lr * trust_ratio * update
279            let effective_lr = self.lr * trust_ratio;
280            let new_param = param_data.sub(&update.mul_scalar(effective_lr)).unwrap();
281            param.update_data(new_param);
282        }
283    }
284
285    fn zero_grad(&mut self) {
286        for param in &self.params {
287            param.zero_grad();
288        }
289    }
290
291    fn get_lr(&self) -> f32 {
292        self.lr
293    }
294
295    fn set_lr(&mut self, lr: f32) {
296        self.lr = lr;
297    }
298
299    fn parameters(&self) -> &[Parameter] {
300        &self.params
301    }
302}
303
304// =============================================================================
305// Tests
306// =============================================================================
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use axonml_autograd::Variable;
312
313    #[test]
314    fn test_lamb_creation() {
315        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
316        let param = Parameter::from_variable(var);
317        let optimizer = LAMB::new(vec![param], 0.001);
318
319        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
320        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
321        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
322    }
323
324    #[test]
325    fn test_lamb_step() {
326        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
327        let param = Parameter::from_variable(var);
328
329        // Set gradient
330        param
331            .variable()
332            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
333
334        let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
335        optimizer.step();
336
337        let new_data = param.data().to_vec();
338        // Parameters should have changed
339        assert!((new_data[0] - 1.0).abs() > 1e-6);
340    }
341
342    #[test]
343    fn test_lamb_with_weight_decay() {
344        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
345        let param = Parameter::from_variable(var);
346
347        param
348            .variable()
349            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
350
351        let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
352        optimizer.step();
353
354        let new_data = param.data().to_vec();
355        assert!((new_data[0] - 1.0).abs() > 1e-6);
356    }
357
358    #[test]
359    fn test_lamb_builder_pattern() {
360        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
361        let param = Parameter::from_variable(var);
362
363        let optimizer = LAMB::new(vec![param], 0.001)
364            .betas(0.95, 0.9999)
365            .eps(1e-7)
366            .weight_decay(0.01);
367
368        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
369        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
370        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
371        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
372    }
373
374    #[test]
375    fn test_lamb_trust_ratio() {
376        // Test that trust ratio is computed correctly
377        let var = Variable::new(Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(), true);
378        let param = Parameter::from_variable(var);
379
380        // Weight norm = sqrt(9 + 16) = 5
381        param
382            .variable()
383            .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
384
385        let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
386
387        // After one step, parameters should change based on trust ratio
388        let old_data = param.data().to_vec();
389        optimizer.step();
390        let new_data = param.data().to_vec();
391
392        // Verify parameters changed
393        assert!((new_data[0] - old_data[0]).abs() > 1e-6);
394        assert!((new_data[1] - old_data[1]).abs() > 1e-6);
395    }
396
397    #[test]
398    fn test_lamb_zero_grad() {
399        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
400        let param = Parameter::from_variable(var);
401
402        param
403            .variable()
404            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
405
406        let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
407        assert!(param.grad().is_some());
408
409        optimizer.zero_grad();
410        // Grad might be zeroed or None depending on implementation
411    }
412
413    #[test]
414    fn test_l2_norm_via_tensor() {
415        let t = Tensor::from_vec(vec![3.0f32, 4.0], &[2]).unwrap();
416        let norm_sq = t.mul(&t).unwrap().sum();
417        let norm = norm_sq.to_vec()[0].sqrt();
418        assert!((norm - 5.0).abs() < 1e-6);
419    }
420}