1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
use super::SparseMatrix;
use crate::number::Number;
use std::{collections::HashMap, ops::Mul};

fn mul<T>(slf: &SparseMatrix<T>, rhs: &SparseMatrix<T>) -> SparseMatrix<T>
where
    T: Number,
{
    if slf.cols != rhs.rows {
        panic!("Dimension mismatch.");
    }
    let mut new_matrix = SparseMatrix::new(slf.rows, rhs.cols, HashMap::new());

    for (&(i, j), &s) in slf.elems.iter() {
        for (&(_, k), &r) in rhs.elems.iter().filter(|&(&(jr, _), _)| j == jr) {
            let sr = s * r;
            if sr == T::default() {
                continue;
            }

            *new_matrix.elems.entry((i, k)).or_insert(T::default()) += sr;
        }
    }

    new_matrix
}

impl<T> Mul<SparseMatrix<T>> for SparseMatrix<T>
where
    T: Number,
{
    type Output = SparseMatrix<T>;

    fn mul(self, rhs: SparseMatrix<T>) -> Self::Output {
        mul(&self, &rhs)
    }
}

impl<T> Mul<&SparseMatrix<T>> for SparseMatrix<T>
where
    T: Number,
{
    type Output = SparseMatrix<T>;

    fn mul(self, rhs: &SparseMatrix<T>) -> Self::Output {
        mul(&self, rhs)
    }
}

impl<T> Mul<SparseMatrix<T>> for &SparseMatrix<T>
where
    T: Number,
{
    type Output = SparseMatrix<T>;

    fn mul(self, rhs: SparseMatrix<T>) -> Self::Output {
        mul(self, &rhs)
    }
}

impl<T> Mul<&SparseMatrix<T>> for &SparseMatrix<T>
where
    T: Number,
{
    type Output = SparseMatrix<T>;

    fn mul(self, rhs: &SparseMatrix<T>) -> Self::Output {
        mul(self, rhs)
    }
}