1use crate::bsc::BscMatrix;
19use crate::bsr::BsrMatrix;
20use crate::coo::CooMatrix;
21use crate::csc::CscMatrix;
22use crate::csr::CsrMatrix;
23use crate::dia::DiaMatrix;
24use crate::ell::EllMatrix;
25use crate::hyb::{HybMatrix, HybWidthStrategy};
26use crate::sell::{SellMatrix, SliceSize};
27use oxiblas_core::scalar::{Field, Real, Scalar};
28
29pub fn csr_to_csc<T: Scalar + Clone>(csr: &CsrMatrix<T>) -> CscMatrix<T> {
34 let nrows = csr.nrows();
35 let ncols = csr.ncols();
36 let nnz = csr.nnz();
37
38 if nnz == 0 {
39 return CscMatrix::zeros(nrows, ncols);
40 }
41
42 let mut col_counts = vec![0usize; ncols];
44 for &col in csr.col_indices() {
45 col_counts[col] += 1;
46 }
47
48 let mut col_ptrs = vec![0usize; ncols + 1];
50 for i in 0..ncols {
51 col_ptrs[i + 1] = col_ptrs[i] + col_counts[i];
52 }
53
54 let mut row_indices = vec![0usize; nnz];
56 let mut values = vec![T::zero(); nnz];
57 let mut write_pos = col_ptrs.clone();
58
59 for row in 0..nrows {
60 let start = csr.row_ptrs()[row];
61 let end = csr.row_ptrs()[row + 1];
62
63 for i in start..end {
64 let col = csr.col_indices()[i];
65 let pos = write_pos[col];
66
67 row_indices[pos] = row;
68 values[pos] = csr.values()[i].clone();
69
70 write_pos[col] += 1;
71 }
72 }
73
74 unsafe { CscMatrix::new_unchecked(nrows, ncols, col_ptrs, row_indices, values) }
76}
77
78pub fn csc_to_csr<T: Scalar + Clone>(csc: &CscMatrix<T>) -> CsrMatrix<T> {
83 let nrows = csc.nrows();
84 let ncols = csc.ncols();
85 let nnz = csc.nnz();
86
87 if nnz == 0 {
88 return CsrMatrix::zeros(nrows, ncols);
89 }
90
91 let mut row_counts = vec![0usize; nrows];
93 for &row in csc.row_indices() {
94 row_counts[row] += 1;
95 }
96
97 let mut row_ptrs = vec![0usize; nrows + 1];
99 for i in 0..nrows {
100 row_ptrs[i + 1] = row_ptrs[i] + row_counts[i];
101 }
102
103 let mut col_indices = vec![0usize; nnz];
105 let mut values = vec![T::zero(); nnz];
106 let mut write_pos = row_ptrs.clone();
107
108 for col in 0..ncols {
109 let start = csc.col_ptrs()[col];
110 let end = csc.col_ptrs()[col + 1];
111
112 for i in start..end {
113 let row = csc.row_indices()[i];
114 let pos = write_pos[row];
115
116 col_indices[pos] = col;
117 values[pos] = csc.values()[i].clone();
118
119 write_pos[row] += 1;
120 }
121 }
122
123 unsafe { CsrMatrix::new_unchecked(nrows, ncols, row_ptrs, col_indices, values) }
125}
126
127pub fn coo_to_csr<T: Scalar<Real = T> + Clone + Field + Real>(coo: &CooMatrix<T>) -> CsrMatrix<T> {
132 let nrows = coo.nrows();
133 let ncols = coo.ncols();
134
135 if coo.is_empty() {
136 return CsrMatrix::zeros(nrows, ncols);
137 }
138
139 let mut indices: Vec<usize> = (0..coo.len()).collect();
141 indices.sort_by_key(|&i| (coo.row_indices()[i], coo.col_indices()[i]));
142
143 let mut row_ptrs = Vec::with_capacity(nrows + 1);
145 let mut col_indices = Vec::with_capacity(coo.len());
146 let mut values: Vec<T> = Vec::with_capacity(coo.len());
147
148 row_ptrs.push(0);
149 let mut current_row = 0;
150
151 for &idx in &indices {
152 let row = coo.row_indices()[idx];
153 let col = coo.col_indices()[idx];
154 let val = coo.values()[idx].clone();
155
156 while current_row < row {
158 row_ptrs.push(values.len());
159 current_row += 1;
160 }
161
162 if !values.is_empty() && col_indices.last() == Some(&col) && current_row == row {
164 let last = values.len() - 1;
166 values[last] = values[last].clone() + val;
167 } else {
168 if !values.is_empty() {
170 let last = values.len() - 1;
171 if Scalar::abs(values[last].clone()) <= <T as Scalar>::epsilon() {
172 values.pop();
173 col_indices.pop();
174 }
175 }
176 if Scalar::abs(val.clone()) > <T as Scalar>::epsilon() {
178 col_indices.push(col);
179 values.push(val);
180 }
181 }
182 }
183
184 if !values.is_empty() {
186 let last = values.len() - 1;
187 if Scalar::abs(values[last].clone()) <= <T as Scalar>::epsilon() {
188 values.pop();
189 col_indices.pop();
190 }
191 }
192
193 while current_row < nrows {
195 row_ptrs.push(values.len());
196 current_row += 1;
197 }
198 row_ptrs.push(values.len());
199
200 unsafe { CsrMatrix::new_unchecked(nrows, ncols, row_ptrs, col_indices, values) }
202}
203
204pub fn coo_to_csc<T: Scalar<Real = T> + Clone + Field + Real>(coo: &CooMatrix<T>) -> CscMatrix<T> {
209 let nrows = coo.nrows();
210 let ncols = coo.ncols();
211
212 if coo.is_empty() {
213 return CscMatrix::zeros(nrows, ncols);
214 }
215
216 let mut indices: Vec<usize> = (0..coo.len()).collect();
218 indices.sort_by_key(|&i| (coo.col_indices()[i], coo.row_indices()[i]));
219
220 let mut col_ptrs = Vec::with_capacity(ncols + 1);
222 let mut row_indices = Vec::with_capacity(coo.len());
223 let mut values: Vec<T> = Vec::with_capacity(coo.len());
224
225 col_ptrs.push(0);
226 let mut current_col = 0;
227
228 for &idx in &indices {
229 let row = coo.row_indices()[idx];
230 let col = coo.col_indices()[idx];
231 let val = coo.values()[idx].clone();
232
233 while current_col < col {
235 col_ptrs.push(values.len());
236 current_col += 1;
237 }
238
239 if !values.is_empty() && row_indices.last() == Some(&row) && current_col == col {
241 let last = values.len() - 1;
243 values[last] = values[last].clone() + val;
244 } else {
245 if !values.is_empty() {
247 let last = values.len() - 1;
248 if Scalar::abs(values[last].clone()) <= <T as Scalar>::epsilon() {
249 values.pop();
250 row_indices.pop();
251 }
252 }
253 if Scalar::abs(val.clone()) > <T as Scalar>::epsilon() {
255 row_indices.push(row);
256 values.push(val);
257 }
258 }
259 }
260
261 if !values.is_empty() {
263 let last = values.len() - 1;
264 if Scalar::abs(values[last].clone()) <= <T as Scalar>::epsilon() {
265 values.pop();
266 row_indices.pop();
267 }
268 }
269
270 while current_col < ncols {
272 col_ptrs.push(values.len());
273 current_col += 1;
274 }
275 col_ptrs.push(values.len());
276
277 unsafe { CscMatrix::new_unchecked(nrows, ncols, col_ptrs, row_indices, values) }
279}
280
281pub fn csr_to_coo<T: Scalar + Clone>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
283 let nrows = csr.nrows();
284 let ncols = csr.ncols();
285 let nnz = csr.nnz();
286
287 let mut row_indices = Vec::with_capacity(nnz);
288 let mut col_indices = Vec::with_capacity(nnz);
289 let mut values = Vec::with_capacity(nnz);
290
291 for row in 0..nrows {
292 let start = csr.row_ptrs()[row];
293 let end = csr.row_ptrs()[row + 1];
294
295 for i in start..end {
296 row_indices.push(row);
297 col_indices.push(csr.col_indices()[i]);
298 values.push(csr.values()[i].clone());
299 }
300 }
301
302 unsafe { CooMatrix::new_unchecked(nrows, ncols, row_indices, col_indices, values) }
304}
305
306pub fn csc_to_coo<T: Scalar + Clone>(csc: &CscMatrix<T>) -> CooMatrix<T> {
308 let nrows = csc.nrows();
309 let ncols = csc.ncols();
310 let nnz = csc.nnz();
311
312 let mut row_indices = Vec::with_capacity(nnz);
313 let mut col_indices = Vec::with_capacity(nnz);
314 let mut values = Vec::with_capacity(nnz);
315
316 for col in 0..ncols {
317 let start = csc.col_ptrs()[col];
318 let end = csc.col_ptrs()[col + 1];
319
320 for i in start..end {
321 row_indices.push(csc.row_indices()[i]);
322 col_indices.push(col);
323 values.push(csc.values()[i].clone());
324 }
325 }
326
327 unsafe { CooMatrix::new_unchecked(nrows, ncols, row_indices, col_indices, values) }
329}
330
331pub fn csr_to_dia<T: Scalar + Clone + Field>(
344 csr: &CsrMatrix<T>,
345 offsets: Option<Vec<isize>>,
346) -> DiaMatrix<T> {
347 let (nrows, ncols) = csr.shape();
348 let eps = <T as Scalar>::epsilon();
349
350 let offsets = offsets.unwrap_or_else(|| {
352 let mut found = std::collections::HashSet::new();
353 for (row, col, val) in csr.iter() {
354 if Scalar::abs(val.clone()) > eps {
355 found.insert(col as isize - row as isize);
356 }
357 }
358 let mut offsets: Vec<_> = found.into_iter().collect();
359 offsets.sort();
360 offsets
361 });
362
363 if offsets.is_empty() {
364 return DiaMatrix::zeros(nrows, ncols);
365 }
366
367 let diag_len = nrows.min(ncols);
368 let mut data = Vec::with_capacity(offsets.len());
369
370 for &offset in &offsets {
371 let mut diag = vec![T::zero(); diag_len];
372
373 for (row, col, val) in csr.iter() {
377 let expected_col = (row as isize + offset) as usize;
378 if col == expected_col && row < nrows && col < ncols {
379 let idx = (row as isize + offset) as usize;
381 if idx < diag_len {
382 diag[idx] = val.clone();
383 }
384 }
385 }
386
387 data.push(diag);
388 }
389
390 unsafe { DiaMatrix::new_unchecked(nrows, ncols, offsets, data) }
392}
393
394pub fn dia_to_csr<T: Scalar + Clone + Field>(dia: &DiaMatrix<T>) -> CsrMatrix<T> {
398 dia.to_csr()
399}
400
401pub fn csr_to_ell<T: Scalar + Clone + Field>(
414 csr: &CsrMatrix<T>,
415 max_width: Option<usize>,
416) -> Result<EllMatrix<T>, crate::ell::EllError> {
417 EllMatrix::from_csr(csr, max_width)
418}
419
420pub fn ell_to_csr<T: Scalar + Clone + Field>(ell: &EllMatrix<T>) -> CsrMatrix<T> {
424 ell.to_csr()
425}
426
427pub fn csr_to_bsr<T: Scalar + Clone + Field>(
441 csr: &CsrMatrix<T>,
442 block_rows: usize,
443 block_cols: usize,
444) -> BsrMatrix<T> {
445 BsrMatrix::from_csr(csr, block_rows, block_cols)
446}
447
448pub fn bsr_to_csr<T: Scalar + Clone + Field>(bsr: &BsrMatrix<T>) -> CsrMatrix<T> {
452 bsr.to_csr()
453}
454
455pub fn dia_to_ell<T: Scalar + Clone + Field>(
461 dia: &DiaMatrix<T>,
462 max_width: Option<usize>,
463) -> Result<EllMatrix<T>, crate::ell::EllError> {
464 let csr = dia.to_csr();
465 EllMatrix::from_csr(&csr, max_width)
466}
467
468pub fn ell_to_dia<T: Scalar + Clone + Field>(
470 ell: &EllMatrix<T>,
471 offsets: Option<Vec<isize>>,
472) -> DiaMatrix<T> {
473 let csr = ell.to_csr();
474 csr_to_dia(&csr, offsets)
475}
476
477pub fn dia_to_bsr<T: Scalar + Clone + Field>(
479 dia: &DiaMatrix<T>,
480 block_rows: usize,
481 block_cols: usize,
482) -> BsrMatrix<T> {
483 let csr = dia.to_csr();
484 BsrMatrix::from_csr(&csr, block_rows, block_cols)
485}
486
487pub fn bsr_to_dia<T: Scalar + Clone + Field>(
489 bsr: &BsrMatrix<T>,
490 offsets: Option<Vec<isize>>,
491) -> DiaMatrix<T> {
492 let csr = bsr.to_csr();
493 csr_to_dia(&csr, offsets)
494}
495
496pub fn ell_to_bsr<T: Scalar + Clone + Field>(
498 ell: &EllMatrix<T>,
499 block_rows: usize,
500 block_cols: usize,
501) -> BsrMatrix<T> {
502 let csr = ell.to_csr();
503 BsrMatrix::from_csr(&csr, block_rows, block_cols)
504}
505
506pub fn bsr_to_ell<T: Scalar + Clone + Field>(
508 bsr: &BsrMatrix<T>,
509 max_width: Option<usize>,
510) -> Result<EllMatrix<T>, crate::ell::EllError> {
511 let csr = bsr.to_csr();
512 EllMatrix::from_csr(&csr, max_width)
513}
514
515pub fn csr_to_bsc<T: Scalar + Clone + Field>(
527 csr: &CsrMatrix<T>,
528 block_rows: usize,
529 block_cols: usize,
530) -> BscMatrix<T> {
531 let bsr = BsrMatrix::from_csr(csr, block_rows, block_cols);
532 BscMatrix::from_bsr(&bsr)
533}
534
535pub fn bsc_to_csr<T: Scalar + Clone + Field>(bsc: &BscMatrix<T>) -> CsrMatrix<T> {
537 let bsr = bsc.to_bsr();
538 bsr.to_csr()
539}
540
541pub fn bsc_to_bsr<T: Scalar + Clone + Field>(bsc: &BscMatrix<T>) -> BsrMatrix<T> {
543 bsc.to_bsr()
544}
545
546pub fn bsr_to_bsc<T: Scalar + Clone + Field>(bsr: &BsrMatrix<T>) -> BscMatrix<T> {
548 BscMatrix::from_bsr(bsr)
549}
550
551pub fn csr_to_hyb<T: Scalar + Clone + Field>(
562 csr: &CsrMatrix<T>,
563 strategy: HybWidthStrategy,
564) -> HybMatrix<T> {
565 HybMatrix::from_csr(csr, strategy)
566}
567
568pub fn hyb_to_csr<T: Scalar + Clone + Field>(hyb: &HybMatrix<T>) -> CsrMatrix<T> {
570 hyb.to_csr()
571}
572
573pub fn ell_to_hyb<T: Scalar + Clone + Field>(ell: &EllMatrix<T>) -> HybMatrix<T> {
575 HybMatrix::from_ell(ell)
576}
577
578pub fn hyb_to_ell<T: Scalar + Clone + Field>(hyb: &HybMatrix<T>) -> EllMatrix<T> {
580 hyb.to_ell()
581}
582
583pub fn csr_to_sell<T: Scalar + Clone + Field>(
594 csr: &CsrMatrix<T>,
595 slice_size: SliceSize,
596) -> SellMatrix<T> {
597 SellMatrix::from_csr(csr, slice_size)
598}
599
600pub fn sell_to_csr<T: Scalar + Clone + Field>(sell: &SellMatrix<T>) -> CsrMatrix<T> {
602 sell.to_csr()
603}
604
605#[derive(Debug, Clone, Copy, PartialEq, Eq)]
611pub enum RecommendedFormat {
612 Csr,
614 Csc,
616 Dia,
618 Ell,
620 Hyb,
622 Sell,
624 Bsr,
626 Bsc,
628}
629
630#[derive(Debug, Clone)]
632pub struct SparsityAnalysis {
633 pub nrows: usize,
635 pub ncols: usize,
637 pub nnz: usize,
639 pub density: f64,
641 pub max_row_length: usize,
643 pub min_row_length: usize,
645 pub avg_row_length: f64,
647 pub row_length_stddev: f64,
649 pub num_diagonals: usize,
651 pub has_block_structure: bool,
653 pub detected_block_size: Option<(usize, usize)>,
655 pub recommended_format: RecommendedFormat,
657}
658
659pub fn analyze_sparsity_pattern<T: Scalar + Clone + Field>(csr: &CsrMatrix<T>) -> SparsityAnalysis {
665 let (nrows, ncols) = csr.shape();
666 let nnz = csr.nnz();
667 let eps = <T as Scalar>::epsilon();
668
669 if nrows == 0 || ncols == 0 {
670 return SparsityAnalysis {
671 nrows,
672 ncols,
673 nnz,
674 density: 0.0,
675 max_row_length: 0,
676 min_row_length: 0,
677 avg_row_length: 0.0,
678 row_length_stddev: 0.0,
679 num_diagonals: 0,
680 has_block_structure: false,
681 detected_block_size: None,
682 recommended_format: RecommendedFormat::Csr,
683 };
684 }
685
686 let mut row_lengths = Vec::with_capacity(nrows);
688 for row in 0..nrows {
689 let mut count = 0;
690 for (_, val) in csr.row_iter(row) {
691 if Scalar::abs(val.clone()) > eps {
692 count += 1;
693 }
694 }
695 row_lengths.push(count);
696 }
697
698 let max_row_length = row_lengths.iter().max().copied().unwrap_or(0);
699 let min_row_length = row_lengths.iter().min().copied().unwrap_or(0);
700 let avg_row_length = if nrows > 0 {
701 row_lengths.iter().sum::<usize>() as f64 / nrows as f64
702 } else {
703 0.0
704 };
705
706 let variance: f64 = row_lengths
708 .iter()
709 .map(|&x| {
710 let diff = x as f64 - avg_row_length;
711 diff * diff
712 })
713 .sum::<f64>()
714 / nrows.max(1) as f64;
715 let row_length_stddev = variance.sqrt();
716
717 let mut diagonals = std::collections::HashSet::new();
719 for (row, col, val) in csr.iter() {
720 if Scalar::abs(val.clone()) > eps {
721 diagonals.insert(col as isize - row as isize);
722 }
723 }
724 let num_diagonals = diagonals.len();
725
726 let (has_block_structure, detected_block_size) = detect_block_structure(csr);
728
729 let density = if nrows * ncols > 0 {
730 nnz as f64 / (nrows * ncols) as f64
731 } else {
732 0.0
733 };
734
735 let recommended_format = determine_recommended_format(
737 nrows,
738 ncols,
739 nnz,
740 max_row_length,
741 min_row_length,
742 row_length_stddev,
743 num_diagonals,
744 has_block_structure,
745 );
746
747 SparsityAnalysis {
748 nrows,
749 ncols,
750 nnz,
751 density,
752 max_row_length,
753 min_row_length,
754 avg_row_length,
755 row_length_stddev,
756 num_diagonals,
757 has_block_structure,
758 detected_block_size,
759 recommended_format,
760 }
761}
762
763fn detect_block_structure<T: Scalar + Clone + Field>(
765 csr: &CsrMatrix<T>,
766) -> (bool, Option<(usize, usize)>) {
767 let (nrows, ncols) = csr.shape();
768 let eps = <T as Scalar>::epsilon();
769
770 if nrows < 4 || ncols < 4 {
771 return (false, None);
772 }
773
774 for block_size in [2, 3, 4, 6, 8] {
776 if nrows % block_size != 0 || ncols % block_size != 0 {
777 continue;
778 }
779
780 let _num_block_rows = nrows / block_size;
781 let _num_block_cols = ncols / block_size;
782
783 let block_aligned = true;
785 let mut blocks_found = std::collections::HashSet::new();
786
787 for (row, col, val) in csr.iter() {
788 if Scalar::abs(val.clone()) > eps {
789 let block_row = row / block_size;
790 let block_col = col / block_size;
791 blocks_found.insert((block_row, block_col));
792 }
793 }
794
795 let mut dense_blocks = 0;
797 for &(br, bc) in &blocks_found {
798 let mut count = 0;
799 for i in 0..block_size {
800 for j in 0..block_size {
801 let row = br * block_size + i;
802 let col = bc * block_size + j;
803 if let Some(val) = csr.get(row, col) {
804 if Scalar::abs(val.clone()) > eps {
805 count += 1;
806 }
807 }
808 }
809 }
810 if count * 2 >= block_size * block_size {
812 dense_blocks += 1;
813 }
814 }
815
816 if !blocks_found.is_empty() && dense_blocks * 10 >= blocks_found.len() * 7 {
818 return (true, Some((block_size, block_size)));
819 }
820 if !block_aligned {
821 continue;
823 }
824 }
825
826 (false, None)
827}
828
829fn determine_recommended_format(
831 nrows: usize,
832 ncols: usize,
833 nnz: usize,
834 max_row_length: usize,
835 min_row_length: usize,
836 row_length_stddev: f64,
837 num_diagonals: usize,
838 has_block_structure: bool,
839) -> RecommendedFormat {
840 if nnz == 0 || nrows <= 10 || ncols <= 10 {
842 return RecommendedFormat::Csr;
843 }
844
845 let avg_row_length = nnz as f64 / nrows.max(1) as f64;
846
847 if has_block_structure {
849 return RecommendedFormat::Bsr;
850 }
851
852 if num_diagonals <= 10 && num_diagonals * 2 <= nrows.max(1) {
855 return RecommendedFormat::Dia;
856 }
857
858 let coefficient_of_variation = row_length_stddev / avg_row_length.max(1.0);
860
861 if coefficient_of_variation < 0.3 {
862 return RecommendedFormat::Ell;
864 }
865
866 if coefficient_of_variation < 0.8 {
867 return RecommendedFormat::Hyb;
869 }
870
871 if max_row_length > min_row_length * 10 {
873 return RecommendedFormat::Sell;
875 }
876
877 RecommendedFormat::Csr
879}
880
881#[cfg(test)]
882mod tests {
883 use super::*;
884 use crate::bsr::DenseBlock;
885
886 #[test]
887 fn test_csr_to_csc() {
888 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
892 let col_indices = vec![0, 2, 1, 0, 2];
893 let row_ptrs = vec![0, 2, 3, 5];
894
895 let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
896 let csc = csr_to_csc(&csr);
897
898 assert_eq!(csc.nnz(), 5);
899 assert_eq!(csc.get(0, 0), Some(&1.0));
900 assert_eq!(csc.get(0, 2), Some(&2.0));
901 assert_eq!(csc.get(1, 1), Some(&3.0));
902 assert_eq!(csc.get(2, 0), Some(&4.0));
903 assert_eq!(csc.get(2, 2), Some(&5.0));
904 }
905
906 #[test]
907 fn test_csc_to_csr() {
908 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
912 let row_indices = vec![0, 2, 1, 0, 2];
913 let col_ptrs = vec![0, 2, 3, 5];
914
915 let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
916 let csr = csc_to_csr(&csc);
917
918 assert_eq!(csr.nnz(), 5);
919 assert_eq!(csr.get(0, 0), Some(&1.0));
920 assert_eq!(csr.get(0, 2), Some(&4.0));
921 assert_eq!(csr.get(1, 1), Some(&3.0));
922 assert_eq!(csr.get(2, 0), Some(&2.0));
923 assert_eq!(csr.get(2, 2), Some(&5.0));
924 }
925
926 #[test]
927 fn test_coo_to_csr() {
928 let row_indices = vec![0, 1, 2, 0, 2];
929 let col_indices = vec![0, 1, 0, 2, 2];
930 let values = vec![1.0f64, 3.0, 4.0, 2.0, 5.0];
931
932 let coo = CooMatrix::new(3, 3, row_indices, col_indices, values).unwrap();
933 let csr = coo_to_csr(&coo);
934
935 assert_eq!(csr.nnz(), 5);
936 assert_eq!(csr.get(0, 0), Some(&1.0));
937 assert_eq!(csr.get(0, 2), Some(&2.0));
938 assert_eq!(csr.get(1, 1), Some(&3.0));
939 assert_eq!(csr.get(2, 0), Some(&4.0));
940 assert_eq!(csr.get(2, 2), Some(&5.0));
941 }
942
943 #[test]
944 fn test_coo_to_csr_duplicates() {
945 let row_indices = vec![0, 0, 1];
947 let col_indices = vec![0, 0, 1];
948 let values = vec![1.0f64, 2.0, 3.0];
949
950 let coo = CooMatrix::new(2, 2, row_indices, col_indices, values).unwrap();
951 let csr = coo_to_csr(&coo);
952
953 assert_eq!(csr.nnz(), 2);
954 assert_eq!(csr.get(0, 0), Some(&3.0)); assert_eq!(csr.get(1, 1), Some(&3.0));
956 }
957
958 #[test]
959 fn test_coo_to_csc() {
960 let row_indices = vec![0, 1, 2, 0, 2];
961 let col_indices = vec![0, 1, 0, 2, 2];
962 let values = vec![1.0f64, 3.0, 4.0, 2.0, 5.0];
963
964 let coo = CooMatrix::new(3, 3, row_indices, col_indices, values).unwrap();
965 let csc = coo_to_csc(&coo);
966
967 assert_eq!(csc.nnz(), 5);
968 assert_eq!(csc.get(0, 0), Some(&1.0));
969 assert_eq!(csc.get(0, 2), Some(&2.0));
970 assert_eq!(csc.get(1, 1), Some(&3.0));
971 assert_eq!(csc.get(2, 0), Some(&4.0));
972 assert_eq!(csc.get(2, 2), Some(&5.0));
973 }
974
975 #[test]
976 fn test_roundtrip_csr_csc_csr() {
977 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
978 let col_indices = vec![0, 2, 1, 0, 2];
979 let row_ptrs = vec![0, 2, 3, 5];
980
981 let csr1 = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
982 let csc = csr_to_csc(&csr1);
983 let csr2 = csc_to_csr(&csc);
984
985 assert_eq!(csr1.nnz(), csr2.nnz());
986 for row in 0..3 {
987 for col in 0..3 {
988 assert_eq!(csr1.get(row, col), csr2.get(row, col));
989 }
990 }
991 }
992
993 #[test]
994 fn test_csr_to_coo() {
995 let values = vec![1.0f64, 2.0, 3.0];
996 let col_indices = vec![0, 1, 2];
997 let row_ptrs = vec![0, 1, 2, 3];
998
999 let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
1000 let coo = csr_to_coo(&csr);
1001
1002 assert_eq!(coo.len(), 3);
1003 let entries: Vec<_> = coo.iter().map(|(r, c, v)| (r, c, *v)).collect();
1004 assert_eq!(entries, vec![(0, 0, 1.0), (1, 1, 2.0), (2, 2, 3.0)]);
1005 }
1006
1007 #[test]
1008 fn test_empty_matrix_conversion() {
1009 let csr: CsrMatrix<f64> = CsrMatrix::zeros(5, 3);
1010 let csc = csr_to_csc(&csr);
1011
1012 assert_eq!(csc.nrows(), 5);
1013 assert_eq!(csc.ncols(), 3);
1014 assert_eq!(csc.nnz(), 0);
1015 }
1016
1017 #[test]
1022 fn test_csr_to_dia_tridiagonal() {
1023 let values = vec![4.0f64, 1.0, 2.0, 5.0, 1.0, 3.0, 6.0];
1028 let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
1029 let row_ptrs = vec![0, 2, 5, 7];
1030
1031 let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
1032 let dia = csr_to_dia(&csr, None);
1033
1034 assert_eq!(dia.ndiag(), 3);
1035 assert_eq!(dia.get(0, 0), Some(&4.0));
1036 assert_eq!(dia.get(0, 1), Some(&1.0));
1037 assert_eq!(dia.get(1, 0), Some(&2.0));
1038 assert_eq!(dia.get(1, 1), Some(&5.0));
1039 assert_eq!(dia.get(2, 2), Some(&6.0));
1040 }
1041
1042 #[test]
1043 fn test_dia_to_csr() {
1044 let offsets = vec![-1, 0, 1];
1045 let data = vec![
1046 vec![2.0, 3.0, 0.0],
1047 vec![4.0, 5.0, 6.0],
1048 vec![0.0, 1.0, 1.0],
1049 ];
1050
1051 let dia = DiaMatrix::new(3, 3, offsets, data).unwrap();
1052 let csr = dia_to_csr(&dia);
1053
1054 assert_eq!(csr.nrows(), 3);
1055 assert_eq!(csr.get(0, 0), Some(&4.0));
1056 assert_eq!(csr.get(1, 0), Some(&2.0));
1057 }
1058
1059 #[test]
1060 fn test_csr_dia_roundtrip() {
1061 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
1062 let col_indices = vec![0, 1, 1, 0, 2];
1063 let row_ptrs = vec![0, 2, 3, 5];
1064
1065 let csr1 = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
1066 let dia = csr_to_dia(&csr1, None);
1067 let csr2 = dia_to_csr(&dia);
1068
1069 for row in 0..3 {
1070 for col in 0..3 {
1071 let v1 = csr1.get(row, col).cloned().unwrap_or(0.0);
1072 let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
1073 assert!((v1 - v2).abs() < 1e-10);
1074 }
1075 }
1076 }
1077
1078 #[test]
1083 fn test_csr_to_ell() {
1084 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1085 let col_indices = vec![0, 1, 1, 2, 0, 3];
1086 let row_ptrs = vec![0, 2, 4, 6];
1087
1088 let csr = CsrMatrix::new(3, 4, row_ptrs, col_indices, values).unwrap();
1089 let ell = csr_to_ell(&csr, None).unwrap();
1090
1091 assert_eq!(ell.width(), 2);
1092 assert_eq!(ell.get(0, 0), Some(&1.0));
1093 assert_eq!(ell.get(1, 2), Some(&4.0));
1094 }
1095
1096 #[test]
1097 fn test_ell_to_csr() {
1098 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1099 let indices = vec![vec![0, 1], vec![1, 2]];
1100
1101 let ell = EllMatrix::new(2, 3, 2, data, indices).unwrap();
1102 let csr = ell_to_csr(&ell);
1103
1104 assert_eq!(csr.nrows(), 2);
1105 assert_eq!(csr.get(0, 0), Some(&1.0));
1106 assert_eq!(csr.get(1, 2), Some(&4.0));
1107 }
1108
1109 #[test]
1110 fn test_csr_ell_roundtrip() {
1111 let values = vec![1.0f64, 2.0, 3.0, 4.0];
1112 let col_indices = vec![0, 1, 1, 2];
1113 let row_ptrs = vec![0, 2, 4];
1114
1115 let csr1 = CsrMatrix::new(2, 3, row_ptrs, col_indices, values).unwrap();
1116 let ell = csr_to_ell(&csr1, None).unwrap();
1117 let csr2 = ell_to_csr(&ell);
1118
1119 for row in 0..2 {
1120 for col in 0..3 {
1121 let v1 = csr1.get(row, col).cloned().unwrap_or(0.0);
1122 let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
1123 assert!((v1 - v2).abs() < 1e-10);
1124 }
1125 }
1126 }
1127
1128 #[test]
1133 fn test_csr_to_bsr() {
1134 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1136 let col_indices = vec![0, 1, 0, 1, 2, 3, 2, 3];
1137 let row_ptrs = vec![0, 2, 4, 6, 8];
1138
1139 let csr = CsrMatrix::new(4, 4, row_ptrs, col_indices, values).unwrap();
1140 let bsr = csr_to_bsr(&csr, 2, 2);
1141
1142 assert_eq!(bsr.nblocks(), 2);
1143 assert_eq!(bsr.get(0, 0), Some(1.0));
1144 assert_eq!(bsr.get(3, 3), Some(8.0));
1145 }
1146
1147 #[test]
1148 fn test_bsr_to_csr() {
1149 let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1150 let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1151
1152 let bsr =
1153 BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1154
1155 let csr = bsr_to_csr(&bsr);
1156
1157 assert_eq!(csr.nrows(), 4);
1158 assert_eq!(csr.get(0, 0), Some(&1.0));
1159 assert_eq!(csr.get(2, 2), Some(&5.0));
1160 }
1161
1162 #[test]
1163 fn test_csr_bsr_roundtrip() {
1164 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1165 let col_indices = vec![0, 1, 0, 1, 2, 3, 2, 3];
1166 let row_ptrs = vec![0, 2, 4, 6, 8];
1167
1168 let csr1 = CsrMatrix::new(4, 4, row_ptrs, col_indices, values).unwrap();
1169 let bsr = csr_to_bsr(&csr1, 2, 2);
1170 let csr2 = bsr_to_csr(&bsr);
1171
1172 for row in 0..4 {
1173 for col in 0..4 {
1174 let v1 = csr1.get(row, col).cloned().unwrap_or(0.0);
1175 let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
1176 assert!((v1 - v2).abs() < 1e-10);
1177 }
1178 }
1179 }
1180
1181 #[test]
1186 fn test_dia_to_ell() {
1187 let offsets = vec![0];
1188 let data = vec![vec![1.0, 2.0, 3.0]];
1189
1190 let dia = DiaMatrix::new(3, 3, offsets, data).unwrap();
1191 let ell = dia_to_ell(&dia, None).unwrap();
1192
1193 assert_eq!(ell.width(), 1);
1194 assert_eq!(ell.get(0, 0), Some(&1.0));
1195 assert_eq!(ell.get(1, 1), Some(&2.0));
1196 }
1197
1198 #[test]
1199 fn test_dia_to_bsr() {
1200 let offsets = vec![0];
1201 let data = vec![vec![1.0, 2.0, 3.0, 4.0]];
1202
1203 let dia = DiaMatrix::new(4, 4, offsets, data).unwrap();
1204 let bsr = dia_to_bsr(&dia, 2, 2);
1205
1206 assert_eq!(bsr.get(0, 0), Some(1.0));
1207 assert_eq!(bsr.get(1, 1), Some(2.0));
1208 }
1209
1210 #[test]
1211 fn test_ell_to_bsr() {
1212 let data = vec![
1213 vec![1.0, 2.0],
1214 vec![3.0, 4.0],
1215 vec![5.0, 6.0],
1216 vec![7.0, 8.0],
1217 ];
1218 let indices = vec![vec![0, 1], vec![0, 1], vec![2, 3], vec![2, 3]];
1219
1220 let ell = EllMatrix::new(4, 4, 2, data, indices).unwrap();
1221 let bsr = ell_to_bsr(&ell, 2, 2);
1222
1223 assert_eq!(bsr.nrows(), 4);
1224 assert_eq!(bsr.get(0, 0), Some(1.0));
1225 assert_eq!(bsr.get(3, 3), Some(8.0));
1226 }
1227}