use crate::dtype::DType;
use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::Tensor;
use super::{CscData, CsrData, SparseTensor};
pub trait SparseOps<R: Runtime<DType = DType>>: Sized {
fn spmv_csr<T: crate::dtype::Element>(
&self,
row_ptrs: &Tensor<R>,
col_indices: &Tensor<R>,
values: &Tensor<R>,
x: &Tensor<R>,
shape: [usize; 2],
) -> Result<Tensor<R>>;
fn spmm_csr<T: crate::dtype::Element>(
&self,
row_ptrs: &Tensor<R>,
col_indices: &Tensor<R>,
values: &Tensor<R>,
b: &Tensor<R>,
shape: [usize; 2],
) -> Result<Tensor<R>>;
fn add_csr<T: crate::dtype::Element>(
&self,
a_row_ptrs: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_ptrs: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn sub_csr<T: crate::dtype::Element>(
&self,
a_row_ptrs: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_ptrs: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn mul_csr<T: crate::dtype::Element>(
&self,
a_row_ptrs: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_ptrs: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn div_csr<T: crate::dtype::Element>(
&self,
a_row_ptrs: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_ptrs: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn add_csc<T: crate::dtype::Element>(
&self,
a_col_ptrs: &Tensor<R>,
a_row_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_col_ptrs: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn sub_csc<T: crate::dtype::Element>(
&self,
a_col_ptrs: &Tensor<R>,
a_row_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_col_ptrs: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn mul_csc<T: crate::dtype::Element>(
&self,
a_col_ptrs: &Tensor<R>,
a_row_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_col_ptrs: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn div_csc<T: crate::dtype::Element>(
&self,
a_col_ptrs: &Tensor<R>,
a_row_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_col_ptrs: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn add_coo<T: crate::dtype::Element>(
&self,
a_row_indices: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn sub_coo<T: crate::dtype::Element>(
&self,
a_row_indices: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn mul_coo<T: crate::dtype::Element>(
&self,
a_row_indices: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn div_coo<T: crate::dtype::Element>(
&self,
a_row_indices: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn spmv(&self, a: &SparseTensor<R>, x: &Tensor<R>) -> Result<Tensor<R>>;
fn spmm(&self, a: &SparseTensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>;
fn dsmm(&self, a: &Tensor<R>, b: &SparseTensor<R>) -> Result<Tensor<R>>;
fn sparse_add(&self, a: &SparseTensor<R>, b: &SparseTensor<R>) -> Result<SparseTensor<R>>;
fn sparse_sub(&self, a: &SparseTensor<R>, b: &SparseTensor<R>) -> Result<SparseTensor<R>>;
fn sparse_matmul(&self, a: &SparseTensor<R>, b: &SparseTensor<R>) -> Result<SparseTensor<R>>;
fn sparse_mul(&self, a: &SparseTensor<R>, b: &SparseTensor<R>) -> Result<SparseTensor<R>>;
fn sparse_scale(&self, a: &SparseTensor<R>, scalar: f64) -> Result<SparseTensor<R>>;
fn sparse_add_scalar(&self, a: &SparseTensor<R>, scalar: f64) -> Result<SparseTensor<R>>;
fn sparse_sum(&self, a: &SparseTensor<R>) -> Result<Tensor<R>>;
fn sparse_sum_rows(&self, a: &SparseTensor<R>) -> Result<Tensor<R>>;
fn sparse_sum_cols(&self, a: &SparseTensor<R>) -> Result<Tensor<R>>;
fn sparse_nnz_per_row(&self, a: &SparseTensor<R>) -> Result<Tensor<R>>;
fn sparse_nnz_per_col(&self, a: &SparseTensor<R>) -> Result<Tensor<R>>;
fn sparse_to_dense(&self, a: &SparseTensor<R>) -> Result<Tensor<R>>;
fn dense_to_sparse(&self, a: &Tensor<R>, threshold: f64) -> Result<SparseTensor<R>>;
fn dense_to_csr(&self, a: &Tensor<R>, threshold: f64) -> Result<CsrData<R>> {
let sparse = self.dense_to_sparse(a, threshold)?;
let csr = sparse.to_csr()?;
match csr {
SparseTensor::Csr(data) => Ok(data),
_ => unreachable!("to_csr() always returns SparseTensor::Csr"),
}
}
fn dense_to_csc(&self, a: &Tensor<R>, threshold: f64) -> Result<CscData<R>> {
let sparse = self.dense_to_sparse(a, threshold)?;
let csc = sparse.to_csc()?;
match csc {
SparseTensor::Csc(data) => Ok(data),
_ => unreachable!("to_csc() always returns SparseTensor::Csc"),
}
}
fn coo_to_csr<T: crate::dtype::Element>(
&self,
row_indices: &Tensor<R>,
col_indices: &Tensor<R>,
values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn coo_to_csc<T: crate::dtype::Element>(
&self,
row_indices: &Tensor<R>,
col_indices: &Tensor<R>,
values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn csr_to_coo<T: crate::dtype::Element>(
&self,
row_ptrs: &Tensor<R>,
col_indices: &Tensor<R>,
values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn csc_to_coo<T: crate::dtype::Element>(
&self,
col_ptrs: &Tensor<R>,
row_indices: &Tensor<R>,
values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn csr_to_csc<T: crate::dtype::Element>(
&self,
row_ptrs: &Tensor<R>,
col_indices: &Tensor<R>,
values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn csc_to_csr<T: crate::dtype::Element>(
&self,
col_ptrs: &Tensor<R>,
row_indices: &Tensor<R>,
values: &Tensor<R>,
shape: [usize; 2],
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn sparse_transpose(&self, a: &SparseTensor<R>) -> Result<SparseTensor<R>>;
fn extract_diagonal_csr<T: crate::dtype::Element>(
&self,
row_ptrs: &Tensor<R>,
col_indices: &Tensor<R>,
values: &Tensor<R>,
shape: [usize; 2],
) -> Result<Tensor<R>>;
fn sparse_extract_diagonal(&self, a: &SparseTensor<R>) -> Result<Tensor<R>> {
let csr = match a {
SparseTensor::Csr(data) => data.clone(),
SparseTensor::Coo(data) => data.to_csr()?,
SparseTensor::Csc(data) => data.to_csr()?,
};
let shape = csr.shape;
let dtype = csr.values.dtype();
crate::dispatch_dtype!(dtype, T => {
self.extract_diagonal_csr::<T>(
&csr.row_ptrs,
&csr.col_indices,
&csr.values,
shape,
)
}, "sparse_extract_diagonal")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormType {
L1,
L2,
Linf,
}
pub trait SparseScaling<R: Runtime> {
fn row_norms<T: crate::dtype::Element + Default + Copy>(
&self,
norm: NormType,
) -> Result<Tensor<R>>;
fn col_norms<T: crate::dtype::Element + Default + Copy>(
&self,
norm: NormType,
) -> Result<Tensor<R>>;
fn scale_rows<T: crate::dtype::Element + Default + Copy + std::ops::Mul<Output = T>>(
&self,
scales: &[T],
) -> Result<Self>
where
Self: Sized;
fn scale_cols<T: crate::dtype::Element + Default + Copy + std::ops::Mul<Output = T>>(
&self,
scales: &[T],
) -> Result<Self>
where
Self: Sized;
fn equilibrate<T: crate::dtype::Element + Default + Copy + num_traits::Float>(
&self,
) -> Result<(Self, Vec<T>, Vec<T>)>
where
Self: Sized;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_ops_trait_exists() {
fn _accepts_sparse_ops<R: Runtime<DType = DType>, T: SparseOps<R>>(_: &T) {}
}
#[test]
fn test_norm_type() {
let _l1 = NormType::L1;
let _l2 = NormType::L2;
let _linf = NormType::Linf;
assert_eq!(NormType::L1, NormType::L1);
}
}