use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
use crate::optim::LRScheduler;
pub struct LRSchedulerCallback<S: LRScheduler + Send> {
scheduler: S,
per_step: bool,
initial_lr: Option<f32>,
}
impl<S: LRScheduler + Send> LRSchedulerCallback<S> {
pub fn per_epoch(scheduler: S) -> Self {
Self { scheduler, per_step: false, initial_lr: None }
}
pub fn per_step(scheduler: S) -> Self {
Self { scheduler, per_step: true, initial_lr: None }
}
pub fn current_lr(&self) -> f32 {
self.scheduler.get_lr()
}
}
impl<S: LRScheduler + Send> TrainerCallback for LRSchedulerCallback<S> {
fn on_train_begin(&mut self, ctx: &CallbackContext) -> CallbackAction {
self.initial_lr = Some(ctx.lr);
CallbackAction::Continue
}
fn on_epoch_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
if !self.per_step {
self.scheduler.step();
}
CallbackAction::Continue
}
fn on_step_end(&mut self, _ctx: &CallbackContext) -> CallbackAction {
if self.per_step {
self.scheduler.step();
}
CallbackAction::Continue
}
fn name(&self) -> &'static str {
"LRSchedulerCallback"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optim::StepDecayLR;
#[test]
fn test_lr_scheduler_callback_per_epoch() {
let scheduler = StepDecayLR::new(0.1, 10, 0.5);
let mut cb = LRSchedulerCallback::per_epoch(scheduler);
let ctx = CallbackContext { lr: 0.1, ..Default::default() };
cb.on_train_begin(&ctx);
assert_eq!(cb.initial_lr, Some(0.1));
cb.on_epoch_end(&ctx);
}
#[test]
fn test_lr_scheduler_callback_per_step() {
let scheduler = StepDecayLR::new(0.1, 10, 0.5);
let mut cb = LRSchedulerCallback::per_step(scheduler);
cb.on_step_end(&CallbackContext::default());
}
#[test]
fn test_lr_scheduler_callback_current_lr() {
let scheduler = StepDecayLR::new(0.1, 10, 0.5);
let cb = LRSchedulerCallback::per_epoch(scheduler);
assert!((cb.current_lr() - 0.1).abs() < 1e-6);
}
#[test]
fn test_lr_scheduler_callback_name() {
let scheduler = StepDecayLR::new(0.1, 10, 0.5);
let cb = LRSchedulerCallback::per_epoch(scheduler);
assert_eq!(cb.name(), "LRSchedulerCallback");
}
}