jiro_nn 0.5.0

Neural Networks framework with model specification & data preprocessing features.
Documentation
use serde::{Deserialize, Serialize};

use crate::linalg::Scalar;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PiecewiseConstant {
    pub boundaries: Vec<usize>,
    pub values: Vec<Scalar>,
}

impl PiecewiseConstant {
    pub fn new(boundaries: Vec<usize>, values: Vec<Scalar>) -> Self {
        Self { boundaries, values }
    }

    pub fn get_learning_rate(&self, epoch: usize) -> Scalar {
        let mut learning_rate = self.values[0];
        for (i, boundary) in self.boundaries.iter().enumerate() {
            if epoch >= *boundary {
                learning_rate = self.values[i + 1];
            }
        }
        learning_rate
    }
}