1use super::Optimizer;
4use crate::Tensor;
5use ndarray::Array1;
6use provable_contracts_macros::requires;
7
8pub struct AdamW {
17 lr: f32,
18 beta1: f32,
19 beta2: f32,
20 epsilon: f32,
21 weight_decay: f32,
22 t: u64,
23 m: Vec<Option<Array1<f32>>>, v: Vec<Option<Array1<f32>>>, }
26
27impl AdamW {
28 #[allow(clippy::manual_range_contains)]
30 #[requires(lr > 0.0 && beta1 >= 0.0 && beta1 < 1.0 && beta2 >= 0.0 && beta2 < 1.0 && epsilon > 0.0 && weight_decay >= 0.0)]
31 pub fn new(lr: f32, beta1: f32, beta2: f32, epsilon: f32, weight_decay: f32) -> Self {
32 Self { lr, beta1, beta2, epsilon, weight_decay, t: 0, m: Vec::new(), v: Vec::new() }
33 }
34
35 pub fn default_params(lr: f32) -> Self {
37 Self::new(lr, 0.9, 0.999, 1e-8, 0.01)
38 }
39
40 fn ensure_moments(&mut self, params: &[Tensor]) {
42 if self.m.is_empty() {
43 self.m = params.iter().map(|_| None).collect();
44 self.v = params.iter().map(|_| None).collect();
45 }
46 }
47
48 #[must_use]
52 pub fn step_count(&self) -> u64 {
53 self.t
54 }
55
56 pub fn set_step_count(&mut self, t: u64) {
58 self.t = t;
59 }
60
61 #[must_use]
63 pub fn first_moments(&self) -> &[Option<Array1<f32>>] {
64 &self.m
65 }
66
67 #[must_use]
69 pub fn second_moments(&self) -> &[Option<Array1<f32>>] {
70 &self.v
71 }
72
73 pub fn set_first_moment(&mut self, idx: usize, data: Array1<f32>) {
75 if idx >= self.m.len() {
76 self.m.resize(idx + 1, None);
77 }
78 self.m[idx] = Some(data);
79 }
80
81 pub fn set_second_moment(&mut self, idx: usize, data: Array1<f32>) {
83 if idx >= self.v.len() {
84 self.v.resize(idx + 1, None);
85 }
86 self.v[idx] = Some(data);
87 }
88
89 #[must_use]
91 pub fn beta1(&self) -> f32 {
92 self.beta1
93 }
94
95 #[must_use]
97 pub fn beta2(&self) -> f32 {
98 self.beta2
99 }
100
101 #[must_use]
103 pub fn weight_decay(&self) -> f32 {
104 self.weight_decay
105 }
106}
107
108impl Optimizer for AdamW {
109 #[requires(!params.is_empty())]
110 fn step(&mut self, params: &mut [Tensor]) {
111 self.ensure_moments(params);
112 self.t += 1;
113
114 let lr_t = self.lr
116 * ((1.0 - self.beta2.powi(self.t as i32)).sqrt()
117 / (1.0 - self.beta1.powi(self.t as i32)));
118
119 for (i, param) in params.iter_mut().enumerate() {
120 if let Some(grad) = param.grad() {
121 if grad.len() >= 16 {
123 if self.m[i].is_none() {
125 self.m[i] = Some(Array1::zeros(grad.len()));
126 self.v[i] = Some(Array1::zeros(grad.len()));
127 }
128
129 let m = self.m[i].as_mut().expect("momentum buffer initialized above");
130 let v = self.v[i].as_mut().expect("velocity buffer initialized above");
131
132 let grad_slice = grad.as_slice().expect("grad array is contiguous");
134 let m_slice = m.as_slice_mut().expect("momentum array is contiguous");
135 let v_slice = v.as_slice_mut().expect("velocity array is contiguous");
136 let param_slice =
137 param.data_mut().as_slice_mut().expect("param array is contiguous");
138
139 super::simd::simd_adamw_update(
141 grad_slice,
142 m_slice,
143 v_slice,
144 param_slice,
145 self.beta1,
146 self.beta2,
147 self.lr,
148 lr_t,
149 self.weight_decay,
150 self.epsilon,
151 );
152 } else {
153 let m_t = if let Some(m) = &self.m[i] {
156 m * self.beta1 + &grad * (1.0 - self.beta1)
157 } else {
158 &grad * (1.0 - self.beta1)
159 };
160
161 let grad_sq = &grad * &grad;
163 let v_t = if let Some(v) = &self.v[i] {
164 v * self.beta2 + &grad_sq * (1.0 - self.beta2)
165 } else {
166 &grad_sq * (1.0 - self.beta2)
167 };
168
169 let adaptive_update = &m_t / &(v_t.mapv(f32::sqrt) + self.epsilon) * lr_t;
172
173 let weight_decay_factor = 1.0 - self.lr * self.weight_decay;
175 *param.data_mut() = param.data() * weight_decay_factor - &adaptive_update;
176
177 self.m[i] = Some(m_t);
178 self.v[i] = Some(v_t);
179 }
180 }
181 }
182 }
183
184 fn step_refs(&mut self, params: &mut [&mut Tensor]) {
185 contract_pre_weight_update!();
186 if self.m.len() < params.len() {
188 self.m.resize(params.len(), None);
189 self.v.resize(params.len(), None);
190 }
191 self.t += 1;
192
193 let lr_t = self.lr
195 * ((1.0 - self.beta2.powi(self.t as i32)).sqrt()
196 / (1.0 - self.beta1.powi(self.t as i32)));
197
198 for (i, param) in params.iter_mut().enumerate() {
199 if let Some(grad) = param.grad() {
200 if grad.len() >= 16 {
202 if self.m[i].is_none() {
204 self.m[i] = Some(Array1::zeros(grad.len()));
205 self.v[i] = Some(Array1::zeros(grad.len()));
206 }
207
208 let m = self.m[i].as_mut().expect("momentum buffer initialized above");
209 let v = self.v[i].as_mut().expect("velocity buffer initialized above");
210
211 let grad_slice = grad.as_slice().expect("grad array is contiguous");
213 let m_slice = m.as_slice_mut().expect("momentum array is contiguous");
214 let v_slice = v.as_slice_mut().expect("velocity array is contiguous");
215 let param_slice =
216 param.data_mut().as_slice_mut().expect("param array is contiguous");
217
218 super::simd::simd_adamw_update(
220 grad_slice,
221 m_slice,
222 v_slice,
223 param_slice,
224 self.beta1,
225 self.beta2,
226 self.lr,
227 lr_t,
228 self.weight_decay,
229 self.epsilon,
230 );
231 } else {
232 let m_t = if let Some(m) = &self.m[i] {
234 m * self.beta1 + &grad * (1.0 - self.beta1)
235 } else {
236 &grad * (1.0 - self.beta1)
237 };
238
239 let grad_sq = &grad * &grad;
240 let v_t = if let Some(v) = &self.v[i] {
241 v * self.beta2 + &grad_sq * (1.0 - self.beta2)
242 } else {
243 &grad_sq * (1.0 - self.beta2)
244 };
245
246 let adaptive_update = &m_t / &(v_t.mapv(f32::sqrt) + self.epsilon) * lr_t;
247 let weight_decay_factor = 1.0 - self.lr * self.weight_decay;
248 *param.data_mut() = param.data() * weight_decay_factor - &adaptive_update;
249
250 self.m[i] = Some(m_t);
251 self.v[i] = Some(v_t);
252 }
253 }
254 }
255 }
256
257 fn lr(&self) -> f32 {
258 self.lr
259 }
260
261 fn set_lr(&mut self, lr: f32) {
262 self.lr = lr;
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::autograd::*;
270 use approx::assert_abs_diff_eq;
271
272 #[test]
273 fn test_adamw_quadratic_convergence() {
274 let mut params = vec![Tensor::from_vec(vec![5.0, -3.0, 2.0], true)];
276 let mut optimizer = AdamW::default_params(0.1);
277
278 for _ in 0..100 {
279 let grad = params[0].data().mapv(|x| 2.0 * x);
281 params[0].set_grad(grad);
282
283 optimizer.step(&mut params);
284 }
285
286 for &val in params[0].data() {
288 assert!(val.abs() < 0.5, "Value {val} did not converge");
289 }
290 }
291
292 #[test]
293 fn test_adamw_weight_decay() {
294 let mut params = vec![Tensor::from_vec(vec![1.0], true)];
296 let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.1);
297
298 let grad = ndarray::arr1(&[0.0]);
300 params[0].set_grad(grad);
301
302 let initial_value = params[0].data()[0];
303 optimizer.step(&mut params);
304 let after_step = params[0].data()[0];
305
306 assert!(after_step < initial_value);
309 assert_abs_diff_eq!(after_step, 0.99, epsilon = 1e-6);
310 }
311
312 #[test]
313 fn test_adamw_vs_adam_difference() {
314 let mut params_adamw = vec![Tensor::from_vec(vec![2.0, -2.0], true)];
316 let mut params_adam = vec![Tensor::from_vec(vec![2.0, -2.0], true)];
317
318 let mut adamw = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.1);
319 let mut adam = super::super::Adam::default_params(0.1);
320
321 for _ in 0..10 {
322 let grad = ndarray::arr1(&[1.0, -1.0]);
324
325 params_adamw[0].set_grad(grad.clone());
326 params_adam[0].set_grad(grad.clone());
327
328 adamw.step(&mut params_adamw);
329 adam.step(&mut params_adam);
330 }
331
332 assert!(params_adamw[0].data()[0].abs() < params_adam[0].data()[0].abs());
335 assert!(params_adamw[0].data()[1].abs() < params_adam[0].data()[1].abs());
336 }
337
338 #[test]
343 fn test_adamw_simd_path() {
344 let data: Vec<f32> = (0..32).map(|i| i as f32).collect();
346 let mut params = vec![Tensor::from_vec(data, true)];
347 let mut optimizer = AdamW::default_params(0.01);
348
349 for _ in 0..10 {
350 let grad = params[0].data().mapv(|x| 2.0 * x);
351 params[0].set_grad(grad);
352 optimizer.step(&mut params);
353 }
354
355 assert_eq!(params[0].data().len(), 32);
357 }
358
359 #[test]
360 fn test_adamw_simd_convergence() {
361 let data: Vec<f32> = (0..32).map(|i| (i as f32) - 16.0).collect();
363 let mut params = vec![Tensor::from_vec(data.clone(), true)];
364 let mut optimizer = AdamW::default_params(0.1);
365
366 let initial_mean: f32 = data.iter().map(|x| x.abs()).sum::<f32>() / 32.0;
367 for _ in 0..100 {
368 let grad = params[0].data().mapv(|x| 2.0 * x);
369 params[0].set_grad(grad);
370 optimizer.step(&mut params);
371 }
372
373 let final_mean: f32 = params[0].data().iter().map(|x| x.abs()).sum::<f32>() / 32.0;
375 assert!(final_mean < initial_mean, "Mean {final_mean} did not improve from {initial_mean}");
376 }
377
378 #[test]
379 fn test_adamw_lr_getter_setter() {
380 let mut optimizer = AdamW::default_params(0.1);
381 assert_abs_diff_eq!(optimizer.lr(), 0.1, epsilon = 1e-6);
382
383 optimizer.set_lr(0.01);
384 assert_abs_diff_eq!(optimizer.lr(), 0.01, epsilon = 1e-6);
385 }
386
387 #[test]
388 fn test_adamw_multiple_params() {
389 let mut params =
390 vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0, 4.0], true)];
391 let mut optimizer = AdamW::default_params(0.1);
392
393 params[0].set_grad(ndarray::arr1(&[0.1, 0.2]));
395 params[1].set_grad(ndarray::arr1(&[0.3, 0.4]));
396
397 optimizer.step(&mut params);
398
399 assert!(params[0].data()[0] < 1.0);
401 assert!(params[1].data()[0] < 3.0);
402 }
403
404 #[test]
405 fn test_adamw_no_grad() {
406 let mut params = vec![Tensor::from_vec(vec![1.0, 2.0], false)]; let mut optimizer = AdamW::default_params(0.1);
408
409 let initial = params[0].data().clone();
410 optimizer.step(&mut params);
411
412 assert_eq!(params[0].data(), &initial);
414 }
415
416 #[test]
417 fn test_adamw_momentum_accumulation() {
418 let mut params = vec![Tensor::from_vec(vec![5.0], true)];
419 let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.0); let initial = params[0].data()[0];
422 for _ in 0..5 {
424 params[0].set_grad(ndarray::arr1(&[1.0]));
425 optimizer.step(&mut params);
426 }
427
428 assert!(params[0].data()[0] != initial, "Parameter did not change");
430 }
431
432 #[test]
433 fn test_adamw_simd_multiple_steps() {
434 let data: Vec<f32> = vec![1.0; 20];
436 let mut params = vec![Tensor::from_vec(data, true)];
437 let mut optimizer = AdamW::default_params(0.1);
438
439 for step in 0..5 {
440 let grad = params[0].data().mapv(|_| 1.0);
441 params[0].set_grad(grad);
442 optimizer.step(&mut params);
443
444 assert!(
446 params[0].data()[0] < 1.0 - (step as f32 * 0.05),
447 "Step {step} did not make progress"
448 );
449 }
450 }
451
452 #[test]
453 fn test_adamw_zero_weight_decay() {
454 let mut params = vec![Tensor::from_vec(vec![1.0], true)];
455 let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.0); params[0].set_grad(ndarray::arr1(&[0.0]));
459 let initial = params[0].data()[0];
460 optimizer.step(&mut params);
461
462 assert_abs_diff_eq!(params[0].data()[0], initial, epsilon = 1e-6);
464 }
465
466 #[test]
467 fn test_adamw_bias_adjust() {
468 let mut params = vec![Tensor::from_vec(vec![0.0], true)];
470 let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.0);
471
472 params[0].set_grad(ndarray::arr1(&[1.0]));
474 optimizer.step(&mut params);
475 let after_first = params[0].data()[0];
476
477 assert!(after_first.abs() > 0.05, "Bias adjust not applied");
479 }
480
481 #[test]
498 fn falsify_aw_002e_second_moment_non_negative() {
499 let mut params = vec![Tensor::from_vec(vec![5.0, -3.0, 2.0, -1.0], true)];
500 let mut optimizer = AdamW::default_params(0.01);
501
502 for step in 0..50 {
503 let grad = params[0].data().mapv(|x| ((x + step as f32) * 0.37).sin() * 5.0);
504 params[0].set_grad(grad);
505 optimizer.step(&mut params);
506 }
507
508 for v_arr in optimizer.v.iter().flatten() {
510 for (j, &v_val) in v_arr.iter().enumerate() {
511 assert!(v_val >= 0.0, "FALSIFIED AW-002e: v[{j}] = {v_val} < 0 after 50 steps");
512 }
513 }
514 }
515
516 #[test]
518 fn falsify_aw_003e_bias_adjust() {
519 for &beta in &[0.9_f32, 0.99, 0.999] {
520 for t in 1..=100i32 {
521 let adjust = 1.0 / (1.0 - beta.powi(t));
522 assert!(adjust > 1.0, "FALSIFIED AW-003e: 1/(1-{beta}^{t}) = {adjust} not > 1");
523 }
524 }
525 }
526
527 #[test]
529 fn falsify_aw_004e_update_finiteness() {
530 let mut params = vec![Tensor::from_vec(vec![1e6, -1e6, 1e-6, -1e-6], true)];
531 let mut optimizer = AdamW::default_params(0.001);
532
533 let grad = params[0].data().mapv(|x| 2.0 * x);
534 params[0].set_grad(grad);
535 optimizer.step(&mut params);
536
537 for (i, &val) in params[0].data().iter().enumerate() {
538 assert!(val.is_finite(), "FALSIFIED AW-004e: param[{i}] = {val} (not finite)");
539 }
540 }
541
542 #[test]
544 fn falsify_aw_006e_zero_gradient_weight_decay_only() {
545 let init_vals = vec![5.0, -3.0, 2.0];
546 let mut params = vec![Tensor::from_vec(init_vals.clone(), true)];
547 let lr = 0.01;
548 let wd = 0.1;
549 let mut optimizer = AdamW::new(lr, 0.9, 0.999, 1e-8, wd);
550
551 params[0].set_grad(ndarray::Array1::zeros(3));
553 optimizer.step(&mut params);
554
555 let factor = 1.0 - lr * wd;
557 for (i, (&val, &init)) in params[0].data().iter().zip(init_vals.iter()).enumerate() {
558 let expected = init * factor;
559 let diff = (val - expected).abs();
560 assert!(
561 diff < 1e-4,
562 "FALSIFIED AW-006e: param[{i}] = {val}, expected {expected} (only wd)"
563 );
564 }
565 }
566
567 #[test]
568 fn test_adamw_checkpoint_accessors() {
569 let mut opt = AdamW::default_params(0.01);
570 assert_eq!(opt.step_count(), 0);
571 opt.set_step_count(42);
572 assert_eq!(opt.step_count(), 42);
573 assert_eq!(opt.beta1(), 0.9);
574 assert_eq!(opt.beta2(), 0.999);
575 assert!((opt.weight_decay() - 0.01).abs() < 1e-6);
576 }
577
578 #[test]
579 fn test_adamw_moment_set_get() {
580 let mut opt = AdamW::default_params(0.01);
581 assert!(opt.first_moments().is_empty());
583 assert!(opt.second_moments().is_empty());
584 opt.set_first_moment(0, ndarray::arr1(&[1.0, 2.0]));
586 opt.set_second_moment(0, ndarray::arr1(&[0.5, 0.5]));
587 assert_eq!(opt.first_moments().len(), 1);
588 assert_eq!(opt.second_moments().len(), 1);
589 opt.set_first_moment(3, ndarray::arr1(&[3.0]));
591 assert_eq!(opt.first_moments().len(), 4);
592 assert!(opt.first_moments()[1].is_none());
593 assert!(opt.first_moments()[3].is_some());
594 }
595
596 #[test]
597 fn test_adamw_scalar_fallback_path() {
598 let mut params = vec![Tensor::from_vec(vec![2.0, -1.0], true)];
600 let mut optimizer = AdamW::new(0.1, 0.9, 0.999, 1e-8, 0.01);
601
602 for _ in 0..3 {
604 let grad = params[0].data().mapv(|x| 2.0 * x);
605 params[0].set_grad(grad);
606 optimizer.step(&mut params);
607 }
608 assert!(params[0].data()[0].abs() < 2.0);
610 }
611
612 mod aw_proptest_falsify {
613 use super::*;
614 use proptest::prelude::*;
615
616 proptest! {
618 #![proptest_config(ProptestConfig::with_cases(50))]
619
620 #[test]
621 fn falsify_aw_002e_prop_second_moment_non_negative(
622 seed in 0..500u32,
623 ) {
624 let beta2 = 0.999_f32;
625 let n = 4;
626 let mut v = vec![0.0_f32; n];
627
628 for step in 0..20 {
629 let g: Vec<f32> = (0..n)
630 .map(|i| ((i as f32 + seed as f32 + step as f32 * 13.0) * 0.37).sin() * 10.0)
631 .collect();
632 for i in 0..n {
633 v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i];
634 }
635 }
636
637 for (i, &vi) in v.iter().enumerate() {
638 prop_assert!(vi >= 0.0, "FALSIFIED AW-002e-prop: v[{}] = {} < 0", i, vi);
639 }
640 }
641 }
642
643 proptest! {
645 #![proptest_config(ProptestConfig::with_cases(50))]
646
647 #[test]
648 fn falsify_aw_004e_prop_update_finiteness(
649 seed in 0..500u32,
650 ) {
651 let data: Vec<f32> = (0..4)
652 .map(|i| ((i as f32 + seed as f32) * 0.37).sin() * 100.0)
653 .collect();
654 let mut params = vec![Tensor::from_vec(data.clone(), true)];
655 let mut optimizer = AdamW::default_params(0.001);
656
657 let grad_data: Vec<f32> = data.iter().map(|&x| 2.0 * x).collect();
658 params[0].set_grad(ndarray::Array1::from(grad_data));
659 optimizer.step(&mut params);
660
661 for (i, &val) in params[0].data().iter().enumerate() {
662 prop_assert!(
663 val.is_finite(),
664 "FALSIFIED AW-004e-prop: param[{}] = {} (not finite)",
665 i, val
666 );
667 }
668 }
669 }
670 }
671}