use std::f32::consts::PI;
#[derive(Debug, Clone)]
pub struct OneCycleLr {
pub max_lr: f32,
pub min_lr: f32,
pub total_steps: usize,
pub warmup_steps: usize,
step: usize,
}
impl OneCycleLr {
pub fn new(max_lr: f32, total_steps: usize) -> Self {
let warmup_steps = (total_steps as f32 * 0.3) as usize;
let min_lr = max_lr / 10_000.0_f32;
Self {
max_lr,
min_lr,
total_steps,
warmup_steps,
step: 0,
}
}
pub fn with_warmup_fraction(mut self, fraction: f32) -> Self {
let fraction = fraction.clamp(0.0, 1.0);
self.warmup_steps = (self.total_steps as f32 * fraction) as usize;
self
}
pub fn with_min_lr(mut self, min_lr: f32) -> Self {
self.min_lr = min_lr;
self
}
pub fn current_lr(&self) -> f32 {
let s = self.step.min(self.total_steps);
if s < self.warmup_steps {
if self.warmup_steps == 0 {
return self.max_lr;
}
let t = s as f32 / self.warmup_steps as f32;
self.min_lr + t * (self.max_lr - self.min_lr)
} else {
let decay_steps = self.total_steps.saturating_sub(self.warmup_steps);
if decay_steps == 0 {
return self.min_lr;
}
let elapsed = s.saturating_sub(self.warmup_steps);
let progress = (elapsed as f32 / decay_steps as f32).min(1.0);
self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + (PI * progress).cos())
}
}
pub fn step(&mut self) -> f32 {
let lr = self.current_lr();
self.step = (self.step + 1).min(self.total_steps);
lr
}
pub fn progress(&self) -> f32 {
if self.total_steps == 0 {
return 1.0;
}
(self.step as f32 / self.total_steps as f32).min(1.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PlateauMode {
Min,
Max,
}
#[derive(Debug, Clone)]
pub struct ReduceOnPlateau {
lr: f32,
factor: f32,
patience: usize,
min_lr: f32,
best_metric: f32,
bad_steps: usize,
mode: PlateauMode,
reduction_count: usize,
}
impl ReduceOnPlateau {
pub fn new(initial_lr: f32, patience: usize, mode: PlateauMode) -> Self {
let best_metric = match mode {
PlateauMode::Min => f32::INFINITY,
PlateauMode::Max => f32::NEG_INFINITY,
};
Self {
lr: initial_lr,
factor: 0.5,
patience,
min_lr: 1e-8,
best_metric,
bad_steps: 0,
mode,
reduction_count: 0,
}
}
pub fn with_factor(mut self, factor: f32) -> Self {
self.factor = factor;
self
}
pub fn with_min_lr(mut self, min_lr: f32) -> Self {
self.min_lr = min_lr;
self
}
pub fn step(&mut self, metric: f32) -> f32 {
let improved = match self.mode {
PlateauMode::Min => metric < self.best_metric,
PlateauMode::Max => metric > self.best_metric,
};
if improved {
self.best_metric = metric;
self.bad_steps = 0;
} else {
self.bad_steps += 1;
if self.bad_steps >= self.patience {
let new_lr = (self.lr * self.factor).max(self.min_lr);
if new_lr < self.lr {
self.lr = new_lr;
self.reduction_count += 1;
}
self.bad_steps = 0;
}
}
self.lr
}
pub fn current_lr(&self) -> f32 {
self.lr
}
pub fn times_reduced(&self) -> usize {
self.reduction_count
}
}
#[derive(Debug, Clone)]
pub struct LinearWarmupCosineDecay {
pub max_lr: f32,
pub min_lr: f32,
pub warmup_steps: usize,
pub total_steps: usize,
step: usize,
}
impl LinearWarmupCosineDecay {
pub fn new(max_lr: f32, warmup_steps: usize, total_steps: usize) -> Self {
Self {
max_lr,
min_lr: 0.0,
warmup_steps,
total_steps,
step: 0,
}
}
pub fn with_min_lr(mut self, min_lr: f32) -> Self {
self.min_lr = min_lr;
self
}
pub fn current_lr(&self) -> f32 {
let s = self.step.min(self.total_steps);
if s < self.warmup_steps {
if self.warmup_steps == 0 {
return self.max_lr;
}
self.max_lr * (s as f32 / self.warmup_steps as f32)
} else {
let cosine_steps = self.total_steps.saturating_sub(self.warmup_steps);
if cosine_steps == 0 {
return self.min_lr;
}
let elapsed = s.saturating_sub(self.warmup_steps);
let progress = (elapsed as f32 / cosine_steps as f32).min(1.0);
self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + (PI * progress).cos())
}
}
pub fn step(&mut self) -> f32 {
let lr = self.current_lr();
self.step = (self.step + 1).min(self.total_steps);
lr
}
}
#[derive(Debug, Clone)]
pub struct PolynomialDecay {
pub initial_lr: f32,
pub end_lr: f32,
pub total_steps: usize,
pub power: f32,
step: usize,
}
impl PolynomialDecay {
pub fn new(initial_lr: f32, end_lr: f32, total_steps: usize, power: f32) -> Self {
Self {
initial_lr,
end_lr,
total_steps,
power,
step: 0,
}
}
pub fn current_lr(&self) -> f32 {
if self.total_steps == 0 || self.step >= self.total_steps {
return self.end_lr;
}
let t = self.step as f32 / self.total_steps as f32;
let decay = (1.0 - t).powf(self.power);
(self.initial_lr - self.end_lr) * decay + self.end_lr
}
pub fn step(&mut self) -> f32 {
let lr = self.current_lr();
self.step = (self.step + 1).min(self.total_steps);
lr
}
}
#[derive(Debug, Clone)]
pub struct CyclicLr {
pub base_lr: f32,
pub max_lr: f32,
pub step_size: usize,
step: usize,
}
impl CyclicLr {
pub fn new(base_lr: f32, max_lr: f32, step_size: usize) -> Self {
Self {
base_lr,
max_lr,
step_size,
step: 0,
}
}
pub fn cycle_position(&self) -> f32 {
if self.step_size == 0 {
return 0.0;
}
let cycle_len = 2 * self.step_size;
let pos_in_cycle = self.step % cycle_len;
pos_in_cycle as f32 / cycle_len as f32
}
pub fn current_lr(&self) -> f32 {
if self.step_size == 0 {
return self.base_lr;
}
let cycle_len = 2 * self.step_size;
let pos_in_cycle = self.step % cycle_len;
let t = if pos_in_cycle < self.step_size {
pos_in_cycle as f32 / self.step_size as f32
} else {
1.0 - (pos_in_cycle - self.step_size) as f32 / self.step_size as f32
};
self.base_lr + t * (self.max_lr - self.base_lr)
}
pub fn step(&mut self) -> f32 {
let lr = self.current_lr();
self.step += 1;
lr
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f32 = 1e-5;
fn approx_eq(a: f32, b: f32, tol: f32) -> bool {
(a - b).abs() < tol
}
#[test]
fn onecycle_starts_at_min_lr() {
let sched = OneCycleLr::new(1.0, 100)
.with_min_lr(0.01)
.with_warmup_fraction(0.3);
let lr = sched.current_lr();
assert!(
approx_eq(lr, 0.01, 1e-3),
"first LR should be ~min_lr, got {lr}"
);
}
#[test]
fn onecycle_peaks_at_warmup() {
let total = 100_usize;
let warmup_frac = 0.3_f32;
let mut sched = OneCycleLr::new(1.0, total)
.with_min_lr(0.0)
.with_warmup_fraction(warmup_frac);
let warmup_steps = (total as f32 * warmup_frac) as usize;
let mut lr_at_peak = 0.0_f32;
for i in 0..=total {
let lr = sched.step();
if i + 1 == warmup_steps {
lr_at_peak = lr;
}
}
assert!(
approx_eq(lr_at_peak, 1.0, 0.05),
"LR should peak near max_lr at warmup boundary, got {lr_at_peak}"
);
}
#[test]
fn onecycle_ends_at_min_lr() {
let total = 100_usize;
let max_lr = 1.0_f32;
let min_lr = 1e-4_f32;
let sched = {
let mut s = OneCycleLr::new(max_lr, total)
.with_min_lr(min_lr)
.with_warmup_fraction(0.3);
for _ in 0..total {
s.step();
}
s
};
let lr = sched.current_lr();
assert!(
approx_eq(lr, min_lr, min_lr * 10.0),
"final LR should be ~min_lr, got {lr}"
);
}
#[test]
fn onecycle_progress_monotone() {
let total = 50_usize;
let mut sched = OneCycleLr::new(1.0, total);
let mut prev = sched.progress();
for _ in 0..total {
sched.step();
let p = sched.progress();
assert!(p >= prev, "progress must be non-decreasing: {prev} → {p}");
prev = p;
}
assert!(approx_eq(prev, 1.0, EPS), "progress must reach 1.0 at end");
}
#[test]
fn reduce_plateau_min_mode_reduces_lr() {
let patience = 3_usize;
let mut sched = ReduceOnPlateau::new(1e-2, patience, PlateauMode::Min);
sched.step(1.0); for _ in 0..patience {
sched.step(1.0); }
assert_eq!(
sched.times_reduced(),
1,
"should have reduced once after patience steps"
);
assert!(sched.current_lr() < 1e-2, "LR should have decreased");
}
#[test]
fn reduce_plateau_improvement_keeps_lr() {
let mut sched = ReduceOnPlateau::new(1e-2, 3, PlateauMode::Min);
for i in 0..20_usize {
sched.step(1.0 / (i + 1) as f32);
}
assert_eq!(
sched.times_reduced(),
0,
"should not reduce when metric improves"
);
assert!(approx_eq(sched.current_lr(), 1e-2, EPS));
}
#[test]
fn reduce_plateau_min_lr_floor() {
let min_lr = 1e-5_f32;
let mut sched = ReduceOnPlateau::new(1e-3, 1, PlateauMode::Min).with_min_lr(min_lr);
for _ in 0..100 {
sched.step(1.0);
}
assert!(
sched.current_lr() >= min_lr,
"LR must never go below min_lr, got {}",
sched.current_lr()
);
}
#[test]
fn linear_warmup_cosine_warmup_phase_increases() {
let warmup = 10_usize;
let total = 100_usize;
let mut sched = LinearWarmupCosineDecay::new(1.0, warmup, total);
let mut prev = -1.0_f32;
for _ in 0..warmup {
let lr = sched.step();
assert!(lr >= prev, "LR must increase during warmup: {prev} → {lr}");
prev = lr;
}
}
#[test]
fn linear_warmup_cosine_decay_phase_decreases() {
let warmup = 10_usize;
let total = 100_usize;
let mut sched = LinearWarmupCosineDecay::new(1.0, warmup, total).with_min_lr(0.0);
for _ in 0..warmup {
sched.step();
}
let mut prev = f32::INFINITY;
for _ in warmup..total {
let lr = sched.step();
assert!(
lr <= prev + EPS,
"LR must decrease (or stay) during decay: {prev} → {lr}"
);
prev = lr;
}
}
#[test]
fn polynomial_decay_starts_at_initial_lr() {
let sched = PolynomialDecay::new(1e-3, 1e-6, 1000, 1.0);
let first = sched.current_lr();
assert!(
approx_eq(first, 1e-3, 1e-7),
"should start at initial_lr, got {first}"
);
}
#[test]
fn polynomial_decay_ends_at_end_lr() {
let end_lr = 1e-6_f32;
let mut sched = PolynomialDecay::new(1e-3, end_lr, 100, 1.0);
for _ in 0..100 {
sched.step();
}
let last = sched.current_lr();
assert!(
approx_eq(last, end_lr, 1e-9),
"should end at end_lr, got {last}"
);
}
#[test]
fn cyclic_lr_oscillates() {
let base = 1e-4_f32;
let max = 1e-2_f32;
let step_size = 10_usize;
let mut sched = CyclicLr::new(base, max, step_size);
let lrs: Vec<f32> = (0..2 * step_size).map(|_| sched.step()).collect();
for i in 1..step_size {
assert!(
lrs[i] >= lrs[i - 1] - EPS,
"should rise in first half: lrs[{i}]={} < lrs[{}]={}",
lrs[i],
i - 1,
lrs[i - 1]
);
}
for i in (step_size + 1)..(2 * step_size) {
assert!(
lrs[i] <= lrs[i - 1] + EPS,
"should fall in second half: lrs[{i}]={} > lrs[{}]={}",
lrs[i],
i - 1,
lrs[i - 1]
);
}
}
#[test]
fn cyclic_lr_period_is_two_step_size() {
let step_size = 20_usize;
let mut sched = CyclicLr::new(0.0, 1.0, step_size);
let lrs_first: Vec<f32> = (0..2 * step_size).map(|_| sched.step()).collect();
let lrs_second: Vec<f32> = (0..2 * step_size).map(|_| sched.step()).collect();
for (a, b) in lrs_first.iter().zip(lrs_second.iter()) {
assert!(
approx_eq(*a, *b, EPS),
"cyclic LR must repeat with period 2*step_size: {a} vs {b}"
);
}
}
}