1use crate as burn;
2use crate::module::{Content, DisplaySettings, ModuleDisplay};
3
4use crate::nn::Initializer;
5use crate::{
6 config::Config,
7 module::{Module, Param, RunningState},
8 tensor::{Tensor, backend::Backend},
9};
10
11#[derive(Config, Debug)]
13pub struct BatchNormConfig {
14 pub num_features: usize,
16 #[config(default = 1e-5)]
18 pub epsilon: f64,
19 #[config(default = 0.1)]
21 pub momentum: f64,
22}
23
24#[derive(Module, Debug)]
37#[module(custom_display)]
38pub struct BatchNorm<B: Backend, const D: usize> {
39 pub gamma: Param<Tensor<B, 1>>,
41 pub beta: Param<Tensor<B, 1>>,
43 pub running_mean: RunningState<Tensor<B, 1>>,
45 pub running_var: RunningState<Tensor<B, 1>>,
47 pub momentum: f64,
49 pub epsilon: f64,
51}
52
53impl BatchNormConfig {
54 pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> BatchNorm<B, D> {
56 let gamma = Initializer::Ones.init([self.num_features], device);
57 let beta = Initializer::Zeros.init([self.num_features], device);
58
59 let running_mean = Tensor::zeros([self.num_features], device);
60 let running_var = Tensor::ones([self.num_features], device);
61
62 BatchNorm {
63 gamma,
64 beta,
65 running_mean: RunningState::new(running_mean),
66 running_var: RunningState::new(running_var),
67 momentum: self.momentum,
68 epsilon: self.epsilon,
69 }
70 }
71}
72
73impl<const D: usize, B: Backend> BatchNorm<B, D> {
74 pub fn forward<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
87 if D + 2 != DI {
90 panic!(
91 "BatchNorm{}D can only be applied on tensors of size {} with the following shape \
92 [batch_size, channels, ...], received {}D tensor",
93 D,
94 D + 2,
95 DI
96 );
97 }
98
99 match B::ad_enabled() {
100 true => self.forward_train(input),
101 false => self.forward_inference(input),
102 }
103 }
104
105 fn forward_inference<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
106 let device = input.device();
107 let channels = input.dims()[1];
108 let mean = self.running_mean.value().to_device(&device);
109 let var = self.running_var.value().to_device(&device);
110
111 let mut shape = [1; DI];
112 shape[1] = channels;
113
114 self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
115 }
116
117 fn forward_train<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
118 let device = input.device();
119 let dims = input.dims();
120 let batch_size = dims[0];
121 let channels = dims[1];
122
123 let mut shape_unsqueeze = [1; DI];
124 let mut flatten_size = batch_size;
125 shape_unsqueeze[1] = channels;
126
127 for dim in dims.iter().take(DI).skip(2) {
128 flatten_size *= dim;
129 }
130
131 let mean = input
132 .clone()
133 .swap_dims(0, 1)
134 .reshape([channels, flatten_size])
135 .mean_dim(1)
136 .reshape(shape_unsqueeze);
137
138 let var = input
139 .clone()
140 .sub(mean.clone())
141 .powi_scalar(2)
142 .swap_dims(0, 1)
143 .reshape([channels, flatten_size])
144 .mean_dim(1)
145 .reshape(shape_unsqueeze);
146
147 let running_mean = self.running_mean.value_sync().to_device(&device);
148 let running_var = self.running_var.value_sync().to_device(&device);
149
150 let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
151 mean.clone()
152 .detach()
153 .mul_scalar(self.momentum)
154 .reshape([channels]),
155 );
156 let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
157 var.clone()
158 .detach()
159 .mul_scalar(self.momentum)
160 .reshape([channels]),
161 );
162
163 self.running_mean.update(running_mean.detach());
164 self.running_var.update(running_var.detach());
165
166 self.forward_shared(input, mean, var)
167 }
168
169 fn forward_shared<const DI: usize>(
170 &self,
171 x: Tensor<B, DI>,
172 mean: Tensor<B, DI>,
173 var: Tensor<B, DI>,
174 ) -> Tensor<B, DI> {
175 let channels = x.dims()[1];
176 let mut shape = [1; DI];
177 shape[1] = channels;
178
179 let std = var.add_scalar(self.epsilon).sqrt();
180
181 let x = x.sub(mean);
182 let x = x.div(std);
183
184 let x = x.mul(self.gamma.val().reshape(shape));
185
186 x.add(self.beta.val().reshape(shape))
187 }
188}
189
190impl<const D: usize, B: Backend> ModuleDisplay for BatchNorm<B, D> {
191 fn custom_settings(&self) -> Option<DisplaySettings> {
192 DisplaySettings::new()
193 .with_new_line_after_attribute(false)
194 .optional()
195 }
196
197 fn custom_content(&self, content: Content) -> Option<Content> {
198 let [num_features] = self.beta.shape().dims();
199
200 content
201 .add("num_features", &num_features)
202 .add("momentum", &self.momentum)
203 .add("epsilon", &self.epsilon)
204 .optional()
205 }
206}
207
208#[cfg(feature = "std")]
209#[cfg(test)]
210mod tests_1d {
211 use super::*;
212 use crate::tensor::TensorData;
213 use crate::{TestAutodiffBackend, module::AutodiffModule};
214 use burn_tensor::{Tolerance, ops::FloatElem};
215 type FT = FloatElem<TestAutodiffBackend>;
216
217 #[test]
218 fn batch_norm_forward_train() {
219 let device = Default::default();
220 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 1>(&device);
221
222 let output = module.forward(input_tensor(&device));
223
224 let expected = TensorData::from([
225 [
226 [1.1483e+00, 3.7521e-01],
227 [1.6272e-03, 7.5067e-01],
228 [1.6204e+00, -4.5168e-02],
229 ],
230 [
231 [6.8856e-02, -1.5923e+00],
232 [-1.6318e+00, 8.7949e-01],
233 [-5.3368e-01, -1.0416e+00],
234 ],
235 ]);
236 output
237 .to_data()
238 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
239 }
240
241 #[test]
242 fn batch_norm_forward_inference() {
243 let device = Default::default();
244 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 1>(&device);
245
246 module.forward(input_tensor(&device));
247 let module = module.valid();
248 let output = module.forward(input_tensor(&device));
249
250 let expected = TensorData::from([
251 [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
252 [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
253 ]);
254 output
255 .to_data()
256 .assert_approx_eq::<FT>(&expected, Tolerance::default());
257 }
258
259 fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
260 Tensor::<B, 3>::from_floats(
261 [
262 [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
263 [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
264 ],
265 device,
266 )
267 }
268}
269
270#[cfg(feature = "std")]
271#[cfg(test)]
272mod tests_2d {
273 use super::*;
274 use crate::tensor::TensorData;
275 use crate::{TestAutodiffBackend, module::AutodiffModule};
276 use burn_tensor::{Tolerance, ops::FloatElem};
277 type FT = FloatElem<TestAutodiffBackend>;
278
279 #[test]
280 fn batch_norm_forward_train() {
281 let device = Default::default();
282 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
283
284 let output = module.forward(input_tensor(&device));
285
286 let expected = TensorData::from([
287 [
288 [[1.5136, 0.7506], [-1.2216, 0.1477]],
289 [[0.3135, 1.2252], [-0.4150, 0.6130]],
290 [[1.4186, 0.3372], [-1.5183, 1.5262]],
291 ],
292 [
293 [[0.4483, -1.1914], [-1.2010, 0.7537]],
294 [[-1.6752, 1.3822], [-0.5058, -0.9381]],
295 [[0.0200, -0.3097], [-0.5715, -0.9026]],
296 ],
297 ]);
298 output
299 .to_data()
300 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));
301 }
302
303 #[test]
304 fn batch_norm_forward_inference() {
305 let device = Default::default();
306 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
307
308 module.forward(input_tensor(&device));
309 let module = module.valid();
310 let output = module.forward(input_tensor(&device));
311
312 let expected = TensorData::from([
313 [
314 [[0.9538, 0.7103], [0.0808, 0.5179]],
315 [[0.6015, 0.8910], [0.3703, 0.6966]],
316 [[0.9171, 0.6912], [0.3037, 0.9395]],
317 ],
318 [
319 [[0.6138, 0.0904], [0.0874, 0.7113]],
320 [[-0.0297, 0.9408], [0.3415, 0.2042]],
321 [[0.6250, 0.5561], [0.5013, 0.4323]],
322 ],
323 ]);
324 output
325 .to_data()
326 .assert_approx_eq::<FT>(&expected, Tolerance::default());
327 }
328
329 #[test]
330 fn batch_norm_running_mean() {
331 let device = Default::default();
332 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
333
334 let _output = module.forward(input_tensor(&device));
335
336 let running_mean = module.running_mean.value_sync();
337
338 let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
339 running_mean
340 .reshape([3])
341 .into_data()
342 .assert_approx_eq::<FT>(&expected, Tolerance::default());
343 }
344
345 #[test]
346 fn batch_norm_running_var() {
347 let device = Default::default();
348 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
349
350 let _output = module.forward(input_tensor(&device));
351
352 let running_var = module.running_var.value_sync();
353
354 let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
355 running_var
356 .reshape([3])
357 .into_data()
358 .assert_approx_eq::<FT>(&expected, Tolerance::default());
359 }
360
361 #[test]
362 fn batch_norm_running_mean_inner_module() {
363 let device = Default::default();
364 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
365
366 let _output = module.forward(input_tensor(&device));
367
368 let module_valid = module.valid();
369 let running_mean = module_valid.running_mean.value();
370 let running_mean_after = module.running_mean.value();
371
372 running_mean_after
373 .into_data()
374 .assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());
375 }
376
377 #[test]
378 fn batch_norm_grads() {
379 let device = Default::default();
380 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
381 let input = input_tensor(&device).require_grad();
382
383 let output = module.forward(input.clone());
384
385 let grads = output.backward();
386
387 let tolerance = Tolerance::rel_abs(0.1, 0.001);
388 let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
389 module
390 .gamma
391 .grad(&grads)
392 .unwrap()
393 .reshape([3])
394 .into_data()
395 .assert_approx_eq::<FT>(&expected, tolerance);
396
397 let expected = TensorData::from([8., 8., 8.]);
398 module
399 .beta
400 .grad(&grads)
401 .unwrap()
402 .reshape([3])
403 .into_data()
404 .assert_approx_eq::<FT>(&expected, tolerance);
405
406 let expected = TensorData::from([
407 [
408 [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
409 [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
410 [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
411 ],
412 [
413 [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
414 [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
415 [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
416 ],
417 ]);
418 input
419 .grad(&grads)
420 .unwrap()
421 .into_data()
422 .assert_approx_eq::<FT>(&expected, tolerance);
423 }
424
425 fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
426 Tensor::<B, 4>::from_floats(
427 [
428 [
429 [[0.9601, 0.7277], [0.1270, 0.5441]],
430 [[0.6272, 0.9034], [0.4066, 0.7179]],
431 [[0.9378, 0.7230], [0.3544, 0.9591]],
432 ],
433 [
434 [[0.6356, 0.1362], [0.1333, 0.7287]],
435 [[0.0249, 0.9509], [0.3791, 0.2481]],
436 [[0.6600, 0.5945], [0.5424, 0.4767]],
437 ],
438 ],
439 device,
440 )
441 }
442
443 #[test]
444 fn display() {
445 let batch_norm =
446 BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&Default::default());
447
448 assert_eq!(
449 format!("{batch_norm}"),
450 "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
451 );
452 }
453}