concision_neural/model/impls/
impl_model_params_rand.rs

1/*
2    appellation: impl_model_params_rand <module>
3    authors: @FL03
4*/
5
6use crate::model::{ModelFeatures, ModelParamsBase};
7
8use cnc::init::{self, Initialize};
9use cnc::params::ParamsBase;
10use cnc::rand_distr;
11
12use ndarray::DataOwned;
13use num_traits::{Float, FromPrimitive};
14use rand_distr::uniform::{SampleUniform, Uniform};
15use rand_distr::{Distribution, StandardNormal};
16
17impl<A, S> ModelParamsBase<S>
18where
19    S: DataOwned<Elem = A>,
20{
21    /// returns a new instance of the model initialized with the given features and random
22    /// distribution
23    pub fn init_rand<G, Ds>(features: ModelFeatures, distr: G) -> Self
24    where
25        G: Fn((usize, usize)) -> Ds,
26        Ds: Clone + Distribution<A>,
27        S: DataOwned,
28    {
29        let input = ParamsBase::rand(features.dim_input(), distr(features.dim_input()));
30        let hidden = (0..features.layers())
31            .map(|_| ParamsBase::rand(features.dim_hidden(), distr(features.dim_hidden())))
32            .collect::<Vec<_>>();
33
34        let output = ParamsBase::rand(features.dim_output(), distr(features.dim_output()));
35
36        Self::new(input, hidden, output)
37    }
38    /// initialize the model parameters using a glorot normal distribution
39    pub fn glorot_normal(features: ModelFeatures) -> Self
40    where
41        A: Float + FromPrimitive,
42        StandardNormal: Distribution<A>,
43    {
44        Self::init_rand(features, |(rows, cols)| {
45            cnc::init::XavierNormal::new(rows, cols)
46        })
47    }
48    /// initialize the model parameters using a glorot uniform distribution
49    pub fn glorot_uniform(features: ModelFeatures) -> Self
50    where
51        A: Clone + Float + FromPrimitive + SampleUniform,
52        <S::Elem as SampleUniform>::Sampler: Clone,
53        Uniform<S::Elem>: Distribution<S::Elem>,
54    {
55        Self::init_rand(features, |(rows, cols)| {
56            init::XavierUniform::new(rows, cols).expect("failed to create distribution")
57        })
58    }
59}