Skip to main content

axonml_optim/
sgd.rs

1//! SGD Optimizer - Stochastic Gradient Descent
2//!
3//! Implements SGD with optional momentum and Nesterov acceleration.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::{Optimizer, ParamState};
12
13// =============================================================================
14// SGD
15// =============================================================================
16
17/// Stochastic Gradient Descent optimizer.
18///
19/// Supports momentum and Nesterov acceleration.
20///
21/// Update rule (with momentum):
22/// ```text
23/// v_t = momentum * v_{t-1} + grad
24/// param = param - lr * v_t
25/// ```
26///
27/// Update rule (with Nesterov):
28/// ```text
29/// v_t = momentum * v_{t-1} + grad
30/// param = param - lr * (momentum * v_t + grad)
31/// ```
32pub struct SGD {
33    /// Parameters to optimize.
34    params: Vec<Parameter>,
35    /// Learning rate.
36    lr: f32,
37    /// Momentum factor.
38    momentum: f32,
39    /// Weight decay (L2 regularization).
40    weight_decay: f32,
41    /// Whether to use Nesterov momentum.
42    nesterov: bool,
43    /// Dampening factor for momentum.
44    dampening: f32,
45    /// Per-parameter state (momentum buffers).
46    state: Vec<ParamState>,
47}
48
49impl SGD {
50    /// Creates a new SGD optimizer with default settings.
51    #[must_use]
52    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
53        let num_params = params.len();
54        Self {
55            params,
56            lr,
57            momentum: 0.0,
58            weight_decay: 0.0,
59            nesterov: false,
60            dampening: 0.0,
61            state: vec![ParamState::new(); num_params],
62        }
63    }
64
65    /// Creates SGD with momentum.
66    #[must_use]
67    pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
68        let num_params = params.len();
69        Self {
70            params,
71            lr,
72            momentum,
73            weight_decay: 0.0,
74            nesterov: false,
75            dampening: 0.0,
76            state: vec![ParamState::new(); num_params],
77        }
78    }
79
80    /// Creates SGD with all options.
81    #[must_use]
82    pub fn with_options(
83        params: Vec<Parameter>,
84        lr: f32,
85        momentum: f32,
86        weight_decay: f32,
87        dampening: f32,
88        nesterov: bool,
89    ) -> Self {
90        let num_params = params.len();
91        Self {
92            params,
93            lr,
94            momentum,
95            weight_decay,
96            nesterov,
97            dampening,
98            state: vec![ParamState::new(); num_params],
99        }
100    }
101
102    /// Builder method to set momentum.
103    #[must_use]
104    pub fn momentum(mut self, momentum: f32) -> Self {
105        self.momentum = momentum;
106        self
107    }
108
109    /// Builder method to set weight decay.
110    #[must_use]
111    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
112        self.weight_decay = weight_decay;
113        self
114    }
115
116    /// Builder method to enable Nesterov momentum.
117    #[must_use]
118    pub fn nesterov(mut self, nesterov: bool) -> Self {
119        self.nesterov = nesterov;
120        self
121    }
122
123    /// Builder method to set dampening.
124    #[must_use]
125    pub fn dampening(mut self, dampening: f32) -> Self {
126        self.dampening = dampening;
127        self
128    }
129}
130
131impl Optimizer for SGD {
132    fn step(&mut self) {
133        for (i, param) in self.params.iter().enumerate() {
134            if !param.requires_grad() {
135                continue;
136            }
137
138            let grad = match param.grad() {
139                Some(g) => g,
140                None => continue,
141            };
142
143            let mut grad_vec = grad.to_vec();
144
145            // Apply weight decay
146            if self.weight_decay != 0.0 {
147                let param_vec = param.data().to_vec();
148                for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
149                    *g += self.weight_decay * p;
150                }
151            }
152
153            // Apply momentum
154            if self.momentum != 0.0 {
155                let state = &mut self.state[i];
156
157                if state.momentum_buffer.is_none() {
158                    // First iteration: initialize momentum buffer
159                    state.init_momentum(grad_vec.len());
160                    let buf = state.momentum_buffer.as_mut().unwrap();
161                    buf.copy_from_slice(&grad_vec);
162                } else {
163                    // Subsequent iterations: update momentum buffer
164                    let buf = state.momentum_buffer.as_mut().unwrap();
165                    for (b, g) in buf.iter_mut().zip(grad_vec.iter()) {
166                        *b = self.momentum * *b + (1.0 - self.dampening) * *g;
167                    }
168                }
169
170                let buf = state.momentum_buffer.as_ref().unwrap();
171
172                if self.nesterov {
173                    // Nesterov: use momentum * buf + grad
174                    let nesterov_grad: Vec<f32> = buf
175                        .iter()
176                        .zip(grad_vec.iter())
177                        .map(|(b, g)| self.momentum * *b + *g)
178                        .collect();
179                    grad_vec = nesterov_grad;
180                } else {
181                    // Standard momentum: use buf directly
182                    grad_vec = buf.clone();
183                }
184            }
185
186            // Update parameters: param = param - lr * grad
187            let param_data = param.data();
188            let param_vec = param_data.to_vec();
189            let new_data: Vec<f32> = param_vec
190                .iter()
191                .zip(grad_vec.iter())
192                .map(|(p, g)| p - self.lr * g)
193                .collect();
194
195            let update = Tensor::from_vec(new_data, param_data.shape()).unwrap();
196            param.update_data(update);
197        }
198    }
199
200    fn zero_grad(&mut self) {
201        for param in &self.params {
202            param.zero_grad();
203        }
204    }
205
206    fn get_lr(&self) -> f32 {
207        self.lr
208    }
209
210    fn set_lr(&mut self, lr: f32) {
211        self.lr = lr;
212    }
213
214    fn parameters(&self) -> &[Parameter] {
215        &self.params
216    }
217}
218
219// =============================================================================
220// Tests
221// =============================================================================
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use axonml_autograd::Variable;
227
228    #[test]
229    fn test_sgd_creation() {
230        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
231        let param = Parameter::from_variable(var);
232        let optimizer = SGD::new(vec![param], 0.01);
233
234        assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
235        assert_eq!(optimizer.num_parameters(), 1);
236    }
237
238    #[test]
239    fn test_sgd_with_momentum() {
240        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
241        let param = Parameter::from_variable(var);
242        let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
243
244        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
245    }
246
247    #[test]
248    fn test_sgd_step() {
249        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
250        let param = Parameter::from_variable(var);
251
252        // Manually set gradient
253        param
254            .variable()
255            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
256
257        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
258        optimizer.step();
259
260        let new_data = param.data().to_vec();
261        // param = param - lr * grad = [1, 2, 3] - 0.1 * [0.1, 0.2, 0.3]
262        assert!((new_data[0] - 0.99).abs() < 1e-5);
263        assert!((new_data[1] - 1.98).abs() < 1e-5);
264        assert!((new_data[2] - 2.97).abs() < 1e-5);
265    }
266
267    #[test]
268    fn test_sgd_zero_grad() {
269        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
270        let param = Parameter::from_variable(var);
271
272        // Set gradient
273        param
274            .variable()
275            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
276
277        let mut optimizer = SGD::new(vec![param.clone()], 0.1);
278
279        // Verify gradient exists
280        assert!(param.grad().is_some());
281
282        optimizer.zero_grad();
283
284        // Gradient should be zeroed
285        let grad = param.grad();
286        if let Some(g) = grad {
287            assert!(g.to_vec().iter().all(|&x| x == 0.0));
288        }
289    }
290
291    #[test]
292    fn test_sgd_builder_pattern() {
293        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
294        let param = Parameter::from_variable(var);
295
296        let optimizer = SGD::new(vec![param], 0.01)
297            .momentum(0.9)
298            .weight_decay(0.0001)
299            .nesterov(true);
300
301        assert!((optimizer.momentum - 0.9).abs() < 1e-6);
302        assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
303        assert!(optimizer.nesterov);
304    }
305}