Skip to main content

rlx_optim/
sgd.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Stochastic Gradient Descent with optional momentum and decoupled
17//! L2 weight decay.
18//!
19//! # Update rules
20//!
21//! Vanilla SGD (`momentum = 0`):
22//!
23//! ```text
24//! θ_{t+1} = θ_t − lr · (g_t + λ·θ_t)
25//! ```
26//!
27//! Polyak momentum (`momentum = μ`, `nesterov = false`):
28//!
29//! ```text
30//! v_{t+1} = μ·v_t + (g_t + λ·θ_t)
31//! θ_{t+1} = θ_t − lr · v_{t+1}
32//! ```
33//!
34//! Nesterov-accelerated SGD (`nesterov = true`):
35//!
36//! ```text
37//! v_{t+1} = μ·v_t + (g_t + λ·θ_t)
38//! θ_{t+1} = θ_t − lr · (g_t + λ·θ_t + μ·v_{t+1})
39//! ```
40//!
41//! # When to use
42//!
43//! The default choice when training CNNs from scratch; with a
44//! well-tuned `lr` schedule it still beats Adam on many vision
45//! benchmarks. Cheap state (one buffer if `momentum > 0`).
46
47use std::collections::HashMap;
48
49use crate::Optimizer;
50use crate::common::zeros_entry;
51
52/// SGD with momentum / Nesterov / L2 weight decay.
53///
54/// All hyperparameters are public so callers can hot-swap them between
55/// iterations (e.g. for a warm-up schedule). State is keyed by
56/// parameter name; the same `Sgd` instance can drive every tensor in
57/// a model.
58#[derive(Debug, Clone)]
59pub struct Sgd {
60    /// Learning rate. No default — pass it to [`Sgd::new`].
61    pub lr: f32,
62    /// Polyak momentum coefficient ∈ \[0, 1\). `0.0` disables momentum
63    /// entirely (and the per-tensor velocity buffer is still allocated
64    /// but unused — set via [`Sgd::with_momentum`] if you want it on).
65    pub momentum: f32,
66    /// Use Nesterov-accelerated momentum. Only meaningful when
67    /// `momentum > 0`.
68    pub nesterov: bool,
69    /// L2 weight decay coefficient λ. Folded into the gradient
70    /// *before* the momentum EMA (classical, **not** decoupled).
71    /// Use [`crate::AdamW`]-style decoupling if you need that.
72    pub weight_decay: f32,
73    v: HashMap<String, Vec<f32>>,
74}
75
76impl Sgd {
77    /// Construct with `lr` and momentum / decay disabled.
78    pub fn new(lr: f32) -> Self {
79        Self {
80            lr,
81            momentum: 0.0,
82            nesterov: false,
83            weight_decay: 0.0,
84            v: HashMap::new(),
85        }
86    }
87
88    /// Enable Polyak (or Nesterov) momentum.
89    pub fn with_momentum(mut self, momentum: f32, nesterov: bool) -> Self {
90        self.momentum = momentum;
91        self.nesterov = nesterov;
92        self
93    }
94
95    /// Set the L2 weight-decay coefficient.
96    pub fn with_weight_decay(mut self, wd: f32) -> Self {
97        self.weight_decay = wd;
98        self
99    }
100}
101
102impl Optimizer for Sgd {
103    fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
104        debug_assert_eq!(param.len(), grad.len());
105        let v = zeros_entry(&mut self.v, name, param.len());
106        let mu = self.momentum;
107        let wd = self.weight_decay;
108        let lr = self.lr;
109        for i in 0..param.len() {
110            let g = grad[i] + wd * param[i];
111            if mu == 0.0 {
112                param[i] -= lr * g;
113            } else {
114                v[i] = mu * v[i] + g;
115                let update = if self.nesterov { g + mu * v[i] } else { v[i] };
116                param[i] -= lr * update;
117            }
118        }
119    }
120}