1use burn_core as burn;
2
3use burn::module::Initializer;
4use burn::module::{Content, DisplaySettings, ModuleDisplay};
5use burn::tensor::{Tensor, backend::Backend};
6use burn::{
7 config::Config,
8 module::{Module, Param, RunningState},
9};
10
11#[derive(Config, Debug)]
15pub struct BatchNormConfig {
16 pub num_features: usize,
18 #[config(default = 1e-5)]
20 pub epsilon: f64,
21 #[config(default = 0.1)]
23 pub momentum: f64,
24}
25
26#[derive(Module, Debug)]
43#[module(custom_display)]
44pub struct BatchNorm<B: Backend> {
45 pub gamma: Param<Tensor<B, 1>>,
47 pub beta: Param<Tensor<B, 1>>,
49 pub running_mean: RunningState<Tensor<B, 1>>,
51 pub running_var: RunningState<Tensor<B, 1>>,
53 pub momentum: f64,
55 pub epsilon: f64,
57}
58
59impl BatchNormConfig {
60 pub fn init<B: Backend>(&self, device: &B::Device) -> BatchNorm<B> {
62 let gamma = Initializer::Ones.init([self.num_features], device);
63 let beta = Initializer::Zeros.init([self.num_features], device);
64
65 let running_mean = Tensor::zeros([self.num_features], device);
66 let running_var = Tensor::ones([self.num_features], device);
67
68 BatchNorm {
69 gamma,
70 beta,
71 running_mean: RunningState::new(running_mean),
72 running_var: RunningState::new(running_var),
73 momentum: self.momentum,
74 epsilon: self.epsilon,
75 }
76 }
77}
78
79impl<B: Backend> BatchNorm<B> {
80 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
93 if D < 2 {
96 panic!(
97 "BatchNorm can only be applied on tensors of rank >= 2 with the following shape \
98 [batch_size, channels, ...], received {}D tensor",
99 D
100 );
101 }
102
103 match B::ad_enabled(&input.device()) {
104 true => self.forward_train(input),
105 false => self.forward_inference(input),
106 }
107 }
108
109 fn forward_inference<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
110 let device = input.device();
111 let channels = input.dims()[1];
112 let mean = self.running_mean.value().to_device(&device);
113 let var = self.running_var.value().to_device(&device);
114
115 let mut shape = [1; D];
116 shape[1] = channels;
117
118 self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
119 }
120
121 fn forward_train<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
122 let device = input.device();
123 let dims = input.dims();
124 let batch_size = dims[0];
125 let channels = dims[1];
126
127 let mut shape_unsqueeze = [1; D];
128 let mut flatten_size = batch_size;
129 shape_unsqueeze[1] = channels;
130
131 for dim in dims.iter().take(D).skip(2) {
132 flatten_size *= dim;
133 }
134
135 let mean = input
136 .clone()
137 .swap_dims(0, 1)
138 .reshape([channels, flatten_size])
139 .mean_dim(1)
140 .reshape(shape_unsqueeze);
141
142 let var = input
143 .clone()
144 .sub(mean.clone())
145 .square()
146 .swap_dims(0, 1)
147 .reshape([channels, flatten_size])
148 .mean_dim(1)
149 .reshape(shape_unsqueeze);
150
151 let running_mean = self.running_mean.value_sync().to_device(&device);
152 let running_var = self.running_var.value_sync().to_device(&device);
153
154 let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
155 mean.clone()
156 .detach()
157 .mul_scalar(self.momentum)
158 .reshape([channels]),
159 );
160 let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
161 var.clone()
162 .detach()
163 .mul_scalar(self.momentum)
164 .reshape([channels]),
165 );
166
167 self.running_mean.update(running_mean.detach());
168 self.running_var.update(running_var.detach());
169
170 self.forward_shared(input, mean, var)
171 }
172
173 fn forward_shared<const D: usize>(
174 &self,
175 x: Tensor<B, D>,
176 mean: Tensor<B, D>,
177 var: Tensor<B, D>,
178 ) -> Tensor<B, D> {
179 let channels = x.dims()[1];
180 let mut shape = [1; D];
181 shape[1] = channels;
182
183 let std = var.add_scalar(self.epsilon).sqrt();
184
185 let x = x.sub(mean);
186 let x = x.div(std);
187
188 let x = x.mul(self.gamma.val().reshape(shape));
189
190 x.add(self.beta.val().reshape(shape))
191 }
192}
193
194impl<B: Backend> ModuleDisplay for BatchNorm<B> {
195 fn custom_settings(&self) -> Option<DisplaySettings> {
196 DisplaySettings::new()
197 .with_new_line_after_attribute(false)
198 .optional()
199 }
200
201 fn custom_content(&self, content: Content) -> Option<Content> {
202 let [num_features] = self.beta.shape().dims();
203
204 content
205 .add("num_features", &num_features)
206 .add("momentum", &self.momentum)
207 .add("epsilon", &self.epsilon)
208 .optional()
209 }
210}
211
212#[cfg(feature = "std")]
213#[cfg(test)]
214mod tests_1d {
215 use super::*;
216 use crate::TestAutodiffBackend;
217 use burn::module::AutodiffModule;
218 use burn::tensor::TensorData;
219 use burn::tensor::{Tolerance, ops::FloatElem};
220 type FT = FloatElem<TestAutodiffBackend>;
221
222 #[test]
223 fn batch_norm_forward_train() {
224 let device = Default::default();
225 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
226
227 let output = module.forward(input_tensor(&device));
228
229 output
230 .to_data()
231 .assert_approx_eq::<FT>(&expected_train(), Tolerance::rel_abs(0.1, 0.001));
232 }
233
234 #[test]
235 fn batch_norm_forward_inference() {
236 let device = Default::default();
237 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
238
239 module.forward(input_tensor(&device));
240 let module = module.valid();
241 let output = module.forward(input_tensor(&device));
242
243 output
244 .to_data()
245 .assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());
246 }
247
248 fn expected_valid() -> TensorData {
249 TensorData::from([
250 [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
251 [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
252 ])
253 }
254
255 fn expected_train() -> TensorData {
256 TensorData::from([
257 [
258 [1.1483e+00, 3.7521e-01],
259 [1.6272e-03, 7.5067e-01],
260 [1.6204e+00, -4.5168e-02],
261 ],
262 [
263 [6.8856e-02, -1.5923e+00],
264 [-1.6318e+00, 8.7949e-01],
265 [-5.3368e-01, -1.0416e+00],
266 ],
267 ])
268 }
269
270 fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
271 Tensor::<B, 3>::from_floats(
272 [
273 [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
274 [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
275 ],
276 device,
277 )
278 }
279
280 #[test]
281 fn batch_norm_forward_train_inference() {
282 let device = Default::default();
283 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
284
285 module.forward(input_tensor(&device));
286 let module = module.valid();
287 let output = module.forward(input_tensor(&device));
288
289 output
290 .to_data()
291 .assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());
292
293 let module = module.train::<TestAutodiffBackend>();
294 let output = module.forward(input_tensor(&device));
295 output
296 .to_data()
297 .assert_approx_eq::<FT>(&expected_train(), Tolerance::default());
298 }
299}
300
301#[cfg(feature = "std")]
302#[cfg(test)]
303mod tests_2d {
304 use super::*;
305 use crate::TestAutodiffBackend;
306 use burn::module::AutodiffModule;
307 use burn::tensor::TensorData;
308 use burn::tensor::{Tolerance, ops::FloatElem};
309 type FT = FloatElem<TestAutodiffBackend>;
310
311 #[test]
312 fn batch_norm_forward_train() {
313 let device = Default::default();
314 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
315
316 let output = module.forward(input_tensor(&device));
317
318 let expected = TensorData::from([
319 [
320 [[1.5136, 0.7506], [-1.2216, 0.1477]],
321 [[0.3135, 1.2252], [-0.4150, 0.6130]],
322 [[1.4186, 0.3372], [-1.5183, 1.5262]],
323 ],
324 [
325 [[0.4483, -1.1914], [-1.2010, 0.7537]],
326 [[-1.6752, 1.3822], [-0.5058, -0.9381]],
327 [[0.0200, -0.3097], [-0.5715, -0.9026]],
328 ],
329 ]);
330 output
331 .to_data()
332 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
333 }
334
335 #[test]
336 fn batch_norm_forward_inference() {
337 let device = Default::default();
338 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
339
340 module.forward(input_tensor(&device));
341 let module = module.valid();
342 let output = module.forward(input_tensor(&device));
343
344 let expected = TensorData::from([
345 [
346 [[0.9538, 0.7103], [0.0808, 0.5179]],
347 [[0.6015, 0.8910], [0.3703, 0.6966]],
348 [[0.9171, 0.6912], [0.3037, 0.9395]],
349 ],
350 [
351 [[0.6138, 0.0904], [0.0874, 0.7113]],
352 [[-0.0297, 0.9408], [0.3415, 0.2042]],
353 [[0.6250, 0.5561], [0.5013, 0.4323]],
354 ],
355 ]);
356 output
357 .to_data()
358 .assert_approx_eq::<FT>(&expected, Tolerance::default());
359 }
360
361 #[test]
362 fn batch_norm_running_mean() {
363 let device = Default::default();
364 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
365
366 let _output = module.forward(input_tensor(&device));
367
368 let running_mean = module.running_mean.value_sync();
369
370 let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
371 running_mean
372 .reshape([3])
373 .into_data()
374 .assert_approx_eq::<FT>(&expected, Tolerance::default());
375 }
376
377 #[test]
378 fn batch_norm_running_var() {
379 let device = Default::default();
380 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
381
382 let _output = module.forward(input_tensor(&device));
383
384 let running_var = module.running_var.value_sync();
385
386 let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
387 running_var
388 .reshape([3])
389 .into_data()
390 .assert_approx_eq::<FT>(&expected, Tolerance::default());
391 }
392
393 #[test]
394 fn batch_norm_running_mean_inner_module() {
395 let device = Default::default();
396 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
397
398 let _output = module.forward(input_tensor(&device));
399
400 let module_valid = module.valid();
401 let running_mean = module_valid.running_mean.value();
402 let running_mean_after = module.running_mean.value();
403
404 running_mean_after
405 .into_data()
406 .assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());
407 }
408
409 #[test]
410 fn batch_norm_grads() {
411 let device = Default::default();
412 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
413 let input = input_tensor(&device).require_grad();
414
415 let output = module.forward(input.clone());
416
417 let grads = output.backward();
418
419 let tolerance = Tolerance::rel_abs(0.1, 0.001);
420 let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
421 module
422 .gamma
423 .grad(&grads)
424 .unwrap()
425 .reshape([3])
426 .into_data()
427 .assert_approx_eq::<FT>(&expected, tolerance);
428
429 let expected = TensorData::from([8., 8., 8.]);
430 module
431 .beta
432 .grad(&grads)
433 .unwrap()
434 .reshape([3])
435 .into_data()
436 .assert_approx_eq::<FT>(&expected, tolerance);
437
438 let expected = TensorData::from([
439 [
440 [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
441 [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
442 [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
443 ],
444 [
445 [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
446 [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
447 [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
448 ],
449 ]);
450 input
451 .grad(&grads)
452 .unwrap()
453 .into_data()
454 .assert_approx_eq::<FT>(&expected, tolerance);
455 }
456
457 fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
458 Tensor::<B, 4>::from_floats(
459 [
460 [
461 [[0.9601, 0.7277], [0.1270, 0.5441]],
462 [[0.6272, 0.9034], [0.4066, 0.7179]],
463 [[0.9378, 0.7230], [0.3544, 0.9591]],
464 ],
465 [
466 [[0.6356, 0.1362], [0.1333, 0.7287]],
467 [[0.0249, 0.9509], [0.3791, 0.2481]],
468 [[0.6600, 0.5945], [0.5424, 0.4767]],
469 ],
470 ],
471 device,
472 )
473 }
474
475 #[test]
476 fn display() {
477 let batch_norm = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&Default::default());
478
479 assert_eq!(
480 format!("{batch_norm}"),
481 "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
482 );
483 }
484}