use crate as burn;
use crate::{
config::Config,
module::{Module, Param, RunningState},
tensor::{backend::Backend, Tensor},
};
#[derive(Config)]
pub struct BatchNormConfig {
pub num_features: usize,
#[config(default = 1e-5)]
pub epsilon: f64,
#[config(default = 0.1)]
pub momentum: f64,
}
#[derive(Module, Debug)]
pub struct BatchNorm<B: Backend, const D: usize> {
gamma: Param<Tensor<B, 1>>,
beta: Param<Tensor<B, 1>>,
running_mean: RunningState<Tensor<B, 1>>,
running_var: RunningState<Tensor<B, 1>>,
momentum: f64,
epsilon: f64,
}
impl BatchNormConfig {
pub fn init<B: Backend, const D: usize>(&self) -> BatchNorm<B, D> {
let gamma = Tensor::ones([self.num_features]);
let beta = Tensor::zeros([self.num_features]);
let running_mean = Tensor::zeros([self.num_features]);
let running_var = Tensor::ones([self.num_features]);
BatchNorm {
gamma: Param::from(gamma),
beta: Param::from(beta),
running_mean: RunningState::new(running_mean),
running_var: RunningState::new(running_var),
momentum: self.momentum,
epsilon: self.epsilon,
}
}
pub fn init_with<B: Backend, const D: usize>(
&self,
record: BatchNormRecord<B, D>,
) -> BatchNorm<B, D> {
BatchNorm {
gamma: record.gamma,
beta: record.beta,
running_mean: RunningState::from_record(record.running_mean),
running_var: RunningState::from_record(record.running_var),
momentum: self.momentum,
epsilon: self.epsilon,
}
}
}
impl<const D: usize, B: Backend> BatchNorm<B, D> {
pub fn forward<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
if D + 2 != DI {
panic!("BatchNorm{}D can only be applied on tensors of size {} with the following shape [batch_size, channels, ...], received {}D tensor", D, D+2, DI);
}
match B::ad_enabled() {
true => self.forward_train(input),
false => self.forward_inference(input),
}
}
fn forward_inference<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
let channels = input.dims()[1];
let mean = self.running_mean.value();
let var = self.running_var.value();
let mut shape = [1; DI];
shape[1] = channels;
self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
}
fn forward_train<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
let dims = input.dims();
let batch_size = dims[0];
let channels = dims[1];
let mut shape_unsqueeze = [1; DI];
let mut flatten_size = batch_size;
shape_unsqueeze[1] = channels;
for dim in dims.iter().take(DI).skip(2) {
flatten_size *= dim;
}
let mean = input
.clone()
.swap_dims(0, 1)
.reshape([channels, flatten_size])
.mean_dim(1)
.reshape(shape_unsqueeze);
let var = input
.clone()
.sub(mean.clone())
.powf(2.0)
.swap_dims(0, 1)
.reshape([channels, flatten_size])
.mean_dim(1)
.reshape(shape_unsqueeze);
let running_mean = self.running_mean.value_sync();
let running_var = self.running_var.value_sync();
let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
mean.clone()
.detach()
.mul_scalar(self.momentum)
.reshape([channels]),
);
let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
var.clone()
.detach()
.mul_scalar(self.momentum)
.reshape([channels]),
);
self.running_mean.update(running_mean.detach());
self.running_var.update(running_var.detach());
self.forward_shared(input, mean, var)
}
fn forward_shared<const DI: usize>(
&self,
x: Tensor<B, DI>,
mean: Tensor<B, DI>,
var: Tensor<B, DI>,
) -> Tensor<B, DI> {
let channels = x.dims()[1];
let mut shape = [1; DI];
shape[1] = channels;
let std = var.add_scalar(self.epsilon).sqrt();
let x = x.sub(mean);
let x = x.div(std);
let x = x.mul(self.gamma.val().reshape(shape));
x.add(self.beta.val().reshape(shape))
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests_1d {
use super::*;
use crate::{module::ADModule, TestADBackend};
use burn_tensor::Data;
#[test]
fn batch_norm_forward_train() {
let module = BatchNormConfig::new(3).init::<TestADBackend, 1>();
let output = module.forward(input_tensor());
output.to_data().assert_approx_eq(
&Data::from([
[
[1.1483e+00, 3.7521e-01],
[1.6272e-03, 7.5067e-01],
[1.6204e+00, -4.5168e-02],
],
[
[6.8856e-02, -1.5923e+00],
[-1.6318e+00, 8.7949e-01],
[-5.3368e-01, -1.0416e+00],
],
]),
2,
);
}
#[test]
fn batch_norm_forward_inference() {
let module = BatchNormConfig::new(3).init::<TestADBackend, 1>();
module.forward(input_tensor());
let module = module.valid();
let output = module.forward(input_tensor());
output.to_data().assert_approx_eq(
&Data::from([
[[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],
[[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],
]),
2,
);
}
fn input_tensor<B: Backend>() -> Tensor<B, 3> {
Tensor::<B, 3>::from_floats([
[[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],
[[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],
])
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests_2d {
use super::*;
use crate::{module::ADModule, TestADBackend};
use burn_tensor::Data;
#[test]
fn batch_norm_forward_train() {
let module = BatchNormConfig::new(3).init::<TestADBackend, 2>();
let output = module.forward(input_tensor());
output.to_data().assert_approx_eq(
&Data::from([
[
[[1.5136, 0.7506], [-1.2216, 0.1477]],
[[0.3135, 1.2252], [-0.4150, 0.6130]],
[[1.4186, 0.3372], [-1.5183, 1.5262]],
],
[
[[0.4483, -1.1914], [-1.2010, 0.7537]],
[[-1.6752, 1.3822], [-0.5058, -0.9381]],
[[0.0200, -0.3097], [-0.5715, -0.9026]],
],
]),
2,
);
}
#[test]
fn batch_norm_forward_inference() {
let module = BatchNormConfig::new(3).init::<TestADBackend, 2>();
module.forward(input_tensor());
let module = module.valid();
let output = module.forward(input_tensor());
output.to_data().assert_approx_eq(
&Data::from([
[
[[0.9538, 0.7103], [0.0808, 0.5179]],
[[0.6015, 0.8910], [0.3703, 0.6966]],
[[0.9171, 0.6912], [0.3037, 0.9395]],
],
[
[[0.6138, 0.0904], [0.0874, 0.7113]],
[[-0.0297, 0.9408], [0.3415, 0.2042]],
[[0.6250, 0.5561], [0.5013, 0.4323]],
],
]),
2,
);
}
#[test]
fn batch_norm_running_mean() {
let module = BatchNormConfig::new(3).init::<TestADBackend, 2>();
let _output = module.forward(input_tensor());
let running_mean = module.running_mean.value_sync();
running_mean
.reshape([3])
.into_data()
.assert_approx_eq(&Data::from([0.0499, 0.0532, 0.0656]), 2);
}
#[test]
fn batch_norm_running_var() {
let module = BatchNormConfig::new(3).init::<TestADBackend, 2>();
let _output = module.forward(input_tensor());
let running_var = module.running_var.value_sync();
running_var
.reshape([3])
.into_data()
.assert_approx_eq(&Data::from([0.9106, 0.9105, 0.9045]), 2);
}
#[test]
fn batch_norm_running_mean_inner_module() {
let module = BatchNormConfig::new(3).init::<TestADBackend, 2>();
let _output = module.forward(input_tensor());
let module_valid = module.valid();
let running_mean = module_valid.running_mean.value();
let running_mean_after = module.running_mean.value();
running_mean_after
.into_data()
.assert_approx_eq(&running_mean.into_data(), 3);
}
#[test]
fn batch_norm_grads() {
let module = BatchNormConfig::new(3).init::<TestADBackend, 2>();
let input = input_tensor().require_grad();
let output = module.forward(input.clone());
let grads = output.backward();
module
.gamma
.grad(&grads)
.unwrap()
.reshape([3])
.into_data()
.assert_approx_eq(&Data::from([0.0000e+00, -5.9035e-07, -6.0011e-07]), 3);
module
.beta
.grad(&grads)
.unwrap()
.reshape([3])
.into_data()
.assert_approx_eq(&Data::from([8., 8., 8.]), 3);
input.grad(&grads).unwrap().into_data().assert_approx_eq(
&Data::from([
[
[[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
[[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],
[[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],
],
[
[[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],
[[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],
[[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],
],
]),
4,
);
}
fn input_tensor<B: Backend>() -> Tensor<B, 4> {
Tensor::<B, 4>::from_floats([
[
[[0.9601, 0.7277], [0.1270, 0.5441]],
[[0.6272, 0.9034], [0.4066, 0.7179]],
[[0.9378, 0.7230], [0.3544, 0.9591]],
],
[
[[0.6356, 0.1362], [0.1333, 0.7287]],
[[0.0249, 0.9509], [0.3791, 0.2481]],
[[0.6600, 0.5945], [0.5424, 0.4767]],
],
])
}
}