concision_init/distr/
xavier.rs

1/*
2    Appellation: xavier <distr>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5//! # Xavier
6//!
7//! Xavier initialization techniques were developed in 2010 by Xavier Glorot.
8//! These methods are designed to initialize the weights of a neural network in a way that
9//! prevents the vanishing and exploding gradient problems. The initialization technique
10//! manifests into two distributions: [XavierNormal] and [XavierUniform].
11use rand_distr::uniform::{SampleUniform, Uniform};
12use rand_distr::{Distribution, StandardNormal};
13
14/// Normal Xavier initializers leverage a normal distribution centered around `0` and using a
15/// standard deviation ($\sigma$) computed by:
16///
17/// ```math
18/// \sigma = \sqrt{\frac{2}{d_{in} + d_{out}}}
19/// ```
20#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
21pub struct XavierNormal<T>
22where
23    StandardNormal: Distribution<T>,
24{
25    std: T,
26}
27
28/// Uniform Xavier initializers use a uniform distribution to initialize the weights of a neural network
29/// within a given range.
30pub struct XavierUniform<T>
31where
32    T: SampleUniform,
33{
34    distr: Uniform<T>,
35}
36
37/*
38 ************* Implementations *************
39*/
40
41mod impl_normal {
42    use super::XavierNormal;
43    use num_traits::{Float, FromPrimitive};
44    use rand::RngCore;
45    use rand_distr::{Distribution, Normal, StandardNormal};
46
47    fn std_dev<T>(inputs: usize, outputs: usize) -> T
48    where
49        T: FromPrimitive + Float,
50    {
51        let numerator = T::from_usize(2).unwrap();
52        let denominator = T::from_usize(inputs + outputs).unwrap();
53        (numerator / denominator).sqrt()
54    }
55
56    impl<T> XavierNormal<T>
57    where
58        T: Float,
59        StandardNormal: Distribution<T>,
60    {
61        pub fn new(inputs: usize, outputs: usize) -> Self
62        where
63            T: FromPrimitive,
64        {
65            Self {
66                std: std_dev(inputs, outputs),
67            }
68        }
69        /// tries creating a new [`Normal`] distribution with a mean of 0 and the computed
70        /// standard deviation ($\sigma$) based on the number of inputs and outputs.
71        pub fn distr(&self) -> crate::Result<Normal<T>> {
72            Normal::new(T::zero(), self.std_dev()).map_err(Into::into)
73        }
74        /// returns a reference to the standard deviation of the distribution
75        pub const fn std_dev(&self) -> T {
76            self.std
77        }
78    }
79
80    impl<T> Distribution<T> for XavierNormal<T>
81    where
82        T: Float,
83        StandardNormal: Distribution<T>,
84    {
85        fn sample<R>(&self, rng: &mut R) -> T
86        where
87            R: RngCore + ?Sized,
88        {
89            self.distr().unwrap().sample(rng)
90        }
91    }
92}
93
94mod impl_uniform {
95    use super::XavierUniform;
96    use num_traits::{Float, FromPrimitive};
97    use rand::RngCore;
98    use rand_distr::Distribution;
99    use rand_distr::uniform::{SampleUniform, Uniform};
100
101    fn boundary<U>(inputs: usize, outputs: usize) -> U
102    where
103        U: FromPrimitive + Float,
104    {
105        let numer = <U>::from_usize(6).unwrap();
106        let denom = <U>::from_usize(inputs + outputs).unwrap();
107        (numer / denom).sqrt()
108    }
109
110    impl<T> XavierUniform<T>
111    where
112        T: SampleUniform,
113    {
114        pub fn new(inputs: usize, outputs: usize) -> crate::Result<Self>
115        where
116            T: Float + FromPrimitive,
117        {
118            // calculate the boundary for the uniform distribution
119            let limit = boundary::<T>(inputs, outputs);
120            // create a uniform distribution with the calculated limit
121            let distr = Uniform::new(-limit, limit)?;
122            Ok(Self { distr })
123        }
124        /// returns an immutable reference to the underlying uniform distribution
125        pub(crate) const fn distr(&self) -> &Uniform<T> {
126            &self.distr
127        }
128    }
129
130    impl<T> Distribution<T> for XavierUniform<T>
131    where
132        T: Float + SampleUniform,
133    {
134        fn sample<R>(&self, rng: &mut R) -> T
135        where
136            R: RngCore + ?Sized,
137        {
138            self.distr().sample(rng)
139        }
140    }
141
142    impl<T> Clone for XavierUniform<T>
143    where
144        T: Clone + SampleUniform,
145        <T as SampleUniform>::Sampler: Clone,
146    {
147        fn clone(&self) -> Self {
148            Self {
149                distr: self.distr.clone(),
150            }
151        }
152    }
153
154    impl<T> Copy for XavierUniform<T>
155    where
156        T: Copy + SampleUniform,
157        <T as SampleUniform>::Sampler: Copy,
158    {
159    }
160
161    impl<T> Eq for XavierUniform<T>
162    where
163        T: Eq + SampleUniform,
164        <T as SampleUniform>::Sampler: Eq,
165    {
166    }
167
168    impl<T> PartialEq for XavierUniform<T>
169    where
170        T: PartialEq + SampleUniform,
171        <T as SampleUniform>::Sampler: PartialEq,
172    {
173        fn eq(&self, other: &Self) -> bool {
174            self.distr == other.distr
175        }
176    }
177}