burn_nn/modules/norm/
batch.rs

1use burn_core as burn;
2
3use burn::module::Initializer;
4use burn::module::{Content, DisplaySettings, ModuleDisplay};
5use burn::tensor::{Tensor, backend::Backend};
6use burn::{
7    config::Config,
8    module::{Module, Param, RunningState},
9};
10
11/// [`BatchNorm`] Configuration.
12///
13/// Used to create a [`BatchNorm`] layer using the [`BatchNormConfig::init`].
14#[derive(Config, Debug)]
15pub struct BatchNormConfig {
16    /// The number of features.
17    pub num_features: usize,
18    /// A value required for numerical stability. Default: 1e-5
19    #[config(default = 1e-5)]
20    pub epsilon: f64,
21    /// Momentum used to update the metrics. Default: 0.1
22    #[config(default = 0.1)]
23    pub momentum: f64,
24}
25
26/// Applies Batch Normalization over a tensor.
27///
28/// Based upon the paper [Batch Normalization](https://arxiv.org/abs/1502.03167).
29///
30/// Assumes input tensor is of shape ``[batch_size, channels, ...]``.
31///
32/// `Y = norm(X) * γ + β`
33///
34/// Where:
35/// - `X` is the input tensor
36/// - `Y` is the output tensor
37/// - `norm` is the normalization function
38/// - `γ` is the learnable weight
39/// - `β` is the learnable bias
40///
41/// Should be created using [`BatchNormConfig`].
42#[derive(Module, Debug)]
43#[module(custom_display)]
44pub struct BatchNorm<B: Backend> {
45    /// The learnable weight gamma.
46    pub gamma: Param<Tensor<B, 1>>,
47    /// The learnable weight beta.
48    pub beta: Param<Tensor<B, 1>>,
49    /// The running mean.
50    pub running_mean: RunningState<Tensor<B, 1>>,
51    /// The running variance.
52    pub running_var: RunningState<Tensor<B, 1>>,
53    /// Momentum used to update the metrics.
54    pub momentum: f64,
55    /// A value required for numerical stability.
56    pub epsilon: f64,
57}
58
59impl BatchNormConfig {
60    /// Initializes a new [batch norm](BatchNorm) module.
61    pub fn init<B: Backend>(&self, device: &B::Device) -> BatchNorm<B> {
62        let gamma = Initializer::Ones.init([self.num_features], device);
63        let beta = Initializer::Zeros.init([self.num_features], device);
64
65        let running_mean = Tensor::zeros([self.num_features], device);
66        let running_var = Tensor::ones([self.num_features], device);
67
68        BatchNorm {
69            gamma,
70            beta,
71            running_mean: RunningState::new(running_mean),
72            running_var: RunningState::new(running_var),
73            momentum: self.momentum,
74            epsilon: self.epsilon,
75        }
76    }
77}
78
79impl<B: Backend> BatchNorm<B> {
80    /// Applies the forward pass on the input tensor.
81    ///
82    /// See [`BatchNorm`] for more information.
83    ///
84    /// # Shapes
85    ///
86    /// - `input`: ``[batch_size, channels, ...]``
87    /// - `output`: ``[batch_size, channels, ...]``
88    ///
89    /// # Panics
90    ///
91    /// This function will panic if the input tensor has rank < 2.
92    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
93        // Should be move to a compilation error when const generic support that kind of
94        // validation. https://github.com/rust-lang/rust/issues/76560
95        if D < 2 {
96            panic!(
97                "BatchNorm can only be applied on tensors of rank >= 2 with the following shape \
98                 [batch_size, channels, ...], received {}D tensor",
99                D
100            );
101        }
102
103        match B::ad_enabled() {
104            true => self.forward_train(input),
105            false => self.forward_inference(input),
106        }
107    }
108
109    fn forward_inference<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
110        let device = input.device();
111        let channels = input.dims()[1];
112        let mean = self.running_mean.value().to_device(&device);
113        let var = self.running_var.value().to_device(&device);
114
115        let mut shape = [1; D];
116        shape[1] = channels;
117
118        self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
119    }
120
121    fn forward_train<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
122        let device = input.device();
123        let dims = input.dims();
124        let batch_size = dims[0];
125        let channels = dims[1];
126
127        let mut shape_unsqueeze = [1; D];
128        let mut flatten_size = batch_size;
129        shape_unsqueeze[1] = channels;
130
131        for dim in dims.iter().take(D).skip(2) {
132            flatten_size *= dim;
133        }
134
135        let mean = input
136            .clone()
137            .swap_dims(0, 1)
138            .reshape([channels, flatten_size])
139            .mean_dim(1)
140            .reshape(shape_unsqueeze);
141
142        let var = input
143            .clone()
144            .sub(mean.clone())
145            .square()
146            .swap_dims(0, 1)
147            .reshape([channels, flatten_size])
148            .mean_dim(1)
149            .reshape(shape_unsqueeze);
150
151        let running_mean = self.running_mean.value_sync().to_device(&device);
152        let running_var = self.running_var.value_sync().to_device(&device);
153
154        let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
155            mean.clone()
156                .detach()
157                .mul_scalar(self.momentum)
158                .reshape([channels]),
159        );
160        let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
161            var.clone()
162                .detach()
163                .mul_scalar(self.momentum)
164                .reshape([channels]),
165        );
166
167        self.running_mean.update(running_mean.detach());
168        self.running_var.update(running_var.detach());
169
170        self.forward_shared(input, mean, var)
171    }
172
173    fn forward_shared<const D: usize>(
174        &self,
175        x: Tensor<B, D>,
176        mean: Tensor<B, D>,
177        var: Tensor<B, D>,
178    ) -> Tensor<B, D> {
179        let channels = x.dims()[1];
180        let mut shape = [1; D];
181        shape[1] = channels;
182
183        let std = var.add_scalar(self.epsilon).sqrt();
184
185        let x = x.sub(mean);
186        let x = x.div(std);
187
188        let x = x.mul(self.gamma.val().reshape(shape));
189
190        x.add(self.beta.val().reshape(shape))
191    }
192}
193
194impl<B: Backend> ModuleDisplay for BatchNorm<B> {
195    fn custom_settings(&self) -> Option<DisplaySettings> {
196        DisplaySettings::new()
197            .with_new_line_after_attribute(false)
198            .optional()
199    }
200
201    fn custom_content(&self, content: Content) -> Option<Content> {
202        let [num_features] = self.beta.shape().dims();
203
204        content
205            .add("num_features", &num_features)
206            .add("momentum", &self.momentum)
207            .add("epsilon", &self.epsilon)
208            .optional()
209    }
210}
211
212#[cfg(feature = "std")]
213#[cfg(test)]
214mod tests_1d {
215    use super::*;
216    use crate::TestAutodiffBackend;
217    use burn::module::AutodiffModule;
218    use burn::tensor::TensorData;
219    use burn::tensor::{Tolerance, ops::FloatElem};
220    type FT = FloatElem<TestAutodiffBackend>;
221
222    #[test]
223    fn batch_norm_forward_train() {
224        let device = Default::default();
225        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
226
227        let output = module.forward(input_tensor(&device));
228
229        let expected = TensorData::from([
230            [
231                [1.1483e+00, 3.7521e-01],
232                [1.6272e-03, 7.5067e-01],
233                [1.6204e+00, -4.5168e-02],
234            ],
235            [
236                [6.8856e-02, -1.5923e+00],
237                [-1.6318e+00, 8.7949e-01],
238                [-5.3368e-01, -1.0416e+00],
239            ],
240        ]);
241        output
242            .to_data()
243            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
244    }
245
246    #[test]
247    fn batch_norm_forward_inference() {
248        let device = Default::default();
249        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
250
251        module.forward(input_tensor(&device));
252        let module = module.valid();
253        let output = module.forward(input_tensor(&device));
254
255        let expected = TensorData::from([
256            [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
257            [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
258        ]);
259        output
260            .to_data()
261            .assert_approx_eq::<FT>(&expected, Tolerance::default());
262    }
263
264    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
265        Tensor::<B, 3>::from_floats(
266            [
267                [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
268                [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
269            ],
270            device,
271        )
272    }
273}
274
275#[cfg(feature = "std")]
276#[cfg(test)]
277mod tests_2d {
278    use super::*;
279    use crate::TestAutodiffBackend;
280    use burn::module::AutodiffModule;
281    use burn::tensor::TensorData;
282    use burn::tensor::{Tolerance, ops::FloatElem};
283    type FT = FloatElem<TestAutodiffBackend>;
284
285    #[test]
286    fn batch_norm_forward_train() {
287        let device = Default::default();
288        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
289
290        let output = module.forward(input_tensor(&device));
291
292        let expected = TensorData::from([
293            [
294                [[1.5136, 0.7506], [-1.2216, 0.1477]],
295                [[0.3135, 1.2252], [-0.4150, 0.6130]],
296                [[1.4186, 0.3372], [-1.5183, 1.5262]],
297            ],
298            [
299                [[0.4483, -1.1914], [-1.2010, 0.7537]],
300                [[-1.6752, 1.3822], [-0.5058, -0.9381]],
301                [[0.0200, -0.3097], [-0.5715, -0.9026]],
302            ],
303        ]);
304        output
305            .to_data()
306            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
307    }
308
309    #[test]
310    fn batch_norm_forward_inference() {
311        let device = Default::default();
312        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
313
314        module.forward(input_tensor(&device));
315        let module = module.valid();
316        let output = module.forward(input_tensor(&device));
317
318        let expected = TensorData::from([
319            [
320                [[0.9538, 0.7103], [0.0808, 0.5179]],
321                [[0.6015, 0.8910], [0.3703, 0.6966]],
322                [[0.9171, 0.6912], [0.3037, 0.9395]],
323            ],
324            [
325                [[0.6138, 0.0904], [0.0874, 0.7113]],
326                [[-0.0297, 0.9408], [0.3415, 0.2042]],
327                [[0.6250, 0.5561], [0.5013, 0.4323]],
328            ],
329        ]);
330        output
331            .to_data()
332            .assert_approx_eq::<FT>(&expected, Tolerance::default());
333    }
334
335    #[test]
336    fn batch_norm_running_mean() {
337        let device = Default::default();
338        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
339
340        let _output = module.forward(input_tensor(&device));
341
342        let running_mean = module.running_mean.value_sync();
343
344        let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
345        running_mean
346            .reshape([3])
347            .into_data()
348            .assert_approx_eq::<FT>(&expected, Tolerance::default());
349    }
350
351    #[test]
352    fn batch_norm_running_var() {
353        let device = Default::default();
354        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
355
356        let _output = module.forward(input_tensor(&device));
357
358        let running_var = module.running_var.value_sync();
359
360        let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
361        running_var
362            .reshape([3])
363            .into_data()
364            .assert_approx_eq::<FT>(&expected, Tolerance::default());
365    }
366
367    #[test]
368    fn batch_norm_running_mean_inner_module() {
369        let device = Default::default();
370        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
371
372        let _output = module.forward(input_tensor(&device));
373
374        let module_valid = module.valid();
375        let running_mean = module_valid.running_mean.value();
376        let running_mean_after = module.running_mean.value();
377
378        running_mean_after
379            .into_data()
380            .assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());
381    }
382
383    #[test]
384    fn batch_norm_grads() {
385        let device = Default::default();
386        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
387        let input = input_tensor(&device).require_grad();
388
389        let output = module.forward(input.clone());
390
391        let grads = output.backward();
392
393        let tolerance = Tolerance::rel_abs(0.1, 0.001);
394        let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
395        module
396            .gamma
397            .grad(&grads)
398            .unwrap()
399            .reshape([3])
400            .into_data()
401            .assert_approx_eq::<FT>(&expected, tolerance);
402
403        let expected = TensorData::from([8., 8., 8.]);
404        module
405            .beta
406            .grad(&grads)
407            .unwrap()
408            .reshape([3])
409            .into_data()
410            .assert_approx_eq::<FT>(&expected, tolerance);
411
412        let expected = TensorData::from([
413            [
414                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
415                [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
416                [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
417            ],
418            [
419                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
420                [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
421                [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
422            ],
423        ]);
424        input
425            .grad(&grads)
426            .unwrap()
427            .into_data()
428            .assert_approx_eq::<FT>(&expected, tolerance);
429    }
430
431    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
432        Tensor::<B, 4>::from_floats(
433            [
434                [
435                    [[0.9601, 0.7277], [0.1270, 0.5441]],
436                    [[0.6272, 0.9034], [0.4066, 0.7179]],
437                    [[0.9378, 0.7230], [0.3544, 0.9591]],
438                ],
439                [
440                    [[0.6356, 0.1362], [0.1333, 0.7287]],
441                    [[0.0249, 0.9509], [0.3791, 0.2481]],
442                    [[0.6600, 0.5945], [0.5424, 0.4767]],
443                ],
444            ],
445            device,
446        )
447    }
448
449    #[test]
450    fn display() {
451        let batch_norm = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&Default::default());
452
453        assert_eq!(
454            format!("{batch_norm}"),
455            "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
456        );
457    }
458}