use crate::dtype::DType;
use crate::error::TorshError;
use crate::shape::Shape;
use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SparseFormat {
COO,
CSR,
CSC,
BSR,
DIA,
ELL,
}
impl SparseFormat {
pub fn name(self) -> &'static str {
match self {
Self::COO => "COO",
Self::CSR => "CSR",
Self::CSC => "CSC",
Self::BSR => "BSR",
Self::DIA => "DIA",
Self::ELL => "ELL",
}
}
pub fn supports_row_access(self) -> bool {
matches!(self, Self::CSR | Self::BSR)
}
pub fn supports_column_access(self) -> bool {
matches!(self, Self::CSC)
}
pub fn is_gpu_friendly(self) -> bool {
matches!(self, Self::CSR | Self::ELL | Self::BSR)
}
}
impl fmt::Display for SparseFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone)]
pub struct SparseMetadata {
format: SparseFormat,
nnz: usize,
sparsity: f32,
indices_sorted: bool,
duplicates_summed: bool,
block_size: Option<(usize, usize)>,
num_diagonals: Option<usize>,
ell_width: Option<usize>,
compression_stats: CompressionStats,
}
#[derive(Debug, Clone)]
pub struct CompressionStats {
dense_size_bytes: usize,
sparse_size_bytes: usize,
compression_ratio: f32,
#[allow(dead_code)] index_overhead_bytes: usize,
}
impl SparseMetadata {
pub fn new(
format: SparseFormat,
nnz: usize,
total_elements: usize,
dense_size_bytes: usize,
sparse_size_bytes: usize,
) -> Self {
let sparsity = 1.0 - (nnz as f32 / total_elements as f32);
let compression_ratio = dense_size_bytes as f32 / sparse_size_bytes as f32;
Self {
format,
nnz,
sparsity,
indices_sorted: false,
duplicates_summed: false,
block_size: None,
num_diagonals: None,
ell_width: None,
compression_stats: CompressionStats {
dense_size_bytes,
sparse_size_bytes,
compression_ratio,
index_overhead_bytes: sparse_size_bytes - (nnz * 4), },
}
}
pub fn format(&self) -> SparseFormat {
self.format
}
pub fn nnz(&self) -> usize {
self.nnz
}
pub fn sparsity(&self) -> f32 {
self.sparsity
}
pub fn density(&self) -> f32 {
1.0 - self.sparsity
}
pub fn indices_sorted(&self) -> bool {
self.indices_sorted
}
pub fn set_indices_sorted(&mut self, sorted: bool) {
self.indices_sorted = sorted;
}
pub fn duplicates_summed(&self) -> bool {
self.duplicates_summed
}
pub fn set_duplicates_summed(&mut self, summed: bool) {
self.duplicates_summed = summed;
}
pub fn block_size(&self) -> Option<(usize, usize)> {
self.block_size
}
pub fn set_block_size(&mut self, size: (usize, usize)) {
self.block_size = Some(size);
}
pub fn compression_stats(&self) -> &CompressionStats {
&self.compression_stats
}
pub fn is_beneficial(&self) -> bool {
self.compression_stats.compression_ratio > 1.2 }
pub fn memory_savings_bytes(&self) -> i64 {
self.compression_stats.dense_size_bytes as i64
- self.compression_stats.sparse_size_bytes as i64
}
pub fn format_info(&self) -> String {
match self.format {
SparseFormat::BSR => {
if let Some((bm, bn)) = self.block_size {
format!("BSR({}x{})", bm, bn)
} else {
"BSR".to_string()
}
}
SparseFormat::DIA => {
if let Some(ndiag) = self.num_diagonals {
format!("DIA({})", ndiag)
} else {
"DIA".to_string()
}
}
SparseFormat::ELL => {
if let Some(width) = self.ell_width {
format!("ELL({})", width)
} else {
"ELL".to_string()
}
}
_ => self.format.name().to_string(),
}
}
}
impl fmt::Display for SparseMetadata {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SparseMetadata({}, nnz={}, sparsity={:.2}%, compression={:.1}x)",
self.format_info(),
self.nnz,
self.sparsity * 100.0,
self.compression_stats.compression_ratio
)
}
}
#[derive(Debug, Clone)]
pub struct CooIndices {
pub rows: Vec<usize>,
pub cols: Vec<usize>,
pub extra_dims: Vec<Vec<usize>>,
}
impl CooIndices {
pub fn new_2d(rows: Vec<usize>, cols: Vec<usize>) -> Self {
assert_eq!(
rows.len(),
cols.len(),
"Row and column indices must have same length"
);
Self {
rows,
cols,
extra_dims: Vec::new(),
}
}
pub fn new_nd(indices: Vec<Vec<usize>>) -> Self {
let nnz = indices.first().map_or(0, |dim| dim.len());
for (i, dim_indices) in indices.iter().enumerate() {
assert_eq!(
dim_indices.len(),
nnz,
"Dimension {} indices length mismatch: expected {}, got {}",
i,
nnz,
dim_indices.len()
);
}
if indices.len() < 2 {
panic!("N-D tensor must have at least 2 dimensions");
}
Self {
rows: indices[0].clone(),
cols: indices[1].clone(),
extra_dims: if indices.len() > 2 {
indices[2..].to_vec()
} else {
Vec::new()
},
}
}
pub fn nnz(&self) -> usize {
self.rows.len()
}
pub fn ndim(&self) -> usize {
2 + self.extra_dims.len()
}
pub fn is_sorted(&self) -> bool {
for i in 1..self.rows.len() {
if self.rows[i] < self.rows[i - 1] {
return false;
}
if self.rows[i] == self.rows[i - 1] && self.cols[i] < self.cols[i - 1] {
return false;
}
}
true
}
pub fn sort(&mut self) -> Vec<usize> {
let mut perm: Vec<usize> = (0..self.nnz()).collect();
perm.sort_by(|&a, &b| {
match self.rows[a].cmp(&self.rows[b]) {
std::cmp::Ordering::Equal => {
match self.cols[a].cmp(&self.cols[b]) {
std::cmp::Ordering::Equal => {
for dim_indices in &self.extra_dims {
match dim_indices[a].cmp(&dim_indices[b]) {
std::cmp::Ordering::Equal => continue,
other => return other,
}
}
std::cmp::Ordering::Equal
}
other => other,
}
}
other => other,
}
});
let orig_rows = self.rows.clone();
let orig_cols = self.cols.clone();
let orig_extra: Vec<_> = self.extra_dims.clone();
for (i, &p) in perm.iter().enumerate() {
self.rows[i] = orig_rows[p];
self.cols[i] = orig_cols[p];
for (dim_idx, orig_dim) in orig_extra.iter().enumerate() {
self.extra_dims[dim_idx][i] = orig_dim[p];
}
}
perm
}
}
#[derive(Debug, Clone)]
pub struct CsrIndices {
pub row_ptrs: Vec<usize>,
pub col_indices: Vec<usize>,
}
impl CsrIndices {
pub fn new(row_ptrs: Vec<usize>, col_indices: Vec<usize>) -> Self {
let nnz = col_indices.len();
let _nrows = row_ptrs.len().saturating_sub(1);
assert_eq!(
*row_ptrs.last().unwrap_or(&0),
nnz,
"Last row pointer must equal nnz"
);
for i in 1..row_ptrs.len() {
assert!(
row_ptrs[i] >= row_ptrs[i - 1],
"Row pointers must be non-decreasing"
);
}
Self {
row_ptrs,
col_indices,
}
}
pub fn from_coo(coo: &CooIndices, nrows: usize) -> Self {
let _nnz = coo.nnz();
let mut row_ptrs = vec![0; nrows + 1];
for &row in &coo.rows {
if row < nrows {
row_ptrs[row + 1] += 1;
}
}
for i in 1..=nrows {
row_ptrs[i] += row_ptrs[i - 1];
}
let col_indices = coo.cols.clone();
Self::new(row_ptrs, col_indices)
}
pub fn nrows(&self) -> usize {
self.row_ptrs.len().saturating_sub(1)
}
pub fn nnz(&self) -> usize {
self.col_indices.len()
}
pub fn row_range(&self, row: usize) -> Option<std::ops::Range<usize>> {
if row >= self.nrows() {
return None;
}
Some(self.row_ptrs[row]..self.row_ptrs[row + 1])
}
}
pub trait SparseStorage: Send + Sync + std::fmt::Debug {
fn metadata(&self) -> &SparseMetadata;
fn nnz(&self) -> usize {
self.metadata().nnz()
}
fn format(&self) -> SparseFormat {
self.metadata().format()
}
fn is_beneficial(&self) -> bool {
self.metadata().is_beneficial()
}
fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError>;
fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError>;
fn memory_usage(&self) -> usize;
}
#[derive(Debug)]
pub struct CooStorage {
metadata: SparseMetadata,
indices: CooIndices,
values: Vec<u8>, dtype: DType,
shape: Shape,
}
impl CooStorage {
pub fn new(
indices: CooIndices,
values: Vec<u8>,
dtype: DType,
shape: Shape,
) -> Result<Self, TorshError> {
let nnz = indices.nnz();
let expected_value_size = nnz * dtype.size();
if values.len() != expected_value_size {
return Err(TorshError::InvalidArgument(format!(
"Value buffer size mismatch: expected {}, got {}",
expected_value_size,
values.len()
)));
}
let total_elements: usize = shape.dims().iter().product();
let dense_size = total_elements * dtype.size();
let sparse_size = values.len() + indices.rows.len() * 8 + indices.cols.len() * 8;
let metadata = SparseMetadata::new(
SparseFormat::COO,
nnz,
total_elements,
dense_size,
sparse_size,
);
Ok(Self {
metadata,
indices,
values,
dtype,
shape,
})
}
pub fn indices(&self) -> &CooIndices {
&self.indices
}
pub fn indices_mut(&mut self) -> &mut CooIndices {
&mut self.indices
}
pub fn values_bytes(&self) -> &[u8] {
&self.values
}
pub fn dtype(&self) -> DType {
self.dtype
}
pub fn shape(&self) -> &Shape {
&self.shape
}
}
impl SparseStorage for CooStorage {
fn metadata(&self) -> &SparseMetadata {
&self.metadata
}
fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
Ok(Arc::new(CooStorage {
metadata: self.metadata.clone(),
indices: self.indices.clone(),
values: self.values.clone(),
dtype: self.dtype,
shape: self.shape.clone(),
}))
}
fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
if self.shape.ndim() != 2 {
return Err(TorshError::InvalidArgument(
"CSR format only supports 2D tensors".to_string(),
));
}
let nrows = self.shape.dims()[0];
let csr_indices = CsrIndices::from_coo(&self.indices, nrows);
Ok(Arc::new(CsrStorage {
metadata: {
let mut meta = self.metadata.clone();
meta.format = SparseFormat::CSR;
meta
},
indices: csr_indices,
values: self.values.clone(),
dtype: self.dtype,
shape: self.shape.clone(),
}))
}
fn memory_usage(&self) -> usize {
self.values.len()
+ self.indices.rows.len() * std::mem::size_of::<usize>()
+ self.indices.cols.len() * std::mem::size_of::<usize>()
+ self
.indices
.extra_dims
.iter()
.map(|dim| dim.len() * std::mem::size_of::<usize>())
.sum::<usize>()
}
}
#[derive(Debug)]
pub struct CsrStorage {
metadata: SparseMetadata,
indices: CsrIndices,
values: Vec<u8>,
dtype: DType,
shape: Shape,
}
impl CsrStorage {
pub fn new(
indices: CsrIndices,
values: Vec<u8>,
dtype: DType,
shape: Shape,
) -> Result<Self, TorshError> {
if shape.ndim() != 2 {
return Err(TorshError::InvalidArgument(
"CSR format only supports 2D tensors".to_string(),
));
}
let nnz = indices.nnz();
let expected_value_size = nnz * dtype.size();
if values.len() != expected_value_size {
return Err(TorshError::InvalidArgument(format!(
"Value buffer size mismatch: expected {}, got {}",
expected_value_size,
values.len()
)));
}
let total_elements: usize = shape.dims().iter().product();
let dense_size = total_elements * dtype.size();
let sparse_size = values.len() + indices.row_ptrs.len() * 8 + indices.col_indices.len() * 8;
let metadata = SparseMetadata::new(
SparseFormat::CSR,
nnz,
total_elements,
dense_size,
sparse_size,
);
Ok(Self {
metadata,
indices,
values,
dtype,
shape,
})
}
pub fn indices(&self) -> &CsrIndices {
&self.indices
}
}
impl SparseStorage for CsrStorage {
fn metadata(&self) -> &SparseMetadata {
&self.metadata
}
fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
let mut rows = Vec::with_capacity(self.nnz());
let mut cols = Vec::with_capacity(self.nnz());
for row in 0..self.indices.nrows() {
let range = self
.indices
.row_range(row)
.expect("row index should be valid within nrows bound");
for col_idx in range {
rows.push(row);
cols.push(self.indices.col_indices[col_idx]);
}
}
let coo_indices = CooIndices::new_2d(rows, cols);
Ok(Arc::new(CooStorage {
metadata: {
let mut meta = self.metadata.clone();
meta.format = SparseFormat::COO;
meta
},
indices: coo_indices,
values: self.values.clone(),
dtype: self.dtype,
shape: self.shape.clone(),
}))
}
fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
Ok(Arc::new(CsrStorage {
metadata: self.metadata.clone(),
indices: self.indices.clone(),
values: self.values.clone(),
dtype: self.dtype,
shape: self.shape.clone(),
}))
}
fn memory_usage(&self) -> usize {
self.values.len()
+ self.indices.row_ptrs.len() * std::mem::size_of::<usize>()
+ self.indices.col_indices.len() * std::mem::size_of::<usize>()
}
}
pub mod utils {
use super::*;
pub fn analyze_sparsity(data: &[f32], shape: &[usize]) -> SparseAnalysis {
let total_elements = data.len();
let mut nnz = 0;
let mut pattern_info = PatternInfo::default();
for (idx, &value) in data.iter().enumerate() {
if value != 0.0 {
nnz += 1;
pattern_info.update(idx, shape);
}
}
let sparsity = 1.0 - (nnz as f32 / total_elements as f32);
SparseAnalysis {
sparsity,
nnz,
total_elements,
pattern_info,
}
}
pub fn recommend_format(analysis: &SparseAnalysis, shape: &[usize]) -> FormatRecommendation {
let sparsity = analysis.sparsity;
let nnz = analysis.nnz;
if sparsity < 0.5 {
return FormatRecommendation {
format: None, reason: "Low sparsity, dense representation more efficient".to_string(),
confidence: 0.9,
};
}
if shape.len() == 2 {
let (nrows, ncols) = (shape[0], shape[1]);
if analysis.pattern_info.has_structured_rows {
return FormatRecommendation {
format: Some(SparseFormat::CSR),
reason: "Good row locality, CSR optimal for row-wise operations".to_string(),
confidence: 0.8,
};
}
if analysis.pattern_info.has_structured_cols {
return FormatRecommendation {
format: Some(SparseFormat::CSC),
reason: "Good column locality, CSC optimal for column-wise operations"
.to_string(),
confidence: 0.8,
};
}
if nnz < (nrows + ncols) * 10 {
return FormatRecommendation {
format: Some(SparseFormat::COO),
reason: "Very sparse matrix, COO has lowest overhead".to_string(),
confidence: 0.7,
};
}
return FormatRecommendation {
format: Some(SparseFormat::CSR),
reason: "General 2D sparse matrix, CSR is default choice".to_string(),
confidence: 0.6,
};
}
FormatRecommendation {
format: Some(SparseFormat::COO),
reason: "Multi-dimensional tensor, COO supports arbitrary dimensions".to_string(),
confidence: 0.8,
}
}
pub fn densify_to_sparse<T>(
data: &[T],
shape: &Shape,
dtype: DType,
threshold: Option<f64>,
) -> Result<Arc<dyn SparseStorage>, TorshError>
where
T: Clone + PartialEq + Into<f64> + Default,
{
let threshold = threshold.unwrap_or(1e-12);
let zero = T::default();
let mut indices = Vec::new();
let mut values = Vec::new();
for (linear_idx, value) in data.iter().enumerate() {
let abs_val = value.clone().into().abs();
if abs_val > threshold && *value != zero {
let multi_idx = linear_to_multidim(linear_idx, shape.dims());
indices.push(multi_idx);
values.push(value.clone());
}
}
if indices.is_empty() {
return Err(TorshError::InvalidArgument(
"No non-zero elements found".to_string(),
));
}
let value_bytes: Vec<u8> = values
.iter()
.flat_map(|v| {
let val_f64 = v.clone().into();
val_f64.to_ne_bytes()
})
.collect();
let dims = shape.dims();
match dims.len() {
1 => {
let rows: Vec<usize> = indices.iter().map(|idx| idx[0]).collect();
let cols = vec![0; rows.len()]; let coo_indices = CooIndices::new_2d(rows, cols);
CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
.map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
}
2 => {
let rows: Vec<usize> = indices.iter().map(|idx| idx[0]).collect();
let cols: Vec<usize> = indices.iter().map(|idx| idx[1]).collect();
let coo_indices = CooIndices::new_2d(rows, cols);
CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
.map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
}
_ => {
let transposed_indices: Vec<Vec<usize>> = (0..dims.len())
.map(|dim| indices.iter().map(|idx| idx[dim]).collect())
.collect();
let coo_indices = CooIndices::new_nd(transposed_indices);
CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
.map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
}
}
}
fn linear_to_multidim(linear_idx: usize, shape: &[usize]) -> Vec<usize> {
let mut result = Vec::with_capacity(shape.len());
let mut remaining = linear_idx;
for &dim_size in shape.iter().rev() {
result.push(remaining % dim_size);
remaining /= dim_size;
}
result.reverse();
result
}
#[derive(Debug, Clone)]
pub struct SparseAnalysis {
pub sparsity: f32,
pub nnz: usize,
pub total_elements: usize,
pub pattern_info: PatternInfo,
}
#[derive(Debug, Clone, Default)]
pub struct PatternInfo {
pub has_structured_rows: bool,
pub has_structured_cols: bool,
pub has_diagonal_structure: bool,
pub has_block_structure: bool,
pub block_size: Option<(usize, usize)>,
}
impl PatternInfo {
fn update(&mut self, idx: usize, shape: &[usize]) {
if shape.len() == 2 {
let (_nrows, ncols) = (shape[0], shape[1]);
let row = idx / ncols;
let col = idx % ncols;
if row == col {
self.has_diagonal_structure = true;
}
if row.is_multiple_of(4) && col.is_multiple_of(4) {
self.has_block_structure = true;
self.block_size = Some((4, 4));
}
}
}
}
#[derive(Debug, Clone)]
pub struct FormatRecommendation {
pub format: Option<SparseFormat>,
pub reason: String,
pub confidence: f32, }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::shape::Shape;
#[test]
fn test_sparse_metadata_creation() {
let metadata = SparseMetadata::new(
SparseFormat::COO,
1000, 10000, 40000, 8000, );
assert_eq!(metadata.format(), SparseFormat::COO);
assert_eq!(metadata.nnz(), 1000);
assert_eq!(metadata.sparsity(), 0.9); assert!(metadata.is_beneficial()); }
#[test]
fn test_coo_indices_creation() {
let rows = vec![0, 1, 2, 1];
let cols = vec![1, 0, 2, 2];
let indices = CooIndices::new_2d(rows.clone(), cols.clone());
assert_eq!(indices.nnz(), 4);
assert_eq!(indices.ndim(), 2);
assert_eq!(indices.rows, rows);
assert_eq!(indices.cols, cols);
}
#[test]
fn test_coo_indices_sorting() {
let mut indices = CooIndices::new_2d(
vec![2, 1, 0, 1], vec![0, 2, 1, 0], );
assert!(!indices.is_sorted());
let _perm = indices.sort();
assert_eq!(indices.rows, vec![0, 1, 1, 2]);
assert_eq!(indices.cols, vec![1, 0, 2, 0]);
assert!(indices.is_sorted());
}
#[test]
fn test_csr_from_coo() {
let coo_indices = CooIndices::new_2d(
vec![0, 0, 1, 2, 2], vec![1, 2, 0, 1, 2], );
let csr_indices = CsrIndices::from_coo(&coo_indices, 3);
assert_eq!(csr_indices.nrows(), 3);
assert_eq!(csr_indices.nnz(), 5);
assert_eq!(csr_indices.row_ptrs, vec![0, 2, 3, 5]);
assert_eq!(csr_indices.col_indices, vec![1, 2, 0, 1, 2]);
}
#[test]
fn test_coo_storage_creation() {
let indices = CooIndices::new_2d(vec![0, 1], vec![1, 0]);
let values = [1.0_f32.to_ne_bytes(), 2.0_f32.to_ne_bytes()].concat();
let shape = Shape::new(vec![2, 2]);
let storage = CooStorage::new(indices, values, DType::F32, shape)
.expect("CooStorage::new should succeed");
assert_eq!(storage.nnz(), 2);
assert_eq!(storage.format(), SparseFormat::COO);
assert_eq!(storage.dtype(), DType::F32);
}
#[test]
fn test_format_conversion() {
let indices = CooIndices::new_2d(vec![0, 1], vec![1, 0]);
let values = [1.0_f32.to_ne_bytes(), 2.0_f32.to_ne_bytes()].concat();
let shape = Shape::new(vec![2, 2]);
let coo_storage = CooStorage::new(indices, values, DType::F32, shape)
.expect("CooStorage::new should succeed");
let csr_storage = coo_storage.to_csr().expect("to_csr should succeed");
assert_eq!(csr_storage.format(), SparseFormat::CSR);
let coo_again = csr_storage.to_coo().expect("to_coo should succeed");
assert_eq!(coo_again.format(), SparseFormat::COO);
}
#[test]
fn test_sparsity_analysis() {
let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0];
let shape = vec![3, 3];
let analysis = utils::analyze_sparsity(&data, &shape);
assert_eq!(analysis.nnz, 3);
assert_eq!(analysis.total_elements, 9);
assert!((analysis.sparsity - 2.0 / 3.0).abs() < 1e-6);
}
#[test]
fn test_format_recommendation() {
let analysis = utils::SparseAnalysis {
sparsity: 0.9,
nnz: 100,
total_elements: 1000,
pattern_info: utils::PatternInfo::default(),
};
let shape = vec![100, 10];
let recommendation = utils::recommend_format(&analysis, &shape);
assert!(recommendation.format.is_some());
assert!(recommendation.confidence > 0.0);
}
}