burn_core/module/
initializer.rs

1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor, s};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11#[allow(unused_imports)]
12use num_traits::Float as _;
13
14/// Enum specifying with what values a tensor should be initialized
15#[derive(Config, Debug, PartialEq)]
16pub enum Initializer {
17    /// Fills tensor with specified value everywhere
18    Constant {
19        /// The value to fill the tensor with
20        value: f64,
21    },
22    /// Fills tensor with 1s everywhere
23    Ones,
24    /// Fills tensor with 0s everywhere
25    Zeros,
26    /// Fills tensor with values drawn uniformly between specified values
27    Uniform {
28        /// The minimum value to draw from
29        min: f64,
30
31        /// The maximum value to draw from
32        max: f64,
33    },
34    /// Fills tensor with values drawn from normal distribution with specified mean and std
35    Normal {
36        /// The mean of the normal distribution
37        mean: f64,
38
39        /// The standard deviation of the normal distribution
40        std: f64,
41    },
42    /// Fills tensor with values according to the uniform version of Kaiming initialization
43    KaimingUniform {
44        /// The gain to use in initialization formula
45        gain: f64,
46
47        /// Whether to use fan out only in initialization formula
48        fan_out_only: bool,
49    },
50    /// Fills tensor with values according to the uniform version of Kaiming initialization
51    KaimingNormal {
52        /// The gain to use in initialization formula
53        gain: f64,
54
55        /// Whether to use fan out only in initialization formula
56        fan_out_only: bool,
57    },
58    /// Fills tensor with values according to the uniform version of Xavier Glorot initialization
59    /// described in [Understanding the difficulty of training deep feedforward neural networks
60    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
61    XavierUniform {
62        /// The gain to use in initialization formula
63        gain: f64,
64    },
65    /// Fills tensor with values according to the normal version of Xavier Glorot initialization
66    /// described in [Understanding the difficulty of training deep feedforward neural networks
67    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
68    XavierNormal {
69        /// The gain to use in initialization formula
70        gain: f64,
71    },
72    /// Fills tensor with values according to the (semi) orthogonal initialization
73    /// described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks`
74    ///  - [Saxe, A. et al. (2013)](https://arxiv.org/abs/1312.6120)
75    Orthogonal {
76        /// The gain to use in initialization formula
77        gain: f64,
78    },
79}
80
81impl Initializer {
82    /// Inits a tensor parameter of given shape with values depending on initializer kind.
83    ///
84    /// # Params
85    ///
86    /// - shape: Shape of the initiated tensor.
87    pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
88        &self,
89        shape: S,
90        device: &B::Device,
91    ) -> Param<Tensor<B, D>> {
92        self.init_with(shape, None, None, device)
93    }
94
95    /// Inits a tensor parameter of given shape with values depending on initializer kind.
96    ///
97    /// # Params
98    ///
99    /// - shape: Shape of the initiated tensor.
100    pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
101        &self,
102        shape: S,
103        fan_in: Option<usize>,
104        fan_out: Option<usize>,
105        device: &B::Device,
106    ) -> Param<Tensor<B, D>> {
107        let device = device.clone();
108        let shape: Shape = shape.into();
109        let config = self.clone();
110        let shape_for_closure = shape.clone();
111
112        Param::uninitialized(
113            ParamId::new(),
114            move |device, require_grad| {
115                B::memory_persistent_allocations(device, (), move |_| {
116                    let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
117
118                    if require_grad {
119                        tensor = tensor.require_grad();
120                    }
121
122                    tensor
123                })
124            },
125            device,
126            true,
127            shape_for_closure,
128        )
129    }
130
131    fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
132        &self,
133        shape: S,
134        fan_in: Option<usize>,
135        fan_out: Option<usize>,
136        device: &B::Device,
137    ) -> Tensor<B, D> {
138        let shape = shape.into();
139        match self {
140            Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
141            Initializer::Ones => Tensor::<B, D>::ones(shape, device),
142            Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
143            Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
144            Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
145            Initializer::KaimingUniform { gain, fan_out_only } => {
146                let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
147                uniform_draw(shape, -a, a, device)
148            }
149            Initializer::KaimingNormal { gain, fan_out_only } => {
150                let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
151                normal_draw(shape, 0.0, std, device)
152            }
153            Initializer::XavierUniform { gain } => {
154                let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
155                uniform_draw(shape, -a, a, device)
156            }
157            Initializer::XavierNormal { gain } => {
158                let std = *gain * self.xavier_std(fan_in, fan_out);
159                normal_draw(shape, 0.0, std, device)
160            }
161            Initializer::Orthogonal { gain } => {
162                // following the implementation in pytorch:
163                // https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/init.py#L574
164
165                assert!(
166                    D >= 2,
167                    "Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
168                );
169
170                let rows: usize = shape.dims::<D>()[0];
171                let cols: usize = shape.num_elements() / rows;
172
173                let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
174
175                if rows < cols {
176                    t = t.transpose();
177                }
178
179                let (q, r) = qr_decomposition(t, device);
180                let [r_rows, r_cols] = r.clone().dims();
181
182                let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
183                    .matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
184
185                let ph = diag_r.clone().sign();
186
187                let mut q = q.mul(ph);
188
189                if rows < cols {
190                    q = q.transpose();
191                }
192
193                q.reshape(shape).mul_scalar(*gain)
194            }
195        }
196    }
197
198    fn kaiming_std(
199        &self,
200        fan_out_only: bool,
201        fan_in: Option<usize>,
202        fan_out: Option<usize>,
203    ) -> f64 {
204        let fan = if fan_out_only { fan_out } else { fan_in };
205        let fan = fan.expect(
206            "Can't use Kaiming initialization without specifying fan. Use init_with method.",
207        );
208
209        1.0 / (fan as f64).sqrt()
210    }
211
212    fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
213        let fan_in = fan_in.expect(
214            "Can't use Xavier initialization without specifying fan in. Use init_with method and \
215             provide fan_in.",
216        );
217        let fan_out = fan_out.expect(
218            "Can't use Xavier initialization without specifying fan out. Use init_with method and \
219             provide fan_out.",
220        );
221        (2.0 / (fan_in + fan_out) as f64).sqrt()
222    }
223}
224
225fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
226    shape: S,
227    low: f64,
228    high: f64,
229    device: &B::Device,
230) -> Tensor<B, D> {
231    let distribution = Distribution::Uniform(low, high);
232    Tensor::<B, D>::random(shape, distribution, device)
233}
234
235fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
236    shape: S,
237    mean: f64,
238    std: f64,
239    device: &B::Device,
240) -> Tensor<B, D> {
241    let distribution = Distribution::Normal(mean, std);
242    Tensor::<B, D>::random(shape, distribution, device)
243}
244
245fn qr_decomposition<B: Backend>(
246    a: Tensor<B, 2>,
247    device: &B::Device,
248) -> (Tensor<B, 2>, Tensor<B, 2>) {
249    // Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
250
251    let [m, n] = a.clone().dims();
252    let mut q = Tensor::<B, 2>::zeros([m, n], device);
253    let mut r = Tensor::<B, 2>::zeros([n, n], device);
254
255    for j in 0..n {
256        let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);
257
258        for i in 0..j {
259            let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);
260            let r_ij = q_i.clone().mul(v.clone()).sum();
261
262            r = r
263                .clone()
264                .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
265
266            v = v - q_i.mul(r_ij);
267        }
268
269        // norm of v
270        let r_jj = v
271            .clone()
272            .powf(Tensor::from_floats([2.0], device))
273            .sum()
274            .sqrt();
275
276        r = r
277            .clone()
278            .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
279
280        let q_j = v / r_jj;
281
282        q = q
283            .clone()
284            .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
285    }
286
287    (q, r)
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    use burn_tensor::{ElementConversion, TensorData};
295    use num_traits::Pow;
296
297    pub type TB = burn_ndarray::NdArray<f32>;
298    use burn_tensor::{Tolerance, ops::FloatElem};
299    type FT = FloatElem<TB>;
300
301    fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
302        let (actual_vars, actual_means) = tensor.clone().var_mean(0);
303        let actual_vars = actual_vars.to_data();
304        let actual_vars = actual_vars.as_slice::<FT>().unwrap();
305        let actual_means = actual_means.to_data();
306        let actual_means = actual_means.as_slice::<FT>().unwrap();
307
308        for i in 0..tensor.shape().dims[0] {
309            let actual_var = actual_vars[i] as f64;
310            let actual_mean = actual_means[i] as f64;
311
312            assert!(
313                (expected_var - actual_var).abs() <= 0.1,
314                "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
315            );
316            assert!(
317                (expected_mean - actual_mean).abs() <= 0.1,
318                "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
319            );
320        }
321    }
322
323    #[test]
324    fn initializer_uniform_init() {
325        let device = Default::default();
326        TB::seed(&device, 0);
327
328        let (min, max) = (0.0, 1.0);
329        let uniform = Initializer::Uniform { min, max };
330        let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
331
332        tensor
333            .into_data()
334            .assert_within_range::<FT>(min.elem()..max.elem());
335    }
336
337    #[test]
338    fn initializer_normal_init() {
339        // seed random generator
340        let device = Default::default();
341        TB::seed(&device, 0);
342
343        let (mean, std) = (0.0, 1.0);
344        let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
345            .init([1000], &Default::default())
346            .into_value();
347        let (var_act, mean_act) = normal.var_mean(0);
348
349        let var_act: f32 = var_act.into_scalar().elem();
350        let mean_act: f32 = mean_act.into_scalar().elem();
351
352        assert!(
353            var_act > 0.9 && var_act < 1.1,
354            "Expected variance to be between 1.0 += 0.1, but got {var_act}"
355        );
356        assert!(
357            mean_act > -0.1 && mean_act < 0.1,
358            "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
359        );
360    }
361
362    #[test]
363    fn initializer_constant_init() {
364        let value = 5.0;
365        let constants: Tensor<TB, 4> = Initializer::Constant { value }
366            .init([2, 2, 2, 2], &Default::default())
367            .into_value();
368        constants.sum().to_data().assert_approx_eq::<FT>(
369            &TensorData::from([value as f32 * 16.0]),
370            Tolerance::default(),
371        );
372    }
373
374    #[test]
375    fn initializer_zeros_init() {
376        let zeros: Tensor<TB, 4> = Initializer::Zeros
377            .init([2, 2, 2, 2], &Default::default())
378            .into_value();
379        zeros
380            .sum()
381            .to_data()
382            .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
383    }
384
385    #[test]
386    fn initializer_ones_init() {
387        let ones: Tensor<TB, 4> = Initializer::Ones
388            .init([2, 2, 2, 2], &Default::default())
389            .into_value();
390        ones.sum()
391            .to_data()
392            .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
393    }
394
395    #[test]
396    fn initializer_kaiming_uniform_init() {
397        let device = Default::default();
398        TB::seed(&device, 0);
399
400        let gain = 2_f64;
401        let (fan_in, fan_out) = (5, 6);
402        let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
403
404        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
405            gain,
406            fan_out_only: false,
407        }
408        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
409        .into_value();
410        tensor.into_data().assert_within_range(-k..k);
411    }
412
413    #[test]
414    fn initializer_kaiming_normal_init() {
415        let device = Default::default();
416        TB::seed(&device, 0);
417
418        let gain = 2.;
419        let (fan_in, fan_out) = (1000, 10);
420        let expected_mean = 0_f64;
421
422        let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
423        let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
424            gain,
425            fan_out_only: false,
426        }
427        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
428        .into_value();
429        assert_normal_init(expected_mean, expected_var, &tensor)
430    }
431
432    #[test]
433    fn initializer_kaiming_uniform_init_bias() {
434        let device = Default::default();
435        TB::seed(&device, 0);
436
437        let gain = 2_f64;
438        let shape = [3];
439        let fan_in = 5;
440        let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
441
442        let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
443            gain,
444            fan_out_only: false,
445        }
446        .init_with(shape, Some(fan_in), None, &Default::default())
447        .into_value();
448        tensor.into_data().assert_within_range(-k..k);
449    }
450
451    #[test]
452    fn initializer_kaiming_uniform_init_fan_out() {
453        let device = Default::default();
454        TB::seed(&device, 0);
455
456        let gain = 2_f64;
457        let (fan_in, fan_out) = (5, 6);
458        let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
459
460        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
461            gain,
462            fan_out_only: true,
463        }
464        .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
465        .into_value();
466        tensor.into_data().assert_within_range(-k..k);
467    }
468
469    #[test]
470    #[should_panic]
471    fn initializer_kaiming_uniform_no_fan() {
472        let device = Default::default();
473        TB::seed(&device, 0);
474
475        let gain = 2_f64;
476        let (fan_in, fan_out) = (5, 6);
477
478        let _: Tensor<TB, 2> = Initializer::KaimingUniform {
479            gain,
480            fan_out_only: false,
481        }
482        .init([fan_out, fan_in], &Default::default())
483        .into_value();
484    }
485
486    #[test]
487    fn initializer_xavier_uniform_init() {
488        let device = Default::default();
489        TB::seed(&device, 0);
490
491        let gain = 2.;
492        let (fan_in, fan_out) = (5, 6);
493        let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
494        let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
495            .init_with(
496                [fan_out, fan_in],
497                Some(fan_in),
498                Some(fan_out),
499                &Default::default(),
500            )
501            .into_value();
502
503        tensor.into_data().assert_within_range(-bound..bound);
504    }
505
506    #[test]
507    fn initializer_xavier_normal_init() {
508        let device = Default::default();
509        TB::seed(&device, 0);
510
511        let gain = 2.;
512        let (fan_in, fan_out) = (1000, 10);
513        let expected_mean = 0_f64;
514
515        let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
516        let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
517            .init_with(
518                [fan_out, fan_in],
519                Some(fan_in),
520                Some(fan_out),
521                &Default::default(),
522            )
523            .into_value();
524        assert_normal_init(expected_mean, expected_var, &tensor)
525    }
526
527    #[test]
528    #[should_panic]
529    fn initializer_xavier_uniform_no_fan() {
530        let device = Default::default();
531        TB::seed(&device, 0);
532
533        let gain = 2.;
534        let (fan_in, fan_out) = (5, 6);
535        let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
536            .init([fan_out, fan_in], &Default::default())
537            .into_value();
538    }
539
540    #[test]
541    fn test_qr_decomposition() {
542        let device = Default::default();
543        TB::seed(&device, 0);
544
545        // test values follow the example from https://pytorch.org/docs/stable/generated/torch.linalg.qr.html#torch.linalg.qr
546        let a = Tensor::<TB, 2>::from_floats(
547            [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
548            &Default::default(),
549        );
550        let qr = qr_decomposition(a.clone(), &Default::default());
551
552        // Q @ R should reconstruct input `a`
553        let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
554
555        // assert that the difference between input (`a`) and Q @ R is (almost) zero
556        q_matmul_r
557            .into_data()
558            .assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
559    }
560
561    #[test]
562    fn initializer_orthogonal_correct() {
563        let device = Default::default();
564        TB::seed(&device, 0);
565
566        let gain = 1.;
567
568        // test 2D tensor
569        let size = 10;
570        let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
571            .init([size, size], &Default::default())
572            .into_value();
573        let eye = Tensor::<TB, 2>::eye(size, &Default::default());
574
575        // Q.T @ Q should be close to identity matrix
576        q.clone()
577            .transpose()
578            .matmul(q)
579            .into_data()
580            .assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
581    }
582
583    #[test]
584    fn initializer_orthogonal_init() {
585        let device = Default::default();
586        TB::seed(&device, 0);
587
588        let gain = 1.;
589
590        // test 2D tensor
591        let shape = [25, 30];
592        let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
593            .init(shape, &Default::default())
594            .into_value();
595        let dims = t.dims();
596        assert_eq!(
597            shape, dims,
598            "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
599        );
600
601        // test 3D tensor
602        let shape = [24, 6, 85];
603        let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
604            .init(shape, &Default::default())
605            .into_value();
606        let dims = t.dims();
607        assert_eq!(
608            shape, dims,
609            "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
610        );
611    }
612
613    #[test]
614    #[should_panic]
615    fn initializer_orthogonal_init_1d() {
616        let device = Default::default();
617        TB::seed(&device, 0);
618
619        let gain = 1.;
620
621        // test 1D tensor
622        let shape = [3];
623        let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
624            .init(shape, &Default::default())
625            .into_value();
626    }
627}