Skip to main content

zyx_optim/
adam.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use zyx::Tensor;
5use zyx_derive::Module;
6
7/// # Adaptive momentum estimation optimizer
8#[derive(Module)]
9#[cfg_attr(feature = "py", pyo3::pyclass)]
10pub struct Adam {
11    /// learning rate (default: 1e-3)
12    pub learning_rate: f32,
13    /// coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
14    pub betas: (f32, f32),
15    /// term added to the denominator to improve numerical stability (default: 1e-8)
16    pub eps: f32,
17    /// weight decay (L2 penalty) (default: 0)
18    pub weight_decay: f32,
19    /// whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: false)
20    pub amsgrad: bool,
21    /// m
22    pub m: Vec<Tensor>,
23    /// v
24    pub v: Vec<Tensor>,
25    /// vm
26    pub vm: Vec<Tensor>,
27    /// t
28    pub t: usize,
29}
30
31impl Default for Adam {
32    fn default() -> Self {
33        Self {
34            learning_rate: 0.001,
35            betas: (0.9, 0.999),
36            eps: 1e-8,
37            weight_decay: 0.0,
38            amsgrad: false,
39            m: Vec::new(),
40            v: Vec::new(),
41            vm: Vec::new(),
42            t: 0,
43        }
44    }
45}
46
47impl Adam {
48    /// Updates parameters with gradients.
49    /// Number of parameters must be the same as number of gradients.
50    /// Gradients can be None, those are simply skipped.
51    pub fn update<'a>(
52        &mut self,
53        parameters: impl IntoIterator<Item = &'a mut Tensor>,
54        gradients: impl IntoIterator<Item = Option<Tensor>>,
55    ) {
56        use zyx::Scalar;
57        //let params: Vec<&mut Tensor> = parameters.into_iter().collect();
58        //let grads: Vec<Option<Tensor>> = gradients.into_iter().collect();
59
60        /*assert_eq!(
61            params.len(),
62            grads.len(),
63            "Number of parameters != number of gradients."
64        );*/
65
66        self.t += 1;
67
68        for (i, (param, grad)) in parameters.into_iter().zip(gradients).enumerate() {
69            let Some(mut grad) = grad else {
70                if self.m.len() <= i {
71                    self.m.push(Tensor::zeros_like(&*param));
72                    self.v.push(Tensor::zeros_like(&*param));
73                }
74                continue;
75            };
76            if self.weight_decay != 0.0 {
77                grad = grad + &*param * self.weight_decay;
78            }
79            if let Some(m) = self.m.get_mut(i) {
80                *m = &*m * self.betas.0 + &grad * (1.0 - self.betas.0);
81            } else {
82                self.m.push(&grad * (1.0 - self.betas.0));
83            }
84            if let Some(v) = self.v.get_mut(i) {
85                *v = &*v * self.betas.1 + &grad * &grad * (1.0 - self.betas.1);
86            } else {
87                self.v.push(&grad * &grad * (1.0 - self.betas.1));
88            }
89            let mh = &self.m[i] / (1.0 - self.betas.0.pow(self.t as f32));
90            let vh = &self.v[i] / (1.0 - self.betas.1.pow(self.t as f32));
91            if self.amsgrad {
92                if let Some(vm) = self.vm.get_mut(i) {
93                    *vm = vm.cmplt(&vh).unwrap().where_(vh, &*vm).unwrap();
94                } else {
95                    self.vm.push(vh);
96                }
97                *param = (&*param - self.learning_rate * mh / (self.vm[i].sqrt() + self.eps))
98                    .cast(param.dtype());
99            } else {
100                *param = (&*param - self.learning_rate * mh / (vh.sqrt() + self.eps))
101                    .cast(param.dtype());
102            }
103        }
104    }
105}