1use zyx::Tensor;
5use zyx_derive::Module;
6
7#[derive(Module)]
9#[cfg_attr(feature = "py", pyo3::pyclass)]
10pub struct SGD {
11 pub learning_rate: f32,
13 pub momentum: f32,
15 pub weight_decay: f32,
17 pub dampening: f32,
19 pub nesterov: bool,
21 pub maximize: bool,
23 pub bias: Vec<Tensor>,
25}
26
27impl Default for SGD {
28 fn default() -> Self {
29 Self {
30 learning_rate: 0.001,
31 momentum: 0.0,
32 weight_decay: 0.0,
33 dampening: 0.0,
34 nesterov: false,
35 maximize: false,
36 bias: Vec::new(),
37 }
38 }
39}
40
41impl SGD {
42 pub fn update<'a>(
46 &mut self,
47 parameters: impl IntoIterator<Item = &'a mut Tensor>,
48 gradients: impl IntoIterator<Item = Option<Tensor>>,
49 ) {
50 let params: Vec<&mut Tensor> = parameters.into_iter().collect();
51 let grads: Vec<Option<Tensor>> = gradients.into_iter().collect();
52
53 assert_eq!(
54 params.len(),
55 grads.len(),
56 "Number of parameters != number of gradients."
57 );
58
59 for (i, (param, grad)) in params.into_iter().zip(grads).enumerate() {
60 if let Some(mut grad) = grad {
61 if self.weight_decay != 0.0 {
62 grad = grad + param.clone() * self.weight_decay;
63 }
64 if self.momentum != 0.0 {
65 if let Some(bias) = self.bias.get_mut(i) {
66 *bias =
67 bias.clone() * self.momentum + grad.clone() * (1.0 - self.dampening);
68 } else {
69 self.bias.push(grad.clone());
70 }
71 if self.nesterov {
72 grad = grad + self.bias[i].clone() * self.momentum;
73 } else {
74 grad = self.bias[i].clone();
75 }
76 }
77 if self.maximize {
78 *param = (&*param + grad * self.learning_rate).cast(param.dtype());
81 } else {
82 *param = (&*param - grad * self.learning_rate).cast(param.dtype());
83 }
84 }
85 }
86 }
87}