use scivex_core::Float;
pub trait LrScheduler<T: Float> {
fn get_lr(&self, step: usize) -> T;
fn base_lr(&self) -> T;
}
pub struct StepLR<T: Float> {
base: T,
step_size: usize,
gamma: T,
}
impl<T: Float> StepLR<T> {
pub fn new(base_lr: T, step_size: usize, gamma: T) -> Self {
Self {
base: base_lr,
step_size,
gamma,
}
}
}
impl<T: Float> LrScheduler<T> for StepLR<T> {
fn get_lr(&self, step: usize) -> T {
let n = step / self.step_size;
self.base * self.gamma.powf(T::from_usize(n))
}
fn base_lr(&self) -> T {
self.base
}
}
pub struct ExponentialLR<T: Float> {
base: T,
gamma: T,
}
impl<T: Float> ExponentialLR<T> {
pub fn new(base_lr: T, gamma: T) -> Self {
Self {
base: base_lr,
gamma,
}
}
}
impl<T: Float> LrScheduler<T> for ExponentialLR<T> {
fn get_lr(&self, step: usize) -> T {
self.base * self.gamma.powf(T::from_usize(step))
}
fn base_lr(&self) -> T {
self.base
}
}
pub struct CosineAnnealingLR<T: Float> {
base: T,
t_max: usize,
eta_min: T,
}
impl<T: Float> CosineAnnealingLR<T> {
pub fn new(base_lr: T, t_max: usize, eta_min: T) -> Self {
Self {
base: base_lr,
t_max,
eta_min,
}
}
}
impl<T: Float> LrScheduler<T> for CosineAnnealingLR<T> {
fn get_lr(&self, step: usize) -> T {
let clamped = if step >= self.t_max { self.t_max } else { step };
let pi = T::from_f64(std::f64::consts::PI);
let ratio = T::from_usize(clamped) / T::from_usize(self.t_max);
let cos_val = (pi * ratio).cos();
self.eta_min + T::from_f64(0.5) * (self.base - self.eta_min) * (T::one() + cos_val)
}
fn base_lr(&self) -> T {
self.base
}
}
pub struct LinearLR<T: Float> {
base: T,
start_factor: T,
end_factor: T,
total_steps: usize,
}
impl<T: Float> LinearLR<T> {
pub fn new(base_lr: T, start_factor: T, end_factor: T, total_steps: usize) -> Self {
Self {
base: base_lr,
start_factor,
end_factor,
total_steps,
}
}
}
impl<T: Float> LrScheduler<T> for LinearLR<T> {
fn get_lr(&self, step: usize) -> T {
let clamped = if step >= self.total_steps {
self.total_steps
} else {
step
};
let ratio = T::from_usize(clamped) / T::from_usize(self.total_steps);
let factor = self.start_factor + (self.end_factor - self.start_factor) * ratio;
self.base * factor
}
fn base_lr(&self) -> T {
self.base
}
}
pub struct WarmupCosineDecay<T: Float> {
base: T,
warmup_steps: usize,
total_steps: usize,
eta_min: T,
}
impl<T: Float> WarmupCosineDecay<T> {
pub fn new(base_lr: T, warmup_steps: usize, total_steps: usize, eta_min: T) -> Self {
Self {
base: base_lr,
warmup_steps,
total_steps,
eta_min,
}
}
}
impl<T: Float> LrScheduler<T> for WarmupCosineDecay<T> {
fn get_lr(&self, step: usize) -> T {
if step < self.warmup_steps {
self.base * T::from_usize(step + 1) / T::from_usize(self.warmup_steps)
} else {
let decay_steps = self.total_steps - self.warmup_steps;
if decay_steps == 0 {
return self.base;
}
let progress = step - self.warmup_steps;
let clamped = if progress >= decay_steps {
decay_steps
} else {
progress
};
let pi = T::from_f64(std::f64::consts::PI);
let ratio = T::from_usize(clamped) / T::from_usize(decay_steps);
let cos_val = (pi * ratio).cos();
self.eta_min + T::from_f64(0.5) * (self.base - self.eta_min) * (T::one() + cos_val)
}
}
fn base_lr(&self) -> T {
self.base
}
}
pub struct ReduceLROnPlateau<T: Float> {
current_lr: T,
factor: T,
patience: usize,
min_lr: T,
best: Option<T>,
num_bad_epochs: usize,
mode_min: bool,
}
impl<T: Float> ReduceLROnPlateau<T> {
pub fn new(initial_lr: T, factor: T, patience: usize, min_lr: T, mode_min: bool) -> Self {
Self {
current_lr: initial_lr,
factor,
patience,
min_lr,
best: None,
num_bad_epochs: 0,
mode_min,
}
}
pub fn report(&mut self, metric: T) -> T {
let improved = match self.best {
None => true,
Some(best) => {
if self.mode_min {
metric < best
} else {
metric > best
}
}
};
if improved {
self.best = Some(metric);
self.num_bad_epochs = 0;
} else {
self.num_bad_epochs += 1;
if self.num_bad_epochs >= self.patience {
let new_lr = self.current_lr * self.factor;
if new_lr >= self.min_lr {
self.current_lr = new_lr;
} else {
self.current_lr = self.min_lr;
}
self.num_bad_epochs = 0;
}
}
self.current_lr
}
pub fn current_lr(&self) -> T {
self.current_lr
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn step_lr_basic() {
let s = StepLR::new(0.1_f64, 10, 0.1);
assert!((s.get_lr(0) - 0.1).abs() < 1e-10);
assert!((s.get_lr(9) - 0.1).abs() < 1e-10);
assert!((s.get_lr(10) - 0.01).abs() < 1e-10);
assert!((s.get_lr(20) - 0.001).abs() < 1e-10);
}
#[test]
fn exponential_lr_basic() {
let s = ExponentialLR::new(1.0_f64, 0.9);
assert!((s.get_lr(0) - 1.0).abs() < 1e-10);
assert!((s.get_lr(1) - 0.9).abs() < 1e-10);
assert!((s.get_lr(2) - 0.81).abs() < 1e-10);
}
#[test]
fn cosine_annealing_endpoints() {
let s = CosineAnnealingLR::new(0.1_f64, 100, 0.0);
assert!((s.get_lr(0) - 0.1).abs() < 1e-10);
assert!(s.get_lr(100).abs() < 1e-10);
assert!((s.get_lr(50) - 0.05).abs() < 1e-6);
}
#[test]
fn linear_lr_ramp() {
let s = LinearLR::new(0.1_f64, 0.1, 1.0, 10);
assert!((s.get_lr(0) - 0.01).abs() < 1e-10);
assert!((s.get_lr(10) - 0.1).abs() < 1e-10);
assert!((s.get_lr(5) - 0.055).abs() < 1e-10);
}
#[test]
fn warmup_cosine_decay_phases() {
let s = WarmupCosineDecay::new(0.1_f64, 10, 110, 0.0);
assert!((s.get_lr(0) - 0.01).abs() < 1e-10); assert!((s.get_lr(4) - 0.05).abs() < 1e-10); assert!((s.get_lr(9) - 0.1).abs() < 1e-10); assert!((s.get_lr(10) - 0.1).abs() < 1e-6);
assert!(s.get_lr(110).abs() < 1e-6);
}
#[test]
fn reduce_on_plateau_reduces() {
let mut s = ReduceLROnPlateau::new(0.1_f64, 0.5, 3, 0.001, true);
assert!((s.report(1.0) - 0.1).abs() < 1e-10);
assert!((s.report(0.9) - 0.1).abs() < 1e-10);
assert!((s.report(0.95) - 0.1).abs() < 1e-10); assert!((s.report(0.95) - 0.1).abs() < 1e-10); let lr = s.report(0.95);
assert!((lr - 0.05).abs() < 1e-10);
}
#[test]
fn reduce_on_plateau_respects_min_lr() {
let mut s = ReduceLROnPlateau::new(0.01_f64, 0.1, 1, 0.005, true);
s.report(1.0);
let lr = s.report(2.0);
assert!((lr - 0.005).abs() < 1e-10);
}
}