burn_core/nn/
initializer.rs

1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11use num_traits::Float;
12
13/// Enum specifying with what values a tensor should be initialized
14#[derive(Config, Debug, PartialEq)]
15pub enum Initializer {
16    /// Fills tensor with specified value everywhere
17    Constant {
18        /// The value to fill the tensor with
19        value: f64,
20    },
21    /// Fills tensor with 1s everywhere
22    Ones,
23    /// Fills tensor with 0s everywhere
24    Zeros,
25    /// Fills tensor with values drawn uniformly between specified values
26    Uniform {
27        /// The minimum value to draw from
28        min: f64,
29
30        /// The maximum value to draw from
31        max: f64,
32    },
33    /// Fills tensor with values drawn from normal distribution with specified mean and std
34    Normal {
35        /// The mean of the normal distribution
36        mean: f64,
37
38        /// The standard deviation of the normal distribution
39        std: f64,
40    },
41    /// Fills tensor with values according to the uniform version of Kaiming initialization
42    KaimingUniform {
43        /// The gain to use in initialization formula
44        gain: f64,
45
46        /// Whether to use fan out only in initialization formula
47        fan_out_only: bool,
48    },
49    /// Fills tensor with values according to the uniform version of Kaiming initialization
50    KaimingNormal {
51        /// The gain to use in initialization formula
52        gain: f64,
53
54        /// Whether to use fan out only in initialization formula
55        fan_out_only: bool,
56    },
57    /// Fills tensor with values according to the uniform version of Xavier Glorot initialization
58    /// described in [Understanding the difficulty of training deep feedforward neural networks
59    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
60    XavierUniform {
61        /// The gain to use in initialization formula
62        gain: f64,
63    },
64    /// Fills tensor with values according to the normal version of Xavier Glorot initialization
65    /// described in [Understanding the difficulty of training deep feedforward neural networks
66    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
67    XavierNormal {
68        /// The gain to use in initialization formula
69        gain: f64,
70    },
71}
72
73impl Initializer {
74    /// Inits a tensor parameter of given shape with values depending on initializer kind.
75    ///
76    /// # Params
77    ///
78    /// - shape: Shape of the initiated tensor.
79    pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
80        &self,
81        shape: S,
82        device: &B::Device,
83    ) -> Param<Tensor<B, D>> {
84        self.init_with(shape, None, None, device)
85    }
86
87    /// Inits a tensor parameter of given shape with values depending on initializer kind.
88    ///
89    /// # Params
90    ///
91    /// - shape: Shape of the initiated tensor.
92    pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
93        &self,
94        shape: S,
95        fan_in: Option<usize>,
96        fan_out: Option<usize>,
97        device: &B::Device,
98    ) -> Param<Tensor<B, D>> {
99        let device = device.clone();
100        let shape: Shape = shape.into();
101        let config = self.clone();
102
103        Param::uninitialized(
104            ParamId::new(),
105            move |device, require_grad| {
106                let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
107
108                if require_grad {
109                    tensor = tensor.require_grad();
110                }
111
112                tensor
113            },
114            device,
115            true,
116        )
117    }
118
119    fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
120        &self,
121        shape: S,
122        fan_in: Option<usize>,
123        fan_out: Option<usize>,
124        device: &B::Device,
125    ) -> Tensor<B, D> {
126        let shape = shape.into();
127        match self {
128            Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
129            Initializer::Ones => Tensor::<B, D>::ones(shape, device),
130            Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
131            Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
132            Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
133            Initializer::KaimingUniform { gain, fan_out_only } => {
134                let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
135                uniform_draw(shape, -a, a, device)
136            }
137            Initializer::KaimingNormal { gain, fan_out_only } => {
138                let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
139                normal_draw(shape, 0.0, std, device)
140            }
141            Initializer::XavierUniform { gain } => {
142                let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
143                uniform_draw(shape, -a, a, device)
144            }
145            Initializer::XavierNormal { gain } => {
146                let std = *gain * self.xavier_std(fan_in, fan_out);
147                normal_draw(shape, 0.0, std, device)
148            }
149        }
150    }
151
152    fn kaiming_std(
153        &self,
154        fan_out_only: bool,
155        fan_in: Option<usize>,
156        fan_out: Option<usize>,
157    ) -> f64 {
158        let fan = if fan_out_only { fan_out } else { fan_in };
159        let fan = fan.expect(
160            "Can't use Kaiming initialization without specifying fan. Use init_with method.",
161        );
162
163        1.0 / (fan as f64).sqrt()
164    }
165
166    fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
167        let fan_in = fan_in.expect(
168            "Can't use Xavier initialization without specifying fan in. Use init_with method and \
169             provide fan_in.",
170        );
171        let fan_out = fan_out.expect(
172            "Can't use Xavier initialization without specifying fan out. Use init_with method and \
173             provide fan_out.",
174        );
175        (2.0 / (fan_in + fan_out) as f64).sqrt()
176    }
177}
178
179fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
180    shape: S,
181    low: f64,
182    high: f64,
183    device: &B::Device,
184) -> Tensor<B, D> {
185    let distribution = Distribution::Uniform(low, high);
186    Tensor::<B, D>::random(shape, distribution, device)
187}
188
189fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
190    shape: S,
191    mean: f64,
192    std: f64,
193    device: &B::Device,
194) -> Tensor<B, D> {
195    let distribution = Distribution::Normal(mean, std);
196    Tensor::<B, D>::random(shape, distribution, device)
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    use crate::tensor::{ElementConversion, TensorData};
204    use num_traits::Pow;
205
206    pub type TB = burn_ndarray::NdArray<f32>;
207
208    fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
209        let (actual_vars, actual_means) = tensor.clone().var_mean(0);
210        let actual_vars = actual_vars.to_data();
211        let actual_vars = actual_vars
212            .as_slice::<<TB as Backend>::FloatElem>()
213            .unwrap();
214        let actual_means = actual_means.to_data();
215        let actual_means = actual_means
216            .as_slice::<<TB as Backend>::FloatElem>()
217            .unwrap();
218
219        for i in 0..tensor.shape().dims[0] {
220            let actual_var = actual_vars[i] as f64;
221            let actual_mean = actual_means[i] as f64;
222
223            assert!(
224                (expected_var - actual_var).abs() <= 0.1,
225                "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
226            );
227            assert!(
228                (expected_mean - actual_mean).abs() <= 0.1,
229                "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
230            );
231        }
232    }
233
234    #[test]
235    fn initializer_uniform_init() {
236        TB::seed(0);
237
238        let (min, max) = (0.0, 1.0);
239        let uniform = Initializer::Uniform { min, max };
240        let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
241
242        tensor.into_data().assert_within_range(min..max);
243    }
244
245    #[test]
246    fn initializer_normal_init() {
247        // seed random generator
248        TB::seed(0);
249        let (mean, std) = (0.0, 1.0);
250        let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
251            .init([1000], &Default::default())
252            .into_value();
253        let (var_act, mean_act) = normal.var_mean(0);
254
255        let var_act: f32 = var_act.into_scalar().elem();
256        let mean_act: f32 = mean_act.into_scalar().elem();
257
258        assert!(
259            var_act > 0.9 && var_act < 1.1,
260            "Expected variance to be between 1.0 += 0.1, but got {var_act}"
261        );
262        assert!(
263            mean_act > -0.1 && mean_act < 0.1,
264            "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
265        );
266    }
267
268    #[test]
269    fn initializer_constant_init() {
270        let value = 5.0;
271        let constants: Tensor<TB, 4> = Initializer::Constant { value }
272            .init([2, 2, 2, 2], &Default::default())
273            .into_value();
274        constants
275            .sum()
276            .to_data()
277            .assert_approx_eq(&TensorData::from([value as f32 * 16.0]), 3);
278    }
279
280    #[test]
281    fn initializer_zeros_init() {
282        let zeros: Tensor<TB, 4> = Initializer::Zeros
283            .init([2, 2, 2, 2], &Default::default())
284            .into_value();
285        zeros
286            .sum()
287            .to_data()
288            .assert_approx_eq(&TensorData::from([0.0]), 3);
289    }
290
291    #[test]
292    fn initializer_ones_init() {
293        let ones: Tensor<TB, 4> = Initializer::Ones
294            .init([2, 2, 2, 2], &Default::default())
295            .into_value();
296        ones.sum()
297            .to_data()
298            .assert_approx_eq(&TensorData::from([16.0]), 3);
299    }
300
301    #[test]
302    fn initializer_kaiming_uniform_init() {
303        TB::seed(0);
304
305        let gain = 2_f64;
306        let (fan_in, fan_out) = (5, 6);
307        let k = gain * (3.0 / fan_in as f64).sqrt();
308
309        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
310            gain,
311            fan_out_only: false,
312        }
313        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
314        .into_value();
315        tensor.into_data().assert_within_range(-k..k);
316    }
317
318    #[test]
319    fn initializer_kaiming_normal_init() {
320        TB::seed(0);
321
322        let gain = 2.;
323        let (fan_in, fan_out) = (1000, 10);
324        let expected_mean = 0_f64;
325
326        let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
327        let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
328            gain,
329            fan_out_only: false,
330        }
331        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
332        .into_value();
333        assert_normal_init(expected_mean, expected_var, &tensor)
334    }
335
336    #[test]
337    fn initializer_kaiming_uniform_init_bias() {
338        TB::seed(0);
339
340        let gain = 2_f64;
341        let shape = [3];
342        let fan_in = 5;
343        let k = gain * (3.0 / fan_in as f64).sqrt();
344
345        let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
346            gain,
347            fan_out_only: false,
348        }
349        .init_with(shape, Some(fan_in), None, &Default::default())
350        .into_value();
351        tensor.into_data().assert_within_range(-k..k);
352    }
353
354    #[test]
355    fn initializer_kaiming_uniform_init_fan_out() {
356        TB::seed(0);
357
358        let gain = 2_f64;
359        let (fan_in, fan_out) = (5, 6);
360        let k = gain * (3.0 / fan_out as f64).sqrt();
361
362        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
363            gain,
364            fan_out_only: true,
365        }
366        .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
367        .into_value();
368        tensor.into_data().assert_within_range(-k..k);
369    }
370
371    #[test]
372    #[should_panic]
373    fn initializer_kaiming_uniform_no_fan() {
374        TB::seed(0);
375
376        let gain = 2_f64;
377        let (fan_in, fan_out) = (5, 6);
378
379        let _: Tensor<TB, 2> = Initializer::KaimingUniform {
380            gain,
381            fan_out_only: false,
382        }
383        .init([fan_out, fan_in], &Default::default())
384        .into_value();
385    }
386
387    #[test]
388    fn initializer_xavier_uniform_init() {
389        TB::seed(0);
390
391        let gain = 2.;
392        let (fan_in, fan_out) = (5, 6);
393        let bound = gain * (6. / (fan_in + fan_out) as f64).sqrt();
394        let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
395            .init_with(
396                [fan_out, fan_in],
397                Some(fan_in),
398                Some(fan_out),
399                &Default::default(),
400            )
401            .into_value();
402
403        tensor.into_data().assert_within_range(-bound..bound);
404    }
405
406    #[test]
407    fn initializer_xavier_normal_init() {
408        TB::seed(0);
409
410        let gain = 2.;
411        let (fan_in, fan_out) = (1000, 10);
412        let expected_mean = 0_f64;
413
414        let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
415        let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
416            .init_with(
417                [fan_out, fan_in],
418                Some(fan_in),
419                Some(fan_out),
420                &Default::default(),
421            )
422            .into_value();
423        assert_normal_init(expected_mean, expected_var, &tensor)
424    }
425
426    #[test]
427    #[should_panic]
428    fn initializer_xavier_uniform_no_fan() {
429        TB::seed(0);
430
431        let gain = 2.;
432        let (fan_in, fan_out) = (5, 6);
433        let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
434            .init([fan_out, fan_in], &Default::default())
435            .into_value();
436    }
437}