pub trait LRScheduler {
fn step(&mut self);
fn get_lr(&self) -> f32;
fn set_initial_lr(&mut self, lr: f32);
}
pub struct StepLR {
initial_lr: f32,
step_size: usize,
gamma: f32,
current_epoch: usize,
current_lr: f32,
}
impl StepLR {
pub fn new(initial_lr: f32, step_size: usize, gamma: f32) -> Self {
assert!(step_size > 0, "StepLR: step_size must be > 0");
Self {
initial_lr,
step_size,
gamma,
current_epoch: 0,
current_lr: initial_lr,
}
}
}
impl LRScheduler for StepLR {
fn step(&mut self) {
self.current_epoch += 1;
if self.current_epoch % self.step_size == 0 {
self.current_lr *= self.gamma;
}
}
fn get_lr(&self) -> f32 {
self.current_lr
}
fn set_initial_lr(&mut self, lr: f32) {
self.initial_lr = lr;
self.current_lr = lr * self.gamma.powi((self.current_epoch / self.step_size) as i32);
}
}
pub struct CosineAnnealingLR {
initial_lr: f32,
t_max: usize,
eta_min: f32,
current_epoch: usize,
current_lr: f32,
}
impl CosineAnnealingLR {
pub fn new(initial_lr: f32, t_max: usize, eta_min: f32) -> Self {
assert!(t_max > 0, "CosineAnnealingLR: t_max must be > 0");
Self {
initial_lr,
t_max,
eta_min,
current_epoch: 0,
current_lr: initial_lr,
}
}
}
impl LRScheduler for CosineAnnealingLR {
fn step(&mut self) {
self.current_epoch += 1;
if self.current_epoch >= self.t_max {
self.current_lr = self.eta_min;
} else {
self.current_lr = self.eta_min
+ 0.5
* (self.initial_lr - self.eta_min)
* (1.0
+ (std::f32::consts::PI * self.current_epoch as f32
/ self.t_max as f32)
.cos());
}
}
fn get_lr(&self) -> f32 {
self.current_lr
}
fn set_initial_lr(&mut self, lr: f32) {
self.initial_lr = lr;
if self.current_epoch >= self.t_max {
self.current_lr = self.eta_min;
} else {
self.current_lr = self.eta_min
+ 0.5
* (lr - self.eta_min)
* (1.0
+ (std::f32::consts::PI * self.current_epoch as f32
/ self.t_max as f32)
.cos());
}
}
}