spalinalg/csc/ops/
mul.rs

1use std::ops::Mul;
2
3use crate::{scalar::Scalar, CscMatrix};
4
5impl<T: Scalar> Mul for &CscMatrix<T> {
6    type Output = CscMatrix<T>;
7
8    fn mul(self, rhs: Self) -> Self::Output {
9        assert_eq!(self.ncols(), rhs.nrows());
10
11        // Transpose inputs
12        let (lhs, rhs) = (rhs.transpose(), self.transpose());
13
14        // Allocate output
15        let mut colptr = Vec::with_capacity(rhs.ncols() + 1);
16        let cap = lhs.nnz() + rhs.nnz();
17        let mut rowind = Vec::with_capacity(cap);
18        let mut values = Vec::with_capacity(cap);
19
20        // Allocate workspace
21        let mut set = vec![0; rhs.ncols()];
22        let mut vec = vec![T::zero(); rhs.ncols()];
23
24        // Multiply
25        let mut nz = 0;
26        for col in 0..rhs.ncols() {
27            colptr.push(nz);
28            for rhsptr in rhs.colptr[col]..rhs.colptr[col + 1] {
29                let rhsrow = rhs.rowind[rhsptr];
30                for lhsptr in lhs.colptr[rhsrow]..lhs.colptr[rhsrow + 1] {
31                    let lhsrow = lhs.rowind[lhsptr];
32                    if set[lhsrow] < col + 1 {
33                        set[lhsrow] = col + 1;
34                        rowind.push(lhsrow);
35                        vec[lhsrow] = rhs.values[rhsptr] * lhs.values[lhsptr];
36                        nz += 1;
37                    } else {
38                        vec[lhsrow] += rhs.values[rhsptr] * lhs.values[lhsptr];
39                    }
40                }
41            }
42            for ptr in colptr[col]..nz {
43                let value = vec[rowind[ptr]];
44                values.push(value)
45            }
46        }
47        colptr.push(nz);
48
49        // Construct matrix
50        let output = CscMatrix {
51            nrows: lhs.nrows(),
52            ncols: rhs.ncols(),
53            colptr,
54            rowind,
55            values,
56        };
57
58        // Transpose output
59        output.transpose()
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn mul() {
69        let lhs = CscMatrix::new(
70            5,
71            3,
72            vec![0, 3, 4, 6],
73            vec![0, 1, 4, 3, 1, 2],
74            vec![1.0, -5.0, 4.0, 3.0, 7.0, 2.0],
75        );
76        let rhs = CscMatrix::new(
77            3,
78            4,
79            vec![0, 3, 4, 5, 6],
80            vec![0, 1, 2, 2, 0, 1],
81            vec![1.0, -5.0, 7.0, 3.0, -2.0, 4.0],
82        );
83        let mat = &lhs * &rhs;
84        assert_eq!(mat.nrows, 5);
85        assert_eq!(mat.ncols, 4);
86        assert_eq!(mat.colptr, [0, 5, 7, 10, 11]);
87        assert_eq!(mat.rowind, [0, 1, 2, 3, 4, 1, 2, 0, 1, 4, 3]);
88        assert_eq!(
89            mat.values,
90            [1.0, 44.0, 14.0, -15.0, 4.0, 21.0, 6.0, -2.0, 10.0, -8.0, 12.0]
91        );
92        assert_eq!(mat.colptr.capacity(), mat.ncols() + 1);
93        assert_eq!(mat.rowind.capacity(), mat.nnz());
94        assert_eq!(mat.values.capacity(), mat.nnz());
95    }
96}