concision_init/distr/
xavier.rs1use rand_distr::uniform::{SampleUniform, Uniform};
12use rand_distr::{Distribution, StandardNormal};
13
14#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17pub struct XavierNormal<T>
18where
19 StandardNormal: Distribution<T>,
20{
21 std: T,
22}
23
24pub struct XavierUniform<T>
27where
28 T: SampleUniform,
29{
30 distr: Uniform<T>,
31}
32
33mod impl_normal {
38 use super::XavierNormal;
39 use num_traits::{Float, FromPrimitive};
40 use rand::RngCore;
41 use rand_distr::{Distribution, Normal, StandardNormal};
42
43 fn std_dev<T>(inputs: usize, outputs: usize) -> T
44 where
45 T: FromPrimitive + Float,
46 {
47 let numerator = T::from_usize(2).unwrap();
48 let denominator = T::from_usize(inputs + outputs).unwrap();
49 (numerator / denominator).sqrt()
50 }
51
52 impl<T> XavierNormal<T>
53 where
54 T: Float,
55 StandardNormal: Distribution<T>,
56 {
57 pub fn new(inputs: usize, outputs: usize) -> Self
58 where
59 T: FromPrimitive,
60 {
61 Self {
62 std: std_dev(inputs, outputs),
63 }
64 }
65 pub fn distr(&self) -> crate::Result<Normal<T>> {
68 Normal::new(T::zero(), self.std_dev()).map_err(Into::into)
69 }
70 pub const fn std_dev(&self) -> T {
72 self.std
73 }
74 }
75
76 impl<T> Distribution<T> for XavierNormal<T>
77 where
78 T: Float,
79 StandardNormal: Distribution<T>,
80 {
81 fn sample<R>(&self, rng: &mut R) -> T
82 where
83 R: RngCore + ?Sized,
84 {
85 self.distr().unwrap().sample(rng)
86 }
87 }
88}
89
90mod impl_uniform {
91 use super::XavierUniform;
92 use num_traits::{Float, FromPrimitive};
93 use rand::RngCore;
94 use rand_distr::Distribution;
95 use rand_distr::uniform::{SampleUniform, Uniform};
96
97 fn boundary<U>(inputs: usize, outputs: usize) -> U
98 where
99 U: FromPrimitive + Float,
100 {
101 let numer = <U>::from_usize(6).unwrap();
102 let denom = <U>::from_usize(inputs + outputs).unwrap();
103 (numer / denom).sqrt()
104 }
105
106 impl<T> XavierUniform<T>
107 where
108 T: SampleUniform,
109 {
110 pub fn new(inputs: usize, outputs: usize) -> crate::Result<Self>
111 where
112 T: Float + FromPrimitive,
113 {
114 let limit = boundary::<T>(inputs, outputs);
116 let distr = Uniform::new(-limit, limit)?;
118 Ok(Self { distr })
119 }
120 pub(crate) const fn distr(&self) -> &Uniform<T> {
122 &self.distr
123 }
124 }
125
126 impl<T> Distribution<T> for XavierUniform<T>
127 where
128 T: Float + SampleUniform,
129 {
130 fn sample<R>(&self, rng: &mut R) -> T
131 where
132 R: RngCore + ?Sized,
133 {
134 self.distr().sample(rng)
135 }
136 }
137
138 impl<T> Clone for XavierUniform<T>
139 where
140 T: Clone + SampleUniform,
141 <T as SampleUniform>::Sampler: Clone,
142 {
143 fn clone(&self) -> Self {
144 Self {
145 distr: self.distr.clone(),
146 }
147 }
148 }
149
150 impl<T> Copy for XavierUniform<T>
151 where
152 T: Copy + SampleUniform,
153 <T as SampleUniform>::Sampler: Copy,
154 {
155 }
156
157 impl<T> Eq for XavierUniform<T>
158 where
159 T: Eq + SampleUniform,
160 <T as SampleUniform>::Sampler: Eq,
161 {
162 }
163
164 impl<T> PartialEq for XavierUniform<T>
165 where
166 T: PartialEq + SampleUniform,
167 <T as SampleUniform>::Sampler: PartialEq,
168 {
169 fn eq(&self, other: &Self) -> bool {
170 self.distr == other.distr
171 }
172 }
173}