zenu_optimizer/
adam.rs

1use std::{cell::RefCell, collections::HashMap, rc::Rc};
2
3use zenu_autograd::{creator::zeros::zeros_like, Variable};
4use zenu_layer::Parameters;
5use zenu_matrix::{device::Device, num::Num};
6
7use crate::Optimizer;
8
9pub struct Adam<T: Num, D: Device> {
10    learning_rate: T,
11    beta1: T,
12    beta2: T,
13    epsilon: T,
14    step: Rc<RefCell<usize>>,
15    pub m: HashMap<String, Variable<T, D>>,
16    pub v: HashMap<String, Variable<T, D>>,
17}
18
19impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for Adam<T, D> {
20    fn update(&self, parameters: &P) {
21        *self.step.borrow_mut() += 1;
22        let step = T::from_usize(*self.step.borrow());
23
24        let beta1_t = self.beta1.powf(step);
25        let beta2_t = self.beta2.powf(step);
26
27        let parameters = parameters
28            .parameters()
29            .iter()
30            .filter_map(|(key, value)| {
31                value
32                    .get_grad()
33                    .map(|grad| (key.clone(), (value.clone(), grad.clone())))
34            })
35            .collect::<Vec<_>>();
36
37        for (k, (data, grad)) in &parameters {
38            let v = self.v.get(k).unwrap();
39            let m = self.m.get(k).unwrap();
40            let mut v = v.get_as_mut();
41            let mut m = m.get_as_mut();
42            let grad = grad.get_as_ref();
43
44            m *= self.beta1;
45            m += grad.to_ref() * (T::one() - self.beta1);
46
47            v *= self.beta2;
48            v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2);
49
50            let m_hat = m.clone() / (T::one() - beta1_t);
51            let v_hat = v.clone() / (T::one() - beta2_t);
52
53            let m_v_hat = m_hat / (v_hat.sqrt() + self.epsilon);
54            let lr_mv_hat = m_v_hat * self.learning_rate;
55
56            data.get_as_mut().sub_assign(&lr_mv_hat.to_ref());
57        }
58    }
59}
60
61impl<T: Num, D: Device> Adam<T, D> {
62    pub fn new(
63        learning_rate: T,
64        beta1: T,
65        beta2: T,
66        epsilon: T,
67        model: &impl Parameters<T, D>,
68    ) -> Self {
69        let m = model
70            .parameters()
71            .iter()
72            .map(|(key, value)| (key.clone(), zeros_like(value)))
73            .collect();
74        let v = model
75            .parameters()
76            .iter()
77            .map(|(key, value)| (key.clone(), zeros_like(value)))
78            .collect();
79        Self {
80            learning_rate,
81            beta1,
82            beta2,
83            epsilon,
84            step: Rc::new(RefCell::new(0)),
85            m,
86            v,
87        }
88    }
89}