use super::LRScheduler;
use crate::optim::Optimizer;
use std::f32::consts::PI;
pub struct WarmupCosineDecayLR {
lr_max: f32,
lr_min: f32,
warmup_steps: usize,
total_steps: usize,
current_step: usize,
}
impl WarmupCosineDecayLR {
pub fn new(lr_max: f32, lr_min: f32, warmup_steps: usize, total_steps: usize) -> Self {
Self { lr_max, lr_min, warmup_steps, total_steps, current_step: 0 }
}
pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
optimizer.set_lr(self.get_lr());
}
}
impl LRScheduler for WarmupCosineDecayLR {
fn get_lr(&self) -> f32 {
if self.current_step < self.warmup_steps {
if self.warmup_steps == 0 {
return self.lr_max;
}
let progress = self.current_step as f32 / self.warmup_steps as f32;
return self.lr_max * progress;
}
let decay_steps = self.total_steps.saturating_sub(self.warmup_steps);
if decay_steps == 0 {
return self.lr_min;
}
let decay_step = self.current_step - self.warmup_steps;
if decay_step >= decay_steps {
return self.lr_min;
}
let progress = decay_step as f32 / decay_steps as f32;
let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
self.lr_min + (self.lr_max - self.lr_min) * cosine_decay
}
fn step(&mut self) {
self.current_step += 1;
}
}