1#![cfg(feature = "rand")]
6
7use crate::params::{LinearParams, ParamMode, ParamsBase};
8use crate::{bias_dim, Linear};
9use concision::init::rand::Rng;
10use concision::init::rand_distr::{uniform::SampleUniform, Distribution, StandardNormal};
11use concision::{Initialize, InitializeExt};
12use nd::*;
13use num::Float;
14
15impl<A, S, D, K> Linear<A, K, D, S>
16where
17 A: Clone + Float,
18 D: RemoveAxis,
19 K: ParamMode,
20 S: DataOwned<Elem = A>,
21 StandardNormal: Distribution<A>,
22{
23 pub fn uniform(self) -> Linear<A, K, D, OwnedRepr<A>>
24 where
25 A: SampleUniform,
26 <A as SampleUniform>::Sampler: Clone,
27 {
28 Linear {
29 config: self.config,
30 params: self.params.uniform(),
31 }
32 }
33}
34
35impl<A, S, D, K> ParamsBase<S, D, K>
36where
37 A: Clone + Float + SampleUniform,
38 D: RemoveAxis,
39 K: ParamMode,
40 S: RawData<Elem = A>,
41 StandardNormal: Distribution<A>,
42 <A as SampleUniform>::Sampler: Clone,
43{
44 pub(crate) fn dk(&self) -> A {
46 A::from(self.in_features()).unwrap().recip()
47 }
48 pub(crate) fn dk_sqrt(&self) -> A {
50 self.dk().sqrt()
51 }
52
53 pub fn uniform(self) -> LinearParams<A, K, D>
54 where
55 S: DataOwned,
56 {
57 let dk = self.dk_sqrt();
58 self.uniform_between(-dk, dk)
59 }
60
61 pub fn uniform_between(self, low: A, high: A) -> LinearParams<A, K, D>
62 where
63 S: DataOwned,
64 {
65 let weight = Array::uniform_between(self.raw_dim(), low, high);
66 let bias = if self.is_biased() && !self.bias.is_some() {
67 let b_dim = bias_dim(self.raw_dim());
68 Some(Array::uniform_between(b_dim, low, high))
69 } else if !self.is_biased() && self.bias.is_some() {
70 None
71 } else {
72 self.bias
73 .as_ref()
74 .map(|b| Array::uniform_between(b.raw_dim(), low, high))
75 };
76 LinearParams {
77 weight,
78 bias,
79 _mode: core::marker::PhantomData::<K>,
80 }
81 }
82}
83
84impl<A, S, D, K> Initialize<A, D> for Linear<A, K, D, S>
85where
86 D: RemoveAxis,
87 K: ParamMode,
88 S: DataOwned<Elem = A>,
89 StandardNormal: Distribution<A>,
90{
91 type Data = OwnedRepr<A>;
92 fn rand<Sh, Ds>(shape: Sh, distr: Ds) -> Self
93 where
94 Sh: ShapeBuilder<Dim = D>,
95 Ds: Clone + Distribution<A>,
96 {
97 Self::from_params(ParamsBase::rand(shape, distr))
98 }
99
100 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
101 where
102 R: Rng + ?Sized,
103 Ds: Clone + Distribution<A>,
104 Sh: ShapeBuilder<Dim = D>,
105 {
106 Self::from_params(ParamsBase::rand_with(shape, distr, rng))
107 }
108
109 fn init_rand<Ds>(self, distr: Ds) -> Self
110 where
111 Ds: Clone + Distribution<A>,
112 Self: Sized,
113 {
114 Self::rand(self.dim(), distr)
115 }
116
117 fn init_rand_with<Ds, R>(self, distr: Ds, rng: &mut R) -> Self
118 where
119 R: Rng + ?Sized,
120 Ds: Clone + Distribution<A>,
121 {
122 Self::rand_with(self.dim(), distr, rng)
123 }
124}
125
126impl<A, S, D, K> Initialize<A, D> for ParamsBase<S, D, K>
127where
128 D: RemoveAxis,
129 K: ParamMode,
130 S: DataOwned<Elem = A>,
131 StandardNormal: Distribution<A>,
132{
133 type Data = S;
134 fn rand<Sh, Dstr>(shape: Sh, distr: Dstr) -> Self
135 where
136 Sh: ShapeBuilder<Dim = D>,
137 Dstr: Clone + Distribution<A>,
138 {
139 let dim = shape.into_shape().raw_dim().clone();
140 let bias = if K::BIASED {
141 Some(ArrayBase::rand(bias_dim(dim.clone()), distr.clone()))
142 } else {
143 None
144 };
145 Self {
146 weight: ArrayBase::rand(dim, distr),
147 bias,
148 _mode: core::marker::PhantomData::<K>,
149 }
150 }
151
152 fn rand_with<Sh, Ds, R>(shape: Sh, distr: Ds, rng: &mut R) -> Self
153 where
154 R: Rng + ?Sized,
155 S: DataOwned,
156 Ds: Clone + Distribution<A>,
157 Sh: ShapeBuilder<Dim = D>,
158 {
159 let dim = shape.into_shape().raw_dim().clone();
160 let bias = if K::BIASED {
161 Some(ArrayBase::rand_with(
162 bias_dim(dim.clone()),
163 distr.clone(),
164 rng,
165 ))
166 } else {
167 None
168 };
169 Self {
170 weight: ArrayBase::rand_with(dim, distr, rng),
171 bias,
172 _mode: core::marker::PhantomData::<K>,
173 }
174 }
175
176 fn init_rand<Ds>(self, distr: Ds) -> Self
177 where
178 S: DataOwned,
179 Ds: Clone + Distribution<A>,
180 Self: Sized,
181 {
182 Self::rand(self.dim(), distr)
183 }
184
185 fn init_rand_with<Ds, R>(self, distr: Ds, rng: &mut R) -> Self
186 where
187 R: Rng + ?Sized,
188 S: DataOwned,
189 Ds: Clone + Distribution<A>,
190 {
191 Self::rand_with(self.dim(), distr, rng)
192 }
193}