Skip to main content

entrenar/optim/
adamw.rs

1//! AdamW optimizer (Adam with decoupled Weight decay)
2
3use super::Optimizer;
4use crate::Tensor;
5use ndarray::Array1;
6use provable_contracts_macros::requires;
7
8/// AdamW optimizer
9///
10/// AdamW decouples weight decay from the gradient-based update, making it more
11/// effective than L2 regularization. Instead of adding weight decay to the gradient,
12/// it applies weight decay directly to the parameters.
13///
14/// Standard Adam with L2: θ_t = θ_{t-1} - lr * (m_t / (√v_t + ε) + λ * θ_{t-1})
15/// AdamW: θ_t = (1 - lr * λ) * θ_{t-1} - lr * m_t / (√v_t + ε)
16pub struct AdamW {
17    lr: f32,
18    beta1: f32,
19    beta2: f32,
20    epsilon: f32,
21    weight_decay: f32,
22    t: u64,
23    m: Vec<Option<Array1<f32>>>, // First moment
24    v: Vec<Option<Array1<f32>>>, // Second moment
25}
26
27impl AdamW {
28    /// Create a new AdamW optimizer
29    #[allow(clippy::manual_range_contains)]
30    #[requires(lr > 0.0 && beta1 >= 0.0 && beta1 < 1.0 && beta2 >= 0.0 && beta2 < 1.0 && epsilon > 0.0 && weight_decay >= 0.0)]
31    pub fn new(lr: f32, beta1: f32, beta2: f32, epsilon: f32, weight_decay: f32) -> Self {
32        Self { lr, beta1, beta2, epsilon, weight_decay, t: 0, m: Vec::new(), v: Vec::new() }
33    }
34
35    /// Create AdamW with default parameters (weight_decay = 0.01)
36    pub fn default_params(lr: f32) -> Self {
37        Self::new(lr, 0.9, 0.999, 1e-8, 0.01)
38    }
39
40    /// Initialize moments if needed
41    fn ensure_moments(&mut self, params: &[Tensor]) {
42        if self.m.is_empty() {
43            self.m = params.iter().map(|_| None).collect();
44            self.v = params.iter().map(|_| None).collect();
45        }
46    }
47
48    // ── Checkpoint state accessors (F-CKPT-004) ────────────────────────
49
50    /// Get optimizer step counter.
51    #[must_use]
52    pub fn step_count(&self) -> u64 {
53        self.t
54    }
55
56    /// Set optimizer step counter (for checkpoint resume).
57    pub fn set_step_count(&mut self, t: u64) {
58        self.t = t;
59    }
60
61    /// Get first moment buffers (m) as f32 slices.
62    #[must_use]
63    pub fn first_moments(&self) -> &[Option<Array1<f32>>] {
64        &self.m
65    }
66
67    /// Get second moment buffers (v) as f32 slices.
68    #[must_use]
69    pub fn second_moments(&self) -> &[Option<Array1<f32>>] {
70        &self.v
71    }
72
73    /// Set first moment buffer at index.
74    pub fn set_first_moment(&mut self, idx: usize, data: Array1<f32>) {
75        if idx >= self.m.len() {
76            self.m.resize(idx + 1, None);
77        }
78        self.m[idx] = Some(data);
79    }
80
81    /// Set second moment buffer at index.
82    pub fn set_second_moment(&mut self, idx: usize, data: Array1<f32>) {
83        if idx >= self.v.len() {
84            self.v.resize(idx + 1, None);
85        }
86        self.v[idx] = Some(data);
87    }
88
89    /// Get beta1 hyperparameter.
90    #[must_use]
91    pub fn beta1(&self) -> f32 {
92        self.beta1
93    }
94
95    /// Get beta2 hyperparameter.
96    #[must_use]
97    pub fn beta2(&self) -> f32 {
98        self.beta2
99    }
100
101    /// Get weight decay hyperparameter.
102    #[must_use]
103    pub fn weight_decay(&self) -> f32 {
104        self.weight_decay
105    }
106}
107
108impl Optimizer for AdamW {
109    #[requires(!params.is_empty())]
110    fn step(&mut self, params: &mut [Tensor]) {
111        self.ensure_moments(params);
112        self.t += 1;
113
114        // Bias adjust factors
115        let lr_t = self.lr
116            * ((1.0 - self.beta2.powi(self.t as i32)).sqrt()
117                / (1.0 - self.beta1.powi(self.t as i32)));
118
119        for (i, param) in params.iter_mut().enumerate() {
120            if let Some(grad) = param.grad() {
121                // Use SIMD for large tensors (>= 16 elements for meaningful speedup)
122                if grad.len() >= 16 {
123                    // Initialize moments if needed
124                    if self.m[i].is_none() {
125                        self.m[i] = Some(Array1::zeros(grad.len()));
126                        self.v[i] = Some(Array1::zeros(grad.len()));
127                    }
128
129                    let m = self.m[i].as_mut().expect("momentum buffer initialized above");
130                    let v = self.v[i].as_mut().expect("velocity buffer initialized above");
131
132                    // Get mutable slices (arrays are always contiguous)
133                    let grad_slice = grad.as_slice().expect("grad array is contiguous");
134                    let m_slice = m.as_slice_mut().expect("momentum array is contiguous");
135                    let v_slice = v.as_slice_mut().expect("velocity array is contiguous");
136                    let param_slice =
137                        param.data_mut().as_slice_mut().expect("param array is contiguous");
138
139                    // Use SIMD-accelerated update
140                    super::simd::simd_adamw_update(
141                        grad_slice,
142                        m_slice,
143                        v_slice,
144                        param_slice,
145                        self.beta1,
146                        self.beta2,
147                        self.lr,
148                        lr_t,
149                        self.weight_decay,
150                        self.epsilon,
151                    );
152                } else {
153                    // Fallback to scalar implementation for small tensors
154                    // m_t = β1 * m_{t-1} + (1 - β1) * g
155                    let m_t = if let Some(m) = &self.m[i] {
156                        m * self.beta1 + &grad * (1.0 - self.beta1)
157                    } else {
158                        &grad * (1.0 - self.beta1)
159                    };
160
161                    // v_t = β2 * v_{t-1} + (1 - β2) * g²
162                    let grad_sq = &grad * &grad;
163                    let v_t = if let Some(v) = &self.v[i] {
164                        v * self.beta2 + &grad_sq * (1.0 - self.beta2)
165                    } else {
166                        &grad_sq * (1.0 - self.beta2)
167                    };
168
169                    // AdamW update with decoupled weight decay:
170                    // θ_t = (1 - lr * λ) * θ_{t-1} - lr_t * m_t / (√v_t + ε)
171                    let adaptive_update = &m_t / &(v_t.mapv(f32::sqrt) + self.epsilon) * lr_t;
172
173                    // Apply weight decay directly to parameters (decoupled)
174                    let weight_decay_factor = 1.0 - self.lr * self.weight_decay;
175                    *param.data_mut() = param.data() * weight_decay_factor - &adaptive_update;
176
177                    self.m[i] = Some(m_t);
178                    self.v[i] = Some(v_t);
179                }
180            }
181        }
182    }
183
184    fn step_refs(&mut self, params: &mut [&mut Tensor]) {
185        contract_pre_weight_update!();
186        // Ensure moments are sized correctly
187        if self.m.len() < params.len() {
188            self.m.resize(params.len(), None);
189            self.v.resize(params.len(), None);
190        }
191        self.t += 1;
192
193        // Bias adjust factors
194        let lr_t = self.lr
195            * ((1.0 - self.beta2.powi(self.t as i32)).sqrt()
196                / (1.0 - self.beta1.powi(self.t as i32)));
197
198        for (i, param) in params.iter_mut().enumerate() {
199            if let Some(grad) = param.grad() {
200                // Use SIMD for large tensors (>= 16 elements for meaningful speedup)
201                if grad.len() >= 16 {
202                    // Initialize moments if needed
203                    if self.m[i].is_none() {
204                        self.m[i] = Some(Array1::zeros(grad.len()));
205                        self.v[i] = Some(Array1::zeros(grad.len()));
206                    }
207
208                    let m = self.m[i].as_mut().expect("momentum buffer initialized above");
209                    let v = self.v[i].as_mut().expect("velocity buffer initialized above");
210
211                    // Get mutable slices (arrays are always contiguous)
212                    let grad_slice = grad.as_slice().expect("grad array is contiguous");
213                    let m_slice = m.as_slice_mut().expect("momentum array is contiguous");
214                    let v_slice = v.as_slice_mut().expect("velocity array is contiguous");
215                    let param_slice =
216                        param.data_mut().as_slice_mut().expect("param array is contiguous");
217
218                    // Use SIMD-accelerated update
219                    super::simd::simd_adamw_update(
220                        grad_slice,
221                        m_slice,
222                        v_slice,
223                        param_slice,
224                        self.beta1,
225                        self.beta2,
226                        self.lr,
227                        lr_t,
228                        self.weight_decay,
229                        self.epsilon,
230                    );
231                } else {
232                    // Fallback to scalar implementation for small tensors
233                    let m_t = if let Some(m) = &self.m[i] {
234                        m * self.beta1 + &grad * (1.0 - self.beta1)
235                    } else {
236                        &grad * (1.0 - self.beta1)
237                    };
238
239                    let grad_sq = &grad * &grad;
240                    let v_t = if let Some(v) = &self.v[i] {
241                        v * self.beta2 + &grad_sq * (1.0 - self.beta2)
242                    } else {
243                        &grad_sq * (1.0 - self.beta2)
244                    };
245
246                    let adaptive_update = &m_t / &(v_t.mapv(f32::sqrt) + self.epsilon) * lr_t;
247                    let weight_decay_factor = 1.0 - self.lr * self.weight_decay;
248                    *param.data_mut() = param.data() * weight_decay_factor - &adaptive_update;
249
250                    self.m[i] = Some(m_t);
251                    self.v[i] = Some(v_t);
252                }
253            }
254        }
255    }
256
257    fn lr(&self) -> f32 {
258        self.lr
259    }
260
261    fn set_lr(&mut self, lr: f32) {
262        self.lr = lr;
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::autograd::*;
270    use approx::assert_abs_diff_eq;
271
272    #[test]
273    fn test_adamw_quadratic_convergence() {
274        // Test convergence on f(x) = x²
275        let mut params = vec![Tensor::from_vec(vec![5.0, -3.0, 2.0], true)];
276        let mut optimizer = AdamW::default_params(0.1);
277
278        for _ in 0..100 {
279            // Compute gradient: ∇(x²) = 2x
280            let grad = params[0].data().mapv(|x| 2.0 * x);
281            params[0].set_grad(grad);
282
283            optimizer.step(&mut params);
284        }
285
286        // Should converge close to 0
287        for &val in params[0].data() {
288            assert!(val.abs() < 0.5, "Value {val} did not converge");
289        }
290    }
291
292    #[test]
293    fn test_adamw_weight_decay() {
294        // Test that weight decay is properly applied
295        let mut params = vec![Tensor::from_vec(vec![1.0], true)];
296        let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.1);
297
298        // Zero gradient - only weight decay should apply
299        let grad = ndarray::arr1(&[0.0]);
300        params[0].set_grad(grad);
301
302        let initial_value = params[0].data()[0];
303        optimizer.step(&mut params);
304        let after_step = params[0].data()[0];
305
306        // With zero gradient, weight decay should reduce the parameter
307        // θ_t = (1 - lr * λ) * θ_{t-1} = (1 - 0.1 * 0.1) * 1.0 = 0.99
308        assert!(after_step < initial_value);
309        assert_abs_diff_eq!(after_step, 0.99, epsilon = 1e-6);
310    }
311
312    #[test]
313    fn test_adamw_vs_adam_difference() {
314        // AdamW and Adam should behave differently with weight decay
315        let mut params_adamw = vec![Tensor::from_vec(vec![2.0, -2.0], true)];
316        let mut params_adam = vec![Tensor::from_vec(vec![2.0, -2.0], true)];
317
318        let mut adamw = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.1);
319        let mut adam = super::super::Adam::default_params(0.1);
320
321        for _ in 0..10 {
322            // Same gradient for both
323            let grad = ndarray::arr1(&[1.0, -1.0]);
324
325            params_adamw[0].set_grad(grad.clone());
326            params_adam[0].set_grad(grad.clone());
327
328            adamw.step(&mut params_adamw);
329            adam.step(&mut params_adam);
330        }
331
332        // With weight decay, AdamW should have smaller absolute values
333        // (weight decay shrinks parameters toward zero)
334        assert!(params_adamw[0].data()[0].abs() < params_adam[0].data()[0].abs());
335        assert!(params_adamw[0].data()[1].abs() < params_adam[0].data()[1].abs());
336    }
337
338    // =========================================================================
339    // Additional Coverage Tests
340    // =========================================================================
341
342    #[test]
343    fn test_adamw_simd_path() {
344        // Test with >= 16 elements to exercise SIMD path
345        let data: Vec<f32> = (0..32).map(|i| i as f32).collect();
346        let mut params = vec![Tensor::from_vec(data, true)];
347        let mut optimizer = AdamW::default_params(0.01);
348
349        for _ in 0..10 {
350            let grad = params[0].data().mapv(|x| 2.0 * x);
351            params[0].set_grad(grad);
352            optimizer.step(&mut params);
353        }
354
355        // Just verify it runs without panic
356        assert_eq!(params[0].data().len(), 32);
357    }
358
359    #[test]
360    fn test_adamw_simd_convergence() {
361        // Test convergence with SIMD path (32 elements)
362        let data: Vec<f32> = (0..32).map(|i| (i as f32) - 16.0).collect();
363        let mut params = vec![Tensor::from_vec(data.clone(), true)];
364        let mut optimizer = AdamW::default_params(0.1);
365
366        let initial_mean: f32 = data.iter().map(|x| x.abs()).sum::<f32>() / 32.0;
367        for _ in 0..100 {
368            let grad = params[0].data().mapv(|x| 2.0 * x);
369            params[0].set_grad(grad);
370            optimizer.step(&mut params);
371        }
372
373        // Should make progress toward 0
374        let final_mean: f32 = params[0].data().iter().map(|x| x.abs()).sum::<f32>() / 32.0;
375        assert!(final_mean < initial_mean, "Mean {final_mean} did not improve from {initial_mean}");
376    }
377
378    #[test]
379    fn test_adamw_lr_getter_setter() {
380        let mut optimizer = AdamW::default_params(0.1);
381        assert_abs_diff_eq!(optimizer.lr(), 0.1, epsilon = 1e-6);
382
383        optimizer.set_lr(0.01);
384        assert_abs_diff_eq!(optimizer.lr(), 0.01, epsilon = 1e-6);
385    }
386
387    #[test]
388    fn test_adamw_multiple_params() {
389        let mut params =
390            vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0, 4.0], true)];
391        let mut optimizer = AdamW::default_params(0.1);
392
393        // Set gradients for both
394        params[0].set_grad(ndarray::arr1(&[0.1, 0.2]));
395        params[1].set_grad(ndarray::arr1(&[0.3, 0.4]));
396
397        optimizer.step(&mut params);
398
399        // Both params should be updated
400        assert!(params[0].data()[0] < 1.0);
401        assert!(params[1].data()[0] < 3.0);
402    }
403
404    #[test]
405    fn test_adamw_no_grad() {
406        let mut params = vec![Tensor::from_vec(vec![1.0, 2.0], false)]; // requires_grad=false
407        let mut optimizer = AdamW::default_params(0.1);
408
409        let initial = params[0].data().clone();
410        optimizer.step(&mut params);
411
412        // No gradient, so params unchanged
413        assert_eq!(params[0].data(), &initial);
414    }
415
416    #[test]
417    fn test_adamw_momentum_accumulation() {
418        let mut params = vec![Tensor::from_vec(vec![5.0], true)];
419        let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.0); // No weight decay
420
421        let initial = params[0].data()[0];
422        // Multiple steps with same gradient should accumulate momentum
423        for _ in 0..5 {
424            params[0].set_grad(ndarray::arr1(&[1.0]));
425            optimizer.step(&mut params);
426        }
427
428        // Should have moved due to gradient (direction depends on sign)
429        assert!(params[0].data()[0] != initial, "Parameter did not change");
430    }
431
432    #[test]
433    fn test_adamw_simd_multiple_steps() {
434        // Test multiple steps with SIMD to cover momentum accumulation
435        let data: Vec<f32> = vec![1.0; 20];
436        let mut params = vec![Tensor::from_vec(data, true)];
437        let mut optimizer = AdamW::default_params(0.1);
438
439        for step in 0..5 {
440            let grad = params[0].data().mapv(|_| 1.0);
441            params[0].set_grad(grad);
442            optimizer.step(&mut params);
443
444            // Verify progress
445            assert!(
446                params[0].data()[0] < 1.0 - (step as f32 * 0.05),
447                "Step {step} did not make progress"
448            );
449        }
450    }
451
452    #[test]
453    fn test_adamw_zero_weight_decay() {
454        let mut params = vec![Tensor::from_vec(vec![1.0], true)];
455        let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.0); // Zero weight decay
456
457        // Zero gradient
458        params[0].set_grad(ndarray::arr1(&[0.0]));
459        let initial = params[0].data()[0];
460        optimizer.step(&mut params);
461
462        // With zero gradient and zero weight decay, param should be unchanged
463        assert_abs_diff_eq!(params[0].data()[0], initial, epsilon = 1e-6);
464    }
465
466    #[test]
467    fn test_adamw_bias_adjust() {
468        // Test that bias adjust is applied correctly
469        let mut params = vec![Tensor::from_vec(vec![0.0], true)];
470        let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.0);
471
472        // First step should have large bias adjust
473        params[0].set_grad(ndarray::arr1(&[1.0]));
474        optimizer.step(&mut params);
475        let after_first = params[0].data()[0];
476
477        // Step size should be close to lr due to bias adjust
478        assert!(after_first.abs() > 0.05, "Bias adjust not applied");
479    }
480
481    // =========================================================================
482    // FALSIFY-AW: adamw-kernel-v1.yaml contract (entrenar AdamW)
483    //
484    // Five-Whys (PMAT-354):
485    //   Why 1: entrenar had 11 AdamW tests but zero FALSIFY-AW-* tests
486    //   Why 2: tests verify convergence/params, not optimizer invariants
487    //   Why 3: no mapping from adamw-kernel-v1.yaml to entrenar test names
488    //   Why 4: entrenar predates the provable-contracts YAML convention
489    //   Why 5: AdamW was "obviously correct" (standard implementation)
490    //
491    // References:
492    //   - provable-contracts/contracts/adamw-kernel-v1.yaml
493    //   - Loshchilov & Hutter (2019) "Decoupled Weight Decay Regularization"
494    // =========================================================================
495
496    /// FALSIFY-AW-002e: Second moment non-negativity
497    #[test]
498    fn falsify_aw_002e_second_moment_non_negative() {
499        let mut params = vec![Tensor::from_vec(vec![5.0, -3.0, 2.0, -1.0], true)];
500        let mut optimizer = AdamW::default_params(0.01);
501
502        for step in 0..50 {
503            let grad = params[0].data().mapv(|x| ((x + step as f32) * 0.37).sin() * 5.0);
504            params[0].set_grad(grad);
505            optimizer.step(&mut params);
506        }
507
508        // Check v (second moment) is non-negative
509        for v_arr in optimizer.v.iter().flatten() {
510            for (j, &v_val) in v_arr.iter().enumerate() {
511                assert!(v_val >= 0.0, "FALSIFIED AW-002e: v[{j}] = {v_val} < 0 after 50 steps");
512            }
513        }
514    }
515
516    /// FALSIFY-AW-003e: Bias adjust factor > 1
517    #[test]
518    fn falsify_aw_003e_bias_adjust() {
519        for &beta in &[0.9_f32, 0.99, 0.999] {
520            for t in 1..=100i32 {
521                let adjust = 1.0 / (1.0 - beta.powi(t));
522                assert!(adjust > 1.0, "FALSIFIED AW-003e: 1/(1-{beta}^{t}) = {adjust} not > 1");
523            }
524        }
525    }
526
527    /// FALSIFY-AW-004e: Update finiteness with extreme values
528    #[test]
529    fn falsify_aw_004e_update_finiteness() {
530        let mut params = vec![Tensor::from_vec(vec![1e6, -1e6, 1e-6, -1e-6], true)];
531        let mut optimizer = AdamW::default_params(0.001);
532
533        let grad = params[0].data().mapv(|x| 2.0 * x);
534        params[0].set_grad(grad);
535        optimizer.step(&mut params);
536
537        for (i, &val) in params[0].data().iter().enumerate() {
538            assert!(val.is_finite(), "FALSIFIED AW-004e: param[{i}] = {val} (not finite)");
539        }
540    }
541
542    /// FALSIFY-AW-006e: Zero gradient — only weight decay modifies theta
543    #[test]
544    fn falsify_aw_006e_zero_gradient_weight_decay_only() {
545        let init_vals = vec![5.0, -3.0, 2.0];
546        let mut params = vec![Tensor::from_vec(init_vals.clone(), true)];
547        let lr = 0.01;
548        let wd = 0.1;
549        let mut optimizer = AdamW::new(lr, 0.9, 0.999, 1e-8, wd);
550
551        // Set zero gradient
552        params[0].set_grad(ndarray::Array1::zeros(3));
553        optimizer.step(&mut params);
554
555        // With zero gradient, only weight decay: theta_new ≈ theta * (1 - lr*wd)
556        let factor = 1.0 - lr * wd;
557        for (i, (&val, &init)) in params[0].data().iter().zip(init_vals.iter()).enumerate() {
558            let expected = init * factor;
559            let diff = (val - expected).abs();
560            assert!(
561                diff < 1e-4,
562                "FALSIFIED AW-006e: param[{i}] = {val}, expected {expected} (only wd)"
563            );
564        }
565    }
566
567    #[test]
568    fn test_adamw_checkpoint_accessors() {
569        let mut opt = AdamW::default_params(0.01);
570        assert_eq!(opt.step_count(), 0);
571        opt.set_step_count(42);
572        assert_eq!(opt.step_count(), 42);
573        assert_eq!(opt.beta1(), 0.9);
574        assert_eq!(opt.beta2(), 0.999);
575        assert!((opt.weight_decay() - 0.01).abs() < 1e-6);
576    }
577
578    #[test]
579    fn test_adamw_moment_set_get() {
580        let mut opt = AdamW::default_params(0.01);
581        // Initially empty
582        assert!(opt.first_moments().is_empty());
583        assert!(opt.second_moments().is_empty());
584        // Set at index 0
585        opt.set_first_moment(0, ndarray::arr1(&[1.0, 2.0]));
586        opt.set_second_moment(0, ndarray::arr1(&[0.5, 0.5]));
587        assert_eq!(opt.first_moments().len(), 1);
588        assert_eq!(opt.second_moments().len(), 1);
589        // Set at index 3 (should resize)
590        opt.set_first_moment(3, ndarray::arr1(&[3.0]));
591        assert_eq!(opt.first_moments().len(), 4);
592        assert!(opt.first_moments()[1].is_none());
593        assert!(opt.first_moments()[3].is_some());
594    }
595
596    #[test]
597    fn test_adamw_scalar_fallback_path() {
598        // Small tensor (<16 elements) triggers scalar fallback
599        let mut params = vec![Tensor::from_vec(vec![2.0, -1.0], true)];
600        let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.01);
601
602        // Multi-step to hit all moment update paths
603        for _ in 0..3 {
604            let grad = params[0].data().mapv(|x| 2.0 * x);
605            params[0].set_grad(grad);
606            optimizer.step(&mut params);
607        }
608        // Params should converge toward 0
609        assert!(params[0].data()[0].abs() < 2.0);
610    }
611
612    mod aw_proptest_falsify {
613        use super::*;
614        use proptest::prelude::*;
615
616        // FALSIFY-AW-002e-prop: Second moment non-negative for random gradients
617        proptest! {
618            #![proptest_config(ProptestConfig::with_cases(50))]
619
620            #[test]
621            fn falsify_aw_002e_prop_second_moment_non_negative(
622                seed in 0..500u32,
623            ) {
624                let beta2 = 0.999_f32;
625                let n = 4;
626                let mut v = vec![0.0_f32; n];
627
628                for step in 0..20 {
629                    let g: Vec<f32> = (0..n)
630                        .map(|i| ((i as f32 + seed as f32 + step as f32 * 13.0) * 0.37).sin() * 10.0)
631                        .collect();
632                    for i in 0..n {
633                        v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i];
634                    }
635                }
636
637                for (i, &vi) in v.iter().enumerate() {
638                    prop_assert!(vi >= 0.0, "FALSIFIED AW-002e-prop: v[{}] = {} < 0", i, vi);
639                }
640            }
641        }
642
643        // FALSIFY-AW-004e-prop: Update finiteness for random initial params
644        proptest! {
645            #![proptest_config(ProptestConfig::with_cases(50))]
646
647            #[test]
648            fn falsify_aw_004e_prop_update_finiteness(
649                seed in 0..500u32,
650            ) {
651                let data: Vec<f32> = (0..4)
652                    .map(|i| ((i as f32 + seed as f32) * 0.37).sin() * 100.0)
653                    .collect();
654                let mut params = vec![Tensor::from_vec(data.clone(), true)];
655                let mut optimizer = AdamW::default_params(0.001);
656
657                let grad_data: Vec<f32> = data.iter().map(|&x| 2.0 * x).collect();
658                params[0].set_grad(ndarray::Array1::from(grad_data));
659                optimizer.step(&mut params);
660
661                for (i, &val) in params[0].data().iter().enumerate() {
662                    prop_assert!(
663                        val.is_finite(),
664                        "FALSIFIED AW-004e-prop: param[{}] = {} (not finite)",
665                        i, val
666                    );
667                }
668            }
669        }
670    }
671}