burn_core/nn/
initializer.rs

1use crate::tensor::Shape;
2
3use crate::config::Config;
4use crate::module::{Param, ParamId};
5use crate::tensor::backend::Backend;
6use crate::tensor::{Distribution, Tensor};
7
8use crate as burn;
9
10#[cfg(not(feature = "std"))]
11use num_traits::Float;
12
13/// Enum specifying with what values a tensor should be initialized
14#[derive(Config, Debug, PartialEq)]
15pub enum Initializer {
16    /// Fills tensor with specified value everywhere
17    Constant {
18        /// The value to fill the tensor with
19        value: f64,
20    },
21    /// Fills tensor with 1s everywhere
22    Ones,
23    /// Fills tensor with 0s everywhere
24    Zeros,
25    /// Fills tensor with values drawn uniformly between specified values
26    Uniform {
27        /// The minimum value to draw from
28        min: f64,
29
30        /// The maximum value to draw from
31        max: f64,
32    },
33    /// Fills tensor with values drawn from normal distribution with specified mean and std
34    Normal {
35        /// The mean of the normal distribution
36        mean: f64,
37
38        /// The standard deviation of the normal distribution
39        std: f64,
40    },
41    /// Fills tensor with values according to the uniform version of Kaiming initialization
42    KaimingUniform {
43        /// The gain to use in initialization formula
44        gain: f64,
45
46        /// Whether to use fan out only in initialization formula
47        fan_out_only: bool,
48    },
49    /// Fills tensor with values according to the uniform version of Kaiming initialization
50    KaimingNormal {
51        /// The gain to use in initialization formula
52        gain: f64,
53
54        /// Whether to use fan out only in initialization formula
55        fan_out_only: bool,
56    },
57    /// Fills tensor with values according to the uniform version of Xavier Glorot initialization
58    /// described in [Understanding the difficulty of training deep feedforward neural networks
59    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
60    XavierUniform {
61        /// The gain to use in initialization formula
62        gain: f64,
63    },
64    /// Fills tensor with values according to the normal version of Xavier Glorot initialization
65    /// described in [Understanding the difficulty of training deep feedforward neural networks
66    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
67    XavierNormal {
68        /// The gain to use in initialization formula
69        gain: f64,
70    },
71}
72
73impl Initializer {
74    /// Inits a tensor parameter of given shape with values depending on initializer kind.
75    ///
76    /// # Params
77    ///
78    /// - shape: Shape of the initiated tensor.
79    pub fn init<B: Backend, const D: usize, S: Into<Shape>>(
80        &self,
81        shape: S,
82        device: &B::Device,
83    ) -> Param<Tensor<B, D>> {
84        self.init_with(shape, None, None, device)
85    }
86
87    /// Inits a tensor parameter of given shape with values depending on initializer kind.
88    ///
89    /// # Params
90    ///
91    /// - shape: Shape of the initiated tensor.
92    pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(
93        &self,
94        shape: S,
95        fan_in: Option<usize>,
96        fan_out: Option<usize>,
97        device: &B::Device,
98    ) -> Param<Tensor<B, D>> {
99        let device = device.clone();
100        let shape: Shape = shape.into();
101        let config = self.clone();
102
103        Param::uninitialized(
104            ParamId::new(),
105            move |device, require_grad| {
106                let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);
107
108                if require_grad {
109                    tensor = tensor.require_grad();
110                }
111
112                tensor
113            },
114            device,
115            true,
116        )
117    }
118
119    fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(
120        &self,
121        shape: S,
122        fan_in: Option<usize>,
123        fan_out: Option<usize>,
124        device: &B::Device,
125    ) -> Tensor<B, D> {
126        let shape = shape.into();
127        match self {
128            Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),
129            Initializer::Ones => Tensor::<B, D>::ones(shape, device),
130            Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),
131            Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),
132            Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),
133            Initializer::KaimingUniform { gain, fan_out_only } => {
134                let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
135                uniform_draw(shape, -a, a, device)
136            }
137            Initializer::KaimingNormal { gain, fan_out_only } => {
138                let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);
139                normal_draw(shape, 0.0, std, device)
140            }
141            Initializer::XavierUniform { gain } => {
142                let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);
143                uniform_draw(shape, -a, a, device)
144            }
145            Initializer::XavierNormal { gain } => {
146                let std = *gain * self.xavier_std(fan_in, fan_out);
147                normal_draw(shape, 0.0, std, device)
148            }
149        }
150    }
151
152    fn kaiming_std(
153        &self,
154        fan_out_only: bool,
155        fan_in: Option<usize>,
156        fan_out: Option<usize>,
157    ) -> f64 {
158        let fan = if fan_out_only { fan_out } else { fan_in };
159        let fan = fan.expect(
160            "Can't use Kaiming initialization without specifying fan. Use init_with method.",
161        );
162
163        1.0 / (fan as f64).sqrt()
164    }
165
166    fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
167        let fan_in = fan_in.expect(
168            "Can't use Xavier initialization without specifying fan in. Use init_with method and \
169             provide fan_in.",
170        );
171        let fan_out = fan_out.expect(
172            "Can't use Xavier initialization without specifying fan out. Use init_with method and \
173             provide fan_out.",
174        );
175        (2.0 / (fan_in + fan_out) as f64).sqrt()
176    }
177}
178
179fn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(
180    shape: S,
181    low: f64,
182    high: f64,
183    device: &B::Device,
184) -> Tensor<B, D> {
185    let distribution = Distribution::Uniform(low, high);
186    Tensor::<B, D>::random(shape, distribution, device)
187}
188
189fn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(
190    shape: S,
191    mean: f64,
192    std: f64,
193    device: &B::Device,
194) -> Tensor<B, D> {
195    let distribution = Distribution::Normal(mean, std);
196    Tensor::<B, D>::random(shape, distribution, device)
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    use crate::tensor::{ElementConversion, TensorData};
204    use num_traits::Pow;
205
206    pub type TB = burn_ndarray::NdArray<f32>;
207    use burn_tensor::{Tolerance, ops::FloatElem};
208    type FT = FloatElem<TB>;
209
210    fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {
211        let (actual_vars, actual_means) = tensor.clone().var_mean(0);
212        let actual_vars = actual_vars.to_data();
213        let actual_vars = actual_vars
214            .as_slice::<<TB as Backend>::FloatElem>()
215            .unwrap();
216        let actual_means = actual_means.to_data();
217        let actual_means = actual_means
218            .as_slice::<<TB as Backend>::FloatElem>()
219            .unwrap();
220
221        for i in 0..tensor.shape().dims[0] {
222            let actual_var = actual_vars[i] as f64;
223            let actual_mean = actual_means[i] as f64;
224
225            assert!(
226                (expected_var - actual_var).abs() <= 0.1,
227                "Expected variance to be between {expected_var} += 0.1, but got {actual_var}"
228            );
229            assert!(
230                (expected_mean - actual_mean).abs() <= 0.1,
231                "Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}"
232            );
233        }
234    }
235
236    #[test]
237    fn initializer_uniform_init() {
238        TB::seed(0);
239
240        let (min, max) = (0.0, 1.0);
241        let uniform = Initializer::Uniform { min, max };
242        let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();
243
244        tensor.into_data().assert_within_range(min..max);
245    }
246
247    #[test]
248    fn initializer_normal_init() {
249        // seed random generator
250        TB::seed(0);
251        let (mean, std) = (0.0, 1.0);
252        let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }
253            .init([1000], &Default::default())
254            .into_value();
255        let (var_act, mean_act) = normal.var_mean(0);
256
257        let var_act: f32 = var_act.into_scalar().elem();
258        let mean_act: f32 = mean_act.into_scalar().elem();
259
260        assert!(
261            var_act > 0.9 && var_act < 1.1,
262            "Expected variance to be between 1.0 += 0.1, but got {var_act}"
263        );
264        assert!(
265            mean_act > -0.1 && mean_act < 0.1,
266            "Expected mean to be between 0.0 += 0.1, but got {mean_act}"
267        );
268    }
269
270    #[test]
271    fn initializer_constant_init() {
272        let value = 5.0;
273        let constants: Tensor<TB, 4> = Initializer::Constant { value }
274            .init([2, 2, 2, 2], &Default::default())
275            .into_value();
276        constants.sum().to_data().assert_approx_eq::<FT>(
277            &TensorData::from([value as f32 * 16.0]),
278            Tolerance::default(),
279        );
280    }
281
282    #[test]
283    fn initializer_zeros_init() {
284        let zeros: Tensor<TB, 4> = Initializer::Zeros
285            .init([2, 2, 2, 2], &Default::default())
286            .into_value();
287        zeros
288            .sum()
289            .to_data()
290            .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());
291    }
292
293    #[test]
294    fn initializer_ones_init() {
295        let ones: Tensor<TB, 4> = Initializer::Ones
296            .init([2, 2, 2, 2], &Default::default())
297            .into_value();
298        ones.sum()
299            .to_data()
300            .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());
301    }
302
303    #[test]
304    fn initializer_kaiming_uniform_init() {
305        TB::seed(0);
306
307        let gain = 2_f64;
308        let (fan_in, fan_out) = (5, 6);
309        let k = gain * (3.0 / fan_in as f64).sqrt();
310
311        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
312            gain,
313            fan_out_only: false,
314        }
315        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
316        .into_value();
317        tensor.into_data().assert_within_range(-k..k);
318    }
319
320    #[test]
321    fn initializer_kaiming_normal_init() {
322        TB::seed(0);
323
324        let gain = 2.;
325        let (fan_in, fan_out) = (1000, 10);
326        let expected_mean = 0_f64;
327
328        let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);
329        let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {
330            gain,
331            fan_out_only: false,
332        }
333        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())
334        .into_value();
335        assert_normal_init(expected_mean, expected_var, &tensor)
336    }
337
338    #[test]
339    fn initializer_kaiming_uniform_init_bias() {
340        TB::seed(0);
341
342        let gain = 2_f64;
343        let shape = [3];
344        let fan_in = 5;
345        let k = gain * (3.0 / fan_in as f64).sqrt();
346
347        let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {
348            gain,
349            fan_out_only: false,
350        }
351        .init_with(shape, Some(fan_in), None, &Default::default())
352        .into_value();
353        tensor.into_data().assert_within_range(-k..k);
354    }
355
356    #[test]
357    fn initializer_kaiming_uniform_init_fan_out() {
358        TB::seed(0);
359
360        let gain = 2_f64;
361        let (fan_in, fan_out) = (5, 6);
362        let k = gain * (3.0 / fan_out as f64).sqrt();
363
364        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {
365            gain,
366            fan_out_only: true,
367        }
368        .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())
369        .into_value();
370        tensor.into_data().assert_within_range(-k..k);
371    }
372
373    #[test]
374    #[should_panic]
375    fn initializer_kaiming_uniform_no_fan() {
376        TB::seed(0);
377
378        let gain = 2_f64;
379        let (fan_in, fan_out) = (5, 6);
380
381        let _: Tensor<TB, 2> = Initializer::KaimingUniform {
382            gain,
383            fan_out_only: false,
384        }
385        .init([fan_out, fan_in], &Default::default())
386        .into_value();
387    }
388
389    #[test]
390    fn initializer_xavier_uniform_init() {
391        TB::seed(0);
392
393        let gain = 2.;
394        let (fan_in, fan_out) = (5, 6);
395        let bound = gain * (6. / (fan_in + fan_out) as f64).sqrt();
396        let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }
397            .init_with(
398                [fan_out, fan_in],
399                Some(fan_in),
400                Some(fan_out),
401                &Default::default(),
402            )
403            .into_value();
404
405        tensor.into_data().assert_within_range(-bound..bound);
406    }
407
408    #[test]
409    fn initializer_xavier_normal_init() {
410        TB::seed(0);
411
412        let gain = 2.;
413        let (fan_in, fan_out) = (1000, 10);
414        let expected_mean = 0_f64;
415
416        let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);
417        let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }
418            .init_with(
419                [fan_out, fan_in],
420                Some(fan_in),
421                Some(fan_out),
422                &Default::default(),
423            )
424            .into_value();
425        assert_normal_init(expected_mean, expected_var, &tensor)
426    }
427
428    #[test]
429    #[should_panic]
430    fn initializer_xavier_uniform_no_fan() {
431        TB::seed(0);
432
433        let gain = 2.;
434        let (fan_in, fan_out) = (5, 6);
435        let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }
436            .init([fan_out, fan_in], &Default::default())
437            .into_value();
438    }
439}