pub struct Optimizer {
pub epochs: usize,
pub initial_learning_rate: f64,
pub stop_on_loss: Option<f64>,
pub behaviors: Vec<OptimizingBehavior>,
current_epoch: usize,
pub learning_rate: f64,
pub last_loss: Option<f64>,
}
impl Optimizer {
pub fn new(
epochs: usize,
initial_learning_rate: f64,
stop_on_loss: Option<f64>,
behaviors: Vec<OptimizingBehavior>,
) -> Optimizer {
Optimizer {
epochs,
initial_learning_rate,
stop_on_loss,
behaviors,
current_epoch: 0,
learning_rate: initial_learning_rate,
last_loss: None,
}
}
pub fn tick(&mut self, loss: f64) {
self.current_epoch += 1;
if self.current_epoch >= self.epochs {
return;
}
for behavior in self.behaviors.iter() {
match behavior {
OptimizingBehavior::Decay { decrease_by } => {
self.learning_rate *= decrease_by;
}
OptimizingBehavior::DecreaseOnLossDecrease { decrease_by } => {
if let Some(last_loss) = self.last_loss {
if loss < last_loss {
self.learning_rate *= decrease_by;
}
}
}
OptimizingBehavior::IncreaseOnLossIncrease { increase_by } => {
if let Some(last_loss) = self.last_loss {
if loss > last_loss {
self.learning_rate *= increase_by;
}
}
}
OptimizingBehavior::PreventStagnation { threshold, mul_by } => {
if let Some(last_loss) = self.last_loss {
if loss >= last_loss - threshold {
self.learning_rate *= mul_by;
}
}
}
OptimizingBehavior::SubEveryNEpochs {
subtract_by,
every_n,
} => {
if self.current_epoch % every_n == 0 && self.learning_rate > *subtract_by {
self.learning_rate -= subtract_by;
}
}
OptimizingBehavior::Sine { period, min, max } => {
let t = self.current_epoch as f64 / *period as f64;
let value = min + (max - min) * (t * std::f64::consts::PI).sin().abs();
self.learning_rate = value;
}
OptimizingBehavior::Sawtooth { period, min, max } => {
let t = self.current_epoch as f64 / *period as f64;
let value = min + (max - min) * (-t + t.floor() + 1f64);
self.learning_rate = value;
}
}
}
self.last_loss = Some(loss);
}
}
pub enum OptimizingBehavior {
Decay { decrease_by: f64 },
DecreaseOnLossDecrease { decrease_by: f64 },
IncreaseOnLossIncrease { increase_by: f64 },
PreventStagnation { threshold: f64, mul_by: f64 },
SubEveryNEpochs { subtract_by: f64, every_n: usize },
Sine { period: usize, min: f64, max: f64 },
Sawtooth { period: usize, min: f64, max: f64 },
}