rlx_optim/sgd.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Stochastic Gradient Descent with optional momentum and decoupled
17//! L2 weight decay.
18//!
19//! # Update rules
20//!
21//! Vanilla SGD (`momentum = 0`):
22//!
23//! ```text
24//! θ_{t+1} = θ_t − lr · (g_t + λ·θ_t)
25//! ```
26//!
27//! Polyak momentum (`momentum = μ`, `nesterov = false`):
28//!
29//! ```text
30//! v_{t+1} = μ·v_t + (g_t + λ·θ_t)
31//! θ_{t+1} = θ_t − lr · v_{t+1}
32//! ```
33//!
34//! Nesterov-accelerated SGD (`nesterov = true`):
35//!
36//! ```text
37//! v_{t+1} = μ·v_t + (g_t + λ·θ_t)
38//! θ_{t+1} = θ_t − lr · (g_t + λ·θ_t + μ·v_{t+1})
39//! ```
40//!
41//! # When to use
42//!
43//! The default choice when training CNNs from scratch; with a
44//! well-tuned `lr` schedule it still beats Adam on many vision
45//! benchmarks. Cheap state (one buffer if `momentum > 0`).
46
47use std::collections::HashMap;
48
49use crate::Optimizer;
50use crate::common::zeros_entry;
51
52/// SGD with momentum / Nesterov / L2 weight decay.
53///
54/// All hyperparameters are public so callers can hot-swap them between
55/// iterations (e.g. for a warm-up schedule). State is keyed by
56/// parameter name; the same `Sgd` instance can drive every tensor in
57/// a model.
58#[derive(Debug, Clone)]
59pub struct Sgd {
60 /// Learning rate. No default — pass it to [`Sgd::new`].
61 pub lr: f32,
62 /// Polyak momentum coefficient ∈ \[0, 1\). `0.0` disables momentum
63 /// entirely (and the per-tensor velocity buffer is still allocated
64 /// but unused — set via [`Sgd::with_momentum`] if you want it on).
65 pub momentum: f32,
66 /// Use Nesterov-accelerated momentum. Only meaningful when
67 /// `momentum > 0`.
68 pub nesterov: bool,
69 /// L2 weight decay coefficient λ. Folded into the gradient
70 /// *before* the momentum EMA (classical, **not** decoupled).
71 /// Use [`crate::AdamW`]-style decoupling if you need that.
72 pub weight_decay: f32,
73 v: HashMap<String, Vec<f32>>,
74}
75
76impl Sgd {
77 /// Construct with `lr` and momentum / decay disabled.
78 pub fn new(lr: f32) -> Self {
79 Self {
80 lr,
81 momentum: 0.0,
82 nesterov: false,
83 weight_decay: 0.0,
84 v: HashMap::new(),
85 }
86 }
87
88 /// Enable Polyak (or Nesterov) momentum.
89 pub fn with_momentum(mut self, momentum: f32, nesterov: bool) -> Self {
90 self.momentum = momentum;
91 self.nesterov = nesterov;
92 self
93 }
94
95 /// Set the L2 weight-decay coefficient.
96 pub fn with_weight_decay(mut self, wd: f32) -> Self {
97 self.weight_decay = wd;
98 self
99 }
100}
101
102impl Optimizer for Sgd {
103 fn step(&mut self, name: &str, _shape: &[usize], param: &mut [f32], grad: &[f32]) {
104 debug_assert_eq!(param.len(), grad.len());
105 let v = zeros_entry(&mut self.v, name, param.len());
106 let mu = self.momentum;
107 let wd = self.weight_decay;
108 let lr = self.lr;
109 for i in 0..param.len() {
110 let g = grad[i] + wd * param[i];
111 if mu == 0.0 {
112 param[i] -= lr * g;
113 } else {
114 v[i] = mu * v[i] + g;
115 let update = if self.nesterov { g + mu * v[i] } else { v[i] };
116 param[i] -= lr * update;
117 }
118 }
119 }
120}