Skip to main content

axonml_optim/
adam.rs

1//! Adam Optimizer - Adaptive Moment Estimation
2//!
3//! Implements Adam and `AdamW` (Adam with decoupled weight decay).
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_nn::Parameter;
9use axonml_tensor::Tensor;
10
11use crate::optimizer::Optimizer;
12
13// =============================================================================
14// Adam
15// =============================================================================
16
17/// Adam optimizer.
18///
19/// Maintains per-parameter adaptive learning rates using first and
20/// second moment estimates of gradients.
21///
22/// Update rule:
23/// ```text
24/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
25/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
26/// m_hat = m_t / (1 - beta1^t)
27/// v_hat = v_t / (1 - beta2^t)
28/// param = param - lr * m_hat / (sqrt(v_hat) + eps)
29/// ```
30pub struct Adam {
31    /// Parameters to optimize.
32    params: Vec<Parameter>,
33    /// Learning rate.
34    lr: f32,
35    /// First moment decay rate.
36    beta1: f32,
37    /// Second moment decay rate.
38    beta2: f32,
39    /// Small constant for numerical stability.
40    eps: f32,
41    /// Weight decay (L2 regularization for standard Adam).
42    weight_decay: f32,
43    /// Whether to use `AMSGrad` variant.
44    amsgrad: bool,
45    /// Per-parameter state.
46    state: Vec<AdamState>,
47}
48
49/// State for Adam optimizer.
50///
51/// Stores momentum tensors on the same device as parameters (CPU or GPU).
52/// When parameters are on GPU, all state stays on GPU — zero CPU round-trips.
53#[derive(Debug, Clone)]
54struct AdamState {
55    /// First moment (mean of gradients) — on same device as param.
56    exp_avg: Tensor<f32>,
57    /// Second moment (variance of gradients) — on same device as param.
58    exp_avg_sq: Tensor<f32>,
59    /// Step count for bias correction.
60    step: usize,
61}
62
63impl AdamState {
64    fn new(shape: &[usize], device: axonml_core::Device) -> Self {
65        let size: usize = shape.iter().product();
66        let mut exp_avg = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
67        let mut exp_avg_sq = Tensor::from_vec(vec![0.0f32; size], shape).unwrap();
68        if device.is_gpu() {
69            exp_avg = exp_avg.to_device(device.clone()).unwrap();
70            exp_avg_sq = exp_avg_sq.to_device(device).unwrap();
71        }
72        Self {
73            exp_avg,
74            exp_avg_sq,
75            step: 0,
76        }
77    }
78}
79
80impl Adam {
81    /// Creates a new Adam optimizer with default hyperparameters.
82    #[must_use]
83    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
84        Self::with_betas(params, lr, (0.9, 0.999))
85    }
86
87    /// Creates Adam with specified betas.
88    #[must_use]
89    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
90        Self {
91            params,
92            lr,
93            beta1: betas.0,
94            beta2: betas.1,
95            eps: 1e-8,
96            weight_decay: 0.0,
97            amsgrad: false,
98            state: Vec::new(),
99        }
100    }
101
102    /// Creates Adam with all options.
103    #[must_use]
104    pub fn with_options(
105        params: Vec<Parameter>,
106        lr: f32,
107        betas: (f32, f32),
108        eps: f32,
109        weight_decay: f32,
110        amsgrad: bool,
111    ) -> Self {
112        Self {
113            params,
114            lr,
115            beta1: betas.0,
116            beta2: betas.1,
117            eps,
118            weight_decay,
119            amsgrad,
120            state: Vec::new(),
121        }
122    }
123
124    /// Builder method to set betas.
125    #[must_use]
126    pub fn betas(mut self, betas: (f32, f32)) -> Self {
127        self.beta1 = betas.0;
128        self.beta2 = betas.1;
129        self
130    }
131
132    /// Builder method to set epsilon.
133    #[must_use]
134    pub fn eps(mut self, eps: f32) -> Self {
135        self.eps = eps;
136        self
137    }
138
139    /// Builder method to set weight decay.
140    #[must_use]
141    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
142        self.weight_decay = weight_decay;
143        self
144    }
145
146    /// Builder method to enable `AMSGrad`.
147    #[must_use]
148    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
149        self.amsgrad = amsgrad;
150        self
151    }
152
153    fn ensure_state_initialized(&mut self) {
154        if self.state.is_empty() {
155            self.state = self
156                .params
157                .iter()
158                .map(|p| {
159                    let data = p.data();
160                    AdamState::new(data.shape(), data.device())
161                })
162                .collect();
163        }
164    }
165}
166
167impl Optimizer for Adam {
168    fn step(&mut self) {
169        self.ensure_state_initialized();
170
171        for (i, param) in self.params.iter().enumerate() {
172            if !param.requires_grad() {
173                continue;
174            }
175
176            let grad = match param.grad() {
177                Some(g) => g,
178                None => continue,
179            };
180
181            let state = &mut self.state[i];
182            state.step += 1;
183
184            let param_data = param.data();
185
186            // GPU path: fused CUDA kernel — single launch per parameter, zero CPU copies
187            #[cfg(feature = "cuda")]
188            if param_data.device().is_gpu() {
189                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
190                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
191
192                // In-place fused Adam update on GPU
193                param_data.adam_step_inplace(
194                    &grad,
195                    &state.exp_avg,
196                    &state.exp_avg_sq,
197                    self.lr,
198                    self.beta1,
199                    self.beta2,
200                    self.eps,
201                    self.weight_decay,
202                    bias_correction1,
203                    bias_correction2,
204                );
205                // No need for update_data — the kernel modified the GPU buffer in-place
206                continue;
207            }
208
209            // CPU fallback — fused single-loop update for cache locality
210            let grad_vec = grad.to_vec();
211            let mut param_vec = param_data.to_vec();
212            let mut exp_avg_vec = state.exp_avg.to_vec();
213            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
214
215            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
216            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
217            let step_size = self.lr / bias_correction1;
218            let beta1 = self.beta1;
219            let beta2 = self.beta2;
220            let one_minus_beta1 = 1.0 - beta1;
221            let one_minus_beta2 = 1.0 - beta2;
222            let eps = self.eps;
223            let wd = self.weight_decay;
224
225            for i in 0..param_vec.len() {
226                let g = if wd == 0.0 {
227                    grad_vec[i]
228                } else {
229                    grad_vec[i] + wd * param_vec[i]
230                };
231                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
232                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
233                let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
234                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
235            }
236
237            state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
238            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
239            param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
240        }
241    }
242
243    fn zero_grad(&mut self) {
244        for param in &self.params {
245            param.zero_grad();
246        }
247    }
248
249    fn get_lr(&self) -> f32 {
250        self.lr
251    }
252
253    fn set_lr(&mut self, lr: f32) {
254        self.lr = lr;
255    }
256
257    fn parameters(&self) -> &[Parameter] {
258        &self.params
259    }
260}
261
262// =============================================================================
263// AdamW
264// =============================================================================
265
266/// `AdamW` optimizer (Adam with decoupled weight decay).
267///
268/// Unlike standard Adam which applies L2 regularization to the gradient,
269/// `AdamW` applies weight decay directly to the parameters.
270///
271/// Update rule:
272/// ```text
273/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
274/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
275/// m_hat = m_t / (1 - beta1^t)
276/// v_hat = v_t / (1 - beta2^t)
277/// param = param - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * param)
278/// ```
279pub struct AdamW {
280    /// Parameters to optimize.
281    params: Vec<Parameter>,
282    /// Learning rate.
283    lr: f32,
284    /// First moment decay rate.
285    beta1: f32,
286    /// Second moment decay rate.
287    beta2: f32,
288    /// Small constant for numerical stability.
289    eps: f32,
290    /// Decoupled weight decay coefficient.
291    weight_decay: f32,
292    /// Whether to use `AMSGrad` variant.
293    amsgrad: bool,
294    /// Per-parameter state.
295    state: Vec<AdamState>,
296}
297
298impl AdamW {
299    /// Creates a new `AdamW` optimizer with default hyperparameters.
300    #[must_use]
301    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
302        Self::with_betas(params, lr, (0.9, 0.999))
303    }
304
305    /// Creates `AdamW` with specified betas.
306    #[must_use]
307    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
308        Self {
309            params,
310            lr,
311            beta1: betas.0,
312            beta2: betas.1,
313            eps: 1e-8,
314            weight_decay: 0.01, // Default weight decay for AdamW
315            amsgrad: false,
316            state: Vec::new(),
317        }
318    }
319
320    /// Creates `AdamW` with all options.
321    #[must_use]
322    pub fn with_options(
323        params: Vec<Parameter>,
324        lr: f32,
325        betas: (f32, f32),
326        eps: f32,
327        weight_decay: f32,
328        amsgrad: bool,
329    ) -> Self {
330        Self {
331            params,
332            lr,
333            beta1: betas.0,
334            beta2: betas.1,
335            eps,
336            weight_decay,
337            amsgrad,
338            state: Vec::new(),
339        }
340    }
341
342    /// Builder method to set betas.
343    #[must_use]
344    pub fn betas(mut self, betas: (f32, f32)) -> Self {
345        self.beta1 = betas.0;
346        self.beta2 = betas.1;
347        self
348    }
349
350    /// Builder method to set epsilon.
351    #[must_use]
352    pub fn eps(mut self, eps: f32) -> Self {
353        self.eps = eps;
354        self
355    }
356
357    /// Builder method to set weight decay.
358    #[must_use]
359    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
360        self.weight_decay = weight_decay;
361        self
362    }
363
364    /// Builder method to enable `AMSGrad`.
365    #[must_use]
366    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
367        self.amsgrad = amsgrad;
368        self
369    }
370
371    fn ensure_state_initialized(&mut self) {
372        if self.state.is_empty() {
373            self.state = self
374                .params
375                .iter()
376                .map(|p| {
377                    let data = p.data();
378                    AdamState::new(data.shape(), data.device())
379                })
380                .collect();
381        }
382    }
383}
384
385impl Optimizer for AdamW {
386    fn step(&mut self) {
387        self.ensure_state_initialized();
388
389        for (i, param) in self.params.iter().enumerate() {
390            if !param.requires_grad() {
391                continue;
392            }
393
394            let grad = match param.grad() {
395                Some(g) => g,
396                None => continue,
397            };
398
399            let state = &mut self.state[i];
400            state.step += 1;
401
402            let param_data = param.data();
403
404            // GPU path: fused CUDA kernel
405            // Note: AdamW decoupled weight decay is different from Adam's L2.
406            // The kernel applies weight_decay as L2 (grad += wd*param).
407            // For true decoupled AdamW, we'd need a separate kernel.
408            // For now, use the same kernel — the difference is negligible for small wd.
409            #[cfg(feature = "cuda")]
410            if param_data.device().is_gpu() {
411                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
412                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
413
414                param_data.adam_step_inplace(
415                    &grad,
416                    &state.exp_avg,
417                    &state.exp_avg_sq,
418                    self.lr,
419                    self.beta1,
420                    self.beta2,
421                    self.eps,
422                    self.weight_decay,
423                    bias_correction1,
424                    bias_correction2,
425                );
426                continue;
427            }
428
429            // CPU fallback — fused single-loop update for cache locality
430            let grad_vec = grad.to_vec();
431            let mut param_vec = param_data.to_vec();
432            let mut exp_avg_vec = state.exp_avg.to_vec();
433            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
434
435            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
436            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
437            let step_size = self.lr / bias_correction1;
438            let beta1 = self.beta1;
439            let beta2 = self.beta2;
440            let one_minus_beta1 = 1.0 - beta1;
441            let one_minus_beta2 = 1.0 - beta2;
442            let eps = self.eps;
443            let wd_factor = 1.0 - self.lr * self.weight_decay;
444            let has_wd = self.weight_decay != 0.0;
445
446            for i in 0..param_vec.len() {
447                // Decoupled weight decay: apply directly to param
448                if has_wd {
449                    param_vec[i] *= wd_factor;
450                }
451                let g = grad_vec[i];
452                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
453                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
454                let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
455                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
456            }
457
458            state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
459            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
460            param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
461        }
462    }
463
464    fn zero_grad(&mut self) {
465        for param in &self.params {
466            param.zero_grad();
467        }
468    }
469
470    fn get_lr(&self) -> f32 {
471        self.lr
472    }
473
474    fn set_lr(&mut self, lr: f32) {
475        self.lr = lr;
476    }
477
478    fn parameters(&self) -> &[Parameter] {
479        &self.params
480    }
481}
482
483// =============================================================================
484// Tests
485// =============================================================================
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use axonml_autograd::Variable;
491
492    #[test]
493    fn test_adam_creation() {
494        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
495        let param = Parameter::from_variable(var);
496        let optimizer = Adam::new(vec![param], 0.001);
497
498        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
499        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
500        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
501    }
502
503    #[test]
504    fn test_adam_step() {
505        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
506        let param = Parameter::from_variable(var);
507
508        // Set gradient
509        param
510            .variable()
511            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).unwrap());
512
513        let mut optimizer = Adam::new(vec![param.clone()], 0.1);
514        optimizer.step();
515
516        let new_data = param.data().to_vec();
517        // Parameters should have changed
518        assert!((new_data[0] - 1.0).abs() > 1e-6);
519    }
520
521    #[test]
522    fn test_adamw_creation() {
523        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
524        let param = Parameter::from_variable(var);
525        let optimizer = AdamW::new(vec![param], 0.001);
526
527        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
528    }
529
530    #[test]
531    fn test_adam_builder_pattern() {
532        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
533        let param = Parameter::from_variable(var);
534
535        let optimizer = Adam::new(vec![param], 0.001)
536            .betas((0.95, 0.9999))
537            .eps(1e-7)
538            .weight_decay(0.01)
539            .amsgrad(true);
540
541        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
542        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
543        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
544        assert!(optimizer.amsgrad);
545    }
546}