Skip to main content

axonml_optim/
sgd.rs

1//! `SGD` — Stochastic Gradient Descent with optional momentum and Nesterov.
2//!
3//! `SGD::new(params, lr)`, `.momentum(m)`, `.nesterov(true)`,
4//! `.weight_decay(wd)`, `.dampening(d)`. Standard PyTorch-equivalent
5//! update rule with velocity buffer for momentum variants.
6//!
7//! # File
8//! `crates/axonml-optim/src/sgd.rs`
9//!
10//! # Author
11//! Andrew Jewell Sr. — AutomataNexus LLC
12//! ORCID: 0009-0005-2158-7060
13//!
14//! # Updated
15//! April 14, 2026 11:15 PM EST
16//!
17//! # Disclaimer
18//! Use at own risk. This software is provided "as is", without warranty of any
19//! kind, express or implied. The author and AutomataNexus shall not be held
20//! liable for any damages arising from the use of this software.
21
22use axonml_nn::Parameter;
23use axonml_tensor::Tensor;
24
25use crate::optimizer::Optimizer;
26
27// =============================================================================
28// SGD
29// =============================================================================
30
31/// Stochastic Gradient Descent optimizer.
32///
33/// Supports momentum and Nesterov acceleration.
34///
35/// Update rule (with momentum):
36/// ```text
37/// v_t = momentum * v_{t-1} + grad
38/// param = param - lr * v_t
39/// ```
40///
41/// Update rule (with Nesterov):
42/// ```text
43/// v_t = momentum * v_{t-1} + grad
44/// param = param - lr * (momentum * v_t + grad)
45/// ```
46pub struct SGD {
47    /// Parameters to optimize.
48    params: Vec<Parameter>,
49    /// Learning rate.
50    lr: f32,
51    /// Momentum factor.
52    momentum: f32,
53    /// Weight decay (L2 regularization).
54    weight_decay: f32,
55    /// Whether to use Nesterov momentum.
56    nesterov: bool,
57    /// Dampening factor for momentum.
58    dampening: f32,
59    /// Per-parameter Tensor-based momentum buffers (GPU or CPU).
60    /// Lazily initialized on first step when momentum != 0.
61    momentum_buffers: Vec<Option<Tensor<f32>>>,
62}
63
64impl SGD {
65    /// Creates a new SGD optimizer with default settings.
66    #[must_use]
67    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
68        let num_params = params.len();
69        Self {
70            params,
71            lr,
72            momentum: 0.0,
73            weight_decay: 0.0,
74            nesterov: false,
75            dampening: 0.0,
76            momentum_buffers: vec![None; num_params],
77        }
78    }
79
80    /// Creates SGD with momentum.
81    #[must_use]
82    pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
83        let num_params = params.len();
84        Self {
85            params,
86            lr,
87            momentum,
88            weight_decay: 0.0,
89            nesterov: false,
90            dampening: 0.0,
91            momentum_buffers: vec![None; num_params],
92        }
93    }
94
95    /// Creates SGD with all options.
96    #[must_use]
97    pub fn with_options(
98        params: Vec<Parameter>,
99        lr: f32,
100        momentum: f32,
101        weight_decay: f32,
102        dampening: f32,
103        nesterov: bool,
104    ) -> Self {
105        let num_params = params.len();
106        Self {
107            params,
108            lr,
109            momentum,
110            weight_decay,
111            nesterov,
112            dampening,
113            momentum_buffers: vec![None; num_params],
114        }
115    }
116
117    /// Builder method to set momentum.
118    #[must_use]
119    pub fn momentum(mut self, momentum: f32) -> Self {
120        self.momentum = momentum;
121        self
122    }
123
124    /// Builder method to set weight decay.
125    #[must_use]
126    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
127        self.weight_decay = weight_decay;
128        self
129    }
130
131    /// Builder method to enable Nesterov momentum.
132    #[must_use]
133    pub fn nesterov(mut self, nesterov: bool) -> Self {
134        self.nesterov = nesterov;
135        self
136    }
137
138    /// Builder method to set dampening.
139    #[must_use]
140    pub fn dampening(mut self, dampening: f32) -> Self {
141        self.dampening = dampening;
142        self
143    }
144}
145
146impl Optimizer for SGD {
147    fn step(&mut self) {
148        for (i, param) in self.params.iter().enumerate() {
149            if !param.requires_grad() {
150                continue;
151            }
152
153            let grad = match param.grad() {
154                Some(g) => g,
155                None => continue,
156            };
157
158            let param_data = param.data();
159
160            // ============================================================
161            // Tensor-op path: works on both CPU and GPU without to_vec()
162            // All ops (add, mul, mul_scalar, sub) dispatch to CUDA when
163            // the tensors are GPU-resident.
164            // ============================================================
165
166            // Apply weight decay: d = grad + weight_decay * param
167            let d = if self.weight_decay == 0.0 {
168                grad.clone()
169            } else {
170                grad.add(&param_data.mul_scalar(self.weight_decay)).unwrap()
171            };
172
173            // Apply momentum
174            let update_dir = if self.momentum == 0.0 {
175                d
176            } else {
177                let buf = &mut self.momentum_buffers[i];
178
179                if buf.is_none() {
180                    // First iteration: momentum buffer = d
181                    *buf = Some(d.clone());
182                } else {
183                    // buf = momentum * buf + (1 - dampening) * d
184                    let old = buf.as_ref().unwrap();
185                    let new_buf = old
186                        .mul_scalar(self.momentum)
187                        .add(&d.mul_scalar(1.0 - self.dampening))
188                        .unwrap();
189                    *buf = Some(new_buf);
190                }
191
192                let buf_ref = buf.as_ref().unwrap();
193
194                if self.nesterov {
195                    // effective = d + momentum * buf
196                    d.add(&buf_ref.mul_scalar(self.momentum)).unwrap()
197                } else {
198                    buf_ref.clone()
199                }
200            };
201
202            // param = param - lr * update_dir
203            let new_param = param_data.sub(&update_dir.mul_scalar(self.lr)).unwrap();
204            param.update_data(new_param);
205        }
206    }
207
208    fn zero_grad(&mut self) {
209        for param in &self.params {
210            param.zero_grad();
211        }
212    }
213
214    fn get_lr(&self) -> f32 {
215        self.lr
216    }
217
218    fn set_lr(&mut self, lr: f32) {
219        self.lr = lr;
220    }
221
222    fn parameters(&self) -> &[Parameter] {
223        &self.params
224    }
225}
226
227// =============================================================================
228// Tests
229// =============================================================================
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use axonml_autograd::Variable;
235
236    #[test]
237    fn test_sgd_creation() {
238        let var = Variable::new(
239            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
240            true,
241        );
242        let param = Parameter::from_variable(var);
243        let optimizer = SGD::new(vec![param], 0.01);
244
245        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
246        assert_eq!(optimizer.num_parameters(), 1);
247    }
248
249    #[test]
250    fn test_sgd_with_momentum() {
251        let var = Variable::new(
252            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
253            true,
254        );
255        let param = Parameter::from_variable(var);
256        let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
257
258        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
259    }
260
261    #[test]
262    fn test_sgd_step() {
263        let var = Variable::new(
264            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
265            true,
266        );
267        let param = Parameter::from_variable(var);
268
269        // Manually set gradient
270        param
271            .variable()
272            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
273
274        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
275        optimizer.step();
276
277        let new_data = param.data().to_vec();
278        // param = param - lr * grad = [1, 2, 3] - 0.1 * [0.1, 0.2, 0.3]
279        assert!((new_data[0] - 0.99).abs() < 1e-5);
280        assert!((new_data[1] - 1.98).abs() < 1e-5);
281        assert!((new_data[2] - 2.97).abs() < 1e-5);
282    }
283
284    #[test]
285    fn test_sgd_zero_grad() {
286        let var = Variable::new(
287            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
288            true,
289        );
290        let param = Parameter::from_variable(var);
291
292        // Set gradient
293        param
294            .variable()
295            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
296
297        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
298
299        // Verify gradient exists
300        assert!(param.grad().is_some());
301
302        optimizer.zero_grad();
303
304        // Gradient should be zeroed
305        let grad = param.grad();
306        if let Some(g) = grad {
307            assert!(g.to_vec().iter().all(|&x| x == 0.0));
308        }
309    }
310
311    #[test]
312    fn test_sgd_builder_pattern() {
313        let var = Variable::new(
314            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
315            true,
316        );
317        let param = Parameter::from_variable(var);
318
319        let optimizer = SGD::new(vec![param], 0.01)
320            .momentum(0.9)
321            .weight_decay(0.0001)
322            .nesterov(true);
323
324        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
325        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
326        assert!(optimizer.nesterov);
327    }
328}