opensrdk_linear_algebra/matrix/ss/
mul.rs

1use super::SparseMatrix;
2use crate::number::Number;
3use std::{collections::HashMap, ops::Mul};
4
5fn mul<T>(lhs: &SparseMatrix<T>, rhs: &SparseMatrix<T>) -> SparseMatrix<T>
6where
7    T: Number,
8{
9    if lhs.cols != rhs.rows {
10        panic!("Dimension mismatch.");
11    }
12
13    let elems_orig = lhs
14        .elems
15        .iter()
16        .map(|(&(l_row, l_col), &l)| {
17            let elems_orig = rhs
18                .elems
19                .iter()
20                .filter(|(&(r_row, _r_col), &_r)| l_col == r_row)
21                .map(|(&(_r_row, r_col), &r)| {
22                    let elem = r * l;
23                    ((l_row, r_col), elem)
24                })
25                .collect::<Vec<((usize, usize), T)>>();
26            elems_orig
27        })
28        .collect::<Vec<Vec<((usize, usize), T)>>>()
29        .concat();
30
31    let elems_hash = elems_orig.clone().into_iter().collect::<HashMap<_, _>>();
32
33    let elems = elems_hash
34        .iter()
35        .map(|((row, col), _)| {
36            let mut elems_same = elems_orig.clone();
37            elems_same.retain(|((row_v, col_v), _value)| (row_v, col_v) == (row, col));
38            let value = elems_same
39                .iter()
40                .map(|((_row, _col), value)| *value)
41                .sum::<T>();
42            ((*row, *col), value)
43        })
44        .collect::<HashMap<_, _>>();
45
46    let new_matrix = SparseMatrix::from(lhs.rows, rhs.cols, elems);
47
48    new_matrix
49}
50
51impl<T> Mul<SparseMatrix<T>> for SparseMatrix<T>
52where
53    T: Number,
54{
55    type Output = SparseMatrix<T>;
56
57    fn mul(self, rhs: SparseMatrix<T>) -> Self::Output {
58        mul(&self, &rhs)
59    }
60}
61
62impl<T> Mul<&SparseMatrix<T>> for SparseMatrix<T>
63where
64    T: Number,
65{
66    type Output = SparseMatrix<T>;
67
68    fn mul(self, rhs: &SparseMatrix<T>) -> Self::Output {
69        mul(&self, rhs)
70    }
71}
72
73impl<T> Mul<SparseMatrix<T>> for &SparseMatrix<T>
74where
75    T: Number,
76{
77    type Output = SparseMatrix<T>;
78
79    fn mul(self, rhs: SparseMatrix<T>) -> Self::Output {
80        mul(self, &rhs)
81    }
82}
83
84impl<T> Mul<&SparseMatrix<T>> for &SparseMatrix<T>
85where
86    T: Number,
87{
88    type Output = SparseMatrix<T>;
89
90    fn mul(self, rhs: &SparseMatrix<T>) -> Self::Output {
91        mul(self, rhs)
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use crate::*;
98    #[test]
99    fn it_works() {
100        let mut a = SparseMatrix::new(3, 2);
101        a[(0, 0)] = 1.0;
102        a[(2, 1)] = 2.0;
103        println!("a {:#?}", a);
104        let mut b = SparseMatrix::new(2, 2);
105        b[(0, 0)] = 3.0;
106        b[(1, 0)] = 4.0;
107        println!("b {:#?}", b);
108        let c = a * b;
109        println!("c {:#?}", c);
110
111        // assert_eq!(c[(0, 0)], 3.0);
112        // assert_eq!(c[(2, 0)], 8.0);
113
114        let d = mat![
115            1.0, 0.0;
116            0.0, 0.0;
117            0.0, 2.0
118        ];
119        let e = mat![
120            3.0, 0.0;
121            4.0, 0.0];
122        println!("row {:#?}", d.rows());
123        let f = d.dot(&e);
124        println!("f {:#?}", f);
125    }
126}