1use crate::optimizer::Optimizer;
18
19pub trait LRScheduler {
25 fn step<O: Optimizer>(&mut self, optimizer: &mut O);
27
28 fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, _metric: f32) {
33 self.step(optimizer);
34 }
35
36 fn get_last_lr(&self) -> f32;
38
39 fn get_step(&self) -> usize;
41}
42
43pub struct StepLR {
51 initial_lr: f32,
52 step_size: usize,
53 gamma: f32,
54 current_step: usize,
55 last_lr: f32,
56}
57
58impl StepLR {
59 pub fn new<O: Optimizer>(optimizer: &O, step_size: usize, gamma: f32) -> Self {
61 let initial_lr = optimizer.get_lr();
62 Self {
63 initial_lr,
64 step_size,
65 gamma,
66 current_step: 0,
67 last_lr: initial_lr,
68 }
69 }
70}
71
72impl LRScheduler for StepLR {
73 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
74 self.current_step += 1;
75 let num_decays = self.current_step / self.step_size;
76 let new_lr = self.initial_lr * self.gamma.powi(num_decays as i32);
77 optimizer.set_lr(new_lr);
78 self.last_lr = new_lr;
79 }
80
81 fn get_last_lr(&self) -> f32 {
82 self.last_lr
83 }
84
85 fn get_step(&self) -> usize {
86 self.current_step
87 }
88}
89
90pub struct MultiStepLR {
96 initial_lr: f32,
97 milestones: Vec<usize>,
98 gamma: f32,
99 current_step: usize,
100 last_lr: f32,
101 milestone_idx: usize,
102}
103
104impl MultiStepLR {
105 pub fn new<O: Optimizer>(optimizer: &O, mut milestones: Vec<usize>, gamma: f32) -> Self {
107 let initial_lr = optimizer.get_lr();
108 milestones.sort_unstable();
109 Self {
110 initial_lr,
111 milestones,
112 gamma,
113 current_step: 0,
114 last_lr: initial_lr,
115 milestone_idx: 0,
116 }
117 }
118}
119
120impl LRScheduler for MultiStepLR {
121 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
122 self.current_step += 1;
123
124 while self.milestone_idx < self.milestones.len()
126 && self.current_step >= self.milestones[self.milestone_idx]
127 {
128 self.milestone_idx += 1;
129 }
130
131 let new_lr = self.initial_lr * self.gamma.powi(self.milestone_idx as i32);
132 optimizer.set_lr(new_lr);
133 self.last_lr = new_lr;
134 }
135
136 fn get_last_lr(&self) -> f32 {
137 self.last_lr
138 }
139
140 fn get_step(&self) -> usize {
141 self.current_step
142 }
143}
144
145pub struct ExponentialLR {
153 initial_lr: f32,
154 gamma: f32,
155 current_step: usize,
156 last_lr: f32,
157}
158
159impl ExponentialLR {
160 pub fn new<O: Optimizer>(optimizer: &O, gamma: f32) -> Self {
162 let initial_lr = optimizer.get_lr();
163 Self {
164 initial_lr,
165 gamma,
166 current_step: 0,
167 last_lr: initial_lr,
168 }
169 }
170}
171
172impl LRScheduler for ExponentialLR {
173 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
174 self.current_step += 1;
175 let new_lr = self.initial_lr * self.gamma.powi(self.current_step as i32);
176 optimizer.set_lr(new_lr);
177 self.last_lr = new_lr;
178 }
179
180 fn get_last_lr(&self) -> f32 {
181 self.last_lr
182 }
183
184 fn get_step(&self) -> usize {
185 self.current_step
186 }
187}
188
189pub struct CosineAnnealingLR {
197 initial_lr: f32,
198 t_max: usize,
199 eta_min: f32,
200 current_step: usize,
201 last_lr: f32,
202}
203
204impl CosineAnnealingLR {
205 pub fn new<O: Optimizer>(optimizer: &O, t_max: usize) -> Self {
207 Self::with_eta_min(optimizer, t_max, 0.0)
208 }
209
210 pub fn with_eta_min<O: Optimizer>(optimizer: &O, t_max: usize, eta_min: f32) -> Self {
212 let initial_lr = optimizer.get_lr();
213 Self {
214 initial_lr,
215 t_max,
216 eta_min,
217 current_step: 0,
218 last_lr: initial_lr,
219 }
220 }
221}
222
223impl LRScheduler for CosineAnnealingLR {
224 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
225 self.current_step += 1;
226
227 let progress = self.current_step as f32 / self.t_max as f32;
228 let new_lr = self.eta_min
229 + (self.initial_lr - self.eta_min) * (1.0 + (std::f32::consts::PI * progress).cos())
230 / 2.0;
231
232 optimizer.set_lr(new_lr);
233 self.last_lr = new_lr;
234 }
235
236 fn get_last_lr(&self) -> f32 {
237 self.last_lr
238 }
239
240 fn get_step(&self) -> usize {
241 self.current_step
242 }
243}
244
245pub struct ReduceLROnPlateau {
251 mode: String,
252 factor: f32,
253 patience: usize,
254 threshold: f32,
255 cooldown: usize,
256 min_lr: f32,
257 best: f32,
258 num_bad_epochs: usize,
259 cooldown_counter: usize,
260 current_step: usize,
261 last_lr: f32,
262}
263
264impl ReduceLROnPlateau {
265 pub fn new<O: Optimizer>(optimizer: &O) -> Self {
267 Self::with_options(optimizer, "min", 0.1, 10, 1e-4, 0, 0.0)
268 }
269
270 pub fn with_options<O: Optimizer>(
272 optimizer: &O,
273 mode: &str,
274 factor: f32,
275 patience: usize,
276 threshold: f32,
277 cooldown: usize,
278 min_lr: f32,
279 ) -> Self {
280 let initial_lr = optimizer.get_lr();
281 let best = if mode == "min" {
282 f32::INFINITY
283 } else {
284 f32::NEG_INFINITY
285 };
286 Self {
287 mode: mode.to_string(),
288 factor,
289 patience,
290 threshold,
291 cooldown,
292 min_lr,
293 best,
294 num_bad_epochs: 0,
295 cooldown_counter: 0,
296 current_step: 0,
297 last_lr: initial_lr,
298 }
299 }
300
301 fn step_metric_impl<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
303 self.current_step += 1;
304
305 if self.cooldown_counter > 0 {
307 self.cooldown_counter -= 1;
308 return;
309 }
310
311 let improved = if self.mode == "min" {
313 metric < self.best * (1.0 - self.threshold)
314 } else {
315 metric > self.best * (1.0 + self.threshold)
316 };
317
318 if improved {
319 self.best = metric;
320 self.num_bad_epochs = 0;
321 } else {
322 self.num_bad_epochs += 1;
323 }
324
325 if self.num_bad_epochs > self.patience {
327 let current_lr = optimizer.get_lr();
328 let new_lr = (current_lr * self.factor).max(self.min_lr);
329 optimizer.set_lr(new_lr);
330 self.last_lr = new_lr;
331 self.cooldown_counter = self.cooldown;
332 self.num_bad_epochs = 0;
333 }
334 }
335}
336
337impl LRScheduler for ReduceLROnPlateau {
338 fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
339 self.current_step += 1;
341 }
342
343 fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
344 self.step_metric_impl(optimizer, metric);
345 }
346
347 fn get_last_lr(&self) -> f32 {
348 self.last_lr
349 }
350
351 fn get_step(&self) -> usize {
352 self.current_step
353 }
354}
355
356pub struct OneCycleLR {
364 max_lr: f32,
365 total_steps: usize,
366 pct_start: f32,
367 div_factor: f32,
368 final_div_factor: f32,
369 current_step: usize,
370 last_lr: f32,
371}
372
373impl OneCycleLR {
374 pub fn new<O: Optimizer>(optimizer: &O, max_lr: f32, total_steps: usize) -> Self {
376 Self::with_options(optimizer, max_lr, total_steps, 0.3, 25.0, 1e4)
377 }
378
379 pub fn with_options<O: Optimizer>(
381 _optimizer: &O,
382 max_lr: f32,
383 total_steps: usize,
384 pct_start: f32,
385 div_factor: f32,
386 final_div_factor: f32,
387 ) -> Self {
388 let initial_lr = max_lr / div_factor;
389 Self {
390 max_lr,
391 total_steps,
392 pct_start,
393 div_factor,
394 final_div_factor,
395 current_step: 0,
396 last_lr: initial_lr,
397 }
398 }
399}
400
401impl LRScheduler for OneCycleLR {
402 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
403 self.current_step += 1;
404
405 let step_ratio = self.current_step as f32 / self.total_steps as f32;
406 let initial_lr = self.max_lr / self.div_factor;
407 let min_lr = self.max_lr / self.final_div_factor;
408
409 let new_lr = if step_ratio <= self.pct_start {
410 let phase_ratio = step_ratio / self.pct_start;
412 initial_lr + (self.max_lr - initial_lr) * phase_ratio
413 } else {
414 let phase_ratio = (step_ratio - self.pct_start) / (1.0 - self.pct_start);
416 min_lr
417 + (self.max_lr - min_lr) * (1.0 + (std::f32::consts::PI * phase_ratio).cos()) / 2.0
418 };
419
420 optimizer.set_lr(new_lr);
421 self.last_lr = new_lr;
422 }
423
424 fn get_last_lr(&self) -> f32 {
425 self.last_lr
426 }
427
428 fn get_step(&self) -> usize {
429 self.current_step
430 }
431}
432
433pub struct WarmupLR {
441 initial_lr: f32,
442 warmup_steps: usize,
443 current_step: usize,
444 last_lr: f32,
445}
446
447impl WarmupLR {
448 pub fn new<O: Optimizer>(optimizer: &O, warmup_steps: usize) -> Self {
450 let initial_lr = optimizer.get_lr();
451 Self {
452 initial_lr,
453 warmup_steps,
454 current_step: 0,
455 last_lr: 0.0,
456 }
457 }
458}
459
460impl LRScheduler for WarmupLR {
461 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
462 self.current_step += 1;
463
464 let new_lr = if self.current_step <= self.warmup_steps {
465 self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
466 } else {
467 self.initial_lr
468 };
469
470 optimizer.set_lr(new_lr);
471 self.last_lr = new_lr;
472 }
473
474 fn get_last_lr(&self) -> f32 {
475 self.last_lr
476 }
477
478 fn get_step(&self) -> usize {
479 self.current_step
480 }
481}
482
483#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::SGD;
491 use axonml_autograd::Variable;
492 use axonml_nn::Parameter;
493 use axonml_tensor::Tensor;
494
495 fn create_test_optimizer() -> SGD {
496 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
497 let param = Parameter::from_variable(var);
498 SGD::new(vec![param], 0.1)
499 }
500
501 #[test]
502 fn test_step_lr() {
503 let mut optimizer = create_test_optimizer();
504 let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
505
506 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
507
508 for _ in 0..10 {
509 scheduler.step(&mut optimizer);
510 }
511
512 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
513
514 for _ in 0..10 {
515 scheduler.step(&mut optimizer);
516 }
517
518 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
519 }
520
521 #[test]
522 fn test_multi_step_lr() {
523 let mut optimizer = create_test_optimizer();
524 let mut scheduler = MultiStepLR::new(&optimizer, vec![5, 15], 0.1);
525
526 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
527
528 for _ in 0..5 {
529 scheduler.step(&mut optimizer);
530 }
531 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
532
533 for _ in 0..10 {
534 scheduler.step(&mut optimizer);
535 }
536 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
537 }
538
539 #[test]
540 fn test_exponential_lr() {
541 let mut optimizer = create_test_optimizer();
542 let mut scheduler = ExponentialLR::new(&optimizer, 0.9);
543
544 scheduler.step(&mut optimizer);
545 assert!((optimizer.get_lr() - 0.09).abs() < 1e-6);
546
547 scheduler.step(&mut optimizer);
548 assert!((optimizer.get_lr() - 0.081).abs() < 1e-6);
549 }
550
551 #[test]
552 fn test_cosine_annealing_lr() {
553 let mut optimizer = create_test_optimizer();
554 let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
555
556 for _ in 0..50 {
558 scheduler.step(&mut optimizer);
559 }
560 assert!((optimizer.get_lr() - 0.05).abs() < 0.01);
561
562 for _ in 0..50 {
564 scheduler.step(&mut optimizer);
565 }
566 assert!(optimizer.get_lr() < 0.01);
567 }
568
569 #[test]
570 fn test_warmup_lr() {
571 let mut optimizer = create_test_optimizer();
572 let mut scheduler = WarmupLR::new(&optimizer, 10);
573
574 scheduler.step(&mut optimizer);
575 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
576
577 for _ in 0..9 {
578 scheduler.step(&mut optimizer);
579 }
580 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
581
582 scheduler.step(&mut optimizer);
584 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
585 }
586
587 #[test]
588 fn test_one_cycle_lr() {
589 let mut optimizer = create_test_optimizer();
590 let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
591
592 assert!((scheduler.get_last_lr() - 0.004).abs() < 0.001);
594
595 for _ in 0..30 {
597 scheduler.step(&mut optimizer);
598 }
599
600 assert!(optimizer.get_lr() > 0.08);
602 }
603
604 #[test]
605 fn test_reduce_lr_on_plateau() {
606 let mut optimizer = create_test_optimizer();
607 let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 2, 0.0, 0, 0.0);
608
609 let initial_lr = optimizer.get_lr();
610
611 scheduler.step_with_metric(&mut optimizer, 1.0);
613 scheduler.step_with_metric(&mut optimizer, 0.9);
614 assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
615
616 scheduler.step_with_metric(&mut optimizer, 0.91);
618 scheduler.step_with_metric(&mut optimizer, 0.91);
619 scheduler.step_with_metric(&mut optimizer, 0.91);
620
621 assert!(optimizer.get_lr() < initial_lr);
623 }
624
625 #[test]
630 fn test_reduce_lr_on_plateau_max_mode() {
631 let mut optimizer = create_test_optimizer();
632 let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "max", 0.5, 2, 0.0, 0, 0.0);
633
634 let initial_lr = optimizer.get_lr();
635
636 scheduler.step_with_metric(&mut optimizer, 0.8);
638 scheduler.step_with_metric(&mut optimizer, 0.9);
639 assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
640
641 scheduler.step_with_metric(&mut optimizer, 0.85);
643 scheduler.step_with_metric(&mut optimizer, 0.85);
644 scheduler.step_with_metric(&mut optimizer, 0.85);
645
646 assert!(
647 optimizer.get_lr() < initial_lr,
648 "LR should reduce on plateau in max mode"
649 );
650 }
651
652 #[test]
653 fn test_reduce_lr_on_plateau_min_lr_floor() {
654 let mut optimizer = create_test_optimizer();
655 let mut scheduler =
656 ReduceLROnPlateau::with_options(&optimizer, "min", 0.1, 0, 0.0, 0, 0.001);
657
658 for _ in 0..50 {
660 scheduler.step_with_metric(&mut optimizer, 999.0); }
662
663 assert!(
664 optimizer.get_lr() >= 0.001,
665 "LR should not go below min_lr, got {}",
666 optimizer.get_lr()
667 );
668 }
669
670 #[test]
671 fn test_reduce_lr_cooldown() {
672 let mut optimizer = create_test_optimizer();
673 let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 0, 0.0, 3, 0.0);
674
675 let initial_lr = optimizer.get_lr();
676
677 scheduler.step_with_metric(&mut optimizer, 999.0);
679 scheduler.step_with_metric(&mut optimizer, 999.0);
680 let lr_after_first_reduce = optimizer.get_lr();
681 assert!(lr_after_first_reduce < initial_lr);
682
683 scheduler.step_with_metric(&mut optimizer, 999.0);
685 scheduler.step_with_metric(&mut optimizer, 999.0);
686 scheduler.step_with_metric(&mut optimizer, 999.0);
687 assert!(
688 (optimizer.get_lr() - lr_after_first_reduce).abs() < 1e-8,
689 "LR should not change during cooldown"
690 );
691 }
692
693 #[test]
698 fn test_one_cycle_lr_full_cycle() {
699 let mut optimizer = create_test_optimizer();
700 let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
701
702 let mut lrs = Vec::new();
703 for _ in 0..100 {
704 scheduler.step(&mut optimizer);
705 lrs.push(optimizer.get_lr());
706 }
707
708 let max_lr = lrs.iter().cloned().fold(f32::MIN, f32::max);
710 let final_lr = *lrs.last().unwrap();
711
712 assert!(
713 max_lr > 0.08,
714 "Peak should be near max_lr=0.1, got {}",
715 max_lr
716 );
717 assert!(
718 final_lr < 0.001,
719 "Final LR should be very small, got {}",
720 final_lr
721 );
722
723 let peak_idx = lrs
725 .iter()
726 .enumerate()
727 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
728 .unwrap()
729 .0;
730 assert!(
731 peak_idx >= 25 && peak_idx <= 35,
732 "Peak should be around step 30, was at step {}",
733 peak_idx
734 );
735 }
736
737 #[test]
738 fn test_one_cycle_lr_monotonic_phases() {
739 let mut optimizer = create_test_optimizer();
740 let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
741
742 let mut lrs = Vec::new();
743 for _ in 0..100 {
744 scheduler.step(&mut optimizer);
745 lrs.push(optimizer.get_lr());
746 }
747
748 for i in 1..29 {
750 assert!(
751 lrs[i] >= lrs[i - 1] - 1e-6,
752 "Warmup should increase: step {} lr={} < step {} lr={}",
753 i,
754 lrs[i],
755 i - 1,
756 lrs[i - 1]
757 );
758 }
759
760 for i in 32..99 {
762 assert!(
763 lrs[i] <= lrs[i - 1] + 1e-6,
764 "Annealing should decrease: step {} lr={} > step {} lr={}",
765 i,
766 lrs[i],
767 i - 1,
768 lrs[i - 1]
769 );
770 }
771 }
772
773 #[test]
778 fn test_cosine_annealing_with_eta_min() {
779 let mut optimizer = create_test_optimizer();
780 let mut scheduler = CosineAnnealingLR::with_eta_min(&optimizer, 100, 0.001);
781
782 for _ in 0..100 {
783 scheduler.step(&mut optimizer);
784 }
785
786 assert!(
788 (optimizer.get_lr() - 0.001).abs() < 0.002,
789 "Should reach eta_min at end, got {}",
790 optimizer.get_lr()
791 );
792 }
793
794 #[test]
795 fn test_cosine_annealing_monotonic_decrease() {
796 let mut optimizer = create_test_optimizer();
797 let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
798
799 let mut lrs = Vec::new();
800 for _ in 0..100 {
801 scheduler.step(&mut optimizer);
802 lrs.push(optimizer.get_lr());
803 }
804
805 for i in 1..lrs.len() {
807 assert!(
808 lrs[i] <= lrs[i - 1] + 1e-6,
809 "Cosine should decrease: step {} lr={} > step {} lr={}",
810 i + 1,
811 lrs[i],
812 i,
813 lrs[i - 1]
814 );
815 }
816
817 assert!(
819 lrs.iter().all(|lr| *lr >= 0.0),
820 "LRs should be non-negative"
821 );
822 }
823
824 #[test]
829 fn test_warmup_lr_stays_constant_after() {
830 let mut optimizer = create_test_optimizer();
831 let mut scheduler = WarmupLR::new(&optimizer, 5);
832
833 for _ in 0..5 {
834 scheduler.step(&mut optimizer);
835 }
836 let target = optimizer.get_lr();
837
838 for _ in 0..100 {
840 scheduler.step(&mut optimizer);
841 assert!(
842 (optimizer.get_lr() - target).abs() < 1e-8,
843 "LR should stay at {} after warmup, got {}",
844 target,
845 optimizer.get_lr()
846 );
847 }
848 }
849}