use crate::Optimizer;
pub trait LrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer);
fn get_lr(&self) -> f64;
fn state_dict(&self) -> std::collections::HashMap<String, f64>;
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()>;
}
#[derive(Debug, Clone)]
pub struct StepLrScheduler {
pub initial_lr: f64,
pub step_size: usize,
pub gamma: f64,
current_epoch: usize,
current_lr: f64,
}
impl StepLrScheduler {
pub fn new(initial_lr: f64, step_size: usize, gamma: f64) -> Self {
Self {
initial_lr,
step_size,
gamma,
current_epoch: 0,
current_lr: initial_lr,
}
}
}
impl LrScheduler for StepLrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_epoch += 1;
if self.current_epoch.is_multiple_of(self.step_size) {
self.current_lr *= self.gamma;
optimizer.set_lr(self.current_lr);
}
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("initial_lr".to_string(), self.initial_lr);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_epoch".to_string(), self.current_epoch as f64);
state.insert("step_size".to_string(), self.step_size as f64);
state.insert("gamma".to_string(), self.gamma);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_epoch) = state.get("current_epoch") {
self.current_epoch = current_epoch as usize;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ExponentialLrScheduler {
pub initial_lr: f64,
pub gamma: f64,
current_epoch: usize,
current_lr: f64,
}
impl ExponentialLrScheduler {
pub fn new(initial_lr: f64, gamma: f64) -> Self {
Self {
initial_lr,
gamma,
current_epoch: 0,
current_lr: initial_lr,
}
}
}
impl LrScheduler for ExponentialLrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_epoch += 1;
self.current_lr = self.initial_lr * self.gamma.powi(self.current_epoch as i32);
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("initial_lr".to_string(), self.initial_lr);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_epoch".to_string(), self.current_epoch as f64);
state.insert("gamma".to_string(), self.gamma);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_epoch) = state.get("current_epoch") {
self.current_epoch = current_epoch as usize;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CosineAnnealingLrScheduler {
pub initial_lr: f64,
pub min_lr: f64,
pub t_max: usize,
current_epoch: usize,
current_lr: f64,
}
impl CosineAnnealingLrScheduler {
pub fn new(initial_lr: f64, min_lr: f64, t_max: usize) -> Self {
Self {
initial_lr,
min_lr,
t_max,
current_epoch: 0,
current_lr: initial_lr,
}
}
}
impl LrScheduler for CosineAnnealingLrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_epoch += 1;
let progress = (self.current_epoch as f64) / (self.t_max as f64);
let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
self.current_lr = self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay;
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("initial_lr".to_string(), self.initial_lr);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_epoch".to_string(), self.current_epoch as f64);
state.insert("min_lr".to_string(), self.min_lr);
state.insert("t_max".to_string(), self.t_max as f64);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_epoch) = state.get("current_epoch") {
self.current_epoch = current_epoch as usize;
}
Ok(())
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct WarmupScheduler {
pub target_lr: f64,
pub warmup_steps: usize,
current_step: usize,
current_lr: f64,
}
impl WarmupScheduler {
#[allow(dead_code)]
pub fn new(target_lr: f64, warmup_steps: usize) -> Self {
Self {
target_lr,
warmup_steps,
current_step: 0,
current_lr: 0.0,
}
}
}
impl LrScheduler for WarmupScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_step += 1;
if self.current_step < self.warmup_steps {
self.current_lr =
self.target_lr * (self.current_step as f64) / (self.warmup_steps as f64);
} else {
self.current_lr = self.target_lr;
}
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("target_lr".to_string(), self.target_lr);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_step".to_string(), self.current_step as f64);
state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_step) = state.get("current_step") {
self.current_step = current_step as usize;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct OneCycleLrScheduler {
pub initial_lr: f64,
pub max_lr: f64,
pub min_lr: f64,
pub total_steps: usize,
pub pct_start: f64,
current_step: usize,
current_lr: f64,
}
impl OneCycleLrScheduler {
pub fn new(
initial_lr: f64,
max_lr: f64,
min_lr: f64,
total_steps: usize,
pct_start: f64,
) -> Self {
Self {
initial_lr,
max_lr,
min_lr,
total_steps,
pct_start,
current_step: 0,
current_lr: initial_lr,
}
}
}
impl LrScheduler for OneCycleLrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_step += 1;
let step_num = self.current_step.min(self.total_steps);
let pct = step_num as f64 / self.total_steps as f64;
if pct < self.pct_start {
let phase_pct = pct / self.pct_start;
self.current_lr = self.initial_lr + (self.max_lr - self.initial_lr) * phase_pct;
} else {
let phase_pct = (pct - self.pct_start) / (1.0 - self.pct_start);
let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * phase_pct).cos());
self.current_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay;
}
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("initial_lr".to_string(), self.initial_lr);
state.insert("max_lr".to_string(), self.max_lr);
state.insert("min_lr".to_string(), self.min_lr);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_step".to_string(), self.current_step as f64);
state.insert("total_steps".to_string(), self.total_steps as f64);
state.insert("pct_start".to_string(), self.pct_start);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_step) = state.get("current_step") {
self.current_step = current_step as usize;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PolynomialDecayLrScheduler {
pub initial_lr: f64,
pub final_lr: f64,
pub power: f64,
pub decay_steps: usize,
current_step: usize,
current_lr: f64,
}
impl PolynomialDecayLrScheduler {
pub fn new(initial_lr: f64, final_lr: f64, power: f64, decay_steps: usize) -> Self {
Self {
initial_lr,
final_lr,
power,
decay_steps,
current_step: 0,
current_lr: initial_lr,
}
}
}
impl LrScheduler for PolynomialDecayLrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_step += 1;
let step_num = self.current_step.min(self.decay_steps);
let decay_factor = (1.0 - (step_num as f64 / self.decay_steps as f64)).powf(self.power);
self.current_lr = (self.initial_lr - self.final_lr) * decay_factor + self.final_lr;
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("initial_lr".to_string(), self.initial_lr);
state.insert("final_lr".to_string(), self.final_lr);
state.insert("power".to_string(), self.power);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_step".to_string(), self.current_step as f64);
state.insert("decay_steps".to_string(), self.decay_steps as f64);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_step) = state.get("current_step") {
self.current_step = current_step as usize;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CyclicLrMode {
Triangular,
Triangular2,
ExpRange,
}
#[derive(Debug, Clone)]
pub struct CyclicLrScheduler {
pub base_lr: f64,
pub max_lr: f64,
pub step_size: usize,
pub mode: CyclicLrMode,
pub gamma: f64,
current_step: usize,
current_lr: f64,
cycle: usize,
}
impl CyclicLrScheduler {
pub fn new(base_lr: f64, max_lr: f64, step_size: usize, mode: CyclicLrMode) -> Self {
Self {
base_lr,
max_lr,
step_size,
mode,
gamma: 0.99994,
current_step: 0,
current_lr: base_lr,
cycle: 0,
}
}
pub fn new_exp_range(base_lr: f64, max_lr: f64, step_size: usize, gamma: f64) -> Self {
Self {
base_lr,
max_lr,
step_size,
mode: CyclicLrMode::ExpRange,
gamma,
current_step: 0,
current_lr: base_lr,
cycle: 0,
}
}
}
impl LrScheduler for CyclicLrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_step += 1;
let cycle = (self.current_step - 1) / (2 * self.step_size);
let x = ((self.current_step - 1) as f64 / self.step_size as f64).abs() % 2.0;
let scale_fn = match self.mode {
CyclicLrMode::Triangular => 1.0,
CyclicLrMode::Triangular2 => 1.0 / 2.0_f64.powi(cycle as i32),
CyclicLrMode::ExpRange => self.gamma.powi(self.current_step as i32),
};
if x <= 1.0 {
self.current_lr = self.base_lr + (self.max_lr - self.base_lr) * x * scale_fn;
} else {
self.current_lr = self.base_lr + (self.max_lr - self.base_lr) * (2.0 - x) * scale_fn;
}
self.cycle = cycle;
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("base_lr".to_string(), self.base_lr);
state.insert("max_lr".to_string(), self.max_lr);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_step".to_string(), self.current_step as f64);
state.insert("step_size".to_string(), self.step_size as f64);
state.insert("cycle".to_string(), self.cycle as f64);
state.insert("gamma".to_string(), self.gamma);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_step) = state.get("current_step") {
self.current_step = current_step as usize;
}
if let Some(&cycle) = state.get("cycle") {
self.cycle = cycle as usize;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct WarmupCosineLrScheduler {
pub target_lr: f64,
pub min_lr: f64,
pub warmup_steps: usize,
pub total_steps: usize,
current_step: usize,
current_lr: f64,
}
impl WarmupCosineLrScheduler {
pub fn new(target_lr: f64, min_lr: f64, warmup_steps: usize, total_steps: usize) -> Self {
Self {
target_lr,
min_lr,
warmup_steps,
total_steps,
current_step: 0,
current_lr: 0.0,
}
}
}
impl LrScheduler for WarmupCosineLrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_step += 1;
if self.current_step <= self.warmup_steps {
self.current_lr =
self.target_lr * (self.current_step as f64 / self.warmup_steps as f64);
} else {
let progress = (self.current_step - self.warmup_steps) as f64
/ (self.total_steps - self.warmup_steps) as f64;
let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
self.current_lr = self.min_lr + (self.target_lr - self.min_lr) * cosine_decay;
}
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("target_lr".to_string(), self.target_lr);
state.insert("min_lr".to_string(), self.min_lr);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_step".to_string(), self.current_step as f64);
state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
state.insert("total_steps".to_string(), self.total_steps as f64);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_step) = state.get("current_step") {
self.current_step = current_step as usize;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct NoamScheduler {
model_dim: f64,
warmup_steps: usize,
scale_factor: f64,
current_step: usize,
current_lr: f64,
}
impl NoamScheduler {
pub fn new(model_dim: usize, warmup_steps: usize, scale_factor: f64) -> Self {
let model_dim_f64 = model_dim as f64;
let current_lr = scale_factor * model_dim_f64.powf(-0.5);
Self {
model_dim: model_dim_f64,
warmup_steps,
scale_factor,
current_step: 0,
current_lr,
}
}
fn compute_lr(&self) -> f64 {
let step = (self.current_step + 1) as f64; let warmup = self.warmup_steps as f64;
self.scale_factor
* self.model_dim.powf(-0.5)
* step.powf(-0.5).min(step * warmup.powf(-1.5))
}
}
impl LrScheduler for NoamScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_step += 1;
self.current_lr = self.compute_lr();
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("model_dim".to_string(), self.model_dim);
state.insert("warmup_steps".to_string(), self.warmup_steps as f64);
state.insert("scale_factor".to_string(), self.scale_factor);
state.insert("current_step".to_string(), self.current_step as f64);
state.insert("current_lr".to_string(), self.current_lr);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_step) = state.get("current_step") {
self.current_step = current_step as usize;
}
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MultiStepLrScheduler {
pub initial_lr: f64,
pub milestones: Vec<usize>,
pub gamma: f64,
current_epoch: usize,
current_lr: f64,
next_milestone_idx: usize,
}
impl MultiStepLrScheduler {
pub fn new(initial_lr: f64, mut milestones: Vec<usize>, gamma: f64) -> Self {
milestones.sort_unstable();
Self {
initial_lr,
milestones,
gamma,
current_epoch: 0,
current_lr: initial_lr,
next_milestone_idx: 0,
}
}
}
impl LrScheduler for MultiStepLrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
self.current_epoch += 1;
if self.next_milestone_idx < self.milestones.len()
&& self.current_epoch >= self.milestones[self.next_milestone_idx]
{
self.current_lr *= self.gamma;
self.next_milestone_idx += 1;
optimizer.set_lr(self.current_lr);
}
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("initial_lr".to_string(), self.initial_lr);
state.insert("current_lr".to_string(), self.current_lr);
state.insert("current_epoch".to_string(), self.current_epoch as f64);
state.insert("gamma".to_string(), self.gamma);
state.insert(
"next_milestone_idx".to_string(),
self.next_milestone_idx as f64,
);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(¤t_epoch) = state.get("current_epoch") {
self.current_epoch = current_epoch as usize;
}
if let Some(&next_milestone_idx) = state.get("next_milestone_idx") {
self.next_milestone_idx = next_milestone_idx as usize;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ReduceLROnPlateauScheduler {
current_lr: f64,
pub factor: f64,
pub patience: usize,
pub min_lr: f64,
pub threshold: f64,
pub cooldown: usize,
best_metric: Option<f64>,
num_bad_epochs: usize,
cooldown_counter: usize,
mode: PlateauMode,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PlateauMode {
Min,
Max,
}
impl ReduceLROnPlateauScheduler {
pub fn new(
initial_lr: f64,
mode: PlateauMode,
factor: f64,
patience: usize,
threshold: f64,
min_lr: f64,
cooldown: usize,
) -> Self {
Self {
current_lr: initial_lr,
factor,
patience,
min_lr,
threshold,
cooldown,
best_metric: None,
num_bad_epochs: 0,
cooldown_counter: 0,
mode,
}
}
pub fn step_with_metric(&mut self, optimizer: &mut dyn Optimizer, metric: f64) {
if self.cooldown_counter > 0 {
self.cooldown_counter -= 1;
return;
}
let is_better = match self.best_metric {
None => true, Some(best) => match self.mode {
PlateauMode::Min => metric < best * (1.0 - self.threshold),
PlateauMode::Max => metric > best * (1.0 + self.threshold),
},
};
if is_better {
self.best_metric = Some(metric);
self.num_bad_epochs = 0;
} else {
self.num_bad_epochs += 1;
if self.num_bad_epochs >= self.patience {
let new_lr = (self.current_lr * self.factor).max(self.min_lr);
if new_lr < self.current_lr {
self.current_lr = new_lr;
optimizer.set_lr(self.current_lr);
self.cooldown_counter = self.cooldown;
self.num_bad_epochs = 0;
}
}
}
}
}
impl LrScheduler for ReduceLROnPlateauScheduler {
fn step(&mut self, _optimizer: &mut dyn Optimizer) {
}
fn get_lr(&self) -> f64 {
self.current_lr
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("current_lr".to_string(), self.current_lr);
state.insert("factor".to_string(), self.factor);
state.insert("patience".to_string(), self.patience as f64);
state.insert("min_lr".to_string(), self.min_lr);
state.insert("threshold".to_string(), self.threshold);
state.insert("cooldown".to_string(), self.cooldown as f64);
state.insert(
"best_metric".to_string(),
self.best_metric.unwrap_or(f64::NAN),
);
state.insert("num_bad_epochs".to_string(), self.num_bad_epochs as f64);
state.insert("cooldown_counter".to_string(), self.cooldown_counter as f64);
state.insert(
"mode".to_string(),
match self.mode {
PlateauMode::Min => 0.0,
PlateauMode::Max => 1.0,
},
);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(¤t_lr) = state.get("current_lr") {
self.current_lr = current_lr;
}
if let Some(&best_metric) = state.get("best_metric") {
self.best_metric = if best_metric.is_nan() {
None
} else {
Some(best_metric)
};
}
if let Some(&num_bad_epochs) = state.get("num_bad_epochs") {
self.num_bad_epochs = num_bad_epochs as usize;
}
if let Some(&cooldown_counter) = state.get("cooldown_counter") {
self.cooldown_counter = cooldown_counter as usize;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{OptimizerConfig, SgdOptimizer};
#[test]
fn test_step_lr_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.1,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = StepLrScheduler::new(0.1, 2, 0.5);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step(&mut optimizer);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step(&mut optimizer);
assert_eq!(scheduler.get_lr(), 0.05);
scheduler.step(&mut optimizer);
assert_eq!(scheduler.get_lr(), 0.05);
scheduler.step(&mut optimizer);
assert_eq!(scheduler.get_lr(), 0.025);
}
#[test]
fn test_exponential_lr_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.1,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = ExponentialLrScheduler::new(0.1, 0.9);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step(&mut optimizer);
assert!((scheduler.get_lr() - 0.09).abs() < 1e-6);
scheduler.step(&mut optimizer);
assert!((scheduler.get_lr() - 0.081).abs() < 1e-6);
}
#[test]
fn test_cosine_annealing_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.1,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = CosineAnnealingLrScheduler::new(0.1, 0.01, 10);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step(&mut optimizer);
assert!(scheduler.get_lr() < 0.1);
assert!(scheduler.get_lr() > 0.01);
for _ in 1..5 {
scheduler.step(&mut optimizer);
}
let halfway_lr = scheduler.get_lr();
assert!((halfway_lr - 0.055).abs() < 0.01); }
#[test]
fn test_warmup_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.0,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = WarmupScheduler::new(0.1, 10);
assert_eq!(scheduler.get_lr(), 0.0);
scheduler.step(&mut optimizer);
assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
for _ in 1..10 {
scheduler.step(&mut optimizer);
}
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step(&mut optimizer);
assert_eq!(scheduler.get_lr(), 0.1); }
#[test]
fn test_one_cycle_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.01,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = OneCycleLrScheduler::new(0.01, 0.1, 0.001, 100, 0.3);
assert_eq!(scheduler.get_lr(), 0.01);
for _ in 0..30 {
scheduler.step(&mut optimizer);
}
assert!(scheduler.get_lr() > 0.01);
assert!(scheduler.get_lr() <= 0.1);
for _ in 30..100 {
scheduler.step(&mut optimizer);
}
assert!(scheduler.get_lr() < 0.1);
}
#[test]
fn test_polynomial_decay_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.1,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = PolynomialDecayLrScheduler::new(0.1, 0.001, 2.0, 100);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step(&mut optimizer);
assert!(scheduler.get_lr() < 0.1);
for _ in 1..100 {
scheduler.step(&mut optimizer);
}
assert!((scheduler.get_lr() - 0.001).abs() < 1e-6);
}
#[test]
fn test_cyclic_lr_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.01,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = CyclicLrScheduler::new(0.01, 0.1, 10, CyclicLrMode::Triangular);
assert_eq!(scheduler.get_lr(), 0.01);
for _ in 0..10 {
scheduler.step(&mut optimizer);
}
assert!(scheduler.get_lr() > 0.01);
for _ in 10..20 {
scheduler.step(&mut optimizer);
}
assert!(scheduler.get_lr() < 0.1);
}
#[test]
fn test_warmup_cosine_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.0,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = WarmupCosineLrScheduler::new(0.1, 0.001, 10, 100);
assert_eq!(scheduler.get_lr(), 0.0);
for _ in 0..10 {
scheduler.step(&mut optimizer);
}
assert!((scheduler.get_lr() - 0.1).abs() < 1e-6);
for _ in 10..50 {
scheduler.step(&mut optimizer);
}
assert!(scheduler.get_lr() < 0.1);
assert!(scheduler.get_lr() > 0.001);
for _ in 50..100 {
scheduler.step(&mut optimizer);
}
assert!(scheduler.get_lr() < 0.1);
assert!((scheduler.get_lr() - 0.001).abs() < 0.01);
}
#[test]
fn test_noam_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.0,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = NoamScheduler::new(512, 4000, 1.0);
let initial_lr = scheduler.get_lr();
assert!(initial_lr > 0.0);
scheduler.step(&mut optimizer);
let step1_lr = scheduler.get_lr();
assert!(step1_lr != initial_lr);
for _ in 1..4000 {
scheduler.step(&mut optimizer);
}
let peak_lr = scheduler.get_lr();
for _ in 4000..8000 {
scheduler.step(&mut optimizer);
}
assert!(scheduler.get_lr() < peak_lr);
}
#[test]
fn test_multistep_lr_scheduler() {
let config = OptimizerConfig {
learning_rate: 0.1,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = MultiStepLrScheduler::new(0.1, vec![10, 20, 30], 0.1);
assert_eq!(scheduler.get_lr(), 0.1);
for _ in 0..9 {
scheduler.step(&mut optimizer);
}
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step(&mut optimizer);
assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
for _ in 10..19 {
scheduler.step(&mut optimizer);
}
assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
scheduler.step(&mut optimizer);
assert!((scheduler.get_lr() - 0.001).abs() < 1e-6);
for _ in 20..29 {
scheduler.step(&mut optimizer);
}
scheduler.step(&mut optimizer);
assert!((scheduler.get_lr() - 0.0001).abs() < 1e-6);
}
#[test]
fn test_reduce_lr_on_plateau_min_mode() {
let config = OptimizerConfig {
learning_rate: 0.1,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = ReduceLROnPlateauScheduler::new(
0.1, PlateauMode::Min, 0.5, 3, 0.01, 0.001, 2, );
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 1.0);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 0.9);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 0.9);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 0.9);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 0.9);
assert_eq!(scheduler.get_lr(), 0.05);
scheduler.step_with_metric(&mut optimizer, 1.0);
assert_eq!(scheduler.get_lr(), 0.05);
scheduler.step_with_metric(&mut optimizer, 1.0);
assert_eq!(scheduler.get_lr(), 0.05);
}
#[test]
fn test_reduce_lr_on_plateau_max_mode() {
let config = OptimizerConfig {
learning_rate: 0.1,
..Default::default()
};
let mut optimizer = SgdOptimizer::new(config);
let mut scheduler = ReduceLROnPlateauScheduler::new(
0.1,
PlateauMode::Max, 0.1,
2,
0.01,
0.001,
0,
);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 0.5);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 0.6);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 0.6);
assert_eq!(scheduler.get_lr(), 0.1);
scheduler.step_with_metric(&mut optimizer, 0.6);
assert!((scheduler.get_lr() - 0.01).abs() < 1e-6);
}
#[test]
fn test_sgdr_scheduler() {
let mut scheduler = SgdrScheduler::new(0.1, 0.001, 10, 2.0);
let mut optimizer = SgdOptimizer::new(OptimizerConfig::default());
let initial_lr = scheduler.get_current_lr();
assert!((initial_lr - 0.1).abs() < 1e-6);
for _ in 0..5 {
scheduler.step(&mut optimizer);
}
let mid_lr = scheduler.get_lr();
assert!(mid_lr < initial_lr);
for _ in 5..10 {
scheduler.step(&mut optimizer);
}
scheduler.step(&mut optimizer);
let restart_lr = scheduler.get_lr();
assert!(restart_lr > mid_lr);
assert_eq!(scheduler.current_period, 20);
}
}
#[derive(Debug, Clone)]
pub struct SgdrScheduler {
pub max_lr: f64,
pub min_lr: f64,
pub t_0: usize,
pub t_mult: f64,
current_step: usize,
current_period: usize,
total_steps: usize,
}
impl SgdrScheduler {
pub fn new(max_lr: f64, min_lr: f64, t_0: usize, t_mult: f64) -> Self {
Self {
max_lr,
min_lr,
t_0,
t_mult,
current_step: 0,
current_period: t_0,
total_steps: 0,
}
}
fn get_current_lr(&self) -> f64 {
let progress = self.current_step as f64 / self.current_period as f64;
let cosine_factor = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
self.min_lr + (self.max_lr - self.min_lr) * cosine_factor
}
}
impl LrScheduler for SgdrScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer) {
let lr = self.get_current_lr();
optimizer.set_lr(lr);
self.current_step += 1;
self.total_steps += 1;
if self.current_step >= self.current_period {
self.current_step = 0;
self.current_period = (self.current_period as f64 * self.t_mult) as usize;
}
}
fn get_lr(&self) -> f64 {
self.get_current_lr()
}
fn state_dict(&self) -> std::collections::HashMap<String, f64> {
let mut state = std::collections::HashMap::new();
state.insert("max_lr".to_string(), self.max_lr);
state.insert("min_lr".to_string(), self.min_lr);
state.insert("t_0".to_string(), self.t_0 as f64);
state.insert("t_mult".to_string(), self.t_mult);
state.insert("current_step".to_string(), self.current_step as f64);
state.insert("current_period".to_string(), self.current_period as f64);
state.insert("total_steps".to_string(), self.total_steps as f64);
state
}
fn load_state_dict(
&mut self,
state: &std::collections::HashMap<String, f64>,
) -> crate::TrainResult<()> {
if let Some(&max_lr) = state.get("max_lr") {
self.max_lr = max_lr;
}
if let Some(&min_lr) = state.get("min_lr") {
self.min_lr = min_lr;
}
if let Some(&t_0) = state.get("t_0") {
self.t_0 = t_0 as usize;
}
if let Some(&t_mult) = state.get("t_mult") {
self.t_mult = t_mult;
}
if let Some(¤t_step) = state.get("current_step") {
self.current_step = current_step as usize;
}
if let Some(¤t_period) = state.get("current_period") {
self.current_period = current_period as usize;
}
if let Some(&total_steps) = state.get("total_steps") {
self.total_steps = total_steps as usize;
}
Ok(())
}
}