concision_core/init/distr/
xavier.rs1use num_traits::{Float, FromPrimitive};
13use rand::Rng;
14use rand_distr::uniform::{SampleUniform, Uniform};
15use rand_distr::{Distribution, Normal, NormalError, StandardNormal};
16
17pub(crate) fn std_dev<F>(inputs: usize, outputs: usize) -> F
18where
19 F: Float + FromPrimitive,
20{
21 (F::from_usize(2).unwrap() / F::from_usize(inputs + outputs).unwrap()).sqrt()
22}
23
24pub(crate) fn boundary<F>(inputs: usize, outputs: usize) -> F
25where
26 F: Float + FromPrimitive,
27{
28 (F::from_usize(6).unwrap() / F::from_usize(inputs + outputs).unwrap()).sqrt()
29}
30#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
33pub struct XavierNormal<F>
34where
35 StandardNormal: Distribution<F>,
36{
37 std: F,
38}
39
40impl<F> XavierNormal<F>
41where
42 F: Float,
43 StandardNormal: Distribution<F>,
44{
45 pub fn new(inputs: usize, outputs: usize) -> Self
46 where
47 F: FromPrimitive,
48 {
49 Self {
50 std: std_dev(inputs, outputs),
51 }
52 }
53
54 pub fn distr(&self) -> Result<Normal<F>, NormalError> {
55 Normal::new(F::zero(), self.std_dev())
56 }
57
58 pub fn std_dev(&self) -> F {
59 self.std
60 }
61}
62
63impl<F> Distribution<F> for XavierNormal<F>
64where
65 F: Float,
66 StandardNormal: Distribution<F>,
67{
68 fn sample<R>(&self, rng: &mut R) -> F
69 where
70 R: Rng + ?Sized,
71 {
72 self.distr().unwrap().sample(rng)
73 }
74}
75
76pub struct XavierUniform<X>
79where
80 X: Float + SampleUniform,
81{
82 distr: Uniform<X>,
83}
84
85impl<X> XavierUniform<X>
86where
87 X: Float + SampleUniform,
88{
89 pub fn new(inputs: usize, outputs: usize) -> Result<Uniform<X>, rand_distr::uniform::Error>
90 where
91 X: FromPrimitive,
92 {
93 let limit = boundary::<X>(inputs, outputs);
94 Uniform::new(-limit, limit)
95 }
96
97 pub const fn distr(&self) -> &Uniform<X> {
98 &self.distr
99 }
100}
101
102impl<X> Distribution<X> for XavierUniform<X>
103where
104 X: Float + SampleUniform,
105{
106 fn sample<R>(&self, rng: &mut R) -> X
107 where
108 R: Rng + ?Sized,
109 {
110 self.distr().sample(rng)
111 }
112}