use crate::models::{DeepParamsBase, ShallowParamsBase};
use crate::ModelFeatures;
use concision_init::distr as init;
use concision_init::{NdRandom, rand_distr};
use concision_params::ParamsBase;
use ndarray::{DataOwned, Ix2};
use num_traits::{Float, FromPrimitive};
use rand_distr::uniform::{SampleUniform, Uniform};
use rand_distr::{Distribution, StandardNormal};
impl<A, S> ShallowParamsBase<S, Ix2, A>
where
S: DataOwned<Elem = A>,
{
pub fn rand_with<G, Ds>(features: ModelFeatures, distr: G) -> Self
where
G: Fn((usize, usize)) -> Ds,
Ds: Clone + Distribution<A>,
S: DataOwned,
{
Self {
input: ParamsBase::rand(features.dim_input(), distr(features.dim_input())),
hidden: ParamsBase::rand(features.dim_hidden(), distr(features.dim_hidden())),
output: ParamsBase::rand(features.dim_output(), distr(features.dim_output())),
}
}
pub fn glorot_normal(features: ModelFeatures) -> Self
where
A: Float + FromPrimitive,
StandardNormal: Distribution<A>,
{
Self::rand_with(features, |(rows, cols)| init::XavierNormal::new(rows, cols))
}
pub fn glorot_uniform(features: ModelFeatures) -> Self
where
A: Float + FromPrimitive + SampleUniform,
<A as SampleUniform>::Sampler: Clone,
Uniform<A>: Distribution<A>,
{
Self::rand_with(features, |(rows, cols)| {
init::XavierUniform::new(rows, cols).expect("failed to create distribution")
})
}
pub fn init(self) -> Self
where
A: Float + num_traits::FromPrimitive + rand_distr::uniform::SampleUniform,
rand_distr::StandardNormal: rand_distr::Distribution<A>,
{
let hidden = ParamsBase::glorot_normal(self.hidden().dim());
let input = ParamsBase::glorot_normal(self.input().dim());
let output = ParamsBase::glorot_normal(self.output().dim());
Self {
hidden,
input,
output,
}
}
}
impl<A, S> DeepParamsBase<S, Ix2, A>
where
S: DataOwned<Elem = A>,
{
pub fn init_rand<G, Ds>(features: ModelFeatures, distr: G) -> Self
where
G: Fn((usize, usize)) -> Ds,
Ds: Clone + Distribution<A>,
S: DataOwned,
{
let input = ParamsBase::rand(features.dim_input(), distr(features.dim_input()));
let hidden = (0..features.layers())
.map(|_| ParamsBase::rand(features.dim_hidden(), distr(features.dim_hidden())))
.collect::<Vec<_>>();
let output = ParamsBase::rand(features.dim_output(), distr(features.dim_output()));
Self::new(input, hidden, output)
}
pub fn glorot_normal(features: ModelFeatures) -> Self
where
A: Float + FromPrimitive,
StandardNormal: Distribution<A>,
{
Self::init_rand(features, |(rows, cols)| init::XavierNormal::new(rows, cols))
}
pub fn glorot_uniform(features: ModelFeatures) -> Self
where
A: Clone + Float + FromPrimitive + SampleUniform,
<S::Elem as SampleUniform>::Sampler: Clone,
Uniform<S::Elem>: Distribution<S::Elem>,
{
Self::init_rand(features, |(rows, cols)| {
init::XavierUniform::new(rows, cols).expect("failed to create distribution")
})
}
}