use core::f64::consts::PI;
pub trait LRScheduler: Send + Sync {
fn learning_rate(&mut self, step: u64, current_loss: f64) -> f64;
fn reset(&mut self);
}
#[derive(Clone, Debug)]
pub struct ConstantLR {
lr: f64,
}
impl ConstantLR {
pub fn new(lr: f64) -> Self {
Self { lr }
}
}
impl LRScheduler for ConstantLR {
#[inline]
fn learning_rate(&mut self, _step: u64, _current_loss: f64) -> f64 {
self.lr
}
fn reset(&mut self) {
}
}
#[derive(Clone, Debug)]
pub struct LinearDecayLR {
initial_lr: f64,
final_lr: f64,
decay_steps: u64,
}
impl LinearDecayLR {
pub fn new(initial_lr: f64, final_lr: f64, decay_steps: u64) -> Self {
Self {
initial_lr,
final_lr,
decay_steps,
}
}
}
impl LRScheduler for LinearDecayLR {
#[inline]
fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
let t = if self.decay_steps == 0 {
1.0
} else {
(step as f64 / self.decay_steps as f64).min(1.0)
};
self.initial_lr - (self.initial_lr - self.final_lr) * t
}
fn reset(&mut self) {
}
}
#[derive(Clone, Debug)]
pub struct ExponentialDecayLR {
initial_lr: f64,
gamma: f64,
}
impl ExponentialDecayLR {
pub fn new(initial_lr: f64, gamma: f64) -> Self {
Self { initial_lr, gamma }
}
}
impl LRScheduler for ExponentialDecayLR {
#[inline]
fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
crate::math::fmax(
self.initial_lr * crate::math::powi(self.gamma, step as i32),
1e-8,
)
}
fn reset(&mut self) {
}
}
#[derive(Clone, Debug)]
pub struct CosineAnnealingLR {
max_lr: f64,
min_lr: f64,
period: u64,
}
impl CosineAnnealingLR {
pub fn new(max_lr: f64, min_lr: f64, period: u64) -> Self {
Self {
max_lr,
min_lr,
period,
}
}
}
impl LRScheduler for CosineAnnealingLR {
#[inline]
fn learning_rate(&mut self, step: u64, _current_loss: f64) -> f64 {
let phase = if self.period == 0 {
0.0
} else {
(step % self.period) as f64 / self.period as f64
};
self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + crate::math::cos(2.0 * PI * phase))
}
fn reset(&mut self) {
}
}
#[derive(Clone, Debug)]
pub struct PlateauLR {
initial_lr: f64,
factor: f64,
patience: u64,
min_lr: f64,
best_loss: f64,
steps_without_improvement: u64,
current_lr: f64,
}
impl PlateauLR {
pub fn new(initial_lr: f64, factor: f64, patience: u64, min_lr: f64) -> Self {
Self {
initial_lr,
factor,
patience,
min_lr,
best_loss: f64::INFINITY,
steps_without_improvement: 0,
current_lr: initial_lr,
}
}
}
impl LRScheduler for PlateauLR {
fn learning_rate(&mut self, _step: u64, current_loss: f64) -> f64 {
if current_loss < self.best_loss {
self.best_loss = current_loss;
self.steps_without_improvement = 0;
} else {
self.steps_without_improvement += 1;
if self.steps_without_improvement > self.patience {
self.current_lr = (self.current_lr * self.factor).max(self.min_lr);
self.steps_without_improvement = 0;
}
}
self.current_lr
}
fn reset(&mut self) {
self.best_loss = f64::INFINITY;
self.steps_without_improvement = 0;
self.current_lr = self.initial_lr;
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
#[test]
fn test_constant_lr() {
let mut sched = ConstantLR::new(0.05);
for step in 0..100 {
let lr = sched.learning_rate(step, 999.0);
assert!(
(lr - 0.05).abs() < f64::EPSILON,
"ConstantLR should always return 0.05, got {} at step {}",
lr,
step,
);
}
}
#[test]
fn test_linear_decay() {
let mut sched = LinearDecayLR::new(0.1, 0.01, 100);
let lr0 = sched.learning_rate(0, 0.0);
assert!(
(lr0 - 0.1).abs() < 1e-12,
"step 0 should be initial_lr (0.1), got {}",
lr0,
);
let lr50 = sched.learning_rate(50, 0.0);
let expected_50 = 0.1 - (0.1 - 0.01) * 0.5;
assert!(
(lr50 - expected_50).abs() < 1e-12,
"step 50 should be {}, got {}",
expected_50,
lr50,
);
let lr100 = sched.learning_rate(100, 0.0);
assert!(
(lr100 - 0.01).abs() < 1e-12,
"step 100 should be final_lr (0.01), got {}",
lr100,
);
}
#[test]
fn test_linear_decay_clamps() {
let mut sched = LinearDecayLR::new(0.1, 0.01, 50);
let lr_before = sched.learning_rate(50, 0.0);
let lr_after = sched.learning_rate(200, 0.0);
assert!(
(lr_before - 0.01).abs() < 1e-12,
"at decay_steps should be final_lr, got {}",
lr_before,
);
assert!(
(lr_after - 0.01).abs() < 1e-12,
"beyond decay_steps should still be final_lr, got {}",
lr_after,
);
}
#[test]
fn test_exponential_decay() {
let mut sched = ExponentialDecayLR::new(1.0, 0.9);
let lr0 = sched.learning_rate(0, 0.0);
assert!(
(lr0 - 1.0).abs() < 1e-12,
"step 0 should be initial_lr (1.0), got {}",
lr0,
);
let lr1 = sched.learning_rate(1, 0.0);
assert!(
(lr1 - 0.9).abs() < 1e-12,
"step 1 should be 0.9, got {}",
lr1,
);
let lr2 = sched.learning_rate(2, 0.0);
assert!(
(lr2 - 0.81).abs() < 1e-12,
"step 2 should be 0.81, got {}",
lr2,
);
let lr10 = sched.learning_rate(10, 0.0);
let expected_10 = 0.9_f64.powi(10);
assert!(
(lr10 - expected_10).abs() < 1e-10,
"step 10 should be {}, got {}",
expected_10,
lr10,
);
}
#[test]
fn test_exponential_floor() {
let mut sched = ExponentialDecayLR::new(1.0, 0.01);
let lr = sched.learning_rate(10_000, 0.0);
assert!(
lr >= 1e-8,
"exponential decay should floor at 1e-8, got {}",
lr,
);
assert!(
(lr - 1e-8).abs() < 1e-15,
"at extreme steps the rate should equal the floor, got {}",
lr,
);
}
#[test]
fn test_cosine_annealing() {
let mut sched = CosineAnnealingLR::new(0.1, 0.01, 100);
let lr0 = sched.learning_rate(0, 0.0);
assert!(
(lr0 - 0.1).abs() < 1e-12,
"period start should be max_lr (0.1), got {}",
lr0,
);
let lr50 = sched.learning_rate(50, 0.0);
assert!(
(lr50 - 0.01).abs() < 1e-12,
"period midpoint should be min_lr (0.01), got {}",
lr50,
);
let lr25 = sched.learning_rate(25, 0.0);
let expected_25 = 0.01 + 0.5 * (0.1 - 0.01) * (1.0 + (2.0 * PI * 0.25).cos());
assert!(
(lr25 - expected_25).abs() < 1e-12,
"quarter-period should be {}, got {}",
expected_25,
lr25,
);
}
#[test]
fn test_cosine_boundaries() {
let mut sched = CosineAnnealingLR::new(0.1, 0.01, 100);
let at_boundary = sched.learning_rate(100, 0.0);
assert!(
(at_boundary - 0.1).abs() < 1e-12,
"step==period should wrap to max_lr, got {}",
at_boundary,
);
let second_mid = sched.learning_rate(150, 0.0);
assert!(
(second_mid - 0.01).abs() < 1e-12,
"second period midpoint should be min_lr, got {}",
second_mid,
);
}
#[test]
fn test_plateau_reduces() {
let mut sched = PlateauLR::new(0.1, 0.5, 3, 0.001);
let lr = sched.learning_rate(0, 1.0);
assert!(
(lr - 0.1).abs() < 1e-12,
"initial rate should be 0.1, got {}",
lr
);
sched.learning_rate(1, 1.0); sched.learning_rate(2, 1.0); sched.learning_rate(3, 1.0);
let lr_reduced = sched.learning_rate(4, 1.0);
assert!(
(lr_reduced - 0.05).abs() < 1e-12,
"after patience exceeded, rate should be 0.1*0.5 = 0.05, got {}",
lr_reduced,
);
}
#[test]
fn test_plateau_improvement_resets() {
let mut sched = PlateauLR::new(0.1, 0.5, 2, 0.001);
sched.learning_rate(0, 1.0);
sched.learning_rate(1, 1.5);
sched.learning_rate(2, 1.5);
sched.learning_rate(3, 0.5);
sched.learning_rate(4, 0.6);
let lr = sched.learning_rate(5, 0.6);
assert!(
(lr - 0.1).abs() < 1e-12,
"improvement should have reset counter; rate should be 0.1, got {}",
lr,
);
}
#[test]
fn test_plateau_min_lr() {
let mut sched = PlateauLR::new(0.1, 0.1, 0, 0.05);
sched.learning_rate(0, 1.0);
sched.learning_rate(1, 1.0);
let lr = sched.learning_rate(2, 1.0);
assert!(
lr >= 0.05 - 1e-12,
"rate should never drop below min_lr (0.05), got {}",
lr,
);
}
#[test]
fn test_plateau_reset() {
let mut sched = PlateauLR::new(0.1, 0.5, 1, 0.001);
sched.learning_rate(0, 1.0);
sched.learning_rate(1, 1.0);
sched.learning_rate(2, 1.0);
let lr_before_reset = sched.current_lr;
assert!(
lr_before_reset < 0.1,
"rate should have decreased before reset, got {}",
lr_before_reset,
);
sched.reset();
let lr_after = sched.learning_rate(0, 10.0);
assert!(
(lr_after - 0.1).abs() < 1e-12,
"after reset, rate should be back to initial_lr (0.1), got {}",
lr_after,
);
}
#[test]
fn test_all_positive() {
let mut schedulers: Vec<Box<dyn LRScheduler>> = vec![
Box::new(ConstantLR::new(0.05)),
Box::new(LinearDecayLR::new(0.1, 0.001, 100)),
Box::new(ExponentialDecayLR::new(1.0, 0.99)),
Box::new(CosineAnnealingLR::new(0.1, 0.001, 50)),
Box::new(PlateauLR::new(0.1, 0.5, 5, 0.001)),
];
for (i, sched) in schedulers.iter_mut().enumerate() {
for step in 0..500 {
let lr = sched.learning_rate(step, 1.0);
assert!(
lr > 0.0,
"scheduler {} returned non-positive lr {} at step {}",
i,
lr,
step,
);
assert!(
lr.is_finite(),
"scheduler {} returned non-finite lr at step {}",
i,
step,
);
}
}
}
}