1use crate::bsr::{BsrMatrix, DenseBlock};
26use oxiblas_core::scalar::{Field, Scalar};
27
28#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum BscError {
31 InvalidBlockSize {
33 block_rows: usize,
35 block_cols: usize,
37 },
38 IncompatibleDimensions {
40 nrows: usize,
42 ncols: usize,
44 block_rows: usize,
46 block_cols: usize,
48 },
49 InvalidIndptr {
51 expected: usize,
53 actual: usize,
55 },
56 DataIndicesMismatch {
58 num_blocks: usize,
60 num_indices: usize,
62 },
63 InvalidBlockIndex {
65 index: usize,
67 mb_rows: usize,
69 },
70 InvalidIndptrOrder,
72 InvalidBlockData {
74 block_idx: usize,
76 expected: usize,
78 actual: usize,
80 },
81}
82
83impl core::fmt::Display for BscError {
84 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
85 match self {
86 Self::InvalidBlockSize {
87 block_rows,
88 block_cols,
89 } => {
90 write!(f, "Invalid block size: {block_rows}×{block_cols}")
91 }
92 Self::IncompatibleDimensions {
93 nrows,
94 ncols,
95 block_rows,
96 block_cols,
97 } => {
98 write!(
99 f,
100 "Matrix {nrows}×{ncols} incompatible with {block_rows}×{block_cols} blocks"
101 )
102 }
103 Self::InvalidIndptr { expected, actual } => {
104 write!(
105 f,
106 "Invalid indptr length: expected {expected}, got {actual}"
107 )
108 }
109 Self::DataIndicesMismatch {
110 num_blocks,
111 num_indices,
112 } => {
113 write!(f, "Mismatch: {num_blocks} blocks but {num_indices} indices")
114 }
115 Self::InvalidBlockIndex { index, mb_rows } => {
116 write!(
117 f,
118 "Block row index {index} out of bounds (mb_rows={mb_rows})"
119 )
120 }
121 Self::InvalidIndptrOrder => {
122 write!(f, "Indptr must be monotonically increasing")
123 }
124 Self::InvalidBlockData {
125 block_idx,
126 expected,
127 actual,
128 } => {
129 write!(
130 f,
131 "Block {block_idx}: expected {expected} elements, got {actual}"
132 )
133 }
134 }
135 }
136}
137
138impl std::error::Error for BscError {}
139
140#[derive(Debug, Clone)]
177pub struct BscMatrix<T: Scalar> {
178 nrows: usize,
180 ncols: usize,
182 block_rows: usize,
184 block_cols: usize,
186 mb: usize,
188 nb: usize,
190 indptr: Vec<usize>,
192 indices: Vec<usize>,
194 data: Vec<DenseBlock<T>>,
196}
197
198impl<T: Scalar + Clone> BscMatrix<T> {
199 pub fn new(
215 nrows: usize,
216 ncols: usize,
217 block_rows: usize,
218 block_cols: usize,
219 indptr: Vec<usize>,
220 indices: Vec<usize>,
221 data: Vec<DenseBlock<T>>,
222 ) -> Result<Self, BscError> {
223 if block_rows == 0 || block_cols == 0 {
225 return Err(BscError::InvalidBlockSize {
226 block_rows,
227 block_cols,
228 });
229 }
230
231 let mb = nrows.div_ceil(block_rows);
233 let nb = ncols.div_ceil(block_cols);
234
235 if indptr.len() != nb + 1 {
237 return Err(BscError::InvalidIndptr {
238 expected: nb + 1,
239 actual: indptr.len(),
240 });
241 }
242
243 for i in 1..indptr.len() {
245 if indptr[i] < indptr[i - 1] {
246 return Err(BscError::InvalidIndptrOrder);
247 }
248 }
249
250 let nnz_blocks = data.len();
252 if indices.len() != nnz_blocks {
253 return Err(BscError::DataIndicesMismatch {
254 num_blocks: nnz_blocks,
255 num_indices: indices.len(),
256 });
257 }
258
259 if indptr[nb] != nnz_blocks {
261 return Err(BscError::InvalidIndptr {
262 expected: nnz_blocks,
263 actual: indptr[nb],
264 });
265 }
266
267 for &idx in &indices {
269 if idx >= mb {
270 return Err(BscError::InvalidBlockIndex {
271 index: idx,
272 mb_rows: mb,
273 });
274 }
275 }
276
277 let block_size = block_rows * block_cols;
279 for (i, block) in data.iter().enumerate() {
280 if block.data().len() != block_size {
281 return Err(BscError::InvalidBlockData {
282 block_idx: i,
283 expected: block_size,
284 actual: block.data().len(),
285 });
286 }
287 }
288
289 Ok(Self {
290 nrows,
291 ncols,
292 block_rows,
293 block_cols,
294 mb,
295 nb,
296 indptr,
297 indices,
298 data,
299 })
300 }
301
302 #[inline]
308 pub unsafe fn new_unchecked(
309 nrows: usize,
310 ncols: usize,
311 block_rows: usize,
312 block_cols: usize,
313 indptr: Vec<usize>,
314 indices: Vec<usize>,
315 data: Vec<DenseBlock<T>>,
316 ) -> Self {
317 let mb = nrows.div_ceil(block_rows);
318 let nb = ncols.div_ceil(block_cols);
319 Self {
320 nrows,
321 ncols,
322 block_rows,
323 block_cols,
324 mb,
325 nb,
326 indptr,
327 indices,
328 data,
329 }
330 }
331
332 pub fn zeros(nrows: usize, ncols: usize, block_rows: usize, block_cols: usize) -> Self {
334 let nb = ncols.div_ceil(block_cols);
335 Self {
336 nrows,
337 ncols,
338 block_rows,
339 block_cols,
340 mb: nrows.div_ceil(block_rows),
341 nb,
342 indptr: vec![0; nb + 1],
343 indices: Vec::new(),
344 data: Vec::new(),
345 }
346 }
347
348 pub fn eye(n: usize, block_size: usize) -> Self
350 where
351 T: Field,
352 {
353 let nb = n.div_ceil(block_size);
354 let mut indptr = Vec::with_capacity(nb + 1);
355 let mut indices = Vec::with_capacity(nb);
356 let mut data = Vec::with_capacity(nb);
357
358 indptr.push(0);
359
360 for bj in 0..nb {
361 indices.push(bj); let mut block_data = vec![T::zero(); block_size * block_size];
365 for i in 0..block_size {
366 let global_idx = bj * block_size + i;
367 if global_idx < n {
368 block_data[i * block_size + i] = T::one();
369 }
370 }
371 data.push(DenseBlock::new(block_size, block_size, block_data));
372
373 indptr.push(data.len());
374 }
375
376 Self {
377 nrows: n,
378 ncols: n,
379 block_rows: block_size,
380 block_cols: block_size,
381 mb: nb,
382 nb,
383 indptr,
384 indices,
385 data,
386 }
387 }
388
389 #[inline]
391 pub fn nrows(&self) -> usize {
392 self.nrows
393 }
394
395 #[inline]
397 pub fn ncols(&self) -> usize {
398 self.ncols
399 }
400
401 #[inline]
403 pub fn shape(&self) -> (usize, usize) {
404 (self.nrows, self.ncols)
405 }
406
407 #[inline]
409 pub fn block_shape(&self) -> (usize, usize) {
410 (self.block_rows, self.block_cols)
411 }
412
413 #[inline]
415 pub fn nblock_rows(&self) -> usize {
416 self.mb
417 }
418
419 #[inline]
421 pub fn nblock_cols(&self) -> usize {
422 self.nb
423 }
424
425 #[inline]
427 pub fn nblocks(&self) -> usize {
428 self.data.len()
429 }
430
431 pub fn nnz(&self) -> usize
435 where
436 T: Field,
437 {
438 let eps = <T as Scalar>::epsilon();
439 let mut count = 0;
440
441 for block in &self.data {
442 for val in block.data() {
443 if Scalar::abs(val.clone()) > eps {
444 count += 1;
445 }
446 }
447 }
448
449 count
450 }
451
452 #[inline]
454 pub fn nstored(&self) -> usize {
455 self.data.len() * self.block_rows * self.block_cols
456 }
457
458 #[inline]
460 pub fn indptr(&self) -> &[usize] {
461 &self.indptr
462 }
463
464 #[inline]
466 pub fn indices(&self) -> &[usize] {
467 &self.indices
468 }
469
470 #[inline]
472 pub fn data(&self) -> &[DenseBlock<T>] {
473 &self.data
474 }
475
476 #[inline]
478 pub fn data_mut(&mut self) -> &mut [DenseBlock<T>] {
479 &mut self.data
480 }
481
482 pub fn get_block(&self, bi: usize, bj: usize) -> Option<&DenseBlock<T>> {
484 if bi >= self.mb || bj >= self.nb {
485 return None;
486 }
487
488 let start = self.indptr[bj];
489 let end = self.indptr[bj + 1];
490
491 for k in start..end {
492 if self.indices[k] == bi {
493 return Some(&self.data[k]);
494 }
495 }
496
497 None
498 }
499
500 pub fn get(&self, row: usize, col: usize) -> Option<T>
502 where
503 T: Field,
504 {
505 if row >= self.nrows || col >= self.ncols {
506 return None;
507 }
508
509 let bi = row / self.block_rows;
510 let bj = col / self.block_cols;
511 let local_i = row % self.block_rows;
512 let local_j = col % self.block_cols;
513
514 self.get_block(bi, bj)
515 .map(|block| block.get(local_i, local_j).clone())
516 }
517
518 pub fn get_or_zero(&self, row: usize, col: usize) -> T
520 where
521 T: Field,
522 {
523 self.get(row, col).unwrap_or_else(T::zero)
524 }
525
526 pub fn block_iter(&self) -> impl Iterator<Item = (usize, usize, &DenseBlock<T>)> + '_ {
528 (0..self.nb).flat_map(move |bj| {
529 let start = self.indptr[bj];
530 let end = self.indptr[bj + 1];
531
532 (start..end).map(move |k| (self.indices[k], bj, &self.data[k]))
533 })
534 }
535
536 pub fn col_block_iter(&self, bj: usize) -> impl Iterator<Item = (usize, &DenseBlock<T>)> + '_ {
538 let start = if bj < self.nb { self.indptr[bj] } else { 0 };
539 let end = if bj < self.nb { self.indptr[bj + 1] } else { 0 };
540
541 (start..end).map(move |k| (self.indices[k], &self.data[k]))
542 }
543
544 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, T)> + '_
546 where
547 T: Field,
548 {
549 let eps = <T as Scalar>::epsilon();
550 let br = self.block_rows;
551 let bc = self.block_cols;
552 let nrows = self.nrows;
553 let ncols = self.ncols;
554
555 self.block_iter().flat_map(move |(bi, bj, block)| {
556 let base_row = bi * br;
557 let base_col = bj * bc;
558
559 (0..br).flat_map(move |i| {
560 (0..bc).filter_map(move |j| {
561 let global_row = base_row + i;
562 let global_col = base_col + j;
563
564 if global_row < nrows && global_col < ncols {
565 let val = block.get(i, j).clone();
566 if Scalar::abs(val.clone()) > eps {
567 return Some((global_row, global_col, val));
568 }
569 }
570 None
571 })
572 })
573 })
574 }
575
576 pub fn matvec(&self, x: &[T], y: &mut [T])
578 where
579 T: Field,
580 {
581 assert_eq!(x.len(), self.ncols, "x length must equal ncols");
582 assert_eq!(y.len(), self.nrows, "y length must equal nrows");
583
584 for yi in y.iter_mut() {
586 *yi = T::zero();
587 }
588
589 for bj in 0..self.nb {
591 let start = self.indptr[bj];
592 let end = self.indptr[bj + 1];
593 let col_start = bj * self.block_cols;
594 let col_end = (col_start + self.block_cols).min(self.ncols);
595
596 for k in start..end {
597 let bi = self.indices[k];
598 let block = &self.data[k];
599 let row_start = bi * self.block_rows;
600 let row_end = (row_start + self.block_rows).min(self.nrows);
601
602 for (i, yi) in y[row_start..row_end].iter_mut().enumerate() {
604 for j in 0..(col_end - col_start) {
605 *yi = yi.clone() + block.get(i, j).clone() * x[col_start + j].clone();
606 }
607 }
608 }
609 }
610 }
611
612 pub fn matvec_transpose(&self, x: &[T], y: &mut [T])
614 where
615 T: Field,
616 {
617 assert_eq!(x.len(), self.nrows, "x length must equal nrows");
618 assert_eq!(y.len(), self.ncols, "y length must equal ncols");
619
620 for yi in y.iter_mut() {
622 *yi = T::zero();
623 }
624
625 for bj in 0..self.nb {
627 let start = self.indptr[bj];
628 let end = self.indptr[bj + 1];
629 let col_start = bj * self.block_cols;
630 let col_end = (col_start + self.block_cols).min(self.ncols);
631
632 for k in start..end {
633 let bi = self.indices[k];
634 let block = &self.data[k];
635 let row_start = bi * self.block_rows;
636 let row_end = (row_start + self.block_rows).min(self.nrows);
637
638 for j in 0..(col_end - col_start) {
640 for i in 0..(row_end - row_start) {
641 y[col_start + j] = y[col_start + j].clone()
642 + block.get(i, j).clone() * x[row_start + i].clone();
643 }
644 }
645 }
646 }
647 }
648
649 pub fn mul_vec(&self, x: &[T]) -> Vec<T>
651 where
652 T: Field,
653 {
654 let mut y = vec![T::zero(); self.nrows];
655 self.matvec(x, &mut y);
656 y
657 }
658
659 pub fn to_bsr(&self) -> BsrMatrix<T>
661 where
662 T: Field,
663 {
664 let mut block_entries: Vec<(usize, usize, DenseBlock<T>)> =
665 Vec::with_capacity(self.nblocks());
666
667 for bj in 0..self.nb {
669 let start = self.indptr[bj];
670 let end = self.indptr[bj + 1];
671 for k in start..end {
672 let bi = self.indices[k];
673 block_entries.push((bi, bj, self.data[k].clone()));
674 }
675 }
676
677 block_entries.sort_by_key(|(bi, bj, _)| (*bi, *bj));
679
680 let mut indptr = Vec::with_capacity(self.mb + 1);
682 let mut indices = Vec::with_capacity(block_entries.len());
683 let mut data = Vec::with_capacity(block_entries.len());
684
685 indptr.push(0);
686
687 let mut entry_idx = 0;
688 for bi in 0..self.mb {
689 while entry_idx < block_entries.len() && block_entries[entry_idx].0 == bi {
690 let (_, bj, block) = &block_entries[entry_idx];
691 indices.push(*bj);
692 data.push(block.clone());
693 entry_idx += 1;
694 }
695 indptr.push(data.len());
696 }
697
698 unsafe {
700 BsrMatrix::new_unchecked(
701 self.nrows,
702 self.ncols,
703 self.block_rows,
704 self.block_cols,
705 indptr,
706 indices,
707 data,
708 )
709 }
710 }
711
712 pub fn from_bsr(bsr: &BsrMatrix<T>) -> Self
714 where
715 T: Field,
716 {
717 let (nrows, ncols) = bsr.shape();
718 let (block_rows, block_cols) = bsr.block_shape();
719 let mb = bsr.nblock_rows();
720 let nb = bsr.nblock_cols();
721
722 let mut block_entries: Vec<(usize, usize, DenseBlock<T>)> =
724 Vec::with_capacity(bsr.nblocks());
725 for (bi, bj, block) in bsr.block_iter() {
726 block_entries.push((bi, bj, block.clone()));
727 }
728
729 block_entries.sort_by_key(|(bi, bj, _)| (*bj, *bi));
731
732 let mut indptr = Vec::with_capacity(nb + 1);
734 let mut indices = Vec::with_capacity(block_entries.len());
735 let mut data = Vec::with_capacity(block_entries.len());
736
737 indptr.push(0);
738
739 let mut entry_idx = 0;
740 for bj in 0..nb {
741 while entry_idx < block_entries.len() && block_entries[entry_idx].1 == bj {
742 let (bi, _, block) = &block_entries[entry_idx];
743 indices.push(*bi);
744 data.push(block.clone());
745 entry_idx += 1;
746 }
747 indptr.push(data.len());
748 }
749
750 Self {
751 nrows,
752 ncols,
753 block_rows,
754 block_cols,
755 mb,
756 nb,
757 indptr,
758 indices,
759 data,
760 }
761 }
762
763 pub fn to_csc(&self) -> crate::csc::CscMatrix<T>
765 where
766 T: Field,
767 {
768 let eps = <T as Scalar>::epsilon();
769
770 let mut col_ptrs = vec![0usize; self.ncols + 1];
771 let mut row_indices = Vec::new();
772 let mut values = Vec::new();
773
774 for col in 0..self.ncols {
775 let bj = col / self.block_cols;
776 let local_j = col % self.block_cols;
777
778 let block_start = self.indptr[bj];
779 let block_end = self.indptr[bj + 1];
780
781 let mut col_entries: Vec<(usize, T)> = Vec::new();
782
783 for k in block_start..block_end {
784 let bi = self.indices[k];
785 let block = &self.data[k];
786
787 for i in 0..self.block_rows {
788 let global_row = bi * self.block_rows + i;
789 if global_row < self.nrows {
790 let val = block.get(i, local_j).clone();
791 if Scalar::abs(val.clone()) > eps {
792 col_entries.push((global_row, val));
793 }
794 }
795 }
796 }
797
798 col_entries.sort_by_key(|(row, _)| *row);
800
801 for (row, val) in col_entries {
802 row_indices.push(row);
803 values.push(val);
804 }
805 col_ptrs[col + 1] = values.len();
806 }
807
808 unsafe {
810 crate::csc::CscMatrix::new_unchecked(
811 self.nrows,
812 self.ncols,
813 col_ptrs,
814 row_indices,
815 values,
816 )
817 }
818 }
819
820 pub fn from_csc(csc: &crate::csc::CscMatrix<T>, block_rows: usize, block_cols: usize) -> Self
822 where
823 T: Field,
824 {
825 let (nrows, ncols) = csc.shape();
826 let eps = <T as Scalar>::epsilon();
827
828 let mb = nrows.div_ceil(block_rows);
829 let nb = ncols.div_ceil(block_cols);
830
831 let mut indptr = Vec::with_capacity(nb + 1);
832 let mut indices = Vec::new();
833 let mut data = Vec::new();
834
835 indptr.push(0);
836
837 for bj in 0..nb {
838 let col_start = bj * block_cols;
839 let col_end = (col_start + block_cols).min(ncols);
840
841 let mut block_rows_present = std::collections::HashSet::new();
843 for col in col_start..col_end {
844 for (row, _) in csc.col_iter(col) {
845 let bi = row / block_rows;
846 block_rows_present.insert(bi);
847 }
848 }
849
850 let mut sorted_bis: Vec<_> = block_rows_present.into_iter().collect();
852 sorted_bis.sort_unstable();
853
854 for bi in sorted_bis {
855 let row_start = bi * block_rows;
856
857 let mut block_data = vec![T::zero(); block_rows * block_cols];
859 let mut has_nonzero = false;
860
861 for col in col_start..col_end {
862 let local_j = col - col_start;
863 for (row, val) in csc.col_iter(col) {
864 if row >= row_start && row < row_start + block_rows {
865 let local_i = row - row_start;
866 if Scalar::abs(val.clone()) > eps {
867 block_data[local_i * block_cols + local_j] = val.clone();
868 has_nonzero = true;
869 }
870 }
871 }
872 }
873
874 if has_nonzero {
875 indices.push(bi);
876 data.push(DenseBlock::new(block_rows, block_cols, block_data));
877 }
878 }
879
880 indptr.push(data.len());
881 }
882
883 Self {
884 nrows,
885 ncols,
886 block_rows,
887 block_cols,
888 mb,
889 nb,
890 indptr,
891 indices,
892 data,
893 }
894 }
895
896 pub fn to_dense(&self) -> oxiblas_matrix::Mat<T>
898 where
899 T: Field + bytemuck::Zeroable,
900 {
901 let mut dense = oxiblas_matrix::Mat::zeros(self.nrows, self.ncols);
902
903 for bj in 0..self.nb {
904 let start = self.indptr[bj];
905 let end = self.indptr[bj + 1];
906 let col_start = bj * self.block_cols;
907
908 for k in start..end {
909 let bi = self.indices[k];
910 let block = &self.data[k];
911 let row_start = bi * self.block_rows;
912
913 for i in 0..self.block_rows {
914 let global_row = row_start + i;
915 if global_row >= self.nrows {
916 break;
917 }
918
919 for j in 0..self.block_cols {
920 let global_col = col_start + j;
921 if global_col >= self.ncols {
922 break;
923 }
924
925 dense[(global_row, global_col)] = block.get(i, j).clone();
926 }
927 }
928 }
929 }
930
931 dense
932 }
933
934 pub fn from_dense(
942 dense: &oxiblas_matrix::MatRef<'_, T>,
943 block_rows: usize,
944 block_cols: usize,
945 ) -> Self
946 where
947 T: Field,
948 {
949 let (nrows, ncols) = dense.shape();
950 let eps = <T as Scalar>::epsilon();
951
952 let mb = nrows.div_ceil(block_rows);
953 let nb = ncols.div_ceil(block_cols);
954
955 let mut indptr = Vec::with_capacity(nb + 1);
956 let mut indices = Vec::new();
957 let mut data = Vec::new();
958
959 indptr.push(0);
960
961 for bj in 0..nb {
962 let col_start = bj * block_cols;
963 let col_end = (col_start + block_cols).min(ncols);
964
965 for bi in 0..mb {
966 let row_start = bi * block_rows;
967 let row_end = (row_start + block_rows).min(nrows);
968
969 let mut has_nonzero = false;
971 'outer: for i in row_start..row_end {
972 for j in col_start..col_end {
973 if Scalar::abs(dense[(i, j)].clone()) > eps {
974 has_nonzero = true;
975 break 'outer;
976 }
977 }
978 }
979
980 if has_nonzero {
981 let mut block_data = vec![T::zero(); block_rows * block_cols];
983 for i in 0..block_rows {
984 let global_row = row_start + i;
985 if global_row >= nrows {
986 break;
987 }
988 for j in 0..block_cols {
989 let global_col = col_start + j;
990 if global_col >= ncols {
991 break;
992 }
993 block_data[i * block_cols + j] =
994 dense[(global_row, global_col)].clone();
995 }
996 }
997
998 indices.push(bi);
999 data.push(DenseBlock::new(block_rows, block_cols, block_data));
1000 }
1001 }
1002
1003 indptr.push(data.len());
1004 }
1005
1006 Self {
1007 nrows,
1008 ncols,
1009 block_rows,
1010 block_cols,
1011 mb,
1012 nb,
1013 indptr,
1014 indices,
1015 data,
1016 }
1017 }
1018
1019 pub fn scale(&mut self, alpha: T) {
1021 for block in &mut self.data {
1022 block.scale(alpha.clone());
1023 }
1024 }
1025
1026 pub fn scaled(&self, alpha: T) -> Self {
1028 let mut result = self.clone();
1029 result.scale(alpha);
1030 result
1031 }
1032
1033 pub fn transpose(&self) -> Self
1037 where
1038 T: Field,
1039 {
1040 let mut block_entries: Vec<(usize, usize, DenseBlock<T>)> =
1042 Vec::with_capacity(self.nblocks());
1043
1044 for bj in 0..self.nb {
1045 let start = self.indptr[bj];
1046 let end = self.indptr[bj + 1];
1047 for k in start..end {
1048 let bi = self.indices[k];
1049 let block = &self.data[k];
1050
1051 let mut transposed_data = vec![T::zero(); self.block_rows * self.block_cols];
1053 for i in 0..self.block_rows {
1054 for j in 0..self.block_cols {
1055 transposed_data[j * self.block_rows + i] = block.get(i, j).clone();
1056 }
1057 }
1058 let transposed_block =
1059 DenseBlock::new(self.block_cols, self.block_rows, transposed_data);
1060
1061 block_entries.push((bj, bi, transposed_block));
1063 }
1064 }
1065
1066 block_entries.sort_by_key(|(bi, bj, _)| (*bj, *bi));
1068
1069 let new_nb = self.mb;
1074 let new_mb = self.nb;
1075
1076 let mut indptr = Vec::with_capacity(new_nb + 1);
1077 let mut indices = Vec::with_capacity(block_entries.len());
1078 let mut data = Vec::with_capacity(block_entries.len());
1079
1080 indptr.push(0);
1081
1082 let mut entry_idx = 0;
1083 for bj in 0..new_nb {
1084 while entry_idx < block_entries.len() && block_entries[entry_idx].1 == bj {
1085 let (bi, _, block) = &block_entries[entry_idx];
1086 indices.push(*bi);
1087 data.push(block.clone());
1088 entry_idx += 1;
1089 }
1090 indptr.push(data.len());
1091 }
1092
1093 Self {
1094 nrows: self.ncols,
1095 ncols: self.nrows,
1096 block_rows: self.block_cols,
1097 block_cols: self.block_rows,
1098 mb: new_mb,
1099 nb: new_nb,
1100 indptr,
1101 indices,
1102 data,
1103 }
1104 }
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109 use super::*;
1110
1111 #[test]
1112 fn test_bsc_new() {
1113 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1120 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1121
1122 let bsc =
1123 BscMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1124
1125 assert_eq!(bsc.nrows(), 4);
1126 assert_eq!(bsc.ncols(), 4);
1127 assert_eq!(bsc.nblocks(), 2);
1128 assert_eq!(bsc.block_shape(), (2, 2));
1129 }
1130
1131 #[test]
1132 fn test_bsc_get() {
1133 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1134 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1135
1136 let bsc =
1137 BscMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1138
1139 assert_eq!(bsc.get(0, 0), Some(1.0));
1141 assert_eq!(bsc.get(0, 1), Some(2.0));
1142 assert_eq!(bsc.get(1, 0), Some(3.0));
1143 assert_eq!(bsc.get(1, 1), Some(4.0));
1144
1145 assert_eq!(bsc.get(2, 2), Some(5.0));
1147 assert_eq!(bsc.get(2, 3), Some(6.0));
1148 assert_eq!(bsc.get(3, 2), Some(7.0));
1149 assert_eq!(bsc.get(3, 3), Some(8.0));
1150
1151 assert_eq!(bsc.get(0, 2), None);
1153 assert_eq!(bsc.get(2, 0), None);
1154 }
1155
1156 #[test]
1157 fn test_bsc_matvec() {
1158 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1163 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1164
1165 let bsc =
1166 BscMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1167
1168 let x = vec![1.0, 1.0, 1.0, 1.0];
1169 let y = bsc.mul_vec(&x);
1170
1171 assert!((y[0] - 3.0).abs() < 1e-10);
1172 assert!((y[1] - 7.0).abs() < 1e-10);
1173 assert!((y[2] - 11.0).abs() < 1e-10);
1174 assert!((y[3] - 15.0).abs() < 1e-10);
1175 }
1176
1177 #[test]
1178 fn test_bsc_matvec_transpose() {
1179 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1184 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1185
1186 let bsc =
1187 BscMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1188
1189 let x = vec![1.0, 1.0, 1.0, 1.0];
1190 let mut y = vec![0.0; 4];
1191 bsc.matvec_transpose(&x, &mut y);
1192
1193 assert!((y[0] - 4.0).abs() < 1e-10);
1194 assert!((y[1] - 6.0).abs() < 1e-10);
1195 assert!((y[2] - 12.0).abs() < 1e-10);
1196 assert!((y[3] - 14.0).abs() < 1e-10);
1197 }
1198
1199 #[test]
1200 fn test_bsc_eye() {
1201 let bsc: BscMatrix<f64> = BscMatrix::eye(4, 2);
1202
1203 assert_eq!(bsc.nrows(), 4);
1204 assert_eq!(bsc.ncols(), 4);
1205 assert_eq!(bsc.nblocks(), 2);
1206
1207 for i in 0..4 {
1208 assert_eq!(bsc.get(i, i), Some(1.0));
1209 }
1210 }
1211
1212 #[test]
1213 fn test_bsc_to_dense() {
1214 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1215 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1216
1217 let bsc =
1218 BscMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1219
1220 let dense = bsc.to_dense();
1221
1222 assert!((dense[(0, 0)] - 1.0).abs() < 1e-10);
1223 assert!((dense[(0, 1)] - 2.0).abs() < 1e-10);
1224 assert!((dense[(1, 0)] - 3.0).abs() < 1e-10);
1225 assert!((dense[(1, 1)] - 4.0).abs() < 1e-10);
1226 assert!((dense[(0, 2)] - 0.0).abs() < 1e-10);
1227 assert!((dense[(2, 2)] - 5.0).abs() < 1e-10);
1228 }
1229
1230 #[test]
1231 fn test_bsc_to_bsr_roundtrip() {
1232 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1233 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1234
1235 let bsc =
1236 BscMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1237
1238 let bsr = bsc.to_bsr();
1239 let bsc2 = BscMatrix::from_bsr(&bsr);
1240
1241 for i in 0..4 {
1243 for j in 0..4 {
1244 let v1 = bsc.get_or_zero(i, j);
1245 let v2 = bsc2.get_or_zero(i, j);
1246 assert!((v1 - v2).abs() < 1e-10, "Mismatch at ({i}, {j})");
1247 }
1248 }
1249 }
1250
1251 #[test]
1252 fn test_bsc_from_dense() {
1253 use oxiblas_matrix::Mat;
1254
1255 let dense = Mat::from_rows(&[
1256 &[1.0f64, 2.0, 0.0, 0.0],
1257 &[3.0, 4.0, 0.0, 0.0],
1258 &[0.0, 0.0, 5.0, 6.0],
1259 &[0.0, 0.0, 7.0, 8.0],
1260 ]);
1261
1262 let bsc = BscMatrix::from_dense(&dense.as_ref(), 2, 2);
1263
1264 assert_eq!(bsc.nblocks(), 2);
1265 assert_eq!(bsc.get(0, 0), Some(1.0));
1266 assert_eq!(bsc.get(2, 2), Some(5.0));
1267 }
1268
1269 #[test]
1270 fn test_bsc_scale() {
1271 let block = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1272 let mut bsc = BscMatrix::new(2, 2, 2, 2, vec![0, 1], vec![0], vec![block]).unwrap();
1273
1274 bsc.scale(2.0);
1275
1276 assert_eq!(bsc.get(0, 0), Some(2.0));
1277 assert_eq!(bsc.get(1, 1), Some(8.0));
1278 }
1279
1280 #[test]
1281 fn test_bsc_transpose() {
1282 let block = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1285 let bsc = BscMatrix::new(2, 2, 2, 2, vec![0, 1], vec![0], vec![block]).unwrap();
1286
1287 let bsc_t = bsc.transpose();
1288 let dense = bsc.to_dense();
1289 let dense_t = bsc_t.to_dense();
1290
1291 for i in 0..2 {
1292 for j in 0..2 {
1293 assert!((dense[(i, j)] - dense_t[(j, i)]).abs() < 1e-10);
1294 }
1295 }
1296 }
1297
1298 #[test]
1299 fn test_bsc_zeros() {
1300 let bsc: BscMatrix<f64> = BscMatrix::zeros(6, 8, 2, 4);
1301
1302 assert_eq!(bsc.nrows(), 6);
1303 assert_eq!(bsc.ncols(), 8);
1304 assert_eq!(bsc.nblocks(), 0);
1305 assert_eq!(bsc.nblock_rows(), 3);
1306 assert_eq!(bsc.nblock_cols(), 2);
1307 }
1308
1309 #[test]
1310 fn test_bsc_non_square_blocks() {
1311 let block1 = DenseBlock::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1313 let block2 = DenseBlock::new(2, 3, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1314
1315 let bsc =
1316 BscMatrix::new(4, 6, 2, 3, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1317
1318 assert_eq!(bsc.nrows(), 4);
1319 assert_eq!(bsc.ncols(), 6);
1320 assert_eq!(bsc.nblock_rows(), 2);
1321 assert_eq!(bsc.nblock_cols(), 2);
1322
1323 assert_eq!(bsc.get(0, 0), Some(1.0));
1325 assert_eq!(bsc.get(1, 2), Some(6.0));
1326 assert_eq!(bsc.get(2, 3), Some(7.0));
1327 }
1328
1329 #[test]
1330 fn test_bsc_block_iter() {
1331 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1332 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1333
1334 let bsc =
1335 BscMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1336
1337 let blocks: Vec<_> = bsc.block_iter().map(|(bi, bj, _)| (bi, bj)).collect();
1338 assert_eq!(blocks, vec![(0, 0), (1, 1)]);
1339 }
1340
1341 #[test]
1342 fn test_bsc_col_block_iter() {
1343 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1344 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1345
1346 let bsc =
1347 BscMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1348
1349 let col0_blocks: Vec<_> = bsc.col_block_iter(0).map(|(bi, _)| bi).collect();
1351 assert_eq!(col0_blocks, vec![0]);
1352
1353 let col1_blocks: Vec<_> = bsc.col_block_iter(1).map(|(bi, _)| bi).collect();
1355 assert_eq!(col1_blocks, vec![1]);
1356 }
1357}