concision_core/params/impls/
impl_params_init.rs1use crate::params::ParamsBase;
6
7use crate::init::Initialize;
8use ndarray::{
9 ArrayBase, Axis, DataOwned, Dimension, RawData, RemoveAxis, ScalarOperand, ShapeBuilder,
10};
11use num_traits::{Float, FromPrimitive};
12use rand::rngs::SmallRng;
13use rand_distr::Distribution;
14
15impl<A, S, D> ParamsBase<S, D>
16where
17 A: Float + FromPrimitive + ScalarOperand,
18 D: Dimension,
19 S: RawData<Elem = A>,
20{
21 pub fn init_rand<G, Dst, Sh>(shape: Sh, distr: G) -> Self
24 where
25 D: RemoveAxis,
26 S: DataOwned,
27 Sh: ShapeBuilder<Dim = D>,
28 Dst: Clone + Distribution<A>,
29 G: Fn(&Sh) -> Dst,
30 {
31 let dist = distr(&shape);
32 Self::rand(shape, dist)
33 }
34}
35
36impl<A, S, D> Initialize<S, D> for ParamsBase<S, D>
37where
38 D: RemoveAxis,
39 S: RawData<Elem = A>,
40{
41 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
42 where
43 Ds: Distribution<A>,
44 Sh: ShapeBuilder<Dim = D>,
45 S: DataOwned,
46 {
47 use rand::SeedableRng;
48 Self::rand_with(shape, distr, &mut SmallRng::from_rng(&mut rand::rng()))
49 }
50
51 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
52 where
53 R: rand::RngCore + ?Sized,
54 Ds: Distribution<A>,
55 Sh: ShapeBuilder<Dim = D>,
56 S: DataOwned,
57 {
58 let shape = shape.into_shape_with_order();
59 let bias_shape = shape.raw_dim().remove_axis(Axis(0));
60 let bias = ArrayBase::from_shape_fn(bias_shape, |_| distr.sample(rng));
61 let weights = ArrayBase::from_shape_fn(shape, |_| distr.sample(rng));
62 Self { bias, weights }
63 }
64}