use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::schedulers::LearningRateScheduler;
#[derive(Debug, Clone)]
pub struct LinearDecay<A: Float + Debug> {
initial_lr: A,
final_lr: A,
total_steps: usize,
step: usize,
current_lr: A,
}
impl<A: Float + Debug + Send + Sync> LinearDecay<A> {
pub fn new(initial_lr: A, final_lr: A, total_steps: usize) -> Self {
Self {
initial_lr,
final_lr,
total_steps,
step: 0,
current_lr: initial_lr,
}
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A> for LinearDecay<A> {
fn get_learning_rate(&self) -> A {
self.current_lr
}
fn step(&mut self) -> A {
self.step += 1;
let progress = if self.total_steps > 0 {
(A::from(self.step).expect("unwrap failed")
/ A::from(self.total_steps).expect("unwrap failed"))
.min(A::one())
} else {
A::one()
};
self.current_lr = self.initial_lr - (self.initial_lr - self.final_lr) * progress;
self.current_lr
}
fn reset(&mut self) {
self.step = 0;
self.current_lr = self.initial_lr;
}
}