Skip to main content

entrenar/optim/
sgd.rs

1//! Stochastic Gradient Descent optimizer
2
3use super::Optimizer;
4use crate::Tensor;
5use ndarray::Array1;
6
7/// SGD optimizer with optional momentum
8pub struct SGD {
9    lr: f32,
10    momentum: f32,
11    velocities: Vec<Option<Array1<f32>>>,
12}
13
14impl SGD {
15    /// Create a new SGD optimizer
16    pub fn new(lr: f32, momentum: f32) -> Self {
17        Self { lr, momentum, velocities: Vec::new() }
18    }
19
20    /// Initialize velocities if needed
21    fn ensure_velocities(&mut self, params: &[Tensor]) {
22        if self.velocities.is_empty() {
23            self.velocities = params.iter().map(|_| None).collect();
24        }
25    }
26}
27
28impl Optimizer for SGD {
29    fn step(&mut self, params: &mut [Tensor]) {
30        self.ensure_velocities(params);
31
32        for (i, param) in params.iter_mut().enumerate() {
33            if let Some(grad) = param.grad() {
34                // Use SIMD for large tensors (>= 16 elements for meaningful speedup)
35                if grad.len() >= 16 {
36                    let grad_slice = grad.as_slice().expect("grad array is contiguous");
37                    let param_slice =
38                        param.data_mut().as_slice_mut().expect("param array is contiguous");
39
40                    if self.momentum > 0.0 {
41                        // Initialize velocity if needed
42                        if self.velocities[i].is_none() {
43                            self.velocities[i] = Some(Array1::zeros(grad.len()));
44                        }
45
46                        let velocity =
47                            self.velocities[i].as_mut().expect("velocity buffer initialized above");
48                        let velocity_slice =
49                            velocity.as_slice_mut().expect("velocity array is contiguous");
50
51                        // v = momentum * v - lr * grad (using SIMD)
52                        // First scale velocity by momentum
53                        for v in velocity_slice.iter_mut() {
54                            *v *= self.momentum;
55                        }
56
57                        // Then add -lr * grad using SIMD axpy
58                        super::simd::simd_axpy(-self.lr, grad_slice, velocity_slice);
59
60                        // param = param + velocity (using SIMD axpy with a=1.0)
61                        super::simd::simd_axpy(1.0, velocity_slice, param_slice);
62                    } else {
63                        // Simple SGD: param -= lr * grad (using SIMD axpy)
64                        super::simd::simd_axpy(-self.lr, grad_slice, param_slice);
65                    }
66                } else {
67                    // Fallback to scalar implementation for small tensors
68                    if self.momentum > 0.0 {
69                        // v = momentum * v - lr * grad
70                        let velocity = if let Some(v) = &self.velocities[i] {
71                            v * self.momentum - &grad * self.lr
72                        } else {
73                            &grad * (-self.lr)
74                        };
75
76                        *param.data_mut() = param.data() + &velocity;
77                        self.velocities[i] = Some(velocity);
78                    } else {
79                        // Simple SGD: param -= lr * grad
80                        *param.data_mut() = param.data() - &(&grad * self.lr);
81                    }
82                }
83            }
84        }
85    }
86
87    fn lr(&self) -> f32 {
88        self.lr
89    }
90
91    fn set_lr(&mut self, lr: f32) {
92        self.lr = lr;
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_sgd_small_tensor_no_momentum() {
102        let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
103        param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
104
105        let mut opt = SGD::new(0.1, 0.0);
106        opt.step(&mut [param.clone()]);
107        // Small tensor path, no momentum
108    }
109
110    #[test]
111    fn test_sgd_small_tensor_with_momentum() {
112        let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
113        param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
114
115        let mut opt = SGD::new(0.1, 0.9);
116        // First step initializes velocity from scratch
117        opt.step(&mut [param.clone()]);
118
119        // Second step uses existing velocity
120        param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
121        opt.step(&mut [param.clone()]);
122    }
123
124    #[test]
125    fn test_sgd_large_tensor_with_momentum() {
126        // >= 16 elements to trigger SIMD path
127        let data: Vec<f32> = (0..20).map(|i| i as f32).collect();
128        let grad: Vec<f32> = vec![0.1; 20];
129
130        let param = Tensor::from_vec(data, true);
131        param.set_grad(Array1::from_vec(grad.clone()));
132
133        let mut opt = SGD::new(0.1, 0.9);
134        opt.step(&mut [param.clone()]);
135
136        // Second step with existing velocity
137        param.set_grad(Array1::from_vec(grad));
138        opt.step(&mut [param.clone()]);
139    }
140
141    #[test]
142    fn test_sgd_lr_getter_setter() {
143        let mut opt = SGD::new(0.1, 0.0);
144        assert!((opt.lr() - 0.1).abs() < 1e-6);
145        opt.set_lr(0.01);
146        assert!((opt.lr() - 0.01).abs() < 1e-6);
147    }
148
149    #[test]
150    fn test_sgd_no_grad_skips() {
151        let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
152        // No gradient set
153
154        let mut opt = SGD::new(0.1, 0.0);
155        opt.step(&mut [param.clone()]); // Should not panic
156    }
157}