use super::LRScheduler;
use crate::optim::Optimizer;
pub struct StepDecayLR {
lr_initial: f32,
gamma: f32,
step_size: usize,
current_epoch: usize,
}
impl StepDecayLR {
pub fn new(lr_initial: f32, step_size: usize, gamma: f32) -> Self {
Self { lr_initial, gamma, step_size, current_epoch: 0 }
}
pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
optimizer.set_lr(self.get_lr());
}
}
impl LRScheduler for StepDecayLR {
fn get_lr(&self) -> f32 {
if self.step_size == 0 {
return self.lr_initial;
}
let num_decays = self.current_epoch / self.step_size;
self.lr_initial * self.gamma.powi(num_decays as i32)
}
fn step(&mut self) {
self.current_epoch += 1;
}
}