pub trait LRScheduler: Send + Sync {
fn get_lr(&self, step: usize) -> f32;
fn step(&mut self);
}
#[derive(Debug)]
pub struct LinearScheduler {
base_lr: f32,
warmup_steps: usize,
total_steps: usize,
current_step: usize,
}
impl LinearScheduler {
pub fn new(base_lr: f32, warmup_steps: usize, total_steps: usize) -> Self {
Self {
base_lr,
warmup_steps,
total_steps,
current_step: 0,
}
}
}
impl LRScheduler for LinearScheduler {
fn get_lr(&self, step: usize) -> f32 {
if step < self.warmup_steps {
self.base_lr * (step as f32) / (self.warmup_steps as f32)
} else {
let progress =
(step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
self.base_lr * (1.0 - progress).max(0.0)
}
}
fn step(&mut self) {
self.current_step += 1;
}
}
#[derive(Debug)]
pub struct CosineScheduler {
base_lr: f32,
warmup_steps: usize,
total_steps: usize,
current_step: usize,
min_lr: f32,
}
impl CosineScheduler {
pub fn new(base_lr: f32, warmup_steps: usize, total_steps: usize, min_lr: f32) -> Self {
Self {
base_lr,
warmup_steps,
total_steps,
current_step: 0,
min_lr,
}
}
}
impl LRScheduler for CosineScheduler {
fn get_lr(&self, step: usize) -> f32 {
use std::f32::consts::PI;
if step < self.warmup_steps {
self.base_lr * (step as f32) / (self.warmup_steps as f32)
} else {
let progress =
(step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
}
}
fn step(&mut self) {
self.current_step += 1;
}
}
#[derive(Debug)]
pub struct PolynomialScheduler {
base_lr: f32,
warmup_steps: usize,
total_steps: usize,
current_step: usize,
min_lr: f32,
power: f32,
}
impl PolynomialScheduler {
pub fn new(
base_lr: f32,
warmup_steps: usize,
total_steps: usize,
min_lr: f32,
power: f32,
) -> Self {
Self {
base_lr,
warmup_steps,
total_steps,
current_step: 0,
min_lr,
power,
}
}
}
impl LRScheduler for PolynomialScheduler {
fn get_lr(&self, step: usize) -> f32 {
if step < self.warmup_steps {
self.base_lr * (step as f32) / (self.warmup_steps as f32)
} else {
let progress =
(step - self.warmup_steps) as f32 / (self.total_steps - self.warmup_steps) as f32;
let decay_factor = (1.0 - progress.min(1.0)).powf(self.power);
self.min_lr + (self.base_lr - self.min_lr) * decay_factor
}
}
fn step(&mut self) {
self.current_step += 1;
}
}
#[derive(Debug)]
pub struct ConstantWithWarmupScheduler {
base_lr: f32,
warmup_steps: usize,
current_step: usize,
}
impl ConstantWithWarmupScheduler {
pub fn new(base_lr: f32, warmup_steps: usize) -> Self {
Self {
base_lr,
warmup_steps,
current_step: 0,
}
}
}
impl LRScheduler for ConstantWithWarmupScheduler {
fn get_lr(&self, step: usize) -> f32 {
if step < self.warmup_steps {
self.base_lr * (step as f32) / (self.warmup_steps as f32)
} else {
self.base_lr
}
}
fn step(&mut self) {
self.current_step += 1;
}
}
#[derive(Debug)]
pub struct ExponentialScheduler {
base_lr: f32,
warmup_steps: usize,
current_step: usize,
decay_rate: f32,
decay_steps: usize,
}
impl ExponentialScheduler {
pub fn new(base_lr: f32, warmup_steps: usize, decay_rate: f32, decay_steps: usize) -> Self {
Self {
base_lr,
warmup_steps,
current_step: 0,
decay_rate,
decay_steps,
}
}
}
impl LRScheduler for ExponentialScheduler {
fn get_lr(&self, step: usize) -> f32 {
if step < self.warmup_steps {
self.base_lr * (step as f32) / (self.warmup_steps as f32)
} else {
let decay_step = (step - self.warmup_steps) / self.decay_steps;
self.base_lr * self.decay_rate.powf(decay_step as f32)
}
}
fn step(&mut self) {
self.current_step += 1;
}
}
#[derive(Debug)]
pub struct StepScheduler {
base_lr: f32,
warmup_steps: usize,
current_step: usize,
step_size: usize,
gamma: f32,
}
impl StepScheduler {
pub fn new(base_lr: f32, warmup_steps: usize, step_size: usize, gamma: f32) -> Self {
Self {
base_lr,
warmup_steps,
current_step: 0,
step_size,
gamma,
}
}
}
impl LRScheduler for StepScheduler {
fn get_lr(&self, step: usize) -> f32 {
if step < self.warmup_steps {
self.base_lr * (step as f32) / (self.warmup_steps as f32)
} else {
let decay_step = (step - self.warmup_steps) / self.step_size;
self.base_lr * self.gamma.powf(decay_step as f32)
}
}
fn step(&mut self) {
self.current_step += 1;
}
}
#[derive(Debug)]
pub struct OneCycleScheduler {
max_lr: f32,
final_lr: f32,
total_steps: usize,
pct_start: f32,
current_step: usize,
}
impl OneCycleScheduler {
pub fn new(max_lr: f32, total_steps: usize, pct_start: f32, final_lr: f32) -> Self {
Self {
max_lr,
final_lr,
total_steps,
pct_start: pct_start.clamp(0.0, 1.0),
current_step: 0,
}
}
}
impl LRScheduler for OneCycleScheduler {
fn get_lr(&self, step: usize) -> f32 {
use std::f32::consts::PI;
let step = step.min(self.total_steps);
let pct = step as f32 / self.total_steps as f32;
if pct <= self.pct_start {
let phase_pct = pct / self.pct_start;
let cosine_term = 0.5 * (1.0 - (PI * phase_pct).cos());
self.final_lr + (self.max_lr - self.final_lr) * cosine_term
} else {
let remaining_pct = (pct - self.pct_start) / (1.0 - self.pct_start);
let cosine_term = 0.5 * (1.0 + (PI * remaining_pct).cos());
self.final_lr + (self.max_lr - self.final_lr) * cosine_term
}
}
fn step(&mut self) {
self.current_step += 1;
}
}
#[derive(Debug)]
pub struct CosineWithRestartsScheduler {
base_lr: f32,
min_lr: f32,
t_0: usize,
t_mult: f32,
current_step: usize,
next_restart: usize,
current_t: usize,
}
impl CosineWithRestartsScheduler {
pub fn new(base_lr: f32, min_lr: f32, t_0: usize, t_mult: f32) -> Self {
Self {
base_lr,
min_lr,
t_0,
t_mult,
current_step: 0,
next_restart: t_0,
current_t: t_0,
}
}
}
impl LRScheduler for CosineWithRestartsScheduler {
fn get_lr(&self, step: usize) -> f32 {
use std::f32::consts::PI;
let mut step_in_cycle = step;
let mut cycle_length = self.t_0;
while step_in_cycle >= cycle_length {
step_in_cycle -= cycle_length;
cycle_length = (cycle_length as f32 * self.t_mult) as usize;
}
let progress = step_in_cycle as f32 / cycle_length as f32;
let cosine_decay = 0.5 * (1.0 + (PI * progress).cos());
self.min_lr + (self.base_lr - self.min_lr) * cosine_decay
}
fn step(&mut self) {
self.current_step += 1;
if self.current_step >= self.next_restart {
self.current_t = (self.current_t as f32 * self.t_mult) as usize;
self.next_restart += self.current_t;
}
}
}
#[derive(Debug)]
pub struct CyclicalScheduler {
base_lr: f32,
max_lr: f32,
step_size_up: usize,
step_size_down: usize,
current_step: usize,
mode: CyclicalMode,
}
#[derive(Debug, Clone)]
pub enum CyclicalMode {
Triangular,
Triangular2,
ExpRange(f32), }
impl CyclicalScheduler {
pub fn new(
base_lr: f32,
max_lr: f32,
step_size_up: usize,
step_size_down: usize,
mode: CyclicalMode,
) -> Self {
Self {
base_lr,
max_lr,
step_size_up,
step_size_down,
current_step: 0,
mode,
}
}
}
impl LRScheduler for CyclicalScheduler {
fn get_lr(&self, step: usize) -> f32 {
let cycle_length = self.step_size_up + self.step_size_down;
let cycle = (step / cycle_length) + 1;
let x = (step % cycle_length) as f32;
let (amplitude, _phase) = if x <= self.step_size_up as f32 {
(x / self.step_size_up as f32, 1.0)
} else {
(
(self.step_size_down as f32 - (x - self.step_size_up as f32))
/ self.step_size_down as f32,
1.0,
)
};
let scale_factor = match &self.mode {
CyclicalMode::Triangular => 1.0,
CyclicalMode::Triangular2 => 1.0 / (2.0_f32.powi((cycle - 1) as i32)),
CyclicalMode::ExpRange(gamma) => gamma.powi(step as i32),
};
self.base_lr + (self.max_lr - self.base_lr) * amplitude * scale_factor
}
fn step(&mut self) {
self.current_step += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_scheduler() {
let scheduler = LinearScheduler::new(1e-3, 100, 1000);
assert_eq!(scheduler.get_lr(0), 0.0);
assert_eq!(scheduler.get_lr(50), 5e-4);
assert_eq!(scheduler.get_lr(100), 1e-3);
assert_eq!(scheduler.get_lr(550), 5e-4);
assert_eq!(scheduler.get_lr(1000), 0.0);
}
#[test]
fn test_cosine_scheduler() {
let scheduler = CosineScheduler::new(1e-3, 100, 1000, 1e-5);
assert_eq!(scheduler.get_lr(0), 0.0);
assert_eq!(scheduler.get_lr(50), 5e-4);
assert_eq!(scheduler.get_lr(100), 1e-3);
let mid_lr = scheduler.get_lr(550);
assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
let end_lr = scheduler.get_lr(1000);
assert!((end_lr - 1e-5).abs() < 1e-6);
}
#[test]
fn test_polynomial_scheduler() {
let scheduler = PolynomialScheduler::new(1e-3, 100, 1000, 1e-5, 2.0);
assert_eq!(scheduler.get_lr(0), 0.0);
assert_eq!(scheduler.get_lr(100), 1e-3);
let mid_lr = scheduler.get_lr(550);
assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
}
#[test]
fn test_constant_with_warmup_scheduler() {
let scheduler = ConstantWithWarmupScheduler::new(1e-3, 100);
assert_eq!(scheduler.get_lr(0), 0.0);
assert_eq!(scheduler.get_lr(50), 5e-4);
assert_eq!(scheduler.get_lr(100), 1e-3);
assert_eq!(scheduler.get_lr(200), 1e-3);
assert_eq!(scheduler.get_lr(1000), 1e-3);
}
#[test]
fn test_exponential_scheduler() {
let scheduler = ExponentialScheduler::new(1e-3, 100, 0.9, 100);
assert_eq!(scheduler.get_lr(0), 0.0);
assert_eq!(scheduler.get_lr(100), 1e-3);
assert_eq!(scheduler.get_lr(200), 1e-3 * 0.9);
assert_eq!(scheduler.get_lr(300), 1e-3 * 0.9 * 0.9);
}
#[test]
fn test_step_scheduler() {
let scheduler = StepScheduler::new(1e-3, 100, 200, 0.5);
assert_eq!(scheduler.get_lr(0), 0.0);
assert_eq!(scheduler.get_lr(100), 1e-3);
assert_eq!(scheduler.get_lr(250), 1e-3); assert_eq!(scheduler.get_lr(300), 1e-3 * 0.5); assert_eq!(scheduler.get_lr(500), 1e-3 * 0.5 * 0.5); }
#[test]
fn test_onecycle_scheduler() {
let scheduler = OneCycleScheduler::new(1e-2, 1000, 0.3, 1e-5);
assert_eq!(scheduler.get_lr(0), 1e-5);
let peak_lr = scheduler.get_lr(150);
assert!(peak_lr > 5e-3);
let end_lr = scheduler.get_lr(1000);
assert!((end_lr - 1e-5).abs() < 1e-6);
}
#[test]
fn test_cosine_with_restarts_scheduler() {
let scheduler = CosineWithRestartsScheduler::new(1e-3, 1e-5, 100, 2.0);
assert!((scheduler.get_lr(0) - 1e-3).abs() < 1e-6);
let mid_lr = scheduler.get_lr(50);
assert!(mid_lr > 1e-5 && mid_lr < 1e-3);
let near_end_lr = scheduler.get_lr(99);
assert!(near_end_lr < 2e-4);
let restart_lr = scheduler.get_lr(100);
assert!(restart_lr > 5e-4);
}
#[test]
fn test_cyclical_scheduler() {
let scheduler = CyclicalScheduler::new(1e-4, 1e-3, 50, 50, CyclicalMode::Triangular);
assert!((scheduler.get_lr(0) - 1e-4).abs() < 1e-6);
assert!((scheduler.get_lr(50) - 1e-3).abs() < 1e-6);
assert!((scheduler.get_lr(100) - 1e-4).abs() < 1e-6);
assert!((scheduler.get_lr(150) - 1e-3).abs() < 1e-6);
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveScheduler {
current_lr: f32,
factor: f32,
patience: usize,
threshold: f32,
min_lr: f32,
mode: String,
epochs_since_improvement: usize,
best_metric: Option<f32>,
current_step: usize,
}
impl AdaptiveScheduler {
pub fn new(
initial_lr: f32,
factor: f32,
patience: usize,
threshold: f32,
min_lr: f32,
mode: &str,
) -> Self {
assert!(
factor > 0.0 && factor < 1.0,
"Factor must be between 0 and 1"
);
assert!(patience > 0, "Patience must be positive");
assert!(threshold >= 0.0, "Threshold must be non-negative");
assert!(min_lr >= 0.0, "Min LR must be non-negative");
assert!(mode == "min" || mode == "max", "Mode must be min or max");
Self {
current_lr: initial_lr,
factor,
patience,
threshold,
min_lr,
mode: mode.to_string(),
epochs_since_improvement: 0,
best_metric: None,
current_step: 0,
}
}
pub fn step_with_metric(&mut self, metric: f32) -> (f32, bool) {
self.current_step += 1;
let mut lr_reduced = false;
let is_improvement = match self.best_metric {
None => {
self.best_metric = Some(metric);
true
},
Some(best) => {
let improvement = if self.mode == "min" {
(best - metric) / best.abs().max(1e-8) > self.threshold
} else {
(metric - best) / best.abs().max(1e-8) > self.threshold
};
if improvement {
self.best_metric = Some(metric);
}
improvement
},
};
if is_improvement {
self.epochs_since_improvement = 0;
} else {
self.epochs_since_improvement += 1;
if self.epochs_since_improvement >= self.patience {
let new_lr = (self.current_lr * self.factor).max(self.min_lr);
if new_lr < self.current_lr {
self.current_lr = new_lr;
lr_reduced = true;
self.epochs_since_improvement = 0; }
}
}
(self.current_lr, lr_reduced)
}
pub fn get_current_lr(&self) -> f32 {
self.current_lr
}
pub fn get_best_metric(&self) -> Option<f32> {
self.best_metric
}
pub fn get_epochs_since_improvement(&self) -> usize {
self.epochs_since_improvement
}
pub fn reset(&mut self) {
self.epochs_since_improvement = 0;
self.best_metric = None;
self.current_step = 0;
}
pub fn set_lr(&mut self, lr: f32) {
self.current_lr = lr;
}
}
impl LRScheduler for AdaptiveScheduler {
fn get_lr(&self, _step: usize) -> f32 {
self.current_lr
}
fn step(&mut self) {
}
}
pub struct CompositeScheduler {
schedulers: Vec<Box<dyn LRScheduler>>,
step_boundaries: Vec<usize>,
current_step: usize,
#[allow(dead_code)]
global_step_offset: usize,
}
impl CompositeScheduler {
pub fn new(schedulers: Vec<Box<dyn LRScheduler>>, step_boundaries: Vec<usize>) -> Self {
assert_eq!(
schedulers.len(),
step_boundaries.len(),
"Number of schedulers must match number of boundaries"
);
assert!(
!schedulers.is_empty(),
"Must provide at least one scheduler"
);
Self {
schedulers,
step_boundaries,
current_step: 0,
global_step_offset: 0,
}
}
fn get_active_scheduler_index(&self, step: usize) -> usize {
for (i, &boundary) in self.step_boundaries.iter().enumerate() {
if step < boundary {
return i;
}
}
self.schedulers.len() - 1
}
fn get_local_step(&self, global_step: usize, scheduler_index: usize) -> usize {
if scheduler_index == 0 {
global_step
} else {
global_step - self.step_boundaries[scheduler_index - 1]
}
}
}
impl LRScheduler for CompositeScheduler {
fn get_lr(&self, step: usize) -> f32 {
let scheduler_idx = self.get_active_scheduler_index(step);
let local_step = self.get_local_step(step, scheduler_idx);
self.schedulers[scheduler_idx].get_lr(local_step)
}
fn step(&mut self) {
self.current_step += 1;
let _scheduler_idx = self.get_active_scheduler_index(self.current_step);
}
}
pub struct PhaseBasedScheduler {
phases: Vec<Phase>,
current_phase: usize,
current_step: usize,
phase_start_step: usize,
}
pub struct Phase {
pub name: String,
pub scheduler: Box<dyn LRScheduler>,
pub duration_steps: usize,
pub lr_multiplier: f32,
}
impl PhaseBasedScheduler {
pub fn new(phases: Vec<Phase>) -> Self {
assert!(!phases.is_empty(), "Must provide at least one phase");
Self {
phases,
current_phase: 0,
current_step: 0,
phase_start_step: 0,
}
}
pub fn get_current_phase(&self) -> &str {
&self.phases[self.current_phase].name
}
pub fn get_current_phase_index(&self) -> usize {
self.current_phase
}
pub fn is_complete(&self) -> bool {
self.current_phase >= self.phases.len()
}
fn update_phase(&mut self, step: usize) {
while self.current_phase < self.phases.len() {
let phase_end = self.phase_start_step + self.phases[self.current_phase].duration_steps;
if step < phase_end {
break; }
self.current_phase += 1;
self.phase_start_step = phase_end;
}
}
}
impl LRScheduler for PhaseBasedScheduler {
fn get_lr(&self, step: usize) -> f32 {
if self.current_phase >= self.phases.len() {
return 0.0; }
let phase = &self.phases[self.current_phase];
let phase_step = step - self.phase_start_step;
let base_lr = phase.scheduler.get_lr(phase_step);
base_lr * phase.lr_multiplier
}
fn step(&mut self) {
self.current_step += 1;
self.update_phase(self.current_step);
}
}
pub struct DynamicScheduler {
primary_scheduler: Box<dyn LRScheduler>,
fallback_scheduler: Box<dyn LRScheduler>,
current_scheduler: usize, switch_condition: SwitchCondition,
metrics_window: Vec<f32>,
window_size: usize,
current_step: usize,
}
#[derive(Debug)]
pub enum SwitchCondition {
LossPlateauSteps(usize),
GradientNormThreshold(f32),
StepThreshold(usize),
LossIncreaseFactor(f32),
}
impl DynamicScheduler {
pub fn new(
primary_scheduler: Box<dyn LRScheduler>,
fallback_scheduler: Box<dyn LRScheduler>,
switch_condition: SwitchCondition,
window_size: usize,
) -> Self {
Self {
primary_scheduler,
fallback_scheduler,
current_scheduler: 0,
switch_condition,
metrics_window: Vec::with_capacity(window_size),
window_size,
current_step: 0,
}
}
pub fn update_metric(&mut self, metric: f32) {
self.metrics_window.push(metric);
if self.metrics_window.len() > self.window_size {
self.metrics_window.remove(0);
}
if self.current_scheduler == 0 && self.should_switch() {
self.current_scheduler = 1;
}
}
fn should_switch(&self) -> bool {
match &self.switch_condition {
SwitchCondition::LossPlateauSteps(steps) => {
if self.metrics_window.len() < *steps {
return false;
}
let recent_avg =
self.metrics_window.iter().rev().take(*steps).sum::<f32>() / *steps as f32;
let older_avg =
self.metrics_window.iter().take(self.metrics_window.len() - steps).sum::<f32>()
/ (self.metrics_window.len() - steps) as f32;
recent_avg >= older_avg * 0.995 },
SwitchCondition::StepThreshold(step) => self.current_step >= *step,
SwitchCondition::LossIncreaseFactor(factor) => {
if self.metrics_window.len() < 2 {
return false;
}
let latest = self.metrics_window[self.metrics_window.len() - 1];
let previous = self.metrics_window[self.metrics_window.len() - 2];
latest > previous * factor
},
SwitchCondition::GradientNormThreshold(_) => false, }
}
pub fn get_active_scheduler(&self) -> &str {
if self.current_scheduler == 0 {
"primary"
} else {
"fallback"
}
}
}
impl LRScheduler for DynamicScheduler {
fn get_lr(&self, step: usize) -> f32 {
if self.current_scheduler == 0 {
self.primary_scheduler.get_lr(step)
} else {
self.fallback_scheduler.get_lr(step)
}
}
fn step(&mut self) {
self.current_step += 1;
if self.current_scheduler == 0 {
self.primary_scheduler.step();
} else {
self.fallback_scheduler.step();
}
}
}
pub struct TaskSpecificScheduler {
scheduler: Box<dyn LRScheduler>,
task_type: TaskType,
current_step: usize,
}
#[derive(Debug)]
pub enum TaskType {
LanguageModelPretraining,
FineTuning,
ComputerVision,
ReinforcementLearning,
GANTraining,
}
impl TaskSpecificScheduler {
pub fn new(task_type: TaskType, base_lr: f32, total_steps: usize) -> Self {
let scheduler: Box<dyn LRScheduler> = match task_type {
TaskType::LanguageModelPretraining => {
Box::new(CosineScheduler::new(
base_lr,
(total_steps as f32 * 0.06) as usize, total_steps,
base_lr * 0.1, ))
},
TaskType::FineTuning => {
Box::new(LinearScheduler::new(
base_lr * 0.1, (total_steps as f32 * 0.1) as usize, total_steps,
))
},
TaskType::ComputerVision => {
Box::new(StepScheduler::new(
base_lr,
(total_steps as f32 * 0.05) as usize, total_steps / 3, 0.1, ))
},
TaskType::ReinforcementLearning => {
Box::new(AdaptiveScheduler::new(
base_lr,
0.5, 10, 1e-4, base_lr * 1e-3, "max", ))
},
TaskType::GANTraining => {
Box::new(ConstantWithWarmupScheduler::new(
base_lr,
(total_steps as f32 * 0.02) as usize, ))
},
};
Self {
scheduler,
task_type,
current_step: 0,
}
}
pub fn get_task_type(&self) -> &TaskType {
&self.task_type
}
}
impl LRScheduler for TaskSpecificScheduler {
fn get_lr(&self, step: usize) -> f32 {
self.scheduler.get_lr(step)
}
fn step(&mut self) {
self.current_step += 1;
self.scheduler.step();
}
}