1use crate::init::distr::*;
6
7use core::ops::Neg;
8use nd::{ArrayBase, DataOwned, Dimension, RawData, ShapeBuilder};
9use ndrand::RandomExt;
10use num::complex::ComplexDistribution;
11use num::traits::Float;
12use rand::rngs::StdRng;
13use rand::{Rng, SeedableRng};
14use rand_distr::uniform::{SampleUniform, Uniform};
15use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, NormalError, StandardNormal};
16
17pub trait Initialize<A, D>
21where
22 D: Dimension,
23{
24 type Data: RawData<Elem = A>;
25 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
27 where
28 Ds: Clone + Distribution<A>,
29 Sh: ShapeBuilder<Dim = D>,
30 Self::Data: DataOwned;
31 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
33 where
34 R: Rng + ?Sized,
35 Ds: Clone + Distribution<A>,
36 Sh: ShapeBuilder<Dim = D>,
37 Self::Data: DataOwned;
38 fn init_rand<Ds>(self, distr: Ds) -> Self
40 where
41 Ds: Clone + Distribution<A>,
42 Self: Sized,
43 Self::Data: DataOwned;
44 fn init_rand_with<Ds, R>(self, distr: Ds, rng: &mut R) -> Self
46 where
47 R: Rng + ?Sized,
48 Ds: Clone + Distribution<A>,
49 Self::Data: DataOwned;
50}
51
52pub trait InitializeExt<A, S, D>: Initialize<A, D, Data = S> + Sized
54where
55 A: Clone,
56 D: Dimension,
57 S: RawData<Elem = A>,
58{
59 fn bernoulli<Sh>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
60 where
61 S: DataOwned,
62 Sh: ShapeBuilder<Dim = D>,
63 Bernoulli: Distribution<A>,
64 {
65 let dist = Bernoulli::new(p)?;
66 Ok(Self::rand(shape, dist))
67 }
68 fn lecun_normal<Sh>(shape: Sh, n: usize) -> Self
73 where
74 A: Float,
75 S: DataOwned,
76 Sh: ShapeBuilder<Dim = D>,
77 StandardNormal: Distribution<A>,
78 {
79 let distr = LecunNormal::new(n);
80 Self::rand(shape, distr)
81 }
82 fn normal<Sh>(shape: Sh, mean: A, std: A) -> Result<Self, NormalError>
84 where
85 A: Float,
86 S: DataOwned,
87 Sh: ShapeBuilder<Dim = D>,
88 StandardNormal: Distribution<A>,
89 {
90 let distr = Normal::new(mean, std)?;
91 Ok(Self::rand(shape, distr))
92 }
93
94 fn randc<Sh>(shape: Sh, re: A, im: A) -> Self
95 where
96 S: DataOwned,
97 Sh: ShapeBuilder<Dim = D>,
98 ComplexDistribution<A, A>: Distribution<A>,
99 {
100 let distr = ComplexDistribution::new(re, im);
101 Self::rand(shape, distr)
102 }
103 fn stdnorm<Sh>(shape: Sh) -> Self
105 where
106 S: DataOwned,
107 Sh: ShapeBuilder<Dim = D>,
108 StandardNormal: Distribution<A>,
109 {
110 Self::rand(shape, StandardNormal)
111 }
112 fn stdnorm_from_seed<Sh>(shape: Sh, seed: u64) -> Self
114 where
115 S: DataOwned,
116 Sh: ShapeBuilder<Dim = D>,
117 StandardNormal: Distribution<A>,
118 {
119 Self::rand_with(shape, StandardNormal, &mut StdRng::seed_from_u64(seed))
120 }
121 fn truncnorm<Sh>(shape: Sh, mean: A, std: A) -> Result<Self, NormalError>
123 where
124 A: Float,
125 S: DataOwned,
126 Sh: ShapeBuilder<Dim = D>,
127 StandardNormal: Distribution<A>,
128 {
129 let distr = TruncatedNormal::new(mean, std)?;
130 Ok(Self::rand(shape, distr))
131 }
132 fn uniform<Sh>(shape: Sh, dk: A) -> Self
134 where
135 A: Neg<Output = A> + SampleUniform,
136 S: DataOwned,
137 Sh: ShapeBuilder<Dim = D>,
138 <A as SampleUniform>::Sampler: Clone,
139 {
140 Self::rand(shape, Uniform::new(dk.clone().neg(), dk))
141 }
142
143 fn uniform_from_seed<Sh>(shape: Sh, start: A, stop: A, key: u64) -> Self
144 where
145 A: SampleUniform,
146 S: DataOwned,
147 Sh: ShapeBuilder<Dim = D>,
148 <A as SampleUniform>::Sampler: Clone,
149 {
150 Self::rand_with(
151 shape,
152 Uniform::new(start, stop),
153 &mut StdRng::seed_from_u64(key),
154 )
155 }
156 fn uniform_along<Sh>(shape: Sh, axis: usize) -> Self
158 where
159 A: Copy + Float + SampleUniform,
160 S: DataOwned,
161 Sh: ShapeBuilder<Dim = D>,
162 <A as SampleUniform>::Sampler: Clone,
163 {
164 let dim = shape.into_shape().raw_dim().clone();
165 let dk = A::from(dim[axis]).unwrap().recip();
166 Self::uniform(dim, dk)
167 }
168 fn uniform_between<Sh>(shape: Sh, a: A, b: A) -> Self
170 where
171 A: SampleUniform,
172 S: DataOwned,
173 Sh: ShapeBuilder<Dim = D>,
174 <A as SampleUniform>::Sampler: Clone,
175 {
176 Self::rand(shape, Uniform::new(a, b))
177 }
178}
179impl<A, S, D> Initialize<A, D> for ArrayBase<S, D>
183where
184 D: Dimension,
185 S: RawData<Elem = A>,
186 ArrayBase<S, D>: RandomExt<S, A, D>,
187{
188 type Data = S;
189
190 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> ArrayBase<S, D>
191 where
192 S: DataOwned,
193 Ds: Clone + Distribution<S::Elem>,
194 Sh: ShapeBuilder<Dim = D>,
195 {
196 Self::random(shape, distr)
197 }
198
199 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> ArrayBase<S, D>
200 where
201 R: Rng + ?Sized,
202 S: DataOwned,
203 Ds: Clone + Distribution<S::Elem>,
204 Sh: ShapeBuilder<Dim = D>,
205 {
206 Self::random_using(shape, distr, rng)
207 }
208
209 fn init_rand<Ds>(self, distr: Ds) -> ArrayBase<S, D>
210 where
211 S: DataOwned,
212 Ds: Clone + Distribution<S::Elem>,
213 {
214 Self::rand(self.dim(), distr)
215 }
216
217 fn init_rand_with<Ds, R>(self, distr: Ds, rng: &mut R) -> ArrayBase<S, D>
218 where
219 R: Rng + ?Sized,
220 S: DataOwned,
221 Ds: Clone + Distribution<S::Elem>,
222 {
223 Self::rand_with(self.dim(), distr, rng)
224 }
225}
226
227impl<U, A, S, D> InitializeExt<A, S, D> for U
228where
229 A: Clone,
230 D: Dimension,
231 S: RawData<Elem = A>,
232 U: Initialize<A, D, Data = S>,
233{
234}