use crate::dtype::DType;
use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::Tensor;
use super::super::coo::CooData;
use super::super::csc::CscData;
use super::super::csr::CsrData;
use super::super::format::{SparseFormat, SparseStorage};
#[derive(Debug, Clone)]
pub enum SparseTensor<R: Runtime> {
Coo(CooData<R>),
Csr(CsrData<R>),
Csc(CscData<R>),
}
impl<R: Runtime<DType = DType>> SparseTensor<R> {
pub fn from_coo(data: CooData<R>) -> Self {
SparseTensor::Coo(data)
}
pub fn from_csr(data: CsrData<R>) -> Self {
SparseTensor::Csr(data)
}
pub fn from_csc(data: CscData<R>) -> Self {
SparseTensor::Csc(data)
}
pub fn empty(
shape: [usize; 2],
dtype: DType,
format: SparseFormat,
device: &R::Device,
) -> Self {
match format {
SparseFormat::Coo => SparseTensor::Coo(CooData::empty(shape, dtype, device)),
SparseFormat::Csr => SparseTensor::Csr(CsrData::empty(shape, dtype, device)),
SparseFormat::Csc => SparseTensor::Csc(CscData::empty(shape, dtype, device)),
}
}
pub fn from_coo_slices<T: crate::dtype::Element>(
rows: &[i64],
cols: &[i64],
values: &[T],
shape: [usize; 2],
device: &R::Device,
) -> Result<Self> {
let coo = CooData::from_slices(rows, cols, values, shape, device)?;
Ok(SparseTensor::Coo(coo))
}
pub fn from_csr_slices<T: crate::dtype::Element>(
row_ptrs: &[i64],
col_indices: &[i64],
values: &[T],
shape: [usize; 2],
device: &R::Device,
) -> Result<Self> {
let csr = CsrData::from_slices(row_ptrs, col_indices, values, shape, device)?;
Ok(SparseTensor::Csr(csr))
}
pub fn from_csc_slices<T: crate::dtype::Element>(
col_ptrs: &[i64],
row_indices: &[i64],
values: &[T],
shape: [usize; 2],
device: &R::Device,
) -> Result<Self> {
let csc = CscData::from_slices(col_ptrs, row_indices, values, shape, device)?;
Ok(SparseTensor::Csc(csc))
}
pub fn from_dense<C>(client: &C, tensor: &Tensor<R>, threshold: f64) -> Result<Self>
where
C: crate::sparse::SparseOps<R>,
{
client.dense_to_sparse(tensor, threshold)
}
pub fn format(&self) -> SparseFormat {
match self {
SparseTensor::Coo(d) => d.format(),
SparseTensor::Csr(d) => d.format(),
SparseTensor::Csc(d) => d.format(),
}
}
pub fn shape(&self) -> [usize; 2] {
match self {
SparseTensor::Coo(d) => d.shape(),
SparseTensor::Csr(d) => d.shape(),
SparseTensor::Csc(d) => d.shape(),
}
}
pub fn nrows(&self) -> usize {
self.shape()[0]
}
pub fn ncols(&self) -> usize {
self.shape()[1]
}
pub fn nnz(&self) -> usize {
match self {
SparseTensor::Coo(d) => d.nnz(),
SparseTensor::Csr(d) => d.nnz(),
SparseTensor::Csc(d) => d.nnz(),
}
}
pub fn dtype(&self) -> DType {
match self {
SparseTensor::Coo(d) => d.dtype(),
SparseTensor::Csr(d) => d.dtype(),
SparseTensor::Csc(d) => d.dtype(),
}
}
pub fn sparsity(&self) -> f64 {
match self {
SparseTensor::Coo(d) => d.sparsity(),
SparseTensor::Csr(d) => d.sparsity(),
SparseTensor::Csc(d) => d.sparsity(),
}
}
pub fn density(&self) -> f64 {
match self {
SparseTensor::Coo(d) => d.density(),
SparseTensor::Csr(d) => d.density(),
SparseTensor::Csc(d) => d.density(),
}
}
pub fn is_empty(&self) -> bool {
self.nnz() == 0
}
pub fn memory_usage(&self) -> usize {
match self {
SparseTensor::Coo(d) => d.memory_usage(),
SparseTensor::Csr(d) => d.memory_usage(),
SparseTensor::Csc(d) => d.memory_usage(),
}
}
pub fn is_coo(&self) -> bool {
matches!(self, SparseTensor::Coo(_))
}
pub fn is_csr(&self) -> bool {
matches!(self, SparseTensor::Csr(_))
}
pub fn is_csc(&self) -> bool {
matches!(self, SparseTensor::Csc(_))
}
pub fn as_coo(&self) -> Option<&CooData<R>> {
match self {
SparseTensor::Coo(d) => Some(d),
_ => None,
}
}
pub fn as_csr(&self) -> Option<&CsrData<R>> {
match self {
SparseTensor::Csr(d) => Some(d),
_ => None,
}
}
pub fn as_csc(&self) -> Option<&CscData<R>> {
match self {
SparseTensor::Csc(d) => Some(d),
_ => None,
}
}
}
impl<R: Runtime<DType = DType>> std::fmt::Display for SparseTensor<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"SparseTensor({:?}, nnz={}, format={}, dtype={}, sparsity={:.1}%)",
self.shape(),
self.nnz(),
self.format(),
self.dtype(),
self.sparsity() * 100.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::Runtime;
use crate::runtime::cpu::{CpuClient, CpuRuntime};
use crate::sparse::SparseFormat;
use crate::tensor::Tensor;
#[test]
fn test_sparse_tensor_coo() {
let device = <CpuRuntime as Runtime>::Device::default();
let sparse = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64, 1, 2],
&[1i64, 0, 2],
&[1.0f32, 2.0, 3.0],
[3, 3],
&device,
)
.unwrap();
assert!(sparse.is_coo());
assert_eq!(sparse.format(), SparseFormat::Coo);
assert_eq!(sparse.nnz(), 3);
assert_eq!(sparse.shape(), [3, 3]);
}
#[test]
fn test_sparse_tensor_csr() {
let device = <CpuRuntime as Runtime>::Device::default();
let sparse = SparseTensor::<CpuRuntime>::from_csr_slices(
&[0i64, 2, 3, 5],
&[0i64, 2, 2, 0, 1],
&[1.0f32, 2.0, 3.0, 4.0, 5.0],
[3, 3],
&device,
)
.unwrap();
assert!(sparse.is_csr());
assert_eq!(sparse.format(), SparseFormat::Csr);
assert_eq!(sparse.nnz(), 5);
}
#[test]
fn test_sparse_tensor_coo_to_csr_conversion() {
let device = <CpuRuntime as Runtime>::Device::default();
let coo = SparseTensor::<CpuRuntime>::from_coo_slices(
&[2i64, 0, 1, 0, 2], &[1i64, 0, 2, 2, 0],
&[5.0f32, 1.0, 3.0, 2.0, 4.0],
[3, 3],
&device,
)
.unwrap();
assert!(coo.is_coo());
assert_eq!(coo.nnz(), 5);
let csr = coo.to_csr().unwrap();
assert!(csr.is_csr());
assert_eq!(csr.format(), SparseFormat::Csr);
assert_eq!(csr.nnz(), 5);
assert_eq!(csr.shape(), [3, 3]);
let csr_data = csr.as_csr().unwrap();
let row_ptrs: Vec<i64> = csr_data.row_ptrs().to_vec();
let col_indices: Vec<i64> = csr_data.col_indices().to_vec();
let values: Vec<f32> = csr_data.values().to_vec();
assert_eq!(row_ptrs, vec![0, 2, 3, 5]);
assert_eq!(col_indices, vec![0, 2, 2, 0, 1]);
assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_sparse_tensor_coo_to_csc_conversion() {
let device = <CpuRuntime as Runtime>::Device::default();
let coo = SparseTensor::<CpuRuntime>::from_coo_slices(
&[2i64, 0, 1, 0, 2], &[1i64, 0, 2, 2, 0],
&[5.0f32, 1.0, 3.0, 2.0, 4.0],
[3, 3],
&device,
)
.unwrap();
assert!(coo.is_coo());
let csc = coo.to_csc().unwrap();
assert!(csc.is_csc());
assert_eq!(csc.format(), SparseFormat::Csc);
assert_eq!(csc.nnz(), 5);
assert_eq!(csc.shape(), [3, 3]);
let csc_data = csc.as_csc().unwrap();
let col_ptrs: Vec<i64> = csc_data.col_ptrs().to_vec();
let row_indices: Vec<i64> = csc_data.row_indices().to_vec();
let values: Vec<f32> = csc_data.values().to_vec();
assert_eq!(col_ptrs, vec![0, 2, 3, 5]);
assert_eq!(row_indices, vec![0, 2, 2, 0, 1]);
assert_eq!(values, vec![1.0, 4.0, 5.0, 2.0, 3.0]);
}
#[test]
fn test_sparse_tensor_display() {
let device = <CpuRuntime as Runtime>::Device::default();
let sparse = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64, 1],
&[0i64, 1],
&[1.0f32, 2.0],
[10, 10],
&device,
)
.unwrap();
let display = format!("{}", sparse);
assert!(display.contains("SparseTensor"));
assert!(display.contains("nnz=2"));
assert!(display.contains("COO"));
}
#[test]
fn test_from_dense() {
let device = <CpuRuntime as Runtime>::Device::default();
let client = CpuClient::new(device.clone());
let data = vec![1.0f32, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0];
let dense = Tensor::<CpuRuntime>::from_slice(&data, &[3, 3], &device);
let sparse = SparseTensor::from_dense(&client, &dense, 1e-10).unwrap();
assert!(sparse.is_coo());
assert_eq!(sparse.nnz(), 5);
assert_eq!(sparse.shape(), [3, 3]);
let coo = sparse.as_coo().unwrap();
let rows: Vec<i64> = coo.row_indices().to_vec();
let cols: Vec<i64> = coo.col_indices().to_vec();
let values: Vec<f32> = coo.values().to_vec();
assert_eq!(rows, vec![0, 0, 1, 2, 2]);
assert_eq!(cols, vec![0, 2, 2, 0, 1]);
assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert!(coo.is_sorted());
}
#[test]
fn test_from_dense_empty() {
let device = <CpuRuntime as Runtime>::Device::default();
let client = CpuClient::new(device.clone());
let data = vec![0.0f32; 9];
let dense = Tensor::<CpuRuntime>::from_slice(&data, &[3, 3], &device);
let sparse = SparseTensor::from_dense(&client, &dense, 1e-10).unwrap();
assert_eq!(sparse.nnz(), 0);
assert_eq!(sparse.shape(), [3, 3]);
}
#[test]
fn test_from_dense_with_threshold() {
let device = <CpuRuntime as Runtime>::Device::default();
let client = CpuClient::new(device.clone());
let data = vec![1.0f32, 0.001, 2.0, 0.0, 0.0, 0.002];
let dense = Tensor::<CpuRuntime>::from_slice(&data, &[2, 3], &device);
let sparse = SparseTensor::from_dense(&client, &dense, 0.01).unwrap();
assert_eq!(sparse.nnz(), 2);
let coo = sparse.as_coo().unwrap();
let values: Vec<f32> = coo.values().to_vec();
assert_eq!(values, vec![1.0, 2.0]);
}
#[test]
fn test_from_dense_single_element() {
let device = <CpuRuntime as Runtime>::Device::default();
let client = CpuClient::new(device.clone());
let data = vec![0.0f32, 0.0, 0.0, 42.0];
let dense = Tensor::<CpuRuntime>::from_slice(&data, &[2, 2], &device);
let sparse = SparseTensor::from_dense(&client, &dense, 1e-10).unwrap();
assert_eq!(sparse.nnz(), 1);
let coo = sparse.as_coo().unwrap();
let rows: Vec<i64> = coo.row_indices().to_vec();
let cols: Vec<i64> = coo.col_indices().to_vec();
let values: Vec<f32> = coo.values().to_vec();
assert_eq!(rows, vec![1]);
assert_eq!(cols, vec![1]);
assert_eq!(values, vec![42.0]);
}
#[test]
fn test_from_dense_f64() {
let device = <CpuRuntime as Runtime>::Device::default();
let client = CpuClient::new(device.clone());
let data = vec![1.0f64, 0.0, 2.0, 3.0];
let dense = Tensor::<CpuRuntime>::from_slice(&data, &[2, 2], &device);
let sparse = SparseTensor::from_dense(&client, &dense, 1e-10).unwrap();
assert_eq!(sparse.nnz(), 3);
assert_eq!(sparse.dtype(), DType::F64);
}
}