Skip to main content

oxiblas_sparse/
test_matrices.rs

1//! Standard test matrix generators for sparse operations.
2//!
3//! Provides well-known test matrices commonly used in numerical linear algebra
4//! benchmarks and validation. All generators return `CsrMatrix<f64>`.
5//!
6//! # Available Generators
7//!
8//! - [`laplacian_2d`] - 2D Laplacian with 5-point stencil
9//! - [`laplacian_3d`] - 3D Laplacian with 7-point stencil
10//! - [`tridiagonal`] - General tridiagonal matrix
11//! - [`diagonal`] - Diagonal matrix
12//! - [`arrow_matrix`] - Arrow/arrowhead matrix
13//! - [`random_spd`] - Random symmetric positive definite matrix
14//! - [`poisson_1d`] - 1D Poisson matrix
15//!
16//! # Properties
17//!
18//! Most generators produce matrices with known mathematical properties
19//! (symmetry, positive definiteness, specific eigenvalue distributions)
20//! making them suitable for validating solvers and preconditioners.
21
22use crate::CooMatrixBuilder;
23use crate::csr::CsrMatrix;
24
25/// Error type for test matrix generation.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum TestMatrixError {
28    /// Invalid dimensions (zero or too large).
29    InvalidDimension {
30        /// Description of the invalid parameter.
31        param: String,
32        /// The invalid value.
33        value: usize,
34    },
35    /// Density out of valid range [0, 1].
36    InvalidDensity {
37        /// The invalid density value.
38        density: String,
39    },
40    /// Matrix construction failed internally.
41    ConstructionError {
42        /// Description of the failure.
43        description: String,
44    },
45}
46
47impl core::fmt::Display for TestMatrixError {
48    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
49        match self {
50            Self::InvalidDimension { param, value } => {
51                write!(f, "Invalid dimension for {param}: {value}")
52            }
53            Self::InvalidDensity { density } => {
54                write!(f, "Invalid density: {density} (must be in [0, 1])")
55            }
56            Self::ConstructionError { description } => {
57                write!(f, "Matrix construction error: {description}")
58            }
59        }
60    }
61}
62
63impl std::error::Error for TestMatrixError {}
64
65/// Creates a 2D Laplacian matrix using the 5-point stencil.
66///
67/// For an `nx` x `ny` grid, produces an `(nx*ny)` x `(nx*ny)` sparse matrix
68/// representing the discrete Laplacian operator with Dirichlet boundary conditions.
69///
70/// The stencil at interior point (i,j) is:
71/// ```text
72///       -1
73///   -1   4  -1
74///       -1
75/// ```
76///
77/// # Properties
78/// - Symmetric positive definite (SPD)
79/// - Eigenvalues: 4 - 2*cos(pi*k/(nx+1)) - 2*cos(pi*l/(ny+1)) for k=1..nx, l=1..ny
80/// - Condition number grows as O(max(nx,ny)^2)
81/// - nnz = 5*nx*ny - 2*(nx+ny)
82///
83/// # Errors
84/// Returns error if `nx` or `ny` is zero.
85pub fn laplacian_2d(nx: usize, ny: usize) -> Result<CsrMatrix<f64>, TestMatrixError> {
86    if nx == 0 {
87        return Err(TestMatrixError::InvalidDimension {
88            param: "nx".to_string(),
89            value: 0,
90        });
91    }
92    if ny == 0 {
93        return Err(TestMatrixError::InvalidDimension {
94            param: "ny".to_string(),
95            value: 0,
96        });
97    }
98
99    let n = nx * ny;
100    let mut builder = CooMatrixBuilder::new(n, n);
101
102    for j in 0..ny {
103        for i in 0..nx {
104            let idx = j * nx + i;
105
106            // Diagonal: 4
107            builder.add(idx, idx, 4.0);
108
109            // Left neighbor: -1
110            if i > 0 {
111                builder.add(idx, idx - 1, -1.0);
112            }
113
114            // Right neighbor: -1
115            if i < nx - 1 {
116                builder.add(idx, idx + 1, -1.0);
117            }
118
119            // Bottom neighbor: -1
120            if j > 0 {
121                builder.add(idx, idx - nx, -1.0);
122            }
123
124            // Top neighbor: -1
125            if j < ny - 1 {
126                builder.add(idx, idx + nx, -1.0);
127            }
128        }
129    }
130
131    Ok(builder.build().to_csr())
132}
133
134/// Creates a 3D Laplacian matrix using the 7-point stencil.
135///
136/// For an `nx` x `ny` x `nz` grid, produces an `(nx*ny*nz)` x `(nx*ny*nz)` sparse matrix
137/// representing the discrete 3D Laplacian operator with Dirichlet boundary conditions.
138///
139/// The stencil at interior point (i,j,k) is:
140/// ```text
141/// z-1: -1    y-1: -1    center: -1  6  -1    y+1: -1    z+1: -1
142///                                x-1     x+1
143/// ```
144///
145/// # Properties
146/// - Symmetric positive definite (SPD)
147/// - Diagonal value: 6, off-diagonal values: -1
148/// - nnz = 7*nx*ny*nz - 2*(nx*ny + ny*nz + nx*nz)
149///
150/// # Errors
151/// Returns error if `nx`, `ny`, or `nz` is zero.
152pub fn laplacian_3d(nx: usize, ny: usize, nz: usize) -> Result<CsrMatrix<f64>, TestMatrixError> {
153    if nx == 0 {
154        return Err(TestMatrixError::InvalidDimension {
155            param: "nx".to_string(),
156            value: 0,
157        });
158    }
159    if ny == 0 {
160        return Err(TestMatrixError::InvalidDimension {
161            param: "ny".to_string(),
162            value: 0,
163        });
164    }
165    if nz == 0 {
166        return Err(TestMatrixError::InvalidDimension {
167            param: "nz".to_string(),
168            value: 0,
169        });
170    }
171
172    let n = nx * ny * nz;
173    let nxy = nx * ny;
174    let mut builder = CooMatrixBuilder::new(n, n);
175
176    for k in 0..nz {
177        for j in 0..ny {
178            for i in 0..nx {
179                let idx = k * nxy + j * nx + i;
180
181                // Diagonal: 6
182                builder.add(idx, idx, 6.0);
183
184                // x-direction neighbors
185                if i > 0 {
186                    builder.add(idx, idx - 1, -1.0);
187                }
188                if i < nx - 1 {
189                    builder.add(idx, idx + 1, -1.0);
190                }
191
192                // y-direction neighbors
193                if j > 0 {
194                    builder.add(idx, idx - nx, -1.0);
195                }
196                if j < ny - 1 {
197                    builder.add(idx, idx + nx, -1.0);
198                }
199
200                // z-direction neighbors
201                if k > 0 {
202                    builder.add(idx, idx - nxy, -1.0);
203                }
204                if k < nz - 1 {
205                    builder.add(idx, idx + nxy, -1.0);
206                }
207            }
208        }
209    }
210
211    Ok(builder.build().to_csr())
212}
213
214/// Creates a tridiagonal matrix with specified sub-diagonal, diagonal, and super-diagonal values.
215///
216/// Produces an `n` x `n` matrix:
217/// ```text
218/// [ diag  sup   0    0  ... ]
219/// [ sub   diag  sup  0  ... ]
220/// [ 0     sub   diag sup ... ]
221/// [ ...                      ]
222/// ```
223///
224/// # Properties
225/// - Bandwidth = 1
226/// - SPD if `sub == sup` and `diag > 2*|sub|` (diagonal dominance)
227/// - nnz = 3n - 2 for n >= 2
228///
229/// # Errors
230/// Returns error if `n` is zero.
231pub fn tridiagonal(
232    n: usize,
233    sub: f64,
234    diag: f64,
235    sup: f64,
236) -> Result<CsrMatrix<f64>, TestMatrixError> {
237    if n == 0 {
238        return Err(TestMatrixError::InvalidDimension {
239            param: "n".to_string(),
240            value: 0,
241        });
242    }
243
244    let mut builder = CooMatrixBuilder::new(n, n);
245
246    for i in 0..n {
247        // Main diagonal
248        builder.add(i, i, diag);
249
250        // Sub-diagonal
251        if i > 0 {
252            builder.add(i, i - 1, sub);
253        }
254
255        // Super-diagonal
256        if i < n - 1 {
257            builder.add(i, i + 1, sup);
258        }
259    }
260
261    Ok(builder.build().to_csr())
262}
263
264/// Creates a diagonal matrix with a constant value on the main diagonal.
265///
266/// Produces an `n` x `n` matrix with `value` on all diagonal entries and zeros elsewhere.
267///
268/// # Properties
269/// - Symmetric
270/// - Positive definite if `value > 0`
271/// - All eigenvalues equal `value`
272/// - nnz = n
273/// - Condition number = 1
274///
275/// # Errors
276/// Returns error if `n` is zero.
277pub fn diagonal(n: usize, value: f64) -> Result<CsrMatrix<f64>, TestMatrixError> {
278    if n == 0 {
279        return Err(TestMatrixError::InvalidDimension {
280            param: "n".to_string(),
281            value: 0,
282        });
283    }
284
285    let mut builder = CooMatrixBuilder::new(n, n);
286
287    for i in 0..n {
288        builder.add(i, i, value);
289    }
290
291    Ok(builder.build().to_csr())
292}
293
294/// Creates an arrow (arrowhead) matrix.
295///
296/// The arrow matrix has a dense first row, dense first column, and a diagonal:
297/// ```text
298/// [ d0  a1  a2  a3 ... ]
299/// [ a1  d1  0   0  ... ]
300/// [ a2  0   d2  0  ... ]
301/// [ a3  0   0   d3 ... ]
302/// [ ...                 ]
303/// ```
304///
305/// where `d_i = n + 1 - i` (diagonal values) and `a_i = 1` (arrow entries).
306///
307/// # Properties
308/// - Symmetric
309/// - Positive definite (with chosen diagonal values)
310/// - Tests sparse solvers with dense row/column structure
311/// - nnz = 3n - 2 for n >= 2
312///
313/// # Errors
314/// Returns error if `n` is zero.
315pub fn arrow_matrix(n: usize) -> Result<CsrMatrix<f64>, TestMatrixError> {
316    if n == 0 {
317        return Err(TestMatrixError::InvalidDimension {
318            param: "n".to_string(),
319            value: 0,
320        });
321    }
322
323    if n == 1 {
324        let mut builder = CooMatrixBuilder::new(1, 1);
325        builder.add(0, 0, 1.0);
326        return Ok(builder.build().to_csr());
327    }
328
329    let mut builder = CooMatrixBuilder::new(n, n);
330
331    // First row and column (arrow entries)
332    // Diagonal d0 = n + 1 to ensure diagonal dominance
333    builder.add(0, 0, (n + 1) as f64);
334
335    for i in 1..n {
336        // First row: a_i = 1
337        builder.add(0, i, 1.0);
338        // First column: a_i = 1 (symmetric)
339        builder.add(i, 0, 1.0);
340        // Diagonal: d_i = n + 1 - i to ensure positive definiteness
341        // We need d_i > 1/d_0 * sum(...) for SPD, so use generous values
342        builder.add(i, i, (n + 1 - i) as f64);
343    }
344
345    Ok(builder.build().to_csr())
346}
347
348/// Creates a random symmetric positive definite (SPD) matrix with a given density.
349///
350/// Constructs A = L * L^T + n*I where L is a sparse lower triangular matrix
351/// with entries drawn from a simple deterministic pseudo-random pattern.
352/// The `n*I` shift guarantees positive definiteness.
353///
354/// # Arguments
355/// - `n` - Matrix dimension
356/// - `density` - Target density of non-zeros in [0.0, 1.0]
357///
358/// # Properties
359/// - Symmetric positive definite (guaranteed)
360/// - Approximately `density * n * n` non-zeros
361///
362/// # Errors
363/// Returns error if `n` is zero or `density` is not in [0, 1].
364pub fn random_spd(n: usize, density: f64) -> Result<CsrMatrix<f64>, TestMatrixError> {
365    if n == 0 {
366        return Err(TestMatrixError::InvalidDimension {
367            param: "n".to_string(),
368            value: 0,
369        });
370    }
371    if !(0.0..=1.0).contains(&density) {
372        return Err(TestMatrixError::InvalidDensity {
373            density: format!("{density}"),
374        });
375    }
376
377    // Use a simple deterministic hash-based pseudo-random generator
378    // to avoid needing the rand crate in production code.
379    // We build a lower triangular L, then compute A = L*L^T + n*I.
380    let mut builder = CooMatrixBuilder::new(n, n);
381
382    // Add n*I for guaranteed positive definiteness
383    for i in 0..n {
384        builder.add(i, i, n as f64);
385    }
386
387    if density <= 0.0 || n == 1 {
388        return Ok(builder.build().to_csr());
389    }
390
391    // Target number of lower-triangular non-zeros (excluding diagonal)
392    let max_lt_entries = n * (n - 1) / 2;
393    let target_lt_nnz = ((max_lt_entries as f64) * density).ceil() as usize;
394
395    // Deterministic pseudo-random entry generation using a simple hash
396    let mut seed: u64 = 0x517cc1b727220a95;
397    let mut generated = 0usize;
398    let max_attempts = max_lt_entries * 3; // prevent infinite loop on high density
399
400    for attempt in 0..max_attempts {
401        if generated >= target_lt_nnz {
402            break;
403        }
404
405        // Simple xorshift-style hash
406        seed ^= seed.wrapping_shl(13);
407        seed ^= seed.wrapping_shr(7);
408        seed ^= seed.wrapping_shl(17);
409        seed = seed.wrapping_add(attempt as u64);
410
411        let row = ((seed >> 16) as usize) % n;
412        let col = ((seed >> 32) as usize) % n;
413
414        if row > col {
415            // Deterministic value in range [0.1, 1.0]
416            let val = 0.1 + 0.9 * ((seed & 0xFF) as f64) / 255.0;
417            // Add both (row, col) and (col, row) for symmetry
418            builder.add(row, col, val);
419            builder.add(col, row, val);
420            // Also strengthen the diagonal to maintain SPD
421            builder.add(row, row, val);
422            builder.add(col, col, val);
423            generated += 1;
424        }
425    }
426
427    let mut coo = builder.build();
428    coo.sum_duplicates();
429    Ok(coo.to_csr())
430}
431
432/// Creates the 1D Poisson matrix (second-difference operator).
433///
434/// Produces an `n` x `n` tridiagonal matrix:
435/// ```text
436/// [ 2  -1   0   0  ... ]
437/// [-1   2  -1   0  ... ]
438/// [ 0  -1   2  -1  ... ]
439/// [ ...                 ]
440/// [ 0   0  ... -1   2  ]
441/// ```
442///
443/// This is the standard discretization of -u''(x) = f(x) on \[0,1\]
444/// with uniform mesh spacing h = 1/(n+1), scaled by h^2.
445///
446/// # Properties
447/// - Symmetric positive definite (SPD)
448/// - Eigenvalues: 2 - 2*cos(k*pi/(n+1)) for k=1..n
449/// - Condition number: O(n^2)
450/// - nnz = 3n - 2 for n >= 2
451///
452/// # Errors
453/// Returns error if `n` is zero.
454pub fn poisson_1d(n: usize) -> Result<CsrMatrix<f64>, TestMatrixError> {
455    tridiagonal(n, -1.0, 2.0, -1.0)
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    // ========================================================================
463    // laplacian_2d tests
464    // ========================================================================
465
466    #[test]
467    fn test_laplacian_2d_basic() {
468        let mat = laplacian_2d(3, 3).expect("Failed to create 2D Laplacian");
469        assert_eq!(mat.nrows(), 9);
470        assert_eq!(mat.ncols(), 9);
471        // Interior points have 4 neighbors, boundary points have 2-3
472        // 5-point stencil on 3x3 grid:
473        // corners: 2 off-diag + 1 diag = 3 entries each (4 corners)
474        // edges: 3 off-diag + 1 diag = 4 entries each (4 edges for 3x3 minus corners = 4)
475        // center: 4 off-diag + 1 diag = 5 entries (1 center)
476        // Total: 4*3 + 4*4 + 1*5 = 12 + 16 + 5 = 33? No...
477        // Actually: nnz = 5*nx*ny - 2*(nx+ny) = 5*9 - 2*6 = 45 - 12 = 33
478        // But the formula includes the diagonal, so nnz = 5*3*3 - 2*(3+3) = 33
479        assert_eq!(mat.nnz(), 33);
480    }
481
482    #[test]
483    fn test_laplacian_2d_symmetry() {
484        let mat = laplacian_2d(4, 3).expect("Failed to create 2D Laplacian");
485        let n = mat.nrows();
486
487        // Check A[i,j] == A[j,i] for all entries
488        for (row, col, &val) in mat.iter() {
489            let transpose_val = mat.get_or_zero(col, row);
490            assert!(
491                (val - transpose_val).abs() < 1e-14,
492                "Laplacian 2D not symmetric at ({row}, {col}): {val} vs {transpose_val}"
493            );
494        }
495        // Also check dimensions
496        assert_eq!(n, 12);
497    }
498
499    #[test]
500    fn test_laplacian_2d_positive_definiteness() {
501        // Verify diagonal dominance (sufficient for SPD)
502        let mat = laplacian_2d(5, 5).expect("Failed to create 2D Laplacian");
503        let n = mat.nrows();
504
505        for i in 0..n {
506            let diag = mat.get_or_zero(i, i);
507            let mut off_diag_sum = 0.0;
508            for (col, val) in mat.row_iter(i) {
509                if col != i {
510                    off_diag_sum += val.abs();
511                }
512            }
513            assert!(
514                diag >= off_diag_sum,
515                "Row {i}: diagonal {diag} < off-diagonal sum {off_diag_sum}"
516            );
517        }
518    }
519
520    #[test]
521    fn test_laplacian_2d_1x1() {
522        let mat = laplacian_2d(1, 1).expect("Failed to create 1x1 Laplacian");
523        assert_eq!(mat.nrows(), 1);
524        assert_eq!(mat.ncols(), 1);
525        assert_eq!(mat.nnz(), 1);
526        assert!((mat.get_or_zero(0, 0) - 4.0).abs() < 1e-14);
527    }
528
529    #[test]
530    fn test_laplacian_2d_zero_dimension() {
531        assert!(laplacian_2d(0, 5).is_err());
532        assert!(laplacian_2d(5, 0).is_err());
533    }
534
535    // ========================================================================
536    // laplacian_3d tests
537    // ========================================================================
538
539    #[test]
540    fn test_laplacian_3d_basic() {
541        let mat = laplacian_3d(3, 3, 3).expect("Failed to create 3D Laplacian");
542        assert_eq!(mat.nrows(), 27);
543        assert_eq!(mat.ncols(), 27);
544        // nnz = 7*nx*ny*nz - 2*(nx*ny + ny*nz + nx*nz)
545        // = 7*27 - 2*(9 + 9 + 9) = 189 - 54 = 135
546        assert_eq!(mat.nnz(), 135);
547    }
548
549    #[test]
550    fn test_laplacian_3d_symmetry() {
551        let mat = laplacian_3d(3, 2, 2).expect("Failed to create 3D Laplacian");
552
553        for (row, col, &val) in mat.iter() {
554            let transpose_val = mat.get_or_zero(col, row);
555            assert!(
556                (val - transpose_val).abs() < 1e-14,
557                "Laplacian 3D not symmetric at ({row}, {col}): {val} vs {transpose_val}"
558            );
559        }
560    }
561
562    #[test]
563    fn test_laplacian_3d_diagonal_dominance() {
564        let mat = laplacian_3d(3, 3, 3).expect("Failed to create 3D Laplacian");
565        let n = mat.nrows();
566
567        for i in 0..n {
568            let diag = mat.get_or_zero(i, i);
569            let mut off_diag_sum = 0.0;
570            for (col, val) in mat.row_iter(i) {
571                if col != i {
572                    off_diag_sum += val.abs();
573                }
574            }
575            assert!(
576                diag >= off_diag_sum,
577                "Row {i}: diagonal {diag} < off-diagonal sum {off_diag_sum}"
578            );
579        }
580    }
581
582    #[test]
583    fn test_laplacian_3d_zero_dimension() {
584        assert!(laplacian_3d(0, 3, 3).is_err());
585        assert!(laplacian_3d(3, 0, 3).is_err());
586        assert!(laplacian_3d(3, 3, 0).is_err());
587    }
588
589    // ========================================================================
590    // tridiagonal tests
591    // ========================================================================
592
593    #[test]
594    fn test_tridiagonal_basic() {
595        let mat = tridiagonal(5, -1.0, 2.0, -1.0).expect("Failed to create tridiagonal");
596        assert_eq!(mat.nrows(), 5);
597        assert_eq!(mat.ncols(), 5);
598        // 3*5 - 2 = 13
599        assert_eq!(mat.nnz(), 13);
600    }
601
602    #[test]
603    fn test_tridiagonal_values() {
604        let mat = tridiagonal(4, -1.0, 3.0, -2.0).expect("Failed to create tridiagonal");
605
606        // Check diagonal
607        for i in 0..4 {
608            assert!((mat.get_or_zero(i, i) - 3.0).abs() < 1e-14);
609        }
610
611        // Check sub-diagonal
612        for i in 1..4 {
613            assert!((mat.get_or_zero(i, i - 1) - (-1.0)).abs() < 1e-14);
614        }
615
616        // Check super-diagonal
617        for i in 0..3 {
618            assert!((mat.get_or_zero(i, i + 1) - (-2.0)).abs() < 1e-14);
619        }
620
621        // Check zeros
622        assert!((mat.get_or_zero(0, 2)).abs() < 1e-14);
623        assert!((mat.get_or_zero(0, 3)).abs() < 1e-14);
624        assert!((mat.get_or_zero(3, 0)).abs() < 1e-14);
625    }
626
627    #[test]
628    fn test_tridiagonal_symmetric() {
629        let mat = tridiagonal(5, -1.0, 4.0, -1.0).expect("Failed to create symmetric tridiagonal");
630
631        for (row, col, &val) in mat.iter() {
632            let transpose_val = mat.get_or_zero(col, row);
633            assert!(
634                (val - transpose_val).abs() < 1e-14,
635                "Symmetric tridiagonal not symmetric at ({row}, {col})"
636            );
637        }
638    }
639
640    #[test]
641    fn test_tridiagonal_size_1() {
642        let mat = tridiagonal(1, -1.0, 5.0, -1.0).expect("Failed to create 1x1 tridiagonal");
643        assert_eq!(mat.nrows(), 1);
644        assert_eq!(mat.nnz(), 1);
645        assert!((mat.get_or_zero(0, 0) - 5.0).abs() < 1e-14);
646    }
647
648    #[test]
649    fn test_tridiagonal_zero() {
650        assert!(tridiagonal(0, -1.0, 2.0, -1.0).is_err());
651    }
652
653    // ========================================================================
654    // diagonal tests
655    // ========================================================================
656
657    #[test]
658    fn test_diagonal_basic() {
659        let mat = diagonal(5, 3.0).expect("Failed to create diagonal");
660        assert_eq!(mat.nrows(), 5);
661        assert_eq!(mat.ncols(), 5);
662        assert_eq!(mat.nnz(), 5);
663
664        for i in 0..5 {
665            assert!((mat.get_or_zero(i, i) - 3.0).abs() < 1e-14);
666        }
667    }
668
669    #[test]
670    fn test_diagonal_off_diagonal_zeros() {
671        let mat = diagonal(4, 2.0).expect("Failed to create diagonal");
672
673        for i in 0..4 {
674            for j in 0..4 {
675                if i != j {
676                    assert!(
677                        mat.get_or_zero(i, j).abs() < 1e-14,
678                        "Off-diagonal entry ({i}, {j}) is not zero"
679                    );
680                }
681            }
682        }
683    }
684
685    #[test]
686    fn test_diagonal_symmetry() {
687        let mat = diagonal(10, 7.5).expect("Failed to create diagonal");
688
689        for (row, col, &val) in mat.iter() {
690            let transpose_val = mat.get_or_zero(col, row);
691            assert!(
692                (val - transpose_val).abs() < 1e-14,
693                "Diagonal matrix not symmetric at ({row}, {col})"
694            );
695        }
696    }
697
698    #[test]
699    fn test_diagonal_zero() {
700        assert!(diagonal(0, 1.0).is_err());
701    }
702
703    // ========================================================================
704    // arrow_matrix tests
705    // ========================================================================
706
707    #[test]
708    fn test_arrow_basic() {
709        let mat = arrow_matrix(5).expect("Failed to create arrow matrix");
710        assert_eq!(mat.nrows(), 5);
711        assert_eq!(mat.ncols(), 5);
712        // nnz = n (diagonal) + 2*(n-1) (first row/col excluding diagonal) = 3n - 2
713        assert_eq!(mat.nnz(), 13);
714    }
715
716    #[test]
717    fn test_arrow_symmetry() {
718        let mat = arrow_matrix(8).expect("Failed to create arrow matrix");
719
720        for (row, col, &val) in mat.iter() {
721            let transpose_val = mat.get_or_zero(col, row);
722            assert!(
723                (val - transpose_val).abs() < 1e-14,
724                "Arrow matrix not symmetric at ({row}, {col}): {val} vs {transpose_val}"
725            );
726        }
727    }
728
729    #[test]
730    fn test_arrow_structure() {
731        let n = 5;
732        let mat = arrow_matrix(n).expect("Failed to create arrow matrix");
733
734        // First row should have entries in all columns
735        let mut first_row_cols = Vec::new();
736        for (col, _) in mat.row_iter(0) {
737            first_row_cols.push(col);
738        }
739        assert_eq!(first_row_cols.len(), n, "First row should have {n} entries");
740
741        // Other rows should have exactly 2 entries (diagonal + first column)
742        for i in 1..n {
743            let mut row_nnz = 0;
744            for _ in mat.row_iter(i) {
745                row_nnz += 1;
746            }
747            assert_eq!(
748                row_nnz, 2,
749                "Row {i} should have exactly 2 entries, got {row_nnz}"
750            );
751        }
752    }
753
754    #[test]
755    fn test_arrow_diagonal_dominance() {
756        let mat = arrow_matrix(10).expect("Failed to create arrow matrix");
757        let n = mat.nrows();
758
759        for i in 0..n {
760            let diag = mat.get_or_zero(i, i);
761            let mut off_diag_sum = 0.0;
762            for (col, val) in mat.row_iter(i) {
763                if col != i {
764                    off_diag_sum += val.abs();
765                }
766            }
767            assert!(
768                diag >= off_diag_sum,
769                "Arrow row {i}: diagonal {diag} < off-diagonal sum {off_diag_sum}"
770            );
771        }
772    }
773
774    #[test]
775    fn test_arrow_size_1() {
776        let mat = arrow_matrix(1).expect("Failed to create 1x1 arrow");
777        assert_eq!(mat.nrows(), 1);
778        assert_eq!(mat.nnz(), 1);
779    }
780
781    #[test]
782    fn test_arrow_zero() {
783        assert!(arrow_matrix(0).is_err());
784    }
785
786    // ========================================================================
787    // random_spd tests
788    // ========================================================================
789
790    #[test]
791    fn test_random_spd_basic() {
792        let mat = random_spd(10, 0.3).expect("Failed to create random SPD");
793        assert_eq!(mat.nrows(), 10);
794        assert_eq!(mat.ncols(), 10);
795        assert!(mat.nnz() >= 10, "Should have at least n entries (diagonal)");
796    }
797
798    #[test]
799    fn test_random_spd_symmetry() {
800        let mat = random_spd(20, 0.2).expect("Failed to create random SPD");
801
802        for (row, col, &val) in mat.iter() {
803            let transpose_val = mat.get_or_zero(col, row);
804            assert!(
805                (val - transpose_val).abs() < 1e-10,
806                "Random SPD not symmetric at ({row}, {col}): {val} vs {transpose_val}"
807            );
808        }
809    }
810
811    #[test]
812    fn test_random_spd_positive_diagonal() {
813        let mat = random_spd(15, 0.4).expect("Failed to create random SPD");
814
815        for i in 0..mat.nrows() {
816            let diag = mat.get_or_zero(i, i);
817            assert!(
818                diag > 0.0,
819                "Random SPD diagonal at {i} should be positive, got {diag}"
820            );
821        }
822    }
823
824    #[test]
825    fn test_random_spd_zero_density() {
826        let mat = random_spd(5, 0.0).expect("Failed to create diagonal SPD");
827        assert_eq!(mat.nnz(), 5, "Zero density should yield diagonal matrix");
828    }
829
830    #[test]
831    fn test_random_spd_invalid() {
832        assert!(random_spd(0, 0.5).is_err());
833        assert!(random_spd(5, -0.1).is_err());
834        assert!(random_spd(5, 1.1).is_err());
835    }
836
837    // ========================================================================
838    // poisson_1d tests
839    // ========================================================================
840
841    #[test]
842    fn test_poisson_1d_basic() {
843        let mat = poisson_1d(5).expect("Failed to create 1D Poisson");
844        assert_eq!(mat.nrows(), 5);
845        assert_eq!(mat.ncols(), 5);
846        assert_eq!(mat.nnz(), 13); // 3*5 - 2
847    }
848
849    #[test]
850    fn test_poisson_1d_values() {
851        let mat = poisson_1d(4).expect("Failed to create 1D Poisson");
852
853        // Check diagonal = 2
854        for i in 0..4 {
855            assert!((mat.get_or_zero(i, i) - 2.0).abs() < 1e-14);
856        }
857
858        // Check off-diagonal = -1
859        for i in 0..3 {
860            assert!((mat.get_or_zero(i, i + 1) - (-1.0)).abs() < 1e-14);
861            assert!((mat.get_or_zero(i + 1, i) - (-1.0)).abs() < 1e-14);
862        }
863    }
864
865    #[test]
866    fn test_poisson_1d_symmetry() {
867        let mat = poisson_1d(10).expect("Failed to create 1D Poisson");
868
869        for (row, col, &val) in mat.iter() {
870            let transpose_val = mat.get_or_zero(col, row);
871            assert!(
872                (val - transpose_val).abs() < 1e-14,
873                "Poisson 1D not symmetric at ({row}, {col})"
874            );
875        }
876    }
877
878    #[test]
879    fn test_poisson_1d_spd() {
880        // Check diagonal dominance (sufficient for SPD)
881        let mat = poisson_1d(10).expect("Failed to create 1D Poisson");
882
883        for i in 0..mat.nrows() {
884            let diag = mat.get_or_zero(i, i);
885            let mut off_diag_sum = 0.0;
886            for (col, val) in mat.row_iter(i) {
887                if col != i {
888                    off_diag_sum += val.abs();
889                }
890            }
891            assert!(
892                diag >= off_diag_sum,
893                "Poisson 1D row {i}: diagonal {diag} < off-diagonal sum {off_diag_sum}"
894            );
895        }
896    }
897
898    #[test]
899    fn test_poisson_1d_zero() {
900        assert!(poisson_1d(0).is_err());
901    }
902
903    // ========================================================================
904    // Cross-generator property tests
905    // ========================================================================
906
907    #[test]
908    fn test_poisson_1d_equals_tridiagonal() {
909        let poisson = poisson_1d(10).expect("poisson_1d");
910        let tri = tridiagonal(10, -1.0, 2.0, -1.0).expect("tridiagonal");
911
912        assert_eq!(poisson.nnz(), tri.nnz());
913        assert_eq!(poisson.nrows(), tri.nrows());
914
915        for i in 0..10 {
916            for j in 0..10 {
917                let pval = poisson.get_or_zero(i, j);
918                let tval = tri.get_or_zero(i, j);
919                assert!(
920                    (pval - tval).abs() < 1e-14,
921                    "Poisson != tridiag at ({i}, {j}): {pval} vs {tval}"
922                );
923            }
924        }
925    }
926
927    #[test]
928    fn test_laplacian_2d_1d_consistency() {
929        // laplacian_2d(n, 1) should produce a similar structure to a 1D Laplacian
930        // (tridiagonal with diagonal=4, off-diagonal=-1) since there's only 1 row in y
931        let mat = laplacian_2d(5, 1).expect("laplacian_2d(5,1)");
932        assert_eq!(mat.nrows(), 5);
933        assert_eq!(mat.ncols(), 5);
934        // Should be tridiagonal-like with diagonal=4
935        for i in 0..5 {
936            assert!((mat.get_or_zero(i, i) - 4.0).abs() < 1e-14);
937        }
938    }
939
940    #[test]
941    fn test_generators_produce_valid_csr() {
942        // All generators should produce valid CSR matrices
943        let matrices: Vec<CsrMatrix<f64>> = vec![
944            laplacian_2d(4, 4).expect("lap2d"),
945            laplacian_3d(3, 3, 3).expect("lap3d"),
946            tridiagonal(10, -1.0, 2.0, -1.0).expect("tridiag"),
947            diagonal(10, 5.0).expect("diag"),
948            arrow_matrix(10).expect("arrow"),
949            random_spd(10, 0.3).expect("rspd"),
950            poisson_1d(10).expect("poisson"),
951        ];
952
953        for mat in &matrices {
954            // Row pointers should be monotonically increasing
955            let row_ptrs = mat.row_ptrs();
956            for i in 1..row_ptrs.len() {
957                assert!(row_ptrs[i] >= row_ptrs[i - 1], "Row pointers not monotonic");
958            }
959
960            // Column indices should be in bounds
961            for &col in mat.col_indices() {
962                assert!(col < mat.ncols(), "Column index out of bounds");
963            }
964
965            // nnz should match row_ptrs
966            assert_eq!(
967                mat.nnz(),
968                row_ptrs[mat.nrows()],
969                "nnz mismatch with row_ptrs"
970            );
971        }
972    }
973}