use super::core::{validation, FunctionalConfig};
use crate::{func_error, validate_inputs};
use torsh_core::error::Result;
use torsh_tensor::Tensor;
pub fn batch_norm_2d(
input: &Tensor,
weight: Option<&Tensor>,
bias: Option<&Tensor>,
running_mean: Option<&Tensor>,
running_var: Option<&Tensor>,
training: bool,
momentum: f32,
eps: f32,
) -> Result<Tensor> {
batch_norm_2d_with_config(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
&super::core::default_config(),
)
}
pub fn batch_norm_2d_with_config(
input: &Tensor,
weight: Option<&Tensor>,
bias: Option<&Tensor>,
running_mean: Option<&Tensor>,
running_var: Option<&Tensor>,
training: bool,
momentum: f32,
eps: f32,
config: &FunctionalConfig,
) -> Result<Tensor> {
validate_inputs!(
config,
validation::validate_not_empty(input, "input"),
validation::validate_min_ndim(input, 4, "input"),
validation::validate_positive(eps, "eps"),
validation::validate_range(momentum, 0.0, 1.0, "momentum")
);
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
let batch_size = input_shape[0];
let channels = input_shape[1];
if training {
let spatial_dims: usize = input_shape[2..].iter().product();
let total_spatial = batch_size * spatial_dims;
let reshaped = input.view(&[total_spatial as i32, channels as i32])?;
let mean = reshaped.mean(Some(&[0]), false)?;
let centered = reshaped.sub(&mean.unsqueeze(0)?)?;
let variance = centered.pow_scalar(2.0)?.mean(Some(&[0]), false)?;
let eps_tensor = torsh_tensor::creation::full(&[channels], eps)?;
let stable_var = variance.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let normalized = centered.mul_op(&inv_std.unsqueeze(0)?)?;
let input_shape_i32: Vec<i32> = input_shape.iter().map(|&x| x as i32).collect();
let output = normalized.view(&input_shape_i32)?;
let mut result = output;
if let Some(w) = weight {
let weight_expanded = w.view(&[1, channels as i32, 1, 1])?;
result = result.mul_op(&weight_expanded)?;
}
if let Some(b) = bias {
let bias_expanded = b.view(&[1, channels as i32, 1, 1])?;
result = result.add(&bias_expanded)?;
}
if let (Some(r_mean), Some(r_var)) = (running_mean, running_var) {
let momentum_tensor = torsh_tensor::creation::full(&[1], momentum)?;
let one_minus_momentum = torsh_tensor::creation::full(&[1], 1.0 - momentum)?;
let _new_running_mean = r_mean
.mul_op(&one_minus_momentum)?
.add(&mean.mul_op(&momentum_tensor)?)?;
let _new_running_var = r_var
.mul_op(&one_minus_momentum)?
.add(&variance.mul_op(&momentum_tensor)?)?;
}
Ok(result)
} else {
let default_mean = torsh_tensor::creation::zeros(&[channels])?;
let default_var = torsh_tensor::creation::ones(&[channels])?;
let r_mean = running_mean.unwrap_or(&default_mean);
let r_var = running_var.unwrap_or(&default_var);
let eps_tensor = torsh_tensor::creation::full(&[channels], eps)?;
let stable_var = r_var.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let mean_expanded = r_mean.view(&[1, channels as i32, 1, 1])?;
let inv_std_expanded = inv_std.view(&[1, channels as i32, 1, 1])?;
let normalized = input.sub(&mean_expanded)?.mul_op(&inv_std_expanded)?;
let mut result = normalized;
if let Some(w) = weight {
let weight_expanded = w.view(&[1, channels as i32, 1, 1])?;
result = result.mul_op(&weight_expanded)?;
}
if let Some(b) = bias {
let bias_expanded = b.view(&[1, channels as i32, 1, 1])?;
result = result.add(&bias_expanded)?;
}
Ok(result)
}
}
pub fn batch_norm_1d(
input: &Tensor,
weight: Option<&Tensor>,
bias: Option<&Tensor>,
running_mean: Option<&Tensor>,
running_var: Option<&Tensor>,
training: bool,
momentum: f32,
eps: f32,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let batch_size = input_shape[0];
let channels = input_shape[1];
let length = input_shape[2];
if training {
let total_spatial = batch_size * length;
let reshaped = input.view(&[total_spatial as i32, channels as i32])?;
let mean = reshaped.mean(Some(&[0]), false)?;
let centered = reshaped.sub(&mean.unsqueeze(0)?)?;
let variance = centered.pow_scalar(2.0)?.mean(Some(&[0]), false)?;
let eps_tensor = torsh_tensor::creation::full(&[channels], eps)?;
let stable_var = variance.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let normalized = centered.mul_op(&inv_std.unsqueeze(0)?)?;
let input_shape_i32: Vec<i32> = input_shape.iter().map(|&x| x as i32).collect();
let output = normalized.view(&input_shape_i32)?;
let mut result = output;
if let Some(w) = weight {
let weight_expanded = w.view(&[1, channels as i32, 1])?;
result = result.mul_op(&weight_expanded)?;
}
if let Some(b) = bias {
let bias_expanded = b.view(&[1, channels as i32, 1])?;
result = result.add(&bias_expanded)?;
}
let _ = (running_mean, running_var, momentum, mean, variance);
Ok(result)
} else {
let default_mean = torsh_tensor::creation::zeros(&[channels])?;
let default_var = torsh_tensor::creation::ones(&[channels])?;
let r_mean = running_mean.unwrap_or(&default_mean);
let r_var = running_var.unwrap_or(&default_var);
let eps_tensor = torsh_tensor::creation::full(&[channels], eps)?;
let stable_var = r_var.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let mean_expanded = r_mean.view(&[1, channels as i32, 1])?;
let inv_std_expanded = inv_std.view(&[1, channels as i32, 1])?;
let normalized = input.sub(&mean_expanded)?.mul_op(&inv_std_expanded)?;
let mut result = normalized;
if let Some(w) = weight {
let weight_expanded = w.view(&[1, channels as i32, 1])?;
result = result.mul_op(&weight_expanded)?;
}
if let Some(b) = bias {
let bias_expanded = b.view(&[1, channels as i32, 1])?;
result = result.add(&bias_expanded)?;
}
Ok(result)
}
}
pub fn batch_norm_3d(
input: &Tensor,
weight: Option<&Tensor>,
bias: Option<&Tensor>,
running_mean: Option<&Tensor>,
running_var: Option<&Tensor>,
training: bool,
momentum: f32,
eps: f32,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 5 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let batch_size = input_shape[0];
let channels = input_shape[1];
let depth = input_shape[2];
let height = input_shape[3];
let width = input_shape[4];
if training {
let total_spatial = batch_size * depth * height * width;
let reshaped = input.view(&[total_spatial as i32, channels as i32])?;
let mean = reshaped.mean(Some(&[0]), false)?;
let centered = reshaped.sub(&mean.unsqueeze(0)?)?;
let variance = centered.pow_scalar(2.0)?.mean(Some(&[0]), false)?;
let eps_tensor = torsh_tensor::creation::full(&[channels], eps)?;
let stable_var = variance.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let normalized = centered.mul_op(&inv_std.unsqueeze(0)?)?;
let input_shape_i32: Vec<i32> = input_shape.iter().map(|&x| x as i32).collect();
let output = normalized.view(&input_shape_i32)?;
let mut result = output;
if let Some(w) = weight {
let weight_expanded = w.view(&[1, channels as i32, 1, 1, 1])?;
result = result.mul_op(&weight_expanded)?;
}
if let Some(b) = bias {
let bias_expanded = b.view(&[1, channels as i32, 1, 1, 1])?;
result = result.add(&bias_expanded)?;
}
let _ = (running_mean, running_var, momentum, mean, variance);
Ok(result)
} else {
let default_mean = torsh_tensor::creation::zeros(&[channels])?;
let default_var = torsh_tensor::creation::ones(&[channels])?;
let r_mean = running_mean.unwrap_or(&default_mean);
let r_var = running_var.unwrap_or(&default_var);
let eps_tensor = torsh_tensor::creation::full(&[channels], eps)?;
let stable_var = r_var.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let mean_expanded = r_mean.view(&[1, channels as i32, 1, 1, 1])?;
let inv_std_expanded = inv_std.view(&[1, channels as i32, 1, 1, 1])?;
let normalized = input.sub(&mean_expanded)?.mul_op(&inv_std_expanded)?;
let mut result = normalized;
if let Some(w) = weight {
let weight_expanded = w.view(&[1, channels as i32, 1, 1, 1])?;
result = result.mul_op(&weight_expanded)?;
}
if let Some(b) = bias {
let bias_expanded = b.view(&[1, channels as i32, 1, 1, 1])?;
result = result.add(&bias_expanded)?;
}
Ok(result)
}
}
#[allow(clippy::too_many_arguments)]
pub fn batch_norm(
input: &Tensor,
running_mean: Option<&Tensor>,
running_var: Option<&Tensor>,
weight: Option<&Tensor>,
bias: Option<&Tensor>,
training: bool,
momentum: f32,
eps: f32,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_dims = input_shape_obj.dims();
match input_dims.len() {
3 => {
batch_norm_1d(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
}
4 => {
batch_norm_2d(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
}
5 => {
batch_norm_3d(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
}
2 => {
let batch_size = input_dims[0] as i32;
let features = input_dims[1] as i32;
let reshaped_input = input.view(&[batch_size, features, 1])?;
let result = batch_norm_1d(
&reshaped_input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)?;
result.view(&[batch_size, features])
}
_ => {
Err(torsh_core::error::TorshError::InvalidArgument(format!(
"batch_norm: expected 2D, 3D, 4D, or 5D input, got {}D",
input_dims.len()
)))
}
}
}
pub fn layer_norm_enhanced(
input: &Tensor,
normalized_shape: &[usize],
weight: Option<&Tensor>,
bias: Option<&Tensor>,
eps: f32,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
let norm_dims = normalized_shape.len();
let _norm_size: usize = normalized_shape.iter().product();
let norm_dim_indices: Vec<usize> = (input_shape.len() - norm_dims..input_shape.len()).collect();
let mean = input.mean(Some(&norm_dim_indices), true)?;
let centered = input.sub(&mean)?;
let variance = centered
.pow_scalar(2.0)?
.mean(Some(&norm_dim_indices), true)?;
let eps_tensor = torsh_tensor::creation::full(&[1], eps)?;
let stable_var = variance.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let normalized = centered.mul_op(&inv_std)?;
let mut result = normalized;
if let Some(w) = weight {
result = result.mul_op(w)?;
}
if let Some(b) = bias {
result = result.add(b)?;
}
Ok(result)
}
pub fn layer_norm(
input: &Tensor,
normalized_shape: &[usize],
weight: Option<&Tensor>,
bias: Option<&Tensor>,
eps: f32,
) -> Result<Tensor> {
layer_norm_enhanced(input, normalized_shape, weight, bias, eps)
}
pub fn layer_norm_configured(
input: &Tensor,
normalized_shape: &[usize],
weight: Option<&Tensor>,
bias: Option<&Tensor>,
eps: f32,
config: &FunctionalConfig,
) -> Result<Tensor> {
validate_inputs!(
config,
validation::validate_not_empty(input, "input"),
validation::validate_positive(eps, "eps")
);
func_error!(
layer_norm_enhanced(input, normalized_shape, weight, bias, eps),
"Layer normalization"
)
}
pub fn group_norm(
input: &Tensor,
num_groups: usize,
weight: Option<&Tensor>,
bias: Option<&Tensor>,
eps: f32,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() < 2 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Input must have at least 2 dimensions for group normalization".to_string(),
));
}
let batch_size = input_shape[0];
let channels = input_shape[1];
if channels % num_groups != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Number of channels ({}) must be divisible by number of groups ({})",
channels, num_groups
)));
}
let channels_per_group = channels / num_groups;
let spatial_size: usize = input_shape[2..].iter().product();
let reshaped = input.view(&[
batch_size as i32,
num_groups as i32,
channels_per_group as i32,
spatial_size as i32,
])?;
let group_size = channels_per_group * spatial_size;
let flattened = reshaped.view(&[(batch_size * num_groups) as i32, group_size as i32])?;
let mean = flattened.mean(Some(&[1]), true)?;
let centered = flattened.sub(&mean)?;
let variance = centered.pow_scalar(2.0)?.mean(Some(&[1]), true)?;
let eps_tensor = torsh_tensor::creation::full(&[1], eps)?;
let stable_var = variance.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let normalized = centered.mul_op(&inv_std)?;
let normalized = normalized.view(&[
batch_size as i32,
num_groups as i32,
channels_per_group as i32,
spatial_size as i32,
])?;
let input_shape_i32: Vec<i32> = input_shape.iter().map(|&x| x as i32).collect();
let mut result = normalized.view(&input_shape_i32)?;
if let Some(w) = weight {
let mut weight_shape = vec![1, channels];
weight_shape.extend(vec![1; input_shape.len() - 2]);
let weight_shape_i32: Vec<i32> = weight_shape.iter().map(|&x| x as i32).collect();
let weight_expanded = w.view(&weight_shape_i32)?;
result = result.mul_op(&weight_expanded)?;
}
if let Some(b) = bias {
let mut bias_shape = vec![1, channels];
bias_shape.extend(vec![1; input_shape.len() - 2]);
let bias_shape_i32: Vec<i32> = bias_shape.iter().map(|&x| x as i32).collect();
let bias_expanded = b.view(&bias_shape_i32)?;
result = result.add(&bias_expanded)?;
}
Ok(result)
}
pub fn instance_norm(
input: &Tensor,
weight: Option<&Tensor>,
bias: Option<&Tensor>,
eps: f32,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() < 3 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Input must have at least 3 dimensions for instance normalization".to_string(),
));
}
let batch_size = input_shape[0];
let channels = input_shape[1];
let spatial_size: usize = input_shape[2..].iter().product();
let reshaped = input.view(&[(batch_size * channels) as i32, spatial_size as i32])?;
let mean = reshaped.mean(Some(&[1]), true)?;
let centered = reshaped.sub(&mean)?;
let variance = centered.pow_scalar(2.0)?.mean(Some(&[1]), true)?;
let eps_tensor = torsh_tensor::creation::full(&[1], eps)?;
let stable_var = variance.add(&eps_tensor)?;
let inv_std = stable_var.rsqrt()?;
let normalized = centered.mul_op(&inv_std)?;
let input_shape_i32: Vec<i32> = input_shape.iter().map(|&x| x as i32).collect();
let mut result = normalized.view(&input_shape_i32)?;
if let Some(w) = weight {
let mut weight_shape = vec![1, channels];
weight_shape.extend(vec![1; input_shape.len() - 2]);
let weight_shape_i32: Vec<i32> = weight_shape.iter().map(|&x| x as i32).collect();
let weight_expanded = w.view(&weight_shape_i32)?;
result = result.mul_op(&weight_expanded)?;
}
if let Some(b) = bias {
let mut bias_shape = vec![1, channels];
bias_shape.extend(vec![1; input_shape.len() - 2]);
let bias_shape_i32: Vec<i32> = bias_shape.iter().map(|&x| x as i32).collect();
let bias_expanded = b.view(&bias_shape_i32)?;
result = result.add(&bias_expanded)?;
}
Ok(result)
}
pub fn local_response_norm(
input: &Tensor,
size: usize,
alpha: f32,
beta: f32,
k: f32,
) -> Result<Tensor> {
let input_shape_binding = input.shape();
let input_shape = input_shape_binding.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"LRN requires 4D input [batch, channels, height, width]".to_string(),
));
}
let batch_size = input_shape[0];
let channels = input_shape[1];
let height = input_shape[2];
let width = input_shape[3];
let input_data = input.to_vec()?;
let mut output_data = vec![0.0f32; input_data.len()];
let half_size = size / 2;
for b in 0..batch_size {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
let c_start = if c >= half_size { c - half_size } else { 0 };
let c_end = (c + half_size + 1).min(channels);
let mut sum_squares = 0.0f32;
for c_neighbor in c_start..c_end {
let idx = b * channels * height * width
+ c_neighbor * height * width
+ h * width
+ w;
sum_squares += input_data[idx] * input_data[idx];
}
let input_idx =
b * channels * height * width + c * height * width + h * width + w;
let scale = k + alpha * sum_squares;
output_data[input_idx] = input_data[input_idx] / scale.powf(beta);
}
}
}
}
Tensor::from_vec(output_data, input_shape)
}
pub fn spectral_norm(
weight: &Tensor,
u: &Tensor,
n_power_iterations: usize,
eps: f32,
) -> Result<(Tensor, Tensor)> {
let weight_shape_binding = weight.shape();
let weight_shape = weight_shape_binding.dims();
if weight_shape.len() != 2 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Spectral normalization requires 2D weight matrix".to_string(),
));
}
let out_features = weight_shape[0];
let in_features = weight_shape[1];
let weight_data = weight.to_vec()?;
let mut u_data = u.to_vec()?;
if u_data.len() != out_features {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"u vector must have size {}, got {}",
out_features,
u_data.len()
)));
}
for _ in 0..n_power_iterations {
let mut v_data = vec![0.0f32; in_features];
for j in 0..in_features {
let mut sum = 0.0f32;
for i in 0..out_features {
sum += weight_data[i * in_features + j] * u_data[i];
}
v_data[j] = sum;
}
let v_norm = (v_data.iter().map(|&x| x * x).sum::<f32>()).sqrt() + eps;
for v in v_data.iter_mut() {
*v /= v_norm;
}
for i in 0..out_features {
let mut sum = 0.0f32;
for j in 0..in_features {
sum += weight_data[i * in_features + j] * v_data[j];
}
u_data[i] = sum;
}
let u_norm = (u_data.iter().map(|&x| x * x).sum::<f32>()).sqrt() + eps;
for u in u_data.iter_mut() {
*u /= u_norm;
}
}
let mut v_data = vec![0.0f32; in_features];
for j in 0..in_features {
let mut sum = 0.0f32;
for i in 0..out_features {
sum += weight_data[i * in_features + j] * u_data[i];
}
v_data[j] = sum;
}
let v_norm = (v_data.iter().map(|&x| x * x).sum::<f32>()).sqrt() + eps;
for v in v_data.iter_mut() {
*v /= v_norm;
}
let mut sigma = 0.0f32;
for i in 0..out_features {
let mut row_dot_v = 0.0f32;
for j in 0..in_features {
row_dot_v += weight_data[i * in_features + j] * v_data[j];
}
sigma += u_data[i] * row_dot_v;
}
let sigma_with_eps = sigma.max(eps); let normalized_weight_data: Vec<f32> =
weight_data.iter().map(|&w| w / sigma_with_eps).collect();
let normalized_weight = Tensor::from_vec(normalized_weight_data, weight_shape)?;
let new_u = Tensor::from_vec(u_data, &[out_features])?;
Ok((normalized_weight, new_u))
}
pub fn weight_norm(weight: &Tensor, g: &Tensor, _dim: i32) -> Result<Tensor> {
let squared = weight.pow(2.0)?;
let sum_squared = squared.sum()?;
let norm = sum_squared.sqrt()?;
let normalized = weight.div(&norm)?;
let result = normalized.mul_op(g)?;
Ok(result)
}
pub fn rms_norm(input: &Tensor, weight: &Tensor, eps: f32) -> Result<Tensor> {
let squared = input.pow(2.0)?;
let last_dim = input.shape().dims().len() - 1;
let mean_squared = squared.mean(Some(&[last_dim]), true)?;
let eps_tensor = torsh_tensor::creation::full_like(&mean_squared, eps)?;
let rms = mean_squared.add(&eps_tensor)?.sqrt()?;
let normalized = input.div(&rms)?;
normalized.mul(weight)
}
pub mod configured {
use super::*;
pub fn batch_norm_configured(
input: &Tensor,
weight: Option<&Tensor>,
bias: Option<&Tensor>,
running_mean: Option<&Tensor>,
running_var: Option<&Tensor>,
training: bool,
momentum: f32,
eps: f32,
config: &FunctionalConfig,
) -> Result<Tensor> {
batch_norm_2d_with_config(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
config,
)
}
}
pub fn get_norm_features(_input_shape: &[usize], normalized_shape: &[usize]) -> usize {
normalized_shape.iter().product()
}
pub fn validate_norm_params(
input_shape: &[usize],
normalized_shape: &[usize],
eps: f32,
) -> Result<()> {
if eps <= 0.0 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Epsilon must be positive".to_string(),
));
}
if normalized_shape.len() > input_shape.len() {
return Err(torsh_core::error::TorshError::InvalidArgument(
"Normalized shape cannot have more dimensions than input".to_string(),
));
}
let input_suffix = &input_shape[input_shape.len() - normalized_shape.len()..];
if input_suffix != normalized_shape {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Normalized shape {:?} doesn't match input shape suffix {:?}",
normalized_shape, input_suffix
)));
}
Ok(())
}
pub fn create_affine_params(
shape: &[usize],
init_weight: f32,
init_bias: f32,
) -> Result<(Tensor, Tensor)> {
let weight = torsh_tensor::creation::full(shape, init_weight)?;
let bias = torsh_tensor::creation::full(shape, init_bias)?;
Ok((weight, bias))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_layer_norm_basic() -> Result<()> {
let input = Tensor::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ],
&[2, 3],
)?;
let output = layer_norm(&input, &[3], None, None, 1e-5)?;
assert_eq!(output.shape().dims(), &[2, 3]);
let output_data = output.to_vec()?;
let sample1_mean = (output_data[0] + output_data[1] + output_data[2]) / 3.0;
let sample2_mean = (output_data[3] + output_data[4] + output_data[5]) / 3.0;
assert_relative_eq!(sample1_mean, 0.0, epsilon = 1e-5);
assert_relative_eq!(sample2_mean, 0.0, epsilon = 1e-5);
let sample1_var =
(output_data[0].powi(2) + output_data[1].powi(2) + output_data[2].powi(2)) / 3.0;
let sample2_var =
(output_data[3].powi(2) + output_data[4].powi(2) + output_data[5].powi(2)) / 3.0;
assert_relative_eq!(sample1_var, 1.0, epsilon = 1e-4);
assert_relative_eq!(sample2_var, 1.0, epsilon = 1e-4);
Ok(())
}
#[test]
fn test_layer_norm_with_affine() -> Result<()> {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[1, 4])?;
let weight = Tensor::from_vec(vec![2.0, 2.0, 2.0, 2.0], &[4])?;
let bias = Tensor::from_vec(vec![0.5, 0.5, 0.5, 0.5], &[4])?;
let output = layer_norm(&input, &[4], Some(&weight), Some(&bias), 1e-5)?;
let output_data = output.to_vec()?;
for &val in output_data.iter() {
assert!(val.abs() < 10.0); }
Ok(())
}
#[test]
fn test_layer_norm_multidimensional() -> Result<()> {
let input = Tensor::from_vec(
vec![
1.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 6.0, 4.0, 5.0, 6.0, 7.0, 5.0, 6.0, 7.0, 8.0, 6.0, 7.0, 8.0, 9.0, ],
&[2, 3, 4],
)?;
let output = layer_norm(&input, &[4], None, None, 1e-5)?;
assert_eq!(output.shape().dims(), &[2, 3, 4]);
let output_data = output.to_vec()?;
for batch in 0..2 {
for pos in 0..3 {
let start_idx = (batch * 3 + pos) * 4;
let slice = &output_data[start_idx..start_idx + 4];
let mean: f32 = slice.iter().sum::<f32>() / 4.0;
let var: f32 = slice.iter().map(|x| x.powi(2)).sum::<f32>() / 4.0;
assert_relative_eq!(mean, 0.0, epsilon = 1e-4);
assert_relative_eq!(var, 1.0, epsilon = 1e-4);
}
}
Ok(())
}
#[test]
fn test_layer_norm_single_value() -> Result<()> {
let input = Tensor::from_vec(vec![5.0, 10.0, 15.0], &[3, 1])?;
let output = layer_norm(&input, &[1], None, None, 1e-5)?;
let output_data = output.to_vec()?;
for &val in output_data.iter() {
assert_relative_eq!(val, 0.0, epsilon = 1e-5);
}
Ok(())
}
#[test]
fn test_layer_norm_zeros_input() -> Result<()> {
let input = Tensor::from_vec(vec![0.0, 0.0, 0.0, 0.0], &[1, 4])?;
let output = layer_norm(&input, &[4], None, None, 1e-5)?;
let output_data = output.to_vec()?;
for &val in output_data.iter() {
assert!(val.abs() < 1e-2); }
Ok(())
}
#[test]
fn test_layer_norm_consistency() -> Result<()> {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])?;
let output1 = layer_norm(&input, &[4], None, None, 1e-5)?;
let output2 = layer_norm_enhanced(&input, &[4], None, None, 1e-5)?;
let data1 = output1.to_vec()?;
let data2 = output2.to_vec()?;
assert_eq!(data1.len(), data2.len());
for (v1, v2) in data1.iter().zip(data2.iter()) {
assert_relative_eq!(v1, v2, epsilon = 1e-6);
}
Ok(())
}
#[test]
fn test_group_norm_basic() -> Result<()> {
let batch_size = 2;
let channels = 4;
let height = 3;
let width = 3;
let num_groups = 2;
let channels_per_group = channels / num_groups;
let input_data: Vec<f32> = (0..batch_size * channels * height * width)
.map(|i| (i as f32) * 0.1)
.collect();
let input = Tensor::from_vec(input_data, &[batch_size, channels, height, width])?;
let output = group_norm(&input, num_groups, None, None, 1e-5)?;
assert_eq!(
output.shape().dims(),
&[batch_size, channels, height, width]
);
let output_data = output.to_vec()?;
let group_size = channels_per_group * height * width;
let first_group_start = 0;
let first_group_end = group_size;
let first_group: Vec<f32> = output_data[first_group_start..first_group_end].to_vec();
let mean: f32 = first_group.iter().sum::<f32>() / first_group.len() as f32;
let variance: f32 =
first_group.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / first_group.len() as f32;
assert_relative_eq!(mean, 0.0, epsilon = 1e-4);
assert_relative_eq!(variance, 1.0, epsilon = 1e-3);
Ok(())
}
#[test]
fn test_group_norm_with_affine() -> Result<()> {
let batch_size = 1;
let channels = 6;
let spatial_dim = 4;
let num_groups = 3;
let input_data: Vec<f32> = (0..batch_size * channels * spatial_dim)
.map(|i| (i as f32) * 0.2 + 1.0)
.collect();
let input = Tensor::from_vec(input_data, &[batch_size, channels, spatial_dim])?;
let weight = Tensor::from_vec(vec![2.0, 1.5, 1.0, 0.5, 2.5, 3.0], &[channels])?;
let bias = Tensor::from_vec(vec![0.1, -0.1, 0.2, -0.2, 0.3, -0.3], &[channels])?;
let output = group_norm(&input, num_groups, Some(&weight), Some(&bias), 1e-5)?;
assert_eq!(output.shape().dims(), &[batch_size, channels, spatial_dim]);
let output_data = output.to_vec()?;
assert!(output_data.len() == batch_size * channels * spatial_dim);
Ok(())
}
#[test]
fn test_group_norm_invalid_groups() {
let input = Tensor::from_vec(vec![1.0; 2 * 5 * 3 * 3], &[2, 5, 3, 3])
.expect("Tensor should succeed");
let result = group_norm(&input, 3, None, None, 1e-5); assert!(result.is_err());
}
#[test]
fn test_instance_norm_basic() -> Result<()> {
let batch_size = 2;
let channels = 3;
let spatial_size = 4;
let input_data: Vec<f32> = (0..batch_size * channels * spatial_size)
.map(|i| (i as f32) * 0.5 + 2.0)
.collect();
let input = Tensor::from_vec(input_data, &[batch_size, channels, spatial_size])?;
let output = instance_norm(&input, None, None, 1e-5)?;
assert_eq!(output.shape().dims(), &[batch_size, channels, spatial_size]);
let output_data = output.to_vec()?;
let instance_start = 0;
let instance_end = spatial_size;
let first_instance: Vec<f32> = output_data[instance_start..instance_end].to_vec();
let mean: f32 = first_instance.iter().sum::<f32>() / first_instance.len() as f32;
let variance: f32 = first_instance
.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f32>()
/ first_instance.len() as f32;
assert_relative_eq!(mean, 0.0, epsilon = 1e-4);
assert_relative_eq!(variance, 1.0, epsilon = 1e-3);
Ok(())
}
#[test]
fn test_instance_norm_with_affine() -> Result<()> {
let batch_size = 2;
let channels = 4;
let height = 3;
let width = 3;
let input_data: Vec<f32> = (0..batch_size * channels * height * width)
.map(|i| (i as f32) * 0.1)
.collect();
let input = Tensor::from_vec(input_data, &[batch_size, channels, height, width])?;
let weight = Tensor::from_vec(vec![1.5, 2.0, 0.5, 3.0], &[channels])?;
let bias = Tensor::from_vec(vec![0.2, -0.2, 0.1, -0.1], &[channels])?;
let output = instance_norm(&input, Some(&weight), Some(&bias), 1e-5)?;
assert_eq!(
output.shape().dims(),
&[batch_size, channels, height, width]
);
let output_data = output.to_vec()?;
assert!(output_data.len() == batch_size * channels * height * width);
Ok(())
}
#[test]
fn test_instance_norm_2d_image() -> Result<()> {
let batch_size = 1;
let channels = 3; let height = 4;
let width = 4;
let input_data: Vec<f32> = (0..batch_size * channels * height * width)
.map(|i| (i as f32) / 10.0)
.collect();
let input = Tensor::from_vec(input_data, &[batch_size, channels, height, width])?;
let output = instance_norm(&input, None, None, 1e-5)?;
assert_eq!(
output.shape().dims(),
&[batch_size, channels, height, width]
);
Ok(())
}
#[test]
fn test_instance_norm_invalid_dims() {
let input =
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("Tensor should succeed");
let result = instance_norm(&input, None, None, 1e-5);
assert!(result.is_err());
}
#[test]
fn test_weight_norm_basic() -> Result<()> {
let weight = Tensor::from_vec(
vec![3.0, 4.0], &[2],
)?;
let g = Tensor::from_vec(vec![10.0], &[1])?;
let output = weight_norm(&weight, &g, 0)?;
let output_data = output.to_vec()?;
assert_relative_eq!(output_data[0], 6.0, epsilon = 1e-5);
assert_relative_eq!(output_data[1], 8.0, epsilon = 1e-5);
Ok(())
}
#[test]
fn test_weight_norm_matrix() -> Result<()> {
let weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])?;
let g = Tensor::from_vec(vec![5.0], &[1])?;
let output = weight_norm(&weight, &g, 0)?;
assert_eq!(output.shape().dims(), &[2, 2]);
let output_data = output.to_vec()?;
assert!(output_data.iter().all(|&x| x != 0.0));
Ok(())
}
#[test]
fn test_weight_norm_preserve_direction() -> Result<()> {
let weight = Tensor::from_vec(vec![6.0, 8.0], &[2])?; let g = Tensor::from_vec(vec![20.0], &[1])?;
let output = weight_norm(&weight, &g, 0)?;
let output_data = output.to_vec()?;
assert_relative_eq!(output_data[0], 12.0, epsilon = 1e-4);
assert_relative_eq!(output_data[1], 16.0, epsilon = 1e-4);
let input_ratio = 6.0 / 8.0;
let output_ratio = output_data[0] / output_data[1];
assert_relative_eq!(input_ratio, output_ratio, epsilon = 1e-5);
Ok(())
}
}