1use crate::autograd::{get_grad, Tensor, TensorId};
40
41pub trait Optimizer {
43 fn step(&mut self);
45
46 fn zero_grad(&mut self);
48
49 fn lr(&self) -> f32;
51
52 fn set_lr(&mut self, lr: f32);
54}
55
56#[derive(Debug)]
70pub struct SGD {
71 param_ids: Vec<TensorId>,
73 lr: f32,
75 momentum: f32,
77 weight_decay: f32,
79 nesterov: bool,
81 velocities: Vec<Vec<f32>>,
83 initialized: bool,
85}
86
87impl SGD {
88 #[allow(clippy::needless_pass_by_value)]
95 #[must_use]
96 pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
97 let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
98 Self {
99 param_ids,
100 lr,
101 momentum: 0.0,
102 weight_decay: 0.0,
103 nesterov: false,
104 velocities: Vec::new(),
105 initialized: false,
106 }
107 }
108
109 #[allow(clippy::needless_pass_by_value)]
111 #[must_use]
112 pub fn with_momentum(params: Vec<&mut Tensor>, lr: f32, momentum: f32) -> Self {
113 let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
114 Self {
115 param_ids,
116 lr,
117 momentum,
118 weight_decay: 0.0,
119 nesterov: false,
120 velocities: Vec::new(),
121 initialized: false,
122 }
123 }
124
125 #[must_use]
127 pub fn nesterov(mut self) -> Self {
128 self.nesterov = true;
129 self
130 }
131
132 #[must_use]
134 pub fn weight_decay(mut self, wd: f32) -> Self {
135 self.weight_decay = wd;
136 self
137 }
138
139 #[allow(clippy::if_not_else)]
141 fn update_param(&mut self, param: &mut Tensor, idx: usize) {
142 let Some(grad) = get_grad(param.id()) else {
143 return; };
145
146 let grad_data = grad.data();
147 let param_data = param.data_mut();
148
149 if !self.initialized || idx >= self.velocities.len() {
151 if idx >= self.velocities.len() {
152 self.velocities.resize(idx + 1, Vec::new());
153 }
154 self.velocities[idx] = vec![0.0; param_data.len()];
155 }
156
157 let velocity = &mut self.velocities[idx];
158
159 for i in 0..param_data.len() {
160 let mut g = grad_data[i];
161
162 if self.weight_decay != 0.0 {
164 g += self.weight_decay * param_data[i];
165 }
166
167 if self.momentum != 0.0 {
168 velocity[i] = self.momentum * velocity[i] + g;
170
171 if self.nesterov {
172 param_data[i] -= self.lr * (self.momentum * velocity[i] + g);
174 } else {
175 param_data[i] -= self.lr * velocity[i];
177 }
178 } else {
179 param_data[i] -= self.lr * g;
181 }
182 }
183 }
184}
185
186impl Optimizer for SGD {
187 fn step(&mut self) {
188 self.initialized = true;
192 }
193
194 fn zero_grad(&mut self) {
195 for &id in &self.param_ids {
196 crate::autograd::clear_grad(id);
197 }
198 }
199
200 fn lr(&self) -> f32 {
201 self.lr
202 }
203
204 fn set_lr(&mut self, lr: f32) {
205 self.lr = lr;
206 }
207}
208
209impl SGD {
210 pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
214 for (idx, param) in params.iter_mut().enumerate() {
215 self.update_param(param, idx);
216 }
217 self.initialized = true;
218 }
219}
220
221#[derive(Debug)]
235pub struct Adam {
236 param_ids: Vec<TensorId>,
237 lr: f32,
238 beta1: f32,
239 beta2: f32,
240 eps: f32,
241 weight_decay: f32,
242 m: Vec<Vec<f32>>,
244 v: Vec<Vec<f32>>,
246 t: usize,
248 initialized: bool,
249}
250
251impl Adam {
252 #[allow(clippy::needless_pass_by_value)]
256 #[must_use]
257 pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
258 let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
259 Self {
260 param_ids,
261 lr,
262 beta1: 0.9,
263 beta2: 0.999,
264 eps: 1e-8,
265 weight_decay: 0.0,
266 m: Vec::new(),
267 v: Vec::new(),
268 t: 0,
269 initialized: false,
270 }
271 }
272
273 #[must_use]
275 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
276 self.beta1 = beta1;
277 self.beta2 = beta2;
278 self
279 }
280
281 #[must_use]
283 pub fn eps(mut self, eps: f32) -> Self {
284 self.eps = eps;
285 self
286 }
287
288 #[must_use]
290 pub fn weight_decay(mut self, wd: f32) -> Self {
291 self.weight_decay = wd;
292 self
293 }
294
295 fn update_param(&mut self, param: &mut Tensor, idx: usize) {
296 let Some(grad) = get_grad(param.id()) else {
297 return;
298 };
299
300 let grad_data = grad.data();
301 let param_data = param.data_mut();
302
303 if !self.initialized || idx >= self.m.len() {
305 if idx >= self.m.len() {
306 self.m.resize(idx + 1, Vec::new());
307 self.v.resize(idx + 1, Vec::new());
308 }
309 self.m[idx] = vec![0.0; param_data.len()];
310 self.v[idx] = vec![0.0; param_data.len()];
311 }
312
313 let m = &mut self.m[idx];
314 let v = &mut self.v[idx];
315
316 let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
318 let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
319
320 for i in 0..param_data.len() {
321 let mut g = grad_data[i];
322
323 if self.weight_decay != 0.0 {
325 g += self.weight_decay * param_data[i];
326 }
327
328 m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * g;
330
331 v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * g * g;
333
334 let m_hat = m[i] / bias_correction1;
336 let v_hat = v[i] / bias_correction2;
337
338 param_data[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
340 }
341 }
342
343 pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
345 self.t += 1;
346 for (idx, param) in params.iter_mut().enumerate() {
347 self.update_param(param, idx);
348 }
349 self.initialized = true;
350 }
351}
352
353impl Optimizer for Adam {
354 fn step(&mut self) {
355 self.t += 1;
356 self.initialized = true;
357 }
358
359 fn zero_grad(&mut self) {
360 for &id in &self.param_ids {
361 crate::autograd::clear_grad(id);
362 }
363 }
364
365 fn lr(&self) -> f32 {
366 self.lr
367 }
368
369 fn set_lr(&mut self, lr: f32) {
370 self.lr = lr;
371 }
372}
373
374#[derive(Debug)]
385pub struct AdamW {
386 param_ids: Vec<TensorId>,
387 lr: f32,
388 beta1: f32,
389 beta2: f32,
390 eps: f32,
391 weight_decay: f32,
392 m: Vec<Vec<f32>>,
393 v: Vec<Vec<f32>>,
394 t: usize,
395 initialized: bool,
396}
397
398impl AdamW {
399 #[allow(clippy::needless_pass_by_value)]
403 #[must_use]
404 pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
405 let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
406 Self {
407 param_ids,
408 lr,
409 beta1: 0.9,
410 beta2: 0.999,
411 eps: 1e-8,
412 weight_decay: 0.01,
413 m: Vec::new(),
414 v: Vec::new(),
415 t: 0,
416 initialized: false,
417 }
418 }
419
420 #[must_use]
421 pub fn betas(mut self, beta1: f32, beta2: f32) -> Self {
422 self.beta1 = beta1;
423 self.beta2 = beta2;
424 self
425 }
426
427 #[must_use]
428 pub fn eps(mut self, eps: f32) -> Self {
429 self.eps = eps;
430 self
431 }
432
433 #[must_use]
434 pub fn weight_decay(mut self, wd: f32) -> Self {
435 self.weight_decay = wd;
436 self
437 }
438
439 fn update_param(&mut self, param: &mut Tensor, idx: usize) {
440 let Some(grad) = get_grad(param.id()) else {
441 return;
442 };
443
444 let grad_data = grad.data();
445 let param_data = param.data_mut();
446
447 if !self.initialized || idx >= self.m.len() {
449 if idx >= self.m.len() {
450 self.m.resize(idx + 1, Vec::new());
451 self.v.resize(idx + 1, Vec::new());
452 }
453 self.m[idx] = vec![0.0; param_data.len()];
454 self.v[idx] = vec![0.0; param_data.len()];
455 }
456
457 let m = &mut self.m[idx];
458 let v = &mut self.v[idx];
459
460 let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
461 let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
462
463 for i in 0..param_data.len() {
464 let g = grad_data[i];
465
466 m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * g;
468 v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * g * g;
469
470 let m_hat = m[i] / bias_correction1;
471 let v_hat = v[i] / bias_correction2;
472
473 param_data[i] -= self.lr * self.weight_decay * param_data[i];
475
476 param_data[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
478 }
479 }
480
481 pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
482 self.t += 1;
483 for (idx, param) in params.iter_mut().enumerate() {
484 self.update_param(param, idx);
485 }
486 self.initialized = true;
487 }
488}
489
490impl Optimizer for AdamW {
491 fn step(&mut self) {
492 self.t += 1;
493 self.initialized = true;
494 }
495
496 fn zero_grad(&mut self) {
497 for &id in &self.param_ids {
498 crate::autograd::clear_grad(id);
499 }
500 }
501
502 fn lr(&self) -> f32 {
503 self.lr
504 }
505
506 fn set_lr(&mut self, lr: f32) {
507 self.lr = lr;
508 }
509}
510
511#[derive(Debug)]
521pub struct RMSprop {
522 param_ids: Vec<TensorId>,
523 lr: f32,
524 alpha: f32,
525 eps: f32,
526 weight_decay: f32,
527 momentum: f32,
528 v: Vec<Vec<f32>>,
530 buffer: Vec<Vec<f32>>,
532 initialized: bool,
533}
534
535impl RMSprop {
536 #[allow(clippy::needless_pass_by_value)]
540 #[must_use]
541 pub fn new(params: Vec<&mut Tensor>, lr: f32) -> Self {
542 let param_ids: Vec<TensorId> = params.iter().map(|p| p.id()).collect();
543 Self {
544 param_ids,
545 lr,
546 alpha: 0.99,
547 eps: 1e-8,
548 weight_decay: 0.0,
549 momentum: 0.0,
550 v: Vec::new(),
551 buffer: Vec::new(),
552 initialized: false,
553 }
554 }
555
556 #[must_use]
557 pub fn alpha(mut self, alpha: f32) -> Self {
558 self.alpha = alpha;
559 self
560 }
561
562 #[must_use]
563 pub fn eps(mut self, eps: f32) -> Self {
564 self.eps = eps;
565 self
566 }
567
568 #[must_use]
569 pub fn momentum(mut self, momentum: f32) -> Self {
570 self.momentum = momentum;
571 self
572 }
573
574 #[must_use]
575 pub fn weight_decay(mut self, wd: f32) -> Self {
576 self.weight_decay = wd;
577 self
578 }
579
580 fn update_param(&mut self, param: &mut Tensor, idx: usize) {
581 let Some(grad) = get_grad(param.id()) else {
582 return;
583 };
584
585 let grad_data = grad.data();
586 let param_data = param.data_mut();
587
588 if !self.initialized || idx >= self.v.len() {
590 if idx >= self.v.len() {
591 self.v.resize(idx + 1, Vec::new());
592 self.buffer.resize(idx + 1, Vec::new());
593 }
594 self.v[idx] = vec![0.0; param_data.len()];
595 self.buffer[idx] = vec![0.0; param_data.len()];
596 }
597
598 let v = &mut self.v[idx];
599 let buffer = &mut self.buffer[idx];
600
601 for i in 0..param_data.len() {
602 let mut g = grad_data[i];
603
604 if self.weight_decay != 0.0 {
606 g += self.weight_decay * param_data[i];
607 }
608
609 v[i] = self.alpha * v[i] + (1.0 - self.alpha) * g * g;
611
612 let update = g / (v[i].sqrt() + self.eps);
614
615 if self.momentum > 0.0 {
616 buffer[i] = self.momentum * buffer[i] + update;
617 param_data[i] -= self.lr * buffer[i];
618 } else {
619 param_data[i] -= self.lr * update;
620 }
621 }
622 }
623
624 pub fn step_with_params(&mut self, params: &mut [&mut Tensor]) {
625 for (idx, param) in params.iter_mut().enumerate() {
626 self.update_param(param, idx);
627 }
628 self.initialized = true;
629 }
630}
631
632impl Optimizer for RMSprop {
633 fn step(&mut self) {
634 self.initialized = true;
635 }
636
637 fn zero_grad(&mut self) {
638 for &id in &self.param_ids {
639 crate::autograd::clear_grad(id);
640 }
641 }
642
643 fn lr(&self) -> f32 {
644 self.lr
645 }
646
647 fn set_lr(&mut self, lr: f32) {
648 self.lr = lr;
649 }
650}
651
652#[cfg(test)]
653mod tests {
654 use super::*;
655 use crate::autograd::clear_graph;
656
657 #[test]
658 fn test_sgd_basic() {
659 clear_graph();
660
661 let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0]).requires_grad();
663 let param_id = param.id();
664
665 let loss = param.pow(2.0).sum();
667 loss.backward();
668
669 let grad = get_grad(param_id).expect("Should have gradient");
671 assert_eq!(grad.data(), &[2.0, 4.0, 6.0]); let mut sgd = SGD::new(vec![&mut param], 0.1);
675 sgd.step_with_params(&mut [&mut param]);
676
677 let expected = [0.8, 1.6, 2.4];
679 for (p, e) in param.data().iter().zip(expected.iter()) {
680 assert!((p - e).abs() < 1e-5, "Expected {e}, got {p}");
681 }
682 }
683
684 #[test]
685 fn test_sgd_with_momentum() {
686 clear_graph();
687
688 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
689
690 let loss = param.pow(2.0).sum();
692 loss.backward();
693
694 let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9);
695 sgd.step_with_params(&mut [&mut param]);
696
697 assert!((param.data()[0] - 0.8).abs() < 1e-5);
700
701 clear_graph();
703 let loss = param.pow(2.0).sum();
704 loss.backward();
705
706 sgd.step_with_params(&mut [&mut param]);
707
708 assert!((param.data()[0] - 0.46).abs() < 1e-5);
712 }
713
714 #[test]
715 fn test_adam_basic() {
716 clear_graph();
717
718 let mut param = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
719
720 let loss = param.pow(2.0).sum();
721 loss.backward();
722
723 let mut adam = Adam::new(vec![&mut param], 0.1);
724 adam.step_with_params(&mut [&mut param]);
725
726 assert!(param.data()[0] < 1.0);
728 assert!(param.data()[1] < 2.0);
729 }
730
731 #[test]
732 fn test_adam_convergence() {
733 clear_graph();
735
736 let mut param = Tensor::from_slice(&[5.0]).requires_grad();
737 let mut adam = Adam::new(vec![&mut param], 0.5);
738
739 for _ in 0..100 {
741 clear_graph();
742 let loss = param.pow(2.0).sum();
743 loss.backward();
744 adam.step_with_params(&mut [&mut param]);
745 }
746
747 assert!(
749 param.data()[0].abs() < 0.1,
750 "Parameter should converge to 0, got {}",
751 param.data()[0]
752 );
753 }
754
755 #[test]
756 fn test_adamw_weight_decay() {
757 clear_graph();
758
759 let mut param = Tensor::from_slice(&[10.0]).requires_grad();
760
761 let loss = param.pow(2.0).sum();
766 loss.backward();
767
768 let mut adamw = AdamW::new(vec![&mut param], 0.1).weight_decay(0.1);
769 adamw.step_with_params(&mut [&mut param]);
770
771 assert!(param.data()[0] < 10.0);
773 }
774
775 #[test]
776 fn test_rmsprop_basic() {
777 clear_graph();
778
779 let mut param = Tensor::from_slice(&[3.0]).requires_grad();
780
781 let loss = param.pow(2.0).sum();
782 loss.backward();
783
784 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
785 rmsprop.step_with_params(&mut [&mut param]);
786
787 assert!(param.data()[0] < 3.0);
789 }
790
791 #[test]
792 fn test_zero_grad() {
793 clear_graph();
794
795 let mut param = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
796 let param_id = param.id();
797
798 let loss = param.pow(2.0).sum();
799 loss.backward();
800
801 assert!(get_grad(param_id).is_some());
803
804 let mut sgd = SGD::new(vec![&mut param], 0.1);
806 sgd.zero_grad();
807
808 assert!(get_grad(param_id).is_none());
810 }
811
812 #[test]
813 fn test_learning_rate_change() {
814 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
815 let mut sgd = SGD::new(vec![&mut param], 0.1);
816
817 assert!((sgd.lr() - 0.1).abs() < 1e-6);
818
819 sgd.set_lr(0.01);
820 assert!((sgd.lr() - 0.01).abs() < 1e-6);
821 }
822
823 #[test]
824 fn test_sgd_nesterov() {
825 clear_graph();
826
827 let mut param = Tensor::from_slice(&[2.0]).requires_grad();
828
829 let loss = param.pow(2.0).sum();
830 loss.backward();
831
832 let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9).nesterov();
833 sgd.step_with_params(&mut [&mut param]);
834
835 assert!(
840 (param.data()[0] - 1.24).abs() < 1e-5,
841 "Nesterov update failed: {}",
842 param.data()[0]
843 );
844 }
845
846 #[test]
847 fn test_sgd_weight_decay() {
848 clear_graph();
849
850 let mut param = Tensor::from_slice(&[5.0]).requires_grad();
851
852 let loss = param.pow(2.0).sum();
853 loss.backward();
854
855 let mut sgd = SGD::new(vec![&mut param], 0.1).weight_decay(0.1);
856 sgd.step_with_params(&mut [&mut param]);
857
858 assert!(
861 (param.data()[0] - 3.95).abs() < 1e-5,
862 "Weight decay update failed: {}",
863 param.data()[0]
864 );
865 }
866
867 #[test]
868 fn test_adam_with_custom_betas() {
869 clear_graph();
870
871 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
872
873 let loss = param.pow(2.0).sum();
874 loss.backward();
875
876 let mut adam = Adam::new(vec![&mut param], 0.1).betas(0.8, 0.99);
877 adam.step_with_params(&mut [&mut param]);
878
879 assert!(param.data()[0] < 1.0);
881 }
882
883 #[test]
884 fn test_adam_with_eps() {
885 clear_graph();
886
887 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
888
889 let loss = param.pow(2.0).sum();
890 loss.backward();
891
892 let mut adam = Adam::new(vec![&mut param], 0.1).eps(1e-6);
893 adam.step_with_params(&mut [&mut param]);
894
895 assert!(param.data()[0] < 1.0);
896 }
897
898 #[test]
899 fn test_adam_with_weight_decay() {
900 clear_graph();
901
902 let mut param = Tensor::from_slice(&[10.0]).requires_grad();
903
904 let loss = param.pow(2.0).sum();
905 loss.backward();
906
907 let mut adam_wd = Adam::new(vec![&mut param], 0.1).weight_decay(0.1);
909 adam_wd.step_with_params(&mut [&mut param]);
910
911 assert!(param.data()[0] < 10.0);
913 }
914
915 #[test]
916 fn test_adamw_with_custom_betas_and_eps() {
917 clear_graph();
918
919 let mut param = Tensor::from_slice(&[3.0]).requires_grad();
920
921 let loss = param.pow(2.0).sum();
922 loss.backward();
923
924 let mut adamw = AdamW::new(vec![&mut param], 0.1)
925 .betas(0.85, 0.995)
926 .eps(1e-7);
927 adamw.step_with_params(&mut [&mut param]);
928
929 assert!(param.data()[0] < 3.0);
930 }
931
932 #[test]
933 fn test_adamw_lr_methods() {
934 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
935 let mut adamw = AdamW::new(vec![&mut param], 0.01);
936
937 assert!((adamw.lr() - 0.01).abs() < 1e-6);
938 adamw.set_lr(0.001);
939 assert!((adamw.lr() - 0.001).abs() < 1e-6);
940 }
941
942 #[test]
943 fn test_adamw_zero_grad() {
944 clear_graph();
945
946 let mut param = Tensor::from_slice(&[2.0]).requires_grad();
947 let param_id = param.id();
948
949 let loss = param.pow(2.0).sum();
950 loss.backward();
951
952 assert!(get_grad(param_id).is_some());
953
954 let mut adamw = AdamW::new(vec![&mut param], 0.1);
955 adamw.zero_grad();
956
957 assert!(get_grad(param_id).is_none());
958 }
959
960 #[test]
961 fn test_adamw_step_trait() {
962 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
963 let mut adamw = AdamW::new(vec![&mut param], 0.1);
964
965 adamw.step();
967 assert!(adamw.initialized);
968 assert_eq!(adamw.t, 1);
969 }
970
971 #[test]
972 fn test_rmsprop_with_alpha() {
973 clear_graph();
974
975 let mut param = Tensor::from_slice(&[2.0]).requires_grad();
976
977 let loss = param.pow(2.0).sum();
978 loss.backward();
979
980 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).alpha(0.9);
981 rmsprop.step_with_params(&mut [&mut param]);
982
983 assert!(param.data()[0] < 2.0);
984 }
985
986 #[test]
987 fn test_rmsprop_with_eps() {
988 clear_graph();
989
990 let mut param = Tensor::from_slice(&[2.0]).requires_grad();
991
992 let loss = param.pow(2.0).sum();
993 loss.backward();
994
995 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).eps(1e-6);
996 rmsprop.step_with_params(&mut [&mut param]);
997
998 assert!(param.data()[0] < 2.0);
999 }
1000
1001 #[test]
1002 fn test_rmsprop_with_momentum() {
1003 clear_graph();
1004
1005 let mut param = Tensor::from_slice(&[3.0]).requires_grad();
1006
1007 let loss = param.pow(2.0).sum();
1009 loss.backward();
1010
1011 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).momentum(0.9);
1012 rmsprop.step_with_params(&mut [&mut param]);
1013
1014 let after_first = param.data()[0];
1015 assert!(after_first < 3.0);
1016
1017 clear_graph();
1019 let loss = param.pow(2.0).sum();
1020 loss.backward();
1021
1022 rmsprop.step_with_params(&mut [&mut param]);
1023
1024 assert!(param.data()[0] < after_first);
1025 }
1026
1027 #[test]
1028 fn test_rmsprop_with_weight_decay() {
1029 clear_graph();
1030
1031 let mut param = Tensor::from_slice(&[5.0]).requires_grad();
1032
1033 let loss = param.pow(2.0).sum();
1034 loss.backward();
1035
1036 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).weight_decay(0.1);
1037 rmsprop.step_with_params(&mut [&mut param]);
1038
1039 assert!(param.data()[0] < 5.0);
1040 }
1041
1042 #[test]
1043 fn test_rmsprop_lr_methods() {
1044 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1045 let mut rmsprop = RMSprop::new(vec![&mut param], 0.01);
1046
1047 assert!((rmsprop.lr() - 0.01).abs() < 1e-6);
1048 rmsprop.set_lr(0.001);
1049 assert!((rmsprop.lr() - 0.001).abs() < 1e-6);
1050 }
1051
1052 #[test]
1053 fn test_rmsprop_zero_grad() {
1054 clear_graph();
1055
1056 let mut param = Tensor::from_slice(&[2.0]).requires_grad();
1057 let param_id = param.id();
1058
1059 let loss = param.pow(2.0).sum();
1060 loss.backward();
1061
1062 assert!(get_grad(param_id).is_some());
1063
1064 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
1065 rmsprop.zero_grad();
1066
1067 assert!(get_grad(param_id).is_none());
1068 }
1069
1070 #[test]
1071 fn test_rmsprop_step_trait() {
1072 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1073 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
1074
1075 rmsprop.step();
1076 assert!(rmsprop.initialized);
1077 }
1078
1079 #[test]
1080 fn test_sgd_step_trait() {
1081 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1082 let mut sgd = SGD::new(vec![&mut param], 0.1);
1083
1084 sgd.step();
1085 assert!(sgd.initialized);
1086 }
1087
1088 #[test]
1089 fn test_adam_step_trait() {
1090 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1091 let mut adam = Adam::new(vec![&mut param], 0.1);
1092
1093 adam.step();
1094 assert!(adam.initialized);
1095 assert_eq!(adam.t, 1);
1096 }
1097
1098 #[test]
1099 fn test_adam_lr_methods() {
1100 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1101 let mut adam = Adam::new(vec![&mut param], 0.01);
1102
1103 assert!((adam.lr() - 0.01).abs() < 1e-6);
1104 adam.set_lr(0.001);
1105 assert!((adam.lr() - 0.001).abs() < 1e-6);
1106 }
1107
1108 #[test]
1109 fn test_adam_zero_grad() {
1110 clear_graph();
1111
1112 let mut param = Tensor::from_slice(&[2.0]).requires_grad();
1113 let param_id = param.id();
1114
1115 let loss = param.pow(2.0).sum();
1116 loss.backward();
1117
1118 assert!(get_grad(param_id).is_some());
1119
1120 let mut adam = Adam::new(vec![&mut param], 0.1);
1121 adam.zero_grad();
1122
1123 assert!(get_grad(param_id).is_none());
1124 }
1125
1126 #[test]
1127 fn test_sgd_multi_element_tensor() {
1128 clear_graph();
1129
1130 let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0]).requires_grad();
1131
1132 let loss = param.pow(2.0).sum();
1133 loss.backward();
1134
1135 let mut sgd = SGD::new(vec![&mut param], 0.1);
1136 sgd.step_with_params(&mut [&mut param]);
1137
1138 assert!(param.data()[0] < 1.0);
1140 assert!(param.data()[1] < 2.0);
1141 assert!(param.data()[2] < 3.0);
1142 assert!(param.data()[3] < 4.0);
1143 }
1144
1145 #[test]
1146 fn test_adam_multi_element_tensor() {
1147 clear_graph();
1148
1149 let mut param = Tensor::from_slice(&[1.0, 2.0, 3.0]).requires_grad();
1150
1151 let loss = param.pow(2.0).sum();
1152 loss.backward();
1153
1154 let mut adam = Adam::new(vec![&mut param], 0.1);
1155 adam.step_with_params(&mut [&mut param]);
1156
1157 assert!(param.data()[0] < 1.0);
1159 assert!(param.data()[1] < 2.0);
1160 assert!(param.data()[2] < 3.0);
1161 }
1162
1163 #[test]
1164 fn test_adamw_multi_step() {
1165 clear_graph();
1166
1167 let mut param = Tensor::from_slice(&[5.0]).requires_grad();
1168 let mut adamw = AdamW::new(vec![&mut param], 0.5).weight_decay(0.01);
1169
1170 for _ in 0..10 {
1172 clear_graph();
1173 let loss = param.pow(2.0).sum();
1174 loss.backward();
1175 adamw.step_with_params(&mut [&mut param]);
1176 }
1177
1178 assert!(param.data()[0] < 1.0);
1180 }
1181
1182 #[test]
1183 fn test_rmsprop_convergence() {
1184 clear_graph();
1185
1186 let mut param = Tensor::from_slice(&[5.0]).requires_grad();
1187 let mut rmsprop = RMSprop::new(vec![&mut param], 0.5);
1188
1189 for _ in 0..10 {
1191 clear_graph();
1192 let loss = param.pow(2.0).sum();
1193 loss.backward();
1194 rmsprop.step_with_params(&mut [&mut param]);
1195 }
1196
1197 assert!(param.data()[0] < 1.0);
1199 }
1200
1201 #[test]
1204 fn test_sgd_lr_accessor() {
1205 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1206 let sgd = SGD::new(vec![&mut param], 0.05);
1207 assert!((sgd.lr() - 0.05).abs() < 1e-6);
1208 }
1209
1210 #[test]
1211 fn test_adam_lr_accessor() {
1212 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1213 let adam = Adam::new(vec![&mut param], 0.001);
1214 assert!((adam.lr() - 0.001).abs() < 1e-6);
1215 }
1216
1217 #[test]
1218 fn test_adamw_lr_accessor() {
1219 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1220 let adamw = AdamW::new(vec![&mut param], 0.002);
1221 assert!((adamw.lr() - 0.002).abs() < 1e-6);
1222 }
1223
1224 #[test]
1225 fn test_rmsprop_lr_accessor() {
1226 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1227 let rmsprop = RMSprop::new(vec![&mut param], 0.003);
1228 assert!((rmsprop.lr() - 0.003).abs() < 1e-6);
1229 }
1230
1231 #[test]
1232 fn test_adam_set_lr() {
1233 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1234 let mut adam = Adam::new(vec![&mut param], 0.1);
1235 adam.set_lr(0.001);
1236 assert!((adam.lr() - 0.001).abs() < 1e-6);
1237 }
1238
1239 #[test]
1240 fn test_adamw_set_lr() {
1241 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1242 let mut adamw = AdamW::new(vec![&mut param], 0.1);
1243 adamw.set_lr(0.001);
1244 assert!((adamw.lr() - 0.001).abs() < 1e-6);
1245 }
1246
1247 #[test]
1248 fn test_rmsprop_set_lr() {
1249 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1250 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
1251 rmsprop.set_lr(0.001);
1252 assert!((rmsprop.lr() - 0.001).abs() < 1e-6);
1253 }
1254
1255 #[test]
1256 fn test_adam_zero_grad_clears() {
1257 clear_graph();
1258
1259 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1260 let param_id = param.id();
1261
1262 let loss = param.pow(2.0).sum();
1263 loss.backward();
1264
1265 assert!(get_grad(param_id).is_some());
1266
1267 let mut adam = Adam::new(vec![&mut param], 0.1);
1268 adam.zero_grad();
1269
1270 assert!(get_grad(param_id).is_none());
1271 }
1272
1273 #[test]
1274 fn test_adamw_zero_grad_clears() {
1275 clear_graph();
1276
1277 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1278 let param_id = param.id();
1279
1280 let loss = param.pow(2.0).sum();
1281 loss.backward();
1282
1283 assert!(get_grad(param_id).is_some());
1284
1285 let mut adamw = AdamW::new(vec![&mut param], 0.1);
1286 adamw.zero_grad();
1287
1288 assert!(get_grad(param_id).is_none());
1289 }
1290
1291 #[test]
1292 fn test_rmsprop_zero_grad_clears() {
1293 clear_graph();
1294
1295 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1296 let param_id = param.id();
1297
1298 let loss = param.pow(2.0).sum();
1299 loss.backward();
1300
1301 assert!(get_grad(param_id).is_some());
1302
1303 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1);
1304 rmsprop.zero_grad();
1305
1306 assert!(get_grad(param_id).is_none());
1307 }
1308
1309 #[test]
1310 fn test_sgd_multiple_params() {
1311 clear_graph();
1312
1313 let mut param1 = Tensor::from_slice(&[1.0]).requires_grad();
1314 let mut param2 = Tensor::from_slice(&[2.0]).requires_grad();
1315
1316 let loss1 = param1.pow(2.0).sum();
1318 let loss2 = param2.pow(2.0).sum();
1319 let loss = loss1.add(&loss2);
1320 loss.backward();
1321
1322 let mut sgd = SGD::new(vec![&mut param1, &mut param2], 0.1);
1323 sgd.step_with_params(&mut [&mut param1, &mut param2]);
1324
1325 assert!(param1.data()[0] < 1.0);
1327 assert!(param2.data()[0] < 2.0);
1328 }
1329
1330 #[test]
1331 fn test_adam_multiple_params() {
1332 clear_graph();
1333
1334 let mut param1 = Tensor::from_slice(&[1.0]).requires_grad();
1335 let mut param2 = Tensor::from_slice(&[2.0]).requires_grad();
1336
1337 let loss1 = param1.pow(2.0).sum();
1338 let loss2 = param2.pow(2.0).sum();
1339 let loss = loss1.add(&loss2);
1340 loss.backward();
1341
1342 let mut adam = Adam::new(vec![&mut param1, &mut param2], 0.1);
1343 adam.step_with_params(&mut [&mut param1, &mut param2]);
1344
1345 assert!(param1.data()[0] < 1.0);
1346 assert!(param2.data()[0] < 2.0);
1347 }
1348
1349 #[test]
1350 fn test_rmsprop_alpha_builder() {
1351 clear_graph();
1352
1353 let mut param = Tensor::from_slice(&[5.0]).requires_grad();
1354
1355 let loss = param.pow(2.0).sum();
1356 loss.backward();
1357
1358 let mut rmsprop = RMSprop::new(vec![&mut param], 0.1).alpha(0.9);
1360 rmsprop.step_with_params(&mut [&mut param]);
1361
1362 assert!(param.data()[0] < 5.0);
1363 }
1364
1365 #[test]
1366 fn test_sgd_debug_trait() {
1367 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1368 let sgd = SGD::new(vec![&mut param], 0.1);
1369 let debug_str = format!("{:?}", sgd);
1370 assert!(debug_str.contains("SGD"));
1371 }
1372
1373 #[test]
1374 fn test_adam_debug_trait() {
1375 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1376 let adam = Adam::new(vec![&mut param], 0.1);
1377 let debug_str = format!("{:?}", adam);
1378 assert!(debug_str.contains("Adam"));
1379 }
1380
1381 #[test]
1382 fn test_adamw_debug_trait() {
1383 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1384 let adamw = AdamW::new(vec![&mut param], 0.1);
1385 let debug_str = format!("{:?}", adamw);
1386 assert!(debug_str.contains("AdamW"));
1387 }
1388
1389 #[test]
1390 fn test_rmsprop_debug_trait() {
1391 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1392 let rmsprop = RMSprop::new(vec![&mut param], 0.1);
1393 let debug_str = format!("{:?}", rmsprop);
1394 assert!(debug_str.contains("RMSprop"));
1395 }
1396
1397 #[test]
1398 fn test_sgd_empty_params() {
1399 let sgd = SGD::new(vec![], 0.1);
1400 assert!((sgd.lr() - 0.1).abs() < 1e-6);
1401 }
1402
1403 #[test]
1404 fn test_adam_empty_params() {
1405 let adam = Adam::new(vec![], 0.1);
1406 assert!((adam.lr() - 0.1).abs() < 1e-6);
1407 }
1408
1409 #[test]
1410 fn test_sgd_momentum_initialization() {
1411 clear_graph();
1412
1413 let mut param = Tensor::from_slice(&[3.0, 4.0]).requires_grad();
1414
1415 let loss = param.pow(2.0).sum();
1416 loss.backward();
1417
1418 let mut sgd = SGD::with_momentum(vec![&mut param], 0.1, 0.9);
1420 sgd.step_with_params(&mut [&mut param]);
1421
1422 assert!(param.data()[0] < 3.0);
1424 assert!(param.data()[1] < 4.0);
1425 }
1426
1427 #[test]
1428 fn test_adam_step_counter() {
1429 clear_graph();
1430
1431 let mut param = Tensor::from_slice(&[1.0]).requires_grad();
1432 let mut adam = Adam::new(vec![&mut param], 0.1);
1433
1434 for _ in 0..3 {
1436 clear_graph();
1437 let loss = param.pow(2.0).sum();
1438 loss.backward();
1439 adam.step_with_params(&mut [&mut param]);
1440 }
1441
1442 assert!(param.data()[0] < 1.0);
1444 }
1445}