Skip to main content

rlx_optim/
qhadamw.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! QHAdamW — Quasi-Hyperbolic Adam (Ma & Yarats, 2019) with decoupled
17//! weight decay.
18//!
19//! # Idea
20//!
21//! Adam can be viewed as "the second moment EMA scales the gradient,
22//! the first moment EMA *replaces* the gradient." The quasi-hyperbolic
23//! family says: don't *replace* — *interpolate*. Mix the EMA with the
24//! raw current gradient, controlled by per-moment scalars `ν₁, ν₂`.
25//!
26//! # Update rule
27//!
28//! ```text
29//! m_t = β₁·m_{t-1} + (1 − β₁)·g_t
30//! v_t = β₂·v_{t-1} + (1 − β₂)·g_t²
31//! num = (1 − ν₁)·g_t   + ν₁·m̂_t
32//! den = √((1 − ν₂)·g_t² + ν₂·v̂_t) + ε
33//! θ_t = θ_{t-1} − lr · ( num / den + λ·θ_{t-1} )
34//! ```
35//!
36//! Setting `ν₁ = ν₂ = 1` recovers standard AdamW; `ν₁ = β₁` recovers
37//! Nesterov-style behavior in the limit; `ν₁ < 1` injects more of the
38//! current gradient and tends to be more robust on noisy losses.
39//!
40//! # When to use
41//!
42//! When you've found AdamW too sluggish on noisy / heavy-tail
43//! gradients (RL, GAN training) and an LR sweep didn't help.
44
45use std::collections::HashMap;
46
47use crate::Optimizer;
48use crate::common::zeros_entry;
49
50/// Quasi-hyperbolic AdamW. Per-tensor state: two `f32` buffers.
51#[derive(Debug, Clone)]
52pub struct QHAdamW {
53    /// Learning rate.
54    pub lr: f32,
55    /// First-moment EMA decay β₁. Ma & Yarats recommend `0.995` (much
56    /// closer to 1 than vanilla Adam) — the QH interpolation already
57    /// keeps current-gradient weight in the numerator.
58    pub beta1: f32,
59    /// Second-moment EMA decay β₂. Default `0.999`.
60    pub beta2: f32,
61    /// First-moment QH interpolation coefficient ν₁ ∈ \[0, 1\].
62    /// `1.0` = pure EMA (standard Adam first moment); `0.0` = pure
63    /// current gradient (no momentum). Default `0.7`.
64    pub nu1: f32,
65    /// Second-moment QH interpolation coefficient ν₂. `1.0` = standard
66    /// Adam denominator. Default `1.0`.
67    pub nu2: f32,
68    /// Denominator stability constant. Default `1e-8`.
69    pub eps: f32,
70    /// Decoupled weight-decay coefficient λ. Default `0.01`.
71    pub weight_decay: f32,
72    step: u64,
73    m: HashMap<String, Vec<f32>>,
74    v: HashMap<String, Vec<f32>>,
75}
76
77impl QHAdamW {
78    /// Construct with `(β₁, β₂, ν₁, ν₂, ε, λ) = (0.995, 0.999, 0.7, 1.0, 1e-8, 0.01)`.
79    pub fn new(lr: f32) -> Self {
80        Self {
81            lr,
82            beta1: 0.995,
83            beta2: 0.999,
84            nu1: 0.7,
85            nu2: 1.0,
86            eps: 1e-8,
87            weight_decay: 0.01,
88            step: 0,
89            m: HashMap::new(),
90            v: HashMap::new(),
91        }
92    }
93
94    /// Override (β₁, β₂).
95    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
96        self.beta1 = b1;
97        self.beta2 = b2;
98        self
99    }
100
101    /// Override the quasi-hyperbolic coefficients (ν₁, ν₂).
102    pub fn with_nus(mut self, n1: f32, n2: f32) -> Self {
103        self.nu1 = n1;
104        self.nu2 = n2;
105        self
106    }
107
108    /// Override the decoupled-decay coefficient.
109    pub fn with_weight_decay(mut self, wd: f32) -> Self {
110        self.weight_decay = wd;
111        self
112    }
113}
114
115impl Optimizer for QHAdamW {
116    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
117        debug_assert_eq!(param.len(), grad.len());
118        let t = (self.step + 1) as f64;
119        let b1 = self.beta1 as f64;
120        let b2 = self.beta2 as f64;
121        let bc1 = 1.0 - b1.powf(t);
122        let bc2 = 1.0 - b2.powf(t);
123        let n1 = self.nu1 as f64;
124        let n2 = self.nu2 as f64;
125        let eps = self.eps as f64;
126        let lr = self.lr as f64;
127        let wd = self.weight_decay as f64;
128        let m = zeros_entry(&mut self.m, name, param.len());
129        let v = zeros_entry(&mut self.v, name, param.len());
130        for i in 0..param.len() {
131            let g = grad[i] as f64;
132            let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
133            let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
134            m[i] = mi as f32;
135            v[i] = vi as f32;
136            let m_hat = mi / bc1;
137            let v_hat = vi / bc2;
138            // Quasi-hyperbolic numerator & denominator (Ma & Yarats Alg. 2).
139            let num = (1.0 - n1) * g + n1 * m_hat;
140            let den = ((1.0 - n2) * g * g + n2 * v_hat).sqrt() + eps;
141            let p = param[i] as f64;
142            param[i] = (p - lr * (num / den + wd * p)) as f32;
143        }
144    }
145
146    fn end_iteration(&mut self) {
147        self.step += 1;
148    }
149}