use super::super::CscData;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::sparse::{SparseOps, SparseStorage};
impl<R: Runtime<DType = DType>> CscData<R> {
pub fn mul(&self, other: &Self) -> Result<Self>
where
R::Client: SparseOps<R>,
{
if self.shape != other.shape {
return Err(Error::ShapeMismatch {
expected: vec![self.shape[0], self.shape[1]],
got: vec![other.shape[0], other.shape[1]],
});
}
if self.dtype() != other.dtype() {
return Err(Error::DTypeMismatch {
lhs: self.dtype(),
rhs: other.dtype(),
});
}
let dtype = self.dtype();
let device = self.values.device();
let client = R::default_client(device);
crate::dispatch_dtype!(dtype, T => {
let (out_col_ptrs, out_row_indices, out_values) = client.mul_csc::<T>(
&self.col_ptrs,
&self.row_indices,
&self.values,
&other.col_ptrs,
&other.row_indices,
&other.values,
self.shape,
)?;
Ok(Self {
col_ptrs: out_col_ptrs,
row_indices: out_row_indices,
values: out_values,
shape: self.shape,
})
}, "csc_mul")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::CpuRuntime;
#[test]
fn test_csc_mul_overlapping() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 3],
&[0i64, 0, 1],
&[2.0f32, 3.0, 5.0],
[2, 2],
&device,
)
.unwrap();
let b = CscData::<CpuRuntime>::from_slices(
&[0i64, 2, 3],
&[0i64, 1, 1],
&[4.0f32, 6.0, 7.0],
[2, 2],
&device,
)
.unwrap();
let c = a.mul(&b).unwrap();
assert_eq!(c.shape(), [2, 2]);
assert_eq!(c.nnz(), 2);
let col_ptrs: Vec<i64> = c.col_ptrs().to_vec();
let row_indices: Vec<i64> = c.row_indices().to_vec();
let vals: Vec<f32> = c.values().to_vec();
assert_eq!(col_ptrs, vec![0, 1, 2]);
assert_eq!(row_indices, vec![0, 1]);
assert_eq!(vals, vec![8.0, 35.0]);
}
#[test]
fn test_csc_mul_disjoint() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 2],
&[0i64, 1],
&[1.0f32, 3.0],
[2, 2],
&device,
)
.unwrap();
let b = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 2],
&[1i64, 0],
&[4.0f32, 2.0],
[2, 2],
&device,
)
.unwrap();
let c = a.mul(&b).unwrap();
assert_eq!(c.nnz(), 0);
}
#[test]
fn test_csc_mul_empty() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 2],
&[0i64, 1],
&[1.0f32, 2.0],
[2, 2],
&device,
)
.unwrap();
let b = CscData::<CpuRuntime>::empty([2, 2], DType::F32, &device);
let c = a.mul(&b).unwrap();
assert_eq!(c.nnz(), 0);
let c2 = b.mul(&a).unwrap();
assert_eq!(c2.nnz(), 0);
}
#[test]
fn test_csc_mul_shape_mismatch() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::empty([2, 3], DType::F32, &device);
let b = CscData::<CpuRuntime>::empty([3, 2], DType::F32, &device);
let result = a.mul(&b);
assert!(result.is_err());
}
#[test]
fn test_csc_mul_same_positions() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::from_slices(
&[0i64, 2, 3],
&[0i64, 1, 0],
&[2.0f32, 3.0, 4.0],
[2, 2],
&device,
)
.unwrap();
let b = CscData::<CpuRuntime>::from_slices(
&[0i64, 2, 3],
&[0i64, 1, 0],
&[5.0f32, 6.0, 7.0],
[2, 2],
&device,
)
.unwrap();
let c = a.mul(&b).unwrap();
assert_eq!(c.nnz(), 3);
let vals: Vec<f32> = c.values().to_vec();
assert_eq!(vals, vec![10.0, 18.0, 28.0]); }
}