use std::f64::consts::PI;
pub trait LrSchedule: Send + Sync {
fn get_lr(&self, epoch: usize, base_lr: f64) -> f64;
}
#[derive(Debug, Clone)]
pub struct StepDecay {
pub step_size: usize,
pub gamma: f64,
}
impl StepDecay {
pub fn new(step_size: usize, gamma: f64) -> Self {
Self { step_size, gamma }
}
}
impl LrSchedule for StepDecay {
fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
let steps = epoch / self.step_size.max(1);
base_lr * self.gamma.powi(steps as i32)
}
}
#[derive(Debug, Clone)]
pub struct CosineAnnealing {
pub t_max: usize,
pub eta_min: f64,
}
impl CosineAnnealing {
pub fn new(t_max: usize, eta_min: f64) -> Self {
Self { t_max, eta_min }
}
}
impl LrSchedule for CosineAnnealing {
fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
let t_max = self.t_max.max(1) as f64;
let cos_val = (PI * epoch as f64 / t_max).cos();
self.eta_min + 0.5 * (base_lr - self.eta_min) * (1.0 + cos_val)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AnnealStrategy {
Cos,
Linear,
}
#[derive(Debug, Clone)]
pub struct OneCycle {
pub max_lr: f64,
pub pct_start: f64,
pub anneal_strategy: AnnealStrategy,
pub total_epochs: usize,
pub div_factor: f64,
pub final_div_factor: f64,
}
impl OneCycle {
pub fn new(
max_lr: f64,
pct_start: f64,
anneal_strategy: AnnealStrategy,
total_epochs: usize,
) -> Self {
Self {
max_lr,
pct_start: pct_start.clamp(0.0, 1.0),
anneal_strategy,
total_epochs,
div_factor: 25.0,
final_div_factor: 1e4,
}
}
fn anneal(&self, start: f64, end: f64, pct: f64) -> f64 {
let p = pct.clamp(0.0, 1.0);
match self.anneal_strategy {
AnnealStrategy::Cos => end + (start - end) / 2.0 * (1.0 + (PI * p).cos()),
AnnealStrategy::Linear => start + (end - start) * p,
}
}
}
impl LrSchedule for OneCycle {
fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
let total = self.total_epochs.max(1) as f64;
let pct = epoch as f64 / total;
let init_lr = base_lr / self.div_factor;
let final_lr = init_lr / self.final_div_factor;
if pct <= self.pct_start {
let phase_pct = if self.pct_start > 0.0 {
pct / self.pct_start
} else {
1.0
};
self.anneal(init_lr, self.max_lr, phase_pct)
} else {
let phase_pct = (pct - self.pct_start) / (1.0 - self.pct_start).max(1e-9);
self.anneal(self.max_lr, final_lr, phase_pct)
}
}
}
#[derive(Debug, Clone)]
pub struct WarmupCosine {
pub warmup_steps: usize,
pub total_steps: usize,
pub min_lr: f64,
}
impl WarmupCosine {
pub fn new(warmup_steps: usize, total_steps: usize, min_lr: f64) -> Self {
Self {
warmup_steps,
total_steps,
min_lr,
}
}
}
impl LrSchedule for WarmupCosine {
fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
if epoch < self.warmup_steps {
let warmup = self.warmup_steps.max(1) as f64;
base_lr * epoch as f64 / warmup
} else {
let decay_steps = (self.total_steps.saturating_sub(self.warmup_steps)).max(1) as f64;
let step = (epoch - self.warmup_steps) as f64;
let cos_val = (PI * step / decay_steps).cos();
self.min_lr + 0.5 * (base_lr - self.min_lr) * (1.0 + cos_val)
}
}
}
#[derive(Debug, Clone)]
pub struct ExponentialDecay {
pub gamma: f64,
}
impl ExponentialDecay {
pub fn new(gamma: f64) -> Self {
Self { gamma }
}
}
impl LrSchedule for ExponentialDecay {
fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
base_lr * self.gamma.powi(epoch as i32)
}
}
#[derive(Debug, Clone, Default)]
pub struct ConstantLr;
impl LrSchedule for ConstantLr {
fn get_lr(&self, _epoch: usize, base_lr: f64) -> f64 {
base_lr
}
}
#[derive(Debug, Clone)]
pub struct PolynomialDecay {
pub total_epochs: usize,
pub power: f64,
pub end_lr: f64,
}
impl PolynomialDecay {
pub fn new(total_epochs: usize, power: f64, end_lr: f64) -> Self {
Self {
total_epochs,
power,
end_lr,
}
}
}
impl LrSchedule for PolynomialDecay {
fn get_lr(&self, epoch: usize, base_lr: f64) -> f64 {
let total = self.total_epochs.max(1);
if epoch >= total {
return self.end_lr;
}
let decay = (1.0 - epoch as f64 / total as f64).powf(self.power);
let lr = (base_lr - self.end_lr) * decay + self.end_lr;
lr.max(self.end_lr)
}
}
#[derive(Debug, Clone)]
pub struct CyclicLr {
pub base_lr: f64,
pub max_lr: f64,
pub step_size: usize,
}
impl CyclicLr {
pub fn new(base_lr: f64, max_lr: f64, step_size: usize) -> Self {
Self {
base_lr,
max_lr,
step_size: step_size.max(1),
}
}
}
impl LrSchedule for CyclicLr {
fn get_lr(&self, epoch: usize, _base_lr: f64) -> f64 {
let cycle = epoch / (2 * self.step_size);
let x = (epoch as f64 / self.step_size as f64) - 2.0 * cycle as f64 - 1.0;
let scale = (1.0 - x.abs()).max(0.0);
self.base_lr + (self.max_lr - self.base_lr) * scale
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_step_decay() {
let sched = StepDecay::new(10, 0.5);
assert_abs_diff_eq!(sched.get_lr(0, 0.1), 0.1, epsilon = 1e-12);
assert_abs_diff_eq!(sched.get_lr(9, 0.1), 0.1, epsilon = 1e-12);
assert_abs_diff_eq!(sched.get_lr(10, 0.1), 0.05, epsilon = 1e-12);
assert_abs_diff_eq!(sched.get_lr(20, 0.1), 0.025, epsilon = 1e-12);
}
#[test]
fn test_cosine_annealing() {
let sched = CosineAnnealing::new(100, 0.0);
let lr_start = sched.get_lr(0, 1.0);
let lr_mid = sched.get_lr(50, 1.0);
let lr_end = sched.get_lr(100, 1.0);
assert_abs_diff_eq!(lr_start, 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(lr_mid, 0.5, epsilon = 1e-10);
assert_abs_diff_eq!(lr_end, 0.0, epsilon = 1e-12);
}
#[test]
fn test_one_cycle_warmup_peak() {
let sched = OneCycle::new(0.1, 0.3, AnnealStrategy::Cos, 100);
let lr_start = sched.get_lr(0, 0.01);
let lr_peak = sched.get_lr(30, 0.01);
assert!(lr_peak >= lr_start, "peak must exceed start");
assert_abs_diff_eq!(lr_peak, sched.max_lr, epsilon = 1e-10);
}
#[test]
fn test_warmup_cosine() {
let sched = WarmupCosine::new(10, 100, 0.0);
assert_abs_diff_eq!(sched.get_lr(0, 1.0), 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(sched.get_lr(5, 1.0), 0.5, epsilon = 1e-12);
assert_abs_diff_eq!(sched.get_lr(10, 1.0), 1.0, epsilon = 1e-12);
let lr_after = sched.get_lr(55, 1.0);
assert!(lr_after < 1.0, "should decay after warmup");
assert!(lr_after >= 0.0, "should not go below min_lr");
}
#[test]
fn test_exponential_decay() {
let sched = ExponentialDecay::new(0.9);
assert_abs_diff_eq!(sched.get_lr(0, 1.0), 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(sched.get_lr(1, 1.0), 0.9, epsilon = 1e-12);
assert_abs_diff_eq!(sched.get_lr(2, 1.0), 0.81, epsilon = 1e-12);
}
#[test]
fn test_constant_lr() {
let sched = ConstantLr;
for epoch in 0..100 {
assert_abs_diff_eq!(sched.get_lr(epoch, 0.01), 0.01, epsilon = 1e-12);
}
}
#[test]
fn test_cyclic_lr() {
let sched = CyclicLr::new(0.001, 0.01, 5);
let lr0 = sched.get_lr(0, 0.0);
let lr5 = sched.get_lr(5, 0.0);
assert_abs_diff_eq!(lr5, sched.max_lr, epsilon = 1e-10);
let lr10 = sched.get_lr(10, 0.0);
assert_abs_diff_eq!(lr10, sched.base_lr, epsilon = 1e-10);
assert!(lr5 > lr0, "peak should exceed start");
}
}