use crate::config::LrScheduler;
pub struct LrSchedule {
base_lr: f64,
warmup_steps: u64,
total_steps: u64,
scheduler_type: LrScheduler,
}
impl LrSchedule {
pub fn new(
base_lr: f64,
warmup_steps: u64,
total_steps: u64,
scheduler_type: LrScheduler,
) -> Self {
Self {
base_lr,
warmup_steps,
total_steps,
scheduler_type,
}
}
pub fn get_lr(&self, step: u64) -> f64 {
if step == 0 {
return 0.0;
}
if step <= self.warmup_steps {
return self.base_lr * (step as f64 / self.warmup_steps.max(1) as f64);
}
let decay_step = step - self.warmup_steps;
let decay_total = self.total_steps.saturating_sub(self.warmup_steps).max(1);
match self.scheduler_type {
LrScheduler::Constant => self.base_lr,
LrScheduler::Linear => {
let progress = decay_step as f64 / decay_total as f64;
self.base_lr * (1.0 - progress).max(0.0)
}
LrScheduler::Cosine => {
let progress = decay_step as f64 / decay_total as f64;
self.base_lr * 0.5 * (1.0 + (std::f64::consts::PI * progress).cos())
}
LrScheduler::CosineWarmRestarts => {
let t_0 = (decay_total as f64 / 2.0).max(1.0);
let t_cur = decay_step as f64 % t_0;
self.base_lr * 0.5 * (1.0 + (std::f64::consts::PI * t_cur / t_0).cos())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_schedule() {
let sched = LrSchedule::new(1e-4, 10, 100, LrScheduler::Constant);
assert!((sched.get_lr(5) - 5e-5).abs() < 1e-10);
assert!((sched.get_lr(10) - 1e-4).abs() < 1e-10);
assert!((sched.get_lr(50) - 1e-4).abs() < 1e-10);
assert!((sched.get_lr(100) - 1e-4).abs() < 1e-10);
}
#[test]
fn test_linear_schedule() {
let sched = LrSchedule::new(1e-4, 0, 100, LrScheduler::Linear);
assert!((sched.get_lr(1) - 1e-4 * 0.99).abs() < 1e-10);
assert!((sched.get_lr(50) - 1e-4 * 0.5).abs() < 1e-10);
assert!(sched.get_lr(100) < 1e-10);
}
#[test]
fn test_cosine_schedule() {
let sched = LrSchedule::new(1e-4, 10, 110, LrScheduler::Cosine);
let lr_mid = sched.get_lr(60); assert!((lr_mid - 1e-4 * 0.5).abs() < 1e-6);
assert!(sched.get_lr(110) < 1e-8);
}
#[test]
fn test_warmup_ramp() {
let sched = LrSchedule::new(1e-3, 100, 1000, LrScheduler::Cosine);
assert_eq!(sched.get_lr(0), 0.0);
assert!((sched.get_lr(50) - 5e-4).abs() < 1e-10);
assert!((sched.get_lr(100) - 1e-3).abs() < 1e-10);
}
}