1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor, s};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11#[allow(unused_imports)]
12use num_traits::Float as _;
13
14#[derive(Config, Debug, PartialEq)]
16pub enum Initializer {
17 Constant {
19 value: f64,
21 },
22 Ones,
24 Zeros,
26 Uniform {
28 min: f64,
30
31 max: f64,
33 },
34 Normal {
36 mean: f64,
38
39 std: f64,
41 },
42 KaimingUniform {
44 gain: f64,
46
47 fan_out_only: bool,
49 },
50 KaimingNormal {
52 gain: f64,
54
55 fan_out_only: bool,
57 },
58 XavierUniform {
62 gain: f64,
64 },
65 XavierNormal {
69 gain: f64,
71 },
72 Orthogonal {
76 gain: f64,
78 },
79}
80
81impl Initializer {
82 pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
88 &self,
89 shape: S,
90 device: &B::Device,
91 ) -> Param<Tensor<B, D>> {
92 self.init_with(shape, None, None, device)
93 }
94
95 pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
101 &self,
102 shape: S,
103 fan_in: Option<usize>,
104 fan_out: Option<usize>,
105 device: &B::Device,
106 ) -> Param<Tensor<B, D>> {
107 let device = device.clone();
108 let shape: Shape = shape.into();
109 let config = self.clone();
110 let shape_for_closure = shape.clone();
111
112 Param::uninitialized(
113 ParamId::new(),
114 move |device, require_grad| {
115 let config = config.clone();
116 let shape = shape.clone();
117 B::memory_persistent_allocations(device, (), move |_| {
118 let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
119
120 if require_grad {
121 tensor = tensor.require_grad();
122 }
123
124 tensor
125 })
126 },
127 device,
128 true,
129 shape_for_closure,
130 )
131 }
132
133 fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
134 &self,
135 shape: S,
136 fan_in: Option<usize>,
137 fan_out: Option<usize>,
138 device: &B::Device,
139 ) -> Tensor<B, D> {
140 let shape = shape.into();
141 match self {
142 Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
143 Initializer::Ones => Tensor::<B, D>::ones(shape, device),
144 Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
145 Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
146 Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
147 Initializer::KaimingUniform { gain, fan_out_only } => {
148 let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
149 uniform_draw(shape, -a, a, device)
150 }
151 Initializer::KaimingNormal { gain, fan_out_only } => {
152 let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
153 normal_draw(shape, 0.0, std, device)
154 }
155 Initializer::XavierUniform { gain } => {
156 let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
157 uniform_draw(shape, -a, a, device)
158 }
159 Initializer::XavierNormal { gain } => {
160 let std = *gain * self.xavier_std(fan_in, fan_out);
161 normal_draw(shape, 0.0, std, device)
162 }
163 Initializer::Orthogonal { gain } => {
164 assert!(
168 D >= 2,
169 "Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
170 );
171
172 let rows: usize = shape.dims::<D>()[0];
173 let cols: usize = shape.num_elements() / rows;
174
175 let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
176
177 if rows < cols {
178 t = t.transpose();
179 }
180
181 let (q, r) = qr_decomposition(t, device);
182 let [r_rows, r_cols] = r.clone().dims();
183
184 let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
185 .matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
186
187 let ph = diag_r.clone().sign();
188
189 let mut q = q.mul(ph);
190
191 if rows < cols {
192 q = q.transpose();
193 }
194
195 q.reshape(shape).mul_scalar(*gain)
196 }
197 }
198 }
199
200 fn kaiming_std(
201 &self,
202 fan_out_only: bool,
203 fan_in: Option<usize>,
204 fan_out: Option<usize>,
205 ) -> f64 {
206 let fan = if fan_out_only { fan_out } else { fan_in };
207 let fan = fan.expect(
208 "Can't use Kaiming initialization without specifying fan. Use init_with method.",
209 );
210
211 1.0 / (fan as f64).sqrt()
212 }
213
214 fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
215 let fan_in = fan_in.expect(
216 "Can't use Xavier initialization without specifying fan in. Use init_with method and \
217 provide fan_in.",
218 );
219 let fan_out = fan_out.expect(
220 "Can't use Xavier initialization without specifying fan out. Use init_with method and \
221 provide fan_out.",
222 );
223 (2.0 / (fan_in + fan_out) as f64).sqrt()
224 }
225}
226
227fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
228 shape: S,
229 low: f64,
230 high: f64,
231 device: &B::Device,
232) -> Tensor<B, D> {
233 let distribution = Distribution::Uniform(low, high);
234 Tensor::<B, D>::random(shape, distribution, device)
235}
236
237fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
238 shape: S,
239 mean: f64,
240 std: f64,
241 device: &B::Device,
242) -> Tensor<B, D> {
243 let distribution = Distribution::Normal(mean, std);
244 Tensor::<B, D>::random(shape, distribution, device)
245}
246
247fn qr_decomposition<B: Backend>(
248 a: Tensor<B, 2>,
249 device: &B::Device,
250) -> (Tensor<B, 2>, Tensor<B, 2>) {
251 let [m, n] = a.clone().dims();
254 let mut q = Tensor::<B, 2>::zeros([m, n], device);
255 let mut r = Tensor::<B, 2>::zeros([n, n], device);
256
257 for j in 0..n {
258 let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);
259
260 for i in 0..j {
261 let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);
262 let r_ij = q_i.clone().mul(v.clone()).sum();
263
264 r = r
265 .clone()
266 .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
267
268 v = v - q_i.mul(r_ij);
269 }
270
271 let r_jj = v
273 .clone()
274 .powf(Tensor::from_floats([2.0], device))
275 .sum()
276 .sqrt();
277
278 r = r
279 .clone()
280 .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
281
282 let q_j = v / r_jj;
283
284 q = q
285 .clone()
286 .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
287 }
288
289 (q, r)
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 use burn_tensor::{ElementConversion, TensorData};
297 use num_traits::Pow;
298
299 pub type TB = burn_flex::Flex;
300 use burn_tensor::{Tolerance, ops::FloatElem};
301 type FT = FloatElem<TB>;
302
303 fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
304 let (actual_vars, actual_means) = tensor.clone().var_mean(0);
305 let actual_vars = actual_vars.to_data();
306 let actual_vars = actual_vars.as_slice::<FT>().unwrap();
307 let actual_means = actual_means.to_data();
308 let actual_means = actual_means.as_slice::<FT>().unwrap();
309
310 for i in 0..tensor.shape()[0] {
311 let actual_var = actual_vars[i] as f64;
312 let actual_mean = actual_means[i] as f64;
313
314 assert!(
315 (expected_var - actual_var).abs() <= 0.1,
316 "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
317 );
318 assert!(
319 (expected_mean - actual_mean).abs() <= 0.1,
320 "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
321 );
322 }
323 }
324
325 #[test]
326 fn initializer_uniform_init() {
327 let device = Default::default();
328 TB::seed(&device, 0);
329
330 let (min, max) = (0.0, 1.0);
331 let uniform = Initializer::Uniform { min, max };
332 let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
333
334 tensor
335 .into_data()
336 .assert_within_range::<FT>(min.elem()..max.elem());
337 }
338
339 #[test]
340 fn initializer_normal_init() {
341 let device = Default::default();
343 TB::seed(&device, 0);
344
345 let (mean, std) = (0.0, 1.0);
346 let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
347 .init([10000], &Default::default())
348 .into_value();
349 let (var_act, mean_act) = normal.var_mean(0);
350
351 let var_act: f32 = var_act.into_scalar().elem();
352 let mean_act: f32 = mean_act.into_scalar().elem();
353
354 assert!(
355 var_act > 0.9 && var_act < 1.1,
356 "Expected variance to be between 1.0 += 0.1, but got {var_act}"
357 );
358 assert!(
359 mean_act > -0.1 && mean_act < 0.1,
360 "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
361 );
362 }
363
364 #[test]
365 fn initializer_constant_init() {
366 let value = 5.0;
367 let constants: Tensor<TB, 4> = Initializer::Constant { value }
368 .init([2, 2, 2, 2], &Default::default())
369 .into_value();
370 constants.sum().to_data().assert_approx_eq::<FT>(
371 &TensorData::from([value as f32 * 16.0]),
372 Tolerance::default(),
373 );
374 }
375
376 #[test]
377 fn initializer_zeros_init() {
378 let zeros: Tensor<TB, 4> = Initializer::Zeros
379 .init([2, 2, 2, 2], &Default::default())
380 .into_value();
381 zeros
382 .sum()
383 .to_data()
384 .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
385 }
386
387 #[test]
388 fn initializer_ones_init() {
389 let ones: Tensor<TB, 4> = Initializer::Ones
390 .init([2, 2, 2, 2], &Default::default())
391 .into_value();
392 ones.sum()
393 .to_data()
394 .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
395 }
396
397 #[test]
398 fn initializer_kaiming_uniform_init() {
399 let device = Default::default();
400 TB::seed(&device, 0);
401
402 let gain = 2_f64;
403 let (fan_in, fan_out) = (5, 6);
404 let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
405
406 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
407 gain,
408 fan_out_only: false,
409 }
410 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
411 .into_value();
412 tensor.into_data().assert_within_range(-k..k);
413 }
414
415 #[test]
416 fn initializer_kaiming_normal_init() {
417 let device = Default::default();
418 TB::seed(&device, 0);
419
420 let gain = 2.;
421 let (fan_in, fan_out) = (1000, 10);
422 let expected_mean = 0_f64;
423
424 let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
425 let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
426 gain,
427 fan_out_only: false,
428 }
429 .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
430 .into_value();
431 assert_normal_init(expected_mean, expected_var, &tensor)
432 }
433
434 #[test]
435 fn initializer_kaiming_uniform_init_bias() {
436 let device = Default::default();
437 TB::seed(&device, 0);
438
439 let gain = 2_f64;
440 let shape = [3];
441 let fan_in = 5;
442 let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
443
444 let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
445 gain,
446 fan_out_only: false,
447 }
448 .init_with(shape, Some(fan_in), None, &Default::default())
449 .into_value();
450 tensor.into_data().assert_within_range(-k..k);
451 }
452
453 #[test]
454 fn initializer_kaiming_uniform_init_fan_out() {
455 let device = Default::default();
456 TB::seed(&device, 0);
457
458 let gain = 2_f64;
459 let (fan_in, fan_out) = (5, 6);
460 let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
461
462 let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
463 gain,
464 fan_out_only: true,
465 }
466 .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
467 .into_value();
468 tensor.into_data().assert_within_range(-k..k);
469 }
470
471 #[test]
472 #[should_panic]
473 fn initializer_kaiming_uniform_no_fan() {
474 let device = Default::default();
475 TB::seed(&device, 0);
476
477 let gain = 2_f64;
478 let (fan_in, fan_out) = (5, 6);
479
480 let _: Tensor<TB, 2> = Initializer::KaimingUniform {
481 gain,
482 fan_out_only: false,
483 }
484 .init([fan_out, fan_in], &Default::default())
485 .into_value();
486 }
487
488 #[test]
489 fn initializer_xavier_uniform_init() {
490 let device = Default::default();
491 TB::seed(&device, 0);
492
493 let gain = 2.;
494 let (fan_in, fan_out) = (5, 6);
495 let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
496 let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
497 .init_with(
498 [fan_out, fan_in],
499 Some(fan_in),
500 Some(fan_out),
501 &Default::default(),
502 )
503 .into_value();
504
505 tensor.into_data().assert_within_range(-bound..bound);
506 }
507
508 #[test]
509 fn initializer_xavier_normal_init() {
510 let device = Default::default();
511 TB::seed(&device, 0);
512
513 let gain = 2.;
514 let (fan_in, fan_out) = (1000, 10);
515 let expected_mean = 0_f64;
516
517 let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
518 let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
519 .init_with(
520 [fan_out, fan_in],
521 Some(fan_in),
522 Some(fan_out),
523 &Default::default(),
524 )
525 .into_value();
526 assert_normal_init(expected_mean, expected_var, &tensor)
527 }
528
529 #[test]
530 #[should_panic]
531 fn initializer_xavier_uniform_no_fan() {
532 let device = Default::default();
533 TB::seed(&device, 0);
534
535 let gain = 2.;
536 let (fan_in, fan_out) = (5, 6);
537 let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
538 .init([fan_out, fan_in], &Default::default())
539 .into_value();
540 }
541
542 #[test]
543 fn test_qr_decomposition() {
544 let device = Default::default();
545 TB::seed(&device, 0);
546
547 let a = Tensor::<TB, 2>::from_floats(
549 [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
550 &Default::default(),
551 );
552 let qr = qr_decomposition(a.clone(), &Default::default());
553
554 let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
556
557 q_matmul_r
559 .into_data()
560 .assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
561 }
562
563 #[test]
564 fn initializer_orthogonal_correct() {
565 let device = Default::default();
566 TB::seed(&device, 0);
567
568 let gain = 1.;
569
570 let size = 10;
572 let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
573 .init([size, size], &Default::default())
574 .into_value();
575 let eye = Tensor::<TB, 2>::eye(size, &Default::default());
576
577 q.clone()
579 .transpose()
580 .matmul(q)
581 .into_data()
582 .assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
583 }
584
585 #[test]
586 fn initializer_orthogonal_init() {
587 let device = Default::default();
588 TB::seed(&device, 0);
589
590 let gain = 1.;
591
592 let shape = [25, 30];
594 let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
595 .init(shape, &Default::default())
596 .into_value();
597 let dims = t.dims();
598 assert_eq!(
599 shape, dims,
600 "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
601 );
602
603 let shape = [24, 6, 85];
605 let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
606 .init(shape, &Default::default())
607 .into_value();
608 let dims = t.dims();
609 assert_eq!(
610 shape, dims,
611 "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
612 );
613 }
614
615 #[test]
616 #[should_panic]
617 fn initializer_orthogonal_init_1d() {
618 let device = Default::default();
619 TB::seed(&device, 0);
620
621 let gain = 1.;
622
623 let shape = [3];
625 let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
626 .init(shape, &Default::default())
627 .into_value();
628 }
629}