1use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22pub struct SGD {
42 params: Vec<Parameter>,
44 lr: f32,
46 momentum: f32,
48 weight_decay: f32,
50 nesterov: bool,
52 dampening: f32,
54 momentum_buffers: Vec<Option<Tensor<f32>>>,
57}
58
59impl SGD {
60 #[must_use]
62 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
63 let num_params = params.len();
64 Self {
65 params,
66 lr,
67 momentum: 0.0,
68 weight_decay: 0.0,
69 nesterov: false,
70 dampening: 0.0,
71 momentum_buffers: vec![None; num_params],
72 }
73 }
74
75 #[must_use]
77 pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
78 let num_params = params.len();
79 Self {
80 params,
81 lr,
82 momentum,
83 weight_decay: 0.0,
84 nesterov: false,
85 dampening: 0.0,
86 momentum_buffers: vec![None; num_params],
87 }
88 }
89
90 #[must_use]
92 pub fn with_options(
93 params: Vec<Parameter>,
94 lr: f32,
95 momentum: f32,
96 weight_decay: f32,
97 dampening: f32,
98 nesterov: bool,
99 ) -> Self {
100 let num_params = params.len();
101 Self {
102 params,
103 lr,
104 momentum,
105 weight_decay,
106 nesterov,
107 dampening,
108 momentum_buffers: vec![None; num_params],
109 }
110 }
111
112 #[must_use]
114 pub fn momentum(mut self, momentum: f32) -> Self {
115 self.momentum = momentum;
116 self
117 }
118
119 #[must_use]
121 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
122 self.weight_decay = weight_decay;
123 self
124 }
125
126 #[must_use]
128 pub fn nesterov(mut self, nesterov: bool) -> Self {
129 self.nesterov = nesterov;
130 self
131 }
132
133 #[must_use]
135 pub fn dampening(mut self, dampening: f32) -> Self {
136 self.dampening = dampening;
137 self
138 }
139}
140
141impl Optimizer for SGD {
142 fn step(&mut self) {
143 for (i, param) in self.params.iter().enumerate() {
144 if !param.requires_grad() {
145 continue;
146 }
147
148 let grad = match param.grad() {
149 Some(g) => g,
150 None => continue,
151 };
152
153 let param_data = param.data();
154
155 let d = if self.weight_decay == 0.0 {
163 grad.clone()
164 } else {
165 grad.add(¶m_data.mul_scalar(self.weight_decay)).unwrap()
166 };
167
168 let update_dir = if self.momentum == 0.0 {
170 d
171 } else {
172 let buf = &mut self.momentum_buffers[i];
173
174 if buf.is_none() {
175 *buf = Some(d.clone());
177 } else {
178 let old = buf.as_ref().unwrap();
180 let new_buf = old
181 .mul_scalar(self.momentum)
182 .add(&d.mul_scalar(1.0 - self.dampening))
183 .unwrap();
184 *buf = Some(new_buf);
185 }
186
187 let buf_ref = buf.as_ref().unwrap();
188
189 if self.nesterov {
190 d.add(&buf_ref.mul_scalar(self.momentum)).unwrap()
192 } else {
193 buf_ref.clone()
194 }
195 };
196
197 let new_param = param_data.sub(&update_dir.mul_scalar(self.lr)).unwrap();
199 param.update_data(new_param);
200 }
201 }
202
203 fn zero_grad(&mut self) {
204 for param in &self.params {
205 param.zero_grad();
206 }
207 }
208
209 fn get_lr(&self) -> f32 {
210 self.lr
211 }
212
213 fn set_lr(&mut self, lr: f32) {
214 self.lr = lr;
215 }
216
217 fn parameters(&self) -> &[Parameter] {
218 &self.params
219 }
220}
221
222#[cfg(test)]
227mod tests {
228 use super::*;
229 use axonml_autograd::Variable;
230
231 #[test]
232 fn test_sgd_creation() {
233 let var = Variable::new(
234 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
235 true,
236 );
237 let param = Parameter::from_variable(var);
238 let optimizer = SGD::new(vec![param], 0.01);
239
240 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
241 assert_eq!(optimizer.num_parameters(), 1);
242 }
243
244 #[test]
245 fn test_sgd_with_momentum() {
246 let var = Variable::new(
247 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
248 true,
249 );
250 let param = Parameter::from_variable(var);
251 let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
252
253 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
254 }
255
256 #[test]
257 fn test_sgd_step() {
258 let var = Variable::new(
259 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
260 true,
261 );
262 let param = Parameter::from_variable(var);
263
264 param
266 .variable()
267 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
268
269 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
270 optimizer.step();
271
272 let new_data = param.data().to_vec();
273 assert!((new_data[0] - 0.99).abs() < 1e-5);
275 assert!((new_data[1] - 1.98).abs() < 1e-5);
276 assert!((new_data[2] - 2.97).abs() < 1e-5);
277 }
278
279 #[test]
280 fn test_sgd_zero_grad() {
281 let var = Variable::new(
282 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
283 true,
284 );
285 let param = Parameter::from_variable(var);
286
287 param
289 .variable()
290 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
291
292 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
293
294 assert!(param.grad().is_some());
296
297 optimizer.zero_grad();
298
299 let grad = param.grad();
301 if let Some(g) = grad {
302 assert!(g.to_vec().iter().all(|&x| x == 0.0));
303 }
304 }
305
306 #[test]
307 fn test_sgd_builder_pattern() {
308 let var = Variable::new(
309 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
310 true,
311 );
312 let param = Parameter::from_variable(var);
313
314 let optimizer = SGD::new(vec![param], 0.01)
315 .momentum(0.9)
316 .weight_decay(0.0001)
317 .nesterov(true);
318
319 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
320 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
321 assert!(optimizer.nesterov);
322 }
323}