stochastic_optimizers/optimizers/
adam.rs

1use crate::{Parameters, Optimizer};
2use num_traits::{Float, AsPrimitive};
3
4/// Implements the Adam algorithm
5#[derive(Debug)]
6pub struct Adam<P : Parameters> {
7    parameters : P,
8    learning_rate : P::Scalar,
9    beta1 : P::Scalar,
10    beta2 : P::Scalar,
11    epsilon : P::Scalar,
12    timestep : P::Scalar,
13    m0 : P,
14    v0 : P
15}
16
17/// See [`Adam::builder`](Adam::builder)
18pub struct AdamBuilder<P : Parameters>(Adam<P>);
19
20impl<Scalar, P : Parameters<Scalar = Scalar>> Adam<P>
21where
22    Scalar : Float + 'static,
23    f64 : AsPrimitive<Scalar>
24{
25    /// Creates a new Adam optimizer for parameters with given learning rate.
26    /// It uses the default values beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8
27    /// When you want different values use the [`builder`](Adam::builder) function
28    pub fn new(parameters : P, learning_rate : Scalar) -> Adam<P> {
29        let m0 = parameters.zeros();
30        let v0 = parameters.zeros();
31        Adam { parameters, learning_rate, beta1: 0.9.as_(), beta2: 0.999.as_(), epsilon: 1e-8.as_(), timestep: 0.0.as_(), m0, v0}
32    }
33
34    /// Creates a builder for Adam. It uses the default values learning_rate = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = 1e.8.
35    /// These can be changes by calling methods on [`AdamBuilder`](AdamBuilder).
36    /// ```
37    /// use stochastic_optimizers::Adam;
38    /// let init = 0.5;
39    /// 
40    /// let optimizer = Adam::builder(init).learning_rate(0.1).beta2(0.99).build();
41    /// ```
42    pub fn builder(parameters : P) -> AdamBuilder<P> {
43        let m0 = parameters.zeros();
44        let v0 = parameters.zeros();
45        let adam = Adam { parameters, learning_rate : 0.001.as_(), beta1: 0.9.as_(), beta2: 0.999.as_(), epsilon: 1e-8.as_(), timestep: 0.0.as_(), m0, v0};
46
47        AdamBuilder(adam)
48    }
49   
50}
51
52impl<P : Parameters> AdamBuilder<P> {
53    pub fn learning_rate(mut self, learning_rate : P::Scalar) -> AdamBuilder<P> {
54        self.0.learning_rate = learning_rate;
55        self
56    }
57
58    pub fn beta1(mut self, beta1 : P::Scalar) -> AdamBuilder<P> {
59        self.0.beta1 = beta1;
60        self
61    }
62
63    pub fn beta2(mut self, beta2 : P::Scalar) -> AdamBuilder<P> {
64        self.0.beta2 = beta2;
65        self
66    }
67
68    pub fn epsilon(mut self, epsilon : P::Scalar) -> AdamBuilder<P> {
69        self.0.epsilon = epsilon;
70        self
71    }
72
73    pub fn build(self) -> Adam<P> {
74        self.0
75    }
76}
77
78impl<Scalar, P : Parameters<Scalar = Scalar>> Optimizer for Adam<P>
79where
80    Scalar : Float
81{
82    type P = P;
83
84    fn step(&mut self, gradients : &P) {
85        self.timestep = self.timestep + Scalar::one();
86        
87        //m_t = beta_1 * m_t-1 + (1-beta_1) * g
88        self.m0.zip_mut_with(gradients, |m, &g| *m = self.beta1 * *m + (Scalar::one() - self.beta1) * g);
89        self.v0.zip_mut_with(gradients, |v, &g| *v = self.beta2 * *v + (Scalar::one() - self.beta2) * g * g);
90
91        let bias_correction1 = Scalar::one() - self.beta1.powf(self.timestep);
92        let bias_correction2_sqrt = (Scalar::one() - self.beta2.powf(self.timestep)).sqrt();
93
94        let alpha_t = self.learning_rate / bias_correction1;
95
96        self.parameters.zip2_mut_with(&self.m0, &self.v0, |p,&m,&v| *p = *p - alpha_t * m / (v.sqrt() / bias_correction2_sqrt + self.epsilon));
97    }
98
99    fn parameters(&self) -> &P {
100        &self.parameters
101    }
102
103    fn parameters_mut(&mut self) -> &mut P {
104        &mut self.parameters
105    }
106
107    fn into_parameters(self) -> P {
108        self.parameters
109    }
110
111    fn change_learning_rate(&mut self, learning_rate : Scalar) {
112        self.learning_rate = learning_rate;
113    }
114}
115
116
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    use tch::COptimizer;
123
124    #[test]
125    fn qudratic_function() {
126        let start = -3.0;
127        let mut optimizer = Adam::new(start, 0.1);
128
129        for _ in 0..10000 {
130            let current_paramter = optimizer.parameters();
131
132            // d/dx (x-4)^2
133            let gradient = 2.0 * current_paramter - 8.0;
134            optimizer.step(&gradient);
135        }
136
137        assert_eq!(optimizer.into_parameters(), 4.0);
138    }
139
140    #[test]
141    fn pytorch_compare() {
142        let init = vec![3.0, 1.0, 4.0, 1.0, 5.0];
143
144        let optimizer = Adam::new(init, 0.005);
145
146        let optimizer_torch = COptimizer::adam(0.005, 0.9, 0.999, 0.0, 1e-8, false).unwrap();
147
148        assert!(crate::test_utils::compare_optimizers(optimizer, optimizer_torch));
149    }
150
151    #[test]
152    fn builder() {
153        let init = vec![3.0, 1.0, 4.0, 1.0, 5.0];
154
155        let optimizer = Adam::builder(init).learning_rate(0.005).build();
156
157        let optimizer_torch = COptimizer::adam(0.005, 0.9, 0.999, 0.0, 1e-8, false).unwrap();
158
159        assert!(crate::test_utils::compare_optimizers(optimizer, optimizer_torch));
160    }
161}