1use super::Optimizer;
4use crate::Tensor;
5use ndarray::Array1;
6
7pub struct SGD {
9 lr: f32,
10 momentum: f32,
11 velocities: Vec<Option<Array1<f32>>>,
12}
13
14impl SGD {
15 pub fn new(lr: f32, momentum: f32) -> Self {
17 Self { lr, momentum, velocities: Vec::new() }
18 }
19
20 fn ensure_velocities(&mut self, params: &[Tensor]) {
22 if self.velocities.is_empty() {
23 self.velocities = params.iter().map(|_| None).collect();
24 }
25 }
26}
27
28impl Optimizer for SGD {
29 fn step(&mut self, params: &mut [Tensor]) {
30 self.ensure_velocities(params);
31
32 for (i, param) in params.iter_mut().enumerate() {
33 if let Some(grad) = param.grad() {
34 if grad.len() >= 16 {
36 let grad_slice = grad.as_slice().expect("grad array is contiguous");
37 let param_slice =
38 param.data_mut().as_slice_mut().expect("param array is contiguous");
39
40 if self.momentum > 0.0 {
41 if self.velocities[i].is_none() {
43 self.velocities[i] = Some(Array1::zeros(grad.len()));
44 }
45
46 let velocity =
47 self.velocities[i].as_mut().expect("velocity buffer initialized above");
48 let velocity_slice =
49 velocity.as_slice_mut().expect("velocity array is contiguous");
50
51 for v in velocity_slice.iter_mut() {
54 *v *= self.momentum;
55 }
56
57 super::simd::simd_axpy(-self.lr, grad_slice, velocity_slice);
59
60 super::simd::simd_axpy(1.0, velocity_slice, param_slice);
62 } else {
63 super::simd::simd_axpy(-self.lr, grad_slice, param_slice);
65 }
66 } else {
67 if self.momentum > 0.0 {
69 let velocity = if let Some(v) = &self.velocities[i] {
71 v * self.momentum - &grad * self.lr
72 } else {
73 &grad * (-self.lr)
74 };
75
76 *param.data_mut() = param.data() + &velocity;
77 self.velocities[i] = Some(velocity);
78 } else {
79 *param.data_mut() = param.data() - &(&grad * self.lr);
81 }
82 }
83 }
84 }
85 }
86
87 fn lr(&self) -> f32 {
88 self.lr
89 }
90
91 fn set_lr(&mut self, lr: f32) {
92 self.lr = lr;
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_sgd_small_tensor_no_momentum() {
102 let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
103 param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
104
105 let mut opt = SGD::new(0.1, 0.0);
106 opt.step(&mut [param.clone()]);
107 }
109
110 #[test]
111 fn test_sgd_small_tensor_with_momentum() {
112 let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
113 param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
114
115 let mut opt = SGD::new(0.1, 0.9);
116 opt.step(&mut [param.clone()]);
118
119 param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
121 opt.step(&mut [param.clone()]);
122 }
123
124 #[test]
125 fn test_sgd_large_tensor_with_momentum() {
126 let data: Vec<f32> = (0..20).map(|i| i as f32).collect();
128 let grad: Vec<f32> = vec![0.1; 20];
129
130 let param = Tensor::from_vec(data, true);
131 param.set_grad(Array1::from_vec(grad.clone()));
132
133 let mut opt = SGD::new(0.1, 0.9);
134 opt.step(&mut [param.clone()]);
135
136 param.set_grad(Array1::from_vec(grad));
138 opt.step(&mut [param.clone()]);
139 }
140
141 #[test]
142 fn test_sgd_lr_getter_setter() {
143 let mut opt = SGD::new(0.1, 0.0);
144 assert!((opt.lr() - 0.1).abs() < 1e-6);
145 opt.set_lr(0.01);
146 assert!((opt.lr() - 0.01).abs() < 1e-6);
147 }
148
149 #[test]
150 fn test_sgd_no_grad_skips() {
151 let param = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
152 let mut opt = SGD::new(0.1, 0.0);
155 opt.step(&mut [param.clone()]); }
157}