use serde::{Deserialize, Serialize};
pub trait LRScheduler {
fn get_lr(&self, step: usize) -> f64;
fn last_lr(&self) -> f64;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstantScheduler {
lr: f64,
}
impl ConstantScheduler {
pub fn new(lr: f64) -> Self {
Self { lr }
}
}
impl LRScheduler for ConstantScheduler {
fn get_lr(&self, _step: usize) -> f64 {
self.lr
}
fn last_lr(&self) -> f64 {
self.lr
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LinearScheduler {
initial_lr: f64,
final_lr: f64,
warmup_steps: usize,
total_steps: usize,
last_lr: f64,
}
impl LinearScheduler {
pub fn new(initial_lr: f64, final_lr: f64, total_steps: usize, warmup_steps: usize) -> Self {
Self {
initial_lr,
final_lr,
warmup_steps,
total_steps,
last_lr: initial_lr,
}
}
}
impl LRScheduler for LinearScheduler {
fn get_lr(&self, step: usize) -> f64 {
if step < self.warmup_steps {
self.initial_lr * (step as f64 / self.warmup_steps as f64)
} else {
let progress =
(step - self.warmup_steps) as f64 / (self.total_steps - self.warmup_steps) as f64;
self.initial_lr + (self.final_lr - self.initial_lr) * progress
}
}
fn last_lr(&self) -> f64 {
self.last_lr
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CosineScheduler {
max_lr: f64,
min_lr: f64,
total_steps: usize,
warmup_steps: usize,
last_lr: f64,
}
impl CosineScheduler {
pub fn new(max_lr: f64, total_steps: usize, warmup_steps: usize) -> Self {
Self {
max_lr,
min_lr: 0.0,
total_steps,
warmup_steps,
last_lr: max_lr,
}
}
pub fn with_min_lr(mut self, min_lr: f64) -> Self {
self.min_lr = min_lr;
self
}
}
impl LRScheduler for CosineScheduler {
fn get_lr(&self, step: usize) -> f64 {
if step < self.warmup_steps {
self.max_lr * (step as f64 / self.warmup_steps as f64)
} else {
let progress =
(step - self.warmup_steps) as f64 / (self.total_steps - self.warmup_steps) as f64;
let cosine = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
self.min_lr + (self.max_lr - self.min_lr) * cosine
}
}
fn last_lr(&self) -> f64 {
self.last_lr
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepScheduler {
initial_lr: f64,
decay_factor: f64,
milestones: Vec<usize>,
last_lr: f64,
}
impl StepScheduler {
pub fn new(initial_lr: f64, decay_factor: f64, milestones: Vec<usize>) -> Self {
Self {
initial_lr,
decay_factor,
milestones,
last_lr: initial_lr,
}
}
}
impl LRScheduler for StepScheduler {
fn get_lr(&self, step: usize) -> f64 {
let num_decays = self.milestones.iter().filter(|&&m| step >= m).count();
self.initial_lr * self.decay_factor.powi(num_decays as i32)
}
fn last_lr(&self) -> f64 {
self.last_lr
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExponentialScheduler {
initial_lr: f64,
decay_rate: f64,
decay_steps: usize,
last_lr: f64,
}
impl ExponentialScheduler {
pub fn new(initial_lr: f64, decay_rate: f64, decay_steps: usize) -> Self {
Self {
initial_lr,
decay_rate,
decay_steps,
last_lr: initial_lr,
}
}
}
impl LRScheduler for ExponentialScheduler {
fn get_lr(&self, step: usize) -> f64 {
let num_decays = step / self.decay_steps;
self.initial_lr * self.decay_rate.powi(num_decays as i32)
}
fn last_lr(&self) -> f64 {
self.last_lr
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OneCycleScheduler {
initial_lr: f64,
max_lr: f64,
final_lr: f64,
total_steps: usize,
warmup_pct: f64,
last_lr: f64,
}
impl OneCycleScheduler {
pub fn new(max_lr: f64, total_steps: usize) -> Self {
Self {
initial_lr: max_lr / 25.0,
max_lr,
final_lr: max_lr / 10000.0,
total_steps,
warmup_pct: 0.3,
last_lr: max_lr / 25.0,
}
}
pub fn with_warmup_pct(mut self, warmup_pct: f64) -> Self {
self.warmup_pct = warmup_pct.clamp(0.0, 1.0);
self
}
pub fn with_div_factor(mut self, div_factor: f64) -> Self {
self.initial_lr = self.max_lr / div_factor;
self
}
pub fn with_final_div_factor(mut self, final_div_factor: f64) -> Self {
self.final_lr = self.max_lr / final_div_factor;
self
}
}
impl LRScheduler for OneCycleScheduler {
fn get_lr(&self, step: usize) -> f64 {
let warmup_steps = (self.total_steps as f64 * self.warmup_pct) as usize;
if step < warmup_steps {
let progress = step as f64 / warmup_steps as f64;
self.initial_lr + (self.max_lr - self.initial_lr) * progress
} else {
let progress = (step - warmup_steps) as f64 / (self.total_steps - warmup_steps) as f64;
let cosine = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
self.final_lr + (self.max_lr - self.final_lr) * cosine
}
}
fn last_lr(&self) -> f64 {
self.last_lr
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolynomialScheduler {
initial_lr: f64,
final_lr: f64,
total_steps: usize,
power: f64,
last_lr: f64,
}
impl PolynomialScheduler {
pub fn new(initial_lr: f64, final_lr: f64, total_steps: usize, power: f64) -> Self {
Self {
initial_lr,
final_lr,
total_steps,
power,
last_lr: initial_lr,
}
}
}
impl LRScheduler for PolynomialScheduler {
fn get_lr(&self, step: usize) -> f64 {
if step >= self.total_steps {
return self.final_lr;
}
let progress = step as f64 / self.total_steps as f64;
let decay = (1.0 - progress).powf(self.power);
self.final_lr + (self.initial_lr - self.final_lr) * decay
}
fn last_lr(&self) -> f64 {
self.last_lr
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_scheduler() {
let scheduler = ConstantScheduler::new(1e-3);
assert_eq!(scheduler.get_lr(0), 1e-3);
assert_eq!(scheduler.get_lr(100), 1e-3);
assert_eq!(scheduler.get_lr(1000), 1e-3);
}
#[test]
fn test_linear_scheduler() {
let scheduler = LinearScheduler::new(1e-3, 1e-5, 1000, 100);
assert!(scheduler.get_lr(0) < scheduler.get_lr(50));
assert!(scheduler.get_lr(50) < scheduler.get_lr(100));
assert!(scheduler.get_lr(500) > scheduler.get_lr(750));
assert!(scheduler.get_lr(750) > scheduler.get_lr(1000));
}
#[test]
fn test_cosine_scheduler() {
let scheduler = CosineScheduler::new(1e-3, 1000, 100).with_min_lr(1e-5);
let lr_0 = scheduler.get_lr(0);
let lr_50 = scheduler.get_lr(50);
let lr_100 = scheduler.get_lr(100);
assert!(lr_0 < lr_50);
assert!(lr_50 < lr_100);
let lr_500 = scheduler.get_lr(500);
let lr_1000 = scheduler.get_lr(1000);
assert!(lr_500 > lr_1000);
assert!((lr_1000 - 1e-5).abs() < 1e-6);
}
#[test]
fn test_step_scheduler() {
let scheduler = StepScheduler::new(1.0, 0.1, vec![100, 200, 300]);
assert!((scheduler.get_lr(0) - 1.0).abs() < 1e-10);
assert!((scheduler.get_lr(99) - 1.0).abs() < 1e-10);
assert!((scheduler.get_lr(100) - 0.1).abs() < 1e-10);
assert!((scheduler.get_lr(199) - 0.1).abs() < 1e-10);
assert!((scheduler.get_lr(200) - 0.01).abs() < 1e-10);
assert!((scheduler.get_lr(300) - 0.001).abs() < 1e-10);
}
#[test]
fn test_exponential_scheduler() {
let scheduler = ExponentialScheduler::new(1.0, 0.96, 100);
let lr_0 = scheduler.get_lr(0);
let lr_100 = scheduler.get_lr(100);
let lr_200 = scheduler.get_lr(200);
assert_eq!(lr_0, 1.0);
assert!((lr_100 - 0.96).abs() < 1e-6);
assert!((lr_200 - 0.96 * 0.96).abs() < 1e-6);
}
#[test]
fn test_onecycle_scheduler() {
let scheduler = OneCycleScheduler::new(1e-3, 1000).with_warmup_pct(0.3);
let lr_0 = scheduler.get_lr(0);
let lr_150 = scheduler.get_lr(150); let lr_300 = scheduler.get_lr(300); let lr_650 = scheduler.get_lr(650); let lr_1000 = scheduler.get_lr(1000);
assert!(lr_0 < lr_150);
assert!(lr_150 < lr_300);
assert!(lr_300 > lr_650);
assert!(lr_650 > lr_1000);
assert!((lr_300 - 1e-3).abs() < 1e-4);
}
#[test]
fn test_polynomial_scheduler() {
let scheduler = PolynomialScheduler::new(1.0, 0.1, 1000, 2.0);
let lr_0 = scheduler.get_lr(0);
let lr_500 = scheduler.get_lr(500);
let lr_1000 = scheduler.get_lr(1000);
assert_eq!(lr_0, 1.0);
assert!(lr_500 > lr_1000);
assert!((lr_1000 - 0.1).abs() < 1e-6);
let lr_250 = scheduler.get_lr(250);
let lr_750 = scheduler.get_lr(750);
let early_diff = lr_0 - lr_250;
let late_diff = lr_500 - lr_750;
assert!(early_diff > late_diff);
}
}