use crate::error::Result;
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
pub trait LearningRateScheduler<F: Float + Debug + ScalarOperand + NumAssign> {
fn get_learning_rate(&mut self, progress: f64) -> Result<F>;
fn reset(&mut self) {
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> LearningRateScheduler<F>
for super::StepDecay<F>
{
fn get_learning_rate(&mut self, progress: f64) -> Result<F> {
let step = (progress * 100.0).floor() as usize; self.update_lr(step);
Ok(self.get_lr())
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> LearningRateScheduler<F>
for super::CosineAnnealingLR<F>
{
fn get_learning_rate(&mut self, progress: f64) -> Result<F> {
let step = (progress * self.total_steps as f64).floor() as usize;
let lr = self.calculate_lr(step);
Ok(lr)
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> LearningRateScheduler<F>
for super::ReduceOnPlateau<F>
{
fn get_learning_rate(&mut self, _progress: f64) -> Result<F> {
Ok(self.get_current_lr())
}
fn reset(&mut self) {
self.reset();
}
}