1use crate as burn;
2use crate::module::{Content, DisplaySettings, ModuleDisplay};
3
4use crate::nn::Initializer;
5use crate::{
6 config::Config,
7 module::{Module, Param, RunningState},
8 tensor::{backend::Backend, Tensor},
9};
10
11#[derive(Config, Debug)]
13pub struct BatchNormConfig {
14 pub num_features: usize,
16 #[config(default = 1e-5)]
18 pub epsilon: f64,
19 #[config(default = 0.1)]
21 pub momentum: f64,
22}
23
24#[derive(Module, Debug)]
37#[module(custom_display)]
38pub struct BatchNorm<B: Backend, const D: usize> {
39 pub gamma: Param<Tensor<B, 1>>,
41 pub beta: Param<Tensor<B, 1>>,
43 pub running_mean: RunningState<Tensor<B, 1>>,
45 pub running_var: RunningState<Tensor<B, 1>>,
47 pub momentum: f64,
49 pub epsilon: f64,
51}
52
53impl BatchNormConfig {
54 pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> BatchNorm<B, D> {
56 let gamma = Initializer::Ones.init([self.num_features], device);
57 let beta = Initializer::Zeros.init([self.num_features], device);
58
59 let running_mean = Tensor::zeros([self.num_features], device);
60 let running_var = Tensor::ones([self.num_features], device);
61
62 BatchNorm {
63 gamma,
64 beta,
65 running_mean: RunningState::new(running_mean),
66 running_var: RunningState::new(running_var),
67 momentum: self.momentum,
68 epsilon: self.epsilon,
69 }
70 }
71}
72
73impl<const D: usize, B: Backend> BatchNorm<B, D> {
74 pub fn forward<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
87 if D + 2 != DI {
90 panic!(
91 "BatchNorm{}D can only be applied on tensors of size {} with the following shape \
92 [batch_size, channels, ...], received {}D tensor",
93 D,
94 D + 2,
95 DI
96 );
97 }
98
99 match B::ad_enabled() {
100 true => self.forward_train(input),
101 false => self.forward_inference(input),
102 }
103 }
104
105 fn forward_inference<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
106 let device = input.device();
107 let channels = input.dims()[1];
108 let mean = self.running_mean.value().to_device(&device);
109 let var = self.running_var.value().to_device(&device);
110
111 let mut shape = [1; DI];
112 shape[1] = channels;
113
114 self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
115 }
116
117 fn forward_train<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
118 let device = input.device();
119 let dims = input.dims();
120 let batch_size = dims[0];
121 let channels = dims[1];
122
123 let mut shape_unsqueeze = [1; DI];
124 let mut flatten_size = batch_size;
125 shape_unsqueeze[1] = channels;
126
127 for dim in dims.iter().take(DI).skip(2) {
128 flatten_size *= dim;
129 }
130
131 let mean = input
132 .clone()
133 .swap_dims(0, 1)
134 .reshape([channels, flatten_size])
135 .mean_dim(1)
136 .reshape(shape_unsqueeze);
137
138 let var = input
139 .clone()
140 .sub(mean.clone())
141 .powf_scalar(2.0)
142 .swap_dims(0, 1)
143 .reshape([channels, flatten_size])
144 .mean_dim(1)
145 .reshape(shape_unsqueeze);
146
147 let running_mean = self.running_mean.value_sync().to_device(&device);
148 let running_var = self.running_var.value_sync().to_device(&device);
149
150 let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
151 mean.clone()
152 .detach()
153 .mul_scalar(self.momentum)
154 .reshape([channels]),
155 );
156 let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
157 var.clone()
158 .detach()
159 .mul_scalar(self.momentum)
160 .reshape([channels]),
161 );
162
163 self.running_mean.update(running_mean.detach());
164 self.running_var.update(running_var.detach());
165
166 self.forward_shared(input, mean, var)
167 }
168
169 fn forward_shared<const DI: usize>(
170 &self,
171 x: Tensor<B, DI>,
172 mean: Tensor<B, DI>,
173 var: Tensor<B, DI>,
174 ) -> Tensor<B, DI> {
175 let channels = x.dims()[1];
176 let mut shape = [1; DI];
177 shape[1] = channels;
178
179 let std = var.add_scalar(self.epsilon).sqrt();
180
181 let x = x.sub(mean);
182 let x = x.div(std);
183
184 let x = x.mul(self.gamma.val().reshape(shape));
185
186 x.add(self.beta.val().reshape(shape))
187 }
188}
189
190impl<const D: usize, B: Backend> ModuleDisplay for BatchNorm<B, D> {
191 fn custom_settings(&self) -> Option<DisplaySettings> {
192 DisplaySettings::new()
193 .with_new_line_after_attribute(false)
194 .optional()
195 }
196
197 fn custom_content(&self, content: Content) -> Option<Content> {
198 let [num_features] = self.beta.shape().dims();
199
200 content
201 .add("num_features", &num_features)
202 .add("momentum", &self.momentum)
203 .add("epsilon", &self.epsilon)
204 .optional()
205 }
206}
207
208#[cfg(feature = "std")]
209#[cfg(test)]
210mod tests_1d {
211 use super::*;
212 use crate::tensor::TensorData;
213 use crate::{module::AutodiffModule, TestAutodiffBackend};
214
215 #[test]
216 fn batch_norm_forward_train() {
217 let device = Default::default();
218 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 1>(&device);
219
220 let output = module.forward(input_tensor(&device));
221
222 let expected = TensorData::from([
223 [
224 [1.1483e+00, 3.7521e-01],
225 [1.6272e-03, 7.5067e-01],
226 [1.6204e+00, -4.5168e-02],
227 ],
228 [
229 [6.8856e-02, -1.5923e+00],
230 [-1.6318e+00, 8.7949e-01],
231 [-5.3368e-01, -1.0416e+00],
232 ],
233 ]);
234 output.to_data().assert_approx_eq(&expected, 2);
235 }
236
237 #[test]
238 fn batch_norm_forward_inference() {
239 let device = Default::default();
240 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 1>(&device);
241
242 module.forward(input_tensor(&device));
243 let module = module.valid();
244 let output = module.forward(input_tensor(&device));
245
246 let expected = TensorData::from([
247 [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
248 [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
249 ]);
250 output.to_data().assert_approx_eq(&expected, 2);
251 }
252
253 fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {
254 Tensor::<B, 3>::from_floats(
255 [
256 [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
257 [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
258 ],
259 device,
260 )
261 }
262}
263
264#[cfg(feature = "std")]
265#[cfg(test)]
266mod tests_2d {
267 use super::*;
268 use crate::tensor::TensorData;
269 use crate::{module::AutodiffModule, TestAutodiffBackend};
270
271 #[test]
272 fn batch_norm_forward_train() {
273 let device = Default::default();
274 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
275
276 let output = module.forward(input_tensor(&device));
277
278 let expected = TensorData::from([
279 [
280 [[1.5136, 0.7506], [-1.2216, 0.1477]],
281 [[0.3135, 1.2252], [-0.4150, 0.6130]],
282 [[1.4186, 0.3372], [-1.5183, 1.5262]],
283 ],
284 [
285 [[0.4483, -1.1914], [-1.2010, 0.7537]],
286 [[-1.6752, 1.3822], [-0.5058, -0.9381]],
287 [[0.0200, -0.3097], [-0.5715, -0.9026]],
288 ],
289 ]);
290 output.to_data().assert_approx_eq(&expected, 2);
291 }
292
293 #[test]
294 fn batch_norm_forward_inference() {
295 let device = Default::default();
296 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
297
298 module.forward(input_tensor(&device));
299 let module = module.valid();
300 let output = module.forward(input_tensor(&device));
301
302 let expected = TensorData::from([
303 [
304 [[0.9538, 0.7103], [0.0808, 0.5179]],
305 [[0.6015, 0.8910], [0.3703, 0.6966]],
306 [[0.9171, 0.6912], [0.3037, 0.9395]],
307 ],
308 [
309 [[0.6138, 0.0904], [0.0874, 0.7113]],
310 [[-0.0297, 0.9408], [0.3415, 0.2042]],
311 [[0.6250, 0.5561], [0.5013, 0.4323]],
312 ],
313 ]);
314 output.to_data().assert_approx_eq(&expected, 2);
315 }
316
317 #[test]
318 fn batch_norm_running_mean() {
319 let device = Default::default();
320 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
321
322 let _output = module.forward(input_tensor(&device));
323
324 let running_mean = module.running_mean.value_sync();
325
326 let expected = TensorData::from([0.0499, 0.0532, 0.0656]);
327 running_mean
328 .reshape([3])
329 .into_data()
330 .assert_approx_eq(&expected, 2);
331 }
332
333 #[test]
334 fn batch_norm_running_var() {
335 let device = Default::default();
336 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
337
338 let _output = module.forward(input_tensor(&device));
339
340 let running_var = module.running_var.value_sync();
341
342 let expected = TensorData::from([0.9106, 0.9105, 0.9045]);
343 running_var
344 .reshape([3])
345 .into_data()
346 .assert_approx_eq(&expected, 2);
347 }
348
349 #[test]
350 fn batch_norm_running_mean_inner_module() {
351 let device = Default::default();
352 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
353
354 let _output = module.forward(input_tensor(&device));
355
356 let module_valid = module.valid();
357 let running_mean = module_valid.running_mean.value();
358 let running_mean_after = module.running_mean.value();
359
360 running_mean_after
361 .into_data()
362 .assert_approx_eq(&running_mean.into_data(), 3);
363 }
364
365 #[test]
366 fn batch_norm_grads() {
367 let device = Default::default();
368 let module = BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&device);
369 let input = input_tensor(&device).require_grad();
370
371 let output = module.forward(input.clone());
372
373 let grads = output.backward();
374
375 let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);
376 module
377 .gamma
378 .grad(&grads)
379 .unwrap()
380 .reshape([3])
381 .into_data()
382 .assert_approx_eq(&expected, 3);
383
384 let expected = TensorData::from([8., 8., 8.]);
385 module
386 .beta
387 .grad(&grads)
388 .unwrap()
389 .reshape([3])
390 .into_data()
391 .assert_approx_eq(&expected, 3);
392
393 let expected = TensorData::from([
394 [
395 [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
396 [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
397 [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
398 ],
399 [
400 [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
401 [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
402 [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
403 ],
404 ]);
405 input
406 .grad(&grads)
407 .unwrap()
408 .into_data()
409 .assert_approx_eq(&expected, 4);
410 }
411
412 fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {
413 Tensor::<B, 4>::from_floats(
414 [
415 [
416 [[0.9601, 0.7277], [0.1270, 0.5441]],
417 [[0.6272, 0.9034], [0.4066, 0.7179]],
418 [[0.9378, 0.7230], [0.3544, 0.9591]],
419 ],
420 [
421 [[0.6356, 0.1362], [0.1333, 0.7287]],
422 [[0.0249, 0.9509], [0.3791, 0.2481]],
423 [[0.6600, 0.5945], [0.5424, 0.4767]],
424 ],
425 ],
426 device,
427 )
428 }
429
430 #[test]
431 fn display() {
432 let batch_norm =
433 BatchNormConfig::new(3).init::<TestAutodiffBackend, 2>(&Default::default());
434
435 assert_eq!(
436 format!("{}", batch_norm),
437 "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}"
438 );
439 }
440}