jiro_nn 0.8.0

Neural Networks framework with model building & data preprocessing features.
Documentation
use serde::{Deserialize, Serialize};

use super::image::Image;
use crate::{linalg::Scalar, vision::image::ImageTrait};

#[derive(Serialize, Debug, Deserialize, Clone)]
pub enum ConvInitializers {
    Zeros,
    Uniform,
    UniformSigned,
    GlorotUniform,
}

impl ConvInitializers {
    pub fn gen_image(&self, nrow: usize, ncol: usize, nchan: usize, nsample: usize) -> Image {
        match self {
            ConvInitializers::Zeros => Image::zeros(nrow, ncol, nchan, nsample),
            ConvInitializers::Uniform => {
                Image::random_uniform(nrow, ncol, nchan, nsample, 0.0, 1.0)
            }
            ConvInitializers::UniformSigned => {
                Image::random_uniform(nrow, ncol, nchan, nsample, -1.0, 1.0)
            }
            ConvInitializers::GlorotUniform => {
                let limit = (6. / (ncol * nrow + ncol * nrow) as Scalar).sqrt();
                Image::random_uniform(nrow, ncol, nchan, nsample, -limit, limit)
            }
        }
    }
}