1use axonml_nn::Parameter;
24use axonml_tensor::Tensor;
25
26use crate::optimizer::Optimizer;
27
28pub struct Adam {
46 params: Vec<Parameter>,
48 lr: f32,
50 beta1: f32,
52 beta2: f32,
54 eps: f32,
56 weight_decay: f32,
58 amsgrad: bool,
60 state: Vec<AdamState>,
62}
63
64#[derive(Debug, Clone)]
69struct AdamState {
70 exp_avg: Tensor<f32>,
72 exp_avg_sq: Tensor<f32>,
74 max_exp_avg_sq: Option<Tensor<f32>>,
76 step: usize,
78}
79
80impl AdamState {
81 fn new(shape: &[usize], device: axonml_core::Device) -> Self {
82 let size: usize = shape.iter().product();
83 let mut exp_avg =
84 Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
85 let mut exp_avg_sq =
86 Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
87 if device.is_gpu() {
88 exp_avg = exp_avg.to_device(device).expect("device transfer failed");
89 exp_avg_sq = exp_avg_sq
90 .to_device(device)
91 .expect("device transfer failed");
92 }
93 Self {
94 exp_avg,
95 exp_avg_sq,
96 max_exp_avg_sq: None, step: 0,
98 }
99 }
100}
101
102impl Adam {
103 #[must_use]
105 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
106 Self::with_betas(params, lr, (0.9, 0.999))
107 }
108
109 #[must_use]
111 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
112 Self {
113 params,
114 lr,
115 beta1: betas.0,
116 beta2: betas.1,
117 eps: 1e-8,
118 weight_decay: 0.0,
119 amsgrad: false,
120 state: Vec::new(),
121 }
122 }
123
124 #[must_use]
126 pub fn with_options(
127 params: Vec<Parameter>,
128 lr: f32,
129 betas: (f32, f32),
130 eps: f32,
131 weight_decay: f32,
132 amsgrad: bool,
133 ) -> Self {
134 Self {
135 params,
136 lr,
137 beta1: betas.0,
138 beta2: betas.1,
139 eps,
140 weight_decay,
141 amsgrad,
142 state: Vec::new(),
143 }
144 }
145
146 #[must_use]
148 pub fn betas(mut self, betas: (f32, f32)) -> Self {
149 self.beta1 = betas.0;
150 self.beta2 = betas.1;
151 self
152 }
153
154 #[must_use]
156 pub fn eps(mut self, eps: f32) -> Self {
157 self.eps = eps;
158 self
159 }
160
161 #[must_use]
163 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
164 self.weight_decay = weight_decay;
165 self
166 }
167
168 #[must_use]
170 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
171 self.amsgrad = amsgrad;
172 self
173 }
174
175 fn ensure_state_initialized(&mut self) {
176 if self.state.is_empty() {
177 self.state = self
178 .params
179 .iter()
180 .map(|p| {
181 let data = p.data();
182 AdamState::new(data.shape(), data.device())
183 })
184 .collect();
185 }
186 }
187}
188
189impl Optimizer for Adam {
190 fn step(&mut self) {
191 self.ensure_state_initialized();
192
193 for (i, param) in self.params.iter().enumerate() {
194 if !param.requires_grad() {
195 continue;
196 }
197
198 let grad = match param.grad() {
199 Some(g) => g,
200 None => continue,
201 };
202
203 let state = &mut self.state[i];
204 state.step += 1;
205
206 let param_data = param.data();
207
208 #[cfg(feature = "cuda")]
210 if param_data.device().is_gpu() {
211 let grad = if !grad.device().is_gpu() {
214 grad.to_device(param_data.device())
215 .expect("Adam: failed to migrate CPU gradient to GPU")
216 } else {
217 grad
218 };
219 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
220 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
221
222 param_data.adam_step_inplace(
224 &grad,
225 &state.exp_avg,
226 &state.exp_avg_sq,
227 self.lr,
228 self.beta1,
229 self.beta2,
230 self.eps,
231 self.weight_decay,
232 bias_correction1,
233 bias_correction2,
234 );
235 continue;
237 }
238
239 let grad_vec = grad.to_vec();
241 let mut param_vec = param_data.to_vec();
242 let mut exp_avg_vec = state.exp_avg.to_vec();
243 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
244
245 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
246 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
247 let step_size = self.lr / bias_correction1;
248 let beta1 = self.beta1;
249 let beta2 = self.beta2;
250 let one_minus_beta1 = 1.0 - beta1;
251 let one_minus_beta2 = 1.0 - beta2;
252 let eps = self.eps;
253 let wd = self.weight_decay;
254
255 let mut max_sq_vec = if self.amsgrad {
257 state
258 .max_exp_avg_sq
259 .as_ref()
260 .map_or_else(|| vec![0.0f32; param_vec.len()], |t| t.to_vec())
261 } else {
262 Vec::new()
263 };
264
265 for i in 0..param_vec.len() {
266 let g = if wd == 0.0 {
267 grad_vec[i]
268 } else {
269 grad_vec[i] + wd * param_vec[i]
270 };
271 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
272 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
273
274 let v_hat = if self.amsgrad {
275 max_sq_vec[i] = max_sq_vec[i].max(exp_avg_sq_vec[i]);
276 max_sq_vec[i] / bias_correction2
277 } else {
278 exp_avg_sq_vec[i] / bias_correction2
279 };
280
281 let denom = v_hat.sqrt() + eps;
282 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
283 }
284
285 state.exp_avg =
286 Tensor::from_vec(exp_avg_vec, param_data.shape()).expect("tensor creation failed");
287 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape())
288 .expect("tensor creation failed");
289 if self.amsgrad {
290 state.max_exp_avg_sq = Some(
291 Tensor::from_vec(max_sq_vec, param_data.shape())
292 .expect("tensor creation failed"),
293 );
294 }
295 param.update_data(
296 Tensor::from_vec(param_vec, param_data.shape()).expect("tensor creation failed"),
297 );
298 }
299 }
300
301 fn zero_grad(&mut self) {
302 for param in &self.params {
303 param.zero_grad();
304 }
305 }
306
307 fn get_lr(&self) -> f32 {
308 self.lr
309 }
310
311 fn set_lr(&mut self, lr: f32) {
312 self.lr = lr;
313 }
314
315 fn parameters(&self) -> &[Parameter] {
316 &self.params
317 }
318}
319
320pub struct AdamW {
338 params: Vec<Parameter>,
340 lr: f32,
342 beta1: f32,
344 beta2: f32,
346 eps: f32,
348 weight_decay: f32,
350 amsgrad: bool,
352 state: Vec<AdamState>,
354}
355
356impl AdamW {
357 #[must_use]
359 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
360 Self::with_betas(params, lr, (0.9, 0.999))
361 }
362
363 #[must_use]
365 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
366 Self {
367 params,
368 lr,
369 beta1: betas.0,
370 beta2: betas.1,
371 eps: 1e-8,
372 weight_decay: 0.01, amsgrad: false,
374 state: Vec::new(),
375 }
376 }
377
378 #[must_use]
380 pub fn with_options(
381 params: Vec<Parameter>,
382 lr: f32,
383 betas: (f32, f32),
384 eps: f32,
385 weight_decay: f32,
386 amsgrad: bool,
387 ) -> Self {
388 Self {
389 params,
390 lr,
391 beta1: betas.0,
392 beta2: betas.1,
393 eps,
394 weight_decay,
395 amsgrad,
396 state: Vec::new(),
397 }
398 }
399
400 #[must_use]
402 pub fn betas(mut self, betas: (f32, f32)) -> Self {
403 self.beta1 = betas.0;
404 self.beta2 = betas.1;
405 self
406 }
407
408 #[must_use]
410 pub fn eps(mut self, eps: f32) -> Self {
411 self.eps = eps;
412 self
413 }
414
415 #[must_use]
417 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
418 self.weight_decay = weight_decay;
419 self
420 }
421
422 #[must_use]
424 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
425 self.amsgrad = amsgrad;
426 self
427 }
428
429 fn ensure_state_initialized(&mut self) {
430 if self.state.is_empty() {
431 self.state = self
432 .params
433 .iter()
434 .map(|p| {
435 let data = p.data();
436 AdamState::new(data.shape(), data.device())
437 })
438 .collect();
439 }
440 }
441}
442
443impl Optimizer for AdamW {
444 fn step(&mut self) {
445 self.ensure_state_initialized();
446
447 for (i, param) in self.params.iter().enumerate() {
448 if !param.requires_grad() {
449 continue;
450 }
451
452 let grad = match param.grad() {
453 Some(g) => g,
454 None => continue,
455 };
456
457 let state = &mut self.state[i];
458 state.step += 1;
459
460 let param_data = param.data();
461
462 #[cfg(feature = "cuda")]
464 if param_data.device().is_gpu() {
465 let grad = if !grad.device().is_gpu() {
467 grad.to_device(param_data.device())
468 .expect("AdamW: failed to migrate CPU gradient to GPU")
469 } else {
470 grad
471 };
472
473 if self.weight_decay > 0.0 {
477 let decay_factor = 1.0 - self.lr * self.weight_decay;
478 let decayed = param_data.mul_scalar(decay_factor);
479 param.update_data(decayed);
480 }
481
482 let param_data = param.data();
484
485 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
486 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
487
488 param_data.adam_step_inplace(
490 &grad,
491 &state.exp_avg,
492 &state.exp_avg_sq,
493 self.lr,
494 self.beta1,
495 self.beta2,
496 self.eps,
497 0.0, bias_correction1,
499 bias_correction2,
500 );
501 continue;
502 }
503
504 let grad_vec = grad.to_vec();
506 let mut param_vec = param_data.to_vec();
507 let mut exp_avg_vec = state.exp_avg.to_vec();
508 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
509
510 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
511 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
512 let step_size = self.lr / bias_correction1;
513 let beta1 = self.beta1;
514 let beta2 = self.beta2;
515 let one_minus_beta1 = 1.0 - beta1;
516 let one_minus_beta2 = 1.0 - beta2;
517 let eps = self.eps;
518 let wd_factor = 1.0 - self.lr * self.weight_decay;
519 let has_wd = self.weight_decay != 0.0;
520
521 for i in 0..param_vec.len() {
522 if has_wd {
524 param_vec[i] *= wd_factor;
525 }
526 let g = grad_vec[i];
527 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
528 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
529 let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
530 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
531 }
532
533 state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
534 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
535 param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
536 }
537 }
538
539 fn zero_grad(&mut self) {
540 for param in &self.params {
541 param.zero_grad();
542 }
543 }
544
545 fn get_lr(&self) -> f32 {
546 self.lr
547 }
548
549 fn set_lr(&mut self, lr: f32) {
550 self.lr = lr;
551 }
552
553 fn parameters(&self) -> &[Parameter] {
554 &self.params
555 }
556}
557
558#[cfg(test)]
563mod tests {
564 use super::*;
565 use axonml_autograd::Variable;
566
567 #[test]
568 fn test_adam_creation() {
569 let var = Variable::new(
570 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
571 true,
572 );
573 let param = Parameter::from_variable(var);
574 let optimizer = Adam::new(vec![param], 0.001);
575
576 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
577 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
578 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
579 }
580
581 #[test]
582 fn test_adam_step() {
583 let var = Variable::new(
584 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
585 true,
586 );
587 let param = Parameter::from_variable(var);
588
589 param
591 .variable()
592 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
593
594 let mut optimizer = Adam::new(vec![param.clone()], 0.1);
595 optimizer.step();
596
597 let new_data = param.data().to_vec();
598 assert!((new_data[0] - 1.0).abs() > 1e-6);
600 }
601
602 #[test]
603 fn test_adamw_creation() {
604 let var = Variable::new(
605 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
606 true,
607 );
608 let param = Parameter::from_variable(var);
609 let optimizer = AdamW::new(vec![param], 0.001);
610
611 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
612 }
613
614 #[test]
615 fn test_adam_builder_pattern() {
616 let var = Variable::new(
617 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
618 true,
619 );
620 let param = Parameter::from_variable(var);
621
622 let optimizer = Adam::new(vec![param], 0.001)
623 .betas((0.95, 0.9999))
624 .eps(1e-7)
625 .weight_decay(0.01)
626 .amsgrad(true);
627
628 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
629 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
630 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
631 assert!(optimizer.amsgrad);
632 }
633
634 #[test]
644 fn test_adam_step_correctness() {
645 let var = Variable::new(Tensor::from_vec(vec![0.5, -0.3], &[2]).unwrap(), true);
646 let param = Parameter::from_variable(var);
647 param.set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
648
649 let mut opt = Adam::new(vec![param.clone()], 0.1);
650 let before = param.data().to_vec();
651 opt.step();
652 let after = param.data().to_vec();
653
654 assert!(
656 after[0] < before[0],
657 "param[0] should decrease: {} -> {}",
658 before[0],
659 after[0]
660 );
661 assert!(
662 after[1] < before[1],
663 "param[1] should decrease: {} -> {}",
664 before[1],
665 after[1]
666 );
667
668 let delta0 = before[0] - after[0];
670 let delta1 = before[1] - after[1];
671 assert!(
672 (delta0 - delta1).abs() < 1e-6,
673 "Uniform gradient should produce uniform update: {} vs {}",
674 delta0,
675 delta1
676 );
677 }
678
679 #[test]
682 fn test_adam_converges_on_quadratic() {
683 let var = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), true);
684 let param = Parameter::from_variable(var);
685 let mut opt = Adam::new(vec![param.clone()], 0.1);
686
687 for _ in 0..200 {
688 opt.zero_grad();
689 let x = param.variable();
691 let loss = x.mul_var(&x).sum(); loss.backward();
693 opt.step();
694 }
695
696 let final_x = param.data().to_vec()[0];
697 assert!(
698 final_x.abs() < 0.1,
699 "Adam should converge near 0 for f(x)=x^2, got {}",
700 final_x
701 );
702 }
703
704 #[test]
706 fn test_adam_zero_grad() {
707 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), true);
708 let param = Parameter::from_variable(var);
709 param.set_grad(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap());
710 assert!(param.grad().is_some());
711
712 let mut opt = Adam::new(vec![param.clone()], 0.01);
713 opt.zero_grad();
714 if let Some(g) = param.grad() {
716 let gv = g.to_vec();
717 assert!(
718 gv.iter().all(|&v| v.abs() < 1e-10),
719 "Gradients should be zero after zero_grad: {:?}",
720 gv
721 );
722 }
723 }
724
725 #[test]
727 fn test_adam_lr_management() {
728 let var = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), true);
729 let param = Parameter::from_variable(var);
730 let mut opt = Adam::new(vec![param], 0.001);
731
732 assert!((opt.get_lr() - 0.001).abs() < 1e-8);
733 opt.set_lr(0.01);
734 assert!((opt.get_lr() - 0.01).abs() < 1e-8);
735 }
736
737 #[test]
739 fn test_adam_skips_frozen_params() {
740 let trainable = Parameter::from_variable(Variable::new(
741 Tensor::from_vec(vec![1.0], &[1]).unwrap(),
742 true,
743 ));
744 let frozen = Parameter::from_variable(Variable::new(
745 Tensor::from_vec(vec![2.0], &[1]).unwrap(),
746 false,
747 ));
748
749 trainable.set_grad(Tensor::from_vec(vec![1.0], &[1]).unwrap());
750
751 let mut opt = Adam::new(vec![trainable.clone(), frozen.clone()], 0.1);
752 opt.step();
753
754 assert!((trainable.data().to_vec()[0] - 1.0).abs() > 1e-6);
756 assert!((frozen.data().to_vec()[0] - 2.0).abs() < 1e-8);
757 }
758
759 #[test]
761 fn test_adam_weight_decay() {
762 let var = Variable::new(Tensor::from_vec(vec![10.0], &[1]).unwrap(), true);
763 let param = Parameter::from_variable(var);
764 param.set_grad(Tensor::from_vec(vec![0.0], &[1]).unwrap());
766
767 let mut opt = Adam::new(vec![param.clone()], 0.1).weight_decay(0.1);
768 let before = param.data().to_vec()[0];
769 opt.step();
770 let after = param.data().to_vec()[0];
771
772 assert!(
775 after < before,
776 "Weight decay should shrink large params: {} -> {}",
777 before,
778 after
779 );
780 }
781
782 #[test]
784 fn test_adam_multiple_steps_improve() {
785 let var = Variable::new(Tensor::from_vec(vec![3.0, -2.0], &[2]).unwrap(), true);
786 let param = Parameter::from_variable(var);
787 let mut opt = Adam::new(vec![param.clone()], 0.05);
788
789 let mut losses = Vec::new();
790 for _ in 0..50 {
791 opt.zero_grad();
792 let x = param.variable();
793 let loss = x.mul_var(&x).sum(); losses.push(loss.data().to_vec()[0]);
795 loss.backward();
796 opt.step();
797 }
798
799 let first = losses[0];
801 let last = *losses.last().unwrap();
802 assert!(
803 last < first * 0.5,
804 "Loss should decrease significantly: first={}, last={}",
805 first,
806 last
807 );
808 }
809
810 #[test]
816 fn test_adamw_step_correctness() {
817 let var = Variable::new(Tensor::from_vec(vec![5.0, -3.0], &[2]).unwrap(), true);
818 let param = Parameter::from_variable(var);
819 param.set_grad(Tensor::from_vec(vec![1.0, -1.0], &[2]).unwrap());
820
821 let mut opt = AdamW::new(vec![param.clone()], 0.01);
822 let before = param.data().to_vec();
823 opt.step();
824 let after = param.data().to_vec();
825
826 assert!(after[0] < before[0], "Positive grad should decrease param");
828 assert!(after[1] > before[1], "Negative grad should increase param");
829 }
830
831 #[test]
833 fn test_adamw_converges() {
834 let var = Variable::new(Tensor::from_vec(vec![4.0], &[1]).unwrap(), true);
835 let param = Parameter::from_variable(var);
836 let mut opt = AdamW::new(vec![param.clone()], 0.1);
837
838 for _ in 0..200 {
839 opt.zero_grad();
840 let x = param.variable();
841 let loss = x.mul_var(&x).sum();
842 loss.backward();
843 opt.step();
844 }
845
846 assert!(
847 param.data().to_vec()[0].abs() < 0.1,
848 "AdamW should converge near 0, got {}",
849 param.data().to_vec()[0]
850 );
851 }
852}