1use zyx::Tensor;
5use zyx_derive::Module;
6
7#[derive(Module)]
9#[cfg_attr(feature = "py", pyo3::pyclass)]
10pub struct Adam {
11 pub learning_rate: f32,
13 pub betas: (f32, f32),
15 pub eps: f32,
17 pub weight_decay: f32,
19 pub amsgrad: bool,
21 pub m: Vec<Tensor>,
23 pub v: Vec<Tensor>,
25 pub vm: Vec<Tensor>,
27 pub t: usize,
29}
30
31impl Default for Adam {
32 fn default() -> Self {
33 Self {
34 learning_rate: 0.001,
35 betas: (0.9, 0.999),
36 eps: 1e-8,
37 weight_decay: 0.0,
38 amsgrad: false,
39 m: Vec::new(),
40 v: Vec::new(),
41 vm: Vec::new(),
42 t: 0,
43 }
44 }
45}
46
47impl Adam {
48 pub fn update<'a>(
52 &mut self,
53 parameters: impl IntoIterator<Item = &'a mut Tensor>,
54 gradients: impl IntoIterator<Item = Option<Tensor>>,
55 ) {
56 use zyx::Scalar;
57 self.t += 1;
67
68 for (i, (param, grad)) in parameters.into_iter().zip(gradients).enumerate() {
69 let Some(mut grad) = grad else {
70 if self.m.len() <= i {
71 self.m.push(Tensor::zeros_like(&*param));
72 self.v.push(Tensor::zeros_like(&*param));
73 }
74 continue;
75 };
76 if self.weight_decay != 0.0 {
77 grad = grad + &*param * self.weight_decay;
78 }
79 if let Some(m) = self.m.get_mut(i) {
80 *m = &*m * self.betas.0 + &grad * (1.0 - self.betas.0);
81 } else {
82 self.m.push(&grad * (1.0 - self.betas.0));
83 }
84 if let Some(v) = self.v.get_mut(i) {
85 *v = &*v * self.betas.1 + &grad * &grad * (1.0 - self.betas.1);
86 } else {
87 self.v.push(&grad * &grad * (1.0 - self.betas.1));
88 }
89 let mh = &self.m[i] / (1.0 - self.betas.0.pow(self.t as f32));
90 let vh = &self.v[i] / (1.0 - self.betas.1.pow(self.t as f32));
91 if self.amsgrad {
92 if let Some(vm) = self.vm.get_mut(i) {
93 *vm = vm.cmplt(&vh).unwrap().where_(vh, &*vm).unwrap();
94 } else {
95 self.vm.push(vh);
96 }
97 *param = (&*param - self.learning_rate * mh / (self.vm[i].sqrt() + self.eps))
98 .cast(param.dtype());
99 } else {
100 *param = (&*param - self.learning_rate * mh / (vh.sqrt() + self.eps))
101 .cast(param.dtype());
102 }
103 }
104 }
105}