use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
use super::common::{utils, NormalizationConfig};
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct SwitchableNorm2d {
base: ModuleBase,
num_features: usize,
config: NormalizationConfig,
#[allow(dead_code)]
using_movavg: bool,
}
impl SwitchableNorm2d {
pub fn new(num_features: usize) -> Result<Self> {
Self::with_config(num_features, NormalizationConfig::default())
}
pub fn with_config(num_features: usize, config: NormalizationConfig) -> Result<Self> {
let mut base = ModuleBase::new();
let switch_weight = ones(&[3, num_features])?; base.register_parameter("switch_weight".to_string(), Parameter::new(switch_weight));
if config.affine {
let weight = ones(&[num_features])?;
let bias = zeros(&[num_features])?;
base.register_parameter("weight".to_string(), Parameter::new(weight));
base.register_parameter("bias".to_string(), Parameter::new(bias));
}
if config.track_running_stats {
let running_mean = zeros(&[num_features])?;
let running_var = ones(&[num_features])?;
base.register_buffer("running_mean".to_string(), running_mean);
base.register_buffer("running_var".to_string(), running_var);
base.register_buffer("num_batches_tracked".to_string(), zeros(&[1])?);
}
let using_movavg = config.track_running_stats;
Ok(Self {
base,
num_features,
config,
using_movavg,
})
}
pub fn num_features(&self) -> usize {
self.num_features
}
pub fn eps(&self) -> f32 {
self.config.eps
}
fn compute_batch_norm_stats(&self, input: &Tensor) -> Result<(Tensor, Tensor)> {
utils::compute_channel_mean(input)
.and_then(|mean| utils::compute_channel_variance(input, &mean).map(|var| (mean, var)))
}
fn compute_instance_norm_stats(&self, input: &Tensor) -> Result<(Tensor, Tensor)> {
let input_shape = input.shape();
let dims = input_shape.dims();
let batch_size = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
let input_data = input.to_vec()?;
let mut means = vec![0.0f32; batch_size * channels];
let mut vars = vec![0.0f32; batch_size * channels];
let spatial_size = (height * width) as f32;
for batch in 0..batch_size {
for c in 0..channels {
let mut sum = 0.0;
let mut sum_sq = 0.0;
for h in 0..height {
for w in 0..width {
let idx = batch * (channels * height * width)
+ c * (height * width)
+ h * width
+ w;
let val = input_data[idx];
sum += val;
sum_sq += val * val;
}
}
let mean = sum / spatial_size;
let var = (sum_sq / spatial_size) - (mean * mean);
let stat_idx = batch * channels + c;
means[stat_idx] = mean;
vars[stat_idx] = var;
}
}
let mean_tensor =
Tensor::from_data(means, vec![batch_size, channels, 1, 1], input.device())?;
let var_tensor = Tensor::from_data(vars, vec![batch_size, channels, 1, 1], input.device())?;
Ok((mean_tensor, var_tensor))
}
fn compute_layer_norm_stats(&self, input: &Tensor) -> Result<(Tensor, Tensor)> {
let input_shape = input.shape();
let dims = input_shape.dims();
let batch_size = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
let input_data = input.to_vec()?;
let mut means = vec![0.0f32; batch_size];
let mut vars = vec![0.0f32; batch_size];
let layer_size = (channels * height * width) as f32;
for batch in 0..batch_size {
let mut sum = 0.0;
let mut sum_sq = 0.0;
let batch_start = batch * (channels * height * width);
for i in 0..(channels * height * width) {
let val = input_data[batch_start + i];
sum += val;
sum_sq += val * val;
}
let mean = sum / layer_size;
let var = (sum_sq / layer_size) - (mean * mean);
means[batch] = mean;
vars[batch] = var;
}
let mean_tensor = Tensor::from_data(means, vec![batch_size, 1, 1, 1], input.device())?;
let var_tensor = Tensor::from_data(vars, vec![batch_size, 1, 1, 1], input.device())?;
Ok((mean_tensor, var_tensor))
}
fn apply_switchable_norm(&self, input: &Tensor) -> Result<Tensor> {
let (bn_mean, bn_var) = self.compute_batch_norm_stats(input)?;
let (in_mean, in_var) = self.compute_instance_norm_stats(input)?;
let (ln_mean, ln_var) = self.compute_layer_norm_stats(input)?;
let switch_weight = self.base.parameters.get("switch_weight").ok_or_else(|| {
torsh_core::error::TorshError::InvalidOperation(
"Switch weight parameter not found".to_string(),
)
})?;
let switch_data = switch_weight.tensor().read().to_vec()?;
let mut normalized_weights = vec![0.0f32; switch_data.len()];
for c in 0..self.num_features {
let mut max_val = switch_data[c];
for norm_type in 1..3 {
let idx = norm_type * self.num_features + c;
if switch_data[idx] > max_val {
max_val = switch_data[idx];
}
}
let mut sum = 0.0;
for norm_type in 0..3 {
let idx = norm_type * self.num_features + c;
let exp_val = (switch_data[idx] - max_val).exp();
normalized_weights[idx] = exp_val;
sum += exp_val;
}
for norm_type in 0..3 {
let idx = norm_type * self.num_features + c;
normalized_weights[idx] /= sum;
}
}
let bn_mean_expanded = bn_mean.unsqueeze(0)?.unsqueeze(2)?.unsqueeze(3)?;
let bn_var_expanded = bn_var.unsqueeze(0)?.unsqueeze(2)?.unsqueeze(3)?;
let input_shape = input.shape();
let dims = input_shape.dims();
let mut combined_mean_data = vec![0.0f32; dims.iter().product()];
let mut combined_var_data = vec![0.0f32; dims.iter().product()];
let bn_mean_data = bn_mean_expanded.to_vec()?;
let bn_var_data = bn_var_expanded.to_vec()?;
let in_mean_data = in_mean.to_vec()?;
let in_var_data = in_var.to_vec()?;
let ln_mean_data = ln_mean.to_vec()?;
let ln_var_data = ln_var.to_vec()?;
let batch_size = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
for batch in 0..batch_size {
for c in 0..channels {
let bn_weight = normalized_weights[c];
let in_weight = normalized_weights[self.num_features + c];
let ln_weight = normalized_weights[2 * self.num_features + c];
for h in 0..height {
for w in 0..width {
let idx = batch * (channels * height * width)
+ c * (height * width)
+ h * width
+ w;
let bn_idx = c;
let in_idx = batch * channels + c;
let ln_idx = batch;
combined_mean_data[idx] = bn_weight * bn_mean_data[bn_idx]
+ in_weight * in_mean_data[in_idx]
+ ln_weight * ln_mean_data[ln_idx];
combined_var_data[idx] = bn_weight * bn_var_data[bn_idx]
+ in_weight * in_var_data[in_idx]
+ ln_weight * ln_var_data[ln_idx];
}
}
}
}
let combined_mean = Tensor::from_data(combined_mean_data, dims.to_vec(), input.device())?;
let combined_var = Tensor::from_data(combined_var_data, dims.to_vec(), input.device())?;
let weight = if self.config.affine {
self.base.parameters.get("weight")
} else {
None
};
let bias = if self.config.affine {
self.base.parameters.get("bias")
} else {
None
};
let weight_tensor = weight.as_ref().map(|p| p.tensor().read().clone());
let bias_tensor = bias.as_ref().map(|p| p.tensor().read().clone());
utils::apply_normalization(
input,
&combined_mean,
&combined_var,
weight_tensor.as_ref(),
bias_tensor.as_ref(),
self.config.eps,
)
}
}
impl Module for SwitchableNorm2d {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let input_shape = input.shape();
let dims = input_shape.dims();
if dims.len() != 4 {
return Err(torsh_core::error::TorshError::InvalidShape(format!(
"SwitchableNorm2d expects 4D input (N, C, H, W), got shape {:?}",
dims
)));
}
if dims[1] != self.num_features {
return Err(torsh_core::error::TorshError::InvalidShape(format!(
"Expected {} features, got {}",
self.num_features, dims[1]
)));
}
self.apply_switchable_norm(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_switchable_norm_creation() {
let switchable_norm = SwitchableNorm2d::new(64).expect("Switchable Norm2d should succeed");
assert_eq!(switchable_norm.num_features(), 64);
assert_eq!(switchable_norm.eps(), 1e-5);
}
#[test]
fn test_switchable_norm_shape_validation() {
let switchable_norm = SwitchableNorm2d::new(3).expect("Switchable Norm2d should succeed");
let input = zeros(&[2, 3, 32, 32]).expect("zeros should succeed");
assert!(switchable_norm.forward(&input).is_ok());
let input_3d = zeros(&[2, 3, 32]).expect("zeros should succeed");
assert!(switchable_norm.forward(&input_3d).is_err());
let input_wrong_channels = zeros(&[2, 4, 32, 32]).expect("zeros should succeed");
assert!(switchable_norm.forward(&input_wrong_channels).is_err());
}
}