concision_core/params/impls/
impl_params_init.rs

1/*
2    appellation: impl_params_init <module>
3    authors: @FL03
4*/
5use crate::params::ParamsBase;
6
7use crate::init::Initialize;
8use ndarray::{
9    ArrayBase, Axis, DataOwned, Dimension, RawData, RemoveAxis, ScalarOperand, ShapeBuilder,
10};
11use num_traits::{Float, FromPrimitive};
12use rand::rngs::SmallRng;
13use rand_distr::Distribution;
14
15impl<A, S, D> ParamsBase<S, D>
16where
17    A: Float + FromPrimitive + ScalarOperand,
18    D: Dimension,
19    S: RawData<Elem = A>,
20{
21    /// generates a randomly initialized set of parameters with the given shape using the
22    /// output of the given distribution function `G`
23    pub fn init_rand<G, Dst, Sh>(shape: Sh, distr: G) -> Self
24    where
25        D: RemoveAxis,
26        S: DataOwned,
27        Sh: ShapeBuilder<Dim = D>,
28        Dst: Clone + Distribution<A>,
29        G: Fn(&Sh) -> Dst,
30    {
31        let dist = distr(&shape);
32        Self::rand(shape, dist)
33    }
34}
35
36impl<A, S, D> Initialize<S, D> for ParamsBase<S, D>
37where
38    D: RemoveAxis,
39    S: RawData<Elem = A>,
40{
41    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
42    where
43        Ds: Distribution<A>,
44        Sh: ShapeBuilder<Dim = D>,
45        S: DataOwned,
46    {
47        use rand::SeedableRng;
48        Self::rand_with(shape, distr, &mut SmallRng::from_rng(&mut rand::rng()))
49    }
50
51    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
52    where
53        R: rand::RngCore + ?Sized,
54        Ds: Distribution<A>,
55        Sh: ShapeBuilder<Dim = D>,
56        S: DataOwned,
57    {
58        let shape = shape.into_shape_with_order();
59        let bias_shape = shape.raw_dim().remove_axis(Axis(0));
60        let bias = ArrayBase::from_shape_fn(bias_shape, |_| distr.sample(rng));
61        let weights = ArrayBase::from_shape_fn(shape, |_| distr.sample(rng));
62        Self { bias, weights }
63    }
64}