concision_init/distr/
xavier.rs1use rand_distr::uniform::{SampleUniform, Uniform};
12use rand_distr::{Distribution, StandardNormal};
13
14#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
21pub struct XavierNormal<T>
22where
23 StandardNormal: Distribution<T>,
24{
25 std: T,
26}
27
28pub struct XavierUniform<T>
31where
32 T: SampleUniform,
33{
34 distr: Uniform<T>,
35}
36
37mod impl_normal {
42 use super::XavierNormal;
43 use num_traits::{Float, FromPrimitive};
44 use rand::RngCore;
45 use rand_distr::{Distribution, Normal, StandardNormal};
46
47 fn std_dev<T>(inputs: usize, outputs: usize) -> T
48 where
49 T: FromPrimitive + Float,
50 {
51 let numerator = T::from_usize(2).unwrap();
52 let denominator = T::from_usize(inputs + outputs).unwrap();
53 (numerator / denominator).sqrt()
54 }
55
56 impl<T> XavierNormal<T>
57 where
58 T: Float,
59 StandardNormal: Distribution<T>,
60 {
61 pub fn new(inputs: usize, outputs: usize) -> Self
62 where
63 T: FromPrimitive,
64 {
65 Self {
66 std: std_dev(inputs, outputs),
67 }
68 }
69 pub fn distr(&self) -> crate::Result<Normal<T>> {
72 Normal::new(T::zero(), self.std_dev()).map_err(Into::into)
73 }
74 pub const fn std_dev(&self) -> T {
76 self.std
77 }
78 }
79
80 impl<T> Distribution<T> for XavierNormal<T>
81 where
82 T: Float,
83 StandardNormal: Distribution<T>,
84 {
85 fn sample<R>(&self, rng: &mut R) -> T
86 where
87 R: RngCore + ?Sized,
88 {
89 self.distr().unwrap().sample(rng)
90 }
91 }
92}
93
94mod impl_uniform {
95 use super::XavierUniform;
96 use num_traits::{Float, FromPrimitive};
97 use rand::RngCore;
98 use rand_distr::Distribution;
99 use rand_distr::uniform::{SampleUniform, Uniform};
100
101 fn boundary<U>(inputs: usize, outputs: usize) -> U
102 where
103 U: FromPrimitive + Float,
104 {
105 let numer = <U>::from_usize(6).unwrap();
106 let denom = <U>::from_usize(inputs + outputs).unwrap();
107 (numer / denom).sqrt()
108 }
109
110 impl<T> XavierUniform<T>
111 where
112 T: SampleUniform,
113 {
114 pub fn new(inputs: usize, outputs: usize) -> crate::Result<Self>
115 where
116 T: Float + FromPrimitive,
117 {
118 let limit = boundary::<T>(inputs, outputs);
120 let distr = Uniform::new(-limit, limit)?;
122 Ok(Self { distr })
123 }
124 pub(crate) const fn distr(&self) -> &Uniform<T> {
126 &self.distr
127 }
128 }
129
130 impl<T> Distribution<T> for XavierUniform<T>
131 where
132 T: Float + SampleUniform,
133 {
134 fn sample<R>(&self, rng: &mut R) -> T
135 where
136 R: RngCore + ?Sized,
137 {
138 self.distr().sample(rng)
139 }
140 }
141
142 impl<T> Clone for XavierUniform<T>
143 where
144 T: Clone + SampleUniform,
145 <T as SampleUniform>::Sampler: Clone,
146 {
147 fn clone(&self) -> Self {
148 Self {
149 distr: self.distr.clone(),
150 }
151 }
152 }
153
154 impl<T> Copy for XavierUniform<T>
155 where
156 T: Copy + SampleUniform,
157 <T as SampleUniform>::Sampler: Copy,
158 {
159 }
160
161 impl<T> Eq for XavierUniform<T>
162 where
163 T: Eq + SampleUniform,
164 <T as SampleUniform>::Sampler: Eq,
165 {
166 }
167
168 impl<T> PartialEq for XavierUniform<T>
169 where
170 T: PartialEq + SampleUniform,
171 <T as SampleUniform>::Sampler: PartialEq,
172 {
173 fn eq(&self, other: &Self) -> bool {
174 self.distr == other.distr
175 }
176 }
177}