Skip to main content

rlx_optim/
lamb.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! LAMB — Layer-wise Adaptive Moments for Batch training (You et al.,
17//! 2019, "Large Batch Optimization for Deep Learning: Training BERT
18//! in 76 minutes").
19//!
20//! # Idea
21//!
22//! Naïve large-batch training stalls because the per-coordinate Adam
23//! step doesn't account for the magnitude difference between
24//! different layers' weights. LAMB rescales each tensor's Adam-style
25//! update by the **trust ratio** `‖θ‖ / ‖u‖`, so that the per-step
26//! relative change `‖Δθ‖ / ‖θ‖` is bounded and identical across layers.
27//!
28//! # Update rule
29//!
30//! For each tensor (and its flat parameter vector θ):
31//!
32//! ```text
33//! m_t = β₁·m_{t-1} + (1 − β₁)·g_t
34//! v_t = β₂·v_{t-1} + (1 − β₂)·g_t²
35//! u_t = m̂_t / (√v̂_t + ε) + λ·θ_{t-1}        // raw update
36//! r_t = ‖θ_{t-1}‖₂ / ‖u_t‖₂                  // trust ratio
37//! θ_t = θ_{t-1} − lr · r_t · u_t
38//! ```
39//!
40//! `r_t` is clamped to `1.0` when either norm is zero (warm-up edge
41//! case). LAMB's headline result is that this rescaling makes very
42//! large batch sizes (32k–64k) viable without quality loss.
43//!
44//! # When to use
45//!
46//! Large-batch pre-training — BERT/ViT/Llama-scale data-parallel
47//! runs. State cost = Adam (two buffers).
48
49use std::collections::HashMap;
50
51use crate::Optimizer;
52use crate::common::{l2_norm, zeros_entry};
53
54/// Layer-wise Adaptive Moments for Batch training.
55///
56/// Per-tensor state: two `f32` buffers + a per-call scratch buffer
57/// for the trust-ratio numerator (allocated inside [`Optimizer::step`]).
58#[derive(Debug, Clone)]
59pub struct Lamb {
60    /// Learning rate.
61    pub lr: f32,
62    /// First-moment EMA decay β₁. Default `0.9`.
63    pub beta1: f32,
64    /// Second-moment EMA decay β₂. Default `0.999`.
65    pub beta2: f32,
66    /// Denominator stability constant. Default `1e-6` (looser than
67    /// Adam's `1e-8` — matches NVIDIA's reference).
68    pub eps: f32,
69    /// Decoupled weight-decay coefficient λ. Default `0.01`.
70    pub weight_decay: f32,
71    /// If `true`, divide by bias-corrected moments. Defaults to `true`
72    /// (matches NVIDIA's reference impl); the original paper omits it.
73    pub bias_correction: bool,
74    step: u64,
75    m: HashMap<String, Vec<f32>>,
76    v: HashMap<String, Vec<f32>>,
77    /// Reusable per-tensor scratch buffer for the trust-ratio
78    /// numerator. Cached so we don't allocate every step.
79    scratch: HashMap<String, Vec<f32>>,
80}
81
82impl Lamb {
83    /// Construct with `(β₁, β₂, ε, λ) = (0.9, 0.999, 1e-6, 0.01)`.
84    pub fn new(lr: f32) -> Self {
85        Self {
86            lr,
87            beta1: 0.9,
88            beta2: 0.999,
89            eps: 1e-6,
90            weight_decay: 0.01,
91            bias_correction: true,
92            step: 0,
93            m: HashMap::new(),
94            v: HashMap::new(),
95            scratch: HashMap::new(),
96        }
97    }
98
99    /// Override (β₁, β₂).
100    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
101        self.beta1 = b1;
102        self.beta2 = b2;
103        self
104    }
105
106    /// Override the decoupled-decay coefficient.
107    pub fn with_weight_decay(mut self, wd: f32) -> Self {
108        self.weight_decay = wd;
109        self
110    }
111}
112
113impl Optimizer for Lamb {
114    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
115        debug_assert_eq!(param.len(), grad.len());
116        let t = (self.step + 1) as f64;
117        let b1 = self.beta1 as f64;
118        let b2 = self.beta2 as f64;
119        let (bc1, bc2) = if self.bias_correction {
120            (1.0 - b1.powf(t), 1.0 - b2.powf(t))
121        } else {
122            (1.0, 1.0)
123        };
124        let eps = self.eps as f64;
125        let lr = self.lr;
126        let wd = self.weight_decay as f64;
127        let m = zeros_entry(&mut self.m, name, param.len());
128        let v = zeros_entry(&mut self.v, name, param.len());
129        let update = zeros_entry(&mut self.scratch, name, param.len());
130        // First pass: update m/v, build `r_i = m_hat / (sqrt(v_hat) + eps) + wd * w`.
131        for i in 0..param.len() {
132            let g = grad[i] as f64;
133            let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
134            let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
135            m[i] = mi as f32;
136            v[i] = vi as f32;
137            let m_hat = mi / bc1;
138            let v_hat = vi / bc2;
139            update[i] = (m_hat / (v_hat.sqrt() + eps) + wd * param[i] as f64) as f32;
140        }
141        let w_norm = l2_norm(param);
142        let r_norm = l2_norm(update);
143        let trust = if w_norm > 0.0 && r_norm > 0.0 {
144            w_norm / r_norm
145        } else {
146            1.0
147        };
148        let step_size = lr * trust;
149        for i in 0..param.len() {
150            param[i] -= step_size * update[i];
151        }
152    }
153
154    fn end_iteration(&mut self) {
155        self.step += 1;
156    }
157}