use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::schedulers::LearningRateScheduler;
#[derive(Debug, Clone)]
pub struct ExponentialDecay<A: Float + Debug> {
initial_lr: A,
decay_rate: A,
decay_steps: usize,
step: usize,
current_lr: A,
}
impl<A: Float + Debug + Send + Sync> ExponentialDecay<A> {
pub fn new(initial_lr: A, decay_rate: A, decay_steps: usize) -> Self {
Self {
initial_lr,
decay_rate,
decay_steps,
step: 0,
current_lr: initial_lr,
}
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A>
for ExponentialDecay<A>
{
fn get_learning_rate(&self) -> A {
self.current_lr
}
fn step(&mut self) -> A {
self.step += 1;
let power = A::from(self.step).expect("unwrap failed")
/ A::from(self.decay_steps).expect("unwrap failed");
self.current_lr = self.initial_lr * self.decay_rate.powf(power);
self.current_lr
}
fn reset(&mut self) {
self.step = 0;
self.current_lr = self.initial_lr;
}
}