del_ls/
sparse_matrix_multiplication.rs

1use num_traits::AsPrimitive;
2
3/// non-zero pattern of spares matrix multiplication (`c = a * b`)
4/// * `a_is_square` - matrix `a` has diagonal entry that is not listed in the CRS
5/// * `b_is_square` - matrix `b` has diagonal entry that is not listed in the CRS
6/// * `c_is_square` - if we exclude diagonal entry of matrix `c` from CRS
7fn symbolic_multiplication(
8    a_row2idx: &[usize],
9    a_idx2col: &[usize],
10    a_is_square: bool,
11    b_row2idx: &[usize],
12    b_idx2col: &[usize],
13    b_is_square: bool,
14    b_num_column: usize,
15    c_is_square: bool,
16) -> (Vec<usize>, Vec<usize>) {
17    let nrow_a = a_row2idx.len() - 1;
18    let mut c_row2idx = vec![0_usize; nrow_a + 1];
19    let mut col2flag = vec![usize::MAX; b_num_column];
20    for irow_a in 0..nrow_a {
21        for &k_colrow in &a_idx2col[a_row2idx[irow_a]..a_row2idx[irow_a + 1]] {
22            for &jcol_b in &b_idx2col[b_row2idx[k_colrow]..b_row2idx[k_colrow + 1]] {
23                if col2flag[jcol_b] == irow_a || (jcol_b == irow_a && c_is_square) {
24                    continue;
25                }
26                c_row2idx[irow_a + 1] += 1;
27                col2flag[jcol_b] = irow_a;
28            }
29            if b_is_square {
30                if col2flag[k_colrow] == irow_a || (k_colrow == irow_a && c_is_square) {
31                    continue;
32                }
33                c_row2idx[irow_a + 1] += 1;
34                col2flag[k_colrow] = irow_a;
35            }
36        }
37        if a_is_square {
38            for &jcol_b in &b_idx2col[b_row2idx[irow_a]..b_row2idx[irow_a + 1]] {
39                if col2flag[jcol_b] == irow_a || (jcol_b == irow_a && c_is_square) {
40                    continue;
41                }
42                c_row2idx[irow_a + 1] += 1;
43                col2flag[jcol_b] = irow_a;
44            }
45            if b_is_square {
46                if col2flag[irow_a] == irow_a || c_is_square {
47                    continue;
48                }
49                c_row2idx[irow_a + 1] += 1;
50                col2flag[irow_a] = irow_a;
51            }
52        }
53    }
54    // ---------
55    for irow_a in 0..nrow_a {
56        c_row2idx[irow_a + 1] += c_row2idx[irow_a];
57    }
58    let mut c_idx2col = vec![0_usize; c_row2idx[nrow_a]];
59    // ---------
60    col2flag.iter_mut().for_each(|v| *v = usize::MAX);
61    for irow_a in 0..nrow_a {
62        for &k in &a_idx2col[a_row2idx[irow_a]..a_row2idx[irow_a + 1]] {
63            for &jcol_b in &b_idx2col[b_row2idx[k]..b_row2idx[k + 1]] {
64                if col2flag[jcol_b] == irow_a || (jcol_b == irow_a && c_is_square) {
65                    continue;
66                }
67                let c_ind = c_row2idx[irow_a];
68                c_row2idx[irow_a] += 1;
69                c_idx2col[c_ind] = jcol_b;
70                col2flag[jcol_b] = irow_a;
71            }
72            if b_is_square {
73                if col2flag[k] == irow_a || (k == irow_a && c_is_square) {
74                    continue;
75                }
76                let c_ind = c_row2idx[irow_a];
77                c_row2idx[irow_a] += 1;
78                c_idx2col[c_ind] = k;
79                col2flag[k] = irow_a;
80            }
81        }
82        if a_is_square {
83            for &jcol_b in &b_idx2col[b_row2idx[irow_a]..b_row2idx[irow_a + 1]] {
84                if col2flag[jcol_b] == irow_a || (jcol_b == irow_a && c_is_square) {
85                    continue;
86                }
87                let c_ind = c_row2idx[irow_a];
88                c_row2idx[irow_a] += 1;
89                c_idx2col[c_ind] = jcol_b;
90                col2flag[jcol_b] = irow_a;
91            }
92            if b_is_square {
93                if col2flag[irow_a] == irow_a || c_is_square {
94                    continue;
95                }
96                let c_ind = c_row2idx[irow_a];
97                c_row2idx[irow_a] += 1;
98                c_idx2col[c_ind] = irow_a;
99                col2flag[irow_a] = irow_a;
100            }
101        }
102    }
103    for irow0 in (1..nrow_a).rev() {
104        c_row2idx[irow0] = c_row2idx[irow0 - 1];
105    }
106    c_row2idx[0] = 0;
107    (c_row2idx, c_idx2col)
108}
109
110pub fn mult_square_matrices<T>(
111    m0: &crate::sparse_square::Matrix<T>,
112    m1: &crate::sparse_square::Matrix<T>,
113) -> crate::sparse_square::Matrix<T>
114where
115    T: 'static
116        + Copy
117        + std::ops::Mul<Output = T>
118        + std::ops::AddAssign
119        + std::clone::Clone
120        + num_traits::Zero,
121    f32: AsPrimitive<T>,
122{
123    let num_blk = m0.num_blk;
124    assert_eq!(num_blk, m1.num_blk);
125    let (row2idx, idx2col) = symbolic_multiplication(
126        &m0.row2idx,
127        &m0.idx2col,
128        true,
129        &m1.row2idx,
130        &m1.idx2col,
131        true,
132        m1.num_blk,
133        true,
134    );
135
136    let mut idx2val = vec![T::zero(); idx2col.len()];
137    let mut row2val = vec![T::zero(); num_blk];
138
139    let mut col2idx = vec![usize::MAX; num_blk];
140    for irow0 in 0..m0.num_blk {
141        for idx0 in row2idx[irow0]..row2idx[irow0 + 1] {
142            let icol0 = idx2col[idx0];
143            col2idx[icol0] = idx0;
144        }
145        // ----
146        for m0_idx in m0.row2idx[irow0]..m0.row2idx[irow0 + 1] {
147            let k = m0.idx2col[m0_idx];
148            for m1_idx in m1.row2idx[k]..m1.row2idx[k + 1] {
149                let jcol0 = m1.idx2col[m1_idx];
150                if irow0 == jcol0 {
151                    row2val[irow0] += m0.idx2val[m0_idx] * m1.idx2val[m1_idx];
152                } else {
153                    let idx0 = col2idx[jcol0];
154                    idx2val[idx0] += m0.idx2val[m0_idx] * m1.idx2val[m1_idx];
155                }
156            }
157            if irow0 == k {
158                row2val[irow0] += m0.idx2val[m0_idx] * m1.row2val[k];
159            } else {
160                let idx0 = col2idx[k];
161                idx2val[idx0] += m0.idx2val[m0_idx] * m1.row2val[k];
162            }
163        }
164        for m1_idx in m1.row2idx[irow0]..m1.row2idx[irow0 + 1] {
165            let jcol0 = m1.idx2col[m1_idx];
166            if irow0 == jcol0 {
167                row2val[irow0] += m0.row2val[irow0] * m1.idx2val[m1_idx];
168            } else {
169                let idx0 = col2idx[jcol0];
170                idx2val[idx0] += m0.row2val[irow0] * m1.idx2val[m1_idx];
171            }
172        }
173        row2val[irow0] += m0.row2val[irow0] * m1.row2val[irow0];
174        // ----
175        for idx0 in row2idx[irow0]..row2idx[irow0 + 1] {
176            let icol0 = idx2col[idx0];
177            col2idx[icol0] = usize::MAX;
178        }
179    }
180
181    crate::sparse_square::Matrix::<T> {
182        num_blk,
183        row2idx,
184        idx2col,
185        idx2val,
186        row2val,
187    }
188}