1use std::f64::consts::PI;
2
3pub trait LearningRateScheduler {
5 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64;
7
8 fn reset(&mut self);
10
11 fn name(&self) -> &'static str;
13}
14
15#[derive(Clone, Debug)]
17pub struct ConstantLR;
18
19impl LearningRateScheduler for ConstantLR {
20 fn get_lr(&mut self, _epoch: usize, base_lr: f64) -> f64 {
21 base_lr
22 }
23
24 fn reset(&mut self) {}
25
26 fn name(&self) -> &'static str {
27 "ConstantLR"
28 }
29}
30
31#[derive(Clone, Debug)]
33pub struct StepLR {
34 step_size: usize,
35 gamma: f64,
36}
37
38impl StepLR {
39 pub fn new(step_size: usize, gamma: f64) -> Self {
40 StepLR { step_size, gamma }
41 }
42}
43
44impl LearningRateScheduler for StepLR {
45 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
46 let steps = epoch / self.step_size;
47 base_lr * self.gamma.powi(steps as i32)
48 }
49
50 fn reset(&mut self) {}
51
52 fn name(&self) -> &'static str {
53 "StepLR"
54 }
55}
56
57#[derive(Clone, Debug)]
59pub struct MultiStepLR {
60 milestones: Vec<usize>,
61 gamma: f64,
62}
63
64impl MultiStepLR {
65 pub fn new(milestones: Vec<usize>, gamma: f64) -> Self {
66 MultiStepLR { milestones, gamma }
67 }
68}
69
70impl LearningRateScheduler for MultiStepLR {
71 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
72 let num_reductions = self.milestones.iter()
73 .filter(|&&milestone| epoch >= milestone)
74 .count();
75 base_lr * self.gamma.powi(num_reductions as i32)
76 }
77
78 fn reset(&mut self) {}
79
80 fn name(&self) -> &'static str {
81 "MultiStepLR"
82 }
83}
84
85#[derive(Clone, Debug)]
87pub struct ExponentialLR {
88 gamma: f64,
89}
90
91impl ExponentialLR {
92 pub fn new(gamma: f64) -> Self {
93 ExponentialLR { gamma }
94 }
95}
96
97impl LearningRateScheduler for ExponentialLR {
98 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
99 base_lr * self.gamma.powi(epoch as i32)
100 }
101
102 fn reset(&mut self) {}
103
104 fn name(&self) -> &'static str {
105 "ExponentialLR"
106 }
107}
108
109#[derive(Clone, Debug)]
111pub struct CosineAnnealingLR {
112 t_max: usize,
113 eta_min: f64,
114 last_epoch: usize,
115}
116
117impl CosineAnnealingLR {
118 pub fn new(t_max: usize, eta_min: f64) -> Self {
119 CosineAnnealingLR {
120 t_max,
121 eta_min,
122 last_epoch: 0,
123 }
124 }
125}
126
127impl LearningRateScheduler for CosineAnnealingLR {
128 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
129 self.last_epoch = epoch;
130 if epoch == 0 {
131 return base_lr;
132 }
133
134 let t = epoch % self.t_max;
135 self.eta_min + (base_lr - self.eta_min) *
136 (1.0 + (PI * t as f64 / self.t_max as f64).cos()) / 2.0
137 }
138
139 fn reset(&mut self) {
140 self.last_epoch = 0;
141 }
142
143 fn name(&self) -> &'static str {
144 "CosineAnnealingLR"
145 }
146}
147
148#[derive(Clone, Debug)]
150pub struct CosineAnnealingWarmRestarts {
151 t_0: usize,
152 t_mult: usize,
153 eta_min: f64,
154 last_restart: usize,
155 restart_count: usize,
156}
157
158impl CosineAnnealingWarmRestarts {
159 pub fn new(t_0: usize, t_mult: usize, eta_min: f64) -> Self {
160 CosineAnnealingWarmRestarts {
161 t_0,
162 t_mult,
163 eta_min,
164 last_restart: 0,
165 restart_count: 0,
166 }
167 }
168}
169
170impl LearningRateScheduler for CosineAnnealingWarmRestarts {
171 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
172 if epoch == 0 {
173 return base_lr;
174 }
175
176 let t_cur = epoch - self.last_restart;
177 let t_i = self.t_0 * self.t_mult.pow(self.restart_count as u32);
178
179 if t_cur >= t_i {
180 self.last_restart = epoch;
181 self.restart_count += 1;
182 return base_lr;
183 }
184
185 self.eta_min + (base_lr - self.eta_min) *
186 (1.0 + (PI * t_cur as f64 / t_i as f64).cos()) / 2.0
187 }
188
189 fn reset(&mut self) {
190 self.last_restart = 0;
191 self.restart_count = 0;
192 }
193
194 fn name(&self) -> &'static str {
195 "CosineAnnealingWarmRestarts"
196 }
197}
198
199#[derive(Clone, Debug)]
201pub struct OneCycleLR {
202 max_lr: f64,
203 total_steps: usize,
204 pct_start: f64,
205 anneal_strategy: AnnealStrategy,
206 div_factor: f64,
207 final_div_factor: f64,
208}
209
210#[derive(Clone, Debug)]
211pub enum AnnealStrategy {
212 Cos,
213 Linear,
214}
215
216impl OneCycleLR {
217 pub fn new(max_lr: f64, total_steps: usize) -> Self {
218 OneCycleLR {
219 max_lr,
220 total_steps,
221 pct_start: 0.3,
222 anneal_strategy: AnnealStrategy::Cos,
223 div_factor: 25.0,
224 final_div_factor: 10000.0,
225 }
226 }
227
228 pub fn with_params(
229 max_lr: f64,
230 total_steps: usize,
231 pct_start: f64,
232 anneal_strategy: AnnealStrategy,
233 div_factor: f64,
234 final_div_factor: f64,
235 ) -> Self {
236 OneCycleLR {
237 max_lr,
238 total_steps,
239 pct_start,
240 anneal_strategy,
241 div_factor,
242 final_div_factor,
243 }
244 }
245}
246
247impl LearningRateScheduler for OneCycleLR {
248 fn get_lr(&mut self, epoch: usize, _base_lr: f64) -> f64 {
249 if epoch >= self.total_steps {
250 return self.max_lr / self.final_div_factor;
251 }
252
253 let _step_ratio = epoch as f64 / self.total_steps as f64;
254 let warmup_steps = (self.total_steps as f64 * self.pct_start) as usize;
255
256 if epoch < warmup_steps {
257 let warmup_ratio = epoch as f64 / warmup_steps as f64;
259 (self.max_lr / self.div_factor) +
260 (self.max_lr - self.max_lr / self.div_factor) * warmup_ratio
261 } else {
262 let anneal_ratio = (epoch - warmup_steps) as f64 /
264 (self.total_steps - warmup_steps) as f64;
265
266 match self.anneal_strategy {
267 AnnealStrategy::Cos => {
268 let cos_factor = (1.0 + (PI * anneal_ratio).cos()) / 2.0;
269 (self.max_lr / self.final_div_factor) +
270 (self.max_lr - self.max_lr / self.final_div_factor) * cos_factor
271 },
272 AnnealStrategy::Linear => {
273 self.max_lr - (self.max_lr - self.max_lr / self.final_div_factor) * anneal_ratio
274 }
275 }
276 }
277 }
278
279 fn reset(&mut self) {}
280
281 fn name(&self) -> &'static str {
282 "OneCycleLR"
283 }
284}
285
286#[derive(Clone, Debug)]
288pub struct ReduceLROnPlateau {
289 factor: f64,
290 patience: usize,
291 threshold: f64,
292 cooldown: usize,
293 min_lr: f64,
294 best_loss: f64,
295 wait_count: usize,
296 cooldown_counter: usize,
297 current_lr: f64,
298}
299
300impl ReduceLROnPlateau {
301 pub fn new(factor: f64, patience: usize) -> Self {
302 ReduceLROnPlateau {
303 factor,
304 patience,
305 threshold: 1e-4,
306 cooldown: 0,
307 min_lr: 0.0,
308 best_loss: f64::INFINITY,
309 wait_count: 0,
310 cooldown_counter: 0,
311 current_lr: 0.0,
312 }
313 }
314
315 pub fn with_params(
316 factor: f64,
317 patience: usize,
318 threshold: f64,
319 cooldown: usize,
320 min_lr: f64,
321 ) -> Self {
322 ReduceLROnPlateau {
323 factor,
324 patience,
325 threshold,
326 cooldown,
327 min_lr,
328 best_loss: f64::INFINITY,
329 wait_count: 0,
330 cooldown_counter: 0,
331 current_lr: 0.0,
332 }
333 }
334
335 pub fn step(&mut self, val_loss: f64, base_lr: f64) -> f64 {
337 if self.current_lr == 0.0 {
338 self.current_lr = base_lr;
339 }
340
341 if self.cooldown_counter > 0 {
342 self.cooldown_counter -= 1;
343 return self.current_lr;
344 }
345
346 if val_loss < self.best_loss - self.threshold {
347 self.best_loss = val_loss;
348 self.wait_count = 0;
349 } else {
350 self.wait_count += 1;
351
352 if self.wait_count >= self.patience {
353 let new_lr = self.current_lr * self.factor;
354 self.current_lr = new_lr.max(self.min_lr);
355 self.wait_count = 0;
356 self.cooldown_counter = self.cooldown;
357 println!("ReduceLROnPlateau: reducing learning rate to {:.2e}", self.current_lr);
358 }
359 }
360
361 self.current_lr
362 }
363}
364
365impl LearningRateScheduler for ReduceLROnPlateau {
366 fn get_lr(&mut self, _epoch: usize, base_lr: f64) -> f64 {
367 if self.current_lr == 0.0 {
368 self.current_lr = base_lr;
369 }
370 self.current_lr
371 }
372
373 fn reset(&mut self) {
374 self.best_loss = f64::INFINITY;
375 self.wait_count = 0;
376 self.cooldown_counter = 0;
377 self.current_lr = 0.0;
378 }
379
380 fn name(&self) -> &'static str {
381 "ReduceLROnPlateau"
382 }
383}
384
385#[derive(Clone, Debug)]
387pub struct LinearLR {
388 start_factor: f64,
389 end_factor: f64,
390 total_iters: usize,
391}
392
393impl LinearLR {
394 pub fn new(start_factor: f64, end_factor: f64, total_iters: usize) -> Self {
395 LinearLR {
396 start_factor,
397 end_factor,
398 total_iters,
399 }
400 }
401}
402
403impl LearningRateScheduler for LinearLR {
404 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
405 if epoch >= self.total_iters {
406 return base_lr * self.end_factor;
407 }
408
409 let progress = epoch as f64 / self.total_iters as f64;
410 let factor = self.start_factor +
411 (self.end_factor - self.start_factor) * progress;
412
413 base_lr * factor
414 }
415
416 fn reset(&mut self) {}
417
418 fn name(&self) -> &'static str {
419 "LinearLR"
420 }
421}
422
423#[derive(Clone, Debug)]
425pub struct PolynomialLR {
426 total_iters: usize,
427 power: f64,
428 end_lr: f64,
429}
430
431impl PolynomialLR {
432 pub fn new(total_iters: usize, power: f64, end_lr: f64) -> Self {
433 PolynomialLR {
434 total_iters,
435 power,
436 end_lr,
437 }
438 }
439}
440
441impl LearningRateScheduler for PolynomialLR {
442 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
443 if epoch >= self.total_iters {
444 return self.end_lr;
445 }
446
447 let factor = (1.0 - epoch as f64 / self.total_iters as f64).powf(self.power);
448 self.end_lr + (base_lr - self.end_lr) * factor
449 }
450
451 fn reset(&mut self) {}
452
453 fn name(&self) -> &'static str {
454 "PolynomialLR"
455 }
456}
457
458#[derive(Clone, Debug)]
460pub struct CyclicalLR {
461 base_lr: f64,
462 max_lr: f64,
463 step_size: usize,
464 mode: CyclicalMode,
465 gamma: f64,
466 scale_mode: ScaleMode,
467 last_step: usize,
468}
469
470#[derive(Clone, Debug)]
471pub enum CyclicalMode {
472 Triangular,
473 Triangular2,
474 ExpRange,
475}
476
477#[derive(Clone, Debug)]
478pub enum ScaleMode {
479 Cycle,
480 Iterations,
481}
482
483impl CyclicalLR {
484 pub fn new(base_lr: f64, max_lr: f64, step_size: usize) -> Self {
485 CyclicalLR {
486 base_lr,
487 max_lr,
488 step_size,
489 mode: CyclicalMode::Triangular,
490 gamma: 1.0,
491 scale_mode: ScaleMode::Cycle,
492 last_step: 0,
493 }
494 }
495
496 pub fn with_mode(mut self, mode: CyclicalMode) -> Self {
497 self.mode = mode;
498 self
499 }
500
501 pub fn with_gamma(mut self, gamma: f64) -> Self {
502 self.gamma = gamma;
503 self
504 }
505
506 pub fn with_scale_mode(mut self, scale_mode: ScaleMode) -> Self {
507 self.scale_mode = scale_mode;
508 self
509 }
510}
511
512impl LearningRateScheduler for CyclicalLR {
513 fn get_lr(&mut self, epoch: usize, _base_lr: f64) -> f64 {
514 self.last_step = epoch;
515
516 let cycle = (epoch as f64 / (2.0 * self.step_size as f64)).floor() as usize;
517 let x = (epoch as f64 / self.step_size as f64 - 2.0 * cycle as f64 - 1.0).abs();
518
519 let scale_factor = match self.mode {
520 CyclicalMode::Triangular => 1.0,
521 CyclicalMode::Triangular2 => 1.0 / (2.0_f64.powi(cycle as i32 - 1)),
522 CyclicalMode::ExpRange => self.gamma.powi(epoch as i32),
523 };
524
525 let scale_factor = match self.scale_mode {
526 ScaleMode::Cycle => scale_factor,
527 ScaleMode::Iterations => self.gamma.powi(epoch as i32),
528 };
529
530 self.base_lr + (self.max_lr - self.base_lr) * (1.0 - x).max(0.0) * scale_factor
531 }
532
533 fn reset(&mut self) {
534 self.last_step = 0;
535 }
536
537 fn name(&self) -> &'static str {
538 "CyclicalLR"
539 }
540}
541
542#[derive(Clone, Debug)]
544pub struct WarmupScheduler<S: LearningRateScheduler> {
545 warmup_epochs: usize,
546 base_scheduler: S,
547 warmup_start_lr: f64,
548}
549
550impl<S: LearningRateScheduler> WarmupScheduler<S> {
551 pub fn new(warmup_epochs: usize, base_scheduler: S, warmup_start_lr: f64) -> Self {
552 WarmupScheduler {
553 warmup_epochs,
554 base_scheduler,
555 warmup_start_lr,
556 }
557 }
558}
559
560impl<S: LearningRateScheduler> LearningRateScheduler for WarmupScheduler<S> {
561 fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 {
562 if epoch < self.warmup_epochs {
563 let warmup_factor = epoch as f64 / self.warmup_epochs as f64;
565 self.warmup_start_lr + (base_lr - self.warmup_start_lr) * warmup_factor
566 } else {
567 self.base_scheduler.get_lr(epoch - self.warmup_epochs, base_lr)
569 }
570 }
571
572 fn reset(&mut self) {
573 self.base_scheduler.reset();
574 }
575
576 fn name(&self) -> &'static str {
577 "WarmupScheduler"
578 }
579}
580
581pub struct LRScheduleVisualizer;
583
584impl LRScheduleVisualizer {
585 pub fn generate_schedule<S: LearningRateScheduler>(
587 mut scheduler: S,
588 base_lr: f64,
589 epochs: usize,
590 ) -> Vec<(usize, f64)> {
591 let mut schedule = Vec::new();
592
593 for epoch in 0..epochs {
594 let lr = scheduler.get_lr(epoch, base_lr);
595 schedule.push((epoch, lr));
596 }
597
598 schedule
599 }
600
601 pub fn print_schedule<S: LearningRateScheduler>(
603 scheduler: S,
604 base_lr: f64,
605 epochs: usize,
606 width: usize,
607 height: usize,
608 ) {
609 let schedule = Self::generate_schedule(scheduler, base_lr, epochs);
610
611 if schedule.is_empty() {
612 return;
613 }
614
615 let min_lr = schedule.iter().map(|(_, lr)| *lr).fold(f64::INFINITY, f64::min);
616 let max_lr = schedule.iter().map(|(_, lr)| *lr).fold(0.0, f64::max);
617
618 println!("Learning Rate Schedule Visualization ({}x{})", width, height);
619 println!("Min LR: {:.2e}, Max LR: {:.2e}", min_lr, max_lr);
620 println!("┌{}┐", "─".repeat(width));
621
622 for row in 0..height {
623 let y_value = max_lr - (max_lr - min_lr) * row as f64 / (height - 1) as f64;
624 print!("│");
625
626 for col in 0..width {
627 let epoch_idx = col * epochs / width;
628 let lr = if epoch_idx < schedule.len() {
629 schedule[epoch_idx].1
630 } else {
631 min_lr
632 };
633
634 if (lr - y_value).abs() < (max_lr - min_lr) / height as f64 {
635 print!("█");
636 } else {
637 print!(" ");
638 }
639 }
640
641 println!("│ {:.2e}", y_value);
642 }
643
644 println!("└{}┘", "─".repeat(width));
645 print!(" ");
646 for i in 0..=4 {
647 let epoch = i * epochs / 4;
648 print!("{:>width$}", epoch, width = width / 5);
649 }
650 println!();
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657
658 #[test]
659 fn test_constant_lr() {
660 let mut scheduler = ConstantLR;
661 let base_lr = 0.01;
662
663 assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
664 assert_eq!(scheduler.get_lr(10, base_lr), base_lr);
665 assert_eq!(scheduler.get_lr(100, base_lr), base_lr);
666 }
667
668 #[test]
669 fn test_step_lr() {
670 let mut scheduler = StepLR::new(10, 0.1);
671 let base_lr = 0.01;
672
673 assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
674 assert_eq!(scheduler.get_lr(9, base_lr), base_lr);
675 assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-15);
676 assert!((scheduler.get_lr(20, base_lr) - base_lr * 0.01).abs() < 1e-15);
677 }
678
679 #[test]
680 fn test_exponential_lr() {
681 let mut scheduler = ExponentialLR::new(0.9);
682 let base_lr = 0.01;
683
684 assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
685 assert!((scheduler.get_lr(1, base_lr) - base_lr * 0.9).abs() < 1e-10);
686 assert!((scheduler.get_lr(2, base_lr) - base_lr * 0.81).abs() < 1e-10);
687 }
688
689 #[test]
690 fn test_multi_step_lr() {
691 let mut scheduler = MultiStepLR::new(vec![10, 20], 0.1);
692 let base_lr = 0.01;
693
694 assert_eq!(scheduler.get_lr(5, base_lr), base_lr);
695 assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-15);
696 assert!((scheduler.get_lr(15, base_lr) - base_lr * 0.1).abs() < 1e-15);
697 assert!((scheduler.get_lr(20, base_lr) - base_lr * 0.01).abs() < 1e-15);
698 }
699
700 #[test]
701 fn test_one_cycle_lr() {
702 let mut scheduler = OneCycleLR::new(0.1, 100);
703 let base_lr = 0.01;
704
705 let lr_0 = scheduler.get_lr(0, base_lr);
706 let lr_30 = scheduler.get_lr(30, base_lr); let lr_100 = scheduler.get_lr(100, base_lr); assert!(lr_0 < lr_30);
710 assert!(lr_100 < lr_0);
711 assert!(lr_30 <= 0.1);
712 }
713
714 #[test]
715 fn test_reduce_lr_on_plateau() {
716 let mut scheduler = ReduceLROnPlateau::new(0.5, 2);
717 let base_lr = 0.01;
718
719 let lr1 = scheduler.step(1.0, base_lr);
721 assert_eq!(lr1, base_lr);
722
723 let lr2 = scheduler.step(0.8, base_lr);
725 assert_eq!(lr2, base_lr);
726
727 let _lr3 = scheduler.step(0.9, base_lr);
729 let _lr4 = scheduler.step(0.9, base_lr);
730 let lr5 = scheduler.step(0.9, base_lr);
731
732 assert!(lr5 < base_lr);
733 assert!((lr5 - base_lr * 0.5).abs() < 1e-10);
734 }
735
736 #[test]
737 fn test_linear_lr() {
738 let mut scheduler = LinearLR::new(1.0, 0.1, 10);
739 let base_lr = 0.01;
740
741 assert_eq!(scheduler.get_lr(0, base_lr), base_lr);
742 assert!((scheduler.get_lr(5, base_lr) - base_lr * 0.55).abs() < 1e-10);
743 assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-10);
744 }
745
746 #[test]
747 fn test_polynomial_lr() {
748 let mut scheduler = PolynomialLR::new(100, 2.0, 0.01);
749 let base_lr = 0.1;
750
751 assert_eq!(scheduler.get_lr(0, base_lr), 0.1);
752 assert!((scheduler.get_lr(50, base_lr) - 0.0325).abs() < 1e-10);
755 assert!((scheduler.get_lr(100, base_lr) - 0.01).abs() < 1e-10);
756 }
757
758 #[test]
759 fn test_cyclical_lr() {
760 let mut scheduler = CyclicalLR::new(0.1, 1.0, 10);
761 let base_lr = 0.1;
762
763 assert_eq!(scheduler.get_lr(0, base_lr), 0.1);
764 assert!((scheduler.get_lr(5, base_lr) - 0.55).abs() < 1e-10);
767 assert_eq!(scheduler.get_lr(10, base_lr), 1.0);
771 }
772
773 #[test]
774 fn test_warmup_scheduler() {
775 let base_scheduler = ConstantLR;
776 let mut scheduler = WarmupScheduler::new(10, base_scheduler, 0.01);
777 let base_lr = 0.1;
778
779 assert_eq!(scheduler.get_lr(0, base_lr), 0.01);
780 assert!((scheduler.get_lr(5, base_lr) - 0.055).abs() < 1e-10);
783 assert_eq!(scheduler.get_lr(10, base_lr), 0.1);
784 }
785}