sparsetools/coo/
std_ops.rs

1use crate::coo::Coo;
2use crate::csr::CSR;
3use crate::traits::{Integer, Scalar};
4use std::ops::{Add, Neg};
5
6#[opimps::impl_ops(Add)]
7fn add<I: Integer, T: Scalar>(self: Coo<I, T>, rhs: Coo<I, T>) -> CSR<I, T> {
8    assert_eq!(self.rows(), rhs.rows());
9    assert_eq!(self.cols(), rhs.cols());
10
11    let k = self.nnz();
12    let nnz = k + rhs.nnz();
13
14    let mut rowidx = Vec::with_capacity(nnz);
15    let mut colidx = Vec::with_capacity(nnz);
16    let mut values = Vec::with_capacity(nnz);
17
18    rowidx.extend(self.rowidx());
19    colidx.extend(self.colidx());
20    values.extend(self.values());
21
22    rowidx.extend(rhs.rowidx());
23    colidx.extend(rhs.colidx());
24    values.extend(rhs.values());
25
26    let a_mat = Coo::new(self.rows(), self.cols(), rowidx, colidx, values).unwrap();
27    a_mat.to_csr() // Duplicate entries are summed.
28}
29
30#[opimps::impl_uni_op(Neg)]
31fn neg<I: Integer, T: Scalar + Neg<Output = T>>(self: Coo<I, T>) -> Coo<I, T> {
32    Coo::new(
33        self.rows(),
34        self.cols(),
35        self.rowidx,
36        self.colidx,
37        self.values.iter().map(|&d| -d).collect(),
38    )
39    .unwrap()
40}