pub mod advanced;
pub mod batch;
pub mod common;
pub mod instance;
pub mod layer_group;
pub mod weight_based;
pub use common::{utils, NormalizationConfig, NormalizationStats};
pub use batch::{BatchNorm1d, BatchNorm2d, BatchNorm3d};
pub use instance::{InstanceNorm1d, InstanceNorm2d, InstanceNorm3d};
pub use layer_group::{GroupNorm, LayerNorm, RMSNorm};
pub use weight_based::{SpectralNorm, WeightNorm, WeightStandardization};
pub use advanced::SwitchableNorm2d;
pub use BatchNorm2d as BatchNorm;
pub use GroupNorm as GN;
pub use InstanceNorm2d as InstanceNorm;
pub use LayerNorm as LN;
pub struct NormalizationFactory;
impl NormalizationFactory {
pub fn batch_norm(num_features: usize) -> torsh_core::error::Result<BatchNorm2d> {
BatchNorm2d::new(num_features)
}
pub fn layer_norm(normalized_shape: Vec<usize>) -> torsh_core::error::Result<LayerNorm> {
LayerNorm::new(normalized_shape)
}
pub fn group_norm(
num_groups: usize,
num_channels: usize,
) -> torsh_core::error::Result<GroupNorm> {
GroupNorm::new(num_groups, num_channels)
}
pub fn instance_norm(num_features: usize) -> torsh_core::error::Result<InstanceNorm2d> {
InstanceNorm2d::new(num_features)
}
pub fn switchable_norm(num_features: usize) -> torsh_core::error::Result<SwitchableNorm2d> {
SwitchableNorm2d::new(num_features)
}
pub fn rms_norm(normalized_shape: Vec<usize>) -> torsh_core::error::Result<RMSNorm> {
RMSNorm::new(normalized_shape)
}
pub fn batch_norm_training(num_features: usize) -> torsh_core::error::Result<BatchNorm2d> {
BatchNorm2d::with_config(num_features, NormalizationConfig::training())
}
pub fn batch_norm_inference(num_features: usize) -> torsh_core::error::Result<BatchNorm2d> {
BatchNorm2d::with_config(num_features, NormalizationConfig::inference())
}
pub fn layer_norm_non_affine(
normalized_shape: Vec<usize>,
) -> torsh_core::error::Result<LayerNorm> {
LayerNorm::with_config(normalized_shape, NormalizationConfig::non_affine())
}
}
pub struct NormalizationPresets;
impl NormalizationPresets {
pub fn resnet_batch_norm(num_features: usize) -> torsh_core::error::Result<BatchNorm2d> {
BatchNorm2d::with_config(num_features, NormalizationConfig::with_momentum(0.1))
}
pub fn transformer_layer_norm(hidden_size: usize) -> torsh_core::error::Result<LayerNorm> {
LayerNorm::with_config(vec![hidden_size], NormalizationConfig::with_eps(1e-12))
}
pub fn style_transfer_instance_norm(
num_features: usize,
) -> torsh_core::error::Result<InstanceNorm2d> {
InstanceNorm2d::with_config(num_features, NormalizationConfig::non_affine())
}
pub fn small_batch_group_norm(num_channels: usize) -> torsh_core::error::Result<GroupNorm> {
let num_groups = if num_channels >= 32 { 32 } else { num_channels };
GroupNorm::new(num_groups, num_channels)
}
pub fn llama_rms_norm(hidden_size: usize) -> torsh_core::error::Result<RMSNorm> {
RMSNorm::with_config(vec![hidden_size], 1e-6, true)
}
pub fn gpt_rms_norm(hidden_size: usize) -> torsh_core::error::Result<RMSNorm> {
RMSNorm::with_config(vec![hidden_size], 1e-5, true)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Module;
use torsh_tensor::creation::zeros;
#[test]
fn test_normalization_factory() {
let bn =
NormalizationFactory::batch_norm(64).expect("Normalization Factory should succeed");
assert_eq!(bn.num_features(), 64);
let ln = NormalizationFactory::layer_norm(vec![128])
.expect("Normalization Factory should succeed");
assert_eq!(ln.normalized_shape(), &[128]);
let gn =
NormalizationFactory::group_norm(8, 64).expect("Normalization Factory should succeed");
assert_eq!(gn.num_groups(), 8);
assert_eq!(gn.num_channels(), 64);
let inn =
NormalizationFactory::instance_norm(32).expect("Normalization Factory should succeed");
assert_eq!(inn.num_features(), 32);
let sn = NormalizationFactory::switchable_norm(16)
.expect("Normalization Factory should succeed");
assert_eq!(sn.num_features(), 16);
}
#[test]
fn test_normalization_presets() {
let resnet_bn = NormalizationPresets::resnet_batch_norm(64)
.expect("Normalization Presets should succeed");
assert_eq!(resnet_bn.momentum(), 0.1);
let transformer_ln = NormalizationPresets::transformer_layer_norm(768)
.expect("Normalization Presets should succeed");
assert_eq!(transformer_ln.eps(), 1e-12);
let style_in = NormalizationPresets::style_transfer_instance_norm(64)
.expect("Normalization Presets should succeed");
assert!(style_in.parameters().is_empty());
let small_batch_gn = NormalizationPresets::small_batch_group_norm(64)
.expect("Normalization Presets should succeed");
assert_eq!(small_batch_gn.num_groups(), 32);
}
#[test]
fn test_module_integration() {
let input_2d = zeros(&[4, 64]).expect("zeros should succeed");
let input_4d = zeros(&[4, 64, 32, 32]).expect("zeros should succeed");
let bn2d = BatchNorm2d::new(64).expect("Batch Norm2d should succeed");
assert!(bn2d.forward(&input_4d).is_ok());
let bn1d = BatchNorm1d::new(64).expect("Batch Norm1d should succeed");
assert!(bn1d.forward(&input_2d).is_ok());
let ln = LayerNorm::new(vec![64]).expect("Layer Norm should succeed");
assert!(ln.forward(&input_2d).is_ok());
let gn = GroupNorm::new(8, 64).expect("Group Norm should succeed");
assert!(gn.forward(&input_4d).is_ok());
let in2d = InstanceNorm2d::new(64).expect("Instance Norm2d should succeed");
assert!(in2d.forward(&input_4d).is_ok());
}
#[test]
fn test_backward_compatibility_aliases() {
let bn = BatchNorm::new(32).expect("Batch Norm should succeed");
assert_eq!(bn.num_features(), 32);
let ln = LN::new(vec![128]).expect("LN should succeed");
assert_eq!(ln.normalized_shape(), &[128]);
let gn = GN::new(4, 32).expect("GN should succeed");
assert_eq!(gn.num_groups(), 4);
let inn = InstanceNorm::new(16).expect("Instance Norm should succeed");
assert_eq!(inn.num_features(), 16);
}
#[test]
fn test_configuration_variants() {
let training_config = NormalizationConfig::training();
assert!(training_config.track_running_stats);
assert!(training_config.affine);
let inference_config = NormalizationConfig::inference();
assert!(!inference_config.track_running_stats);
let non_affine_config = NormalizationConfig::non_affine();
assert!(!non_affine_config.affine);
let custom_eps_config = NormalizationConfig::with_eps(1e-8);
assert_eq!(custom_eps_config.eps, 1e-8);
let custom_momentum_config = NormalizationConfig::with_momentum(0.05);
assert_eq!(custom_momentum_config.momentum, 0.05);
}
}