use super::super::CooData;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::sparse::{SparseOps, SparseStorage};
impl<R: Runtime<DType = DType>> CooData<R> {
pub fn sub(&self, other: &Self) -> Result<Self>
where
R::Client: SparseOps<R>,
{
if self.shape != other.shape {
return Err(Error::ShapeMismatch {
expected: vec![self.shape[0], self.shape[1]],
got: vec![other.shape[0], other.shape[1]],
});
}
if self.dtype() != other.dtype() {
return Err(Error::DTypeMismatch {
lhs: self.dtype(),
rhs: other.dtype(),
});
}
let dtype = self.dtype();
let device = self.values.device();
let client = R::default_client(device);
crate::dispatch_dtype!(dtype, T => {
let (out_row_indices, out_col_indices, out_values) = client.sub_coo::<T>(
&self.row_indices,
&self.col_indices,
&self.values,
&other.row_indices,
&other.col_indices,
&other.values,
self.shape,
)?;
Ok(Self {
row_indices: out_row_indices,
col_indices: out_col_indices,
values: out_values,
shape: self.shape,
sorted: true, })
}, "coo_sub")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::cpu::CpuRuntime;
#[test]
fn test_coo_sub_basic() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = CooData::<CpuRuntime>::from_slices(
&[0i64, 1],
&[0i64, 1],
&[5.0f32, 4.0],
[2, 2],
&device,
)
.unwrap();
let b = CooData::<CpuRuntime>::from_slices(
&[0i64, 0, 1],
&[0i64, 1, 1],
&[2.0f32, 1.0, 3.0],
[2, 2],
&device,
)
.unwrap();
let c = a.sub(&b).unwrap();
assert_eq!(c.shape(), [2, 2]);
assert!(c.is_sorted());
}
}