use super::CsrData;
use crate::dtype::DType;
use crate::error::Result;
use crate::runtime::Runtime;
use crate::sparse::{CooData, CscData, SparseStorage};
use crate::tensor::Tensor;
impl<R: Runtime<DType = DType>> CsrData<R> {
pub fn to_coo(&self) -> Result<CooData<R>> {
let [nrows, _ncols] = self.shape;
let nnz = self.nnz();
if nnz == 0 {
return Ok(CooData::empty(
self.shape,
self.dtype(),
self.values.device(),
));
}
let row_ptrs: Vec<i64> = self.row_ptrs.to_vec();
let col_indices: Vec<i64> = self.col_indices.to_vec();
let mut row_indices: Vec<i64> = Vec::with_capacity(nnz);
for row in 0..nrows {
let start = row_ptrs[row] as usize;
let end = row_ptrs[row + 1] as usize;
for _ in start..end {
row_indices.push(row as i64);
}
}
let device = self.values.device();
crate::dispatch_dtype!(self.dtype(), T => {
let values: Vec<T> = self.values.to_vec();
let row_indices_tensor = Tensor::from_slice(&row_indices, &[row_indices.len()], device);
let col_indices_tensor = Tensor::from_slice(&col_indices, &[col_indices.len()], device);
let values_tensor = Tensor::from_slice(&values, &[values.len()], device);
let mut coo = CooData::new(row_indices_tensor, col_indices_tensor, values_tensor, self.shape)?;
unsafe { coo.set_sorted(true); }
return Ok(coo);
}, "CSR to COO conversion");
}
pub fn to_csc(&self) -> Result<CscData<R>> {
let coo = self.to_coo()?;
coo.to_csc()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::CpuRuntime;
#[test]
fn test_csr_to_coo() {
let device = <CpuRuntime as Runtime>::Device::default();
let row_ptrs = vec![0i64, 2, 3, 5];
let col_indices = vec![0i64, 2, 2, 0, 1];
let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let csr =
CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
.unwrap();
let coo = csr.to_coo().unwrap();
assert_eq!(coo.nnz(), 5);
assert_eq!(coo.shape(), [3, 3]);
assert!(coo.is_sorted());
let coo_rows: Vec<i64> = coo.row_indices().to_vec();
let coo_cols: Vec<i64> = coo.col_indices().to_vec();
let coo_values: Vec<f32> = coo.values().to_vec();
assert_eq!(coo_rows, vec![0, 0, 1, 2, 2]);
assert_eq!(coo_cols, vec![0, 2, 2, 0, 1]);
assert_eq!(coo_values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_csr_to_coo_empty() {
let device = <CpuRuntime as Runtime>::Device::default();
let csr = CsrData::<CpuRuntime>::empty([3, 3], DType::F32, &device);
let coo = csr.to_coo().unwrap();
assert_eq!(coo.nnz(), 0);
assert_eq!(coo.shape(), [3, 3]);
assert!(coo.is_sorted());
}
#[test]
fn test_csr_to_coo_single_element() {
let device = <CpuRuntime as Runtime>::Device::default();
let row_ptrs = vec![0i64, 0, 1, 1];
let col_indices = vec![2i64];
let values = vec![42.0f32];
let csr =
CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
.unwrap();
let coo = csr.to_coo().unwrap();
assert_eq!(coo.nnz(), 1);
let coo_rows: Vec<i64> = coo.row_indices().to_vec();
let coo_cols: Vec<i64> = coo.col_indices().to_vec();
let coo_values: Vec<f32> = coo.values().to_vec();
assert_eq!(coo_rows, vec![1]);
assert_eq!(coo_cols, vec![2]);
assert_eq!(coo_values, vec![42.0]);
}
#[test]
fn test_csr_to_csc() {
let device = <CpuRuntime as Runtime>::Device::default();
let row_ptrs = vec![0i64, 2, 3, 5];
let col_indices = vec![0i64, 2, 2, 0, 1];
let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let csr =
CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
.unwrap();
let csc = csr.to_csc().unwrap();
assert_eq!(csc.nnz(), 5);
assert_eq!(csc.shape(), [3, 3]);
let csc_col_ptrs: Vec<i64> = csc.col_ptrs().to_vec();
let csc_row_indices: Vec<i64> = csc.row_indices().to_vec();
let csc_values: Vec<f32> = csc.values().to_vec();
assert_eq!(csc_col_ptrs, vec![0, 2, 3, 5]);
assert_eq!(csc_row_indices, vec![0, 2, 2, 0, 1]);
assert_eq!(csc_values, vec![1.0, 4.0, 5.0, 2.0, 3.0]);
}
#[test]
fn test_csr_to_csc_empty() {
let device = <CpuRuntime as Runtime>::Device::default();
let csr = CsrData::<CpuRuntime>::empty([3, 3], DType::F32, &device);
let csc = csr.to_csc().unwrap();
assert_eq!(csc.nnz(), 0);
assert_eq!(csc.shape(), [3, 3]);
}
}