1pub struct Matrix<MAT> {
7 pub num_blk: usize,
8 pub row2idx: Vec<usize>,
9 pub idx2col: Vec<usize>,
10 pub idx2val: Vec<MAT>,
11 pub row2val: Vec<MAT>,
12}
13
14impl<MAT> Matrix<MAT>
15where
16 MAT: 'static
17 + num_traits::Zero
18 + std::default::Default
19 + std::ops::AddAssign + Copy
21 + std::fmt::Display,
22{
23 pub fn new() -> Self {
24 Matrix {
25 num_blk: 0,
26 row2idx: vec![0],
27 idx2col: Vec::<usize>::new(),
28 idx2val: Vec::<MAT>::new(),
29 row2val: Vec::<MAT>::new(),
30 }
31 }
32
33 pub fn clone(&self) -> Self {
34 Matrix {
35 num_blk: self.num_blk,
36 row2idx: self.row2idx.clone(),
37 idx2col: self.idx2col.clone(),
38 idx2val: self.idx2val.clone(),
39 row2val: self.row2val.clone(),
40 }
41 }
42
43 pub fn symbolic_initialization(&mut self, row2idx: &Vec<usize>, idx2col: &Vec<usize>) {
45 self.num_blk = row2idx.len() - 1;
46 self.row2idx = row2idx.clone();
47 self.idx2col = idx2col.clone();
48 let num_idx = self.row2idx[self.num_blk];
49 assert_eq!(num_idx, idx2col.len());
50 self.idx2val.resize_with(num_idx, Default::default);
51 self.row2val.resize_with(self.num_blk, Default::default);
52 }
53
54 pub fn set_zero(&mut self) {
56 assert_eq!(self.idx2val.len(), self.idx2col.len());
57 for m in self.row2val.iter_mut() {
58 m.set_zero()
59 }
60 for m in self.idx2val.iter_mut() {
61 m.set_zero()
62 }
63 }
64
65 pub fn merge(
67 &mut self,
68 node2row: &[usize],
69 node2col: &[usize],
70 emat: &[MAT],
71 merge_buffer: &mut Vec<usize>,
72 ) {
73 assert_eq!(emat.len(), node2row.len() * node2col.len());
74 merge_buffer.resize(self.num_blk, usize::MAX);
75 let col2idx = merge_buffer;
76 for inode in 0..node2row.len() {
77 let i_row = node2row[inode];
78 assert!(i_row < self.num_blk);
79 for ij_idx in self.row2idx[i_row]..self.row2idx[i_row + 1] {
80 assert!(ij_idx < self.idx2col.len());
81 let j_col = self.idx2col[ij_idx];
82 col2idx[j_col] = ij_idx;
83 }
84 for jnode in 0..node2col.len() {
85 let j_col = node2col[jnode];
86 assert!(j_col < self.num_blk);
87 if i_row == j_col {
88 self.row2val[i_row] += emat[inode * node2col.len() + jnode];
90 } else {
91 assert!(col2idx[j_col] < self.idx2col.len());
93 let ij_idx = col2idx[j_col];
94 assert_eq!(self.idx2col[ij_idx], j_col);
95 self.idx2val[ij_idx] += emat[inode * node2col.len() + jnode];
96 }
97 }
98 for ij_idx in self.row2idx[i_row]..self.row2idx[i_row + 1] {
99 assert!(ij_idx < self.idx2col.len());
100 let j_col = self.idx2col[ij_idx];
101 col2idx[j_col] = usize::MAX;
102 }
103 }
104 }
105}
106
107pub fn mult_vec<T>(y_vec: &mut Vec<T>, beta: T, alpha: T, a_mat: &Matrix<T>, x_vec: &Vec<T>)
111where
112 T: std::ops::MulAssign + std::ops::Mul<Output = T> + std::ops::AddAssign + 'static
116 + Copy, f32: num_traits::AsPrimitive<T>,
118{
119 assert_eq!(y_vec.len(), a_mat.num_blk);
120 for m in y_vec.iter_mut() {
121 *m *= beta;
122 }
123 for iblk in 0..a_mat.num_blk {
124 for icrs in a_mat.row2idx[iblk]..a_mat.row2idx[iblk + 1] {
125 assert!(icrs < a_mat.idx2col.len());
126 let jblk0 = a_mat.idx2col[icrs];
127 assert!(jblk0 < a_mat.num_blk);
128 y_vec[iblk] += alpha * a_mat.idx2val[icrs] * x_vec[jblk0];
129 }
130 y_vec[iblk] += alpha * a_mat.row2val[iblk] * x_vec[iblk];
131 }
132}
133
134pub fn mult_mat<T>(y_mat: &mut [T], beta: T, alpha: T, a_mat: &Matrix<T>, x_mat: &[T])
135where
136 T: std::ops::MulAssign + std::ops::Mul<Output = T> + std::ops::AddAssign + 'static
140 + Copy, f32: num_traits::AsPrimitive<T>,
142{
143 let num_row = a_mat.row2idx.len() - 1;
144 assert_eq!(y_mat.len(), x_mat.len());
145 let num_dim = y_mat.len() / num_row;
146 assert_eq!(y_mat.len(), num_dim * num_row);
147 for val_y in y_mat.iter_mut() {
148 *val_y *= beta;
149 }
150 for i_row in 0..num_row {
151 for idx in a_mat.row2idx[i_row]..a_mat.row2idx[i_row + 1] {
152 let j_col = a_mat.idx2col[idx];
153 for y in 0..num_dim {
154 y_mat[i_row * num_dim + y] +=
155 alpha * a_mat.idx2val[idx] * x_mat[j_col * num_dim + y];
156 }
157 }
158 for y in 0..num_dim {
159 y_mat[i_row * num_dim + y] += alpha * a_mat.row2val[i_row] * x_mat[i_row * num_dim + y];
160 }
161 }
162}
163
164#[test]
165fn test_scalar() {
166 let mut sparse = crate::sparse_square::Matrix::<f32>::new();
167 let colind = vec![0, 2, 5, 8, 10];
168 let rowptr = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3];
169 sparse.symbolic_initialization(&colind, &rowptr);
170 sparse.set_zero();
171 {
172 let emat = [1., 0., 0., 1.];
173 let mut tmp_buffer = Vec::<usize>::new();
174 sparse.merge(&[0, 1], &[0, 1], &emat, &mut tmp_buffer);
175 }
176 let nblk = colind.len() - 1;
177 let mut rhs = Vec::<f32>::new();
178 rhs.resize(nblk, Default::default());
179 let mut lhs = Vec::<f32>::new();
180 lhs.resize(nblk, Default::default());
181 mult_vec(&mut lhs, 1.0, 1.0, &sparse, &rhs);
182}