oxiblas_sparse/
bsr.rs

1//! Block Sparse Row (BSR) matrix format.
2//!
3//! BSR stores matrix data as dense blocks, using:
4//! - `data`: Array of dense r×c blocks
5//! - `indices`: Block column indices (like CSR but for block columns)
6//! - `indptr`: Block row pointers (like CSR but for block rows)
7//!
8//! For an m×n matrix with r×c blocks:
9//! - Number of block rows: mb = ceil(m/r)
10//! - Number of block columns: nb = ceil(n/c)
11//!
12//! # When to Use BSR
13//!
14//! BSR format is optimal for:
15//! - Block-structured matrices (FEM, structural mechanics)
16//! - Matrices with dense subblocks
17//! - When block size matches SIMD register width
18//! - Matrix-vector products with vectorized block operations
19//!
20//! BSR is NOT efficient for:
21//! - Matrices without block structure
22//! - Irregular sparsity patterns
23//! - Very small matrices
24
25use oxiblas_core::scalar::{Field, Scalar};
26
27/// Error type for BSR matrix operations.
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum BsrError {
30    /// Invalid block dimensions.
31    InvalidBlockSize {
32        /// Block rows.
33        block_rows: usize,
34        /// Block columns.
35        block_cols: usize,
36    },
37    /// Matrix dimensions not compatible with block size.
38    IncompatibleDimensions {
39        /// Matrix rows.
40        nrows: usize,
41        /// Matrix columns.
42        ncols: usize,
43        /// Block rows.
44        block_rows: usize,
45        /// Block columns.
46        block_cols: usize,
47    },
48    /// Invalid indptr array length.
49    InvalidIndptr {
50        /// Expected length.
51        expected: usize,
52        /// Actual length.
53        actual: usize,
54    },
55    /// Mismatched data/indices counts.
56    DataIndicesMismatch {
57        /// Number of blocks in data.
58        num_blocks: usize,
59        /// Number of indices.
60        num_indices: usize,
61    },
62    /// Block column index out of bounds.
63    InvalidBlockIndex {
64        /// The invalid index.
65        index: usize,
66        /// Number of block columns.
67        nb_cols: usize,
68    },
69    /// Indptr not monotonically increasing.
70    InvalidIndptrOrder,
71    /// Block data has wrong size.
72    InvalidBlockData {
73        /// Block index.
74        block_idx: usize,
75        /// Expected size.
76        expected: usize,
77        /// Actual size.
78        actual: usize,
79    },
80}
81
82impl core::fmt::Display for BsrError {
83    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
84        match self {
85            Self::InvalidBlockSize {
86                block_rows,
87                block_cols,
88            } => {
89                write!(f, "Invalid block size: {block_rows}×{block_cols}")
90            }
91            Self::IncompatibleDimensions {
92                nrows,
93                ncols,
94                block_rows,
95                block_cols,
96            } => {
97                write!(
98                    f,
99                    "Matrix {nrows}×{ncols} incompatible with {block_rows}×{block_cols} blocks"
100                )
101            }
102            Self::InvalidIndptr { expected, actual } => {
103                write!(
104                    f,
105                    "Invalid indptr length: expected {expected}, got {actual}"
106                )
107            }
108            Self::DataIndicesMismatch {
109                num_blocks,
110                num_indices,
111            } => {
112                write!(f, "Mismatch: {num_blocks} blocks but {num_indices} indices")
113            }
114            Self::InvalidBlockIndex { index, nb_cols } => {
115                write!(
116                    f,
117                    "Block column index {index} out of bounds (nb_cols={nb_cols})"
118                )
119            }
120            Self::InvalidIndptrOrder => {
121                write!(f, "Indptr must be monotonically increasing")
122            }
123            Self::InvalidBlockData {
124                block_idx,
125                expected,
126                actual,
127            } => {
128                write!(
129                    f,
130                    "Block {block_idx}: expected {expected} elements, got {actual}"
131                )
132            }
133        }
134    }
135}
136
137impl std::error::Error for BsrError {}
138
139/// A dense block stored in row-major order.
140#[derive(Debug, Clone)]
141pub struct DenseBlock<T: Scalar> {
142    /// Block data in row-major order.
143    data: Vec<T>,
144    /// Number of rows in block.
145    rows: usize,
146    /// Number of columns in block.
147    cols: usize,
148}
149
150impl<T: Scalar + Clone> DenseBlock<T> {
151    /// Creates a new dense block.
152    pub fn new(rows: usize, cols: usize, data: Vec<T>) -> Self {
153        debug_assert_eq!(data.len(), rows * cols);
154        Self { data, rows, cols }
155    }
156
157    /// Creates a zero block.
158    pub fn zeros(rows: usize, cols: usize) -> Self
159    where
160        T: Field,
161    {
162        Self {
163            data: vec![T::zero(); rows * cols],
164            rows,
165            cols,
166        }
167    }
168
169    /// Gets element at (i, j) within the block.
170    #[inline]
171    pub fn get(&self, i: usize, j: usize) -> &T {
172        &self.data[i * self.cols + j]
173    }
174
175    /// Gets mutable element at (i, j) within the block.
176    #[inline]
177    pub fn get_mut(&mut self, i: usize, j: usize) -> &mut T {
178        &mut self.data[i * self.cols + j]
179    }
180
181    /// Returns block dimensions.
182    #[inline]
183    pub fn shape(&self) -> (usize, usize) {
184        (self.rows, self.cols)
185    }
186
187    /// Returns the data as a slice.
188    #[inline]
189    pub fn data(&self) -> &[T] {
190        &self.data
191    }
192
193    /// Returns mutable data slice.
194    #[inline]
195    pub fn data_mut(&mut self) -> &mut [T] {
196        &mut self.data
197    }
198
199    /// Block matrix-vector product: y += A * x.
200    pub fn matvec_add(&self, x: &[T], y: &mut [T])
201    where
202        T: Field,
203    {
204        for i in 0..self.rows {
205            for j in 0..self.cols {
206                y[i] = y[i].clone() + self.get(i, j).clone() * x[j].clone();
207            }
208        }
209    }
210
211    /// Scales the block by a scalar.
212    pub fn scale(&mut self, alpha: T) {
213        for val in &mut self.data {
214            *val = val.clone() * alpha.clone();
215        }
216    }
217
218    /// Returns the Frobenius norm squared.
219    pub fn frobenius_norm_sq(&self) -> T
220    where
221        T: Field,
222    {
223        self.data
224            .iter()
225            .fold(T::zero(), |acc, val| acc + val.clone() * val.clone())
226    }
227}
228
229/// Block Sparse Row matrix format.
230///
231/// Efficient for:
232/// - Block-structured problems (FEM, etc.)
233/// - Vectorized block operations
234/// - Dense subblocks
235///
236/// # Storage
237///
238/// Stores sparse matrices as a collection of dense blocks arranged in CSR-like
239/// structure. Each block is r×c dense matrix.
240///
241/// # Example
242///
243/// ```
244/// use oxiblas_sparse::{BsrMatrix, DenseBlock};
245///
246/// // 4×4 matrix with 2×2 blocks:
247/// // [1 2 | 0 0]
248/// // [3 4 | 0 0]
249/// // [----+----]
250/// // [0 0 | 5 6]
251/// // [0 0 | 7 8]
252/// let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
253/// let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
254///
255/// let bsr = BsrMatrix::new(
256///     4, 4,           // matrix dimensions
257///     2, 2,           // block dimensions
258///     vec![0, 1, 2],  // indptr (2 block rows)
259///     vec![0, 1],     // indices (block columns)
260///     vec![block1, block2],
261/// ).unwrap();
262///
263/// assert_eq!(bsr.nblocks(), 2);
264/// ```
265#[derive(Debug, Clone)]
266pub struct BsrMatrix<T: Scalar> {
267    /// Number of matrix rows.
268    nrows: usize,
269    /// Number of matrix columns.
270    ncols: usize,
271    /// Block row size.
272    block_rows: usize,
273    /// Block column size.
274    block_cols: usize,
275    /// Number of block rows.
276    mb: usize,
277    /// Number of block columns.
278    nb: usize,
279    /// Row pointers for blocks.
280    indptr: Vec<usize>,
281    /// Block column indices.
282    indices: Vec<usize>,
283    /// Dense blocks.
284    data: Vec<DenseBlock<T>>,
285}
286
287impl<T: Scalar + Clone> BsrMatrix<T> {
288    /// Creates a new BSR matrix from raw components.
289    ///
290    /// # Arguments
291    ///
292    /// * `nrows` - Number of matrix rows
293    /// * `ncols` - Number of matrix columns
294    /// * `block_rows` - Number of rows per block
295    /// * `block_cols` - Number of columns per block
296    /// * `indptr` - Block row pointers (length mb + 1)
297    /// * `indices` - Block column indices
298    /// * `data` - Dense blocks
299    ///
300    /// # Errors
301    ///
302    /// Returns an error if the input is invalid.
303    pub fn new(
304        nrows: usize,
305        ncols: usize,
306        block_rows: usize,
307        block_cols: usize,
308        indptr: Vec<usize>,
309        indices: Vec<usize>,
310        data: Vec<DenseBlock<T>>,
311    ) -> Result<Self, BsrError> {
312        // Validate block size
313        if block_rows == 0 || block_cols == 0 {
314            return Err(BsrError::InvalidBlockSize {
315                block_rows,
316                block_cols,
317            });
318        }
319
320        // Calculate number of block rows/cols
321        let mb = nrows.div_ceil(block_rows);
322        let nb = ncols.div_ceil(block_cols);
323
324        // Validate indptr length
325        if indptr.len() != mb + 1 {
326            return Err(BsrError::InvalidIndptr {
327                expected: mb + 1,
328                actual: indptr.len(),
329            });
330        }
331
332        // Validate indptr is monotonically increasing
333        for i in 1..indptr.len() {
334            if indptr[i] < indptr[i - 1] {
335                return Err(BsrError::InvalidIndptrOrder);
336            }
337        }
338
339        // Validate data and indices count
340        let nnz_blocks = data.len();
341        if indices.len() != nnz_blocks {
342            return Err(BsrError::DataIndicesMismatch {
343                num_blocks: nnz_blocks,
344                num_indices: indices.len(),
345            });
346        }
347
348        // Validate indptr[mb] equals nnz_blocks
349        if indptr[mb] != nnz_blocks {
350            return Err(BsrError::InvalidIndptr {
351                expected: nnz_blocks,
352                actual: indptr[mb],
353            });
354        }
355
356        // Validate block column indices
357        for &idx in &indices {
358            if idx >= nb {
359                return Err(BsrError::InvalidBlockIndex {
360                    index: idx,
361                    nb_cols: nb,
362                });
363            }
364        }
365
366        // Validate block sizes
367        let block_size = block_rows * block_cols;
368        for (i, block) in data.iter().enumerate() {
369            if block.data.len() != block_size {
370                return Err(BsrError::InvalidBlockData {
371                    block_idx: i,
372                    expected: block_size,
373                    actual: block.data.len(),
374                });
375            }
376        }
377
378        Ok(Self {
379            nrows,
380            ncols,
381            block_rows,
382            block_cols,
383            mb,
384            nb,
385            indptr,
386            indices,
387            data,
388        })
389    }
390
391    /// Creates a BSR matrix without validation (unsafe but faster).
392    ///
393    /// # Safety
394    ///
395    /// The caller must ensure all invariants hold.
396    #[inline]
397    pub unsafe fn new_unchecked(
398        nrows: usize,
399        ncols: usize,
400        block_rows: usize,
401        block_cols: usize,
402        indptr: Vec<usize>,
403        indices: Vec<usize>,
404        data: Vec<DenseBlock<T>>,
405    ) -> Self {
406        let mb = nrows.div_ceil(block_rows);
407        let nb = ncols.div_ceil(block_cols);
408        Self {
409            nrows,
410            ncols,
411            block_rows,
412            block_cols,
413            mb,
414            nb,
415            indptr,
416            indices,
417            data,
418        }
419    }
420
421    /// Creates an empty BSR matrix with given dimensions.
422    pub fn zeros(nrows: usize, ncols: usize, block_rows: usize, block_cols: usize) -> Self {
423        let mb = nrows.div_ceil(block_rows);
424        Self {
425            nrows,
426            ncols,
427            block_rows,
428            block_cols,
429            mb,
430            nb: ncols.div_ceil(block_cols),
431            indptr: vec![0; mb + 1],
432            indices: Vec::new(),
433            data: Vec::new(),
434        }
435    }
436
437    /// Creates an identity matrix in BSR format.
438    pub fn eye(n: usize, block_size: usize) -> Self
439    where
440        T: Field,
441    {
442        let mb = n.div_ceil(block_size);
443        let mut indptr = Vec::with_capacity(mb + 1);
444        let mut indices = Vec::with_capacity(mb);
445        let mut data = Vec::with_capacity(mb);
446
447        indptr.push(0);
448
449        for bi in 0..mb {
450            indices.push(bi);
451
452            // Create identity block
453            let mut block_data = vec![T::zero(); block_size * block_size];
454            for i in 0..block_size {
455                let global_row = bi * block_size + i;
456                if global_row < n {
457                    block_data[i * block_size + i] = T::one();
458                }
459            }
460            data.push(DenseBlock::new(block_size, block_size, block_data));
461
462            indptr.push(data.len());
463        }
464
465        Self {
466            nrows: n,
467            ncols: n,
468            block_rows: block_size,
469            block_cols: block_size,
470            mb,
471            nb: mb,
472            indptr,
473            indices,
474            data,
475        }
476    }
477
478    /// Returns the number of matrix rows.
479    #[inline]
480    pub fn nrows(&self) -> usize {
481        self.nrows
482    }
483
484    /// Returns the number of matrix columns.
485    #[inline]
486    pub fn ncols(&self) -> usize {
487        self.ncols
488    }
489
490    /// Returns the shape (nrows, ncols).
491    #[inline]
492    pub fn shape(&self) -> (usize, usize) {
493        (self.nrows, self.ncols)
494    }
495
496    /// Returns the block dimensions (block_rows, block_cols).
497    #[inline]
498    pub fn block_shape(&self) -> (usize, usize) {
499        (self.block_rows, self.block_cols)
500    }
501
502    /// Returns the number of block rows.
503    #[inline]
504    pub fn nblock_rows(&self) -> usize {
505        self.mb
506    }
507
508    /// Returns the number of block columns.
509    #[inline]
510    pub fn nblock_cols(&self) -> usize {
511        self.nb
512    }
513
514    /// Returns the number of non-zero blocks.
515    #[inline]
516    pub fn nblocks(&self) -> usize {
517        self.data.len()
518    }
519
520    /// Returns the number of non-zero scalar elements.
521    ///
522    /// Note: This counts actual non-zeros within blocks, not stored values.
523    pub fn nnz(&self) -> usize
524    where
525        T: Field,
526    {
527        let eps = <T as Scalar>::epsilon();
528        let mut count = 0;
529
530        for block in &self.data {
531            for val in block.data() {
532                if Scalar::abs(val.clone()) > eps {
533                    count += 1;
534                }
535            }
536        }
537
538        count
539    }
540
541    /// Returns the total stored values.
542    #[inline]
543    pub fn nstored(&self) -> usize {
544        self.data.len() * self.block_rows * self.block_cols
545    }
546
547    /// Returns the block row pointers.
548    #[inline]
549    pub fn indptr(&self) -> &[usize] {
550        &self.indptr
551    }
552
553    /// Returns the block column indices.
554    #[inline]
555    pub fn indices(&self) -> &[usize] {
556        &self.indices
557    }
558
559    /// Returns the block data.
560    #[inline]
561    pub fn data(&self) -> &[DenseBlock<T>] {
562        &self.data
563    }
564
565    /// Returns mutable block data.
566    #[inline]
567    pub fn data_mut(&mut self) -> &mut [DenseBlock<T>] {
568        &mut self.data
569    }
570
571    /// Gets the block at block position (bi, bj), if present.
572    pub fn get_block(&self, bi: usize, bj: usize) -> Option<&DenseBlock<T>> {
573        if bi >= self.mb || bj >= self.nb {
574            return None;
575        }
576
577        let start = self.indptr[bi];
578        let end = self.indptr[bi + 1];
579
580        for k in start..end {
581            if self.indices[k] == bj {
582                return Some(&self.data[k]);
583            }
584        }
585
586        None
587    }
588
589    /// Gets the scalar value at (row, col).
590    pub fn get(&self, row: usize, col: usize) -> Option<T>
591    where
592        T: Field,
593    {
594        if row >= self.nrows || col >= self.ncols {
595            return None;
596        }
597
598        let bi = row / self.block_rows;
599        let bj = col / self.block_cols;
600        let local_i = row % self.block_rows;
601        let local_j = col % self.block_cols;
602
603        self.get_block(bi, bj)
604            .map(|block| block.get(local_i, local_j).clone())
605    }
606
607    /// Gets the scalar value at (row, col), returning zero if not present.
608    pub fn get_or_zero(&self, row: usize, col: usize) -> T
609    where
610        T: Field,
611    {
612        self.get(row, col).unwrap_or_else(T::zero)
613    }
614
615    /// Returns an iterator over non-zero blocks as (block_row, block_col, &block).
616    pub fn block_iter(&self) -> impl Iterator<Item = (usize, usize, &DenseBlock<T>)> + '_ {
617        (0..self.mb).flat_map(move |bi| {
618            let start = self.indptr[bi];
619            let end = self.indptr[bi + 1];
620
621            (start..end).map(move |k| (bi, self.indices[k], &self.data[k]))
622        })
623    }
624
625    /// Returns an iterator over all non-zero scalars as (row, col, value).
626    pub fn iter(&self) -> impl Iterator<Item = (usize, usize, T)> + '_
627    where
628        T: Field,
629    {
630        let eps = <T as Scalar>::epsilon();
631        let br = self.block_rows;
632        let bc = self.block_cols;
633        let nrows = self.nrows;
634        let ncols = self.ncols;
635
636        self.block_iter().flat_map(move |(bi, bj, block)| {
637            let base_row = bi * br;
638            let base_col = bj * bc;
639
640            (0..br).flat_map(move |i| {
641                (0..bc).filter_map(move |j| {
642                    let global_row = base_row + i;
643                    let global_col = base_col + j;
644
645                    if global_row < nrows && global_col < ncols {
646                        let val = block.get(i, j).clone();
647                        if Scalar::abs(val.clone()) > eps {
648                            return Some((global_row, global_col, val));
649                        }
650                    }
651                    None
652                })
653            })
654        })
655    }
656
657    /// Matrix-vector product: y = A * x.
658    pub fn matvec(&self, x: &[T], y: &mut [T])
659    where
660        T: Field,
661    {
662        assert_eq!(x.len(), self.ncols, "x length must equal ncols");
663        assert_eq!(y.len(), self.nrows, "y length must equal nrows");
664
665        // Initialize y to zero
666        for yi in y.iter_mut() {
667            *yi = T::zero();
668        }
669
670        // Process each block row
671        for bi in 0..self.mb {
672            let start = self.indptr[bi];
673            let end = self.indptr[bi + 1];
674            let row_start = bi * self.block_rows;
675            let row_end = (row_start + self.block_rows).min(self.nrows);
676
677            for k in start..end {
678                let bj = self.indices[k];
679                let block = &self.data[k];
680                let col_start = bj * self.block_cols;
681                let col_end = (col_start + self.block_cols).min(self.ncols);
682
683                // Block matrix-vector product
684                for (i, yi) in y[row_start..row_end].iter_mut().enumerate() {
685                    for j in 0..(col_end - col_start) {
686                        *yi = yi.clone() + block.get(i, j).clone() * x[col_start + j].clone();
687                    }
688                }
689            }
690        }
691    }
692
693    /// Matrix-vector product returning a new vector.
694    pub fn mul_vec(&self, x: &[T]) -> Vec<T>
695    where
696        T: Field,
697    {
698        let mut y = vec![T::zero(); self.nrows];
699        self.matvec(x, &mut y);
700        y
701    }
702
703    /// Converts to CSR format.
704    pub fn to_csr(&self) -> crate::csr::CsrMatrix<T>
705    where
706        T: Field,
707    {
708        let eps = <T as Scalar>::epsilon();
709
710        let mut row_ptrs = vec![0usize; self.nrows + 1];
711        let mut col_indices = Vec::new();
712        let mut values = Vec::new();
713
714        for row in 0..self.nrows {
715            let bi = row / self.block_rows;
716            let local_i = row % self.block_rows;
717
718            let block_start = self.indptr[bi];
719            let block_end = self.indptr[bi + 1];
720
721            let mut row_entries: Vec<(usize, T)> = Vec::new();
722
723            for k in block_start..block_end {
724                let bj = self.indices[k];
725                let block = &self.data[k];
726
727                for j in 0..self.block_cols {
728                    let global_col = bj * self.block_cols + j;
729                    if global_col < self.ncols {
730                        let val = block.get(local_i, j).clone();
731                        if Scalar::abs(val.clone()) > eps {
732                            row_entries.push((global_col, val));
733                        }
734                    }
735                }
736            }
737
738            // Sort by column
739            row_entries.sort_by_key(|(col, _)| *col);
740
741            for (col, val) in row_entries {
742                col_indices.push(col);
743                values.push(val);
744            }
745            row_ptrs[row + 1] = values.len();
746        }
747
748        // Safety: we constructed valid CSR data
749        unsafe {
750            crate::csr::CsrMatrix::new_unchecked(
751                self.nrows,
752                self.ncols,
753                row_ptrs,
754                col_indices,
755                values,
756            )
757        }
758    }
759
760    /// Converts to dense matrix.
761    pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
762    where
763        T: Field + bytemuck::Zeroable,
764    {
765        let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
766
767        for bi in 0..self.mb {
768            let start = self.indptr[bi];
769            let end = self.indptr[bi + 1];
770            let row_start = bi * self.block_rows;
771
772            for k in start..end {
773                let bj = self.indices[k];
774                let block = &self.data[k];
775                let col_start = bj * self.block_cols;
776
777                for i in 0..self.block_rows {
778                    let global_row = row_start + i;
779                    if global_row >= self.nrows {
780                        break;
781                    }
782
783                    for j in 0..self.block_cols {
784                        let global_col = col_start + j;
785                        if global_col >= self.ncols {
786                            break;
787                        }
788
789                        dense[(global_row, global_col)] = block.get(i, j).clone();
790                    }
791                }
792            }
793        }
794
795        dense
796    }
797
798    /// Creates a BSR matrix from a dense matrix.
799    ///
800    /// # Arguments
801    ///
802    /// * `dense` - Source dense matrix
803    /// * `block_rows` - Block row size
804    /// * `block_cols` - Block column size
805    pub fn from_dense(
806        dense: &oxiblas_matrix::MatRef<'_, T>,
807        block_rows: usize,
808        block_cols: usize,
809    ) -> Self
810    where
811        T: Field,
812    {
813        let (nrows, ncols) = dense.shape();
814        let eps = <T as Scalar>::epsilon();
815
816        let mb = nrows.div_ceil(block_rows);
817        let nb = ncols.div_ceil(block_cols);
818
819        let mut indptr = Vec::with_capacity(mb + 1);
820        let mut indices = Vec::new();
821        let mut data = Vec::new();
822
823        indptr.push(0);
824
825        for bi in 0..mb {
826            let row_start = bi * block_rows;
827            let row_end = (row_start + block_rows).min(nrows);
828
829            for bj in 0..nb {
830                let col_start = bj * block_cols;
831                let col_end = (col_start + block_cols).min(ncols);
832
833                // Check if block has any non-zeros
834                let mut has_nonzero = false;
835                for i in row_start..row_end {
836                    for j in col_start..col_end {
837                        if Scalar::abs(dense[(i, j)].clone()) > eps {
838                            has_nonzero = true;
839                            break;
840                        }
841                    }
842                    if has_nonzero {
843                        break;
844                    }
845                }
846
847                if has_nonzero {
848                    // Extract block
849                    let mut block_data = vec![T::zero(); block_rows * block_cols];
850                    for i in 0..block_rows {
851                        let global_row = row_start + i;
852                        if global_row >= nrows {
853                            break;
854                        }
855                        for j in 0..block_cols {
856                            let global_col = col_start + j;
857                            if global_col >= ncols {
858                                break;
859                            }
860                            block_data[i * block_cols + j] =
861                                dense[(global_row, global_col)].clone();
862                        }
863                    }
864
865                    indices.push(bj);
866                    data.push(DenseBlock::new(block_rows, block_cols, block_data));
867                }
868            }
869
870            indptr.push(data.len());
871        }
872
873        Self {
874            nrows,
875            ncols,
876            block_rows,
877            block_cols,
878            mb,
879            nb,
880            indptr,
881            indices,
882            data,
883        }
884    }
885
886    /// Creates a BSR matrix from CSR format.
887    ///
888    /// # Arguments
889    ///
890    /// * `csr` - Source CSR matrix
891    /// * `block_rows` - Block row size
892    /// * `block_cols` - Block column size
893    pub fn from_csr(csr: &crate::csr::CsrMatrix<T>, block_rows: usize, block_cols: usize) -> Self
894    where
895        T: Field,
896    {
897        let (nrows, ncols) = csr.shape();
898        let eps = <T as Scalar>::epsilon();
899
900        let mb = nrows.div_ceil(block_rows);
901        let nb = ncols.div_ceil(block_cols);
902
903        let mut indptr = Vec::with_capacity(mb + 1);
904        let mut indices = Vec::new();
905        let mut data = Vec::new();
906
907        indptr.push(0);
908
909        for bi in 0..mb {
910            let row_start = bi * block_rows;
911            let row_end = (row_start + block_rows).min(nrows);
912
913            // Collect all block columns that have entries in this block row
914            let mut block_cols_present = std::collections::HashSet::new();
915            for row in row_start..row_end {
916                for (col, _) in csr.row_iter(row) {
917                    let bj = col / block_cols;
918                    block_cols_present.insert(bj);
919                }
920            }
921
922            // Sort block columns
923            let mut sorted_bjs: Vec<_> = block_cols_present.into_iter().collect();
924            sorted_bjs.sort();
925
926            for bj in sorted_bjs {
927                let col_start = bj * block_cols;
928
929                // Extract block data
930                let mut block_data = vec![T::zero(); block_rows * block_cols];
931                let mut has_nonzero = false;
932
933                for row in row_start..row_end {
934                    let local_i = row - row_start;
935                    for (col, val) in csr.row_iter(row) {
936                        if col >= col_start && col < col_start + block_cols {
937                            let local_j = col - col_start;
938                            if Scalar::abs(val.clone()) > eps {
939                                block_data[local_i * block_cols + local_j] = val.clone();
940                                has_nonzero = true;
941                            }
942                        }
943                    }
944                }
945
946                if has_nonzero {
947                    indices.push(bj);
948                    data.push(DenseBlock::new(block_rows, block_cols, block_data));
949                }
950            }
951
952            indptr.push(data.len());
953        }
954
955        Self {
956            nrows,
957            ncols,
958            block_rows,
959            block_cols,
960            mb,
961            nb,
962            indptr,
963            indices,
964            data,
965        }
966    }
967
968    /// Scales all values by a scalar.
969    pub fn scale(&mut self, alpha: T) {
970        for block in &mut self.data {
971            block.scale(alpha.clone());
972        }
973    }
974
975    /// Returns a scaled copy of this matrix.
976    pub fn scaled(&self, alpha: T) -> Self {
977        let mut result = self.clone();
978        result.scale(alpha);
979        result
980    }
981
982    /// Returns the transpose of this matrix.
983    pub fn transpose(&self) -> Self
984    where
985        T: Field,
986    {
987        // Convert to CSR, transpose, convert back
988        let csr = self.to_csr();
989        let csr_t = csr.transpose();
990        Self::from_csr(&csr_t, self.block_cols, self.block_rows)
991    }
992}
993
994#[cfg(test)]
995mod tests {
996    use super::*;
997
998    #[test]
999    fn test_bsr_new() {
1000        // 4×4 matrix with 2×2 blocks:
1001        // [1 2 | 0 0]
1002        // [3 4 | 0 0]
1003        // [----+----]
1004        // [0 0 | 5 6]
1005        // [0 0 | 7 8]
1006        let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1007        let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1008
1009        let bsr =
1010            BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1011
1012        assert_eq!(bsr.nrows(), 4);
1013        assert_eq!(bsr.ncols(), 4);
1014        assert_eq!(bsr.nblocks(), 2);
1015        assert_eq!(bsr.block_shape(), (2, 2));
1016    }
1017
1018    #[test]
1019    fn test_bsr_get() {
1020        let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1021        let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1022
1023        let bsr =
1024            BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1025
1026        // Block 1
1027        assert_eq!(bsr.get(0, 0), Some(1.0));
1028        assert_eq!(bsr.get(0, 1), Some(2.0));
1029        assert_eq!(bsr.get(1, 0), Some(3.0));
1030        assert_eq!(bsr.get(1, 1), Some(4.0));
1031
1032        // Block 2
1033        assert_eq!(bsr.get(2, 2), Some(5.0));
1034        assert_eq!(bsr.get(2, 3), Some(6.0));
1035        assert_eq!(bsr.get(3, 2), Some(7.0));
1036        assert_eq!(bsr.get(3, 3), Some(8.0));
1037
1038        // Zero blocks
1039        assert_eq!(bsr.get(0, 2), None);
1040        assert_eq!(bsr.get(2, 0), None);
1041    }
1042
1043    #[test]
1044    fn test_bsr_matvec() {
1045        // [1 2 0 0]   [1]   [3]
1046        // [3 4 0 0] * [1] = [7]
1047        // [0 0 5 6]   [1]   [11]
1048        // [0 0 7 8]   [1]   [15]
1049        let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1050        let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1051
1052        let bsr =
1053            BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1054
1055        let x = vec![1.0, 1.0, 1.0, 1.0];
1056        let y = bsr.mul_vec(&x);
1057
1058        assert!((y[0] - 3.0).abs() < 1e-10);
1059        assert!((y[1] - 7.0).abs() < 1e-10);
1060        assert!((y[2] - 11.0).abs() < 1e-10);
1061        assert!((y[3] - 15.0).abs() < 1e-10);
1062    }
1063
1064    #[test]
1065    fn test_bsr_eye() {
1066        let bsr: BsrMatrix<f64> = BsrMatrix::eye(4, 2);
1067
1068        assert_eq!(bsr.nrows(), 4);
1069        assert_eq!(bsr.ncols(), 4);
1070        assert_eq!(bsr.nblocks(), 2);
1071
1072        for i in 0..4 {
1073            assert_eq!(bsr.get(i, i), Some(1.0));
1074        }
1075    }
1076
1077    #[test]
1078    fn test_bsr_to_dense() {
1079        let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1080        let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1081
1082        let bsr =
1083            BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1084
1085        let dense = bsr.to_dense();
1086
1087        assert!((dense[(0, 0)] - 1.0).abs() < 1e-10);
1088        assert!((dense[(0, 1)] - 2.0).abs() < 1e-10);
1089        assert!((dense[(1, 0)] - 3.0).abs() < 1e-10);
1090        assert!((dense[(1, 1)] - 4.0).abs() < 1e-10);
1091        assert!((dense[(0, 2)] - 0.0).abs() < 1e-10);
1092        assert!((dense[(2, 2)] - 5.0).abs() < 1e-10);
1093    }
1094
1095    #[test]
1096    fn test_bsr_to_csr() {
1097        let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1098        let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1099
1100        let bsr =
1101            BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1102
1103        let csr = bsr.to_csr();
1104
1105        assert_eq!(csr.nrows(), 4);
1106        assert_eq!(csr.ncols(), 4);
1107        assert_eq!(csr.get(0, 0), Some(&1.0));
1108        assert_eq!(csr.get(2, 2), Some(&5.0));
1109    }
1110
1111    #[test]
1112    fn test_bsr_from_dense() {
1113        use oxiblas_matrix::Mat;
1114
1115        let dense = Mat::from_rows(&[
1116            &[1.0f64, 2.0, 0.0, 0.0],
1117            &[3.0, 4.0, 0.0, 0.0],
1118            &[0.0, 0.0, 5.0, 6.0],
1119            &[0.0, 0.0, 7.0, 8.0],
1120        ]);
1121
1122        let bsr = BsrMatrix::from_dense(&dense.as_ref(), 2, 2);
1123
1124        assert_eq!(bsr.nblocks(), 2);
1125        assert_eq!(bsr.get(0, 0), Some(1.0));
1126        assert_eq!(bsr.get(2, 2), Some(5.0));
1127    }
1128
1129    #[test]
1130    fn test_bsr_from_csr() {
1131        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1132        let col_indices = vec![0, 1, 0, 1, 2, 3, 2, 3];
1133        let row_ptrs = vec![0, 2, 4, 6, 8];
1134
1135        let csr = crate::csr::CsrMatrix::new(4, 4, row_ptrs, col_indices, values).unwrap();
1136        let bsr = BsrMatrix::from_csr(&csr, 2, 2);
1137
1138        assert_eq!(bsr.nblocks(), 2);
1139        assert_eq!(bsr.get(0, 0), Some(1.0));
1140        assert_eq!(bsr.get(3, 3), Some(8.0));
1141    }
1142
1143    #[test]
1144    fn test_bsr_scale() {
1145        let block = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1146        let mut bsr = BsrMatrix::new(2, 2, 2, 2, vec![0, 1], vec![0], vec![block]).unwrap();
1147
1148        bsr.scale(2.0);
1149
1150        assert_eq!(bsr.get(0, 0), Some(2.0));
1151        assert_eq!(bsr.get(1, 1), Some(8.0));
1152    }
1153
1154    #[test]
1155    fn test_bsr_transpose() {
1156        // [1 2]       [1 3]
1157        // [3 4]  ->   [2 4]
1158        let block = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1159        let bsr = BsrMatrix::new(2, 2, 2, 2, vec![0, 1], vec![0], vec![block]).unwrap();
1160
1161        let bsr_t = bsr.transpose();
1162        let dense = bsr.to_dense();
1163        let dense_t = bsr_t.to_dense();
1164
1165        for i in 0..2 {
1166            for j in 0..2 {
1167                assert!((dense[(i, j)] - dense_t[(j, i)]).abs() < 1e-10);
1168            }
1169        }
1170    }
1171
1172    #[test]
1173    fn test_bsr_zeros() {
1174        let bsr: BsrMatrix<f64> = BsrMatrix::zeros(6, 8, 2, 4);
1175
1176        assert_eq!(bsr.nrows(), 6);
1177        assert_eq!(bsr.ncols(), 8);
1178        assert_eq!(bsr.nblocks(), 0);
1179        assert_eq!(bsr.nblock_rows(), 3);
1180        assert_eq!(bsr.nblock_cols(), 2);
1181    }
1182
1183    #[test]
1184    fn test_bsr_non_square_blocks() {
1185        // 6×4 matrix with 3×2 blocks
1186        let block1 = DenseBlock::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1187        let block2 = DenseBlock::new(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1188
1189        let bsr =
1190            BsrMatrix::new(6, 4, 3, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1191
1192        assert_eq!(bsr.nrows(), 6);
1193        assert_eq!(bsr.ncols(), 4);
1194        assert_eq!(bsr.nblock_rows(), 2);
1195        assert_eq!(bsr.nblock_cols(), 2);
1196
1197        // Check values
1198        assert_eq!(bsr.get(0, 0), Some(1.0));
1199        assert_eq!(bsr.get(2, 1), Some(6.0));
1200        assert_eq!(bsr.get(3, 2), Some(7.0));
1201    }
1202
1203    #[test]
1204    fn test_bsr_block_iter() {
1205        let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1206        let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1207
1208        let bsr =
1209            BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1210
1211        let blocks: Vec<_> = bsr.block_iter().map(|(bi, bj, _)| (bi, bj)).collect();
1212        assert_eq!(blocks, vec![(0, 0), (1, 1)]);
1213    }
1214
1215    #[test]
1216    fn test_dense_block() {
1217        let mut block = DenseBlock::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1218
1219        assert_eq!(block.shape(), (2, 3));
1220        assert_eq!(*block.get(0, 0), 1.0);
1221        assert_eq!(*block.get(0, 2), 3.0);
1222        assert_eq!(*block.get(1, 1), 5.0);
1223
1224        *block.get_mut(1, 1) = 10.0;
1225        assert_eq!(*block.get(1, 1), 10.0);
1226    }
1227}