rlx_optim/lamb.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! LAMB — Layer-wise Adaptive Moments for Batch training (You et al.,
17//! 2019, "Large Batch Optimization for Deep Learning: Training BERT
18//! in 76 minutes").
19//!
20//! # Idea
21//!
22//! Naïve large-batch training stalls because the per-coordinate Adam
23//! step doesn't account for the magnitude difference between
24//! different layers' weights. LAMB rescales each tensor's Adam-style
25//! update by the **trust ratio** `‖θ‖ / ‖u‖`, so that the per-step
26//! relative change `‖Δθ‖ / ‖θ‖` is bounded and identical across layers.
27//!
28//! # Update rule
29//!
30//! For each tensor (and its flat parameter vector θ):
31//!
32//! ```text
33//! m_t = β₁·m_{t-1} + (1 − β₁)·g_t
34//! v_t = β₂·v_{t-1} + (1 − β₂)·g_t²
35//! u_t = m̂_t / (√v̂_t + ε) + λ·θ_{t-1} // raw update
36//! r_t = ‖θ_{t-1}‖₂ / ‖u_t‖₂ // trust ratio
37//! θ_t = θ_{t-1} − lr · r_t · u_t
38//! ```
39//!
40//! `r_t` is clamped to `1.0` when either norm is zero (warm-up edge
41//! case). LAMB's headline result is that this rescaling makes very
42//! large batch sizes (32k–64k) viable without quality loss.
43//!
44//! # When to use
45//!
46//! Large-batch pre-training — BERT/ViT/Llama-scale data-parallel
47//! runs. State cost = Adam (two buffers).
48
49use std::collections::HashMap;
50
51use crate::Optimizer;
52use crate::common::{l2_norm, zeros_entry};
53
54/// Layer-wise Adaptive Moments for Batch training.
55///
56/// Per-tensor state: two `f32` buffers + a per-call scratch buffer
57/// for the trust-ratio numerator (allocated inside [`Optimizer::step`]).
58#[derive(Debug, Clone)]
59pub struct Lamb {
60 /// Learning rate.
61 pub lr: f32,
62 /// First-moment EMA decay β₁. Default `0.9`.
63 pub beta1: f32,
64 /// Second-moment EMA decay β₂. Default `0.999`.
65 pub beta2: f32,
66 /// Denominator stability constant. Default `1e-6` (looser than
67 /// Adam's `1e-8` — matches NVIDIA's reference).
68 pub eps: f32,
69 /// Decoupled weight-decay coefficient λ. Default `0.01`.
70 pub weight_decay: f32,
71 /// If `true`, divide by bias-corrected moments. Defaults to `true`
72 /// (matches NVIDIA's reference impl); the original paper omits it.
73 pub bias_correction: bool,
74 step: u64,
75 m: HashMap<String, Vec<f32>>,
76 v: HashMap<String, Vec<f32>>,
77 /// Reusable per-tensor scratch buffer for the trust-ratio
78 /// numerator. Cached so we don't allocate every step.
79 scratch: HashMap<String, Vec<f32>>,
80}
81
82impl Lamb {
83 /// Construct with `(β₁, β₂, ε, λ) = (0.9, 0.999, 1e-6, 0.01)`.
84 pub fn new(lr: f32) -> Self {
85 Self {
86 lr,
87 beta1: 0.9,
88 beta2: 0.999,
89 eps: 1e-6,
90 weight_decay: 0.01,
91 bias_correction: true,
92 step: 0,
93 m: HashMap::new(),
94 v: HashMap::new(),
95 scratch: HashMap::new(),
96 }
97 }
98
99 /// Override (β₁, β₂).
100 pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
101 self.beta1 = b1;
102 self.beta2 = b2;
103 self
104 }
105
106 /// Override the decoupled-decay coefficient.
107 pub fn with_weight_decay(mut self, wd: f32) -> Self {
108 self.weight_decay = wd;
109 self
110 }
111}
112
113impl Optimizer for Lamb {
114 fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
115 debug_assert_eq!(param.len(), grad.len());
116 let t = (self.step + 1) as f64;
117 let b1 = self.beta1 as f64;
118 let b2 = self.beta2 as f64;
119 let (bc1, bc2) = if self.bias_correction {
120 (1.0 - b1.powf(t), 1.0 - b2.powf(t))
121 } else {
122 (1.0, 1.0)
123 };
124 let eps = self.eps as f64;
125 let lr = self.lr;
126 let wd = self.weight_decay as f64;
127 let m = zeros_entry(&mut self.m, name, param.len());
128 let v = zeros_entry(&mut self.v, name, param.len());
129 let update = zeros_entry(&mut self.scratch, name, param.len());
130 // First pass: update m/v, build `r_i = m_hat / (sqrt(v_hat) + eps) + wd * w`.
131 for i in 0..param.len() {
132 let g = grad[i] as f64;
133 let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
134 let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
135 m[i] = mi as f32;
136 v[i] = vi as f32;
137 let m_hat = mi / bc1;
138 let v_hat = vi / bc2;
139 update[i] = (m_hat / (v_hat.sqrt() + eps) + wd * param[i] as f64) as f32;
140 }
141 let w_norm = l2_norm(param);
142 let r_norm = l2_norm(update);
143 let trust = if w_norm > 0.0 && r_norm > 0.0 {
144 w_norm / r_norm
145 } else {
146 1.0
147 };
148 let step_size = lr * trust;
149 for i in 0..param.len() {
150 param[i] -= step_size * update[i];
151 }
152 }
153
154 fn end_iteration(&mut self) {
155 self.step += 1;
156 }
157}