use std::f32;
pub trait LearningRateSchedule {
fn compute_epoch_learning_rate(&mut self, epoch: usize, last_score: f32) -> f32;
fn compute_step_learning_rate(&mut self, global_step: usize) -> f32;
fn initial_lr(&self) -> f32;
fn set_initial_lr(&mut self, lr: f32);
}
pub struct ConstantLearningRate {
lr: f32,
warmup_steps: usize,
}
impl ConstantLearningRate {
pub fn new(lr: f32, warmup_steps: usize) -> Self {
assert!(lr > 0.0, "Learning rate must be a positive value");
ConstantLearningRate { lr, warmup_steps }
}
}
impl LearningRateSchedule for ConstantLearningRate {
fn compute_epoch_learning_rate(&mut self, _epoch: usize, _last_score: f32) -> f32 {
self.lr
}
fn compute_step_learning_rate(&mut self, global_step: usize) -> f32 {
if global_step < self.warmup_steps {
return (self.lr / (self.warmup_steps as f32)) * global_step as f32;
}
self.lr
}
fn initial_lr(&self) -> f32 {
self.lr
}
fn set_initial_lr(&mut self, lr: f32) {
self.lr = lr;
}
}
#[derive(Clone)]
pub struct ExponentialDecay {
initial_lr: f32,
lr: f32,
decay_rate: f32,
decay_steps: usize,
warmup_steps: usize,
staircase: bool,
}
impl ExponentialDecay {
pub fn new(
initial_lr: f32,
decay_rate: f32,
decay_steps: usize,
staircase: bool,
warmup_steps: usize,
) -> Self {
assert!(
initial_lr > 0.0,
"The initial learning rate must be a positive value."
);
assert!(
decay_rate > 0.0 && decay_rate < 1.0,
"The decay rate must be in (0, 1)."
);
assert!(
decay_steps > 0,
"The number decay steps should be non-zero."
);
ExponentialDecay {
lr: initial_lr,
initial_lr,
decay_rate,
decay_steps,
staircase,
warmup_steps,
}
}
}
impl LearningRateSchedule for ExponentialDecay {
fn compute_step_learning_rate(&mut self, global_step: usize) -> f32 {
if global_step < self.warmup_steps {
return (self.initial_lr / (self.warmup_steps as f32)) * global_step as f32;
}
let step = global_step - self.warmup_steps;
let exponent = if self.staircase {
(step / self.decay_steps) as f32
} else {
step as f32 / self.decay_steps as f32
};
self.lr = self.initial_lr * self.decay_rate.powf(exponent);
self.lr
}
fn compute_epoch_learning_rate(&mut self, _epoch: usize, _last_score: f32) -> f32 {
self.lr
}
fn initial_lr(&self) -> f32 {
self.initial_lr
}
fn set_initial_lr(&mut self, lr: f32) {
self.initial_lr = lr;
}
}
#[derive(Clone)]
pub struct PlateauLearningRate<I> {
scale: f32,
best_score: f32,
patience: usize,
max_patience: usize,
lr_schedule: I,
}
impl<I> PlateauLearningRate<I> {
pub fn new(lr_schedule: I, scale: f32, max_patience: usize) -> Self {
PlateauLearningRate {
lr_schedule,
scale,
best_score: -f32::INFINITY,
patience: 0,
max_patience,
}
}
}
impl<I> LearningRateSchedule for PlateauLearningRate<I>
where
I: LearningRateSchedule,
{
fn compute_epoch_learning_rate(&mut self, epoch: usize, last_score: f32) -> f32 {
if last_score > self.best_score {
self.best_score = last_score;
self.patience = 0;
} else {
self.patience += 1;
if self.patience == self.max_patience {
let mut lr = self.lr_schedule.initial_lr();
lr *= self.scale;
self.lr_schedule.set_initial_lr(lr);
self.patience = 0;
}
}
self.lr_schedule
.compute_epoch_learning_rate(epoch, last_score)
}
fn compute_step_learning_rate(&mut self, global_step: usize) -> f32 {
self.lr_schedule.compute_step_learning_rate(global_step)
}
fn initial_lr(&self) -> f32 {
self.lr_schedule.initial_lr()
}
fn set_initial_lr(&mut self, lr: f32) {
self.lr_schedule.set_initial_lr(lr);
}
}
#[cfg(test)]
mod tests {
use approx::assert_relative_eq;
use super::{
ConstantLearningRate, ExponentialDecay, LearningRateSchedule, PlateauLearningRate,
};
#[test]
pub fn constant_lr() {
let mut constant = ConstantLearningRate::new(0.1, 0);
assert_relative_eq!(constant.compute_epoch_learning_rate(0, 0.), 0.1);
assert_relative_eq!(constant.compute_epoch_learning_rate(1, 0.), 0.1);
assert_relative_eq!(constant.compute_epoch_learning_rate(5, 0.), 0.1);
assert_relative_eq!(constant.compute_epoch_learning_rate(15, 0.), 0.1);
assert_relative_eq!(constant.compute_epoch_learning_rate(25, 0.), 0.1);
}
#[test]
pub fn exponential_decay_lr() {
let mut decay1 = ExponentialDecay::new(0.1, 0.2, 10, true, 0);
assert_relative_eq!(decay1.compute_step_learning_rate(0), 0.1);
assert_relative_eq!(decay1.compute_step_learning_rate(1), 0.1);
assert_relative_eq!(decay1.compute_step_learning_rate(5), 0.1);
assert_relative_eq!(decay1.compute_step_learning_rate(15), 0.02);
assert_relative_eq!(decay1.compute_step_learning_rate(25), 0.004);
let mut decay2 = ExponentialDecay::new(0.1, 0.2, 10, false, 0);
assert_relative_eq!(decay2.compute_step_learning_rate(0), 0.1);
assert_relative_eq!(decay2.compute_step_learning_rate(1), 0.085133992);
assert_relative_eq!(decay2.compute_step_learning_rate(5), 0.044721359);
assert_relative_eq!(decay2.compute_step_learning_rate(15), 0.008944271);
assert_relative_eq!(decay2.compute_step_learning_rate(25), 0.001788854);
}
#[test]
pub fn exponential_decay_lr_warmup() {
let mut decay1 = ExponentialDecay::new(0.1, 0.2, 10, true, 5);
assert_relative_eq!(decay1.compute_step_learning_rate(0), 0.0);
assert_relative_eq!(decay1.compute_step_learning_rate(1), 0.02);
assert_relative_eq!(decay1.compute_step_learning_rate(2), 0.04);
assert_relative_eq!(decay1.compute_step_learning_rate(3), 0.06);
assert_relative_eq!(decay1.compute_step_learning_rate(4), 0.08);
assert_relative_eq!(decay1.compute_step_learning_rate(5), 0.1);
assert_relative_eq!(decay1.compute_step_learning_rate(6), 0.1);
assert_relative_eq!(decay1.compute_epoch_learning_rate(0, 0.), 0.1);
assert_relative_eq!(decay1.compute_step_learning_rate(7), 0.1);
assert_relative_eq!(decay1.compute_step_learning_rate(11), 0.1);
assert_relative_eq!(decay1.compute_step_learning_rate(21), 0.02);
assert_relative_eq!(decay1.compute_step_learning_rate(31), 0.004);
}
#[test]
fn plateau_lr() {
let mut plateau = PlateauLearningRate::new(ConstantLearningRate::new(0.1, 0), 0.5, 2);
assert_relative_eq!(plateau.compute_epoch_learning_rate(0, 1.0), 0.1);
assert_relative_eq!(plateau.compute_epoch_learning_rate(1, 2.0), 0.1);
assert_relative_eq!(plateau.compute_epoch_learning_rate(2, 2.0), 0.1);
assert_relative_eq!(plateau.compute_epoch_learning_rate(3, 2.0), 0.05);
assert_relative_eq!(plateau.compute_epoch_learning_rate(4, 2.0), 0.05);
assert_relative_eq!(plateau.compute_epoch_learning_rate(5, 2.0), 0.025);
assert_relative_eq!(plateau.compute_epoch_learning_rate(6, 3.0), 0.025);
assert_relative_eq!(plateau.compute_epoch_learning_rate(6, 4.0), 0.025);
assert_relative_eq!(plateau.compute_epoch_learning_rate(6, 5.0), 0.025);
}
#[test]
fn plateau_lr_warmup() {
let mut plateau = PlateauLearningRate::new(ConstantLearningRate::new(0.1, 2), 0.5, 2);
assert_relative_eq!(plateau.compute_step_learning_rate(0), 0.0,);
assert_relative_eq!(plateau.compute_step_learning_rate(1), 0.05);
assert_relative_eq!(plateau.compute_step_learning_rate(2), 0.1);
assert_relative_eq!(plateau.compute_epoch_learning_rate(0, 1.0), 0.1);
assert_relative_eq!(plateau.compute_step_learning_rate(3), 0.1);
assert_relative_eq!(plateau.compute_step_learning_rate(4), 0.1);
assert_relative_eq!(plateau.compute_step_learning_rate(5), 0.1);
assert_relative_eq!(plateau.compute_epoch_learning_rate(1, 2.0), 0.1);
assert_relative_eq!(plateau.compute_epoch_learning_rate(2, 2.0), 0.1);
assert_relative_eq!(plateau.compute_epoch_learning_rate(3, 2.0), 0.05);
assert_relative_eq!(plateau.compute_epoch_learning_rate(4, 2.0), 0.05);
assert_relative_eq!(plateau.compute_epoch_learning_rate(5, 2.0), 0.025);
assert_relative_eq!(plateau.compute_epoch_learning_rate(6, 3.0), 0.025);
assert_relative_eq!(plateau.compute_epoch_learning_rate(6, 4.0), 0.025);
assert_relative_eq!(plateau.compute_epoch_learning_rate(6, 5.0), 0.025);
}
}