Skip to main content

rlx_optim/
radam.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RAdam — Rectified Adam (Liu et al., 2019, "On the Variance of the
17//! Adaptive Learning Rate and Beyond").
18//!
19//! # Motivation
20//!
21//! Early in training, `v_t` is built from very few samples and its
22//! variance is huge — which makes Adam's effective learning rate
23//! noisy at the same iterations where stability matters most.
24//! Practitioners "fix" this with an LR warm-up; RAdam derives a
25//! *closed-form* warm-up from the variance of the inverse-square-root
26//! of `v̂_t`.
27//!
28//! # Update rule
29//!
30//! Let `ρ_∞ = 2/(1−β₂) − 1` and the "SMA length"
31//! `ρ_t = ρ_∞ − 2t·β₂ᵗ / (1 − β₂ᵗ)`. Define the rectification term
32//!
33//! ```text
34//! r_t = √( ((ρ_t − 4)(ρ_t − 2)·ρ_∞) / ((ρ_∞ − 4)(ρ_∞ − 2)·ρ_t) )
35//! ```
36//!
37//! When `ρ_t > 4` the second moment is "stable enough" — use the
38//! corrected Adam step with `r_t` scaling:
39//!
40//! ```text
41//! θ_t = θ_{t-1} − lr · r_t · m̂_t / (√v̂_t + ε)
42//! ```
43//!
44//! Otherwise (`ρ_t ≤ 4`, early steps), fall back to SGD-with-momentum
45//! using the first moment alone:
46//!
47//! ```text
48//! θ_t = θ_{t-1} − lr · m̂_t
49//! ```
50//!
51//! # When to use
52//!
53//! Drop-in replacement for Adam when you don't want to hand-tune a
54//! warm-up schedule. Same memory cost as Adam.
55
56use std::collections::HashMap;
57
58use crate::Optimizer;
59use crate::common::zeros_entry;
60
61/// Rectified Adam. Per-tensor state: two `f32` buffers.
62#[derive(Debug, Clone)]
63pub struct RAdam {
64    /// Learning rate.
65    pub lr: f32,
66    /// First-moment EMA decay β₁. Default `0.9`.
67    pub beta1: f32,
68    /// Second-moment EMA decay β₂. Default `0.999`.
69    pub beta2: f32,
70    /// Denominator stability constant. Default `1e-8`.
71    pub eps: f32,
72    /// L2 weight-decay coefficient (folded into the gradient — like
73    /// classical Adam, **not** decoupled). Default `0.0`.
74    pub weight_decay: f32,
75    step: u64,
76    m: HashMap<String, Vec<f32>>,
77    v: HashMap<String, Vec<f32>>,
78}
79
80impl RAdam {
81    /// Construct with `(β₁, β₂, ε, λ) = (0.9, 0.999, 1e-8, 0.0)`.
82    pub fn new(lr: f32) -> Self {
83        Self {
84            lr,
85            beta1: 0.9,
86            beta2: 0.999,
87            eps: 1e-8,
88            weight_decay: 0.0,
89            step: 0,
90            m: HashMap::new(),
91            v: HashMap::new(),
92        }
93    }
94
95    /// Override (β₁, β₂).
96    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
97        self.beta1 = b1;
98        self.beta2 = b2;
99        self
100    }
101
102    /// Override the L2 weight-decay coefficient.
103    pub fn with_weight_decay(mut self, wd: f32) -> Self {
104        self.weight_decay = wd;
105        self
106    }
107}
108
109impl Optimizer for RAdam {
110    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
111        debug_assert_eq!(param.len(), grad.len());
112        let t = (self.step + 1) as f64;
113        let b1 = self.beta1 as f64;
114        let b2 = self.beta2 as f64;
115        let bc1 = 1.0 - b1.powf(t);
116        let bc2 = 1.0 - b2.powf(t);
117        let rho_inf = 2.0 / (1.0 - b2) - 1.0;
118        let rho_t = rho_inf - 2.0 * t * b2.powf(t) / bc2;
119        let eps = self.eps as f64;
120        let lr = self.lr as f64;
121        let wd = self.weight_decay;
122        // Variance-rectification term `r_t` (Liu et al. eq. 14).
123        let r_t = if rho_t > 4.0 {
124            (((rho_t - 4.0) * (rho_t - 2.0) * rho_inf)
125                / ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t))
126                .sqrt()
127        } else {
128            0.0
129        };
130        let m = zeros_entry(&mut self.m, name, param.len());
131        let v = zeros_entry(&mut self.v, name, param.len());
132        for i in 0..param.len() {
133            let g = (grad[i] + wd * param[i]) as f64;
134            let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
135            let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
136            m[i] = mi as f32;
137            v[i] = vi as f32;
138            let m_hat = mi / bc1;
139            let update = if rho_t > 4.0 {
140                let v_hat = (vi / bc2).sqrt();
141                r_t * m_hat / (v_hat + eps)
142            } else {
143                m_hat
144            };
145            param[i] -= (lr * update) as f32;
146        }
147    }
148
149    fn end_iteration(&mut self) {
150        self.step += 1;
151    }
152}