1use super::distr::*;
6
7use core::ops::Neg;
8use ndarray::{ArrayBase, DataOwned, Dimension, RawData, Shape, ShapeBuilder};
9use num_traits::{Float, FromPrimitive};
10use rand::rngs::{SmallRng, StdRng};
11use rand::{Rng, SeedableRng};
12use rand_distr::uniform::{SampleUniform, Uniform};
13use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, NormalError, StandardNormal};
14
15pub trait Initialize<S, D>: Sized
22where
23 D: Dimension,
24 S: RawData,
25{
26 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
27 where
28 Ds: Distribution<S::Elem>,
29 Sh: ShapeBuilder<Dim = D>,
30 S: DataOwned;
31
32 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
33 where
34 R: Rng + ?Sized,
35 Ds: Distribution<S::Elem>,
36 Sh: ShapeBuilder<Dim = D>,
37 S: DataOwned;
38
39 fn bernoulli<Sh>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
40 where
41 Bernoulli: Distribution<S::Elem>,
42 S: DataOwned,
43 Sh: ShapeBuilder<Dim = D>,
44 {
45 let dist = Bernoulli::new(p)?;
46 Ok(Self::rand(shape, dist))
47 }
48 fn glorot_normal<Sh>(shape: Sh, inputs: usize, outputs: usize) -> Self
50 where
51 S::Elem: Float + FromPrimitive,
52 StandardNormal: Distribution<S::Elem>,
53 S: DataOwned,
54 Sh: ShapeBuilder<Dim = D>,
55 {
56 let distr = XavierNormal::new(inputs, outputs);
57 Self::rand(shape, distr)
58 }
59 fn glorot_uniform<Sh>(shape: Sh, inputs: usize, outputs: usize) -> super::UniformResult<Self>
61 where
62 S: DataOwned,
63 Sh: ShapeBuilder<Dim = D>,
64 S::Elem: Float + FromPrimitive + SampleUniform,
65 <S::Elem as SampleUniform>::Sampler: Clone,
66 {
67 let distr = XavierUniform::new(inputs, outputs)?;
68 Ok(Self::rand(shape, distr))
69 }
70 fn lecun_normal<Sh>(shape: Sh) -> Self
75 where
76 StandardNormal: Distribution<S::Elem>,
77 S: DataOwned,
78 Sh: ShapeBuilder<Dim = D>,
79 S::Elem: Float,
80 {
81 let shape = shape.into_shape_with_order();
82 let distr = LecunNormal::new(shape.size());
83 Self::rand(shape, distr)
84 }
85 fn normal<Sh>(shape: Sh, mean: S::Elem, std: S::Elem) -> Result<Self, NormalError>
87 where
88 StandardNormal: Distribution<S::Elem>,
89 S: DataOwned,
90 Sh: ShapeBuilder<Dim = D>,
91 S::Elem: Float,
92 {
93 let distr = Normal::new(mean, std)?;
94 Ok(Self::rand(shape, distr))
95 }
96 #[cfg(feature = "complex")]
97 fn randc<Sh>(shape: Sh, re: S::Elem, im: S::Elem) -> Self
98 where
99 S: DataOwned,
100 Sh: ShapeBuilder<Dim = D>,
101 num_complex::ComplexDistribution<S::Elem, S::Elem>: Distribution<S::Elem>,
102 {
103 let distr = num_complex::ComplexDistribution::new(re, im);
104 Self::rand(shape, &distr)
105 }
106 fn stdnorm<Sh>(shape: Sh) -> Self
108 where
109 StandardNormal: Distribution<S::Elem>,
110 S: DataOwned,
111 Sh: ShapeBuilder<Dim = D>,
112 {
113 Self::rand(shape, StandardNormal)
114 }
115 fn stdnorm_from_seed<Sh>(shape: Sh, seed: u64) -> Self
117 where
118 StandardNormal: Distribution<S::Elem>,
119 S: DataOwned,
120 Sh: ShapeBuilder<Dim = D>,
121 {
122 Self::rand_with(shape, StandardNormal, &mut StdRng::seed_from_u64(seed))
123 }
124 fn truncnorm<Sh>(shape: Sh, mean: S::Elem, std: S::Elem) -> Result<Self, NormalError>
126 where
127 StandardNormal: Distribution<S::Elem>,
128 S: DataOwned,
129 Sh: ShapeBuilder<Dim = D>,
130 S::Elem: Float,
131 {
132 let distr = TruncatedNormal::new(mean, std)?;
133 Ok(Self::rand(shape, distr))
134 }
135 fn uniform<Sh>(shape: Sh, dk: S::Elem) -> super::UniformResult<Self>
137 where
138 S: DataOwned,
139 Sh: ShapeBuilder<Dim = D>,
140 S::Elem: Clone + Neg<Output = S::Elem> + SampleUniform,
141 <S::Elem as SampleUniform>::Sampler: Clone,
142 {
143 Self::uniform_between(shape, dk.clone().neg(), dk)
144 }
145 fn uniform_from_seed<Sh>(
148 shape: Sh,
149 start: S::Elem,
150 stop: S::Elem,
151 key: u64,
152 ) -> super::UniformResult<Self>
153 where
154 S: DataOwned,
155 Sh: ShapeBuilder<Dim = D>,
156 S::Elem: Clone + SampleUniform,
157 <S::Elem as SampleUniform>::Sampler: Clone,
158 {
159 let distr = Uniform::new(start, stop)?;
160 Ok(Self::rand_with(
161 shape,
162 distr,
163 &mut StdRng::seed_from_u64(key),
164 ))
165 }
166 fn uniform_along<Sh>(shape: Sh, axis: usize) -> super::UniformResult<Self>
170 where
171 Sh: ShapeBuilder<Dim = D>,
172 S: DataOwned,
173 S::Elem: Float + FromPrimitive + SampleUniform,
174 <S::Elem as SampleUniform>::Sampler: Clone,
175 {
176 let shape: Shape<D> = shape.into_shape_with_order();
178 let dim: D = shape.raw_dim().clone();
179 let dk = S::Elem::from_usize(dim[axis]).map(|i| i.recip()).unwrap();
180 Self::uniform(dim, dk)
181 }
182 fn uniform_between<Sh>(shape: Sh, a: S::Elem, b: S::Elem) -> super::UniformResult<Self>
185 where
186 Sh: ShapeBuilder<Dim = D>,
187 S: DataOwned,
188 S::Elem: Clone + SampleUniform,
189 <S::Elem as SampleUniform>::Sampler: Clone,
190 {
191 let distr = Uniform::new(a, b)?;
192 Ok(Self::rand(shape, distr))
193 }
194}
195impl<A, S, D> Initialize<S, D> for ArrayBase<S, D>
200where
201 D: Dimension,
202 S: RawData<Elem = A>,
203{
204 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
205 where
206 Ds: Distribution<S::Elem>,
207 Sh: ShapeBuilder<Dim = D>,
208 S: DataOwned,
209 {
210 Self::rand_with(shape, distr, &mut SmallRng::from_rng(&mut rand::rng()))
211 }
212
213 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
214 where
215 R: Rng + ?Sized,
216 Ds: Distribution<S::Elem>,
217 Sh: ShapeBuilder<Dim = D>,
218 S: DataOwned,
219 {
220 Self::from_shape_simple_fn(shape, move || distr.sample(rng))
221 }
222}