Skip to main content

axonml_optim/
lamb.rs

1//! `LAMB` — Layer-wise Adaptive Moments for large-batch training.
2//!
3//! Extends Adam with a per-layer trust ratio: `phi(||w||) / ||adam_update||`
4//! scales the step size per parameter group, enabling stable training at
5//! batch sizes of 32k+ without warmup hacks. Config mirrors Adam
6//! (beta1/beta2/epsilon/weight_decay) plus optional trust_ratio clipping.
7//!
8//! # File
9//! `crates/axonml-optim/src/lamb.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23use axonml_core;
24use axonml_nn::Parameter;
25use axonml_tensor::Tensor;
26
27use crate::optimizer::Optimizer;
28
29// =============================================================================
30// LAMB State
31// =============================================================================
32
33/// Per-parameter state for LAMB optimizer.
34///
35/// Stores momentum tensors on the same device as parameters (CPU or GPU).
36/// When parameters are on GPU, all state stays GPU-resident — zero CPU round-trips.
37#[derive(Debug, Clone)]
38struct LambState {
39    /// First moment (exponential moving average of gradient) — on same device as param.
40    exp_avg: Tensor<f32>,
41    /// Second moment (exponential moving average of squared gradient) — on same device as param.
42    exp_avg_sq: Tensor<f32>,
43    /// Step count for bias correction
44    step: usize,
45}
46
47impl LambState {
48    fn new(shape: &[usize], device: axonml_core::Device) -> Self {
49        let size: usize = shape.iter().product();
50        let mut exp_avg =
51            Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
52        let mut exp_avg_sq =
53            Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
54        if device.is_gpu() {
55            exp_avg = exp_avg.to_device(device).expect("device transfer failed");
56            exp_avg_sq = exp_avg_sq
57                .to_device(device)
58                .expect("device transfer failed");
59        }
60        Self {
61            exp_avg,
62            exp_avg_sq,
63            step: 0,
64        }
65    }
66}
67
68// =============================================================================
69// LAMB Optimizer
70// =============================================================================
71
72/// LAMB optimizer for large batch training.
73///
74/// LAMB extends Adam by adding a layer-wise trust ratio that scales
75/// the update based on the ratio of parameter norm to update norm.
76/// This enables stable training with very large batch sizes.
77///
78/// The update rule is:
79/// ```text
80/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
81/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
82/// m_hat = m_t / (1 - beta1^t)
83/// v_hat = v_t / (1 - beta2^t)
84/// r = m_hat / (sqrt(v_hat) + eps) + weight_decay * param
85/// trust_ratio = ||param|| / ||r||  (layer-wise)
86/// param = param - lr * trust_ratio * r
87/// ```
88pub struct LAMB {
89    /// Parameters to optimize
90    params: Vec<Parameter>,
91    /// Learning rate
92    lr: f32,
93    /// First moment decay rate
94    beta1: f32,
95    /// Second moment decay rate
96    beta2: f32,
97    /// Small constant for numerical stability
98    eps: f32,
99    /// Weight decay coefficient (decoupled)
100    weight_decay: f32,
101    /// Whether to use bias correction
102    bias_correction: bool,
103    /// Per-parameter state
104    state: Vec<LambState>,
105}
106
107impl LAMB {
108    /// Creates a new LAMB optimizer with default hyperparameters.
109    ///
110    /// Defaults:
111    /// - betas: (0.9, 0.999)
112    /// - eps: 1e-6
113    /// - weight_decay: 0.0
114    #[must_use]
115    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
116        Self {
117            params,
118            lr,
119            beta1: 0.9,
120            beta2: 0.999,
121            eps: 1e-6,
122            weight_decay: 0.0,
123            bias_correction: true,
124            state: Vec::new(),
125        }
126    }
127
128    /// Creates LAMB with specified betas.
129    #[must_use]
130    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
131        Self {
132            params,
133            lr,
134            beta1: betas.0,
135            beta2: betas.1,
136            eps: 1e-6,
137            weight_decay: 0.0,
138            bias_correction: true,
139            state: Vec::new(),
140        }
141    }
142
143    /// Creates LAMB with all options.
144    #[must_use]
145    pub fn with_options(
146        params: Vec<Parameter>,
147        lr: f32,
148        betas: (f32, f32),
149        eps: f32,
150        weight_decay: f32,
151    ) -> Self {
152        Self {
153            params,
154            lr,
155            beta1: betas.0,
156            beta2: betas.1,
157            eps,
158            weight_decay,
159            bias_correction: true,
160            state: Vec::new(),
161        }
162    }
163
164    /// Builder: set betas (momentum decay rates)
165    #[must_use]
166    pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
167        self.beta1 = beta1;
168        self.beta2 = beta2;
169        self
170    }
171
172    /// Builder: set epsilon
173    #[must_use]
174    pub fn eps(mut self, eps: f32) -> Self {
175        self.eps = eps;
176        self
177    }
178
179    /// Builder: set weight decay
180    #[must_use]
181    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
182        self.weight_decay = weight_decay;
183        self
184    }
185
186    /// Builder: set bias correction
187    #[must_use]
188    pub fn bias_correction(mut self, enabled: bool) -> Self {
189        self.bias_correction = enabled;
190        self
191    }
192
193    fn ensure_state_initialized(&mut self) {
194        if self.state.is_empty() {
195            self.state = self
196                .params
197                .iter()
198                .map(|p| {
199                    let data = p.data();
200                    LambState::new(data.shape(), data.device())
201                })
202                .collect();
203        }
204    }
205}
206
207impl Optimizer for LAMB {
208    fn step(&mut self) {
209        self.ensure_state_initialized();
210
211        // ============================================================
212        // Tensor-op path: works on both CPU and GPU without to_vec()
213        // All ops (add, mul, mul_scalar, div, sqrt, add_scalar, sub)
214        // dispatch to CUDA when the tensors are GPU-resident.
215        // ============================================================
216
217        for (i, param) in self.params.iter().enumerate() {
218            if !param.requires_grad() {
219                continue;
220            }
221
222            let grad = match param.grad() {
223                Some(g) => g,
224                None => continue,
225            };
226
227            let state = &mut self.state[i];
228            state.step += 1;
229
230            let param_data = param.data();
231
232            // Update biased first moment: m = beta1 * m + (1 - beta1) * grad
233            state.exp_avg = state
234                .exp_avg
235                .mul_scalar(self.beta1)
236                .add(&grad.mul_scalar(1.0 - self.beta1))
237                .unwrap();
238
239            // Update biased second moment: v = beta2 * v + (1 - beta2) * grad^2
240            let grad_sq = grad.mul(&grad).unwrap();
241            state.exp_avg_sq = state
242                .exp_avg_sq
243                .mul_scalar(self.beta2)
244                .add(&grad_sq.mul_scalar(1.0 - self.beta2))
245                .unwrap();
246
247            // Compute bias-corrected moments
248            let (bias_correction1, bias_correction2) = if self.bias_correction {
249                (
250                    1.0 - self.beta1.powi(state.step as i32),
251                    1.0 - self.beta2.powi(state.step as i32),
252                )
253            } else {
254                (1.0, 1.0)
255            };
256
257            // m_hat = m / bc1, v_hat = v / bc2
258            let m_hat = state.exp_avg.mul_scalar(1.0 / bias_correction1);
259            let v_hat = state.exp_avg_sq.mul_scalar(1.0 / bias_correction2);
260
261            // adam_update = m_hat / (sqrt(v_hat) + eps)
262            let adam_update = m_hat.div(&v_hat.sqrt().add_scalar(self.eps)).unwrap();
263
264            // update = adam_update + weight_decay * param (decoupled weight decay)
265            let update = if self.weight_decay > 0.0 {
266                adam_update
267                    .add(&param_data.mul_scalar(self.weight_decay))
268                    .unwrap()
269            } else {
270                adam_update
271            };
272
273            // Compute trust ratio: ||param|| / ||update||
274            // norm = sqrt(sum(x^2))  using Tensor ops
275            let weight_norm_sq = param_data.mul(&param_data).unwrap().sum();
276            let update_norm_sq = update.mul(&update).unwrap().sum();
277
278            // Extract scalar norms (single element tensors)
279            let weight_norm = weight_norm_sq.to_vec()[0].sqrt();
280            let update_norm = update_norm_sq.to_vec()[0].sqrt();
281
282            let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
283                weight_norm / update_norm
284            } else {
285                1.0
286            };
287
288            // param = param - lr * trust_ratio * update
289            let effective_lr = self.lr * trust_ratio;
290            let new_param = param_data.sub(&update.mul_scalar(effective_lr)).unwrap();
291            param.update_data(new_param);
292        }
293    }
294
295    fn zero_grad(&mut self) {
296        for param in &self.params {
297            param.zero_grad();
298        }
299    }
300
301    fn get_lr(&self) -> f32 {
302        self.lr
303    }
304
305    fn set_lr(&mut self, lr: f32) {
306        self.lr = lr;
307    }
308
309    fn parameters(&self) -> &[Parameter] {
310        &self.params
311    }
312}
313
314// =============================================================================
315// Tests
316// =============================================================================
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use axonml_autograd::Variable;
322
323    #[test]
324    fn test_lamb_creation() {
325        let var = Variable::new(
326            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
327            true,
328        );
329        let param = Parameter::from_variable(var);
330        let optimizer = LAMB::new(vec![param], 0.001);
331
332        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
333        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
334        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
335    }
336
337    #[test]
338    fn test_lamb_step() {
339        let var = Variable::new(
340            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
341            true,
342        );
343        let param = Parameter::from_variable(var);
344
345        // Set gradient
346        param
347            .variable()
348            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
349
350        let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
351        optimizer.step();
352
353        let new_data = param.data().to_vec();
354        // Parameters should have changed
355        assert!((new_data[0] - 1.0).abs() > 1e-6);
356    }
357
358    #[test]
359    fn test_lamb_with_weight_decay() {
360        let var = Variable::new(
361            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
362            true,
363        );
364        let param = Parameter::from_variable(var);
365
366        param
367            .variable()
368            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
369
370        let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
371        optimizer.step();
372
373        let new_data = param.data().to_vec();
374        assert!((new_data[0] - 1.0).abs() > 1e-6);
375    }
376
377    #[test]
378    fn test_lamb_builder_pattern() {
379        let var = Variable::new(
380            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
381            true,
382        );
383        let param = Parameter::from_variable(var);
384
385        let optimizer = LAMB::new(vec![param], 0.001)
386            .betas(0.95, 0.9999)
387            .eps(1e-7)
388            .weight_decay(0.01);
389
390        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
391        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
392        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
393        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
394    }
395
396    #[test]
397    fn test_lamb_trust_ratio() {
398        // Test that trust ratio is computed correctly
399        let var = Variable::new(
400            Tensor::from_vec(vec![3.0, 4.0], &[2]).expect("tensor creation failed"),
401            true,
402        );
403        let param = Parameter::from_variable(var);
404
405        // Weight norm = sqrt(9 + 16) = 5
406        param
407            .variable()
408            .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).expect("tensor creation failed"));
409
410        let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
411
412        // After one step, parameters should change based on trust ratio
413        let old_data = param.data().to_vec();
414        optimizer.step();
415        let new_data = param.data().to_vec();
416
417        // Verify parameters changed
418        assert!((new_data[0] - old_data[0]).abs() > 1e-6);
419        assert!((new_data[1] - old_data[1]).abs() > 1e-6);
420    }
421
422    #[test]
423    fn test_lamb_zero_grad() {
424        let var = Variable::new(
425            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
426            true,
427        );
428        let param = Parameter::from_variable(var);
429
430        param
431            .variable()
432            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
433
434        let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
435        assert!(param.grad().is_some());
436
437        optimizer.zero_grad();
438        // Grad might be zeroed or None depending on implementation
439    }
440
441    #[test]
442    fn test_l2_norm_via_tensor() {
443        let t = Tensor::from_vec(vec![3.0f32, 4.0], &[2]).expect("tensor creation failed");
444        let norm_sq = t.mul(&t).unwrap().sum();
445        let norm = norm_sq.to_vec()[0].sqrt();
446        assert!((norm - 5.0).abs() < 1e-6);
447    }
448}