Skip to main content

axonml_optim/
sgd.rs

1//! SGD Optimizer - Stochastic Gradient Descent
2//!
3//! # File
4//! `crates/axonml-optim/src/sgd.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22// =============================================================================
23// SGD
24// =============================================================================
25
26/// Stochastic Gradient Descent optimizer.
27///
28/// Supports momentum and Nesterov acceleration.
29///
30/// Update rule (with momentum):
31/// ```text
32/// v_t = momentum * v_{t-1} + grad
33/// param = param - lr * v_t
34/// ```
35///
36/// Update rule (with Nesterov):
37/// ```text
38/// v_t = momentum * v_{t-1} + grad
39/// param = param - lr * (momentum * v_t + grad)
40/// ```
41pub struct SGD {
42    /// Parameters to optimize.
43    params: Vec<Parameter>,
44    /// Learning rate.
45    lr: f32,
46    /// Momentum factor.
47    momentum: f32,
48    /// Weight decay (L2 regularization).
49    weight_decay: f32,
50    /// Whether to use Nesterov momentum.
51    nesterov: bool,
52    /// Dampening factor for momentum.
53    dampening: f32,
54    /// Per-parameter Tensor-based momentum buffers (GPU or CPU).
55    /// Lazily initialized on first step when momentum != 0.
56    momentum_buffers: Vec<Option<Tensor<f32>>>,
57}
58
59impl SGD {
60    /// Creates a new SGD optimizer with default settings.
61    #[must_use]
62    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
63        let num_params = params.len();
64        Self {
65            params,
66            lr,
67            momentum: 0.0,
68            weight_decay: 0.0,
69            nesterov: false,
70            dampening: 0.0,
71            momentum_buffers: vec![None; num_params],
72        }
73    }
74
75    /// Creates SGD with momentum.
76    #[must_use]
77    pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
78        let num_params = params.len();
79        Self {
80            params,
81            lr,
82            momentum,
83            weight_decay: 0.0,
84            nesterov: false,
85            dampening: 0.0,
86            momentum_buffers: vec![None; num_params],
87        }
88    }
89
90    /// Creates SGD with all options.
91    #[must_use]
92    pub fn with_options(
93        params: Vec<Parameter>,
94        lr: f32,
95        momentum: f32,
96        weight_decay: f32,
97        dampening: f32,
98        nesterov: bool,
99    ) -> Self {
100        let num_params = params.len();
101        Self {
102            params,
103            lr,
104            momentum,
105            weight_decay,
106            nesterov,
107            dampening,
108            momentum_buffers: vec![None; num_params],
109        }
110    }
111
112    /// Builder method to set momentum.
113    #[must_use]
114    pub fn momentum(mut self, momentum: f32) -> Self {
115        self.momentum = momentum;
116        self
117    }
118
119    /// Builder method to set weight decay.
120    #[must_use]
121    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
122        self.weight_decay = weight_decay;
123        self
124    }
125
126    /// Builder method to enable Nesterov momentum.
127    #[must_use]
128    pub fn nesterov(mut self, nesterov: bool) -> Self {
129        self.nesterov = nesterov;
130        self
131    }
132
133    /// Builder method to set dampening.
134    #[must_use]
135    pub fn dampening(mut self, dampening: f32) -> Self {
136        self.dampening = dampening;
137        self
138    }
139}
140
141impl Optimizer for SGD {
142    fn step(&mut self) {
143        for (i, param) in self.params.iter().enumerate() {
144            if !param.requires_grad() {
145                continue;
146            }
147
148            let grad = match param.grad() {
149                Some(g) => g,
150                None => continue,
151            };
152
153            let param_data = param.data();
154
155            // ============================================================
156            // Tensor-op path: works on both CPU and GPU without to_vec()
157            // All ops (add, mul, mul_scalar, sub) dispatch to CUDA when
158            // the tensors are GPU-resident.
159            // ============================================================
160
161            // Apply weight decay: d = grad + weight_decay * param
162            let d = if self.weight_decay == 0.0 {
163                grad.clone()
164            } else {
165                grad.add(&param_data.mul_scalar(self.weight_decay)).unwrap()
166            };
167
168            // Apply momentum
169            let update_dir = if self.momentum == 0.0 {
170                d
171            } else {
172                let buf = &mut self.momentum_buffers[i];
173
174                if buf.is_none() {
175                    // First iteration: momentum buffer = d
176                    *buf = Some(d.clone());
177                } else {
178                    // buf = momentum * buf + (1 - dampening) * d
179                    let old = buf.as_ref().unwrap();
180                    let new_buf = old
181                        .mul_scalar(self.momentum)
182                        .add(&d.mul_scalar(1.0 - self.dampening))
183                        .unwrap();
184                    *buf = Some(new_buf);
185                }
186
187                let buf_ref = buf.as_ref().unwrap();
188
189                if self.nesterov {
190                    // effective = d + momentum * buf
191                    d.add(&buf_ref.mul_scalar(self.momentum)).unwrap()
192                } else {
193                    buf_ref.clone()
194                }
195            };
196
197            // param = param - lr * update_dir
198            let new_param = param_data.sub(&update_dir.mul_scalar(self.lr)).unwrap();
199            param.update_data(new_param);
200        }
201    }
202
203    fn zero_grad(&mut self) {
204        for param in &self.params {
205            param.zero_grad();
206        }
207    }
208
209    fn get_lr(&self) -> f32 {
210        self.lr
211    }
212
213    fn set_lr(&mut self, lr: f32) {
214        self.lr = lr;
215    }
216
217    fn parameters(&self) -> &[Parameter] {
218        &self.params
219    }
220}
221
222// =============================================================================
223// Tests
224// =============================================================================
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use axonml_autograd::Variable;
230
231    #[test]
232    fn test_sgd_creation() {
233        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
234        let param = Parameter::from_variable(var);
235        let optimizer = SGD::new(vec![param], 0.01);
236
237        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
238        assert_eq!(optimizer.num_parameters(), 1);
239    }
240
241    #[test]
242    fn test_sgd_with_momentum() {
243        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
244        let param = Parameter::from_variable(var);
245        let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
246
247        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
248    }
249
250    #[test]
251    fn test_sgd_step() {
252        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
253        let param = Parameter::from_variable(var);
254
255        // Manually set gradient
256        param
257            .variable()
258            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
259
260        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
261        optimizer.step();
262
263        let new_data = param.data().to_vec();
264        // param = param - lr * grad = [1, 2, 3] - 0.1 * [0.1, 0.2, 0.3]
265        assert!((new_data[0] - 0.99).abs() < 1e-5);
266        assert!((new_data[1] - 1.98).abs() < 1e-5);
267        assert!((new_data[2] - 2.97).abs() < 1e-5);
268    }
269
270    #[test]
271    fn test_sgd_zero_grad() {
272        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
273        let param = Parameter::from_variable(var);
274
275        // Set gradient
276        param
277            .variable()
278            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
279
280        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
281
282        // Verify gradient exists
283        assert!(param.grad().is_some());
284
285        optimizer.zero_grad();
286
287        // Gradient should be zeroed
288        let grad = param.grad();
289        if let Some(g) = grad {
290            assert!(g.to_vec().iter().all(|&x| x == 0.0));
291        }
292    }
293
294    #[test]
295    fn test_sgd_builder_pattern() {
296        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
297        let param = Parameter::from_variable(var);
298
299        let optimizer = SGD::new(vec![param], 0.01)
300            .momentum(0.9)
301            .weight_decay(0.0001)
302            .nesterov(true);
303
304        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
305        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
306        assert!(optimizer.nesterov);
307    }
308}