1use crate::csr::CsrMatrix;
7use crate::error::SparseError;
8use crate::ops::SparseOps;
9
10#[derive(Debug, Clone)]
15pub struct BsrMatrix {
16 block_rows: usize,
18 block_cols: usize,
20 block_size: usize,
22 offsets: Vec<u32>,
24 col_indices: Vec<u32>,
26 values: Vec<f32>,
29}
30
31impl BsrMatrix {
32 pub fn new(
46 block_rows: usize,
47 block_cols: usize,
48 block_size: usize,
49 offsets: Vec<u32>,
50 col_indices: Vec<u32>,
51 values: Vec<f32>,
52 ) -> Result<Self, SparseError> {
53 if offsets.len() != block_rows + 1 {
54 return Err(SparseError::InvalidOffsetsLength {
55 actual: offsets.len(),
56 expected: block_rows + 1,
57 });
58 }
59 let nnz_blocks = col_indices.len();
60 let expected_vals = nnz_blocks * block_size * block_size;
61 if values.len() != expected_vals {
62 return Err(SparseError::LengthMismatch {
63 col_len: expected_vals,
64 val_len: values.len(),
65 });
66 }
67 Ok(Self {
68 block_rows,
69 block_cols,
70 block_size,
71 offsets,
72 col_indices,
73 values,
74 })
75 }
76
77 pub fn from_dense(data: &[f32], rows: usize, cols: usize, block_size: usize) -> Self {
82 let br = rows.div_ceil(block_size);
83 let bc = cols.div_ceil(block_size);
84
85 let mut offsets = vec![0u32; br + 1];
86 let mut col_indices = Vec::new();
87 let mut values = Vec::new();
88 let bs2 = block_size * block_size;
89
90 for bi in 0..br {
91 for bj in 0..bc {
92 let mut block = vec![0.0f32; bs2];
93 let mut has_nonzero = false;
94 for li in 0..block_size {
95 for lj in 0..block_size {
96 let gi = bi * block_size + li;
97 let gj = bj * block_size + lj;
98 if gi < rows && gj < cols {
99 let val = data[gi * cols + gj];
100 block[li * block_size + lj] = val;
101 if val != 0.0 {
102 has_nonzero = true;
103 }
104 }
105 }
106 }
107 if has_nonzero {
108 col_indices.push(bj as u32);
109 values.extend_from_slice(&block);
110 }
111 }
112 offsets[bi + 1] = col_indices.len() as u32;
113 }
114
115 Self {
116 block_rows: br,
117 block_cols: bc,
118 block_size,
119 offsets,
120 col_indices,
121 values,
122 }
123 }
124
125 pub fn to_csr(&self) -> Result<CsrMatrix<f32>, SparseError> {
131 let rows = self.block_rows * self.block_size;
132 let cols = self.block_cols * self.block_size;
133 let bs = self.block_size;
134 let bs2 = bs * bs;
135
136 let mut csr_offsets = vec![0u32; rows + 1];
137 let mut csr_cols = Vec::new();
138 let mut csr_vals = Vec::new();
139
140 for bi in 0..self.block_rows {
141 let blk_start = self.offsets[bi] as usize;
142 let blk_end = self.offsets[bi + 1] as usize;
143
144 for li in 0..bs {
145 let global_row = bi * bs + li;
146 if global_row >= rows {
147 break;
148 }
149 for blk_idx in blk_start..blk_end {
150 let bj = self.col_indices[blk_idx] as usize;
151 for lj in 0..bs {
152 let global_col = bj * bs + lj;
153 if global_col >= cols {
154 continue;
155 }
156 let val = self.values[blk_idx * bs2 + li * bs + lj];
157 if val != 0.0 {
158 csr_cols.push(global_col as u32);
159 csr_vals.push(val);
160 }
161 }
162 }
163 csr_offsets[global_row + 1] = csr_cols.len() as u32;
164 }
165 }
166
167 CsrMatrix::new(rows, cols, csr_offsets, csr_cols, csr_vals)
168 }
169
170 pub fn rows(&self) -> usize {
172 self.block_rows * self.block_size
173 }
174
175 pub fn cols(&self) -> usize {
177 self.block_cols * self.block_size
178 }
179
180 pub fn nnz_blocks(&self) -> usize {
182 self.col_indices.len()
183 }
184
185 pub fn block_size(&self) -> usize {
187 self.block_size
188 }
189}
190
191impl SparseOps for BsrMatrix {
192 fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
193 if x.len() != self.cols() {
194 return Err(SparseError::SpMVDimensionMismatch {
195 matrix_cols: self.cols(),
196 x_len: x.len(),
197 });
198 }
199 if y.len() != self.rows() {
200 return Err(SparseError::SpMVOutputDimensionMismatch {
201 matrix_rows: self.rows(),
202 y_len: y.len(),
203 });
204 }
205
206 let bs = self.block_size;
207 let bs2 = bs * bs;
208
209 for yi in y.iter_mut() {
211 *yi *= beta;
212 }
213
214 for bi in 0..self.block_rows {
216 let blk_start = self.offsets[bi] as usize;
217 let blk_end = self.offsets[bi + 1] as usize;
218
219 for blk_idx in blk_start..blk_end {
220 let bj = self.col_indices[blk_idx] as usize;
221 let block = &self.values[blk_idx * bs2..(blk_idx + 1) * bs2];
222
223 for li in 0..bs {
224 let gi = bi * bs + li;
225 if gi >= y.len() {
226 break;
227 }
228 let mut sum = 0.0f32;
229 for lj in 0..bs {
230 let gj = bj * bs + lj;
231 if gj < x.len() {
232 sum += block[li * bs + lj] * x[gj];
233 }
234 }
235 y[gi] += alpha * sum;
236 }
237 }
238 }
239
240 Ok(())
241 }
242
243 fn spmm(
244 &self,
245 alpha: f32,
246 b: &[f32],
247 b_cols: usize,
248 beta: f32,
249 c: &mut [f32],
250 ) -> Result<(), SparseError> {
251 if b.len() != self.cols() * b_cols {
252 return Err(SparseError::SpMVDimensionMismatch {
253 matrix_cols: self.cols(),
254 x_len: b.len(),
255 });
256 }
257 if c.len() != self.rows() * b_cols {
258 return Err(SparseError::SpMVOutputDimensionMismatch {
259 matrix_rows: self.rows(),
260 y_len: c.len(),
261 });
262 }
263
264 let bs = self.block_size;
265 let bs2 = bs * bs;
266
267 for ci in c.iter_mut() {
269 *ci *= beta;
270 }
271
272 for bi in 0..self.block_rows {
274 let blk_start = self.offsets[bi] as usize;
275 let blk_end = self.offsets[bi + 1] as usize;
276
277 for blk_idx in blk_start..blk_end {
278 let bj = self.col_indices[blk_idx] as usize;
279 let block = &self.values[blk_idx * bs2..(blk_idx + 1) * bs2];
280
281 for li in 0..bs {
282 let gi = bi * bs + li;
283 if gi >= self.rows() {
284 break;
285 }
286 for lj in 0..bs {
287 let gj = bj * bs + lj;
288 if gj >= self.cols() {
289 continue;
290 }
291 let a_val = alpha * block[li * bs + lj];
292 for k in 0..b_cols {
293 c[gi * b_cols + k] += a_val * b[gj * b_cols + k];
294 }
295 }
296 }
297 }
298 }
299
300 Ok(())
301 }
302}