1use crate::error::{OptimizeError, OptimizeResult};
26use std::f64::consts::PI;
27
28#[derive(Debug, Clone)]
37pub enum LrSchedule {
38 Constant(f64),
40
41 ExponentialDecay {
43 initial: f64,
45 decay: f64,
47 },
48
49 CosineAnnealing {
51 lr_max: f64,
53 lr_min: f64,
55 t_max: usize,
57 },
58
59 WarmupCosine {
61 warmup_steps: usize,
63 lr_peak: f64,
65 lr_min: f64,
67 total_steps: usize,
69 },
70
71 StepLr {
73 initial: f64,
75 step_size: usize,
77 gamma: f64,
79 },
80}
81
82impl LrSchedule {
83 pub fn lr_at(&self, step: usize) -> f64 {
93 match self {
94 LrSchedule::Constant(lr) => *lr,
95
96 LrSchedule::ExponentialDecay { initial, decay } => initial * decay.powi(step as i32),
97
98 LrSchedule::CosineAnnealing {
99 lr_max,
100 lr_min,
101 t_max,
102 } => {
103 let t = (step % (2 * (*t_max).max(1))) as f64;
104 let t_m = *t_max as f64;
105 let cos_inner = PI * t / t_m;
106 lr_min + 0.5 * (lr_max - lr_min) * (1.0 + cos_inner.cos())
107 }
108
109 LrSchedule::WarmupCosine {
110 warmup_steps,
111 lr_peak,
112 lr_min,
113 total_steps,
114 } => {
115 let ws = *warmup_steps;
116 let ts = (*total_steps).max(ws + 1);
117 if step < ws {
118 lr_peak * step as f64 / ws.max(1) as f64
120 } else {
121 let progress = (step - ws) as f64 / (ts - ws) as f64;
123 lr_min + 0.5 * (lr_peak - lr_min) * (1.0 + (PI * progress).cos())
124 }
125 }
126
127 LrSchedule::StepLr {
128 initial,
129 step_size,
130 gamma,
131 } => {
132 let n_decays = step / (*step_size).max(1);
133 initial * gamma.powi(n_decays as i32)
134 }
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
151pub struct Sgd {
152 pub learning_rate: f64,
154 pub momentum: f64,
156 pub weight_decay: f64,
158 pub nesterov: bool,
160 velocity: Vec<f64>,
162}
163
164impl Sgd {
165 pub fn new(learning_rate: f64, momentum: f64) -> Self {
167 Self {
168 learning_rate,
169 momentum,
170 weight_decay: 0.0,
171 nesterov: false,
172 velocity: Vec::new(),
173 }
174 }
175
176 pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
182 if params.len() != grad.len() {
183 return Err(OptimizeError::InvalidInput(format!(
184 "params length {} != grad length {}",
185 params.len(),
186 grad.len()
187 )));
188 }
189
190 let n = params.len();
191 if self.velocity.len() != n {
192 self.velocity = vec![0.0; n];
193 }
194
195 let lr = self.learning_rate;
196 let mu = self.momentum;
197 let wd = self.weight_decay;
198
199 if self.nesterov {
200 for i in 0..n {
201 let g = grad[i] + wd * params[i];
202 self.velocity[i] = mu * self.velocity[i] + g;
203 params[i] -= lr * (mu * self.velocity[i] + g);
204 }
205 } else {
206 for i in 0..n {
207 let g = grad[i] + wd * params[i];
208 self.velocity[i] = mu * self.velocity[i] + g;
209 params[i] -= lr * self.velocity[i];
210 }
211 }
212 Ok(())
213 }
214
215 pub fn zero_velocity(&mut self, n: usize) {
217 self.velocity = vec![0.0; n];
218 }
219}
220
221#[derive(Debug, Clone)]
233pub struct Adam {
234 pub lr: f64,
236 pub beta1: f64,
238 pub beta2: f64,
240 pub eps: f64,
242 pub weight_decay: f64,
244 m: Vec<f64>,
246 v: Vec<f64>,
248 t: usize,
250}
251
252impl Adam {
253 pub fn new(lr: f64) -> Self {
255 Self {
256 lr,
257 beta1: 0.9,
258 beta2: 0.999,
259 eps: 1e-8,
260 weight_decay: 0.0,
261 m: Vec::new(),
262 v: Vec::new(),
263 t: 0,
264 }
265 }
266
267 pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
269 if params.len() != grad.len() {
270 return Err(OptimizeError::InvalidInput(format!(
271 "params length {} != grad length {}",
272 params.len(),
273 grad.len()
274 )));
275 }
276
277 let n = params.len();
278 if self.m.len() != n {
279 self.m = vec![0.0; n];
280 self.v = vec![0.0; n];
281 }
282
283 self.t += 1;
284 let t = self.t as f64;
285 let bias_corr1 = 1.0 - self.beta1.powf(t);
286 let bias_corr2 = 1.0 - self.beta2.powf(t);
287
288 for i in 0..n {
289 let g = grad[i] + self.weight_decay * params[i];
290 self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
291 self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
292 let m_hat = self.m[i] / bias_corr1;
293 let v_hat = self.v[i] / bias_corr2;
294 params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
295 }
296 Ok(())
297 }
298
299 pub fn reset_state(&mut self) {
301 self.m.clear();
302 self.v.clear();
303 self.t = 0;
304 }
305}
306
307#[derive(Debug, Clone)]
318pub struct AdaGrad {
319 pub lr: f64,
321 pub eps: f64,
323 pub weight_decay: f64,
325 sum_sq_grad: Vec<f64>,
327}
328
329impl AdaGrad {
330 pub fn new(lr: f64) -> Self {
332 Self {
333 lr,
334 eps: 1e-8,
335 weight_decay: 0.0,
336 sum_sq_grad: Vec::new(),
337 }
338 }
339
340 pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
342 if params.len() != grad.len() {
343 return Err(OptimizeError::InvalidInput(format!(
344 "params/grad length mismatch: {} vs {}",
345 params.len(),
346 grad.len()
347 )));
348 }
349
350 let n = params.len();
351 if self.sum_sq_grad.len() != n {
352 self.sum_sq_grad = vec![0.0; n];
353 }
354
355 for i in 0..n {
356 let g = grad[i] + self.weight_decay * params[i];
357 self.sum_sq_grad[i] += g * g;
358 params[i] -= self.lr * g / (self.sum_sq_grad[i].sqrt() + self.eps);
359 }
360 Ok(())
361 }
362
363 pub fn reset_state(&mut self) {
365 self.sum_sq_grad.clear();
366 }
367}
368
369#[derive(Debug, Clone)]
383pub struct RmsProp {
384 pub lr: f64,
386 pub alpha: f64,
388 pub eps: f64,
390 pub momentum: f64,
392 sq_avg: Vec<f64>,
394 velocity: Vec<f64>,
396}
397
398impl RmsProp {
399 pub fn new(lr: f64) -> Self {
401 Self {
402 lr,
403 alpha: 0.99,
404 eps: 1e-8,
405 momentum: 0.0,
406 sq_avg: Vec::new(),
407 velocity: Vec::new(),
408 }
409 }
410
411 pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
413 if params.len() != grad.len() {
414 return Err(OptimizeError::InvalidInput(format!(
415 "params/grad length mismatch: {} vs {}",
416 params.len(),
417 grad.len()
418 )));
419 }
420
421 let n = params.len();
422 if self.sq_avg.len() != n {
423 self.sq_avg = vec![0.0; n];
424 self.velocity = vec![0.0; n];
425 }
426
427 for i in 0..n {
428 let g = grad[i];
429 self.sq_avg[i] = self.alpha * self.sq_avg[i] + (1.0 - self.alpha) * g * g;
430 let denom = self.sq_avg[i].sqrt() + self.eps;
431 if self.momentum > 0.0 {
432 self.velocity[i] = self.momentum * self.velocity[i] + self.lr * g / denom;
433 params[i] -= self.velocity[i];
434 } else {
435 params[i] -= self.lr * g / denom;
436 }
437 }
438 Ok(())
439 }
440
441 pub fn reset_state(&mut self) {
443 self.sq_avg.clear();
444 self.velocity.clear();
445 }
446}
447
448#[derive(Debug, Clone)]
459pub struct AdamW {
460 pub lr: f64,
462 pub beta1: f64,
464 pub beta2: f64,
466 pub eps: f64,
468 pub weight_decay: f64,
470 m: Vec<f64>,
472 v: Vec<f64>,
474 t: usize,
476}
477
478impl AdamW {
479 pub fn new(lr: f64) -> Self {
482 Self {
483 lr,
484 beta1: 0.9,
485 beta2: 0.999,
486 eps: 1e-8,
487 weight_decay: 0.01,
488 m: Vec::new(),
489 v: Vec::new(),
490 t: 0,
491 }
492 }
493
494 pub fn step(&mut self, params: &mut Vec<f64>, grad: &[f64]) -> OptimizeResult<()> {
496 if params.len() != grad.len() {
497 return Err(OptimizeError::InvalidInput(format!(
498 "params/grad length mismatch: {} vs {}",
499 params.len(),
500 grad.len()
501 )));
502 }
503
504 let n = params.len();
505 if self.m.len() != n {
506 self.m = vec![0.0; n];
507 self.v = vec![0.0; n];
508 }
509
510 self.t += 1;
511 let t = self.t as f64;
512 let bc1 = 1.0 - self.beta1.powf(t);
513 let bc2 = 1.0 - self.beta2.powf(t);
514
515 for i in 0..n {
516 params[i] *= 1.0 - self.lr * self.weight_decay;
518
519 let g = grad[i];
520 self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
521 self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
522 let m_hat = self.m[i] / bc1;
523 let v_hat = self.v[i] / bc2;
524 params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
525 }
526 Ok(())
527 }
528
529 pub fn reset_state(&mut self) {
531 self.m.clear();
532 self.v.clear();
533 self.t = 0;
534 }
535}
536
537#[derive(Debug, Clone)]
553pub struct Svrg {
554 pub lr: f64,
556 pub n: usize,
558 pub update_freq: usize,
560 snapshot_params: Vec<f64>,
562 snapshot_grad: Vec<f64>,
564 inner_t: usize,
566}
567
568impl Svrg {
569 pub fn new(lr: f64, n: usize, update_freq: usize) -> Self {
576 Self {
577 lr,
578 n,
579 update_freq,
580 snapshot_params: Vec::new(),
581 snapshot_grad: Vec::new(),
582 inner_t: 0,
583 }
584 }
585
586 pub fn step(
597 &mut self,
598 params: &mut Vec<f64>,
599 stochastic_grad: &[f64],
600 snapshot_grad_i: &[f64],
601 ) -> OptimizeResult<()> {
602 let n = params.len();
603
604 if stochastic_grad.len() != n || snapshot_grad_i.len() != n {
605 return Err(OptimizeError::InvalidInput(format!(
606 "SVRG gradient/param length mismatch: params={}, sg={}, sgi={}",
607 n,
608 stochastic_grad.len(),
609 snapshot_grad_i.len()
610 )));
611 }
612
613 if self.snapshot_grad.len() != n {
614 return Err(OptimizeError::InvalidInput(
615 "SVRG: snapshot not initialised — call update_snapshot first".to_string(),
616 ));
617 }
618
619 for i in 0..n {
621 let g_tilde = stochastic_grad[i] - snapshot_grad_i[i] + self.snapshot_grad[i];
622 params[i] -= self.lr * g_tilde;
623 }
624
625 self.inner_t += 1;
626 Ok(())
627 }
628
629 pub fn update_snapshot(&mut self, params: &[f64], full_grad: &[f64]) {
634 self.snapshot_params = params.to_vec();
635 self.snapshot_grad = full_grad.to_vec();
636 self.inner_t = 0;
637 }
638
639 pub fn needs_snapshot_update(&self) -> bool {
641 self.inner_t >= self.update_freq
642 }
643
644 pub fn snapshot_params(&self) -> &[f64] {
646 &self.snapshot_params
647 }
648}
649
650#[cfg(test)]
655mod tests {
656 use super::*;
657 use approx::assert_abs_diff_eq;
658
659 fn quadratic_grad(params: &[f64]) -> Vec<f64> {
660 params.iter().map(|&p| 2.0 * p).collect()
661 }
662
663 #[test]
666 fn test_constant_schedule() {
667 let s = LrSchedule::Constant(0.01);
668 assert_abs_diff_eq!(s.lr_at(0), 0.01, epsilon = 1e-14);
669 assert_abs_diff_eq!(s.lr_at(1000), 0.01, epsilon = 1e-14);
670 }
671
672 #[test]
673 fn test_exponential_decay_schedule() {
674 let s = LrSchedule::ExponentialDecay {
675 initial: 0.1,
676 decay: 0.9,
677 };
678 assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-12);
679 assert_abs_diff_eq!(s.lr_at(1), 0.09, epsilon = 1e-10);
680 assert_abs_diff_eq!(s.lr_at(10), 0.1 * 0.9_f64.powi(10), epsilon = 1e-10);
681 }
682
683 #[test]
684 fn test_cosine_annealing_at_zero() {
685 let s = LrSchedule::CosineAnnealing {
686 lr_max: 0.1,
687 lr_min: 0.0,
688 t_max: 100,
689 };
690 assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-10);
692 }
693
694 #[test]
695 fn test_cosine_annealing_at_t_max() {
696 let s = LrSchedule::CosineAnnealing {
697 lr_max: 0.1,
698 lr_min: 0.001,
699 t_max: 50,
700 };
701 assert_abs_diff_eq!(s.lr_at(50), 0.001, epsilon = 1e-10);
703 }
704
705 #[test]
706 fn test_warmup_cosine_warmup_phase() {
707 let s = LrSchedule::WarmupCosine {
708 warmup_steps: 10,
709 lr_peak: 0.1,
710 lr_min: 0.0,
711 total_steps: 110,
712 };
713 assert_abs_diff_eq!(s.lr_at(5), 0.05, epsilon = 1e-10);
715 let lr10 = s.lr_at(10);
717 assert!(
718 lr10 >= 0.09 && lr10 <= 0.1 + 1e-9,
719 "lr at warmup end ≈ peak, got {}",
720 lr10
721 );
722 }
723
724 #[test]
725 fn test_step_lr_schedule() {
726 let s = LrSchedule::StepLr {
727 initial: 0.1,
728 step_size: 10,
729 gamma: 0.5,
730 };
731 assert_abs_diff_eq!(s.lr_at(0), 0.1, epsilon = 1e-12);
732 assert_abs_diff_eq!(s.lr_at(9), 0.1, epsilon = 1e-12);
733 assert_abs_diff_eq!(s.lr_at(10), 0.05, epsilon = 1e-12);
734 assert_abs_diff_eq!(s.lr_at(20), 0.025, epsilon = 1e-12);
735 }
736
737 #[test]
740 fn test_sgd_converges_quadratic() {
741 let mut opt = Sgd::new(0.1, 0.0);
742 let mut p = vec![1.0, -2.0];
743 for _ in 0..200 {
744 let g = quadratic_grad(&p);
745 opt.step(&mut p, &g).expect("step failed");
746 }
747 assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-4);
748 assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-4);
749 }
750
751 #[test]
752 fn test_sgd_momentum_converges() {
753 let mut opt = Sgd::new(0.05, 0.9);
754 let mut p = vec![2.0, -1.5];
755 for _ in 0..500 {
756 let g = quadratic_grad(&p);
757 opt.step(&mut p, &g).expect("step failed");
758 }
759 assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-3);
760 assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-3);
761 }
762
763 #[test]
764 fn test_sgd_nesterov() {
765 let mut opt = Sgd {
766 nesterov: true,
767 ..Sgd::new(0.05, 0.9)
768 };
769 let mut p = vec![1.0, 1.0];
770 for _ in 0..500 {
771 let g = quadratic_grad(&p);
772 opt.step(&mut p, &g).expect("step failed");
773 }
774 assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-3);
775 }
776
777 #[test]
778 fn test_sgd_weight_decay() {
779 let mut opt = Sgd {
780 weight_decay: 0.1,
781 ..Sgd::new(0.01, 0.0)
782 };
783 let mut p = vec![1.0];
784 opt.step(&mut p, &[0.0]).expect("step failed");
785 assert!(p[0] < 1.0, "weight decay should shrink param");
787 }
788
789 #[test]
790 fn test_sgd_length_mismatch() {
791 let mut opt = Sgd::new(0.01, 0.0);
792 let mut p = vec![1.0, 2.0];
793 assert!(opt.step(&mut p, &[0.1]).is_err());
794 }
795
796 #[test]
797 fn test_sgd_zero_velocity() {
798 let mut opt = Sgd::new(0.01, 0.9);
799 opt.zero_velocity(5);
800 assert_eq!(opt.velocity.len(), 5);
801 assert!(opt.velocity.iter().all(|&v| v == 0.0));
802 }
803
804 #[test]
807 fn test_adam_converges() {
808 let mut opt = Adam::new(0.01);
809 let mut p = vec![3.0, -3.0];
810 for _ in 0..1000 {
811 let g = quadratic_grad(&p);
812 opt.step(&mut p, &g).expect("step failed");
813 }
814 assert_abs_diff_eq!(p[0], 0.0, epsilon = 1e-2);
815 assert_abs_diff_eq!(p[1], 0.0, epsilon = 1e-2);
816 }
817
818 #[test]
819 fn test_adam_reset_state() {
820 let mut opt = Adam::new(0.01);
821 let mut p = vec![1.0];
822 opt.step(&mut p, &[0.5]).expect("step failed");
823 assert_eq!(opt.t, 1);
824 opt.reset_state();
825 assert_eq!(opt.t, 0);
826 assert!(opt.m.is_empty());
827 assert!(opt.v.is_empty());
828 }
829
830 #[test]
831 fn test_adam_weight_decay_coupled() {
832 let mut opt = Adam {
833 weight_decay: 0.01,
834 ..Adam::new(0.001)
835 };
836 let mut p = vec![1.0];
837 let p_before = p[0];
838 opt.step(&mut p, &[0.0]).expect("step failed");
839 assert!(p[0] < p_before, "weight decay should reduce param");
841 }
842
843 #[test]
846 fn test_adagrad_converges() {
847 let mut opt = AdaGrad::new(0.5);
848 let mut p = vec![3.0, -2.0];
849 for _ in 0..2000 {
850 let g = quadratic_grad(&p);
851 opt.step(&mut p, &g).expect("step failed");
852 }
853 assert!(p[0].abs() < 0.5, "adagrad should converge, p[0]={}", p[0]);
854 }
855
856 #[test]
857 fn test_adagrad_reset() {
858 let mut opt = AdaGrad::new(0.1);
859 let mut p = vec![1.0];
860 opt.step(&mut p, &[1.0]).expect("step failed");
861 assert_eq!(opt.sum_sq_grad.len(), 1);
862 opt.reset_state();
863 assert!(opt.sum_sq_grad.is_empty());
864 }
865
866 #[test]
869 fn test_rmsprop_converges() {
870 let mut opt = RmsProp::new(0.01);
871 let mut p = vec![2.0, -2.0];
872 for _ in 0..1000 {
873 let g = quadratic_grad(&p);
874 opt.step(&mut p, &g).expect("step failed");
875 }
876 assert!(p[0].abs() < 0.1, "rmsprop p[0]={}", p[0]);
877 }
878
879 #[test]
880 fn test_rmsprop_with_momentum() {
881 let mut opt = RmsProp {
882 momentum: 0.9,
883 ..RmsProp::new(0.01)
884 };
885 let mut p = vec![1.0, 1.0];
886 for _ in 0..500 {
887 let g = quadratic_grad(&p);
888 opt.step(&mut p, &g).expect("step failed");
889 }
890 assert!(p[0].abs() < 0.5, "rmsprop+momentum p[0]={}", p[0]);
891 }
892
893 #[test]
894 fn test_rmsprop_length_mismatch() {
895 let mut opt = RmsProp::new(0.01);
896 let mut p = vec![1.0, 2.0];
897 assert!(opt.step(&mut p, &[0.1]).is_err());
898 }
899
900 #[test]
903 fn test_adamw_decoupled_wd() {
904 let mut opt = AdamW {
905 weight_decay: 0.1,
906 ..AdamW::new(0.001)
907 };
908 let mut p = vec![1.0];
909 let p_before = p[0];
910 opt.step(&mut p, &[0.0]).expect("step failed");
911 assert!(p[0] < p_before, "decoupled WD should shrink param");
913 }
914
915 #[test]
916 fn test_adamw_converges() {
917 let mut opt = AdamW {
918 weight_decay: 0.0,
919 ..AdamW::new(0.01)
920 };
921 let mut p = vec![2.0, -2.0];
922 for _ in 0..1000 {
923 let g = quadratic_grad(&p);
924 opt.step(&mut p, &g).expect("step failed");
925 }
926 assert!(p[0].abs() < 0.1, "adamw p[0]={}", p[0]);
927 }
928
929 #[test]
930 fn test_adamw_reset() {
931 let mut opt = AdamW::new(0.001);
932 let mut p = vec![1.0];
933 opt.step(&mut p, &[0.5]).expect("step failed");
934 assert_eq!(opt.t, 1);
935 opt.reset_state();
936 assert_eq!(opt.t, 0);
937 assert!(opt.m.is_empty());
938 }
939
940 #[test]
943 fn test_svrg_needs_snapshot() {
944 let mut svrg = Svrg::new(0.01, 100, 10);
945 let mut p = vec![1.0, 2.0];
946 let sg = vec![0.1, 0.2];
947 let sgi = vec![0.05, 0.1];
948 assert!(svrg.step(&mut p, &sg, &sgi).is_err());
950 }
951
952 #[test]
953 fn test_svrg_step_after_snapshot() {
954 let mut svrg = Svrg::new(0.01, 100, 10);
955 let mut p = vec![1.0, 1.0];
956 let full_grad = vec![2.0, 2.0]; svrg.update_snapshot(&p, &full_grad);
958
959 let sg = vec![2.1, 1.9];
960 let sgi = vec![2.0, 2.0];
961 svrg.step(&mut p, &sg, &sgi).expect("step failed");
962 assert_abs_diff_eq!(p[0], 1.0 - 0.01 * 2.1, epsilon = 1e-12);
965 }
966
967 #[test]
968 fn test_svrg_update_freq() {
969 let mut svrg = Svrg::new(0.01, 100, 3);
970 let mut p = vec![1.0];
971 svrg.update_snapshot(&p, &[0.0]);
972 assert!(!svrg.needs_snapshot_update());
973
974 for _ in 0..3 {
975 svrg.step(&mut p, &[0.0], &[0.0]).expect("step");
976 }
977 assert!(svrg.needs_snapshot_update());
978 }
979
980 #[test]
981 fn test_svrg_snapshot_params() {
982 let mut svrg = Svrg::new(0.01, 100, 10);
983 let snap = vec![3.0, 4.0];
984 svrg.update_snapshot(&snap, &[0.0, 0.0]);
985 assert_eq!(svrg.snapshot_params(), &[3.0, 4.0]);
986 }
987
988 #[test]
989 fn test_svrg_length_mismatch() {
990 let mut svrg = Svrg::new(0.01, 100, 10);
991 let mut p = vec![1.0, 2.0];
992 svrg.update_snapshot(&p, &[0.0, 0.0]);
993 assert!(svrg.step(&mut p, &[0.1], &[0.0, 0.0]).is_err());
995 }
996}