use std::collections::{HashMap, HashSet};
use std::fmt;
use crate::device::Device;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[inline]
fn vals_to_t<T, U>(vals: Vec<U>) -> Vec<T>
where
T: 'static,
U: 'static,
{
debug_assert_eq!(
std::any::TypeId::of::<T>(),
std::any::TypeId::of::<U>(),
"vals_to_t: TypeId mismatch — caller must ensure T == U"
);
debug_assert_eq!(std::mem::size_of::<T>(), std::mem::size_of::<U>());
debug_assert_eq!(std::mem::align_of::<T>(), std::mem::align_of::<U>());
let mut v = std::mem::ManuallyDrop::new(vals);
let len = v.len();
let cap = v.capacity();
let ptr = v.as_mut_ptr().cast::<T>();
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
fn csr_to_coo_t<T: Float>(
crow_indices: &[u32],
col_indices: &[u32],
values: Vec<T>,
) -> FerrotorchResult<(Vec<Vec<usize>>, Vec<T>)> {
if crow_indices.is_empty() {
return Ok((Vec::new(), values));
}
let m = crow_indices.len() - 1;
let nnz = values.len();
if col_indices.len() != nnz {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"csr_to_coo: col_indices ({}) and values ({}) must have equal length",
col_indices.len(),
nnz
),
});
}
let mut indices: Vec<Vec<usize>> = Vec::with_capacity(nnz);
for row in 0..m {
let start = crow_indices[row] as usize;
let end = crow_indices[row + 1] as usize;
for j in start..end {
indices.push(vec![row, col_indices[j] as usize]);
}
}
Ok((indices, values))
}
pub struct SparseTensor<T: Float> {
indices: Vec<Vec<usize>>,
values: Vec<T>,
shape: Vec<usize>,
nnz: usize,
}
impl<T: Float> SparseTensor<T> {
pub fn new(
indices: Vec<Vec<usize>>,
values: Vec<T>,
shape: Vec<usize>,
) -> FerrotorchResult<Self> {
if indices.len() != values.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"indices length ({}) must equal values length ({})",
indices.len(),
values.len()
),
});
}
let ndim = shape.len();
for (i, idx) in indices.iter().enumerate() {
if idx.len() != ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"index {} has {} dimensions but shape has {}",
i,
idx.len(),
ndim
),
});
}
for (axis, &coord) in idx.iter().enumerate() {
if coord >= shape[axis] {
return Err(FerrotorchError::IndexOutOfBounds {
index: coord,
axis,
size: shape[axis],
});
}
}
}
let nnz = values.len();
Ok(Self {
indices,
values,
shape,
nnz,
})
}
pub fn from_dense(tensor: &Tensor<T>, threshold: T) -> FerrotorchResult<Self> {
use std::any::TypeId;
if tensor.is_cuda()
&& <T as num_traits::Zero>::is_zero(&threshold)
&& tensor.ndim() == 2
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let dense_contig = tensor.contiguous()?;
let dense_handle = dense_contig.gpu_handle()?;
let m = dense_contig.shape()[0];
let n = dense_contig.shape()[1];
let csr_opt: Option<(Vec<Vec<usize>>, Vec<T>)> =
if TypeId::of::<T>() == TypeId::of::<f32>() {
let (crow, col, vals) = backend.dense_to_sparse_csr_f32(dense_handle, m, n)?;
let (idx, vals_t) = csr_to_coo_t::<T>(&crow, &col, vals_to_t::<T, f32>(vals))?;
Some((idx, vals_t))
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
let (crow, col, vals) = backend.dense_to_sparse_csr_f64(dense_handle, m, n)?;
let (idx, vals_t) = csr_to_coo_t::<T>(&crow, &col, vals_to_t::<T, f64>(vals))?;
Some((idx, vals_t))
} else {
None
};
if let Some((indices, values)) = csr_opt {
let nnz = values.len();
return Ok(Self {
indices,
values,
shape: vec![m, n],
nnz,
});
}
}
let data = tensor.data()?;
let shape = tensor.shape().to_vec();
let ndim = shape.len();
let mut indices = Vec::new();
let mut values = Vec::new();
for (flat_idx, &val) in data.iter().enumerate() {
if val.abs() > threshold {
let mut coord = vec![0usize; ndim];
let mut remaining = flat_idx;
for d in (0..ndim).rev() {
if shape[d] > 0 {
coord[d] = remaining % shape[d];
remaining /= shape[d];
}
}
indices.push(coord);
values.push(val);
}
}
let nnz = values.len();
Ok(Self {
indices,
values,
shape,
nnz,
})
}
pub fn to_dense(&self) -> FerrotorchResult<Tensor<T>> {
self.to_dense_on(Device::Cpu)
}
pub fn to_dense_on(&self, device: Device) -> FerrotorchResult<Tensor<T>> {
use std::any::TypeId;
if let Device::Cuda(_) = device {
if self.ndim() == 2
&& (TypeId::of::<T>() == TypeId::of::<f32>()
|| TypeId::of::<T>() == TypeId::of::<f64>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let m = self.shape[0];
let n = self.shape[1];
let coalesced = self.coalesce();
let nnz = coalesced.nnz;
let mut crow_indices: Vec<u32> = vec![0; m + 1];
for idx in &coalesced.indices {
let row = idx[0];
if row >= m {
return Err(FerrotorchError::IndexOutOfBounds {
index: row,
axis: 0,
size: m,
});
}
crow_indices[row + 1] += 1;
}
for r in 0..m {
crow_indices[r + 1] += crow_indices[r];
}
let mut col_indices: Vec<u32> = Vec::with_capacity(nnz);
let mut values_csr: Vec<T> = Vec::with_capacity(nnz);
for (idx, &v) in coalesced.indices.iter().zip(coalesced.values.iter()) {
col_indices.push(idx[1] as u32);
values_csr.push(v);
}
let device_ord = match device {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let out_handle = if TypeId::of::<T>() == TypeId::of::<f32>() {
let values_f32 = unsafe {
std::slice::from_raw_parts(
values_csr.as_ptr().cast::<f32>(),
values_csr.len(),
)
};
backend.sparse_to_dense_csr_f32(
&crow_indices,
&col_indices,
values_f32,
device_ord,
m,
n,
)?
} else {
let values_f64 = unsafe {
std::slice::from_raw_parts(
values_csr.as_ptr().cast::<f64>(),
values_csr.len(),
)
};
backend.sparse_to_dense_csr_f64(
&crow_indices,
&col_indices,
values_f64,
device_ord,
m,
n,
)?
};
let storage = TensorStorage::gpu(out_handle);
return Tensor::from_storage(storage, vec![m, n], false);
}
let cpu_dense = self.to_dense_cpu()?;
return cpu_dense.to(device);
}
self.to_dense_cpu()
}
fn to_dense_cpu(&self) -> FerrotorchResult<Tensor<T>> {
let numel: usize = self.shape.iter().product();
let mut data = vec![<T as num_traits::Zero>::zero(); numel];
let ndim = self.shape.len();
for (idx, &val) in self.indices.iter().zip(self.values.iter()) {
let mut flat = 0usize;
let mut stride = 1usize;
for d in (0..ndim).rev() {
flat += idx[d] * stride;
stride *= self.shape[d];
}
data[flat] += val;
}
Tensor::from_storage(TensorStorage::cpu(data), self.shape.clone(), false)
}
#[inline]
pub fn nnz(&self) -> usize {
self.nnz
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn indices(&self) -> &[Vec<usize>] {
&self.indices
}
pub fn spmm(&self, dense: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
use std::any::TypeId;
if self.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("spmm requires 2-D sparse tensor, got {}-D", self.ndim()),
});
}
if dense.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!("spmm requires 2-D dense tensor, got {}-D", dense.ndim()),
});
}
let m = self.shape[0];
let k_sparse = self.shape[1];
let dense_shape = dense.shape();
let k_dense = dense_shape[0];
let n = dense_shape[1];
if k_sparse != k_dense {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"spmm inner dimensions mismatch: sparse [{m}, {k_sparse}] @ dense [{k_dense}, {n}]"
),
});
}
if dense.is_cuda()
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let dense_contig = dense.contiguous()?;
let dense_handle = dense_contig.gpu_handle()?;
let coalesced = self.coalesce();
let nnz = coalesced.nnz;
let mut crow_indices: Vec<u32> = vec![0; m + 1];
for idx in &coalesced.indices {
let row = idx[0];
if row >= m {
return Err(FerrotorchError::IndexOutOfBounds {
index: row,
axis: 0,
size: m,
});
}
crow_indices[row + 1] += 1;
}
for r in 0..m {
crow_indices[r + 1] += crow_indices[r];
}
let mut col_indices: Vec<u32> = Vec::with_capacity(nnz);
let mut values_csr: Vec<T> = Vec::with_capacity(nnz);
for (idx, &v) in coalesced.indices.iter().zip(coalesced.values.iter()) {
col_indices.push(idx[1] as u32);
values_csr.push(v);
}
let out_handle_opt = if TypeId::of::<T>() == TypeId::of::<f32>() {
let values_f32 = unsafe {
std::slice::from_raw_parts(values_csr.as_ptr().cast::<f32>(), values_csr.len())
};
Some(backend.spmm_csr_f32(
&crow_indices,
&col_indices,
values_f32,
dense_handle,
m,
k_sparse,
n,
)?)
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
let values_f64 = unsafe {
std::slice::from_raw_parts(values_csr.as_ptr().cast::<f64>(), values_csr.len())
};
Some(backend.spmm_csr_f64(
&crow_indices,
&col_indices,
values_f64,
dense_handle,
m,
k_sparse,
n,
)?)
} else {
None
};
if let Some(out_handle) = out_handle_opt {
let storage = TensorStorage::gpu(out_handle);
return Tensor::from_storage(storage, vec![m, n], false);
}
}
let dense_data = dense.data()?;
let mut output = vec![<T as num_traits::Zero>::zero(); m * n];
for (idx, &v) in self.indices.iter().zip(self.values.iter()) {
let i = idx[0];
let j = idx[1];
for col in 0..n {
output[i * n + col] += v * dense_data[j * n + col];
}
}
Tensor::from_storage(TensorStorage::cpu(output), vec![m, n], false)
}
pub fn mul_scalar(&self, scalar: T) -> Self {
let new_values: Vec<T> = self.values.iter().map(|&v| v * scalar).collect();
Self {
indices: self.indices.clone(),
values: new_values,
shape: self.shape.clone(),
nnz: self.nnz,
}
}
pub fn add(&self, other: &SparseTensor<T>) -> FerrotorchResult<SparseTensor<T>> {
if self.shape != other.shape {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cannot add sparse tensors with shapes {:?} and {:?}",
self.shape, other.shape
),
});
}
let mut indices = self.indices.clone();
indices.extend_from_slice(&other.indices);
let mut values = self.values.clone();
values.extend_from_slice(&other.values);
let nnz = values.len();
Ok(SparseTensor {
indices,
values,
shape: self.shape.clone(),
nnz,
})
}
pub fn coalesce(&self) -> SparseTensor<T> {
let mut map: HashMap<Vec<usize>, T> = HashMap::new();
for (idx, &val) in self.indices.iter().zip(self.values.iter()) {
let entry = map
.entry(idx.clone())
.or_insert_with(<T as num_traits::Zero>::zero);
*entry += val;
}
let mut pairs: Vec<(Vec<usize>, T)> = map
.into_iter()
.filter(|(_, val)| !<T as num_traits::Zero>::is_zero(val))
.collect();
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
let mut indices = Vec::with_capacity(pairs.len());
let mut values = Vec::with_capacity(pairs.len());
for (idx, val) in pairs {
indices.push(idx);
values.push(val);
}
let nnz = values.len();
SparseTensor {
indices,
values,
shape: self.shape.clone(),
nnz,
}
}
pub fn t(&self) -> FerrotorchResult<SparseTensor<T>> {
if self.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"transpose requires a 2-D sparse tensor, got {}-D",
self.ndim()
),
});
}
let new_indices: Vec<Vec<usize>> = self
.indices
.iter()
.map(|idx| vec![idx[1], idx[0]])
.collect();
let new_shape = vec![self.shape[1], self.shape[0]];
Ok(SparseTensor {
indices: new_indices,
values: self.values.clone(),
shape: new_shape,
nnz: self.nnz,
})
}
}
impl<T: Float> Clone for SparseTensor<T> {
fn clone(&self) -> Self {
Self {
indices: self.indices.clone(),
values: self.values.clone(),
shape: self.shape.clone(),
nnz: self.nnz,
}
}
}
impl<T: Float> fmt::Debug for SparseTensor<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SparseTensor")
.field("shape", &self.shape)
.field("nnz", &self.nnz)
.field("ndim", &self.shape.len())
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct CooTensor<T: Float> {
row_indices: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
is_coalesced: bool,
}
impl<T: Float> CooTensor<T> {
pub fn new(
row_indices: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
) -> FerrotorchResult<Self> {
if row_indices.len() != values.len() || col_indices.len() != values.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"row_indices ({}), col_indices ({}), and values ({}) must have equal length",
row_indices.len(),
col_indices.len(),
values.len()
),
});
}
for (i, (&r, &c)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
if r >= nrows {
return Err(FerrotorchError::IndexOutOfBounds {
index: r,
axis: 0,
size: nrows,
});
}
if c >= ncols {
return Err(FerrotorchError::IndexOutOfBounds {
index: c,
axis: 1,
size: ncols,
});
}
let _ = i; }
Ok(Self {
row_indices,
col_indices,
values,
nrows,
ncols,
is_coalesced: false,
})
}
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn is_coalesced(&self) -> bool {
self.is_coalesced
}
#[inline]
pub fn row_indices(&self) -> &[usize] {
&self.row_indices
}
#[inline]
pub fn col_indices(&self) -> &[usize] {
&self.col_indices
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
pub fn coalesce(&self) -> Self {
let mut map: HashMap<(usize, usize), T> = HashMap::new();
for i in 0..self.values.len() {
let key = (self.row_indices[i], self.col_indices[i]);
let entry = map.entry(key).or_insert_with(<T as num_traits::Zero>::zero);
*entry += self.values[i];
}
let mut pairs: Vec<((usize, usize), T)> = map
.into_iter()
.filter(|(_, val)| !<T as num_traits::Zero>::is_zero(val))
.collect();
pairs.sort_by_key(|&((r, c), _)| (r, c));
let mut row_indices = Vec::with_capacity(pairs.len());
let mut col_indices = Vec::with_capacity(pairs.len());
let mut values = Vec::with_capacity(pairs.len());
for ((r, c), v) in pairs {
row_indices.push(r);
col_indices.push(c);
values.push(v);
}
Self {
row_indices,
col_indices,
values,
nrows: self.nrows,
ncols: self.ncols,
is_coalesced: true,
}
}
pub fn to_dense(&self) -> FerrotorchResult<Tensor<T>> {
self.to_dense_on(Device::Cpu)
}
pub fn to_dense_on(&self, device: Device) -> FerrotorchResult<Tensor<T>> {
use std::any::TypeId;
if let Device::Cuda(ord) = device
&& (TypeId::of::<T>() == TypeId::of::<f32>()
|| TypeId::of::<T>() == TypeId::of::<f64>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let coalesced = self.coalesce();
let row_u32: Vec<u32> = coalesced.row_indices.iter().map(|&v| v as u32).collect();
let col_u32: Vec<u32> = coalesced.col_indices.iter().map(|&v| v as u32).collect();
let (crow_u32, col_csr) = if TypeId::of::<T>() == TypeId::of::<f32>() {
let vals_f32 = unsafe {
std::slice::from_raw_parts(
coalesced.values.as_ptr().cast::<f32>(),
coalesced.values.len(),
)
};
let (cr, ci, _v) = backend.coo_to_csr_f32(
&row_u32,
&col_u32,
vals_f32,
ord,
coalesced.nrows,
coalesced.ncols,
)?;
(cr, ci)
} else {
let vals_f64 = unsafe {
std::slice::from_raw_parts(
coalesced.values.as_ptr().cast::<f64>(),
coalesced.values.len(),
)
};
let (cr, ci, _v) = backend.coo_to_csr_f64(
&row_u32,
&col_u32,
vals_f64,
ord,
coalesced.nrows,
coalesced.ncols,
)?;
(cr, ci)
};
let out_handle = if TypeId::of::<T>() == TypeId::of::<f32>() {
let values_f32 = unsafe {
std::slice::from_raw_parts(
coalesced.values.as_ptr().cast::<f32>(),
coalesced.values.len(),
)
};
backend.sparse_to_dense_csr_f32(
&crow_u32,
&col_csr,
values_f32,
ord,
coalesced.nrows,
coalesced.ncols,
)?
} else {
let values_f64 = unsafe {
std::slice::from_raw_parts(
coalesced.values.as_ptr().cast::<f64>(),
coalesced.values.len(),
)
};
backend.sparse_to_dense_csr_f64(
&crow_u32,
&col_csr,
values_f64,
ord,
coalesced.nrows,
coalesced.ncols,
)?
};
let storage = TensorStorage::gpu(out_handle);
return Tensor::from_storage(storage, vec![coalesced.nrows, coalesced.ncols], false);
}
let mut data = vec![<T as num_traits::Zero>::zero(); self.nrows * self.ncols];
for i in 0..self.values.len() {
let flat = self.row_indices[i] * self.ncols + self.col_indices[i];
data[flat] += self.values[i];
}
let cpu_tensor = Tensor::from_storage(
TensorStorage::cpu(data),
vec![self.nrows, self.ncols],
false,
)?;
if matches!(device, Device::Cpu) {
Ok(cpu_tensor)
} else {
cpu_tensor.to(device)
}
}
pub fn from_csr(csr: &CsrTensor<T>) -> Self {
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
for row in 0..csr.nrows {
let start = csr.row_ptrs[row];
let end = csr.row_ptrs[row + 1];
for j in start..end {
row_indices.push(row);
col_indices.push(csr.col_indices[j]);
values.push(csr.values[j]);
}
}
Self {
row_indices,
col_indices,
values,
nrows: csr.nrows,
ncols: csr.ncols,
is_coalesced: false,
}
}
}
#[derive(Debug, Clone)]
pub struct CsrTensor<T: Float> {
row_ptrs: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
}
impl<T: Float> CsrTensor<T> {
pub fn new(
row_ptrs: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
) -> FerrotorchResult<Self> {
if row_ptrs.len() != nrows + 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"row_ptrs length ({}) must be nrows + 1 ({})",
row_ptrs.len(),
nrows + 1
),
});
}
if col_indices.len() != values.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"col_indices length ({}) must equal values length ({})",
col_indices.len(),
values.len()
),
});
}
Ok(Self {
row_ptrs,
col_indices,
values,
nrows,
ncols,
})
}
pub fn from_coo(coo: &CooTensor<T>) -> FerrotorchResult<Self> {
let nrows = coo.nrows;
let ncols = coo.ncols;
let mut seen: HashSet<(usize, usize)> = HashSet::with_capacity(coo.nnz());
let mut row_entries: Vec<HashMap<usize, T>> = vec![HashMap::new(); nrows];
for i in 0..coo.values.len() {
let r = coo.row_indices[i];
let c = coo.col_indices[i];
seen.insert((r, c));
let entry = row_entries[r]
.entry(c)
.or_insert_with(<T as num_traits::Zero>::zero);
*entry += coo.values[i];
}
let mut row_ptrs = Vec::with_capacity(nrows + 1);
let mut col_indices = Vec::new();
let mut values = Vec::new();
row_ptrs.push(0);
for entry in row_entries.iter_mut().take(nrows) {
let mut cols: Vec<(usize, T)> = entry
.drain()
.filter(|(_, v)| !<T as num_traits::Zero>::is_zero(v))
.collect();
cols.sort_by_key(|&(c, _)| c);
for (c, v) in cols {
col_indices.push(c);
values.push(v);
}
row_ptrs.push(col_indices.len());
}
let _ = seen;
Ok(Self {
row_ptrs,
col_indices,
values,
nrows,
ncols,
})
}
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn row_ptrs(&self) -> &[usize] {
&self.row_ptrs
}
#[inline]
pub fn col_indices(&self) -> &[usize] {
&self.col_indices
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
pub fn to_dense(&self) -> FerrotorchResult<Tensor<T>> {
self.to_dense_on(Device::Cpu)
}
pub fn to_dense_on(&self, device: Device) -> FerrotorchResult<Tensor<T>> {
use std::any::TypeId;
if let Device::Cuda(ord) = device
&& (TypeId::of::<T>() == TypeId::of::<f32>()
|| TypeId::of::<T>() == TypeId::of::<f64>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let crow: Vec<u32> = self.row_ptrs.iter().map(|&v| v as u32).collect();
let col: Vec<u32> = self.col_indices.iter().map(|&v| v as u32).collect();
let out_handle = if TypeId::of::<T>() == TypeId::of::<f32>() {
let values_f32 = unsafe {
std::slice::from_raw_parts(
self.values.as_ptr().cast::<f32>(),
self.values.len(),
)
};
backend
.sparse_to_dense_csr_f32(&crow, &col, values_f32, ord, self.nrows, self.ncols)?
} else {
let values_f64 = unsafe {
std::slice::from_raw_parts(
self.values.as_ptr().cast::<f64>(),
self.values.len(),
)
};
backend
.sparse_to_dense_csr_f64(&crow, &col, values_f64, ord, self.nrows, self.ncols)?
};
let storage = TensorStorage::gpu(out_handle);
return Tensor::from_storage(storage, vec![self.nrows, self.ncols], false);
}
let mut data = vec![<T as num_traits::Zero>::zero(); self.nrows * self.ncols];
for row in 0..self.nrows {
let start = self.row_ptrs[row];
let end = self.row_ptrs[row + 1];
for j in start..end {
let flat = row * self.ncols + self.col_indices[j];
data[flat] += self.values[j];
}
}
let cpu_tensor = Tensor::from_storage(
TensorStorage::cpu(data),
vec![self.nrows, self.ncols],
false,
)?;
if matches!(device, Device::Cpu) {
Ok(cpu_tensor)
} else {
cpu_tensor.to(device)
}
}
pub fn from_coo_on(coo: &CooTensor<T>, device: Device) -> FerrotorchResult<Self> {
use std::any::TypeId;
if let Device::Cuda(ord) = device
&& (TypeId::of::<T>() == TypeId::of::<f32>()
|| TypeId::of::<T>() == TypeId::of::<f64>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let coalesced = coo.coalesce();
let row_u32: Vec<u32> = coalesced.row_indices.iter().map(|&v| v as u32).collect();
let col_u32: Vec<u32> = coalesced.col_indices.iter().map(|&v| v as u32).collect();
let (crow, col, vals_t) = if TypeId::of::<T>() == TypeId::of::<f32>() {
let vals_f32 = unsafe {
std::slice::from_raw_parts(
coalesced.values.as_ptr().cast::<f32>(),
coalesced.values.len(),
)
};
let (cr, ci, vals) = backend.coo_to_csr_f32(
&row_u32,
&col_u32,
vals_f32,
ord,
coalesced.nrows,
coalesced.ncols,
)?;
(cr, ci, vals_to_t::<T, f32>(vals))
} else {
let vals_f64 = unsafe {
std::slice::from_raw_parts(
coalesced.values.as_ptr().cast::<f64>(),
coalesced.values.len(),
)
};
let (cr, ci, vals) = backend.coo_to_csr_f64(
&row_u32,
&col_u32,
vals_f64,
ord,
coalesced.nrows,
coalesced.ncols,
)?;
(cr, ci, vals_to_t::<T, f64>(vals))
};
let row_ptrs: Vec<usize> = crow.into_iter().map(|v| v as usize).collect();
let col_indices: Vec<usize> = col.into_iter().map(|v| v as usize).collect();
return Self::new(
row_ptrs,
col_indices,
vals_t,
coalesced.nrows,
coalesced.ncols,
);
}
Self::from_coo(coo)
}
}
#[derive(Debug, Clone)]
pub struct SemiStructuredSparseTensor<T: Float> {
values: Vec<T>,
mask: Vec<u8>,
shape: Vec<usize>,
}
impl<T: Float> SemiStructuredSparseTensor<T> {
pub fn compress(dense: &Tensor<T>) -> FerrotorchResult<Self> {
let data = dense.data_vec()?;
let numel = data.len();
if numel % 4 != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"SemiStructuredSparseTensor::compress: numel must be a \
multiple of 4, got {numel}"
),
});
}
let num_groups = numel / 4;
let mut values = Vec::with_capacity(num_groups * 2);
let mut mask = vec![0u8; num_groups.div_ceil(2)];
for g in 0..num_groups {
let base = g * 4;
let mut mags: [(usize, T); 4] = [
(0, data[base].abs()),
(1, data[base + 1].abs()),
(2, data[base + 2].abs()),
(3, data[base + 3].abs()),
];
mags.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
let mut kept = [mags[0].0, mags[1].0];
kept.sort_unstable();
values.push(data[base + kept[0]]);
values.push(data[base + kept[1]]);
let nibble: u8 = (1 << kept[0]) | (1 << kept[1]);
let byte = g / 2;
let shift = (g % 2) * 4;
mask[byte] |= nibble << shift;
}
Ok(Self {
values,
mask,
shape: dense.shape().to_vec(),
})
}
pub fn decompress(&self) -> FerrotorchResult<Tensor<T>> {
let numel = self.shape.iter().product::<usize>();
let mut out = vec![<T as num_traits::Zero>::zero(); numel];
let num_groups = numel / 4;
for g in 0..num_groups {
let byte = g / 2;
let shift = (g % 2) * 4;
let nibble = (self.mask[byte] >> shift) & 0xF;
let mut val_idx = g * 2;
for pos in 0..4 {
if (nibble >> pos) & 1 != 0 {
out[g * 4 + pos] = self.values[val_idx];
val_idx += 1;
}
}
}
Tensor::from_storage(TensorStorage::cpu(out), self.shape.clone(), false)
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn mask(&self) -> &[u8] {
&self.mask
}
#[inline]
pub fn num_groups(&self) -> usize {
self.shape.iter().product::<usize>() / 4
}
pub fn compression_ratio(&self) -> f64 {
let dense_bytes = (self.shape.iter().product::<usize>()) * std::mem::size_of::<T>();
if dense_bytes == 0 {
return 1.0;
}
let compressed = self.values.len() * std::mem::size_of::<T>() + self.mask.len();
compressed as f64 / dense_bytes as f64
}
pub fn group_mask(&self, g: usize) -> u8 {
let byte = g / 2;
let shift = (g % 2) * 4;
(self.mask[byte] >> shift) & 0xF
}
}
pub fn sparse_matmul_24<T: Float>(
a: &Tensor<T>,
b: &SemiStructuredSparseTensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if a.shape().len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"sparse_matmul_24: `a` must be 2-D, got shape {:?}",
a.shape()
),
});
}
if b.shape().len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"sparse_matmul_24: `b` must be 2-D, got shape {:?}",
b.shape()
),
});
}
let m = a.shape()[0];
let k = a.shape()[1];
let kb = b.shape()[0];
let n = b.shape()[1];
if k != kb {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"sparse_matmul_24: inner dims mismatch: a.shape[1]={k} != b.shape[0]={kb}"
),
});
}
if a.is_cuda()
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f32>() {
let b_dense_cpu = b.decompress()?;
let b_dense_data = b_dense_cpu.data_vec()?;
let ordinal = a.gpu_handle()?.device_ordinal();
let b_bytes = unsafe {
std::slice::from_raw_parts(
b_dense_data.as_ptr().cast::<u8>(),
b_dense_data.len() * std::mem::size_of::<f32>(),
)
};
let b_handle = backend.cpu_to_gpu(b_bytes, crate::dtype::DType::F32, ordinal)?;
let a_handle = a.gpu_handle()?;
match backend.sparse_matmul_24_f32(a_handle, &b_handle, m, k, n) {
Ok(out_handle) => {
let storage = TensorStorage::gpu(out_handle);
return Tensor::from_storage(storage, vec![m, n], false);
}
Err(_) => {
let _ = b_handle;
}
}
}
}
let b_dense = b.decompress()?;
let a_data = a.data_vec()?;
let b_data = b_dense.data_vec()?;
let mut out = vec![<T as num_traits::Zero>::zero(); m * n];
for i in 0..m {
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for kk in 0..k {
acc += a_data[i * k + kk] * b_data[kk * n + j];
}
out[i * n + j] = acc;
}
}
Tensor::from_storage(TensorStorage::cpu(out), vec![m, n], false)
}
#[derive(Debug, Clone)]
pub struct CscTensor<T: Float> {
col_ptrs: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
}
impl<T: Float> CscTensor<T> {
pub fn new(
col_ptrs: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
) -> FerrotorchResult<Self> {
if col_ptrs.len() != ncols + 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"CscTensor: col_ptrs length ({}) must be ncols + 1 ({})",
col_ptrs.len(),
ncols + 1
),
});
}
if row_indices.len() != values.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"CscTensor: row_indices length ({}) must equal values length ({})",
row_indices.len(),
values.len()
),
});
}
for &r in &row_indices {
if r >= nrows {
return Err(FerrotorchError::InvalidArgument {
message: format!("CscTensor: row index {r} >= nrows {nrows}"),
});
}
}
Ok(Self {
col_ptrs,
row_indices,
values,
nrows,
ncols,
})
}
pub fn from_csr(csr: &CsrTensor<T>) -> Self {
let nrows = csr.nrows;
let ncols = csr.ncols;
let mut counts = vec![0usize; ncols];
for &c in &csr.col_indices {
counts[c] += 1;
}
let mut col_ptrs = vec![0usize; ncols + 1];
for j in 0..ncols {
col_ptrs[j + 1] = col_ptrs[j] + counts[j];
}
let nnz = csr.values.len();
let mut row_indices = vec![0usize; nnz];
let mut values = vec![<T as num_traits::Zero>::zero(); nnz];
let mut cursor = col_ptrs.clone();
for r in 0..nrows {
let start = csr.row_ptrs[r];
let end = csr.row_ptrs[r + 1];
for k in start..end {
let c = csr.col_indices[k];
let dst = cursor[c];
row_indices[dst] = r;
values[dst] = csr.values[k];
cursor[c] += 1;
}
}
Self {
col_ptrs,
row_indices,
values,
nrows,
ncols,
}
}
pub fn to_csr(&self) -> CsrTensor<T> {
let mut counts = vec![0usize; self.nrows];
for &r in &self.row_indices {
counts[r] += 1;
}
let mut row_ptrs = vec![0usize; self.nrows + 1];
for i in 0..self.nrows {
row_ptrs[i + 1] = row_ptrs[i] + counts[i];
}
let nnz = self.values.len();
let mut col_indices = vec![0usize; nnz];
let mut values = vec![<T as num_traits::Zero>::zero(); nnz];
let mut cursor = row_ptrs.clone();
for c in 0..self.ncols {
let start = self.col_ptrs[c];
let end = self.col_ptrs[c + 1];
for k in start..end {
let r = self.row_indices[k];
let dst = cursor[r];
col_indices[dst] = c;
values[dst] = self.values[k];
cursor[r] += 1;
}
}
CsrTensor::new(row_ptrs, col_indices, values, self.nrows, self.ncols)
.expect("CSR rebuild from CSC always satisfies invariants")
}
pub fn to_dense(&self) -> FerrotorchResult<Tensor<T>> {
self.to_dense_on(Device::Cpu)
}
pub fn to_dense_on(&self, device: Device) -> FerrotorchResult<Tensor<T>> {
use std::any::TypeId;
if let Device::Cuda(ord) = device
&& (TypeId::of::<T>() == TypeId::of::<f32>()
|| TypeId::of::<T>() == TypeId::of::<f64>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let col_ptrs_u32: Vec<u32> = self.col_ptrs.iter().map(|&v| v as u32).collect();
let row_idx_u32: Vec<u32> = self.row_indices.iter().map(|&v| v as u32).collect();
let out_handle = if TypeId::of::<T>() == TypeId::of::<f32>() {
let values_f32 = unsafe {
std::slice::from_raw_parts(
self.values.as_ptr().cast::<f32>(),
self.values.len(),
)
};
backend.csc_to_dense_f32(
&col_ptrs_u32,
&row_idx_u32,
values_f32,
ord,
self.nrows,
self.ncols,
)?
} else {
let values_f64 = unsafe {
std::slice::from_raw_parts(
self.values.as_ptr().cast::<f64>(),
self.values.len(),
)
};
backend.csc_to_dense_f64(
&col_ptrs_u32,
&row_idx_u32,
values_f64,
ord,
self.nrows,
self.ncols,
)?
};
let storage = TensorStorage::gpu(out_handle);
return Tensor::from_storage(storage, vec![self.nrows, self.ncols], false);
}
let mut data = vec![<T as num_traits::Zero>::zero(); self.nrows * self.ncols];
for c in 0..self.ncols {
let start = self.col_ptrs[c];
let end = self.col_ptrs[c + 1];
for k in start..end {
let r = self.row_indices[k];
data[r * self.ncols + c] = self.values[k];
}
}
let cpu_tensor = Tensor::from_storage(
TensorStorage::cpu(data),
vec![self.nrows, self.ncols],
false,
)?;
if matches!(device, Device::Cpu) {
Ok(cpu_tensor)
} else {
cpu_tensor.to(device)
}
}
pub fn from_csr_on(csr: &CsrTensor<T>, device: Device) -> FerrotorchResult<Self> {
use std::any::TypeId;
if let Device::Cuda(ord) = device
&& (TypeId::of::<T>() == TypeId::of::<f32>()
|| TypeId::of::<T>() == TypeId::of::<f64>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let crow: Vec<u32> = csr.row_ptrs.iter().map(|&v| v as u32).collect();
let col: Vec<u32> = csr.col_indices.iter().map(|&v| v as u32).collect();
let (col_ptrs_u32, row_idx_u32, values_t) = if TypeId::of::<T>() == TypeId::of::<f32>()
{
let vals_f32 = unsafe {
std::slice::from_raw_parts(csr.values.as_ptr().cast::<f32>(), csr.values.len())
};
let (cp, ri, v) =
backend.csr_to_csc_f32(&crow, &col, vals_f32, ord, csr.nrows, csr.ncols)?;
(cp, ri, vals_to_t::<T, f32>(v))
} else {
let vals_f64 = unsafe {
std::slice::from_raw_parts(csr.values.as_ptr().cast::<f64>(), csr.values.len())
};
let (cp, ri, v) =
backend.csr_to_csc_f64(&crow, &col, vals_f64, ord, csr.nrows, csr.ncols)?;
(cp, ri, vals_to_t::<T, f64>(v))
};
let col_ptrs: Vec<usize> = col_ptrs_u32.into_iter().map(|v| v as usize).collect();
let row_indices: Vec<usize> = row_idx_u32.into_iter().map(|v| v as usize).collect();
return Self::new(col_ptrs, row_indices, values_t, csr.nrows, csr.ncols);
}
Ok(Self::from_csr(csr))
}
pub fn to_csr_on(&self, device: Device) -> FerrotorchResult<CsrTensor<T>> {
use std::any::TypeId;
if let Device::Cuda(ord) = device
&& (TypeId::of::<T>() == TypeId::of::<f32>()
|| TypeId::of::<T>() == TypeId::of::<f64>())
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
let col_ptrs_u32: Vec<u32> = self.col_ptrs.iter().map(|&v| v as u32).collect();
let row_idx_u32: Vec<u32> = self.row_indices.iter().map(|&v| v as u32).collect();
let (crow_u32, col_u32, values_t) = if TypeId::of::<T>() == TypeId::of::<f32>() {
let vals_f32 = unsafe {
std::slice::from_raw_parts(
self.values.as_ptr().cast::<f32>(),
self.values.len(),
)
};
let (cr, ci, v) = backend.csr_to_csc_f32(
&col_ptrs_u32,
&row_idx_u32,
vals_f32,
ord,
self.ncols,
self.nrows,
)?;
(cr, ci, vals_to_t::<T, f32>(v))
} else {
let vals_f64 = unsafe {
std::slice::from_raw_parts(
self.values.as_ptr().cast::<f64>(),
self.values.len(),
)
};
let (cr, ci, v) = backend.csr_to_csc_f64(
&col_ptrs_u32,
&row_idx_u32,
vals_f64,
ord,
self.ncols,
self.nrows,
)?;
(cr, ci, vals_to_t::<T, f64>(v))
};
let row_ptrs: Vec<usize> = crow_u32.into_iter().map(|v| v as usize).collect();
let col_indices: Vec<usize> = col_u32.into_iter().map(|v| v as usize).collect();
return CsrTensor::new(row_ptrs, col_indices, values_t, self.nrows, self.ncols);
}
Ok(self.to_csr())
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn nrows(&self) -> usize {
self.nrows
}
pub fn ncols(&self) -> usize {
self.ncols
}
pub fn col_ptrs(&self) -> &[usize] {
&self.col_ptrs
}
pub fn row_indices(&self) -> &[usize] {
&self.row_indices
}
pub fn values(&self) -> &[T] {
&self.values
}
}
#[derive(Debug, Clone)]
pub struct SparseGrad<T: Float> {
indices: Vec<usize>,
values: Vec<T>,
slab_shape: Vec<usize>,
}
impl<T: Float> SparseGrad<T> {
pub fn new(
indices: Vec<usize>,
values: Vec<T>,
slab_shape: Vec<usize>,
) -> FerrotorchResult<Self> {
let slab_size: usize = slab_shape.iter().product::<usize>().max(1);
if values.len() != indices.len() * slab_size {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"SparseGrad: values.len()={} != indices.len()={} * slab_size={}",
values.len(),
indices.len(),
slab_size
),
});
}
Ok(Self {
indices,
values,
slab_shape,
})
}
pub fn nnz(&self) -> usize {
self.indices.len()
}
pub fn indices(&self) -> &[usize] {
&self.indices
}
pub fn values(&self) -> &[T] {
&self.values
}
pub fn slab_shape(&self) -> &[usize] {
&self.slab_shape
}
pub fn slab_size(&self) -> usize {
self.slab_shape.iter().product::<usize>().max(1)
}
#[inline]
#[must_use]
pub fn is_sparse(&self) -> bool {
true
}
pub fn coalesce(&self) -> Self {
let slab_size = self.slab_size();
let mut groups: std::collections::BTreeMap<usize, Vec<T>> =
std::collections::BTreeMap::new();
for (k, &idx) in self.indices.iter().enumerate() {
let slab_start = k * slab_size;
let entry = groups
.entry(idx)
.or_insert_with(|| vec![<T as num_traits::Zero>::zero(); slab_size]);
for (j, dst) in entry.iter_mut().enumerate() {
*dst += self.values[slab_start + j];
}
}
let new_nnz = groups.len();
let mut new_indices = Vec::with_capacity(new_nnz);
let mut new_values = Vec::with_capacity(new_nnz * slab_size);
for (idx, slab) in groups {
new_indices.push(idx);
new_values.extend(slab);
}
Self {
indices: new_indices,
values: new_values,
slab_shape: self.slab_shape.clone(),
}
}
pub fn apply_sgd(&self, param: &mut Tensor<T>, lr: T) -> FerrotorchResult<()> {
let shape = param.shape().to_vec();
if shape.len() != 1 + self.slab_shape.len() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"SparseGrad::apply_sgd: param shape {:?} incompatible with slab shape {:?}",
shape, self.slab_shape
),
});
}
if shape[1..] != self.slab_shape[..] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"SparseGrad::apply_sgd: param trailing dims {:?} != slab_shape {:?}",
&shape[1..],
self.slab_shape
),
});
}
let slab_size = self.slab_size();
let leading = shape[0];
for (k, &idx) in self.indices.iter().enumerate() {
if idx >= leading {
return Err(FerrotorchError::InvalidArgument {
message: format!("SparseGrad::apply_sgd: index {idx} >= {leading} (slot {k})"),
});
}
}
if param.is_cuda()
&& let Some(backend) = crate::gpu_dispatch::gpu_backend()
{
use std::any::TypeId;
let nnz = self.indices.len();
if nnz == 0 {
return Ok(());
}
if TypeId::of::<T>() == TypeId::of::<f32>() {
let ordinal = param.gpu_handle()?.device_ordinal();
let values_bytes = unsafe {
std::slice::from_raw_parts(
self.values.as_ptr().cast::<u8>(),
self.values.len() * std::mem::size_of::<f32>(),
)
};
let values_handle =
backend.cpu_to_gpu(values_bytes, crate::dtype::DType::F32, ordinal)?;
let indices_f32: Vec<f32> = self.indices.iter().map(|&i| i as f32).collect();
let idx_bytes = unsafe {
std::slice::from_raw_parts(
indices_f32.as_ptr().cast::<u8>(),
indices_f32.len() * std::mem::size_of::<f32>(),
)
};
let idx_handle =
backend.cpu_to_gpu(idx_bytes, crate::dtype::DType::F32, ordinal)?;
let dense_grad = backend.scatter_add_rows_f32(
&values_handle,
&idx_handle,
leading,
slab_size,
)?;
let lr_f32: f32 = num_traits::ToPrimitive::to_f32(&lr).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: "SparseGrad::apply_sgd: lr not representable as f32".into(),
}
})?;
let scaled = backend.scale_f32(&dense_grad, lr_f32)?;
let param_handle = param.gpu_handle()?;
let updated = backend.sub_f32(param_handle, &scaled)?;
let storage = TensorStorage::gpu(updated);
*param = Tensor::from_storage(storage, shape, param.requires_grad())?;
return Ok(());
}
if TypeId::of::<T>() == TypeId::of::<f64>() {
let ordinal = param.gpu_handle()?.device_ordinal();
let values_bytes = unsafe {
std::slice::from_raw_parts(
self.values.as_ptr().cast::<u8>(),
self.values.len() * std::mem::size_of::<f64>(),
)
};
let values_handle =
backend.cpu_to_gpu(values_bytes, crate::dtype::DType::F64, ordinal)?;
let indices_f32: Vec<f32> = self.indices.iter().map(|&i| i as f32).collect();
let idx_bytes = unsafe {
std::slice::from_raw_parts(
indices_f32.as_ptr().cast::<u8>(),
indices_f32.len() * std::mem::size_of::<f32>(),
)
};
let idx_handle =
backend.cpu_to_gpu(idx_bytes, crate::dtype::DType::F32, ordinal)?;
let dense_grad = backend.scatter_add_rows_f64(
&values_handle,
&idx_handle,
leading,
slab_size,
)?;
let lr_f64: f64 = num_traits::ToPrimitive::to_f64(&lr).ok_or_else(|| {
FerrotorchError::InvalidArgument {
message: "SparseGrad::apply_sgd: lr not representable as f64".into(),
}
})?;
let scaled = backend.scale_f64(&dense_grad, lr_f64)?;
let param_handle = param.gpu_handle()?;
let updated = backend.sub_f64(param_handle, &scaled)?;
let storage = TensorStorage::gpu(updated);
*param = Tensor::from_storage(storage, shape, param.requires_grad())?;
return Ok(());
}
}
let mut data = param.data_vec()?;
for (k, &idx) in self.indices.iter().enumerate() {
let row_start = idx * slab_size;
let val_start = k * slab_size;
for j in 0..slab_size {
data[row_start + j] = data[row_start + j] - lr * self.values[val_start + j];
}
}
*param = Tensor::from_storage(TensorStorage::cpu(data), shape, param.requires_grad())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_construction_and_accessors() {
let indices = vec![vec![0, 1], vec![1, 2], vec![2, 0]];
let values = vec![1.0f32, 2.0, 3.0];
let shape = vec![3, 3];
let sp = SparseTensor::new(indices.clone(), values.clone(), shape.clone()).unwrap();
assert_eq!(sp.nnz(), 3);
assert_eq!(sp.shape(), &[3, 3]);
assert_eq!(sp.ndim(), 2);
assert_eq!(sp.values(), &[1.0, 2.0, 3.0]);
assert_eq!(sp.indices(), &indices);
}
#[test]
#[allow(clippy::float_cmp)]
fn test_from_dense_with_threshold() {
let data = vec![0.0f32, 0.0, 5.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0];
let tensor = Tensor::from_storage(TensorStorage::cpu(data), vec![3, 3], false).unwrap();
let sp = SparseTensor::from_dense(&tensor, 0.0).unwrap();
assert_eq!(sp.nnz(), 2);
assert_eq!(sp.shape(), &[3, 3]);
let dense = sp.to_dense().unwrap();
let d = dense.data().unwrap();
let idx = |r: usize, c: usize| r * 3 + c;
assert_eq!(d[idx(0, 2)], 5.0);
assert_eq!(d[idx(2, 0)], 3.0);
}
#[test]
#[allow(clippy::float_cmp)]
fn test_from_dense_threshold_filters_small() {
let data = vec![0.5f32, 1.5, 0.1, 2.0];
let tensor = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 2], false).unwrap();
let sp = SparseTensor::from_dense(&tensor, 1.0).unwrap();
assert_eq!(sp.nnz(), 2);
let dense = sp.to_dense().unwrap();
let d = dense.data().unwrap();
assert_eq!(d[0], 0.0); assert_eq!(d[1], 1.5); assert_eq!(d[2], 0.0); assert_eq!(d[3], 2.0); }
#[test]
fn test_to_dense_round_trip() {
let data = vec![1.0f64, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0];
let original =
Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![3, 3], false).unwrap();
let sp = SparseTensor::from_dense(&original, 0.0).unwrap();
let reconstructed = sp.to_dense().unwrap();
let orig_data = original.data().unwrap();
let recon_data = reconstructed.data().unwrap();
for (a, b) in orig_data.iter().zip(recon_data.iter()) {
assert!((*a - *b).abs() < 1e-10, "mismatch: {a} vs {b}");
}
}
#[test]
fn test_spmm_matches_dense_mm() {
let sp = SparseTensor::new(
vec![vec![0, 0], vec![0, 2], vec![1, 1]],
vec![1.0f32, 2.0, 3.0],
vec![2, 3],
)
.unwrap();
let dense = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 4.0, 2.0, 5.0, 3.0, 6.0]),
vec![3, 2],
false,
)
.unwrap();
let result = sp.spmm(&dense).unwrap();
let d = result.data().unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert!((d[0] - 7.0).abs() < 1e-6);
assert!((d[1] - 16.0).abs() < 1e-6);
assert!((d[2] - 6.0).abs() < 1e-6);
assert!((d[3] - 15.0).abs() < 1e-6);
}
#[test]
fn test_spmm_identity() {
let sp = SparseTensor::new(
vec![vec![0, 0], vec![1, 1], vec![2, 2]],
vec![1.0f32; 3],
vec![3, 3],
)
.unwrap();
let dense = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]),
vec![3, 2],
false,
)
.unwrap();
let result = sp.spmm(&dense).unwrap();
let d = result.data().unwrap();
let expected = dense.data().unwrap();
assert_eq!(result.shape(), &[3, 2]);
for (a, b) in d.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn test_coalesce_merges_duplicates() {
let sp = SparseTensor::new(
vec![vec![0, 0], vec![0, 1], vec![0, 1]],
vec![1.0f32, 3.0, 4.0],
vec![1, 3],
)
.unwrap();
let coalesced = sp.coalesce();
assert_eq!(coalesced.nnz(), 2);
let dense = coalesced.to_dense().unwrap();
let d = dense.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6);
assert!((d[1] - 7.0).abs() < 1e-6);
assert!((d[2] - 0.0).abs() < 1e-6);
}
#[test]
fn test_coalesce_removes_zero_sum() {
let sp = SparseTensor::new(vec![vec![0, 0], vec![0, 0]], vec![5.0f32, -5.0], vec![1, 1])
.unwrap();
let coalesced = sp.coalesce();
assert_eq!(coalesced.nnz(), 0);
}
#[test]
fn test_transpose() {
let sp =
SparseTensor::new(vec![vec![0, 1], vec![2, 0]], vec![5.0f32, 3.0], vec![3, 4]).unwrap();
let transposed = sp.t().unwrap();
assert_eq!(transposed.shape(), &[4, 3]);
assert_eq!(transposed.nnz(), 2);
assert_eq!(transposed.indices()[0], vec![1, 0]);
assert_eq!(transposed.indices()[1], vec![0, 2]);
assert_eq!(transposed.values(), &[5.0, 3.0]);
}
#[test]
fn test_transpose_not_2d() {
let sp = SparseTensor::new(vec![vec![0, 1, 2]], vec![1.0f32], vec![3, 3, 3]).unwrap();
assert!(sp.t().is_err());
}
#[test]
fn test_mul_scalar() {
let sp =
SparseTensor::new(vec![vec![0, 0], vec![1, 1]], vec![2.0f64, 3.0], vec![2, 2]).unwrap();
let scaled = sp.mul_scalar(10.0);
assert_eq!(scaled.values(), &[20.0, 30.0]);
assert_eq!(scaled.nnz(), 2);
assert_eq!(scaled.shape(), &[2, 2]);
assert_eq!(scaled.indices(), sp.indices());
}
#[test]
fn test_add_sparse_tensors() {
let a =
SparseTensor::new(vec![vec![0, 0], vec![0, 1]], vec![1.0f32, 2.0], vec![2, 2]).unwrap();
let b =
SparseTensor::new(vec![vec![0, 1], vec![1, 0]], vec![3.0, 4.0], vec![2, 2]).unwrap();
let sum = a.add(&b).unwrap();
assert_eq!(sum.nnz(), 4);
let coalesced = sum.coalesce();
assert_eq!(coalesced.nnz(), 3);
let dense = coalesced.to_dense().unwrap();
let d = dense.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6); assert!((d[1] - 5.0).abs() < 1e-6); assert!((d[2] - 4.0).abs() < 1e-6); assert!((d[3] - 0.0).abs() < 1e-6); }
#[test]
fn test_add_shape_mismatch() {
let a = SparseTensor::<f32>::new(vec![], vec![], vec![2, 3]).unwrap();
let b = SparseTensor::<f32>::new(vec![], vec![], vec![3, 2]).unwrap();
assert!(a.add(&b).is_err());
}
#[test]
fn test_index_out_of_bounds() {
let result = SparseTensor::new(
vec![vec![3, 0]], vec![1.0f32],
vec![3, 3],
);
assert!(result.is_err());
let err = result.unwrap_err();
match err {
FerrotorchError::IndexOutOfBounds { index, axis, size } => {
assert_eq!(index, 3);
assert_eq!(axis, 0);
assert_eq!(size, 3);
}
other => panic!("expected IndexOutOfBounds, got: {other:?}"),
}
}
#[test]
fn test_index_out_of_bounds_second_axis() {
let result = SparseTensor::new(
vec![vec![0, 5]], vec![1.0f64],
vec![3, 3],
);
assert!(result.is_err());
match result.unwrap_err() {
FerrotorchError::IndexOutOfBounds { index, axis, size } => {
assert_eq!(index, 5);
assert_eq!(axis, 1);
assert_eq!(size, 3);
}
other => panic!("expected IndexOutOfBounds, got: {other:?}"),
}
}
#[test]
fn test_empty_sparse_tensor() {
let sp = SparseTensor::<f32>::new(vec![], vec![], vec![5, 5]).unwrap();
assert_eq!(sp.nnz(), 0);
assert_eq!(sp.shape(), &[5, 5]);
let dense = sp.to_dense().unwrap();
assert!(dense.data().unwrap().iter().all(|&x| x == 0.0));
}
#[test]
fn test_indices_values_length_mismatch() {
let result = SparseTensor::new(
vec![vec![0, 0], vec![1, 1]],
vec![1.0f32], vec![2, 2],
);
assert!(result.is_err());
}
#[test]
fn test_spmm_dimension_mismatch() {
let sp = SparseTensor::new(vec![vec![0, 0]], vec![1.0f32], vec![2, 3]).unwrap();
let dense =
Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 8]), vec![4, 2], false).unwrap();
assert!(sp.spmm(&dense).is_err());
}
#[test]
fn test_debug_format() {
let sp = SparseTensor::new(vec![vec![0, 0]], vec![1.0f32], vec![3, 3]).unwrap();
let debug = format!("{sp:?}");
assert!(debug.contains("SparseTensor"));
assert!(debug.contains("nnz: 1"));
}
#[test]
fn test_clone() {
let sp = SparseTensor::new(vec![vec![0, 1]], vec![42.0f32], vec![2, 2]).unwrap();
let sp2 = sp.clone();
assert_eq!(sp2.values(), &[42.0]);
assert_eq!(sp2.indices(), &[vec![0, 1]]);
assert_eq!(sp2.shape(), &[2, 2]);
}
#[test]
fn test_coo_coalesce_uses_tuple_key() {
let coo =
CooTensor::new(vec![0, 0, 1], vec![1, 1, 0], vec![3.0f32, 4.0, 5.0], 2, 2).unwrap();
let coalesced = coo.coalesce();
assert!(coalesced.is_coalesced());
assert_eq!(coalesced.nnz(), 2);
let dense = coalesced.to_dense().unwrap();
let d = dense.data().unwrap();
assert!((d[1] - 7.0).abs() < 1e-6); assert!((d[2] - 5.0).abs() < 1e-6); }
#[test]
fn test_coo_from_csr_not_coalesced() {
let csr = CsrTensor::new(vec![0, 1, 2], vec![0, 1], vec![1.0f32, 2.0], 2, 2).unwrap();
let coo = CooTensor::from_csr(&csr);
assert!(!coo.is_coalesced());
assert_eq!(coo.nnz(), 2);
}
#[test]
fn test_csr_from_coo_with_duplicates() {
let coo =
CooTensor::new(vec![0, 0, 1], vec![0, 0, 1], vec![1.0f32, 2.0, 3.0], 2, 2).unwrap();
let csr = CsrTensor::from_coo(&coo).unwrap();
assert_eq!(csr.nnz(), 2);
let dense = csr.to_dense().unwrap();
let d = dense.data().unwrap();
assert!((d[0] - 3.0).abs() < 1e-6); assert!((d[3] - 3.0).abs() < 1e-6); }
#[test]
fn test_coalesce_deterministic_order() {
let sp = SparseTensor::new(
vec![vec![1, 0], vec![0, 1], vec![0, 0]],
vec![3.0f32, 2.0, 1.0],
vec![2, 2],
)
.unwrap();
let coalesced = sp.coalesce();
assert_eq!(coalesced.indices()[0], vec![0, 0]);
assert_eq!(coalesced.indices()[1], vec![0, 1]);
assert_eq!(coalesced.indices()[2], vec![1, 0]);
}
#[test]
#[allow(clippy::float_cmp)]
fn test_1d_sparse_tensor() {
let sp = SparseTensor::new(vec![vec![1], vec![4]], vec![10.0f32, 20.0], vec![5]).unwrap();
assert_eq!(sp.ndim(), 1);
assert_eq!(sp.nnz(), 2);
assert_eq!(sp.shape(), &[5]);
let dense = sp.to_dense().unwrap();
let d = dense.data().unwrap();
assert_eq!(d.len(), 5);
assert_eq!(d[0], 0.0);
assert_eq!(d[1], 10.0);
assert_eq!(d[2], 0.0);
assert_eq!(d[3], 0.0);
assert_eq!(d[4], 20.0);
}
#[test]
fn test_3d_sparse_tensor() {
let sp = SparseTensor::new(
vec![vec![0, 1, 2], vec![1, 0, 0]],
vec![5.0f64, 7.0],
vec![2, 2, 3],
)
.unwrap();
assert_eq!(sp.ndim(), 3);
assert_eq!(sp.nnz(), 2);
assert_eq!(sp.shape(), &[2, 2, 3]);
let dense = sp.to_dense().unwrap();
let d = dense.data().unwrap();
assert_eq!(d.len(), 12);
assert!((d[5] - 5.0).abs() < 1e-10);
assert!((d[6] - 7.0).abs() < 1e-10);
}
#[test]
fn test_zero_dimension_sparse_tensor() {
let sp = SparseTensor::<f32>::new(vec![], vec![], vec![0, 5]).unwrap();
assert_eq!(sp.ndim(), 2);
assert_eq!(sp.nnz(), 0);
assert_eq!(sp.shape(), &[0, 5]);
let dense = sp.to_dense().unwrap();
assert_eq!(dense.numel(), 0);
assert!(dense.data().unwrap().is_empty());
}
fn mk(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
#[test]
fn semi24_compress_keeps_two_largest_magnitudes_per_group() {
let t = mk(vec![1.0, 4.0, 2.0, 3.0, -5.0, 2.0, 0.0, 1.0], vec![8]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.values(), &[4.0, 3.0, -5.0, 2.0]);
assert_eq!(sp.mask(), &[0x3A]);
assert_eq!(sp.num_groups(), 2);
assert_eq!(sp.group_mask(0), 0xA);
assert_eq!(sp.group_mask(1), 0x3);
}
#[test]
fn semi24_decompress_roundtrips_compressed_values() {
let t = mk(vec![1.0, 4.0, 2.0, 3.0, -5.0, 2.0, 0.0, 1.0], vec![8]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
let dense = sp.decompress().unwrap();
let data = dense.data().unwrap();
assert_eq!(data, &[0.0, 4.0, 0.0, 3.0, -5.0, 2.0, 0.0, 0.0]);
assert_eq!(dense.shape(), &[8]);
}
#[test]
fn semi24_compress_decompress_preserves_shape() {
let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
let t = mk(data, vec![2, 8]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.shape(), &[2, 8]);
let dense = sp.decompress().unwrap();
assert_eq!(dense.shape(), &[2, 8]);
}
#[test]
fn semi24_rejects_non_multiple_of_4() {
let t = mk(vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![5]);
let result = SemiStructuredSparseTensor::compress(&t);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("multiple of 4"));
}
#[test]
fn semi24_tie_breaking_prefers_lower_position() {
let t = mk(vec![1.0, 1.0, 1.0, 1.0], vec![4]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.group_mask(0), 0x3);
assert_eq!(sp.values(), &[1.0, 1.0]);
}
#[test]
fn semi24_compression_ratio_is_roughly_half() {
let n = 1024usize;
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let t = mk(data, vec![n]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
let ratio = sp.compression_ratio();
assert!(ratio > 0.5 && ratio < 0.6, "unexpected ratio: {ratio}");
}
#[test]
fn semi24_zero_tensor_has_deterministic_mask() {
let t = mk(vec![0.0; 16], vec![16]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.values(), &[0.0; 8]);
for g in 0..4 {
assert_eq!(sp.group_mask(g), 0x3);
}
}
#[test]
fn semi24_negative_and_positive_by_magnitude() {
let t = mk(vec![-10.0, 1.0, -2.0, 3.0], vec![4]);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.values(), &[-10.0, 3.0]);
assert_eq!(sp.group_mask(0), 0x9);
}
#[test]
#[allow(clippy::float_cmp)]
fn semi24_sparse_matmul_matches_dense_matmul() {
let a = mk(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let b_data = vec![
1.0, 4.0, 2.0, 3.0, -5.0, 2.0, 0.0, 1.0, ];
let b_dense = mk(b_data.clone(), vec![2, 4]);
let b_sparse = SemiStructuredSparseTensor::compress(&b_dense).unwrap();
let out = sparse_matmul_24(&a, &b_sparse).unwrap();
assert_eq!(out.shape(), &[2, 4]);
let b_masked = b_sparse.decompress().unwrap();
let b_m = b_masked.data().unwrap();
let d = out.data().unwrap();
assert_eq!(d[0], 1.0 * b_m[0] + 2.0 * b_m[4]);
assert_eq!(d[1], 1.0 * b_m[1] + 2.0 * b_m[5]);
assert_eq!(d[2], 1.0 * b_m[2] + 2.0 * b_m[6]);
assert_eq!(d[3], 1.0 * b_m[3] + 2.0 * b_m[7]);
assert_eq!(d[4], 3.0 * b_m[0] + 4.0 * b_m[4]);
assert_eq!(d[5], 3.0 * b_m[1] + 4.0 * b_m[5]);
assert_eq!(d[6], 3.0 * b_m[2] + 4.0 * b_m[6]);
assert_eq!(d[7], 3.0 * b_m[3] + 4.0 * b_m[7]);
}
#[test]
fn semi24_sparse_matmul_rejects_non_2d_a() {
let a = mk(vec![1.0, 2.0, 3.0, 4.0], vec![4]); let b_dense = mk(vec![1.0; 16], vec![4, 4]);
let b_sparse = SemiStructuredSparseTensor::compress(&b_dense).unwrap();
let result = sparse_matmul_24(&a, &b_sparse);
assert!(result.is_err());
}
#[test]
fn semi24_sparse_matmul_rejects_inner_dim_mismatch() {
let a = mk(vec![1.0, 2.0, 3.0], vec![1, 3]); let b_dense = mk(vec![1.0; 16], vec![4, 4]); let b_sparse = SemiStructuredSparseTensor::compress(&b_dense).unwrap();
let result = sparse_matmul_24(&a, &b_sparse);
assert!(result.is_err());
}
#[test]
fn semi24_compress_then_decompress_matches_apply_2_4_mask() {
let t = mk(
vec![
0.1, 0.9, 0.3, 0.5, -0.8, 0.2, 0.7, -0.4, 1.5, -2.0, 0.1, 0.3,
],
vec![12],
);
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
let sp_dense = sp.decompress().unwrap();
let mask_result = crate::pruning::apply_2_4_mask(&t).unwrap();
assert_eq!(
sp_dense.data().unwrap(),
mask_result.data().unwrap(),
"compress+decompress must match apply_2_4_mask output"
);
}
#[test]
fn semi24_f64_parity() {
let t = Tensor::<f64>::from_storage(
TensorStorage::cpu(vec![1.0, 4.0, 2.0, 3.0, -5.0, 2.0, 0.0, 1.0]),
vec![8],
false,
)
.unwrap();
let sp = SemiStructuredSparseTensor::compress(&t).unwrap();
assert_eq!(sp.values(), &[4.0, 3.0, -5.0, 2.0]);
let dense = sp.decompress().unwrap();
let data = dense.data().unwrap();
assert_eq!(data, &[0.0, 4.0, 0.0, 3.0, -5.0, 2.0, 0.0, 0.0]);
}
#[test]
fn csc_from_csr_roundtrip() {
let csr = CsrTensor::new(
vec![0, 2, 3, 5],
vec![0, 2, 2, 0, 1],
vec![1.0_f32, 2.0, 3.0, 4.0, 5.0],
3,
3,
)
.unwrap();
let csc = CscTensor::from_csr(&csr);
assert_eq!(csc.nrows(), 3);
assert_eq!(csc.ncols(), 3);
assert_eq!(csc.nnz(), 5);
let csr2 = csc.to_csr();
assert_eq!(csr2.values().to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn csc_to_dense_matches_csr() {
let csr = CsrTensor::new(
vec![0, 2, 3, 5],
vec![0, 2, 2, 0, 1],
vec![1.0_f32, 2.0, 3.0, 4.0, 5.0],
3,
3,
)
.unwrap();
let csc = CscTensor::from_csr(&csr);
let dense = csc.to_dense().unwrap();
let d = dense.data().unwrap();
assert_eq!(d, &[1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0]);
}
#[test]
fn csc_rejects_bad_col_ptrs_length() {
let err = CscTensor::new(vec![0, 1], vec![0], vec![1.0_f32], 2, 3).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn csc_rejects_oob_row_index() {
let err = CscTensor::new(vec![0, 1], vec![5], vec![1.0_f32], 2, 1).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn sparse_grad_construction_validates_size() {
let err = SparseGrad::<f32>::new(vec![0, 2], vec![1.0; 5], vec![3]).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn sparse_grad_is_sparse_predicate() {
let g = SparseGrad::<f32>::new(vec![0], vec![1.0, 2.0], vec![2]).unwrap();
assert!(g.is_sparse());
}
#[test]
fn sparse_grad_coalesce_sums_duplicate_indices() {
let g = SparseGrad::<f32>::new(vec![0, 1, 0], vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0], vec![2])
.unwrap();
let c = g.coalesce();
assert_eq!(c.indices(), &[0, 1]);
assert_eq!(c.values(), &[4.0, 6.0, 5.0, 6.0]);
}
#[test]
fn sparse_grad_apply_sgd_updates_only_affected_rows() {
let mut param =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![0.0; 12]), vec![4, 3], false)
.unwrap();
let grad = SparseGrad::<f32>::new(vec![1, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3])
.unwrap();
grad.apply_sgd(&mut param, 1.0).unwrap();
let d = param.data().unwrap();
assert_eq!(&d[0..3], &[0.0, 0.0, 0.0]);
assert_eq!(&d[3..6], &[-1.0, -2.0, -3.0]);
assert_eq!(&d[6..9], &[0.0, 0.0, 0.0]);
assert_eq!(&d[9..12], &[-4.0, -5.0, -6.0]);
}
#[test]
fn sparse_grad_apply_sgd_rejects_oob_index() {
let mut param =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![0.0; 6]), vec![2, 3], false)
.unwrap();
let grad = SparseGrad::<f32>::new(vec![5], vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let err = grad.apply_sgd(&mut param, 1.0).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn sparse_grad_apply_sgd_rejects_shape_mismatch() {
let mut param =
Tensor::<f32>::from_storage(TensorStorage::cpu(vec![0.0; 6]), vec![2, 3], false)
.unwrap();
let grad = SparseGrad::<f32>::new(vec![0], vec![1.0; 4], vec![4]).unwrap();
let err = grad.apply_sgd(&mut param, 1.0).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
}