use crate::sparse::core::SparseTensor;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn sparse_to_csr(sparse: &SparseTensor) -> TorshResult<(Tensor, Tensor, Tensor)> {
if sparse.ndim != 2 {
return Err(TorshError::invalid_argument_with_context(
"CSR format only supports 2D tensors",
"sparse_to_csr",
));
}
let values = sparse.values.clone();
let indices_data = sparse.indices.to_vec()?;
let mut col_indices = Vec::with_capacity(sparse.nnz);
for i in 0..sparse.nnz {
col_indices.push(indices_data[sparse.nnz + i]);
}
let mut row_ptrs = vec![0.0f32; sparse.shape[0] + 1];
let mut current_row = 0usize;
let mut ptr = 0;
for i in 0..sparse.nnz {
let row = indices_data[i] as usize;
while current_row <= row {
row_ptrs[current_row] = ptr as f32;
current_row += 1;
}
ptr += 1;
}
while current_row <= sparse.shape[0] {
row_ptrs[current_row] = ptr as f32;
current_row += 1;
}
let col_indices_tensor =
Tensor::from_data(col_indices, vec![sparse.nnz], sparse.values.device())?;
let row_ptrs_tensor =
Tensor::from_data(row_ptrs, vec![sparse.shape[0] + 1], sparse.values.device())?;
Ok((values, col_indices_tensor, row_ptrs_tensor))
}
pub fn csr_to_sparse(
values: &Tensor,
col_indices: &Tensor,
row_ptrs: &Tensor,
shape: &[usize],
) -> TorshResult<SparseTensor> {
if shape.len() != 2 {
return Err(TorshError::invalid_argument_with_context(
"CSR to COO conversion only supports 2D tensors",
"csr_to_sparse",
));
}
let values_data = values.to_vec()?;
let col_indices_data = col_indices.to_vec()?;
let row_ptrs_data = row_ptrs.to_vec()?;
let nnz = values_data.len();
let mut row_indices = Vec::with_capacity(nnz);
for row in 0..shape[0] {
let start = row_ptrs_data[row] as usize;
let end = row_ptrs_data[row + 1] as usize;
for _ in start..end {
row_indices.push(row as f32);
}
}
let mut indices_data = Vec::with_capacity(2 * nnz);
indices_data.extend_from_slice(&row_indices);
indices_data.extend_from_slice(&col_indices_data);
let indices_tensor = Tensor::from_data(indices_data, vec![2, nnz], values.device())?;
SparseTensor::new(values.clone(), indices_tensor, shape.to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse::core::sparse_coo_tensor;
#[test]
fn test_sparse_to_csr() -> TorshResult<()> {
let values = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![5],
torsh_core::DeviceType::Cpu,
)?;
let indices = Tensor::from_data(
vec![
0.0, 0.0, 1.0, 2.0, 2.0, 0.0, 2.0, 1.0, 0.0, 2.0,
], vec![2, 5],
torsh_core::DeviceType::Cpu,
)?;
let sparse = sparse_coo_tensor(&indices, &values, &[3, 3])?;
let (csr_values, col_indices, row_ptrs) = sparse_to_csr(&sparse)?;
let csr_values_data = csr_values.to_vec()?;
assert_eq!(csr_values_data.len(), 5);
let col_indices_data = col_indices.to_vec()?;
assert_eq!(col_indices_data.len(), 5);
let row_ptrs_data = row_ptrs.to_vec()?;
assert_eq!(row_ptrs_data.len(), 4);
assert_eq!(row_ptrs_data[0], 0.0); assert_eq!(row_ptrs_data[3], 5.0);
Ok(())
}
#[test]
fn test_csr_to_sparse_roundtrip() -> TorshResult<()> {
let values = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], torsh_core::DeviceType::Cpu)?;
let indices = Tensor::from_data(
vec![0.0, 1.0, 2.0, 0.0, 1.0, 2.0], vec![2, 3],
torsh_core::DeviceType::Cpu,
)?;
let original_sparse = sparse_coo_tensor(&indices, &values, &[3, 3])?;
let (csr_values, col_indices, row_ptrs) = sparse_to_csr(&original_sparse)?;
let reconstructed_sparse = csr_to_sparse(&csr_values, &col_indices, &row_ptrs, &[3, 3])?;
let original_dense = original_sparse.to_dense()?;
let reconstructed_dense = reconstructed_sparse.to_dense()?;
let original_data = original_dense.to_vec()?;
let reconstructed_data = reconstructed_dense.to_vec()?;
for (original, reconstructed) in original_data.iter().zip(reconstructed_data.iter()) {
assert!((original - reconstructed).abs() < 1e-6);
}
Ok(())
}
}