use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::schedulers::LearningRateScheduler;
#[derive(Debug, Clone)]
pub struct ReduceOnPlateau<A: Float + Debug> {
current_lr: A,
factor: A,
patience: usize,
min_lr: A,
stagnation_count: usize,
best_metric: Option<A>,
threshold: A,
mode_is_min: bool,
}
impl<A: Float + Debug + Send + Sync> ReduceOnPlateau<A> {
pub fn new(initial_lr: A, factor: A, patience: usize, min_lr: A) -> Self {
Self {
current_lr: initial_lr,
factor,
patience,
min_lr,
stagnation_count: 0,
best_metric: None,
threshold: A::from(1e-4).expect("unwrap failed"),
mode_is_min: true,
}
}
pub fn mode_min(&mut self) -> &mut Self {
self.mode_is_min = true;
self
}
pub fn mode_max(&mut self) -> &mut Self {
self.mode_is_min = false;
self
}
pub fn set_threshold(&mut self, threshold: A) -> &mut Self {
self.threshold = threshold;
self
}
pub fn step_with_metric(&mut self, metric: A) -> A {
let is_improvement = match self.best_metric {
None => true, Some(best) => {
if self.mode_is_min {
metric < best * (A::one() - self.threshold)
} else {
metric > best * (A::one() + self.threshold)
}
}
};
if is_improvement {
self.best_metric = Some(metric);
self.stagnation_count = 0;
} else {
self.stagnation_count += 1;
if self.stagnation_count >= self.patience {
self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
self.stagnation_count = 0;
}
}
self.current_lr
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A>
for ReduceOnPlateau<A>
{
fn get_learning_rate(&self) -> A {
self.current_lr
}
fn step(&mut self) -> A {
self.current_lr
}
fn reset(&mut self) {
self.stagnation_count = 0;
self.best_metric = None;
}
}