Skip to main content

entrenar/optim/convergence_tests/
adamw_tests.rs

1//! AdamW optimizer convergence tests
2
3#[cfg(test)]
4mod tests {
5    use super::super::helpers::*;
6    use crate::optim::*;
7    use crate::Tensor;
8    use proptest::prelude::*;
9    use proptest::test_runner::Config;
10
11    proptest! {
12        #[test]
13        fn prop_adamw_converges_quadratic(
14            lr in 0.05f32..0.5
15        ) {
16            let optimizer = AdamW::default_params(lr);
17            prop_assert!(test_quadratic_convergence(optimizer, 100, 1.5));
18        }
19
20        #[test]
21        fn prop_adamw_loss_decreases(
22            lr in 0.01f32..0.3
23        ) {
24            let optimizer = AdamW::default_params(lr);
25            prop_assert!(test_loss_decreases(optimizer, 30));
26        }
27    }
28
29    #[test]
30    fn test_adamw_weight_decay_effect() {
31        // AdamW with weight decay should have smaller final weights than Adam
32        let mut params_adamw = vec![Tensor::from_vec(vec![2.0, 2.0], true)];
33        let mut params_adam = vec![Tensor::from_vec(vec![2.0, 2.0], true)];
34
35        let mut adamw = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.01);
36        let mut adam = Adam::new(0.1, 0.9, 0.999, 1e-8);
37
38        for _ in 0..50 {
39            // Same small gradient for both
40            let grad = ndarray::arr1(&[0.1, 0.1]);
41            params_adamw[0].set_grad(grad.clone());
42            params_adam[0].set_grad(grad);
43
44            adamw.step(&mut params_adamw);
45            adam.step(&mut params_adam);
46        }
47
48        // AdamW should have smaller weights due to weight decay
49        let adamw_norm: f32 = params_adamw[0].data().iter().map(|&x| x * x).sum::<f32>().sqrt();
50        let adam_norm: f32 = params_adam[0].data().iter().map(|&x| x * x).sum::<f32>().sqrt();
51
52        assert!(adamw_norm < adam_norm);
53    }
54
55    // ========================================================================
56    // EXTENDED PROPERTY TESTS - High iteration counts for quality validation
57    // ========================================================================
58
59    proptest! {
60        #![proptest_config(Config::with_cases(1000))]
61
62        #[test]
63        fn prop_adamw_ill_conditioned(
64            lr in 0.05f32..0.2,
65            weight_decay in 0.0f32..0.05
66        ) {
67            let optimizer = AdamW::new(lr, 0.9, 0.999, 1e-8, weight_decay);
68            prop_assert!(test_ill_conditioned_convergence(optimizer, 300, 10.0));
69        }
70
71        #[test]
72        fn prop_numerical_stability_adamw(
73            lr in 0.001f32..0.5,
74            weight_decay in 0.0f32..0.5
75        ) {
76            let optimizer = AdamW::new(lr, 0.9, 0.999, 1e-8, weight_decay);
77            prop_assert!(test_large_gradient_stability(optimizer));
78        }
79    }
80
81    // ========================================================================
82    // DETERMINISTIC CONVERGENCE TESTS
83    // ========================================================================
84
85    #[test]
86    fn test_adamw_regularization_strength() {
87        // Higher weight decay = smaller final weights
88        let mut params_high = vec![Tensor::from_vec(vec![5.0, 5.0], true)];
89        let mut params_low = vec![Tensor::from_vec(vec![5.0, 5.0], true)];
90
91        let mut opt_high = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.1);
92        let mut opt_low = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.001);
93
94        for _ in 0..100 {
95            // Constant small gradient
96            let grad = ndarray::arr1(&[0.01, 0.01]);
97            params_high[0].set_grad(grad.clone());
98            params_low[0].set_grad(grad);
99            opt_high.step(&mut params_high);
100            opt_low.step(&mut params_low);
101        }
102
103        let norm_high: f32 = params_high[0].data().iter().map(|x| x * x).sum();
104        let norm_low: f32 = params_low[0].data().iter().map(|x| x * x).sum();
105
106        assert!(norm_high < norm_low);
107    }
108}