astrai 2.2.0

A pretty bad neural network library
Documentation
pub struct Optimizer {
    pub epochs: usize,
    pub initial_learning_rate: f64,
    pub stop_on_loss: Option<f64>,
    pub behaviors: Vec<OptimizingBehavior>,
    current_epoch: usize,
    pub learning_rate: f64,
    pub last_loss: Option<f64>,
}

impl Optimizer {
    pub fn new(
        epochs: usize,
        initial_learning_rate: f64,
        stop_on_loss: Option<f64>,
        behaviors: Vec<OptimizingBehavior>,
    ) -> Optimizer {
        Optimizer {
            epochs,
            initial_learning_rate,
            stop_on_loss,
            behaviors,
            current_epoch: 0,
            learning_rate: initial_learning_rate,
            last_loss: None,
        }
    }

    pub fn tick(&mut self, loss: f64) {
        self.current_epoch += 1;
        if self.current_epoch >= self.epochs {
            return;
        }

        for behavior in self.behaviors.iter() {
            match behavior {
                OptimizingBehavior::Decay { decrease_by } => {
                    self.learning_rate *= decrease_by;
                }
                OptimizingBehavior::DecreaseOnLossDecrease { decrease_by } => {
                    if let Some(last_loss) = self.last_loss {
                        if loss < last_loss {
                            self.learning_rate *= decrease_by;
                        }
                    }
                }

                OptimizingBehavior::IncreaseOnLossIncrease { increase_by } => {
                    if let Some(last_loss) = self.last_loss {
                        if loss > last_loss {
                            self.learning_rate *= increase_by;
                        }
                    }
                }
                OptimizingBehavior::PreventStagnation { threshold, mul_by } => {
                    if let Some(last_loss) = self.last_loss {
                        if loss >= last_loss - threshold {
                            self.learning_rate *= mul_by;
                        }
                    }
                }
                OptimizingBehavior::SubEveryNEpochs {
                    subtract_by,
                    every_n,
                } => {
                    if self.current_epoch % every_n == 0 && self.learning_rate > *subtract_by {
                        self.learning_rate -= subtract_by;
                    }
                }
                OptimizingBehavior::Sine { period, min, max } => {
                    //f\left(x,\ n,\ a,\ p\right)=n+\left(a-n\right)\cdot\left|\sin\left(\frac{x}{p}\cdot\pi\right)\right|
                    let t = self.current_epoch as f64 / *period as f64;
                    let value = min + (max - min) * (t * std::f64::consts::PI).sin().abs();
                    self.learning_rate = value;
                }
                OptimizingBehavior::Sawtooth { period, min, max } => {
                    //s\left(x,\ n,\ a,\ p\right)=n+\left(a-n\right)\cdot\left(-x+\operatorname{floor}\left(x\right)+1\right)
                    let t = self.current_epoch as f64 / *period as f64;
                    let value = min + (max - min) * (-t + t.floor() + 1f64);
                    self.learning_rate = value;
                }
            }
        }

        self.last_loss = Some(loss);
    }
}

pub enum OptimizingBehavior {
    Decay { decrease_by: f64 },
    DecreaseOnLossDecrease { decrease_by: f64 },
    IncreaseOnLossIncrease { increase_by: f64 },
    PreventStagnation { threshold: f64, mul_by: f64 },
    SubEveryNEpochs { subtract_by: f64, every_n: usize },
    Sine { period: usize, min: f64, max: f64 },
    Sawtooth { period: usize, min: f64, max: f64 },
}