opensrdk_linear_algebra/matrix/ss/
mul.rs1use super::SparseMatrix;
2use crate::number::Number;
3use std::{collections::HashMap, ops::Mul};
4
5fn mul<T>(lhs: &SparseMatrix<T>, rhs: &SparseMatrix<T>) -> SparseMatrix<T>
6where
7 T: Number,
8{
9 if lhs.cols != rhs.rows {
10 panic!("Dimension mismatch.");
11 }
12
13 let elems_orig = lhs
14 .elems
15 .iter()
16 .map(|(&(l_row, l_col), &l)| {
17 let elems_orig = rhs
18 .elems
19 .iter()
20 .filter(|(&(r_row, _r_col), &_r)| l_col == r_row)
21 .map(|(&(_r_row, r_col), &r)| {
22 let elem = r * l;
23 ((l_row, r_col), elem)
24 })
25 .collect::<Vec<((usize, usize), T)>>();
26 elems_orig
27 })
28 .collect::<Vec<Vec<((usize, usize), T)>>>()
29 .concat();
30
31 let elems_hash = elems_orig.clone().into_iter().collect::<HashMap<_, _>>();
32
33 let elems = elems_hash
34 .iter()
35 .map(|((row, col), _)| {
36 let mut elems_same = elems_orig.clone();
37 elems_same.retain(|((row_v, col_v), _value)| (row_v, col_v) == (row, col));
38 let value = elems_same
39 .iter()
40 .map(|((_row, _col), value)| *value)
41 .sum::<T>();
42 ((*row, *col), value)
43 })
44 .collect::<HashMap<_, _>>();
45
46 let new_matrix = SparseMatrix::from(lhs.rows, rhs.cols, elems);
47
48 new_matrix
49}
50
51impl<T> Mul<SparseMatrix<T>> for SparseMatrix<T>
52where
53 T: Number,
54{
55 type Output = SparseMatrix<T>;
56
57 fn mul(self, rhs: SparseMatrix<T>) -> Self::Output {
58 mul(&self, &rhs)
59 }
60}
61
62impl<T> Mul<&SparseMatrix<T>> for SparseMatrix<T>
63where
64 T: Number,
65{
66 type Output = SparseMatrix<T>;
67
68 fn mul(self, rhs: &SparseMatrix<T>) -> Self::Output {
69 mul(&self, rhs)
70 }
71}
72
73impl<T> Mul<SparseMatrix<T>> for &SparseMatrix<T>
74where
75 T: Number,
76{
77 type Output = SparseMatrix<T>;
78
79 fn mul(self, rhs: SparseMatrix<T>) -> Self::Output {
80 mul(self, &rhs)
81 }
82}
83
84impl<T> Mul<&SparseMatrix<T>> for &SparseMatrix<T>
85where
86 T: Number,
87{
88 type Output = SparseMatrix<T>;
89
90 fn mul(self, rhs: &SparseMatrix<T>) -> Self::Output {
91 mul(self, rhs)
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use crate::*;
98 #[test]
99 fn it_works() {
100 let mut a = SparseMatrix::new(3, 2);
101 a[(0, 0)] = 1.0;
102 a[(2, 1)] = 2.0;
103 println!("a {:#?}", a);
104 let mut b = SparseMatrix::new(2, 2);
105 b[(0, 0)] = 3.0;
106 b[(1, 0)] = 4.0;
107 println!("b {:#?}", b);
108 let c = a * b;
109 println!("c {:#?}", c);
110
111 let d = mat![
115 1.0, 0.0;
116 0.0, 0.0;
117 0.0, 2.0
118 ];
119 let e = mat![
120 3.0, 0.0;
121 4.0, 0.0];
122 println!("row {:#?}", d.rows());
123 let f = d.dot(&e);
124 println!("f {:#?}", f);
125 }
126}