use crate::{CooTensor, CscTensor, CsrTensor, SparseTensor, TorshResult};
use torsh_core::{Shape, TorshError};
use torsh_tensor::{creation::zeros, Tensor};
fn unzip_triplets(triplets: Vec<(usize, usize, f32)>) -> (Vec<usize>, Vec<usize>, Vec<f32>) {
triplets.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rows, mut cols, mut vals), (r, c, v)| {
rows.push(r);
cols.push(c);
vals.push(v);
(rows, cols, vals)
},
)
}
pub struct SparseBatchNorm {
num_features: usize,
eps: f32,
momentum: f32,
affine: bool,
training: bool,
running_mean: Tensor,
running_var: Tensor,
weight: Option<Tensor>,
bias: Option<Tensor>,
}
impl SparseBatchNorm {
pub fn new(num_features: usize, eps: f32, momentum: f32, affine: bool) -> TorshResult<Self> {
let running_mean = zeros::<f32>(&[num_features])?;
let running_var = zeros::<f32>(&[num_features])?;
let (weight, bias) = if affine {
let weight = zeros::<f32>(&[num_features])?;
for i in 0..num_features {
weight.set(&[i], 1.0)?;
}
let bias = zeros::<f32>(&[num_features])?;
(Some(weight), Some(bias))
} else {
(None, None)
};
Ok(Self {
num_features,
eps,
momentum,
affine,
training: true,
running_mean,
running_var,
weight,
bias,
})
}
pub fn train(&mut self, mode: bool) {
self.training = mode;
}
pub fn forward(&mut self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
if self.training {
self.update_statistics(&triplets)?;
}
let normalized_triplets = self.normalize_triplets(&triplets)?;
let (rows, cols, values) = unzip_triplets(normalized_triplets);
let normalized_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
crate::SparseFormat::Coo => Ok(Box::new(normalized_coo)),
crate::SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&normalized_coo)?)),
crate::SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&normalized_coo)?)),
_ => Ok(Box::new(normalized_coo)), }
}
fn update_statistics(&mut self, triplets: &[(usize, usize, f32)]) -> TorshResult<()> {
let mut feature_sums = vec![0.0f32; self.num_features];
let mut feature_counts = vec![0usize; self.num_features];
let mut feature_sq_sums = vec![0.0f32; self.num_features];
for &(_, col, val) in triplets {
if col < self.num_features {
feature_sums[col] += val;
feature_counts[col] += 1;
feature_sq_sums[col] += val * val;
}
}
for i in 0..self.num_features {
if feature_counts[i] > 0 {
let mean = feature_sums[i] / feature_counts[i] as f32;
let var = (feature_sq_sums[i] / feature_counts[i] as f32) - mean * mean;
let old_mean = self.running_mean.get(&[i])?;
let new_mean = (1.0 - self.momentum) * old_mean + self.momentum * mean;
self.running_mean.set(&[i], new_mean)?;
let old_var = self.running_var.get(&[i])?;
let new_var = (1.0 - self.momentum) * old_var + self.momentum * var;
self.running_var.set(&[i], new_var)?;
}
}
Ok(())
}
fn normalize_triplets(
&self,
triplets: &[(usize, usize, f32)],
) -> TorshResult<Vec<(usize, usize, f32)>> {
let mut normalized = Vec::with_capacity(triplets.len());
for &(row, col, val) in triplets {
if col < self.num_features {
let mean = self.running_mean.get(&[col])?;
let var = self.running_var.get(&[col])?;
let normalized_val = (val - mean) / (var + self.eps).sqrt();
let final_val = if self.affine {
let weight = self
.weight
.as_ref()
.ok_or_else(|| {
TorshError::InvalidState(
"Weight not initialized for affine transformation".to_string(),
)
})?
.get(&[col])?;
let bias = self
.bias
.as_ref()
.ok_or_else(|| {
TorshError::InvalidState(
"Bias not initialized for affine transformation".to_string(),
)
})?
.get(&[col])?;
normalized_val * weight + bias
} else {
normalized_val
};
normalized.push((row, col, final_val));
}
}
Ok(normalized)
}
}
pub struct SparseLayerNorm {
#[allow(dead_code)]
normalized_shape: Vec<usize>,
eps: f32,
elementwise_affine: bool,
weight: Option<Tensor>,
bias: Option<Tensor>,
}
impl SparseLayerNorm {
pub fn new(
normalized_shape: Vec<usize>,
eps: f32,
elementwise_affine: bool,
) -> TorshResult<Self> {
let total_elements: usize = normalized_shape.iter().product();
let (weight, bias) = if elementwise_affine {
let weight = zeros::<f32>(&[total_elements])?;
for i in 0..total_elements {
weight.set(&[i], 1.0)?;
}
let bias = zeros::<f32>(&[total_elements])?;
(Some(weight), Some(bias))
} else {
(None, None)
};
Ok(Self {
normalized_shape,
eps,
elementwise_affine,
weight,
bias,
})
}
pub fn forward(&self, input: &dyn SparseTensor) -> TorshResult<Box<dyn SparseTensor>> {
let coo = input.to_coo()?;
let triplets = coo.triplets();
let shape = input.shape().clone();
let normalized_triplets = self.normalize_by_groups(&triplets, &shape)?;
let (rows, cols, values) = unzip_triplets(normalized_triplets);
let normalized_coo = CooTensor::new(rows, cols, values, shape)?;
match input.format() {
crate::SparseFormat::Coo => Ok(Box::new(normalized_coo)),
crate::SparseFormat::Csr => Ok(Box::new(CsrTensor::from_coo(&normalized_coo)?)),
crate::SparseFormat::Csc => Ok(Box::new(CscTensor::from_coo(&normalized_coo)?)),
_ => Ok(Box::new(normalized_coo)), }
}
fn normalize_by_groups(
&self,
triplets: &[(usize, usize, f32)],
_shape: &Shape,
) -> TorshResult<Vec<(usize, usize, f32)>> {
let mut groups: std::collections::HashMap<usize, Vec<(usize, usize, f32)>> =
std::collections::HashMap::new();
for &triplet in triplets {
let group_key = triplet.0; groups.entry(group_key).or_default().push(triplet);
}
let mut normalized = Vec::new();
for (_, group_triplets) in groups {
let values: Vec<f32> = group_triplets.iter().map(|&(_, _, v)| v).collect();
let mean = values.iter().sum::<f32>() / values.len() as f32;
let variance =
values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
let std_dev = (variance + self.eps).sqrt();
for (row, col, val) in group_triplets {
let normalized_val = (val - mean) / std_dev;
let final_val = if self.elementwise_affine {
let weight = self
.weight
.as_ref()
.ok_or_else(|| {
TorshError::InvalidState(
"Weight not initialized for elementwise affine transformation"
.to_string(),
)
})?
.get(&[col])?;
let bias = self
.bias
.as_ref()
.ok_or_else(|| {
TorshError::InvalidState(
"Bias not initialized for elementwise affine transformation"
.to_string(),
)
})?
.get(&[col])?;
normalized_val * weight + bias
} else {
normalized_val
};
normalized.push((row, col, final_val));
}
}
Ok(normalized)
}
}