concision_init/distr/
lecun.rs1use super::TruncatedNormal;
6use num_traits::Float;
7use rand::RngCore;
8use rand_distr::{Distribution, StandardNormal};
9
10#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
14pub struct LecunNormal {
15 n: usize,
16}
17
18impl LecunNormal {
19 pub const fn new(n: usize) -> Self {
20 Self { n }
21 }
22 pub fn distr<F>(&self) -> crate::Result<TruncatedNormal<F>>
25 where
26 F: Float,
27 StandardNormal: Distribution<F>,
28 {
29 TruncatedNormal::new(F::zero(), self.std_dev())
30 }
31 pub fn std_dev<F>(&self) -> F
36 where
37 F: Float,
38 {
39 F::from(self.n).unwrap().recip().sqrt()
40 }
41}
42
43impl<F> Distribution<F> for LecunNormal
44where
45 F: Float,
46 StandardNormal: Distribution<F>,
47{
48 fn sample<R>(&self, rng: &mut R) -> F
49 where
50 R: RngCore + ?Sized,
51 {
52 self.distr().expect("NormalError").sample(rng)
53 }
54}