1use super::optim::Optimizer;
31
32pub trait LRScheduler {
34 fn step<O: Optimizer>(&mut self, optimizer: &mut O);
36
37 fn get_lr(&self) -> f32;
39
40 fn last_epoch(&self) -> usize;
42}
43
44#[derive(Debug, Clone)]
52pub struct StepLR {
53 initial_lr: f32,
54 step_size: usize,
55 gamma: f32,
56 current_lr: f32,
57 last_epoch: usize,
58}
59
60impl StepLR {
61 #[must_use]
68 pub fn new(step_size: usize, gamma: f32) -> Self {
69 Self {
70 initial_lr: 0.0, step_size,
72 gamma,
73 current_lr: 0.0,
74 last_epoch: 0,
75 }
76 }
77
78 #[must_use]
80 pub fn with_lr(initial_lr: f32, step_size: usize, gamma: f32) -> Self {
81 Self {
82 initial_lr,
83 step_size,
84 gamma,
85 current_lr: initial_lr,
86 last_epoch: 0,
87 }
88 }
89}
90
91impl LRScheduler for StepLR {
92 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
93 if self.last_epoch == 0 && self.initial_lr == 0.0 {
95 self.initial_lr = optimizer.lr();
96 self.current_lr = self.initial_lr;
97 }
98
99 self.last_epoch += 1;
100
101 if self.last_epoch % self.step_size == 0 {
103 self.current_lr *= self.gamma;
104 optimizer.set_lr(self.current_lr);
105 }
106 }
107
108 fn get_lr(&self) -> f32 {
109 self.current_lr
110 }
111
112 fn last_epoch(&self) -> usize {
113 self.last_epoch
114 }
115}
116
117#[derive(Debug, Clone)]
125pub struct ExponentialLR {
126 initial_lr: f32,
127 gamma: f32,
128 current_lr: f32,
129 last_epoch: usize,
130}
131
132impl ExponentialLR {
133 #[must_use]
139 pub fn new(gamma: f32) -> Self {
140 Self {
141 initial_lr: 0.0,
142 gamma,
143 current_lr: 0.0,
144 last_epoch: 0,
145 }
146 }
147
148 #[must_use]
149 pub fn with_lr(initial_lr: f32, gamma: f32) -> Self {
150 Self {
151 initial_lr,
152 gamma,
153 current_lr: initial_lr,
154 last_epoch: 0,
155 }
156 }
157}
158
159impl LRScheduler for ExponentialLR {
160 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
161 if self.last_epoch == 0 && self.initial_lr == 0.0 {
162 self.initial_lr = optimizer.lr();
163 self.current_lr = self.initial_lr;
164 }
165
166 self.last_epoch += 1;
167 self.current_lr *= self.gamma;
168 optimizer.set_lr(self.current_lr);
169 }
170
171 fn get_lr(&self) -> f32 {
172 self.current_lr
173 }
174
175 fn last_epoch(&self) -> usize {
176 self.last_epoch
177 }
178}
179
180#[derive(Debug, Clone)]
188pub struct CosineAnnealingLR {
189 initial_lr: f32,
190 min_lr: f32,
191 t_max: usize,
192 current_lr: f32,
193 last_epoch: usize,
194}
195
196impl CosineAnnealingLR {
197 #[must_use]
204 pub fn new(t_max: usize) -> Self {
205 Self {
206 initial_lr: 0.0,
207 min_lr: 0.0,
208 t_max,
209 current_lr: 0.0,
210 last_epoch: 0,
211 }
212 }
213
214 #[must_use]
215 pub fn with_min_lr(t_max: usize, min_lr: f32) -> Self {
216 Self {
217 initial_lr: 0.0,
218 min_lr,
219 t_max,
220 current_lr: 0.0,
221 last_epoch: 0,
222 }
223 }
224
225 #[must_use]
226 pub fn with_lr(initial_lr: f32, t_max: usize, min_lr: f32) -> Self {
227 Self {
228 initial_lr,
229 min_lr,
230 t_max,
231 current_lr: initial_lr,
232 last_epoch: 0,
233 }
234 }
235}
236
237impl LRScheduler for CosineAnnealingLR {
238 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
239 if self.last_epoch == 0 && self.initial_lr == 0.0 {
240 self.initial_lr = optimizer.lr();
241 self.current_lr = self.initial_lr;
242 }
243
244 self.last_epoch += 1;
245
246 let progress = self.last_epoch as f32 / self.t_max as f32;
248 let cosine = (std::f32::consts::PI * progress).cos();
249 self.current_lr = self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + cosine);
250
251 optimizer.set_lr(self.current_lr);
252 }
253
254 fn get_lr(&self) -> f32 {
255 self.current_lr
256 }
257
258 fn last_epoch(&self) -> usize {
259 self.last_epoch
260 }
261}
262
263#[derive(Debug, Clone)]
274pub struct LinearWarmup {
275 initial_lr: f32,
276 warmup_steps: usize,
277 current_lr: f32,
278 last_epoch: usize,
279}
280
281impl LinearWarmup {
282 #[must_use]
288 pub fn new(warmup_steps: usize) -> Self {
289 Self {
290 initial_lr: 0.0,
291 warmup_steps,
292 current_lr: 0.0,
293 last_epoch: 0,
294 }
295 }
296
297 #[must_use]
298 pub fn with_lr(initial_lr: f32, warmup_steps: usize) -> Self {
299 Self {
300 initial_lr,
301 warmup_steps,
302 current_lr: 0.0,
303 last_epoch: 0,
304 }
305 }
306}
307
308impl LRScheduler for LinearWarmup {
309 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
310 if self.last_epoch == 0 && self.initial_lr == 0.0 {
311 self.initial_lr = optimizer.lr();
312 }
313
314 self.last_epoch += 1;
315
316 if self.last_epoch <= self.warmup_steps {
317 self.current_lr = self.initial_lr * (self.last_epoch as f32 / self.warmup_steps as f32);
319 } else {
320 self.current_lr = self.initial_lr;
321 }
322
323 optimizer.set_lr(self.current_lr);
324 }
325
326 fn get_lr(&self) -> f32 {
327 self.current_lr
328 }
329
330 fn last_epoch(&self) -> usize {
331 self.last_epoch
332 }
333}
334
335#[derive(Debug, Clone)]
347pub struct WarmupCosineScheduler {
348 initial_lr: f32,
349 min_lr: f32,
350 warmup_steps: usize,
351 total_steps: usize,
352 current_lr: f32,
353 last_epoch: usize,
354}
355
356impl WarmupCosineScheduler {
357 #[must_use]
364 pub fn new(warmup_steps: usize, total_steps: usize) -> Self {
365 Self {
366 initial_lr: 0.0,
367 min_lr: 0.0,
368 warmup_steps,
369 total_steps,
370 current_lr: 0.0,
371 last_epoch: 0,
372 }
373 }
374
375 #[must_use]
376 pub fn with_min_lr(warmup_steps: usize, total_steps: usize, min_lr: f32) -> Self {
377 Self {
378 initial_lr: 0.0,
379 min_lr,
380 warmup_steps,
381 total_steps,
382 current_lr: 0.0,
383 last_epoch: 0,
384 }
385 }
386}
387
388impl LRScheduler for WarmupCosineScheduler {
389 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
390 if self.last_epoch == 0 && self.initial_lr == 0.0 {
391 self.initial_lr = optimizer.lr();
392 }
393
394 self.last_epoch += 1;
395
396 if self.last_epoch <= self.warmup_steps {
397 self.current_lr = self.initial_lr * (self.last_epoch as f32 / self.warmup_steps as f32);
399 } else {
400 let decay_steps = self.total_steps - self.warmup_steps;
402 let decay_epoch = self.last_epoch - self.warmup_steps;
403 let progress = decay_epoch as f32 / decay_steps as f32;
404 let cosine = (std::f32::consts::PI * progress).cos();
405 self.current_lr = self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + cosine);
406 }
407
408 optimizer.set_lr(self.current_lr);
409 }
410
411 fn get_lr(&self) -> f32 {
412 self.current_lr
413 }
414
415 fn last_epoch(&self) -> usize {
416 self.last_epoch
417 }
418}
419
420#[derive(Debug, Clone)]
424pub struct ReduceLROnPlateau {
425 factor: f32,
426 patience: usize,
427 min_lr: f32,
428 threshold: f32,
429 current_lr: f32,
430 best_metric: f32,
431 num_bad_epochs: usize,
432 last_epoch: usize,
433 mode: PlateauMode,
434}
435
436#[derive(Debug, Clone, Copy, PartialEq, Eq)]
438pub enum PlateauMode {
439 Min,
441 Max,
443}
444
445impl ReduceLROnPlateau {
446 #[must_use]
454 pub fn new(mode: PlateauMode, factor: f32, patience: usize) -> Self {
455 let best_metric = match mode {
456 PlateauMode::Min => f32::INFINITY,
457 PlateauMode::Max => f32::NEG_INFINITY,
458 };
459
460 Self {
461 factor,
462 patience,
463 min_lr: 1e-8,
464 threshold: 1e-4,
465 current_lr: 0.0,
466 best_metric,
467 num_bad_epochs: 0,
468 last_epoch: 0,
469 mode,
470 }
471 }
472
473 #[must_use]
475 pub fn min_lr(mut self, min_lr: f32) -> Self {
476 self.min_lr = min_lr;
477 self
478 }
479
480 #[must_use]
482 pub fn threshold(mut self, threshold: f32) -> Self {
483 self.threshold = threshold;
484 self
485 }
486
487 pub fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
489 if self.last_epoch == 0 && self.current_lr == 0.0 {
490 self.current_lr = optimizer.lr();
491 }
492
493 self.last_epoch += 1;
494
495 let is_better = match self.mode {
497 PlateauMode::Min => metric < self.best_metric - self.threshold,
498 PlateauMode::Max => metric > self.best_metric + self.threshold,
499 };
500
501 if is_better {
502 self.best_metric = metric;
503 self.num_bad_epochs = 0;
504 } else {
505 self.num_bad_epochs += 1;
506 }
507
508 if self.num_bad_epochs >= self.patience {
510 let new_lr = (self.current_lr * self.factor).max(self.min_lr);
511 if new_lr < self.current_lr {
512 self.current_lr = new_lr;
513 optimizer.set_lr(self.current_lr);
514 self.num_bad_epochs = 0;
515 }
516 }
517 }
518}
519
520impl LRScheduler for ReduceLROnPlateau {
521 fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
522 self.last_epoch += 1;
524 }
525
526 fn get_lr(&self) -> f32 {
527 self.current_lr
528 }
529
530 fn last_epoch(&self) -> usize {
531 self.last_epoch
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 struct MockOptimizer {
541 lr: f32,
542 }
543
544 impl MockOptimizer {
545 fn new(lr: f32) -> Self {
546 Self { lr }
547 }
548 }
549
550 impl Optimizer for MockOptimizer {
551 fn step(&mut self) {}
552 fn zero_grad(&mut self) {}
553 fn lr(&self) -> f32 {
554 self.lr
555 }
556 fn set_lr(&mut self, lr: f32) {
557 self.lr = lr;
558 }
559 }
560
561 #[test]
562 fn test_step_lr() {
563 let mut optimizer = MockOptimizer::new(0.1);
564 let mut scheduler = StepLR::new(3, 0.1);
565
566 scheduler.step(&mut optimizer);
568 assert!((optimizer.lr() - 0.1).abs() < 1e-6);
569 scheduler.step(&mut optimizer);
570 assert!((optimizer.lr() - 0.1).abs() < 1e-6);
571 scheduler.step(&mut optimizer);
572 assert!((optimizer.lr() - 0.01).abs() < 1e-6);
574
575 scheduler.step(&mut optimizer);
577 scheduler.step(&mut optimizer);
578 scheduler.step(&mut optimizer);
579 assert!((optimizer.lr() - 0.001).abs() < 1e-6);
581 }
582
583 #[test]
584 fn test_exponential_lr() {
585 let mut optimizer = MockOptimizer::new(0.1);
586 let mut scheduler = ExponentialLR::new(0.9);
587
588 scheduler.step(&mut optimizer);
589 assert!((optimizer.lr() - 0.09).abs() < 1e-6);
590
591 scheduler.step(&mut optimizer);
592 assert!((optimizer.lr() - 0.081).abs() < 1e-6);
593 }
594
595 #[test]
596 fn test_cosine_annealing() {
597 let mut optimizer = MockOptimizer::new(0.1);
598 let mut scheduler = CosineAnnealingLR::new(10);
599
600 scheduler.step(&mut optimizer);
602 assert!(optimizer.lr() < 0.1);
604 assert!(optimizer.lr() > 0.09);
605
606 for _ in 0..4 {
608 scheduler.step(&mut optimizer);
609 }
610 assert!((optimizer.lr() - 0.05).abs() < 0.01);
611
612 for _ in 0..5 {
614 scheduler.step(&mut optimizer);
615 }
616 assert!(optimizer.lr() < 0.01);
617 }
618
619 #[test]
620 fn test_linear_warmup() {
621 let mut optimizer = MockOptimizer::new(0.1);
622 let mut scheduler = LinearWarmup::new(5);
623
624 scheduler.step(&mut optimizer);
626 assert!((optimizer.lr() - 0.02).abs() < 1e-6); scheduler.step(&mut optimizer);
629 assert!((optimizer.lr() - 0.04).abs() < 1e-6); for _ in 0..3 {
633 scheduler.step(&mut optimizer);
634 }
635 assert!((optimizer.lr() - 0.1).abs() < 1e-6);
636
637 scheduler.step(&mut optimizer);
638 assert!((optimizer.lr() - 0.1).abs() < 1e-6); }
640
641 #[test]
642 fn test_warmup_cosine() {
643 let mut optimizer = MockOptimizer::new(0.1);
644 let mut scheduler = WarmupCosineScheduler::new(5, 20);
645
646 scheduler.step(&mut optimizer);
648 assert!((optimizer.lr() - 0.02).abs() < 1e-6);
649
650 for _ in 0..4 {
652 scheduler.step(&mut optimizer);
653 }
654 assert!((optimizer.lr() - 0.1).abs() < 1e-6);
655
656 scheduler.step(&mut optimizer);
658 assert!(optimizer.lr() < 0.1);
659 }
660
661 #[test]
662 fn test_reduce_on_plateau() {
663 let mut optimizer = MockOptimizer::new(0.1);
664 let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 3);
665
666 scheduler.step_with_metric(&mut optimizer, 1.0);
668 assert!((optimizer.lr() - 0.1).abs() < 1e-6);
669
670 scheduler.step_with_metric(&mut optimizer, 0.9);
671 assert!((optimizer.lr() - 0.1).abs() < 1e-6);
672
673 scheduler.step_with_metric(&mut optimizer, 0.9);
675 scheduler.step_with_metric(&mut optimizer, 0.9);
676 scheduler.step_with_metric(&mut optimizer, 0.9);
677
678 assert!((optimizer.lr() - 0.01).abs() < 1e-6);
680 }
681
682 #[test]
683 fn test_reduce_on_plateau_max_mode() {
684 let mut optimizer = MockOptimizer::new(0.1);
685 let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Max, 0.5, 2);
686
687 scheduler.step_with_metric(&mut optimizer, 0.5);
689 scheduler.step_with_metric(&mut optimizer, 0.6);
690 assert!((optimizer.lr() - 0.1).abs() < 1e-6);
691
692 scheduler.step_with_metric(&mut optimizer, 0.6);
694 scheduler.step_with_metric(&mut optimizer, 0.6);
695
696 assert!((optimizer.lr() - 0.05).abs() < 1e-6);
698 }
699
700 #[test]
703 fn test_step_lr_with_lr() {
704 let mut optimizer = MockOptimizer::new(0.1);
705 let mut scheduler = StepLR::with_lr(0.2, 2, 0.5);
706
707 assert_eq!(scheduler.get_lr(), 0.2);
708 assert_eq!(scheduler.last_epoch(), 0);
709
710 scheduler.step(&mut optimizer);
711 assert_eq!(scheduler.last_epoch(), 1);
712 scheduler.step(&mut optimizer);
713 assert!((scheduler.get_lr() - 0.1).abs() < 1e-6);
715 }
716
717 #[test]
718 fn test_exponential_lr_with_lr() {
719 let mut optimizer = MockOptimizer::new(0.1);
720 let mut scheduler = ExponentialLR::with_lr(0.5, 0.8);
721
722 assert_eq!(scheduler.get_lr(), 0.5);
723 assert_eq!(scheduler.last_epoch(), 0);
724
725 scheduler.step(&mut optimizer);
726 assert!((scheduler.get_lr() - 0.4).abs() < 1e-6);
727 assert_eq!(scheduler.last_epoch(), 1);
728 }
729
730 #[test]
731 fn test_cosine_annealing_with_min_lr() {
732 let mut optimizer = MockOptimizer::new(0.1);
733 let mut scheduler = CosineAnnealingLR::with_min_lr(10, 0.01);
734
735 scheduler.step(&mut optimizer);
736 assert!(scheduler.get_lr() > 0.01);
737 assert!(scheduler.get_lr() < 0.1);
738 }
739
740 #[test]
741 fn test_cosine_annealing_with_lr() {
742 let mut optimizer = MockOptimizer::new(0.05);
743 let mut scheduler = CosineAnnealingLR::with_lr(0.2, 10, 0.02);
744
745 assert_eq!(scheduler.get_lr(), 0.2);
746 scheduler.step(&mut optimizer);
747 assert!(scheduler.get_lr() < 0.2);
749 assert!(scheduler.get_lr() > 0.02);
750 }
751
752 #[test]
753 fn test_linear_warmup_with_lr() {
754 let mut optimizer = MockOptimizer::new(0.05);
755 let mut scheduler = LinearWarmup::with_lr(0.2, 4);
756
757 assert_eq!(scheduler.get_lr(), 0.0); scheduler.step(&mut optimizer);
759 assert!((scheduler.get_lr() - 0.05).abs() < 1e-6); assert_eq!(scheduler.last_epoch(), 1);
761 }
762
763 #[test]
764 fn test_warmup_cosine_with_min_lr() {
765 let mut optimizer = MockOptimizer::new(0.1);
766 let mut scheduler = WarmupCosineScheduler::with_min_lr(5, 20, 0.001);
767
768 for _ in 0..5 {
770 scheduler.step(&mut optimizer);
771 }
772 assert!((scheduler.get_lr() - 0.1).abs() < 1e-6);
773
774 scheduler.step(&mut optimizer);
776 assert!(scheduler.get_lr() < 0.1);
777 assert!(scheduler.get_lr() > 0.001);
778 assert_eq!(scheduler.last_epoch(), 6);
779 }
780
781 #[test]
782 fn test_reduce_on_plateau_min_lr_builder() {
783 let scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 3).min_lr(0.0001);
784 assert!((scheduler.min_lr - 0.0001).abs() < 1e-8);
785 }
786
787 #[test]
788 fn test_reduce_on_plateau_threshold_builder() {
789 let scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 3).threshold(0.001);
790 assert!((scheduler.threshold - 0.001).abs() < 1e-8);
791 }
792
793 #[test]
794 fn test_reduce_on_plateau_step_without_metric() {
795 let mut optimizer = MockOptimizer::new(0.1);
796 let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 2);
797
798 scheduler.step(&mut optimizer);
800 assert_eq!(scheduler.last_epoch(), 1);
801 scheduler.step(&mut optimizer);
802 assert_eq!(scheduler.last_epoch(), 2);
803 }
804
805 #[test]
806 fn test_reduce_on_plateau_min_lr_clamp() {
807 let mut optimizer = MockOptimizer::new(0.001);
808 let mut scheduler = ReduceLROnPlateau::new(PlateauMode::Min, 0.1, 1).min_lr(0.0005);
809
810 scheduler.step_with_metric(&mut optimizer, 1.0);
812 scheduler.step_with_metric(&mut optimizer, 1.0);
814 assert!(scheduler.get_lr() >= 0.0005);
816 }
817
818 #[test]
819 fn test_step_lr_getters() {
820 let scheduler = StepLR::with_lr(0.1, 5, 0.9);
821 assert_eq!(scheduler.get_lr(), 0.1);
822 assert_eq!(scheduler.last_epoch(), 0);
823 }
824
825 #[test]
826 fn test_exponential_lr_getters() {
827 let scheduler = ExponentialLR::with_lr(0.1, 0.9);
828 assert_eq!(scheduler.get_lr(), 0.1);
829 assert_eq!(scheduler.last_epoch(), 0);
830 }
831
832 #[test]
833 fn test_cosine_annealing_getters() {
834 let scheduler = CosineAnnealingLR::with_lr(0.1, 10, 0.01);
835 assert_eq!(scheduler.get_lr(), 0.1);
836 assert_eq!(scheduler.last_epoch(), 0);
837 }
838
839 #[test]
840 fn test_linear_warmup_getters() {
841 let scheduler = LinearWarmup::with_lr(0.1, 5);
842 assert_eq!(scheduler.get_lr(), 0.0);
843 assert_eq!(scheduler.last_epoch(), 0);
844 }
845
846 #[test]
847 fn test_warmup_cosine_getters() {
848 let scheduler = WarmupCosineScheduler::with_min_lr(5, 20, 0.01);
849 assert_eq!(scheduler.get_lr(), 0.0);
850 assert_eq!(scheduler.last_epoch(), 0);
851 }
852
853 #[test]
854 fn test_reduce_on_plateau_getters() {
855 let scheduler = ReduceLROnPlateau::new(PlateauMode::Max, 0.5, 3);
856 assert_eq!(scheduler.get_lr(), 0.0);
857 assert_eq!(scheduler.last_epoch(), 0);
858 }
859
860 #[test]
861 fn test_plateau_mode_eq() {
862 assert_eq!(PlateauMode::Min, PlateauMode::Min);
863 assert_eq!(PlateauMode::Max, PlateauMode::Max);
864 assert_ne!(PlateauMode::Min, PlateauMode::Max);
865 }
866
867 #[test]
868 fn test_scheduler_clone() {
869 let scheduler = StepLR::with_lr(0.1, 5, 0.9);
870 let cloned = scheduler.clone();
871 assert_eq!(scheduler.get_lr(), cloned.get_lr());
872 }
873
874 #[test]
875 fn test_scheduler_debug() {
876 let scheduler = StepLR::with_lr(0.1, 5, 0.9);
877 let debug = format!("{scheduler:?}");
878 assert!(debug.contains("StepLR"));
879 }
880}