use super::super::CscData;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::sparse::{SparseOps, SparseStorage};
impl<R: Runtime<DType = DType>> CscData<R> {
pub fn add(&self, other: &Self) -> Result<Self>
where
R::Client: SparseOps<R>,
{
if self.shape != other.shape {
return Err(Error::ShapeMismatch {
expected: vec![self.shape[0], self.shape[1]],
got: vec![other.shape[0], other.shape[1]],
});
}
if self.dtype() != other.dtype() {
return Err(Error::DTypeMismatch {
lhs: self.dtype(),
rhs: other.dtype(),
});
}
let dtype = self.dtype();
let device = self.values.device();
let client = R::default_client(device);
crate::dispatch_dtype!(dtype, T => {
let (out_col_ptrs, out_row_indices, out_values) = client.add_csc::<T>(
&self.col_ptrs,
&self.row_indices,
&self.values,
&other.col_ptrs,
&other.row_indices,
&other.values,
self.shape,
)?;
Ok(Self {
col_ptrs: out_col_ptrs,
row_indices: out_row_indices,
values: out_values,
shape: self.shape,
})
}, "csc_add")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::CpuRuntime;
#[test]
fn test_csc_add_disjoint() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 2],
&[0i64, 1],
&[1.0f32, 3.0],
[2, 2],
&device,
)
.unwrap();
let b = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 2],
&[1i64, 0],
&[4.0f32, 2.0],
[2, 2],
&device,
)
.unwrap();
let c = a.add(&b).unwrap();
assert_eq!(c.shape(), [2, 2]);
assert_eq!(c.nnz(), 4);
let col_ptrs: Vec<i64> = c.col_ptrs().to_vec();
let row_indices: Vec<i64> = c.row_indices().to_vec();
let vals: Vec<f32> = c.values().to_vec();
assert_eq!(col_ptrs, vec![0, 2, 4]);
assert_eq!(row_indices, vec![0, 1, 0, 1]);
assert_eq!(vals, vec![1.0, 4.0, 2.0, 3.0]);
}
#[test]
fn test_csc_add_overlapping() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 2],
&[0i64, 0],
&[1.0f32, 2.0],
[2, 2],
&device,
)
.unwrap();
let b = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 2],
&[0i64, 1],
&[3.0f32, 4.0],
[2, 2],
&device,
)
.unwrap();
let c = a.add(&b).unwrap();
assert_eq!(c.nnz(), 3);
let col_ptrs: Vec<i64> = c.col_ptrs().to_vec();
let row_indices: Vec<i64> = c.row_indices().to_vec();
let vals: Vec<f32> = c.values().to_vec();
assert_eq!(col_ptrs, vec![0, 1, 3]);
assert_eq!(row_indices, vec![0, 0, 1]);
assert_eq!(vals, vec![4.0, 2.0, 4.0]);
}
#[test]
fn test_csc_add_empty() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::empty([2, 2], DType::F32, &device);
let b = CscData::<CpuRuntime>::from_slices(
&[0i64, 1, 2],
&[0i64, 1],
&[1.0f32, 2.0],
[2, 2],
&device,
)
.unwrap();
let c = a.add(&b).unwrap();
assert_eq!(c.nnz(), 2);
let c2 = b.add(&a).unwrap();
assert_eq!(c2.nnz(), 2);
}
#[test]
fn test_csc_add_shape_mismatch() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CscData::<CpuRuntime>::empty([2, 3], DType::F32, &device);
let b = CscData::<CpuRuntime>::empty([3, 2], DType::F32, &device);
let result = a.add(&b);
assert!(result.is_err());
}
}