Skip to main content

entrenar/train/callback/
scheduler.rs

1//! Learning rate scheduler callback
2
3use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4use crate::optim::LRScheduler;
5
6/// Callback that applies a learning rate scheduler during training
7///
8/// Can schedule per-step or per-epoch updates.
9///
10/// # Example
11///
12/// ```rust,ignore
13/// use entrenar::train::LRSchedulerCallback;
14/// use entrenar::optim::CosineAnnealingLR;
15///
16/// let scheduler = CosineAnnealingLR::new(0.001, 100, 0.0);
17/// let callback = LRSchedulerCallback::per_epoch(scheduler);
18/// trainer.add_callback(callback);
19/// ```
20pub struct LRSchedulerCallback<S: LRScheduler + Send> {
21    scheduler: S,
22    per_step: bool,
23    initial_lr: Option<f32>,
24}
25
26impl<S: LRScheduler + Send> LRSchedulerCallback<S> {
27    /// Create callback that steps scheduler per epoch
28    pub fn per_epoch(scheduler: S) -> Self {
29        Self { scheduler, per_step: false, initial_lr: None }
30    }
31
32    /// Create callback that steps scheduler per step
33    pub fn per_step(scheduler: S) -> Self {
34        Self { scheduler, per_step: true, initial_lr: None }
35    }
36
37    /// Get current learning rate from scheduler
38    pub fn current_lr(&self) -> f32 {
39        self.scheduler.get_lr()
40    }
41}
42
43impl<S: LRScheduler + Send> TrainerCallback for LRSchedulerCallback<S> {
44    fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
45        self.initial_lr = Some(ctx.lr);
46        CallbackAction::Continue
47    }
48
49    fn on_epoch_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
50        if !self.per_step {
51            self.scheduler.step();
52        }
53        CallbackAction::Continue
54    }
55
56    fn on_step_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
57        if self.per_step {
58            self.scheduler.step();
59        }
60        CallbackAction::Continue
61    }
62
63    fn name(&self) -> &'static str {
64        "LRSchedulerCallback"
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::optim::StepDecayLR;
72
73    #[test]
74    fn test_lr_scheduler_callback_per_epoch() {
75        let scheduler = StepDecayLR::new(0.1, 10, 0.5);
76        let mut cb = LRSchedulerCallback::per_epoch(scheduler);
77        let ctx = CallbackContext { lr: 0.1, ..Default::default() };
78        cb.on_train_begin(&ctx);
79        assert_eq!(cb.initial_lr, Some(0.1));
80        cb.on_epoch_end(&ctx);
81    }
82
83    #[test]
84    fn test_lr_scheduler_callback_per_step() {
85        let scheduler = StepDecayLR::new(0.1, 10, 0.5);
86        let mut cb = LRSchedulerCallback::per_step(scheduler);
87        cb.on_step_end(&CallbackContext::default());
88    }
89
90    #[test]
91    fn test_lr_scheduler_callback_current_lr() {
92        let scheduler = StepDecayLR::new(0.1, 10, 0.5);
93        let cb = LRSchedulerCallback::per_epoch(scheduler);
94        assert!((cb.current_lr() - 0.1).abs() < 1e-6);
95    }
96
97    #[test]
98    fn test_lr_scheduler_callback_name() {
99        let scheduler = StepDecayLR::new(0.1, 10, 0.5);
100        let cb = LRSchedulerCallback::per_epoch(scheduler);
101        assert_eq!(cb.name(), "LRSchedulerCallback");
102    }
103}