use super::Layer;
use intricate_macros::FromForAllUnnamedVariants;
use rand::prelude::*;
use rand_distr::Normal;
use rand_distr::Uniform;
use savefile::{Introspect, Serialize};
use savefile_derive::Savefile;
use std::ops::Range;
pub trait InitializerTrait
where
Self: std::fmt::Debug + Serialize + Introspect,
{
fn initialize_0d<'a>(&self, layer: &dyn Layer<'a>) -> f32;
fn initialize_1d<'a>(&self, count: usize, layer: &dyn Layer<'a>) -> Vec<f32> {
(0..count).map(|_| self.initialize_0d(layer)).collect()
}
fn initialize_2d<'a>(&self, shape: (usize, usize), layer: &dyn Layer<'a>) -> Vec<Vec<f32>> {
(0..shape.0)
.map(|_| self.initialize_1d(shape.1, layer))
.collect()
}
fn initialize_3d<'a>(
&self,
shape: (usize, usize, usize),
layer: &dyn Layer<'a>,
) -> Vec<Vec<Vec<f32>>> {
(0..shape.0)
.map(|_| self.initialize_2d((shape.1, shape.2), layer))
.collect()
}
}
#[derive(Debug, Clone, Savefile)]
pub struct ConstantInitializer {
pub constant: f32,
}
impl ConstantInitializer {
pub fn new(constant: f32) -> Self {
ConstantInitializer { constant }
}
}
impl InitializerTrait for ConstantInitializer {
fn initialize_0d<'a>(&self, _: &dyn Layer<'a>) -> f32 {
self.constant
}
}
#[derive(Debug, Clone, Savefile)]
pub struct LimitedRandomInitializer {
pub limit_interval: Range<f32>,
}
impl LimitedRandomInitializer {
pub fn new(limit_interval: Range<f32>) -> Self {
LimitedRandomInitializer { limit_interval }
}
}
impl InitializerTrait for LimitedRandomInitializer {
fn initialize_0d<'a>(&self, _layer: &dyn Layer<'a>) -> f32 {
let mut rng = thread_rng();
rng.gen_range(self.limit_interval.clone())
}
}
#[derive(Debug, Clone, Savefile)]
pub struct NormalRandomInitializer {
pub mean: f32,
pub standard_deviation: f32,
}
impl NormalRandomInitializer {
pub fn new(mean: f32, std_dev: f32) -> Self {
NormalRandomInitializer {
mean,
standard_deviation: std_dev,
}
}
}
impl InitializerTrait for NormalRandomInitializer {
fn initialize_0d<'a>(&self, _layer: &dyn Layer<'a>) -> f32 {
let distribution = Normal::new(self.mean, self.standard_deviation)
.expect("Unable to create Normal distribution for the NormaRandomInitalizer");
let mut rng = thread_rng();
distribution.sample(&mut rng)
}
fn initialize_1d<'a>(&self, count: usize, _layer: &dyn Layer<'a>) -> Vec<f32> {
let distribution = Normal::new(self.mean, self.standard_deviation)
.expect("Unable to create Normal distribution for the NormaRandomInitalizer");
let mut rng = thread_rng();
(0..count).map(|_| distribution.sample(&mut rng)).collect()
}
fn initialize_2d<'a>(&self, shape: (usize, usize), _layer: &dyn Layer<'a>) -> Vec<Vec<f32>> {
let distribution = Normal::new(self.mean, self.standard_deviation)
.expect("Unable to create Normal distribution for the NormaRandomInitalizer");
let mut rng = thread_rng();
(0..shape.0)
.map(|_| {
(0..shape.1)
.map(|_| distribution.sample(&mut rng))
.collect()
})
.collect()
}
fn initialize_3d<'a>(
&self,
shape: (usize, usize, usize),
_layer: &dyn Layer<'a>,
) -> Vec<Vec<Vec<f32>>> {
let distribution = Normal::new(self.mean, self.standard_deviation)
.expect("Unable to create Normal distribution for the NormaRandomInitalizer");
let mut rng = thread_rng();
(0..shape.0)
.map(|_| {
(0..shape.1)
.map(|_| {
(0..shape.2)
.map(|_| distribution.sample(&mut rng))
.collect()
})
.collect()
})
.collect()
}
}
#[derive(Debug, Clone, Savefile)]
pub struct UniformRandomInitializer {
pub interval: Range<f32>,
}
impl UniformRandomInitializer {
pub fn new(interval: Range<f32>) -> Self {
UniformRandomInitializer { interval }
}
}
impl InitializerTrait for UniformRandomInitializer {
fn initialize_0d<'a>(&self, _layer: &dyn Layer<'a>) -> f32 {
let distribution = Uniform::new(self.interval.start, self.interval.end);
let mut rng = thread_rng();
distribution.sample(&mut rng)
}
fn initialize_1d<'a>(&self, count: usize, _layer: &dyn Layer<'a>) -> Vec<f32> {
let distribution = Uniform::new(self.interval.start, self.interval.end);
let mut rng = thread_rng();
(0..count).map(|_| distribution.sample(&mut rng)).collect()
}
fn initialize_2d<'a>(&self, shape: (usize, usize), _layer: &dyn Layer<'a>) -> Vec<Vec<f32>> {
let distribution = Uniform::new(self.interval.start, self.interval.end);
let mut rng = thread_rng();
(0..shape.0)
.map(|_| {
(0..shape.1)
.map(|_| distribution.sample(&mut rng))
.collect()
})
.collect()
}
fn initialize_3d<'a>(
&self,
shape: (usize, usize, usize),
_layer: &dyn Layer<'a>,
) -> Vec<Vec<Vec<f32>>> {
let distribution = Uniform::new(self.interval.start, self.interval.end);
let mut rng = thread_rng();
(0..shape.0)
.map(|_| {
(0..shape.1)
.map(|_| {
(0..shape.2)
.map(|_| distribution.sample(&mut rng))
.collect()
})
.collect()
})
.collect()
}
}
#[derive(Debug, Clone, Savefile)]
pub struct GlorotUniformInitializer();
impl GlorotUniformInitializer {
pub fn new() -> Self {
GlorotUniformInitializer()
}
}
impl InitializerTrait for GlorotUniformInitializer {
fn initialize_0d<'a>(&self, layer: &dyn Layer<'a>) -> f32 {
let fan_in = layer.get_inputs_amount() as f32;
let fan_out = layer.get_outputs_amount() as f32;
let limit = (6.0 / (fan_in + fan_out)).sqrt();
let distribution = Uniform::new(-limit, limit);
let mut rng = thread_rng();
distribution.sample(&mut rng) as f32
}
fn initialize_1d<'a>(&self, count: usize, layer: &dyn Layer<'a>) -> Vec<f32> {
let fan_in = layer.get_inputs_amount() as f32;
let fan_out = layer.get_outputs_amount() as f32;
let limit = (6.0 / (fan_in + fan_out)).sqrt();
let distribution = Uniform::new(-limit, limit);
let mut rng = thread_rng();
(0..count).map(|_| distribution.sample(&mut rng)).collect()
}
fn initialize_2d<'a>(&self, shape: (usize, usize), layer: &dyn Layer<'a>) -> Vec<Vec<f32>> {
let fan_in = layer.get_inputs_amount() as f32;
let fan_out = layer.get_outputs_amount() as f32;
let limit = (6.0 / (fan_in + fan_out)).sqrt();
let distribution = Uniform::new(-limit, limit);
let mut rng = thread_rng();
(0..shape.0)
.map(|_| {
(0..shape.1)
.map(|_| distribution.sample(&mut rng))
.collect()
})
.collect()
}
fn initialize_3d<'a>(
&self,
shape: (usize, usize, usize),
layer: &dyn Layer<'a>,
) -> Vec<Vec<Vec<f32>>> {
let fan_in = layer.get_inputs_amount() as f32;
let fan_out = layer.get_outputs_amount() as f32;
let limit = (6.0 / (fan_in + fan_out)).sqrt();
let distribution = Uniform::new(-limit, limit);
let mut rng = thread_rng();
(0..shape.0)
.map(|_| {
(0..shape.1)
.map(|_| {
(0..shape.2)
.map(|_| distribution.sample(&mut rng))
.collect()
})
.collect()
})
.collect()
}
}
#[derive(Debug, Clone, Savefile)]
pub struct GlorotNormalInitializer();
impl GlorotNormalInitializer {
pub fn new() -> Self {
GlorotNormalInitializer()
}
}
impl InitializerTrait for GlorotNormalInitializer {
fn initialize_0d<'a>(&self, layer: &dyn Layer<'a>) -> f32 {
let fan_in = layer.get_inputs_amount() as f32;
let fan_out = layer.get_outputs_amount() as f32;
let mean = 0.0f32;
let std_dev = (2.0f32 / (fan_in + fan_out)).sqrt();
let distribution = Normal::new(mean, std_dev)
.expect("Unable to create Normal distribution for the GlorotNormalInitializer");
let mut rng = thread_rng();
distribution.sample(&mut rng) as f32
}
fn initialize_1d<'a>(&self, count: usize, layer: &dyn Layer<'a>) -> Vec<f32> {
let fan_in = layer.get_inputs_amount() as f32;
let fan_out = layer.get_outputs_amount() as f32;
let mean = 0.0f32;
let std_dev = (2.0f32 / (fan_in + fan_out)).sqrt();
let distribution = Normal::new(mean, std_dev)
.expect("Unable to create Normal distribution for the GlorotNormalInitializer");
let mut rng = thread_rng();
(0..count).map(|_| distribution.sample(&mut rng)).collect()
}
fn initialize_2d<'a>(&self, shape: (usize, usize), layer: &dyn Layer<'a>) -> Vec<Vec<f32>> {
let fan_in = layer.get_inputs_amount() as f32;
let fan_out = layer.get_outputs_amount() as f32;
let mean = 0.0f32;
let std_dev = (2.0f32 / (fan_in + fan_out)).sqrt();
let distribution = Normal::new(mean, std_dev)
.expect("Unable to create Normal distribution for the GlorotNormalInitializer");
let mut rng = thread_rng();
(0..shape.0)
.map(|_| {
(0..shape.1)
.map(|_| distribution.sample(&mut rng))
.collect()
})
.collect()
}
fn initialize_3d<'a>(
&self,
shape: (usize, usize, usize),
layer: &dyn Layer<'a>,
) -> Vec<Vec<Vec<f32>>> {
let fan_in = layer.get_inputs_amount() as f32;
let fan_out = layer.get_outputs_amount() as f32;
let mean = 0.0f32;
let std_dev = (2.0f32 / (fan_in + fan_out)).sqrt();
let distribution = Normal::new(mean, std_dev)
.expect("Unable to create Normal distribution for the GlorotNormalInitializer");
let mut rng = thread_rng();
(0..shape.0)
.map(|_| {
(0..shape.1)
.map(|_| {
(0..shape.2)
.map(|_| distribution.sample(&mut rng))
.collect()
})
.collect()
})
.collect()
}
}
#[derive(Debug, Clone, Savefile, FromForAllUnnamedVariants)]
pub enum Initializer {
Constant(ConstantInitializer),
LimitedRandom(LimitedRandomInitializer),
UniformRandom(UniformRandomInitializer),
NormalRandom(NormalRandomInitializer),
GlorotNormal(GlorotNormalInitializer),
GlorotUniform(GlorotUniformInitializer),
}
impl InitializerTrait for Initializer {
fn initialize_0d<'a>(&self, layer: &dyn Layer<'a>) -> f32 {
match self {
Initializer::Constant(i) => i.initialize_0d(layer),
Initializer::LimitedRandom(i) => i.initialize_0d(layer),
Initializer::UniformRandom(i) => i.initialize_0d(layer),
Initializer::NormalRandom(i) => i.initialize_0d(layer),
Initializer::GlorotNormal(i) => i.initialize_0d(layer),
Initializer::GlorotUniform(i) => i.initialize_0d(layer),
}
}
}
impl Default for Initializer {
fn default() -> Self {
Self::GlorotUniform(GlorotUniformInitializer::new())
}
}