concision_neural/model/impls/
impl_model_params_rand.rs1use crate::model::{ModelFeatures, ModelParamsBase};
7
8use cnc::init::{self, Initialize};
9use cnc::params::ParamsBase;
10use cnc::rand_distr;
11
12use ndarray::DataOwned;
13use num_traits::{Float, FromPrimitive};
14use rand_distr::uniform::{SampleUniform, Uniform};
15use rand_distr::{Distribution, StandardNormal};
16
17impl<A, S> ModelParamsBase<S>
18where
19 S: DataOwned<Elem = A>,
20{
21 pub fn init_rand<G, Ds>(features: ModelFeatures, distr: G) -> Self
24 where
25 G: Fn((usize, usize)) -> Ds,
26 Ds: Clone + Distribution<A>,
27 S: DataOwned,
28 {
29 let input = ParamsBase::rand(features.dim_input(), distr(features.dim_input()));
30 let hidden = (0..features.layers())
31 .map(|_| ParamsBase::rand(features.dim_hidden(), distr(features.dim_hidden())))
32 .collect::<Vec<_>>();
33
34 let output = ParamsBase::rand(features.dim_output(), distr(features.dim_output()));
35
36 Self::new(input, hidden, output)
37 }
38 pub fn glorot_normal(features: ModelFeatures) -> Self
40 where
41 A: Float + FromPrimitive,
42 StandardNormal: Distribution<A>,
43 {
44 Self::init_rand(features, |(rows, cols)| {
45 cnc::init::XavierNormal::new(rows, cols)
46 })
47 }
48 pub fn glorot_uniform(features: ModelFeatures) -> Self
50 where
51 A: Clone + Float + FromPrimitive + SampleUniform,
52 <S::Elem as SampleUniform>::Sampler: Clone,
53 Uniform<S::Elem>: Distribution<S::Elem>,
54 {
55 Self::init_rand(features, |(rows, cols)| {
56 init::XavierUniform::new(rows, cols).expect("failed to create distribution")
57 })
58 }
59}