Skip to main content

entrenar/optim/
adam.rs

1//! Adam optimizer
2
3use super::Optimizer;
4use crate::Tensor;
5use ndarray::Array1;
6
7/// Adam optimizer (Adaptive Moment Estimation)
8pub struct Adam {
9    lr: f32,
10    beta1: f32,
11    beta2: f32,
12    epsilon: f32,
13    t: u64,
14    m: Vec<Option<Array1<f32>>>, // First moment
15    v: Vec<Option<Array1<f32>>>, // Second moment
16}
17
18impl Adam {
19    /// Create a new Adam optimizer
20    pub fn new(lr: f32, beta1: f32, beta2: f32, epsilon: f32) -> Self {
21        Self { lr, beta1, beta2, epsilon, t: 0, m: Vec::new(), v: Vec::new() }
22    }
23
24    /// Create Adam with default parameters
25    pub fn default_params(lr: f32) -> Self {
26        Self::new(lr, 0.9, 0.999, 1e-8)
27    }
28
29    /// Initialize moments if needed
30    fn ensure_moments(&mut self, params: &[Tensor]) {
31        if self.m.is_empty() {
32            self.m = params.iter().map(|_| None).collect();
33            self.v = params.iter().map(|_| None).collect();
34        }
35    }
36}
37
38impl Optimizer for Adam {
39    fn step(&mut self, params: &mut [Tensor]) {
40        self.ensure_moments(params);
41        self.t += 1;
42
43        // Bias adjust factors
44        let lr_t = self.lr
45            * ((1.0 - self.beta2.powi(self.t as i32)).sqrt()
46                / (1.0 - self.beta1.powi(self.t as i32)));
47
48        for (i, param) in params.iter_mut().enumerate() {
49            if let Some(grad) = param.grad() {
50                // Use SIMD for large tensors (>= 16 elements for meaningful speedup)
51                if grad.len() >= 16 {
52                    // Initialize moments if needed
53                    if self.m[i].is_none() {
54                        self.m[i] = Some(Array1::zeros(grad.len()));
55                        self.v[i] = Some(Array1::zeros(grad.len()));
56                    }
57
58                    let m = self.m[i].as_mut().expect("momentum buffer initialized above");
59                    let v = self.v[i].as_mut().expect("velocity buffer initialized above");
60
61                    // Get mutable slices (arrays are always contiguous)
62                    let grad_slice = grad.as_slice().expect("grad array is contiguous");
63                    let m_slice = m.as_slice_mut().expect("momentum array is contiguous");
64                    let v_slice = v.as_slice_mut().expect("velocity array is contiguous");
65                    let param_slice =
66                        param.data_mut().as_slice_mut().expect("param array is contiguous");
67
68                    // Use SIMD-accelerated update
69                    super::simd::simd_adam_update(
70                        grad_slice,
71                        m_slice,
72                        v_slice,
73                        param_slice,
74                        self.beta1,
75                        self.beta2,
76                        lr_t,
77                        self.epsilon,
78                    );
79                } else {
80                    // Fallback to scalar implementation for small tensors
81                    // m_t = β1 * m_{t-1} + (1 - β1) * g
82                    let m_t = if let Some(m) = &self.m[i] {
83                        m * self.beta1 + &grad * (1.0 - self.beta1)
84                    } else {
85                        &grad * (1.0 - self.beta1)
86                    };
87
88                    // v_t = β2 * v_{t-1} + (1 - β2) * g²
89                    let grad_sq = &grad * &grad;
90                    let v_t = if let Some(v) = &self.v[i] {
91                        v * self.beta2 + &grad_sq * (1.0 - self.beta2)
92                    } else {
93                        &grad_sq * (1.0 - self.beta2)
94                    };
95
96                    // θ_t = θ_{t-1} - lr_t * m_t / (√v_t + ε)
97                    let update = &m_t / &(v_t.mapv(f32::sqrt) + self.epsilon) * lr_t;
98                    *param.data_mut() = param.data() - &update;
99
100                    self.m[i] = Some(m_t);
101                    self.v[i] = Some(v_t);
102                }
103            }
104        }
105    }
106
107    fn lr(&self) -> f32 {
108        self.lr
109    }
110
111    fn set_lr(&mut self, lr: f32) {
112        self.lr = lr;
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use crate::autograd::*;
120
121    #[test]
122    fn test_adam_quadratic_convergence() {
123        // Test convergence on f(x) = x²
124        let mut params = vec![Tensor::from_vec(vec![5.0, -3.0, 2.0], true)];
125        let mut optimizer = Adam::default_params(0.1);
126
127        for _ in 0..100 {
128            // Compute gradient: ∇(x²) = 2x
129            let grad = params[0].data().mapv(|x| 2.0 * x);
130            params[0].set_grad(grad);
131
132            optimizer.step(&mut params);
133        }
134
135        // Should converge close to 0
136        for &val in params[0].data() {
137            assert!(val.abs() < 0.5, "Value {val} did not converge");
138        }
139    }
140}