axonml_optim/
adam.rs

1//! Adam Optimizer - Adaptive Moment Estimation
2//!
3//! Implements Adam and `AdamW` (Adam with decoupled weight decay).
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13// =============================================================================
14// Adam
15// =============================================================================
16
17/// Adam optimizer.
18///
19/// Maintains per-parameter adaptive learning rates using first and
20/// second moment estimates of gradients.
21///
22/// Update rule:
23/// ```text
24/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
25/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
26/// m_hat = m_t / (1 - beta1^t)
27/// v_hat = v_t / (1 - beta2^t)
28/// param = param - lr * m_hat / (sqrt(v_hat) + eps)
29/// ```
30pub struct Adam {
31    /// Parameters to optimize.
32    params: Vec<Parameter>,
33    /// Learning rate.
34    lr: f32,
35    /// First moment decay rate.
36    beta1: f32,
37    /// Second moment decay rate.
38    beta2: f32,
39    /// Small constant for numerical stability.
40    eps: f32,
41    /// Weight decay (L2 regularization for standard Adam).
42    weight_decay: f32,
43    /// Whether to use `AMSGrad` variant.
44    amsgrad: bool,
45    /// Per-parameter state.
46    state: Vec<AdamState>,
47}
48
49/// State for Adam optimizer.
50#[derive(Debug, Clone)]
51struct AdamState {
52    /// First moment (mean of gradients).
53    exp_avg: Vec<f32>,
54    /// Second moment (variance of gradients).
55    exp_avg_sq: Vec<f32>,
56    /// Max second moment for `AMSGrad`.
57    max_exp_avg_sq: Option<Vec<f32>>,
58    /// Step count for bias correction.
59    step: usize,
60}
61
62impl AdamState {
63    fn new(size: usize, amsgrad: bool) -> Self {
64        Self {
65            exp_avg: vec![0.0; size],
66            exp_avg_sq: vec![0.0; size],
67            max_exp_avg_sq: if amsgrad { Some(vec![0.0; size]) } else { None },
68            step: 0,
69        }
70    }
71}
72
73impl Adam {
74    /// Creates a new Adam optimizer with default hyperparameters.
75    #[must_use] pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
76        Self::with_betas(params, lr, (0.9, 0.999))
77    }
78
79    /// Creates Adam with specified betas.
80    #[must_use] pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
81        Self {
82            params,
83            lr,
84            beta1: betas.0,
85            beta2: betas.1,
86            eps: 1e-8,
87            weight_decay: 0.0,
88            amsgrad: false,
89            state: Vec::new(),
90        }
91    }
92
93    /// Creates Adam with all options.
94    #[must_use] pub fn with_options(
95        params: Vec<Parameter>,
96        lr: f32,
97        betas: (f32, f32),
98        eps: f32,
99        weight_decay: f32,
100        amsgrad: bool,
101    ) -> Self {
102        Self {
103            params,
104            lr,
105            beta1: betas.0,
106            beta2: betas.1,
107            eps,
108            weight_decay,
109            amsgrad,
110            state: Vec::new(),
111        }
112    }
113
114    /// Builder method to set betas.
115    #[must_use] pub fn betas(mut self, betas: (f32, f32)) -> Self {
116        self.beta1 = betas.0;
117        self.beta2 = betas.1;
118        self
119    }
120
121    /// Builder method to set epsilon.
122    #[must_use] pub fn eps(mut self, eps: f32) -> Self {
123        self.eps = eps;
124        self
125    }
126
127    /// Builder method to set weight decay.
128    #[must_use] pub fn weight_decay(mut self, weight_decay: f32) -> Self {
129        self.weight_decay = weight_decay;
130        self
131    }
132
133    /// Builder method to enable `AMSGrad`.
134    #[must_use] pub fn amsgrad(mut self, amsgrad: bool) -> Self {
135        self.amsgrad = amsgrad;
136        self
137    }
138
139    fn ensure_state_initialized(&mut self) {
140        if self.state.is_empty() {
141            self.state = self
142                .params
143                .iter()
144                .map(|p| AdamState::new(p.numel(), self.amsgrad))
145                .collect();
146        }
147    }
148}
149
150impl Optimizer for Adam {
151    fn step(&mut self) {
152        self.ensure_state_initialized();
153
154        for (i, param) in self.params.iter().enumerate() {
155            if !param.requires_grad() {
156                continue;
157            }
158
159            let grad = match param.grad() {
160                Some(g) => g,
161                None => continue,
162            };
163
164            let grad_vec = grad.to_vec();
165            let state = &mut self.state[i];
166            state.step += 1;
167
168            let param_data = param.data();
169            let mut param_vec = param_data.to_vec();
170
171            // Apply L2 regularization to gradient (standard Adam weight decay)
172            let grad_vec: Vec<f32> = if self.weight_decay == 0.0 {
173                grad_vec
174            } else {
175                grad_vec
176                    .iter()
177                    .zip(param_vec.iter())
178                    .map(|(g, p)| g + self.weight_decay * p)
179                    .collect()
180            };
181
182            // Update biased first moment estimate
183            for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
184                *m = self.beta1 * *m + (1.0 - self.beta1) * g;
185            }
186
187            // Update biased second moment estimate
188            for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
189                *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
190            }
191
192            // Bias correction
193            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
194            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
195
196            // Compute step size
197            let step_size = self.lr / bias_correction1;
198
199            // Update parameters
200            if self.amsgrad {
201                // AMSGrad variant
202                let max_exp_avg_sq = state.max_exp_avg_sq.as_mut().unwrap();
203                for (max_v, v) in max_exp_avg_sq.iter_mut().zip(state.exp_avg_sq.iter()) {
204                    *max_v = max_v.max(*v);
205                }
206                for (p, (m, max_v)) in param_vec
207                    .iter_mut()
208                    .zip(state.exp_avg.iter().zip(max_exp_avg_sq.iter()))
209                {
210                    let denom = (max_v / bias_correction2).sqrt() + self.eps;
211                    *p -= step_size * m / denom;
212                }
213            } else {
214                // Standard Adam
215                for (p, (m, v)) in param_vec
216                    .iter_mut()
217                    .zip(state.exp_avg.iter().zip(state.exp_avg_sq.iter()))
218                {
219                    let denom = (v / bias_correction2).sqrt() + self.eps;
220                    *p -= step_size * m / denom;
221                }
222            }
223
224            let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
225            param.update_data(update);
226        }
227    }
228
229    fn zero_grad(&mut self) {
230        for param in &self.params {
231            param.zero_grad();
232        }
233    }
234
235    fn get_lr(&self) -> f32 {
236        self.lr
237    }
238
239    fn set_lr(&mut self, lr: f32) {
240        self.lr = lr;
241    }
242
243    fn parameters(&self) -> &[Parameter] {
244        &self.params
245    }
246}
247
248// =============================================================================
249// AdamW
250// =============================================================================
251
252/// `AdamW` optimizer (Adam with decoupled weight decay).
253///
254/// Unlike standard Adam which applies L2 regularization to the gradient,
255/// `AdamW` applies weight decay directly to the parameters.
256///
257/// Update rule:
258/// ```text
259/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
260/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
261/// m_hat = m_t / (1 - beta1^t)
262/// v_hat = v_t / (1 - beta2^t)
263/// param = param - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * param)
264/// ```
265pub struct AdamW {
266    /// Parameters to optimize.
267    params: Vec<Parameter>,
268    /// Learning rate.
269    lr: f32,
270    /// First moment decay rate.
271    beta1: f32,
272    /// Second moment decay rate.
273    beta2: f32,
274    /// Small constant for numerical stability.
275    eps: f32,
276    /// Decoupled weight decay coefficient.
277    weight_decay: f32,
278    /// Whether to use `AMSGrad` variant.
279    amsgrad: bool,
280    /// Per-parameter state.
281    state: Vec<AdamState>,
282}
283
284impl AdamW {
285    /// Creates a new `AdamW` optimizer with default hyperparameters.
286    #[must_use] pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
287        Self::with_betas(params, lr, (0.9, 0.999))
288    }
289
290    /// Creates `AdamW` with specified betas.
291    #[must_use] pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
292        Self {
293            params,
294            lr,
295            beta1: betas.0,
296            beta2: betas.1,
297            eps: 1e-8,
298            weight_decay: 0.01, // Default weight decay for AdamW
299            amsgrad: false,
300            state: Vec::new(),
301        }
302    }
303
304    /// Creates `AdamW` with all options.
305    #[must_use] pub fn with_options(
306        params: Vec<Parameter>,
307        lr: f32,
308        betas: (f32, f32),
309        eps: f32,
310        weight_decay: f32,
311        amsgrad: bool,
312    ) -> Self {
313        Self {
314            params,
315            lr,
316            beta1: betas.0,
317            beta2: betas.1,
318            eps,
319            weight_decay,
320            amsgrad,
321            state: Vec::new(),
322        }
323    }
324
325    /// Builder method to set betas.
326    #[must_use] pub fn betas(mut self, betas: (f32, f32)) -> Self {
327        self.beta1 = betas.0;
328        self.beta2 = betas.1;
329        self
330    }
331
332    /// Builder method to set epsilon.
333    #[must_use] pub fn eps(mut self, eps: f32) -> Self {
334        self.eps = eps;
335        self
336    }
337
338    /// Builder method to set weight decay.
339    #[must_use] pub fn weight_decay(mut self, weight_decay: f32) -> Self {
340        self.weight_decay = weight_decay;
341        self
342    }
343
344    /// Builder method to enable `AMSGrad`.
345    #[must_use] pub fn amsgrad(mut self, amsgrad: bool) -> Self {
346        self.amsgrad = amsgrad;
347        self
348    }
349
350    fn ensure_state_initialized(&mut self) {
351        if self.state.is_empty() {
352            self.state = self
353                .params
354                .iter()
355                .map(|p| AdamState::new(p.numel(), self.amsgrad))
356                .collect();
357        }
358    }
359}
360
361impl Optimizer for AdamW {
362    fn step(&mut self) {
363        self.ensure_state_initialized();
364
365        for (i, param) in self.params.iter().enumerate() {
366            if !param.requires_grad() {
367                continue;
368            }
369
370            let grad = match param.grad() {
371                Some(g) => g,
372                None => continue,
373            };
374
375            let grad_vec = grad.to_vec();
376            let state = &mut self.state[i];
377            state.step += 1;
378
379            let param_data = param.data();
380            let mut param_vec = param_data.to_vec();
381
382            // Decoupled weight decay (applied directly to parameters)
383            if self.weight_decay != 0.0 {
384                for p in &mut param_vec {
385                    *p *= 1.0 - self.lr * self.weight_decay;
386                }
387            }
388
389            // Update biased first moment estimate
390            for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
391                *m = self.beta1 * *m + (1.0 - self.beta1) * g;
392            }
393
394            // Update biased second moment estimate
395            for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
396                *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
397            }
398
399            // Bias correction
400            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
401            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
402
403            // Compute step size
404            let step_size = self.lr / bias_correction1;
405
406            // Update parameters
407            if self.amsgrad {
408                let max_exp_avg_sq = state.max_exp_avg_sq.as_mut().unwrap();
409                for (max_v, v) in max_exp_avg_sq.iter_mut().zip(state.exp_avg_sq.iter()) {
410                    *max_v = max_v.max(*v);
411                }
412                for (p, (m, max_v)) in param_vec
413                    .iter_mut()
414                    .zip(state.exp_avg.iter().zip(max_exp_avg_sq.iter()))
415                {
416                    let denom = (max_v / bias_correction2).sqrt() + self.eps;
417                    *p -= step_size * m / denom;
418                }
419            } else {
420                for (p, (m, v)) in param_vec
421                    .iter_mut()
422                    .zip(state.exp_avg.iter().zip(state.exp_avg_sq.iter()))
423                {
424                    let denom = (v / bias_correction2).sqrt() + self.eps;
425                    *p -= step_size * m / denom;
426                }
427            }
428
429            let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
430            param.update_data(update);
431        }
432    }
433
434    fn zero_grad(&mut self) {
435        for param in &self.params {
436            param.zero_grad();
437        }
438    }
439
440    fn get_lr(&self) -> f32 {
441        self.lr
442    }
443
444    fn set_lr(&mut self, lr: f32) {
445        self.lr = lr;
446    }
447
448    fn parameters(&self) -> &[Parameter] {
449        &self.params
450    }
451}
452
453// =============================================================================
454// Tests
455// =============================================================================
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use axonml_autograd::Variable;
461
462    #[test]
463    fn test_adam_creation() {
464        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
465        let param = Parameter::from_variable(var);
466        let optimizer = Adam::new(vec![param], 0.001);
467
468        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
469        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
470        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
471    }
472
473    #[test]
474    fn test_adam_step() {
475        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
476        let param = Parameter::from_variable(var);
477
478        // Set gradient
479        param
480            .variable()
481            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
482
483        let mut optimizer = Adam::new(vec![param.clone()], 0.1);
484        optimizer.step();
485
486        let new_data = param.data().to_vec();
487        // Parameters should have changed
488        assert!((new_data[0] - 1.0).abs() > 1e-6);
489    }
490
491    #[test]
492    fn test_adamw_creation() {
493        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
494        let param = Parameter::from_variable(var);
495        let optimizer = AdamW::new(vec![param], 0.001);
496
497        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
498    }
499
500    #[test]
501    fn test_adam_builder_pattern() {
502        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
503        let param = Parameter::from_variable(var);
504
505        let optimizer = Adam::new(vec![param], 0.001)
506            .betas((0.95, 0.9999))
507            .eps(1e-7)
508            .weight_decay(0.01)
509            .amsgrad(true);
510
511        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
512        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
513        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
514        assert!(optimizer.amsgrad);
515    }
516}