use crate as burn;
use crate::nn::Initializer;
use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
#[derive(Debug, Config)]
pub struct GroupNormConfig {
pub num_groups: usize,
pub num_channels: usize,
#[config(default = 1e-5)]
pub epsilon: f64,
#[config(default = true)]
pub affine: bool,
}
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct GroupNorm<B: Backend> {
pub gamma: Option<Param<Tensor<B, 1>>>,
pub beta: Option<Param<Tensor<B, 1>>>,
pub num_groups: usize,
pub num_channels: usize,
pub epsilon: f64,
pub affine: bool,
}
impl<B: Backend> ModuleDisplay for GroupNorm<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("num_groups", &self.num_groups)
.add("num_channels", &self.num_channels)
.add("epsilon", &self.epsilon)
.add("affine", &self.affine)
.optional()
}
}
impl GroupNormConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
assert_eq!(
self.num_channels % self.num_groups,
0,
"The number of channels must be divisible by the number of groups"
);
let (gamma, beta) = if self.affine {
let gamma = Initializer::Ones.init([self.num_channels], device);
let beta = Initializer::Zeros.init([self.num_channels], device);
(Some(gamma), Some(beta))
} else {
(None, None)
};
GroupNorm {
num_groups: self.num_groups,
num_channels: self.num_channels,
gamma,
beta,
epsilon: self.epsilon,
affine: self.affine,
}
}
}
impl<B: Backend> GroupNorm<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
if input.shape().dims[1] != self.num_channels {
panic!(
"The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}",
self.num_channels,
input.shape().dims[1]
);
}
let gamma = self.gamma.as_ref().map(|x| x.val());
let beta = self.beta.as_ref().map(|x| x.val());
group_norm(
input,
gamma,
beta,
self.num_groups,
self.epsilon,
self.affine,
)
}
}
pub(crate) fn group_norm<B: Backend, const D: usize>(
input: Tensor<B, D>,
gamma: Option<Tensor<B, 1>>,
beta: Option<Tensor<B, 1>>,
num_groups: usize,
epsilon: f64,
affine: bool,
) -> Tensor<B, D> {
if (beta.is_none() || gamma.is_none()) && affine {
panic!("Affine is set to true, but gamma or beta is None");
}
let shape = input.shape();
if shape.num_elements() <= 2 {
panic!(
"input rank for GroupNorm should be at least 3, but got {}",
shape.num_elements()
);
}
let batch_size = shape.dims[0];
let num_channels = shape.dims[1];
let hidden_size = shape.dims[2..].iter().product::<usize>() * num_channels / num_groups;
let input = input.reshape([batch_size, num_groups, hidden_size]);
let mean = input.clone().sum_dim(2) / hidden_size as f64;
let input = input.sub(mean);
let var = input.clone().powf_scalar(2.).sum_dim(2) / hidden_size as f64;
let input_normalized = input.div(var.sqrt().add_scalar(epsilon));
if affine {
let mut affine_shape = [1; D];
affine_shape[1] = num_channels;
input_normalized
.reshape(shape)
.mul(gamma.clone().unwrap().reshape(affine_shape))
.add(beta.clone().unwrap().reshape(affine_shape))
} else {
input_normalized.reshape(shape)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::TensorData;
use crate::TestBackend;
use alloc::format;
#[test]
fn group_norm_forward_affine_false() {
let device = Default::default();
let module = GroupNormConfig::new(2, 6)
.with_affine(false)
.init::<TestBackend>(&device);
assert!(module.gamma.is_none());
assert!(module.beta.is_none());
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([
[
[-0.3034, 0.2726, -0.9659],
[-1.1845, -1.3236, 0.0172],
[1.9507, 1.2554, -0.8625],
[1.0682, 0.3604, 0.3985],
[-0.4957, -0.4461, -0.9721],
[1.5157, -0.1546, -0.5596],
],
[
[-1.6698, -0.4040, -0.7927],
[0.3736, -0.0975, -0.1351],
[-0.9461, 0.5461, -0.6334],
[-1.0919, -0.1158, 0.1213],
[-0.9535, 0.1281, 0.4372],
[-0.2845, 0.3488, 0.5641],
],
]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([
[
[-0.1653, 0.3748, -0.7866],
[-0.9916, -1.1220, 0.1353],
[1.9485, 1.2965, -0.6896],
[1.2769, 0.3628, 0.4120],
[-0.7427, -0.6786, -1.3578],
[1.8547, -0.3022, -0.8252],
],
[
[-1.9342, 0.0211, -0.5793],
[1.2223, 0.4945, 0.4365],
[-0.8163, 1.4887, -0.3333],
[-1.7960, -0.0392, 0.3875],
[-1.5469, 0.3998, 0.9561],
[-0.3428, 0.7970, 1.1845],
],
]);
output.to_data().assert_approx_eq(&expected, 3);
}
#[test]
fn group_norm_forward_affine_true() {
let device = Default::default();
let module = GroupNormConfig::new(3, 6)
.with_affine(true)
.init::<TestBackend>(&device);
module
.gamma
.as_ref()
.expect("gamma should not be None")
.val()
.to_data()
.assert_approx_eq(&TensorData::ones::<f32, _>([6]), 3);
module
.beta
.as_ref()
.expect("beta should not be None")
.val()
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>([6]), 3);
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([
[
[0.3345, 0.4429, 0.6639],
[0.5041, 0.4175, 0.8437],
[0.6159, 0.3758, 0.4071],
[0.5417, 0.5785, 0.7671],
[0.3837, 0.9883, 0.0420],
[0.4808, 0.8989, 0.6144],
],
[
[0.3930, 0.2098, 0.0602],
[0.2298, 0.9425, 0.0333],
[0.7409, 0.8172, 0.8879],
[0.4846, 0.0486, 0.2029],
[0.6741, 0.9765, 0.6864],
[0.2827, 0.5534, 0.2125],
],
]),
&device,
);
let output = module.forward(input);
let expected = TensorData::from([
[
[-1.1694, -0.5353, 0.7572],
[-0.1775, -0.6838, 1.8087],
[0.5205, -1.3107, -1.0723],
[-0.0459, 0.2351, 1.6734],
[-0.5796, 1.3218, -1.6544],
[-0.2744, 1.0406, 0.1459],
],
[
[0.2665, -0.3320, -0.8205],
[-0.2667, 2.0612, -0.9085],
[0.6681, 0.9102, 1.1345],
[-0.1453, -1.5287, -1.0389],
[0.4253, 1.5962, 0.4731],
[-1.0903, -0.0419, -1.3623],
],
]);
output.to_data().assert_approx_eq(&expected, 3);
}
#[test]
fn display() {
let config = GroupNormConfig::new(3, 6);
let group_norm = config.init::<TestBackend>(&Default::default());
assert_eq!(
format!("{}", group_norm),
"GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}"
);
}
}