Skip to main content

burn_core/module/
initializer.rs

1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor, s};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11#[allow(unused_imports)]
12use num_traits::Float as _;
13
14/// Enum specifying with what values a tensor should be initialized
15#[derive(Config, Debug, PartialEq)]
16pub enum Initializer {
17    /// Fills tensor with specified value everywhere
18    Constant {
19        /// The value to fill the tensor with
20        value: f64,
21    },
22    /// Fills tensor with 1s everywhere
23    Ones,
24    /// Fills tensor with 0s everywhere
25    Zeros,
26    /// Fills tensor with values drawn uniformly between specified values
27    Uniform {
28        /// The minimum value to draw from
29        min: f64,
30
31        /// The maximum value to draw from
32        max: f64,
33    },
34    /// Fills tensor with values drawn from normal distribution with specified mean and std
35    Normal {
36        /// The mean of the normal distribution
37        mean: f64,
38
39        /// The standard deviation of the normal distribution
40        std: f64,
41    },
42    /// Fills tensor with values according to the uniform version of Kaiming initialization
43    KaimingUniform {
44        /// The gain to use in initialization formula
45        gain: f64,
46
47        /// Whether to use fan out only in initialization formula
48        fan_out_only: bool,
49    },
50    /// Fills tensor with values according to the uniform version of Kaiming initialization
51    KaimingNormal {
52        /// The gain to use in initialization formula
53        gain: f64,
54
55        /// Whether to use fan out only in initialization formula
56        fan_out_only: bool,
57    },
58    /// Fills tensor with values according to the uniform version of Xavier Glorot initialization
59    /// described in [Understanding the difficulty of training deep feedforward neural networks
60    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
61    XavierUniform {
62        /// The gain to use in initialization formula
63        gain: f64,
64    },
65    /// Fills tensor with values according to the normal version of Xavier Glorot initialization
66    /// described in [Understanding the difficulty of training deep feedforward neural networks
67    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
68    XavierNormal {
69        /// The gain to use in initialization formula
70        gain: f64,
71    },
72    /// Fills tensor with values according to the (semi) orthogonal initialization
73    /// described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks`
74    ///  - [Saxe, A. et al. (2013)](https://arxiv.org/abs/1312.6120)
75    Orthogonal {
76        /// The gain to use in initialization formula
77        gain: f64,
78    },
79}
80
81impl Initializer {
82    /// Inits a tensor parameter of given shape with values depending on initializer kind.
83    ///
84    /// # Params
85    ///
86    /// - shape: Shape of the initiated tensor.
87    pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
88        &self,
89        shape: S,
90        device: &B::Device,
91    ) -> Param<Tensor<B, D>> {
92        self.init_with(shape, None, None, device)
93    }
94
95    /// Inits a tensor parameter of given shape with values depending on initializer kind.
96    ///
97    /// # Params
98    ///
99    /// - shape: Shape of the initiated tensor.
100    pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
101        &self,
102        shape: S,
103        fan_in: Option<usize>,
104        fan_out: Option<usize>,
105        device: &B::Device,
106    ) -> Param<Tensor<B, D>> {
107        let device = device.clone();
108        let shape: Shape = shape.into();
109        let config = self.clone();
110        let shape_for_closure = shape.clone();
111
112        Param::uninitialized(
113            ParamId::new(),
114            move |device, require_grad| {
115                let config = config.clone();
116                let shape = shape.clone();
117                B::memory_persistent_allocations(device, (), move |_| {
118                    let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
119
120                    if require_grad {
121                        tensor = tensor.require_grad();
122                    }
123
124                    tensor
125                })
126            },
127            device,
128            true,
129            shape_for_closure,
130        )
131    }
132
133    fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
134        &self,
135        shape: S,
136        fan_in: Option<usize>,
137        fan_out: Option<usize>,
138        device: &B::Device,
139    ) -> Tensor<B, D> {
140        let shape = shape.into();
141        match self {
142            Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
143            Initializer::Ones => Tensor::<B, D>::ones(shape, device),
144            Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
145            Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
146            Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
147            Initializer::KaimingUniform { gain, fan_out_only } => {
148                let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
149                uniform_draw(shape, -a, a, device)
150            }
151            Initializer::KaimingNormal { gain, fan_out_only } => {
152                let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
153                normal_draw(shape, 0.0, std, device)
154            }
155            Initializer::XavierUniform { gain } => {
156                let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
157                uniform_draw(shape, -a, a, device)
158            }
159            Initializer::XavierNormal { gain } => {
160                let std = *gain * self.xavier_std(fan_in, fan_out);
161                normal_draw(shape, 0.0, std, device)
162            }
163            Initializer::Orthogonal { gain } => {
164                // following the implementation in pytorch:
165                // https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/init.py#L574
166
167                assert!(
168                    D >= 2,
169                    "Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
170                );
171
172                let rows: usize = shape.dims::<D>()[0];
173                let cols: usize = shape.num_elements() / rows;
174
175                let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
176
177                if rows < cols {
178                    t = t.transpose();
179                }
180
181                let (q, r) = qr_decomposition(t, device);
182                let [r_rows, r_cols] = r.clone().dims();
183
184                let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
185                    .matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
186
187                let ph = diag_r.clone().sign();
188
189                let mut q = q.mul(ph);
190
191                if rows < cols {
192                    q = q.transpose();
193                }
194
195                q.reshape(shape).mul_scalar(*gain)
196            }
197        }
198    }
199
200    fn kaiming_std(
201        &self,
202        fan_out_only: bool,
203        fan_in: Option<usize>,
204        fan_out: Option<usize>,
205    ) -> f64 {
206        let fan = if fan_out_only { fan_out } else { fan_in };
207        let fan = fan.expect(
208            "Can't use Kaiming initialization without specifying fan. Use init_with method.",
209        );
210
211        1.0 / (fan as f64).sqrt()
212    }
213
214    fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
215        let fan_in = fan_in.expect(
216            "Can't use Xavier initialization without specifying fan in. Use init_with method and \
217             provide fan_in.",
218        );
219        let fan_out = fan_out.expect(
220            "Can't use Xavier initialization without specifying fan out. Use init_with method and \
221             provide fan_out.",
222        );
223        (2.0 / (fan_in + fan_out) as f64).sqrt()
224    }
225}
226
227fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
228    shape: S,
229    low: f64,
230    high: f64,
231    device: &B::Device,
232) -> Tensor<B, D> {
233    let distribution = Distribution::Uniform(low, high);
234    Tensor::<B, D>::random(shape, distribution, device)
235}
236
237fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
238    shape: S,
239    mean: f64,
240    std: f64,
241    device: &B::Device,
242) -> Tensor<B, D> {
243    let distribution = Distribution::Normal(mean, std);
244    Tensor::<B, D>::random(shape, distribution, device)
245}
246
247fn qr_decomposition<B: Backend>(
248    a: Tensor<B, 2>,
249    device: &B::Device,
250) -> (Tensor<B, 2>, Tensor<B, 2>) {
251    // Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
252
253    let [m, n] = a.clone().dims();
254    let mut q = Tensor::<B, 2>::zeros([m, n], device);
255    let mut r = Tensor::<B, 2>::zeros([n, n], device);
256
257    for j in 0..n {
258        let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);
259
260        for i in 0..j {
261            let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);
262            let r_ij = q_i.clone().mul(v.clone()).sum();
263
264            r = r
265                .clone()
266                .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
267
268            v = v - q_i.mul(r_ij);
269        }
270
271        // norm of v
272        let r_jj = v
273            .clone()
274            .powf(Tensor::from_floats([2.0], device))
275            .sum()
276            .sqrt();
277
278        r = r
279            .clone()
280            .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
281
282        let q_j = v / r_jj;
283
284        q = q
285            .clone()
286            .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
287    }
288
289    (q, r)
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    use burn_tensor::{ElementConversion, TensorData};
297    use num_traits::Pow;
298
299    pub type TB = burn_flex::Flex;
300    use burn_tensor::{Tolerance, ops::FloatElem};
301    type FT = FloatElem<TB>;
302
303    fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
304        let (actual_vars, actual_means) = tensor.clone().var_mean(0);
305        let actual_vars = actual_vars.to_data();
306        let actual_vars = actual_vars.as_slice::<FT>().unwrap();
307        let actual_means = actual_means.to_data();
308        let actual_means = actual_means.as_slice::<FT>().unwrap();
309
310        for i in 0..tensor.shape()[0] {
311            let actual_var = actual_vars[i] as f64;
312            let actual_mean = actual_means[i] as f64;
313
314            assert!(
315                (expected_var - actual_var).abs() <= 0.1,
316                "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
317            );
318            assert!(
319                (expected_mean - actual_mean).abs() <= 0.1,
320                "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
321            );
322        }
323    }
324
325    #[test]
326    fn initializer_uniform_init() {
327        let device = Default::default();
328        TB::seed(&device, 0);
329
330        let (min, max) = (0.0, 1.0);
331        let uniform = Initializer::Uniform { min, max };
332        let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
333
334        tensor
335            .into_data()
336            .assert_within_range::<FT>(min.elem()..max.elem());
337    }
338
339    #[test]
340    fn initializer_normal_init() {
341        // seed random generator
342        let device = Default::default();
343        TB::seed(&device, 0);
344
345        let (mean, std) = (0.0, 1.0);
346        let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
347            .init([10000], &Default::default())
348            .into_value();
349        let (var_act, mean_act) = normal.var_mean(0);
350
351        let var_act: f32 = var_act.into_scalar().elem();
352        let mean_act: f32 = mean_act.into_scalar().elem();
353
354        assert!(
355            var_act > 0.9 && var_act < 1.1,
356            "Expected variance to be between 1.0 += 0.1, but got {var_act}"
357        );
358        assert!(
359            mean_act > -0.1 && mean_act < 0.1,
360            "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
361        );
362    }
363
364    #[test]
365    fn initializer_constant_init() {
366        let value = 5.0;
367        let constants: Tensor<TB, 4> = Initializer::Constant { value }
368            .init([2, 2, 2, 2], &Default::default())
369            .into_value();
370        constants.sum().to_data().assert_approx_eq::<FT>(
371            &TensorData::from([value as f32 * 16.0]),
372            Tolerance::default(),
373        );
374    }
375
376    #[test]
377    fn initializer_zeros_init() {
378        let zeros: Tensor<TB, 4> = Initializer::Zeros
379            .init([2, 2, 2, 2], &Default::default())
380            .into_value();
381        zeros
382            .sum()
383            .to_data()
384            .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
385    }
386
387    #[test]
388    fn initializer_ones_init() {
389        let ones: Tensor<TB, 4> = Initializer::Ones
390            .init([2, 2, 2, 2], &Default::default())
391            .into_value();
392        ones.sum()
393            .to_data()
394            .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
395    }
396
397    #[test]
398    fn initializer_kaiming_uniform_init() {
399        let device = Default::default();
400        TB::seed(&device, 0);
401
402        let gain = 2_f64;
403        let (fan_in, fan_out) = (5, 6);
404        let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
405
406        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
407            gain,
408            fan_out_only: false,
409        }
410        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
411        .into_value();
412        tensor.into_data().assert_within_range(-k..k);
413    }
414
415    #[test]
416    fn initializer_kaiming_normal_init() {
417        let device = Default::default();
418        TB::seed(&device, 0);
419
420        let gain = 2.;
421        let (fan_in, fan_out) = (1000, 10);
422        let expected_mean = 0_f64;
423
424        let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
425        let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
426            gain,
427            fan_out_only: false,
428        }
429        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
430        .into_value();
431        assert_normal_init(expected_mean, expected_var, &tensor)
432    }
433
434    #[test]
435    fn initializer_kaiming_uniform_init_bias() {
436        let device = Default::default();
437        TB::seed(&device, 0);
438
439        let gain = 2_f64;
440        let shape = [3];
441        let fan_in = 5;
442        let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
443
444        let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
445            gain,
446            fan_out_only: false,
447        }
448        .init_with(shape, Some(fan_in), None, &Default::default())
449        .into_value();
450        tensor.into_data().assert_within_range(-k..k);
451    }
452
453    #[test]
454    fn initializer_kaiming_uniform_init_fan_out() {
455        let device = Default::default();
456        TB::seed(&device, 0);
457
458        let gain = 2_f64;
459        let (fan_in, fan_out) = (5, 6);
460        let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
461
462        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
463            gain,
464            fan_out_only: true,
465        }
466        .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
467        .into_value();
468        tensor.into_data().assert_within_range(-k..k);
469    }
470
471    #[test]
472    #[should_panic]
473    fn initializer_kaiming_uniform_no_fan() {
474        let device = Default::default();
475        TB::seed(&device, 0);
476
477        let gain = 2_f64;
478        let (fan_in, fan_out) = (5, 6);
479
480        let _: Tensor<TB, 2> = Initializer::KaimingUniform {
481            gain,
482            fan_out_only: false,
483        }
484        .init([fan_out, fan_in], &Default::default())
485        .into_value();
486    }
487
488    #[test]
489    fn initializer_xavier_uniform_init() {
490        let device = Default::default();
491        TB::seed(&device, 0);
492
493        let gain = 2.;
494        let (fan_in, fan_out) = (5, 6);
495        let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
496        let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
497            .init_with(
498                [fan_out, fan_in],
499                Some(fan_in),
500                Some(fan_out),
501                &Default::default(),
502            )
503            .into_value();
504
505        tensor.into_data().assert_within_range(-bound..bound);
506    }
507
508    #[test]
509    fn initializer_xavier_normal_init() {
510        let device = Default::default();
511        TB::seed(&device, 0);
512
513        let gain = 2.;
514        let (fan_in, fan_out) = (1000, 10);
515        let expected_mean = 0_f64;
516
517        let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
518        let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
519            .init_with(
520                [fan_out, fan_in],
521                Some(fan_in),
522                Some(fan_out),
523                &Default::default(),
524            )
525            .into_value();
526        assert_normal_init(expected_mean, expected_var, &tensor)
527    }
528
529    #[test]
530    #[should_panic]
531    fn initializer_xavier_uniform_no_fan() {
532        let device = Default::default();
533        TB::seed(&device, 0);
534
535        let gain = 2.;
536        let (fan_in, fan_out) = (5, 6);
537        let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
538            .init([fan_out, fan_in], &Default::default())
539            .into_value();
540    }
541
542    #[test]
543    fn test_qr_decomposition() {
544        let device = Default::default();
545        TB::seed(&device, 0);
546
547        // test values follow the example from https://pytorch.org/docs/stable/generated/torch.linalg.qr.html#torch.linalg.qr
548        let a = Tensor::<TB, 2>::from_floats(
549            [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
550            &Default::default(),
551        );
552        let qr = qr_decomposition(a.clone(), &Default::default());
553
554        // Q @ R should reconstruct input `a`
555        let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
556
557        // assert that the difference between input (`a`) and Q @ R is (almost) zero
558        q_matmul_r
559            .into_data()
560            .assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
561    }
562
563    #[test]
564    fn initializer_orthogonal_correct() {
565        let device = Default::default();
566        TB::seed(&device, 0);
567
568        let gain = 1.;
569
570        // test 2D tensor
571        let size = 10;
572        let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
573            .init([size, size], &Default::default())
574            .into_value();
575        let eye = Tensor::<TB, 2>::eye(size, &Default::default());
576
577        // Q.T @ Q should be close to identity matrix
578        q.clone()
579            .transpose()
580            .matmul(q)
581            .into_data()
582            .assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
583    }
584
585    #[test]
586    fn initializer_orthogonal_init() {
587        let device = Default::default();
588        TB::seed(&device, 0);
589
590        let gain = 1.;
591
592        // test 2D tensor
593        let shape = [25, 30];
594        let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
595            .init(shape, &Default::default())
596            .into_value();
597        let dims = t.dims();
598        assert_eq!(
599            shape, dims,
600            "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
601        );
602
603        // test 3D tensor
604        let shape = [24, 6, 85];
605        let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
606            .init(shape, &Default::default())
607            .into_value();
608        let dims = t.dims();
609        assert_eq!(
610            shape, dims,
611            "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
612        );
613    }
614
615    #[test]
616    #[should_panic]
617    fn initializer_orthogonal_init_1d() {
618        let device = Default::default();
619        TB::seed(&device, 0);
620
621        let gain = 1.;
622
623        // test 1D tensor
624        let shape = [3];
625        let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
626            .init(shape, &Default::default())
627            .into_value();
628    }
629}