1use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22pub struct SGD {
42 params: Vec<Parameter>,
44 lr: f32,
46 momentum: f32,
48 weight_decay: f32,
50 nesterov: bool,
52 dampening: f32,
54 momentum_buffers: Vec<Option<Tensor<f32>>>,
57}
58
59impl SGD {
60 #[must_use]
62 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
63 let num_params = params.len();
64 Self {
65 params,
66 lr,
67 momentum: 0.0,
68 weight_decay: 0.0,
69 nesterov: false,
70 dampening: 0.0,
71 momentum_buffers: vec![None; num_params],
72 }
73 }
74
75 #[must_use]
77 pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
78 let num_params = params.len();
79 Self {
80 params,
81 lr,
82 momentum,
83 weight_decay: 0.0,
84 nesterov: false,
85 dampening: 0.0,
86 momentum_buffers: vec![None; num_params],
87 }
88 }
89
90 #[must_use]
92 pub fn with_options(
93 params: Vec<Parameter>,
94 lr: f32,
95 momentum: f32,
96 weight_decay: f32,
97 dampening: f32,
98 nesterov: bool,
99 ) -> Self {
100 let num_params = params.len();
101 Self {
102 params,
103 lr,
104 momentum,
105 weight_decay,
106 nesterov,
107 dampening,
108 momentum_buffers: vec![None; num_params],
109 }
110 }
111
112 #[must_use]
114 pub fn momentum(mut self, momentum: f32) -> Self {
115 self.momentum = momentum;
116 self
117 }
118
119 #[must_use]
121 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
122 self.weight_decay = weight_decay;
123 self
124 }
125
126 #[must_use]
128 pub fn nesterov(mut self, nesterov: bool) -> Self {
129 self.nesterov = nesterov;
130 self
131 }
132
133 #[must_use]
135 pub fn dampening(mut self, dampening: f32) -> Self {
136 self.dampening = dampening;
137 self
138 }
139}
140
141impl Optimizer for SGD {
142 fn step(&mut self) {
143 for (i, param) in self.params.iter().enumerate() {
144 if !param.requires_grad() {
145 continue;
146 }
147
148 let grad = match param.grad() {
149 Some(g) => g,
150 None => continue,
151 };
152
153 let param_data = param.data();
154
155 let d = if self.weight_decay == 0.0 {
163 grad.clone()
164 } else {
165 grad.add(¶m_data.mul_scalar(self.weight_decay)).unwrap()
166 };
167
168 let update_dir = if self.momentum == 0.0 {
170 d
171 } else {
172 let buf = &mut self.momentum_buffers[i];
173
174 if buf.is_none() {
175 *buf = Some(d.clone());
177 } else {
178 let old = buf.as_ref().unwrap();
180 let new_buf = old
181 .mul_scalar(self.momentum)
182 .add(&d.mul_scalar(1.0 - self.dampening))
183 .unwrap();
184 *buf = Some(new_buf);
185 }
186
187 let buf_ref = buf.as_ref().unwrap();
188
189 if self.nesterov {
190 d.add(&buf_ref.mul_scalar(self.momentum)).unwrap()
192 } else {
193 buf_ref.clone()
194 }
195 };
196
197 let new_param = param_data.sub(&update_dir.mul_scalar(self.lr)).unwrap();
199 param.update_data(new_param);
200 }
201 }
202
203 fn zero_grad(&mut self) {
204 for param in &self.params {
205 param.zero_grad();
206 }
207 }
208
209 fn get_lr(&self) -> f32 {
210 self.lr
211 }
212
213 fn set_lr(&mut self, lr: f32) {
214 self.lr = lr;
215 }
216
217 fn parameters(&self) -> &[Parameter] {
218 &self.params
219 }
220}
221
222#[cfg(test)]
227mod tests {
228 use super::*;
229 use axonml_autograd::Variable;
230
231 #[test]
232 fn test_sgd_creation() {
233 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
234 let param = Parameter::from_variable(var);
235 let optimizer = SGD::new(vec![param], 0.01);
236
237 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
238 assert_eq!(optimizer.num_parameters(), 1);
239 }
240
241 #[test]
242 fn test_sgd_with_momentum() {
243 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
244 let param = Parameter::from_variable(var);
245 let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
246
247 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
248 }
249
250 #[test]
251 fn test_sgd_step() {
252 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
253 let param = Parameter::from_variable(var);
254
255 param
257 .variable()
258 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
259
260 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
261 optimizer.step();
262
263 let new_data = param.data().to_vec();
264 assert!((new_data[0] - 0.99).abs() < 1e-5);
266 assert!((new_data[1] - 1.98).abs() < 1e-5);
267 assert!((new_data[2] - 2.97).abs() < 1e-5);
268 }
269
270 #[test]
271 fn test_sgd_zero_grad() {
272 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
273 let param = Parameter::from_variable(var);
274
275 param
277 .variable()
278 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
279
280 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
281
282 assert!(param.grad().is_some());
284
285 optimizer.zero_grad();
286
287 let grad = param.grad();
289 if let Some(g) = grad {
290 assert!(g.to_vec().iter().all(|&x| x == 0.0));
291 }
292 }
293
294 #[test]
295 fn test_sgd_builder_pattern() {
296 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
297 let param = Parameter::from_variable(var);
298
299 let optimizer = SGD::new(vec![param], 0.01)
300 .momentum(0.9)
301 .weight_decay(0.0001)
302 .nesterov(true);
303
304 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
305 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
306 assert!(optimizer.nesterov);
307 }
308}