concision_core/init/
initialize.rs

1/*
2    Appellation: initialize <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::distr::*;
6
7use core::ops::Neg;
8use ndarray::{ArrayBase, DataOwned, Dimension, RawData, ShapeBuilder};
9use num_complex::ComplexDistribution;
10use num_traits::{Float, FromPrimitive};
11use rand::{
12    Rng, SeedableRng,
13    rngs::{SmallRng, StdRng},
14};
15use rand_distr::uniform::{SampleUniform, Uniform};
16use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, NormalError, StandardNormal};
17
18/// This trait facilitates the initialization of tensors with random values. [Initialize]
19///
20/// This trait provides the base methods required for initializing an [ndarray](ndarray::ArrayBase) with random values.
21/// [Initialize] is similar to [RandomExt](ndarray_rand::RandomExt), however, it focuses on flexibility while implementing additional
22/// features geared towards machine-learning models; such as lecun_normal initialization.
23pub trait Initialize<A, D>
24where
25    D: Dimension,
26{
27    type Data: RawData<Elem = A>;
28
29    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
30    where
31        Ds: Distribution<A>,
32        Sh: ShapeBuilder<Dim = D>,
33        Self::Data: DataOwned;
34
35    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
36    where
37        R: Rng + ?Sized,
38        Ds: Distribution<A>,
39        Sh: ShapeBuilder<Dim = D>,
40        Self::Data: DataOwned;
41}
42
43/// This trait extends the [Initialize] trait with methods for generating random arrays from various distributions.
44pub trait InitializeExt<A, D>
45where
46    D: Dimension,
47    Self: Initialize<A, D> + Sized,
48    Self::Data: DataOwned<Elem = A>,
49{
50    fn bernoulli<Sh: ShapeBuilder<Dim = D>>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
51    where
52        Bernoulli: Distribution<A>,
53    {
54        let dist = Bernoulli::new(p)?;
55        Ok(Self::rand(shape, dist))
56    }
57    /// Initialize the object according to the Glorot Initialization scheme.
58    fn glorot_normal<Sh>(shape: Sh, inputs: usize, outputs: usize) -> Self
59    where
60        A: Float + FromPrimitive,
61        Sh: ShapeBuilder<Dim = D>,
62        StandardNormal: Distribution<A>,
63        Self::Data: DataOwned<Elem = A>,
64    {
65        let distr = XavierNormal::new(inputs, outputs);
66        Self::rand(shape, distr)
67    }
68    /// Initialize the object according to the Glorot Initialization scheme.
69    fn glorot_uniform<Sh>(shape: Sh, inputs: usize, outputs: usize) -> super::UniformResult<Self>
70    where
71        A: Float + FromPrimitive + SampleUniform,
72        Sh: ShapeBuilder<Dim = D>,
73        <A as SampleUniform>::Sampler: Clone,
74        Self::Data: DataOwned<Elem = A>,
75    {
76        let distr = XavierUniform::new(inputs, outputs)?;
77        Ok(Self::rand(shape, distr))
78    }
79    /// Initialize the object according to the Lecun Initialization scheme.
80    /// LecunNormal distributions are truncated [Normal](rand_distr::Normal)
81    /// distributions centered at 0 with a standard deviation equal to the
82    /// square root of the reciprocal of the number of inputs.
83    fn lecun_normal<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
84    where
85        A: Float,
86        StandardNormal: Distribution<A>,
87        Self::Data: DataOwned<Elem = A>,
88    {
89        let shape = shape.into_shape_with_order();
90        let distr = LecunNormal::new(shape.size());
91        Self::rand(shape, distr)
92    }
93    /// Given a shape, mean, and standard deviation generate a new object using the [Normal](rand_distr::Normal) distribution
94    fn normal<Sh: ShapeBuilder<Dim = D>>(shape: Sh, mean: A, std: A) -> Result<Self, NormalError>
95    where
96        A: Float,
97        StandardNormal: Distribution<A>,
98        Self::Data: DataOwned<Elem = A>,
99    {
100        let distr = Normal::new(mean, std)?;
101        Ok(Self::rand(shape, distr))
102    }
103
104    fn randc<Sh: ShapeBuilder<Dim = D>>(shape: Sh, re: A, im: A) -> Self
105    where
106        ComplexDistribution<A, A>: Distribution<A>,
107        Self::Data: DataOwned<Elem = A>,
108    {
109        let distr = ComplexDistribution::new(re, im);
110        Self::rand(shape, &distr)
111    }
112    /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution
113    fn stdnorm<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
114    where
115        StandardNormal: Distribution<A>,
116        Self::Data: DataOwned<Elem = A>,
117    {
118        Self::rand(shape, StandardNormal)
119    }
120    /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution with a given seed
121    fn stdnorm_from_seed<Sh: ShapeBuilder<Dim = D>>(shape: Sh, seed: u64) -> Self
122    where
123        StandardNormal: Distribution<A>,
124        Self::Data: DataOwned<Elem = A>,
125    {
126        Self::rand_with(shape, StandardNormal, &mut StdRng::seed_from_u64(seed))
127    }
128    /// Initialize the object using the [TruncatedNormal](crate::init::distr::TruncatedNormal) distribution
129    fn truncnorm<Sh: ShapeBuilder<Dim = D>>(shape: Sh, mean: A, std: A) -> Result<Self, NormalError>
130    where
131        A: Float,
132        StandardNormal: Distribution<A>,
133        Self::Data: DataOwned<Elem = A>,
134    {
135        let distr = TruncatedNormal::new(mean, std)?;
136        Ok(Self::rand(shape, distr))
137    }
138    /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk)
139    fn uniform<Sh>(shape: Sh, dk: A) -> super::UniformResult<Self>
140    where
141        A: Copy + Neg<Output = A> + SampleUniform,
142        Sh: ShapeBuilder<Dim = D>,
143        <A as SampleUniform>::Sampler: Clone,
144        Self::Data: DataOwned<Elem = A>,
145    {
146        Uniform::new(dk.neg(), dk).map(|distr| Self::rand(shape, distr))
147    }
148
149    fn uniform_from_seed<Sh>(shape: Sh, start: A, stop: A, key: u64) -> super::UniformResult<Self>
150    where
151        A: Clone + SampleUniform,
152        Sh: ShapeBuilder<Dim = D>,
153        <A as SampleUniform>::Sampler: Clone,
154        Self::Data: DataOwned<Elem = A>,
155    {
156        Uniform::new(start, stop)
157            .map(|distr| Self::rand_with(shape, distr, &mut StdRng::seed_from_u64(key)))
158    }
159    /// Generate a random array with values between u(-a, a) where a is the reciprocal of the value at the given axis
160    fn uniform_along<Sh>(shape: Sh, axis: usize) -> super::UniformResult<Self>
161    where
162        A: Copy + Float + SampleUniform,
163        Sh: ShapeBuilder<Dim = D>,
164        <A as SampleUniform>::Sampler: Clone,
165        Self::Data: DataOwned<Elem = A>,
166    {
167        let dim = shape.into_shape_with_order().raw_dim().clone();
168        let dk = A::from(dim[axis]).unwrap().recip();
169        Self::uniform(dim, dk)
170    }
171    /// A [uniform](rand_distr::uniform::Uniform) generator with values between u(-dk, dk)
172    fn uniform_between<Sh>(shape: Sh, a: A, b: A) -> super::UniformResult<Self>
173    where
174        A: Clone + SampleUniform,
175        Sh: ShapeBuilder<Dim = D>,
176        <A as SampleUniform>::Sampler: Clone,
177        Self::Data: DataOwned<Elem = A>,
178    {
179        Uniform::new(a, b).map(|distr| Self::rand(shape, distr))
180    }
181}
182/*
183 ************ Implementations ************
184*/
185
186impl<S, A, D> Initialize<A, D> for ArrayBase<S, D>
187where
188    D: Dimension,
189    S: RawData<Elem = A>,
190{
191    type Data = S;
192
193    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
194    where
195        Ds: Distribution<A>,
196        Sh: ShapeBuilder<Dim = D>,
197        S: DataOwned,
198    {
199        Self::rand_with(shape, distr, &mut SmallRng::from_rng(&mut rand::rng()))
200    }
201
202    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
203    where
204        R: Rng + ?Sized,
205        Ds: Distribution<A>,
206        Sh: ShapeBuilder<Dim = D>,
207        S: DataOwned,
208    {
209        Self::from_shape_simple_fn(shape, move || distr.sample(rng))
210    }
211}
212
213impl<U, A, S, D> InitializeExt<A, D> for U
214where
215    A: Clone,
216    D: Dimension,
217    S: DataOwned<Elem = A>,
218    U: Initialize<A, D, Data = S> + Sized,
219{
220}