concision_init/distr/
lecun.rs

1/*
2    Appellation: lecun <distr>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::TruncatedNormal;
6use num_traits::Float;
7use rand::RngCore;
8use rand_distr::{Distribution, StandardNormal};
9
10/// [LecunNormal] is a truncated [normal](rand_distr::Normal) distribution centered at 0
11/// with a standard deviation that is calculated as:
12///
13/// $$
14/// \sigma = {n_{in}}^{-\frac{1}{2}}
15/// $$
16///
17/// where $`n_{in}`$ is the number of input units.
18#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
19pub struct LecunNormal {
20    n: usize,
21}
22
23impl LecunNormal {
24    pub const fn new(n: usize) -> Self {
25        Self { n }
26    }
27    /// Create a [truncated normal](TruncatedNormal) [distribution](Distribution) centered at 0;
28    /// See [Self::std_dev] for the standard deviation calculations.
29    pub fn distr<F>(&self) -> crate::Result<TruncatedNormal<F>>
30    where
31        F: Float,
32        StandardNormal: Distribution<F>,
33    {
34        TruncatedNormal::new(F::zero(), self.std_dev())
35    }
36    /// Calculate the standard deviation ($`\sigma`$) of the distribution.
37    /// This is done by computing the root of the reciprocal of the number of inputs
38    /// ($`n_{in}`$) as follows:
39    ///
40    /// ```math
41    /// \sigma = {n_{in}}^{-\frac{1}{2}}
42    /// ```
43    pub fn std_dev<F>(&self) -> F
44    where
45        F: Float,
46    {
47        F::from(self.n).unwrap().recip().sqrt()
48    }
49}
50
51impl<F> Distribution<F> for LecunNormal
52where
53    F: Float,
54    StandardNormal: Distribution<F>,
55{
56    fn sample<R>(&self, rng: &mut R) -> F
57    where
58        R: RngCore + ?Sized,
59    {
60        self.distr().expect("NormalError").sample(rng)
61    }
62}