use std::collections::{HashMap, HashSet};
use std::fmt;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub struct SparseTensor<T: Float> {
indices: Vec<Vec<usize>>,
values: Vec<T>,
shape: Vec<usize>,
nnz: usize,
}
impl<T: Float> SparseTensor<T> {
pub fn new(
indices: Vec<Vec<usize>>,
values: Vec<T>,
shape: Vec<usize>,
) -> FerrotorchResult<Self> {
if indices.len() != values.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"indices length ({}) must equal values length ({})",
indices.len(),
values.len()
),
});
}
let ndim = shape.len();
for (i, idx) in indices.iter().enumerate() {
if idx.len() != ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index {} has {} dimensions but shape has {}",
i,
idx.len(),
ndim
),
});
}
for (axis, &coord) in idx.iter().enumerate() {
if coord >= shape[axis] {
return Err(FerrotorchError::IndexOutOfBounds {
index: coord,
axis,
size: shape[axis],
});
}
}
}
let nnz = values.len();
Ok(Self {
indices,
values,
shape,
nnz,
})
}
pub fn from_dense(tensor: &Tensor<T>, threshold: T) -> FerrotorchResult<Self> {
let data = tensor.data()?;
let shape = tensor.shape().to_vec();
let ndim = shape.len();
let mut indices = Vec::new();
let mut values = Vec::new();
for (flat_idx, &val) in data.iter().enumerate() {
if val.abs() > threshold {
let mut coord = vec![0usize; ndim];
let mut remaining = flat_idx;
for d in (0..ndim).rev() {
if shape[d] > 0 {
coord[d] = remaining % shape[d];
remaining /= shape[d];
}
}
indices.push(coord);
values.push(val);
}
}
let nnz = values.len();
Ok(Self {
indices,
values,
shape,
nnz,
})
}
pub fn to_dense(&self) -> FerrotorchResult<Tensor<T>> {
let numel: usize = self.shape.iter().product();
let mut data = vec![<T as num_traits::Zero>::zero(); numel];
let ndim = self.shape.len();
for (idx, &val) in self.indices.iter().zip(self.values.iter()) {
let mut flat = 0usize;
let mut stride = 1usize;
for d in (0..ndim).rev() {
flat += idx[d] * stride;
stride *= self.shape[d];
}
data[flat] += val;
}
Tensor::from_storage(TensorStorage::cpu(data), self.shape.clone(), false)
}
#[inline]
pub fn nnz(&self) -> usize {
self.nnz
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn indices(&self) -> &[Vec<usize>] {
&self.indices
}
pub fn spmm(&self, dense: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if self.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("spmm requires 2-D sparse tensor, got {}-D", self.ndim()),
});
}
if dense.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("spmm requires 2-D dense tensor, got {}-D", dense.ndim()),
});
}
let m = self.shape[0];
let k_sparse = self.shape[1];
let dense_shape = dense.shape();
let k_dense = dense_shape[0];
let n = dense_shape[1];
if k_sparse != k_dense {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"spmm inner dimensions mismatch: sparse [{}, {}] @ dense [{}, {}]",
m, k_sparse, k_dense, n
),
});
}
let dense_data = dense.data()?;
let mut output = vec![<T as num_traits::Zero>::zero(); m * n];
for (idx, &v) in self.indices.iter().zip(self.values.iter()) {
let i = idx[0];
let j = idx[1];
for col in 0..n {
output[i * n + col] += v * dense_data[j * n + col];
}
}
Tensor::from_storage(TensorStorage::cpu(output), vec![m, n], false)
}
pub fn mul_scalar(&self, scalar: T) -> Self {
let new_values: Vec<T> = self.values.iter().map(|&v| v * scalar).collect();
Self {
indices: self.indices.clone(),
values: new_values,
shape: self.shape.clone(),
nnz: self.nnz,
}
}
pub fn add(&self, other: &SparseTensor<T>) -> FerrotorchResult<SparseTensor<T>> {
if self.shape != other.shape {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cannot add sparse tensors with shapes {:?} and {:?}",
self.shape, other.shape
),
});
}
let mut indices = self.indices.clone();
indices.extend_from_slice(&other.indices);
let mut values = self.values.clone();
values.extend_from_slice(&other.values);
let nnz = values.len();
Ok(SparseTensor {
indices,
values,
shape: self.shape.clone(),
nnz,
})
}
pub fn coalesce(&self) -> SparseTensor<T> {
let mut map: HashMap<Vec<usize>, T> = HashMap::new();
for (idx, &val) in self.indices.iter().zip(self.values.iter()) {
let entry = map
.entry(idx.clone())
.or_insert_with(<T as num_traits::Zero>::zero);
*entry += val;
}
let mut pairs: Vec<(Vec<usize>, T)> = map
.into_iter()
.filter(|(_, val)| !<T as num_traits::Zero>::is_zero(val))
.collect();
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
let mut indices = Vec::with_capacity(pairs.len());
let mut values = Vec::with_capacity(pairs.len());
for (idx, val) in pairs {
indices.push(idx);
values.push(val);
}
let nnz = values.len();
SparseTensor {
indices,
values,
shape: self.shape.clone(),
nnz,
}
}
pub fn t(&self) -> FerrotorchResult<SparseTensor<T>> {
if self.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"transpose requires a 2-D sparse tensor, got {}-D",
self.ndim()
),
});
}
let new_indices: Vec<Vec<usize>> = self
.indices
.iter()
.map(|idx| vec![idx[1], idx[0]])
.collect();
let new_shape = vec![self.shape[1], self.shape[0]];
Ok(SparseTensor {
indices: new_indices,
values: self.values.clone(),
shape: new_shape,
nnz: self.nnz,
})
}
}
impl<T: Float> Clone for SparseTensor<T> {
fn clone(&self) -> Self {
Self {
indices: self.indices.clone(),
values: self.values.clone(),
shape: self.shape.clone(),
nnz: self.nnz,
}
}
}
impl<T: Float> fmt::Debug for SparseTensor<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SparseTensor")
.field("shape", &self.shape)
.field("nnz", &self.nnz)
.field("ndim", &self.shape.len())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct CooTensor<T: Float> {
row_indices: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
is_coalesced: bool,
}
impl<T: Float> CooTensor<T> {
pub fn new(
row_indices: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
) -> FerrotorchResult<Self> {
if row_indices.len() != values.len() || col_indices.len() != values.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"row_indices ({}), col_indices ({}), and values ({}) must have equal length",
row_indices.len(),
col_indices.len(),
values.len()
),
});
}
for (i, (&r, &c)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
if r >= nrows {
return Err(FerrotorchError::IndexOutOfBounds {
index: r,
axis: 0,
size: nrows,
});
}
if c >= ncols {
return Err(FerrotorchError::IndexOutOfBounds {
index: c,
axis: 1,
size: ncols,
});
}
let _ = i; }
Ok(Self {
row_indices,
col_indices,
values,
nrows,
ncols,
is_coalesced: false,
})
}
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn is_coalesced(&self) -> bool {
self.is_coalesced
}
#[inline]
pub fn row_indices(&self) -> &[usize] {
&self.row_indices
}
#[inline]
pub fn col_indices(&self) -> &[usize] {
&self.col_indices
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
pub fn coalesce(&self) -> Self {
let mut map: HashMap<(usize, usize), T> = HashMap::new();
for i in 0..self.values.len() {
let key = (self.row_indices[i], self.col_indices[i]);
let entry = map.entry(key).or_insert_with(<T as num_traits::Zero>::zero);
*entry += self.values[i];
}
let mut pairs: Vec<((usize, usize), T)> = map
.into_iter()
.filter(|(_, val)| !<T as num_traits::Zero>::is_zero(val))
.collect();
pairs.sort_by_key(|&((r, c), _)| (r, c));
let mut row_indices = Vec::with_capacity(pairs.len());
let mut col_indices = Vec::with_capacity(pairs.len());
let mut values = Vec::with_capacity(pairs.len());
for ((r, c), v) in pairs {
row_indices.push(r);
col_indices.push(c);
values.push(v);
}
Self {
row_indices,
col_indices,
values,
nrows: self.nrows,
ncols: self.ncols,
is_coalesced: true,
}
}
pub fn to_dense(&self) -> FerrotorchResult<Tensor<T>> {
let mut data = vec![<T as num_traits::Zero>::zero(); self.nrows * self.ncols];
for i in 0..self.values.len() {
let flat = self.row_indices[i] * self.ncols + self.col_indices[i];
data[flat] += self.values[i];
}
Tensor::from_storage(
TensorStorage::cpu(data),
vec![self.nrows, self.ncols],
false,
)
}
pub fn from_csr(csr: &CsrTensor<T>) -> Self {
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
for row in 0..csr.nrows {
let start = csr.row_ptrs[row];
let end = csr.row_ptrs[row + 1];
for j in start..end {
row_indices.push(row);
col_indices.push(csr.col_indices[j]);
values.push(csr.values[j]);
}
}
Self {
row_indices,
col_indices,
values,
nrows: csr.nrows,
ncols: csr.ncols,
is_coalesced: false,
}
}
}
#[derive(Debug, Clone)]
pub struct CsrTensor<T: Float> {
row_ptrs: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
}
impl<T: Float> CsrTensor<T> {
pub fn new(
row_ptrs: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
) -> FerrotorchResult<Self> {
if row_ptrs.len() != nrows + 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"row_ptrs length ({}) must be nrows + 1 ({})",
row_ptrs.len(),
nrows + 1
),
});
}
if col_indices.len() != values.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"col_indices length ({}) must equal values length ({})",
col_indices.len(),
values.len()
),
});
}
Ok(Self {
row_ptrs,
col_indices,
values,
nrows,
ncols,
})
}
pub fn from_coo(coo: &CooTensor<T>) -> FerrotorchResult<Self> {
let nrows = coo.nrows;
let ncols = coo.ncols;
let mut seen: HashSet<(usize, usize)> = HashSet::with_capacity(coo.nnz());
let mut row_entries: Vec<HashMap<usize, T>> = vec![HashMap::new(); nrows];
for i in 0..coo.values.len() {
let r = coo.row_indices[i];
let c = coo.col_indices[i];
seen.insert((r, c));
let entry = row_entries[r]
.entry(c)
.or_insert_with(<T as num_traits::Zero>::zero);
*entry += coo.values[i];
}
let mut row_ptrs = Vec::with_capacity(nrows + 1);
let mut col_indices = Vec::new();
let mut values = Vec::new();
row_ptrs.push(0);
for entry in row_entries.iter_mut().take(nrows) {
let mut cols: Vec<(usize, T)> = entry
.drain()
.filter(|(_, v)| !<T as num_traits::Zero>::is_zero(v))
.collect();
cols.sort_by_key(|&(c, _)| c);
for (c, v) in cols {
col_indices.push(c);
values.push(v);
}
row_ptrs.push(col_indices.len());
}
let _ = seen;
Ok(Self {
row_ptrs,
col_indices,
values,
nrows,
ncols,
})
}
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn row_ptrs(&self) -> &[usize] {
&self.row_ptrs
}
#[inline]
pub fn col_indices(&self) -> &[usize] {
&self.col_indices
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
pub fn to_dense(&self) -> FerrotorchResult<Tensor<T>> {
let mut data = vec![<T as num_traits::Zero>::zero(); self.nrows * self.ncols];
for row in 0..self.nrows {
let start = self.row_ptrs[row];
let end = self.row_ptrs[row + 1];
for j in start..end {
let flat = row * self.ncols + self.col_indices[j];
data[flat] += self.values[j];
}
}
Tensor::from_storage(
TensorStorage::cpu(data),
vec![self.nrows, self.ncols],
false,
)
}
}
#[derive(Debug, Clone)]
pub struct SemiStructuredSparseTensor<T: Float> {
values: Vec<T>,
mask: Vec<u8>,
shape: Vec<usize>,
}
impl<T: Float> SemiStructuredSparseTensor<T> {
pub fn compress(dense: &Tensor<T>) -> FerrotorchResult<Self> {
let data = dense.data_vec()?;
let numel = data.len();
if !numel.is_multiple_of(4) {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"SemiStructuredSparseTensor::compress: numel must be a \
multiple of 4, got {numel}"
),
});
}
let num_groups = numel / 4;
let mut values = Vec::with_capacity(num_groups * 2);
let mut mask = vec![0u8; num_groups.div_ceil(2)];
for g in 0..num_groups {
let base = g * 4;
let mut mags: [(usize, T); 4] = [
(0, data[base].abs()),
(1, data[base + 1].abs()),
(2, data[base + 2].abs()),
(3, data[base + 3].abs()),
];
mags.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
let mut kept = [mags[0].0, mags[1].0];
kept.sort();
values.push(data[base + kept[0]]);
values.push(data[base + kept[1]]);
let nibble: u8 = (1 << kept[0]) | (1 << kept[1]);
let byte = g / 2;
let shift = (g % 2) * 4;
mask[byte] |= nibble << shift;
}
Ok(Self {
values,
mask,
shape: dense.shape().to_vec(),
})
}
pub fn decompress(&self) -> FerrotorchResult<Tensor<T>> {
let numel = self.shape.iter().product::<usize>();
let mut out = vec![<T as num_traits::Zero>::zero(); numel];
let num_groups = numel / 4;
for g in 0..num_groups {
let byte = g / 2;
let shift = (g % 2) * 4;
let nibble = (self.mask[byte] >> shift) & 0xF;
let mut val_idx = g * 2;
for pos in 0..4 {
if (nibble >> pos) & 1 != 0 {
out[g * 4 + pos] = self.values[val_idx];
val_idx += 1;
}
}
}
Tensor::from_storage(TensorStorage::cpu(out), self.shape.clone(), false)
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn mask(&self) -> &[u8] {
&self.mask
}
#[inline]
pub fn num_groups(&self) -> usize {
self.shape.iter().product::<usize>() / 4
}
pub fn compression_ratio(&self) -> f64 {
let dense_bytes = (self.shape.iter().product::<usize>()) * std::mem::size_of::<T>();
if dense_bytes == 0 {
return 1.0;
}
let compressed =
self.values.len() * std::mem::size_of::<T>() + self.mask.len();
compressed as f64 / dense_bytes as f64
}
pub fn group_mask(&self, g: usize) -> u8 {
let byte = g / 2;
let shift = (g % 2) * 4;
(self.mask[byte] >> shift) & 0xF
}
}
pub fn sparse_matmul_24<T: Float>(
a: &Tensor<T>,
b: &SemiStructuredSparseTensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if a.shape().len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"sparse_matmul_24: `a` must be 2-D, got shape {:?}",
a.shape()
),
});
}
if b.shape().len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"sparse_matmul_24: `b` must be 2-D, got shape {:?}",
b.shape()
),
});
}
let m = a.shape()[0];
let k = a.shape()[1];
let kb = b.shape()[0];
let n = b.shape()[1];
if k != kb {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"sparse_matmul_24: inner dims mismatch: a.shape[1]={k} != b.shape[0]={kb}"
),
});
}
let b_dense = b.decompress()?;
let a_data = a.data_vec()?;
let b_data = b_dense.data_vec()?;
let mut out = vec![<T as num_traits::Zero>::zero(); m * n];
for i in 0..m {
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for kk in 0..k {
acc += a_data[i * k + kk] * b_data[kk * n + j];
}
out[i * n + j] = acc;
}
}
Tensor::from_storage(TensorStorage::cpu(out), vec![m, n], false)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_construction_and_accessors() {
let indices = vec![vec![0, 1], vec![1, 2], vec![2, 0]];
let values = vec![1.0f32, 2.0, 3.0];
let shape = vec![3, 3];
let sp = SparseTensor::new(indices.clone(), values.clone(), shape.clone()).unwrap();
assert_eq!(sp.nnz(), 3);
assert_eq!(sp.shape(), &[3, 3]);
assert_eq!(sp.ndim(), 2);
assert_eq!(sp.values(), &[1.0, 2.0, 3.0]);
assert_eq!(sp.indices(), &indices);
}
#[test]
fn test_from_dense_with_threshold() {
let data = vec![0.0f32, 0.0, 5.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0];
let tensor = Tensor::from_storage(TensorStorage::cpu(data), vec![3, 3], false).unwrap();
let sp = SparseTensor::from_dense(&tensor, 0.0).unwrap();
assert_eq!(sp.nnz(), 2);
assert_eq!(sp.shape(), &[3, 3]);
let dense = sp.to_dense().unwrap();
let d = dense.data().unwrap();
assert_eq!(d[0 * 3 + 2], 5.0); assert_eq!(d[2 * 3 + 0], 3.0); }
#[test]
fn test_from_dense_threshold_filters_small() {
let data = vec![0.5f32, 1.5, 0.1, 2.0];
let tensor = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 2], false).unwrap();
let sp = SparseTensor::from_dense(&tensor, 1.0).unwrap();
assert_eq!(sp.nnz(), 2);
let dense = sp.to_dense().unwrap();
let d = dense.data().unwrap();
assert_eq!(d[0], 0.0); assert_eq!(d[1], 1.5); assert_eq!(d[2], 0.0); assert_eq!(d[3], 2.0); }
#[test]
fn test_to_dense_round_trip() {
let data = vec![1.0f64, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0];
let original =
Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![3, 3], false).unwrap();
let sp = SparseTensor::from_dense(&original, 0.0).unwrap();
let reconstructed = sp.to_dense().unwrap();
let orig_data = original.data().unwrap();
let recon_data = reconstructed.data().unwrap();
for (a, b) in orig_data.iter().zip(recon_data.iter()) {
assert!((*a - *b).abs() < 1e-10, "mismatch: {} vs {}", a, b);
}
}
#[test]
fn test_spmm_matches_dense_mm() {
let sp = SparseTensor::new(
vec![vec![0, 0], vec![0, 2], vec![1, 1]],
vec![1.0f32, 2.0, 3.0],
vec![2, 3],
)
.unwrap();
let dense = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 4.0, 2.0, 5.0, 3.0, 6.0]),
vec![3, 2],
false,
)
.unwrap();
let result = sp.spmm(&dense).unwrap();
let d = result.data().unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert!((d[0] - 7.0).abs() < 1e-6);
assert!((d[1] - 16.0).abs() < 1e-6);
assert!((d[2] - 6.0).abs() < 1e-6);
assert!((d[3] - 15.0).abs() < 1e-6);
}
#[test]
fn test_spmm_identity() {
let sp = SparseTensor::new(
vec![vec![0, 0], vec![1, 1], vec![2, 2]],
vec![1.0f32; 3],
vec![3, 3],
)
.unwrap();
let dense = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]),
vec![3, 2],
false,
)
.unwrap();
let result = sp.spmm(&dense).unwrap();
let d = result.data().unwrap();
let expected = dense.data().unwrap();
assert_eq!(result.shape(), &[3, 2]);
for (a, b) in d.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_coalesce_merges_duplicates() {
let sp = SparseTensor::new(
vec![vec![0, 0], vec![0, 1], vec![0, 1]],
vec![1.0f32, 3.0, 4.0],
vec![1, 3],
)
.unwrap();
let coalesced = sp.coalesce();
assert_eq!(coalesced.nnz(), 2);
let dense = coalesced.to_dense().unwrap();
let d = dense.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6);
assert!((d[1] - 7.0).abs() < 1e-6);
assert!((d[2] - 0.0).abs() < 1e-6);
}
#[test]
fn test_coalesce_removes_zero_sum() {
let sp = SparseTensor::new(vec![vec![0, 0], vec![0, 0]], vec![5.0f32, -5.0], vec![1, 1])
.unwrap();
let coalesced = sp.coalesce();
assert_eq!(coalesced.nnz(), 0);
}
#[test]
fn test_transpose() {
let sp =
SparseTensor::new(vec![vec![0, 1], vec![2, 0]], vec![5.0f32, 3.0], vec![3, 4]).unwrap();
let transposed = sp.t().unwrap();
assert_eq!(transposed.shape(), &[4, 3]);
assert_eq!(transposed.nnz(), 2);
assert_eq!(transposed.indices()[0], vec![1, 0]);
assert_eq!(transposed.indices()[1], vec![0, 2]);
assert_eq!(transposed.values(), &[5.0, 3.0]);
}
#[test]
fn test_transpose_not_2d() {
let sp = SparseTensor::new(vec![vec![0, 1, 2]], vec![1.0f32], vec![3, 3, 3]).unwrap();
assert!(sp.t().is_err());
}
#[test]
fn test_mul_scalar() {
let sp =
SparseTensor::new(vec![vec![0, 0], vec![1, 1]], vec![2.0f64, 3.0], vec![2, 2]).unwrap();
let scaled = sp.mul_scalar(10.0);
assert_eq!(scaled.values(), &[20.0, 30.0]);
assert_eq!(scaled.nnz(), 2);
assert_eq!(scaled.shape(), &[2, 2]);
assert_eq!(scaled.indices(), sp.indices());
}
#[test]
fn test_add_sparse_tensors() {
let a =
SparseTensor::new(vec![vec![0, 0], vec![0, 1]], vec![1.0f32, 2.0], vec![2, 2]).unwrap();
let b =
SparseTensor::new(vec![vec![0, 1], vec![1, 0]], vec![3.0, 4.0], vec![2, 2]).unwrap();
let sum = a.add(&b).unwrap();
assert_eq!(sum.nnz(), 4);
let coalesced = sum.coalesce();
assert_eq!(coalesced.nnz(), 3);
let dense = coalesced.to_dense().unwrap();
let d = dense.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6); assert!((d[1] - 5.0).abs() < 1e-6); assert!((d[2] - 4.0).abs() < 1e-6); assert!((d[3] - 0.0).abs() < 1e-6); }
#[test]
fn test_add_shape_mismatch() {
let a = SparseTensor::<f32>::new(vec![], vec![], vec![2, 3]).unwrap();
let b = SparseTensor::<f32>::new(vec![], vec![], vec![3, 2]).unwrap();
assert!(a.add(&b).is_err());
}
#[test]
fn test_index_out_of_bounds() {
let result = SparseTensor::new(
vec![vec![3, 0]], vec![1.0f32],
vec![3, 3],
);
assert!(result.is_err());
let err = result.unwrap_err();
match err {
FerrotorchError::IndexOutOfBounds { index, axis, size } => {
assert_eq!(index, 3);
assert_eq!(axis, 0);
assert_eq!(size, 3);
}
other => panic!("expected IndexOutOfBounds, got: {other:?}"),
}
}
#[test]
fn test_index_out_of_bounds_second_axis() {
let result = SparseTensor::new(
vec![vec![0, 5]], vec![1.0f64],
vec![3, 3],
);
assert!(result.is_err());
match result.unwrap_err() {
FerrotorchError::IndexOutOfBounds { index, axis, size } => {
assert_eq!(index, 5);
assert_eq!(axis, 1);
assert_eq!(size, 3);
}
other => panic!("expected IndexOutOfBounds, got: {other:?}"),
}
}
#[test]
fn test_empty_sparse_tensor() {
let sp = SparseTensor::<f32>::new(vec![], vec![], vec![5, 5]).unwrap();
assert_eq!(sp.nnz(), 0);
assert_eq!(sp.shape(), &[5, 5]);
let dense = sp.to_dense().unwrap();
assert!(dense.data().unwrap().iter().all(|&x| x == 0.0));
}
#[test]
fn test_indices_values_length_mismatch() {
let result = SparseTensor::new(
vec![vec![0, 0], vec![1, 1]],
vec![1.0f32], vec![2, 2],
);
assert!(result.is_err());
}
#[test]
fn test_spmm_dimension_mismatch() {
let sp = SparseTensor::new(vec![vec![0, 0]], vec![1.0f32], vec![2, 3]).unwrap();
let dense =
Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 8]), vec![4, 2], false).unwrap();
assert!(sp.spmm(&dense).is_err());
}
#[test]
fn test_debug_format() {
let sp = SparseTensor::new(vec![vec![0, 0]], vec![1.0f32], vec![3, 3]).unwrap();
let debug = format!("{sp:?}");
assert!(debug.contains("SparseTensor"));
assert!(debug.contains("nnz: 1"));
}
#[test]
fn test_clone() {
let sp = SparseTensor::new(vec![vec![0, 1]], vec![42.0f32], vec![2, 2]).unwrap();
let sp2 = sp.clone();
assert_eq!(sp2.values(), &[42.0]);
assert_eq!(sp2.indices(), &[vec![0, 1]]);
assert_eq!(sp2.shape(), &[2, 2]);
}
#[test]
fn test_coo_coalesce_uses_tuple_key() {
let coo =
CooTensor::new(vec![0, 0, 1], vec![1, 1, 0], vec![3.0f32, 4.0, 5.0], 2, 2).unwrap();
let coalesced = coo.coalesce();
assert!(coalesced.is_coalesced());
assert_eq!(coalesced.nnz(), 2);
let dense = coalesced.to_dense().unwrap();
let d = dense.data().unwrap();
assert!((d[1] - 7.0).abs() < 1e-6); assert!((d[2] - 5.0).abs() < 1e-6); }
#[test]
fn test_coo_from_csr_not_coalesced() {
let csr = CsrTensor::new(vec![0, 1, 2], vec![0, 1], vec![1.0f32, 2.0], 2, 2).unwrap();
let coo = CooTensor::from_csr(&csr);
assert!(!coo.is_coalesced());
assert_eq!(coo.nnz(), 2);
}
#[test]
fn test_csr_from_coo_with_duplicates() {
let coo =
CooTensor::new(vec![0, 0, 1], vec![0, 0, 1], vec![1.0f32, 2.0, 3.0], 2, 2).unwrap();
let csr = CsrTensor::from_coo(&coo).unwrap();
assert_eq!(csr.nnz(), 2);
let dense = csr.to_dense().unwrap();
let d = dense.data().unwrap();
assert!((d[0] - 3.0).abs() < 1e-6); assert!((d[3] - 3.0).abs() < 1e-6); }
#[test]
fn test_coalesce_deterministic_order() {
let sp = SparseTensor::new(
vec![vec![1, 0], vec![0, 1], vec![0, 0]],
vec![3.0f32, 2.0, 1.0],
vec![2, 2],
)
.unwrap();
let coalesced = sp.coalesce();
assert_eq!(coalesced.indices()[0], vec![0, 0]);
assert_eq!(coalesced.indices()[1], vec![0, 1]);
assert_eq!(coalesced.indices()[2], vec![1, 0]);
}
#[test]
fn test_1d_sparse_tensor() {
let sp = SparseTensor::new(vec![vec![1], vec![4]], vec![10.0f32, 20.0], vec![5]).unwrap();
assert_eq!(sp.ndim(), 1);
assert_eq!(sp.nnz(), 2);
assert_eq!(sp.shape(), &[5]);
let dense = sp.to_dense().unwrap();
let d = dense.data().unwrap();
assert_eq!(d.len(), 5);
assert_eq!(d[0], 0.0);
assert_eq!(d[1], 10.0);
assert_eq!(d[2], 0.0);
assert_eq!(d[3], 0.0);
assert_eq!(d[4], 20.0);
}
#[test]
fn test_3d_sparse_tensor() {
let sp = SparseTensor::new(
vec![vec![0, 1, 2], vec![1, 0, 0]],
vec![5.0f64, 7.0],
vec![2, 2, 3],
)
.unwrap();
assert_eq!(sp.ndim(), 3);
assert_eq!(sp.nnz(), 2);
assert_eq!(sp.shape(), &[2, 2, 3]);
let dense = sp.to_dense().unwrap();
let d = dense.data().unwrap();
assert_eq!(d.len(), 12);
assert!((d[5] - 5.0).abs() < 1e-10);
assert!((d[6] - 7.0).abs() < 1e-10);
}
#[test]
fn test_zero_dimension_sparse_tensor() {
let sp = SparseTensor::<f32>::new(vec![], vec![], vec![0, 5]).unwrap();
assert_eq!(sp.ndim(), 2);
assert_eq!(sp.nnz(), 0);
assert_eq!(sp.shape(), &[0, 5]);
let dense = sp.to_dense().unwrap();
assert_eq!(dense.numel(), 0);
assert!(dense.data().unwrap().is_empty());
}
fn mk(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
#[test]
fn semi24_compress_keeps_two_largest_magnitudes_per_group() {
let t = mk(vec![1.0, 4.0, 2.0, 3.0, -5.0, 2.0, 0.0, 1.0], vec![8]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.values(), &[4.0, 3.0, -5.0, 2.0]);
assert_eq!(sp.mask(), &[0x3A]);
assert_eq!(sp.num_groups(), 2);
assert_eq!(sp.group_mask(0), 0xA);
assert_eq!(sp.group_mask(1), 0x3);
}
#[test]
fn semi24_decompress_roundtrips_compressed_values() {
let t = mk(vec![1.0, 4.0, 2.0, 3.0, -5.0, 2.0, 0.0, 1.0], vec![8]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
let dense = sp.decompress().unwrap();
let data = dense.data().unwrap();
assert_eq!(data, &[0.0, 4.0, 0.0, 3.0, -5.0, 2.0, 0.0, 0.0]);
assert_eq!(dense.shape(), &[8]);
}
#[test]
fn semi24_compress_decompress_preserves_shape() {
let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
let t = mk(data, vec![2, 8]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.shape(), &[2, 8]);
let dense = sp.decompress().unwrap();
assert_eq!(dense.shape(), &[2, 8]);
}
#[test]
fn semi24_rejects_non_multiple_of_4() {
let t = mk(vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![5]);
let result = SemiStructuredSparseTensor::compress(&t);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("multiple of 4"));
}
#[test]
fn semi24_tie_breaking_prefers_lower_position() {
let t = mk(vec![1.0, 1.0, 1.0, 1.0], vec![4]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.group_mask(0), 0x3);
assert_eq!(sp.values(), &[1.0, 1.0]);
}
#[test]
fn semi24_compression_ratio_is_roughly_half() {
let n = 1024usize;
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let t = mk(data, vec![n]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
let ratio = sp.compression_ratio();
assert!(ratio > 0.5 && ratio < 0.6, "unexpected ratio: {ratio}");
}
#[test]
fn semi24_zero_tensor_has_deterministic_mask() {
let t = mk(vec![0.0; 16], vec![16]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.values(), &[0.0; 8]);
for g in 0..4 {
assert_eq!(sp.group_mask(g), 0x3);
}
}
#[test]
fn semi24_negative_and_positive_by_magnitude() {
let t = mk(vec![-10.0, 1.0, -2.0, 3.0], vec![4]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.values(), &[-10.0, 3.0]);
assert_eq!(sp.group_mask(0), 0x9);
}
#[test]
fn semi24_sparse_matmul_matches_dense_matmul() {
let a = mk(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let b_data = vec![
1.0, 4.0, 2.0, 3.0, -5.0, 2.0, 0.0, 1.0, ];
let b_dense = mk(b_data.clone(), vec![2, 4]);
let b_sparse = SemiStructuredSparseTensor::compress(&b_dense).unwrap();
let out = sparse_matmul_24(&a, &b_sparse).unwrap();
assert_eq!(out.shape(), &[2, 4]);
let b_masked = b_sparse.decompress().unwrap();
let b_m = b_masked.data().unwrap();
let d = out.data().unwrap();
assert_eq!(d[0], 1.0 * b_m[0] + 2.0 * b_m[4]);
assert_eq!(d[1], 1.0 * b_m[1] + 2.0 * b_m[5]);
assert_eq!(d[2], 1.0 * b_m[2] + 2.0 * b_m[6]);
assert_eq!(d[3], 1.0 * b_m[3] + 2.0 * b_m[7]);
assert_eq!(d[4], 3.0 * b_m[0] + 4.0 * b_m[4]);
assert_eq!(d[5], 3.0 * b_m[1] + 4.0 * b_m[5]);
assert_eq!(d[6], 3.0 * b_m[2] + 4.0 * b_m[6]);
assert_eq!(d[7], 3.0 * b_m[3] + 4.0 * b_m[7]);
}
#[test]
fn semi24_sparse_matmul_rejects_non_2d_a() {
let a = mk(vec![1.0, 2.0, 3.0, 4.0], vec![4]); let b_dense = mk(vec![1.0; 16], vec![4, 4]);
let b_sparse = SemiStructuredSparseTensor::compress(&b_dense).unwrap();
let result = sparse_matmul_24(&a, &b_sparse);
assert!(result.is_err());
}
#[test]
fn semi24_sparse_matmul_rejects_inner_dim_mismatch() {
let a = mk(vec![1.0, 2.0, 3.0], vec![1, 3]); let b_dense = mk(vec![1.0; 16], vec![4, 4]); let b_sparse = SemiStructuredSparseTensor::compress(&b_dense).unwrap();
let result = sparse_matmul_24(&a, &b_sparse);
assert!(result.is_err());
}
#[test]
fn semi24_compress_then_decompress_matches_apply_2_4_mask() {
let t = mk(
vec![0.1, 0.9, 0.3, 0.5, -0.8, 0.2, 0.7, -0.4, 1.5, -2.0, 0.1, 0.3],
vec![12],
);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
let sp_dense = sp.decompress().unwrap();
let mask_result = crate::pruning::apply_2_4_mask(&t).unwrap();
assert_eq!(
sp_dense.data().unwrap(),
mask_result.data().unwrap(),
"compress+decompress must match apply_2_4_mask output"
);
}
#[test]
fn semi24_f64_parity() {
let t = Tensor::<f64>::from_storage(
TensorStorage::cpu(vec![1.0, 4.0, 2.0, 3.0, -5.0, 2.0, 0.0, 1.0]),
vec![8],
false,
)
.unwrap();
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.values(), &[4.0, 3.0, -5.0, 2.0]);
let dense = sp.decompress().unwrap();
let data = dense.data().unwrap();
assert_eq!(data, &[0.0, 4.0, 0.0, 3.0, -5.0, 2.0, 0.0, 0.0]);
}
}