burn_core/nn/norm/
batch.rs

1use crate as burn;
2use crate::module::{Content, DisplaySettings, ModuleDisplay};
3
4use crate::nn::Initializer;
5use crate::{
6    config::Config,
7    module::{Module, Param, RunningState},
8    tensor::{backend::Backend, Tensor},
9};
10
11/// Configuration to create a [BatchNorm](BatchNorm) layer using the [init function](BatchNormConfig::init).
12#[derive(Config, Debug)]
13pub struct BatchNormConfig {
14    /// The number of features.
15    pub num_features: usize,
16    /// A value required for numerical stability. Default: 1e-5
17    #[config(default = 1e-5)]
18    pub epsilon: f64,
19    /// Momentum used to update the metrics. Default: 0.1
20    #[config(default = 0.1)]
21    pub momentum: f64,
22}
23
24/// Applies Batch Normalization over a tensor as described in the paper [Batch Normalization](https://arxiv.org/abs/1502.03167)
25///
26/// `Y = norm(X) * γ + β`
27///
28/// Where:
29/// - `X` is the input tensor
30/// - `Y` is the output tensor
31/// - `norm` is the normalization function
32/// - `γ` is the learnable weight
33/// - `β` is the learnable bias
34///
35/// Should be created using [BatchNormConfig].
36#[derive(Module, Debug)]
37#[module(custom_display)]
38pub struct BatchNorm<B: Backend, const D: usize> {
39    /// The learnable weight gamma.
40    pub gamma: Param<Tensor<B, 1>>,
41    /// The learnable weight beta.
42    pub beta: Param<Tensor<B, 1>>,
43    /// The running mean.
44    pub running_mean: RunningState<Tensor<B, 1>>,
45    /// The running variance.
46    pub running_var: RunningState<Tensor<B, 1>>,
47    /// Momentum used to update the metrics.
48    pub momentum: f64,
49    /// A value required for numerical stability.
50    pub epsilon: f64,
51}
52
53impl BatchNormConfig {
54    /// Initializes a new [batch norm](BatchNorm) module.
55    pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> BatchNorm<B, D> {
56        let gamma = Initializer::Ones.init([self.num_features], device);
57        let beta = Initializer::Zeros.init([self.num_features], device);
58
59        let running_mean = Tensor::zeros([self.num_features], device);
60        let running_var = Tensor::ones([self.num_features], device);
61
62        BatchNorm {
63            gamma,
64            beta,
65            running_mean: RunningState::new(running_mean),
66            running_var: RunningState::new(running_var),
67            momentum: self.momentum,
68            epsilon: self.epsilon,
69        }
70    }
71}
72
73impl<const D: usize, B: Backend> BatchNorm<B, D> {
74    /// Applies the forward pass on the input tensor.
75    ///
76    /// See [BatchNorm](BatchNorm) for more information.
77    ///
78    /// # Shapes
79    ///
80    /// - input: `[batch_size, channels, ...]`
81    /// - output: `[batch_size, channels, ...]`
82    ///
83    /// # Panics
84    ///
85    /// This function will panic if the input tensor has a dimension different from `D + 2`.
86    pub fn forward<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
87        // Should be move to a compilation error when const generic support that kind of
88        // validation. https://github.com/rust-lang/rust/issues/76560
89        if D + 2 != DI {
90            panic!(
91                "BatchNorm{}D can only be applied on tensors of size {} with the following shape \
92                 [batch_size, channels, ...], received {}D tensor",
93                D,
94                D + 2,
95                DI
96            );
97        }
98
99        match B::ad_enabled() {
100            true => self.forward_train(input),
101            false => self.forward_inference(input),
102        }
103    }
104
105    fn forward_inference<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
106        let device = input.device();
107        let channels = input.dims()[1];
108        let mean = self.running_mean.value().to_device(&device);
109        let var = self.running_var.value().to_device(&device);
110
111        let mut shape = [1; DI];
112        shape[1] = channels;
113
114        self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
115    }
116
117    fn forward_train<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
118        let device = input.device();
119        let dims = input.dims();
120        let batch_size = dims[0];
121        let channels = dims[1];
122
123        let mut shape_unsqueeze = [1; DI];
124        let mut flatten_size = batch_size;
125        shape_unsqueeze[1] = channels;
126
127        for dim in dims.iter().take(DI).skip(2) {
128            flatten_size *= dim;
129        }
130
131        let mean = input
132            .clone()
133            .swap_dims(0, 1)
134            .reshape([channels, flatten_size])
135            .mean_dim(1)
136            .reshape(shape_unsqueeze);
137
138        let var = input
139            .clone()
140            .sub(mean.clone())
141            .powf_scalar(2.0)
142            .swap_dims(0, 1)
143            .reshape([channels, flatten_size])
144            .mean_dim(1)
145            .reshape(shape_unsqueeze);
146
147        let running_mean = self.running_mean.value_sync().to_device(&device);
148        let running_var = self.running_var.value_sync().to_device(&device);
149
150        let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
151            mean.clone()
152                .detach()
153                .mul_scalar(self.momentum)
154                .reshape([channels]),
155        );
156        let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
157            var.clone()
158                .detach()
159                .mul_scalar(self.momentum)
160                .reshape([channels]),
161        );
162
163        self.running_mean.update(running_mean.detach());
164        self.running_var.update(running_var.detach());
165
166        self.forward_shared(input, mean, var)
167    }
168
169    fn forward_shared<const DI: usize>(
170        &self,
171        x: Tensor<B, DI>,
172        mean: Tensor<B, DI>,
173        var: Tensor<B, DI>,
174    ) -> Tensor<B, DI> {
175        let channels = x.dims()[1];
176        let mut shape = [1; DI];
177        shape[1] = channels;
178
179        let std = var.add_scalar(self.epsilon).sqrt();
180
181        let x = x.sub(mean);
182        let x = x.div(std);
183
184        let x = x.mul(self.gamma.val().reshape(shape));
185
186        x.add(self.beta.val().reshape(shape))
187    }
188}
189
190impl<const D: usize, B: Backend> ModuleDisplay for BatchNorm<B, D> {
191    fn custom_settings(&self) -> Option<DisplaySettings> {
192        DisplaySettings::new()
193            .with_new_line_after_attribute(false)
194            .optional()
195    }
196
197    fn custom_content(&self, content: Content) -> Option<Content> {
198        let [num_features] = self.beta.shape().dims();
199
200        content
201            .add("num_features", &num_features)
202            .add("momentum", &self.momentum)
203            .add("epsilon", &self.epsilon)
204            .optional()
205    }
206}
207
208#[cfg(feature = "std")]
209#[cfg(test)]
210mod tests_1d {
211    use super::*;
212    use crate::tensor::TensorData;
213    use crate::{module::AutodiffModule, TestAutodiffBackend};
214
215    #[test]
216    fn batch_norm_forward_train() {
217        let device = Default::default();
218        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 1>(&device);
219
220        let output = module.forward(input_tensor(&device));
221
222        let expected = TensorData::from([
223            [
224                [1.1483e+00, 3.7521e-01],
225                [1.6272e-03, 7.5067e-01],
226                [1.6204e+00, -4.5168e-02],
227            ],
228            [
229                [6.8856e-02, -1.5923e+00],
230                [-1.6318e+00, 8.7949e-01],
231                [-5.3368e-01, -1.0416e+00],
232            ],
233        ]);
234        output.to_data().assert_approx_eq(&expected, 2);
235    }
236
237    #[test]
238    fn batch_norm_forward_inference() {
239        let device = Default::default();
240        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 1>(&device);
241
242        module.forward(input_tensor(&device));
243        let module = module.valid();
244        let output = module.forward(input_tensor(&device));
245
246        let expected = TensorData::from([
247            [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
248            [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
249        ]);
250        output.to_data().assert_approx_eq(&expected, 2);
251    }
252
253    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
254        Tensor::<B, 3>::from_floats(
255            [
256                [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
257                [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
258            ],
259            device,
260        )
261    }
262}
263
264#[cfg(feature = "std")]
265#[cfg(test)]
266mod tests_2d {
267    use super::*;
268    use crate::tensor::TensorData;
269    use crate::{module::AutodiffModule, TestAutodiffBackend};
270
271    #[test]
272    fn batch_norm_forward_train() {
273        let device = Default::default();
274        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
275
276        let output = module.forward(input_tensor(&device));
277
278        let expected = TensorData::from([
279            [
280                [[1.5136, 0.7506], [-1.2216, 0.1477]],
281                [[0.3135, 1.2252], [-0.4150, 0.6130]],
282                [[1.4186, 0.3372], [-1.5183, 1.5262]],
283            ],
284            [
285                [[0.4483, -1.1914], [-1.2010, 0.7537]],
286                [[-1.6752, 1.3822], [-0.5058, -0.9381]],
287                [[0.0200, -0.3097], [-0.5715, -0.9026]],
288            ],
289        ]);
290        output.to_data().assert_approx_eq(&expected, 2);
291    }
292
293    #[test]
294    fn batch_norm_forward_inference() {
295        let device = Default::default();
296        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
297
298        module.forward(input_tensor(&device));
299        let module = module.valid();
300        let output = module.forward(input_tensor(&device));
301
302        let expected = TensorData::from([
303            [
304                [[0.9538, 0.7103], [0.0808, 0.5179]],
305                [[0.6015, 0.8910], [0.3703, 0.6966]],
306                [[0.9171, 0.6912], [0.3037, 0.9395]],
307            ],
308            [
309                [[0.6138, 0.0904], [0.0874, 0.7113]],
310                [[-0.0297, 0.9408], [0.3415, 0.2042]],
311                [[0.6250, 0.5561], [0.5013, 0.4323]],
312            ],
313        ]);
314        output.to_data().assert_approx_eq(&expected, 2);
315    }
316
317    #[test]
318    fn batch_norm_running_mean() {
319        let device = Default::default();
320        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
321
322        let _output = module.forward(input_tensor(&device));
323
324        let running_mean = module.running_mean.value_sync();
325
326        let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
327        running_mean
328            .reshape([3])
329            .into_data()
330            .assert_approx_eq(&expected, 2);
331    }
332
333    #[test]
334    fn batch_norm_running_var() {
335        let device = Default::default();
336        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
337
338        let _output = module.forward(input_tensor(&device));
339
340        let running_var = module.running_var.value_sync();
341
342        let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
343        running_var
344            .reshape([3])
345            .into_data()
346            .assert_approx_eq(&expected, 2);
347    }
348
349    #[test]
350    fn batch_norm_running_mean_inner_module() {
351        let device = Default::default();
352        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
353
354        let _output = module.forward(input_tensor(&device));
355
356        let module_valid = module.valid();
357        let running_mean = module_valid.running_mean.value();
358        let running_mean_after = module.running_mean.value();
359
360        running_mean_after
361            .into_data()
362            .assert_approx_eq(&running_mean.into_data(), 3);
363    }
364
365    #[test]
366    fn batch_norm_grads() {
367        let device = Default::default();
368        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
369        let input = input_tensor(&device).require_grad();
370
371        let output = module.forward(input.clone());
372
373        let grads = output.backward();
374
375        let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
376        module
377            .gamma
378            .grad(&grads)
379            .unwrap()
380            .reshape([3])
381            .into_data()
382            .assert_approx_eq(&expected, 3);
383
384        let expected = TensorData::from([8., 8., 8.]);
385        module
386            .beta
387            .grad(&grads)
388            .unwrap()
389            .reshape([3])
390            .into_data()
391            .assert_approx_eq(&expected, 3);
392
393        let expected = TensorData::from([
394            [
395                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
396                [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
397                [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
398            ],
399            [
400                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
401                [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
402                [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
403            ],
404        ]);
405        input
406            .grad(&grads)
407            .unwrap()
408            .into_data()
409            .assert_approx_eq(&expected, 4);
410    }
411
412    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
413        Tensor::<B, 4>::from_floats(
414            [
415                [
416                    [[0.9601, 0.7277], [0.1270, 0.5441]],
417                    [[0.6272, 0.9034], [0.4066, 0.7179]],
418                    [[0.9378, 0.7230], [0.3544, 0.9591]],
419                ],
420                [
421                    [[0.6356, 0.1362], [0.1333, 0.7287]],
422                    [[0.0249, 0.9509], [0.3791, 0.2481]],
423                    [[0.6600, 0.5945], [0.5424, 0.4767]],
424                ],
425            ],
426            device,
427        )
428    }
429
430    #[test]
431    fn display() {
432        let batch_norm =
433            BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&Default::default());
434
435        assert_eq!(
436            format!("{}", batch_norm),
437            "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
438        );
439    }
440}