use crate::random_ops::randn;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn spectral_norm(weight: &Tensor, n_power_iterations: usize, eps: f64) -> TorshResult<Tensor> {
let weight_shape_binding = weight.shape();
let weight_shape = weight_shape_binding.dims();
let weight_2d = if weight_shape.len() > 2 {
let first_dim = weight_shape[0];
let remaining: usize = weight_shape[1..].iter().product();
weight.view(&[first_dim as i32, remaining as i32])?
} else {
weight.clone()
};
let (m, n) = (weight_2d.shape().dims()[0], weight_2d.shape().dims()[1]);
let mut u = randn(&[m, 1], None, None, None)?;
let mut v = randn(&[n, 1], None, None, None)?;
for _ in 0..n_power_iterations {
let wt_u = weight_2d.t()?.matmul(&u)?;
let wt_u_norm_tensor = wt_u.norm()?;
let wt_u_norm = wt_u_norm_tensor.data()?[0] + eps as f32;
v = wt_u.div_scalar(wt_u_norm)?;
let w_v = weight_2d.matmul(&v)?;
let w_v_norm_tensor = w_v.norm()?;
let w_v_norm = w_v_norm_tensor.data()?[0] + eps as f32;
u = w_v.div_scalar(w_v_norm)?;
}
let sigma = u.t()?.matmul(&weight_2d)?.matmul(&v)?;
let sigma_value = sigma.data()?[0] + eps as f32;
let normalized_weight = weight.div_scalar(sigma_value)?;
if weight_shape.len() > 2 {
normalized_weight.view(&weight_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())
} else {
Ok(normalized_weight)
}
}
pub fn weight_standardization(weight: &Tensor, eps: f64) -> TorshResult<Tensor> {
let weight_shape_binding = weight.shape();
let weight_shape = weight_shape_binding.dims();
if weight_shape.is_empty() {
return Err(TorshError::invalid_argument_with_context(
"Weight tensor cannot be empty",
"weight_standardization",
));
}
let out_channels = weight_shape[0];
let weight_per_channel: usize = weight_shape[1..].iter().product();
let weight_data = weight.data()?;
let mut standardized_data = Vec::with_capacity(weight_data.len());
for ch in 0..out_channels {
let start_idx = ch * weight_per_channel;
let end_idx = start_idx + weight_per_channel;
let channel_weights = &weight_data[start_idx..end_idx];
let mean = channel_weights.iter().sum::<f32>() / weight_per_channel as f32;
let variance = channel_weights
.iter()
.map(|&w| (w - mean).powi(2))
.sum::<f32>()
/ weight_per_channel as f32;
let std_dev = (variance + eps as f32).sqrt();
for &weight_val in channel_weights {
let standardized_val = (weight_val - mean) / std_dev;
standardized_data.push(standardized_val);
}
}
Tensor::from_data(standardized_data, weight_shape.to_vec(), weight.device())
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::randn;
#[test]
fn test_spectral_norm_basic() -> TorshResult<()> {
let weight = randn(&[4, 6])?;
let normalized = spectral_norm(&weight, 3, 1e-12)?;
assert_eq!(weight.shape().dims(), normalized.shape().dims());
Ok(())
}
#[test]
fn test_weight_standardization_basic() -> TorshResult<()> {
let weight = randn(&[3, 2, 2])?; let standardized = weight_standardization(&weight, 1e-5)?;
assert_eq!(weight.shape().dims(), standardized.shape().dims());
Ok(())
}
}