Skip to main content

burn_nn/modules/norm/
batch.rs

1use burn_core as burn;
2
3use burn::module::Initializer;
4use burn::module::{Content, DisplaySettings, ModuleDisplay};
5use burn::tensor::{Tensor, backend::Backend};
6use burn::{
7    config::Config,
8    module::{Module, Param, RunningState},
9};
10
11/// [`BatchNorm`] Configuration.
12///
13/// Used to create a [`BatchNorm`] layer using the [`BatchNormConfig::init`].
14#[derive(Config, Debug)]
15pub struct BatchNormConfig {
16    /// The number of features.
17    pub num_features: usize,
18    /// A value required for numerical stability. Default: 1e-5
19    #[config(default = 1e-5)]
20    pub epsilon: f64,
21    /// Momentum used to update the metrics. Default: 0.1
22    #[config(default = 0.1)]
23    pub momentum: f64,
24}
25
26/// Applies Batch Normalization over a tensor.
27///
28/// Based upon the paper [Batch Normalization](https://arxiv.org/abs/1502.03167).
29///
30/// Assumes input tensor is of shape ``[batch_size, channels, ...]``.
31///
32/// `Y = norm(X) * γ + β`
33///
34/// Where:
35/// - `X` is the input tensor
36/// - `Y` is the output tensor
37/// - `norm` is the normalization function
38/// - `γ` is the learnable weight
39/// - `β` is the learnable bias
40///
41/// Should be created using [`BatchNormConfig`].
42#[derive(Module, Debug)]
43#[module(custom_display)]
44pub struct BatchNorm<B: Backend> {
45    /// The learnable weight gamma.
46    pub gamma: Param<Tensor<B, 1>>,
47    /// The learnable weight beta.
48    pub beta: Param<Tensor<B, 1>>,
49    /// The running mean.
50    pub running_mean: RunningState<Tensor<B, 1>>,
51    /// The running variance.
52    pub running_var: RunningState<Tensor<B, 1>>,
53    /// Momentum used to update the metrics.
54    pub momentum: f64,
55    /// A value required for numerical stability.
56    pub epsilon: f64,
57}
58
59impl BatchNormConfig {
60    /// Initializes a new [batch norm](BatchNorm) module.
61    pub fn init<B: Backend>(&self, device: &B::Device) -> BatchNorm<B> {
62        let gamma = Initializer::Ones.init([self.num_features], device);
63        let beta = Initializer::Zeros.init([self.num_features], device);
64
65        let running_mean = Tensor::zeros([self.num_features], device);
66        let running_var = Tensor::ones([self.num_features], device);
67
68        BatchNorm {
69            gamma,
70            beta,
71            running_mean: RunningState::new(running_mean),
72            running_var: RunningState::new(running_var),
73            momentum: self.momentum,
74            epsilon: self.epsilon,
75        }
76    }
77}
78
79impl<B: Backend> BatchNorm<B> {
80    /// Applies the forward pass on the input tensor.
81    ///
82    /// See [`BatchNorm`] for more information.
83    ///
84    /// # Shapes
85    ///
86    /// - `input`: ``[batch_size, channels, ...]``
87    /// - `output`: ``[batch_size, channels, ...]``
88    ///
89    /// # Panics
90    ///
91    /// This function will panic if the input tensor has rank < 2.
92    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
93        // Should be move to a compilation error when const generic support that kind of
94        // validation. https://github.com/rust-lang/rust/issues/76560
95        if D < 2 {
96            panic!(
97                "BatchNorm can only be applied on tensors of rank >= 2 with the following shape \
98                 [batch_size, channels, ...], received {}D tensor",
99                D
100            );
101        }
102
103        match B::ad_enabled(&input.device()) {
104            true => self.forward_train(input),
105            false => self.forward_inference(input),
106        }
107    }
108
109    fn forward_inference<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
110        let device = input.device();
111        let channels = input.dims()[1];
112        let mean = self.running_mean.value().to_device(&device);
113        let var = self.running_var.value().to_device(&device);
114
115        let mut shape = [1; D];
116        shape[1] = channels;
117
118        self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
119    }
120
121    fn forward_train<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
122        let device = input.device();
123        let dims = input.dims();
124        let batch_size = dims[0];
125        let channels = dims[1];
126
127        let mut shape_unsqueeze = [1; D];
128        let mut flatten_size = batch_size;
129        shape_unsqueeze[1] = channels;
130
131        for dim in dims.iter().take(D).skip(2) {
132            flatten_size *= dim;
133        }
134
135        let mean = input
136            .clone()
137            .swap_dims(0, 1)
138            .reshape([channels, flatten_size])
139            .mean_dim(1)
140            .reshape(shape_unsqueeze);
141
142        let var = input
143            .clone()
144            .sub(mean.clone())
145            .square()
146            .swap_dims(0, 1)
147            .reshape([channels, flatten_size])
148            .mean_dim(1)
149            .reshape(shape_unsqueeze);
150
151        let running_mean = self.running_mean.value_sync().to_device(&device);
152        let running_var = self.running_var.value_sync().to_device(&device);
153
154        let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
155            mean.clone()
156                .detach()
157                .mul_scalar(self.momentum)
158                .reshape([channels]),
159        );
160        let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
161            var.clone()
162                .detach()
163                .mul_scalar(self.momentum)
164                .reshape([channels]),
165        );
166
167        self.running_mean.update(running_mean.detach());
168        self.running_var.update(running_var.detach());
169
170        self.forward_shared(input, mean, var)
171    }
172
173    fn forward_shared<const D: usize>(
174        &self,
175        x: Tensor<B, D>,
176        mean: Tensor<B, D>,
177        var: Tensor<B, D>,
178    ) -> Tensor<B, D> {
179        let channels = x.dims()[1];
180        let mut shape = [1; D];
181        shape[1] = channels;
182
183        let std = var.add_scalar(self.epsilon).sqrt();
184
185        let x = x.sub(mean);
186        let x = x.div(std);
187
188        let x = x.mul(self.gamma.val().reshape(shape));
189
190        x.add(self.beta.val().reshape(shape))
191    }
192}
193
194impl<B: Backend> ModuleDisplay for BatchNorm<B> {
195    fn custom_settings(&self) -> Option<DisplaySettings> {
196        DisplaySettings::new()
197            .with_new_line_after_attribute(false)
198            .optional()
199    }
200
201    fn custom_content(&self, content: Content) -> Option<Content> {
202        let [num_features] = self.beta.shape().dims();
203
204        content
205            .add("num_features", &num_features)
206            .add("momentum", &self.momentum)
207            .add("epsilon", &self.epsilon)
208            .optional()
209    }
210}
211
212#[cfg(feature = "std")]
213#[cfg(test)]
214mod tests_1d {
215    use super::*;
216    use crate::TestAutodiffBackend;
217    use burn::module::AutodiffModule;
218    use burn::tensor::TensorData;
219    use burn::tensor::{Tolerance, ops::FloatElem};
220    type FT = FloatElem<TestAutodiffBackend>;
221
222    #[test]
223    fn batch_norm_forward_train() {
224        let device = Default::default();
225        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
226
227        let output = module.forward(input_tensor(&device));
228
229        output
230            .to_data()
231            .assert_approx_eq::<FT>(&expected_train(), Tolerance::rel_abs(0.1, 0.001));
232    }
233
234    #[test]
235    fn batch_norm_forward_inference() {
236        let device = Default::default();
237        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
238
239        module.forward(input_tensor(&device));
240        let module = module.valid();
241        let output = module.forward(input_tensor(&device));
242
243        output
244            .to_data()
245            .assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());
246    }
247
248    fn expected_valid() -> TensorData {
249        TensorData::from([
250            [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
251            [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
252        ])
253    }
254
255    fn expected_train() -> TensorData {
256        TensorData::from([
257            [
258                [1.1483e+00, 3.7521e-01],
259                [1.6272e-03, 7.5067e-01],
260                [1.6204e+00, -4.5168e-02],
261            ],
262            [
263                [6.8856e-02, -1.5923e+00],
264                [-1.6318e+00, 8.7949e-01],
265                [-5.3368e-01, -1.0416e+00],
266            ],
267        ])
268    }
269
270    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
271        Tensor::<B, 3>::from_floats(
272            [
273                [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
274                [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
275            ],
276            device,
277        )
278    }
279
280    #[test]
281    fn batch_norm_forward_train_inference() {
282        let device = Default::default();
283        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
284
285        module.forward(input_tensor(&device));
286        let module = module.valid();
287        let output = module.forward(input_tensor(&device));
288
289        output
290            .to_data()
291            .assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());
292
293        let module = module.train::<TestAutodiffBackend>();
294        let output = module.forward(input_tensor(&device));
295        output
296            .to_data()
297            .assert_approx_eq::<FT>(&expected_train(), Tolerance::default());
298    }
299}
300
301#[cfg(feature = "std")]
302#[cfg(test)]
303mod tests_2d {
304    use super::*;
305    use crate::TestAutodiffBackend;
306    use burn::module::AutodiffModule;
307    use burn::tensor::TensorData;
308    use burn::tensor::{Tolerance, ops::FloatElem};
309    type FT = FloatElem<TestAutodiffBackend>;
310
311    #[test]
312    fn batch_norm_forward_train() {
313        let device = Default::default();
314        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
315
316        let output = module.forward(input_tensor(&device));
317
318        let expected = TensorData::from([
319            [
320                [[1.5136, 0.7506], [-1.2216, 0.1477]],
321                [[0.3135, 1.2252], [-0.4150, 0.6130]],
322                [[1.4186, 0.3372], [-1.5183, 1.5262]],
323            ],
324            [
325                [[0.4483, -1.1914], [-1.2010, 0.7537]],
326                [[-1.6752, 1.3822], [-0.5058, -0.9381]],
327                [[0.0200, -0.3097], [-0.5715, -0.9026]],
328            ],
329        ]);
330        output
331            .to_data()
332            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
333    }
334
335    #[test]
336    fn batch_norm_forward_inference() {
337        let device = Default::default();
338        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
339
340        module.forward(input_tensor(&device));
341        let module = module.valid();
342        let output = module.forward(input_tensor(&device));
343
344        let expected = TensorData::from([
345            [
346                [[0.9538, 0.7103], [0.0808, 0.5179]],
347                [[0.6015, 0.8910], [0.3703, 0.6966]],
348                [[0.9171, 0.6912], [0.3037, 0.9395]],
349            ],
350            [
351                [[0.6138, 0.0904], [0.0874, 0.7113]],
352                [[-0.0297, 0.9408], [0.3415, 0.2042]],
353                [[0.6250, 0.5561], [0.5013, 0.4323]],
354            ],
355        ]);
356        output
357            .to_data()
358            .assert_approx_eq::<FT>(&expected, Tolerance::default());
359    }
360
361    #[test]
362    fn batch_norm_running_mean() {
363        let device = Default::default();
364        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
365
366        let _output = module.forward(input_tensor(&device));
367
368        let running_mean = module.running_mean.value_sync();
369
370        let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
371        running_mean
372            .reshape([3])
373            .into_data()
374            .assert_approx_eq::<FT>(&expected, Tolerance::default());
375    }
376
377    #[test]
378    fn batch_norm_running_var() {
379        let device = Default::default();
380        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
381
382        let _output = module.forward(input_tensor(&device));
383
384        let running_var = module.running_var.value_sync();
385
386        let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
387        running_var
388            .reshape([3])
389            .into_data()
390            .assert_approx_eq::<FT>(&expected, Tolerance::default());
391    }
392
393    #[test]
394    fn batch_norm_running_mean_inner_module() {
395        let device = Default::default();
396        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
397
398        let _output = module.forward(input_tensor(&device));
399
400        let module_valid = module.valid();
401        let running_mean = module_valid.running_mean.value();
402        let running_mean_after = module.running_mean.value();
403
404        running_mean_after
405            .into_data()
406            .assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());
407    }
408
409    #[test]
410    fn batch_norm_grads() {
411        let device = Default::default();
412        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
413        let input = input_tensor(&device).require_grad();
414
415        let output = module.forward(input.clone());
416
417        let grads = output.backward();
418
419        let tolerance = Tolerance::rel_abs(0.1, 0.001);
420        let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
421        module
422            .gamma
423            .grad(&grads)
424            .unwrap()
425            .reshape([3])
426            .into_data()
427            .assert_approx_eq::<FT>(&expected, tolerance);
428
429        let expected = TensorData::from([8., 8., 8.]);
430        module
431            .beta
432            .grad(&grads)
433            .unwrap()
434            .reshape([3])
435            .into_data()
436            .assert_approx_eq::<FT>(&expected, tolerance);
437
438        let expected = TensorData::from([
439            [
440                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
441                [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
442                [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
443            ],
444            [
445                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
446                [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
447                [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
448            ],
449        ]);
450        input
451            .grad(&grads)
452            .unwrap()
453            .into_data()
454            .assert_approx_eq::<FT>(&expected, tolerance);
455    }
456
457    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
458        Tensor::<B, 4>::from_floats(
459            [
460                [
461                    [[0.9601, 0.7277], [0.1270, 0.5441]],
462                    [[0.6272, 0.9034], [0.4066, 0.7179]],
463                    [[0.9378, 0.7230], [0.3544, 0.9591]],
464                ],
465                [
466                    [[0.6356, 0.1362], [0.1333, 0.7287]],
467                    [[0.0249, 0.9509], [0.3791, 0.2481]],
468                    [[0.6600, 0.5945], [0.5424, 0.4767]],
469                ],
470            ],
471            device,
472        )
473    }
474
475    #[test]
476    fn display() {
477        let batch_norm = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&Default::default());
478
479        assert_eq!(
480            format!("{batch_norm}"),
481            "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
482        );
483    }
484}