jiro_nn 0.7.1

Neural Networks framework with model building & data preprocessing features.
Documentation
use serde::{Deserialize, Serialize};

use crate::linalg::Scalar;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InverseTimeDecay {
    pub initial_learning_rate: Scalar,
    pub decay_steps: Scalar,
    pub decay_rate: Scalar,
    #[serde(default)]
    pub staircase: bool,
}

impl InverseTimeDecay {
    pub fn new(
        initial_learning_rate: Scalar,
        decay_steps: Scalar,
        decay_rate: Scalar,
        staircase: bool,
    ) -> Self {
        Self {
            initial_learning_rate,
            decay_steps,
            decay_rate,
            staircase,
        }
    }

    pub fn get_learning_rate(&self, epoch: usize) -> Scalar {
        let mut learning_rate = self.initial_learning_rate
            / (1. + self.decay_rate * (epoch as Scalar / self.decay_steps));
        if self.staircase {
            learning_rate = self.initial_learning_rate
                / (1. + self.decay_rate * (epoch as Scalar / self.decay_steps).floor());
        }
        learning_rate
    }
}