use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::schedulers::{
CosineAnnealing, ExponentialDecay, LearningRateScheduler, LinearDecay, StepDecay,
};
#[derive(Debug, Clone)]
pub enum DecayStrategy<A: Float + Debug> {
Linear {
final_lr: A,
},
Exponential {
decay_rate: A,
},
Step {
decay_rate: A,
step_size: usize,
},
Cosine {
min_lr: A,
},
Constant,
}
#[derive(Debug)]
pub struct LinearWarmupDecay<A: Float + Debug> {
initial_lr: A,
min_lr: A,
warmup_steps: usize,
total_decay_steps: usize,
step: usize,
current_lr: A,
decay_strategy: DecayStrategy<A>,
warmup_complete: bool,
#[allow(clippy::missing_docs_in_private_items)]
inner_scheduler: Option<InnerScheduler<A>>,
}
#[allow(clippy::missing_docs_in_private_items)]
#[derive(Debug)]
enum InnerScheduler<A: Float + Debug> {
Linear(LinearDecay<A>),
Exponential(ExponentialDecay<A>),
Step(StepDecay<A>),
Cosine(CosineAnnealing<A>),
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LinearWarmupDecay<A> {
pub fn new(
initial_lr: A,
min_lr: A,
warmup_steps: usize,
total_decay_steps: usize,
decay_strategy: DecayStrategy<A>,
) -> Self {
Self {
initial_lr,
min_lr,
warmup_steps,
total_decay_steps,
step: 0,
current_lr: min_lr,
decay_strategy,
warmup_complete: false,
inner_scheduler: None,
}
}
fn initialize_decay_scheduler(&mut self) {
let scheduler = match self.decay_strategy {
DecayStrategy::Linear { final_lr } => InnerScheduler::Linear(LinearDecay::new(
self.initial_lr,
final_lr,
self.total_decay_steps,
)),
DecayStrategy::Exponential { decay_rate } => InnerScheduler::Exponential(
ExponentialDecay::new(self.initial_lr, decay_rate, self.total_decay_steps),
),
DecayStrategy::Step {
decay_rate,
step_size,
} => InnerScheduler::Step(StepDecay::new(self.initial_lr, decay_rate, step_size)),
DecayStrategy::Cosine { min_lr } => InnerScheduler::Cosine(CosineAnnealing::new(
self.initial_lr,
min_lr,
self.total_decay_steps,
false, )),
DecayStrategy::Constant => {
InnerScheduler::Linear(LinearDecay::new(
self.initial_lr,
self.initial_lr,
self.total_decay_steps,
))
}
};
self.inner_scheduler = Some(scheduler);
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> LearningRateScheduler<A>
for LinearWarmupDecay<A>
{
fn get_learning_rate(&self) -> A {
self.current_lr
}
fn step(&mut self) -> A {
self.step += 1;
if self.warmup_steps == 0 && !self.warmup_complete {
self.warmup_complete = true;
self.initialize_decay_scheduler();
}
if !self.warmup_complete && self.step <= self.warmup_steps {
let progress = if self.warmup_steps > 0 {
A::from(self.step).expect("unwrap failed")
/ A::from(self.warmup_steps).expect("unwrap failed")
} else {
A::one()
};
self.current_lr = self.min_lr + (self.initial_lr - self.min_lr) * progress;
if self.step == self.warmup_steps {
self.warmup_complete = true;
self.initialize_decay_scheduler();
}
} else if self.warmup_complete {
if let Some(scheduler) = &mut self.inner_scheduler {
self.current_lr = match scheduler {
InnerScheduler::Linear(s) => s.step(),
InnerScheduler::Exponential(s) => s.step(),
InnerScheduler::Step(s) => s.step(),
InnerScheduler::Cosine(s) => s.step(),
};
}
}
self.current_lr
}
fn reset(&mut self) {
self.step = 0;
self.current_lr = self.min_lr;
self.warmup_complete = false;
self.inner_scheduler = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_linear_warmup_linear_decay() {
let mut scheduler = LinearWarmupDecay::new(
0.1f64,
0.01,
10,
90,
DecayStrategy::Linear { final_lr: 0.001 },
);
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.01);
let mut warmup_lrs = Vec::new();
for _ in 0..10 {
warmup_lrs.push(scheduler.step());
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.1);
for i in 1..warmup_lrs.len() {
assert!(warmup_lrs[i] > warmup_lrs[i - 1]);
}
let mut decay_lrs = Vec::new();
for _ in 0..90 {
decay_lrs.push(scheduler.step());
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.001, epsilon = 1e-6);
for i in 1..decay_lrs.len() {
assert!(decay_lrs[i] < decay_lrs[i - 1]);
}
}
#[test]
fn test_linear_warmup_exponential_decay() {
let mut scheduler = LinearWarmupDecay::new(
0.1f64,
0.01,
10,
10,
DecayStrategy::Exponential { decay_rate: 0.1 },
);
for _ in 0..10 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.1);
let mut lrs = Vec::new();
for _ in 0..10 {
lrs.push(scheduler.step());
}
let final_lr = *lrs.last().expect("unwrap failed");
assert!(
final_lr < 0.05,
"Final learning rate {:.6} should be significantly less than initial 0.1",
final_lr
);
}
#[test]
fn test_linear_warmup_step_decay() {
let mut scheduler = LinearWarmupDecay::new(
0.1f64,
0.01,
10,
40,
DecayStrategy::Step {
decay_rate: 0.5,
step_size: 10,
},
);
for _ in 0..10 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.1);
for _ in 0..9 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.1);
scheduler.step(); assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.05);
for _ in 0..9 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.05);
scheduler.step(); assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.025);
}
#[test]
fn test_linear_warmup_cosine_decay() {
let mut scheduler = LinearWarmupDecay::new(
0.1f64,
0.01,
10,
90,
DecayStrategy::Cosine { min_lr: 0.001 },
);
for _ in 0..10 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.1);
let mut lrs = Vec::new();
for _ in 0..90 {
lrs.push(scheduler.step());
}
assert!(lrs[0] < 0.1);
let min_lr = lrs.iter().fold(1.0, |a, &b| a.min(b));
assert_abs_diff_eq!(min_lr, 0.001, epsilon = 1e-2);
}
#[test]
fn test_linear_warmup_constant() {
let mut scheduler = LinearWarmupDecay::new(0.1f64, 0.01, 10, 90, DecayStrategy::Constant);
for _ in 0..10 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.1);
for _ in 0..90 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.1);
}
#[test]
fn test_reset() {
let mut scheduler = LinearWarmupDecay::new(
0.1f64,
0.01,
5,
15,
DecayStrategy::Linear { final_lr: 0.001 },
);
for _ in 0..15 {
scheduler.step();
}
scheduler.reset();
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.01);
assert_eq!(scheduler.step, 0);
assert!(!scheduler.warmup_complete);
assert!(scheduler.inner_scheduler.is_none());
for _ in 0..5 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.1);
}
#[test]
fn test_zero_warmup() {
let mut scheduler = LinearWarmupDecay::new(
0.1f64,
0.01,
0,
10,
DecayStrategy::Linear { final_lr: 0.001 },
);
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.01);
scheduler.step();
assert!(scheduler.warmup_complete);
for _ in 0..9 {
scheduler.step();
}
assert_abs_diff_eq!(scheduler.get_learning_rate(), 0.001, epsilon = 1e-6);
}
}