use burn_core as burn;
use crate::{
BatchNorm, BatchNormConfig, GroupNorm, GroupNormConfig, InstanceNorm, InstanceNormConfig,
LayerNorm, LayerNormConfig, RmsNorm, RmsNormConfig,
};
use burn::prelude::{Config, Module};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
#[derive(Config, Debug)]
#[non_exhaustive]
pub enum NormalizationConfig {
Batch(BatchNormConfig),
Group(GroupNormConfig),
Instance(InstanceNormConfig),
Layer(LayerNormConfig),
Rms(RmsNormConfig),
}
impl From<BatchNormConfig> for NormalizationConfig {
fn from(config: BatchNormConfig) -> Self {
Self::Batch(config)
}
}
impl From<GroupNormConfig> for NormalizationConfig {
fn from(config: GroupNormConfig) -> Self {
Self::Group(config)
}
}
impl From<InstanceNormConfig> for NormalizationConfig {
fn from(config: InstanceNormConfig) -> Self {
Self::Instance(config)
}
}
impl From<LayerNormConfig> for NormalizationConfig {
fn from(config: LayerNormConfig) -> Self {
Self::Layer(config)
}
}
impl From<RmsNormConfig> for NormalizationConfig {
fn from(config: RmsNormConfig) -> Self {
Self::Rms(config)
}
}
impl NormalizationConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Normalization<B> {
match self {
NormalizationConfig::Batch(config) => config.init(device).into(),
NormalizationConfig::Group(config) => config.init(device).into(),
NormalizationConfig::Instance(config) => config.init(device).into(),
NormalizationConfig::Layer(config) => config.init(device).into(),
NormalizationConfig::Rms(config) => config.init(device).into(),
}
}
pub fn with_num_features(self, num_features: usize) -> Self {
match self {
NormalizationConfig::Batch(config) => BatchNormConfig {
num_features,
..config
}
.into(),
NormalizationConfig::Group(config) => GroupNormConfig {
num_channels: num_features,
..config
}
.into(),
NormalizationConfig::Instance(config) => InstanceNormConfig {
num_channels: num_features,
..config
}
.into(),
NormalizationConfig::Layer(config) => LayerNormConfig {
d_model: num_features,
..config
}
.into(),
NormalizationConfig::Rms(config) => RmsNormConfig {
d_model: num_features,
..config
}
.into(),
}
}
pub fn num_features(&self) -> usize {
match self {
NormalizationConfig::Batch(config) => config.num_features,
NormalizationConfig::Group(config) => config.num_channels,
NormalizationConfig::Instance(config) => config.num_channels,
NormalizationConfig::Layer(config) => config.d_model,
NormalizationConfig::Rms(config) => config.d_model,
}
}
}
#[derive(Module, Debug)]
#[non_exhaustive]
pub enum Normalization<B: Backend> {
Batch(BatchNorm<B>),
Group(GroupNorm<B>),
Instance(InstanceNorm<B>),
Layer(LayerNorm<B>),
Rms(RmsNorm<B>),
}
impl<B: Backend> From<BatchNorm<B>> for Normalization<B> {
fn from(layer: BatchNorm<B>) -> Self {
Self::Batch(layer)
}
}
impl<B: Backend> From<GroupNorm<B>> for Normalization<B> {
fn from(layer: GroupNorm<B>) -> Self {
Self::Group(layer)
}
}
impl<B: Backend> From<InstanceNorm<B>> for Normalization<B> {
fn from(layer: InstanceNorm<B>) -> Self {
Self::Instance(layer)
}
}
impl<B: Backend> From<LayerNorm<B>> for Normalization<B> {
fn from(layer: LayerNorm<B>) -> Self {
Self::Layer(layer)
}
}
impl<B: Backend> From<RmsNorm<B>> for Normalization<B> {
fn from(layer: RmsNorm<B>) -> Self {
Self::Rms(layer)
}
}
impl<B: Backend> Normalization<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
match self {
Normalization::Batch(norm) => norm.forward(input),
Normalization::Group(norm) => norm.forward(input),
Normalization::Instance(norm) => norm.forward(input),
Normalization::Layer(norm) => norm.forward(input),
Normalization::Rms(norm) => norm.forward(input),
}
}
pub fn num_features(&self) -> usize {
match self {
Normalization::Batch(norm) => norm.gamma.shape().dims[0],
Normalization::Group(norm) => norm.num_channels,
Normalization::Instance(norm) => norm.num_channels,
Normalization::Layer(norm) => norm.gamma.shape().dims[0],
Normalization::Rms(norm) => norm.gamma.shape().dims[0],
}
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod tests {
use super::*;
use crate::TestAutodiffBackend;
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestAutodiffBackend>;
#[test]
fn test_match_feature_size() {
let config: NormalizationConfig = BatchNormConfig::new(0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
let config: NormalizationConfig = GroupNormConfig::new(4, 0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
let config: NormalizationConfig = InstanceNormConfig::new(0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
let config: NormalizationConfig = LayerNormConfig::new(0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
let config: NormalizationConfig = RmsNormConfig::new(0).into();
assert_eq!(config.num_features(), 0);
let config = config.with_num_features(12);
assert_eq!(config.num_features(), 12);
}
#[test]
fn test_batch_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
let config: NormalizationConfig = BatchNormConfig::new(12).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Batch(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output.to_data().assert_eq(&expected.to_data(), true);
}
#[test]
fn test_group_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
let config: NormalizationConfig = GroupNormConfig::new(3, num_features).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Group(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn test_instance_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);
let config: NormalizationConfig = InstanceNormConfig::new(num_features).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Instance(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn test_layer_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
let config: NormalizationConfig = LayerNormConfig::new(num_features).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Layer(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
#[test]
fn test_rms_norm() {
type B = TestAutodiffBackend;
let device = Default::default();
let num_features = 12;
let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);
let config: NormalizationConfig = RmsNormConfig::new(num_features).into();
let layer: Normalization<B> = config.init(&device);
assert_eq!(layer.num_features(), 12);
let expected = match &layer {
Normalization::Rms(inner) => inner.forward(input.clone()),
_ => panic!("Unexpected layer type"),
};
let output = layer.forward(input);
output
.to_data()
.assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
}
}