1use axonml_nn::Parameter;
18use axonml_tensor::Tensor;
19
20use crate::optimizer::Optimizer;
21
22pub struct Adam {
40 params: Vec<Parameter>,
42 lr: f32,
44 beta1: f32,
46 beta2: f32,
48 eps: f32,
50 weight_decay: f32,
52 amsgrad: bool,
54 state: Vec<AdamState>,
56}
57
58#[derive(Debug, Clone)]
63struct AdamState {
64 exp_avg: Tensor<f32>,
66 exp_avg_sq: Tensor<f32>,
68 max_exp_avg_sq: Option<Tensor<f32>>,
70 step: usize,
72}
73
74impl AdamState {
75 fn new(shape: &[usize], device: axonml_core::Device) -> Self {
76 let size: usize = shape.iter().product();
77 let mut exp_avg =
78 Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
79 let mut exp_avg_sq =
80 Tensor::from_vec(vec![0.0f32; size], shape).expect("tensor creation failed");
81 if device.is_gpu() {
82 exp_avg = exp_avg.to_device(device).expect("device transfer failed");
83 exp_avg_sq = exp_avg_sq
84 .to_device(device)
85 .expect("device transfer failed");
86 }
87 Self {
88 exp_avg,
89 exp_avg_sq,
90 max_exp_avg_sq: None, step: 0,
92 }
93 }
94}
95
96impl Adam {
97 #[must_use]
99 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
100 Self::with_betas(params, lr, (0.9, 0.999))
101 }
102
103 #[must_use]
105 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
106 Self {
107 params,
108 lr,
109 beta1: betas.0,
110 beta2: betas.1,
111 eps: 1e-8,
112 weight_decay: 0.0,
113 amsgrad: false,
114 state: Vec::new(),
115 }
116 }
117
118 #[must_use]
120 pub fn with_options(
121 params: Vec<Parameter>,
122 lr: f32,
123 betas: (f32, f32),
124 eps: f32,
125 weight_decay: f32,
126 amsgrad: bool,
127 ) -> Self {
128 Self {
129 params,
130 lr,
131 beta1: betas.0,
132 beta2: betas.1,
133 eps,
134 weight_decay,
135 amsgrad,
136 state: Vec::new(),
137 }
138 }
139
140 #[must_use]
142 pub fn betas(mut self, betas: (f32, f32)) -> Self {
143 self.beta1 = betas.0;
144 self.beta2 = betas.1;
145 self
146 }
147
148 #[must_use]
150 pub fn eps(mut self, eps: f32) -> Self {
151 self.eps = eps;
152 self
153 }
154
155 #[must_use]
157 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
158 self.weight_decay = weight_decay;
159 self
160 }
161
162 #[must_use]
164 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
165 self.amsgrad = amsgrad;
166 self
167 }
168
169 fn ensure_state_initialized(&mut self) {
170 if self.state.is_empty() {
171 self.state = self
172 .params
173 .iter()
174 .map(|p| {
175 let data = p.data();
176 AdamState::new(data.shape(), data.device())
177 })
178 .collect();
179 }
180 }
181}
182
183impl Optimizer for Adam {
184 fn step(&mut self) {
185 self.ensure_state_initialized();
186
187 for (i, param) in self.params.iter().enumerate() {
188 if !param.requires_grad() {
189 continue;
190 }
191
192 let grad = match param.grad() {
193 Some(g) => g,
194 None => continue,
195 };
196
197 let state = &mut self.state[i];
198 state.step += 1;
199
200 let param_data = param.data();
201
202 #[cfg(feature = "cuda")]
204 if param_data.device().is_gpu() {
205 let grad = if !grad.device().is_gpu() {
208 grad.to_device(param_data.device())
209 .expect("Adam: failed to migrate CPU gradient to GPU")
210 } else {
211 grad
212 };
213 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
214 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
215
216 param_data.adam_step_inplace(
218 &grad,
219 &state.exp_avg,
220 &state.exp_avg_sq,
221 self.lr,
222 self.beta1,
223 self.beta2,
224 self.eps,
225 self.weight_decay,
226 bias_correction1,
227 bias_correction2,
228 );
229 continue;
231 }
232
233 let grad_vec = grad.to_vec();
235 let mut param_vec = param_data.to_vec();
236 let mut exp_avg_vec = state.exp_avg.to_vec();
237 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
238
239 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
240 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
241 let step_size = self.lr / bias_correction1;
242 let beta1 = self.beta1;
243 let beta2 = self.beta2;
244 let one_minus_beta1 = 1.0 - beta1;
245 let one_minus_beta2 = 1.0 - beta2;
246 let eps = self.eps;
247 let wd = self.weight_decay;
248
249 let mut max_sq_vec = if self.amsgrad {
251 state
252 .max_exp_avg_sq
253 .as_ref()
254 .map_or_else(|| vec![0.0f32; param_vec.len()], |t| t.to_vec())
255 } else {
256 Vec::new()
257 };
258
259 for i in 0..param_vec.len() {
260 let g = if wd == 0.0 {
261 grad_vec[i]
262 } else {
263 grad_vec[i] + wd * param_vec[i]
264 };
265 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
266 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
267
268 let v_hat = if self.amsgrad {
269 max_sq_vec[i] = max_sq_vec[i].max(exp_avg_sq_vec[i]);
270 max_sq_vec[i] / bias_correction2
271 } else {
272 exp_avg_sq_vec[i] / bias_correction2
273 };
274
275 let denom = v_hat.sqrt() + eps;
276 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
277 }
278
279 state.exp_avg =
280 Tensor::from_vec(exp_avg_vec, param_data.shape()).expect("tensor creation failed");
281 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape())
282 .expect("tensor creation failed");
283 if self.amsgrad {
284 state.max_exp_avg_sq = Some(
285 Tensor::from_vec(max_sq_vec, param_data.shape())
286 .expect("tensor creation failed"),
287 );
288 }
289 param.update_data(
290 Tensor::from_vec(param_vec, param_data.shape()).expect("tensor creation failed"),
291 );
292 }
293 }
294
295 fn zero_grad(&mut self) {
296 for param in &self.params {
297 param.zero_grad();
298 }
299 }
300
301 fn get_lr(&self) -> f32 {
302 self.lr
303 }
304
305 fn set_lr(&mut self, lr: f32) {
306 self.lr = lr;
307 }
308
309 fn parameters(&self) -> &[Parameter] {
310 &self.params
311 }
312}
313
314pub struct AdamW {
332 params: Vec<Parameter>,
334 lr: f32,
336 beta1: f32,
338 beta2: f32,
340 eps: f32,
342 weight_decay: f32,
344 amsgrad: bool,
346 state: Vec<AdamState>,
348}
349
350impl AdamW {
351 #[must_use]
353 pub fn new(params: Vec<Parameter>, lr: f32) -> Self {
354 Self::with_betas(params, lr, (0.9, 0.999))
355 }
356
357 #[must_use]
359 pub fn with_betas(params: Vec<Parameter>, lr: f32, betas: (f32, f32)) -> Self {
360 Self {
361 params,
362 lr,
363 beta1: betas.0,
364 beta2: betas.1,
365 eps: 1e-8,
366 weight_decay: 0.01, amsgrad: false,
368 state: Vec::new(),
369 }
370 }
371
372 #[must_use]
374 pub fn with_options(
375 params: Vec<Parameter>,
376 lr: f32,
377 betas: (f32, f32),
378 eps: f32,
379 weight_decay: f32,
380 amsgrad: bool,
381 ) -> Self {
382 Self {
383 params,
384 lr,
385 beta1: betas.0,
386 beta2: betas.1,
387 eps,
388 weight_decay,
389 amsgrad,
390 state: Vec::new(),
391 }
392 }
393
394 #[must_use]
396 pub fn betas(mut self, betas: (f32, f32)) -> Self {
397 self.beta1 = betas.0;
398 self.beta2 = betas.1;
399 self
400 }
401
402 #[must_use]
404 pub fn eps(mut self, eps: f32) -> Self {
405 self.eps = eps;
406 self
407 }
408
409 #[must_use]
411 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
412 self.weight_decay = weight_decay;
413 self
414 }
415
416 #[must_use]
418 pub fn amsgrad(mut self, amsgrad: bool) -> Self {
419 self.amsgrad = amsgrad;
420 self
421 }
422
423 fn ensure_state_initialized(&mut self) {
424 if self.state.is_empty() {
425 self.state = self
426 .params
427 .iter()
428 .map(|p| {
429 let data = p.data();
430 AdamState::new(data.shape(), data.device())
431 })
432 .collect();
433 }
434 }
435}
436
437impl Optimizer for AdamW {
438 fn step(&mut self) {
439 self.ensure_state_initialized();
440
441 for (i, param) in self.params.iter().enumerate() {
442 if !param.requires_grad() {
443 continue;
444 }
445
446 let grad = match param.grad() {
447 Some(g) => g,
448 None => continue,
449 };
450
451 let state = &mut self.state[i];
452 state.step += 1;
453
454 let param_data = param.data();
455
456 #[cfg(feature = "cuda")]
458 if param_data.device().is_gpu() {
459 let grad = if !grad.device().is_gpu() {
461 grad.to_device(param_data.device())
462 .expect("AdamW: failed to migrate CPU gradient to GPU")
463 } else {
464 grad
465 };
466
467 if self.weight_decay > 0.0 {
471 let decay_factor = 1.0 - self.lr * self.weight_decay;
472 let decayed = param_data.mul_scalar(decay_factor);
473 param.update_data(decayed);
474 }
475
476 let param_data = param.data();
478
479 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
480 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
481
482 param_data.adam_step_inplace(
484 &grad,
485 &state.exp_avg,
486 &state.exp_avg_sq,
487 self.lr,
488 self.beta1,
489 self.beta2,
490 self.eps,
491 0.0, bias_correction1,
493 bias_correction2,
494 );
495 continue;
496 }
497
498 let grad_vec = grad.to_vec();
500 let mut param_vec = param_data.to_vec();
501 let mut exp_avg_vec = state.exp_avg.to_vec();
502 let mut exp_avg_sq_vec = state.exp_avg_sq.to_vec();
503
504 let bias_correction1 = 1.0 - self.beta1.powi(state.step as i32);
505 let bias_correction2 = 1.0 - self.beta2.powi(state.step as i32);
506 let step_size = self.lr / bias_correction1;
507 let beta1 = self.beta1;
508 let beta2 = self.beta2;
509 let one_minus_beta1 = 1.0 - beta1;
510 let one_minus_beta2 = 1.0 - beta2;
511 let eps = self.eps;
512 let wd_factor = 1.0 - self.lr * self.weight_decay;
513 let has_wd = self.weight_decay != 0.0;
514
515 for i in 0..param_vec.len() {
516 if has_wd {
518 param_vec[i] *= wd_factor;
519 }
520 let g = grad_vec[i];
521 exp_avg_vec[i] = beta1 * exp_avg_vec[i] + one_minus_beta1 * g;
522 exp_avg_sq_vec[i] = beta2 * exp_avg_sq_vec[i] + one_minus_beta2 * g * g;
523 let denom = (exp_avg_sq_vec[i] / bias_correction2).sqrt() + eps;
524 param_vec[i] -= step_size * exp_avg_vec[i] / denom;
525 }
526
527 state.exp_avg = Tensor::from_vec(exp_avg_vec, param_data.shape()).unwrap();
528 state.exp_avg_sq = Tensor::from_vec(exp_avg_sq_vec, param_data.shape()).unwrap();
529 param.update_data(Tensor::from_vec(param_vec, param_data.shape()).unwrap());
530 }
531 }
532
533 fn zero_grad(&mut self) {
534 for param in &self.params {
535 param.zero_grad();
536 }
537 }
538
539 fn get_lr(&self) -> f32 {
540 self.lr
541 }
542
543 fn set_lr(&mut self, lr: f32) {
544 self.lr = lr;
545 }
546
547 fn parameters(&self) -> &[Parameter] {
548 &self.params
549 }
550}
551
552#[cfg(test)]
557mod tests {
558 use super::*;
559 use axonml_autograd::Variable;
560
561 #[test]
562 fn test_adam_creation() {
563 let var = Variable::new(
564 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
565 true,
566 );
567 let param = Parameter::from_variable(var);
568 let optimizer = Adam::new(vec![param], 0.001);
569
570 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
571 assert!((optimizer.beta1 - 0.9).abs() < 1e-6);
572 assert!((optimizer.beta2 - 0.999).abs() < 1e-6);
573 }
574
575 #[test]
576 fn test_adam_step() {
577 let var = Variable::new(
578 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
579 true,
580 );
581 let param = Parameter::from_variable(var);
582
583 param
585 .variable()
586 .set_grad(Tensor::from_vec(vec![0.1, 0.2, 0.3], &[3]).expect("tensor creation failed"));
587
588 let mut optimizer = Adam::new(vec![param.clone()], 0.1);
589 optimizer.step();
590
591 let new_data = param.data().to_vec();
592 assert!((new_data[0] - 1.0).abs() > 1e-6);
594 }
595
596 #[test]
597 fn test_adamw_creation() {
598 let var = Variable::new(
599 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
600 true,
601 );
602 let param = Parameter::from_variable(var);
603 let optimizer = AdamW::new(vec![param], 0.001);
604
605 assert!((optimizer.weight_decay - 0.01).abs() < 1e-6);
606 }
607
608 #[test]
609 fn test_adam_builder_pattern() {
610 let var = Variable::new(
611 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
612 true,
613 );
614 let param = Parameter::from_variable(var);
615
616 let optimizer = Adam::new(vec![param], 0.001)
617 .betas((0.95, 0.9999))
618 .eps(1e-7)
619 .weight_decay(0.01)
620 .amsgrad(true);
621
622 assert!((optimizer.beta1 - 0.95).abs() < 1e-6);
623 assert!((optimizer.beta2 - 0.9999).abs() < 1e-6);
624 assert!((optimizer.eps - 1e-7).abs() < 1e-9);
625 assert!(optimizer.amsgrad);
626 }
627
628 #[test]
638 fn test_adam_step_correctness() {
639 let var = Variable::new(Tensor::from_vec(vec![0.5, -0.3], &[2]).unwrap(), true);
640 let param = Parameter::from_variable(var);
641 param.set_grad(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap());
642
643 let mut opt = Adam::new(vec![param.clone()], 0.1);
644 let before = param.data().to_vec();
645 opt.step();
646 let after = param.data().to_vec();
647
648 assert!(
650 after[0] < before[0],
651 "param[0] should decrease: {} -> {}",
652 before[0],
653 after[0]
654 );
655 assert!(
656 after[1] < before[1],
657 "param[1] should decrease: {} -> {}",
658 before[1],
659 after[1]
660 );
661
662 let delta0 = before[0] - after[0];
664 let delta1 = before[1] - after[1];
665 assert!(
666 (delta0 - delta1).abs() < 1e-6,
667 "Uniform gradient should produce uniform update: {} vs {}",
668 delta0,
669 delta1
670 );
671 }
672
673 #[test]
676 fn test_adam_converges_on_quadratic() {
677 let var = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), true);
678 let param = Parameter::from_variable(var);
679 let mut opt = Adam::new(vec![param.clone()], 0.1);
680
681 for _ in 0..200 {
682 opt.zero_grad();
683 let x = param.variable();
685 let loss = x.mul_var(&x).sum(); loss.backward();
687 opt.step();
688 }
689
690 let final_x = param.data().to_vec()[0];
691 assert!(
692 final_x.abs() < 0.1,
693 "Adam should converge near 0 for f(x)=x^2, got {}",
694 final_x
695 );
696 }
697
698 #[test]
700 fn test_adam_zero_grad() {
701 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), true);
702 let param = Parameter::from_variable(var);
703 param.set_grad(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap());
704 assert!(param.grad().is_some());
705
706 let mut opt = Adam::new(vec![param.clone()], 0.01);
707 opt.zero_grad();
708 if let Some(g) = param.grad() {
710 let gv = g.to_vec();
711 assert!(
712 gv.iter().all(|&v| v.abs() < 1e-10),
713 "Gradients should be zero after zero_grad: {:?}",
714 gv
715 );
716 }
717 }
718
719 #[test]
721 fn test_adam_lr_management() {
722 let var = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), true);
723 let param = Parameter::from_variable(var);
724 let mut opt = Adam::new(vec![param], 0.001);
725
726 assert!((opt.get_lr() - 0.001).abs() < 1e-8);
727 opt.set_lr(0.01);
728 assert!((opt.get_lr() - 0.01).abs() < 1e-8);
729 }
730
731 #[test]
733 fn test_adam_skips_frozen_params() {
734 let trainable = Parameter::from_variable(Variable::new(
735 Tensor::from_vec(vec![1.0], &[1]).unwrap(),
736 true,
737 ));
738 let frozen = Parameter::from_variable(Variable::new(
739 Tensor::from_vec(vec![2.0], &[1]).unwrap(),
740 false,
741 ));
742
743 trainable.set_grad(Tensor::from_vec(vec![1.0], &[1]).unwrap());
744
745 let mut opt = Adam::new(vec![trainable.clone(), frozen.clone()], 0.1);
746 opt.step();
747
748 assert!((trainable.data().to_vec()[0] - 1.0).abs() > 1e-6);
750 assert!((frozen.data().to_vec()[0] - 2.0).abs() < 1e-8);
751 }
752
753 #[test]
755 fn test_adam_weight_decay() {
756 let var = Variable::new(Tensor::from_vec(vec![10.0], &[1]).unwrap(), true);
757 let param = Parameter::from_variable(var);
758 param.set_grad(Tensor::from_vec(vec![0.0], &[1]).unwrap());
760
761 let mut opt = Adam::new(vec![param.clone()], 0.1).weight_decay(0.1);
762 let before = param.data().to_vec()[0];
763 opt.step();
764 let after = param.data().to_vec()[0];
765
766 assert!(
769 after < before,
770 "Weight decay should shrink large params: {} -> {}",
771 before,
772 after
773 );
774 }
775
776 #[test]
778 fn test_adam_multiple_steps_improve() {
779 let var = Variable::new(Tensor::from_vec(vec![3.0, -2.0], &[2]).unwrap(), true);
780 let param = Parameter::from_variable(var);
781 let mut opt = Adam::new(vec![param.clone()], 0.05);
782
783 let mut losses = Vec::new();
784 for _ in 0..50 {
785 opt.zero_grad();
786 let x = param.variable();
787 let loss = x.mul_var(&x).sum(); losses.push(loss.data().to_vec()[0]);
789 loss.backward();
790 opt.step();
791 }
792
793 let first = losses[0];
795 let last = *losses.last().unwrap();
796 assert!(
797 last < first * 0.5,
798 "Loss should decrease significantly: first={}, last={}",
799 first,
800 last
801 );
802 }
803
804 #[test]
810 fn test_adamw_step_correctness() {
811 let var = Variable::new(Tensor::from_vec(vec![5.0, -3.0], &[2]).unwrap(), true);
812 let param = Parameter::from_variable(var);
813 param.set_grad(Tensor::from_vec(vec![1.0, -1.0], &[2]).unwrap());
814
815 let mut opt = AdamW::new(vec![param.clone()], 0.01);
816 let before = param.data().to_vec();
817 opt.step();
818 let after = param.data().to_vec();
819
820 assert!(after[0] < before[0], "Positive grad should decrease param");
822 assert!(after[1] > before[1], "Negative grad should increase param");
823 }
824
825 #[test]
827 fn test_adamw_converges() {
828 let var = Variable::new(Tensor::from_vec(vec![4.0], &[1]).unwrap(), true);
829 let param = Parameter::from_variable(var);
830 let mut opt = AdamW::new(vec![param.clone()], 0.1);
831
832 for _ in 0..200 {
833 opt.zero_grad();
834 let x = param.variable();
835 let loss = x.mul_var(&x).sum();
836 loss.backward();
837 opt.step();
838 }
839
840 assert!(
841 param.data().to_vec()[0].abs() < 0.1,
842 "AdamW should converge near 0, got {}",
843 param.data().to_vec()[0]
844 );
845 }
846}