concision_init/distr/lecun.rs
1/*
2 Appellation: lecun <distr>
3 Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::TruncatedNormal;
6use num_traits::Float;
7use rand::RngCore;
8use rand_distr::{Distribution, StandardNormal};
9
10/// [LecunNormal] is a truncated [normal](rand_distr::Normal) distribution centered at 0
11/// with a standard deviation that is calculated as:
12///
13/// $$
14/// \sigma = {n_{in}}^{-\frac{1}{2}}
15/// $$
16///
17/// where $`n_{in}`$ is the number of input units.
18#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
19pub struct LecunNormal {
20 n: usize,
21}
22
23impl LecunNormal {
24 pub const fn new(n: usize) -> Self {
25 Self { n }
26 }
27 /// Create a [truncated normal](TruncatedNormal) [distribution](Distribution) centered at 0;
28 /// See [Self::std_dev] for the standard deviation calculations.
29 pub fn distr<F>(&self) -> crate::Result<TruncatedNormal<F>>
30 where
31 F: Float,
32 StandardNormal: Distribution<F>,
33 {
34 TruncatedNormal::new(F::zero(), self.std_dev())
35 }
36 /// Calculate the standard deviation ($`\sigma`$) of the distribution.
37 /// This is done by computing the root of the reciprocal of the number of inputs
38 /// ($`n_{in}`$) as follows:
39 ///
40 /// ```math
41 /// \sigma = {n_{in}}^{-\frac{1}{2}}
42 /// ```
43 pub fn std_dev<F>(&self) -> F
44 where
45 F: Float,
46 {
47 F::from(self.n).unwrap().recip().sqrt()
48 }
49}
50
51impl<F> Distribution<F> for LecunNormal
52where
53 F: Float,
54 StandardNormal: Distribution<F>,
55{
56 fn sample<R>(&self, rng: &mut R) -> F
57 where
58 R: RngCore + ?Sized,
59 {
60 self.distr().expect("NormalError").sample(rng)
61 }
62}