1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::{Optimizer, ParamState};
12
13pub struct SGD {
33 params: Vec<Parameter>,
35 lr: f32,
37 momentum: f32,
39 weight_decay: f32,
41 nesterov: bool,
43 dampening: f32,
45 state: Vec<ParamState>,
47}
48
49impl SGD {
50 #[must_use]
52 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
53 let num_params = params.len();
54 Self {
55 params,
56 lr,
57 momentum: 0.0,
58 weight_decay: 0.0,
59 nesterov: false,
60 dampening: 0.0,
61 state: vec![ParamState::new(); num_params],
62 }
63 }
64
65 #[must_use]
67 pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
68 let num_params = params.len();
69 Self {
70 params,
71 lr,
72 momentum,
73 weight_decay: 0.0,
74 nesterov: false,
75 dampening: 0.0,
76 state: vec![ParamState::new(); num_params],
77 }
78 }
79
80 #[must_use]
82 pub fn with_options(
83 params: Vec<Parameter>,
84 lr: f32,
85 momentum: f32,
86 weight_decay: f32,
87 dampening: f32,
88 nesterov: bool,
89 ) -> Self {
90 let num_params = params.len();
91 Self {
92 params,
93 lr,
94 momentum,
95 weight_decay,
96 nesterov,
97 dampening,
98 state: vec![ParamState::new(); num_params],
99 }
100 }
101
102 #[must_use]
104 pub fn momentum(mut self, momentum: f32) -> Self {
105 self.momentum = momentum;
106 self
107 }
108
109 #[must_use]
111 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
112 self.weight_decay = weight_decay;
113 self
114 }
115
116 #[must_use]
118 pub fn nesterov(mut self, nesterov: bool) -> Self {
119 self.nesterov = nesterov;
120 self
121 }
122
123 #[must_use]
125 pub fn dampening(mut self, dampening: f32) -> Self {
126 self.dampening = dampening;
127 self
128 }
129}
130
131impl Optimizer for SGD {
132 fn step(&mut self) {
133 for (i, param) in self.params.iter().enumerate() {
134 if !param.requires_grad() {
135 continue;
136 }
137
138 let grad = match param.grad() {
139 Some(g) => g,
140 None => continue,
141 };
142
143 let mut grad_vec = grad.to_vec();
144
145 if self.weight_decay != 0.0 {
147 let param_vec = param.data().to_vec();
148 for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
149 *g += self.weight_decay * p;
150 }
151 }
152
153 if self.momentum != 0.0 {
155 let state = &mut self.state[i];
156
157 if state.momentum_buffer.is_none() {
158 state.init_momentum(grad_vec.len());
160 let buf = state.momentum_buffer.as_mut().unwrap();
161 buf.copy_from_slice(&grad_vec);
162 } else {
163 let buf = state.momentum_buffer.as_mut().unwrap();
165 for (b, g) in buf.iter_mut().zip(grad_vec.iter()) {
166 *b = self.momentum * *b + (1.0 - self.dampening) * *g;
167 }
168 }
169
170 let buf = state.momentum_buffer.as_ref().unwrap();
171
172 if self.nesterov {
173 let nesterov_grad: Vec<f32> = buf
175 .iter()
176 .zip(grad_vec.iter())
177 .map(|(b, g)| self.momentum * *b + *g)
178 .collect();
179 grad_vec = nesterov_grad;
180 } else {
181 grad_vec = buf.clone();
183 }
184 }
185
186 let param_data = param.data();
188 let param_vec = param_data.to_vec();
189 let new_data: Vec<f32> = param_vec
190 .iter()
191 .zip(grad_vec.iter())
192 .map(|(p, g)| p - self.lr * g)
193 .collect();
194
195 let update = Tensor::from_vec(new_data, param_data.shape()).unwrap();
196 param.update_data(update);
197 }
198 }
199
200 fn zero_grad(&mut self) {
201 for param in &self.params {
202 param.zero_grad();
203 }
204 }
205
206 fn get_lr(&self) -> f32 {
207 self.lr
208 }
209
210 fn set_lr(&mut self, lr: f32) {
211 self.lr = lr;
212 }
213
214 fn parameters(&self) -> &[Parameter] {
215 &self.params
216 }
217}
218
219#[cfg(test)]
224mod tests {
225 use super::*;
226 use axonml_autograd::Variable;
227
228 #[test]
229 fn test_sgd_creation() {
230 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
231 let param = Parameter::from_variable(var);
232 let optimizer = SGD::new(vec![param], 0.01);
233
234 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
235 assert_eq!(optimizer.num_parameters(), 1);
236 }
237
238 #[test]
239 fn test_sgd_with_momentum() {
240 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
241 let param = Parameter::from_variable(var);
242 let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
243
244 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
245 }
246
247 #[test]
248 fn test_sgd_step() {
249 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
250 let param = Parameter::from_variable(var);
251
252 param
254 .variable()
255 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
256
257 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
258 optimizer.step();
259
260 let new_data = param.data().to_vec();
261 assert!((new_data[0] - 0.99).abs() < 1e-5);
263 assert!((new_data[1] - 1.98).abs() < 1e-5);
264 assert!((new_data[2] - 2.97).abs() < 1e-5);
265 }
266
267 #[test]
268 fn test_sgd_zero_grad() {
269 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
270 let param = Parameter::from_variable(var);
271
272 param
274 .variable()
275 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
276
277 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
278
279 assert!(param.grad().is_some());
281
282 optimizer.zero_grad();
283
284 let grad = param.grad();
286 if let Some(g) = grad {
287 assert!(g.to_vec().iter().all(|&x| x == 0.0));
288 }
289 }
290
291 #[test]
292 fn test_sgd_builder_pattern() {
293 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
294 let param = Parameter::from_variable(var);
295
296 let optimizer = SGD::new(vec![param], 0.01)
297 .momentum(0.9)
298 .weight_decay(0.0001)
299 .nesterov(true);
300
301 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
302 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
303 assert!(optimizer.nesterov);
304 }
305}