scirs2_sparse/
bsr.rs

1//! Block Sparse Row (BSR) matrix format
2//!
3//! This module provides the BSR matrix format implementation, which is
4//! efficient for block-structured matrices.
5
6use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{SparseElement, Zero};
8
9/// Block Sparse Row (BSR) matrix
10///
11/// A sparse matrix format that stores blocks in compressed sparse row format,
12/// making it efficient for block-structured matrices.
13pub struct BsrMatrix<T> {
14    /// Number of rows
15    rows: usize,
16    /// Number of columns
17    cols: usize,
18    /// Block size (r, c)
19    block_size: (usize, usize),
20    /// Number of block rows
21    block_rows: usize,
22    /// Number of block columns (needed for internal calculations)
23    #[allow(dead_code)]
24    block_cols: usize,
25    /// Data array (blocks stored row by row)
26    data: Vec<Vec<Vec<T>>>,
27    /// Column indices for each block
28    indices: Vec<Vec<usize>>,
29    /// Row pointers (indptr)
30    indptr: Vec<usize>,
31}
32
33impl<T> BsrMatrix<T>
34where
35    T: Clone + Copy + Zero + std::cmp::PartialEq + SparseElement,
36{
37    /// Create a new BSR matrix
38    ///
39    /// # Arguments
40    ///
41    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
42    /// * `block_size` - Tuple containing the block dimensions (r, c)
43    ///
44    /// # Returns
45    ///
46    /// * A new empty BSR matrix
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use scirs2_sparse::bsr::BsrMatrix;
52    ///
53    /// // Create a 6x6 sparse matrix with 2x2 blocks
54    /// let matrix = BsrMatrix::<f64>::new((6, 6), (2, 2)).unwrap();
55    /// ```
56    pub fn new(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
57        let (rows, cols) = shape;
58        let (r, c) = block_size;
59
60        if r == 0 || c == 0 {
61            return Err(SparseError::ValueError(
62                "Block dimensions must be positive".to_string(),
63            ));
64        }
65
66        // Calculate block dimensions
67        #[allow(clippy::manual_div_ceil)]
68        let block_rows = (rows + r - 1) / r; // Ceiling division
69        #[allow(clippy::manual_div_ceil)]
70        let block_cols = (cols + c - 1) / c; // Ceiling division
71
72        // Initialize empty BSR matrix
73        let data = Vec::new();
74        let indices = Vec::new();
75        let indptr = vec![0]; // Initial indptr
76
77        Ok(BsrMatrix {
78            rows,
79            cols,
80            block_size,
81            block_rows,
82            block_cols,
83            data,
84            indices,
85            indptr,
86        })
87    }
88
89    /// Create a BSR matrix from block data
90    ///
91    /// # Arguments
92    ///
93    /// * `data` - Block data (blocks stored row by row)
94    /// * `indices` - Column indices for each block
95    /// * `indptr` - Row pointers
96    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
97    /// * `block_size` - Tuple containing the block dimensions (r, c)
98    ///
99    /// # Returns
100    ///
101    /// * A new BSR matrix
102    pub fn from_blocks(
103        data: Vec<Vec<Vec<T>>>,
104        indices: Vec<Vec<usize>>,
105        indptr: Vec<usize>,
106        shape: (usize, usize),
107        block_size: (usize, usize),
108    ) -> SparseResult<Self> {
109        let (rows, cols) = shape;
110        let (r, c) = block_size;
111
112        if r == 0 || c == 0 {
113            return Err(SparseError::ValueError(
114                "Block dimensions must be positive".to_string(),
115            ));
116        }
117
118        // Calculate block dimensions
119        #[allow(clippy::manual_div_ceil)]
120        let block_rows = (rows + r - 1) / r; // Ceiling division
121        #[allow(clippy::manual_div_ceil)]
122        let block_cols = (cols + c - 1) / c; // Ceiling division
123
124        // Validate input
125        if indptr.len() != block_rows + 1 {
126            return Err(SparseError::DimensionMismatch {
127                expected: block_rows + 1,
128                found: indptr.len(),
129            });
130        }
131
132        if data.len() != indptr[block_rows] {
133            return Err(SparseError::DimensionMismatch {
134                expected: indptr[block_rows],
135                found: data.len(),
136            });
137        }
138
139        if indices.len() != data.len() {
140            return Err(SparseError::DimensionMismatch {
141                expected: data.len(),
142                found: indices.len(),
143            });
144        }
145
146        for block in data.iter() {
147            if block.len() != r {
148                return Err(SparseError::DimensionMismatch {
149                    expected: r,
150                    found: block.len(),
151                });
152            }
153
154            for row in block.iter() {
155                if row.len() != c {
156                    return Err(SparseError::DimensionMismatch {
157                        expected: c,
158                        found: row.len(),
159                    });
160                }
161            }
162        }
163
164        for &idx in indices.iter().flatten() {
165            if idx >= block_cols {
166                return Err(SparseError::ValueError(format!(
167                    "index {} out of bounds (max {})",
168                    idx,
169                    block_cols - 1
170                )));
171            }
172        }
173
174        Ok(BsrMatrix {
175            rows,
176            cols,
177            block_size,
178            block_rows,
179            block_cols,
180            data,
181            indices,
182            indptr,
183        })
184    }
185
186    /// Get the number of rows in the matrix
187    pub fn rows(&self) -> usize {
188        self.rows
189    }
190
191    /// Get the number of columns in the matrix
192    pub fn cols(&self) -> usize {
193        self.cols
194    }
195
196    /// Get the shape (dimensions) of the matrix
197    pub fn shape(&self) -> (usize, usize) {
198        (self.rows, self.cols)
199    }
200
201    /// Get the block size
202    pub fn block_size(&self) -> (usize, usize) {
203        self.block_size
204    }
205
206    /// Get immutable access to the row pointers (indptr) array
207    ///
208    /// The indptr array indicates where each block row starts in the indices
209    /// and data arrays. Specifically, block row `i` contains blocks
210    /// `indptr[i]..indptr[i+1]`.
211    pub fn indptr(&self) -> &[usize] {
212        &self.indptr
213    }
214
215    /// Get immutable access to the column indices array
216    ///
217    /// The indices array contains the block column indices for each stored block.
218    /// The indices for block row `i` are stored in
219    /// `indices[indptr[i]..indptr[i+1]]`.
220    pub fn indices(&self) -> &[Vec<usize>] {
221        &self.indices
222    }
223
224    /// Get mutable access to the block data array
225    ///
226    /// The data array contains the non-zero blocks of the matrix, stored in
227    /// row-major order. Each block is a 2D array (Vec<Vec<T>>) with dimensions
228    /// matching the block size of the matrix.
229    pub fn data_mut(&mut self) -> &mut [Vec<Vec<T>>] {
230        &mut self.data
231    }
232
233    /// Get the number of non-zero blocks in the matrix
234    pub fn nnz_blocks(&self) -> usize {
235        self.data.len()
236    }
237
238    /// Get the number of non-zero elements in the matrix
239    pub fn nnz(&self) -> usize {
240        // Count non-zeros in all blocks
241        let mut count = 0;
242
243        for block in &self.data {
244            for row in block {
245                for &val in row {
246                    if val != T::sparse_zero() {
247                        count += 1;
248                    }
249                }
250            }
251        }
252
253        count
254    }
255
256    /// Convert to dense matrix (as Vec<Vec<T>>)
257    pub fn to_dense(&self) -> Vec<Vec<T>>
258    where
259        T: Zero + Copy + SparseElement,
260    {
261        let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
262        let (r, c) = self.block_size;
263
264        for block_row in 0..self.block_rows {
265            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
266                let block_col = self.indices[k][0];
267                let block = &self.data[k];
268
269                // Copy block to dense matrix
270                for (i, block_row_data) in block.iter().enumerate().take(r) {
271                    let row = block_row * r + i;
272                    if row < self.rows {
273                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
274                            let col = block_col * c + j;
275                            if col < self.cols {
276                                result[row][col] = value;
277                            }
278                        }
279                    }
280                }
281            }
282        }
283
284        result
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_bsr_create() {
294        // Create a 6x6 sparse matrix with 2x2 blocks
295        let matrix = BsrMatrix::<f64>::new((6, 6), (2, 2)).unwrap();
296
297        assert_eq!(matrix.shape(), (6, 6));
298        assert_eq!(matrix.block_size(), (2, 2));
299        assert_eq!(matrix.nnz_blocks(), 0);
300        assert_eq!(matrix.nnz(), 0);
301    }
302
303    #[test]
304    fn test_bsr_from_blocks() {
305        // Create a 4x4 sparse matrix with 2x2 blocks
306        // [1 2 0 0]
307        // [3 4 0 0]
308        // [0 0 5 6]
309        // [0 0 7 8]
310
311        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
312        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
313
314        let data = vec![block1, block2];
315        let indices = vec![vec![0], vec![1]];
316        let indptr = vec![0, 1, 2];
317
318        let matrix = BsrMatrix::from_blocks(data, indices, indptr, (4, 4), (2, 2)).unwrap();
319
320        assert_eq!(matrix.shape(), (4, 4));
321        assert_eq!(matrix.block_size(), (2, 2));
322        assert_eq!(matrix.nnz_blocks(), 2);
323        assert_eq!(matrix.nnz(), 8); // All elements are non-zero
324
325        // Convert to dense
326        let dense = matrix.to_dense();
327
328        let expected = vec![
329            vec![1.0, 2.0, 0.0, 0.0],
330            vec![3.0, 4.0, 0.0, 0.0],
331            vec![0.0, 0.0, 5.0, 6.0],
332            vec![0.0, 0.0, 7.0, 8.0],
333        ];
334
335        assert_eq!(dense, expected);
336    }
337}