rlx_optim/adam.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Adam — Adaptive Moment Estimation (Kingma & Ba, 2014).
17//!
18//! # Update rule
19//!
20//! For each parameter, with `t` the 1-based iteration index:
21//!
22//! ```text
23//! g_t = ∇L(θ_{t-1}) + λ·θ_{t-1} // L2 decay folded in
24//! m_t = β₁·m_{t-1} + (1 − β₁)·g_t
25//! v_t = β₂·v_{t-1} + (1 − β₂)·g_t²
26//! m̂_t = m_t / (1 − β₁ᵗ) // bias correction
27//! v̂_t = v_t / (1 − β₂ᵗ)
28//! θ_t = θ_{t-1} − lr · m̂_t / (√v̂_t + ε)
29//! ```
30//!
31//! # When to use
32//!
33//! Reliable default for transformer pre-training and most non-vision
34//! workloads. Per-parameter memory is **2×** the parameter size (one
35//! `m`, one `v`). If you need decoupled weight decay (recommended for
36//! transformers) use [`crate::AdamW`] instead.
37
38use std::collections::HashMap;
39
40use crate::Optimizer;
41use crate::common::{zeros_entry, zip4_for_each};
42
43/// Bias-corrected first/second moment optimizer.
44///
45/// Per-tensor state: two `f32` buffers (`m`, `v`) of the same shape as
46/// the parameter.
47#[derive(Debug, Clone)]
48pub struct Adam {
49 /// Learning rate. Typical: `1e-3` for from-scratch CNNs, `1e-4`
50 /// for transformer fine-tuning.
51 pub lr: f32,
52 /// First-moment EMA decay β₁ ∈ \[0, 1). Default `0.9`.
53 pub beta1: f32,
54 /// Second-moment EMA decay β₂ ∈ \[0, 1). Default `0.999`.
55 pub beta2: f32,
56 /// Stability constant in the denominator. Default `1e-8`.
57 pub eps: f32,
58 /// L2 weight decay coefficient. **Folded into the gradient**
59 /// (the "classic Adam" rule); use [`crate::AdamW`] for decoupled
60 /// decay. Default `0.0`.
61 pub weight_decay: f32,
62 step: u64,
63 m: HashMap<String, Vec<f32>>,
64 v: HashMap<String, Vec<f32>>,
65}
66
67impl Adam {
68 /// Construct with the given learning rate and the standard
69 /// (β₁, β₂, ε) = (0.9, 0.999, 1e-8) defaults.
70 pub fn new(lr: f32) -> Self {
71 Self {
72 lr,
73 beta1: 0.9,
74 beta2: 0.999,
75 eps: 1e-8,
76 weight_decay: 0.0,
77 step: 0,
78 m: HashMap::new(),
79 v: HashMap::new(),
80 }
81 }
82
83 /// Override (β₁, β₂).
84 pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
85 self.beta1 = b1;
86 self.beta2 = b2;
87 self
88 }
89
90 /// Override the denominator stability constant ε.
91 pub fn with_eps(mut self, eps: f32) -> Self {
92 self.eps = eps;
93 self
94 }
95
96 /// Override the L2 weight-decay coefficient.
97 pub fn with_weight_decay(mut self, wd: f32) -> Self {
98 self.weight_decay = wd;
99 self
100 }
101
102 /// 1-based iteration counter. Starts at 1 (so the first call to
103 /// `step()` sees `t=1`), advances on [`Optimizer::end_iteration`].
104 pub fn current_step(&self) -> u64 {
105 self.step + 1
106 }
107}
108
109impl Optimizer for Adam {
110 fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
111 debug_assert_eq!(param.len(), grad.len());
112 let t = (self.step + 1) as f64;
113 let b1 = self.beta1 as f64;
114 let b2 = self.beta2 as f64;
115 let bc1 = 1.0 - b1.powf(t);
116 let bc2 = 1.0 - b2.powf(t);
117 let eps = self.eps as f64;
118 let lr = self.lr as f64;
119 let wd = self.weight_decay;
120 let n = param.len();
121 // `self.m` / `self.v` are distinct fields, so the two
122 // `zeros_entry` calls borrow disjoint regions of `self` and
123 // their results can coexist.
124 let m = zeros_entry(&mut self.m, name, n);
125 let v = zeros_entry(&mut self.v, name, n);
126 zip4_for_each(param, m, v, grad, |p, mi, vi, gi| {
127 let g = (gi + wd * *p) as f64;
128 let new_m = b1 * *mi as f64 + (1.0 - b1) * g;
129 let new_v = b2 * *vi as f64 + (1.0 - b2) * g * g;
130 *mi = new_m as f32;
131 *vi = new_v as f32;
132 let m_hat = new_m / bc1;
133 let v_hat = new_v / bc2;
134 *p -= (lr * m_hat / (v_hat.sqrt() + eps)) as f32;
135 });
136 }
137
138 fn end_iteration(&mut self) {
139 self.step += 1;
140 }
141}