use super::LRScheduler;
use crate::optim::Optimizer;
pub struct LinearWarmupLR {
lr_target: f32,
warmup_steps: usize,
current_step: usize,
}
impl LinearWarmupLR {
pub fn new(lr_target: f32, warmup_steps: usize) -> Self {
Self { lr_target, warmup_steps, current_step: 0 }
}
pub fn apply<O: Optimizer>(&self, optimizer: &mut O) {
optimizer.set_lr(self.get_lr());
}
}
impl LRScheduler for LinearWarmupLR {
fn get_lr(&self) -> f32 {
if self.warmup_steps == 0 {
return self.lr_target;
}
let progress = (self.current_step as f32 / self.warmup_steps as f32).min(1.0);
self.lr_target * progress
}
fn step(&mut self) {
self.current_step += 1;
}
}