use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::schedulers::LearningRateScheduler;
#[derive(Debug, Clone)]
pub struct CosineAnnealing<A: Float + Debug> {
initial_lr: A,
min_lr: A,
t_max: usize,
warm_restart: bool,
step: usize,
current_lr: A,
}
impl<A: Float + Debug + Send + Sync> CosineAnnealing<A> {
pub fn new(initial_lr: A, min_lr: A, t_max: usize, warm_restart: bool) -> Self {
Self {
initial_lr,
min_lr,
t_max,
warm_restart,
step: 0,
current_lr: initial_lr,
}
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A>
for CosineAnnealing<A>
{
fn get_learning_rate(&self) -> A {
self.current_lr
}
fn step(&mut self) -> A {
self.step += 1;
if self.warm_restart && self.step.is_multiple_of(self.t_max) && self.step > 0 {
self.step = 0;
}
let t_cur = if self.t_max > 0 {
self.step % self.t_max
} else {
0
};
let pi = A::from(std::f64::consts::PI).expect("unwrap failed");
let cos_term = A::one()
+ (pi * A::from(t_cur).expect("unwrap failed")
/ A::from(self.t_max).expect("unwrap failed"))
.cos();
self.current_lr = self.min_lr
+ A::from(0.5).expect("unwrap failed") * (self.initial_lr - self.min_lr) * cos_term;
self.current_lr
}
fn reset(&mut self) {
self.step = 0;
self.current_lr = self.initial_lr;
}
}