use super::optim::Optimizer;
pub trait LRScheduler {
fn step<O: Optimizer>(&mut self, optimizer: &mut O);
fn get_lr(&self) -> f32;
fn last_epoch(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct StepLR {
initial_lr: f32,
step_size: usize,
gamma: f32,
current_lr: f32,
last_epoch: usize,
}
impl StepLR {
#[must_use]
pub fn new(step_size: usize, gamma: f32) -> Self {
Self {
initial_lr: 0.0, step_size,
gamma,
current_lr: 0.0,
last_epoch: 0,
}
}
#[must_use]
pub fn with_lr(initial_lr: f32, step_size: usize, gamma: f32) -> Self {
Self {
initial_lr,
step_size,
gamma,
current_lr: initial_lr,
last_epoch: 0,
}
}
}
impl LRScheduler for StepLR {
fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
if self.last_epoch == 0 && self.initial_lr == 0.0 {
self.initial_lr = optimizer.lr();
self.current_lr = self.initial_lr;
}
self.last_epoch += 1;
if self.last_epoch.is_multiple_of(self.step_size) {
self.current_lr *= self.gamma;
optimizer.set_lr(self.current_lr);
}
}
fn get_lr(&self) -> f32 {
self.current_lr
}
fn last_epoch(&self) -> usize {
self.last_epoch
}
}
#[derive(Debug, Clone)]
pub struct ExponentialLR {
initial_lr: f32,
gamma: f32,
current_lr: f32,
last_epoch: usize,
}
impl ExponentialLR {
#[must_use]
pub fn new(gamma: f32) -> Self {
Self {
initial_lr: 0.0,
gamma,
current_lr: 0.0,
last_epoch: 0,
}
}
#[must_use]
pub fn with_lr(initial_lr: f32, gamma: f32) -> Self {
Self {
initial_lr,
gamma,
current_lr: initial_lr,
last_epoch: 0,
}
}
}
impl LRScheduler for ExponentialLR {
fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
if self.last_epoch == 0 && self.initial_lr == 0.0 {
self.initial_lr = optimizer.lr();
self.current_lr = self.initial_lr;
}
self.last_epoch += 1;
self.current_lr *= self.gamma;
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f32 {
self.current_lr
}
fn last_epoch(&self) -> usize {
self.last_epoch
}
}
#[derive(Debug, Clone)]
pub struct CosineAnnealingLR {
initial_lr: f32,
min_lr: f32,
t_max: usize,
current_lr: f32,
last_epoch: usize,
}
impl CosineAnnealingLR {
#[must_use]
pub fn new(t_max: usize) -> Self {
Self {
initial_lr: 0.0,
min_lr: 0.0,
t_max,
current_lr: 0.0,
last_epoch: 0,
}
}
#[must_use]
pub fn with_min_lr(t_max: usize, min_lr: f32) -> Self {
Self {
initial_lr: 0.0,
min_lr,
t_max,
current_lr: 0.0,
last_epoch: 0,
}
}
#[must_use]
pub fn with_lr(initial_lr: f32, t_max: usize, min_lr: f32) -> Self {
Self {
initial_lr,
min_lr,
t_max,
current_lr: initial_lr,
last_epoch: 0,
}
}
}
impl LRScheduler for CosineAnnealingLR {
fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
if self.last_epoch == 0 && self.initial_lr == 0.0 {
self.initial_lr = optimizer.lr();
self.current_lr = self.initial_lr;
}
self.last_epoch += 1;
let progress = self.last_epoch as f32 / self.t_max as f32;
let cosine = (std::f32::consts::PI * progress).cos();
self.current_lr = self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + cosine);
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f32 {
self.current_lr
}
fn last_epoch(&self) -> usize {
self.last_epoch
}
}
#[derive(Debug, Clone)]
pub struct LinearWarmup {
initial_lr: f32,
warmup_steps: usize,
current_lr: f32,
last_epoch: usize,
}
impl LinearWarmup {
#[must_use]
pub fn new(warmup_steps: usize) -> Self {
Self {
initial_lr: 0.0,
warmup_steps,
current_lr: 0.0,
last_epoch: 0,
}
}
#[must_use]
pub fn with_lr(initial_lr: f32, warmup_steps: usize) -> Self {
Self {
initial_lr,
warmup_steps,
current_lr: 0.0,
last_epoch: 0,
}
}
}
impl LRScheduler for LinearWarmup {
fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
if self.last_epoch == 0 && self.initial_lr == 0.0 {
self.initial_lr = optimizer.lr();
}
self.last_epoch += 1;
if self.last_epoch <= self.warmup_steps {
self.current_lr = self.initial_lr * (self.last_epoch as f32 / self.warmup_steps as f32);
} else {
self.current_lr = self.initial_lr;
}
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f32 {
self.current_lr
}
fn last_epoch(&self) -> usize {
self.last_epoch
}
}
#[derive(Debug, Clone)]
pub struct WarmupCosineScheduler {
initial_lr: f32,
min_lr: f32,
warmup_steps: usize,
total_steps: usize,
current_lr: f32,
last_epoch: usize,
}
impl WarmupCosineScheduler {
#[must_use]
pub fn new(warmup_steps: usize, total_steps: usize) -> Self {
Self {
initial_lr: 0.0,
min_lr: 0.0,
warmup_steps,
total_steps,
current_lr: 0.0,
last_epoch: 0,
}
}
#[must_use]
pub fn with_min_lr(warmup_steps: usize, total_steps: usize, min_lr: f32) -> Self {
Self {
initial_lr: 0.0,
min_lr,
warmup_steps,
total_steps,
current_lr: 0.0,
last_epoch: 0,
}
}
}
impl LRScheduler for WarmupCosineScheduler {
fn step<O: Optimizer>(&mut self, optimizer: &mut O) {
if self.last_epoch == 0 && self.initial_lr == 0.0 {
self.initial_lr = optimizer.lr();
}
self.last_epoch += 1;
if self.last_epoch <= self.warmup_steps {
self.current_lr = self.initial_lr * (self.last_epoch as f32 / self.warmup_steps as f32);
} else {
let decay_steps = self.total_steps - self.warmup_steps;
let decay_epoch = self.last_epoch - self.warmup_steps;
let progress = decay_epoch as f32 / decay_steps as f32;
let cosine = (std::f32::consts::PI * progress).cos();
self.current_lr = self.min_lr + 0.5 * (self.initial_lr - self.min_lr) * (1.0 + cosine);
}
optimizer.set_lr(self.current_lr);
}
fn get_lr(&self) -> f32 {
self.current_lr
}
fn last_epoch(&self) -> usize {
self.last_epoch
}
}
#[derive(Debug, Clone)]
pub struct ReduceLROnPlateau {
factor: f32,
patience: usize,
min_lr: f32,
threshold: f32,
current_lr: f32,
best_metric: f32,
num_bad_epochs: usize,
last_epoch: usize,
mode: PlateauMode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PlateauMode {
Min,
Max,
}
mod improvement;