Skip to main content

axonml_optim/
adam.rs

1//! Adam Optimizer - Adaptive Moment Estimation
2//!
3//! # File
4//! `crates/axonml-optim/src/adam.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22// =============================================================================
23// Adam
24// =============================================================================
25
26/// Adam optimizer.
27///
28/// Maintains per-parameter adaptive learning rates using first and
29/// second moment estimates of gradients.
30///
31/// Update rule:
32/// ```text
33/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
34/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
35/// m_hat = m_t / (1 - beta1^t)
36/// v_hat = v_t / (1 - beta2^t)
37/// param = param - lr * m_hat / (sqrt(v_hat) + eps)
38/// ```
39pub struct Adam {
40    /// Parameters to optimize.
41    params: Vec<Parameter>,
42    /// Learning rate.
43    lr: f32,
44    /// First moment decay rate.
45    beta1: f32,
46    /// Second moment decay rate.
47    beta2: f32,
48    /// Small constant for numerical stability.
49    eps: f32,
50    /// Weight decay (L2 regularization for standard Adam).
51    weight_decay: f32,
52    /// Whether to use `AMSGrad` variant.
53    amsgrad: bool,
54    /// Per-parameter state.
55    state: Vec<AdamState>,
56}
57
58/// State for Adam optimizer.
59///
60/// Stores momentum tensors on the same device as parameters (CPU or GPU).
61/// When parameters are on GPU, all state stays on GPU — zero CPU round-trips.
62#[derive(Debug, Clone)]
63struct AdamState {
64    /// First moment (mean of gradients) — on same device as param.
65    exp_avg: Tensor<f32>,
66    /// Second moment (variance of gradients) — on same device as param.
67    exp_avg_sq: Tensor<f32>,
68    /// Step count for bias correction.
69    step: usize,
70}
71
72impl AdamState {
73    fn new(shape: &[usize], device: axonml_core::Device) -> Self {
74        let size: usize = shape.iter().product();
75        let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
76        let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
77        if device.is_gpu() {
78            exp_avg = exp_avg.to_device(device).unwrap();
79            exp_avg_sq = exp_avg_sq.to_device(device).unwrap();
80        }
81        Self {
82            exp_avg,
83            exp_avg_sq,
84            step: 0,
85        }
86    }
87}
88
89impl Adam {
90    /// Creates a new Adam optimizer with default hyperparameters.
91    #[must_use]
92    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
93        Self::with_betas(params, lr, (0.9, 0.999))
94    }
95
96    /// Creates Adam with specified betas.
97    #[must_use]
98    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
99        Self {
100            params,
101            lr,
102            beta1: betas.0,
103            beta2: betas.1,
104            eps: 1e-8,
105            weight_decay: 0.0,
106            amsgrad: false,
107            state: Vec::new(),
108        }
109    }
110
111    /// Creates Adam with all options.
112    #[must_use]
113    pub fn with_options(
114        params: Vec<Parameter>,
115        lr: f32,
116        betas: (f32, f32),
117        eps: f32,
118        weight_decay: f32,
119        amsgrad: bool,
120    ) -> Self {
121        Self {
122            params,
123            lr,
124            beta1: betas.0,
125            beta2: betas.1,
126            eps,
127            weight_decay,
128            amsgrad,
129            state: Vec::new(),
130        }
131    }
132
133    /// Builder method to set betas.
134    #[must_use]
135    pub fn betas(mut self, betas: (f32, f32)) -> Self {
136        self.beta1 = betas.0;
137        self.beta2 = betas.1;
138        self
139    }
140
141    /// Builder method to set epsilon.
142    #[must_use]
143    pub fn eps(mut self, eps: f32) -> Self {
144        self.eps = eps;
145        self
146    }
147
148    /// Builder method to set weight decay.
149    #[must_use]
150    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
151        self.weight_decay = weight_decay;
152        self
153    }
154
155    /// Builder method to enable `AMSGrad`.
156    #[must_use]
157    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
158        self.amsgrad = amsgrad;
159        self
160    }
161
162    fn ensure_state_initialized(&mut self) {
163        if self.state.is_empty() {
164            self.state = self
165                .params
166                .iter()
167                .map(|p| {
168                    let data = p.data();
169                    AdamState::new(data.shape(), data.device())
170                })
171                .collect();
172        }
173    }
174}
175
176impl Optimizer for Adam {
177    fn step(&mut self) {
178        self.ensure_state_initialized();
179
180        for (i, param) in self.params.iter().enumerate() {
181            if !param.requires_grad() {
182                continue;
183            }
184
185            let grad = match param.grad() {
186                Some(g) => g,
187                None => continue,
188            };
189
190            let state = &mut self.state[i];
191            state.step += 1;
192
193            let param_data = param.data();
194
195            // GPU path: fused CUDA kernel — single launch per parameter, zero CPU copies
196            #[cfg(feature = "cuda")]
197            if param_data.device().is_gpu() {
198                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
199                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
200
201                // In-place fused Adam update on GPU
202                param_data.adam_step_inplace(
203                    &grad,
204                    &state.exp_avg,
205                    &state.exp_avg_sq,
206                    self.lr,
207                    self.beta1,
208                    self.beta2,
209                    self.eps,
210                    self.weight_decay,
211                    bias_correction1,
212                    bias_correction2,
213                );
214                // No need for update_data — the kernel modified the GPU buffer in-place
215                continue;
216            }
217
218            // CPU fallback — fused single-loop update for cache locality
219            let grad_vec = grad.to_vec();
220            let mut param_vec = param_data.to_vec();
221            let mut exp_avg_vec = state.exp_avg.to_vec();
222            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
223
224            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
225            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
226            let step_size = self.lr / bias_correction1;
227            let beta1 = self.beta1;
228            let beta2 = self.beta2;
229            let one_minus_beta1 = 1.0 - beta1;
230            let one_minus_beta2 = 1.0 - beta2;
231            let eps = self.eps;
232            let wd = self.weight_decay;
233
234            for i in 0..param_vec.len() {
235                let g = if wd == 0.0 {
236                    grad_vec[i]
237                } else {
238                    grad_vec[i] + wd * param_vec[i]
239                };
240                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
241                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
242                let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
243                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
244            }
245
246            state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
247            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
248            param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
249        }
250    }
251
252    fn zero_grad(&mut self) {
253        for param in &self.params {
254            param.zero_grad();
255        }
256    }
257
258    fn get_lr(&self) -> f32 {
259        self.lr
260    }
261
262    fn set_lr(&mut self, lr: f32) {
263        self.lr = lr;
264    }
265
266    fn parameters(&self) -> &[Parameter] {
267        &self.params
268    }
269}
270
271// =============================================================================
272// AdamW
273// =============================================================================
274
275/// `AdamW` optimizer (Adam with decoupled weight decay).
276///
277/// Unlike standard Adam which applies L2 regularization to the gradient,
278/// `AdamW` applies weight decay directly to the parameters.
279///
280/// Update rule:
281/// ```text
282/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
283/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
284/// m_hat = m_t / (1 - beta1^t)
285/// v_hat = v_t / (1 - beta2^t)
286/// param = param - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * param)
287/// ```
288pub struct AdamW {
289    /// Parameters to optimize.
290    params: Vec<Parameter>,
291    /// Learning rate.
292    lr: f32,
293    /// First moment decay rate.
294    beta1: f32,
295    /// Second moment decay rate.
296    beta2: f32,
297    /// Small constant for numerical stability.
298    eps: f32,
299    /// Decoupled weight decay coefficient.
300    weight_decay: f32,
301    /// Whether to use `AMSGrad` variant.
302    amsgrad: bool,
303    /// Per-parameter state.
304    state: Vec<AdamState>,
305}
306
307impl AdamW {
308    /// Creates a new `AdamW` optimizer with default hyperparameters.
309    #[must_use]
310    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
311        Self::with_betas(params, lr, (0.9, 0.999))
312    }
313
314    /// Creates `AdamW` with specified betas.
315    #[must_use]
316    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
317        Self {
318            params,
319            lr,
320            beta1: betas.0,
321            beta2: betas.1,
322            eps: 1e-8,
323            weight_decay: 0.01, // Default weight decay for AdamW
324            amsgrad: false,
325            state: Vec::new(),
326        }
327    }
328
329    /// Creates `AdamW` with all options.
330    #[must_use]
331    pub fn with_options(
332        params: Vec<Parameter>,
333        lr: f32,
334        betas: (f32, f32),
335        eps: f32,
336        weight_decay: f32,
337        amsgrad: bool,
338    ) -> Self {
339        Self {
340            params,
341            lr,
342            beta1: betas.0,
343            beta2: betas.1,
344            eps,
345            weight_decay,
346            amsgrad,
347            state: Vec::new(),
348        }
349    }
350
351    /// Builder method to set betas.
352    #[must_use]
353    pub fn betas(mut self, betas: (f32, f32)) -> Self {
354        self.beta1 = betas.0;
355        self.beta2 = betas.1;
356        self
357    }
358
359    /// Builder method to set epsilon.
360    #[must_use]
361    pub fn eps(mut self, eps: f32) -> Self {
362        self.eps = eps;
363        self
364    }
365
366    /// Builder method to set weight decay.
367    #[must_use]
368    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
369        self.weight_decay = weight_decay;
370        self
371    }
372
373    /// Builder method to enable `AMSGrad`.
374    #[must_use]
375    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
376        self.amsgrad = amsgrad;
377        self
378    }
379
380    fn ensure_state_initialized(&mut self) {
381        if self.state.is_empty() {
382            self.state = self
383                .params
384                .iter()
385                .map(|p| {
386                    let data = p.data();
387                    AdamState::new(data.shape(), data.device())
388                })
389                .collect();
390        }
391    }
392}
393
394impl Optimizer for AdamW {
395    fn step(&mut self) {
396        self.ensure_state_initialized();
397
398        for (i, param) in self.params.iter().enumerate() {
399            if !param.requires_grad() {
400                continue;
401            }
402
403            let grad = match param.grad() {
404                Some(g) => g,
405                None => continue,
406            };
407
408            let state = &mut self.state[i];
409            state.step += 1;
410
411            let param_data = param.data();
412
413            // GPU path: fused CUDA kernel
414            // Note: AdamW decoupled weight decay is different from Adam's L2.
415            // The kernel applies weight_decay as L2 (grad += wd*param).
416            // For true decoupled AdamW, we'd need a separate kernel.
417            // For now, use the same kernel — the difference is negligible for small wd.
418            #[cfg(feature = "cuda")]
419            if param_data.device().is_gpu() {
420                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
421                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
422
423                param_data.adam_step_inplace(
424                    &grad,
425                    &state.exp_avg,
426                    &state.exp_avg_sq,
427                    self.lr,
428                    self.beta1,
429                    self.beta2,
430                    self.eps,
431                    self.weight_decay,
432                    bias_correction1,
433                    bias_correction2,
434                );
435                continue;
436            }
437
438            // CPU fallback — fused single-loop update for cache locality
439            let grad_vec = grad.to_vec();
440            let mut param_vec = param_data.to_vec();
441            let mut exp_avg_vec = state.exp_avg.to_vec();
442            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
443
444            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
445            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
446            let step_size = self.lr / bias_correction1;
447            let beta1 = self.beta1;
448            let beta2 = self.beta2;
449            let one_minus_beta1 = 1.0 - beta1;
450            let one_minus_beta2 = 1.0 - beta2;
451            let eps = self.eps;
452            let wd_factor = 1.0 - self.lr * self.weight_decay;
453            let has_wd = self.weight_decay != 0.0;
454
455            for i in 0..param_vec.len() {
456                // Decoupled weight decay: apply directly to param
457                if has_wd {
458                    param_vec[i] *= wd_factor;
459                }
460                let g = grad_vec[i];
461                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
462                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
463                let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
464                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
465            }
466
467            state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
468            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
469            param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
470        }
471    }
472
473    fn zero_grad(&mut self) {
474        for param in &self.params {
475            param.zero_grad();
476        }
477    }
478
479    fn get_lr(&self) -> f32 {
480        self.lr
481    }
482
483    fn set_lr(&mut self, lr: f32) {
484        self.lr = lr;
485    }
486
487    fn parameters(&self) -> &[Parameter] {
488        &self.params
489    }
490}
491
492// =============================================================================
493// Tests
494// =============================================================================
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use axonml_autograd::Variable;
500
501    #[test]
502    fn test_adam_creation() {
503        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
504        let param = Parameter::from_variable(var);
505        let optimizer = Adam::new(vec![param], 0.001);
506
507        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
508        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
509        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
510    }
511
512    #[test]
513    fn test_adam_step() {
514        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
515        let param = Parameter::from_variable(var);
516
517        // Set gradient
518        param
519            .variable()
520            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
521
522        let mut optimizer = Adam::new(vec![param.clone()], 0.1);
523        optimizer.step();
524
525        let new_data = param.data().to_vec();
526        // Parameters should have changed
527        assert!((new_data[0] - 1.0).abs() > 1e-6);
528    }
529
530    #[test]
531    fn test_adamw_creation() {
532        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
533        let param = Parameter::from_variable(var);
534        let optimizer = AdamW::new(vec![param], 0.001);
535
536        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
537    }
538
539    #[test]
540    fn test_adam_builder_pattern() {
541        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
542        let param = Parameter::from_variable(var);
543
544        let optimizer = Adam::new(vec![param], 0.001)
545            .betas((0.95, 0.9999))
546            .eps(1e-7)
547            .weight_decay(0.01)
548            .amsgrad(true);
549
550        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
551        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
552        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
553        assert!(optimizer.amsgrad);
554    }
555}