burn_core/nn/norm/
batch.rs

1use crate as burn;
2use crate::module::{Content, DisplaySettings, ModuleDisplay};
3
4use crate::nn::Initializer;
5use crate::{
6    config::Config,
7    module::{Module, Param, RunningState},
8    tensor::{Tensor, backend::Backend},
9};
10
11/// Configuration to create a [BatchNorm](BatchNorm) layer using the [init function](BatchNormConfig::init).
12#[derive(Config, Debug)]
13pub struct BatchNormConfig {
14    /// The number of features.
15    pub num_features: usize,
16    /// A value required for numerical stability. Default: 1e-5
17    #[config(default = 1e-5)]
18    pub epsilon: f64,
19    /// Momentum used to update the metrics. Default: 0.1
20    #[config(default = 0.1)]
21    pub momentum: f64,
22}
23
24/// Applies Batch Normalization over a tensor as described in the paper [Batch Normalization](https://arxiv.org/abs/1502.03167)
25///
26/// `Y = norm(X) * γ + β`
27///
28/// Where:
29/// - `X` is the input tensor
30/// - `Y` is the output tensor
31/// - `norm` is the normalization function
32/// - `γ` is the learnable weight
33/// - `β` is the learnable bias
34///
35/// Should be created using [BatchNormConfig].
36#[derive(Module, Debug)]
37#[module(custom_display)]
38pub struct BatchNorm<B: Backend, const D: usize> {
39    /// The learnable weight gamma.
40    pub gamma: Param<Tensor<B, 1>>,
41    /// The learnable weight beta.
42    pub beta: Param<Tensor<B, 1>>,
43    /// The running mean.
44    pub running_mean: RunningState<Tensor<B, 1>>,
45    /// The running variance.
46    pub running_var: RunningState<Tensor<B, 1>>,
47    /// Momentum used to update the metrics.
48    pub momentum: f64,
49    /// A value required for numerical stability.
50    pub epsilon: f64,
51}
52
53impl BatchNormConfig {
54    /// Initializes a new [batch norm](BatchNorm) module.
55    pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> BatchNorm<B, D> {
56        let gamma = Initializer::Ones.init([self.num_features], device);
57        let beta = Initializer::Zeros.init([self.num_features], device);
58
59        let running_mean = Tensor::zeros([self.num_features], device);
60        let running_var = Tensor::ones([self.num_features], device);
61
62        BatchNorm {
63            gamma,
64            beta,
65            running_mean: RunningState::new(running_mean),
66            running_var: RunningState::new(running_var),
67            momentum: self.momentum,
68            epsilon: self.epsilon,
69        }
70    }
71}
72
73impl<const D: usize, B: Backend> BatchNorm<B, D> {
74    /// Applies the forward pass on the input tensor.
75    ///
76    /// See [BatchNorm](BatchNorm) for more information.
77    ///
78    /// # Shapes
79    ///
80    /// - input: `[batch_size, channels, ...]`
81    /// - output: `[batch_size, channels, ...]`
82    ///
83    /// # Panics
84    ///
85    /// This function will panic if the input tensor has a dimension different from `D + 2`.
86    pub fn forward<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
87        // Should be move to a compilation error when const generic support that kind of
88        // validation. https://github.com/rust-lang/rust/issues/76560
89        if D + 2 != DI {
90            panic!(
91                "BatchNorm{}D can only be applied on tensors of size {} with the following shape \
92                 [batch_size, channels, ...], received {}D tensor",
93                D,
94                D + 2,
95                DI
96            );
97        }
98
99        match B::ad_enabled() {
100            true => self.forward_train(input),
101            false => self.forward_inference(input),
102        }
103    }
104
105    fn forward_inference<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
106        let device = input.device();
107        let channels = input.dims()[1];
108        let mean = self.running_mean.value().to_device(&device);
109        let var = self.running_var.value().to_device(&device);
110
111        let mut shape = [1; DI];
112        shape[1] = channels;
113
114        self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
115    }
116
117    fn forward_train<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
118        let device = input.device();
119        let dims = input.dims();
120        let batch_size = dims[0];
121        let channels = dims[1];
122
123        let mut shape_unsqueeze = [1; DI];
124        let mut flatten_size = batch_size;
125        shape_unsqueeze[1] = channels;
126
127        for dim in dims.iter().take(DI).skip(2) {
128            flatten_size *= dim;
129        }
130
131        let mean = input
132            .clone()
133            .swap_dims(0, 1)
134            .reshape([channels, flatten_size])
135            .mean_dim(1)
136            .reshape(shape_unsqueeze);
137
138        let var = input
139            .clone()
140            .sub(mean.clone())
141            .powi_scalar(2)
142            .swap_dims(0, 1)
143            .reshape([channels, flatten_size])
144            .mean_dim(1)
145            .reshape(shape_unsqueeze);
146
147        let running_mean = self.running_mean.value_sync().to_device(&device);
148        let running_var = self.running_var.value_sync().to_device(&device);
149
150        let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
151            mean.clone()
152                .detach()
153                .mul_scalar(self.momentum)
154                .reshape([channels]),
155        );
156        let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
157            var.clone()
158                .detach()
159                .mul_scalar(self.momentum)
160                .reshape([channels]),
161        );
162
163        self.running_mean.update(running_mean.detach());
164        self.running_var.update(running_var.detach());
165
166        self.forward_shared(input, mean, var)
167    }
168
169    fn forward_shared<const DI: usize>(
170        &self,
171        x: Tensor<B, DI>,
172        mean: Tensor<B, DI>,
173        var: Tensor<B, DI>,
174    ) -> Tensor<B, DI> {
175        let channels = x.dims()[1];
176        let mut shape = [1; DI];
177        shape[1] = channels;
178
179        let std = var.add_scalar(self.epsilon).sqrt();
180
181        let x = x.sub(mean);
182        let x = x.div(std);
183
184        let x = x.mul(self.gamma.val().reshape(shape));
185
186        x.add(self.beta.val().reshape(shape))
187    }
188}
189
190impl<const D: usize, B: Backend> ModuleDisplay for BatchNorm<B, D> {
191    fn custom_settings(&self) -> Option<DisplaySettings> {
192        DisplaySettings::new()
193            .with_new_line_after_attribute(false)
194            .optional()
195    }
196
197    fn custom_content(&self, content: Content) -> Option<Content> {
198        let [num_features] = self.beta.shape().dims();
199
200        content
201            .add("num_features", &num_features)
202            .add("momentum", &self.momentum)
203            .add("epsilon", &self.epsilon)
204            .optional()
205    }
206}
207
208#[cfg(feature = "std")]
209#[cfg(test)]
210mod tests_1d {
211    use super::*;
212    use crate::tensor::TensorData;
213    use crate::{TestAutodiffBackend, module::AutodiffModule};
214    use burn_tensor::{Tolerance, ops::FloatElem};
215    type FT = FloatElem<TestAutodiffBackend>;
216
217    #[test]
218    fn batch_norm_forward_train() {
219        let device = Default::default();
220        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 1>(&device);
221
222        let output = module.forward(input_tensor(&device));
223
224        let expected = TensorData::from([
225            [
226                [1.1483e+00, 3.7521e-01],
227                [1.6272e-03, 7.5067e-01],
228                [1.6204e+00, -4.5168e-02],
229            ],
230            [
231                [6.8856e-02, -1.5923e+00],
232                [-1.6318e+00, 8.7949e-01],
233                [-5.3368e-01, -1.0416e+00],
234            ],
235        ]);
236        output
237            .to_data()
238            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
239    }
240
241    #[test]
242    fn batch_norm_forward_inference() {
243        let device = Default::default();
244        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 1>(&device);
245
246        module.forward(input_tensor(&device));
247        let module = module.valid();
248        let output = module.forward(input_tensor(&device));
249
250        let expected = TensorData::from([
251            [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
252            [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
253        ]);
254        output
255            .to_data()
256            .assert_approx_eq::<FT>(&expected, Tolerance::default());
257    }
258
259    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
260        Tensor::<B, 3>::from_floats(
261            [
262                [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
263                [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
264            ],
265            device,
266        )
267    }
268}
269
270#[cfg(feature = "std")]
271#[cfg(test)]
272mod tests_2d {
273    use super::*;
274    use crate::tensor::TensorData;
275    use crate::{TestAutodiffBackend, module::AutodiffModule};
276    use burn_tensor::{Tolerance, ops::FloatElem};
277    type FT = FloatElem<TestAutodiffBackend>;
278
279    #[test]
280    fn batch_norm_forward_train() {
281        let device = Default::default();
282        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
283
284        let output = module.forward(input_tensor(&device));
285
286        let expected = TensorData::from([
287            [
288                [[1.5136, 0.7506], [-1.2216, 0.1477]],
289                [[0.3135, 1.2252], [-0.4150, 0.6130]],
290                [[1.4186, 0.3372], [-1.5183, 1.5262]],
291            ],
292            [
293                [[0.4483, -1.1914], [-1.2010, 0.7537]],
294                [[-1.6752, 1.3822], [-0.5058, -0.9381]],
295                [[0.0200, -0.3097], [-0.5715, -0.9026]],
296            ],
297        ]);
298        output
299            .to_data()
300            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
301    }
302
303    #[test]
304    fn batch_norm_forward_inference() {
305        let device = Default::default();
306        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
307
308        module.forward(input_tensor(&device));
309        let module = module.valid();
310        let output = module.forward(input_tensor(&device));
311
312        let expected = TensorData::from([
313            [
314                [[0.9538, 0.7103], [0.0808, 0.5179]],
315                [[0.6015, 0.8910], [0.3703, 0.6966]],
316                [[0.9171, 0.6912], [0.3037, 0.9395]],
317            ],
318            [
319                [[0.6138, 0.0904], [0.0874, 0.7113]],
320                [[-0.0297, 0.9408], [0.3415, 0.2042]],
321                [[0.6250, 0.5561], [0.5013, 0.4323]],
322            ],
323        ]);
324        output
325            .to_data()
326            .assert_approx_eq::<FT>(&expected, Tolerance::default());
327    }
328
329    #[test]
330    fn batch_norm_running_mean() {
331        let device = Default::default();
332        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
333
334        let _output = module.forward(input_tensor(&device));
335
336        let running_mean = module.running_mean.value_sync();
337
338        let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
339        running_mean
340            .reshape([3])
341            .into_data()
342            .assert_approx_eq::<FT>(&expected, Tolerance::default());
343    }
344
345    #[test]
346    fn batch_norm_running_var() {
347        let device = Default::default();
348        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
349
350        let _output = module.forward(input_tensor(&device));
351
352        let running_var = module.running_var.value_sync();
353
354        let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
355        running_var
356            .reshape([3])
357            .into_data()
358            .assert_approx_eq::<FT>(&expected, Tolerance::default());
359    }
360
361    #[test]
362    fn batch_norm_running_mean_inner_module() {
363        let device = Default::default();
364        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
365
366        let _output = module.forward(input_tensor(&device));
367
368        let module_valid = module.valid();
369        let running_mean = module_valid.running_mean.value();
370        let running_mean_after = module.running_mean.value();
371
372        running_mean_after
373            .into_data()
374            .assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());
375    }
376
377    #[test]
378    fn batch_norm_grads() {
379        let device = Default::default();
380        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
381        let input = input_tensor(&device).require_grad();
382
383        let output = module.forward(input.clone());
384
385        let grads = output.backward();
386
387        let tolerance = Tolerance::rel_abs(0.1, 0.001);
388        let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
389        module
390            .gamma
391            .grad(&grads)
392            .unwrap()
393            .reshape([3])
394            .into_data()
395            .assert_approx_eq::<FT>(&expected, tolerance);
396
397        let expected = TensorData::from([8., 8., 8.]);
398        module
399            .beta
400            .grad(&grads)
401            .unwrap()
402            .reshape([3])
403            .into_data()
404            .assert_approx_eq::<FT>(&expected, tolerance);
405
406        let expected = TensorData::from([
407            [
408                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
409                [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
410                [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
411            ],
412            [
413                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
414                [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
415                [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
416            ],
417        ]);
418        input
419            .grad(&grads)
420            .unwrap()
421            .into_data()
422            .assert_approx_eq::<FT>(&expected, tolerance);
423    }
424
425    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
426        Tensor::<B, 4>::from_floats(
427            [
428                [
429                    [[0.9601, 0.7277], [0.1270, 0.5441]],
430                    [[0.6272, 0.9034], [0.4066, 0.7179]],
431                    [[0.9378, 0.7230], [0.3544, 0.9591]],
432                ],
433                [
434                    [[0.6356, 0.1362], [0.1333, 0.7287]],
435                    [[0.0249, 0.9509], [0.3791, 0.2481]],
436                    [[0.6600, 0.5945], [0.5424, 0.4767]],
437                ],
438            ],
439            device,
440        )
441    }
442
443    #[test]
444    fn display() {
445        let batch_norm =
446            BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&Default::default());
447
448        assert_eq!(
449            format!("{batch_norm}"),
450            "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
451        );
452    }
453}