concision_core/init/
traits.rs

1/*
2    Appellation: initialize <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::init::distr::*;
6
7use core::ops::Neg;
8use nd::{ArrayBase, DataOwned, Dimension, RawData, ShapeBuilder};
9use ndrand::RandomExt;
10use num::complex::ComplexDistribution;
11use num::traits::Float;
12use rand::rngs::StdRng;
13use rand::{Rng, SeedableRng};
14use rand_distr::uniform::{SampleUniform, Uniform};
15use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, NormalError, StandardNormal};
16
17/// This trait provides the base methods required for initializing an [ndarray](ndarray::ArrayBase) with random values.
18/// [Initialize] is similar to [RandomExt](ndarray_rand::RandomExt), however, it focuses on flexibility while implementing additional
19/// features geared towards machine-learning models; such as lecun_normal initialization.
20pub trait Initialize<A, D>
21where
22    D: Dimension,
23{
24    type Data: RawData<Elem = A>;
25    /// Generate a random array using the given distribution
26    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
27    where
28        Ds: Clone + Distribution<A>,
29        Sh: ShapeBuilder<Dim = D>,
30        Self::Data: DataOwned;
31    /// Generate a random array using the given distribution and random number generator
32    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
33    where
34        R: Rng + ?Sized,
35        Ds: Clone + Distribution<A>,
36        Sh: ShapeBuilder<Dim = D>,
37        Self::Data: DataOwned;
38    /// Initialize an array with random values using the given distribution and current shape
39    fn init_rand<Ds>(self, distr: Ds) -> Self
40    where
41        Ds: Clone + Distribution<A>,
42        Self: Sized,
43        Self::Data: DataOwned;
44    /// Initialize an array with random values from the current shape using the given distribution and random number generator
45    fn init_rand_with<Ds, R>(self, distr: Ds, rng: &mut R) -> Self
46    where
47        R: Rng + ?Sized,
48        Ds: Clone + Distribution<A>,
49        Self::Data: DataOwned;
50}
51
52/// This trait extends the [Initialize] trait with methods for generating random arrays from various distributions.
53pub trait InitializeExt<A, S, D>: Initialize<A, D, Data = S> + Sized
54where
55    A: Clone,
56    D: Dimension,
57    S: RawData<Elem = A>,
58{
59    fn bernoulli<Sh>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
60    where
61        S: DataOwned,
62        Sh: ShapeBuilder<Dim = D>,
63        Bernoulli: Distribution<A>,
64    {
65        let dist = Bernoulli::new(p)?;
66        Ok(Self::rand(shape, dist))
67    }
68    /// Initialize the object according to the Lecun Initialization scheme.
69    /// LecunNormal distributions are truncated [Normal](rand_distr::Normal)
70    /// distributions centered at 0 with a standard deviation equal to the
71    /// square root of the reciprocal of the number of inputs.
72    fn lecun_normal<Sh>(shape: Sh, n: usize) -> Self
73    where
74        A: Float,
75        S: DataOwned,
76        Sh: ShapeBuilder<Dim = D>,
77        StandardNormal: Distribution<A>,
78    {
79        let distr = LecunNormal::new(n);
80        Self::rand(shape, distr)
81    }
82    /// Given a shape, mean, and standard deviation generate a new object using the [Normal](rand_distr::Normal) distribution
83    fn normal<Sh>(shape: Sh, mean: A, std: A) -> Result<Self, NormalError>
84    where
85        A: Float,
86        S: DataOwned,
87        Sh: ShapeBuilder<Dim = D>,
88        StandardNormal: Distribution<A>,
89    {
90        let distr = Normal::new(mean, std)?;
91        Ok(Self::rand(shape, distr))
92    }
93
94    fn randc<Sh>(shape: Sh, re: A, im: A) -> Self
95    where
96        S: DataOwned,
97        Sh: ShapeBuilder<Dim = D>,
98        ComplexDistribution<A, A>: Distribution<A>,
99    {
100        let distr = ComplexDistribution::new(re, im);
101        Self::rand(shape, distr)
102    }
103    /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution
104    fn stdnorm<Sh>(shape: Sh) -> Self
105    where
106        S: DataOwned,
107        Sh: ShapeBuilder<Dim = D>,
108        StandardNormal: Distribution<A>,
109    {
110        Self::rand(shape, StandardNormal)
111    }
112    /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution with a given seed
113    fn stdnorm_from_seed<Sh>(shape: Sh, seed: u64) -> Self
114    where
115        S: DataOwned,
116        Sh: ShapeBuilder<Dim = D>,
117        StandardNormal: Distribution<A>,
118    {
119        Self::rand_with(shape, StandardNormal, &mut StdRng::seed_from_u64(seed))
120    }
121    /// Initialize the object using the [TruncatedNormal](crate::init::distr::TruncatedNormal) distribution
122    fn truncnorm<Sh>(shape: Sh, mean: A, std: A) -> Result<Self, NormalError>
123    where
124        A: Float,
125        S: DataOwned,
126        Sh: ShapeBuilder<Dim = D>,
127        StandardNormal: Distribution<A>,
128    {
129        let distr = TruncatedNormal::new(mean, std)?;
130        Ok(Self::rand(shape, distr))
131    }
132    /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk)
133    fn uniform<Sh>(shape: Sh, dk: A) -> Self
134    where
135        A: Neg<Output = A> + SampleUniform,
136        S: DataOwned,
137        Sh: ShapeBuilder<Dim = D>,
138        <A as SampleUniform>::Sampler: Clone,
139    {
140        Self::rand(shape, Uniform::new(dk.clone().neg(), dk))
141    }
142
143    fn uniform_from_seed<Sh>(shape: Sh, start: A, stop: A, key: u64) -> Self
144    where
145        A: SampleUniform,
146        S: DataOwned,
147        Sh: ShapeBuilder<Dim = D>,
148        <A as SampleUniform>::Sampler: Clone,
149    {
150        Self::rand_with(
151            shape,
152            Uniform::new(start, stop),
153            &mut StdRng::seed_from_u64(key),
154        )
155    }
156    /// Generate a random array with values between u(-a, a) where a is the reciprocal of the value at the given axis
157    fn uniform_along<Sh>(shape: Sh, axis: usize) -> Self
158    where
159        A: Copy + Float + SampleUniform,
160        S: DataOwned,
161        Sh: ShapeBuilder<Dim = D>,
162        <A as SampleUniform>::Sampler: Clone,
163    {
164        let dim = shape.into_shape().raw_dim().clone();
165        let dk = A::from(dim[axis]).unwrap().recip();
166        Self::uniform(dim, dk)
167    }
168    /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk)
169    fn uniform_between<Sh>(shape: Sh, a: A, b: A) -> Self
170    where
171        A: SampleUniform,
172        S: DataOwned,
173        Sh: ShapeBuilder<Dim = D>,
174        <A as SampleUniform>::Sampler: Clone,
175    {
176        Self::rand(shape, Uniform::new(a, b))
177    }
178}
179/*
180 ************ Implementations ************
181*/
182impl<A, S, D> Initialize<A, D> for ArrayBase<S, D>
183where
184    D: Dimension,
185    S: RawData<Elem = A>,
186    ArrayBase<S, D>: RandomExt<S, A, D>,
187{
188    type Data = S;
189
190    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> ArrayBase<S, D>
191    where
192        S: DataOwned,
193        Ds: Clone + Distribution<S::Elem>,
194        Sh: ShapeBuilder<Dim = D>,
195    {
196        Self::random(shape, distr)
197    }
198
199    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> ArrayBase<S, D>
200    where
201        R: Rng + ?Sized,
202        S: DataOwned,
203        Ds: Clone + Distribution<S::Elem>,
204        Sh: ShapeBuilder<Dim = D>,
205    {
206        Self::random_using(shape, distr, rng)
207    }
208
209    fn init_rand<Ds>(self, distr: Ds) -> ArrayBase<S, D>
210    where
211        S: DataOwned,
212        Ds: Clone + Distribution<S::Elem>,
213    {
214        Self::rand(self.dim(), distr)
215    }
216
217    fn init_rand_with<Ds, R>(self, distr: Ds, rng: &mut R) -> ArrayBase<S, D>
218    where
219        R: Rng + ?Sized,
220        S: DataOwned,
221        Ds: Clone + Distribution<S::Elem>,
222    {
223        Self::rand_with(self.dim(), distr, rng)
224    }
225}
226
227impl<U, A, S, D> InitializeExt<A, S, D> for U
228where
229    A: Clone,
230    D: Dimension,
231    S: RawData<Elem = A>,
232    U: Initialize<A, D, Data = S>,
233{
234}