concision_core/init/
initialize.rs

1/*
2    Appellation: initialize <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::distr::*;
6
7use core::ops::Neg;
8use ndarray::{ArrayBase, DataOwned, Dimension, RawData, Shape, ShapeBuilder};
9use num_traits::{Float, FromPrimitive};
10use rand::rngs::{SmallRng, StdRng};
11use rand::{Rng, SeedableRng};
12use rand_distr::uniform::{SampleUniform, Uniform};
13use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, NormalError, StandardNormal};
14
15/// This trait provides the base methods required for initializing tensors with random values.
16/// The trait is similar to the `RandomExt` trait provided by the `ndarray_rand` crate,
17/// however, it is designed to be more generic, extensible, and optimized for neural network
18/// initialization routines. [Initialize] is implemented for [`ArrayBase`] as well as
19/// [`ParamsBase`](crate::ParamsBase) allowing you to randomly initialize new tensors and
20/// parameters.
21pub trait Initialize<S, D>: Sized
22where
23    D: Dimension,
24    S: RawData,
25{
26    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
27    where
28        Ds: Distribution<S::Elem>,
29        Sh: ShapeBuilder<Dim = D>,
30        S: DataOwned;
31
32    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
33    where
34        R: Rng + ?Sized,
35        Ds: Distribution<S::Elem>,
36        Sh: ShapeBuilder<Dim = D>,
37        S: DataOwned;
38
39    fn bernoulli<Sh>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
40    where
41        Bernoulli: Distribution<S::Elem>,
42        S: DataOwned,
43        Sh: ShapeBuilder<Dim = D>,
44    {
45        let dist = Bernoulli::new(p)?;
46        Ok(Self::rand(shape, dist))
47    }
48    /// Initialize the object according to the Glorot Initialization scheme.
49    fn glorot_normal<Sh>(shape: Sh, inputs: usize, outputs: usize) -> Self
50    where
51        S::Elem: Float + FromPrimitive,
52        StandardNormal: Distribution<S::Elem>,
53        S: DataOwned,
54        Sh: ShapeBuilder<Dim = D>,
55    {
56        let distr = XavierNormal::new(inputs, outputs);
57        Self::rand(shape, distr)
58    }
59    /// Initialize the object according to the Glorot Initialization scheme.
60    fn glorot_uniform<Sh>(shape: Sh, inputs: usize, outputs: usize) -> super::UniformResult<Self>
61    where
62        S: DataOwned,
63        Sh: ShapeBuilder<Dim = D>,
64        S::Elem: Float + FromPrimitive + SampleUniform,
65        <S::Elem as SampleUniform>::Sampler: Clone,
66    {
67        let distr = XavierUniform::new(inputs, outputs)?;
68        Ok(Self::rand(shape, distr))
69    }
70    /// Initialize the object according to the Lecun Initialization scheme.
71    /// LecunNormal distributions are truncated [Normal](rand_distr::Normal)
72    /// distributions centered at 0 with a standard deviation equal to the
73    /// square root of the reciprocal of the number of inputs.
74    fn lecun_normal<Sh>(shape: Sh) -> Self
75    where
76        StandardNormal: Distribution<S::Elem>,
77        S: DataOwned,
78        Sh: ShapeBuilder<Dim = D>,
79        S::Elem: Float,
80    {
81        let shape = shape.into_shape_with_order();
82        let distr = LecunNormal::new(shape.size());
83        Self::rand(shape, distr)
84    }
85    /// Given a shape, mean, and standard deviation generate a new object using the [Normal](rand_distr::Normal) distribution
86    fn normal<Sh>(shape: Sh, mean: S::Elem, std: S::Elem) -> Result<Self, NormalError>
87    where
88        StandardNormal: Distribution<S::Elem>,
89        S: DataOwned,
90        Sh: ShapeBuilder<Dim = D>,
91        S::Elem: Float,
92    {
93        let distr = Normal::new(mean, std)?;
94        Ok(Self::rand(shape, distr))
95    }
96    #[cfg(feature = "complex")]
97    fn randc<Sh>(shape: Sh, re: S::Elem, im: S::Elem) -> Self
98    where
99        S: DataOwned,
100        Sh: ShapeBuilder<Dim = D>,
101        num_complex::ComplexDistribution<S::Elem, S::Elem>: Distribution<S::Elem>,
102    {
103        let distr = num_complex::ComplexDistribution::new(re, im);
104        Self::rand(shape, &distr)
105    }
106    /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution
107    fn stdnorm<Sh>(shape: Sh) -> Self
108    where
109        StandardNormal: Distribution<S::Elem>,
110        S: DataOwned,
111        Sh: ShapeBuilder<Dim = D>,
112    {
113        Self::rand(shape, StandardNormal)
114    }
115    /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution with a given seed
116    fn stdnorm_from_seed<Sh>(shape: Sh, seed: u64) -> Self
117    where
118        StandardNormal: Distribution<S::Elem>,
119        S: DataOwned,
120        Sh: ShapeBuilder<Dim = D>,
121    {
122        Self::rand_with(shape, StandardNormal, &mut StdRng::seed_from_u64(seed))
123    }
124    /// Initialize the object using the [TruncatedNormal](crate::init::distr::TruncatedNormal) distribution
125    fn truncnorm<Sh>(shape: Sh, mean: S::Elem, std: S::Elem) -> Result<Self, NormalError>
126    where
127        StandardNormal: Distribution<S::Elem>,
128        S: DataOwned,
129        Sh: ShapeBuilder<Dim = D>,
130        S::Elem: Float,
131    {
132        let distr = TruncatedNormal::new(mean, std)?;
133        Ok(Self::rand(shape, distr))
134    }
135    /// initialize the object using the [`Uniform`] distribution with values bounded by `+/- dk`
136    fn uniform<Sh>(shape: Sh, dk: S::Elem) -> super::UniformResult<Self>
137    where
138        S: DataOwned,
139        Sh: ShapeBuilder<Dim = D>,
140        S::Elem: Clone + Neg<Output = S::Elem> + SampleUniform,
141        <S::Elem as SampleUniform>::Sampler: Clone,
142    {
143        Self::uniform_between(shape, dk.clone().neg(), dk)
144    }
145    /// randomly initialize the object using the [`Uniform`] distribution with values between
146    /// the `start` and `stop` params using some random seed.
147    fn uniform_from_seed<Sh>(
148        shape: Sh,
149        start: S::Elem,
150        stop: S::Elem,
151        key: u64,
152    ) -> super::UniformResult<Self>
153    where
154        S: DataOwned,
155        Sh: ShapeBuilder<Dim = D>,
156        S::Elem: Clone + SampleUniform,
157        <S::Elem as SampleUniform>::Sampler: Clone,
158    {
159        let distr = Uniform::new(start, stop)?;
160        Ok(Self::rand_with(
161            shape,
162            distr,
163            &mut StdRng::seed_from_u64(key),
164        ))
165    }
166    /// initialize the object using the [`Uniform`] distribution with values bounded by the
167    /// size of the specified axis.
168    /// The values are bounded by `+/- dk` where `dk = 1 / size(axis)`.
169    fn uniform_along<Sh>(shape: Sh, axis: usize) -> super::UniformResult<Self>
170    where
171        Sh: ShapeBuilder<Dim = D>,
172        S: DataOwned,
173        S::Elem: Float + FromPrimitive + SampleUniform,
174        <S::Elem as SampleUniform>::Sampler: Clone,
175    {
176        // extract the shape
177        let shape: Shape<D> = shape.into_shape_with_order();
178        let dim: D = shape.raw_dim().clone();
179        let dk = S::Elem::from_usize(dim[axis]).map(|i| i.recip()).unwrap();
180        Self::uniform(dim, dk)
181    }
182    /// initialize the object using the [`Uniform`] distribution with values between then given
183    /// bounds, `a` and `b`.
184    fn uniform_between<Sh>(shape: Sh, a: S::Elem, b: S::Elem) -> super::UniformResult<Self>
185    where
186        Sh: ShapeBuilder<Dim = D>,
187        S: DataOwned,
188        S::Elem: Clone + SampleUniform,
189        <S::Elem as SampleUniform>::Sampler: Clone,
190    {
191        let distr = Uniform::new(a, b)?;
192        Ok(Self::rand(shape, distr))
193    }
194}
195/*
196 ************ Implementations ************
197*/
198
199impl<A, S, D> Initialize<S, D> for ArrayBase<S, D>
200where
201    D: Dimension,
202    S: RawData<Elem = A>,
203{
204    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
205    where
206        Ds: Distribution<S::Elem>,
207        Sh: ShapeBuilder<Dim = D>,
208        S: DataOwned,
209    {
210        Self::rand_with(shape, distr, &mut SmallRng::from_rng(&mut rand::rng()))
211    }
212
213    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
214    where
215        R: Rng + ?Sized,
216        Ds: Distribution<S::Elem>,
217        Sh: ShapeBuilder<Dim = D>,
218        S: DataOwned,
219    {
220        Self::from_shape_simple_fn(shape, move || distr.sample(rng))
221    }
222}