use burn::{LearningRate, lr_scheduler::LrScheduler, tensor::backend::Backend};
#[derive(Clone, Debug)]
pub(crate) struct CosineAnnealingLR {
t_max: f64,
eta_min: f64,
init_lr: LearningRate,
step_count: f64,
current_lr: LearningRate,
}
impl CosineAnnealingLR {
pub const fn init(t_max: f64, init_lr: LearningRate) -> Self {
Self {
t_max,
eta_min: 0.0,
init_lr,
step_count: -1.0,
current_lr: init_lr,
}
}
}
impl LrScheduler for CosineAnnealingLR {
type Record<B: Backend> = usize;
fn step(&mut self) -> LearningRate {
self.step_count += 1.0;
use std::f64::consts::PI;
fn cosine_annealing_lr(
init_lr: LearningRate,
lr: LearningRate,
step_count: f64,
t_max: f64,
eta_min: f64,
) -> LearningRate {
if step_count == 0.0 {
init_lr
} else if (step_count - 1.0 - t_max) % (2.0 * t_max) == 0.0 {
(init_lr - eta_min) * (1.0 - f64::cos(PI / t_max)) / 2.0
} else {
((1.0 + f64::cos(PI * step_count / t_max))
/ (1.0 + f64::cos(PI * (step_count - 1.0) / t_max)))
.mul_add(lr - eta_min, eta_min)
}
}
self.current_lr = cosine_annealing_lr(
self.init_lr,
self.current_lr,
self.step_count,
self.t_max,
self.eta_min,
);
self.current_lr
}
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.step_count as usize
}
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.step_count = record as LearningRate;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::TestHelper;
#[test]
fn lr_scheduler() {
let mut lr_scheduler = CosineAnnealingLR::init(5.0, 4e-2);
let lrs = (1..=11)
.map(|_| {
LrScheduler::step(&mut lr_scheduler);
lr_scheduler.current_lr
})
.step_by(1)
.collect::<Vec<_>>();
lrs.assert_approx_eq([
0.04,
0.03618033988749895,
0.026180339887498946,
0.013819660112501051,
0.0038196601125010526,
0.0,
0.003819660112501051,
0.013819660112501048,
0.026180339887498943,
0.03618033988749895,
0.039999999999999994,
]);
}
}