Skip to main content

zyx_optim/
adamw.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use zyx::Tensor;
5use zyx_derive::Module;
6
7/// # Adaptive momentum estimation optimizer
8#[derive(Module)]
9#[cfg_attr(feature = "py", pyo3::pyclass)]
10pub struct AdamW {
11    /// learning rate (default: 1e-3)
12    pub learning_rate: f32,
13    /// coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
14    pub betas: (f32, f32),
15    /// term added to the denominator to improve numerical stability (default: 1e-8)
16    pub eps: f32,
17    /// weight decay (L2 penalty) (default: 0)
18    pub weight_decay: f32,
19    /// whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: false)
20    pub amsgrad: bool,
21    /// m
22    pub m: Vec<Tensor>,
23    /// v
24    pub v: Vec<Tensor>,
25    /// vm
26    pub vm: Vec<Tensor>,
27    /// t
28    pub t: usize,
29}
30
31impl Default for AdamW {
32    fn default() -> Self {
33        Self {
34            learning_rate: 0.001,
35            betas: (0.9, 0.999),
36            eps: 1e-8,
37            weight_decay: 0.0,
38            amsgrad: false,
39            m: Vec::new(),
40            v: Vec::new(),
41            vm: Vec::new(),
42            t: 0,
43        }
44    }
45}
46
47impl AdamW {
48    /// Updates parameters with gradients.
49    /// Number of parameters must be the same as number of gradients.
50    /// Gradients can be None, those are simply skipped.
51    pub fn update<'a>(
52        &mut self,
53        parameters: impl IntoIterator<Item = &'a mut Tensor>,
54        gradients: impl IntoIterator<Item = Option<Tensor>>,
55    ) {
56        use zyx::Scalar;
57        self.t += 1;
58        for (i, (param, grad)) in parameters.into_iter().zip(gradients).enumerate() {
59            let Some(grad) = grad else {
60                // Initialize moment estimates for new params (lazy)
61                if self.m.len() <= i {
62                    self.m.push(Tensor::zeros_like(&*param));
63                    self.v.push(Tensor::zeros_like(&*param));
64                    if self.amsgrad {
65                        self.vm.push(Tensor::zeros_like(&*param));
66                    }
67                }
68                continue;
69            };
70
71            // Update biased first moment estimate
72            if let Some(m) = self.m.get_mut(i) {
73                *m = &*m * self.betas.0 + &grad * (1.0 - self.betas.0);
74            } else {
75                self.m.push(&grad * (1.0 - self.betas.0));
76            }
77
78            // Update biased second moment estimate
79            if let Some(v) = self.v.get_mut(i) {
80                *v = &*v * self.betas.1 + &grad * &grad * (1.0 - self.betas.1);
81            } else {
82                self.v.push(&grad * &grad * (1.0 - self.betas.1));
83            }
84
85            // Compute bias-corrected moments
86            let mh = &self.m[i] / (1.0 - self.betas.0.pow(self.t as f32));
87            let vh = &self.v[i] / (1.0 - self.betas.1.pow(self.t as f32));
88
89            if self.amsgrad {
90                if let Some(vm) = self.vm.get_mut(i) {
91                    *vm = vm.cmplt(&vh).unwrap().where_(vh, &*vm).unwrap();
92                } else {
93                    self.vm.push(vh);
94                }
95                // Parameter update with AMSGrad max
96                *param = (&*param - mh / ((self.vm[i].sqrt() + self.eps) * self.learning_rate))
97                    .cast(param.dtype());
98            } else {
99                // Parameter update standard AdamW
100                *param = (&*param - mh / ((vh.sqrt() + self.eps) * self.learning_rate))
101                    .cast(param.dtype());
102            }
103
104            // Decoupled weight decay step applied directly to parameter
105            if self.weight_decay != 0.0 {
106                *param = &*param * (1.0 - self.learning_rate * self.weight_decay);
107            }
108        }
109    }
110}