concision_init/distr/
xavier.rs

1/*
2    Appellation: xavier <distr>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5//! # Xavier
6//!
7//! Xavier initialization techniques were developed in 2010 by Xavier Glorot.
8//! These methods are designed to initialize the weights of a neural network in a way that
9//! prevents the vanishing and exploding gradient problems. The initialization technique
10//! manifests into two distributions: [XavierNormal] and [XavierUniform].
11use rand_distr::uniform::{SampleUniform, Uniform};
12use rand_distr::{Distribution, StandardNormal};
13
14/// Normal Xavier initializers leverage a normal distribution with a mean of 0 and a standard deviation (`σ`)
15/// computed by the formula: $`σ = sqrt(2/(d_in + d_out))`$
16#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17pub struct XavierNormal<T>
18where
19    StandardNormal: Distribution<T>,
20{
21    std: T,
22}
23
24/// Uniform Xavier initializers use a uniform distribution to initialize the weights of a neural network
25/// within a given range.
26pub struct XavierUniform<T>
27where
28    T: SampleUniform,
29{
30    distr: Uniform<T>,
31}
32
33/*
34 ************* Implementations *************
35*/
36
37mod impl_normal {
38    use super::XavierNormal;
39    use num_traits::{Float, FromPrimitive};
40    use rand::RngCore;
41    use rand_distr::{Distribution, Normal, StandardNormal};
42
43    fn std_dev<T>(inputs: usize, outputs: usize) -> T
44    where
45        T: FromPrimitive + Float,
46    {
47        let numerator = T::from_usize(2).unwrap();
48        let denominator = T::from_usize(inputs + outputs).unwrap();
49        (numerator / denominator).sqrt()
50    }
51
52    impl<T> XavierNormal<T>
53    where
54        T: Float,
55        StandardNormal: Distribution<T>,
56    {
57        pub fn new(inputs: usize, outputs: usize) -> Self
58        where
59            T: FromPrimitive,
60        {
61            Self {
62                std: std_dev(inputs, outputs),
63            }
64        }
65        /// tries creating a new [`Normal`] distribution with a mean of 0 and the standard
66        /// deviation computed by the formula: $`σ = sqrt(2/(d_in + d_out))`$
67        pub fn distr(&self) -> crate::Result<Normal<T>> {
68            Normal::new(T::zero(), self.std_dev()).map_err(Into::into)
69        }
70        /// returns a reference to the standard deviation of the distribution
71        pub const fn std_dev(&self) -> T {
72            self.std
73        }
74    }
75
76    impl<T> Distribution<T> for XavierNormal<T>
77    where
78        T: Float,
79        StandardNormal: Distribution<T>,
80    {
81        fn sample<R>(&self, rng: &mut R) -> T
82        where
83            R: RngCore + ?Sized,
84        {
85            self.distr().unwrap().sample(rng)
86        }
87    }
88}
89
90mod impl_uniform {
91    use super::XavierUniform;
92    use num_traits::{Float, FromPrimitive};
93    use rand::RngCore;
94    use rand_distr::Distribution;
95    use rand_distr::uniform::{SampleUniform, Uniform};
96
97    fn boundary<U>(inputs: usize, outputs: usize) -> U
98    where
99        U: FromPrimitive + Float,
100    {
101        let numer = <U>::from_usize(6).unwrap();
102        let denom = <U>::from_usize(inputs + outputs).unwrap();
103        (numer / denom).sqrt()
104    }
105
106    impl<T> XavierUniform<T>
107    where
108        T: SampleUniform,
109    {
110        pub fn new(inputs: usize, outputs: usize) -> crate::Result<Self>
111        where
112            T: Float + FromPrimitive,
113        {
114            // calculate the boundary for the uniform distribution
115            let limit = boundary::<T>(inputs, outputs);
116            // create a uniform distribution with the calculated limit
117            let distr = Uniform::new(-limit, limit)?;
118            Ok(Self { distr })
119        }
120        /// returns an immutable reference to the underlying uniform distribution
121        pub(crate) const fn distr(&self) -> &Uniform<T> {
122            &self.distr
123        }
124    }
125
126    impl<T> Distribution<T> for XavierUniform<T>
127    where
128        T: Float + SampleUniform,
129    {
130        fn sample<R>(&self, rng: &mut R) -> T
131        where
132            R: RngCore + ?Sized,
133        {
134            self.distr().sample(rng)
135        }
136    }
137
138    impl<T> Clone for XavierUniform<T>
139    where
140        T: Clone + SampleUniform,
141        <T as SampleUniform>::Sampler: Clone,
142    {
143        fn clone(&self) -> Self {
144            Self {
145                distr: self.distr.clone(),
146            }
147        }
148    }
149
150    impl<T> Copy for XavierUniform<T>
151    where
152        T: Copy + SampleUniform,
153        <T as SampleUniform>::Sampler: Copy,
154    {
155    }
156
157    impl<T> Eq for XavierUniform<T>
158    where
159        T: Eq + SampleUniform,
160        <T as SampleUniform>::Sampler: Eq,
161    {
162    }
163
164    impl<T> PartialEq for XavierUniform<T>
165    where
166        T: PartialEq + SampleUniform,
167        <T as SampleUniform>::Sampler: PartialEq,
168    {
169        fn eq(&self, other: &Self) -> bool {
170            self.distr == other.distr
171        }
172    }
173}