1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11use num_traits::Float;
12
13#[derive(Config, Debug, PartialEq)]
15pub enum Initializer {
16 Constant {
18 value: f64,
20 },
21 Ones,
23 Zeros,
25 Uniform {
27 min: f64,
29
30 max: f64,
32 },
33 Normal {
35 mean: f64,
37
38 std: f64,
40 },
41 KaimingUniform {
43 gain: f64,
45
46 fan_out_only: bool,
48 },
49 KaimingNormal {
51 gain: f64,
53
54 fan_out_only: bool,
56 },
57 XavierUniform {
61 gain: f64,
63 },
64 XavierNormal {
68 gain: f64,
70 },
71}
72
73impl Initializer {
74 pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
80 &self,
81 shape: S,
82 device: &B::Device,
83 ) -> Param<Tensor<B, D>> {
84 self.init_with(shape, None, None, device)
85 }
86
87 pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
93 &self,
94 shape: S,
95 fan_in: Option<usize>,
96 fan_out: Option<usize>,
97 device: &B::Device,
98 ) -> Param<Tensor<B, D>> {
99 let device = device.clone();
100 let shape: Shape = shape.into();
101 let config = self.clone();
102
103 Param::uninitialized(
104 ParamId::new(),
105 move |device, require_grad| {
106 let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
107
108 if require_grad {
109 tensor = tensor.require_grad();
110 }
111
112 tensor
113 },
114 device,
115 true,
116 )
117 }
118
119 fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
120 &self,
121 shape: S,
122 fan_in: Option<usize>,
123 fan_out: Option<usize>,
124 device: &B::Device,
125 ) -> Tensor<B, D> {
126 let shape = shape.into();
127 match self {
128 Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
129 Initializer::Ones => Tensor::<B, D>::ones(shape, device),
130 Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
131 Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
132 Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
133 Initializer::KaimingUniform { gain, fan_out_only } => {
134 let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
135 uniform_draw(shape, -a, a, device)
136 }
137 Initializer::KaimingNormal { gain, fan_out_only } => {
138 let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
139 normal_draw(shape, 0.0, std, device)
140 }
141 Initializer::XavierUniform { gain } => {
142 let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
143 uniform_draw(shape, -a, a, device)
144 }
145 Initializer::XavierNormal { gain } => {
146 let std = *gain * self.xavier_std(fan_in, fan_out);
147 normal_draw(shape, 0.0, std, device)
148 }
149 }
150 }
151
152 fn kaiming_std(
153 &self,
154 fan_out_only: bool,
155 fan_in: Option<usize>,
156 fan_out: Option<usize>,
157 ) -> f64 {
158 let fan = if fan_out_only { fan_out } else { fan_in };
159 let fan = fan.expect(
160 "Can't use Kaiming initialization without specifying fan. Use init_with method.",
161 );
162
163 1.0 / (fan as f64).sqrt()
164 }
165
166 fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
167 let fan_in = fan_in.expect(
168 "Can't use Xavier initialization without specifying fan in. Use init_with method and \
169 provide fan_in.",
170 );
171 let fan_out = fan_out.expect(
172 "Can't use Xavier initialization without specifying fan out. Use init_with method and \
173 provide fan_out.",
174 );
175 (2.0 / (fan_in + fan_out) as f64).sqrt()
176 }
177}
178
179fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
180 shape: S,
181 low: f64,
182 high: f64,
183 device: &B::Device,
184) -> Tensor<B, D> {
185 let distribution = Distribution::Uniform(low, high);
186 Tensor::<B, D>::random(shape, distribution, device)
187}
188
189fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
190 shape: S,
191 mean: f64,
192 std: f64,
193 device: &B::Device,
194) -> Tensor<B, D> {
195 let distribution = Distribution::Normal(mean, std);
196 Tensor::<B, D>::random(shape, distribution, device)
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 use crate::tensor::{ElementConversion, TensorData};
204 use num_traits::Pow;
205
206 pub type TB = burn_ndarray::NdArray<f32>;
207
208 fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
209 let (actual_vars, actual_means) = tensor.clone().var_mean(0);
210 let actual_vars = actual_vars.to_data();
211 let actual_vars = actual_vars
212 .as_slice::<<TB as Backend>::FloatElem>()
213 .unwrap();
214 let actual_means = actual_means.to_data();
215 let actual_means = actual_means
216 .as_slice::<<TB as Backend>::FloatElem>()
217 .unwrap();
218
219 for i in 0..tensor.shape().dims[0] {
220 let actual_var = actual_vars[i] as f64;
221 let actual_mean = actual_means[i] as f64;
222
223 assert!(
224 (expected_var - actual_var).abs() <= 0.1,
225 "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
226 );
227 assert!(
228 (expected_mean - actual_mean).abs() <= 0.1,
229 "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
230 );
231 }
232 }
233
234 #[test]
235 fn initializer_uniform_init() {
236 TB::seed(0);
237
238 let (min, max) = (0.0, 1.0);
239 let uniform = Initializer::Uniform { min, max };
240 let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
241
242 tensor.into_data().assert_within_range(min..max);
243 }
244
245 #[test]
246 fn initializer_normal_init() {
247 TB::seed(0);
249 let (mean, std) = (0.0, 1.0);
250 let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
251 .init([1000], &Default::default())
252 .into_value();
253 let (var_act, mean_act) = normal.var_mean(0);
254
255 let var_act: f32 = var_act.into_scalar().elem();
256 let mean_act: f32 = mean_act.into_scalar().elem();
257
258 assert!(
259 var_act > 0.9 && var_act < 1.1,
260 "Expected variance to be between 1.0 += 0.1, but got {var_act}"
261 );
262 assert!(
263 mean_act > -0.1 && mean_act < 0.1,
264 "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
265 );
266 }
267
268 #[test]
269 fn initializer_constant_init() {
270 let value = 5.0;
271 let constants: Tensor<TB, 4> = Initializer::Constant { value }
272 .init([2, 2, 2, 2], &Default::default())
273 .into_value();
274 constants
275 .sum()
276 .to_data()
277 .assert_approx_eq(&TensorData::from([value as f32 * 16.0]), 3);
278 }
279
280 #[test]
281 fn initializer_zeros_init() {
282 let zeros: Tensor<TB, 4> = Initializer::Zeros
283 .init([2, 2, 2, 2], &Default::default())
284 .into_value();
285 zeros
286 .sum()
287 .to_data()
288 .assert_approx_eq(&TensorData::from([0.0]), 3);
289 }
290
291 #[test]
292 fn initializer_ones_init() {
293 let ones: Tensor<TB, 4> = Initializer::Ones
294 .init([2, 2, 2, 2], &Default::default())
295 .into_value();
296 ones.sum()
297 .to_data()
298 .assert_approx_eq(&TensorData::from([16.0]), 3);
299 }
300
301 #[test]
302 fn initializer_kaiming_uniform_init() {
303 TB::seed(0);
304
305 let gain = 2_f64;
306 let (fan_in, fan_out) = (5, 6);
307 let k = gain * (3.0 / fan_in as f64).sqrt();
308
309 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
310 gain,
311 fan_out_only: false,
312 }
313 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
314 .into_value();
315 tensor.into_data().assert_within_range(-k..k);
316 }
317
318 #[test]
319 fn initializer_kaiming_normal_init() {
320 TB::seed(0);
321
322 let gain = 2.;
323 let (fan_in, fan_out) = (1000, 10);
324 let expected_mean = 0_f64;
325
326 let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
327 let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
328 gain,
329 fan_out_only: false,
330 }
331 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
332 .into_value();
333 assert_normal_init(expected_mean, expected_var, &tensor)
334 }
335
336 #[test]
337 fn initializer_kaiming_uniform_init_bias() {
338 TB::seed(0);
339
340 let gain = 2_f64;
341 let shape = [3];
342 let fan_in = 5;
343 let k = gain * (3.0 / fan_in as f64).sqrt();
344
345 let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
346 gain,
347 fan_out_only: false,
348 }
349 .init_with(shape, Some(fan_in), None, &Default::default())
350 .into_value();
351 tensor.into_data().assert_within_range(-k..k);
352 }
353
354 #[test]
355 fn initializer_kaiming_uniform_init_fan_out() {
356 TB::seed(0);
357
358 let gain = 2_f64;
359 let (fan_in, fan_out) = (5, 6);
360 let k = gain * (3.0 / fan_out as f64).sqrt();
361
362 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
363 gain,
364 fan_out_only: true,
365 }
366 .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
367 .into_value();
368 tensor.into_data().assert_within_range(-k..k);
369 }
370
371 #[test]
372 #[should_panic]
373 fn initializer_kaiming_uniform_no_fan() {
374 TB::seed(0);
375
376 let gain = 2_f64;
377 let (fan_in, fan_out) = (5, 6);
378
379 let _: Tensor<TB, 2> = Initializer::KaimingUniform {
380 gain,
381 fan_out_only: false,
382 }
383 .init([fan_out, fan_in], &Default::default())
384 .into_value();
385 }
386
387 #[test]
388 fn initializer_xavier_uniform_init() {
389 TB::seed(0);
390
391 let gain = 2.;
392 let (fan_in, fan_out) = (5, 6);
393 let bound = gain * (6. / (fan_in + fan_out) as f64).sqrt();
394 let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
395 .init_with(
396 [fan_out, fan_in],
397 Some(fan_in),
398 Some(fan_out),
399 &Default::default(),
400 )
401 .into_value();
402
403 tensor.into_data().assert_within_range(-bound..bound);
404 }
405
406 #[test]
407 fn initializer_xavier_normal_init() {
408 TB::seed(0);
409
410 let gain = 2.;
411 let (fan_in, fan_out) = (1000, 10);
412 let expected_mean = 0_f64;
413
414 let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
415 let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
416 .init_with(
417 [fan_out, fan_in],
418 Some(fan_in),
419 Some(fan_out),
420 &Default::default(),
421 )
422 .into_value();
423 assert_normal_init(expected_mean, expected_var, &tensor)
424 }
425
426 #[test]
427 #[should_panic]
428 fn initializer_xavier_uniform_no_fan() {
429 TB::seed(0);
430
431 let gain = 2.;
432 let (fan_in, fan_out) = (5, 6);
433 let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
434 .init([fan_out, fan_in], &Default::default())
435 .into_value();
436 }
437}