1use oxiblas_core::scalar::{Field, Scalar};
26
27#[derive(Debug, Clone, PartialEq, Eq)]
29pub enum BsrError {
30 InvalidBlockSize {
32 block_rows: usize,
34 block_cols: usize,
36 },
37 IncompatibleDimensions {
39 nrows: usize,
41 ncols: usize,
43 block_rows: usize,
45 block_cols: usize,
47 },
48 InvalidIndptr {
50 expected: usize,
52 actual: usize,
54 },
55 DataIndicesMismatch {
57 num_blocks: usize,
59 num_indices: usize,
61 },
62 InvalidBlockIndex {
64 index: usize,
66 nb_cols: usize,
68 },
69 InvalidIndptrOrder,
71 InvalidBlockData {
73 block_idx: usize,
75 expected: usize,
77 actual: usize,
79 },
80}
81
82impl core::fmt::Display for BsrError {
83 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
84 match self {
85 Self::InvalidBlockSize {
86 block_rows,
87 block_cols,
88 } => {
89 write!(f, "Invalid block size: {block_rows}×{block_cols}")
90 }
91 Self::IncompatibleDimensions {
92 nrows,
93 ncols,
94 block_rows,
95 block_cols,
96 } => {
97 write!(
98 f,
99 "Matrix {nrows}×{ncols} incompatible with {block_rows}×{block_cols} blocks"
100 )
101 }
102 Self::InvalidIndptr { expected, actual } => {
103 write!(
104 f,
105 "Invalid indptr length: expected {expected}, got {actual}"
106 )
107 }
108 Self::DataIndicesMismatch {
109 num_blocks,
110 num_indices,
111 } => {
112 write!(f, "Mismatch: {num_blocks} blocks but {num_indices} indices")
113 }
114 Self::InvalidBlockIndex { index, nb_cols } => {
115 write!(
116 f,
117 "Block column index {index} out of bounds (nb_cols={nb_cols})"
118 )
119 }
120 Self::InvalidIndptrOrder => {
121 write!(f, "Indptr must be monotonically increasing")
122 }
123 Self::InvalidBlockData {
124 block_idx,
125 expected,
126 actual,
127 } => {
128 write!(
129 f,
130 "Block {block_idx}: expected {expected} elements, got {actual}"
131 )
132 }
133 }
134 }
135}
136
137impl std::error::Error for BsrError {}
138
139#[derive(Debug, Clone)]
141pub struct DenseBlock<T: Scalar> {
142 data: Vec<T>,
144 rows: usize,
146 cols: usize,
148}
149
150impl<T: Scalar + Clone> DenseBlock<T> {
151 pub fn new(rows: usize, cols: usize, data: Vec<T>) -> Self {
153 debug_assert_eq!(data.len(), rows * cols);
154 Self { data, rows, cols }
155 }
156
157 pub fn zeros(rows: usize, cols: usize) -> Self
159 where
160 T: Field,
161 {
162 Self {
163 data: vec![T::zero(); rows * cols],
164 rows,
165 cols,
166 }
167 }
168
169 #[inline]
171 pub fn get(&self, i: usize, j: usize) -> &T {
172 &self.data[i * self.cols + j]
173 }
174
175 #[inline]
177 pub fn get_mut(&mut self, i: usize, j: usize) -> &mut T {
178 &mut self.data[i * self.cols + j]
179 }
180
181 #[inline]
183 pub fn shape(&self) -> (usize, usize) {
184 (self.rows, self.cols)
185 }
186
187 #[inline]
189 pub fn data(&self) -> &[T] {
190 &self.data
191 }
192
193 #[inline]
195 pub fn data_mut(&mut self) -> &mut [T] {
196 &mut self.data
197 }
198
199 pub fn matvec_add(&self, x: &[T], y: &mut [T])
201 where
202 T: Field,
203 {
204 for i in 0..self.rows {
205 for j in 0..self.cols {
206 y[i] = y[i].clone() + self.get(i, j).clone() * x[j].clone();
207 }
208 }
209 }
210
211 pub fn scale(&mut self, alpha: T) {
213 for val in &mut self.data {
214 *val = val.clone() * alpha.clone();
215 }
216 }
217
218 pub fn frobenius_norm_sq(&self) -> T
220 where
221 T: Field,
222 {
223 self.data
224 .iter()
225 .fold(T::zero(), |acc, val| acc + val.clone() * val.clone())
226 }
227}
228
229#[derive(Debug, Clone)]
266pub struct BsrMatrix<T: Scalar> {
267 nrows: usize,
269 ncols: usize,
271 block_rows: usize,
273 block_cols: usize,
275 mb: usize,
277 nb: usize,
279 indptr: Vec<usize>,
281 indices: Vec<usize>,
283 data: Vec<DenseBlock<T>>,
285}
286
287impl<T: Scalar + Clone> BsrMatrix<T> {
288 pub fn new(
304 nrows: usize,
305 ncols: usize,
306 block_rows: usize,
307 block_cols: usize,
308 indptr: Vec<usize>,
309 indices: Vec<usize>,
310 data: Vec<DenseBlock<T>>,
311 ) -> Result<Self, BsrError> {
312 if block_rows == 0 || block_cols == 0 {
314 return Err(BsrError::InvalidBlockSize {
315 block_rows,
316 block_cols,
317 });
318 }
319
320 let mb = nrows.div_ceil(block_rows);
322 let nb = ncols.div_ceil(block_cols);
323
324 if indptr.len() != mb + 1 {
326 return Err(BsrError::InvalidIndptr {
327 expected: mb + 1,
328 actual: indptr.len(),
329 });
330 }
331
332 for i in 1..indptr.len() {
334 if indptr[i] < indptr[i - 1] {
335 return Err(BsrError::InvalidIndptrOrder);
336 }
337 }
338
339 let nnz_blocks = data.len();
341 if indices.len() != nnz_blocks {
342 return Err(BsrError::DataIndicesMismatch {
343 num_blocks: nnz_blocks,
344 num_indices: indices.len(),
345 });
346 }
347
348 if indptr[mb] != nnz_blocks {
350 return Err(BsrError::InvalidIndptr {
351 expected: nnz_blocks,
352 actual: indptr[mb],
353 });
354 }
355
356 for &idx in &indices {
358 if idx >= nb {
359 return Err(BsrError::InvalidBlockIndex {
360 index: idx,
361 nb_cols: nb,
362 });
363 }
364 }
365
366 let block_size = block_rows * block_cols;
368 for (i, block) in data.iter().enumerate() {
369 if block.data.len() != block_size {
370 return Err(BsrError::InvalidBlockData {
371 block_idx: i,
372 expected: block_size,
373 actual: block.data.len(),
374 });
375 }
376 }
377
378 Ok(Self {
379 nrows,
380 ncols,
381 block_rows,
382 block_cols,
383 mb,
384 nb,
385 indptr,
386 indices,
387 data,
388 })
389 }
390
391 #[inline]
397 pub unsafe fn new_unchecked(
398 nrows: usize,
399 ncols: usize,
400 block_rows: usize,
401 block_cols: usize,
402 indptr: Vec<usize>,
403 indices: Vec<usize>,
404 data: Vec<DenseBlock<T>>,
405 ) -> Self {
406 let mb = nrows.div_ceil(block_rows);
407 let nb = ncols.div_ceil(block_cols);
408 Self {
409 nrows,
410 ncols,
411 block_rows,
412 block_cols,
413 mb,
414 nb,
415 indptr,
416 indices,
417 data,
418 }
419 }
420
421 pub fn zeros(nrows: usize, ncols: usize, block_rows: usize, block_cols: usize) -> Self {
423 let mb = nrows.div_ceil(block_rows);
424 Self {
425 nrows,
426 ncols,
427 block_rows,
428 block_cols,
429 mb,
430 nb: ncols.div_ceil(block_cols),
431 indptr: vec![0; mb + 1],
432 indices: Vec::new(),
433 data: Vec::new(),
434 }
435 }
436
437 pub fn eye(n: usize, block_size: usize) -> Self
439 where
440 T: Field,
441 {
442 let mb = n.div_ceil(block_size);
443 let mut indptr = Vec::with_capacity(mb + 1);
444 let mut indices = Vec::with_capacity(mb);
445 let mut data = Vec::with_capacity(mb);
446
447 indptr.push(0);
448
449 for bi in 0..mb {
450 indices.push(bi);
451
452 let mut block_data = vec![T::zero(); block_size * block_size];
454 for i in 0..block_size {
455 let global_row = bi * block_size + i;
456 if global_row < n {
457 block_data[i * block_size + i] = T::one();
458 }
459 }
460 data.push(DenseBlock::new(block_size, block_size, block_data));
461
462 indptr.push(data.len());
463 }
464
465 Self {
466 nrows: n,
467 ncols: n,
468 block_rows: block_size,
469 block_cols: block_size,
470 mb,
471 nb: mb,
472 indptr,
473 indices,
474 data,
475 }
476 }
477
478 #[inline]
480 pub fn nrows(&self) -> usize {
481 self.nrows
482 }
483
484 #[inline]
486 pub fn ncols(&self) -> usize {
487 self.ncols
488 }
489
490 #[inline]
492 pub fn shape(&self) -> (usize, usize) {
493 (self.nrows, self.ncols)
494 }
495
496 #[inline]
498 pub fn block_shape(&self) -> (usize, usize) {
499 (self.block_rows, self.block_cols)
500 }
501
502 #[inline]
504 pub fn nblock_rows(&self) -> usize {
505 self.mb
506 }
507
508 #[inline]
510 pub fn nblock_cols(&self) -> usize {
511 self.nb
512 }
513
514 #[inline]
516 pub fn nblocks(&self) -> usize {
517 self.data.len()
518 }
519
520 pub fn nnz(&self) -> usize
524 where
525 T: Field,
526 {
527 let eps = <T as Scalar>::epsilon();
528 let mut count = 0;
529
530 for block in &self.data {
531 for val in block.data() {
532 if Scalar::abs(val.clone()) > eps {
533 count += 1;
534 }
535 }
536 }
537
538 count
539 }
540
541 #[inline]
543 pub fn nstored(&self) -> usize {
544 self.data.len() * self.block_rows * self.block_cols
545 }
546
547 #[inline]
549 pub fn indptr(&self) -> &[usize] {
550 &self.indptr
551 }
552
553 #[inline]
555 pub fn indices(&self) -> &[usize] {
556 &self.indices
557 }
558
559 #[inline]
561 pub fn data(&self) -> &[DenseBlock<T>] {
562 &self.data
563 }
564
565 #[inline]
567 pub fn data_mut(&mut self) -> &mut [DenseBlock<T>] {
568 &mut self.data
569 }
570
571 pub fn get_block(&self, bi: usize, bj: usize) -> Option<&DenseBlock<T>> {
573 if bi >= self.mb || bj >= self.nb {
574 return None;
575 }
576
577 let start = self.indptr[bi];
578 let end = self.indptr[bi + 1];
579
580 for k in start..end {
581 if self.indices[k] == bj {
582 return Some(&self.data[k]);
583 }
584 }
585
586 None
587 }
588
589 pub fn get(&self, row: usize, col: usize) -> Option<T>
591 where
592 T: Field,
593 {
594 if row >= self.nrows || col >= self.ncols {
595 return None;
596 }
597
598 let bi = row / self.block_rows;
599 let bj = col / self.block_cols;
600 let local_i = row % self.block_rows;
601 let local_j = col % self.block_cols;
602
603 self.get_block(bi, bj)
604 .map(|block| block.get(local_i, local_j).clone())
605 }
606
607 pub fn get_or_zero(&self, row: usize, col: usize) -> T
609 where
610 T: Field,
611 {
612 self.get(row, col).unwrap_or_else(T::zero)
613 }
614
615 pub fn block_iter(&self) -> impl Iterator<Item = (usize, usize, &DenseBlock<T>)> + '_ {
617 (0..self.mb).flat_map(move |bi| {
618 let start = self.indptr[bi];
619 let end = self.indptr[bi + 1];
620
621 (start..end).map(move |k| (bi, self.indices[k], &self.data[k]))
622 })
623 }
624
625 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, T)> + '_
627 where
628 T: Field,
629 {
630 let eps = <T as Scalar>::epsilon();
631 let br = self.block_rows;
632 let bc = self.block_cols;
633 let nrows = self.nrows;
634 let ncols = self.ncols;
635
636 self.block_iter().flat_map(move |(bi, bj, block)| {
637 let base_row = bi * br;
638 let base_col = bj * bc;
639
640 (0..br).flat_map(move |i| {
641 (0..bc).filter_map(move |j| {
642 let global_row = base_row + i;
643 let global_col = base_col + j;
644
645 if global_row < nrows && global_col < ncols {
646 let val = block.get(i, j).clone();
647 if Scalar::abs(val.clone()) > eps {
648 return Some((global_row, global_col, val));
649 }
650 }
651 None
652 })
653 })
654 })
655 }
656
657 pub fn matvec(&self, x: &[T], y: &mut [T])
659 where
660 T: Field,
661 {
662 assert_eq!(x.len(), self.ncols, "x length must equal ncols");
663 assert_eq!(y.len(), self.nrows, "y length must equal nrows");
664
665 for yi in y.iter_mut() {
667 *yi = T::zero();
668 }
669
670 for bi in 0..self.mb {
672 let start = self.indptr[bi];
673 let end = self.indptr[bi + 1];
674 let row_start = bi * self.block_rows;
675 let row_end = (row_start + self.block_rows).min(self.nrows);
676
677 for k in start..end {
678 let bj = self.indices[k];
679 let block = &self.data[k];
680 let col_start = bj * self.block_cols;
681 let col_end = (col_start + self.block_cols).min(self.ncols);
682
683 for (i, yi) in y[row_start..row_end].iter_mut().enumerate() {
685 for j in 0..(col_end - col_start) {
686 *yi = yi.clone() + block.get(i, j).clone() * x[col_start + j].clone();
687 }
688 }
689 }
690 }
691 }
692
693 pub fn mul_vec(&self, x: &[T]) -> Vec<T>
695 where
696 T: Field,
697 {
698 let mut y = vec![T::zero(); self.nrows];
699 self.matvec(x, &mut y);
700 y
701 }
702
703 pub fn to_csr(&self) -> crate::csr::CsrMatrix<T>
705 where
706 T: Field,
707 {
708 let eps = <T as Scalar>::epsilon();
709
710 let mut row_ptrs = vec![0usize; self.nrows + 1];
711 let mut col_indices = Vec::new();
712 let mut values = Vec::new();
713
714 for row in 0..self.nrows {
715 let bi = row / self.block_rows;
716 let local_i = row % self.block_rows;
717
718 let block_start = self.indptr[bi];
719 let block_end = self.indptr[bi + 1];
720
721 let mut row_entries: Vec<(usize, T)> = Vec::new();
722
723 for k in block_start..block_end {
724 let bj = self.indices[k];
725 let block = &self.data[k];
726
727 for j in 0..self.block_cols {
728 let global_col = bj * self.block_cols + j;
729 if global_col < self.ncols {
730 let val = block.get(local_i, j).clone();
731 if Scalar::abs(val.clone()) > eps {
732 row_entries.push((global_col, val));
733 }
734 }
735 }
736 }
737
738 row_entries.sort_by_key(|(col, _)| *col);
740
741 for (col, val) in row_entries {
742 col_indices.push(col);
743 values.push(val);
744 }
745 row_ptrs[row + 1] = values.len();
746 }
747
748 unsafe {
750 crate::csr::CsrMatrix::new_unchecked(
751 self.nrows,
752 self.ncols,
753 row_ptrs,
754 col_indices,
755 values,
756 )
757 }
758 }
759
760 pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
762 where
763 T: Field + bytemuck::Zeroable,
764 {
765 let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
766
767 for bi in 0..self.mb {
768 let start = self.indptr[bi];
769 let end = self.indptr[bi + 1];
770 let row_start = bi * self.block_rows;
771
772 for k in start..end {
773 let bj = self.indices[k];
774 let block = &self.data[k];
775 let col_start = bj * self.block_cols;
776
777 for i in 0..self.block_rows {
778 let global_row = row_start + i;
779 if global_row >= self.nrows {
780 break;
781 }
782
783 for j in 0..self.block_cols {
784 let global_col = col_start + j;
785 if global_col >= self.ncols {
786 break;
787 }
788
789 dense[(global_row, global_col)] = block.get(i, j).clone();
790 }
791 }
792 }
793 }
794
795 dense
796 }
797
798 pub fn from_dense(
806 dense: &oxiblas_matrix::MatRef<'_, T>,
807 block_rows: usize,
808 block_cols: usize,
809 ) -> Self
810 where
811 T: Field,
812 {
813 let (nrows, ncols) = dense.shape();
814 let eps = <T as Scalar>::epsilon();
815
816 let mb = nrows.div_ceil(block_rows);
817 let nb = ncols.div_ceil(block_cols);
818
819 let mut indptr = Vec::with_capacity(mb + 1);
820 let mut indices = Vec::new();
821 let mut data = Vec::new();
822
823 indptr.push(0);
824
825 for bi in 0..mb {
826 let row_start = bi * block_rows;
827 let row_end = (row_start + block_rows).min(nrows);
828
829 for bj in 0..nb {
830 let col_start = bj * block_cols;
831 let col_end = (col_start + block_cols).min(ncols);
832
833 let mut has_nonzero = false;
835 for i in row_start..row_end {
836 for j in col_start..col_end {
837 if Scalar::abs(dense[(i, j)].clone()) > eps {
838 has_nonzero = true;
839 break;
840 }
841 }
842 if has_nonzero {
843 break;
844 }
845 }
846
847 if has_nonzero {
848 let mut block_data = vec![T::zero(); block_rows * block_cols];
850 for i in 0..block_rows {
851 let global_row = row_start + i;
852 if global_row >= nrows {
853 break;
854 }
855 for j in 0..block_cols {
856 let global_col = col_start + j;
857 if global_col >= ncols {
858 break;
859 }
860 block_data[i * block_cols + j] =
861 dense[(global_row, global_col)].clone();
862 }
863 }
864
865 indices.push(bj);
866 data.push(DenseBlock::new(block_rows, block_cols, block_data));
867 }
868 }
869
870 indptr.push(data.len());
871 }
872
873 Self {
874 nrows,
875 ncols,
876 block_rows,
877 block_cols,
878 mb,
879 nb,
880 indptr,
881 indices,
882 data,
883 }
884 }
885
886 pub fn from_csr(csr: &crate::csr::CsrMatrix<T>, block_rows: usize, block_cols: usize) -> Self
894 where
895 T: Field,
896 {
897 let (nrows, ncols) = csr.shape();
898 let eps = <T as Scalar>::epsilon();
899
900 let mb = nrows.div_ceil(block_rows);
901 let nb = ncols.div_ceil(block_cols);
902
903 let mut indptr = Vec::with_capacity(mb + 1);
904 let mut indices = Vec::new();
905 let mut data = Vec::new();
906
907 indptr.push(0);
908
909 for bi in 0..mb {
910 let row_start = bi * block_rows;
911 let row_end = (row_start + block_rows).min(nrows);
912
913 let mut block_cols_present = std::collections::HashSet::new();
915 for row in row_start..row_end {
916 for (col, _) in csr.row_iter(row) {
917 let bj = col / block_cols;
918 block_cols_present.insert(bj);
919 }
920 }
921
922 let mut sorted_bjs: Vec<_> = block_cols_present.into_iter().collect();
924 sorted_bjs.sort();
925
926 for bj in sorted_bjs {
927 let col_start = bj * block_cols;
928
929 let mut block_data = vec![T::zero(); block_rows * block_cols];
931 let mut has_nonzero = false;
932
933 for row in row_start..row_end {
934 let local_i = row - row_start;
935 for (col, val) in csr.row_iter(row) {
936 if col >= col_start && col < col_start + block_cols {
937 let local_j = col - col_start;
938 if Scalar::abs(val.clone()) > eps {
939 block_data[local_i * block_cols + local_j] = val.clone();
940 has_nonzero = true;
941 }
942 }
943 }
944 }
945
946 if has_nonzero {
947 indices.push(bj);
948 data.push(DenseBlock::new(block_rows, block_cols, block_data));
949 }
950 }
951
952 indptr.push(data.len());
953 }
954
955 Self {
956 nrows,
957 ncols,
958 block_rows,
959 block_cols,
960 mb,
961 nb,
962 indptr,
963 indices,
964 data,
965 }
966 }
967
968 pub fn scale(&mut self, alpha: T) {
970 for block in &mut self.data {
971 block.scale(alpha.clone());
972 }
973 }
974
975 pub fn scaled(&self, alpha: T) -> Self {
977 let mut result = self.clone();
978 result.scale(alpha);
979 result
980 }
981
982 pub fn transpose(&self) -> Self
984 where
985 T: Field,
986 {
987 let csr = self.to_csr();
989 let csr_t = csr.transpose();
990 Self::from_csr(&csr_t, self.block_cols, self.block_rows)
991 }
992}
993
994#[cfg(test)]
995mod tests {
996 use super::*;
997
998 #[test]
999 fn test_bsr_new() {
1000 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1007 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1008
1009 let bsr =
1010 BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1011
1012 assert_eq!(bsr.nrows(), 4);
1013 assert_eq!(bsr.ncols(), 4);
1014 assert_eq!(bsr.nblocks(), 2);
1015 assert_eq!(bsr.block_shape(), (2, 2));
1016 }
1017
1018 #[test]
1019 fn test_bsr_get() {
1020 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1021 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1022
1023 let bsr =
1024 BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1025
1026 assert_eq!(bsr.get(0, 0), Some(1.0));
1028 assert_eq!(bsr.get(0, 1), Some(2.0));
1029 assert_eq!(bsr.get(1, 0), Some(3.0));
1030 assert_eq!(bsr.get(1, 1), Some(4.0));
1031
1032 assert_eq!(bsr.get(2, 2), Some(5.0));
1034 assert_eq!(bsr.get(2, 3), Some(6.0));
1035 assert_eq!(bsr.get(3, 2), Some(7.0));
1036 assert_eq!(bsr.get(3, 3), Some(8.0));
1037
1038 assert_eq!(bsr.get(0, 2), None);
1040 assert_eq!(bsr.get(2, 0), None);
1041 }
1042
1043 #[test]
1044 fn test_bsr_matvec() {
1045 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1050 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1051
1052 let bsr =
1053 BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1054
1055 let x = vec![1.0, 1.0, 1.0, 1.0];
1056 let y = bsr.mul_vec(&x);
1057
1058 assert!((y[0] - 3.0).abs() < 1e-10);
1059 assert!((y[1] - 7.0).abs() < 1e-10);
1060 assert!((y[2] - 11.0).abs() < 1e-10);
1061 assert!((y[3] - 15.0).abs() < 1e-10);
1062 }
1063
1064 #[test]
1065 fn test_bsr_eye() {
1066 let bsr: BsrMatrix<f64> = BsrMatrix::eye(4, 2);
1067
1068 assert_eq!(bsr.nrows(), 4);
1069 assert_eq!(bsr.ncols(), 4);
1070 assert_eq!(bsr.nblocks(), 2);
1071
1072 for i in 0..4 {
1073 assert_eq!(bsr.get(i, i), Some(1.0));
1074 }
1075 }
1076
1077 #[test]
1078 fn test_bsr_to_dense() {
1079 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1080 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1081
1082 let bsr =
1083 BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1084
1085 let dense = bsr.to_dense();
1086
1087 assert!((dense[(0, 0)] - 1.0).abs() < 1e-10);
1088 assert!((dense[(0, 1)] - 2.0).abs() < 1e-10);
1089 assert!((dense[(1, 0)] - 3.0).abs() < 1e-10);
1090 assert!((dense[(1, 1)] - 4.0).abs() < 1e-10);
1091 assert!((dense[(0, 2)] - 0.0).abs() < 1e-10);
1092 assert!((dense[(2, 2)] - 5.0).abs() < 1e-10);
1093 }
1094
1095 #[test]
1096 fn test_bsr_to_csr() {
1097 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1098 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1099
1100 let bsr =
1101 BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1102
1103 let csr = bsr.to_csr();
1104
1105 assert_eq!(csr.nrows(), 4);
1106 assert_eq!(csr.ncols(), 4);
1107 assert_eq!(csr.get(0, 0), Some(&1.0));
1108 assert_eq!(csr.get(2, 2), Some(&5.0));
1109 }
1110
1111 #[test]
1112 fn test_bsr_from_dense() {
1113 use oxiblas_matrix::Mat;
1114
1115 let dense = Mat::from_rows(&[
1116 &[1.0f64, 2.0, 0.0, 0.0],
1117 &[3.0, 4.0, 0.0, 0.0],
1118 &[0.0, 0.0, 5.0, 6.0],
1119 &[0.0, 0.0, 7.0, 8.0],
1120 ]);
1121
1122 let bsr = BsrMatrix::from_dense(&dense.as_ref(), 2, 2);
1123
1124 assert_eq!(bsr.nblocks(), 2);
1125 assert_eq!(bsr.get(0, 0), Some(1.0));
1126 assert_eq!(bsr.get(2, 2), Some(5.0));
1127 }
1128
1129 #[test]
1130 fn test_bsr_from_csr() {
1131 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1132 let col_indices = vec![0, 1, 0, 1, 2, 3, 2, 3];
1133 let row_ptrs = vec![0, 2, 4, 6, 8];
1134
1135 let csr = crate::csr::CsrMatrix::new(4, 4, row_ptrs, col_indices, values).unwrap();
1136 let bsr = BsrMatrix::from_csr(&csr, 2, 2);
1137
1138 assert_eq!(bsr.nblocks(), 2);
1139 assert_eq!(bsr.get(0, 0), Some(1.0));
1140 assert_eq!(bsr.get(3, 3), Some(8.0));
1141 }
1142
1143 #[test]
1144 fn test_bsr_scale() {
1145 let block = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1146 let mut bsr = BsrMatrix::new(2, 2, 2, 2, vec![0, 1], vec![0], vec![block]).unwrap();
1147
1148 bsr.scale(2.0);
1149
1150 assert_eq!(bsr.get(0, 0), Some(2.0));
1151 assert_eq!(bsr.get(1, 1), Some(8.0));
1152 }
1153
1154 #[test]
1155 fn test_bsr_transpose() {
1156 let block = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1159 let bsr = BsrMatrix::new(2, 2, 2, 2, vec![0, 1], vec![0], vec![block]).unwrap();
1160
1161 let bsr_t = bsr.transpose();
1162 let dense = bsr.to_dense();
1163 let dense_t = bsr_t.to_dense();
1164
1165 for i in 0..2 {
1166 for j in 0..2 {
1167 assert!((dense[(i, j)] - dense_t[(j, i)]).abs() < 1e-10);
1168 }
1169 }
1170 }
1171
1172 #[test]
1173 fn test_bsr_zeros() {
1174 let bsr: BsrMatrix<f64> = BsrMatrix::zeros(6, 8, 2, 4);
1175
1176 assert_eq!(bsr.nrows(), 6);
1177 assert_eq!(bsr.ncols(), 8);
1178 assert_eq!(bsr.nblocks(), 0);
1179 assert_eq!(bsr.nblock_rows(), 3);
1180 assert_eq!(bsr.nblock_cols(), 2);
1181 }
1182
1183 #[test]
1184 fn test_bsr_non_square_blocks() {
1185 let block1 = DenseBlock::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1187 let block2 = DenseBlock::new(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1188
1189 let bsr =
1190 BsrMatrix::new(6, 4, 3, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1191
1192 assert_eq!(bsr.nrows(), 6);
1193 assert_eq!(bsr.ncols(), 4);
1194 assert_eq!(bsr.nblock_rows(), 2);
1195 assert_eq!(bsr.nblock_cols(), 2);
1196
1197 assert_eq!(bsr.get(0, 0), Some(1.0));
1199 assert_eq!(bsr.get(2, 1), Some(6.0));
1200 assert_eq!(bsr.get(3, 2), Some(7.0));
1201 }
1202
1203 #[test]
1204 fn test_bsr_block_iter() {
1205 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1206 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1207
1208 let bsr =
1209 BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1210
1211 let blocks: Vec<_> = bsr.block_iter().map(|(bi, bj, _)| (bi, bj)).collect();
1212 assert_eq!(blocks, vec![(0, 0), (1, 1)]);
1213 }
1214
1215 #[test]
1216 fn test_dense_block() {
1217 let mut block = DenseBlock::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1218
1219 assert_eq!(block.shape(), (2, 3));
1220 assert_eq!(*block.get(0, 0), 1.0);
1221 assert_eq!(*block.get(0, 2), 3.0);
1222 assert_eq!(*block.get(1, 1), 5.0);
1223
1224 *block.get_mut(1, 1) = 10.0;
1225 assert_eq!(*block.get(1, 1), 10.0);
1226 }
1227}