rlx_optim/radam.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RAdam — Rectified Adam (Liu et al., 2019, "On the Variance of the
17//! Adaptive Learning Rate and Beyond").
18//!
19//! # Motivation
20//!
21//! Early in training, `v_t` is built from very few samples and its
22//! variance is huge — which makes Adam's effective learning rate
23//! noisy at the same iterations where stability matters most.
24//! Practitioners "fix" this with an LR warm-up; RAdam derives a
25//! *closed-form* warm-up from the variance of the inverse-square-root
26//! of `v̂_t`.
27//!
28//! # Update rule
29//!
30//! Let `ρ_∞ = 2/(1−β₂) − 1` and the "SMA length"
31//! `ρ_t = ρ_∞ − 2t·β₂ᵗ / (1 − β₂ᵗ)`. Define the rectification term
32//!
33//! ```text
34//! r_t = √( ((ρ_t − 4)(ρ_t − 2)·ρ_∞) / ((ρ_∞ − 4)(ρ_∞ − 2)·ρ_t) )
35//! ```
36//!
37//! When `ρ_t > 4` the second moment is "stable enough" — use the
38//! corrected Adam step with `r_t` scaling:
39//!
40//! ```text
41//! θ_t = θ_{t-1} − lr · r_t · m̂_t / (√v̂_t + ε)
42//! ```
43//!
44//! Otherwise (`ρ_t ≤ 4`, early steps), fall back to SGD-with-momentum
45//! using the first moment alone:
46//!
47//! ```text
48//! θ_t = θ_{t-1} − lr · m̂_t
49//! ```
50//!
51//! # When to use
52//!
53//! Drop-in replacement for Adam when you don't want to hand-tune a
54//! warm-up schedule. Same memory cost as Adam.
55
56use std::collections::HashMap;
57
58use crate::Optimizer;
59use crate::common::zeros_entry;
60
61/// Rectified Adam. Per-tensor state: two `f32` buffers.
62#[derive(Debug, Clone)]
63pub struct RAdam {
64 /// Learning rate.
65 pub lr: f32,
66 /// First-moment EMA decay β₁. Default `0.9`.
67 pub beta1: f32,
68 /// Second-moment EMA decay β₂. Default `0.999`.
69 pub beta2: f32,
70 /// Denominator stability constant. Default `1e-8`.
71 pub eps: f32,
72 /// L2 weight-decay coefficient (folded into the gradient — like
73 /// classical Adam, **not** decoupled). Default `0.0`.
74 pub weight_decay: f32,
75 step: u64,
76 m: HashMap<String, Vec<f32>>,
77 v: HashMap<String, Vec<f32>>,
78}
79
80impl RAdam {
81 /// Construct with `(β₁, β₂, ε, λ) = (0.9, 0.999, 1e-8, 0.0)`.
82 pub fn new(lr: f32) -> Self {
83 Self {
84 lr,
85 beta1: 0.9,
86 beta2: 0.999,
87 eps: 1e-8,
88 weight_decay: 0.0,
89 step: 0,
90 m: HashMap::new(),
91 v: HashMap::new(),
92 }
93 }
94
95 /// Override (β₁, β₂).
96 pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
97 self.beta1 = b1;
98 self.beta2 = b2;
99 self
100 }
101
102 /// Override the L2 weight-decay coefficient.
103 pub fn with_weight_decay(mut self, wd: f32) -> Self {
104 self.weight_decay = wd;
105 self
106 }
107}
108
109impl Optimizer for RAdam {
110 fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
111 debug_assert_eq!(param.len(), grad.len());
112 let t = (self.step + 1) as f64;
113 let b1 = self.beta1 as f64;
114 let b2 = self.beta2 as f64;
115 let bc1 = 1.0 - b1.powf(t);
116 let bc2 = 1.0 - b2.powf(t);
117 let rho_inf = 2.0 / (1.0 - b2) - 1.0;
118 let rho_t = rho_inf - 2.0 * t * b2.powf(t) / bc2;
119 let eps = self.eps as f64;
120 let lr = self.lr as f64;
121 let wd = self.weight_decay;
122 // Variance-rectification term `r_t` (Liu et al. eq. 14).
123 let r_t = if rho_t > 4.0 {
124 (((rho_t - 4.0) * (rho_t - 2.0) * rho_inf)
125 / ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t))
126 .sqrt()
127 } else {
128 0.0
129 };
130 let m = zeros_entry(&mut self.m, name, param.len());
131 let v = zeros_entry(&mut self.v, name, param.len());
132 for i in 0..param.len() {
133 let g = (grad[i] + wd * param[i]) as f64;
134 let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
135 let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
136 m[i] = mi as f32;
137 v[i] = vi as f32;
138 let m_hat = mi / bc1;
139 let update = if rho_t > 4.0 {
140 let v_hat = (vi / bc2).sqrt();
141 r_t * m_hat / (v_hat + eps)
142 } else {
143 m_hat
144 };
145 param[i] -= (lr * update) as f32;
146 }
147 }
148
149 fn end_iteration(&mut self) {
150 self.step += 1;
151 }
152}