1use crate::optimizer::Optimizer;
26
27pub trait LRScheduler {
33 fn step<O: Optimizer>(&mut self, optimizer: &mut O);
35
36 fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, _metric: f32) {
41 self.step(optimizer);
42 }
43
44 fn get_last_lr(&self) -> f32;
46
47 fn get_step(&self) -> usize;
49}
50
51pub struct StepLR {
59 initial_lr: f32,
60 step_size: usize,
61 gamma: f32,
62 current_step: usize,
63 last_lr: f32,
64}
65
66impl StepLR {
67 pub fn new<O: Optimizer>(optimizer: &O, step_size: usize, gamma: f32) -> Self {
69 let initial_lr = optimizer.get_lr();
70 Self {
71 initial_lr,
72 step_size,
73 gamma,
74 current_step: 0,
75 last_lr: initial_lr,
76 }
77 }
78}
79
80impl LRScheduler for StepLR {
81 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
82 self.current_step += 1;
83 let num_decays = self.current_step / self.step_size;
84 let new_lr = self.initial_lr * self.gamma.powi(num_decays as i32);
85 optimizer.set_lr(new_lr);
86 self.last_lr = new_lr;
87 }
88
89 fn get_last_lr(&self) -> f32 {
90 self.last_lr
91 }
92
93 fn get_step(&self) -> usize {
94 self.current_step
95 }
96}
97
98pub struct MultiStepLR {
104 initial_lr: f32,
105 milestones: Vec<usize>,
106 gamma: f32,
107 current_step: usize,
108 last_lr: f32,
109 milestone_idx: usize,
110}
111
112impl MultiStepLR {
113 pub fn new<O: Optimizer>(optimizer: &O, mut milestones: Vec<usize>, gamma: f32) -> Self {
115 let initial_lr = optimizer.get_lr();
116 milestones.sort_unstable();
117 Self {
118 initial_lr,
119 milestones,
120 gamma,
121 current_step: 0,
122 last_lr: initial_lr,
123 milestone_idx: 0,
124 }
125 }
126}
127
128impl LRScheduler for MultiStepLR {
129 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
130 self.current_step += 1;
131
132 while self.milestone_idx < self.milestones.len()
134 && self.current_step >= self.milestones[self.milestone_idx]
135 {
136 self.milestone_idx += 1;
137 }
138
139 let new_lr = self.initial_lr * self.gamma.powi(self.milestone_idx as i32);
140 optimizer.set_lr(new_lr);
141 self.last_lr = new_lr;
142 }
143
144 fn get_last_lr(&self) -> f32 {
145 self.last_lr
146 }
147
148 fn get_step(&self) -> usize {
149 self.current_step
150 }
151}
152
153pub struct ExponentialLR {
161 initial_lr: f32,
162 gamma: f32,
163 current_step: usize,
164 last_lr: f32,
165}
166
167impl ExponentialLR {
168 pub fn new<O: Optimizer>(optimizer: &O, gamma: f32) -> Self {
170 let initial_lr = optimizer.get_lr();
171 Self {
172 initial_lr,
173 gamma,
174 current_step: 0,
175 last_lr: initial_lr,
176 }
177 }
178}
179
180impl LRScheduler for ExponentialLR {
181 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
182 self.current_step += 1;
183 let new_lr = self.initial_lr * self.gamma.powi(self.current_step as i32);
184 optimizer.set_lr(new_lr);
185 self.last_lr = new_lr;
186 }
187
188 fn get_last_lr(&self) -> f32 {
189 self.last_lr
190 }
191
192 fn get_step(&self) -> usize {
193 self.current_step
194 }
195}
196
197pub struct CosineAnnealingLR {
205 initial_lr: f32,
206 t_max: usize,
207 eta_min: f32,
208 current_step: usize,
209 last_lr: f32,
210}
211
212impl CosineAnnealingLR {
213 pub fn new<O: Optimizer>(optimizer: &O, t_max: usize) -> Self {
215 Self::with_eta_min(optimizer, t_max, 0.0)
216 }
217
218 pub fn with_eta_min<O: Optimizer>(optimizer: &O, t_max: usize, eta_min: f32) -> Self {
220 let initial_lr = optimizer.get_lr();
221 Self {
222 initial_lr,
223 t_max,
224 eta_min,
225 current_step: 0,
226 last_lr: initial_lr,
227 }
228 }
229}
230
231impl LRScheduler for CosineAnnealingLR {
232 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
233 self.current_step += 1;
234
235 let progress = self.current_step as f32 / self.t_max as f32;
236 let new_lr = self.eta_min
237 + (self.initial_lr - self.eta_min) * (1.0 + (std::f32::consts::PI * progress).cos())
238 / 2.0;
239
240 optimizer.set_lr(new_lr);
241 self.last_lr = new_lr;
242 }
243
244 fn get_last_lr(&self) -> f32 {
245 self.last_lr
246 }
247
248 fn get_step(&self) -> usize {
249 self.current_step
250 }
251}
252
253pub struct ReduceLROnPlateau {
259 mode: String,
260 factor: f32,
261 patience: usize,
262 threshold: f32,
263 cooldown: usize,
264 min_lr: f32,
265 best: f32,
266 num_bad_epochs: usize,
267 cooldown_counter: usize,
268 current_step: usize,
269 last_lr: f32,
270}
271
272impl ReduceLROnPlateau {
273 pub fn new<O: Optimizer>(optimizer: &O) -> Self {
275 Self::with_options(optimizer, "min", 0.1, 10, 1e-4, 0, 0.0)
276 }
277
278 pub fn with_options<O: Optimizer>(
280 optimizer: &O,
281 mode: &str,
282 factor: f32,
283 patience: usize,
284 threshold: f32,
285 cooldown: usize,
286 min_lr: f32,
287 ) -> Self {
288 let initial_lr = optimizer.get_lr();
289 let best = if mode == "min" {
290 f32::INFINITY
291 } else {
292 f32::NEG_INFINITY
293 };
294 Self {
295 mode: mode.to_string(),
296 factor,
297 patience,
298 threshold,
299 cooldown,
300 min_lr,
301 best,
302 num_bad_epochs: 0,
303 cooldown_counter: 0,
304 current_step: 0,
305 last_lr: initial_lr,
306 }
307 }
308
309 fn step_metric_impl<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
311 self.current_step += 1;
312
313 if self.cooldown_counter > 0 {
315 self.cooldown_counter -= 1;
316 return;
317 }
318
319 let improved = if self.mode == "min" {
321 metric < self.best * (1.0 - self.threshold)
322 } else {
323 metric > self.best * (1.0 + self.threshold)
324 };
325
326 if improved {
327 self.best = metric;
328 self.num_bad_epochs = 0;
329 } else {
330 self.num_bad_epochs += 1;
331 }
332
333 if self.num_bad_epochs > self.patience {
335 let current_lr = optimizer.get_lr();
336 let new_lr = (current_lr * self.factor).max(self.min_lr);
337 optimizer.set_lr(new_lr);
338 self.last_lr = new_lr;
339 self.cooldown_counter = self.cooldown;
340 self.num_bad_epochs = 0;
341 }
342 }
343}
344
345impl LRScheduler for ReduceLROnPlateau {
346 fn step<O: Optimizer>(&mut self, _optimizer: &mut O) {
347 self.current_step += 1;
349 }
350
351 fn step_with_metric<O: Optimizer>(&mut self, optimizer: &mut O, metric: f32) {
352 self.step_metric_impl(optimizer, metric);
353 }
354
355 fn get_last_lr(&self) -> f32 {
356 self.last_lr
357 }
358
359 fn get_step(&self) -> usize {
360 self.current_step
361 }
362}
363
364pub struct OneCycleLR {
372 max_lr: f32,
373 total_steps: usize,
374 pct_start: f32,
375 div_factor: f32,
376 final_div_factor: f32,
377 current_step: usize,
378 last_lr: f32,
379}
380
381impl OneCycleLR {
382 pub fn new<O: Optimizer>(optimizer: &O, max_lr: f32, total_steps: usize) -> Self {
384 Self::with_options(optimizer, max_lr, total_steps, 0.3, 25.0, 1e4)
385 }
386
387 pub fn with_options<O: Optimizer>(
389 _optimizer: &O,
390 max_lr: f32,
391 total_steps: usize,
392 pct_start: f32,
393 div_factor: f32,
394 final_div_factor: f32,
395 ) -> Self {
396 let initial_lr = max_lr / div_factor;
397 Self {
398 max_lr,
399 total_steps,
400 pct_start,
401 div_factor,
402 final_div_factor,
403 current_step: 0,
404 last_lr: initial_lr,
405 }
406 }
407}
408
409impl LRScheduler for OneCycleLR {
410 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
411 self.current_step += 1;
412
413 let step_ratio = self.current_step as f32 / self.total_steps as f32;
414 let initial_lr = self.max_lr / self.div_factor;
415 let min_lr = self.max_lr / self.final_div_factor;
416
417 let new_lr = if step_ratio <= self.pct_start {
418 let phase_ratio = step_ratio / self.pct_start;
420 initial_lr + (self.max_lr - initial_lr) * phase_ratio
421 } else {
422 let phase_ratio = (step_ratio - self.pct_start) / (1.0 - self.pct_start);
424 min_lr
425 + (self.max_lr - min_lr) * (1.0 + (std::f32::consts::PI * phase_ratio).cos()) / 2.0
426 };
427
428 optimizer.set_lr(new_lr);
429 self.last_lr = new_lr;
430 }
431
432 fn get_last_lr(&self) -> f32 {
433 self.last_lr
434 }
435
436 fn get_step(&self) -> usize {
437 self.current_step
438 }
439}
440
441pub struct WarmupLR {
449 initial_lr: f32,
450 warmup_steps: usize,
451 current_step: usize,
452 last_lr: f32,
453}
454
455impl WarmupLR {
456 pub fn new<O: Optimizer>(optimizer: &O, warmup_steps: usize) -> Self {
458 let initial_lr = optimizer.get_lr();
459 Self {
460 initial_lr,
461 warmup_steps,
462 current_step: 0,
463 last_lr: 0.0,
464 }
465 }
466}
467
468impl LRScheduler for WarmupLR {
469 fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
470 self.current_step += 1;
471
472 let new_lr = if self.current_step <= self.warmup_steps {
473 self.initial_lr * (self.current_step as f32 / self.warmup_steps as f32)
474 } else {
475 self.initial_lr
476 };
477
478 optimizer.set_lr(new_lr);
479 self.last_lr = new_lr;
480 }
481
482 fn get_last_lr(&self) -> f32 {
483 self.last_lr
484 }
485
486 fn get_step(&self) -> usize {
487 self.current_step
488 }
489}
490
491#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::SGD;
499 use axonml_autograd::Variable;
500 use axonml_nn::Parameter;
501 use axonml_tensor::Tensor;
502
503 fn create_test_optimizer() -> SGD {
504 let var = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), true);
505 let param = Parameter::from_variable(var);
506 SGD::new(vec![param], 0.1)
507 }
508
509 #[test]
510 fn test_step_lr() {
511 let mut optimizer = create_test_optimizer();
512 let mut scheduler = StepLR::new(&optimizer, 10, 0.1);
513
514 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
515
516 for _ in 0..10 {
517 scheduler.step(&mut optimizer);
518 }
519
520 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
521
522 for _ in 0..10 {
523 scheduler.step(&mut optimizer);
524 }
525
526 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
527 }
528
529 #[test]
530 fn test_multi_step_lr() {
531 let mut optimizer = create_test_optimizer();
532 let mut scheduler = MultiStepLR::new(&optimizer, vec![5, 15], 0.1);
533
534 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
535
536 for _ in 0..5 {
537 scheduler.step(&mut optimizer);
538 }
539 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
540
541 for _ in 0..10 {
542 scheduler.step(&mut optimizer);
543 }
544 assert!((optimizer.get_lr() - 0.001).abs() < 1e-6);
545 }
546
547 #[test]
548 fn test_exponential_lr() {
549 let mut optimizer = create_test_optimizer();
550 let mut scheduler = ExponentialLR::new(&optimizer, 0.9);
551
552 scheduler.step(&mut optimizer);
553 assert!((optimizer.get_lr() - 0.09).abs() < 1e-6);
554
555 scheduler.step(&mut optimizer);
556 assert!((optimizer.get_lr() - 0.081).abs() < 1e-6);
557 }
558
559 #[test]
560 fn test_cosine_annealing_lr() {
561 let mut optimizer = create_test_optimizer();
562 let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
563
564 for _ in 0..50 {
566 scheduler.step(&mut optimizer);
567 }
568 assert!((optimizer.get_lr() - 0.05).abs() < 0.01);
569
570 for _ in 0..50 {
572 scheduler.step(&mut optimizer);
573 }
574 assert!(optimizer.get_lr() < 0.01);
575 }
576
577 #[test]
578 fn test_warmup_lr() {
579 let mut optimizer = create_test_optimizer();
580 let mut scheduler = WarmupLR::new(&optimizer, 10);
581
582 scheduler.step(&mut optimizer);
583 assert!((optimizer.get_lr() - 0.01).abs() < 1e-6);
584
585 for _ in 0..9 {
586 scheduler.step(&mut optimizer);
587 }
588 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
589
590 scheduler.step(&mut optimizer);
592 assert!((optimizer.get_lr() - 0.1).abs() < 1e-6);
593 }
594
595 #[test]
596 fn test_one_cycle_lr() {
597 let mut optimizer = create_test_optimizer();
598 let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
599
600 assert!((scheduler.get_last_lr() - 0.004).abs() < 0.001);
602
603 for _ in 0..30 {
605 scheduler.step(&mut optimizer);
606 }
607
608 assert!(optimizer.get_lr() > 0.08);
610 }
611
612 #[test]
613 fn test_reduce_lr_on_plateau() {
614 let mut optimizer = create_test_optimizer();
615 let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 2, 0.0, 0, 0.0);
616
617 let initial_lr = optimizer.get_lr();
618
619 scheduler.step_with_metric(&mut optimizer, 1.0);
621 scheduler.step_with_metric(&mut optimizer, 0.9);
622 assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
623
624 scheduler.step_with_metric(&mut optimizer, 0.91);
626 scheduler.step_with_metric(&mut optimizer, 0.91);
627 scheduler.step_with_metric(&mut optimizer, 0.91);
628
629 assert!(optimizer.get_lr() < initial_lr);
631 }
632
633 #[test]
638 fn test_reduce_lr_on_plateau_max_mode() {
639 let mut optimizer = create_test_optimizer();
640 let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "max", 0.5, 2, 0.0, 0, 0.0);
641
642 let initial_lr = optimizer.get_lr();
643
644 scheduler.step_with_metric(&mut optimizer, 0.8);
646 scheduler.step_with_metric(&mut optimizer, 0.9);
647 assert!((optimizer.get_lr() - initial_lr).abs() < 1e-6);
648
649 scheduler.step_with_metric(&mut optimizer, 0.85);
651 scheduler.step_with_metric(&mut optimizer, 0.85);
652 scheduler.step_with_metric(&mut optimizer, 0.85);
653
654 assert!(
655 optimizer.get_lr() < initial_lr,
656 "LR should reduce on plateau in max mode"
657 );
658 }
659
660 #[test]
661 fn test_reduce_lr_on_plateau_min_lr_floor() {
662 let mut optimizer = create_test_optimizer();
663 let mut scheduler =
664 ReduceLROnPlateau::with_options(&optimizer, "min", 0.1, 0, 0.0, 0, 0.001);
665
666 for _ in 0..50 {
668 scheduler.step_with_metric(&mut optimizer, 999.0); }
670
671 assert!(
672 optimizer.get_lr() >= 0.001,
673 "LR should not go below min_lr, got {}",
674 optimizer.get_lr()
675 );
676 }
677
678 #[test]
679 fn test_reduce_lr_cooldown() {
680 let mut optimizer = create_test_optimizer();
681 let mut scheduler = ReduceLROnPlateau::with_options(&optimizer, "min", 0.5, 0, 0.0, 3, 0.0);
682
683 let initial_lr = optimizer.get_lr();
684
685 scheduler.step_with_metric(&mut optimizer, 999.0);
687 scheduler.step_with_metric(&mut optimizer, 999.0);
688 let lr_after_first_reduce = optimizer.get_lr();
689 assert!(lr_after_first_reduce < initial_lr);
690
691 scheduler.step_with_metric(&mut optimizer, 999.0);
693 scheduler.step_with_metric(&mut optimizer, 999.0);
694 scheduler.step_with_metric(&mut optimizer, 999.0);
695 assert!(
696 (optimizer.get_lr() - lr_after_first_reduce).abs() < 1e-8,
697 "LR should not change during cooldown"
698 );
699 }
700
701 #[test]
706 fn test_one_cycle_lr_full_cycle() {
707 let mut optimizer = create_test_optimizer();
708 let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
709
710 let mut lrs = Vec::new();
711 for _ in 0..100 {
712 scheduler.step(&mut optimizer);
713 lrs.push(optimizer.get_lr());
714 }
715
716 let max_lr = lrs.iter().copied().fold(f32::MIN, f32::max);
718 let final_lr = *lrs.last().unwrap();
719
720 assert!(
721 max_lr > 0.08,
722 "Peak should be near max_lr=0.1, got {}",
723 max_lr
724 );
725 assert!(
726 final_lr < 0.001,
727 "Final LR should be very small, got {}",
728 final_lr
729 );
730
731 let peak_idx = lrs
733 .iter()
734 .enumerate()
735 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
736 .unwrap()
737 .0;
738 assert!(
739 (25..=35).contains(&peak_idx),
740 "Peak should be around step 30, was at step {}",
741 peak_idx
742 );
743 }
744
745 #[test]
746 fn test_one_cycle_lr_monotonic_phases() {
747 let mut optimizer = create_test_optimizer();
748 let mut scheduler = OneCycleLR::new(&optimizer, 0.1, 100);
749
750 let mut lrs = Vec::new();
751 for _ in 0..100 {
752 scheduler.step(&mut optimizer);
753 lrs.push(optimizer.get_lr());
754 }
755
756 for i in 1..29 {
758 assert!(
759 lrs[i] >= lrs[i - 1] - 1e-6,
760 "Warmup should increase: step {} lr={} < step {} lr={}",
761 i,
762 lrs[i],
763 i - 1,
764 lrs[i - 1]
765 );
766 }
767
768 for i in 32..99 {
770 assert!(
771 lrs[i] <= lrs[i - 1] + 1e-6,
772 "Annealing should decrease: step {} lr={} > step {} lr={}",
773 i,
774 lrs[i],
775 i - 1,
776 lrs[i - 1]
777 );
778 }
779 }
780
781 #[test]
786 fn test_cosine_annealing_with_eta_min() {
787 let mut optimizer = create_test_optimizer();
788 let mut scheduler = CosineAnnealingLR::with_eta_min(&optimizer, 100, 0.001);
789
790 for _ in 0..100 {
791 scheduler.step(&mut optimizer);
792 }
793
794 assert!(
796 (optimizer.get_lr() - 0.001).abs() < 0.002,
797 "Should reach eta_min at end, got {}",
798 optimizer.get_lr()
799 );
800 }
801
802 #[test]
803 fn test_cosine_annealing_monotonic_decrease() {
804 let mut optimizer = create_test_optimizer();
805 let mut scheduler = CosineAnnealingLR::new(&optimizer, 100);
806
807 let mut lrs = Vec::new();
808 for _ in 0..100 {
809 scheduler.step(&mut optimizer);
810 lrs.push(optimizer.get_lr());
811 }
812
813 for i in 1..lrs.len() {
815 assert!(
816 lrs[i] <= lrs[i - 1] + 1e-6,
817 "Cosine should decrease: step {} lr={} > step {} lr={}",
818 i + 1,
819 lrs[i],
820 i,
821 lrs[i - 1]
822 );
823 }
824
825 assert!(
827 lrs.iter().all(|lr| *lr >= 0.0),
828 "LRs should be non-negative"
829 );
830 }
831
832 #[test]
837 fn test_warmup_lr_stays_constant_after() {
838 let mut optimizer = create_test_optimizer();
839 let mut scheduler = WarmupLR::new(&optimizer, 5);
840
841 for _ in 0..5 {
842 scheduler.step(&mut optimizer);
843 }
844 let target = optimizer.get_lr();
845
846 for _ in 0..100 {
848 scheduler.step(&mut optimizer);
849 assert!(
850 (optimizer.get_lr() - target).abs() < 1e-8,
851 "LR should stay at {} after warmup, got {}",
852 target,
853 optimizer.get_lr()
854 );
855 }
856 }
857}