use core::f64::consts::PI;
pub trait Scheduler {
fn lr(&self, epoch: usize) -> f64;
}
pub struct ConstantLr {
pub lr: f64,
}
impl Scheduler for ConstantLr {
fn lr(&self, _epoch: usize) -> f64 {
self.lr
}
}
pub struct StepLr {
pub initial_lr: f64,
pub step_size: usize,
pub gamma: f64,
}
impl Scheduler for StepLr {
fn lr(&self, epoch: usize) -> f64 {
let exponent = (epoch / self.step_size) as i32;
self.initial_lr * self.gamma.powi(exponent)
}
}
pub struct CosineAnnealingLr {
pub initial_lr: f64,
pub min_lr: f64,
pub total_epochs: usize,
}
impl Scheduler for CosineAnnealingLr {
fn lr(&self, epoch: usize) -> f64 {
let t = epoch as f64 / self.total_epochs as f64;
self.min_lr + (self.initial_lr - self.min_lr) * 0.5 * (1.0 + (PI * t).cos())
}
}
pub struct WarmupCosine {
pub initial_lr: f64,
pub min_lr: f64,
pub warmup_epochs: usize,
pub total_epochs: usize,
}
impl Scheduler for WarmupCosine {
fn lr(&self, epoch: usize) -> f64 {
if epoch < self.warmup_epochs {
self.initial_lr * (epoch as f64 / self.warmup_epochs as f64)
} else {
let decay_epochs = self.total_epochs - self.warmup_epochs;
let t = (epoch - self.warmup_epochs) as f64 / decay_epochs as f64;
self.min_lr + (self.initial_lr - self.min_lr) * 0.5 * (1.0 + (PI * t).cos())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn constant_lr_returns_fixed() {
let s = ConstantLr { lr: 0.01 };
assert!((s.lr(0) - 0.01).abs() < 1e-12);
assert!((s.lr(100) - 0.01).abs() < 1e-12);
}
#[test]
fn step_lr_halves_every_10() {
let s = StepLr {
initial_lr: 1.0,
step_size: 10,
gamma: 0.5,
};
assert!((s.lr(0) - 1.0).abs() < 1e-12);
assert!((s.lr(9) - 1.0).abs() < 1e-12);
assert!((s.lr(10) - 0.5).abs() < 1e-12);
assert!((s.lr(19) - 0.5).abs() < 1e-12);
assert!((s.lr(20) - 0.25).abs() < 1e-12);
assert!((s.lr(30) - 0.125).abs() < 1e-12);
}
#[test]
fn cosine_annealing_starts_high_ends_at_min() {
let s = CosineAnnealingLr {
initial_lr: 1.0,
min_lr: 0.0,
total_epochs: 100,
};
assert!((s.lr(0) - 1.0).abs() < 1e-12);
assert!((s.lr(50) - 0.5).abs() < 1e-12);
assert!((s.lr(100) - 0.0).abs() < 1e-12);
}
#[test]
fn cosine_annealing_with_nonzero_min() {
let s = CosineAnnealingLr {
initial_lr: 1.0,
min_lr: 0.1,
total_epochs: 100,
};
assert!((s.lr(0) - 1.0).abs() < 1e-12);
assert!((s.lr(100) - 0.1).abs() < 1e-12);
}
#[test]
fn warmup_cosine_increases_then_decreases() {
let s = WarmupCosine {
initial_lr: 1.0,
min_lr: 0.0,
warmup_epochs: 10,
total_epochs: 110,
};
assert!((s.lr(0) - 0.0).abs() < 1e-12);
assert!((s.lr(5) - 0.5).abs() < 1e-12);
assert!((s.lr(10) - 1.0).abs() < 1e-12);
assert!((s.lr(60) - 0.5).abs() < 1e-12);
assert!((s.lr(110) - 0.0).abs() < 1e-12);
assert!(s.lr(3) < s.lr(7));
assert!(s.lr(20) > s.lr(80));
}
}