1use super::Optimizer;
4use crate::Tensor;
5use ndarray::Array1;
6
7pub struct Adam {
9 lr: f32,
10 beta1: f32,
11 beta2: f32,
12 epsilon: f32,
13 t: u64,
14 m: Vec<Option<Array1<f32>>>, v: Vec<Option<Array1<f32>>>, }
17
18impl Adam {
19 pub fn new(lr: f32, beta1: f32, beta2: f32, epsilon: f32) -> Self {
21 Self { lr, beta1, beta2, epsilon, t: 0, m: Vec::new(), v: Vec::new() }
22 }
23
24 pub fn default_params(lr: f32) -> Self {
26 Self::new(lr, 0.9, 0.999, 1e-8)
27 }
28
29 fn ensure_moments(&mut self, params: &[Tensor]) {
31 if self.m.is_empty() {
32 self.m = params.iter().map(|_| None).collect();
33 self.v = params.iter().map(|_| None).collect();
34 }
35 }
36}
37
38impl Optimizer for Adam {
39 fn step(&mut self, params: &mut [Tensor]) {
40 self.ensure_moments(params);
41 self.t += 1;
42
43 let lr_t = self.lr
45 * ((1.0 - self.beta2.powi(self.t as i32)).sqrt()
46 / (1.0 - self.beta1.powi(self.t as i32)));
47
48 for (i, param) in params.iter_mut().enumerate() {
49 if let Some(grad) = param.grad() {
50 if grad.len() >= 16 {
52 if self.m[i].is_none() {
54 self.m[i] = Some(Array1::zeros(grad.len()));
55 self.v[i] = Some(Array1::zeros(grad.len()));
56 }
57
58 let m = self.m[i].as_mut().expect("momentum buffer initialized above");
59 let v = self.v[i].as_mut().expect("velocity buffer initialized above");
60
61 let grad_slice = grad.as_slice().expect("grad array is contiguous");
63 let m_slice = m.as_slice_mut().expect("momentum array is contiguous");
64 let v_slice = v.as_slice_mut().expect("velocity array is contiguous");
65 let param_slice =
66 param.data_mut().as_slice_mut().expect("param array is contiguous");
67
68 super::simd::simd_adam_update(
70 grad_slice,
71 m_slice,
72 v_slice,
73 param_slice,
74 self.beta1,
75 self.beta2,
76 lr_t,
77 self.epsilon,
78 );
79 } else {
80 let m_t = if let Some(m) = &self.m[i] {
83 m * self.beta1 + &grad * (1.0 - self.beta1)
84 } else {
85 &grad * (1.0 - self.beta1)
86 };
87
88 let grad_sq = &grad * &grad;
90 let v_t = if let Some(v) = &self.v[i] {
91 v * self.beta2 + &grad_sq * (1.0 - self.beta2)
92 } else {
93 &grad_sq * (1.0 - self.beta2)
94 };
95
96 let update = &m_t / &(v_t.mapv(f32::sqrt) + self.epsilon) * lr_t;
98 *param.data_mut() = param.data() - &update;
99
100 self.m[i] = Some(m_t);
101 self.v[i] = Some(v_t);
102 }
103 }
104 }
105 }
106
107 fn lr(&self) -> f32 {
108 self.lr
109 }
110
111 fn set_lr(&mut self, lr: f32) {
112 self.lr = lr;
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use crate::autograd::*;
120
121 #[test]
122 fn test_adam_quadratic_convergence() {
123 let mut params = vec![Tensor::from_vec(vec![5.0, -3.0, 2.0], true)];
125 let mut optimizer = Adam::default_params(0.1);
126
127 for _ in 0..100 {
128 let grad = params[0].data().mapv(|x| 2.0 * x);
130 params[0].set_grad(grad);
131
132 optimizer.step(&mut params);
133 }
134
135 for &val in params[0].data() {
137 assert!(val.abs() < 0.5, "Value {val} did not converge");
138 }
139 }
140}