Skip to main content

axonml_optim/
adam.rs

1//! Adam Optimizer - Adaptive Moment Estimation
2//!
3//! # File
4//! `crates/axonml-optim/src/adam.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22// =============================================================================
23// Adam
24// =============================================================================
25
26/// Adam optimizer.
27///
28/// Maintains per-parameter adaptive learning rates using first and
29/// second moment estimates of gradients.
30///
31/// Update rule:
32/// ```text
33/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
34/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
35/// m_hat = m_t / (1 - beta1^t)
36/// v_hat = v_t / (1 - beta2^t)
37/// param = param - lr * m_hat / (sqrt(v_hat) + eps)
38/// ```
39pub struct Adam {
40    /// Parameters to optimize.
41    params: Vec<Parameter>,
42    /// Learning rate.
43    lr: f32,
44    /// First moment decay rate.
45    beta1: f32,
46    /// Second moment decay rate.
47    beta2: f32,
48    /// Small constant for numerical stability.
49    eps: f32,
50    /// Weight decay (L2 regularization for standard Adam).
51    weight_decay: f32,
52    /// Whether to use `AMSGrad` variant.
53    amsgrad: bool,
54    /// Per-parameter state.
55    state: Vec<AdamState>,
56}
57
58/// State for Adam optimizer.
59///
60/// Stores momentum tensors on the same device as parameters (CPU or GPU).
61/// When parameters are on GPU, all state stays on GPU — zero CPU round-trips.
62#[derive(Debug, Clone)]
63struct AdamState {
64    /// First moment (mean of gradients) — on same device as param.
65    exp_avg: Tensor<f32>,
66    /// Second moment (variance of gradients) — on same device as param.
67    exp_avg_sq: Tensor<f32>,
68    /// Maximum of all past exp_avg_sq values (for AMSGrad).
69    max_exp_avg_sq: Option<Tensor<f32>>,
70    /// Step count for bias correction.
71    step: usize,
72}
73
74impl AdamState {
75    fn new(shape: &[usize], device: axonml_core::Device) -> Self {
76        let size: usize = shape.iter().product();
77        let mut exp_avg =
78            Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
79        let mut exp_avg_sq =
80            Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
81        if device.is_gpu() {
82            exp_avg = exp_avg.to_device(device).expect("device transfer failed");
83            exp_avg_sq = exp_avg_sq
84                .to_device(device)
85                .expect("device transfer failed");
86        }
87        Self {
88            exp_avg,
89            exp_avg_sq,
90            max_exp_avg_sq: None, // Initialized on first use if amsgrad=true
91            step: 0,
92        }
93    }
94}
95
96impl Adam {
97    /// Creates a new Adam optimizer with default hyperparameters.
98    #[must_use]
99    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
100        Self::with_betas(params, lr, (0.9, 0.999))
101    }
102
103    /// Creates Adam with specified betas.
104    #[must_use]
105    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
106        Self {
107            params,
108            lr,
109            beta1: betas.0,
110            beta2: betas.1,
111            eps: 1e-8,
112            weight_decay: 0.0,
113            amsgrad: false,
114            state: Vec::new(),
115        }
116    }
117
118    /// Creates Adam with all options.
119    #[must_use]
120    pub fn with_options(
121        params: Vec<Parameter>,
122        lr: f32,
123        betas: (f32, f32),
124        eps: f32,
125        weight_decay: f32,
126        amsgrad: bool,
127    ) -> Self {
128        Self {
129            params,
130            lr,
131            beta1: betas.0,
132            beta2: betas.1,
133            eps,
134            weight_decay,
135            amsgrad,
136            state: Vec::new(),
137        }
138    }
139
140    /// Builder method to set betas.
141    #[must_use]
142    pub fn betas(mut self, betas: (f32, f32)) -> Self {
143        self.beta1 = betas.0;
144        self.beta2 = betas.1;
145        self
146    }
147
148    /// Builder method to set epsilon.
149    #[must_use]
150    pub fn eps(mut self, eps: f32) -> Self {
151        self.eps = eps;
152        self
153    }
154
155    /// Builder method to set weight decay.
156    #[must_use]
157    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
158        self.weight_decay = weight_decay;
159        self
160    }
161
162    /// Builder method to enable `AMSGrad`.
163    #[must_use]
164    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
165        self.amsgrad = amsgrad;
166        self
167    }
168
169    fn ensure_state_initialized(&mut self) {
170        if self.state.is_empty() {
171            self.state = self
172                .params
173                .iter()
174                .map(|p| {
175                    let data = p.data();
176                    AdamState::new(data.shape(), data.device())
177                })
178                .collect();
179        }
180    }
181}
182
183impl Optimizer for Adam {
184    fn step(&mut self) {
185        self.ensure_state_initialized();
186
187        for (i, param) in self.params.iter().enumerate() {
188            if !param.requires_grad() {
189                continue;
190            }
191
192            let grad = match param.grad() {
193                Some(g) => g,
194                None => continue,
195            };
196
197            let state = &mut self.state[i];
198            state.step += 1;
199
200            let param_data = param.data();
201
202            // GPU path: fused CUDA kernel — single launch per parameter, zero CPU copies
203            #[cfg(feature = "cuda")]
204            if param_data.device().is_gpu() {
205                // Auto-migrate gradient to GPU if backward produced CPU gradients
206                // (happens when backward functions use CPU fallback computation)
207                let grad = if !grad.device().is_gpu() {
208                    grad.to_device(param_data.device())
209                        .expect("Adam: failed to migrate CPU gradient to GPU")
210                } else {
211                    grad
212                };
213                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
214                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
215
216                // In-place fused Adam update on GPU
217                param_data.adam_step_inplace(
218                    &grad,
219                    &state.exp_avg,
220                    &state.exp_avg_sq,
221                    self.lr,
222                    self.beta1,
223                    self.beta2,
224                    self.eps,
225                    self.weight_decay,
226                    bias_correction1,
227                    bias_correction2,
228                );
229                // No need for update_data — the kernel modified the GPU buffer in-place
230                continue;
231            }
232
233            // CPU fallback — fused single-loop update for cache locality
234            let grad_vec = grad.to_vec();
235            let mut param_vec = param_data.to_vec();
236            let mut exp_avg_vec = state.exp_avg.to_vec();
237            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
238
239            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
240            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
241            let step_size = self.lr / bias_correction1;
242            let beta1 = self.beta1;
243            let beta2 = self.beta2;
244            let one_minus_beta1 = 1.0 - beta1;
245            let one_minus_beta2 = 1.0 - beta2;
246            let eps = self.eps;
247            let wd = self.weight_decay;
248
249            // AMSGrad: track max of all past exp_avg_sq values
250            let mut max_sq_vec = if self.amsgrad {
251                state
252                    .max_exp_avg_sq
253                    .as_ref()
254                    .map_or_else(|| vec![0.0f32; param_vec.len()], |t| t.to_vec())
255            } else {
256                Vec::new()
257            };
258
259            for i in 0..param_vec.len() {
260                let g = if wd == 0.0 {
261                    grad_vec[i]
262                } else {
263                    grad_vec[i] + wd * param_vec[i]
264                };
265                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
266                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
267
268                let v_hat = if self.amsgrad {
269                    max_sq_vec[i] = max_sq_vec[i].max(exp_avg_sq_vec[i]);
270                    max_sq_vec[i] / bias_correction2
271                } else {
272                    exp_avg_sq_vec[i] / bias_correction2
273                };
274
275                let denom = v_hat.sqrt() + eps;
276                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
277            }
278
279            state.exp_avg =
280                Tensor::from_vec(exp_avg_vec, param_data.shape()).expect("tensor creation failed");
281            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape())
282                .expect("tensor creation failed");
283            if self.amsgrad {
284                state.max_exp_avg_sq = Some(
285                    Tensor::from_vec(max_sq_vec, param_data.shape())
286                        .expect("tensor creation failed"),
287                );
288            }
289            param.update_data(
290                Tensor::from_vec(param_vec, param_data.shape()).expect("tensor creation failed"),
291            );
292        }
293    }
294
295    fn zero_grad(&mut self) {
296        for param in &self.params {
297            param.zero_grad();
298        }
299    }
300
301    fn get_lr(&self) -> f32 {
302        self.lr
303    }
304
305    fn set_lr(&mut self, lr: f32) {
306        self.lr = lr;
307    }
308
309    fn parameters(&self) -> &[Parameter] {
310        &self.params
311    }
312}
313
314// =============================================================================
315// AdamW
316// =============================================================================
317
318/// `AdamW` optimizer (Adam with decoupled weight decay).
319///
320/// Unlike standard Adam which applies L2 regularization to the gradient,
321/// `AdamW` applies weight decay directly to the parameters.
322///
323/// Update rule:
324/// ```text
325/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
326/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
327/// m_hat = m_t / (1 - beta1^t)
328/// v_hat = v_t / (1 - beta2^t)
329/// param = param - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * param)
330/// ```
331pub struct AdamW {
332    /// Parameters to optimize.
333    params: Vec<Parameter>,
334    /// Learning rate.
335    lr: f32,
336    /// First moment decay rate.
337    beta1: f32,
338    /// Second moment decay rate.
339    beta2: f32,
340    /// Small constant for numerical stability.
341    eps: f32,
342    /// Decoupled weight decay coefficient.
343    weight_decay: f32,
344    /// Whether to use `AMSGrad` variant.
345    amsgrad: bool,
346    /// Per-parameter state.
347    state: Vec<AdamState>,
348}
349
350impl AdamW {
351    /// Creates a new `AdamW` optimizer with default hyperparameters.
352    #[must_use]
353    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
354        Self::with_betas(params, lr, (0.9, 0.999))
355    }
356
357    /// Creates `AdamW` with specified betas.
358    #[must_use]
359    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
360        Self {
361            params,
362            lr,
363            beta1: betas.0,
364            beta2: betas.1,
365            eps: 1e-8,
366            weight_decay: 0.01, // Default weight decay for AdamW
367            amsgrad: false,
368            state: Vec::new(),
369        }
370    }
371
372    /// Creates `AdamW` with all options.
373    #[must_use]
374    pub fn with_options(
375        params: Vec<Parameter>,
376        lr: f32,
377        betas: (f32, f32),
378        eps: f32,
379        weight_decay: f32,
380        amsgrad: bool,
381    ) -> Self {
382        Self {
383            params,
384            lr,
385            beta1: betas.0,
386            beta2: betas.1,
387            eps,
388            weight_decay,
389            amsgrad,
390            state: Vec::new(),
391        }
392    }
393
394    /// Builder method to set betas.
395    #[must_use]
396    pub fn betas(mut self, betas: (f32, f32)) -> Self {
397        self.beta1 = betas.0;
398        self.beta2 = betas.1;
399        self
400    }
401
402    /// Builder method to set epsilon.
403    #[must_use]
404    pub fn eps(mut self, eps: f32) -> Self {
405        self.eps = eps;
406        self
407    }
408
409    /// Builder method to set weight decay.
410    #[must_use]
411    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
412        self.weight_decay = weight_decay;
413        self
414    }
415
416    /// Builder method to enable `AMSGrad`.
417    #[must_use]
418    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
419        self.amsgrad = amsgrad;
420        self
421    }
422
423    fn ensure_state_initialized(&mut self) {
424        if self.state.is_empty() {
425            self.state = self
426                .params
427                .iter()
428                .map(|p| {
429                    let data = p.data();
430                    AdamState::new(data.shape(), data.device())
431                })
432                .collect();
433        }
434    }
435}
436
437impl Optimizer for AdamW {
438    fn step(&mut self) {
439        self.ensure_state_initialized();
440
441        for (i, param) in self.params.iter().enumerate() {
442            if !param.requires_grad() {
443                continue;
444            }
445
446            let grad = match param.grad() {
447                Some(g) => g,
448                None => continue,
449            };
450
451            let state = &mut self.state[i];
452            state.step += 1;
453
454            let param_data = param.data();
455
456            // GPU path: decoupled weight decay + fused Adam step
457            #[cfg(feature = "cuda")]
458            if param_data.device().is_gpu() {
459                // Auto-migrate gradient to GPU if backward produced CPU gradients
460                let grad = if !grad.device().is_gpu() {
461                    grad.to_device(param_data.device())
462                        .expect("AdamW: failed to migrate CPU gradient to GPU")
463                } else {
464                    grad
465                };
466
467                // DECOUPLED weight decay: param *= (1 - lr * wd)
468                // This is the key difference from Adam's L2 regularization.
469                // Applied BEFORE the Adam update, directly to parameters.
470                if self.weight_decay > 0.0 {
471                    let decay_factor = 1.0 - self.lr * self.weight_decay;
472                    let decayed = param_data.mul_scalar(decay_factor);
473                    param.update_data(decayed);
474                }
475
476                // Re-read param_data after potential decay update
477                let param_data = param.data();
478
479                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
480                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
481
482                // Adam step with wd=0 (decay already applied above)
483                param_data.adam_step_inplace(
484                    &grad,
485                    &state.exp_avg,
486                    &state.exp_avg_sq,
487                    self.lr,
488                    self.beta1,
489                    self.beta2,
490                    self.eps,
491                    0.0, // wd=0: decoupled decay already applied
492                    bias_correction1,
493                    bias_correction2,
494                );
495                continue;
496            }
497
498            // CPU fallback — fused single-loop update for cache locality
499            let grad_vec = grad.to_vec();
500            let mut param_vec = param_data.to_vec();
501            let mut exp_avg_vec = state.exp_avg.to_vec();
502            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
503
504            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
505            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
506            let step_size = self.lr / bias_correction1;
507            let beta1 = self.beta1;
508            let beta2 = self.beta2;
509            let one_minus_beta1 = 1.0 - beta1;
510            let one_minus_beta2 = 1.0 - beta2;
511            let eps = self.eps;
512            let wd_factor = 1.0 - self.lr * self.weight_decay;
513            let has_wd = self.weight_decay != 0.0;
514
515            for i in 0..param_vec.len() {
516                // Decoupled weight decay: apply directly to param
517                if has_wd {
518                    param_vec[i] *= wd_factor;
519                }
520                let g = grad_vec[i];
521                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
522                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
523                let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
524                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
525            }
526
527            state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
528            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
529            param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
530        }
531    }
532
533    fn zero_grad(&mut self) {
534        for param in &self.params {
535            param.zero_grad();
536        }
537    }
538
539    fn get_lr(&self) -> f32 {
540        self.lr
541    }
542
543    fn set_lr(&mut self, lr: f32) {
544        self.lr = lr;
545    }
546
547    fn parameters(&self) -> &[Parameter] {
548        &self.params
549    }
550}
551
552// =============================================================================
553// Tests
554// =============================================================================
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use axonml_autograd::Variable;
560
561    #[test]
562    fn test_adam_creation() {
563        let var = Variable::new(
564            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
565            true,
566        );
567        let param = Parameter::from_variable(var);
568        let optimizer = Adam::new(vec![param], 0.001);
569
570        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
571        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
572        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
573    }
574
575    #[test]
576    fn test_adam_step() {
577        let var = Variable::new(
578            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
579            true,
580        );
581        let param = Parameter::from_variable(var);
582
583        // Set gradient
584        param
585            .variable()
586            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
587
588        let mut optimizer = Adam::new(vec![param.clone()], 0.1);
589        optimizer.step();
590
591        let new_data = param.data().to_vec();
592        // Parameters should have changed
593        assert!((new_data[0] - 1.0).abs() > 1e-6);
594    }
595
596    #[test]
597    fn test_adamw_creation() {
598        let var = Variable::new(
599            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
600            true,
601        );
602        let param = Parameter::from_variable(var);
603        let optimizer = AdamW::new(vec![param], 0.001);
604
605        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
606    }
607
608    #[test]
609    fn test_adam_builder_pattern() {
610        let var = Variable::new(
611            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
612            true,
613        );
614        let param = Parameter::from_variable(var);
615
616        let optimizer = Adam::new(vec![param], 0.001)
617            .betas((0.95, 0.9999))
618            .eps(1e-7)
619            .weight_decay(0.01)
620            .amsgrad(true);
621
622        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
623        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
624        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
625        assert!(optimizer.amsgrad);
626    }
627
628    // =========================================================================
629    // Adam Step Correctness Tests
630    // =========================================================================
631
632    /// Verify Adam update matches the mathematical formula exactly.
633    /// After one step with grad=[1,1], lr=0.1, betas=(0.9,0.999):
634    ///   m = 0.1*[1,1], v = 0.001*[1,1]
635    ///   m_hat = m/0.1 = [1,1], v_hat = v/0.001 = [1,1]
636    ///   param -= 0.1 * 1.0 / (1.0 + 1e-8) ≈ 0.1
637    #[test]
638    fn test_adam_step_correctness() {
639        let var = Variable::new(Tensor::from_vec(vec![0.5, -0.3], &[2]).unwrap(), true);
640        let param = Parameter::from_variable(var);
641        param.set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
642
643        let mut opt = Adam::new(vec![param.clone()], 0.1);
644        let before = param.data().to_vec();
645        opt.step();
646        let after = param.data().to_vec();
647
648        // Both params should decrease (positive gradient → decrease)
649        assert!(
650            after[0] < before[0],
651            "param[0] should decrease: {} -> {}",
652            before[0],
653            after[0]
654        );
655        assert!(
656            after[1] < before[1],
657            "param[1] should decrease: {} -> {}",
658            before[1],
659            after[1]
660        );
661
662        // After one Adam step with uniform gradient, both should change by the same amount
663        let delta0 = before[0] - after[0];
664        let delta1 = before[1] - after[1];
665        assert!(
666            (delta0 - delta1).abs() < 1e-6,
667            "Uniform gradient should produce uniform update: {} vs {}",
668            delta0,
669            delta1
670        );
671    }
672
673    /// Verify Adam converges on a simple quadratic: minimize f(x) = x^2.
674    /// Uses autograd for proper gradient computation.
675    #[test]
676    fn test_adam_converges_on_quadratic() {
677        let var = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), true);
678        let param = Parameter::from_variable(var);
679        let mut opt = Adam::new(vec![param.clone()], 0.1);
680
681        for _ in 0..200 {
682            opt.zero_grad();
683            // f(x) = x^2 → loss, compute gradient via autograd
684            let x = param.variable();
685            let loss = x.mul_var(&x).sum(); // x^2
686            loss.backward();
687            opt.step();
688        }
689
690        let final_x = param.data().to_vec()[0];
691        assert!(
692            final_x.abs() < 0.1,
693            "Adam should converge near 0 for f(x)=x^2, got {}",
694            final_x
695        );
696    }
697
698    /// Verify zero_grad actually clears all gradients.
699    #[test]
700    fn test_adam_zero_grad() {
701        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), true);
702        let param = Parameter::from_variable(var);
703        param.set_grad(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap());
704        assert!(param.grad().is_some());
705
706        let mut opt = Adam::new(vec![param.clone()], 0.01);
707        opt.zero_grad();
708        // After zero_grad, gradient should be None or all zeros
709        if let Some(g) = param.grad() {
710            let gv = g.to_vec();
711            assert!(
712                gv.iter().all(|&v| v.abs() < 1e-10),
713                "Gradients should be zero after zero_grad: {:?}",
714                gv
715            );
716        }
717    }
718
719    /// Verify set_lr / get_lr work correctly.
720    #[test]
721    fn test_adam_lr_management() {
722        let var = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), true);
723        let param = Parameter::from_variable(var);
724        let mut opt = Adam::new(vec![param], 0.001);
725
726        assert!((opt.get_lr() - 0.001).abs() < 1e-8);
727        opt.set_lr(0.01);
728        assert!((opt.get_lr() - 0.01).abs() < 1e-8);
729    }
730
731    /// Verify Adam handles no-grad params gracefully (skips them).
732    #[test]
733    fn test_adam_skips_frozen_params() {
734        let trainable = Parameter::from_variable(Variable::new(
735            Tensor::from_vec(vec![1.0], &[1]).unwrap(),
736            true,
737        ));
738        let frozen = Parameter::from_variable(Variable::new(
739            Tensor::from_vec(vec![2.0], &[1]).unwrap(),
740            false,
741        ));
742
743        trainable.set_grad(Tensor::from_vec(vec![1.0], &[1]).unwrap());
744
745        let mut opt = Adam::new(vec![trainable.clone(), frozen.clone()], 0.1);
746        opt.step();
747
748        // Trainable should change, frozen should not
749        assert!((trainable.data().to_vec()[0] - 1.0).abs() > 1e-6);
750        assert!((frozen.data().to_vec()[0] - 2.0).abs() < 1e-8);
751    }
752
753    /// Verify Adam with weight decay actually decays weights.
754    #[test]
755    fn test_adam_weight_decay() {
756        let var = Variable::new(Tensor::from_vec(vec![10.0], &[1]).unwrap(), true);
757        let param = Parameter::from_variable(var);
758        // Set zero gradient — only weight decay should modify params
759        param.set_grad(Tensor::from_vec(vec![0.0], &[1]).unwrap());
760
761        let mut opt = Adam::new(vec![param.clone()], 0.1).weight_decay(0.1);
762        let before = param.data().to_vec()[0];
763        opt.step();
764        let after = param.data().to_vec()[0];
765
766        // With weight_decay, even zero gradient should shrink params
767        // (grad_effective = grad + wd * param = 0 + 0.1 * 10.0 = 1.0)
768        assert!(
769            after < before,
770            "Weight decay should shrink large params: {} -> {}",
771            before,
772            after
773        );
774    }
775
776    /// Verify multiple Adam steps produce improvement on a simple loss using autograd.
777    #[test]
778    fn test_adam_multiple_steps_improve() {
779        let var = Variable::new(Tensor::from_vec(vec![3.0, -2.0], &[2]).unwrap(), true);
780        let param = Parameter::from_variable(var);
781        let mut opt = Adam::new(vec![param.clone()], 0.05);
782
783        let mut losses = Vec::new();
784        for _ in 0..50 {
785            opt.zero_grad();
786            let x = param.variable();
787            let loss = x.mul_var(&x).sum(); // ||x||^2
788            losses.push(loss.data().to_vec()[0]);
789            loss.backward();
790            opt.step();
791        }
792
793        // First loss should be much higher than last loss
794        let first = losses[0];
795        let last = *losses.last().unwrap();
796        assert!(
797            last < first * 0.5,
798            "Loss should decrease significantly: first={}, last={}",
799            first,
800            last
801        );
802    }
803
804    // =========================================================================
805    // AdamW Tests
806    // =========================================================================
807
808    /// Verify AdamW step works and decoupled weight decay differs from L2.
809    #[test]
810    fn test_adamw_step_correctness() {
811        let var = Variable::new(Tensor::from_vec(vec![5.0, -3.0], &[2]).unwrap(), true);
812        let param = Parameter::from_variable(var);
813        param.set_grad(Tensor::from_vec(vec![1.0, -1.0], &[2]).unwrap());
814
815        let mut opt = AdamW::new(vec![param.clone()], 0.01);
816        let before = param.data().to_vec();
817        opt.step();
818        let after = param.data().to_vec();
819
820        // Positive grad → decrease, negative grad → increase
821        assert!(after[0] < before[0], "Positive grad should decrease param");
822        assert!(after[1] > before[1], "Negative grad should increase param");
823    }
824
825    /// Verify AdamW converges using autograd.
826    #[test]
827    fn test_adamw_converges() {
828        let var = Variable::new(Tensor::from_vec(vec![4.0], &[1]).unwrap(), true);
829        let param = Parameter::from_variable(var);
830        let mut opt = AdamW::new(vec![param.clone()], 0.1);
831
832        for _ in 0..200 {
833            opt.zero_grad();
834            let x = param.variable();
835            let loss = x.mul_var(&x).sum();
836            loss.backward();
837            opt.step();
838        }
839
840        assert!(
841            param.data().to_vec()[0].abs() < 0.1,
842            "AdamW should converge near 0, got {}",
843            param.data().to_vec()[0]
844        );
845    }
846}