Skip to main content

axonml_optim/
sgd.rs

1//! SGD Optimizer - Stochastic Gradient Descent
2//!
3//! Implements SGD with optional momentum and Nesterov acceleration.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::{Optimizer, ParamState};
12
13// =============================================================================
14// SGD
15// =============================================================================
16
17/// Stochastic Gradient Descent optimizer.
18///
19/// Supports momentum and Nesterov acceleration.
20///
21/// Update rule (with momentum):
22/// ```text
23/// v_t = momentum * v_{t-1} + grad
24/// param = param - lr * v_t
25/// ```
26///
27/// Update rule (with Nesterov):
28/// ```text
29/// v_t = momentum * v_{t-1} + grad
30/// param = param - lr * (momentum * v_t + grad)
31/// ```
32pub struct SGD {
33    /// Parameters to optimize.
34    params: Vec<Parameter>,
35    /// Learning rate.
36    lr: f32,
37    /// Momentum factor.
38    momentum: f32,
39    /// Weight decay (L2 regularization).
40    weight_decay: f32,
41    /// Whether to use Nesterov momentum.
42    nesterov: bool,
43    /// Dampening factor for momentum.
44    dampening: f32,
45    /// Per-parameter state (momentum buffers).
46    state: Vec<ParamState>,
47}
48
49impl SGD {
50    /// Creates a new SGD optimizer with default settings.
51    #[must_use]
52    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
53        let num_params = params.len();
54        Self {
55            params,
56            lr,
57            momentum: 0.0,
58            weight_decay: 0.0,
59            nesterov: false,
60            dampening: 0.0,
61            state: vec![ParamState::new(); num_params],
62        }
63    }
64
65    /// Creates SGD with momentum.
66    #[must_use]
67    pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
68        let num_params = params.len();
69        Self {
70            params,
71            lr,
72            momentum,
73            weight_decay: 0.0,
74            nesterov: false,
75            dampening: 0.0,
76            state: vec![ParamState::new(); num_params],
77        }
78    }
79
80    /// Creates SGD with all options.
81    #[must_use]
82    pub fn with_options(
83        params: Vec<Parameter>,
84        lr: f32,
85        momentum: f32,
86        weight_decay: f32,
87        dampening: f32,
88        nesterov: bool,
89    ) -> Self {
90        let num_params = params.len();
91        Self {
92            params,
93            lr,
94            momentum,
95            weight_decay,
96            nesterov,
97            dampening,
98            state: vec![ParamState::new(); num_params],
99        }
100    }
101
102    /// Builder method to set momentum.
103    #[must_use]
104    pub fn momentum(mut self, momentum: f32) -> Self {
105        self.momentum = momentum;
106        self
107    }
108
109    /// Builder method to set weight decay.
110    #[must_use]
111    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
112        self.weight_decay = weight_decay;
113        self
114    }
115
116    /// Builder method to enable Nesterov momentum.
117    #[must_use]
118    pub fn nesterov(mut self, nesterov: bool) -> Self {
119        self.nesterov = nesterov;
120        self
121    }
122
123    /// Builder method to set dampening.
124    #[must_use]
125    pub fn dampening(mut self, dampening: f32) -> Self {
126        self.dampening = dampening;
127        self
128    }
129}
130
131impl Optimizer for SGD {
132    fn step(&mut self) {
133        for (i, param) in self.params.iter().enumerate() {
134            if !param.requires_grad() {
135                continue;
136            }
137
138            let grad = match param.grad() {
139                Some(g) => g,
140                None => continue,
141            };
142
143            let param_data = param.data();
144            let mut param_vec = param_data.to_vec();
145            let mut grad_vec = grad.to_vec();
146
147            // Apply weight decay — fused into grad_vec, no extra allocation
148            if self.weight_decay != 0.0 {
149                for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
150                    *g += self.weight_decay * p;
151                }
152            }
153
154            // Apply momentum
155            if self.momentum != 0.0 {
156                let state = &mut self.state[i];
157
158                if state.momentum_buffer.is_none() {
159                    // First iteration: initialize momentum buffer
160                    state.init_momentum(grad_vec.len());
161                    let buf = state.momentum_buffer.as_mut().unwrap();
162                    buf.copy_from_slice(&grad_vec);
163                } else {
164                    // Subsequent iterations: update momentum buffer in place
165                    let buf = state.momentum_buffer.as_mut().unwrap();
166                    for (b, g) in buf.iter_mut().zip(grad_vec.iter()) {
167                        *b = self.momentum * *b + (1.0 - self.dampening) * *g;
168                    }
169                }
170
171                let buf = state.momentum_buffer.as_ref().unwrap();
172
173                if self.nesterov {
174                    // Nesterov: reuse grad_vec instead of allocating new Vec
175                    for (g, b) in grad_vec.iter_mut().zip(buf.iter()) {
176                        *g += self.momentum * *b;
177                    }
178                    // grad_vec now contains momentum * buf + grad (original grad was already in grad_vec)
179                    // Wait — nesterov formula: effective_grad = momentum * buf + grad
180                    // grad_vec was modified by weight_decay above, so it holds the current grad.
181                    // We need: momentum * buf[i] + grad_vec[i]
182                    // The loop above does: grad_vec[i] = grad_vec[i] + momentum * buf[i] — correct!
183                } else {
184                    // Standard momentum: copy buf into grad_vec (reuse allocation)
185                    grad_vec.copy_from_slice(buf);
186                }
187            }
188
189            // Update parameters in place: param = param - lr * grad
190            let lr = self.lr;
191            for (p, g) in param_vec.iter_mut().zip(grad_vec.iter()) {
192                *p -= lr * g;
193            }
194
195            let mut update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
196            let device = param_data.device();
197            if device.is_gpu() {
198                update = update.to_device(device).unwrap();
199            }
200            param.update_data(update);
201        }
202    }
203
204    fn zero_grad(&mut self) {
205        for param in &self.params {
206            param.zero_grad();
207        }
208    }
209
210    fn get_lr(&self) -> f32 {
211        self.lr
212    }
213
214    fn set_lr(&mut self, lr: f32) {
215        self.lr = lr;
216    }
217
218    fn parameters(&self) -> &[Parameter] {
219        &self.params
220    }
221}
222
223// =============================================================================
224// Tests
225// =============================================================================
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use axonml_autograd::Variable;
231
232    #[test]
233    fn test_sgd_creation() {
234        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
235        let param = Parameter::from_variable(var);
236        let optimizer = SGD::new(vec![param], 0.01);
237
238        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
239        assert_eq!(optimizer.num_parameters(), 1);
240    }
241
242    #[test]
243    fn test_sgd_with_momentum() {
244        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
245        let param = Parameter::from_variable(var);
246        let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
247
248        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
249    }
250
251    #[test]
252    fn test_sgd_step() {
253        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
254        let param = Parameter::from_variable(var);
255
256        // Manually set gradient
257        param
258            .variable()
259            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
260
261        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
262        optimizer.step();
263
264        let new_data = param.data().to_vec();
265        // param = param - lr * grad = [1, 2, 3] - 0.1 * [0.1, 0.2, 0.3]
266        assert!((new_data[0] - 0.99).abs() < 1e-5);
267        assert!((new_data[1] - 1.98).abs() < 1e-5);
268        assert!((new_data[2] - 2.97).abs() < 1e-5);
269    }
270
271    #[test]
272    fn test_sgd_zero_grad() {
273        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
274        let param = Parameter::from_variable(var);
275
276        // Set gradient
277        param
278            .variable()
279            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
280
281        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
282
283        // Verify gradient exists
284        assert!(param.grad().is_some());
285
286        optimizer.zero_grad();
287
288        // Gradient should be zeroed
289        let grad = param.grad();
290        if let Some(g) = grad {
291            assert!(g.to_vec().iter().all(|&x| x == 0.0));
292        }
293    }
294
295    #[test]
296    fn test_sgd_builder_pattern() {
297        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
298        let param = Parameter::from_variable(var);
299
300        let optimizer = SGD::new(vec![param], 0.01)
301            .momentum(0.9)
302            .weight_decay(0.0001)
303            .nesterov(true);
304
305        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
306        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
307        assert!(optimizer.nesterov);
308    }
309}