use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::schedulers::LearningRateScheduler;
#[derive(Debug, Clone)]
pub struct StepDecay<A: Float + Debug> {
initial_lr: A,
gamma: A,
step_size: usize,
step: usize,
current_lr: A,
}
impl<A: Float + Debug + Send + Sync> StepDecay<A> {
pub fn new(initial_lr: A, gamma: A, step_size: usize) -> Self {
Self {
initial_lr,
gamma,
step_size,
step: 0,
current_lr: initial_lr,
}
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A> for StepDecay<A> {
fn get_learning_rate(&self) -> A {
self.current_lr
}
fn step(&mut self) -> A {
self.step += 1;
let exponent = self.step / self.step_size;
self.current_lr = self.initial_lr * self.gamma.powi(exponent as i32);
self.current_lr
}
fn reset(&mut self) {
self.step = 0;
self.current_lr = self.initial_lr;
}
}