Skip to main content

rlx_optim/
adam.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Adam — Adaptive Moment Estimation (Kingma & Ba, 2014).
17//!
18//! # Update rule
19//!
20//! For each parameter, with `t` the 1-based iteration index:
21//!
22//! ```text
23//! g_t   = ∇L(θ_{t-1}) + λ·θ_{t-1}            // L2 decay folded in
24//! m_t   = β₁·m_{t-1} + (1 − β₁)·g_t
25//! v_t   = β₂·v_{t-1} + (1 − β₂)·g_t²
26//! m̂_t   = m_t / (1 − β₁ᵗ)                    // bias correction
27//! v̂_t   = v_t / (1 − β₂ᵗ)
28//! θ_t   = θ_{t-1} − lr · m̂_t / (√v̂_t + ε)
29//! ```
30//!
31//! # When to use
32//!
33//! Reliable default for transformer pre-training and most non-vision
34//! workloads. Per-parameter memory is **2×** the parameter size (one
35//! `m`, one `v`). If you need decoupled weight decay (recommended for
36//! transformers) use [`crate::AdamW`] instead.
37
38use std::collections::HashMap;
39
40use crate::Optimizer;
41use crate::common::{zeros_entry, zip4_for_each};
42
43/// Bias-corrected first/second moment optimizer.
44///
45/// Per-tensor state: two `f32` buffers (`m`, `v`) of the same shape as
46/// the parameter.
47#[derive(Debug, Clone)]
48pub struct Adam {
49    /// Learning rate. Typical: `1e-3` for from-scratch CNNs, `1e-4`
50    /// for transformer fine-tuning.
51    pub lr: f32,
52    /// First-moment EMA decay β₁ ∈ \[0, 1). Default `0.9`.
53    pub beta1: f32,
54    /// Second-moment EMA decay β₂ ∈ \[0, 1). Default `0.999`.
55    pub beta2: f32,
56    /// Stability constant in the denominator. Default `1e-8`.
57    pub eps: f32,
58    /// L2 weight decay coefficient. **Folded into the gradient**
59    /// (the "classic Adam" rule); use [`crate::AdamW`] for decoupled
60    /// decay. Default `0.0`.
61    pub weight_decay: f32,
62    step: u64,
63    m: HashMap<String, Vec<f32>>,
64    v: HashMap<String, Vec<f32>>,
65}
66
67impl Adam {
68    /// Construct with the given learning rate and the standard
69    /// (β₁, β₂, ε) = (0.9, 0.999, 1e-8) defaults.
70    pub fn new(lr: f32) -> Self {
71        Self {
72            lr,
73            beta1: 0.9,
74            beta2: 0.999,
75            eps: 1e-8,
76            weight_decay: 0.0,
77            step: 0,
78            m: HashMap::new(),
79            v: HashMap::new(),
80        }
81    }
82
83    /// Override (β₁, β₂).
84    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
85        self.beta1 = b1;
86        self.beta2 = b2;
87        self
88    }
89
90    /// Override the denominator stability constant ε.
91    pub fn with_eps(mut self, eps: f32) -> Self {
92        self.eps = eps;
93        self
94    }
95
96    /// Override the L2 weight-decay coefficient.
97    pub fn with_weight_decay(mut self, wd: f32) -> Self {
98        self.weight_decay = wd;
99        self
100    }
101
102    /// 1-based iteration counter. Starts at 1 (so the first call to
103    /// `step()` sees `t=1`), advances on [`Optimizer::end_iteration`].
104    pub fn current_step(&self) -> u64 {
105        self.step + 1
106    }
107}
108
109impl Optimizer for Adam {
110    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
111        debug_assert_eq!(param.len(), grad.len());
112        let t = (self.step + 1) as f64;
113        let b1 = self.beta1 as f64;
114        let b2 = self.beta2 as f64;
115        let bc1 = 1.0 - b1.powf(t);
116        let bc2 = 1.0 - b2.powf(t);
117        let eps = self.eps as f64;
118        let lr = self.lr as f64;
119        let wd = self.weight_decay;
120        let n = param.len();
121        // `self.m` / `self.v` are distinct fields, so the two
122        // `zeros_entry` calls borrow disjoint regions of `self` and
123        // their results can coexist.
124        let m = zeros_entry(&mut self.m, name, n);
125        let v = zeros_entry(&mut self.v, name, n);
126        zip4_for_each(param, m, v, grad, |p, mi, vi, gi| {
127            let g = (gi + wd * *p) as f64;
128            let new_m = b1 * *mi as f64 + (1.0 - b1) * g;
129            let new_v = b2 * *vi as f64 + (1.0 - b2) * g * g;
130            *mi = new_m as f32;
131            *vi = new_v as f32;
132            let m_hat = new_m / bc1;
133            let v_hat = new_v / bc2;
134            *p -= (lr * m_hat / (v_hat.sqrt() + eps)) as f32;
135        });
136    }
137
138    fn end_iteration(&mut self) {
139        self.step += 1;
140    }
141}