concision_neural/model/params/
impl_model_params_rand.rs1use crate::model::params::{DeepParamsBase, ShallowParamsBase};
7
8use crate::model::ModelFeatures;
9use cnc::init::{self, Initialize};
10use cnc::params::ParamsBase;
11use ndarray::{DataOwned, Ix2};
12use num_traits::{Float, FromPrimitive};
13
14use rand_distr::uniform::{SampleUniform, Uniform};
15use rand_distr::{Distribution, StandardNormal};
16
17impl<A, S> ShallowParamsBase<S, Ix2>
18where
19 S: DataOwned<Elem = A>,
20{
21 pub fn init(self) -> Self
23 where
24 A: Float + num_traits::FromPrimitive + rand_distr::uniform::SampleUniform,
25 rand_distr::StandardNormal: rand_distr::Distribution<A>,
26 {
27 let hidden = ParamsBase::glorot_normal(self.hidden().dim());
29 let input = ParamsBase::glorot_normal(self.input().dim());
31 let output = ParamsBase::glorot_normal(self.output().dim());
33 Self {
35 hidden,
36 input,
37 output,
38 }
39 }
40 pub fn init_rand<G, Ds>(features: ModelFeatures, distr: G) -> Self
43 where
44 G: Fn((usize, usize)) -> Ds,
45 Ds: Clone + Distribution<A>,
46 S: DataOwned,
47 {
48 Self {
49 input: ParamsBase::rand(features.dim_input(), distr(features.dim_input())),
50 hidden: ParamsBase::rand(features.dim_hidden(), distr(features.dim_hidden())),
51 output: ParamsBase::rand(features.dim_output(), distr(features.dim_output())),
52 }
53 }
54 pub fn glorot_normal(features: ModelFeatures) -> Self
56 where
57 A: Float + FromPrimitive,
58 StandardNormal: Distribution<A>,
59 {
60 Self::init_rand(features, |(rows, cols)| {
61 cnc::init::XavierNormal::new(rows, cols)
62 })
63 }
64 pub fn glorot_uniform(features: ModelFeatures) -> Self
66 where
67 A: Float + FromPrimitive + SampleUniform,
68 <A as SampleUniform>::Sampler: Clone,
69 Uniform<A>: Distribution<A>,
70 {
71 Self::init_rand(features, |(rows, cols)| {
72 init::XavierUniform::new(rows, cols).expect("failed to create distribution")
73 })
74 }
75}
76
77impl<A, S> DeepParamsBase<S, Ix2>
78where
79 S: DataOwned<Elem = A>,
80{
81 pub fn init_rand<G, Ds>(features: ModelFeatures, distr: G) -> Self
84 where
85 G: Fn((usize, usize)) -> Ds,
86 Ds: Clone + Distribution<A>,
87 S: DataOwned,
88 {
89 let input = ParamsBase::rand(features.dim_input(), distr(features.dim_input()));
90 let hidden = (0..features.layers())
91 .map(|_| ParamsBase::rand(features.dim_hidden(), distr(features.dim_hidden())))
92 .collect::<Vec<_>>();
93
94 let output = ParamsBase::rand(features.dim_output(), distr(features.dim_output()));
95
96 Self::new(input, hidden, output)
97 }
98 pub fn glorot_normal(features: ModelFeatures) -> Self
100 where
101 A: Float + FromPrimitive,
102 StandardNormal: Distribution<A>,
103 {
104 Self::init_rand(features, |(rows, cols)| {
105 cnc::init::XavierNormal::new(rows, cols)
106 })
107 }
108 pub fn glorot_uniform(features: ModelFeatures) -> Self
110 where
111 A: Clone + Float + FromPrimitive + SampleUniform,
112 <S::Elem as SampleUniform>::Sampler: Clone,
113 Uniform<S::Elem>: Distribution<S::Elem>,
114 {
115 Self::init_rand(features, |(rows, cols)| {
116 init::XavierUniform::new(rows, cols).expect("failed to create distribution")
117 })
118 }
119}