use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::sparse::{SparseOps, SparseTensor};
impl<R: Runtime<DType = DType>> SparseTensor<R> {
pub fn div(&self, other: &SparseTensor<R>) -> Result<SparseTensor<R>>
where
R::Client: SparseOps<R>,
{
if self.shape() != other.shape() {
return Err(Error::ShapeMismatch {
expected: vec![self.shape()[0], self.shape()[1]],
got: vec![other.shape()[0], other.shape()[1]],
});
}
if self.dtype() != other.dtype() {
return Err(Error::DTypeMismatch {
lhs: self.dtype(),
rhs: other.dtype(),
});
}
match (self, other) {
(SparseTensor::Coo(a), SparseTensor::Coo(b)) => Ok(SparseTensor::Coo(a.div(b)?)),
(SparseTensor::Csr(a), SparseTensor::Csr(b)) => Ok(SparseTensor::Csr(a.div(b)?)),
(SparseTensor::Csc(a), SparseTensor::Csc(b)) => Ok(SparseTensor::Csc(a.div(b)?)),
_ => {
let coo_a = self.to_coo()?;
let coo_b = other.to_coo()?;
let coo_a_data = coo_a.as_coo().unwrap();
let coo_b_data = coo_b.as_coo().unwrap();
Ok(SparseTensor::Coo(coo_a_data.div(coo_b_data)?))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Runtime;
use crate::runtime::cpu::{CpuClient, CpuRuntime};
use crate::tensor::Tensor;
#[test]
fn test_div_coo_coo() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64, 1],
&[0i64, 1],
&[8.0f32, 35.0],
[2, 2],
&device,
)
.unwrap();
let b = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64, 0, 1, 1],
&[0i64, 1, 0, 1],
&[2.0f32, 5.0, 7.0, 7.0],
[2, 2],
&device,
)
.unwrap();
let c = a.div(&b).unwrap();
assert_eq!(c.nnz(), 2);
let dense = c.to_dense(&device).unwrap();
let data: Vec<f32> = dense.to_vec();
assert_eq!(data, vec![4.0, 0.0, 0.0, 5.0]);
}
#[test]
fn test_div_csr_csr() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = SparseTensor::<CpuRuntime>::from_csr_slices(
&[0i64, 1, 2],
&[0i64, 1],
&[8.0f32, 35.0],
[2, 2],
&device,
)
.unwrap();
let b = SparseTensor::<CpuRuntime>::from_csr_slices(
&[0i64, 2, 4],
&[0i64, 1, 0, 1],
&[2.0f32, 5.0, 7.0, 7.0],
[2, 2],
&device,
)
.unwrap();
let c = a.div(&b).unwrap();
assert_eq!(c.nnz(), 2);
let dense = c.to_dense(&device).unwrap();
let data: Vec<f32> = dense.to_vec();
assert_eq!(data, vec![4.0, 0.0, 0.0, 5.0]);
}
#[test]
fn test_div_csc_csc() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = SparseTensor::<CpuRuntime>::from_csc_slices(
&[0i64, 1, 2],
&[0i64, 1],
&[8.0f32, 35.0],
[2, 2],
&device,
)
.unwrap();
let b = SparseTensor::<CpuRuntime>::from_csc_slices(
&[0i64, 2, 4],
&[0i64, 1, 0, 1],
&[2.0f32, 7.0, 5.0, 7.0],
[2, 2],
&device,
)
.unwrap();
let c = a.div(&b).unwrap();
assert_eq!(c.nnz(), 2);
let dense = c.to_dense(&device).unwrap();
let data: Vec<f32> = dense.to_vec();
assert_eq!(data, vec![4.0, 0.0, 0.0, 5.0]);
}
#[test]
fn test_div_mixed_formats() {
let device = <CpuRuntime as Runtime>::Device::default();
let csr = SparseTensor::<CpuRuntime>::from_csr_slices(
&[0i64, 1, 2],
&[0i64, 1],
&[8.0f32, 35.0],
[2, 2],
&device,
)
.unwrap();
let coo = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64, 0, 1, 1],
&[0i64, 1, 0, 1],
&[2.0f32, 5.0, 7.0, 7.0],
[2, 2],
&device,
)
.unwrap();
let c = csr.div(&coo).unwrap();
assert!(matches!(c, SparseTensor::Coo(_)));
assert_eq!(c.nnz(), 2);
let dense = c.to_dense(&device).unwrap();
let data: Vec<f32> = dense.to_vec();
assert_eq!(data, vec![4.0, 0.0, 0.0, 5.0]);
}
#[test]
fn test_div_disjoint() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64, 1],
&[0i64, 1],
&[1.0f32, 3.0],
[2, 2],
&device,
)
.unwrap();
let b = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64, 1],
&[1i64, 0],
&[2.0f32, 4.0],
[2, 2],
&device,
)
.unwrap();
let c = a.div(&b).unwrap();
assert_eq!(c.nnz(), 0);
}
#[test]
fn test_div_shape_mismatch() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64],
&[0i64],
&[1.0f32],
[2, 3],
&device,
)
.unwrap();
let b = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64],
&[0i64],
&[1.0f32],
[3, 2],
&device,
)
.unwrap();
let result = a.div(&b);
assert!(result.is_err());
}
#[test]
fn test_div_dtype_mismatch() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64],
&[0i64],
&[1.0f32],
[2, 2],
&device,
)
.unwrap();
let b = SparseTensor::<CpuRuntime>::from_coo_slices(
&[0i64],
&[0i64],
&[1.0f64],
[2, 2],
&device,
)
.unwrap();
let result = a.div(&b);
assert!(result.is_err());
}
#[test]
fn test_div_from_dense() {
let device = <CpuRuntime as Runtime>::Device::default();
let client = CpuClient::new(device.clone());
let dense_a =
Tensor::<CpuRuntime>::from_slice(&[10.0f32, 0.0, 0.0, 20.0], &[2, 2], &device);
let dense_b = Tensor::<CpuRuntime>::from_slice(&[2.0f32, 0.0, 5.0, 4.0], &[2, 2], &device);
let sparse_a = SparseTensor::from_dense(&client, &dense_a, 1e-10).unwrap();
let sparse_b = SparseTensor::from_dense(&client, &dense_b, 1e-10).unwrap();
let c = sparse_a.div(&sparse_b).unwrap();
assert_eq!(c.nnz(), 2);
let dense = c.to_dense(&device).unwrap();
let data: Vec<f32> = dense.to_vec();
assert_eq!(data, vec![5.0, 0.0, 0.0, 5.0]);
}
#[test]
fn test_div_self() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = SparseTensor::<CpuRuntime>::from_csr_slices(
&[0i64, 2, 3],
&[0i64, 1, 1],
&[2.0f32, 3.0, 5.0],
[2, 2],
&device,
)
.unwrap();
let c = a.div(&a).unwrap();
assert_eq!(c.nnz(), 3);
let dense = c.to_dense(&device).unwrap();
let data: Vec<f32> = dense.to_vec();
assert_eq!(data, vec![1.0, 1.0, 0.0, 1.0]);
}
#[test]
fn test_div_by_ones() {
let device = <CpuRuntime as Runtime>::Device::default();
let a = SparseTensor::<CpuRuntime>::from_csr_slices(
&[0i64, 2, 3],
&[0i64, 1, 1],
&[2.0f32, 3.0, 5.0],
[2, 2],
&device,
)
.unwrap();
let ones = SparseTensor::<CpuRuntime>::from_csr_slices(
&[0i64, 2, 3],
&[0i64, 1, 1],
&[1.0f32, 1.0, 1.0],
[2, 2],
&device,
)
.unwrap();
let c = a.div(&ones).unwrap();
assert_eq!(c.nnz(), 3);
let dense = c.to_dense(&device).unwrap();
let data: Vec<f32> = dense.to_vec();
assert_eq!(data, vec![2.0, 3.0, 0.0, 5.0]);
}
}