1use axonml_nn::Parameter;
23use axonml_tensor::Tensor;
24
25use crate::optimizer::Optimizer;
26
27pub struct SGD {
47 params: Vec<Parameter>,
49 lr: f32,
51 momentum: f32,
53 weight_decay: f32,
55 nesterov: bool,
57 dampening: f32,
59 momentum_buffers: Vec<Option<Tensor<f32>>>,
62}
63
64impl SGD {
65 #[must_use]
67 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
68 let num_params = params.len();
69 Self {
70 params,
71 lr,
72 momentum: 0.0,
73 weight_decay: 0.0,
74 nesterov: false,
75 dampening: 0.0,
76 momentum_buffers: vec![None; num_params],
77 }
78 }
79
80 #[must_use]
82 pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
83 let num_params = params.len();
84 Self {
85 params,
86 lr,
87 momentum,
88 weight_decay: 0.0,
89 nesterov: false,
90 dampening: 0.0,
91 momentum_buffers: vec![None; num_params],
92 }
93 }
94
95 #[must_use]
97 pub fn with_options(
98 params: Vec<Parameter>,
99 lr: f32,
100 momentum: f32,
101 weight_decay: f32,
102 dampening: f32,
103 nesterov: bool,
104 ) -> Self {
105 let num_params = params.len();
106 Self {
107 params,
108 lr,
109 momentum,
110 weight_decay,
111 nesterov,
112 dampening,
113 momentum_buffers: vec![None; num_params],
114 }
115 }
116
117 #[must_use]
119 pub fn momentum(mut self, momentum: f32) -> Self {
120 self.momentum = momentum;
121 self
122 }
123
124 #[must_use]
126 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
127 self.weight_decay = weight_decay;
128 self
129 }
130
131 #[must_use]
133 pub fn nesterov(mut self, nesterov: bool) -> Self {
134 self.nesterov = nesterov;
135 self
136 }
137
138 #[must_use]
140 pub fn dampening(mut self, dampening: f32) -> Self {
141 self.dampening = dampening;
142 self
143 }
144}
145
146impl Optimizer for SGD {
147 fn step(&mut self) {
148 for (i, param) in self.params.iter().enumerate() {
149 if !param.requires_grad() {
150 continue;
151 }
152
153 let grad = match param.grad() {
154 Some(g) => g,
155 None => continue,
156 };
157
158 let param_data = param.data();
159
160 let d = if self.weight_decay == 0.0 {
168 grad.clone()
169 } else {
170 grad.add(¶m_data.mul_scalar(self.weight_decay)).unwrap()
171 };
172
173 let update_dir = if self.momentum == 0.0 {
175 d
176 } else {
177 let buf = &mut self.momentum_buffers[i];
178
179 if buf.is_none() {
180 *buf = Some(d.clone());
182 } else {
183 let old = buf.as_ref().unwrap();
185 let new_buf = old
186 .mul_scalar(self.momentum)
187 .add(&d.mul_scalar(1.0 - self.dampening))
188 .unwrap();
189 *buf = Some(new_buf);
190 }
191
192 let buf_ref = buf.as_ref().unwrap();
193
194 if self.nesterov {
195 d.add(&buf_ref.mul_scalar(self.momentum)).unwrap()
197 } else {
198 buf_ref.clone()
199 }
200 };
201
202 let new_param = param_data.sub(&update_dir.mul_scalar(self.lr)).unwrap();
204 param.update_data(new_param);
205 }
206 }
207
208 fn zero_grad(&mut self) {
209 for param in &self.params {
210 param.zero_grad();
211 }
212 }
213
214 fn get_lr(&self) -> f32 {
215 self.lr
216 }
217
218 fn set_lr(&mut self, lr: f32) {
219 self.lr = lr;
220 }
221
222 fn parameters(&self) -> &[Parameter] {
223 &self.params
224 }
225}
226
227#[cfg(test)]
232mod tests {
233 use super::*;
234 use axonml_autograd::Variable;
235
236 #[test]
237 fn test_sgd_creation() {
238 let var = Variable::new(
239 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
240 true,
241 );
242 let param = Parameter::from_variable(var);
243 let optimizer = SGD::new(vec![param], 0.01);
244
245 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
246 assert_eq!(optimizer.num_parameters(), 1);
247 }
248
249 #[test]
250 fn test_sgd_with_momentum() {
251 let var = Variable::new(
252 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
253 true,
254 );
255 let param = Parameter::from_variable(var);
256 let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
257
258 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
259 }
260
261 #[test]
262 fn test_sgd_step() {
263 let var = Variable::new(
264 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
265 true,
266 );
267 let param = Parameter::from_variable(var);
268
269 param
271 .variable()
272 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
273
274 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
275 optimizer.step();
276
277 let new_data = param.data().to_vec();
278 assert!((new_data[0] - 0.99).abs() < 1e-5);
280 assert!((new_data[1] - 1.98).abs() < 1e-5);
281 assert!((new_data[2] - 2.97).abs() < 1e-5);
282 }
283
284 #[test]
285 fn test_sgd_zero_grad() {
286 let var = Variable::new(
287 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
288 true,
289 );
290 let param = Parameter::from_variable(var);
291
292 param
294 .variable()
295 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
296
297 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
298
299 assert!(param.grad().is_some());
301
302 optimizer.zero_grad();
303
304 let grad = param.grad();
306 if let Some(g) = grad {
307 assert!(g.to_vec().iter().all(|&x| x == 0.0));
308 }
309 }
310
311 #[test]
312 fn test_sgd_builder_pattern() {
313 let var = Variable::new(
314 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
315 true,
316 );
317 let param = Parameter::from_variable(var);
318
319 let optimizer = SGD::new(vec![param], 0.01)
320 .momentum(0.9)
321 .weight_decay(0.0001)
322 .nesterov(true);
323
324 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
325 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
326 assert!(optimizer.nesterov);
327 }
328}