1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor, s};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11use num_traits::Float;
12
13#[derive(Config, Debug, PartialEq)]
15pub enum Initializer {
16 Constant {
18 value: f64,
20 },
21 Ones,
23 Zeros,
25 Uniform {
27 min: f64,
29
30 max: f64,
32 },
33 Normal {
35 mean: f64,
37
38 std: f64,
40 },
41 KaimingUniform {
43 gain: f64,
45
46 fan_out_only: bool,
48 },
49 KaimingNormal {
51 gain: f64,
53
54 fan_out_only: bool,
56 },
57 XavierUniform {
61 gain: f64,
63 },
64 XavierNormal {
68 gain: f64,
70 },
71 Orthogonal {
75 gain: f64,
77 },
78}
79
80impl Initializer {
81 pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
87 &self,
88 shape: S,
89 device: &B::Device,
90 ) -> Param<Tensor<B, D>> {
91 self.init_with(shape, None, None, device)
92 }
93
94 pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
100 &self,
101 shape: S,
102 fan_in: Option<usize>,
103 fan_out: Option<usize>,
104 device: &B::Device,
105 ) -> Param<Tensor<B, D>> {
106 let device = device.clone();
107 let shape: Shape = shape.into();
108 let config = self.clone();
109
110 Param::uninitialized(
111 ParamId::new(),
112 move |device, require_grad| {
113 let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
114
115 if require_grad {
116 tensor = tensor.require_grad();
117 }
118
119 tensor
120 },
121 device,
122 true,
123 )
124 }
125
126 fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
127 &self,
128 shape: S,
129 fan_in: Option<usize>,
130 fan_out: Option<usize>,
131 device: &B::Device,
132 ) -> Tensor<B, D> {
133 let shape = shape.into();
134 match self {
135 Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
136 Initializer::Ones => Tensor::<B, D>::ones(shape, device),
137 Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
138 Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
139 Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
140 Initializer::KaimingUniform { gain, fan_out_only } => {
141 let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
142 uniform_draw(shape, -a, a, device)
143 }
144 Initializer::KaimingNormal { gain, fan_out_only } => {
145 let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
146 normal_draw(shape, 0.0, std, device)
147 }
148 Initializer::XavierUniform { gain } => {
149 let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
150 uniform_draw(shape, -a, a, device)
151 }
152 Initializer::XavierNormal { gain } => {
153 let std = *gain * self.xavier_std(fan_in, fan_out);
154 normal_draw(shape, 0.0, std, device)
155 }
156 Initializer::Orthogonal { gain } => {
157 assert!(
161 D >= 2,
162 "Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
163 );
164
165 let rows: usize = shape.dims::<D>()[0];
166 let cols: usize = shape.num_elements() / rows;
167
168 let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
169
170 if rows < cols {
171 t = t.transpose();
172 }
173
174 let (q, r) = qr_decomposition(t, device);
175 let [r_rows, r_cols] = r.clone().dims();
176
177 let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
178 .matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
179
180 let ph = diag_r.clone().sign();
181
182 let mut q = q.mul(ph);
183
184 if rows < cols {
185 q = q.transpose();
186 }
187
188 q.reshape(shape).mul_scalar(*gain)
189 }
190 }
191 }
192
193 fn kaiming_std(
194 &self,
195 fan_out_only: bool,
196 fan_in: Option<usize>,
197 fan_out: Option<usize>,
198 ) -> f64 {
199 let fan = if fan_out_only { fan_out } else { fan_in };
200 let fan = fan.expect(
201 "Can't use Kaiming initialization without specifying fan. Use init_with method.",
202 );
203
204 1.0 / (fan as f64).sqrt()
205 }
206
207 fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
208 let fan_in = fan_in.expect(
209 "Can't use Xavier initialization without specifying fan in. Use init_with method and \
210 provide fan_in.",
211 );
212 let fan_out = fan_out.expect(
213 "Can't use Xavier initialization without specifying fan out. Use init_with method and \
214 provide fan_out.",
215 );
216 (2.0 / (fan_in + fan_out) as f64).sqrt()
217 }
218}
219
220fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
221 shape: S,
222 low: f64,
223 high: f64,
224 device: &B::Device,
225) -> Tensor<B, D> {
226 let distribution = Distribution::Uniform(low, high);
227 Tensor::<B, D>::random(shape, distribution, device)
228}
229
230fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
231 shape: S,
232 mean: f64,
233 std: f64,
234 device: &B::Device,
235) -> Tensor<B, D> {
236 let distribution = Distribution::Normal(mean, std);
237 Tensor::<B, D>::random(shape, distribution, device)
238}
239
240fn qr_decomposition<B: Backend>(
241 a: Tensor<B, 2>,
242 device: &B::Device,
243) -> (Tensor<B, 2>, Tensor<B, 2>) {
244 let [m, n] = a.clone().dims();
247 let mut q = Tensor::<B, 2>::zeros([m, n], device);
248 let mut r = Tensor::<B, 2>::zeros([n, n], device);
249
250 for j in 0..n {
251 let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze(1);
252
253 for i in 0..j {
254 let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze(1);
255 let r_ij = q_i.clone().mul(v.clone()).sum();
256
257 r = r
258 .clone()
259 .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
260
261 v = v - q_i.mul(r_ij);
262 }
263
264 let r_jj = v
266 .clone()
267 .powf(Tensor::from_floats([2.0], device))
268 .sum()
269 .sqrt();
270
271 r = r
272 .clone()
273 .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
274
275 let q_j = v / r_jj;
276
277 q = q
278 .clone()
279 .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
280 }
281
282 (q, r)
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 use crate::tensor::{ElementConversion, TensorData};
290 use num_traits::Pow;
291
292 pub type TB = burn_ndarray::NdArray<f32>;
293 use burn_tensor::{Tolerance, ops::FloatElem};
294 type FT = FloatElem<TB>;
295
296 fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
297 let (actual_vars, actual_means) = tensor.clone().var_mean(0);
298 let actual_vars = actual_vars.to_data();
299 let actual_vars = actual_vars.as_slice::<FT>().unwrap();
300 let actual_means = actual_means.to_data();
301 let actual_means = actual_means.as_slice::<FT>().unwrap();
302
303 for i in 0..tensor.shape().dims[0] {
304 let actual_var = actual_vars[i] as f64;
305 let actual_mean = actual_means[i] as f64;
306
307 assert!(
308 (expected_var - actual_var).abs() <= 0.1,
309 "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
310 );
311 assert!(
312 (expected_mean - actual_mean).abs() <= 0.1,
313 "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
314 );
315 }
316 }
317
318 #[test]
319 fn initializer_uniform_init() {
320 TB::seed(0);
321
322 let (min, max) = (0.0, 1.0);
323 let uniform = Initializer::Uniform { min, max };
324 let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
325
326 tensor
327 .into_data()
328 .assert_within_range::<FT>(min.elem()..max.elem());
329 }
330
331 #[test]
332 fn initializer_normal_init() {
333 TB::seed(0);
335 let (mean, std) = (0.0, 1.0);
336 let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
337 .init([1000], &Default::default())
338 .into_value();
339 let (var_act, mean_act) = normal.var_mean(0);
340
341 let var_act: f32 = var_act.into_scalar().elem();
342 let mean_act: f32 = mean_act.into_scalar().elem();
343
344 assert!(
345 var_act > 0.9 && var_act < 1.1,
346 "Expected variance to be between 1.0 += 0.1, but got {var_act}"
347 );
348 assert!(
349 mean_act > -0.1 && mean_act < 0.1,
350 "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
351 );
352 }
353
354 #[test]
355 fn initializer_constant_init() {
356 let value = 5.0;
357 let constants: Tensor<TB, 4> = Initializer::Constant { value }
358 .init([2, 2, 2, 2], &Default::default())
359 .into_value();
360 constants.sum().to_data().assert_approx_eq::<FT>(
361 &TensorData::from([value as f32 * 16.0]),
362 Tolerance::default(),
363 );
364 }
365
366 #[test]
367 fn initializer_zeros_init() {
368 let zeros: Tensor<TB, 4> = Initializer::Zeros
369 .init([2, 2, 2, 2], &Default::default())
370 .into_value();
371 zeros
372 .sum()
373 .to_data()
374 .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
375 }
376
377 #[test]
378 fn initializer_ones_init() {
379 let ones: Tensor<TB, 4> = Initializer::Ones
380 .init([2, 2, 2, 2], &Default::default())
381 .into_value();
382 ones.sum()
383 .to_data()
384 .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
385 }
386
387 #[test]
388 fn initializer_kaiming_uniform_init() {
389 TB::seed(0);
390
391 let gain = 2_f64;
392 let (fan_in, fan_out) = (5, 6);
393 let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
394
395 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
396 gain,
397 fan_out_only: false,
398 }
399 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
400 .into_value();
401 tensor.into_data().assert_within_range(-k..k);
402 }
403
404 #[test]
405 fn initializer_kaiming_normal_init() {
406 TB::seed(0);
407
408 let gain = 2.;
409 let (fan_in, fan_out) = (1000, 10);
410 let expected_mean = 0_f64;
411
412 let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
413 let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
414 gain,
415 fan_out_only: false,
416 }
417 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
418 .into_value();
419 assert_normal_init(expected_mean, expected_var, &tensor)
420 }
421
422 #[test]
423 fn initializer_kaiming_uniform_init_bias() {
424 TB::seed(0);
425
426 let gain = 2_f64;
427 let shape = [3];
428 let fan_in = 5;
429 let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
430
431 let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
432 gain,
433 fan_out_only: false,
434 }
435 .init_with(shape, Some(fan_in), None, &Default::default())
436 .into_value();
437 tensor.into_data().assert_within_range(-k..k);
438 }
439
440 #[test]
441 fn initializer_kaiming_uniform_init_fan_out() {
442 TB::seed(0);
443
444 let gain = 2_f64;
445 let (fan_in, fan_out) = (5, 6);
446 let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
447
448 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
449 gain,
450 fan_out_only: true,
451 }
452 .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
453 .into_value();
454 tensor.into_data().assert_within_range(-k..k);
455 }
456
457 #[test]
458 #[should_panic]
459 fn initializer_kaiming_uniform_no_fan() {
460 TB::seed(0);
461
462 let gain = 2_f64;
463 let (fan_in, fan_out) = (5, 6);
464
465 let _: Tensor<TB, 2> = Initializer::KaimingUniform {
466 gain,
467 fan_out_only: false,
468 }
469 .init([fan_out, fan_in], &Default::default())
470 .into_value();
471 }
472
473 #[test]
474 fn initializer_xavier_uniform_init() {
475 TB::seed(0);
476
477 let gain = 2.;
478 let (fan_in, fan_out) = (5, 6);
479 let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
480 let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
481 .init_with(
482 [fan_out, fan_in],
483 Some(fan_in),
484 Some(fan_out),
485 &Default::default(),
486 )
487 .into_value();
488
489 tensor.into_data().assert_within_range(-bound..bound);
490 }
491
492 #[test]
493 fn initializer_xavier_normal_init() {
494 TB::seed(0);
495
496 let gain = 2.;
497 let (fan_in, fan_out) = (1000, 10);
498 let expected_mean = 0_f64;
499
500 let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
501 let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
502 .init_with(
503 [fan_out, fan_in],
504 Some(fan_in),
505 Some(fan_out),
506 &Default::default(),
507 )
508 .into_value();
509 assert_normal_init(expected_mean, expected_var, &tensor)
510 }
511
512 #[test]
513 #[should_panic]
514 fn initializer_xavier_uniform_no_fan() {
515 TB::seed(0);
516
517 let gain = 2.;
518 let (fan_in, fan_out) = (5, 6);
519 let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
520 .init([fan_out, fan_in], &Default::default())
521 .into_value();
522 }
523
524 #[test]
525 fn test_qr_decomposition() {
526 TB::seed(0);
527
528 let a = Tensor::<TB, 2>::from_floats(
530 [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
531 &Default::default(),
532 );
533 let qr = qr_decomposition(a.clone(), &Default::default());
534
535 let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
537
538 q_matmul_r
540 .into_data()
541 .assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
542 }
543
544 #[test]
545 fn initializer_orthogonal_correct() {
546 TB::seed(0);
547
548 let gain = 1.;
549
550 let size = 10;
552 let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
553 .init([size, size], &Default::default())
554 .into_value();
555 let eye = Tensor::<TB, 2>::eye(size, &Default::default());
556
557 q.clone()
559 .transpose()
560 .matmul(q)
561 .into_data()
562 .assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
563 }
564
565 #[test]
566 fn initializer_orthogonal_init() {
567 TB::seed(0);
568
569 let gain = 1.;
570
571 let shape = [25, 30];
573 let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
574 .init(shape, &Default::default())
575 .into_value();
576 let dims = t.dims();
577 assert_eq!(
578 shape, dims,
579 "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
580 );
581
582 let shape = [24, 6, 85];
584 let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
585 .init(shape, &Default::default())
586 .into_value();
587 let dims = t.dims();
588 assert_eq!(
589 shape, dims,
590 "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
591 );
592 }
593
594 #[test]
595 #[should_panic]
596 fn initializer_orthogonal_init_1d() {
597 TB::seed(0);
598 let gain = 1.;
599
600 let shape = [3];
602 let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
603 .init(shape, &Default::default())
604 .into_value();
605 }
606}