pub trait Scheduler: Send + Sync {
fn lr(&self, step: usize) -> f64;
}
pub struct StepDecay {
base_lr: f64,
step_size: usize,
gamma: f64,
}
impl StepDecay {
pub fn new(base_lr: f64, step_size: usize, gamma: f64) -> Self {
StepDecay {
base_lr,
step_size,
gamma,
}
}
pub fn lr(&self, step: usize) -> f64 {
let decays = step / self.step_size;
self.base_lr * self.gamma.powi(decays as i32)
}
}
impl Scheduler for StepDecay {
fn lr(&self, step: usize) -> f64 {
StepDecay::lr(self, step)
}
}
pub struct CosineScheduler {
base_lr: f64,
min_lr: f64,
total_steps: usize,
}
impl CosineScheduler {
pub fn new(base_lr: f64, min_lr: f64, total_steps: usize) -> Self {
CosineScheduler {
base_lr,
min_lr,
total_steps,
}
}
pub fn lr(&self, step: usize) -> f64 {
let t = (step.min(self.total_steps) as f64) / (self.total_steps as f64);
self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1.0 + (t * std::f64::consts::PI).cos())
}
}
impl Scheduler for CosineScheduler {
fn lr(&self, step: usize) -> f64 {
CosineScheduler::lr(self, step)
}
}
pub struct WarmupScheduler<S: Scheduler> {
inner: S,
target_lr: f64,
warmup_steps: usize,
}
impl<S: Scheduler> WarmupScheduler<S> {
pub fn new(inner: S, target_lr: f64, warmup_steps: usize) -> Self {
WarmupScheduler {
inner,
target_lr,
warmup_steps,
}
}
pub fn lr(&self, step: usize) -> f64 {
if step < self.warmup_steps {
self.target_lr * (step as f64 + 1.0) / (self.warmup_steps as f64)
} else {
self.inner.lr(step - self.warmup_steps)
}
}
}
impl<S: Scheduler> Scheduler for WarmupScheduler<S> {
fn lr(&self, step: usize) -> f64 {
WarmupScheduler::lr(self, step)
}
}
pub struct PlateauScheduler {
patience: usize,
factor: f64,
min_lr: f64,
current_lr: f64,
best: f64,
wait: usize,
}
impl PlateauScheduler {
pub fn new(
base_lr: f64,
patience: usize,
factor: f64,
min_lr: f64,
) -> Self {
PlateauScheduler {
patience,
factor,
min_lr,
current_lr: base_lr,
best: f64::INFINITY,
wait: 0,
}
}
pub fn observe(&mut self, metric: f64) -> f64 {
if metric < self.best {
self.best = metric;
self.wait = 0;
} else {
self.wait += 1;
if self.wait >= self.patience {
self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
self.wait = 0;
}
}
self.current_lr
}
pub fn lr(&self) -> f64 {
self.current_lr
}
}
pub struct ExponentialLR {
base_lr: f64,
gamma: f64,
}
impl ExponentialLR {
pub fn new(base_lr: f64, gamma: f64) -> Self {
ExponentialLR { base_lr, gamma }
}
pub fn lr(&self, step: usize) -> f64 {
self.base_lr * self.gamma.powi(step as i32)
}
}
impl Scheduler for ExponentialLR {
fn lr(&self, step: usize) -> f64 {
ExponentialLR::lr(self, step)
}
}
pub struct MultiStepLR {
base_lr: f64,
milestones: Vec<usize>,
gamma: f64,
}
impl MultiStepLR {
pub fn new(base_lr: f64, milestones: &[usize], gamma: f64) -> Self {
let mut ms = milestones.to_vec();
ms.sort();
MultiStepLR {
base_lr,
milestones: ms,
gamma,
}
}
pub fn lr(&self, step: usize) -> f64 {
let passed = self.milestones.iter().filter(|&&m| step >= m).count();
self.base_lr * self.gamma.powi(passed as i32)
}
}
impl Scheduler for MultiStepLR {
fn lr(&self, step: usize) -> f64 {
MultiStepLR::lr(self, step)
}
}
pub struct OneCycleLR {
max_lr: f64,
total_steps: usize,
warmup_steps: usize,
}
impl OneCycleLR {
pub fn new(max_lr: f64, total_steps: usize) -> Self {
let warmup_steps = (total_steps as f64 * 0.3).round() as usize;
OneCycleLR {
max_lr,
total_steps,
warmup_steps,
}
}
pub fn with_warmup_frac(max_lr: f64, total_steps: usize, warmup_frac: f64) -> Self {
let warmup_steps = (total_steps as f64 * warmup_frac.clamp(0.0, 1.0)).round() as usize;
OneCycleLR {
max_lr,
total_steps,
warmup_steps,
}
}
pub fn lr(&self, step: usize) -> f64 {
let step = step.min(self.total_steps);
let min_lr = self.max_lr / 25.0;
if step < self.warmup_steps {
let frac = step as f64 / self.warmup_steps.max(1) as f64;
min_lr + frac * (self.max_lr - min_lr)
} else {
let decay_steps = self.total_steps.saturating_sub(self.warmup_steps).max(1);
let t = (step - self.warmup_steps) as f64 / decay_steps as f64;
min_lr + 0.5 * (self.max_lr - min_lr) * (1.0 + (t * std::f64::consts::PI).cos())
}
}
}
impl Scheduler for OneCycleLR {
fn lr(&self, step: usize) -> f64 {
OneCycleLR::lr(self, step)
}
}
pub struct CyclicLR {
base_lr: f64,
max_lr: f64,
step_size_up: usize,
step_size_down: usize,
}
impl CyclicLR {
pub fn new(base_lr: f64, max_lr: f64, step_size: usize) -> Self {
CyclicLR { base_lr, max_lr, step_size_up: step_size, step_size_down: step_size }
}
pub fn asymmetric(base_lr: f64, max_lr: f64, step_size_up: usize, step_size_down: usize) -> Self {
CyclicLR { base_lr, max_lr, step_size_up, step_size_down }
}
pub fn lr(&self, step: usize) -> f64 {
let cycle_len = self.step_size_up + self.step_size_down;
let pos = step % cycle_len;
let scale = if pos <= self.step_size_up {
pos as f64 / self.step_size_up as f64
} else {
1.0 - (pos - self.step_size_up) as f64 / self.step_size_down as f64
};
self.base_lr + (self.max_lr - self.base_lr) * scale
}
}
impl Scheduler for CyclicLR {
fn lr(&self, step: usize) -> f64 {
CyclicLR::lr(self, step)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_step_decay() {
let sched = StepDecay::new(0.1, 10, 0.5);
assert!((sched.lr(0) - 0.1).abs() < 1e-10);
assert!((sched.lr(9) - 0.1).abs() < 1e-10);
assert!((sched.lr(10) - 0.05).abs() < 1e-10);
assert!((sched.lr(20) - 0.025).abs() < 1e-10);
}
#[test]
fn test_cosine_scheduler() {
let sched = CosineScheduler::new(0.1, 0.001, 100);
assert!((sched.lr(0) - 0.1).abs() < 1e-5);
assert!((sched.lr(100) - 0.001).abs() < 1e-5);
let mid = sched.lr(50);
assert!(mid > 0.04 && mid < 0.06, "mid={}", mid);
}
#[test]
fn test_exponential_lr() {
let sched = ExponentialLR::new(0.1, 0.9);
assert!((sched.lr(0) - 0.1).abs() < 1e-10);
assert!((sched.lr(1) - 0.09).abs() < 1e-10);
assert!((sched.lr(10) - 0.1 * 0.9f64.powi(10)).abs() < 1e-10);
}
#[test]
fn test_exponential_lr_scheduler_trait() {
let sched = ExponentialLR::new(0.1, 0.95);
let s: &dyn Scheduler = &sched;
assert!((s.lr(0) - 0.1).abs() < 1e-10);
}
#[test]
fn test_multi_step_lr() {
let sched = MultiStepLR::new(0.1, &[30, 60, 90], 0.1);
assert!((sched.lr(0) - 0.1).abs() < 1e-10);
assert!((sched.lr(29) - 0.1).abs() < 1e-10);
assert!((sched.lr(30) - 0.01).abs() < 1e-10);
assert!((sched.lr(59) - 0.01).abs() < 1e-10);
assert!((sched.lr(60) - 0.001).abs() < 1e-10);
assert!((sched.lr(89) - 0.001).abs() < 1e-10);
assert!((sched.lr(90) - 0.0001).abs() < 1e-10);
}
#[test]
fn test_multi_step_lr_unsorted_milestones() {
let sched = MultiStepLR::new(0.1, &[60, 30, 90], 0.5);
assert!((sched.lr(29) - 0.1).abs() < 1e-10);
assert!((sched.lr(30) - 0.05).abs() < 1e-10);
}
#[test]
fn test_multi_step_lr_scheduler_trait() {
let sched = MultiStepLR::new(0.1, &[10], 0.5);
let s: &dyn Scheduler = &sched;
assert!((s.lr(10) - 0.05).abs() < 1e-10);
}
#[test]
fn test_one_cycle_lr_shape() {
let sched = OneCycleLR::new(0.01, 100);
let min_lr = 0.01 / 25.0;
assert!((sched.lr(0) - min_lr).abs() < 1e-8, "start={}", sched.lr(0));
let peak = sched.lr(30);
assert!((peak - 0.01).abs() < 1e-6, "peak={}", peak);
let end = sched.lr(100);
assert!((end - min_lr).abs() < 1e-6, "end={}", end);
}
#[test]
fn test_one_cycle_lr_monotonic_warmup() {
let sched = OneCycleLR::new(0.01, 100);
let mut prev = 0.0;
for step in 0..=30 {
let lr = sched.lr(step);
assert!(lr >= prev, "LR should increase during warmup: step={}, lr={}, prev={}", step, lr, prev);
prev = lr;
}
}
#[test]
fn test_one_cycle_lr_monotonic_decay() {
let sched = OneCycleLR::new(0.01, 100);
let mut prev = f64::MAX;
for step in 30..=100 {
let lr = sched.lr(step);
assert!(lr <= prev + 1e-10, "LR should decrease during decay: step={}, lr={}, prev={}", step, lr, prev);
prev = lr;
}
}
#[test]
fn test_one_cycle_lr_custom_warmup() {
let sched = OneCycleLR::with_warmup_frac(0.01, 100, 0.1);
let peak = sched.lr(10);
assert!((peak - 0.01).abs() < 1e-6, "peak={}", peak);
}
#[test]
fn test_one_cycle_lr_scheduler_trait() {
let sched = OneCycleLR::new(0.01, 100);
let s: &dyn Scheduler = &sched;
assert!(s.lr(30) > s.lr(0));
}
#[test]
fn test_plateau_scheduler() {
let mut sched = PlateauScheduler::new(0.1, 3, 0.5, 1e-6);
assert!((sched.observe(1.0) - 0.1).abs() < 1e-10);
assert!((sched.observe(1.1) - 0.1).abs() < 1e-10);
assert!((sched.observe(1.2) - 0.1).abs() < 1e-10);
assert!((sched.observe(1.3) - 0.05).abs() < 1e-10);
}
#[test]
fn test_warmup_cosine() {
let inner = CosineScheduler::new(0.1, 0.001, 90);
let sched = WarmupScheduler::new(inner, 0.1, 10);
assert!(sched.lr(0) < 0.02);
assert!((sched.lr(9) - 0.1).abs() < 1e-5);
assert!((sched.lr(10) - 0.1).abs() < 0.01);
}
#[test]
fn test_cyclic_lr() {
let sched = CyclicLR::new(0.001, 0.01, 10);
assert!((sched.lr(0) - 0.001).abs() < 1e-6);
assert!((sched.lr(5) - 0.0055).abs() < 1e-4);
assert!((sched.lr(10) - 0.01).abs() < 1e-6);
assert!((sched.lr(15) - 0.0055).abs() < 1e-4);
assert!((sched.lr(20) - 0.001).abs() < 1e-6);
}
}