use serde::{Deserialize, Serialize};
use crate::linalg::{Matrix, MatrixTrait, Scalar};
#[derive(Serialize, Debug, Deserialize, Clone)]
pub enum Initializers {
Zeros,
Uniform,
UniformSigned,
GlorotUniform,
}
impl Initializers {
pub fn gen_matrix(&self, nrow: usize, ncol: usize) -> Matrix {
match self {
Initializers::Zeros => Matrix::zeros(nrow, ncol),
Initializers::Uniform => Matrix::random_uniform(nrow, ncol, 0.0, 1.0),
Initializers::UniformSigned => Matrix::random_uniform(nrow, ncol, -1.0, 1.0),
Initializers::GlorotUniform => {
let limit = (6. / (ncol + nrow) as Scalar).sqrt();
Matrix::random_uniform(nrow, ncol, -limit, limit)
}
}
}
pub fn gen_vector(&self, nrow: usize) -> Matrix {
match self {
Initializers::Zeros => Matrix::zeros(nrow, 1),
Initializers::Uniform => Matrix::random_uniform(nrow, 1, 0.0, 1.0),
Initializers::UniformSigned => Matrix::random_uniform(nrow, 1, -1.0, 1.0),
Initializers::GlorotUniform => {
let limit = (6. / (nrow) as Scalar).sqrt();
Matrix::random_uniform(nrow, 1, -limit, limit)
}
}
}
}