scirs2_sparse/
bsr_array.rs

1// BSR Array implementation
2//
3// This module provides the BSR (Block Sparse Row) array format,
4// which is efficient for matrices with block-structured sparsity patterns.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csc_array::CscArray;
13use crate::csr_array::CsrArray;
14use crate::dia_array::DiaArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::{SparseArray, SparseSum};
19
20/// BSR Array format
21///
22/// The BSR (Block Sparse Row) format stores a sparse matrix as a sparse matrix
23/// of dense blocks. It's particularly efficient for matrices with block-structured
24/// sparsity patterns, such as those arising in finite element methods.
25///
26/// # Notes
27///
28/// - Very efficient for matrices with block structure
29/// - Fast matrix-vector products for block-structured matrices
30/// - Reduced indexing overhead compared to CSR for block-structured problems
31/// - Not efficient for general sparse matrices
32/// - Difficult to modify once constructed
33///
34#[derive(Clone)]
35pub struct BsrArray<T>
36where
37    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
38{
39    /// Number of rows
40    rows: usize,
41    /// Number of columns
42    cols: usize,
43    /// Block size (r, c)
44    block_size: (usize, usize),
45    /// Number of block rows
46    block_rows: usize,
47    /// Number of block columns (needed for internal calculations)
48    #[allow(dead_code)]
49    block_cols: usize,
50    /// Data array (blocks stored row by row)
51    data: Vec<Vec<Vec<T>>>,
52    /// Column indices for each block
53    indices: Vec<Vec<usize>>,
54    /// Row pointers (indptr)
55    indptr: Vec<usize>,
56}
57
58impl<T> BsrArray<T>
59where
60    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
61{
62    /// Create a new BSR array from raw data
63    ///
64    /// # Arguments
65    ///
66    /// * `data` - Block data (blocks stored row by row)
67    /// * `indices` - Column indices for each block
68    /// * `indptr` - Row pointers
69    /// * `shape` - Tuple containing the array dimensions (rows, cols)
70    /// * `block_size` - Tuple containing the block dimensions (r, c)
71    ///
72    /// # Returns
73    ///
74    /// * A new BSR array
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// use scirs2_sparse::bsr_array::BsrArray;
80    /// use scirs2_sparse::sparray::SparseArray;
81    ///
82    /// // Create a 4x4 sparse array with 2x2 blocks
83    /// // [1 2 0 0]
84    /// // [3 4 0 0]
85    /// // [0 0 5 6]
86    /// // [0 0 7 8]
87    ///
88    /// let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
89    /// let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
90    ///
91    /// let data = vec![block1, block2];
92    /// let indices = vec![vec![0], vec![1]];
93    /// let indptr = vec![0, 1, 2];
94    ///
95    /// let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
96    /// assert_eq!(array.shape(), (4, 4));
97    /// assert_eq!(array.nnz(), 8); // All elements in the blocks are non-zero
98    /// ```
99    pub fn new(
100        data: Vec<Vec<Vec<T>>>,
101        indices: Vec<Vec<usize>>,
102        indptr: Vec<usize>,
103        shape: (usize, usize),
104        block_size: (usize, usize),
105    ) -> SparseResult<Self> {
106        let (rows, cols) = shape;
107        let (r, c) = block_size;
108
109        if r == 0 || c == 0 {
110            return Err(SparseError::ValueError(
111                "Block dimensions must be positive".to_string(),
112            ));
113        }
114
115        // Calculate block dimensions
116        #[allow(clippy::manual_div_ceil)]
117        let block_rows = (rows + r - 1) / r; // Ceiling division
118        #[allow(clippy::manual_div_ceil)]
119        let block_cols = (cols + c - 1) / c; // Ceiling division
120
121        // Validate input
122        if indptr.len() != block_rows + 1 {
123            return Err(SparseError::DimensionMismatch {
124                expected: block_rows + 1,
125                found: indptr.len(),
126            });
127        }
128
129        if data.len() != indptr[block_rows] {
130            return Err(SparseError::DimensionMismatch {
131                expected: indptr[block_rows],
132                found: data.len(),
133            });
134        }
135
136        if indices.len() != data.len() {
137            return Err(SparseError::DimensionMismatch {
138                expected: data.len(),
139                found: indices.len(),
140            });
141        }
142
143        for block in data.iter() {
144            if block.len() != r {
145                return Err(SparseError::DimensionMismatch {
146                    expected: r,
147                    found: block.len(),
148                });
149            }
150
151            for row in block.iter() {
152                if row.len() != c {
153                    return Err(SparseError::DimensionMismatch {
154                        expected: c,
155                        found: row.len(),
156                    });
157                }
158            }
159        }
160
161        for idx_vec in indices.iter() {
162            if idx_vec.len() != 1 {
163                return Err(SparseError::ValueError(
164                    "Each index vector must contain exactly one block column index".to_string(),
165                ));
166            }
167            if idx_vec[0] >= block_cols {
168                return Err(SparseError::ValueError(format!(
169                    "index {} out of bounds (max {})",
170                    idx_vec[0],
171                    block_cols - 1
172                )));
173            }
174        }
175
176        Ok(BsrArray {
177            rows,
178            cols,
179            block_size,
180            block_rows,
181            block_cols,
182            data,
183            indices,
184            indptr,
185        })
186    }
187
188    /// Create a new empty BSR array
189    ///
190    /// # Arguments
191    ///
192    /// * `shape` - Tuple containing the array dimensions (rows, cols)
193    /// * `block_size` - Tuple containing the block dimensions (r, c)
194    ///
195    /// # Returns
196    ///
197    /// * A new empty BSR array
198    pub fn empty(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
199        let (rows, cols) = shape;
200        let (r, c) = block_size;
201
202        if r == 0 || c == 0 {
203            return Err(SparseError::ValueError(
204                "Block dimensions must be positive".to_string(),
205            ));
206        }
207
208        // Calculate block dimensions
209        #[allow(clippy::manual_div_ceil)]
210        let block_rows = (rows + r - 1) / r; // Ceiling division
211        #[allow(clippy::manual_div_ceil)]
212        let block_cols = (cols + c - 1) / c; // Ceiling division
213
214        // Initialize empty BSR array
215        let data = Vec::new();
216        let indices = Vec::new();
217        let indptr = vec![0; block_rows + 1];
218
219        Ok(BsrArray {
220            rows,
221            cols,
222            block_size,
223            block_rows,
224            block_cols,
225            data,
226            indices,
227            indptr,
228        })
229    }
230
231    /// Convert triplets to BSR format
232    ///
233    /// # Arguments
234    ///
235    /// * `row` - Row indices
236    /// * `col` - Column indices
237    /// * `data` - Data values
238    /// * `shape` - Shape of the array
239    /// * `block_size` - Size of the blocks
240    ///
241    /// # Returns
242    ///
243    /// * A new BSR array
244    pub fn from_triplets(
245        row: &[usize],
246        col: &[usize],
247        data: &[T],
248        shape: (usize, usize),
249        block_size: (usize, usize),
250    ) -> SparseResult<Self> {
251        if row.len() != col.len() || row.len() != data.len() {
252            return Err(SparseError::InconsistentData {
253                reason: "Lengths of row, col, and data arrays must be equal".to_string(),
254            });
255        }
256
257        let (rows, cols) = shape;
258        let (r, c) = block_size;
259
260        if r == 0 || c == 0 {
261            return Err(SparseError::ValueError(
262                "Block dimensions must be positive".to_string(),
263            ));
264        }
265
266        // Calculate block dimensions
267        #[allow(clippy::manual_div_ceil)]
268        let block_rows = (rows + r - 1) / r; // Ceiling division
269        #[allow(clippy::manual_div_ceil)]
270        let block_cols = (cols + c - 1) / c; // Ceiling division
271
272        // First, we'll construct a temporary DOK-like representation for the blocks
273        let mut block_data = std::collections::HashMap::new();
274
275        // Assign each element to its corresponding block
276        for (&row_idx, (&col_idx, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
277            if row_idx >= rows || col_idx >= cols {
278                return Err(SparseError::IndexOutOfBounds {
279                    index: (row_idx, col_idx),
280                    shape,
281                });
282            }
283
284            // Calculate block indices
285            let block_row = row_idx / r;
286            let block_col = col_idx / c;
287
288            // Calculate position within block
289            let block_row_pos = row_idx % r;
290            let block_col_pos = col_idx % c;
291
292            // Create or get the block
293            let block = block_data.entry((block_row, block_col)).or_insert_with(|| {
294                let block = vec![vec![T::sparse_zero(); c]; r];
295                block
296            });
297
298            // Set the value in the block
299            block[block_row_pos][block_col_pos] = val;
300        }
301
302        // Now convert the DOK-like format to BSR
303        let mut rowswith_blocks: Vec<usize> = block_data.keys().map(|&(row_, _)| row_).collect();
304        rowswith_blocks.sort();
305        rowswith_blocks.dedup();
306
307        // Create indptr array
308        let mut indptr = vec![0; block_rows + 1];
309        let mut current_nnz = 0;
310
311        // Sorted blocks data and indices
312        let mut data = Vec::new();
313        let mut indices = Vec::new();
314
315        for row_idx in 0..block_rows {
316            if rowswith_blocks.contains(&row_idx) {
317                // Get all blocks for this row
318                let mut row_blocks: Vec<(usize, Vec<Vec<T>>)> = block_data
319                    .iter()
320                    .filter(|&(&(r, _), _)| r == row_idx)
321                    .map(|(&(_, c), block)| (c, block.clone()))
322                    .collect();
323
324                // Sort by column index
325                row_blocks.sort_by_key(|&(col_, _)| col_);
326
327                // Add to data and indices
328                for (col, block) in row_blocks {
329                    data.push(block);
330                    indices.push(vec![col]);
331                    current_nnz += 1;
332                }
333            }
334
335            indptr[row_idx + 1] = current_nnz;
336        }
337
338        // Create the BSR array
339        BsrArray::new(data, indices, indptr, shape, block_size)
340    }
341
342    /// Convert to COO format triplets
343    fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
344        let (r, c) = self.block_size;
345        let mut row_indices = Vec::new();
346        let mut col_indices = Vec::new();
347        let mut values = Vec::new();
348
349        for block_row in 0..self.block_rows {
350            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
351                let block_col = self.indices[k][0];
352                let block = &self.data[k];
353
354                // For each element in the block
355                for (i, block_row_data) in block.iter().enumerate().take(r) {
356                    let row = block_row * r + i;
357                    if row < self.rows {
358                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
359                            let col = block_col * c + j;
360                            if col < self.cols && !SparseElement::is_zero(&value) {
361                                row_indices.push(row);
362                                col_indices.push(col);
363                                values.push(value);
364                            }
365                        }
366                    }
367                }
368            }
369        }
370
371        (row_indices, col_indices, values)
372    }
373}
374
375impl<T> SparseArray<T> for BsrArray<T>
376where
377    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
378{
379    fn shape(&self) -> (usize, usize) {
380        (self.rows, self.cols)
381    }
382
383    fn nnz(&self) -> usize {
384        let mut count = 0;
385
386        for block in &self.data {
387            for row in block {
388                for &val in row {
389                    if !SparseElement::is_zero(&val) {
390                        count += 1;
391                    }
392                }
393            }
394        }
395
396        count
397    }
398
399    fn dtype(&self) -> &str {
400        "float" // Placeholder; ideally would return the actual type
401    }
402
403    fn to_array(&self) -> Array2<T> {
404        let mut result = Array2::zeros((self.rows, self.cols));
405        let (r, c) = self.block_size;
406
407        for block_row in 0..self.block_rows {
408            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
409                let block_col = self.indices[k][0];
410                let block = &self.data[k];
411
412                for (i, block_row_data) in block.iter().enumerate().take(r) {
413                    let row = block_row * r + i;
414                    if row < self.rows {
415                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
416                            let col = block_col * c + j;
417                            if col < self.cols {
418                                result[[row, col]] = value;
419                            }
420                        }
421                    }
422                }
423            }
424        }
425
426        result
427    }
428
429    fn toarray(&self) -> Array2<T> {
430        self.to_array()
431    }
432
433    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
434        let (row_indices, col_indices, values) = self.to_coo_internal();
435        CooArray::from_triplets(
436            &row_indices,
437            &col_indices,
438            &values,
439            (self.rows, self.cols),
440            false,
441        )
442        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
443    }
444
445    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
446        let (row_indices, col_indices, values) = self.to_coo_internal();
447        CsrArray::from_triplets(
448            &row_indices,
449            &col_indices,
450            &values,
451            (self.rows, self.cols),
452            false,
453        )
454        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
455    }
456
457    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
458        let (row_indices, col_indices, values) = self.to_coo_internal();
459        CscArray::from_triplets(
460            &row_indices,
461            &col_indices,
462            &values,
463            (self.rows, self.cols),
464            false,
465        )
466        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
467    }
468
469    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
470        let (row_indices, col_indices, values) = self.to_coo_internal();
471        DokArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
472            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
473    }
474
475    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
476        let (row_indices, col_indices, values) = self.to_coo_internal();
477        LilArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
478            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
479    }
480
481    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
482        let (row_indices, col_indices, values) = self.to_coo_internal();
483        DiaArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
484            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
485    }
486
487    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
488        Ok(Box::new(self.clone()))
489    }
490
491    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
492        // For efficiency, convert both to CSR for addition
493        let csr_self = self.to_csr()?;
494        let csr_other = other.to_csr()?;
495        csr_self.add(&*csr_other)
496    }
497
498    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
499        // For efficiency, convert both to CSR for subtraction
500        let csr_self = self.to_csr()?;
501        let csr_other = other.to_csr()?;
502        csr_self.sub(&*csr_other)
503    }
504
505    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
506        // For efficiency, convert both to CSR for element-wise multiplication
507        let csr_self = self.to_csr()?;
508        let csr_other = other.to_csr()?;
509        csr_self.mul(&*csr_other)
510    }
511
512    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
513        // For efficiency, convert both to CSR for element-wise division
514        let csr_self = self.to_csr()?;
515        let csr_other = other.to_csr()?;
516        csr_self.div(&*csr_other)
517    }
518
519    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
520        let (_, n) = self.shape();
521        let (p, q) = other.shape();
522
523        if n != p {
524            return Err(SparseError::DimensionMismatch {
525                expected: n,
526                found: p,
527            });
528        }
529
530        // If other is a vector (thin matrix), we can use optimized BSR-Vector multiplication
531        if q == 1 {
532            // Get the vector from other
533            let other_array = other.to_array();
534            let vec_view = other_array.column(0);
535
536            // Perform BSR-Vector multiplication
537            let result = self.dot_vector(&vec_view)?;
538
539            // Convert to a matrix - create a COO from triplets
540            let mut rows = Vec::new();
541            let mut cols = Vec::new();
542            let mut values = Vec::new();
543
544            for (i, &val) in result.iter().enumerate() {
545                if !SparseElement::is_zero(&val) {
546                    rows.push(i);
547                    cols.push(0);
548                    values.push(val);
549                }
550            }
551
552            CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
553                .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
554        } else {
555            // For general matrix-matrix multiplication, convert to CSR
556            let csr_self = self.to_csr()?;
557            csr_self.dot(other)
558        }
559    }
560
561    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
562        let (rows, cols) = self.shape();
563        let (r, c) = self.block_size;
564
565        if cols != other.len() {
566            return Err(SparseError::DimensionMismatch {
567                expected: cols,
568                found: other.len(),
569            });
570        }
571
572        let mut result = Array1::zeros(rows);
573
574        for block_row in 0..self.block_rows {
575            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
576                let block_col = self.indices[k][0];
577                let block = &self.data[k];
578
579                // For each element in the block
580                for (i, block_row_data) in block.iter().enumerate().take(r) {
581                    let row = block_row * r + i;
582                    if row < self.rows {
583                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
584                            let col = block_col * c + j;
585                            if col < self.cols {
586                                result[row] += value * other[col];
587                            }
588                        }
589                    }
590                }
591            }
592        }
593
594        Ok(result)
595    }
596
597    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
598        // For efficiency, convert to COO, transpose, then convert back to BSR
599        self.to_coo()?.transpose()?.to_bsr()
600    }
601
602    fn copy(&self) -> Box<dyn SparseArray<T>> {
603        Box::new(self.clone())
604    }
605
606    fn get(&self, i: usize, j: usize) -> T {
607        if i >= self.rows || j >= self.cols {
608            return T::sparse_zero();
609        }
610
611        let (r, c) = self.block_size;
612        let block_row = i / r;
613        let block_col = j / c;
614        let block_row_pos = i % r;
615        let block_col_pos = j % c;
616
617        // Search for the block in the row
618        for k in self.indptr[block_row]..self.indptr[block_row + 1] {
619            if self.indices[k][0] == block_col {
620                return self.data[k][block_row_pos][block_col_pos];
621            }
622        }
623
624        T::sparse_zero()
625    }
626
627    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
628        if i >= self.rows || j >= self.cols {
629            return Err(SparseError::IndexOutOfBounds {
630                index: (i, j),
631                shape: (self.rows, self.cols),
632            });
633        }
634
635        let (r, c) = self.block_size;
636        let block_row = i / r;
637        let block_col = j / c;
638        let block_row_pos = i % r;
639        let block_col_pos = j % c;
640
641        // Search for the block in the row
642        for k in self.indptr[block_row]..self.indptr[block_row + 1] {
643            if self.indices[k][0] == block_col {
644                // Block exists, update value
645                self.data[k][block_row_pos][block_col_pos] = value;
646                return Ok(());
647            }
648        }
649
650        // Block doesn't exist, we need to create it
651        if !SparseElement::is_zero(&value) {
652            // Find position to insert
653            let pos = self.indptr[block_row + 1];
654
655            // Create new block
656            let mut block = vec![vec![T::sparse_zero(); c]; r];
657            block[block_row_pos][block_col_pos] = value;
658
659            // Insert block, indices
660            self.data.insert(pos, block);
661            self.indices.insert(pos, vec![block_col]);
662
663            // Update indptr for subsequent rows
664            for k in (block_row + 1)..=self.block_rows {
665                self.indptr[k] += 1;
666            }
667
668            Ok(())
669        } else {
670            // If value is zero and block doesn't exist, do nothing
671            Ok(())
672        }
673    }
674
675    fn eliminate_zeros(&mut self) {
676        // No need to use block_size variables here
677        let mut new_data = Vec::new();
678        let mut new_indices = Vec::new();
679        let mut new_indptr = vec![0];
680        let mut current_nnz = 0;
681
682        for block_row in 0..self.block_rows {
683            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
684                let block_col = self.indices[k][0];
685                let block = &self.data[k];
686
687                // Check if block has any non-zero elements
688                let mut has_nonzero = false;
689                for row in block {
690                    for &val in row {
691                        if !SparseElement::is_zero(&val) {
692                            has_nonzero = true;
693                            break;
694                        }
695                    }
696                    if has_nonzero {
697                        break;
698                    }
699                }
700
701                if has_nonzero {
702                    new_data.push(block.clone());
703                    new_indices.push(vec![block_col]);
704                    current_nnz += 1;
705                }
706            }
707
708            new_indptr.push(current_nnz);
709        }
710
711        self.data = new_data;
712        self.indices = new_indices;
713        self.indptr = new_indptr;
714    }
715
716    fn sort_indices(&mut self) {
717        // No need to use block_size variables here
718        let mut new_data = Vec::new();
719        let mut new_indices = Vec::new();
720        let mut new_indptr = vec![0];
721        let mut current_nnz = 0;
722
723        for block_row in 0..self.block_rows {
724            // Get blocks for this row
725            let mut row_blocks = Vec::new();
726            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
727                row_blocks.push((self.indices[k][0], self.data[k].clone()));
728            }
729
730            // Sort by column index
731            row_blocks.sort_by_key(|&(col_, _)| col_);
732
733            // Add sorted blocks to new data structures
734            for (col, block) in row_blocks {
735                new_data.push(block);
736                new_indices.push(vec![col]);
737                current_nnz += 1;
738            }
739
740            new_indptr.push(current_nnz);
741        }
742
743        self.data = new_data;
744        self.indices = new_indices;
745        self.indptr = new_indptr;
746    }
747
748    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
749        let mut result = self.clone();
750        result.sort_indices();
751        Box::new(result)
752    }
753
754    fn has_sorted_indices(&self) -> bool {
755        for block_row in 0..self.block_rows {
756            let mut prev_col = None;
757
758            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
759                let col = self.indices[k][0];
760
761                if let Some(prev) = prev_col {
762                    if col <= prev {
763                        return false;
764                    }
765                }
766
767                prev_col = Some(col);
768            }
769        }
770
771        true
772    }
773
774    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
775        match axis {
776            None => {
777                // Sum all elements
778                let mut total = T::sparse_zero();
779
780                for block in &self.data {
781                    for row in block {
782                        for &val in row {
783                            total += val;
784                        }
785                    }
786                }
787
788                Ok(SparseSum::Scalar(total))
789            }
790            Some(0) => {
791                // Sum along rows (result is 1 x cols)
792                let mut result = vec![T::sparse_zero(); self.cols];
793                let (r, c) = self.block_size;
794
795                for block_row in 0..self.block_rows {
796                    for k in self.indptr[block_row]..self.indptr[block_row + 1] {
797                        let block_col = self.indices[k][0];
798                        let block = &self.data[k];
799
800                        for block_row_data in block.iter().take(r) {
801                            for (j, &value) in block_row_data.iter().enumerate().take(c) {
802                                let col = block_col * c + j;
803                                if col < self.cols {
804                                    result[col] += value;
805                                }
806                            }
807                        }
808                    }
809                }
810
811                // Create a sparse array from the result
812                let mut row_indices = Vec::new();
813                let mut col_indices = Vec::new();
814                let mut values = Vec::new();
815
816                for (j, &val) in result.iter().enumerate() {
817                    if !SparseElement::is_zero(&val) {
818                        row_indices.push(0);
819                        col_indices.push(j);
820                        values.push(val);
821                    }
822                }
823
824                match CooArray::from_triplets(
825                    &row_indices,
826                    &col_indices,
827                    &values,
828                    (1, self.cols),
829                    false,
830                ) {
831                    Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
832                    Err(e) => Err(e),
833                }
834            }
835            Some(1) => {
836                // Sum along columns (result is rows x 1)
837                let mut result = vec![T::sparse_zero(); self.rows];
838                let (r, c) = self.block_size;
839
840                for block_row in 0..self.block_rows {
841                    for k in self.indptr[block_row]..self.indptr[block_row + 1] {
842                        let block = &self.data[k];
843
844                        for (i, block_row_data) in block.iter().enumerate().take(r) {
845                            let row = block_row * r + i;
846                            if row < self.rows {
847                                for &value in block_row_data.iter().take(c) {
848                                    result[row] += value;
849                                }
850                            }
851                        }
852                    }
853                }
854
855                // Create a sparse array from the result
856                let mut row_indices = Vec::new();
857                let mut col_indices = Vec::new();
858                let mut values = Vec::new();
859
860                for (i, &val) in result.iter().enumerate() {
861                    if !SparseElement::is_zero(&val) {
862                        row_indices.push(i);
863                        col_indices.push(0);
864                        values.push(val);
865                    }
866                }
867
868                match CooArray::from_triplets(
869                    &row_indices,
870                    &col_indices,
871                    &values,
872                    (self.rows, 1),
873                    false,
874                ) {
875                    Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
876                    Err(e) => Err(e),
877                }
878            }
879            _ => Err(SparseError::InvalidAxis),
880        }
881    }
882
883    fn max(&self) -> T {
884        let mut max_val = T::neg_infinity();
885
886        for block in &self.data {
887            for row in block {
888                for &val in row {
889                    max_val = max_val.max(val);
890                }
891            }
892        }
893
894        // If no elements or all negative infinity, return zero
895        if max_val == T::neg_infinity() {
896            T::sparse_zero()
897        } else {
898            max_val
899        }
900    }
901
902    fn min(&self) -> T {
903        let mut min_val = T::sparse_zero();
904        let mut has_nonzero = false;
905
906        for block in &self.data {
907            for row in block {
908                for &val in row {
909                    if !SparseElement::is_zero(&val) {
910                        has_nonzero = true;
911                        min_val = min_val.min(val);
912                    }
913                }
914            }
915        }
916
917        // If no non-zero elements, return zero
918        if !has_nonzero {
919            T::sparse_zero()
920        } else {
921            min_val
922        }
923    }
924
925    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
926        let (row_indices, col_indices, values) = self.to_coo_internal();
927
928        (
929            Array1::from_vec(row_indices),
930            Array1::from_vec(col_indices),
931            Array1::from_vec(values),
932        )
933    }
934
935    fn slice(
936        &self,
937        row_range: (usize, usize),
938        col_range: (usize, usize),
939    ) -> SparseResult<Box<dyn SparseArray<T>>> {
940        let (start_row, end_row) = row_range;
941        let (start_col, end_col) = col_range;
942        let (rows, cols) = self.shape();
943
944        if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
945            return Err(SparseError::IndexOutOfBounds {
946                index: (start_row.max(end_row), start_col.max(end_col)),
947                shape: (rows, cols),
948            });
949        }
950
951        if start_row >= end_row || start_col >= end_col {
952            return Err(SparseError::InvalidSliceRange);
953        }
954
955        // Convert to COO, slice, then convert back to BSR
956        let coo = self.to_coo()?;
957        coo.slice(row_range, col_range)?.to_bsr()
958    }
959
960    fn as_any(&self) -> &dyn std::any::Any {
961        self
962    }
963}
964
965// Implement Display for BsrArray for better debugging
966impl<T> fmt::Display for BsrArray<T>
967where
968    T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
969{
970    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
971        writeln!(
972            f,
973            "BsrArray of shape {:?} with {} stored elements",
974            (self.rows, self.cols),
975            self.nnz()
976        )?;
977        writeln!(f, "Block size: {:?}", self.block_size)?;
978        writeln!(f, "Number of blocks: {}", self.data.len())?;
979
980        if self.data.len() <= 5 {
981            for block_row in 0..self.block_rows {
982                for k in self.indptr[block_row]..self.indptr[block_row + 1] {
983                    let block_col = self.indices[k][0];
984                    let block = &self.data[k];
985
986                    writeln!(f, "Block at ({block_row}, {block_col}): ")?;
987                    for row in block {
988                        write!(f, "  [")?;
989                        for (j, &val) in row.iter().enumerate() {
990                            if j > 0 {
991                                write!(f, ", ")?;
992                            }
993                            write!(f, "{val:?}")?;
994                        }
995                        writeln!(f, "]")?;
996                    }
997                }
998            }
999        } else {
1000            writeln!(f, "({} blocks total)", self.data.len())?;
1001        }
1002
1003        Ok(())
1004    }
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009    use super::*;
1010
1011    #[test]
1012    fn test_bsr_array_create() {
1013        // Create a 4x4 sparse array with 2x2 blocks
1014        // [1 2 0 0]
1015        // [3 4 0 0]
1016        // [0 0 5 6]
1017        // [0 0 7 8]
1018
1019        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1020        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1021
1022        let data = vec![block1, block2];
1023        let indices = vec![vec![0], vec![1]];
1024        let indptr = vec![0, 1, 2];
1025
1026        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1027
1028        assert_eq!(array.shape(), (4, 4));
1029        assert_eq!(array.block_size, (2, 2));
1030        assert_eq!(array.nnz(), 8); // All elements in the blocks are non-zero
1031
1032        // Test values
1033        assert_eq!(array.get(0, 0), 1.0);
1034        assert_eq!(array.get(0, 1), 2.0);
1035        assert_eq!(array.get(1, 0), 3.0);
1036        assert_eq!(array.get(1, 1), 4.0);
1037        assert_eq!(array.get(2, 2), 5.0);
1038        assert_eq!(array.get(2, 3), 6.0);
1039        assert_eq!(array.get(3, 2), 7.0);
1040        assert_eq!(array.get(3, 3), 8.0);
1041        assert_eq!(array.get(0, 2), 0.0); // zero element
1042    }
1043
1044    #[test]
1045    fn test_bsr_array_from_triplets() {
1046        // Create a 4x4 sparse array with 2x2 blocks
1047        let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
1048        let cols = vec![0, 1, 0, 1, 2, 3, 2, 3];
1049        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1050        let shape = (4, 4);
1051        let block_size = (2, 2);
1052
1053        let array = BsrArray::from_triplets(&rows, &cols, &data, shape, block_size).unwrap();
1054
1055        assert_eq!(array.shape(), (4, 4));
1056        assert_eq!(array.block_size, (2, 2));
1057        assert_eq!(array.nnz(), 8);
1058
1059        // Test values
1060        assert_eq!(array.get(0, 0), 1.0);
1061        assert_eq!(array.get(0, 1), 2.0);
1062        assert_eq!(array.get(1, 0), 3.0);
1063        assert_eq!(array.get(1, 1), 4.0);
1064        assert_eq!(array.get(2, 2), 5.0);
1065        assert_eq!(array.get(2, 3), 6.0);
1066        assert_eq!(array.get(3, 2), 7.0);
1067        assert_eq!(array.get(3, 3), 8.0);
1068        assert_eq!(array.get(0, 2), 0.0); // zero element
1069    }
1070
1071    #[test]
1072    fn test_bsr_array_conversion() {
1073        // Create a 4x4 sparse array with 2x2 blocks
1074        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1075        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1076
1077        let data = vec![block1, block2];
1078        let indices = vec![vec![0], vec![1]];
1079        let indptr = vec![0, 1, 2];
1080
1081        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1082
1083        // Convert to COO and check
1084        let coo = array.to_coo().unwrap();
1085        assert_eq!(coo.shape(), (4, 4));
1086        assert_eq!(coo.nnz(), 8);
1087
1088        // Convert to CSR and check
1089        let csr = array.to_csr().unwrap();
1090        assert_eq!(csr.shape(), (4, 4));
1091        assert_eq!(csr.nnz(), 8);
1092
1093        // Convert to dense and check
1094        let dense = array.to_array();
1095        let expected = Array2::from_shape_vec(
1096            (4, 4),
1097            vec![
1098                1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 7.0, 8.0,
1099            ],
1100        )
1101        .unwrap();
1102        assert_eq!(dense, expected);
1103    }
1104
1105    #[test]
1106    fn test_bsr_array_operations() {
1107        // Create two simple block arrays
1108        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1109        let data1 = vec![block1];
1110        let indices1 = vec![vec![0]];
1111        let indptr1 = vec![0, 1];
1112        let array1 = BsrArray::new(data1, indices1, indptr1, (2, 2), (2, 2)).unwrap();
1113
1114        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1115        let data2 = vec![block2];
1116        let indices2 = vec![vec![0]];
1117        let indptr2 = vec![0, 1];
1118        let array2 = BsrArray::new(data2, indices2, indptr2, (2, 2), (2, 2)).unwrap();
1119
1120        // Test addition
1121        let sum = array1.add(&array2).unwrap();
1122        assert_eq!(sum.shape(), (2, 2));
1123        assert_eq!(sum.get(0, 0), 6.0); // 1+5
1124        assert_eq!(sum.get(0, 1), 8.0); // 2+6
1125        assert_eq!(sum.get(1, 0), 10.0); // 3+7
1126        assert_eq!(sum.get(1, 1), 12.0); // 4+8
1127
1128        // Test element-wise multiplication
1129        let product = array1.mul(&array2).unwrap();
1130        assert_eq!(product.shape(), (2, 2));
1131        assert_eq!(product.get(0, 0), 5.0); // 1*5
1132        assert_eq!(product.get(0, 1), 12.0); // 2*6
1133        assert_eq!(product.get(1, 0), 21.0); // 3*7
1134        assert_eq!(product.get(1, 1), 32.0); // 4*8
1135
1136        // Test dot product (matrix multiplication)
1137        let dot = array1.dot(&array2).unwrap();
1138        assert_eq!(dot.shape(), (2, 2));
1139        assert_eq!(dot.get(0, 0), 19.0); // 1*5 + 2*7
1140        assert_eq!(dot.get(0, 1), 22.0); // 1*6 + 2*8
1141        assert_eq!(dot.get(1, 0), 43.0); // 3*5 + 4*7
1142        assert_eq!(dot.get(1, 1), 50.0); // 3*6 + 4*8
1143    }
1144
1145    #[test]
1146    fn test_bsr_array_dot_vector() {
1147        // Create a 4x4 sparse array with 2x2 blocks
1148        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1149        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1150
1151        let data = vec![block1, block2];
1152        let indices = vec![vec![0], vec![1]];
1153        let indptr = vec![0, 1, 2];
1154
1155        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1156
1157        // Create a vector
1158        let vector = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1159
1160        // Test matrix-vector multiplication
1161        let result = array.dot_vector(&vector.view()).unwrap();
1162
1163        // Expected: [1*1 + 2*2 + 0*3 + 0*4, 3*1 + 4*2 + 0*3 + 0*4,
1164        //            0*1 + 0*2 + 5*3 + 6*4, 0*1 + 0*2 + 7*3 + 8*4]
1165        // = [5, 11, 39, 53]
1166        let expected = Array1::from_vec(vec![5.0, 11.0, 39.0, 53.0]);
1167        assert_eq!(result, expected);
1168    }
1169
1170    #[test]
1171    fn test_bsr_array_sum() {
1172        // Create a 4x4 sparse array with 2x2 blocks
1173        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1174        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1175
1176        let data = vec![block1, block2];
1177        let indices = vec![vec![0], vec![1]];
1178        let indptr = vec![0, 1, 2];
1179
1180        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1181
1182        // Test sum of entire array
1183        if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1184            assert_eq!(sum, 36.0); // 1+2+3+4+5+6+7+8 = 36
1185        } else {
1186            panic!("Expected SparseSum::Scalar");
1187        }
1188
1189        // Test sum along rows (result should be 1 x 4)
1190        if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1191            assert_eq!(row_sum.shape(), (1, 4));
1192            assert_eq!(row_sum.get(0, 0), 4.0); // 1+3
1193            assert_eq!(row_sum.get(0, 1), 6.0); // 2+4
1194            assert_eq!(row_sum.get(0, 2), 12.0); // 5+7
1195            assert_eq!(row_sum.get(0, 3), 14.0); // 6+8
1196        } else {
1197            panic!("Expected SparseSum::SparseArray");
1198        }
1199
1200        // Test sum along columns (result should be 4 x 1)
1201        if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1202            assert_eq!(col_sum.shape(), (4, 1));
1203            assert_eq!(col_sum.get(0, 0), 3.0); // 1+2
1204            assert_eq!(col_sum.get(1, 0), 7.0); // 3+4
1205            assert_eq!(col_sum.get(2, 0), 11.0); // 5+6
1206            assert_eq!(col_sum.get(3, 0), 15.0); // 7+8
1207        } else {
1208            panic!("Expected SparseSum::SparseArray");
1209        }
1210    }
1211}