Skip to main content

axonml_optim/
lamb.rs

1//! LAMB Optimizer - Layer-wise Adaptive Moments
2//!
3//! Implements the LAMB (Layer-wise Adaptive Moments optimizer for Batch training)
4//! algorithm for large batch training. LAMB enables training with very large
5//! batch sizes while maintaining accuracy.
6//!
7//! Reference: "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes"
8//! https://arxiv.org/abs/1904.00962
9//!
10//! # Example
11//! ```rust,ignore
12//! use axonml_optim::LAMB;
13//!
14//! let mut optimizer = LAMB::new(model.parameters(), 0.001)
15//!     .weight_decay(0.01)
16//!     .betas(0.9, 0.999);
17//!
18//! for epoch in 0..100 {
19//!     optimizer.zero_grad();
20//!     let loss = model.forward(&input).mse_loss(&target);
21//!     loss.backward();
22//!     optimizer.step();
23//! }
24//! ```
25//!
26//! @version 0.1.0
27
28use axonml_nn::Parameter;
29use axonml_tensor::Tensor;
30
31use crate::optimizer::Optimizer;
32
33// =============================================================================
34// LAMB State
35// =============================================================================
36
37/// Per-parameter state for LAMB optimizer.
38#[derive(Debug, Clone)]
39struct LambState {
40    /// First moment (exponential moving average of gradient)
41    exp_avg: Vec<f32>,
42    /// Second moment (exponential moving average of squared gradient)
43    exp_avg_sq: Vec<f32>,
44    /// Step count for bias correction
45    step: usize,
46}
47
48impl LambState {
49    fn new(size: usize) -> Self {
50        Self {
51            exp_avg: vec![0.0; size],
52            exp_avg_sq: vec![0.0; size],
53            step: 0,
54        }
55    }
56}
57
58// =============================================================================
59// LAMB Optimizer
60// =============================================================================
61
62/// LAMB optimizer for large batch training.
63///
64/// LAMB extends Adam by adding a layer-wise trust ratio that scales
65/// the update based on the ratio of parameter norm to update norm.
66/// This enables stable training with very large batch sizes.
67///
68/// The update rule is:
69/// ```text
70/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
71/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
72/// m_hat = m_t / (1 - beta1^t)
73/// v_hat = v_t / (1 - beta2^t)
74/// r = m_hat / (sqrt(v_hat) + eps) + weight_decay * param
75/// trust_ratio = ||param|| / ||r||  (layer-wise)
76/// param = param - lr * trust_ratio * r
77/// ```
78pub struct LAMB {
79    /// Parameters to optimize
80    params: Vec<Parameter>,
81    /// Learning rate
82    lr: f32,
83    /// First moment decay rate
84    beta1: f32,
85    /// Second moment decay rate
86    beta2: f32,
87    /// Small constant for numerical stability
88    eps: f32,
89    /// Weight decay coefficient (decoupled)
90    weight_decay: f32,
91    /// Whether to use bias correction
92    bias_correction: bool,
93    /// Per-parameter state
94    state: Vec<LambState>,
95}
96
97impl LAMB {
98    /// Creates a new LAMB optimizer with default hyperparameters.
99    ///
100    /// Defaults:
101    /// - betas: (0.9, 0.999)
102    /// - eps: 1e-6
103    /// - weight_decay: 0.0
104    #[must_use]
105    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
106        Self {
107            params,
108            lr,
109            beta1: 0.9,
110            beta2: 0.999,
111            eps: 1e-6,
112            weight_decay: 0.0,
113            bias_correction: true,
114            state: Vec::new(),
115        }
116    }
117
118    /// Creates LAMB with specified betas.
119    #[must_use]
120    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
121        Self {
122            params,
123            lr,
124            beta1: betas.0,
125            beta2: betas.1,
126            eps: 1e-6,
127            weight_decay: 0.0,
128            bias_correction: true,
129            state: Vec::new(),
130        }
131    }
132
133    /// Creates LAMB with all options.
134    #[must_use]
135    pub fn with_options(
136        params: Vec<Parameter>,
137        lr: f32,
138        betas: (f32, f32),
139        eps: f32,
140        weight_decay: f32,
141    ) -> Self {
142        Self {
143            params,
144            lr,
145            beta1: betas.0,
146            beta2: betas.1,
147            eps,
148            weight_decay,
149            bias_correction: true,
150            state: Vec::new(),
151        }
152    }
153
154    /// Builder: set betas (momentum decay rates)
155    #[must_use]
156    pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
157        self.beta1 = beta1;
158        self.beta2 = beta2;
159        self
160    }
161
162    /// Builder: set epsilon
163    #[must_use]
164    pub fn eps(mut self, eps: f32) -> Self {
165        self.eps = eps;
166        self
167    }
168
169    /// Builder: set weight decay
170    #[must_use]
171    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
172        self.weight_decay = weight_decay;
173        self
174    }
175
176    /// Builder: set bias correction
177    #[must_use]
178    pub fn bias_correction(mut self, enabled: bool) -> Self {
179        self.bias_correction = enabled;
180        self
181    }
182
183    fn ensure_state_initialized(&mut self) {
184        if self.state.is_empty() {
185            self.state = self
186                .params
187                .iter()
188                .map(|p| LambState::new(p.numel()))
189                .collect();
190        }
191    }
192
193    /// Computes the L2 norm of a vector.
194    fn l2_norm(vec: &[f32]) -> f32 {
195        vec.iter().map(|x| x * x).sum::<f32>().sqrt()
196    }
197}
198
199impl Optimizer for LAMB {
200    fn step(&mut self) {
201        self.ensure_state_initialized();
202
203        for (i, param) in self.params.iter().enumerate() {
204            if !param.requires_grad() {
205                continue;
206            }
207
208            let grad = match param.grad() {
209                Some(g) => g,
210                None => continue,
211            };
212
213            let grad_vec = grad.to_vec();
214            let state = &mut self.state[i];
215            state.step += 1;
216
217            let param_data = param.data();
218            let param_vec = param_data.to_vec();
219
220            // Update biased first moment estimate
221            for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
222                *m = self.beta1 * *m + (1.0 - self.beta1) * g;
223            }
224
225            // Update biased second moment estimate
226            for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
227                *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
228            }
229
230            // Compute bias-corrected moments
231            let (bias_correction1, bias_correction2) = if self.bias_correction {
232                (
233                    1.0 - self.beta1.powi(state.step as i32),
234                    1.0 - self.beta2.powi(state.step as i32),
235                )
236            } else {
237                (1.0, 1.0)
238            };
239
240            // Compute Adam update direction: m_hat / (sqrt(v_hat) + eps)
241            let mut update: Vec<f32> = state
242                .exp_avg
243                .iter()
244                .zip(state.exp_avg_sq.iter())
245                .map(|(m, v)| {
246                    let m_hat = m / bias_correction1;
247                    let v_hat = v / bias_correction2;
248                    m_hat / (v_hat.sqrt() + self.eps)
249                })
250                .collect();
251
252            // Add decoupled weight decay
253            if self.weight_decay > 0.0 {
254                for (u, p) in update.iter_mut().zip(param_vec.iter()) {
255                    *u += self.weight_decay * p;
256                }
257            }
258
259            // Compute layer-wise trust ratio
260            let weight_norm = Self::l2_norm(&param_vec);
261            let update_norm = Self::l2_norm(&update);
262
263            let trust_ratio = if weight_norm > 0.0 && update_norm > 0.0 {
264                weight_norm / update_norm
265            } else {
266                1.0
267            };
268
269            // Apply update with trust ratio
270            let effective_lr = self.lr * trust_ratio;
271            let new_data: Vec<f32> = param_vec
272                .iter()
273                .zip(update.iter())
274                .map(|(p, u)| p - effective_lr * u)
275                .collect();
276
277            let new_tensor = Tensor::from_vec(new_data, param_data.shape()).unwrap();
278            param.update_data(new_tensor);
279        }
280    }
281
282    fn zero_grad(&mut self) {
283        for param in &self.params {
284            param.zero_grad();
285        }
286    }
287
288    fn get_lr(&self) -> f32 {
289        self.lr
290    }
291
292    fn set_lr(&mut self, lr: f32) {
293        self.lr = lr;
294    }
295
296    fn parameters(&self) -> &[Parameter] {
297        &self.params
298    }
299}
300
301// =============================================================================
302// Tests
303// =============================================================================
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use axonml_autograd::Variable;
309
310    #[test]
311    fn test_lamb_creation() {
312        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
313        let param = Parameter::from_variable(var);
314        let optimizer = LAMB::new(vec![param], 0.001);
315
316        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
317        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
318        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
319    }
320
321    #[test]
322    fn test_lamb_step() {
323        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
324        let param = Parameter::from_variable(var);
325
326        // Set gradient
327        param
328            .variable()
329            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
330
331        let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
332        optimizer.step();
333
334        let new_data = param.data().to_vec();
335        // Parameters should have changed
336        assert!((new_data[0] - 1.0).abs() > 1e-6);
337    }
338
339    #[test]
340    fn test_lamb_with_weight_decay() {
341        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
342        let param = Parameter::from_variable(var);
343
344        param
345            .variable()
346            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
347
348        let mut optimizer = LAMB::new(vec![param.clone()], 0.1).weight_decay(0.01);
349        optimizer.step();
350
351        let new_data = param.data().to_vec();
352        assert!((new_data[0] - 1.0).abs() > 1e-6);
353    }
354
355    #[test]
356    fn test_lamb_builder_pattern() {
357        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
358        let param = Parameter::from_variable(var);
359
360        let optimizer = LAMB::new(vec![param], 0.001)
361            .betas(0.95, 0.9999)
362            .eps(1e-7)
363            .weight_decay(0.01);
364
365        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
366        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
367        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
368        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
369    }
370
371    #[test]
372    fn test_lamb_trust_ratio() {
373        // Test that trust ratio is computed correctly
374        let var = Variable::new(Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(), true);
375        let param = Parameter::from_variable(var);
376
377        // Weight norm = sqrt(9 + 16) = 5
378        param
379            .variable()
380            .set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
381
382        let mut optimizer = LAMB::new(vec![param.clone()], 0.1);
383
384        // After one step, parameters should change based on trust ratio
385        let old_data = param.data().to_vec();
386        optimizer.step();
387        let new_data = param.data().to_vec();
388
389        // Verify parameters changed
390        assert!((new_data[0] - old_data[0]).abs() > 1e-6);
391        assert!((new_data[1] - old_data[1]).abs() > 1e-6);
392    }
393
394    #[test]
395    fn test_lamb_zero_grad() {
396        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
397        let param = Parameter::from_variable(var);
398
399        param
400            .variable()
401            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
402
403        let mut optimizer = LAMB::new(vec![param.clone()], 0.001);
404        assert!(param.grad().is_some());
405
406        optimizer.zero_grad();
407        // Grad might be zeroed or None depending on implementation
408    }
409
410    #[test]
411    fn test_l2_norm() {
412        let vec = vec![3.0, 4.0];
413        let norm = LAMB::l2_norm(&vec);
414        assert!((norm - 5.0).abs() < 1e-6);
415    }
416}