use serde::{Deserialize, Serialize};
pub trait Scheduler: Send + Sync {
fn get_lr(&self, step: usize) -> f64;
fn name(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OneCycleLRConfig {
pub max_lr: f64,
pub total_steps: usize,
pub pct_start: f64,
pub div_factor: f64,
pub final_div_factor: f64,
}
impl Default for OneCycleLRConfig {
fn default() -> Self {
Self {
max_lr: 1e-3,
total_steps: 1000,
pct_start: 0.3,
div_factor: 25.0,
final_div_factor: 10000.0,
}
}
}
#[derive(Debug, Clone)]
pub struct OneCycleLR {
config: OneCycleLRConfig,
initial_lr: f64,
final_lr: f64,
warmup_steps: usize,
}
impl OneCycleLR {
pub fn new(config: OneCycleLRConfig) -> Self {
let initial_lr = config.max_lr / config.div_factor;
let final_lr = config.max_lr / config.final_div_factor;
let warmup_steps = (config.total_steps as f64 * config.pct_start) as usize;
Self {
config,
initial_lr,
final_lr,
warmup_steps,
}
}
pub fn simple(max_lr: f64, total_steps: usize) -> Self {
Self::new(OneCycleLRConfig {
max_lr,
total_steps,
..Default::default()
})
}
}
impl Scheduler for OneCycleLR {
fn get_lr(&self, step: usize) -> f64 {
let step = step.min(self.config.total_steps.saturating_sub(1));
if step < self.warmup_steps {
let progress = step as f64 / self.warmup_steps as f64;
self.initial_lr + (self.config.max_lr - self.initial_lr) * progress
} else {
let annealing_steps = self.config.total_steps - self.warmup_steps;
let progress = (step - self.warmup_steps) as f64 / annealing_steps as f64;
let cosine = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
self.final_lr + (self.config.max_lr - self.final_lr) * cosine
}
}
fn name(&self) -> &str {
"OneCycleLR"
}
}
#[derive(Debug, Clone)]
pub struct CosineAnnealingLR {
initial_lr: f64,
min_lr: f64,
total_steps: usize,
}
impl CosineAnnealingLR {
pub fn new(initial_lr: f64, min_lr: f64, total_steps: usize) -> Self {
Self {
initial_lr,
min_lr,
total_steps,
}
}
}
impl Scheduler for CosineAnnealingLR {
fn get_lr(&self, step: usize) -> f64 {
let step = step.min(self.total_steps.saturating_sub(1));
let progress = step as f64 / self.total_steps as f64;
let cosine = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
self.min_lr + (self.initial_lr - self.min_lr) * cosine
}
fn name(&self) -> &str {
"CosineAnnealingLR"
}
}
#[derive(Debug, Clone)]
pub struct StepLR {
initial_lr: f64,
step_size: usize,
gamma: f64,
}
impl StepLR {
pub fn new(initial_lr: f64, step_size: usize, gamma: f64) -> Self {
Self {
initial_lr,
step_size,
gamma,
}
}
}
impl Scheduler for StepLR {
fn get_lr(&self, step: usize) -> f64 {
let n_decays = step / self.step_size;
self.initial_lr * self.gamma.powi(n_decays as i32)
}
fn name(&self) -> &str {
"StepLR"
}
}
#[derive(Debug, Clone)]
pub struct ConstantLR {
lr: f64,
}
impl ConstantLR {
pub fn new(lr: f64) -> Self {
Self { lr }
}
}
impl Scheduler for ConstantLR {
fn get_lr(&self, _step: usize) -> f64 {
self.lr
}
fn name(&self) -> &str {
"ConstantLR"
}
}
#[derive(Debug, Clone)]
pub struct ExponentialLR {
initial_lr: f64,
gamma: f64,
}
impl ExponentialLR {
pub fn new(initial_lr: f64, gamma: f64) -> Self {
Self { initial_lr, gamma }
}
}
impl Scheduler for ExponentialLR {
fn get_lr(&self, step: usize) -> f64 {
self.initial_lr * self.gamma.powi(step as i32)
}
fn name(&self) -> &str {
"ExponentialLR"
}
}
#[derive(Debug, Clone)]
pub struct PolynomialLR {
initial_lr: f64,
end_lr: f64,
total_steps: usize,
power: f64,
}
impl PolynomialLR {
pub fn new(initial_lr: f64, end_lr: f64, total_steps: usize, power: f64) -> Self {
Self {
initial_lr,
end_lr,
total_steps,
power,
}
}
pub fn linear(initial_lr: f64, end_lr: f64, total_steps: usize) -> Self {
Self::new(initial_lr, end_lr, total_steps, 1.0)
}
}
impl Scheduler for PolynomialLR {
fn get_lr(&self, step: usize) -> f64 {
let step = step.min(self.total_steps);
let progress = step as f64 / self.total_steps as f64;
let decay = (1.0 - progress).powf(self.power);
self.end_lr + (self.initial_lr - self.end_lr) * decay
}
fn name(&self) -> &str {
"PolynomialLR"
}
}
#[derive(Debug, Clone)]
pub struct CosineAnnealingWarmRestarts {
initial_lr: f64,
min_lr: f64,
t_0: usize,
t_mult: usize,
}
impl CosineAnnealingWarmRestarts {
pub fn new(initial_lr: f64, min_lr: f64, t_0: usize, t_mult: usize) -> Self {
Self {
initial_lr,
min_lr,
t_0,
t_mult: t_mult.max(1),
}
}
}
impl Scheduler for CosineAnnealingWarmRestarts {
fn get_lr(&self, step: usize) -> f64 {
let mut cycle_start = 0;
let mut cycle_len = self.t_0;
let mut cycle = 0;
while step >= cycle_start + cycle_len {
cycle_start += cycle_len;
if self.t_mult > 1 {
cycle_len *= self.t_mult;
}
cycle += 1;
if cycle > 100 {
break;
}
}
let t_cur = step - cycle_start;
let progress = t_cur as f64 / cycle_len as f64;
let cosine = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
self.min_lr + (self.initial_lr - self.min_lr) * cosine
}
fn name(&self) -> &str {
"CosineAnnealingWarmRestarts"
}
}
#[derive(Debug, Clone)]
pub struct LinearWarmup<S: Scheduler> {
warmup_steps: usize,
warmup_start_lr: f64,
inner: S,
}
impl<S: Scheduler> LinearWarmup<S> {
pub fn new(warmup_steps: usize, warmup_start_lr: f64, inner: S) -> Self {
Self {
warmup_steps,
warmup_start_lr,
inner,
}
}
}
impl<S: Scheduler> Scheduler for LinearWarmup<S> {
fn get_lr(&self, step: usize) -> f64 {
if step < self.warmup_steps {
let target_lr = self.inner.get_lr(self.warmup_steps);
let progress = step as f64 / self.warmup_steps as f64;
self.warmup_start_lr + (target_lr - self.warmup_start_lr) * progress
} else {
self.inner.get_lr(step)
}
}
fn name(&self) -> &str {
"LinearWarmup"
}
}
#[derive(Debug, Clone)]
pub struct ReduceLROnPlateau {
current_lr: f64,
factor: f64,
patience: usize,
min_lr: f64,
mode: ReduceMode,
best_value: f64,
num_bad_epochs: usize,
cooldown: usize,
cooldown_counter: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceMode {
Min,
Max,
}
impl ReduceLROnPlateau {
pub fn new(
initial_lr: f64,
factor: f64,
patience: usize,
min_lr: f64,
mode: ReduceMode,
) -> Self {
let best_value = match mode {
ReduceMode::Min => f64::INFINITY,
ReduceMode::Max => f64::NEG_INFINITY,
};
Self {
current_lr: initial_lr,
factor,
patience,
min_lr,
mode,
best_value,
num_bad_epochs: 0,
cooldown: 0,
cooldown_counter: 0,
}
}
#[must_use]
pub fn with_cooldown(mut self, cooldown: usize) -> Self {
self.cooldown = cooldown;
self
}
pub fn step(&mut self, metric: f64) -> bool {
if self.cooldown_counter > 0 {
self.cooldown_counter -= 1;
return false;
}
let improved = match self.mode {
ReduceMode::Min => metric < self.best_value,
ReduceMode::Max => metric > self.best_value,
};
if improved {
self.best_value = metric;
self.num_bad_epochs = 0;
false
} else {
self.num_bad_epochs += 1;
if self.num_bad_epochs > self.patience {
let new_lr = (self.current_lr * self.factor).max(self.min_lr);
if new_lr < self.current_lr {
self.current_lr = new_lr;
self.num_bad_epochs = 0;
self.cooldown_counter = self.cooldown;
return true;
}
}
false
}
}
pub fn current_lr(&self) -> f64 {
self.current_lr
}
}
impl Scheduler for ReduceLROnPlateau {
fn get_lr(&self, _step: usize) -> f64 {
self.current_lr
}
fn name(&self) -> &str {
"ReduceLROnPlateau"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_one_cycle_lr() {
let scheduler = OneCycleLR::simple(1e-3, 1000);
let start_lr = scheduler.get_lr(0);
assert!(start_lr < 1e-3);
let peak_lr = scheduler.get_lr(300); assert!((peak_lr - 1e-3).abs() < 1e-6);
let end_lr = scheduler.get_lr(999);
assert!(end_lr < start_lr);
}
#[test]
fn test_cosine_annealing() {
let scheduler = CosineAnnealingLR::new(1e-3, 1e-5, 100);
assert!((scheduler.get_lr(0) - 1e-3).abs() < 1e-8);
assert!(scheduler.get_lr(50) < 1e-3);
assert!(scheduler.get_lr(100) < 1e-4); }
#[test]
fn test_step_lr() {
let scheduler = StepLR::new(1e-3, 10, 0.1);
assert!((scheduler.get_lr(0) - 1e-3).abs() < 1e-8);
assert!((scheduler.get_lr(10) - 1e-4).abs() < 1e-9);
assert!((scheduler.get_lr(20) - 1e-5).abs() < 1e-10);
}
#[test]
fn test_exponential_lr() {
let scheduler = ExponentialLR::new(1e-3, 0.9);
assert!((scheduler.get_lr(0) - 1e-3).abs() < 1e-8);
assert!((scheduler.get_lr(1) - 9e-4).abs() < 1e-8);
assert!(scheduler.get_lr(10) < 4e-4);
}
#[test]
fn test_polynomial_lr() {
let scheduler = PolynomialLR::linear(1e-3, 1e-5, 100);
assert!((scheduler.get_lr(0) - 1e-3).abs() < 1e-8);
let mid_lr = scheduler.get_lr(50);
assert!(mid_lr < 1e-3 && mid_lr > 1e-5);
assert!((scheduler.get_lr(100) - 1e-5).abs() < 1e-7);
}
#[test]
fn test_cosine_warm_restarts() {
let scheduler = CosineAnnealingWarmRestarts::new(1e-3, 1e-5, 10, 2);
assert!((scheduler.get_lr(0) - 1e-3).abs() < 1e-8);
assert!(scheduler.get_lr(9) < 2e-4);
assert!(scheduler.get_lr(10) > 9e-4);
}
#[test]
fn test_reduce_lr_on_plateau() {
let mut scheduler = ReduceLROnPlateau::new(1e-3, 0.1, 2, 1e-6, ReduceMode::Min);
assert!((scheduler.current_lr() - 1e-3).abs() < 1e-8);
scheduler.step(1.0);
scheduler.step(0.9);
assert!((scheduler.current_lr() - 1e-3).abs() < 1e-8);
scheduler.step(0.95); scheduler.step(0.95); let reduced = scheduler.step(0.95); assert!(reduced);
assert!((scheduler.current_lr() - 1e-4).abs() < 1e-9);
}
#[test]
fn test_linear_warmup() {
let inner = ConstantLR::new(1e-3);
let scheduler = LinearWarmup::new(10, 0.0, inner);
assert!(scheduler.get_lr(0) < 1e-5);
assert!((scheduler.get_lr(5) - 5e-4).abs() < 1e-6);
assert!((scheduler.get_lr(10) - 1e-3).abs() < 1e-8);
}
}