use crate::optim::Optimizer;
use num_traits::Float;
use std::fmt::Debug;
pub trait LRScheduler<T: Float> {
fn get_lr(&self) -> Vec<T>;
fn step(&mut self);
fn step_with_metric(&mut self, _metric: T) {
self.step();
}
fn last_epoch(&self) -> i32;
fn state_dict(&self) -> SchedulerState<T>;
fn load_state_dict(&mut self, state: SchedulerState<T>);
}
#[derive(Debug, Clone)]
pub struct SchedulerState<T: Float> {
pub last_epoch: i32,
pub base_lrs: Vec<T>,
pub step_count: usize,
pub best_metric: Option<T>,
pub num_bad_epochs: usize,
pub cooldown_counter: usize,
}
#[derive(Debug)]
pub struct StepLR<T: Float> {
step_size: usize,
gamma: T,
last_epoch: i32,
base_lrs: Vec<T>,
current_lrs: Vec<T>,
}
impl<T: Float + Copy + From<f32>> StepLR<T> {
pub fn new(
_optimizer: &mut dyn Optimizer<T>,
step_size: usize,
gamma: T,
last_epoch: Option<i32>,
) -> Self {
let last_epoch = last_epoch.unwrap_or(-1);
let base_lrs = vec![<T as From<f32>>::from(0.01f32)]; let current_lrs = base_lrs.clone();
StepLR {
step_size,
gamma,
last_epoch,
base_lrs,
current_lrs,
}
}
fn calculate_lr(&self, base_lr: T, epoch: i32) -> T {
let step_count = ((epoch + 1) as f32 / self.step_size as f32).floor() as i32;
base_lr * self.gamma.powi(step_count)
}
}
impl<T: Float + Copy + From<f32>> LRScheduler<T> for StepLR<T> {
fn get_lr(&self) -> Vec<T> {
self.current_lrs.clone()
}
fn step(&mut self) {
self.last_epoch += 1;
self.current_lrs = self.base_lrs
.iter()
.map(|&base_lr| self.calculate_lr(base_lr, self.last_epoch))
.collect();
}
fn last_epoch(&self) -> i32 {
self.last_epoch
}
fn state_dict(&self) -> SchedulerState<T> {
SchedulerState {
last_epoch: self.last_epoch,
base_lrs: self.base_lrs.clone(),
step_count: 0,
best_metric: None,
num_bad_epochs: 0,
cooldown_counter: 0,
}
}
fn load_state_dict(&mut self, state: SchedulerState<T>) {
self.last_epoch = state.last_epoch;
self.base_lrs = state.base_lrs;
}
}
#[derive(Debug)]
pub struct ExponentialLR<T: Float> {
gamma: T,
last_epoch: i32,
base_lrs: Vec<T>,
current_lrs: Vec<T>,
}
impl<T: Float + Copy + From<f32>> ExponentialLR<T> {
pub fn new(
_optimizer: &mut dyn Optimizer<T>,
gamma: T,
last_epoch: Option<i32>,
) -> Self {
let last_epoch = last_epoch.unwrap_or(-1);
let base_lrs = vec![<T as From<f32>>::from(0.01f32)]; let current_lrs = base_lrs.clone();
ExponentialLR {
gamma,
last_epoch,
base_lrs,
current_lrs,
}
}
}
impl<T: Float + Copy + From<f32>> LRScheduler<T> for ExponentialLR<T> {
fn get_lr(&self) -> Vec<T> {
self.current_lrs.clone()
}
fn step(&mut self) {
self.last_epoch += 1;
self.current_lrs = self.base_lrs
.iter()
.map(|&base_lr| base_lr * self.gamma.powi(self.last_epoch + 1))
.collect();
}
fn last_epoch(&self) -> i32 {
self.last_epoch
}
fn state_dict(&self) -> SchedulerState<T> {
SchedulerState {
last_epoch: self.last_epoch,
base_lrs: self.base_lrs.clone(),
step_count: 0,
best_metric: None,
num_bad_epochs: 0,
cooldown_counter: 0,
}
}
fn load_state_dict(&mut self, state: SchedulerState<T>) {
self.last_epoch = state.last_epoch;
self.base_lrs = state.base_lrs;
}
}
#[derive(Debug)]
pub struct CosineAnnealingLR<T: Float> {
t_max: usize,
eta_min: T,
last_epoch: i32,
base_lrs: Vec<T>,
current_lrs: Vec<T>,
}
impl<T: Float + Copy + From<f32>> CosineAnnealingLR<T> {
pub fn new(
_optimizer: &mut dyn Optimizer<T>,
t_max: usize,
eta_min: Option<T>,
last_epoch: Option<i32>,
) -> Self {
let last_epoch = last_epoch.unwrap_or(-1);
let eta_min = eta_min.unwrap_or_else(|| <T as From<f32>>::from(0.0f32));
let base_lrs = vec![<T as From<f32>>::from(0.01f32)]; let current_lrs = base_lrs.clone();
CosineAnnealingLR {
t_max,
eta_min,
last_epoch,
base_lrs,
current_lrs,
}
}
fn calculate_lr(&self, base_lr: T, epoch: i32) -> T {
if epoch < 0 {
return base_lr;
}
let t_cur = epoch as f32;
let t_max = self.t_max as f32;
let pi = std::f32::consts::PI;
let cosine_factor = <T as From<f32>>::from((1.0 + (pi * t_cur / t_max).cos()) / 2.0);
self.eta_min + (base_lr - self.eta_min) * cosine_factor
}
}
impl<T: Float + Copy + From<f32>> LRScheduler<T> for CosineAnnealingLR<T> {
fn get_lr(&self) -> Vec<T> {
self.current_lrs.clone()
}
fn step(&mut self) {
self.last_epoch += 1;
self.current_lrs = self.base_lrs
.iter()
.map(|&base_lr| self.calculate_lr(base_lr, self.last_epoch))
.collect();
}
fn last_epoch(&self) -> i32 {
self.last_epoch
}
fn state_dict(&self) -> SchedulerState<T> {
SchedulerState {
last_epoch: self.last_epoch,
base_lrs: self.base_lrs.clone(),
step_count: 0,
best_metric: None,
num_bad_epochs: 0,
cooldown_counter: 0,
}
}
fn load_state_dict(&mut self, state: SchedulerState<T>) {
self.last_epoch = state.last_epoch;
self.base_lrs = state.base_lrs;
}
}
#[derive(Debug)]
pub struct ReduceLROnPlateau<T: Float> {
mode: PlateauMode,
factor: T,
patience: usize,
threshold: T,
threshold_mode: ThresholdMode,
cooldown: usize,
min_lr: T,
eps: T,
last_epoch: i32,
base_lrs: Vec<T>,
current_lrs: Vec<T>,
best_metric: Option<T>,
num_bad_epochs: usize,
cooldown_counter: usize,
}
#[derive(Debug, Clone, Copy)]
pub enum PlateauMode {
Min,
Max,
}
#[derive(Debug, Clone, Copy)]
pub enum ThresholdMode {
Rel,
Abs,
}
impl<T: Float + Copy + From<f32>> ReduceLROnPlateau<T> {
pub fn new(
_optimizer: &mut dyn Optimizer<T>,
mode: Option<PlateauMode>,
factor: Option<T>,
patience: Option<usize>,
threshold: Option<T>,
threshold_mode: Option<ThresholdMode>,
cooldown: Option<usize>,
min_lr: Option<T>,
eps: Option<T>,
) -> Self {
let mode = mode.unwrap_or(PlateauMode::Min);
let factor = factor.unwrap_or_else(|| <T as From<f32>>::from(0.1f32));
let patience = patience.unwrap_or(10);
let threshold = threshold.unwrap_or_else(|| <T as From<f32>>::from(1e-4f32));
let threshold_mode = threshold_mode.unwrap_or(ThresholdMode::Rel);
let cooldown = cooldown.unwrap_or(0);
let min_lr = min_lr.unwrap_or_else(|| <T as From<f32>>::from(0.0f32));
let eps = eps.unwrap_or_else(|| <T as From<f32>>::from(1e-8f32));
let base_lrs = vec![<T as From<f32>>::from(0.01f32)]; let current_lrs = base_lrs.clone();
ReduceLROnPlateau {
mode,
factor,
patience,
threshold,
threshold_mode,
cooldown,
min_lr,
eps,
last_epoch: 0,
base_lrs,
current_lrs,
best_metric: None,
num_bad_epochs: 0,
cooldown_counter: 0,
}
}
fn is_better(&self, current: T, best: T) -> bool {
match (self.mode, self.threshold_mode) {
(PlateauMode::Min, ThresholdMode::Rel) => current < best * (T::one() - self.threshold),
(PlateauMode::Min, ThresholdMode::Abs) => current < best - self.threshold,
(PlateauMode::Max, ThresholdMode::Rel) => current > best * (T::one() + self.threshold),
(PlateauMode::Max, ThresholdMode::Abs) => current > best + self.threshold,
}
}
}
impl<T: Float + Copy + From<f32>> LRScheduler<T> for ReduceLROnPlateau<T> {
fn get_lr(&self) -> Vec<T> {
self.current_lrs.clone()
}
fn step(&mut self) {
}
fn step_with_metric(&mut self, metric: T) {
if self.cooldown_counter > 0 {
self.cooldown_counter -= 1;
self.num_bad_epochs = 0;
}
if self.best_metric.is_none() {
self.best_metric = Some(metric);
} else if self.is_better(metric, self.best_metric.unwrap()) {
self.best_metric = Some(metric);
self.num_bad_epochs = 0;
} else {
self.num_bad_epochs += 1;
}
if self.cooldown_counter == 0 && self.num_bad_epochs > self.patience {
self.reduce_lr();
self.cooldown_counter = self.cooldown;
self.num_bad_epochs = 0;
}
}
fn last_epoch(&self) -> i32 {
self.last_epoch
}
fn state_dict(&self) -> SchedulerState<T> {
SchedulerState {
last_epoch: self.last_epoch,
base_lrs: self.base_lrs.clone(),
step_count: 0,
best_metric: self.best_metric,
num_bad_epochs: self.num_bad_epochs,
cooldown_counter: self.cooldown_counter,
}
}
fn load_state_dict(&mut self, state: SchedulerState<T>) {
self.last_epoch = state.last_epoch;
self.base_lrs = state.base_lrs;
self.best_metric = state.best_metric;
self.num_bad_epochs = state.num_bad_epochs;
self.cooldown_counter = state.cooldown_counter;
}
}
impl<T: Float + Copy + From<f32>> ReduceLROnPlateau<T> {
fn reduce_lr(&mut self) {
for i in 0..self.current_lrs.len() {
let new_lr = (self.current_lrs[i] * self.factor).max(self.min_lr);
if (self.current_lrs[i] - new_lr).abs() > self.eps {
self.current_lrs[i] = new_lr;
self.base_lrs[i] = new_lr; }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optim::SGD;
use crate::autograd::Variable;
use crate::tensor::Tensor;
#[test]
fn test_step_lr_creation() {
let params = vec![Variable::new(Tensor::ones(&[2, 2]), true)];
let mut optimizer = SGD::new(params, 0.01, Some(0.9), None, None, None);
let scheduler = StepLR::new(&mut optimizer, 30, 0.1, None);
assert_eq!(scheduler.last_epoch(), -1);
assert_eq!(scheduler.step_size, 30);
}
#[test]
fn test_step_lr_calculation() {
let params = vec![Variable::new(Tensor::ones(&[2, 2]), true)];
let mut optimizer = SGD::new(params, 0.01, Some(0.9), None, None, None);
let mut scheduler = StepLR::new(&mut optimizer, 30, 0.1, None);
let initial_lrs = scheduler.get_lr();
assert_eq!(initial_lrs.len(), 1);
scheduler.step();
assert_eq!(scheduler.last_epoch(), 0);
}
#[test]
fn test_exponential_lr() {
let params = vec![Variable::new(Tensor::ones(&[2, 2]), true)];
let mut optimizer = SGD::new(params, 0.01, Some(0.9), None, None, None);
let mut scheduler = ExponentialLR::new(&mut optimizer, 0.9, None);
let initial_lrs = scheduler.get_lr();
scheduler.step();
let step1_lrs = scheduler.get_lr();
assert!(step1_lrs[0] < initial_lrs[0]);
}
#[test]
fn test_cosine_annealing_lr() {
let params = vec![Variable::new(Tensor::ones(&[2, 2]), true)];
let mut optimizer = SGD::new(params, 0.01, Some(0.9), None, None, None);
let mut scheduler = CosineAnnealingLR::new(&mut optimizer, 100, None, None);
let initial_lrs = scheduler.get_lr();
for _ in 0..50 {
scheduler.step();
}
let mid_lrs = scheduler.get_lr();
assert!(mid_lrs[0] < initial_lrs[0]);
}
#[test]
fn test_reduce_lr_on_plateau() {
let params = vec![Variable::new(Tensor::ones(&[2, 2]), true)];
let mut optimizer = SGD::new(params, 0.01, Some(0.9), None, None, None);
let mut scheduler = ReduceLROnPlateau::new(
&mut optimizer,
Some(PlateauMode::Min),
Some(0.5),
Some(2),
None,
None,
None,
None,
None,
);
let _initial_lrs = scheduler.get_lr();
scheduler.step_with_metric(1.0);
scheduler.step_with_metric(1.1);
scheduler.step_with_metric(1.2);
scheduler.step_with_metric(1.3);
let _reduced_lrs = scheduler.get_lr();
assert_eq!(scheduler.num_bad_epochs, 0); }
}