1use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{SparseElement, Zero};
8
9pub struct BsrMatrix<T> {
14 rows: usize,
16 cols: usize,
18 block_size: (usize, usize),
20 block_rows: usize,
22 #[allow(dead_code)]
24 block_cols: usize,
25 data: Vec<Vec<Vec<T>>>,
27 indices: Vec<Vec<usize>>,
29 indptr: Vec<usize>,
31}
32
33impl<T> BsrMatrix<T>
34where
35 T: Clone + Copy + Zero + std::cmp::PartialEq + SparseElement,
36{
37 pub fn new(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
57 let (rows, cols) = shape;
58 let (r, c) = block_size;
59
60 if r == 0 || c == 0 {
61 return Err(SparseError::ValueError(
62 "Block dimensions must be positive".to_string(),
63 ));
64 }
65
66 #[allow(clippy::manual_div_ceil)]
68 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
70 let block_cols = (cols + c - 1) / c; let data = Vec::new();
74 let indices = Vec::new();
75 let indptr = vec![0]; Ok(BsrMatrix {
78 rows,
79 cols,
80 block_size,
81 block_rows,
82 block_cols,
83 data,
84 indices,
85 indptr,
86 })
87 }
88
89 pub fn from_blocks(
103 data: Vec<Vec<Vec<T>>>,
104 indices: Vec<Vec<usize>>,
105 indptr: Vec<usize>,
106 shape: (usize, usize),
107 block_size: (usize, usize),
108 ) -> SparseResult<Self> {
109 let (rows, cols) = shape;
110 let (r, c) = block_size;
111
112 if r == 0 || c == 0 {
113 return Err(SparseError::ValueError(
114 "Block dimensions must be positive".to_string(),
115 ));
116 }
117
118 #[allow(clippy::manual_div_ceil)]
120 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
122 let block_cols = (cols + c - 1) / c; if indptr.len() != block_rows + 1 {
126 return Err(SparseError::DimensionMismatch {
127 expected: block_rows + 1,
128 found: indptr.len(),
129 });
130 }
131
132 if data.len() != indptr[block_rows] {
133 return Err(SparseError::DimensionMismatch {
134 expected: indptr[block_rows],
135 found: data.len(),
136 });
137 }
138
139 if indices.len() != data.len() {
140 return Err(SparseError::DimensionMismatch {
141 expected: data.len(),
142 found: indices.len(),
143 });
144 }
145
146 for block in data.iter() {
147 if block.len() != r {
148 return Err(SparseError::DimensionMismatch {
149 expected: r,
150 found: block.len(),
151 });
152 }
153
154 for row in block.iter() {
155 if row.len() != c {
156 return Err(SparseError::DimensionMismatch {
157 expected: c,
158 found: row.len(),
159 });
160 }
161 }
162 }
163
164 for &idx in indices.iter().flatten() {
165 if idx >= block_cols {
166 return Err(SparseError::ValueError(format!(
167 "index {} out of bounds (max {})",
168 idx,
169 block_cols - 1
170 )));
171 }
172 }
173
174 Ok(BsrMatrix {
175 rows,
176 cols,
177 block_size,
178 block_rows,
179 block_cols,
180 data,
181 indices,
182 indptr,
183 })
184 }
185
186 pub fn rows(&self) -> usize {
188 self.rows
189 }
190
191 pub fn cols(&self) -> usize {
193 self.cols
194 }
195
196 pub fn shape(&self) -> (usize, usize) {
198 (self.rows, self.cols)
199 }
200
201 pub fn block_size(&self) -> (usize, usize) {
203 self.block_size
204 }
205
206 pub fn indptr(&self) -> &[usize] {
212 &self.indptr
213 }
214
215 pub fn indices(&self) -> &[Vec<usize>] {
221 &self.indices
222 }
223
224 pub fn data_mut(&mut self) -> &mut [Vec<Vec<T>>] {
230 &mut self.data
231 }
232
233 pub fn nnz_blocks(&self) -> usize {
235 self.data.len()
236 }
237
238 pub fn nnz(&self) -> usize {
240 let mut count = 0;
242
243 for block in &self.data {
244 for row in block {
245 for &val in row {
246 if val != T::sparse_zero() {
247 count += 1;
248 }
249 }
250 }
251 }
252
253 count
254 }
255
256 pub fn to_dense(&self) -> Vec<Vec<T>>
258 where
259 T: Zero + Copy + SparseElement,
260 {
261 let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
262 let (r, c) = self.block_size;
263
264 for block_row in 0..self.block_rows {
265 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
266 let block_col = self.indices[k][0];
267 let block = &self.data[k];
268
269 for (i, block_row_data) in block.iter().enumerate().take(r) {
271 let row = block_row * r + i;
272 if row < self.rows {
273 for (j, &value) in block_row_data.iter().enumerate().take(c) {
274 let col = block_col * c + j;
275 if col < self.cols {
276 result[row][col] = value;
277 }
278 }
279 }
280 }
281 }
282 }
283
284 result
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_bsr_create() {
294 let matrix = BsrMatrix::<f64>::new((6, 6), (2, 2)).unwrap();
296
297 assert_eq!(matrix.shape(), (6, 6));
298 assert_eq!(matrix.block_size(), (2, 2));
299 assert_eq!(matrix.nnz_blocks(), 0);
300 assert_eq!(matrix.nnz(), 0);
301 }
302
303 #[test]
304 fn test_bsr_from_blocks() {
305 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
312 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
313
314 let data = vec![block1, block2];
315 let indices = vec![vec![0], vec![1]];
316 let indptr = vec![0, 1, 2];
317
318 let matrix = BsrMatrix::from_blocks(data, indices, indptr, (4, 4), (2, 2)).unwrap();
319
320 assert_eq!(matrix.shape(), (4, 4));
321 assert_eq!(matrix.block_size(), (2, 2));
322 assert_eq!(matrix.nnz_blocks(), 2);
323 assert_eq!(matrix.nnz(), 8); let dense = matrix.to_dense();
327
328 let expected = vec![
329 vec![1.0, 2.0, 0.0, 0.0],
330 vec![3.0, 4.0, 0.0, 0.0],
331 vec![0.0, 0.0, 5.0, 6.0],
332 vec![0.0, 0.0, 7.0, 8.0],
333 ];
334
335 assert_eq!(dense, expected);
336 }
337}