1use super::distr::*;
6
7use core::ops::Neg;
8use ndarray::{ArrayBase, DataOwned, Dimension, RawData, Shape, ShapeBuilder};
9use num_traits::{Float, FromPrimitive};
10use rand::rngs::{SmallRng, StdRng};
11use rand::{Rng, SeedableRng};
12use rand_distr::uniform::{SampleUniform, Uniform};
13use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, NormalError, StandardNormal};
14
15fn _extract_xy_from_shape<D>(dim: &D, x: usize, y: usize) -> (usize, usize)
16where
17 D: Dimension,
18{
19 let tmp = dim.as_array_view();
20 let a = tmp.get(x).copied().unwrap_or(1_usize);
21 let b = tmp.get(y).copied().unwrap_or(1_usize);
22 (a, b)
23}
24
25pub trait Initialize<S, D>: Sized
32where
33 D: Dimension,
34 S: RawData,
35{
36 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
37 where
38 Ds: Distribution<S::Elem>,
39 Sh: ShapeBuilder<Dim = D>,
40 S: DataOwned;
41
42 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
43 where
44 R: Rng + ?Sized,
45 Ds: Distribution<S::Elem>,
46 Sh: ShapeBuilder<Dim = D>,
47 S: DataOwned;
48
49 fn bernoulli<Sh>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
50 where
51 Bernoulli: Distribution<S::Elem>,
52 S: DataOwned,
53 Sh: ShapeBuilder<Dim = D>,
54 {
55 let dist = Bernoulli::new(p)?;
56 Ok(Self::rand(shape, dist))
57 }
58 fn glorot_normal<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
60 where
61 StandardNormal: Distribution<S::Elem>,
62 S: DataOwned,
63 S::Elem: Float + FromPrimitive,
64 {
65 let shape = shape.into_shape_with_order();
66 let (inputs, outputs) = _extract_xy_from_shape(shape.raw_dim(), 0, 1);
67 let distr = XavierNormal::new(inputs, outputs);
68 Self::rand(shape, distr)
69 }
70 fn glorot_uniform<Sh>(shape: Sh) -> super::UniformResult<Self>
72 where
73 S: DataOwned,
74 Sh: ShapeBuilder<Dim = D>,
75 S::Elem: Float + FromPrimitive + SampleUniform,
76 <S::Elem as SampleUniform>::Sampler: Clone,
77 {
78 let shape = shape.into_shape_with_order();
79 let (inputs, outputs) = _extract_xy_from_shape(shape.raw_dim(), 0, 1);
80 let distr = XavierUniform::new(inputs, outputs)?;
81 Ok(Self::rand(shape, distr))
82 }
83 fn lecun_normal<Sh>(shape: Sh) -> Self
88 where
89 StandardNormal: Distribution<S::Elem>,
90 S: DataOwned,
91 Sh: ShapeBuilder<Dim = D>,
92 S::Elem: Float,
93 {
94 let shape = shape.into_shape_with_order();
95 let distr = LecunNormal::new(shape.size());
96 Self::rand(shape, distr)
97 }
98 fn normal<Sh>(shape: Sh, mean: S::Elem, std: S::Elem) -> Result<Self, NormalError>
100 where
101 StandardNormal: Distribution<S::Elem>,
102 S: DataOwned,
103 Sh: ShapeBuilder<Dim = D>,
104 S::Elem: Float,
105 {
106 let distr = Normal::new(mean, std)?;
107 Ok(Self::rand(shape, distr))
108 }
109 #[cfg(feature = "complex")]
110 fn randc<Sh>(shape: Sh, re: S::Elem, im: S::Elem) -> Self
111 where
112 S: DataOwned,
113 Sh: ShapeBuilder<Dim = D>,
114 num_complex::ComplexDistribution<S::Elem, S::Elem>: Distribution<S::Elem>,
115 {
116 let distr = num_complex::ComplexDistribution::new(re, im);
117 Self::rand(shape, &distr)
118 }
119 fn stdnorm<Sh>(shape: Sh) -> Self
121 where
122 StandardNormal: Distribution<S::Elem>,
123 S: DataOwned,
124 Sh: ShapeBuilder<Dim = D>,
125 {
126 Self::rand(shape, StandardNormal)
127 }
128 fn stdnorm_from_seed<Sh>(shape: Sh, seed: u64) -> Self
130 where
131 StandardNormal: Distribution<S::Elem>,
132 S: DataOwned,
133 Sh: ShapeBuilder<Dim = D>,
134 {
135 Self::rand_with(shape, StandardNormal, &mut StdRng::seed_from_u64(seed))
136 }
137 fn truncnorm<Sh>(shape: Sh, mean: S::Elem, std: S::Elem) -> Result<Self, NormalError>
139 where
140 StandardNormal: Distribution<S::Elem>,
141 S: DataOwned,
142 Sh: ShapeBuilder<Dim = D>,
143 S::Elem: Float,
144 {
145 let distr = TruncatedNormal::new(mean, std)?;
146 Ok(Self::rand(shape, distr))
147 }
148 fn uniform<Sh>(shape: Sh, dk: S::Elem) -> super::UniformResult<Self>
150 where
151 S: DataOwned,
152 Sh: ShapeBuilder<Dim = D>,
153 S::Elem: Clone + Neg<Output = S::Elem> + SampleUniform,
154 <S::Elem as SampleUniform>::Sampler: Clone,
155 {
156 Self::uniform_between(shape, dk.clone().neg(), dk)
157 }
158 fn uniform_from_seed<Sh>(
161 shape: Sh,
162 start: S::Elem,
163 stop: S::Elem,
164 key: u64,
165 ) -> super::UniformResult<Self>
166 where
167 S: DataOwned,
168 Sh: ShapeBuilder<Dim = D>,
169 S::Elem: Clone + SampleUniform,
170 <S::Elem as SampleUniform>::Sampler: Clone,
171 {
172 let distr = Uniform::new(start, stop)?;
173 Ok(Self::rand_with(
174 shape,
175 distr,
176 &mut StdRng::seed_from_u64(key),
177 ))
178 }
179 fn uniform_along<Sh>(shape: Sh, axis: usize) -> super::UniformResult<Self>
183 where
184 Sh: ShapeBuilder<Dim = D>,
185 S: DataOwned,
186 S::Elem: Float + FromPrimitive + SampleUniform,
187 <S::Elem as SampleUniform>::Sampler: Clone,
188 {
189 let shape: Shape<D> = shape.into_shape_with_order();
191 let dim: D = shape.raw_dim().clone();
192 let dk = S::Elem::from_usize(dim[axis]).map(|i| i.recip()).unwrap();
193 Self::uniform(dim, dk)
194 }
195 fn uniform_between<Sh>(shape: Sh, a: S::Elem, b: S::Elem) -> super::UniformResult<Self>
198 where
199 Sh: ShapeBuilder<Dim = D>,
200 S: DataOwned,
201 S::Elem: Clone + SampleUniform,
202 <S::Elem as SampleUniform>::Sampler: Clone,
203 {
204 let distr = Uniform::new(a, b)?;
205 Ok(Self::rand(shape, distr))
206 }
207}
208impl<A, S, D> Initialize<S, D> for ArrayBase<S, D>
213where
214 D: Dimension,
215 S: RawData<Elem = A>,
216{
217 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
218 where
219 Ds: Distribution<S::Elem>,
220 Sh: ShapeBuilder<Dim = D>,
221 S: DataOwned,
222 {
223 Self::rand_with(shape, distr, &mut SmallRng::from_rng(&mut rand::rng()))
224 }
225
226 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
227 where
228 R: Rng + ?Sized,
229 Ds: Distribution<S::Elem>,
230 Sh: ShapeBuilder<Dim = D>,
231 S: DataOwned,
232 {
233 Self::from_shape_simple_fn(shape, move || distr.sample(rng))
234 }
235}