pub trait Scheduler {
fn lr(&self, step: usize) -> f64;
}
pub struct StepDecay {
base_lr: f64,
step_size: usize,
gamma: f64,
}
impl StepDecay {
pub fn new(base_lr: f64, step_size: usize, gamma: f64) -> Self {
StepDecay {
base_lr,
step_size,
gamma,
}
}
pub fn lr(&self, step: usize) -> f64 {
let decays = step / self.step_size;
self.base_lr * self.gamma.powi(decays as i32)
}
}
impl Scheduler for StepDecay {
fn lr(&self, step: usize) -> f64 {
StepDecay::lr(self, step)
}
}
pub struct CosineScheduler {
base_lr: f64,
min_lr: f64,
total_steps: usize,
}
impl CosineScheduler {
pub fn new(base_lr: f64, min_lr: f64, total_steps: usize) -> Self {
CosineScheduler {
base_lr,
min_lr,
total_steps,
}
}
pub fn lr(&self, step: usize) -> f64 {
let t = (step.min(self.total_steps) as f64) / (self.total_steps as f64);
self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1.0 + (t * std::f64::consts::PI).cos())
}
}
impl Scheduler for CosineScheduler {
fn lr(&self, step: usize) -> f64 {
CosineScheduler::lr(self, step)
}
}
pub struct WarmupScheduler<S: Scheduler> {
inner: S,
target_lr: f64,
warmup_steps: usize,
}
impl<S: Scheduler> WarmupScheduler<S> {
pub fn new(inner: S, target_lr: f64, warmup_steps: usize) -> Self {
WarmupScheduler {
inner,
target_lr,
warmup_steps,
}
}
pub fn lr(&self, step: usize) -> f64 {
if step < self.warmup_steps {
self.target_lr * (step as f64 + 1.0) / (self.warmup_steps as f64)
} else {
self.inner.lr(step - self.warmup_steps)
}
}
}
impl<S: Scheduler> Scheduler for WarmupScheduler<S> {
fn lr(&self, step: usize) -> f64 {
WarmupScheduler::lr(self, step)
}
}
pub struct PlateauScheduler {
patience: usize,
factor: f64,
min_lr: f64,
current_lr: f64,
best: f64,
wait: usize,
}
impl PlateauScheduler {
pub fn new(
base_lr: f64,
patience: usize,
factor: f64,
min_lr: f64,
) -> Self {
PlateauScheduler {
patience,
factor,
min_lr,
current_lr: base_lr,
best: f64::INFINITY,
wait: 0,
}
}
pub fn observe(&mut self, metric: f64) -> f64 {
if metric < self.best {
self.best = metric;
self.wait = 0;
} else {
self.wait += 1;
if self.wait >= self.patience {
self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
self.wait = 0;
}
}
self.current_lr
}
pub fn lr(&self) -> f64 {
self.current_lr
}
}