stochastic_optimizers/optimizers/
adam.rs1use crate::{Parameters, Optimizer};
2use num_traits::{Float, AsPrimitive};
3
4#[derive(Debug)]
6pub struct Adam<P : Parameters> {
7 parameters : P,
8 learning_rate : P::Scalar,
9 beta1 : P::Scalar,
10 beta2 : P::Scalar,
11 epsilon : P::Scalar,
12 timestep : P::Scalar,
13 m0 : P,
14 v0 : P
15}
16
17pub struct AdamBuilder<P : Parameters>(Adam<P>);
19
20impl<Scalar, P : Parameters<Scalar = Scalar>> Adam<P>
21where
22 Scalar : Float + 'static,
23 f64 : AsPrimitive<Scalar>
24{
25 pub fn new(parameters : P, learning_rate : Scalar) -> Adam<P> {
29 let m0 = parameters.zeros();
30 let v0 = parameters.zeros();
31 Adam { parameters, learning_rate, beta1: 0.9.as_(), beta2: 0.999.as_(), epsilon: 1e-8.as_(), timestep: 0.0.as_(), m0, v0}
32 }
33
34 pub fn builder(parameters : P) -> AdamBuilder<P> {
43 let m0 = parameters.zeros();
44 let v0 = parameters.zeros();
45 let adam = Adam { parameters, learning_rate : 0.001.as_(), beta1: 0.9.as_(), beta2: 0.999.as_(), epsilon: 1e-8.as_(), timestep: 0.0.as_(), m0, v0};
46
47 AdamBuilder(adam)
48 }
49
50}
51
52impl<P : Parameters> AdamBuilder<P> {
53 pub fn learning_rate(mut self, learning_rate : P::Scalar) -> AdamBuilder<P> {
54 self.0.learning_rate = learning_rate;
55 self
56 }
57
58 pub fn beta1(mut self, beta1 : P::Scalar) -> AdamBuilder<P> {
59 self.0.beta1 = beta1;
60 self
61 }
62
63 pub fn beta2(mut self, beta2 : P::Scalar) -> AdamBuilder<P> {
64 self.0.beta2 = beta2;
65 self
66 }
67
68 pub fn epsilon(mut self, epsilon : P::Scalar) -> AdamBuilder<P> {
69 self.0.epsilon = epsilon;
70 self
71 }
72
73 pub fn build(self) -> Adam<P> {
74 self.0
75 }
76}
77
78impl<Scalar, P : Parameters<Scalar = Scalar>> Optimizer for Adam<P>
79where
80 Scalar : Float
81{
82 type P = P;
83
84 fn step(&mut self, gradients : &P) {
85 self.timestep = self.timestep + Scalar::one();
86
87 self.m0.zip_mut_with(gradients, |m, &g| *m = self.beta1 * *m + (Scalar::one() - self.beta1) * g);
89 self.v0.zip_mut_with(gradients, |v, &g| *v = self.beta2 * *v + (Scalar::one() - self.beta2) * g * g);
90
91 let bias_correction1 = Scalar::one() - self.beta1.powf(self.timestep);
92 let bias_correction2_sqrt = (Scalar::one() - self.beta2.powf(self.timestep)).sqrt();
93
94 let alpha_t = self.learning_rate / bias_correction1;
95
96 self.parameters.zip2_mut_with(&self.m0, &self.v0, |p,&m,&v| *p = *p - alpha_t * m / (v.sqrt() / bias_correction2_sqrt + self.epsilon));
97 }
98
99 fn parameters(&self) -> &P {
100 &self.parameters
101 }
102
103 fn parameters_mut(&mut self) -> &mut P {
104 &mut self.parameters
105 }
106
107 fn into_parameters(self) -> P {
108 self.parameters
109 }
110
111 fn change_learning_rate(&mut self, learning_rate : Scalar) {
112 self.learning_rate = learning_rate;
113 }
114}
115
116
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 use tch::COptimizer;
123
124 #[test]
125 fn qudratic_function() {
126 let start = -3.0;
127 let mut optimizer = Adam::new(start, 0.1);
128
129 for _ in 0..10000 {
130 let current_paramter = optimizer.parameters();
131
132 let gradient = 2.0 * current_paramter - 8.0;
134 optimizer.step(&gradient);
135 }
136
137 assert_eq!(optimizer.into_parameters(), 4.0);
138 }
139
140 #[test]
141 fn pytorch_compare() {
142 let init = vec![3.0, 1.0, 4.0, 1.0, 5.0];
143
144 let optimizer = Adam::new(init, 0.005);
145
146 let optimizer_torch = COptimizer::adam(0.005, 0.9, 0.999, 0.0, 1e-8, false).unwrap();
147
148 assert!(crate::test_utils::compare_optimizers(optimizer, optimizer_torch));
149 }
150
151 #[test]
152 fn builder() {
153 let init = vec![3.0, 1.0, 4.0, 1.0, 5.0];
154
155 let optimizer = Adam::builder(init).learning_rate(0.005).build();
156
157 let optimizer_torch = COptimizer::adam(0.005, 0.9, 0.999, 0.0, 1e-8, false).unwrap();
158
159 assert!(crate::test_utils::compare_optimizers(optimizer, optimizer_torch));
160 }
161}