Skip to main content

axonml_optim/
adam.rs

1//! Adam Optimizer - Adaptive Moment Estimation
2//!
3//! # File
4//! `crates/axonml-optim/src/adam.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22// =============================================================================
23// Adam
24// =============================================================================
25
26/// Adam optimizer.
27///
28/// Maintains per-parameter adaptive learning rates using first and
29/// second moment estimates of gradients.
30///
31/// Update rule:
32/// ```text
33/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
34/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
35/// m_hat = m_t / (1 - beta1^t)
36/// v_hat = v_t / (1 - beta2^t)
37/// param = param - lr * m_hat / (sqrt(v_hat) + eps)
38/// ```
39pub struct Adam {
40    /// Parameters to optimize.
41    params: Vec<Parameter>,
42    /// Learning rate.
43    lr: f32,
44    /// First moment decay rate.
45    beta1: f32,
46    /// Second moment decay rate.
47    beta2: f32,
48    /// Small constant for numerical stability.
49    eps: f32,
50    /// Weight decay (L2 regularization for standard Adam).
51    weight_decay: f32,
52    /// Whether to use `AMSGrad` variant.
53    amsgrad: bool,
54    /// Per-parameter state.
55    state: Vec<AdamState>,
56}
57
58/// State for Adam optimizer.
59///
60/// Stores momentum tensors on the same device as parameters (CPU or GPU).
61/// When parameters are on GPU, all state stays on GPU — zero CPU round-trips.
62#[derive(Debug, Clone)]
63struct AdamState {
64    /// First moment (mean of gradients) — on same device as param.
65    exp_avg: Tensor<f32>,
66    /// Second moment (variance of gradients) — on same device as param.
67    exp_avg_sq: Tensor<f32>,
68    /// Maximum of all past exp_avg_sq values (for AMSGrad).
69    max_exp_avg_sq: Option<Tensor<f32>>,
70    /// Step count for bias correction.
71    step: usize,
72}
73
74impl AdamState {
75    fn new(shape: &[usize], device: axonml_core::Device) -> Self {
76        let size: usize = shape.iter().product();
77        let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
78        let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
79        if device.is_gpu() {
80            exp_avg = exp_avg.to_device(device).expect("device transfer failed");
81            exp_avg_sq = exp_avg_sq.to_device(device).expect("device transfer failed");
82        }
83        Self {
84            exp_avg,
85            exp_avg_sq,
86            max_exp_avg_sq: None, // Initialized on first use if amsgrad=true
87            step: 0,
88        }
89    }
90}
91
92impl Adam {
93    /// Creates a new Adam optimizer with default hyperparameters.
94    #[must_use]
95    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
96        Self::with_betas(params, lr, (0.9, 0.999))
97    }
98
99    /// Creates Adam with specified betas.
100    #[must_use]
101    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
102        Self {
103            params,
104            lr,
105            beta1: betas.0,
106            beta2: betas.1,
107            eps: 1e-8,
108            weight_decay: 0.0,
109            amsgrad: false,
110            state: Vec::new(),
111        }
112    }
113
114    /// Creates Adam with all options.
115    #[must_use]
116    pub fn with_options(
117        params: Vec<Parameter>,
118        lr: f32,
119        betas: (f32, f32),
120        eps: f32,
121        weight_decay: f32,
122        amsgrad: bool,
123    ) -> Self {
124        Self {
125            params,
126            lr,
127            beta1: betas.0,
128            beta2: betas.1,
129            eps,
130            weight_decay,
131            amsgrad,
132            state: Vec::new(),
133        }
134    }
135
136    /// Builder method to set betas.
137    #[must_use]
138    pub fn betas(mut self, betas: (f32, f32)) -> Self {
139        self.beta1 = betas.0;
140        self.beta2 = betas.1;
141        self
142    }
143
144    /// Builder method to set epsilon.
145    #[must_use]
146    pub fn eps(mut self, eps: f32) -> Self {
147        self.eps = eps;
148        self
149    }
150
151    /// Builder method to set weight decay.
152    #[must_use]
153    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
154        self.weight_decay = weight_decay;
155        self
156    }
157
158    /// Builder method to enable `AMSGrad`.
159    #[must_use]
160    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
161        self.amsgrad = amsgrad;
162        self
163    }
164
165    fn ensure_state_initialized(&mut self) {
166        if self.state.is_empty() {
167            self.state = self
168                .params
169                .iter()
170                .map(|p| {
171                    let data = p.data();
172                    AdamState::new(data.shape(), data.device())
173                })
174                .collect();
175        }
176    }
177}
178
179impl Optimizer for Adam {
180    fn step(&mut self) {
181        self.ensure_state_initialized();
182
183        for (i, param) in self.params.iter().enumerate() {
184            if !param.requires_grad() {
185                continue;
186            }
187
188            let grad = match param.grad() {
189                Some(g) => g,
190                None => continue,
191            };
192
193            let state = &mut self.state[i];
194            state.step += 1;
195
196            let param_data = param.data();
197
198            // GPU path: fused CUDA kernel — single launch per parameter, zero CPU copies
199            #[cfg(feature = "cuda")]
200            if param_data.device().is_gpu() {
201                // Auto-migrate gradient to GPU if backward produced CPU gradients
202                // (happens when backward functions use CPU fallback computation)
203                let grad = if !grad.device().is_gpu() {
204                    grad.to_device(param_data.device())
205                        .expect("Adam: failed to migrate CPU gradient to GPU")
206                } else {
207                    grad
208                };
209                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
210                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
211
212                // In-place fused Adam update on GPU
213                param_data.adam_step_inplace(
214                    &grad,
215                    &state.exp_avg,
216                    &state.exp_avg_sq,
217                    self.lr,
218                    self.beta1,
219                    self.beta2,
220                    self.eps,
221                    self.weight_decay,
222                    bias_correction1,
223                    bias_correction2,
224                );
225                // No need for update_data — the kernel modified the GPU buffer in-place
226                continue;
227            }
228
229            // CPU fallback — fused single-loop update for cache locality
230            let grad_vec = grad.to_vec();
231            let mut param_vec = param_data.to_vec();
232            let mut exp_avg_vec = state.exp_avg.to_vec();
233            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
234
235            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
236            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
237            let step_size = self.lr / bias_correction1;
238            let beta1 = self.beta1;
239            let beta2 = self.beta2;
240            let one_minus_beta1 = 1.0 - beta1;
241            let one_minus_beta2 = 1.0 - beta2;
242            let eps = self.eps;
243            let wd = self.weight_decay;
244
245            // AMSGrad: track max of all past exp_avg_sq values
246            let mut max_sq_vec = if self.amsgrad {
247                state
248                    .max_exp_avg_sq
249                    .as_ref()
250                    .map(|t| t.to_vec())
251                    .unwrap_or_else(|| vec![0.0f32; param_vec.len()])
252            } else {
253                Vec::new()
254            };
255
256            for i in 0..param_vec.len() {
257                let g = if wd == 0.0 {
258                    grad_vec[i]
259                } else {
260                    grad_vec[i] + wd * param_vec[i]
261                };
262                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
263                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
264
265                let v_hat = if self.amsgrad {
266                    max_sq_vec[i] = max_sq_vec[i].max(exp_avg_sq_vec[i]);
267                    max_sq_vec[i] / bias_correction2
268                } else {
269                    exp_avg_sq_vec[i] / bias_correction2
270                };
271
272                let denom = v_hat.sqrt() + eps;
273                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
274            }
275
276            state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).expect("tensor creation failed");
277            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).expect("tensor creation failed");
278            if self.amsgrad {
279                state.max_exp_avg_sq = Some(Tensor::from_vec(max_sq_vec, param_data.shape()).expect("tensor creation failed"));
280            }
281            param.update_data(Tensor::from_vec(param_vec, param_data.shape()).expect("tensor creation failed"));
282        }
283    }
284
285    fn zero_grad(&mut self) {
286        for param in &self.params {
287            param.zero_grad();
288        }
289    }
290
291    fn get_lr(&self) -> f32 {
292        self.lr
293    }
294
295    fn set_lr(&mut self, lr: f32) {
296        self.lr = lr;
297    }
298
299    fn parameters(&self) -> &[Parameter] {
300        &self.params
301    }
302}
303
304// =============================================================================
305// AdamW
306// =============================================================================
307
308/// `AdamW` optimizer (Adam with decoupled weight decay).
309///
310/// Unlike standard Adam which applies L2 regularization to the gradient,
311/// `AdamW` applies weight decay directly to the parameters.
312///
313/// Update rule:
314/// ```text
315/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
316/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
317/// m_hat = m_t / (1 - beta1^t)
318/// v_hat = v_t / (1 - beta2^t)
319/// param = param - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * param)
320/// ```
321pub struct AdamW {
322    /// Parameters to optimize.
323    params: Vec<Parameter>,
324    /// Learning rate.
325    lr: f32,
326    /// First moment decay rate.
327    beta1: f32,
328    /// Second moment decay rate.
329    beta2: f32,
330    /// Small constant for numerical stability.
331    eps: f32,
332    /// Decoupled weight decay coefficient.
333    weight_decay: f32,
334    /// Whether to use `AMSGrad` variant.
335    amsgrad: bool,
336    /// Per-parameter state.
337    state: Vec<AdamState>,
338}
339
340impl AdamW {
341    /// Creates a new `AdamW` optimizer with default hyperparameters.
342    #[must_use]
343    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
344        Self::with_betas(params, lr, (0.9, 0.999))
345    }
346
347    /// Creates `AdamW` with specified betas.
348    #[must_use]
349    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
350        Self {
351            params,
352            lr,
353            beta1: betas.0,
354            beta2: betas.1,
355            eps: 1e-8,
356            weight_decay: 0.01, // Default weight decay for AdamW
357            amsgrad: false,
358            state: Vec::new(),
359        }
360    }
361
362    /// Creates `AdamW` with all options.
363    #[must_use]
364    pub fn with_options(
365        params: Vec<Parameter>,
366        lr: f32,
367        betas: (f32, f32),
368        eps: f32,
369        weight_decay: f32,
370        amsgrad: bool,
371    ) -> Self {
372        Self {
373            params,
374            lr,
375            beta1: betas.0,
376            beta2: betas.1,
377            eps,
378            weight_decay,
379            amsgrad,
380            state: Vec::new(),
381        }
382    }
383
384    /// Builder method to set betas.
385    #[must_use]
386    pub fn betas(mut self, betas: (f32, f32)) -> Self {
387        self.beta1 = betas.0;
388        self.beta2 = betas.1;
389        self
390    }
391
392    /// Builder method to set epsilon.
393    #[must_use]
394    pub fn eps(mut self, eps: f32) -> Self {
395        self.eps = eps;
396        self
397    }
398
399    /// Builder method to set weight decay.
400    #[must_use]
401    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
402        self.weight_decay = weight_decay;
403        self
404    }
405
406    /// Builder method to enable `AMSGrad`.
407    #[must_use]
408    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
409        self.amsgrad = amsgrad;
410        self
411    }
412
413    fn ensure_state_initialized(&mut self) {
414        if self.state.is_empty() {
415            self.state = self
416                .params
417                .iter()
418                .map(|p| {
419                    let data = p.data();
420                    AdamState::new(data.shape(), data.device())
421                })
422                .collect();
423        }
424    }
425}
426
427impl Optimizer for AdamW {
428    fn step(&mut self) {
429        self.ensure_state_initialized();
430
431        for (i, param) in self.params.iter().enumerate() {
432            if !param.requires_grad() {
433                continue;
434            }
435
436            let grad = match param.grad() {
437                Some(g) => g,
438                None => continue,
439            };
440
441            let state = &mut self.state[i];
442            state.step += 1;
443
444            let param_data = param.data();
445
446            // GPU path: decoupled weight decay + fused Adam step
447            #[cfg(feature = "cuda")]
448            if param_data.device().is_gpu() {
449                // Auto-migrate gradient to GPU if backward produced CPU gradients
450                let grad = if !grad.device().is_gpu() {
451                    grad.to_device(param_data.device())
452                        .expect("AdamW: failed to migrate CPU gradient to GPU")
453                } else {
454                    grad
455                };
456
457                // DECOUPLED weight decay: param *= (1 - lr * wd)
458                // This is the key difference from Adam's L2 regularization.
459                // Applied BEFORE the Adam update, directly to parameters.
460                if self.weight_decay > 0.0 {
461                    let decay_factor = 1.0 - self.lr * self.weight_decay;
462                    let decayed = param_data.mul_scalar(decay_factor);
463                    param.update_data(decayed);
464                }
465
466                // Re-read param_data after potential decay update
467                let param_data = param.data();
468
469                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
470                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
471
472                // Adam step with wd=0 (decay already applied above)
473                param_data.adam_step_inplace(
474                    &grad,
475                    &state.exp_avg,
476                    &state.exp_avg_sq,
477                    self.lr,
478                    self.beta1,
479                    self.beta2,
480                    self.eps,
481                    0.0, // wd=0: decoupled decay already applied
482                    bias_correction1,
483                    bias_correction2,
484                );
485                continue;
486            }
487
488            // CPU fallback — fused single-loop update for cache locality
489            let grad_vec = grad.to_vec();
490            let mut param_vec = param_data.to_vec();
491            let mut exp_avg_vec = state.exp_avg.to_vec();
492            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
493
494            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
495            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
496            let step_size = self.lr / bias_correction1;
497            let beta1 = self.beta1;
498            let beta2 = self.beta2;
499            let one_minus_beta1 = 1.0 - beta1;
500            let one_minus_beta2 = 1.0 - beta2;
501            let eps = self.eps;
502            let wd_factor = 1.0 - self.lr * self.weight_decay;
503            let has_wd = self.weight_decay != 0.0;
504
505            for i in 0..param_vec.len() {
506                // Decoupled weight decay: apply directly to param
507                if has_wd {
508                    param_vec[i] *= wd_factor;
509                }
510                let g = grad_vec[i];
511                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
512                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
513                let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
514                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
515            }
516
517            state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
518            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
519            param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
520        }
521    }
522
523    fn zero_grad(&mut self) {
524        for param in &self.params {
525            param.zero_grad();
526        }
527    }
528
529    fn get_lr(&self) -> f32 {
530        self.lr
531    }
532
533    fn set_lr(&mut self, lr: f32) {
534        self.lr = lr;
535    }
536
537    fn parameters(&self) -> &[Parameter] {
538        &self.params
539    }
540}
541
542// =============================================================================
543// Tests
544// =============================================================================
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use axonml_autograd::Variable;
550
551    #[test]
552    fn test_adam_creation() {
553        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
554        let param = Parameter::from_variable(var);
555        let optimizer = Adam::new(vec![param], 0.001);
556
557        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
558        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
559        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
560    }
561
562    #[test]
563    fn test_adam_step() {
564        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
565        let param = Parameter::from_variable(var);
566
567        // Set gradient
568        param
569            .variable()
570            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
571
572        let mut optimizer = Adam::new(vec![param.clone()], 0.1);
573        optimizer.step();
574
575        let new_data = param.data().to_vec();
576        // Parameters should have changed
577        assert!((new_data[0] - 1.0).abs() > 1e-6);
578    }
579
580    #[test]
581    fn test_adamw_creation() {
582        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
583        let param = Parameter::from_variable(var);
584        let optimizer = AdamW::new(vec![param], 0.001);
585
586        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
587    }
588
589    #[test]
590    fn test_adam_builder_pattern() {
591        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"), true);
592        let param = Parameter::from_variable(var);
593
594        let optimizer = Adam::new(vec![param], 0.001)
595            .betas((0.95, 0.9999))
596            .eps(1e-7)
597            .weight_decay(0.01)
598            .amsgrad(true);
599
600        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
601        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
602        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
603        assert!(optimizer.amsgrad);
604    }
605}