1use burn_core as burn;
2
3use burn::module::Initializer;
4use burn::module::{Content, DisplaySettings, ModuleDisplay};
5use burn::tensor::{Tensor, backend::Backend};
6use burn::{
7 config::Config,
8 module::{Module, Param, RunningState},
9};
10
11#[derive(Config, Debug)]
15pub struct BatchNormConfig {
16 pub num_features: usize,
18 #[config(default = 1e-5)]
20 pub epsilon: f64,
21 #[config(default = 0.1)]
23 pub momentum: f64,
24}
25
26#[derive(Module, Debug)]
43#[module(custom_display)]
44pub struct BatchNorm<B: Backend> {
45 pub gamma: Param<Tensor<B, 1>>,
47 pub beta: Param<Tensor<B, 1>>,
49 pub running_mean: RunningState<Tensor<B, 1>>,
51 pub running_var: RunningState<Tensor<B, 1>>,
53 pub momentum: f64,
55 pub epsilon: f64,
57}
58
59impl BatchNormConfig {
60 pub fn init<B: Backend>(&self, device: &B::Device) -> BatchNorm<B> {
62 let gamma = Initializer::Ones.init([self.num_features], device);
63 let beta = Initializer::Zeros.init([self.num_features], device);
64
65 let running_mean = Tensor::zeros([self.num_features], device);
66 let running_var = Tensor::ones([self.num_features], device);
67
68 BatchNorm {
69 gamma,
70 beta,
71 running_mean: RunningState::new(running_mean),
72 running_var: RunningState::new(running_var),
73 momentum: self.momentum,
74 epsilon: self.epsilon,
75 }
76 }
77}
78
79impl<B: Backend> BatchNorm<B> {
80 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
93 if D < 2 {
96 panic!(
97 "BatchNorm can only be applied on tensors of rank >= 2 with the following shape \
98 [batch_size, channels, ...], received {}D tensor",
99 D
100 );
101 }
102
103 match B::ad_enabled() {
104 true => self.forward_train(input),
105 false => self.forward_inference(input),
106 }
107 }
108
109 fn forward_inference<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
110 let device = input.device();
111 let channels = input.dims()[1];
112 let mean = self.running_mean.value().to_device(&device);
113 let var = self.running_var.value().to_device(&device);
114
115 let mut shape = [1; D];
116 shape[1] = channels;
117
118 self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
119 }
120
121 fn forward_train<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
122 let device = input.device();
123 let dims = input.dims();
124 let batch_size = dims[0];
125 let channels = dims[1];
126
127 let mut shape_unsqueeze = [1; D];
128 let mut flatten_size = batch_size;
129 shape_unsqueeze[1] = channels;
130
131 for dim in dims.iter().take(D).skip(2) {
132 flatten_size *= dim;
133 }
134
135 let mean = input
136 .clone()
137 .swap_dims(0, 1)
138 .reshape([channels, flatten_size])
139 .mean_dim(1)
140 .reshape(shape_unsqueeze);
141
142 let var = input
143 .clone()
144 .sub(mean.clone())
145 .square()
146 .swap_dims(0, 1)
147 .reshape([channels, flatten_size])
148 .mean_dim(1)
149 .reshape(shape_unsqueeze);
150
151 let running_mean = self.running_mean.value_sync().to_device(&device);
152 let running_var = self.running_var.value_sync().to_device(&device);
153
154 let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
155 mean.clone()
156 .detach()
157 .mul_scalar(self.momentum)
158 .reshape([channels]),
159 );
160 let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
161 var.clone()
162 .detach()
163 .mul_scalar(self.momentum)
164 .reshape([channels]),
165 );
166
167 self.running_mean.update(running_mean.detach());
168 self.running_var.update(running_var.detach());
169
170 self.forward_shared(input, mean, var)
171 }
172
173 fn forward_shared<const D: usize>(
174 &self,
175 x: Tensor<B, D>,
176 mean: Tensor<B, D>,
177 var: Tensor<B, D>,
178 ) -> Tensor<B, D> {
179 let channels = x.dims()[1];
180 let mut shape = [1; D];
181 shape[1] = channels;
182
183 let std = var.add_scalar(self.epsilon).sqrt();
184
185 let x = x.sub(mean);
186 let x = x.div(std);
187
188 let x = x.mul(self.gamma.val().reshape(shape));
189
190 x.add(self.beta.val().reshape(shape))
191 }
192}
193
194impl<B: Backend> ModuleDisplay for BatchNorm<B> {
195 fn custom_settings(&self) -> Option<DisplaySettings> {
196 DisplaySettings::new()
197 .with_new_line_after_attribute(false)
198 .optional()
199 }
200
201 fn custom_content(&self, content: Content) -> Option<Content> {
202 let [num_features] = self.beta.shape().dims();
203
204 content
205 .add("num_features", &num_features)
206 .add("momentum", &self.momentum)
207 .add("epsilon", &self.epsilon)
208 .optional()
209 }
210}
211
212#[cfg(feature = "std")]
213#[cfg(test)]
214mod tests_1d {
215 use super::*;
216 use crate::TestAutodiffBackend;
217 use burn::module::AutodiffModule;
218 use burn::tensor::TensorData;
219 use burn::tensor::{Tolerance, ops::FloatElem};
220 type FT = FloatElem<TestAutodiffBackend>;
221
222 #[test]
223 fn batch_norm_forward_train() {
224 let device = Default::default();
225 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
226
227 let output = module.forward(input_tensor(&device));
228
229 let expected = TensorData::from([
230 [
231 [1.1483e+00, 3.7521e-01],
232 [1.6272e-03, 7.5067e-01],
233 [1.6204e+00, -4.5168e-02],
234 ],
235 [
236 [6.8856e-02, -1.5923e+00],
237 [-1.6318e+00, 8.7949e-01],
238 [-5.3368e-01, -1.0416e+00],
239 ],
240 ]);
241 output
242 .to_data()
243 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
244 }
245
246 #[test]
247 fn batch_norm_forward_inference() {
248 let device = Default::default();
249 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
250
251 module.forward(input_tensor(&device));
252 let module = module.valid();
253 let output = module.forward(input_tensor(&device));
254
255 let expected = TensorData::from([
256 [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
257 [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
258 ]);
259 output
260 .to_data()
261 .assert_approx_eq::<FT>(&expected, Tolerance::default());
262 }
263
264 fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
265 Tensor::<B, 3>::from_floats(
266 [
267 [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
268 [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
269 ],
270 device,
271 )
272 }
273}
274
275#[cfg(feature = "std")]
276#[cfg(test)]
277mod tests_2d {
278 use super::*;
279 use crate::TestAutodiffBackend;
280 use burn::module::AutodiffModule;
281 use burn::tensor::TensorData;
282 use burn::tensor::{Tolerance, ops::FloatElem};
283 type FT = FloatElem<TestAutodiffBackend>;
284
285 #[test]
286 fn batch_norm_forward_train() {
287 let device = Default::default();
288 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
289
290 let output = module.forward(input_tensor(&device));
291
292 let expected = TensorData::from([
293 [
294 [[1.5136, 0.7506], [-1.2216, 0.1477]],
295 [[0.3135, 1.2252], [-0.4150, 0.6130]],
296 [[1.4186, 0.3372], [-1.5183, 1.5262]],
297 ],
298 [
299 [[0.4483, -1.1914], [-1.2010, 0.7537]],
300 [[-1.6752, 1.3822], [-0.5058, -0.9381]],
301 [[0.0200, -0.3097], [-0.5715, -0.9026]],
302 ],
303 ]);
304 output
305 .to_data()
306 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
307 }
308
309 #[test]
310 fn batch_norm_forward_inference() {
311 let device = Default::default();
312 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
313
314 module.forward(input_tensor(&device));
315 let module = module.valid();
316 let output = module.forward(input_tensor(&device));
317
318 let expected = TensorData::from([
319 [
320 [[0.9538, 0.7103], [0.0808, 0.5179]],
321 [[0.6015, 0.8910], [0.3703, 0.6966]],
322 [[0.9171, 0.6912], [0.3037, 0.9395]],
323 ],
324 [
325 [[0.6138, 0.0904], [0.0874, 0.7113]],
326 [[-0.0297, 0.9408], [0.3415, 0.2042]],
327 [[0.6250, 0.5561], [0.5013, 0.4323]],
328 ],
329 ]);
330 output
331 .to_data()
332 .assert_approx_eq::<FT>(&expected, Tolerance::default());
333 }
334
335 #[test]
336 fn batch_norm_running_mean() {
337 let device = Default::default();
338 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
339
340 let _output = module.forward(input_tensor(&device));
341
342 let running_mean = module.running_mean.value_sync();
343
344 let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
345 running_mean
346 .reshape([3])
347 .into_data()
348 .assert_approx_eq::<FT>(&expected, Tolerance::default());
349 }
350
351 #[test]
352 fn batch_norm_running_var() {
353 let device = Default::default();
354 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
355
356 let _output = module.forward(input_tensor(&device));
357
358 let running_var = module.running_var.value_sync();
359
360 let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
361 running_var
362 .reshape([3])
363 .into_data()
364 .assert_approx_eq::<FT>(&expected, Tolerance::default());
365 }
366
367 #[test]
368 fn batch_norm_running_mean_inner_module() {
369 let device = Default::default();
370 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
371
372 let _output = module.forward(input_tensor(&device));
373
374 let module_valid = module.valid();
375 let running_mean = module_valid.running_mean.value();
376 let running_mean_after = module.running_mean.value();
377
378 running_mean_after
379 .into_data()
380 .assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());
381 }
382
383 #[test]
384 fn batch_norm_grads() {
385 let device = Default::default();
386 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);
387 let input = input_tensor(&device).require_grad();
388
389 let output = module.forward(input.clone());
390
391 let grads = output.backward();
392
393 let tolerance = Tolerance::rel_abs(0.1, 0.001);
394 let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
395 module
396 .gamma
397 .grad(&grads)
398 .unwrap()
399 .reshape([3])
400 .into_data()
401 .assert_approx_eq::<FT>(&expected, tolerance);
402
403 let expected = TensorData::from([8., 8., 8.]);
404 module
405 .beta
406 .grad(&grads)
407 .unwrap()
408 .reshape([3])
409 .into_data()
410 .assert_approx_eq::<FT>(&expected, tolerance);
411
412 let expected = TensorData::from([
413 [
414 [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
415 [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
416 [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
417 ],
418 [
419 [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
420 [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
421 [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
422 ],
423 ]);
424 input
425 .grad(&grads)
426 .unwrap()
427 .into_data()
428 .assert_approx_eq::<FT>(&expected, tolerance);
429 }
430
431 fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
432 Tensor::<B, 4>::from_floats(
433 [
434 [
435 [[0.9601, 0.7277], [0.1270, 0.5441]],
436 [[0.6272, 0.9034], [0.4066, 0.7179]],
437 [[0.9378, 0.7230], [0.3544, 0.9591]],
438 ],
439 [
440 [[0.6356, 0.1362], [0.1333, 0.7287]],
441 [[0.0249, 0.9509], [0.3791, 0.2481]],
442 [[0.6600, 0.5945], [0.5424, 0.4767]],
443 ],
444 ],
445 device,
446 )
447 }
448
449 #[test]
450 fn display() {
451 let batch_norm = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&Default::default());
452
453 assert_eq!(
454 format!("{batch_norm}"),
455 "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
456 );
457 }
458}