burn_core/nn/
initializer.rs

1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor, s};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11use num_traits::Float;
12
13/// Enum specifying with what values a tensor should be initialized
14#[derive(Config, Debug, PartialEq)]
15pub enum Initializer {
16    /// Fills tensor with specified value everywhere
17    Constant {
18        /// The value to fill the tensor with
19        value: f64,
20    },
21    /// Fills tensor with 1s everywhere
22    Ones,
23    /// Fills tensor with 0s everywhere
24    Zeros,
25    /// Fills tensor with values drawn uniformly between specified values
26    Uniform {
27        /// The minimum value to draw from
28        min: f64,
29
30        /// The maximum value to draw from
31        max: f64,
32    },
33    /// Fills tensor with values drawn from normal distribution with specified mean and std
34    Normal {
35        /// The mean of the normal distribution
36        mean: f64,
37
38        /// The standard deviation of the normal distribution
39        std: f64,
40    },
41    /// Fills tensor with values according to the uniform version of Kaiming initialization
42    KaimingUniform {
43        /// The gain to use in initialization formula
44        gain: f64,
45
46        /// Whether to use fan out only in initialization formula
47        fan_out_only: bool,
48    },
49    /// Fills tensor with values according to the uniform version of Kaiming initialization
50    KaimingNormal {
51        /// The gain to use in initialization formula
52        gain: f64,
53
54        /// Whether to use fan out only in initialization formula
55        fan_out_only: bool,
56    },
57    /// Fills tensor with values according to the uniform version of Xavier Glorot initialization
58    /// described in [Understanding the difficulty of training deep feedforward neural networks
59    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
60    XavierUniform {
61        /// The gain to use in initialization formula
62        gain: f64,
63    },
64    /// Fills tensor with values according to the normal version of Xavier Glorot initialization
65    /// described in [Understanding the difficulty of training deep feedforward neural networks
66    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
67    XavierNormal {
68        /// The gain to use in initialization formula
69        gain: f64,
70    },
71    /// Fills tensor with values according to the (semi) orthogonal initialization
72    /// described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks`
73    ///  - Saxe, A. et al. (2013)](https://arxiv.org/abs/1312.6120)
74    Orthogonal {
75        /// The gain to use in initialization formula
76        gain: f64,
77    },
78}
79
80impl Initializer {
81    /// Inits a tensor parameter of given shape with values depending on initializer kind.
82    ///
83    /// # Params
84    ///
85    /// - shape: Shape of the initiated tensor.
86    pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
87        &self,
88        shape: S,
89        device: &B::Device,
90    ) -> Param<Tensor<B, D>> {
91        self.init_with(shape, None, None, device)
92    }
93
94    /// Inits a tensor parameter of given shape with values depending on initializer kind.
95    ///
96    /// # Params
97    ///
98    /// - shape: Shape of the initiated tensor.
99    pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
100        &self,
101        shape: S,
102        fan_in: Option<usize>,
103        fan_out: Option<usize>,
104        device: &B::Device,
105    ) -> Param<Tensor<B, D>> {
106        let device = device.clone();
107        let shape: Shape = shape.into();
108        let config = self.clone();
109
110        Param::uninitialized(
111            ParamId::new(),
112            move |device, require_grad| {
113                let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
114
115                if require_grad {
116                    tensor = tensor.require_grad();
117                }
118
119                tensor
120            },
121            device,
122            true,
123        )
124    }
125
126    fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
127        &self,
128        shape: S,
129        fan_in: Option<usize>,
130        fan_out: Option<usize>,
131        device: &B::Device,
132    ) -> Tensor<B, D> {
133        let shape = shape.into();
134        match self {
135            Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
136            Initializer::Ones => Tensor::<B, D>::ones(shape, device),
137            Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
138            Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
139            Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
140            Initializer::KaimingUniform { gain, fan_out_only } => {
141                let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
142                uniform_draw(shape, -a, a, device)
143            }
144            Initializer::KaimingNormal { gain, fan_out_only } => {
145                let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
146                normal_draw(shape, 0.0, std, device)
147            }
148            Initializer::XavierUniform { gain } => {
149                let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
150                uniform_draw(shape, -a, a, device)
151            }
152            Initializer::XavierNormal { gain } => {
153                let std = *gain * self.xavier_std(fan_in, fan_out);
154                normal_draw(shape, 0.0, std, device)
155            }
156            Initializer::Orthogonal { gain } => {
157                // following the implementation in pytorch:
158                // https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/init.py#L574
159
160                assert!(
161                    D >= 2,
162                    "Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)"
163                );
164
165                let rows: usize = shape.dims::<D>()[0];
166                let cols: usize = shape.num_elements() / rows;
167
168                let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);
169
170                if rows < cols {
171                    t = t.transpose();
172                }
173
174                let (q, r) = qr_decomposition(t, device);
175                let [r_rows, r_cols] = r.clone().dims();
176
177                let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)
178                    .matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));
179
180                let ph = diag_r.clone().sign();
181
182                let mut q = q.mul(ph);
183
184                if rows < cols {
185                    q = q.transpose();
186                }
187
188                q.reshape(shape).mul_scalar(*gain)
189            }
190        }
191    }
192
193    fn kaiming_std(
194        &self,
195        fan_out_only: bool,
196        fan_in: Option<usize>,
197        fan_out: Option<usize>,
198    ) -> f64 {
199        let fan = if fan_out_only { fan_out } else { fan_in };
200        let fan = fan.expect(
201            "Can't use Kaiming initialization without specifying fan. Use init_with method.",
202        );
203
204        1.0 / (fan as f64).sqrt()
205    }
206
207    fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
208        let fan_in = fan_in.expect(
209            "Can't use Xavier initialization without specifying fan in. Use init_with method and \
210             provide fan_in.",
211        );
212        let fan_out = fan_out.expect(
213            "Can't use Xavier initialization without specifying fan out. Use init_with method and \
214             provide fan_out.",
215        );
216        (2.0 / (fan_in + fan_out) as f64).sqrt()
217    }
218}
219
220fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
221    shape: S,
222    low: f64,
223    high: f64,
224    device: &B::Device,
225) -> Tensor<B, D> {
226    let distribution = Distribution::Uniform(low, high);
227    Tensor::<B, D>::random(shape, distribution, device)
228}
229
230fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
231    shape: S,
232    mean: f64,
233    std: f64,
234    device: &B::Device,
235) -> Tensor<B, D> {
236    let distribution = Distribution::Normal(mean, std);
237    Tensor::<B, D>::random(shape, distribution, device)
238}
239
240fn qr_decomposition<B: Backend>(
241    a: Tensor<B, 2>,
242    device: &B::Device,
243) -> (Tensor<B, 2>, Tensor<B, 2>) {
244    // Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
245
246    let [m, n] = a.clone().dims();
247    let mut q = Tensor::<B, 2>::zeros([m, n], device);
248    let mut r = Tensor::<B, 2>::zeros([n, n], device);
249
250    for j in 0..n {
251        let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze(1);
252
253        for i in 0..j {
254            let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze(1);
255            let r_ij = q_i.clone().mul(v.clone()).sum();
256
257            r = r
258                .clone()
259                .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());
260
261            v = v - q_i.mul(r_ij);
262        }
263
264        // norm of v
265        let r_jj = v
266            .clone()
267            .powf(Tensor::from_floats([2.0], device))
268            .sum()
269            .sqrt();
270
271        r = r
272            .clone()
273            .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());
274
275        let q_j = v / r_jj;
276
277        q = q
278            .clone()
279            .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));
280    }
281
282    (q, r)
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    use crate::tensor::{ElementConversion, TensorData};
290    use num_traits::Pow;
291
292    pub type TB = burn_ndarray::NdArray<f32>;
293    use burn_tensor::{Tolerance, ops::FloatElem};
294    type FT = FloatElem<TB>;
295
296    fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
297        let (actual_vars, actual_means) = tensor.clone().var_mean(0);
298        let actual_vars = actual_vars.to_data();
299        let actual_vars = actual_vars.as_slice::<FT>().unwrap();
300        let actual_means = actual_means.to_data();
301        let actual_means = actual_means.as_slice::<FT>().unwrap();
302
303        for i in 0..tensor.shape().dims[0] {
304            let actual_var = actual_vars[i] as f64;
305            let actual_mean = actual_means[i] as f64;
306
307            assert!(
308                (expected_var - actual_var).abs() <= 0.1,
309                "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
310            );
311            assert!(
312                (expected_mean - actual_mean).abs() <= 0.1,
313                "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
314            );
315        }
316    }
317
318    #[test]
319    fn initializer_uniform_init() {
320        TB::seed(0);
321
322        let (min, max) = (0.0, 1.0);
323        let uniform = Initializer::Uniform { min, max };
324        let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
325
326        tensor
327            .into_data()
328            .assert_within_range::<FT>(min.elem()..max.elem());
329    }
330
331    #[test]
332    fn initializer_normal_init() {
333        // seed random generator
334        TB::seed(0);
335        let (mean, std) = (0.0, 1.0);
336        let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
337            .init([1000], &Default::default())
338            .into_value();
339        let (var_act, mean_act) = normal.var_mean(0);
340
341        let var_act: f32 = var_act.into_scalar().elem();
342        let mean_act: f32 = mean_act.into_scalar().elem();
343
344        assert!(
345            var_act > 0.9 && var_act < 1.1,
346            "Expected variance to be between 1.0 += 0.1, but got {var_act}"
347        );
348        assert!(
349            mean_act > -0.1 && mean_act < 0.1,
350            "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
351        );
352    }
353
354    #[test]
355    fn initializer_constant_init() {
356        let value = 5.0;
357        let constants: Tensor<TB, 4> = Initializer::Constant { value }
358            .init([2, 2, 2, 2], &Default::default())
359            .into_value();
360        constants.sum().to_data().assert_approx_eq::<FT>(
361            &TensorData::from([value as f32 * 16.0]),
362            Tolerance::default(),
363        );
364    }
365
366    #[test]
367    fn initializer_zeros_init() {
368        let zeros: Tensor<TB, 4> = Initializer::Zeros
369            .init([2, 2, 2, 2], &Default::default())
370            .into_value();
371        zeros
372            .sum()
373            .to_data()
374            .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
375    }
376
377    #[test]
378    fn initializer_ones_init() {
379        let ones: Tensor<TB, 4> = Initializer::Ones
380            .init([2, 2, 2, 2], &Default::default())
381            .into_value();
382        ones.sum()
383            .to_data()
384            .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
385    }
386
387    #[test]
388    fn initializer_kaiming_uniform_init() {
389        TB::seed(0);
390
391        let gain = 2_f64;
392        let (fan_in, fan_out) = (5, 6);
393        let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
394
395        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
396            gain,
397            fan_out_only: false,
398        }
399        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
400        .into_value();
401        tensor.into_data().assert_within_range(-k..k);
402    }
403
404    #[test]
405    fn initializer_kaiming_normal_init() {
406        TB::seed(0);
407
408        let gain = 2.;
409        let (fan_in, fan_out) = (1000, 10);
410        let expected_mean = 0_f64;
411
412        let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
413        let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
414            gain,
415            fan_out_only: false,
416        }
417        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
418        .into_value();
419        assert_normal_init(expected_mean, expected_var, &tensor)
420    }
421
422    #[test]
423    fn initializer_kaiming_uniform_init_bias() {
424        TB::seed(0);
425
426        let gain = 2_f64;
427        let shape = [3];
428        let fan_in = 5;
429        let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();
430
431        let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
432            gain,
433            fan_out_only: false,
434        }
435        .init_with(shape, Some(fan_in), None, &Default::default())
436        .into_value();
437        tensor.into_data().assert_within_range(-k..k);
438    }
439
440    #[test]
441    fn initializer_kaiming_uniform_init_fan_out() {
442        TB::seed(0);
443
444        let gain = 2_f64;
445        let (fan_in, fan_out) = (5, 6);
446        let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();
447
448        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
449            gain,
450            fan_out_only: true,
451        }
452        .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
453        .into_value();
454        tensor.into_data().assert_within_range(-k..k);
455    }
456
457    #[test]
458    #[should_panic]
459    fn initializer_kaiming_uniform_no_fan() {
460        TB::seed(0);
461
462        let gain = 2_f64;
463        let (fan_in, fan_out) = (5, 6);
464
465        let _: Tensor<TB, 2> = Initializer::KaimingUniform {
466            gain,
467            fan_out_only: false,
468        }
469        .init([fan_out, fan_in], &Default::default())
470        .into_value();
471    }
472
473    #[test]
474    fn initializer_xavier_uniform_init() {
475        TB::seed(0);
476
477        let gain = 2.;
478        let (fan_in, fan_out) = (5, 6);
479        let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();
480        let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
481            .init_with(
482                [fan_out, fan_in],
483                Some(fan_in),
484                Some(fan_out),
485                &Default::default(),
486            )
487            .into_value();
488
489        tensor.into_data().assert_within_range(-bound..bound);
490    }
491
492    #[test]
493    fn initializer_xavier_normal_init() {
494        TB::seed(0);
495
496        let gain = 2.;
497        let (fan_in, fan_out) = (1000, 10);
498        let expected_mean = 0_f64;
499
500        let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
501        let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
502            .init_with(
503                [fan_out, fan_in],
504                Some(fan_in),
505                Some(fan_out),
506                &Default::default(),
507            )
508            .into_value();
509        assert_normal_init(expected_mean, expected_var, &tensor)
510    }
511
512    #[test]
513    #[should_panic]
514    fn initializer_xavier_uniform_no_fan() {
515        TB::seed(0);
516
517        let gain = 2.;
518        let (fan_in, fan_out) = (5, 6);
519        let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
520            .init([fan_out, fan_in], &Default::default())
521            .into_value();
522    }
523
524    #[test]
525    fn test_qr_decomposition() {
526        TB::seed(0);
527
528        // test values follow the example from https://pytorch.org/docs/stable/generated/torch.linalg.qr.html#torch.linalg.qr
529        let a = Tensor::<TB, 2>::from_floats(
530            [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],
531            &Default::default(),
532        );
533        let qr = qr_decomposition(a.clone(), &Default::default());
534
535        // Q @ R should reconstruct input `a`
536        let q_matmul_r = qr.0.clone().matmul(qr.1.clone());
537
538        // assert that the difference between input (`a`) and Q @ R is (almost) zero
539        q_matmul_r
540            .into_data()
541            .assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));
542    }
543
544    #[test]
545    fn initializer_orthogonal_correct() {
546        TB::seed(0);
547
548        let gain = 1.;
549
550        // test 2D tensor
551        let size = 10;
552        let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }
553            .init([size, size], &Default::default())
554            .into_value();
555        let eye = Tensor::<TB, 2>::eye(size, &Default::default());
556
557        // Q.T @ Q should be close to identity matrix
558        q.clone()
559            .transpose()
560            .matmul(q)
561            .into_data()
562            .assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));
563    }
564
565    #[test]
566    fn initializer_orthogonal_init() {
567        TB::seed(0);
568
569        let gain = 1.;
570
571        // test 2D tensor
572        let shape = [25, 30];
573        let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }
574            .init(shape, &Default::default())
575            .into_value();
576        let dims = t.dims();
577        assert_eq!(
578            shape, dims,
579            "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
580        );
581
582        // test 3D tensor
583        let shape = [24, 6, 85];
584        let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }
585            .init(shape, &Default::default())
586            .into_value();
587        let dims = t.dims();
588        assert_eq!(
589            shape, dims,
590            "Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})"
591        );
592    }
593
594    #[test]
595    #[should_panic]
596    fn initializer_orthogonal_init_1d() {
597        TB::seed(0);
598        let gain = 1.;
599
600        // test 1D tensor
601        let shape = [3];
602        let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }
603            .init(shape, &Default::default())
604            .into_value();
605    }
606}