Skip to main content

scirs2_sparse/
lib.rs

1#![allow(clippy::manual_div_ceil)]
2#![allow(clippy::needless_return)]
3#![allow(clippy::manual_ok_err)]
4#![allow(clippy::needless_range_loop)]
5#![allow(clippy::while_let_loop)]
6#![allow(clippy::vec_init_then_push)]
7#![allow(clippy::should_implement_trait)]
8#![allow(clippy::only_used_in_recursion)]
9#![allow(clippy::manual_slice_fill)]
10#![allow(dead_code)]
11//! # SciRS2 Sparse - Sparse Matrix Operations
12//!
13//! **scirs2-sparse** provides comprehensive sparse matrix formats and operations modeled after SciPy's
14//! `sparse` module, offering CSR, CSC, COO, DOK, LIL, DIA, BSR formats with efficient algorithms
15//! for large-scale sparse linear algebra, eigenvalue problems, and graph operations.
16//!
17//! ## 🎯 Key Features
18//!
19//! - **SciPy Compatibility**: Drop-in replacement for `scipy.sparse` classes
20//! - **Multiple Formats**: CSR, CSC, COO, DOK, LIL, DIA, BSR with easy conversion
21//! - **Efficient Operations**: Sparse matrix-vector/matrix multiplication
22//! - **Linear Solvers**: Direct (LU, Cholesky) and iterative (CG, GMRES) solvers
23//! - **Eigenvalue Solvers**: ARPACK-based sparse eigenvalue computation
24//! - **Array API**: Modern NumPy-compatible array interface (recommended)
25//!
26//! ## 📦 Module Overview
27//!
28//! | SciRS2 Format | SciPy Equivalent | Description |
29//! |---------------|------------------|-------------|
30//! | `CsrArray` | `scipy.sparse.csr_array` | Compressed Sparse Row (efficient row slicing) |
31//! | `CscArray` | `scipy.sparse.csc_array` | Compressed Sparse Column (efficient column slicing) |
32//! | `CooArray` | `scipy.sparse.coo_array` | Coordinate format (efficient construction) |
33//! | `DokArray` | `scipy.sparse.dok_array` | Dictionary of Keys (efficient element access) |
34//! | `LilArray` | `scipy.sparse.lil_array` | List of Lists (efficient incremental construction) |
35//! | `DiaArray` | `scipy.sparse.dia_array` | Diagonal format (efficient banded matrices) |
36//! | `BsrArray` | `scipy.sparse.bsr_array` | Block Sparse Row (efficient block operations) |
37//!
38//! ## 🚀 Quick Start
39//!
40//! ```toml
41//! [dependencies]
42//! scirs2-sparse = "0.2.0"
43//! ```
44//!
45//! ```rust
46//! use scirs2_sparse::csr_array::CsrArray;
47//!
48//! // Create sparse matrix from triplets (row, col, value)
49//! let rows = vec![0, 0, 1, 2, 2];
50//! let cols = vec![0, 2, 2, 0, 1];
51//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
52//! let sparse = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
53//! ```
54//!
55//! ## 🔒 Version: 0.2.0 (February 8, 2026)
56//!
57//! ## Matrix vs. Array API
58//!
59//! This module provides both a matrix-based API and an array-based API,
60//! following SciPy's transition to a more NumPy-compatible array interface.
61//!
62//! When using the array interface (e.g., `CsrArray`), please note that:
63//!
64//! - `*` performs element-wise multiplication, not matrix multiplication
65//! - Use `dot()` method for matrix multiplication
66//! - Operations like `sum` produce arrays, not matrices
67//! - Array-style slicing operations return scalars, 1D, or 2D arrays
68//!
69//! For new code, we recommend using the array interface, which is more consistent
70//! with the rest of the numerical ecosystem.
71//!
72//! ## Examples
73//!
74//! ### Matrix API (Legacy)
75//!
76//! ```
77//! use scirs2_sparse::csr::CsrMatrix;
78//!
79//! // Create a sparse matrix in CSR format
80//! let rows = vec![0, 0, 1, 2, 2];
81//! let cols = vec![0, 2, 2, 0, 1];
82//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
83//! let shape = (3, 3);
84//!
85//! let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
86//! ```
87//!
88//! ### Array API (Recommended)
89//!
90//! ```
91//! use scirs2_sparse::csr_array::CsrArray;
92//!
93//! // Create a sparse array in CSR format
94//! let rows = vec![0, 0, 1, 2, 2];
95//! let cols = vec![0, 2, 2, 0, 1];
96//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
97//! let shape = (3, 3);
98//!
99//! // From triplets (COO-like construction)
100//! let array = CsrArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
101//!
102//! // Or directly from CSR components
103//! // let array = CsrArray::new(...);
104//! ```
105
106// Export error types
107pub mod error;
108pub use error::{SparseError, SparseResult};
109
110// Base trait for sparse arrays
111pub mod sparray;
112pub use sparray::{is_sparse, SparseArray, SparseSum};
113
114// Trait for symmetric sparse arrays
115pub mod sym_sparray;
116pub use sym_sparray::SymSparseArray;
117
118// No spatial module in sparse
119
120// Array API (recommended)
121pub mod csr_array;
122pub use csr_array::CsrArray;
123
124pub mod csc_array;
125pub use csc_array::CscArray;
126
127pub mod coo_array;
128pub use coo_array::CooArray;
129
130pub mod dok_array;
131pub use dok_array::DokArray;
132
133pub mod lil_array;
134pub use lil_array::LilArray;
135
136pub mod dia_array;
137pub use dia_array::DiaArray;
138
139pub mod bsr_array;
140pub use bsr_array::BsrArray;
141
142pub mod banded_array;
143pub use banded_array::BandedArray;
144
145// Symmetric array formats
146pub mod sym_csr;
147pub use sym_csr::{SymCsrArray, SymCsrMatrix};
148
149pub mod sym_coo;
150pub use sym_coo::{SymCooArray, SymCooMatrix};
151
152// Legacy matrix formats
153pub mod csr;
154pub use csr::CsrMatrix;
155
156pub mod csc;
157pub use csc::CscMatrix;
158
159pub mod coo;
160pub use coo::CooMatrix;
161
162pub mod dok;
163pub use dok::DokMatrix;
164
165pub mod lil;
166pub use lil::LilMatrix;
167
168pub mod dia;
169pub use dia::DiaMatrix;
170
171pub mod bsr;
172pub use bsr::BsrMatrix;
173
174pub mod banded;
175pub use banded::BandedMatrix;
176
177// Utility functions
178pub mod utils;
179
180// Linear algebra with sparse matrices
181pub mod linalg;
182// Re-export the main functions from the reorganized linalg module
183pub use linalg::{
184    // Functions from solvers
185    add,
186    // Functions from iterative
187    bicg,
188    bicgstab,
189    cg,
190    cholesky_decomposition,
191    // Enhanced operators
192    convolution_operator,
193    diag_matrix,
194    eigs,
195    eigsh,
196    enhanced_add,
197    enhanced_diagonal,
198    enhanced_scale,
199    enhanced_subtract,
200    expm,
201    // Functions from matfuncs
202    expm_multiply,
203    eye,
204    finite_difference_operator,
205    // GCROT solver
206    gcrot,
207    gmres,
208    incomplete_cholesky,
209    incomplete_lu,
210    inv,
211    lanczos,
212    // Decomposition functions
213    lu_decomposition,
214    matmul,
215    matrix_power,
216    multiply,
217    norm,
218    onenormest,
219    // Eigenvalue functions
220    power_iteration,
221    qr_decomposition,
222    // Specialized solvers (v0.2.0)
223    solve_arrow_matrix,
224    solve_banded_system,
225    solve_block_2x2,
226    solve_kronecker_system,
227    solve_saddle_point,
228    sparse_direct_solve,
229    sparse_lstsq,
230    spsolve,
231    svd_truncated,
232    // SVD functions
233    svds,
234    // TFQMR solver
235    tfqmr,
236    ArpackOptions,
237    // Interfaces
238    AsLinearOperator,
239    // Types from iterative
240    BiCGOptions,
241    BiCGSTABOptions,
242    BiCGSTABResult,
243    // Enhanced operator types
244    BoundaryCondition,
245    CGOptions,
246    CGSOptions,
247    CGSResult,
248    CholeskyResult,
249    ConvolutionMode,
250    ConvolutionOperator,
251    // Operator types
252    DiagonalOperator,
253    EigenResult,
254    EigenvalueMethod,
255    EnhancedDiagonalOperator,
256    EnhancedDifferenceOperator,
257    EnhancedOperatorOptions,
258    EnhancedScaledOperator,
259    EnhancedSumOperator,
260    FiniteDifferenceOperator,
261    GCROTOptions,
262    GCROTResult,
263    GMRESOptions,
264    ICOptions,
265    // Preconditioners
266    ILU0Preconditioner,
267    ILUOptions,
268    IdentityOperator,
269    IterationResult,
270    JacobiPreconditioner,
271    // Decomposition types
272    LUResult,
273    LanczosOptions,
274    LinearOperator,
275    // Eigenvalue types
276    PowerIterationOptions,
277    QRResult,
278    SSORPreconditioner,
279    // SVD types
280    SVDOptions,
281    SVDResult,
282    ScaledIdentityOperator,
283    TFQMROptions,
284    TFQMRResult,
285};
286
287// Format conversions
288pub mod convert;
289
290// Construction utilities
291pub mod construct;
292pub mod construct_sym;
293
294// Combining arrays
295pub mod combine;
296pub use combine::{block_diag, bmat, hstack, kron, kronsum, tril, triu, vstack};
297
298// Index dtype handling utilities
299pub mod index_dtype;
300pub use index_dtype::{can_cast_safely, get_index_dtype, safely_cast_index_arrays};
301
302// Optimized operations for symmetric sparse formats
303pub mod sym_ops;
304
305// Tensor-based sparse operations (v0.2.0)
306pub mod tensor_sparse;
307pub use sym_ops::{
308    sym_coo_matvec, sym_csr_matvec, sym_csr_quadratic_form, sym_csr_rank1_update, sym_csr_trace,
309};
310
311// Tensor operations (v0.2.0)
312pub use tensor_sparse::{khatri_rao_product, CPDecomposition, SparseTensor, TuckerDecomposition};
313
314// GPU-accelerated operations
315pub mod gpu;
316pub mod gpu_kernel_execution;
317pub mod gpu_ops;
318pub mod gpu_spmv_implementation;
319pub use gpu_kernel_execution::{
320    calculate_adaptive_workgroup_size, execute_spmv_kernel, execute_symmetric_spmv_kernel,
321    execute_triangular_solve_kernel, GpuKernelConfig, GpuMemoryManager as GpuKernelMemoryManager,
322    GpuPerformanceProfiler, MemoryStrategy,
323};
324pub use gpu_ops::{
325    gpu_sparse_matvec, gpu_sym_sparse_matvec, AdvancedGpuOps, GpuKernelScheduler, GpuMemoryManager,
326    GpuOptions, GpuProfiler, OptimizedGpuOps,
327};
328pub use gpu_spmv_implementation::GpuSpMV;
329
330// Memory-efficient algorithms and patterns
331pub mod memory_efficient;
332pub use memory_efficient::{
333    streaming_sparse_matvec, CacheAwareOps, MemoryPool, MemoryTracker, OutOfCoreProcessor,
334};
335
336// SIMD-accelerated operations
337pub mod simd_ops;
338pub use simd_ops::{
339    simd_csr_matvec, simd_sparse_elementwise, simd_sparse_linear_combination, simd_sparse_matmul,
340    simd_sparse_norm, simd_sparse_scale, simd_sparse_transpose, ElementwiseOp, SimdOptions,
341};
342
343// Parallel vector operations for iterative solvers
344pub mod parallel_vector_ops;
345pub use parallel_vector_ops::{
346    advanced_sparse_matvec_csr, parallel_axpy, parallel_dot, parallel_linear_combination,
347    parallel_norm2, parallel_sparse_matvec_csr, parallel_vector_add, parallel_vector_copy,
348    parallel_vector_scale, parallel_vector_sub, ParallelVectorOptions,
349};
350
351// Quantum-inspired sparse matrix operations (Advanced mode)
352pub mod quantum_inspired_sparse;
353pub use quantum_inspired_sparse::{
354    QuantumProcessorStats, QuantumSparseConfig, QuantumSparseProcessor, QuantumStrategy,
355};
356
357// Neural-adaptive sparse matrix operations (Advanced mode)
358pub mod neural_adaptive_sparse;
359pub use neural_adaptive_sparse::{
360    NeuralAdaptiveConfig, NeuralAdaptiveSparseProcessor, NeuralProcessorStats, OptimizationStrategy,
361};
362
363// Quantum-Neural hybrid optimization (Advanced mode)
364pub mod quantum_neural_hybrid;
365pub use quantum_neural_hybrid::{
366    HybridStrategy, QuantumNeuralConfig, QuantumNeuralHybridProcessor, QuantumNeuralHybridStats,
367};
368
369// Adaptive memory compression for advanced-large sparse matrices (Advanced mode)
370pub mod adaptive_memory_compression;
371pub use adaptive_memory_compression::{
372    AdaptiveCompressionConfig, AdaptiveMemoryCompressor, CompressedMatrix, CompressionAlgorithm,
373    MemoryStats,
374};
375
376// Real-time performance monitoring and adaptation (Advanced mode)
377pub mod realtime_performance_monitor;
378pub use realtime_performance_monitor::{
379    Alert, AlertSeverity, PerformanceMonitorConfig, PerformanceSample, ProcessorType,
380    RealTimePerformanceMonitor,
381};
382
383// Compressed sparse graph algorithms
384pub mod csgraph;
385pub use csgraph::{
386    all_pairs_shortest_path,
387    bellman_ford_single_source,
388    // Centrality measures (v0.2.0)
389    betweenness_centrality,
390    bfs_distances,
391    // Traversal algorithms
392    breadth_first_search,
393    closeness_centrality,
394    // Community detection (v0.2.0)
395    community_detection,
396    compute_laplacianmatrix,
397    connected_components,
398    degree_matrix,
399    depth_first_search,
400    dijkstra_single_source,
401    // Max flow algorithms (v0.2.0)
402    dinic,
403    edmonds_karp,
404    eigenvector_centrality,
405    floyd_warshall,
406    ford_fulkerson,
407    has_path,
408    is_connected,
409    is_laplacian,
410    is_spanning_tree,
411    // Minimum spanning trees
412    kruskal_mst,
413    label_propagation,
414    // Laplacian matrices
415    laplacian,
416    largest_component,
417    louvain_communities,
418    min_cut,
419    minimum_spanning_tree,
420    modularity,
421    num_edges,
422    num_vertices,
423    pagerank,
424    prim_mst,
425    reachable_vertices,
426    reconstruct_path,
427    // Graph algorithms
428    shortest_path,
429    // Shortest path algorithms
430    single_source_shortest_path,
431    spanning_tree_weight,
432    strongly_connected_components,
433    to_adjacency_list,
434    topological_sort,
435    traversegraph,
436    // Connected components
437    undirected_connected_components,
438    // Graph utilities
439    validate_graph,
440    weakly_connected_components,
441    LaplacianType,
442    MSTAlgorithm,
443    // Max flow types (v0.2.0)
444    MaxFlowResult,
445    // Enums and types
446    ShortestPathMethod,
447    TraversalOrder,
448};
449
450// Re-export warnings from scipy for compatibility
451pub struct SparseEfficiencyWarning;
452pub struct SparseWarning;
453
454/// Check if an object is a sparse array
455#[allow(dead_code)]
456pub fn is_sparse_array<T>(obj: &dyn SparseArray<T>) -> bool
457where
458    T: scirs2_core::SparseElement + std::ops::Div<Output = T> + PartialOrd + 'static,
459{
460    sparray::is_sparse(obj)
461}
462
463/// Check if an object is a symmetric sparse array
464#[allow(dead_code)]
465pub fn is_sym_sparse_array<T>(obj: &dyn SymSparseArray<T>) -> bool
466where
467    T: scirs2_core::SparseElement
468        + std::ops::Div<Output = T>
469        + scirs2_core::Float
470        + PartialOrd
471        + 'static,
472{
473    obj.is_symmetric()
474}
475
476/// Check if an object is a sparse matrix (legacy API)
477#[allow(dead_code)]
478pub fn is_sparse_matrix(obj: &dyn std::any::Any) -> bool {
479    obj.is::<CsrMatrix<f64>>()
480        || obj.is::<CscMatrix<f64>>()
481        || obj.is::<CooMatrix<f64>>()
482        || obj.is::<DokMatrix<f64>>()
483        || obj.is::<LilMatrix<f64>>()
484        || obj.is::<DiaMatrix<f64>>()
485        || obj.is::<BsrMatrix<f64>>()
486        || obj.is::<SymCsrMatrix<f64>>()
487        || obj.is::<SymCooMatrix<f64>>()
488        || obj.is::<CsrMatrix<f32>>()
489        || obj.is::<CscMatrix<f32>>()
490        || obj.is::<CooMatrix<f32>>()
491        || obj.is::<DokMatrix<f32>>()
492        || obj.is::<LilMatrix<f32>>()
493        || obj.is::<DiaMatrix<f32>>()
494        || obj.is::<BsrMatrix<f32>>()
495        || obj.is::<SymCsrMatrix<f32>>()
496        || obj.is::<SymCooMatrix<f32>>()
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use approx::assert_relative_eq;
503
504    #[test]
505    fn test_csr_array() {
506        let rows = vec![0, 0, 1, 2, 2];
507        let cols = vec![0, 2, 2, 0, 1];
508        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
509        let shape = (3, 3);
510
511        let array =
512            CsrArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
513
514        assert_eq!(array.shape(), (3, 3));
515        assert_eq!(array.nnz(), 5);
516        assert!(is_sparse_array(&array));
517    }
518
519    #[test]
520    fn test_coo_array() {
521        let rows = vec![0, 0, 1, 2, 2];
522        let cols = vec![0, 2, 2, 0, 1];
523        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
524        let shape = (3, 3);
525
526        let array =
527            CooArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
528
529        assert_eq!(array.shape(), (3, 3));
530        assert_eq!(array.nnz(), 5);
531        assert!(is_sparse_array(&array));
532    }
533
534    #[test]
535    fn test_dok_array() {
536        let rows = vec![0, 0, 1, 2, 2];
537        let cols = vec![0, 2, 2, 0, 1];
538        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
539        let shape = (3, 3);
540
541        let array = DokArray::from_triplets(&rows, &cols, &data, shape).expect("Operation failed");
542
543        assert_eq!(array.shape(), (3, 3));
544        assert_eq!(array.nnz(), 5);
545        assert!(is_sparse_array(&array));
546
547        // Test setting and getting values
548        let mut array = DokArray::<f64>::new((2, 2));
549        array.set(0, 0, 1.0).expect("Operation failed");
550        array.set(1, 1, 2.0).expect("Operation failed");
551
552        assert_eq!(array.get(0, 0), 1.0);
553        assert_eq!(array.get(0, 1), 0.0);
554        assert_eq!(array.get(1, 1), 2.0);
555
556        // Test removing zeros
557        array.set(0, 0, 0.0).expect("Operation failed");
558        assert_eq!(array.nnz(), 1);
559    }
560
561    #[test]
562    fn test_lil_array() {
563        let rows = vec![0, 0, 1, 2, 2];
564        let cols = vec![0, 2, 2, 0, 1];
565        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
566        let shape = (3, 3);
567
568        let array = LilArray::from_triplets(&rows, &cols, &data, shape).expect("Operation failed");
569
570        assert_eq!(array.shape(), (3, 3));
571        assert_eq!(array.nnz(), 5);
572        assert!(is_sparse_array(&array));
573
574        // Test setting and getting values
575        let mut array = LilArray::<f64>::new((2, 2));
576        array.set(0, 0, 1.0).expect("Operation failed");
577        array.set(1, 1, 2.0).expect("Operation failed");
578
579        assert_eq!(array.get(0, 0), 1.0);
580        assert_eq!(array.get(0, 1), 0.0);
581        assert_eq!(array.get(1, 1), 2.0);
582
583        // Test sorted indices
584        assert!(array.has_sorted_indices());
585
586        // Test removing zeros
587        array.set(0, 0, 0.0).expect("Operation failed");
588        assert_eq!(array.nnz(), 1);
589    }
590
591    #[test]
592    fn test_dia_array() {
593        use scirs2_core::ndarray::Array1;
594
595        // Create a 3x3 diagonal matrix with main diagonal + upper diagonal
596        let data = vec![
597            Array1::from_vec(vec![1.0, 2.0, 3.0]), // Main diagonal
598            Array1::from_vec(vec![4.0, 5.0, 0.0]), // Upper diagonal
599        ];
600        let offsets = vec![0, 1]; // Main diagonal and k=1
601        let shape = (3, 3);
602
603        let array = DiaArray::new(data, offsets, shape).expect("Operation failed");
604
605        assert_eq!(array.shape(), (3, 3));
606        assert_eq!(array.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
607        assert!(is_sparse_array(&array));
608
609        // Test values
610        assert_eq!(array.get(0, 0), 1.0);
611        assert_eq!(array.get(1, 1), 2.0);
612        assert_eq!(array.get(2, 2), 3.0);
613        assert_eq!(array.get(0, 1), 4.0);
614        assert_eq!(array.get(1, 2), 5.0);
615        assert_eq!(array.get(0, 2), 0.0);
616
617        // Test from_triplets
618        let rows = vec![0, 0, 1, 1, 2];
619        let cols = vec![0, 1, 1, 2, 2];
620        let data_vec = vec![1.0, 4.0, 2.0, 5.0, 3.0];
621
622        let array2 =
623            DiaArray::from_triplets(&rows, &cols, &data_vec, shape).expect("Operation failed");
624
625        // Should have same values
626        assert_eq!(array2.get(0, 0), 1.0);
627        assert_eq!(array2.get(1, 1), 2.0);
628        assert_eq!(array2.get(2, 2), 3.0);
629        assert_eq!(array2.get(0, 1), 4.0);
630        assert_eq!(array2.get(1, 2), 5.0);
631
632        // Test conversion to other formats
633        let csr = array.to_csr().expect("Operation failed");
634        assert_eq!(csr.nnz(), 5);
635        assert_eq!(csr.get(0, 0), 1.0);
636        assert_eq!(csr.get(0, 1), 4.0);
637    }
638
639    #[test]
640    fn test_format_conversions() {
641        let rows = vec![0, 0, 1, 2, 2];
642        let cols = vec![0, 2, 1, 0, 2];
643        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
644        let shape = (3, 3);
645
646        // Create a COO array
647        let coo =
648            CooArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
649
650        // Convert to CSR
651        let csr = coo.to_csr().expect("Operation failed");
652
653        // Check values are preserved
654        let coo_dense = coo.to_array();
655        let csr_dense = csr.to_array();
656
657        for i in 0..shape.0 {
658            for j in 0..shape.1 {
659                assert_relative_eq!(coo_dense[[i, j]], csr_dense[[i, j]]);
660            }
661        }
662    }
663
664    #[test]
665    fn test_dot_product() {
666        let rows = vec![0, 0, 1, 2, 2];
667        let cols = vec![0, 2, 1, 0, 2];
668        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
669        let shape = (3, 3);
670
671        // Create arrays in different formats
672        let coo =
673            CooArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
674        let csr =
675            CsrArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
676
677        // Compute dot product (matrix multiplication)
678        let coo_result = coo.dot(&coo).expect("Operation failed");
679        let csr_result = csr.dot(&csr).expect("Operation failed");
680
681        // Check results match
682        let coo_dense = coo_result.to_array();
683        let csr_dense = csr_result.to_array();
684
685        for i in 0..shape.0 {
686            for j in 0..shape.1 {
687                assert_relative_eq!(coo_dense[[i, j]], csr_dense[[i, j]], epsilon = 1e-10);
688            }
689        }
690    }
691
692    #[test]
693    fn test_sym_csr_array() {
694        // Create a symmetric matrix
695        let data = vec![2.0, 1.0, 2.0, 3.0, 0.0, 3.0, 1.0];
696        let indices = vec![0, 0, 1, 2, 0, 1, 2];
697        let indptr = vec![0, 1, 3, 7];
698
699        let sym_matrix =
700            SymCsrMatrix::new(data, indptr, indices, (3, 3)).expect("Operation failed");
701        let sym_array = SymCsrArray::new(sym_matrix);
702
703        assert_eq!(sym_array.shape(), (3, 3));
704        assert!(is_sym_sparse_array(&sym_array));
705
706        // Check values
707        assert_eq!(SparseArray::get(&sym_array, 0, 0), 2.0);
708        assert_eq!(SparseArray::get(&sym_array, 0, 1), 1.0);
709        assert_eq!(SparseArray::get(&sym_array, 1, 0), 1.0); // Symmetric element
710        assert_eq!(SparseArray::get(&sym_array, 1, 2), 3.0);
711        assert_eq!(SparseArray::get(&sym_array, 2, 1), 3.0); // Symmetric element
712
713        // Convert to standard CSR
714        let csr = SymSparseArray::to_csr(&sym_array).expect("Operation failed");
715        assert_eq!(csr.nnz(), 10); // Full matrix with symmetric elements
716    }
717
718    #[test]
719    fn test_sym_coo_array() {
720        // Create a symmetric matrix in COO format
721        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
722        let rows = vec![0, 1, 1, 2, 2];
723        let cols = vec![0, 0, 1, 1, 2];
724
725        let sym_matrix = SymCooMatrix::new(data, rows, cols, (3, 3)).expect("Operation failed");
726        let sym_array = SymCooArray::new(sym_matrix);
727
728        assert_eq!(sym_array.shape(), (3, 3));
729        assert!(is_sym_sparse_array(&sym_array));
730
731        // Check values
732        assert_eq!(SparseArray::get(&sym_array, 0, 0), 2.0);
733        assert_eq!(SparseArray::get(&sym_array, 0, 1), 1.0);
734        assert_eq!(SparseArray::get(&sym_array, 1, 0), 1.0); // Symmetric element
735        assert_eq!(SparseArray::get(&sym_array, 1, 2), 3.0);
736        assert_eq!(SparseArray::get(&sym_array, 2, 1), 3.0); // Symmetric element
737
738        // Test from_triplets with enforce symmetry
739        // Input is intentionally asymmetric - will be fixed by enforce_symmetric=true
740        let rows2 = vec![0, 0, 1, 1, 2, 1, 0];
741        let cols2 = vec![0, 1, 1, 2, 2, 0, 2];
742        let data2 = vec![2.0, 1.5, 2.0, 3.5, 1.0, 0.5, 0.0];
743
744        let sym_array2 = SymCooArray::from_triplets(&rows2, &cols2, &data2, (3, 3), true)
745            .expect("Operation failed");
746
747        // Should average the asymmetric values
748        assert_eq!(SparseArray::get(&sym_array2, 0, 1), 1.0); // Average of 1.5 and 0.5
749        assert_eq!(SparseArray::get(&sym_array2, 1, 0), 1.0); // Symmetric element
750        assert_eq!(SparseArray::get(&sym_array2, 0, 2), 0.0); // Zero element
751    }
752
753    #[test]
754    fn test_construct_sym_utils() {
755        // Test creating an identity matrix
756        let eye = construct_sym::eye_sym_array::<f64>(3, "csr").expect("Operation failed");
757
758        assert_eq!(eye.shape(), (3, 3));
759        assert_eq!(SparseArray::get(&*eye, 0, 0), 1.0);
760        assert_eq!(SparseArray::get(&*eye, 1, 1), 1.0);
761        assert_eq!(SparseArray::get(&*eye, 2, 2), 1.0);
762        assert_eq!(SparseArray::get(&*eye, 0, 1), 0.0);
763
764        // Test creating a tridiagonal matrix - with coo format since csr had issues
765        let diag = vec![2.0, 2.0, 2.0];
766        let offdiag = vec![1.0, 1.0];
767
768        let tri =
769            construct_sym::tridiagonal_sym_array(&diag, &offdiag, "coo").expect("Operation failed");
770
771        assert_eq!(tri.shape(), (3, 3));
772        assert_eq!(SparseArray::get(&*tri, 0, 0), 2.0); // Main diagonal
773        assert_eq!(SparseArray::get(&*tri, 1, 1), 2.0);
774        assert_eq!(SparseArray::get(&*tri, 2, 2), 2.0);
775        assert_eq!(SparseArray::get(&*tri, 0, 1), 1.0); // Off-diagonal
776        assert_eq!(SparseArray::get(&*tri, 1, 0), 1.0); // Symmetric element
777        assert_eq!(SparseArray::get(&*tri, 1, 2), 1.0);
778        assert_eq!(SparseArray::get(&*tri, 0, 2), 0.0); // Zero element
779
780        // Test creating a banded matrix
781        let diagonals = vec![
782            vec![2.0, 2.0, 2.0, 2.0, 2.0], // Main diagonal
783            vec![1.0, 1.0, 1.0, 1.0],      // First off-diagonal
784            vec![0.5, 0.5, 0.5],           // Second off-diagonal
785        ];
786
787        let band = construct_sym::banded_sym_array(&diagonals, 5, "csr").expect("Operation failed");
788
789        assert_eq!(band.shape(), (5, 5));
790        assert_eq!(SparseArray::get(&*band, 0, 0), 2.0);
791        assert_eq!(SparseArray::get(&*band, 0, 1), 1.0);
792        assert_eq!(SparseArray::get(&*band, 0, 2), 0.5);
793        assert_eq!(SparseArray::get(&*band, 2, 0), 0.5); // Symmetric element
794    }
795
796    #[test]
797    fn test_sym_conversions() {
798        // Create a symmetric matrix
799        // Lower triangular part only
800        let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
801        let rows = vec![0, 1, 1, 2, 2];
802        let cols = vec![0, 0, 1, 1, 2];
803
804        let sym_coo = SymCooArray::from_triplets(&rows, &cols, &data, (3, 3), true)
805            .expect("Operation failed");
806
807        // Convert to symmetric CSR
808        let sym_csr = sym_coo.to_sym_csr().expect("Operation failed");
809
810        // Check values are preserved
811        for i in 0..3 {
812            for j in 0..3 {
813                assert_eq!(
814                    SparseArray::get(&sym_coo, i, j),
815                    SparseArray::get(&sym_csr, i, j)
816                );
817            }
818        }
819
820        // Convert to standard formats
821        let csr = SymSparseArray::to_csr(&sym_coo).expect("Operation failed");
822        let coo = SymSparseArray::to_coo(&sym_csr).expect("Operation failed");
823
824        // Check full symmetric matrix in standard formats
825        assert_eq!(csr.nnz(), 7); // Accounts for symmetric pairs
826        assert_eq!(coo.nnz(), 7);
827
828        for i in 0..3 {
829            for j in 0..3 {
830                assert_eq!(SparseArray::get(&csr, i, j), SparseArray::get(&coo, i, j));
831                assert_eq!(
832                    SparseArray::get(&csr, i, j),
833                    SparseArray::get(&sym_csr, i, j)
834                );
835            }
836        }
837    }
838}