concision_core/init/
initialize.rs

1/*
2    Appellation: initialize <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::distr::*;
6
7use core::ops::Neg;
8use ndarray::{ArrayBase, DataOwned, Dimension, RawData, Shape, ShapeBuilder};
9use num_traits::{Float, FromPrimitive};
10use rand::rngs::{SmallRng, StdRng};
11use rand::{Rng, SeedableRng};
12use rand_distr::uniform::{SampleUniform, Uniform};
13use rand_distr::{Bernoulli, BernoulliError, Distribution, Normal, NormalError, StandardNormal};
14
15fn _extract_xy_from_shape<D>(dim: &D, x: usize, y: usize) -> (usize, usize)
16where
17    D: Dimension,
18{
19    let tmp = dim.as_array_view();
20    let a = tmp.get(x).copied().unwrap_or(1_usize);
21    let b = tmp.get(y).copied().unwrap_or(1_usize);
22    (a, b)
23}
24
25/// This trait provides the base methods required for initializing tensors with random values.
26/// The trait is similar to the `RandomExt` trait provided by the `ndarray_rand` crate,
27/// however, it is designed to be more generic, extensible, and optimized for neural network
28/// initialization routines. [Initialize] is implemented for [`ArrayBase`] as well as
29/// [`ParamsBase`](crate::ParamsBase) allowing you to randomly initialize new tensors and
30/// parameters.
31pub trait Initialize<S, D>: Sized
32where
33    D: Dimension,
34    S: RawData,
35{
36    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
37    where
38        Ds: Distribution<S::Elem>,
39        Sh: ShapeBuilder<Dim = D>,
40        S: DataOwned;
41
42    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
43    where
44        R: Rng + ?Sized,
45        Ds: Distribution<S::Elem>,
46        Sh: ShapeBuilder<Dim = D>,
47        S: DataOwned;
48
49    fn bernoulli<Sh>(shape: Sh, p: f64) -> Result<Self, BernoulliError>
50    where
51        Bernoulli: Distribution<S::Elem>,
52        S: DataOwned,
53        Sh: ShapeBuilder<Dim = D>,
54    {
55        let dist = Bernoulli::new(p)?;
56        Ok(Self::rand(shape, dist))
57    }
58    /// Initialize the object according to the Glorot Initialization scheme.
59    fn glorot_normal<Sh: ShapeBuilder<Dim = D>>(shape: Sh) -> Self
60    where
61        StandardNormal: Distribution<S::Elem>,
62        S: DataOwned,
63        S::Elem: Float + FromPrimitive,
64    {
65        let shape = shape.into_shape_with_order();
66        let (inputs, outputs) = _extract_xy_from_shape(shape.raw_dim(), 0, 1);
67        let distr = XavierNormal::new(inputs, outputs);
68        Self::rand(shape, distr)
69    }
70    /// Initialize the object according to the Glorot Initialization scheme.
71    fn glorot_uniform<Sh>(shape: Sh) -> super::UniformResult<Self>
72    where
73        S: DataOwned,
74        Sh: ShapeBuilder<Dim = D>,
75        S::Elem: Float + FromPrimitive + SampleUniform,
76        <S::Elem as SampleUniform>::Sampler: Clone,
77    {
78        let shape = shape.into_shape_with_order();
79        let (inputs, outputs) = _extract_xy_from_shape(shape.raw_dim(), 0, 1);
80        let distr = XavierUniform::new(inputs, outputs)?;
81        Ok(Self::rand(shape, distr))
82    }
83    /// Initialize the object according to the Lecun Initialization scheme.
84    /// LecunNormal distributions are truncated [Normal](rand_distr::Normal)
85    /// distributions centered at 0 with a standard deviation equal to the
86    /// square root of the reciprocal of the number of inputs.
87    fn lecun_normal<Sh>(shape: Sh) -> Self
88    where
89        StandardNormal: Distribution<S::Elem>,
90        S: DataOwned,
91        Sh: ShapeBuilder<Dim = D>,
92        S::Elem: Float,
93    {
94        let shape = shape.into_shape_with_order();
95        let distr = LecunNormal::new(shape.size());
96        Self::rand(shape, distr)
97    }
98    /// Given a shape, mean, and standard deviation generate a new object using the [Normal](rand_distr::Normal) distribution
99    fn normal<Sh>(shape: Sh, mean: S::Elem, std: S::Elem) -> Result<Self, NormalError>
100    where
101        StandardNormal: Distribution<S::Elem>,
102        S: DataOwned,
103        Sh: ShapeBuilder<Dim = D>,
104        S::Elem: Float,
105    {
106        let distr = Normal::new(mean, std)?;
107        Ok(Self::rand(shape, distr))
108    }
109    #[cfg(feature = "complex")]
110    fn randc<Sh>(shape: Sh, re: S::Elem, im: S::Elem) -> Self
111    where
112        S: DataOwned,
113        Sh: ShapeBuilder<Dim = D>,
114        num_complex::ComplexDistribution<S::Elem, S::Elem>: Distribution<S::Elem>,
115    {
116        let distr = num_complex::ComplexDistribution::new(re, im);
117        Self::rand(shape, &distr)
118    }
119    /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution
120    fn stdnorm<Sh>(shape: Sh) -> Self
121    where
122        StandardNormal: Distribution<S::Elem>,
123        S: DataOwned,
124        Sh: ShapeBuilder<Dim = D>,
125    {
126        Self::rand(shape, StandardNormal)
127    }
128    /// Generate a random array using the [StandardNormal](rand_distr::StandardNormal) distribution with a given seed
129    fn stdnorm_from_seed<Sh>(shape: Sh, seed: u64) -> Self
130    where
131        StandardNormal: Distribution<S::Elem>,
132        S: DataOwned,
133        Sh: ShapeBuilder<Dim = D>,
134    {
135        Self::rand_with(shape, StandardNormal, &mut StdRng::seed_from_u64(seed))
136    }
137    /// Initialize the object using the [TruncatedNormal](crate::init::distr::TruncatedNormal) distribution
138    fn truncnorm<Sh>(shape: Sh, mean: S::Elem, std: S::Elem) -> Result<Self, NormalError>
139    where
140        StandardNormal: Distribution<S::Elem>,
141        S: DataOwned,
142        Sh: ShapeBuilder<Dim = D>,
143        S::Elem: Float,
144    {
145        let distr = TruncatedNormal::new(mean, std)?;
146        Ok(Self::rand(shape, distr))
147    }
148    /// initialize the object using the [`Uniform`] distribution with values bounded by `+/- dk`
149    fn uniform<Sh>(shape: Sh, dk: S::Elem) -> super::UniformResult<Self>
150    where
151        S: DataOwned,
152        Sh: ShapeBuilder<Dim = D>,
153        S::Elem: Clone + Neg<Output = S::Elem> + SampleUniform,
154        <S::Elem as SampleUniform>::Sampler: Clone,
155    {
156        Self::uniform_between(shape, dk.clone().neg(), dk)
157    }
158    /// randomly initialize the object using the [`Uniform`] distribution with values between
159    /// the `start` and `stop` params using some random seed.
160    fn uniform_from_seed<Sh>(
161        shape: Sh,
162        start: S::Elem,
163        stop: S::Elem,
164        key: u64,
165    ) -> super::UniformResult<Self>
166    where
167        S: DataOwned,
168        Sh: ShapeBuilder<Dim = D>,
169        S::Elem: Clone + SampleUniform,
170        <S::Elem as SampleUniform>::Sampler: Clone,
171    {
172        let distr = Uniform::new(start, stop)?;
173        Ok(Self::rand_with(
174            shape,
175            distr,
176            &mut StdRng::seed_from_u64(key),
177        ))
178    }
179    /// initialize the object using the [`Uniform`] distribution with values bounded by the
180    /// size of the specified axis.
181    /// The values are bounded by `+/- dk` where `dk = 1 / size(axis)`.
182    fn uniform_along<Sh>(shape: Sh, axis: usize) -> super::UniformResult<Self>
183    where
184        Sh: ShapeBuilder<Dim = D>,
185        S: DataOwned,
186        S::Elem: Float + FromPrimitive + SampleUniform,
187        <S::Elem as SampleUniform>::Sampler: Clone,
188    {
189        // extract the shape
190        let shape: Shape<D> = shape.into_shape_with_order();
191        let dim: D = shape.raw_dim().clone();
192        let dk = S::Elem::from_usize(dim[axis]).map(|i| i.recip()).unwrap();
193        Self::uniform(dim, dk)
194    }
195    /// initialize the object using the [`Uniform`] distribution with values between then given
196    /// bounds, `a` and `b`.
197    fn uniform_between<Sh>(shape: Sh, a: S::Elem, b: S::Elem) -> super::UniformResult<Self>
198    where
199        Sh: ShapeBuilder<Dim = D>,
200        S: DataOwned,
201        S::Elem: Clone + SampleUniform,
202        <S::Elem as SampleUniform>::Sampler: Clone,
203    {
204        let distr = Uniform::new(a, b)?;
205        Ok(Self::rand(shape, distr))
206    }
207}
208/*
209 ************ Implementations ************
210*/
211
212impl<A, S, D> Initialize<S, D> for ArrayBase<S, D>
213where
214    D: Dimension,
215    S: RawData<Elem = A>,
216{
217    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
218    where
219        Ds: Distribution<S::Elem>,
220        Sh: ShapeBuilder<Dim = D>,
221        S: DataOwned,
222    {
223        Self::rand_with(shape, distr, &mut SmallRng::from_rng(&mut rand::rng()))
224    }
225
226    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
227    where
228        R: Rng + ?Sized,
229        Ds: Distribution<S::Elem>,
230        Sh: ShapeBuilder<Dim = D>,
231        S: DataOwned,
232    {
233        Self::from_shape_simple_fn(shape, move || distr.sample(rng))
234    }
235}