1use super::distr::*;
6
7use core::ops::Neg;
8use ndarray::{ArrayBase, DataOwned, Dimension, RawData, ShapeBuilder};
9use num_complex::ComplexDistribution;
10use num_traits::{Float, FromPrimitive};
11use rand::{
12 Rng, SeedableRng,
13 rngs::{SmallRng, StdRng},
14};
15use rand_distr::uniform::{SampleUniform, Uniform};
16use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, NormalError, StandardNormal};
17
18pub trait Initialize<A, D>
24where
25 D: Dimension,
26{
27 type Data: RawData<Elem = A>;
28
29 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
30 where
31 Ds: Distribution<A>,
32 Sh: ShapeBuilder<Dim = D>,
33 Self::Data: DataOwned;
34
35 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
36 where
37 R: Rng + ?Sized,
38 Ds: Distribution<A>,
39 Sh: ShapeBuilder<Dim = D>,
40 Self::Data: DataOwned;
41}
42
43pub trait InitializeExt<A, D>
45where
46 D: Dimension,
47 Self: Initialize<A, D> + Sized,
48 Self::Data: DataOwned<Elem = A>,
49{
50 fn bernoulli<Sh: ShapeBuilder<Dim = D>>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
51 where
52 Bernoulli: Distribution<A>,
53 {
54 let dist = Bernoulli::new(p)?;
55 Ok(Self::rand(shape, dist))
56 }
57 fn glorot_normal<Sh>(shape: Sh, inputs: usize, outputs: usize) -> Self
59 where
60 A: Float + FromPrimitive,
61 Sh: ShapeBuilder<Dim = D>,
62 StandardNormal: Distribution<A>,
63 Self::Data: DataOwned<Elem = A>,
64 {
65 let distr = XavierNormal::new(inputs, outputs);
66 Self::rand(shape, distr)
67 }
68 fn glorot_uniform<Sh>(shape: Sh, inputs: usize, outputs: usize) -> super::UniformResult<Self>
70 where
71 A: Float + FromPrimitive + SampleUniform,
72 Sh: ShapeBuilder<Dim = D>,
73 <A as SampleUniform>::Sampler: Clone,
74 Self::Data: DataOwned<Elem = A>,
75 {
76 let distr = XavierUniform::new(inputs, outputs)?;
77 Ok(Self::rand(shape, distr))
78 }
79 fn lecun_normal<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
84 where
85 A: Float,
86 StandardNormal: Distribution<A>,
87 Self::Data: DataOwned<Elem = A>,
88 {
89 let shape = shape.into_shape_with_order();
90 let distr = LecunNormal::new(shape.size());
91 Self::rand(shape, distr)
92 }
93 fn normal<Sh: ShapeBuilder<Dim = D>>(shape: Sh, mean: A, std: A) -> Result<Self, NormalError>
95 where
96 A: Float,
97 StandardNormal: Distribution<A>,
98 Self::Data: DataOwned<Elem = A>,
99 {
100 let distr = Normal::new(mean, std)?;
101 Ok(Self::rand(shape, distr))
102 }
103
104 fn randc<Sh: ShapeBuilder<Dim = D>>(shape: Sh, re: A, im: A) -> Self
105 where
106 ComplexDistribution<A, A>: Distribution<A>,
107 Self::Data: DataOwned<Elem = A>,
108 {
109 let distr = ComplexDistribution::new(re, im);
110 Self::rand(shape, &distr)
111 }
112 fn stdnorm<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
114 where
115 StandardNormal: Distribution<A>,
116 Self::Data: DataOwned<Elem = A>,
117 {
118 Self::rand(shape, StandardNormal)
119 }
120 fn stdnorm_from_seed<Sh: ShapeBuilder<Dim = D>>(shape: Sh, seed: u64) -> Self
122 where
123 StandardNormal: Distribution<A>,
124 Self::Data: DataOwned<Elem = A>,
125 {
126 Self::rand_with(shape, StandardNormal, &mut StdRng::seed_from_u64(seed))
127 }
128 fn truncnorm<Sh: ShapeBuilder<Dim = D>>(shape: Sh, mean: A, std: A) -> Result<Self, NormalError>
130 where
131 A: Float,
132 StandardNormal: Distribution<A>,
133 Self::Data: DataOwned<Elem = A>,
134 {
135 let distr = TruncatedNormal::new(mean, std)?;
136 Ok(Self::rand(shape, distr))
137 }
138 fn uniform<Sh>(shape: Sh, dk: A) -> super::UniformResult<Self>
140 where
141 A: Copy + Neg<Output = A> + SampleUniform,
142 Sh: ShapeBuilder<Dim = D>,
143 <A as SampleUniform>::Sampler: Clone,
144 Self::Data: DataOwned<Elem = A>,
145 {
146 Uniform::new(dk.neg(), dk).map(|distr| Self::rand(shape, distr))
147 }
148
149 fn uniform_from_seed<Sh>(shape: Sh, start: A, stop: A, key: u64) -> super::UniformResult<Self>
150 where
151 A: Clone + SampleUniform,
152 Sh: ShapeBuilder<Dim = D>,
153 <A as SampleUniform>::Sampler: Clone,
154 Self::Data: DataOwned<Elem = A>,
155 {
156 Uniform::new(start, stop)
157 .map(|distr| Self::rand_with(shape, distr, &mut StdRng::seed_from_u64(key)))
158 }
159 fn uniform_along<Sh>(shape: Sh, axis: usize) -> super::UniformResult<Self>
161 where
162 A: Copy + Float + SampleUniform,
163 Sh: ShapeBuilder<Dim = D>,
164 <A as SampleUniform>::Sampler: Clone,
165 Self::Data: DataOwned<Elem = A>,
166 {
167 let dim = shape.into_shape_with_order().raw_dim().clone();
168 let dk = A::from(dim[axis]).unwrap().recip();
169 Self::uniform(dim, dk)
170 }
171 fn uniform_between<Sh>(shape: Sh, a: A, b: A) -> super::UniformResult<Self>
173 where
174 A: Clone + SampleUniform,
175 Sh: ShapeBuilder<Dim = D>,
176 <A as SampleUniform>::Sampler: Clone,
177 Self::Data: DataOwned<Elem = A>,
178 {
179 Uniform::new(a, b).map(|distr| Self::rand(shape, distr))
180 }
181}
182impl<S, A, D> Initialize<A, D> for ArrayBase<S, D>
187where
188 D: Dimension,
189 S: RawData<Elem = A>,
190{
191 type Data = S;
192
193 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
194 where
195 Ds: Distribution<A>,
196 Sh: ShapeBuilder<Dim = D>,
197 S: DataOwned,
198 {
199 Self::rand_with(shape, distr, &mut SmallRng::from_rng(&mut rand::rng()))
200 }
201
202 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
203 where
204 R: Rng + ?Sized,
205 Ds: Distribution<A>,
206 Sh: ShapeBuilder<Dim = D>,
207 S: DataOwned,
208 {
209 Self::from_shape_simple_fn(shape, move || distr.sample(rng))
210 }
211}
212
213impl<U, A, S, D> InitializeExt<A, D> for U
214where
215 A: Clone,
216 D: Dimension,
217 S: DataOwned<Elem = A>,
218 U: Initialize<A, D, Data = S> + Sized,
219{
220}