1use crate::CooMatrixBuilder;
23use crate::csr::CsrMatrix;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum TestMatrixError {
28 InvalidDimension {
30 param: String,
32 value: usize,
34 },
35 InvalidDensity {
37 density: String,
39 },
40 ConstructionError {
42 description: String,
44 },
45}
46
47impl core::fmt::Display for TestMatrixError {
48 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
49 match self {
50 Self::InvalidDimension { param, value } => {
51 write!(f, "Invalid dimension for {param}: {value}")
52 }
53 Self::InvalidDensity { density } => {
54 write!(f, "Invalid density: {density} (must be in [0, 1])")
55 }
56 Self::ConstructionError { description } => {
57 write!(f, "Matrix construction error: {description}")
58 }
59 }
60 }
61}
62
63impl std::error::Error for TestMatrixError {}
64
65pub fn laplacian_2d(nx: usize, ny: usize) -> Result<CsrMatrix<f64>, TestMatrixError> {
86 if nx == 0 {
87 return Err(TestMatrixError::InvalidDimension {
88 param: "nx".to_string(),
89 value: 0,
90 });
91 }
92 if ny == 0 {
93 return Err(TestMatrixError::InvalidDimension {
94 param: "ny".to_string(),
95 value: 0,
96 });
97 }
98
99 let n = nx * ny;
100 let mut builder = CooMatrixBuilder::new(n, n);
101
102 for j in 0..ny {
103 for i in 0..nx {
104 let idx = j * nx + i;
105
106 builder.add(idx, idx, 4.0);
108
109 if i > 0 {
111 builder.add(idx, idx - 1, -1.0);
112 }
113
114 if i < nx - 1 {
116 builder.add(idx, idx + 1, -1.0);
117 }
118
119 if j > 0 {
121 builder.add(idx, idx - nx, -1.0);
122 }
123
124 if j < ny - 1 {
126 builder.add(idx, idx + nx, -1.0);
127 }
128 }
129 }
130
131 Ok(builder.build().to_csr())
132}
133
134pub fn laplacian_3d(nx: usize, ny: usize, nz: usize) -> Result<CsrMatrix<f64>, TestMatrixError> {
153 if nx == 0 {
154 return Err(TestMatrixError::InvalidDimension {
155 param: "nx".to_string(),
156 value: 0,
157 });
158 }
159 if ny == 0 {
160 return Err(TestMatrixError::InvalidDimension {
161 param: "ny".to_string(),
162 value: 0,
163 });
164 }
165 if nz == 0 {
166 return Err(TestMatrixError::InvalidDimension {
167 param: "nz".to_string(),
168 value: 0,
169 });
170 }
171
172 let n = nx * ny * nz;
173 let nxy = nx * ny;
174 let mut builder = CooMatrixBuilder::new(n, n);
175
176 for k in 0..nz {
177 for j in 0..ny {
178 for i in 0..nx {
179 let idx = k * nxy + j * nx + i;
180
181 builder.add(idx, idx, 6.0);
183
184 if i > 0 {
186 builder.add(idx, idx - 1, -1.0);
187 }
188 if i < nx - 1 {
189 builder.add(idx, idx + 1, -1.0);
190 }
191
192 if j > 0 {
194 builder.add(idx, idx - nx, -1.0);
195 }
196 if j < ny - 1 {
197 builder.add(idx, idx + nx, -1.0);
198 }
199
200 if k > 0 {
202 builder.add(idx, idx - nxy, -1.0);
203 }
204 if k < nz - 1 {
205 builder.add(idx, idx + nxy, -1.0);
206 }
207 }
208 }
209 }
210
211 Ok(builder.build().to_csr())
212}
213
214pub fn tridiagonal(
232 n: usize,
233 sub: f64,
234 diag: f64,
235 sup: f64,
236) -> Result<CsrMatrix<f64>, TestMatrixError> {
237 if n == 0 {
238 return Err(TestMatrixError::InvalidDimension {
239 param: "n".to_string(),
240 value: 0,
241 });
242 }
243
244 let mut builder = CooMatrixBuilder::new(n, n);
245
246 for i in 0..n {
247 builder.add(i, i, diag);
249
250 if i > 0 {
252 builder.add(i, i - 1, sub);
253 }
254
255 if i < n - 1 {
257 builder.add(i, i + 1, sup);
258 }
259 }
260
261 Ok(builder.build().to_csr())
262}
263
264pub fn diagonal(n: usize, value: f64) -> Result<CsrMatrix<f64>, TestMatrixError> {
278 if n == 0 {
279 return Err(TestMatrixError::InvalidDimension {
280 param: "n".to_string(),
281 value: 0,
282 });
283 }
284
285 let mut builder = CooMatrixBuilder::new(n, n);
286
287 for i in 0..n {
288 builder.add(i, i, value);
289 }
290
291 Ok(builder.build().to_csr())
292}
293
294pub fn arrow_matrix(n: usize) -> Result<CsrMatrix<f64>, TestMatrixError> {
316 if n == 0 {
317 return Err(TestMatrixError::InvalidDimension {
318 param: "n".to_string(),
319 value: 0,
320 });
321 }
322
323 if n == 1 {
324 let mut builder = CooMatrixBuilder::new(1, 1);
325 builder.add(0, 0, 1.0);
326 return Ok(builder.build().to_csr());
327 }
328
329 let mut builder = CooMatrixBuilder::new(n, n);
330
331 builder.add(0, 0, (n + 1) as f64);
334
335 for i in 1..n {
336 builder.add(0, i, 1.0);
338 builder.add(i, 0, 1.0);
340 builder.add(i, i, (n + 1 - i) as f64);
343 }
344
345 Ok(builder.build().to_csr())
346}
347
348pub fn random_spd(n: usize, density: f64) -> Result<CsrMatrix<f64>, TestMatrixError> {
365 if n == 0 {
366 return Err(TestMatrixError::InvalidDimension {
367 param: "n".to_string(),
368 value: 0,
369 });
370 }
371 if !(0.0..=1.0).contains(&density) {
372 return Err(TestMatrixError::InvalidDensity {
373 density: format!("{density}"),
374 });
375 }
376
377 let mut builder = CooMatrixBuilder::new(n, n);
381
382 for i in 0..n {
384 builder.add(i, i, n as f64);
385 }
386
387 if density <= 0.0 || n == 1 {
388 return Ok(builder.build().to_csr());
389 }
390
391 let max_lt_entries = n * (n - 1) / 2;
393 let target_lt_nnz = ((max_lt_entries as f64) * density).ceil() as usize;
394
395 let mut seed: u64 = 0x517cc1b727220a95;
397 let mut generated = 0usize;
398 let max_attempts = max_lt_entries * 3; for attempt in 0..max_attempts {
401 if generated >= target_lt_nnz {
402 break;
403 }
404
405 seed ^= seed.wrapping_shl(13);
407 seed ^= seed.wrapping_shr(7);
408 seed ^= seed.wrapping_shl(17);
409 seed = seed.wrapping_add(attempt as u64);
410
411 let row = ((seed >> 16) as usize) % n;
412 let col = ((seed >> 32) as usize) % n;
413
414 if row > col {
415 let val = 0.1 + 0.9 * ((seed & 0xFF) as f64) / 255.0;
417 builder.add(row, col, val);
419 builder.add(col, row, val);
420 builder.add(row, row, val);
422 builder.add(col, col, val);
423 generated += 1;
424 }
425 }
426
427 let mut coo = builder.build();
428 coo.sum_duplicates();
429 Ok(coo.to_csr())
430}
431
432pub fn poisson_1d(n: usize) -> Result<CsrMatrix<f64>, TestMatrixError> {
455 tridiagonal(n, -1.0, 2.0, -1.0)
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
467 fn test_laplacian_2d_basic() {
468 let mat = laplacian_2d(3, 3).expect("Failed to create 2D Laplacian");
469 assert_eq!(mat.nrows(), 9);
470 assert_eq!(mat.ncols(), 9);
471 assert_eq!(mat.nnz(), 33);
480 }
481
482 #[test]
483 fn test_laplacian_2d_symmetry() {
484 let mat = laplacian_2d(4, 3).expect("Failed to create 2D Laplacian");
485 let n = mat.nrows();
486
487 for (row, col, &val) in mat.iter() {
489 let transpose_val = mat.get_or_zero(col, row);
490 assert!(
491 (val - transpose_val).abs() < 1e-14,
492 "Laplacian 2D not symmetric at ({row}, {col}): {val} vs {transpose_val}"
493 );
494 }
495 assert_eq!(n, 12);
497 }
498
499 #[test]
500 fn test_laplacian_2d_positive_definiteness() {
501 let mat = laplacian_2d(5, 5).expect("Failed to create 2D Laplacian");
503 let n = mat.nrows();
504
505 for i in 0..n {
506 let diag = mat.get_or_zero(i, i);
507 let mut off_diag_sum = 0.0;
508 for (col, val) in mat.row_iter(i) {
509 if col != i {
510 off_diag_sum += val.abs();
511 }
512 }
513 assert!(
514 diag >= off_diag_sum,
515 "Row {i}: diagonal {diag} < off-diagonal sum {off_diag_sum}"
516 );
517 }
518 }
519
520 #[test]
521 fn test_laplacian_2d_1x1() {
522 let mat = laplacian_2d(1, 1).expect("Failed to create 1x1 Laplacian");
523 assert_eq!(mat.nrows(), 1);
524 assert_eq!(mat.ncols(), 1);
525 assert_eq!(mat.nnz(), 1);
526 assert!((mat.get_or_zero(0, 0) - 4.0).abs() < 1e-14);
527 }
528
529 #[test]
530 fn test_laplacian_2d_zero_dimension() {
531 assert!(laplacian_2d(0, 5).is_err());
532 assert!(laplacian_2d(5, 0).is_err());
533 }
534
535 #[test]
540 fn test_laplacian_3d_basic() {
541 let mat = laplacian_3d(3, 3, 3).expect("Failed to create 3D Laplacian");
542 assert_eq!(mat.nrows(), 27);
543 assert_eq!(mat.ncols(), 27);
544 assert_eq!(mat.nnz(), 135);
547 }
548
549 #[test]
550 fn test_laplacian_3d_symmetry() {
551 let mat = laplacian_3d(3, 2, 2).expect("Failed to create 3D Laplacian");
552
553 for (row, col, &val) in mat.iter() {
554 let transpose_val = mat.get_or_zero(col, row);
555 assert!(
556 (val - transpose_val).abs() < 1e-14,
557 "Laplacian 3D not symmetric at ({row}, {col}): {val} vs {transpose_val}"
558 );
559 }
560 }
561
562 #[test]
563 fn test_laplacian_3d_diagonal_dominance() {
564 let mat = laplacian_3d(3, 3, 3).expect("Failed to create 3D Laplacian");
565 let n = mat.nrows();
566
567 for i in 0..n {
568 let diag = mat.get_or_zero(i, i);
569 let mut off_diag_sum = 0.0;
570 for (col, val) in mat.row_iter(i) {
571 if col != i {
572 off_diag_sum += val.abs();
573 }
574 }
575 assert!(
576 diag >= off_diag_sum,
577 "Row {i}: diagonal {diag} < off-diagonal sum {off_diag_sum}"
578 );
579 }
580 }
581
582 #[test]
583 fn test_laplacian_3d_zero_dimension() {
584 assert!(laplacian_3d(0, 3, 3).is_err());
585 assert!(laplacian_3d(3, 0, 3).is_err());
586 assert!(laplacian_3d(3, 3, 0).is_err());
587 }
588
589 #[test]
594 fn test_tridiagonal_basic() {
595 let mat = tridiagonal(5, -1.0, 2.0, -1.0).expect("Failed to create tridiagonal");
596 assert_eq!(mat.nrows(), 5);
597 assert_eq!(mat.ncols(), 5);
598 assert_eq!(mat.nnz(), 13);
600 }
601
602 #[test]
603 fn test_tridiagonal_values() {
604 let mat = tridiagonal(4, -1.0, 3.0, -2.0).expect("Failed to create tridiagonal");
605
606 for i in 0..4 {
608 assert!((mat.get_or_zero(i, i) - 3.0).abs() < 1e-14);
609 }
610
611 for i in 1..4 {
613 assert!((mat.get_or_zero(i, i - 1) - (-1.0)).abs() < 1e-14);
614 }
615
616 for i in 0..3 {
618 assert!((mat.get_or_zero(i, i + 1) - (-2.0)).abs() < 1e-14);
619 }
620
621 assert!((mat.get_or_zero(0, 2)).abs() < 1e-14);
623 assert!((mat.get_or_zero(0, 3)).abs() < 1e-14);
624 assert!((mat.get_or_zero(3, 0)).abs() < 1e-14);
625 }
626
627 #[test]
628 fn test_tridiagonal_symmetric() {
629 let mat = tridiagonal(5, -1.0, 4.0, -1.0).expect("Failed to create symmetric tridiagonal");
630
631 for (row, col, &val) in mat.iter() {
632 let transpose_val = mat.get_or_zero(col, row);
633 assert!(
634 (val - transpose_val).abs() < 1e-14,
635 "Symmetric tridiagonal not symmetric at ({row}, {col})"
636 );
637 }
638 }
639
640 #[test]
641 fn test_tridiagonal_size_1() {
642 let mat = tridiagonal(1, -1.0, 5.0, -1.0).expect("Failed to create 1x1 tridiagonal");
643 assert_eq!(mat.nrows(), 1);
644 assert_eq!(mat.nnz(), 1);
645 assert!((mat.get_or_zero(0, 0) - 5.0).abs() < 1e-14);
646 }
647
648 #[test]
649 fn test_tridiagonal_zero() {
650 assert!(tridiagonal(0, -1.0, 2.0, -1.0).is_err());
651 }
652
653 #[test]
658 fn test_diagonal_basic() {
659 let mat = diagonal(5, 3.0).expect("Failed to create diagonal");
660 assert_eq!(mat.nrows(), 5);
661 assert_eq!(mat.ncols(), 5);
662 assert_eq!(mat.nnz(), 5);
663
664 for i in 0..5 {
665 assert!((mat.get_or_zero(i, i) - 3.0).abs() < 1e-14);
666 }
667 }
668
669 #[test]
670 fn test_diagonal_off_diagonal_zeros() {
671 let mat = diagonal(4, 2.0).expect("Failed to create diagonal");
672
673 for i in 0..4 {
674 for j in 0..4 {
675 if i != j {
676 assert!(
677 mat.get_or_zero(i, j).abs() < 1e-14,
678 "Off-diagonal entry ({i}, {j}) is not zero"
679 );
680 }
681 }
682 }
683 }
684
685 #[test]
686 fn test_diagonal_symmetry() {
687 let mat = diagonal(10, 7.5).expect("Failed to create diagonal");
688
689 for (row, col, &val) in mat.iter() {
690 let transpose_val = mat.get_or_zero(col, row);
691 assert!(
692 (val - transpose_val).abs() < 1e-14,
693 "Diagonal matrix not symmetric at ({row}, {col})"
694 );
695 }
696 }
697
698 #[test]
699 fn test_diagonal_zero() {
700 assert!(diagonal(0, 1.0).is_err());
701 }
702
703 #[test]
708 fn test_arrow_basic() {
709 let mat = arrow_matrix(5).expect("Failed to create arrow matrix");
710 assert_eq!(mat.nrows(), 5);
711 assert_eq!(mat.ncols(), 5);
712 assert_eq!(mat.nnz(), 13);
714 }
715
716 #[test]
717 fn test_arrow_symmetry() {
718 let mat = arrow_matrix(8).expect("Failed to create arrow matrix");
719
720 for (row, col, &val) in mat.iter() {
721 let transpose_val = mat.get_or_zero(col, row);
722 assert!(
723 (val - transpose_val).abs() < 1e-14,
724 "Arrow matrix not symmetric at ({row}, {col}): {val} vs {transpose_val}"
725 );
726 }
727 }
728
729 #[test]
730 fn test_arrow_structure() {
731 let n = 5;
732 let mat = arrow_matrix(n).expect("Failed to create arrow matrix");
733
734 let mut first_row_cols = Vec::new();
736 for (col, _) in mat.row_iter(0) {
737 first_row_cols.push(col);
738 }
739 assert_eq!(first_row_cols.len(), n, "First row should have {n} entries");
740
741 for i in 1..n {
743 let mut row_nnz = 0;
744 for _ in mat.row_iter(i) {
745 row_nnz += 1;
746 }
747 assert_eq!(
748 row_nnz, 2,
749 "Row {i} should have exactly 2 entries, got {row_nnz}"
750 );
751 }
752 }
753
754 #[test]
755 fn test_arrow_diagonal_dominance() {
756 let mat = arrow_matrix(10).expect("Failed to create arrow matrix");
757 let n = mat.nrows();
758
759 for i in 0..n {
760 let diag = mat.get_or_zero(i, i);
761 let mut off_diag_sum = 0.0;
762 for (col, val) in mat.row_iter(i) {
763 if col != i {
764 off_diag_sum += val.abs();
765 }
766 }
767 assert!(
768 diag >= off_diag_sum,
769 "Arrow row {i}: diagonal {diag} < off-diagonal sum {off_diag_sum}"
770 );
771 }
772 }
773
774 #[test]
775 fn test_arrow_size_1() {
776 let mat = arrow_matrix(1).expect("Failed to create 1x1 arrow");
777 assert_eq!(mat.nrows(), 1);
778 assert_eq!(mat.nnz(), 1);
779 }
780
781 #[test]
782 fn test_arrow_zero() {
783 assert!(arrow_matrix(0).is_err());
784 }
785
786 #[test]
791 fn test_random_spd_basic() {
792 let mat = random_spd(10, 0.3).expect("Failed to create random SPD");
793 assert_eq!(mat.nrows(), 10);
794 assert_eq!(mat.ncols(), 10);
795 assert!(mat.nnz() >= 10, "Should have at least n entries (diagonal)");
796 }
797
798 #[test]
799 fn test_random_spd_symmetry() {
800 let mat = random_spd(20, 0.2).expect("Failed to create random SPD");
801
802 for (row, col, &val) in mat.iter() {
803 let transpose_val = mat.get_or_zero(col, row);
804 assert!(
805 (val - transpose_val).abs() < 1e-10,
806 "Random SPD not symmetric at ({row}, {col}): {val} vs {transpose_val}"
807 );
808 }
809 }
810
811 #[test]
812 fn test_random_spd_positive_diagonal() {
813 let mat = random_spd(15, 0.4).expect("Failed to create random SPD");
814
815 for i in 0..mat.nrows() {
816 let diag = mat.get_or_zero(i, i);
817 assert!(
818 diag > 0.0,
819 "Random SPD diagonal at {i} should be positive, got {diag}"
820 );
821 }
822 }
823
824 #[test]
825 fn test_random_spd_zero_density() {
826 let mat = random_spd(5, 0.0).expect("Failed to create diagonal SPD");
827 assert_eq!(mat.nnz(), 5, "Zero density should yield diagonal matrix");
828 }
829
830 #[test]
831 fn test_random_spd_invalid() {
832 assert!(random_spd(0, 0.5).is_err());
833 assert!(random_spd(5, -0.1).is_err());
834 assert!(random_spd(5, 1.1).is_err());
835 }
836
837 #[test]
842 fn test_poisson_1d_basic() {
843 let mat = poisson_1d(5).expect("Failed to create 1D Poisson");
844 assert_eq!(mat.nrows(), 5);
845 assert_eq!(mat.ncols(), 5);
846 assert_eq!(mat.nnz(), 13); }
848
849 #[test]
850 fn test_poisson_1d_values() {
851 let mat = poisson_1d(4).expect("Failed to create 1D Poisson");
852
853 for i in 0..4 {
855 assert!((mat.get_or_zero(i, i) - 2.0).abs() < 1e-14);
856 }
857
858 for i in 0..3 {
860 assert!((mat.get_or_zero(i, i + 1) - (-1.0)).abs() < 1e-14);
861 assert!((mat.get_or_zero(i + 1, i) - (-1.0)).abs() < 1e-14);
862 }
863 }
864
865 #[test]
866 fn test_poisson_1d_symmetry() {
867 let mat = poisson_1d(10).expect("Failed to create 1D Poisson");
868
869 for (row, col, &val) in mat.iter() {
870 let transpose_val = mat.get_or_zero(col, row);
871 assert!(
872 (val - transpose_val).abs() < 1e-14,
873 "Poisson 1D not symmetric at ({row}, {col})"
874 );
875 }
876 }
877
878 #[test]
879 fn test_poisson_1d_spd() {
880 let mat = poisson_1d(10).expect("Failed to create 1D Poisson");
882
883 for i in 0..mat.nrows() {
884 let diag = mat.get_or_zero(i, i);
885 let mut off_diag_sum = 0.0;
886 for (col, val) in mat.row_iter(i) {
887 if col != i {
888 off_diag_sum += val.abs();
889 }
890 }
891 assert!(
892 diag >= off_diag_sum,
893 "Poisson 1D row {i}: diagonal {diag} < off-diagonal sum {off_diag_sum}"
894 );
895 }
896 }
897
898 #[test]
899 fn test_poisson_1d_zero() {
900 assert!(poisson_1d(0).is_err());
901 }
902
903 #[test]
908 fn test_poisson_1d_equals_tridiagonal() {
909 let poisson = poisson_1d(10).expect("poisson_1d");
910 let tri = tridiagonal(10, -1.0, 2.0, -1.0).expect("tridiagonal");
911
912 assert_eq!(poisson.nnz(), tri.nnz());
913 assert_eq!(poisson.nrows(), tri.nrows());
914
915 for i in 0..10 {
916 for j in 0..10 {
917 let pval = poisson.get_or_zero(i, j);
918 let tval = tri.get_or_zero(i, j);
919 assert!(
920 (pval - tval).abs() < 1e-14,
921 "Poisson != tridiag at ({i}, {j}): {pval} vs {tval}"
922 );
923 }
924 }
925 }
926
927 #[test]
928 fn test_laplacian_2d_1d_consistency() {
929 let mat = laplacian_2d(5, 1).expect("laplacian_2d(5,1)");
932 assert_eq!(mat.nrows(), 5);
933 assert_eq!(mat.ncols(), 5);
934 for i in 0..5 {
936 assert!((mat.get_or_zero(i, i) - 4.0).abs() < 1e-14);
937 }
938 }
939
940 #[test]
941 fn test_generators_produce_valid_csr() {
942 let matrices: Vec<CsrMatrix<f64>> = vec![
944 laplacian_2d(4, 4).expect("lap2d"),
945 laplacian_3d(3, 3, 3).expect("lap3d"),
946 tridiagonal(10, -1.0, 2.0, -1.0).expect("tridiag"),
947 diagonal(10, 5.0).expect("diag"),
948 arrow_matrix(10).expect("arrow"),
949 random_spd(10, 0.3).expect("rspd"),
950 poisson_1d(10).expect("poisson"),
951 ];
952
953 for mat in &matrices {
954 let row_ptrs = mat.row_ptrs();
956 for i in 1..row_ptrs.len() {
957 assert!(row_ptrs[i] >= row_ptrs[i - 1], "Row pointers not monotonic");
958 }
959
960 for &col in mat.col_indices() {
962 assert!(col < mat.ncols(), "Column index out of bounds");
963 }
964
965 assert_eq!(
967 mat.nnz(),
968 row_ptrs[mat.nrows()],
969 "nnz mismatch with row_ptrs"
970 );
971 }
972 }
973}