use super::CooData;
use crate::dtype::DType;
use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::Tensor;
impl<R: Runtime<DType = DType>> CooData<R> {
pub fn spmv(&self, x: &Tensor<R>) -> Result<Tensor<R>>
where
R::Client: crate::sparse::SparseOps<R>,
{
let csr = self.to_csr()?;
csr.spmv(x)
}
pub fn spmm(&self, b: &Tensor<R>) -> Result<Tensor<R>>
where
R::Client: crate::sparse::SparseOps<R>,
{
let csr = self.to_csr()?;
csr.spmm(b)
}
pub fn transpose(&self) -> Self {
let [nrows, ncols] = self.shape;
Self {
row_indices: self.col_indices.clone(),
col_indices: self.row_indices.clone(),
values: self.values.clone(),
shape: [ncols, nrows],
sorted: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::CpuRuntime;
use crate::sparse::SparseStorage;
#[test]
fn test_coo_transpose() {
let device = <CpuRuntime as Runtime>::Device::default();
let rows = vec![0i64, 0, 1];
let cols = vec![0i64, 2, 1];
let values = vec![1.0f32, 2.0, 3.0];
let coo =
CooData::<CpuRuntime>::from_slices(&rows, &cols, &values, [2, 3], &device).unwrap();
let coo_t = coo.transpose();
assert_eq!(coo_t.shape(), [3, 2]);
assert_eq!(coo_t.nnz(), 3);
assert!(!coo_t.is_sorted());
let t_rows: Vec<i64> = coo_t.row_indices().to_vec();
let t_cols: Vec<i64> = coo_t.col_indices().to_vec();
let t_values: Vec<f32> = coo_t.values().to_vec();
assert_eq!(t_rows, vec![0, 2, 1]); assert_eq!(t_cols, vec![0, 0, 1]); assert_eq!(t_values, vec![1.0, 2.0, 3.0]); }
#[test]
fn test_coo_transpose_empty() {
let device = <CpuRuntime as Runtime>::Device::default();
let coo = CooData::<CpuRuntime>::empty([3, 5], DType::F32, &device);
let coo_t = coo.transpose();
assert_eq!(coo_t.shape(), [5, 3]);
assert_eq!(coo_t.nnz(), 0);
}
#[test]
fn test_coo_transpose_square() {
let device = <CpuRuntime as Runtime>::Device::default();
let rows = vec![0i64, 1, 2];
let cols = vec![1i64, 2, 0];
let values = vec![1.0f32, 2.0, 3.0];
let coo =
CooData::<CpuRuntime>::from_slices(&rows, &cols, &values, [3, 3], &device).unwrap();
let coo_t = coo.transpose();
assert_eq!(coo_t.shape(), [3, 3]);
assert_eq!(coo_t.nnz(), 3);
let t_rows: Vec<i64> = coo_t.row_indices().to_vec();
let t_cols: Vec<i64> = coo_t.col_indices().to_vec();
assert_eq!(t_rows, vec![1, 2, 0]);
assert_eq!(t_cols, vec![0, 1, 2]);
}
#[test]
fn test_coo_transpose_double() {
let device = <CpuRuntime as Runtime>::Device::default();
let rows = vec![0i64, 0, 1];
let cols = vec![0i64, 2, 1];
let values = vec![1.0f32, 2.0, 3.0];
let coo =
CooData::<CpuRuntime>::from_slices(&rows, &cols, &values, [2, 3], &device).unwrap();
let coo_tt = coo.transpose().transpose();
assert_eq!(coo_tt.shape(), [2, 3]);
let tt_rows: Vec<i64> = coo_tt.row_indices().to_vec();
let tt_cols: Vec<i64> = coo_tt.col_indices().to_vec();
assert_eq!(tt_rows, rows);
assert_eq!(tt_cols, cols);
}
}