Skip to main content

axonml_optim/
adam.rs

1//! `Adam` and `AdamW` — adaptive moment estimation optimizers.
2//!
3//! `Adam::new(params, lr)` with beta1/beta2/epsilon/weight_decay config.
4//! `AdamW` variant implements decoupled weight decay (applies decay to
5//! parameters directly, not through the gradient). Both track first/second
6//! moment estimates per parameter with bias correction.
7//!
8//! # File
9//! `crates/axonml-optim/src/adam.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23use axonml_nn::Parameter;
24use axonml_tensor::Tensor;
25
26use crate::optimizer::Optimizer;
27
28// =============================================================================
29// Adam
30// =============================================================================
31
32/// Adam optimizer.
33///
34/// Maintains per-parameter adaptive learning rates using first and
35/// second moment estimates of gradients.
36///
37/// Update rule:
38/// ```text
39/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
40/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
41/// m_hat = m_t / (1 - beta1^t)
42/// v_hat = v_t / (1 - beta2^t)
43/// param = param - lr * m_hat / (sqrt(v_hat) + eps)
44/// ```
45pub struct Adam {
46    /// Parameters to optimize.
47    params: Vec<Parameter>,
48    /// Learning rate.
49    lr: f32,
50    /// First moment decay rate.
51    beta1: f32,
52    /// Second moment decay rate.
53    beta2: f32,
54    /// Small constant for numerical stability.
55    eps: f32,
56    /// Weight decay (L2 regularization for standard Adam).
57    weight_decay: f32,
58    /// Whether to use `AMSGrad` variant.
59    amsgrad: bool,
60    /// Per-parameter state.
61    state: Vec<AdamState>,
62}
63
64/// State for Adam optimizer.
65///
66/// Stores momentum tensors on the same device as parameters (CPU or GPU).
67/// When parameters are on GPU, all state stays on GPU — zero CPU round-trips.
68#[derive(Debug, Clone)]
69struct AdamState {
70    /// First moment (mean of gradients) — on same device as param.
71    exp_avg: Tensor<f32>,
72    /// Second moment (variance of gradients) — on same device as param.
73    exp_avg_sq: Tensor<f32>,
74    /// Maximum of all past exp_avg_sq values (for AMSGrad).
75    max_exp_avg_sq: Option<Tensor<f32>>,
76    /// Step count for bias correction.
77    step: usize,
78}
79
80impl AdamState {
81    fn new(shape: &[usize], device: axonml_core::Device) -> Self {
82        let size: usize = shape.iter().product();
83        let mut exp_avg =
84            Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
85        let mut exp_avg_sq =
86            Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
87        if device.is_gpu() {
88            exp_avg = exp_avg.to_device(device).expect("device transfer failed");
89            exp_avg_sq = exp_avg_sq
90                .to_device(device)
91                .expect("device transfer failed");
92        }
93        Self {
94            exp_avg,
95            exp_avg_sq,
96            max_exp_avg_sq: None, // Initialized on first use if amsgrad=true
97            step: 0,
98        }
99    }
100}
101
102impl Adam {
103    /// Creates a new Adam optimizer with default hyperparameters.
104    #[must_use]
105    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
106        Self::with_betas(params, lr, (0.9, 0.999))
107    }
108
109    /// Creates Adam with specified betas.
110    #[must_use]
111    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
112        Self {
113            params,
114            lr,
115            beta1: betas.0,
116            beta2: betas.1,
117            eps: 1e-8,
118            weight_decay: 0.0,
119            amsgrad: false,
120            state: Vec::new(),
121        }
122    }
123
124    /// Creates Adam with all options.
125    #[must_use]
126    pub fn with_options(
127        params: Vec<Parameter>,
128        lr: f32,
129        betas: (f32, f32),
130        eps: f32,
131        weight_decay: f32,
132        amsgrad: bool,
133    ) -> Self {
134        Self {
135            params,
136            lr,
137            beta1: betas.0,
138            beta2: betas.1,
139            eps,
140            weight_decay,
141            amsgrad,
142            state: Vec::new(),
143        }
144    }
145
146    /// Builder method to set betas.
147    #[must_use]
148    pub fn betas(mut self, betas: (f32, f32)) -> Self {
149        self.beta1 = betas.0;
150        self.beta2 = betas.1;
151        self
152    }
153
154    /// Builder method to set epsilon.
155    #[must_use]
156    pub fn eps(mut self, eps: f32) -> Self {
157        self.eps = eps;
158        self
159    }
160
161    /// Builder method to set weight decay.
162    #[must_use]
163    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
164        self.weight_decay = weight_decay;
165        self
166    }
167
168    /// Builder method to enable `AMSGrad`.
169    #[must_use]
170    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
171        self.amsgrad = amsgrad;
172        self
173    }
174
175    fn ensure_state_initialized(&mut self) {
176        if self.state.is_empty() {
177            self.state = self
178                .params
179                .iter()
180                .map(|p| {
181                    let data = p.data();
182                    AdamState::new(data.shape(), data.device())
183                })
184                .collect();
185        }
186    }
187}
188
189impl Optimizer for Adam {
190    fn step(&mut self) {
191        self.ensure_state_initialized();
192
193        for (i, param) in self.params.iter().enumerate() {
194            if !param.requires_grad() {
195                continue;
196            }
197
198            let grad = match param.grad() {
199                Some(g) => g,
200                None => continue,
201            };
202
203            let state = &mut self.state[i];
204            state.step += 1;
205
206            let param_data = param.data();
207
208            // GPU path: fused CUDA kernel — single launch per parameter, zero CPU copies
209            #[cfg(feature = "cuda")]
210            if param_data.device().is_gpu() {
211                // Auto-migrate gradient to GPU if backward produced CPU gradients
212                // (happens when backward functions use CPU fallback computation)
213                let grad = if !grad.device().is_gpu() {
214                    grad.to_device(param_data.device())
215                        .expect("Adam: failed to migrate CPU gradient to GPU")
216                } else {
217                    grad
218                };
219                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
220                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
221
222                // In-place fused Adam update on GPU
223                param_data.adam_step_inplace(
224                    &grad,
225                    &state.exp_avg,
226                    &state.exp_avg_sq,
227                    self.lr,
228                    self.beta1,
229                    self.beta2,
230                    self.eps,
231                    self.weight_decay,
232                    bias_correction1,
233                    bias_correction2,
234                );
235                // No need for update_data — the kernel modified the GPU buffer in-place
236                continue;
237            }
238
239            // CPU fallback — fused single-loop update for cache locality
240            let grad_vec = grad.to_vec();
241            let mut param_vec = param_data.to_vec();
242            let mut exp_avg_vec = state.exp_avg.to_vec();
243            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
244
245            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
246            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
247            let step_size = self.lr / bias_correction1;
248            let beta1 = self.beta1;
249            let beta2 = self.beta2;
250            let one_minus_beta1 = 1.0 - beta1;
251            let one_minus_beta2 = 1.0 - beta2;
252            let eps = self.eps;
253            let wd = self.weight_decay;
254
255            // AMSGrad: track max of all past exp_avg_sq values
256            let mut max_sq_vec = if self.amsgrad {
257                state
258                    .max_exp_avg_sq
259                    .as_ref()
260                    .map_or_else(|| vec![0.0f32; param_vec.len()], |t| t.to_vec())
261            } else {
262                Vec::new()
263            };
264
265            for i in 0..param_vec.len() {
266                let g = if wd == 0.0 {
267                    grad_vec[i]
268                } else {
269                    grad_vec[i] + wd * param_vec[i]
270                };
271                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
272                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
273
274                let v_hat = if self.amsgrad {
275                    max_sq_vec[i] = max_sq_vec[i].max(exp_avg_sq_vec[i]);
276                    max_sq_vec[i] / bias_correction2
277                } else {
278                    exp_avg_sq_vec[i] / bias_correction2
279                };
280
281                let denom = v_hat.sqrt() + eps;
282                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
283            }
284
285            state.exp_avg =
286                Tensor::from_vec(exp_avg_vec, param_data.shape()).expect("tensor creation failed");
287            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape())
288                .expect("tensor creation failed");
289            if self.amsgrad {
290                state.max_exp_avg_sq = Some(
291                    Tensor::from_vec(max_sq_vec, param_data.shape())
292                        .expect("tensor creation failed"),
293                );
294            }
295            param.update_data(
296                Tensor::from_vec(param_vec, param_data.shape()).expect("tensor creation failed"),
297            );
298        }
299    }
300
301    fn zero_grad(&mut self) {
302        for param in &self.params {
303            param.zero_grad();
304        }
305    }
306
307    fn get_lr(&self) -> f32 {
308        self.lr
309    }
310
311    fn set_lr(&mut self, lr: f32) {
312        self.lr = lr;
313    }
314
315    fn parameters(&self) -> &[Parameter] {
316        &self.params
317    }
318}
319
320// =============================================================================
321// AdamW
322// =============================================================================
323
324/// `AdamW` optimizer (Adam with decoupled weight decay).
325///
326/// Unlike standard Adam which applies L2 regularization to the gradient,
327/// `AdamW` applies weight decay directly to the parameters.
328///
329/// Update rule:
330/// ```text
331/// m_t = beta1 * m_{t-1} + (1 - beta1) * grad
332/// v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
333/// m_hat = m_t / (1 - beta1^t)
334/// v_hat = v_t / (1 - beta2^t)
335/// param = param - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * param)
336/// ```
337pub struct AdamW {
338    /// Parameters to optimize.
339    params: Vec<Parameter>,
340    /// Learning rate.
341    lr: f32,
342    /// First moment decay rate.
343    beta1: f32,
344    /// Second moment decay rate.
345    beta2: f32,
346    /// Small constant for numerical stability.
347    eps: f32,
348    /// Decoupled weight decay coefficient.
349    weight_decay: f32,
350    /// Whether to use `AMSGrad` variant.
351    amsgrad: bool,
352    /// Per-parameter state.
353    state: Vec<AdamState>,
354}
355
356impl AdamW {
357    /// Creates a new `AdamW` optimizer with default hyperparameters.
358    #[must_use]
359    pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
360        Self::with_betas(params, lr, (0.9, 0.999))
361    }
362
363    /// Creates `AdamW` with specified betas.
364    #[must_use]
365    pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
366        Self {
367            params,
368            lr,
369            beta1: betas.0,
370            beta2: betas.1,
371            eps: 1e-8,
372            weight_decay: 0.01, // Default weight decay for AdamW
373            amsgrad: false,
374            state: Vec::new(),
375        }
376    }
377
378    /// Creates `AdamW` with all options.
379    #[must_use]
380    pub fn with_options(
381        params: Vec<Parameter>,
382        lr: f32,
383        betas: (f32, f32),
384        eps: f32,
385        weight_decay: f32,
386        amsgrad: bool,
387    ) -> Self {
388        Self {
389            params,
390            lr,
391            beta1: betas.0,
392            beta2: betas.1,
393            eps,
394            weight_decay,
395            amsgrad,
396            state: Vec::new(),
397        }
398    }
399
400    /// Builder method to set betas.
401    #[must_use]
402    pub fn betas(mut self, betas: (f32, f32)) -> Self {
403        self.beta1 = betas.0;
404        self.beta2 = betas.1;
405        self
406    }
407
408    /// Builder method to set epsilon.
409    #[must_use]
410    pub fn eps(mut self, eps: f32) -> Self {
411        self.eps = eps;
412        self
413    }
414
415    /// Builder method to set weight decay.
416    #[must_use]
417    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
418        self.weight_decay = weight_decay;
419        self
420    }
421
422    /// Builder method to enable `AMSGrad`.
423    #[must_use]
424    pub fn amsgrad(mut self, amsgrad: bool) -> Self {
425        self.amsgrad = amsgrad;
426        self
427    }
428
429    fn ensure_state_initialized(&mut self) {
430        if self.state.is_empty() {
431            self.state = self
432                .params
433                .iter()
434                .map(|p| {
435                    let data = p.data();
436                    AdamState::new(data.shape(), data.device())
437                })
438                .collect();
439        }
440    }
441}
442
443impl Optimizer for AdamW {
444    fn step(&mut self) {
445        self.ensure_state_initialized();
446
447        for (i, param) in self.params.iter().enumerate() {
448            if !param.requires_grad() {
449                continue;
450            }
451
452            let grad = match param.grad() {
453                Some(g) => g,
454                None => continue,
455            };
456
457            let state = &mut self.state[i];
458            state.step += 1;
459
460            let param_data = param.data();
461
462            // GPU path: decoupled weight decay + fused Adam step
463            #[cfg(feature = "cuda")]
464            if param_data.device().is_gpu() {
465                // Auto-migrate gradient to GPU if backward produced CPU gradients
466                let grad = if !grad.device().is_gpu() {
467                    grad.to_device(param_data.device())
468                        .expect("AdamW: failed to migrate CPU gradient to GPU")
469                } else {
470                    grad
471                };
472
473                // DECOUPLED weight decay: param *= (1 - lr * wd)
474                // This is the key difference from Adam's L2 regularization.
475                // Applied BEFORE the Adam update, directly to parameters.
476                if self.weight_decay > 0.0 {
477                    let decay_factor = 1.0 - self.lr * self.weight_decay;
478                    let decayed = param_data.mul_scalar(decay_factor);
479                    param.update_data(decayed);
480                }
481
482                // Re-read param_data after potential decay update
483                let param_data = param.data();
484
485                let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
486                let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
487
488                // Adam step with wd=0 (decay already applied above)
489                param_data.adam_step_inplace(
490                    &grad,
491                    &state.exp_avg,
492                    &state.exp_avg_sq,
493                    self.lr,
494                    self.beta1,
495                    self.beta2,
496                    self.eps,
497                    0.0, // wd=0: decoupled decay already applied
498                    bias_correction1,
499                    bias_correction2,
500                );
501                continue;
502            }
503
504            // CPU fallback — fused single-loop update for cache locality
505            let grad_vec = grad.to_vec();
506            let mut param_vec = param_data.to_vec();
507            let mut exp_avg_vec = state.exp_avg.to_vec();
508            let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
509
510            let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
511            let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
512            let step_size = self.lr / bias_correction1;
513            let beta1 = self.beta1;
514            let beta2 = self.beta2;
515            let one_minus_beta1 = 1.0 - beta1;
516            let one_minus_beta2 = 1.0 - beta2;
517            let eps = self.eps;
518            let wd_factor = 1.0 - self.lr * self.weight_decay;
519            let has_wd = self.weight_decay != 0.0;
520
521            for i in 0..param_vec.len() {
522                // Decoupled weight decay: apply directly to param
523                if has_wd {
524                    param_vec[i] *= wd_factor;
525                }
526                let g = grad_vec[i];
527                exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
528                exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
529                let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
530                param_vec[i] -= step_size * exp_avg_vec[i] / denom;
531            }
532
533            state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
534            state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
535            param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
536        }
537    }
538
539    fn zero_grad(&mut self) {
540        for param in &self.params {
541            param.zero_grad();
542        }
543    }
544
545    fn get_lr(&self) -> f32 {
546        self.lr
547    }
548
549    fn set_lr(&mut self, lr: f32) {
550        self.lr = lr;
551    }
552
553    fn parameters(&self) -> &[Parameter] {
554        &self.params
555    }
556}
557
558// =============================================================================
559// Tests
560// =============================================================================
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use axonml_autograd::Variable;
566
567    #[test]
568    fn test_adam_creation() {
569        let var = Variable::new(
570            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
571            true,
572        );
573        let param = Parameter::from_variable(var);
574        let optimizer = Adam::new(vec![param], 0.001);
575
576        assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
577        assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
578        assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
579    }
580
581    #[test]
582    fn test_adam_step() {
583        let var = Variable::new(
584            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
585            true,
586        );
587        let param = Parameter::from_variable(var);
588
589        // Set gradient
590        param
591            .variable()
592            .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
593
594        let mut optimizer = Adam::new(vec![param.clone()], 0.1);
595        optimizer.step();
596
597        let new_data = param.data().to_vec();
598        // Parameters should have changed
599        assert!((new_data[0] - 1.0).abs() > 1e-6);
600    }
601
602    #[test]
603    fn test_adamw_creation() {
604        let var = Variable::new(
605            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
606            true,
607        );
608        let param = Parameter::from_variable(var);
609        let optimizer = AdamW::new(vec![param], 0.001);
610
611        assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
612    }
613
614    #[test]
615    fn test_adam_builder_pattern() {
616        let var = Variable::new(
617            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
618            true,
619        );
620        let param = Parameter::from_variable(var);
621
622        let optimizer = Adam::new(vec![param], 0.001)
623            .betas((0.95, 0.9999))
624            .eps(1e-7)
625            .weight_decay(0.01)
626            .amsgrad(true);
627
628        assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
629        assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
630        assert!((optimizer.eps - 1e-7).abs() < 1e-9);
631        assert!(optimizer.amsgrad);
632    }
633
634    // =========================================================================
635    // Adam Step Correctness Tests
636    // =========================================================================
637
638    /// Verify Adam update matches the mathematical formula exactly.
639    /// After one step with grad=[1,1], lr=0.1, betas=(0.9,0.999):
640    ///   m = 0.1*[1,1], v = 0.001*[1,1]
641    ///   m_hat = m/0.1 = [1,1], v_hat = v/0.001 = [1,1]
642    ///   param -= 0.1 * 1.0 / (1.0 + 1e-8) ≈ 0.1
643    #[test]
644    fn test_adam_step_correctness() {
645        let var = Variable::new(Tensor::from_vec(vec![0.5, -0.3], &[2]).unwrap(), true);
646        let param = Parameter::from_variable(var);
647        param.set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
648
649        let mut opt = Adam::new(vec![param.clone()], 0.1);
650        let before = param.data().to_vec();
651        opt.step();
652        let after = param.data().to_vec();
653
654        // Both params should decrease (positive gradient → decrease)
655        assert!(
656            after[0] < before[0],
657            "param[0] should decrease: {} -> {}",
658            before[0],
659            after[0]
660        );
661        assert!(
662            after[1] < before[1],
663            "param[1] should decrease: {} -> {}",
664            before[1],
665            after[1]
666        );
667
668        // After one Adam step with uniform gradient, both should change by the same amount
669        let delta0 = before[0] - after[0];
670        let delta1 = before[1] - after[1];
671        assert!(
672            (delta0 - delta1).abs() < 1e-6,
673            "Uniform gradient should produce uniform update: {} vs {}",
674            delta0,
675            delta1
676        );
677    }
678
679    /// Verify Adam converges on a simple quadratic: minimize f(x) = x^2.
680    /// Uses autograd for proper gradient computation.
681    #[test]
682    fn test_adam_converges_on_quadratic() {
683        let var = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), true);
684        let param = Parameter::from_variable(var);
685        let mut opt = Adam::new(vec![param.clone()], 0.1);
686
687        for _ in 0..200 {
688            opt.zero_grad();
689            // f(x) = x^2 → loss, compute gradient via autograd
690            let x = param.variable();
691            let loss = x.mul_var(&x).sum(); // x^2
692            loss.backward();
693            opt.step();
694        }
695
696        let final_x = param.data().to_vec()[0];
697        assert!(
698            final_x.abs() < 0.1,
699            "Adam should converge near 0 for f(x)=x^2, got {}",
700            final_x
701        );
702    }
703
704    /// Verify zero_grad actually clears all gradients.
705    #[test]
706    fn test_adam_zero_grad() {
707        let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), true);
708        let param = Parameter::from_variable(var);
709        param.set_grad(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap());
710        assert!(param.grad().is_some());
711
712        let mut opt = Adam::new(vec![param.clone()], 0.01);
713        opt.zero_grad();
714        // After zero_grad, gradient should be None or all zeros
715        if let Some(g) = param.grad() {
716            let gv = g.to_vec();
717            assert!(
718                gv.iter().all(|&v| v.abs() < 1e-10),
719                "Gradients should be zero after zero_grad: {:?}",
720                gv
721            );
722        }
723    }
724
725    /// Verify set_lr / get_lr work correctly.
726    #[test]
727    fn test_adam_lr_management() {
728        let var = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), true);
729        let param = Parameter::from_variable(var);
730        let mut opt = Adam::new(vec![param], 0.001);
731
732        assert!((opt.get_lr() - 0.001).abs() < 1e-8);
733        opt.set_lr(0.01);
734        assert!((opt.get_lr() - 0.01).abs() < 1e-8);
735    }
736
737    /// Verify Adam handles no-grad params gracefully (skips them).
738    #[test]
739    fn test_adam_skips_frozen_params() {
740        let trainable = Parameter::from_variable(Variable::new(
741            Tensor::from_vec(vec![1.0], &[1]).unwrap(),
742            true,
743        ));
744        let frozen = Parameter::from_variable(Variable::new(
745            Tensor::from_vec(vec![2.0], &[1]).unwrap(),
746            false,
747        ));
748
749        trainable.set_grad(Tensor::from_vec(vec![1.0], &[1]).unwrap());
750
751        let mut opt = Adam::new(vec![trainable.clone(), frozen.clone()], 0.1);
752        opt.step();
753
754        // Trainable should change, frozen should not
755        assert!((trainable.data().to_vec()[0] - 1.0).abs() > 1e-6);
756        assert!((frozen.data().to_vec()[0] - 2.0).abs() < 1e-8);
757    }
758
759    /// Verify Adam with weight decay actually decays weights.
760    #[test]
761    fn test_adam_weight_decay() {
762        let var = Variable::new(Tensor::from_vec(vec![10.0], &[1]).unwrap(), true);
763        let param = Parameter::from_variable(var);
764        // Set zero gradient — only weight decay should modify params
765        param.set_grad(Tensor::from_vec(vec![0.0], &[1]).unwrap());
766
767        let mut opt = Adam::new(vec![param.clone()], 0.1).weight_decay(0.1);
768        let before = param.data().to_vec()[0];
769        opt.step();
770        let after = param.data().to_vec()[0];
771
772        // With weight_decay, even zero gradient should shrink params
773        // (grad_effective = grad + wd * param = 0 + 0.1 * 10.0 = 1.0)
774        assert!(
775            after < before,
776            "Weight decay should shrink large params: {} -> {}",
777            before,
778            after
779        );
780    }
781
782    /// Verify multiple Adam steps produce improvement on a simple loss using autograd.
783    #[test]
784    fn test_adam_multiple_steps_improve() {
785        let var = Variable::new(Tensor::from_vec(vec![3.0, -2.0], &[2]).unwrap(), true);
786        let param = Parameter::from_variable(var);
787        let mut opt = Adam::new(vec![param.clone()], 0.05);
788
789        let mut losses = Vec::new();
790        for _ in 0..50 {
791            opt.zero_grad();
792            let x = param.variable();
793            let loss = x.mul_var(&x).sum(); // ||x||^2
794            losses.push(loss.data().to_vec()[0]);
795            loss.backward();
796            opt.step();
797        }
798
799        // First loss should be much higher than last loss
800        let first = losses[0];
801        let last = *losses.last().unwrap();
802        assert!(
803            last < first * 0.5,
804            "Loss should decrease significantly: first={}, last={}",
805            first,
806            last
807        );
808    }
809
810    // =========================================================================
811    // AdamW Tests
812    // =========================================================================
813
814    /// Verify AdamW step works and decoupled weight decay differs from L2.
815    #[test]
816    fn test_adamw_step_correctness() {
817        let var = Variable::new(Tensor::from_vec(vec![5.0, -3.0], &[2]).unwrap(), true);
818        let param = Parameter::from_variable(var);
819        param.set_grad(Tensor::from_vec(vec![1.0, -1.0], &[2]).unwrap());
820
821        let mut opt = AdamW::new(vec![param.clone()], 0.01);
822        let before = param.data().to_vec();
823        opt.step();
824        let after = param.data().to_vec();
825
826        // Positive grad → decrease, negative grad → increase
827        assert!(after[0] < before[0], "Positive grad should decrease param");
828        assert!(after[1] > before[1], "Negative grad should increase param");
829    }
830
831    /// Verify AdamW converges using autograd.
832    #[test]
833    fn test_adamw_converges() {
834        let var = Variable::new(Tensor::from_vec(vec![4.0], &[1]).unwrap(), true);
835        let param = Parameter::from_variable(var);
836        let mut opt = AdamW::new(vec![param.clone()], 0.1);
837
838        for _ in 0..200 {
839            opt.zero_grad();
840            let x = param.variable();
841            let loss = x.mul_var(&x).sum();
842            loss.backward();
843            opt.step();
844        }
845
846        assert!(
847            param.data().to_vec()[0].abs() < 0.1,
848            "AdamW should converge near 0, got {}",
849            param.data().to_vec()[0]
850        );
851    }
852}