use crate::callbacks::LearningRateScheduler;
use crate::error::Result;
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::Float;
use scirs2_core::NumAssign;
use std::fmt::Debug;
pub struct LRSchedulerOptimizer<O, S, F>
where
F: Float + Debug + ScalarOperand + NumAssign,
O: Optimizer<F>,
S: LearningRateScheduler<F>,
{
optimizer: O,
scheduler: S,
step: usize,
total_steps: usize,
_phantom: std::marker::PhantomData<F>,
}
impl<O, S, F> LRSchedulerOptimizer<O, S, F>
where
F: Float + Debug + ScalarOperand + NumAssign,
O: Optimizer<F>,
S: LearningRateScheduler<F>,
{
pub fn new(optimizer: O, scheduler: S, total_steps: usize) -> Self {
Self {
optimizer,
scheduler,
step: 0,
total_steps,
_phantom: std::marker::PhantomData,
}
}
pub fn optimizer(&self) -> &O {
&self.optimizer
}
pub fn optimizer_mut(&mut self) -> &mut O {
&mut self.optimizer
}
pub fn scheduler(&self) -> &S {
&self.scheduler
}
pub fn scheduler_mut(&mut self) -> &mut S {
&mut self.scheduler
}
pub fn reset(&mut self) {
self.step = 0;
self.scheduler.reset();
}
pub fn set_step(&mut self, step: usize) {
self.step = step;
}
pub fn get_step(&self) -> usize {
self.step
}
}
impl<O, S, F> Optimizer<F> for LRSchedulerOptimizer<O, S, F>
where
F: Float + Debug + ScalarOperand + NumAssign,
O: Optimizer<F>,
S: LearningRateScheduler<F>,
{
fn update(
&mut self,
params: &mut [Array<F, scirs2_core::ndarray::IxDyn>],
grads: &[Array<F, scirs2_core::ndarray::IxDyn>],
) -> Result<()> {
let progress = if self.total_steps == 0 {
0.0
} else {
self.step as f64 / self.total_steps as f64
};
let new_lr = self.scheduler.get_learning_rate(progress)?;
self.optimizer.set_learning_rate(new_lr);
let result = self.optimizer.update(params, grads);
if result.is_ok() {
self.step += 1;
}
result
}
fn get_learning_rate(&self) -> F {
self.optimizer.get_learning_rate()
}
fn set_learning_rate(&mut self, lr: F) {
self.optimizer.set_learning_rate(lr);
}
fn name(&self) -> &'static str {
"LRSchedulerOptimizer"
}
}
#[allow(dead_code)]
pub fn with_step_decay<O, F>(
optimizer: O,
initial_lr: F,
factor: F,
step_size: usize,
min_lr: F,
total_steps: usize,
) -> LRSchedulerOptimizer<O, crate::callbacks::StepDecay<F>, F>
where
F: Float + Debug + ScalarOperand + NumAssign,
O: Optimizer<F>,
{
let scheduler = crate::callbacks::StepDecay::new(
initial_lr,
factor,
step_size,
crate::callbacks::ScheduleMethod::Epoch,
min_lr,
);
LRSchedulerOptimizer::new(optimizer, scheduler, total_steps)
}
#[allow(dead_code)]
pub fn with_cosine_annealing<O, F>(
optimizer: O,
max_lr: F,
cycle_epochs: usize,
total_steps: usize,
) -> LRSchedulerOptimizer<O, crate::callbacks::CosineAnnealingLR<F>, F>
where
F: Float + Debug + ScalarOperand + NumAssign,
O: Optimizer<F>,
{
let min_lr = F::zero();
let scheduler =
crate::callbacks::CosineAnnealingLR::new(max_lr, min_lr, cycle_epochs, total_steps);
LRSchedulerOptimizer::new(optimizer, scheduler, total_steps)
}