Skip to main content

trueno_sparse/
bsr.rs

1//! Block Sparse Row (BSR) format.
2//!
3//! Stores sparse matrices as blocks of dense sub-matrices, aligned on a
4//! regular block grid. Efficient for FEM and structured sparsity patterns.
5
6use crate::csr::CsrMatrix;
7use crate::error::SparseError;
8use crate::ops::SparseOps;
9
10/// Block Sparse Row matrix.
11///
12/// A matrix of shape `(block_rows * block_size) × (block_cols * block_size)`,
13/// where non-zero blocks are stored in CSR-of-blocks layout.
14#[derive(Debug, Clone)]
15pub struct BsrMatrix {
16    /// Number of block rows.
17    block_rows: usize,
18    /// Number of block columns.
19    block_cols: usize,
20    /// Block dimension (blocks are block_size × block_size).
21    block_size: usize,
22    /// Row offsets for block CSR (length = block_rows + 1).
23    offsets: Vec<u32>,
24    /// Block column indices.
25    col_indices: Vec<u32>,
26    /// Dense block values, stored row-major per block.
27    /// Length = nnz_blocks * block_size * block_size.
28    values: Vec<f32>,
29}
30
31impl BsrMatrix {
32    /// Create a new BSR matrix.
33    ///
34    /// # Arguments
35    ///
36    /// - `block_rows`, `block_cols`: number of block rows/columns
37    /// - `block_size`: dimension of each square block
38    /// - `offsets`: CSR-style row offsets for blocks
39    /// - `col_indices`: block column indices
40    /// - `values`: dense block data (row-major per block)
41    ///
42    /// # Errors
43    ///
44    /// Returns error if structure is invalid.
45    pub fn new(
46        block_rows: usize,
47        block_cols: usize,
48        block_size: usize,
49        offsets: Vec<u32>,
50        col_indices: Vec<u32>,
51        values: Vec<f32>,
52    ) -> Result<Self, SparseError> {
53        if offsets.len() != block_rows + 1 {
54            return Err(SparseError::InvalidOffsetsLength {
55                actual: offsets.len(),
56                expected: block_rows + 1,
57            });
58        }
59        let nnz_blocks = col_indices.len();
60        let expected_vals = nnz_blocks * block_size * block_size;
61        if values.len() != expected_vals {
62            return Err(SparseError::LengthMismatch {
63                col_len: expected_vals,
64                val_len: values.len(),
65            });
66        }
67        Ok(Self {
68            block_rows,
69            block_cols,
70            block_size,
71            offsets,
72            col_indices,
73            values,
74        })
75    }
76
77    /// Create BSR from a dense matrix.
78    ///
79    /// Pads the matrix if dimensions aren't divisible by block_size.
80    /// Only stores blocks with at least one non-zero element.
81    pub fn from_dense(data: &[f32], rows: usize, cols: usize, block_size: usize) -> Self {
82        let br = rows.div_ceil(block_size);
83        let bc = cols.div_ceil(block_size);
84
85        let mut offsets = vec![0u32; br + 1];
86        let mut col_indices = Vec::new();
87        let mut values = Vec::new();
88        let bs2 = block_size * block_size;
89
90        for bi in 0..br {
91            for bj in 0..bc {
92                let mut block = vec![0.0f32; bs2];
93                let mut has_nonzero = false;
94                for li in 0..block_size {
95                    for lj in 0..block_size {
96                        let gi = bi * block_size + li;
97                        let gj = bj * block_size + lj;
98                        if gi < rows && gj < cols {
99                            let val = data[gi * cols + gj];
100                            block[li * block_size + lj] = val;
101                            if val != 0.0 {
102                                has_nonzero = true;
103                            }
104                        }
105                    }
106                }
107                if has_nonzero {
108                    col_indices.push(bj as u32);
109                    values.extend_from_slice(&block);
110                }
111            }
112            offsets[bi + 1] = col_indices.len() as u32;
113        }
114
115        Self {
116            block_rows: br,
117            block_cols: bc,
118            block_size,
119            offsets,
120            col_indices,
121            values,
122        }
123    }
124
125    /// Convert to CSR format.
126    ///
127    /// # Errors
128    ///
129    /// Returns error if the internal conversion produces invalid CSR.
130    pub fn to_csr(&self) -> Result<CsrMatrix<f32>, SparseError> {
131        let rows = self.block_rows * self.block_size;
132        let cols = self.block_cols * self.block_size;
133        let bs = self.block_size;
134        let bs2 = bs * bs;
135
136        let mut csr_offsets = vec![0u32; rows + 1];
137        let mut csr_cols = Vec::new();
138        let mut csr_vals = Vec::new();
139
140        for bi in 0..self.block_rows {
141            let blk_start = self.offsets[bi] as usize;
142            let blk_end = self.offsets[bi + 1] as usize;
143
144            for li in 0..bs {
145                let global_row = bi * bs + li;
146                if global_row >= rows {
147                    break;
148                }
149                for blk_idx in blk_start..blk_end {
150                    let bj = self.col_indices[blk_idx] as usize;
151                    for lj in 0..bs {
152                        let global_col = bj * bs + lj;
153                        if global_col >= cols {
154                            continue;
155                        }
156                        let val = self.values[blk_idx * bs2 + li * bs + lj];
157                        if val != 0.0 {
158                            csr_cols.push(global_col as u32);
159                            csr_vals.push(val);
160                        }
161                    }
162                }
163                csr_offsets[global_row + 1] = csr_cols.len() as u32;
164            }
165        }
166
167        CsrMatrix::new(rows, cols, csr_offsets, csr_cols, csr_vals)
168    }
169
170    /// Total matrix rows.
171    pub fn rows(&self) -> usize {
172        self.block_rows * self.block_size
173    }
174
175    /// Total matrix columns.
176    pub fn cols(&self) -> usize {
177        self.block_cols * self.block_size
178    }
179
180    /// Number of non-zero blocks.
181    pub fn nnz_blocks(&self) -> usize {
182        self.col_indices.len()
183    }
184
185    /// Block size.
186    pub fn block_size(&self) -> usize {
187        self.block_size
188    }
189}
190
191impl SparseOps for BsrMatrix {
192    fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
193        if x.len() != self.cols() {
194            return Err(SparseError::SpMVDimensionMismatch {
195                matrix_cols: self.cols(),
196                x_len: x.len(),
197            });
198        }
199        if y.len() != self.rows() {
200            return Err(SparseError::SpMVOutputDimensionMismatch {
201                matrix_rows: self.rows(),
202                y_len: y.len(),
203            });
204        }
205
206        let bs = self.block_size;
207        let bs2 = bs * bs;
208
209        // y = beta * y
210        for yi in y.iter_mut() {
211            *yi *= beta;
212        }
213
214        // y += alpha * A * x
215        for bi in 0..self.block_rows {
216            let blk_start = self.offsets[bi] as usize;
217            let blk_end = self.offsets[bi + 1] as usize;
218
219            for blk_idx in blk_start..blk_end {
220                let bj = self.col_indices[blk_idx] as usize;
221                let block = &self.values[blk_idx * bs2..(blk_idx + 1) * bs2];
222
223                for li in 0..bs {
224                    let gi = bi * bs + li;
225                    if gi >= y.len() {
226                        break;
227                    }
228                    let mut sum = 0.0f32;
229                    for lj in 0..bs {
230                        let gj = bj * bs + lj;
231                        if gj < x.len() {
232                            sum += block[li * bs + lj] * x[gj];
233                        }
234                    }
235                    y[gi] += alpha * sum;
236                }
237            }
238        }
239
240        Ok(())
241    }
242
243    fn spmm(
244        &self,
245        alpha: f32,
246        b: &[f32],
247        b_cols: usize,
248        beta: f32,
249        c: &mut [f32],
250    ) -> Result<(), SparseError> {
251        if b.len() != self.cols() * b_cols {
252            return Err(SparseError::SpMVDimensionMismatch {
253                matrix_cols: self.cols(),
254                x_len: b.len(),
255            });
256        }
257        if c.len() != self.rows() * b_cols {
258            return Err(SparseError::SpMVOutputDimensionMismatch {
259                matrix_rows: self.rows(),
260                y_len: c.len(),
261            });
262        }
263
264        let bs = self.block_size;
265        let bs2 = bs * bs;
266
267        // Scale C by beta
268        for ci in c.iter_mut() {
269            *ci *= beta;
270        }
271
272        // C += alpha * A * B
273        for bi in 0..self.block_rows {
274            let blk_start = self.offsets[bi] as usize;
275            let blk_end = self.offsets[bi + 1] as usize;
276
277            for blk_idx in blk_start..blk_end {
278                let bj = self.col_indices[blk_idx] as usize;
279                let block = &self.values[blk_idx * bs2..(blk_idx + 1) * bs2];
280
281                for li in 0..bs {
282                    let gi = bi * bs + li;
283                    if gi >= self.rows() {
284                        break;
285                    }
286                    for lj in 0..bs {
287                        let gj = bj * bs + lj;
288                        if gj >= self.cols() {
289                            continue;
290                        }
291                        let a_val = alpha * block[li * bs + lj];
292                        for k in 0..b_cols {
293                            c[gi * b_cols + k] += a_val * b[gj * b_cols + k];
294                        }
295                    }
296                }
297            }
298        }
299
300        Ok(())
301    }
302}