1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// Copyright (C) 2025 zk4x
// SPDX-License-Identifier: LGPL-3.0-only
use zyx::Tensor;
use zyx_derive::Module;
/// # Adaptive momentum estimation optimizer
#[derive(Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct Adam {
/// learning rate (default: 1e-3)
pub learning_rate: f32,
/// coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
pub betas: (f32, f32),
/// term added to the denominator to improve numerical stability (default: 1e-8)
pub eps: f32,
/// weight decay (L2 penalty) (default: 0)
pub weight_decay: f32,
/// whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: false)
pub amsgrad: bool,
/// m
pub m: Vec<Tensor>,
/// v
pub v: Vec<Tensor>,
/// vm
pub vm: Vec<Tensor>,
/// t
pub t: usize,
}
impl Default for Adam {
fn default() -> Self {
Self {
learning_rate: 0.001,
betas: (0.9, 0.999),
eps: 1e-8,
weight_decay: 0.0,
amsgrad: false,
m: Vec::new(),
v: Vec::new(),
vm: Vec::new(),
t: 0,
}
}
}
impl Adam {
/// Updates parameters with gradients.
/// Number of parameters must be the same as number of gradients.
/// Gradients can be None, those are simply skipped.
pub fn update<'a>(
&mut self,
parameters: impl IntoIterator<Item = &'a mut Tensor>,
gradients: impl IntoIterator<Item = Option<Tensor>>,
) {
use zyx::Scalar;
//let params: Vec<&mut Tensor> = parameters.into_iter().collect();
//let grads: Vec<Option<Tensor>> = gradients.into_iter().collect();
/*assert_eq!(
params.len(),
grads.len(),
"Number of parameters != number of gradients."
);*/
self.t += 1;
for (i, (param, grad)) in parameters.into_iter().zip(gradients).enumerate() {
let Some(mut grad) = grad else {
if self.m.len() <= i {
self.m.push(Tensor::zeros_like(&*param));
self.v.push(Tensor::zeros_like(&*param));
}
continue;
};
if self.weight_decay != 0.0 {
grad = grad + &*param * self.weight_decay;
}
if let Some(m) = self.m.get_mut(i) {
*m = &*m * self.betas.0 + &grad * (1.0 - self.betas.0);
} else {
self.m.push(&grad * (1.0 - self.betas.0));
}
if let Some(v) = self.v.get_mut(i) {
*v = &*v * self.betas.1 + &grad * &grad * (1.0 - self.betas.1);
} else {
self.v.push(&grad * &grad * (1.0 - self.betas.1));
}
let mh = &self.m[i] / (1.0 - self.betas.0.pow(self.t as f32));
let vh = &self.v[i] / (1.0 - self.betas.1.pow(self.t as f32));
if self.amsgrad {
if let Some(vm) = self.vm.get_mut(i) {
*vm = vm.cmplt(&vh).unwrap().where_(vh, &*vm).unwrap();
} else {
self.vm.push(vh);
}
*param = (&*param - self.learning_rate * mh / (self.vm[i].sqrt() + self.eps))
.cast(param.dtype());
} else {
*param = (&*param - self.learning_rate * mh / (vh.sqrt() + self.eps))
.cast(param.dtype());
}
}
}
}