1use std::ops::Mul;
2
3use crate::{scalar::Scalar, CscMatrix};
4
5impl<T: Scalar> Mul for &CscMatrix<T> {
6 type Output = CscMatrix<T>;
7
8 fn mul(self, rhs: Self) -> Self::Output {
9 assert_eq!(self.ncols(), rhs.nrows());
10
11 let (lhs, rhs) = (rhs.transpose(), self.transpose());
13
14 let mut colptr = Vec::with_capacity(rhs.ncols() + 1);
16 let cap = lhs.nnz() + rhs.nnz();
17 let mut rowind = Vec::with_capacity(cap);
18 let mut values = Vec::with_capacity(cap);
19
20 let mut set = vec![0; rhs.ncols()];
22 let mut vec = vec![T::zero(); rhs.ncols()];
23
24 let mut nz = 0;
26 for col in 0..rhs.ncols() {
27 colptr.push(nz);
28 for rhsptr in rhs.colptr[col]..rhs.colptr[col + 1] {
29 let rhsrow = rhs.rowind[rhsptr];
30 for lhsptr in lhs.colptr[rhsrow]..lhs.colptr[rhsrow + 1] {
31 let lhsrow = lhs.rowind[lhsptr];
32 if set[lhsrow] < col + 1 {
33 set[lhsrow] = col + 1;
34 rowind.push(lhsrow);
35 vec[lhsrow] = rhs.values[rhsptr] * lhs.values[lhsptr];
36 nz += 1;
37 } else {
38 vec[lhsrow] += rhs.values[rhsptr] * lhs.values[lhsptr];
39 }
40 }
41 }
42 for ptr in colptr[col]..nz {
43 let value = vec[rowind[ptr]];
44 values.push(value)
45 }
46 }
47 colptr.push(nz);
48
49 let output = CscMatrix {
51 nrows: lhs.nrows(),
52 ncols: rhs.ncols(),
53 colptr,
54 rowind,
55 values,
56 };
57
58 output.transpose()
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn mul() {
69 let lhs = CscMatrix::new(
70 5,
71 3,
72 vec![0, 3, 4, 6],
73 vec![0, 1, 4, 3, 1, 2],
74 vec![1.0, -5.0, 4.0, 3.0, 7.0, 2.0],
75 );
76 let rhs = CscMatrix::new(
77 3,
78 4,
79 vec![0, 3, 4, 5, 6],
80 vec![0, 1, 2, 2, 0, 1],
81 vec![1.0, -5.0, 7.0, 3.0, -2.0, 4.0],
82 );
83 let mat = &lhs * &rhs;
84 assert_eq!(mat.nrows, 5);
85 assert_eq!(mat.ncols, 4);
86 assert_eq!(mat.colptr, [0, 5, 7, 10, 11]);
87 assert_eq!(mat.rowind, [0, 1, 2, 3, 4, 1, 2, 0, 1, 4, 3]);
88 assert_eq!(
89 mat.values,
90 [1.0, 44.0, 14.0, -15.0, 4.0, 21.0, 6.0, -2.0, 10.0, -8.0, 12.0]
91 );
92 assert_eq!(mat.colptr.capacity(), mat.ncols() + 1);
93 assert_eq!(mat.rowind.capacity(), mat.nnz());
94 assert_eq!(mat.values.capacity(), mat.nnz());
95 }
96}