Skip to main content

rlx_optim/
nadamw.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NAdamW — Nesterov-accelerated AdamW.
17//!
18//! Combines Dozat (2016, "Incorporating Nesterov Momentum into Adam")
19//! with the decoupled-decay rule of [`crate::AdamW`]. The Nesterov
20//! trick is a *lookahead*: instead of using the bias-corrected first
21//! moment `m̂_t` directly, we use a convex combination of `m̂_{t+1}`
22//! (predicted) and the current bias-corrected gradient.
23//!
24//! # Update rule
25//!
26//! ```text
27//! m_t   = β₁·m_{t-1} + (1 − β₁)·g_t
28//! v_t   = β₂·v_{t-1} + (1 − β₂)·g_t²
29//! m̄_t   = β₁ · m_t / (1 − β₁^{t+1}) + (1 − β₁) · g_t / (1 − β₁ᵗ)
30//! v̂_t   = v_t / (1 − β₂ᵗ)
31//! θ_t   = θ_{t-1} − lr · ( m̄_t/(√v̂_t + ε) + λ·θ_{t-1} )
32//! ```
33//!
34//! # When to use
35//!
36//! Slightly more aggressive than AdamW on the early steps; the
37//! lookahead first-moment occasionally helps escape flat regions.
38//! Same state cost as AdamW.
39
40use std::collections::HashMap;
41
42use crate::Optimizer;
43use crate::common::zeros_entry;
44
45/// Nesterov AdamW. Per-tensor state: two `f32` buffers.
46#[derive(Debug, Clone)]
47pub struct NAdamW {
48    /// Learning rate.
49    pub lr: f32,
50    /// First-moment EMA decay β₁. Default `0.9`.
51    pub beta1: f32,
52    /// Second-moment EMA decay β₂. Default `0.999`.
53    pub beta2: f32,
54    /// Denominator stability constant. Default `1e-8`.
55    pub eps: f32,
56    /// Decoupled weight-decay coefficient λ. Default `0.01`.
57    pub weight_decay: f32,
58    step: u64,
59    m: HashMap<String, Vec<f32>>,
60    v: HashMap<String, Vec<f32>>,
61}
62
63impl NAdamW {
64    /// Construct with `(β₁, β₂, ε, λ) = (0.9, 0.999, 1e-8, 0.01)`.
65    pub fn new(lr: f32) -> Self {
66        Self {
67            lr,
68            beta1: 0.9,
69            beta2: 0.999,
70            eps: 1e-8,
71            weight_decay: 0.01,
72            step: 0,
73            m: HashMap::new(),
74            v: HashMap::new(),
75        }
76    }
77
78    /// Override (β₁, β₂).
79    pub fn with_betas(mut self, b1: f32, b2: f32) -> Self {
80        self.beta1 = b1;
81        self.beta2 = b2;
82        self
83    }
84
85    /// Override the decoupled-decay coefficient.
86    pub fn with_weight_decay(mut self, wd: f32) -> Self {
87        self.weight_decay = wd;
88        self
89    }
90}
91
92impl Optimizer for NAdamW {
93    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
94        debug_assert_eq!(param.len(), grad.len());
95        let t = (self.step + 1) as f64;
96        let b1 = self.beta1 as f64;
97        let b2 = self.beta2 as f64;
98        let bc1 = 1.0 - b1.powf(t);
99        let bc1_next = 1.0 - b1.powf(t + 1.0);
100        let bc2 = 1.0 - b2.powf(t);
101        let eps = self.eps as f64;
102        let lr = self.lr as f64;
103        let wd = self.weight_decay as f64;
104        let m = zeros_entry(&mut self.m, name, param.len());
105        let v = zeros_entry(&mut self.v, name, param.len());
106        for i in 0..param.len() {
107            let g = grad[i] as f64;
108            let mi = b1 * m[i] as f64 + (1.0 - b1) * g;
109            let vi = b2 * v[i] as f64 + (1.0 - b2) * g * g;
110            m[i] = mi as f32;
111            v[i] = vi as f32;
112            // Nesterov-corrected first moment (Dozat eq. 6):
113            //   m_bar = b1 * m_hat_{t+1} + (1-b1)/bc1 * g
114            let m_hat = mi / bc1_next;
115            let m_bar = b1 * m_hat + (1.0 - b1) * g / bc1;
116            let v_hat = vi / bc2;
117            let p = param[i] as f64;
118            param[i] = (p - lr * (m_bar / (v_hat.sqrt() + eps) + wd * p)) as f32;
119        }
120    }
121
122    fn end_iteration(&mut self) {
123        self.step += 1;
124    }
125}