1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor, s};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11#[allow(unused_imports)]
12use num_traits::Float as _;
13
14#[derive(Config, Debug, PartialEq)]
16pub enum Initializer {
17 Constant {
19 value: f64,
21 },
22 Ones,
24 Zeros,
26 Uniform {
28 min: f64,
30
31 max: f64,
33 },
34 Normal {
36 mean: f64,
38
39 std: f64,
41 },
42 KaimingUniform {
44 gain: f64,
46
47 fan_out_only: bool,
49 },
50 KaimingNormal {
52 gain: f64,
54
55 fan_out_only: bool,
57 },
58 XavierUniform {
62 gain: f64,
64 },
65 XavierNormal {
69 gain: f64,
71 },
72 Orthogonal {
76 gain: f64,
78 },
79}
80
81impl Initializer {
82 pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
88 &self,
89 shape: S,
90 device: &B::Device,
91 ) -> Param<Tensor<B, D>> {
92 self.init_with(shape, None, None, device)
93 }
94
95 pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
101 &self,
102 shape: S,
103 fan_in: Option<usize>,
104 fan_out: Option<usize>,
105 device: &B::Device,
106 ) -> Param<Tensor<B, D>> {
107 let device = device.clone();
108 let shape: Shape = shape.into();
109 let config = self.clone();
110 let shape_for_closure = shape.clone();
111
112 Param::uninitialized(
113 ParamId::new(),
114 move |device, require_grad| {
115 B::memory_persistent_allocations(device, (), move |_| {
116 let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
117
118 if require_grad {
119 tensor = tensor.require_grad();
120 }
121
122 tensor
123 })
124 },
125 device,
126 true,
127 shape_for_closure,
128 )
129 }
130
131 fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
132 &self,
133 shape: S,
134 fan_in: Option<usize>,
135 fan_out: Option<usize>,
136 device: &B::Device,
137 ) -> Tensor<B, D> {
138 let shape = shape.into();
139 match self {
140 Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
141 Initializer::Ones => Tensor::<B, D>::ones(shape, device),
142 Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
143 Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
144 Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
145 Initializer::KaimingUniform { gain, fan_out_only } => {
146 let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
147 uniform_draw(shape, -a, a, device)
148 }
149 Initializer::KaimingNormal { gain, fan_out_only } => {
150 let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
151 normal_draw(shape, 0.0, std, device)
152 }
153 Initializer::XavierUniform { gain } => {
154 let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
155 uniform_draw(shape, -a, a, device)
156 }
157 Initializer::XavierNormal { gain } => {
158 let std = *gain * self.xavier_std(fan_in, fan_out);
159 normal_draw(shape, 0.0, std, device)
160 }
161 Initializer::Orthogonal { gain } => {
162 assert!(
166 D >= 2,
167 "Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
168 );
169
170 let rows: usize = shape.dims::<D>()[0];
171 let cols: usize = shape.num_elements() / rows;
172
173 let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
174
175 if rows < cols {
176 t = t.transpose();
177 }
178
179 let (q, r) = qr_decomposition(t, device);
180 let [r_rows, r_cols] = r.clone().dims();
181
182 let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
183 .matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
184
185 let ph = diag_r.clone().sign();
186
187 let mut q = q.mul(ph);
188
189 if rows < cols {
190 q = q.transpose();
191 }
192
193 q.reshape(shape).mul_scalar(*gain)
194 }
195 }
196 }
197
198 fn kaiming_std(
199 &self,
200 fan_out_only: bool,
201 fan_in: Option<usize>,
202 fan_out: Option<usize>,
203 ) -> f64 {
204 let fan = if fan_out_only { fan_out } else { fan_in };
205 let fan = fan.expect(
206 "Can't use Kaiming initialization without specifying fan. Use init_with method.",
207 );
208
209 1.0 / (fan as f64).sqrt()
210 }
211
212 fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
213 let fan_in = fan_in.expect(
214 "Can't use Xavier initialization without specifying fan in. Use init_with method and \
215 provide fan_in.",
216 );
217 let fan_out = fan_out.expect(
218 "Can't use Xavier initialization without specifying fan out. Use init_with method and \
219 provide fan_out.",
220 );
221 (2.0 / (fan_in + fan_out) as f64).sqrt()
222 }
223}
224
225fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
226 shape: S,
227 low: f64,
228 high: f64,
229 device: &B::Device,
230) -> Tensor<B, D> {
231 let distribution = Distribution::Uniform(low, high);
232 Tensor::<B, D>::random(shape, distribution, device)
233}
234
235fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
236 shape: S,
237 mean: f64,
238 std: f64,
239 device: &B::Device,
240) -> Tensor<B, D> {
241 let distribution = Distribution::Normal(mean, std);
242 Tensor::<B, D>::random(shape, distribution, device)
243}
244
245fn qr_decomposition<B: Backend>(
246 a: Tensor<B, 2>,
247 device: &B::Device,
248) -> (Tensor<B, 2>, Tensor<B, 2>) {
249 let [m, n] = a.clone().dims();
252 let mut q = Tensor::<B, 2>::zeros([m, n], device);
253 let mut r = Tensor::<B, 2>::zeros([n, n], device);
254
255 for j in 0..n {
256 let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);
257
258 for i in 0..j {
259 let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);
260 let r_ij = q_i.clone().mul(v.clone()).sum();
261
262 r = r
263 .clone()
264 .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
265
266 v = v - q_i.mul(r_ij);
267 }
268
269 let r_jj = v
271 .clone()
272 .powf(Tensor::from_floats([2.0], device))
273 .sum()
274 .sqrt();
275
276 r = r
277 .clone()
278 .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
279
280 let q_j = v / r_jj;
281
282 q = q
283 .clone()
284 .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
285 }
286
287 (q, r)
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 use burn_tensor::{ElementConversion, TensorData};
295 use num_traits::Pow;
296
297 pub type TB = burn_ndarray::NdArray<f32>;
298 use burn_tensor::{Tolerance, ops::FloatElem};
299 type FT = FloatElem<TB>;
300
301 fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
302 let (actual_vars, actual_means) = tensor.clone().var_mean(0);
303 let actual_vars = actual_vars.to_data();
304 let actual_vars = actual_vars.as_slice::<FT>().unwrap();
305 let actual_means = actual_means.to_data();
306 let actual_means = actual_means.as_slice::<FT>().unwrap();
307
308 for i in 0..tensor.shape().dims[0] {
309 let actual_var = actual_vars[i] as f64;
310 let actual_mean = actual_means[i] as f64;
311
312 assert!(
313 (expected_var - actual_var).abs() <= 0.1,
314 "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
315 );
316 assert!(
317 (expected_mean - actual_mean).abs() <= 0.1,
318 "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
319 );
320 }
321 }
322
323 #[test]
324 fn initializer_uniform_init() {
325 let device = Default::default();
326 TB::seed(&device, 0);
327
328 let (min, max) = (0.0, 1.0);
329 let uniform = Initializer::Uniform { min, max };
330 let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
331
332 tensor
333 .into_data()
334 .assert_within_range::<FT>(min.elem()..max.elem());
335 }
336
337 #[test]
338 fn initializer_normal_init() {
339 let device = Default::default();
341 TB::seed(&device, 0);
342
343 let (mean, std) = (0.0, 1.0);
344 let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
345 .init([1000], &Default::default())
346 .into_value();
347 let (var_act, mean_act) = normal.var_mean(0);
348
349 let var_act: f32 = var_act.into_scalar().elem();
350 let mean_act: f32 = mean_act.into_scalar().elem();
351
352 assert!(
353 var_act > 0.9 && var_act < 1.1,
354 "Expected variance to be between 1.0 += 0.1, but got {var_act}"
355 );
356 assert!(
357 mean_act > -0.1 && mean_act < 0.1,
358 "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
359 );
360 }
361
362 #[test]
363 fn initializer_constant_init() {
364 let value = 5.0;
365 let constants: Tensor<TB, 4> = Initializer::Constant { value }
366 .init([2, 2, 2, 2], &Default::default())
367 .into_value();
368 constants.sum().to_data().assert_approx_eq::<FT>(
369 &TensorData::from([value as f32 * 16.0]),
370 Tolerance::default(),
371 );
372 }
373
374 #[test]
375 fn initializer_zeros_init() {
376 let zeros: Tensor<TB, 4> = Initializer::Zeros
377 .init([2, 2, 2, 2], &Default::default())
378 .into_value();
379 zeros
380 .sum()
381 .to_data()
382 .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
383 }
384
385 #[test]
386 fn initializer_ones_init() {
387 let ones: Tensor<TB, 4> = Initializer::Ones
388 .init([2, 2, 2, 2], &Default::default())
389 .into_value();
390 ones.sum()
391 .to_data()
392 .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
393 }
394
395 #[test]
396 fn initializer_kaiming_uniform_init() {
397 let device = Default::default();
398 TB::seed(&device, 0);
399
400 let gain = 2_f64;
401 let (fan_in, fan_out) = (5, 6);
402 let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
403
404 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
405 gain,
406 fan_out_only: false,
407 }
408 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
409 .into_value();
410 tensor.into_data().assert_within_range(-k..k);
411 }
412
413 #[test]
414 fn initializer_kaiming_normal_init() {
415 let device = Default::default();
416 TB::seed(&device, 0);
417
418 let gain = 2.;
419 let (fan_in, fan_out) = (1000, 10);
420 let expected_mean = 0_f64;
421
422 let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
423 let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
424 gain,
425 fan_out_only: false,
426 }
427 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
428 .into_value();
429 assert_normal_init(expected_mean, expected_var, &tensor)
430 }
431
432 #[test]
433 fn initializer_kaiming_uniform_init_bias() {
434 let device = Default::default();
435 TB::seed(&device, 0);
436
437 let gain = 2_f64;
438 let shape = [3];
439 let fan_in = 5;
440 let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
441
442 let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
443 gain,
444 fan_out_only: false,
445 }
446 .init_with(shape, Some(fan_in), None, &Default::default())
447 .into_value();
448 tensor.into_data().assert_within_range(-k..k);
449 }
450
451 #[test]
452 fn initializer_kaiming_uniform_init_fan_out() {
453 let device = Default::default();
454 TB::seed(&device, 0);
455
456 let gain = 2_f64;
457 let (fan_in, fan_out) = (5, 6);
458 let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
459
460 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
461 gain,
462 fan_out_only: true,
463 }
464 .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
465 .into_value();
466 tensor.into_data().assert_within_range(-k..k);
467 }
468
469 #[test]
470 #[should_panic]
471 fn initializer_kaiming_uniform_no_fan() {
472 let device = Default::default();
473 TB::seed(&device, 0);
474
475 let gain = 2_f64;
476 let (fan_in, fan_out) = (5, 6);
477
478 let _: Tensor<TB, 2> = Initializer::KaimingUniform {
479 gain,
480 fan_out_only: false,
481 }
482 .init([fan_out, fan_in], &Default::default())
483 .into_value();
484 }
485
486 #[test]
487 fn initializer_xavier_uniform_init() {
488 let device = Default::default();
489 TB::seed(&device, 0);
490
491 let gain = 2.;
492 let (fan_in, fan_out) = (5, 6);
493 let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
494 let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
495 .init_with(
496 [fan_out, fan_in],
497 Some(fan_in),
498 Some(fan_out),
499 &Default::default(),
500 )
501 .into_value();
502
503 tensor.into_data().assert_within_range(-bound..bound);
504 }
505
506 #[test]
507 fn initializer_xavier_normal_init() {
508 let device = Default::default();
509 TB::seed(&device, 0);
510
511 let gain = 2.;
512 let (fan_in, fan_out) = (1000, 10);
513 let expected_mean = 0_f64;
514
515 let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
516 let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
517 .init_with(
518 [fan_out, fan_in],
519 Some(fan_in),
520 Some(fan_out),
521 &Default::default(),
522 )
523 .into_value();
524 assert_normal_init(expected_mean, expected_var, &tensor)
525 }
526
527 #[test]
528 #[should_panic]
529 fn initializer_xavier_uniform_no_fan() {
530 let device = Default::default();
531 TB::seed(&device, 0);
532
533 let gain = 2.;
534 let (fan_in, fan_out) = (5, 6);
535 let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
536 .init([fan_out, fan_in], &Default::default())
537 .into_value();
538 }
539
540 #[test]
541 fn test_qr_decomposition() {
542 let device = Default::default();
543 TB::seed(&device, 0);
544
545 let a = Tensor::<TB, 2>::from_floats(
547 [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
548 &Default::default(),
549 );
550 let qr = qr_decomposition(a.clone(), &Default::default());
551
552 let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
554
555 q_matmul_r
557 .into_data()
558 .assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
559 }
560
561 #[test]
562 fn initializer_orthogonal_correct() {
563 let device = Default::default();
564 TB::seed(&device, 0);
565
566 let gain = 1.;
567
568 let size = 10;
570 let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
571 .init([size, size], &Default::default())
572 .into_value();
573 let eye = Tensor::<TB, 2>::eye(size, &Default::default());
574
575 q.clone()
577 .transpose()
578 .matmul(q)
579 .into_data()
580 .assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
581 }
582
583 #[test]
584 fn initializer_orthogonal_init() {
585 let device = Default::default();
586 TB::seed(&device, 0);
587
588 let gain = 1.;
589
590 let shape = [25, 30];
592 let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
593 .init(shape, &Default::default())
594 .into_value();
595 let dims = t.dims();
596 assert_eq!(
597 shape, dims,
598 "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
599 );
600
601 let shape = [24, 6, 85];
603 let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
604 .init(shape, &Default::default())
605 .into_value();
606 let dims = t.dims();
607 assert_eq!(
608 shape, dims,
609 "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
610 );
611 }
612
613 #[test]
614 #[should_panic]
615 fn initializer_orthogonal_init_1d() {
616 let device = Default::default();
617 TB::seed(&device, 0);
618
619 let gain = 1.;
620
621 let shape = [3];
623 let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
624 .init(shape, &Default::default())
625 .into_value();
626 }
627}