concision_linear/impls/
impl_rand.rs

1/*
2    Appellation: impl_rand <impls>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5#![cfg(feature = "rand")]
6
7use crate::params::{LinearParams, ParamMode, ParamsBase};
8use crate::{bias_dim, Linear};
9use concision::init::rand::Rng;
10use concision::init::rand_distr::{uniform::SampleUniform, Distribution, StandardNormal};
11use concision::{Initialize, InitializeExt};
12use nd::*;
13use num::Float;
14
15impl<A, S, D, K> Linear<A, K, D, S>
16where
17    A: Clone + Float,
18    D: RemoveAxis,
19    K: ParamMode,
20    S: DataOwned<Elem = A>,
21    StandardNormal: Distribution<A>,
22{
23    pub fn uniform(self) -> Linear<A, K, D, OwnedRepr<A>>
24    where
25        A: SampleUniform,
26        <A as SampleUniform>::Sampler: Clone,
27    {
28        Linear {
29            config: self.config,
30            params: self.params.uniform(),
31        }
32    }
33}
34
35impl<A, S, D, K> ParamsBase<S, D, K>
36where
37    A: Clone + Float + SampleUniform,
38    D: RemoveAxis,
39    K: ParamMode,
40    S: RawData<Elem = A>,
41    StandardNormal: Distribution<A>,
42    <A as SampleUniform>::Sampler: Clone,
43{
44    /// Computes the reciprocal of the input features.
45    pub(crate) fn dk(&self) -> A {
46        A::from(self.in_features()).unwrap().recip()
47    }
48    /// Computes the square root of the reciprical of the input features.
49    pub(crate) fn dk_sqrt(&self) -> A {
50        self.dk().sqrt()
51    }
52
53    pub fn uniform(self) -> LinearParams<A, K, D>
54    where
55        S: DataOwned,
56    {
57        let dk = self.dk_sqrt();
58        self.uniform_between(-dk, dk)
59    }
60
61    pub fn uniform_between(self, low: A, high: A) -> LinearParams<A, K, D>
62    where
63        S: DataOwned,
64    {
65        let weight = Array::uniform_between(self.raw_dim(), low, high);
66        let bias = if self.is_biased() && !self.bias.is_some() {
67            let b_dim = bias_dim(self.raw_dim());
68            Some(Array::uniform_between(b_dim, low, high))
69        } else if !self.is_biased() && self.bias.is_some() {
70            None
71        } else {
72            self.bias
73                .as_ref()
74                .map(|b| Array::uniform_between(b.raw_dim(), low, high))
75        };
76        LinearParams {
77            weight,
78            bias,
79            _mode: core::marker::PhantomData::<K>,
80        }
81    }
82}
83
84impl<A, S, D, K> Initialize<A, D> for Linear<A, K, D, S>
85where
86    D: RemoveAxis,
87    K: ParamMode,
88    S: DataOwned<Elem = A>,
89    StandardNormal: Distribution<A>,
90{
91    type Data = OwnedRepr<A>;
92    fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
93    where
94        Sh: ShapeBuilder<Dim = D>,
95        Ds: Clone + Distribution<A>,
96    {
97        Self::from_params(ParamsBase::rand(shape, distr))
98    }
99
100    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
101    where
102        R: Rng + ?Sized,
103        Ds: Clone + Distribution<A>,
104        Sh: ShapeBuilder<Dim = D>,
105    {
106        Self::from_params(ParamsBase::rand_with(shape, distr, rng))
107    }
108
109    fn init_rand<Ds>(self, distr: Ds) -> Self
110    where
111        Ds: Clone + Distribution<A>,
112        Self: Sized,
113    {
114        Self::rand(self.dim(), distr)
115    }
116
117    fn init_rand_with<Ds, R>(self, distr: Ds, rng: &mut R) -> Self
118    where
119        R: Rng + ?Sized,
120        Ds: Clone + Distribution<A>,
121    {
122        Self::rand_with(self.dim(), distr, rng)
123    }
124}
125
126impl<A, S, D, K> Initialize<A, D> for ParamsBase<S, D, K>
127where
128    D: RemoveAxis,
129    K: ParamMode,
130    S: DataOwned<Elem = A>,
131    StandardNormal: Distribution<A>,
132{
133    type Data = S;
134    fn rand<Sh, Dstr>(shape: Sh, distr: Dstr) -> Self
135    where
136        Sh: ShapeBuilder<Dim = D>,
137        Dstr: Clone + Distribution<A>,
138    {
139        let dim = shape.into_shape().raw_dim().clone();
140        let bias = if K::BIASED {
141            Some(ArrayBase::rand(bias_dim(dim.clone()), distr.clone()))
142        } else {
143            None
144        };
145        Self {
146            weight: ArrayBase::rand(dim, distr),
147            bias,
148            _mode: core::marker::PhantomData::<K>,
149        }
150    }
151
152    fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
153    where
154        R: Rng + ?Sized,
155        S: DataOwned,
156        Ds: Clone + Distribution<A>,
157        Sh: ShapeBuilder<Dim = D>,
158    {
159        let dim = shape.into_shape().raw_dim().clone();
160        let bias = if K::BIASED {
161            Some(ArrayBase::rand_with(
162                bias_dim(dim.clone()),
163                distr.clone(),
164                rng,
165            ))
166        } else {
167            None
168        };
169        Self {
170            weight: ArrayBase::rand_with(dim, distr, rng),
171            bias,
172            _mode: core::marker::PhantomData::<K>,
173        }
174    }
175
176    fn init_rand<Ds>(self, distr: Ds) -> Self
177    where
178        S: DataOwned,
179        Ds: Clone + Distribution<A>,
180        Self: Sized,
181    {
182        Self::rand(self.dim(), distr)
183    }
184
185    fn init_rand_with<Ds, R>(self, distr: Ds, rng: &mut R) -> Self
186    where
187        R: Rng + ?Sized,
188        S: DataOwned,
189        Ds: Clone + Distribution<A>,
190    {
191        Self::rand_with(self.dim(), distr, rng)
192    }
193}