1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::{Optimizer, ParamState};
12
13pub struct SGD {
33 params: Vec<Parameter>,
35 lr: f32,
37 momentum: f32,
39 weight_decay: f32,
41 nesterov: bool,
43 dampening: f32,
45 state: Vec<ParamState>,
47}
48
49impl SGD {
50 #[must_use]
52 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
53 let num_params = params.len();
54 Self {
55 params,
56 lr,
57 momentum: 0.0,
58 weight_decay: 0.0,
59 nesterov: false,
60 dampening: 0.0,
61 state: vec![ParamState::new(); num_params],
62 }
63 }
64
65 #[must_use]
67 pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
68 let num_params = params.len();
69 Self {
70 params,
71 lr,
72 momentum,
73 weight_decay: 0.0,
74 nesterov: false,
75 dampening: 0.0,
76 state: vec![ParamState::new(); num_params],
77 }
78 }
79
80 #[must_use]
82 pub fn with_options(
83 params: Vec<Parameter>,
84 lr: f32,
85 momentum: f32,
86 weight_decay: f32,
87 dampening: f32,
88 nesterov: bool,
89 ) -> Self {
90 let num_params = params.len();
91 Self {
92 params,
93 lr,
94 momentum,
95 weight_decay,
96 nesterov,
97 dampening,
98 state: vec![ParamState::new(); num_params],
99 }
100 }
101
102 #[must_use]
104 pub fn momentum(mut self, momentum: f32) -> Self {
105 self.momentum = momentum;
106 self
107 }
108
109 #[must_use]
111 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
112 self.weight_decay = weight_decay;
113 self
114 }
115
116 #[must_use]
118 pub fn nesterov(mut self, nesterov: bool) -> Self {
119 self.nesterov = nesterov;
120 self
121 }
122
123 #[must_use]
125 pub fn dampening(mut self, dampening: f32) -> Self {
126 self.dampening = dampening;
127 self
128 }
129}
130
131impl Optimizer for SGD {
132 fn step(&mut self) {
133 for (i, param) in self.params.iter().enumerate() {
134 if !param.requires_grad() {
135 continue;
136 }
137
138 let grad = match param.grad() {
139 Some(g) => g,
140 None => continue,
141 };
142
143 let param_data = param.data();
144 let mut param_vec = param_data.to_vec();
145 let mut grad_vec = grad.to_vec();
146
147 if self.weight_decay != 0.0 {
149 for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
150 *g += self.weight_decay * p;
151 }
152 }
153
154 if self.momentum != 0.0 {
156 let state = &mut self.state[i];
157
158 if state.momentum_buffer.is_none() {
159 state.init_momentum(grad_vec.len());
161 let buf = state.momentum_buffer.as_mut().unwrap();
162 buf.copy_from_slice(&grad_vec);
163 } else {
164 let buf = state.momentum_buffer.as_mut().unwrap();
166 for (b, g) in buf.iter_mut().zip(grad_vec.iter()) {
167 *b = self.momentum * *b + (1.0 - self.dampening) * *g;
168 }
169 }
170
171 let buf = state.momentum_buffer.as_ref().unwrap();
172
173 if self.nesterov {
174 for (g, b) in grad_vec.iter_mut().zip(buf.iter()) {
176 *g += self.momentum * *b;
177 }
178 } else {
184 grad_vec.copy_from_slice(buf);
186 }
187 }
188
189 let lr = self.lr;
191 for (p, g) in param_vec.iter_mut().zip(grad_vec.iter()) {
192 *p -= lr * g;
193 }
194
195 let mut update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
196 let device = param_data.device();
197 if device.is_gpu() {
198 update = update.to_device(device).unwrap();
199 }
200 param.update_data(update);
201 }
202 }
203
204 fn zero_grad(&mut self) {
205 for param in &self.params {
206 param.zero_grad();
207 }
208 }
209
210 fn get_lr(&self) -> f32 {
211 self.lr
212 }
213
214 fn set_lr(&mut self, lr: f32) {
215 self.lr = lr;
216 }
217
218 fn parameters(&self) -> &[Parameter] {
219 &self.params
220 }
221}
222
223#[cfg(test)]
228mod tests {
229 use super::*;
230 use axonml_autograd::Variable;
231
232 #[test]
233 fn test_sgd_creation() {
234 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
235 let param = Parameter::from_variable(var);
236 let optimizer = SGD::new(vec![param], 0.01);
237
238 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
239 assert_eq!(optimizer.num_parameters(), 1);
240 }
241
242 #[test]
243 fn test_sgd_with_momentum() {
244 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
245 let param = Parameter::from_variable(var);
246 let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
247
248 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
249 }
250
251 #[test]
252 fn test_sgd_step() {
253 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
254 let param = Parameter::from_variable(var);
255
256 param
258 .variable()
259 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
260
261 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
262 optimizer.step();
263
264 let new_data = param.data().to_vec();
265 assert!((new_data[0] - 0.99).abs() < 1e-5);
267 assert!((new_data[1] - 1.98).abs() < 1e-5);
268 assert!((new_data[2] - 2.97).abs() < 1e-5);
269 }
270
271 #[test]
272 fn test_sgd_zero_grad() {
273 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
274 let param = Parameter::from_variable(var);
275
276 param
278 .variable()
279 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
280
281 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
282
283 assert!(param.grad().is_some());
285
286 optimizer.zero_grad();
287
288 let grad = param.grad();
290 if let Some(g) = grad {
291 assert!(g.to_vec().iter().all(|&x| x == 0.0));
292 }
293 }
294
295 #[test]
296 fn test_sgd_builder_pattern() {
297 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
298 let param = Parameter::from_variable(var);
299
300 let optimizer = SGD::new(vec![param], 0.01)
301 .momentum(0.9)
302 .weight_decay(0.0001)
303 .nesterov(true);
304
305 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
306 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
307 assert!(optimizer.nesterov);
308 }
309}