entrenar/train/callback/
scheduler.rs1use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4use crate::optim::LRScheduler;
5
6pub struct LRSchedulerCallback<S: LRScheduler + Send> {
21 scheduler: S,
22 per_step: bool,
23 initial_lr: Option<f32>,
24}
25
26impl<S: LRScheduler + Send> LRSchedulerCallback<S> {
27 pub fn per_epoch(scheduler: S) -> Self {
29 Self { scheduler, per_step: false, initial_lr: None }
30 }
31
32 pub fn per_step(scheduler: S) -> Self {
34 Self { scheduler, per_step: true, initial_lr: None }
35 }
36
37 pub fn current_lr(&self) -> f32 {
39 self.scheduler.get_lr()
40 }
41}
42
43impl<S: LRScheduler + Send> TrainerCallback for LRSchedulerCallback<S> {
44 fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
45 self.initial_lr = Some(ctx.lr);
46 CallbackAction::Continue
47 }
48
49 fn on_epoch_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
50 if !self.per_step {
51 self.scheduler.step();
52 }
53 CallbackAction::Continue
54 }
55
56 fn on_step_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
57 if self.per_step {
58 self.scheduler.step();
59 }
60 CallbackAction::Continue
61 }
62
63 fn name(&self) -> &'static str {
64 "LRSchedulerCallback"
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::optim::StepDecayLR;
72
73 #[test]
74 fn test_lr_scheduler_callback_per_epoch() {
75 let scheduler = StepDecayLR::new(0.1, 10, 0.5);
76 let mut cb = LRSchedulerCallback::per_epoch(scheduler);
77 let ctx = CallbackContext { lr: 0.1, ..Default::default() };
78 cb.on_train_begin(&ctx);
79 assert_eq!(cb.initial_lr, Some(0.1));
80 cb.on_epoch_end(&ctx);
81 }
82
83 #[test]
84 fn test_lr_scheduler_callback_per_step() {
85 let scheduler = StepDecayLR::new(0.1, 10, 0.5);
86 let mut cb = LRSchedulerCallback::per_step(scheduler);
87 cb.on_step_end(&CallbackContext::default());
88 }
89
90 #[test]
91 fn test_lr_scheduler_callback_current_lr() {
92 let scheduler = StepDecayLR::new(0.1, 10, 0.5);
93 let cb = LRSchedulerCallback::per_epoch(scheduler);
94 assert!((cb.current_lr() - 0.1).abs() < 1e-6);
95 }
96
97 #[test]
98 fn test_lr_scheduler_callback_name() {
99 let scheduler = StepDecayLR::new(0.1, 10, 0.5);
100 let cb = LRSchedulerCallback::per_epoch(scheduler);
101 assert_eq!(cb.name(), "LRSchedulerCallback");
102 }
103}