Skip to main content

rlx_optim/
adamw.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! AdamW — Adam with decoupled weight decay (Loshchilov & Hutter, 2017).
17//!
18//! # Why "decoupled"?
19//!
20//! In classical Adam, an L2 penalty `λ·θ` is folded into the gradient
21//! before the EMAs — but the second-moment estimator `v_t` then scales
22//! the decay term, so parameters with large gradients get *less* L2
23//! regularization. AdamW separates the two: decay multiplies the
24//! parameter directly, *outside* the adaptive step.
25//!
26//! # Update rule
27//!
28//! ```text
29//! m_t = β₁·m_{t-1} + (1 − β₁)·g_t            // g_t is the raw grad
30//! v_t = β₂·v_{t-1} + (1 − β₂)·g_t²
31//! m̂_t = m_t / (1 − β₁ᵗ)
32//! v̂_t = v_t / (1 − β₂ᵗ)
33//! θ_t = θ_{t-1} − lr · ( m̂_t/(√v̂_t + ε) + λ·θ_{t-1} )
34//! ```
35//!
36//! # When to use
37//!
38//! The default for transformer pre-training. Pair with a cosine /
39//! linear LR schedule and `weight_decay = 0.1` for LLMs, `0.01–0.05`
40//! for vision transformers.
41
42use std::collections::HashMap;
43
44use crate::Optimizer;
45use crate::common::{zeros_entry, zip4_for_each};
46
47/// Adam with decoupled weight decay.
48///
49/// Per-tensor state identical to [`crate::Adam`] (two `f32` buffers).
50#[derive(Debug, Clone)]
51pub struct AdamW {
52    /// Learning rate. Typical LLM pre-training value: `1e-4` to `3e-4`.
53    pub lr: f32,
54    /// First-moment EMA decay. Default `0.9`.
55    pub beta1: f32,
56    /// Second-moment EMA decay. Default `0.999` (matches Adam);
57    /// `0.95` is common for very long pre-training runs.
58    pub beta2: f32,
59    /// Denominator stability constant. Default `1e-8`.
60    pub eps: f32,
61    /// **Decoupled** weight-decay coefficient λ. Multiplies the
62    /// parameter directly inside the update; `0.01–0.1` typical.
63    /// Defaults to `0.01`.
64    pub weight_decay: f32,
65    step: u64,
66    m: HashMap<String, Vec<f32>>,
67    v: HashMap<String, Vec<f32>>,
68}
69
70impl AdamW {
71    /// Construct with the given learning rate and the standard
72    /// (β₁, β₂, ε, λ) = (0.9, 0.999, 1e-8, 0.01) defaults.
73    pub fn new(lr: f32) -> Self {
74        Self {
75            lr,
76            beta1: 0.9,
77            beta2: 0.999,
78            eps: 1e-8,
79            weight_decay: 0.01,
80            step: 0,
81            m: HashMap::new(),
82            v: HashMap::new(),
83        }
84    }
85
86    /// Override (β₁, β₂).
87    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
88        self.beta1 = b1;
89        self.beta2 = b2;
90        self
91    }
92
93    /// Override the decoupled-decay coefficient.
94    pub fn with_weight_decay(mut self, wd: f32) -> Self {
95        self.weight_decay = wd;
96        self
97    }
98
99    /// Override the denominator ε.
100    pub fn with_eps(mut self, eps: f32) -> Self {
101        self.eps = eps;
102        self
103    }
104}
105
106impl Optimizer for AdamW {
107    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
108        debug_assert_eq!(param.len(), grad.len());
109        let t = (self.step + 1) as f64;
110        let b1 = self.beta1 as f64;
111        let b2 = self.beta2 as f64;
112        let bc1 = 1.0 - b1.powf(t);
113        let bc2 = 1.0 - b2.powf(t);
114        let eps = self.eps as f64;
115        let lr = self.lr as f64;
116        let wd = self.weight_decay as f64;
117        let m = zeros_entry(&mut self.m, name, param.len());
118        let v = zeros_entry(&mut self.v, name, param.len());
119        zip4_for_each(param, m, v, grad, |p, mi, vi, gi| {
120            let g = gi as f64;
121            let new_m = b1 * *mi as f64 + (1.0 - b1) * g;
122            let new_v = b2 * *vi as f64 + (1.0 - b2) * g * g;
123            *mi = new_m as f32;
124            *vi = new_v as f32;
125            let m_hat = new_m / bc1;
126            let v_hat = new_v / bc2;
127            // Decoupled decay: applied to the parameter, then the
128            // adaptive step is subtracted.
129            let pf = *p as f64;
130            *p = (pf - lr * (m_hat / (v_hat.sqrt() + eps) + wd * pf)) as f32;
131        });
132    }
133
134    fn end_iteration(&mut self) {
135        self.step += 1;
136    }
137}