axonml_optim/
sgd.rs

1//! SGD Optimizer - Stochastic Gradient Descent
2//!
3//! Implements SGD with optional momentum and Nesterov acceleration.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::{Optimizer, ParamState};
12
13// =============================================================================
14// SGD
15// =============================================================================
16
17/// Stochastic Gradient Descent optimizer.
18///
19/// Supports momentum and Nesterov acceleration.
20///
21/// Update rule (with momentum):
22/// ```text
23/// v_t = momentum * v_{t-1} + grad
24/// param = param - lr * v_t
25/// ```
26///
27/// Update rule (with Nesterov):
28/// ```text
29/// v_t = momentum * v_{t-1} + grad
30/// param = param - lr * (momentum * v_t + grad)
31/// ```
32pub struct SGD {
33    /// Parameters to optimize.
34    params: Vec<Parameter>,
35    /// Learning rate.
36    lr: f32,
37    /// Momentum factor.
38    momentum: f32,
39    /// Weight decay (L2 regularization).
40    weight_decay: f32,
41    /// Whether to use Nesterov momentum.
42    nesterov: bool,
43    /// Dampening factor for momentum.
44    dampening: f32,
45    /// Per-parameter state (momentum buffers).
46    state: Vec<ParamState>,
47}
48
49impl SGD {
50    /// Creates a new SGD optimizer with default settings.
51    #[must_use] pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
52        let num_params = params.len();
53        Self {
54            params,
55            lr,
56            momentum: 0.0,
57            weight_decay: 0.0,
58            nesterov: false,
59            dampening: 0.0,
60            state: vec![ParamState::new(); num_params],
61        }
62    }
63
64    /// Creates SGD with momentum.
65    #[must_use] pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
66        let num_params = params.len();
67        Self {
68            params,
69            lr,
70            momentum,
71            weight_decay: 0.0,
72            nesterov: false,
73            dampening: 0.0,
74            state: vec![ParamState::new(); num_params],
75        }
76    }
77
78    /// Creates SGD with all options.
79    #[must_use] pub fn with_options(
80        params: Vec<Parameter>,
81        lr: f32,
82        momentum: f32,
83        weight_decay: f32,
84        dampening: f32,
85        nesterov: bool,
86    ) -> Self {
87        let num_params = params.len();
88        Self {
89            params,
90            lr,
91            momentum,
92            weight_decay,
93            nesterov,
94            dampening,
95            state: vec![ParamState::new(); num_params],
96        }
97    }
98
99    /// Builder method to set momentum.
100    #[must_use] pub fn momentum(mut self, momentum: f32) -> Self {
101        self.momentum = momentum;
102        self
103    }
104
105    /// Builder method to set weight decay.
106    #[must_use] pub fn weight_decay(mut self, weight_decay: f32) -> Self {
107        self.weight_decay = weight_decay;
108        self
109    }
110
111    /// Builder method to enable Nesterov momentum.
112    #[must_use] pub fn nesterov(mut self, nesterov: bool) -> Self {
113        self.nesterov = nesterov;
114        self
115    }
116
117    /// Builder method to set dampening.
118    #[must_use] pub fn dampening(mut self, dampening: f32) -> Self {
119        self.dampening = dampening;
120        self
121    }
122}
123
124impl Optimizer for SGD {
125    fn step(&mut self) {
126        for (i, param) in self.params.iter().enumerate() {
127            if !param.requires_grad() {
128                continue;
129            }
130
131            let grad = match param.grad() {
132                Some(g) => g,
133                None => continue,
134            };
135
136            let mut grad_vec = grad.to_vec();
137
138            // Apply weight decay
139            if self.weight_decay != 0.0 {
140                let param_vec = param.data().to_vec();
141                for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
142                    *g += self.weight_decay * p;
143                }
144            }
145
146            // Apply momentum
147            if self.momentum != 0.0 {
148                let state = &mut self.state[i];
149
150                if state.momentum_buffer.is_none() {
151                    // First iteration: initialize momentum buffer
152                    state.init_momentum(grad_vec.len());
153                    let buf = state.momentum_buffer.as_mut().unwrap();
154                    buf.copy_from_slice(&grad_vec);
155                } else {
156                    // Subsequent iterations: update momentum buffer
157                    let buf = state.momentum_buffer.as_mut().unwrap();
158                    for (b, g) in buf.iter_mut().zip(grad_vec.iter()) {
159                        *b = self.momentum * *b + (1.0 - self.dampening) * *g;
160                    }
161                }
162
163                let buf = state.momentum_buffer.as_ref().unwrap();
164
165                if self.nesterov {
166                    // Nesterov: use momentum * buf + grad
167                    let nesterov_grad: Vec<f32> = buf
168                        .iter()
169                        .zip(grad_vec.iter())
170                        .map(|(b, g)| self.momentum * *b + *g)
171                        .collect();
172                    grad_vec = nesterov_grad;
173                } else {
174                    // Standard momentum: use buf directly
175                    grad_vec = buf.clone();
176                }
177            }
178
179            // Update parameters: param = param - lr * grad
180            let param_data = param.data();
181            let param_vec = param_data.to_vec();
182            let new_data: Vec<f32> = param_vec
183                .iter()
184                .zip(grad_vec.iter())
185                .map(|(p, g)| p - self.lr * g)
186                .collect();
187
188            let update = Tensor::from_vec(new_data, param_data.shape()).unwrap();
189            param.update_data(update);
190        }
191    }
192
193    fn zero_grad(&mut self) {
194        for param in &self.params {
195            param.zero_grad();
196        }
197    }
198
199    fn get_lr(&self) -> f32 {
200        self.lr
201    }
202
203    fn set_lr(&mut self, lr: f32) {
204        self.lr = lr;
205    }
206
207    fn parameters(&self) -> &[Parameter] {
208        &self.params
209    }
210}
211
212// =============================================================================
213// Tests
214// =============================================================================
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use axonml_autograd::Variable;
220
221    #[test]
222    fn test_sgd_creation() {
223        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
224        let param = Parameter::from_variable(var);
225        let optimizer = SGD::new(vec![param], 0.01);
226
227        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
228        assert_eq!(optimizer.num_parameters(), 1);
229    }
230
231    #[test]
232    fn test_sgd_with_momentum() {
233        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
234        let param = Parameter::from_variable(var);
235        let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
236
237        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
238    }
239
240    #[test]
241    fn test_sgd_step() {
242        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
243        let param = Parameter::from_variable(var);
244
245        // Manually set gradient
246        param
247            .variable()
248            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
249
250        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
251        optimizer.step();
252
253        let new_data = param.data().to_vec();
254        // param = param - lr * grad = [1, 2, 3] - 0.1 * [0.1, 0.2, 0.3]
255        assert!((new_data[0] - 0.99).abs() < 1e-5);
256        assert!((new_data[1] - 1.98).abs() < 1e-5);
257        assert!((new_data[2] - 2.97).abs() < 1e-5);
258    }
259
260    #[test]
261    fn test_sgd_zero_grad() {
262        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
263        let param = Parameter::from_variable(var);
264
265        // Set gradient
266        param
267            .variable()
268            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
269
270        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
271
272        // Verify gradient exists
273        assert!(param.grad().is_some());
274
275        optimizer.zero_grad();
276
277        // Gradient should be zeroed
278        let grad = param.grad();
279        if let Some(g) = grad {
280            assert!(g.to_vec().iter().all(|&x| x == 0.0));
281        }
282    }
283
284    #[test]
285    fn test_sgd_builder_pattern() {
286        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
287        let param = Parameter::from_variable(var);
288
289        let optimizer = SGD::new(vec![param], 0.01)
290            .momentum(0.9)
291            .weight_decay(0.0001)
292            .nesterov(true);
293
294        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
295        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
296        assert!(optimizer.nesterov);
297    }
298}