concision_core/init/distr/
xavier.rs1use num::Float;
13use rand::Rng;
14use rand_distr::uniform::{SampleUniform, Uniform};
15use rand_distr::{Distribution, Normal, NormalError, StandardNormal};
16
17pub(crate) fn std_dev<F>(inputs: usize, outputs: usize) -> F
18where
19 F: Float,
20{
21 (F::from(2).unwrap() / F::from(inputs + outputs).unwrap()).sqrt()
22}
23
24pub(crate) fn boundary<F>(inputs: usize, outputs: usize) -> F
25where
26 F: Float,
27{
28 (F::from(6).unwrap() / F::from(inputs + outputs).unwrap()).sqrt()
29}
30#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
33pub struct XavierNormal<F>
34where
35 F: Float,
36 StandardNormal: Distribution<F>,
37{
38 std: F,
39}
40
41impl<F> XavierNormal<F>
42where
43 F: Float,
44 StandardNormal: Distribution<F>,
45{
46 pub fn new(inputs: usize, outputs: usize) -> Self {
47 Self {
48 std: std_dev(inputs, outputs),
49 }
50 }
51
52 pub fn distr(&self) -> Result<Normal<F>, NormalError> {
53 Normal::new(F::zero(), self.std_dev())
54 }
55
56 pub fn std_dev(&self) -> F {
57 self.std
58 }
59}
60
61impl<F> Distribution<F> for XavierNormal<F>
62where
63 F: Float,
64 StandardNormal: Distribution<F>,
65{
66 fn sample<R>(&self, rng: &mut R) -> F
67 where
68 R: Rng + ?Sized,
69 {
70 self.distr().unwrap().sample(rng)
71 }
72}
73
74#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
77pub struct XavierUniform<X>
78where
79 X: SampleUniform,
80{
81 boundary: X,
82}
83
84impl<X> XavierUniform<X>
85where
86 X: Float + SampleUniform,
87{
88 pub fn new(inputs: usize, outputs: usize) -> Self {
89 Self {
90 boundary: boundary(inputs, outputs),
91 }
92 }
93
94 pub fn boundary(&self) -> X {
95 self.boundary
96 }
97
98 pub fn distr(&self) -> Uniform<X>
99 where
100 X: Float,
101 {
102 let bnd = self.boundary();
103 Uniform::new(-bnd, bnd)
104 }
105}
106
107impl<X> Distribution<X> for XavierUniform<X>
108where
109 X: Float + SampleUniform,
110{
111 fn sample<R>(&self, rng: &mut R) -> X
112 where
113 R: Rng + ?Sized,
114 {
115 self.distr().sample(rng)
116 }
117}