1use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::{Optimizer, ParamState};
12
13pub struct SGD {
33 params: Vec<Parameter>,
35 lr: f32,
37 momentum: f32,
39 weight_decay: f32,
41 nesterov: bool,
43 dampening: f32,
45 state: Vec<ParamState>,
47}
48
49impl SGD {
50 #[must_use] pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
52 let num_params = params.len();
53 Self {
54 params,
55 lr,
56 momentum: 0.0,
57 weight_decay: 0.0,
58 nesterov: false,
59 dampening: 0.0,
60 state: vec![ParamState::new(); num_params],
61 }
62 }
63
64 #[must_use] pub fn with_momentum(params: Vec<Parameter>, lr: f32, momentum: f32) -> Self {
66 let num_params = params.len();
67 Self {
68 params,
69 lr,
70 momentum,
71 weight_decay: 0.0,
72 nesterov: false,
73 dampening: 0.0,
74 state: vec![ParamState::new(); num_params],
75 }
76 }
77
78 #[must_use] pub fn with_options(
80 params: Vec<Parameter>,
81 lr: f32,
82 momentum: f32,
83 weight_decay: f32,
84 dampening: f32,
85 nesterov: bool,
86 ) -> Self {
87 let num_params = params.len();
88 Self {
89 params,
90 lr,
91 momentum,
92 weight_decay,
93 nesterov,
94 dampening,
95 state: vec![ParamState::new(); num_params],
96 }
97 }
98
99 #[must_use] pub fn momentum(mut self, momentum: f32) -> Self {
101 self.momentum = momentum;
102 self
103 }
104
105 #[must_use] pub fn weight_decay(mut self, weight_decay: f32) -> Self {
107 self.weight_decay = weight_decay;
108 self
109 }
110
111 #[must_use] pub fn nesterov(mut self, nesterov: bool) -> Self {
113 self.nesterov = nesterov;
114 self
115 }
116
117 #[must_use] pub fn dampening(mut self, dampening: f32) -> Self {
119 self.dampening = dampening;
120 self
121 }
122}
123
124impl Optimizer for SGD {
125 fn step(&mut self) {
126 for (i, param) in self.params.iter().enumerate() {
127 if !param.requires_grad() {
128 continue;
129 }
130
131 let grad = match param.grad() {
132 Some(g) => g,
133 None => continue,
134 };
135
136 let mut grad_vec = grad.to_vec();
137
138 if self.weight_decay != 0.0 {
140 let param_vec = param.data().to_vec();
141 for (g, p) in grad_vec.iter_mut().zip(param_vec.iter()) {
142 *g += self.weight_decay * p;
143 }
144 }
145
146 if self.momentum != 0.0 {
148 let state = &mut self.state[i];
149
150 if state.momentum_buffer.is_none() {
151 state.init_momentum(grad_vec.len());
153 let buf = state.momentum_buffer.as_mut().unwrap();
154 buf.copy_from_slice(&grad_vec);
155 } else {
156 let buf = state.momentum_buffer.as_mut().unwrap();
158 for (b, g) in buf.iter_mut().zip(grad_vec.iter()) {
159 *b = self.momentum * *b + (1.0 - self.dampening) * *g;
160 }
161 }
162
163 let buf = state.momentum_buffer.as_ref().unwrap();
164
165 if self.nesterov {
166 let nesterov_grad: Vec<f32> = buf
168 .iter()
169 .zip(grad_vec.iter())
170 .map(|(b, g)| self.momentum * *b + *g)
171 .collect();
172 grad_vec = nesterov_grad;
173 } else {
174 grad_vec = buf.clone();
176 }
177 }
178
179 let param_data = param.data();
181 let param_vec = param_data.to_vec();
182 let new_data: Vec<f32> = param_vec
183 .iter()
184 .zip(grad_vec.iter())
185 .map(|(p, g)| p - self.lr * g)
186 .collect();
187
188 let update = Tensor::from_vec(new_data, param_data.shape()).unwrap();
189 param.update_data(update);
190 }
191 }
192
193 fn zero_grad(&mut self) {
194 for param in &self.params {
195 param.zero_grad();
196 }
197 }
198
199 fn get_lr(&self) -> f32 {
200 self.lr
201 }
202
203 fn set_lr(&mut self, lr: f32) {
204 self.lr = lr;
205 }
206
207 fn parameters(&self) -> &[Parameter] {
208 &self.params
209 }
210}
211
212#[cfg(test)]
217mod tests {
218 use super::*;
219 use axonml_autograd::Variable;
220
221 #[test]
222 fn test_sgd_creation() {
223 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
224 let param = Parameter::from_variable(var);
225 let optimizer = SGD::new(vec![param], 0.01);
226
227 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
228 assert_eq!(optimizer.num_parameters(), 1);
229 }
230
231 #[test]
232 fn test_sgd_with_momentum() {
233 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
234 let param = Parameter::from_variable(var);
235 let optimizer = SGD::with_momentum(vec![param], 0.01, 0.9);
236
237 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
238 }
239
240 #[test]
241 fn test_sgd_step() {
242 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
243 let param = Parameter::from_variable(var);
244
245 param
247 .variable()
248 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
249
250 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
251 optimizer.step();
252
253 let new_data = param.data().to_vec();
254 assert!((new_data[0] - 0.99).abs() < 1e-5);
256 assert!((new_data[1] - 1.98).abs() < 1e-5);
257 assert!((new_data[2] - 2.97).abs() < 1e-5);
258 }
259
260 #[test]
261 fn test_sgd_zero_grad() {
262 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
263 let param = Parameter::from_variable(var);
264
265 param
267 .variable()
268 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
269
270 let mut optimizer = SGD::new(vec![param.clone()], 0.1);
271
272 assert!(param.grad().is_some());
274
275 optimizer.zero_grad();
276
277 let grad = param.grad();
279 if let Some(g) = grad {
280 assert!(g.to_vec().iter().all(|&x| x == 0.0));
281 }
282 }
283
284 #[test]
285 fn test_sgd_builder_pattern() {
286 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
287 let param = Parameter::from_variable(var);
288
289 let optimizer = SGD::new(vec![param], 0.01)
290 .momentum(0.9)
291 .weight_decay(0.0001)
292 .nesterov(true);
293
294 assert!((optimizer.momentum - 0.9).abs() < 1e-6);
295 assert!((optimizer.weight_decay - 0.0001).abs() < 1e-6);
296 assert!(optimizer.nesterov);
297 }
298}