use crate::callbacks::{Callback, CallbackContext, CallbackTiming};
use crate::error::Result;
use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::marker::PhantomData;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScheduleMethod {
Epoch,
Batch,
}
pub struct StepDecay<F: Float + Debug + ScalarOperand + NumAssign> {
initial_lr: F,
factor: F,
step_size: usize,
method: ScheduleMethod,
current_lr: F,
min_lr: F,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> StepDecay<F> {
pub fn new(
initial_lr: F,
factor: F,
step_size: usize,
method: ScheduleMethod,
min_lr: F,
) -> Self {
Self {
initial_lr,
factor,
step_size,
method,
current_lr: initial_lr,
min_lr,
}
}
pub fn get_initial_lr(&self) -> F {
self.initial_lr
}
pub fn get_lr(&self) -> F {
self.current_lr
}
pub fn update_lr(&mut self, step: usize) {
if step > 0 && step.is_multiple_of(self.step_size) {
self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
}
}
pub fn reset_to_initial(&mut self) {
self.current_lr = self.initial_lr;
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Callback<F> for StepDecay<F> {
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()> {
match (timing, self.method) {
(CallbackTiming::BeforeEpoch, ScheduleMethod::Epoch) => {
self.update_lr(context.epoch);
println!("Setting learning rate to: {:.6?}", self.current_lr);
}
(CallbackTiming::BeforeBatch, ScheduleMethod::Batch) => {
let step = context.epoch * context.total_batches + context.batch;
self.update_lr(step);
}
_ => {}
}
Ok(())
}
}
pub struct ReduceOnPlateau<F: Float + Debug + ScalarOperand + NumAssign> {
initial_lr: F,
factor: F,
patience: usize,
threshold: F,
min_lr: F,
current_lr: F,
monitor_val_loss: bool,
best_value: Option<F>,
patience_counter: usize,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> ReduceOnPlateau<F> {
pub fn new(initial_lr: F, factor: F, patience: usize, threshold: F, min_lr: F) -> Self {
Self {
initial_lr,
factor,
patience,
threshold,
min_lr,
current_lr: initial_lr,
monitor_val_loss: true,
best_value: None,
patience_counter: 0,
}
}
pub fn monitor_train_loss(mut self) -> Self {
self.monitor_val_loss = false;
self
}
pub fn reset(&mut self) {
self.best_value = None;
self.patience_counter = 0;
self.current_lr = self.initial_lr;
}
pub fn get_current_lr(&self) -> F {
self.current_lr
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Callback<F> for ReduceOnPlateau<F> {
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()> {
if timing == CallbackTiming::AfterEpoch {
let current_value = if self.monitor_val_loss {
context.val_loss
} else {
context.epoch_loss
};
if let Some(current) = current_value {
match self.best_value {
None => {
self.best_value = Some(current);
self.patience_counter = 0;
}
Some(best) => {
if current < best - self.threshold {
self.best_value = Some(current);
self.patience_counter = 0;
} else {
self.patience_counter += 1;
if self.patience_counter >= self.patience {
self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
self.patience_counter = 0;
println!(
"ReduceOnPlateau: No improvement detected, reducing learning rate to {:.6?}",
self.current_lr
);
}
}
}
}
}
}
Ok(())
}
}
pub struct CosineAnnealingLR<F: Float + Debug + ScalarOperand + NumAssign> {
max_lr: F,
min_lr: F,
cycle_epochs: usize,
current_lr: F,
pub total_steps: usize,
_phantom: PhantomData<F>,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> CosineAnnealingLR<F> {
pub fn new(max_lr: F, min_lr: F, cycle_epochs: usize, total_steps: usize) -> Self {
Self {
max_lr,
min_lr,
cycle_epochs,
current_lr: max_lr,
total_steps,
_phantom: PhantomData,
}
}
pub fn get_initial_lr(&self) -> F {
self.max_lr
}
pub fn calculate_lr(&self, step: usize) -> F {
let cycle = step % self.cycle_epochs;
let percent = F::from(cycle).expect("Failed to convert to float")
/ F::from(self.cycle_epochs).expect("Failed to convert to float");
let cosine = (F::one()
+ (percent * F::from(std::f64::consts::PI).expect("Failed to convert to float")).cos())
/ F::from(2.0).expect("Failed to convert constant to float");
self.min_lr + (self.max_lr - self.min_lr) * cosine
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Callback<F> for CosineAnnealingLR<F> {
fn on_event(&mut self, timing: CallbackTiming, context: &mut CallbackContext<F>) -> Result<()> {
if timing == CallbackTiming::BeforeEpoch {
let step = context.epoch;
self.current_lr = self.calculate_lr(step);
println!("Cosine annealing LR: {:.6?}", self.current_lr);
}
Ok(())
}
}