use crate::{CooTensor, CsrTensor, CscTensor, SparseTensor, SparseFormat, TorshResult};
use scirs2_core::random::{Random, rng};
use std::collections::HashMap;
use torsh_core::{Shape, TorshError};
use torsh_tensor::{
creation::{randn, zeros},
Tensor,
};
fn unzip_triplets(triplets: Vec<(usize, usize, f32)>) -> (Vec<usize>, Vec<usize>, Vec<f32>) {
triplets.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), (r, c, v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
)
}
#[derive(Debug, Clone)]
pub struct SparseBatchNorm {
num_features: usize,
eps: f32,
momentum: f32,
affine: bool,
training: bool,
running_mean: Tensor,
running_var: Tensor,
weight: Option<Tensor>,
bias: Option<Tensor>,
}
impl SparseBatchNorm {
pub fn new(num_features: usize, eps: f32, momentum: f32, affine: bool) -> TorshResult<Self> {
if num_features == 0 {
return Err(TorshError::InvalidArgument(
"Number of features must be greater than 0".to_string(),
));
}
if eps < 0.0 {
return Err(TorshError::InvalidArgument(
"Epsilon must be non-negative".to_string(),
));
}
if !(0.0..=1.0).contains(&momentum) {
return Err(TorshError::InvalidArgument(
"Momentum must be between 0.0 and 1.0".to_string(),
));
}
let running_mean = zeros::<f32>(&[num_features])?;
let mut running_var = zeros::<f32>(&[num_features])?;
for i in 0..num_features {
running_var.set(&[i], 1.0)?;
}
let (weight, bias) = if affine {
let mut weight = zeros::<f32>(&[num_features])?;
for i in 0..num_features {
weight.set(&[i], 1.0)?;
}
let bias = zeros::<f32>(&[num_features])?;
(Some(weight), Some(bias))
} else {
(None, None)
};
Ok(Self {
num_features,
eps,
momentum,
affine,
training: true,
running_mean,
running_var,
weight,
bias,
})
}
pub fn train(&mut self, mode: bool) {
self.training = mode;
}
pub fn training(&self) -> bool {
self.training
}
pub fn forward(&mut self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
if self.training {
self.update_statistics(&triplets)?;
}
let normalized_triplets = self.normalize_triplets(&triplets)?;
let (rows, cols, values) = unzip_triplets(normalized_triplets);
let normalized_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(normalized_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&normalized_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&normalized_coo)?)),
}
}
fn update_statistics(&mut self, triplets: &[(usize, usize, f32)]) -> TorshResult<()> {
let mut feature_sums = vec![0.0f32; self.num_features];
let mut feature_counts = vec![0usize; self.num_features];
let mut feature_sq_sums = vec![0.0f32; self.num_features];
for &(_, col, val) in triplets {
if col < self.num_features {
feature_sums[col] += val;
feature_counts[col] += 1;
feature_sq_sums[col] += val * val;
}
}
for i in 0..self.num_features {
if feature_counts[i] > 0 {
let mean = feature_sums[i] / feature_counts[i] as f32;
let var = (feature_sq_sums[i] / feature_counts[i] as f32) - mean * mean;
let old_mean = self.running_mean.get(&[i])?;
let new_mean = (1.0 - self.momentum) * old_mean + self.momentum * mean;
self.running_mean.set(&[i], new_mean)?;
let old_var = self.running_var.get(&[i])?;
let new_var = (1.0 - self.momentum) * old_var + self.momentum * var;
self.running_var.set(&[i], new_var)?;
}
}
Ok(())
}
fn normalize_triplets(
&self,
triplets: &[(usize, usize, f32)],
) -> TorshResult<Vec<(usize, usize, f32)>> {
let mut normalized = Vec::with_capacity(triplets.len());
for &(row, col, val) in triplets {
if col < self.num_features {
let mean = self.running_mean.get(&[col])?;
let var = self.running_var.get(&[col])?;
let normalized_val = (val - mean) / (var + self.eps).sqrt();
let final_val = if self.affine {
let weight = self
.weight
.as_ref()
.ok_or_else(|| {
TorshError::InvalidState(
"Weight not initialized for affine transformation".to_string(),
)
})?
.get(&[col])?;
let bias = self
.bias
.as_ref()
.ok_or_else(|| {
TorshError::InvalidState(
"Bias not initialized for affine transformation".to_string(),
)
})?
.get(&[col])?;
normalized_val * weight + bias
} else {
normalized_val
};
if final_val.abs() > 1e-10 {
normalized.push((row, col, final_val));
}
}
}
Ok(normalized)
}
pub fn num_features(&self) -> usize {
self.num_features
}
pub fn eps(&self) -> f32 {
self.eps
}
pub fn momentum(&self) -> f32 {
self.momentum
}
pub fn affine(&self) -> bool {
self.affine
}
pub fn running_mean(&self) -> &Tensor {
&self.running_mean
}
pub fn running_var(&self) -> &Tensor {
&self.running_var
}
}
#[derive(Debug, Clone)]
pub struct SparseLayerNorm {
normalized_shape: Vec<usize>,
eps: f32,
elementwise_affine: bool,
weight: Option<Tensor>,
bias: Option<Tensor>,
}
impl SparseLayerNorm {
pub fn new(
normalized_shape: Vec<usize>,
eps: f32,
elementwise_affine: bool,
) -> TorshResult<Self> {
if normalized_shape.is_empty() {
return Err(TorshError::InvalidArgument(
"Normalized shape cannot be empty".to_string(),
));
}
if eps < 0.0 {
return Err(TorshError::InvalidArgument(
"Epsilon must be non-negative".to_string(),
));
}
let total_elements: usize = normalized_shape.iter().product();
let (weight, bias) = if elementwise_affine {
let mut weight = zeros::<f32>(&[total_elements])?;
for i in 0..total_elements {
weight.set(&[i], 1.0)?;
}
let bias = zeros::<f32>(&[total_elements])?;
(Some(weight), Some(bias))
} else {
(None, None)
};
Ok(Self {
normalized_shape,
eps,
elementwise_affine,
weight,
bias,
})
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let normalized_triplets = self.normalize_by_groups(&triplets, &shape)?;
let (rows, cols, values) = unzip_triplets(normalized_triplets);
let normalized_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
SparseFormat::Coo => Ok(Box::new(normalized_coo)),
SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&normalized_coo)?)),
SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&normalized_coo)?)),
}
}
fn normalize_by_groups(
&self,
triplets: &[(usize, usize, f32)],
_shape: &Shape,
) -> TorshResult<Vec<(usize, usize, f32)>> {
let mut groups: HashMap<usize, Vec<(usize, usize, f32)>> = HashMap::new();
for &triplet in triplets {
let group_key = triplet.0; groups.entry(group_key).or_default().push(triplet);
}
let mut normalized = Vec::new();
for (_, group_triplets) in groups {
if group_triplets.is_empty() {
continue;
}
let values: Vec<f32> = group_triplets.iter().map(|&(_, _, v)| v).collect();
let mean = values.iter().sum::<f32>() / values.len() as f32;
let variance =
values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
let std_dev = (variance + self.eps).sqrt();
for (row, col, val) in group_triplets {
let normalized_val = (val - mean) / std_dev;
let final_val = if self.elementwise_affine {
let weight = self
.weight
.as_ref()
.ok_or_else(|| {
TorshError::InvalidState(
"Weight not initialized for elementwise affine transformation"
.to_string(),
)
})?
.get(&[col % self.weight.as_ref().expect("weight should be present").shape().numel()])?;
let bias = self
.bias
.as_ref()
.ok_or_else(|| {
TorshError::InvalidState(
"Bias not initialized for elementwise affine transformation"
.to_string(),
)
})?
.get(&[col % self.bias.as_ref().expect("bias should be present").shape().numel()])?;
normalized_val * weight + bias
} else {
normalized_val
};
if final_val.abs() > 1e-10 {
normalized.push((row, col, final_val));
}
}
}
Ok(normalized)
}
pub fn normalized_shape(&self) -> &[usize] {
&self.normalized_shape
}
pub fn eps(&self) -> f32 {
self.eps
}
pub fn elementwise_affine(&self) -> bool {
self.elementwise_affine
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse_tensor::SparseFormat;
#[test]
fn test_sparse_batch_norm_creation() {
let bn = SparseBatchNorm::new(64, 1e-5, 0.1, true).expect("Sparse Batch Norm should succeed");
assert_eq!(bn.num_features(), 64);
assert_eq!(bn.eps(), 1e-5);
assert_eq!(bn.momentum(), 0.1);
assert!(bn.affine());
assert!(bn.training());
}
#[test]
fn test_sparse_layer_norm_creation() {
let ln = SparseLayerNorm::new(vec![128], 1e-5, true).expect("Sparse Layer Norm should succeed");
assert_eq!(ln.normalized_shape(), &[128]);
assert_eq!(ln.eps(), 1e-5);
assert!(ln.elementwise_affine());
}
#[test]
fn test_batch_norm_training_mode() {
let mut bn = SparseBatchNorm::new(32, 1e-5, 0.1, false).expect("Sparse Batch Norm should succeed");
assert!(bn.training());
bn.train(false);
assert!(!bn.training());
bn.train(true);
assert!(bn.training());
}
#[test]
fn test_invalid_parameters() {
assert!(SparseBatchNorm::new(0, 1e-5, 0.1, true).is_err()); assert!(SparseBatchNorm::new(64, -1e-5, 0.1, true).is_err()); assert!(SparseBatchNorm::new(64, 1e-5, 1.5, true).is_err());
assert!(SparseLayerNorm::new(vec![], 1e-5, true).is_err()); assert!(SparseLayerNorm::new(vec![128], -1e-5, true).is_err()); }
#[test]
fn test_unzip_triplets() {
let triplets = vec![(0, 1, 2.0), (1, 0, 3.0), (2, 2, 4.0)];
let (rows, cols, values) = unzip_triplets(triplets);
assert_eq!(rows, vec![0, 1, 2]);
assert_eq!(cols, vec![1, 0, 2]);
assert_eq!(values, vec![2.0, 3.0, 4.0]);
}
#[test]
fn test_batch_norm_statistics() {
let mut bn = SparseBatchNorm::new(3, 1e-5, 0.1, false).expect("Sparse Batch Norm should succeed");
let triplets = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 1, 3.0), (1, 1, 4.0)];
bn.update_statistics(&triplets).expect("statistics update should succeed");
let mean0 = bn.running_mean().get(&[0]).expect("element retrieval should succeed for valid index");
let mean1 = bn.running_mean().get(&[1]).expect("element retrieval should succeed for valid index");
assert!(mean0 > 0.0); assert!(mean1 > 0.0); }
#[test]
fn test_layer_norm_groups() {
let ln = SparseLayerNorm::new(vec![4], 1e-5, false).expect("Sparse Layer Norm should succeed");
let triplets = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
let shape = Shape::new(vec![2, 4]);
let normalized = ln.normalize_by_groups(&triplets, &shape).expect("group normalization should succeed");
assert_eq!(normalized.len(), 4); }
#[test]
fn test_sparsity_preservation() {
let bn = SparseBatchNorm::new(4, 1e-5, 0.1, false).expect("Sparse Batch Norm should succeed");
let triplets = vec![(0, 0, 1.0), (0, 2, 2.0)];
let normalized = bn.normalize_triplets(&triplets).expect("triplet normalization should succeed");
assert!(normalized.len() <= triplets.len());
}
}