del_ls/
sparse_matrix_multiplication.rs1use num_traits::AsPrimitive;
2
3fn symbolic_multiplication(
8 a_row2idx: &[usize],
9 a_idx2col: &[usize],
10 a_is_square: bool,
11 b_row2idx: &[usize],
12 b_idx2col: &[usize],
13 b_is_square: bool,
14 b_num_column: usize,
15 c_is_square: bool,
16) -> (Vec<usize>, Vec<usize>) {
17 let nrow_a = a_row2idx.len() - 1;
18 let mut c_row2idx = vec![0_usize; nrow_a + 1];
19 let mut col2flag = vec![usize::MAX; b_num_column];
20 for irow_a in 0..nrow_a {
21 for &k_colrow in &a_idx2col[a_row2idx[irow_a]..a_row2idx[irow_a + 1]] {
22 for &jcol_b in &b_idx2col[b_row2idx[k_colrow]..b_row2idx[k_colrow + 1]] {
23 if col2flag[jcol_b] == irow_a || (jcol_b == irow_a && c_is_square) {
24 continue;
25 }
26 c_row2idx[irow_a + 1] += 1;
27 col2flag[jcol_b] = irow_a;
28 }
29 if b_is_square {
30 if col2flag[k_colrow] == irow_a || (k_colrow == irow_a && c_is_square) {
31 continue;
32 }
33 c_row2idx[irow_a + 1] += 1;
34 col2flag[k_colrow] = irow_a;
35 }
36 }
37 if a_is_square {
38 for &jcol_b in &b_idx2col[b_row2idx[irow_a]..b_row2idx[irow_a + 1]] {
39 if col2flag[jcol_b] == irow_a || (jcol_b == irow_a && c_is_square) {
40 continue;
41 }
42 c_row2idx[irow_a + 1] += 1;
43 col2flag[jcol_b] = irow_a;
44 }
45 if b_is_square {
46 if col2flag[irow_a] == irow_a || c_is_square {
47 continue;
48 }
49 c_row2idx[irow_a + 1] += 1;
50 col2flag[irow_a] = irow_a;
51 }
52 }
53 }
54 for irow_a in 0..nrow_a {
56 c_row2idx[irow_a + 1] += c_row2idx[irow_a];
57 }
58 let mut c_idx2col = vec![0_usize; c_row2idx[nrow_a]];
59 col2flag.iter_mut().for_each(|v| *v = usize::MAX);
61 for irow_a in 0..nrow_a {
62 for &k in &a_idx2col[a_row2idx[irow_a]..a_row2idx[irow_a + 1]] {
63 for &jcol_b in &b_idx2col[b_row2idx[k]..b_row2idx[k + 1]] {
64 if col2flag[jcol_b] == irow_a || (jcol_b == irow_a && c_is_square) {
65 continue;
66 }
67 let c_ind = c_row2idx[irow_a];
68 c_row2idx[irow_a] += 1;
69 c_idx2col[c_ind] = jcol_b;
70 col2flag[jcol_b] = irow_a;
71 }
72 if b_is_square {
73 if col2flag[k] == irow_a || (k == irow_a && c_is_square) {
74 continue;
75 }
76 let c_ind = c_row2idx[irow_a];
77 c_row2idx[irow_a] += 1;
78 c_idx2col[c_ind] = k;
79 col2flag[k] = irow_a;
80 }
81 }
82 if a_is_square {
83 for &jcol_b in &b_idx2col[b_row2idx[irow_a]..b_row2idx[irow_a + 1]] {
84 if col2flag[jcol_b] == irow_a || (jcol_b == irow_a && c_is_square) {
85 continue;
86 }
87 let c_ind = c_row2idx[irow_a];
88 c_row2idx[irow_a] += 1;
89 c_idx2col[c_ind] = jcol_b;
90 col2flag[jcol_b] = irow_a;
91 }
92 if b_is_square {
93 if col2flag[irow_a] == irow_a || c_is_square {
94 continue;
95 }
96 let c_ind = c_row2idx[irow_a];
97 c_row2idx[irow_a] += 1;
98 c_idx2col[c_ind] = irow_a;
99 col2flag[irow_a] = irow_a;
100 }
101 }
102 }
103 for irow0 in (1..nrow_a).rev() {
104 c_row2idx[irow0] = c_row2idx[irow0 - 1];
105 }
106 c_row2idx[0] = 0;
107 (c_row2idx, c_idx2col)
108}
109
110pub fn mult_square_matrices<T>(
111 m0: &crate::sparse_square::Matrix<T>,
112 m1: &crate::sparse_square::Matrix<T>,
113) -> crate::sparse_square::Matrix<T>
114where
115 T: 'static
116 + Copy
117 + std::ops::Mul<Output = T>
118 + std::ops::AddAssign
119 + std::clone::Clone
120 + num_traits::Zero,
121 f32: AsPrimitive<T>,
122{
123 let num_blk = m0.num_blk;
124 assert_eq!(num_blk, m1.num_blk);
125 let (row2idx, idx2col) = symbolic_multiplication(
126 &m0.row2idx,
127 &m0.idx2col,
128 true,
129 &m1.row2idx,
130 &m1.idx2col,
131 true,
132 m1.num_blk,
133 true,
134 );
135
136 let mut idx2val = vec![T::zero(); idx2col.len()];
137 let mut row2val = vec![T::zero(); num_blk];
138
139 let mut col2idx = vec![usize::MAX; num_blk];
140 for irow0 in 0..m0.num_blk {
141 for idx0 in row2idx[irow0]..row2idx[irow0 + 1] {
142 let icol0 = idx2col[idx0];
143 col2idx[icol0] = idx0;
144 }
145 for m0_idx in m0.row2idx[irow0]..m0.row2idx[irow0 + 1] {
147 let k = m0.idx2col[m0_idx];
148 for m1_idx in m1.row2idx[k]..m1.row2idx[k + 1] {
149 let jcol0 = m1.idx2col[m1_idx];
150 if irow0 == jcol0 {
151 row2val[irow0] += m0.idx2val[m0_idx] * m1.idx2val[m1_idx];
152 } else {
153 let idx0 = col2idx[jcol0];
154 idx2val[idx0] += m0.idx2val[m0_idx] * m1.idx2val[m1_idx];
155 }
156 }
157 if irow0 == k {
158 row2val[irow0] += m0.idx2val[m0_idx] * m1.row2val[k];
159 } else {
160 let idx0 = col2idx[k];
161 idx2val[idx0] += m0.idx2val[m0_idx] * m1.row2val[k];
162 }
163 }
164 for m1_idx in m1.row2idx[irow0]..m1.row2idx[irow0 + 1] {
165 let jcol0 = m1.idx2col[m1_idx];
166 if irow0 == jcol0 {
167 row2val[irow0] += m0.row2val[irow0] * m1.idx2val[m1_idx];
168 } else {
169 let idx0 = col2idx[jcol0];
170 idx2val[idx0] += m0.row2val[irow0] * m1.idx2val[m1_idx];
171 }
172 }
173 row2val[irow0] += m0.row2val[irow0] * m1.row2val[irow0];
174 for idx0 in row2idx[irow0]..row2idx[irow0 + 1] {
176 let icol0 = idx2col[idx0];
177 col2idx[icol0] = usize::MAX;
178 }
179 }
180
181 crate::sparse_square::Matrix::<T> {
182 num_blk,
183 row2idx,
184 idx2col,
185 idx2val,
186 row2val,
187 }
188}