1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::{Float, SparseElement};
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csc_array::CscArray;
13use crate::csr_array::CsrArray;
14use crate::dia_array::DiaArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::{SparseArray, SparseSum};
19
20#[derive(Clone)]
35pub struct BsrArray<T>
36where
37 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
38{
39 rows: usize,
41 cols: usize,
43 block_size: (usize, usize),
45 block_rows: usize,
47 #[allow(dead_code)]
49 block_cols: usize,
50 data: Vec<Vec<Vec<T>>>,
52 indices: Vec<Vec<usize>>,
54 indptr: Vec<usize>,
56}
57
58impl<T> BsrArray<T>
59where
60 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
61{
62 pub fn new(
100 data: Vec<Vec<Vec<T>>>,
101 indices: Vec<Vec<usize>>,
102 indptr: Vec<usize>,
103 shape: (usize, usize),
104 block_size: (usize, usize),
105 ) -> SparseResult<Self> {
106 let (rows, cols) = shape;
107 let (r, c) = block_size;
108
109 if r == 0 || c == 0 {
110 return Err(SparseError::ValueError(
111 "Block dimensions must be positive".to_string(),
112 ));
113 }
114
115 #[allow(clippy::manual_div_ceil)]
117 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
119 let block_cols = (cols + c - 1) / c; if indptr.len() != block_rows + 1 {
123 return Err(SparseError::DimensionMismatch {
124 expected: block_rows + 1,
125 found: indptr.len(),
126 });
127 }
128
129 if data.len() != indptr[block_rows] {
130 return Err(SparseError::DimensionMismatch {
131 expected: indptr[block_rows],
132 found: data.len(),
133 });
134 }
135
136 if indices.len() != data.len() {
137 return Err(SparseError::DimensionMismatch {
138 expected: data.len(),
139 found: indices.len(),
140 });
141 }
142
143 for block in data.iter() {
144 if block.len() != r {
145 return Err(SparseError::DimensionMismatch {
146 expected: r,
147 found: block.len(),
148 });
149 }
150
151 for row in block.iter() {
152 if row.len() != c {
153 return Err(SparseError::DimensionMismatch {
154 expected: c,
155 found: row.len(),
156 });
157 }
158 }
159 }
160
161 for idx_vec in indices.iter() {
162 if idx_vec.len() != 1 {
163 return Err(SparseError::ValueError(
164 "Each index vector must contain exactly one block column index".to_string(),
165 ));
166 }
167 if idx_vec[0] >= block_cols {
168 return Err(SparseError::ValueError(format!(
169 "index {} out of bounds (max {})",
170 idx_vec[0],
171 block_cols - 1
172 )));
173 }
174 }
175
176 Ok(BsrArray {
177 rows,
178 cols,
179 block_size,
180 block_rows,
181 block_cols,
182 data,
183 indices,
184 indptr,
185 })
186 }
187
188 pub fn empty(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
199 let (rows, cols) = shape;
200 let (r, c) = block_size;
201
202 if r == 0 || c == 0 {
203 return Err(SparseError::ValueError(
204 "Block dimensions must be positive".to_string(),
205 ));
206 }
207
208 #[allow(clippy::manual_div_ceil)]
210 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
212 let block_cols = (cols + c - 1) / c; let data = Vec::new();
216 let indices = Vec::new();
217 let indptr = vec![0; block_rows + 1];
218
219 Ok(BsrArray {
220 rows,
221 cols,
222 block_size,
223 block_rows,
224 block_cols,
225 data,
226 indices,
227 indptr,
228 })
229 }
230
231 pub fn from_triplets(
245 row: &[usize],
246 col: &[usize],
247 data: &[T],
248 shape: (usize, usize),
249 block_size: (usize, usize),
250 ) -> SparseResult<Self> {
251 if row.len() != col.len() || row.len() != data.len() {
252 return Err(SparseError::InconsistentData {
253 reason: "Lengths of row, col, and data arrays must be equal".to_string(),
254 });
255 }
256
257 let (rows, cols) = shape;
258 let (r, c) = block_size;
259
260 if r == 0 || c == 0 {
261 return Err(SparseError::ValueError(
262 "Block dimensions must be positive".to_string(),
263 ));
264 }
265
266 #[allow(clippy::manual_div_ceil)]
268 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
270 let block_cols = (cols + c - 1) / c; let mut block_data = std::collections::HashMap::new();
274
275 for (&row_idx, (&col_idx, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
277 if row_idx >= rows || col_idx >= cols {
278 return Err(SparseError::IndexOutOfBounds {
279 index: (row_idx, col_idx),
280 shape,
281 });
282 }
283
284 let block_row = row_idx / r;
286 let block_col = col_idx / c;
287
288 let block_row_pos = row_idx % r;
290 let block_col_pos = col_idx % c;
291
292 let block = block_data.entry((block_row, block_col)).or_insert_with(|| {
294 let block = vec![vec![T::sparse_zero(); c]; r];
295 block
296 });
297
298 block[block_row_pos][block_col_pos] = val;
300 }
301
302 let mut rowswith_blocks: Vec<usize> = block_data.keys().map(|&(row_, _)| row_).collect();
304 rowswith_blocks.sort();
305 rowswith_blocks.dedup();
306
307 let mut indptr = vec![0; block_rows + 1];
309 let mut current_nnz = 0;
310
311 let mut data = Vec::new();
313 let mut indices = Vec::new();
314
315 for row_idx in 0..block_rows {
316 if rowswith_blocks.contains(&row_idx) {
317 let mut row_blocks: Vec<(usize, Vec<Vec<T>>)> = block_data
319 .iter()
320 .filter(|&(&(r, _), _)| r == row_idx)
321 .map(|(&(_, c), block)| (c, block.clone()))
322 .collect();
323
324 row_blocks.sort_by_key(|&(col_, _)| col_);
326
327 for (col, block) in row_blocks {
329 data.push(block);
330 indices.push(vec![col]);
331 current_nnz += 1;
332 }
333 }
334
335 indptr[row_idx + 1] = current_nnz;
336 }
337
338 BsrArray::new(data, indices, indptr, shape, block_size)
340 }
341
342 fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
344 let (r, c) = self.block_size;
345 let mut row_indices = Vec::new();
346 let mut col_indices = Vec::new();
347 let mut values = Vec::new();
348
349 for block_row in 0..self.block_rows {
350 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
351 let block_col = self.indices[k][0];
352 let block = &self.data[k];
353
354 for (i, block_row_data) in block.iter().enumerate().take(r) {
356 let row = block_row * r + i;
357 if row < self.rows {
358 for (j, &value) in block_row_data.iter().enumerate().take(c) {
359 let col = block_col * c + j;
360 if col < self.cols && !SparseElement::is_zero(&value) {
361 row_indices.push(row);
362 col_indices.push(col);
363 values.push(value);
364 }
365 }
366 }
367 }
368 }
369 }
370
371 (row_indices, col_indices, values)
372 }
373}
374
375impl<T> SparseArray<T> for BsrArray<T>
376where
377 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
378{
379 fn shape(&self) -> (usize, usize) {
380 (self.rows, self.cols)
381 }
382
383 fn nnz(&self) -> usize {
384 let mut count = 0;
385
386 for block in &self.data {
387 for row in block {
388 for &val in row {
389 if !SparseElement::is_zero(&val) {
390 count += 1;
391 }
392 }
393 }
394 }
395
396 count
397 }
398
399 fn dtype(&self) -> &str {
400 "float" }
402
403 fn to_array(&self) -> Array2<T> {
404 let mut result = Array2::zeros((self.rows, self.cols));
405 let (r, c) = self.block_size;
406
407 for block_row in 0..self.block_rows {
408 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
409 let block_col = self.indices[k][0];
410 let block = &self.data[k];
411
412 for (i, block_row_data) in block.iter().enumerate().take(r) {
413 let row = block_row * r + i;
414 if row < self.rows {
415 for (j, &value) in block_row_data.iter().enumerate().take(c) {
416 let col = block_col * c + j;
417 if col < self.cols {
418 result[[row, col]] = value;
419 }
420 }
421 }
422 }
423 }
424 }
425
426 result
427 }
428
429 fn toarray(&self) -> Array2<T> {
430 self.to_array()
431 }
432
433 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
434 let (row_indices, col_indices, values) = self.to_coo_internal();
435 CooArray::from_triplets(
436 &row_indices,
437 &col_indices,
438 &values,
439 (self.rows, self.cols),
440 false,
441 )
442 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
443 }
444
445 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
446 let (row_indices, col_indices, values) = self.to_coo_internal();
447 CsrArray::from_triplets(
448 &row_indices,
449 &col_indices,
450 &values,
451 (self.rows, self.cols),
452 false,
453 )
454 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
455 }
456
457 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
458 let (row_indices, col_indices, values) = self.to_coo_internal();
459 CscArray::from_triplets(
460 &row_indices,
461 &col_indices,
462 &values,
463 (self.rows, self.cols),
464 false,
465 )
466 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
467 }
468
469 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
470 let (row_indices, col_indices, values) = self.to_coo_internal();
471 DokArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
472 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
473 }
474
475 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
476 let (row_indices, col_indices, values) = self.to_coo_internal();
477 LilArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
478 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
479 }
480
481 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
482 let (row_indices, col_indices, values) = self.to_coo_internal();
483 DiaArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
484 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
485 }
486
487 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
488 Ok(Box::new(self.clone()))
489 }
490
491 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
492 let csr_self = self.to_csr()?;
494 let csr_other = other.to_csr()?;
495 csr_self.add(&*csr_other)
496 }
497
498 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
499 let csr_self = self.to_csr()?;
501 let csr_other = other.to_csr()?;
502 csr_self.sub(&*csr_other)
503 }
504
505 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
506 let csr_self = self.to_csr()?;
508 let csr_other = other.to_csr()?;
509 csr_self.mul(&*csr_other)
510 }
511
512 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
513 let csr_self = self.to_csr()?;
515 let csr_other = other.to_csr()?;
516 csr_self.div(&*csr_other)
517 }
518
519 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
520 let (_, n) = self.shape();
521 let (p, q) = other.shape();
522
523 if n != p {
524 return Err(SparseError::DimensionMismatch {
525 expected: n,
526 found: p,
527 });
528 }
529
530 if q == 1 {
532 let other_array = other.to_array();
534 let vec_view = other_array.column(0);
535
536 let result = self.dot_vector(&vec_view)?;
538
539 let mut rows = Vec::new();
541 let mut cols = Vec::new();
542 let mut values = Vec::new();
543
544 for (i, &val) in result.iter().enumerate() {
545 if !SparseElement::is_zero(&val) {
546 rows.push(i);
547 cols.push(0);
548 values.push(val);
549 }
550 }
551
552 CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
553 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
554 } else {
555 let csr_self = self.to_csr()?;
557 csr_self.dot(other)
558 }
559 }
560
561 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
562 let (rows, cols) = self.shape();
563 let (r, c) = self.block_size;
564
565 if cols != other.len() {
566 return Err(SparseError::DimensionMismatch {
567 expected: cols,
568 found: other.len(),
569 });
570 }
571
572 let mut result = Array1::zeros(rows);
573
574 for block_row in 0..self.block_rows {
575 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
576 let block_col = self.indices[k][0];
577 let block = &self.data[k];
578
579 for (i, block_row_data) in block.iter().enumerate().take(r) {
581 let row = block_row * r + i;
582 if row < self.rows {
583 for (j, &value) in block_row_data.iter().enumerate().take(c) {
584 let col = block_col * c + j;
585 if col < self.cols {
586 result[row] += value * other[col];
587 }
588 }
589 }
590 }
591 }
592 }
593
594 Ok(result)
595 }
596
597 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
598 self.to_coo()?.transpose()?.to_bsr()
600 }
601
602 fn copy(&self) -> Box<dyn SparseArray<T>> {
603 Box::new(self.clone())
604 }
605
606 fn get(&self, i: usize, j: usize) -> T {
607 if i >= self.rows || j >= self.cols {
608 return T::sparse_zero();
609 }
610
611 let (r, c) = self.block_size;
612 let block_row = i / r;
613 let block_col = j / c;
614 let block_row_pos = i % r;
615 let block_col_pos = j % c;
616
617 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
619 if self.indices[k][0] == block_col {
620 return self.data[k][block_row_pos][block_col_pos];
621 }
622 }
623
624 T::sparse_zero()
625 }
626
627 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
628 if i >= self.rows || j >= self.cols {
629 return Err(SparseError::IndexOutOfBounds {
630 index: (i, j),
631 shape: (self.rows, self.cols),
632 });
633 }
634
635 let (r, c) = self.block_size;
636 let block_row = i / r;
637 let block_col = j / c;
638 let block_row_pos = i % r;
639 let block_col_pos = j % c;
640
641 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
643 if self.indices[k][0] == block_col {
644 self.data[k][block_row_pos][block_col_pos] = value;
646 return Ok(());
647 }
648 }
649
650 if !SparseElement::is_zero(&value) {
652 let pos = self.indptr[block_row + 1];
654
655 let mut block = vec![vec![T::sparse_zero(); c]; r];
657 block[block_row_pos][block_col_pos] = value;
658
659 self.data.insert(pos, block);
661 self.indices.insert(pos, vec![block_col]);
662
663 for k in (block_row + 1)..=self.block_rows {
665 self.indptr[k] += 1;
666 }
667
668 Ok(())
669 } else {
670 Ok(())
672 }
673 }
674
675 fn eliminate_zeros(&mut self) {
676 let mut new_data = Vec::new();
678 let mut new_indices = Vec::new();
679 let mut new_indptr = vec![0];
680 let mut current_nnz = 0;
681
682 for block_row in 0..self.block_rows {
683 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
684 let block_col = self.indices[k][0];
685 let block = &self.data[k];
686
687 let mut has_nonzero = false;
689 for row in block {
690 for &val in row {
691 if !SparseElement::is_zero(&val) {
692 has_nonzero = true;
693 break;
694 }
695 }
696 if has_nonzero {
697 break;
698 }
699 }
700
701 if has_nonzero {
702 new_data.push(block.clone());
703 new_indices.push(vec![block_col]);
704 current_nnz += 1;
705 }
706 }
707
708 new_indptr.push(current_nnz);
709 }
710
711 self.data = new_data;
712 self.indices = new_indices;
713 self.indptr = new_indptr;
714 }
715
716 fn sort_indices(&mut self) {
717 let mut new_data = Vec::new();
719 let mut new_indices = Vec::new();
720 let mut new_indptr = vec![0];
721 let mut current_nnz = 0;
722
723 for block_row in 0..self.block_rows {
724 let mut row_blocks = Vec::new();
726 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
727 row_blocks.push((self.indices[k][0], self.data[k].clone()));
728 }
729
730 row_blocks.sort_by_key(|&(col_, _)| col_);
732
733 for (col, block) in row_blocks {
735 new_data.push(block);
736 new_indices.push(vec![col]);
737 current_nnz += 1;
738 }
739
740 new_indptr.push(current_nnz);
741 }
742
743 self.data = new_data;
744 self.indices = new_indices;
745 self.indptr = new_indptr;
746 }
747
748 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
749 let mut result = self.clone();
750 result.sort_indices();
751 Box::new(result)
752 }
753
754 fn has_sorted_indices(&self) -> bool {
755 for block_row in 0..self.block_rows {
756 let mut prev_col = None;
757
758 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
759 let col = self.indices[k][0];
760
761 if let Some(prev) = prev_col {
762 if col <= prev {
763 return false;
764 }
765 }
766
767 prev_col = Some(col);
768 }
769 }
770
771 true
772 }
773
774 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
775 match axis {
776 None => {
777 let mut total = T::sparse_zero();
779
780 for block in &self.data {
781 for row in block {
782 for &val in row {
783 total += val;
784 }
785 }
786 }
787
788 Ok(SparseSum::Scalar(total))
789 }
790 Some(0) => {
791 let mut result = vec![T::sparse_zero(); self.cols];
793 let (r, c) = self.block_size;
794
795 for block_row in 0..self.block_rows {
796 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
797 let block_col = self.indices[k][0];
798 let block = &self.data[k];
799
800 for block_row_data in block.iter().take(r) {
801 for (j, &value) in block_row_data.iter().enumerate().take(c) {
802 let col = block_col * c + j;
803 if col < self.cols {
804 result[col] += value;
805 }
806 }
807 }
808 }
809 }
810
811 let mut row_indices = Vec::new();
813 let mut col_indices = Vec::new();
814 let mut values = Vec::new();
815
816 for (j, &val) in result.iter().enumerate() {
817 if !SparseElement::is_zero(&val) {
818 row_indices.push(0);
819 col_indices.push(j);
820 values.push(val);
821 }
822 }
823
824 match CooArray::from_triplets(
825 &row_indices,
826 &col_indices,
827 &values,
828 (1, self.cols),
829 false,
830 ) {
831 Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
832 Err(e) => Err(e),
833 }
834 }
835 Some(1) => {
836 let mut result = vec![T::sparse_zero(); self.rows];
838 let (r, c) = self.block_size;
839
840 for block_row in 0..self.block_rows {
841 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
842 let block = &self.data[k];
843
844 for (i, block_row_data) in block.iter().enumerate().take(r) {
845 let row = block_row * r + i;
846 if row < self.rows {
847 for &value in block_row_data.iter().take(c) {
848 result[row] += value;
849 }
850 }
851 }
852 }
853 }
854
855 let mut row_indices = Vec::new();
857 let mut col_indices = Vec::new();
858 let mut values = Vec::new();
859
860 for (i, &val) in result.iter().enumerate() {
861 if !SparseElement::is_zero(&val) {
862 row_indices.push(i);
863 col_indices.push(0);
864 values.push(val);
865 }
866 }
867
868 match CooArray::from_triplets(
869 &row_indices,
870 &col_indices,
871 &values,
872 (self.rows, 1),
873 false,
874 ) {
875 Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
876 Err(e) => Err(e),
877 }
878 }
879 _ => Err(SparseError::InvalidAxis),
880 }
881 }
882
883 fn max(&self) -> T {
884 let mut max_val = T::neg_infinity();
885
886 for block in &self.data {
887 for row in block {
888 for &val in row {
889 max_val = max_val.max(val);
890 }
891 }
892 }
893
894 if max_val == T::neg_infinity() {
896 T::sparse_zero()
897 } else {
898 max_val
899 }
900 }
901
902 fn min(&self) -> T {
903 let mut min_val = T::sparse_zero();
904 let mut has_nonzero = false;
905
906 for block in &self.data {
907 for row in block {
908 for &val in row {
909 if !SparseElement::is_zero(&val) {
910 has_nonzero = true;
911 min_val = min_val.min(val);
912 }
913 }
914 }
915 }
916
917 if !has_nonzero {
919 T::sparse_zero()
920 } else {
921 min_val
922 }
923 }
924
925 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
926 let (row_indices, col_indices, values) = self.to_coo_internal();
927
928 (
929 Array1::from_vec(row_indices),
930 Array1::from_vec(col_indices),
931 Array1::from_vec(values),
932 )
933 }
934
935 fn slice(
936 &self,
937 row_range: (usize, usize),
938 col_range: (usize, usize),
939 ) -> SparseResult<Box<dyn SparseArray<T>>> {
940 let (start_row, end_row) = row_range;
941 let (start_col, end_col) = col_range;
942 let (rows, cols) = self.shape();
943
944 if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
945 return Err(SparseError::IndexOutOfBounds {
946 index: (start_row.max(end_row), start_col.max(end_col)),
947 shape: (rows, cols),
948 });
949 }
950
951 if start_row >= end_row || start_col >= end_col {
952 return Err(SparseError::InvalidSliceRange);
953 }
954
955 let coo = self.to_coo()?;
957 coo.slice(row_range, col_range)?.to_bsr()
958 }
959
960 fn as_any(&self) -> &dyn std::any::Any {
961 self
962 }
963}
964
965impl<T> fmt::Display for BsrArray<T>
967where
968 T: SparseElement + Div<Output = T> + Float + 'static + std::ops::AddAssign,
969{
970 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
971 writeln!(
972 f,
973 "BsrArray of shape {:?} with {} stored elements",
974 (self.rows, self.cols),
975 self.nnz()
976 )?;
977 writeln!(f, "Block size: {:?}", self.block_size)?;
978 writeln!(f, "Number of blocks: {}", self.data.len())?;
979
980 if self.data.len() <= 5 {
981 for block_row in 0..self.block_rows {
982 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
983 let block_col = self.indices[k][0];
984 let block = &self.data[k];
985
986 writeln!(f, "Block at ({block_row}, {block_col}): ")?;
987 for row in block {
988 write!(f, " [")?;
989 for (j, &val) in row.iter().enumerate() {
990 if j > 0 {
991 write!(f, ", ")?;
992 }
993 write!(f, "{val:?}")?;
994 }
995 writeln!(f, "]")?;
996 }
997 }
998 }
999 } else {
1000 writeln!(f, "({} blocks total)", self.data.len())?;
1001 }
1002
1003 Ok(())
1004 }
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009 use super::*;
1010
1011 #[test]
1012 fn test_bsr_array_create() {
1013 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1020 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1021
1022 let data = vec![block1, block2];
1023 let indices = vec![vec![0], vec![1]];
1024 let indptr = vec![0, 1, 2];
1025
1026 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1027
1028 assert_eq!(array.shape(), (4, 4));
1029 assert_eq!(array.block_size, (2, 2));
1030 assert_eq!(array.nnz(), 8); assert_eq!(array.get(0, 0), 1.0);
1034 assert_eq!(array.get(0, 1), 2.0);
1035 assert_eq!(array.get(1, 0), 3.0);
1036 assert_eq!(array.get(1, 1), 4.0);
1037 assert_eq!(array.get(2, 2), 5.0);
1038 assert_eq!(array.get(2, 3), 6.0);
1039 assert_eq!(array.get(3, 2), 7.0);
1040 assert_eq!(array.get(3, 3), 8.0);
1041 assert_eq!(array.get(0, 2), 0.0); }
1043
1044 #[test]
1045 fn test_bsr_array_from_triplets() {
1046 let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
1048 let cols = vec![0, 1, 0, 1, 2, 3, 2, 3];
1049 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1050 let shape = (4, 4);
1051 let block_size = (2, 2);
1052
1053 let array = BsrArray::from_triplets(&rows, &cols, &data, shape, block_size).unwrap();
1054
1055 assert_eq!(array.shape(), (4, 4));
1056 assert_eq!(array.block_size, (2, 2));
1057 assert_eq!(array.nnz(), 8);
1058
1059 assert_eq!(array.get(0, 0), 1.0);
1061 assert_eq!(array.get(0, 1), 2.0);
1062 assert_eq!(array.get(1, 0), 3.0);
1063 assert_eq!(array.get(1, 1), 4.0);
1064 assert_eq!(array.get(2, 2), 5.0);
1065 assert_eq!(array.get(2, 3), 6.0);
1066 assert_eq!(array.get(3, 2), 7.0);
1067 assert_eq!(array.get(3, 3), 8.0);
1068 assert_eq!(array.get(0, 2), 0.0); }
1070
1071 #[test]
1072 fn test_bsr_array_conversion() {
1073 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1075 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1076
1077 let data = vec![block1, block2];
1078 let indices = vec![vec![0], vec![1]];
1079 let indptr = vec![0, 1, 2];
1080
1081 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1082
1083 let coo = array.to_coo().unwrap();
1085 assert_eq!(coo.shape(), (4, 4));
1086 assert_eq!(coo.nnz(), 8);
1087
1088 let csr = array.to_csr().unwrap();
1090 assert_eq!(csr.shape(), (4, 4));
1091 assert_eq!(csr.nnz(), 8);
1092
1093 let dense = array.to_array();
1095 let expected = Array2::from_shape_vec(
1096 (4, 4),
1097 vec![
1098 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 7.0, 8.0,
1099 ],
1100 )
1101 .unwrap();
1102 assert_eq!(dense, expected);
1103 }
1104
1105 #[test]
1106 fn test_bsr_array_operations() {
1107 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1109 let data1 = vec![block1];
1110 let indices1 = vec![vec![0]];
1111 let indptr1 = vec![0, 1];
1112 let array1 = BsrArray::new(data1, indices1, indptr1, (2, 2), (2, 2)).unwrap();
1113
1114 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1115 let data2 = vec![block2];
1116 let indices2 = vec![vec![0]];
1117 let indptr2 = vec![0, 1];
1118 let array2 = BsrArray::new(data2, indices2, indptr2, (2, 2), (2, 2)).unwrap();
1119
1120 let sum = array1.add(&array2).unwrap();
1122 assert_eq!(sum.shape(), (2, 2));
1123 assert_eq!(sum.get(0, 0), 6.0); assert_eq!(sum.get(0, 1), 8.0); assert_eq!(sum.get(1, 0), 10.0); assert_eq!(sum.get(1, 1), 12.0); let product = array1.mul(&array2).unwrap();
1130 assert_eq!(product.shape(), (2, 2));
1131 assert_eq!(product.get(0, 0), 5.0); assert_eq!(product.get(0, 1), 12.0); assert_eq!(product.get(1, 0), 21.0); assert_eq!(product.get(1, 1), 32.0); let dot = array1.dot(&array2).unwrap();
1138 assert_eq!(dot.shape(), (2, 2));
1139 assert_eq!(dot.get(0, 0), 19.0); assert_eq!(dot.get(0, 1), 22.0); assert_eq!(dot.get(1, 0), 43.0); assert_eq!(dot.get(1, 1), 50.0); }
1144
1145 #[test]
1146 fn test_bsr_array_dot_vector() {
1147 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1149 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1150
1151 let data = vec![block1, block2];
1152 let indices = vec![vec![0], vec![1]];
1153 let indptr = vec![0, 1, 2];
1154
1155 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1156
1157 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1159
1160 let result = array.dot_vector(&vector.view()).unwrap();
1162
1163 let expected = Array1::from_vec(vec![5.0, 11.0, 39.0, 53.0]);
1167 assert_eq!(result, expected);
1168 }
1169
1170 #[test]
1171 fn test_bsr_array_sum() {
1172 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1174 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1175
1176 let data = vec![block1, block2];
1177 let indices = vec![vec![0], vec![1]];
1178 let indptr = vec![0, 1, 2];
1179
1180 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1181
1182 if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1184 assert_eq!(sum, 36.0); } else {
1186 panic!("Expected SparseSum::Scalar");
1187 }
1188
1189 if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1191 assert_eq!(row_sum.shape(), (1, 4));
1192 assert_eq!(row_sum.get(0, 0), 4.0); assert_eq!(row_sum.get(0, 1), 6.0); assert_eq!(row_sum.get(0, 2), 12.0); assert_eq!(row_sum.get(0, 3), 14.0); } else {
1197 panic!("Expected SparseSum::SparseArray");
1198 }
1199
1200 if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1202 assert_eq!(col_sum.shape(), (4, 1));
1203 assert_eq!(col_sum.get(0, 0), 3.0); assert_eq!(col_sum.get(1, 0), 7.0); assert_eq!(col_sum.get(2, 0), 11.0); assert_eq!(col_sum.get(3, 0), 15.0); } else {
1208 panic!("Expected SparseSum::SparseArray");
1209 }
1210 }
1211}