Skip to main content

axonml_optim/
adam.rs

1//! Adam Optimizer - Adaptive Moment Estimation
2//!
3//! Implements Adam and `AdamW` (Adam with decoupled weight decay).
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13// =============================================================================
14// Adam
15// =============================================================================
16
17/// Adam optimizer.
18///
19/// Maintains per-parameter adaptive learning rates using first and
20/// second moment estimates of gradients.
21///
22/// Update rule:
23/// ```text
24/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
25/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
26/// m_hat = m_t / (1 - beta1^t)
27/// v_hat = v_t / (1 - beta2^t)
28/// param = param - lr * m_hat / (sqrt(v_hat) + eps)
29/// ```
30pub struct Adam {
31    /// Parameters to optimize.
32    params: Vec<Parameter>,
33    /// Learning rate.
34    lr: f32,
35    /// First moment decay rate.
36    beta1: f32,
37    /// Second moment decay rate.
38    beta2: f32,
39    /// Small constant for numerical stability.
40    eps: f32,
41    /// Weight decay (L2 regularization for standard Adam).
42    weight_decay: f32,
43    /// Whether to use `AMSGrad` variant.
44    amsgrad: bool,
45    /// Per-parameter state.
46    state: Vec<AdamState>,
47}
48
49/// State for Adam optimizer.
50#[derive(Debug, Clone)]
51struct AdamState {
52    /// First moment (mean of gradients).
53    exp_avg: Vec<f32>,
54    /// Second moment (variance of gradients).
55    exp_avg_sq: Vec<f32>,
56    /// Max second moment for `AMSGrad`.
57    max_exp_avg_sq: Option<Vec<f32>>,
58    /// Step count for bias correction.
59    step: usize,
60}
61
62impl AdamState {
63    fn new(size: usize, amsgrad: bool) -> Self {
64        Self {
65            exp_avg: vec![0.0; size],
66            exp_avg_sq: vec![0.0; size],
67            max_exp_avg_sq: if amsgrad { Some(vec![0.0; size]) } else { None },
68            step: 0,
69        }
70    }
71}
72
73impl Adam {
74    /// Creates a new Adam optimizer with default hyperparameters.
75    #[must_use]
76    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
77        Self::with_betas(params, lr, (0.9, 0.999))
78    }
79
80    /// Creates Adam with specified betas.
81    #[must_use]
82    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
83        Self {
84            params,
85            lr,
86            beta1: betas.0,
87            beta2: betas.1,
88            eps: 1e-8,
89            weight_decay: 0.0,
90            amsgrad: false,
91            state: Vec::new(),
92        }
93    }
94
95    /// Creates Adam with all options.
96    #[must_use]
97    pub fn with_options(
98        params: Vec<Parameter>,
99        lr: f32,
100        betas: (f32, f32),
101        eps: f32,
102        weight_decay: f32,
103        amsgrad: bool,
104    ) -> Self {
105        Self {
106            params,
107            lr,
108            beta1: betas.0,
109            beta2: betas.1,
110            eps,
111            weight_decay,
112            amsgrad,
113            state: Vec::new(),
114        }
115    }
116
117    /// Builder method to set betas.
118    #[must_use]
119    pub fn betas(mut self, betas: (f32, f32)) -> Self {
120        self.beta1 = betas.0;
121        self.beta2 = betas.1;
122        self
123    }
124
125    /// Builder method to set epsilon.
126    #[must_use]
127    pub fn eps(mut self, eps: f32) -> Self {
128        self.eps = eps;
129        self
130    }
131
132    /// Builder method to set weight decay.
133    #[must_use]
134    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
135        self.weight_decay = weight_decay;
136        self
137    }
138
139    /// Builder method to enable `AMSGrad`.
140    #[must_use]
141    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
142        self.amsgrad = amsgrad;
143        self
144    }
145
146    fn ensure_state_initialized(&mut self) {
147        if self.state.is_empty() {
148            self.state = self
149                .params
150                .iter()
151                .map(|p| AdamState::new(p.numel(), self.amsgrad))
152                .collect();
153        }
154    }
155}
156
157impl Optimizer for Adam {
158    fn step(&mut self) {
159        self.ensure_state_initialized();
160
161        for (i, param) in self.params.iter().enumerate() {
162            if !param.requires_grad() {
163                continue;
164            }
165
166            let grad = match param.grad() {
167                Some(g) => g,
168                None => continue,
169            };
170
171            let grad_vec = grad.to_vec();
172            let state = &mut self.state[i];
173            state.step += 1;
174
175            let param_data = param.data();
176            let mut param_vec = param_data.to_vec();
177
178            // Apply L2 regularization to gradient (standard Adam weight decay)
179            let grad_vec: Vec<f32> = if self.weight_decay == 0.0 {
180                grad_vec
181            } else {
182                grad_vec
183                    .iter()
184                    .zip(param_vec.iter())
185                    .map(|(g, p)| g + self.weight_decay * p)
186                    .collect()
187            };
188
189            // Update biased first moment estimate
190            for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
191                *m = self.beta1 * *m + (1.0 - self.beta1) * g;
192            }
193
194            // Update biased second moment estimate
195            for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
196                *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
197            }
198
199            // Bias correction
200            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
201            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
202
203            // Compute step size
204            let step_size = self.lr / bias_correction1;
205
206            // Update parameters
207            if self.amsgrad {
208                // AMSGrad variant
209                let max_exp_avg_sq = state.max_exp_avg_sq.as_mut().unwrap();
210                for (max_v, v) in max_exp_avg_sq.iter_mut().zip(state.exp_avg_sq.iter()) {
211                    *max_v = max_v.max(*v);
212                }
213                for (p, (m, max_v)) in param_vec
214                    .iter_mut()
215                    .zip(state.exp_avg.iter().zip(max_exp_avg_sq.iter()))
216                {
217                    let denom = (max_v / bias_correction2).sqrt() + self.eps;
218                    *p -= step_size * m / denom;
219                }
220            } else {
221                // Standard Adam
222                for (p, (m, v)) in param_vec
223                    .iter_mut()
224                    .zip(state.exp_avg.iter().zip(state.exp_avg_sq.iter()))
225                {
226                    let denom = (v / bias_correction2).sqrt() + self.eps;
227                    *p -= step_size * m / denom;
228                }
229            }
230
231            let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
232            param.update_data(update);
233        }
234    }
235
236    fn zero_grad(&mut self) {
237        for param in &self.params {
238            param.zero_grad();
239        }
240    }
241
242    fn get_lr(&self) -> f32 {
243        self.lr
244    }
245
246    fn set_lr(&mut self, lr: f32) {
247        self.lr = lr;
248    }
249
250    fn parameters(&self) -> &[Parameter] {
251        &self.params
252    }
253}
254
255// =============================================================================
256// AdamW
257// =============================================================================
258
259/// `AdamW` optimizer (Adam with decoupled weight decay).
260///
261/// Unlike standard Adam which applies L2 regularization to the gradient,
262/// `AdamW` applies weight decay directly to the parameters.
263///
264/// Update rule:
265/// ```text
266/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
267/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
268/// m_hat = m_t / (1 - beta1^t)
269/// v_hat = v_t / (1 - beta2^t)
270/// param = param - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * param)
271/// ```
272pub struct AdamW {
273    /// Parameters to optimize.
274    params: Vec<Parameter>,
275    /// Learning rate.
276    lr: f32,
277    /// First moment decay rate.
278    beta1: f32,
279    /// Second moment decay rate.
280    beta2: f32,
281    /// Small constant for numerical stability.
282    eps: f32,
283    /// Decoupled weight decay coefficient.
284    weight_decay: f32,
285    /// Whether to use `AMSGrad` variant.
286    amsgrad: bool,
287    /// Per-parameter state.
288    state: Vec<AdamState>,
289}
290
291impl AdamW {
292    /// Creates a new `AdamW` optimizer with default hyperparameters.
293    #[must_use]
294    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
295        Self::with_betas(params, lr, (0.9, 0.999))
296    }
297
298    /// Creates `AdamW` with specified betas.
299    #[must_use]
300    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
301        Self {
302            params,
303            lr,
304            beta1: betas.0,
305            beta2: betas.1,
306            eps: 1e-8,
307            weight_decay: 0.01, // Default weight decay for AdamW
308            amsgrad: false,
309            state: Vec::new(),
310        }
311    }
312
313    /// Creates `AdamW` with all options.
314    #[must_use]
315    pub fn with_options(
316        params: Vec<Parameter>,
317        lr: f32,
318        betas: (f32, f32),
319        eps: f32,
320        weight_decay: f32,
321        amsgrad: bool,
322    ) -> Self {
323        Self {
324            params,
325            lr,
326            beta1: betas.0,
327            beta2: betas.1,
328            eps,
329            weight_decay,
330            amsgrad,
331            state: Vec::new(),
332        }
333    }
334
335    /// Builder method to set betas.
336    #[must_use]
337    pub fn betas(mut self, betas: (f32, f32)) -> Self {
338        self.beta1 = betas.0;
339        self.beta2 = betas.1;
340        self
341    }
342
343    /// Builder method to set epsilon.
344    #[must_use]
345    pub fn eps(mut self, eps: f32) -> Self {
346        self.eps = eps;
347        self
348    }
349
350    /// Builder method to set weight decay.
351    #[must_use]
352    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
353        self.weight_decay = weight_decay;
354        self
355    }
356
357    /// Builder method to enable `AMSGrad`.
358    #[must_use]
359    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
360        self.amsgrad = amsgrad;
361        self
362    }
363
364    fn ensure_state_initialized(&mut self) {
365        if self.state.is_empty() {
366            self.state = self
367                .params
368                .iter()
369                .map(|p| AdamState::new(p.numel(), self.amsgrad))
370                .collect();
371        }
372    }
373}
374
375impl Optimizer for AdamW {
376    fn step(&mut self) {
377        self.ensure_state_initialized();
378
379        for (i, param) in self.params.iter().enumerate() {
380            if !param.requires_grad() {
381                continue;
382            }
383
384            let grad = match param.grad() {
385                Some(g) => g,
386                None => continue,
387            };
388
389            let grad_vec = grad.to_vec();
390            let state = &mut self.state[i];
391            state.step += 1;
392
393            let param_data = param.data();
394            let mut param_vec = param_data.to_vec();
395
396            // Decoupled weight decay (applied directly to parameters)
397            if self.weight_decay != 0.0 {
398                for p in &mut param_vec {
399                    *p *= 1.0 - self.lr * self.weight_decay;
400                }
401            }
402
403            // Update biased first moment estimate
404            for (m, g) in state.exp_avg.iter_mut().zip(grad_vec.iter()) {
405                *m = self.beta1 * *m + (1.0 - self.beta1) * g;
406            }
407
408            // Update biased second moment estimate
409            for (v, g) in state.exp_avg_sq.iter_mut().zip(grad_vec.iter()) {
410                *v = self.beta2 * *v + (1.0 - self.beta2) * g * g;
411            }
412
413            // Bias correction
414            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
415            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
416
417            // Compute step size
418            let step_size = self.lr / bias_correction1;
419
420            // Update parameters
421            if self.amsgrad {
422                let max_exp_avg_sq = state.max_exp_avg_sq.as_mut().unwrap();
423                for (max_v, v) in max_exp_avg_sq.iter_mut().zip(state.exp_avg_sq.iter()) {
424                    *max_v = max_v.max(*v);
425                }
426                for (p, (m, max_v)) in param_vec
427                    .iter_mut()
428                    .zip(state.exp_avg.iter().zip(max_exp_avg_sq.iter()))
429                {
430                    let denom = (max_v / bias_correction2).sqrt() + self.eps;
431                    *p -= step_size * m / denom;
432                }
433            } else {
434                for (p, (m, v)) in param_vec
435                    .iter_mut()
436                    .zip(state.exp_avg.iter().zip(state.exp_avg_sq.iter()))
437                {
438                    let denom = (v / bias_correction2).sqrt() + self.eps;
439                    *p -= step_size * m / denom;
440                }
441            }
442
443            let update = Tensor::from_vec(param_vec, param_data.shape()).unwrap();
444            param.update_data(update);
445        }
446    }
447
448    fn zero_grad(&mut self) {
449        for param in &self.params {
450            param.zero_grad();
451        }
452    }
453
454    fn get_lr(&self) -> f32 {
455        self.lr
456    }
457
458    fn set_lr(&mut self, lr: f32) {
459        self.lr = lr;
460    }
461
462    fn parameters(&self) -> &[Parameter] {
463        &self.params
464    }
465}
466
467// =============================================================================
468// Tests
469// =============================================================================
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use axonml_autograd::Variable;
475
476    #[test]
477    fn test_adam_creation() {
478        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
479        let param = Parameter::from_variable(var);
480        let optimizer = Adam::new(vec![param], 0.001);
481
482        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
483        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
484        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
485    }
486
487    #[test]
488    fn test_adam_step() {
489        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
490        let param = Parameter::from_variable(var);
491
492        // Set gradient
493        param
494            .variable()
495            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
496
497        let mut optimizer = Adam::new(vec![param.clone()], 0.1);
498        optimizer.step();
499
500        let new_data = param.data().to_vec();
501        // Parameters should have changed
502        assert!((new_data[0] - 1.0).abs() > 1e-6);
503    }
504
505    #[test]
506    fn test_adamw_creation() {
507        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
508        let param = Parameter::from_variable(var);
509        let optimizer = AdamW::new(vec![param], 0.001);
510
511        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
512    }
513
514    #[test]
515    fn test_adam_builder_pattern() {
516        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
517        let param = Parameter::from_variable(var);
518
519        let optimizer = Adam::new(vec![param], 0.001)
520            .betas((0.95, 0.9999))
521            .eps(1e-7)
522            .weight_decay(0.01)
523            .amsgrad(true);
524
525        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
526        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
527        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
528        assert!(optimizer.amsgrad);
529    }
530}