use burn::lr_scheduler::LrScheduler;
use burn::record::Record;
use burn::tensor::backend::Backend;
#[derive(Clone, Debug)]
pub struct RsqrtScheduler {
peak_lr: f64,
warmup_steps: usize,
cooldown_start: usize,
total_steps: usize,
current_step: usize,
}
impl RsqrtScheduler {
pub fn new(
peak_lr: f64,
total_steps: usize,
warmup_fraction: f64,
cooldown_fraction: f64,
) -> Self {
let warmup_steps = ((total_steps as f64) * warmup_fraction).round() as usize;
let cooldown_steps = ((total_steps as f64) * cooldown_fraction).round() as usize;
let cooldown_start = total_steps.saturating_sub(cooldown_steps);
Self {
peak_lr,
warmup_steps,
cooldown_start,
total_steps,
current_step: 0,
}
}
pub fn lr_at(&self, step: usize) -> f64 {
if step == 0 {
return 0.0;
}
if step <= self.warmup_steps {
return self.peak_lr * step as f64 / self.warmup_steps as f64;
}
let rsqrt_lr = self.peak_lr * (self.warmup_steps as f64 / step as f64).sqrt();
if step >= self.cooldown_start && self.cooldown_start < self.total_steps {
let lr_at_cooldown_start =
self.peak_lr * (self.warmup_steps as f64 / self.cooldown_start as f64).sqrt();
let frac = (step - self.cooldown_start) as f64
/ (self.total_steps - self.cooldown_start) as f64;
return lr_at_cooldown_start * (1.0 - frac).max(0.0);
}
rsqrt_lr
}
}
#[derive(Clone, Debug)]
pub struct RsqrtSchedulerRecord {
current_step: usize,
}
impl<B: Backend> Record<B> for RsqrtSchedulerRecord {
type Item<S: burn::record::PrecisionSettings> = usize;
fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {
self.current_step
}
fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
Self { current_step: item }
}
}
impl<B: Backend> LrScheduler<B> for RsqrtScheduler {
type Record = RsqrtSchedulerRecord;
fn step(&mut self) -> f64 {
let lr = self.lr_at(self.current_step);
self.current_step += 1;
lr
}
fn to_record(&self) -> Self::Record {
RsqrtSchedulerRecord { current_step: self.current_step }
}
fn load_record(mut self, record: Self::Record) -> Self {
self.current_step = record.current_step;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_warmup() {
let s = RsqrtScheduler::new(1e-3, 1000, 0.2, 0.2);
assert_eq!(s.lr_at(0), 0.0);
let lr_end = s.lr_at(200);
assert!((lr_end - 1e-3).abs() < 1e-10, "LR at end of warmup should be peak");
}
#[test]
fn test_scheduler_rsqrt_decay() {
let s = RsqrtScheduler::new(1e-3, 10_000, 0.01, 0.0);
let lr_400 = s.lr_at(400);
let lr_1600 = s.lr_at(1600);
let ratio = lr_400 / lr_1600;
assert!(
(ratio - 2.0).abs() < 0.01,
"Rsqrt: quadrupling step should halve LR; got ratio={ratio:.4}"
);
}
#[test]
fn test_scheduler_cooldown_reaches_zero() {
let s = RsqrtScheduler::new(1e-3, 1000, 0.2, 0.2);
assert_eq!(s.lr_at(1000), 0.0);
}
}