concision_init/distr/
lecun.rs

1/*
2    Appellation: lecun <distr>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::TruncatedNormal;
6use num_traits::Float;
7use rand::RngCore;
8use rand_distr::{Distribution, StandardNormal};
9
10/// [LecunNormal] is a truncated [normal](rand_distr::Normal) distribution centered at 0
11/// with a standard deviation that is calculated as $`σ = sqrt(1/n_in)`$
12/// where $`n_in`$ is the number of input units.
13#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
14pub struct LecunNormal {
15    n: usize,
16}
17
18impl LecunNormal {
19    pub const fn new(n: usize) -> Self {
20        Self { n }
21    }
22    /// Create a [truncated normal](TruncatedNormal) [distribution](Distribution) centered at 0;
23    /// See [Self::std_dev] for the standard deviation calculations.
24    pub fn distr<F>(&self) -> crate::Result<TruncatedNormal<F>>
25    where
26        F: Float,
27        StandardNormal: Distribution<F>,
28    {
29        TruncatedNormal::new(F::zero(), self.std_dev())
30    }
31    /// Calculate the standard deviation ($`σ`$) of the distribution.
32    /// This is done by computing the root of the reciprocal of the number of inputs
33    ///
34    /// Symbolically: $`σ = sqrt(1/n)`$
35    pub fn std_dev<F>(&self) -> F
36    where
37        F: Float,
38    {
39        F::from(self.n).unwrap().recip().sqrt()
40    }
41}
42
43impl<F> Distribution<F> for LecunNormal
44where
45    F: Float,
46    StandardNormal: Distribution<F>,
47{
48    fn sample<R>(&self, rng: &mut R) -> F
49    where
50        R: RngCore + ?Sized,
51    {
52        self.distr().expect("NormalError").sample(rng)
53    }
54}