Skip to main content

math_audio_solvers/sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format
2//!
3//! CSR format stores:
4//! - `values`: Non-zero entries in row-major order
5//! - `col_indices`: Column index for each value
6//! - `row_ptrs`: Index into values/col_indices where each row starts
7
8use crate::traits::{ComplexField, LinearOperator};
9use ndarray::{Array1, Array2};
10use num_traits::{FromPrimitive, Zero};
11use std::ops::Range;
12
13#[cfg(feature = "rayon")]
14use rayon::prelude::*;
15
16/// Compressed Sparse Row (CSR) matrix format
17///
18/// Memory-efficient storage for sparse matrices with O(nnz) space complexity.
19/// Matrix-vector products are O(nnz) instead of O(n²) for dense matrices.
20#[derive(Debug, Clone)]
21pub struct CsrMatrix<T: ComplexField> {
22    /// Number of rows
23    pub num_rows: usize,
24    /// Number of columns
25    pub num_cols: usize,
26    /// Non-zero values in row-major order
27    pub values: Vec<T>,
28    /// Column indices for each value
29    pub col_indices: Vec<usize>,
30    /// Row pointers: `row_ptrs[i]` is the start index in values/col_indices for row `i`
31    /// `row_ptrs[num_rows]` = nnz (total number of non-zeros)
32    pub row_ptrs: Vec<usize>,
33}
34
35impl<T: ComplexField> CsrMatrix<T> {
36    /// Create a new empty CSR matrix
37    pub fn new(num_rows: usize, num_cols: usize) -> Self {
38        Self {
39            num_rows,
40            num_cols,
41            values: Vec::new(),
42            col_indices: Vec::new(),
43            row_ptrs: vec![0; num_rows + 1],
44        }
45    }
46
47    /// Create a CSR matrix with pre-allocated capacity
48    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
49        Self {
50            num_rows,
51            num_cols,
52            values: Vec::with_capacity(nnz_estimate),
53            col_indices: Vec::with_capacity(nnz_estimate),
54            row_ptrs: vec![0; num_rows + 1],
55        }
56    }
57
58    /// Create a CSR matrix from raw components
59    ///
60    /// This is useful for converting between different CSR matrix representations
61    /// that share the same internal structure.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the input arrays are inconsistent:
66    /// - `row_ptrs` must have length `num_rows + 1`
67    /// - `col_indices` and `values` must have the same length
68    /// - `row_ptrs[num_rows]` must equal `values.len()`
69    pub fn from_raw_parts(
70        num_rows: usize,
71        num_cols: usize,
72        row_ptrs: Vec<usize>,
73        col_indices: Vec<usize>,
74        values: Vec<T>,
75    ) -> Self {
76        assert_eq!(
77            row_ptrs.len(),
78            num_rows + 1,
79            "row_ptrs must have num_rows + 1 elements"
80        );
81        assert_eq!(
82            col_indices.len(),
83            values.len(),
84            "col_indices and values must have the same length"
85        );
86        assert_eq!(
87            row_ptrs[num_rows],
88            values.len(),
89            "row_ptrs[num_rows] must equal nnz"
90        );
91
92        Self {
93            num_rows,
94            num_cols,
95            row_ptrs,
96            col_indices,
97            values,
98        }
99    }
100
101    /// Create a CSR matrix from a dense matrix
102    ///
103    /// Only stores entries with magnitude > threshold
104    pub fn from_dense(dense: &Array2<T>, threshold: T::Real) -> Self {
105        let num_rows = dense.nrows();
106        let num_cols = dense.ncols();
107
108        let mut values = Vec::new();
109        let mut col_indices = Vec::new();
110        let mut row_ptrs = vec![0usize; num_rows + 1];
111
112        for i in 0..num_rows {
113            for j in 0..num_cols {
114                let val = dense[[i, j]];
115                if val.norm() > threshold {
116                    values.push(val);
117                    col_indices.push(j);
118                }
119            }
120            row_ptrs[i + 1] = values.len();
121        }
122
123        Self {
124            num_rows,
125            num_cols,
126            values,
127            col_indices,
128            row_ptrs,
129        }
130    }
131
132    /// Create a CSR matrix from COO (Coordinate) format triplets
133    ///
134    /// Triplets are (row, col, value). Duplicate entries are summed.
135    pub fn from_triplets(
136        num_rows: usize,
137        num_cols: usize,
138        mut triplets: Vec<(usize, usize, T)>,
139    ) -> Self {
140        if triplets.is_empty() {
141            return Self::new(num_rows, num_cols);
142        }
143
144        // Sort by row, then by column
145        triplets.sort_by(|a, b| {
146            if a.0 != b.0 {
147                a.0.cmp(&b.0)
148            } else {
149                a.1.cmp(&b.1)
150            }
151        });
152
153        let mut values = Vec::with_capacity(triplets.len());
154        let mut col_indices = Vec::with_capacity(triplets.len());
155        let mut row_ptrs = vec![0usize; num_rows + 1];
156
157        let mut prev_row = usize::MAX;
158        let mut prev_col = usize::MAX;
159
160        for (row, col, val) in triplets {
161            if row == prev_row && col == prev_col {
162                // Same entry, accumulate
163                if let Some(last) = values.last_mut() {
164                    *last += val;
165                }
166            } else {
167                // Update row pointers for any rows we skipped
168                if row != prev_row {
169                    let start = if prev_row == usize::MAX {
170                        0
171                    } else {
172                        prev_row + 1
173                    };
174                    for item in row_ptrs.iter_mut().take(row + 1).skip(start) {
175                        *item = values.len();
176                    }
177                }
178
179                // New entry - push it
180                values.push(val);
181                col_indices.push(col);
182
183                prev_row = row;
184                prev_col = col;
185            }
186        }
187
188        // Fill remaining row pointers
189        let last_row = if prev_row == usize::MAX {
190            0
191        } else {
192            prev_row + 1
193        };
194        for item in row_ptrs.iter_mut().take(num_rows + 1).skip(last_row) {
195            *item = values.len();
196        }
197
198        Self {
199            num_rows,
200            num_cols,
201            values,
202            col_indices,
203            row_ptrs,
204        }
205    }
206
207    /// Number of non-zero entries
208    pub fn nnz(&self) -> usize {
209        self.values.len()
210    }
211
212    /// Sparsity ratio (fraction of non-zero entries)
213    pub fn sparsity(&self) -> f64 {
214        let total = self.num_rows * self.num_cols;
215        if total == 0 {
216            0.0
217        } else {
218            self.nnz() as f64 / total as f64
219        }
220    }
221
222    /// Get the range of indices in values/col_indices for a given row
223    pub fn row_range(&self, row: usize) -> Range<usize> {
224        self.row_ptrs[row]..self.row_ptrs[row + 1]
225    }
226
227    /// Get the (col, value) pairs for a row
228    pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, T)> + '_ {
229        let range = self.row_range(row);
230        self.col_indices[range.clone()]
231            .iter()
232            .copied()
233            .zip(self.values[range].iter().copied())
234    }
235
236    /// Matrix-vector product: y = A * x
237    ///
238    /// Uses parallel processing when the `rayon` feature is enabled and the
239    /// matrix is large enough to benefit from parallelization.
240    pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
241        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
242
243        // Use parallel version for large matrices when rayon is available
244        #[cfg(feature = "rayon")]
245        {
246            // Only parallelize if we have enough rows to benefit
247            if self.num_rows >= 246 {
248                return self.matvec_parallel(x);
249            }
250        }
251
252        self.matvec_sequential(x)
253    }
254
255    /// Sequential matrix-vector product
256    fn matvec_sequential(&self, x: &Array1<T>) -> Array1<T> {
257        let mut y = Array1::from_elem(self.num_rows, T::zero());
258
259        for i in 0..self.num_rows {
260            let mut sum = T::zero();
261            for idx in self.row_range(i) {
262                let j = self.col_indices[idx];
263                sum += self.values[idx] * x[j];
264            }
265            y[i] = sum;
266        }
267
268        y
269    }
270
271    /// Parallel matrix-vector product using rayon
272    #[cfg(feature = "rayon")]
273    fn matvec_parallel(&self, x: &Array1<T>) -> Array1<T>
274    where
275        T: Send + Sync,
276    {
277        let x_slice = x.as_slice().expect("Array should be contiguous");
278
279        let results: Vec<T> = (0..self.num_rows)
280            .into_par_iter()
281            .map(|i| {
282                let mut sum = T::zero();
283                for idx in self.row_range(i) {
284                    let j = self.col_indices[idx];
285                    sum += self.values[idx] * x_slice[j];
286                }
287                sum
288            })
289            .collect();
290
291        Array1::from_vec(results)
292    }
293
294    /// Matrix-vector product with accumulation: y += A * x
295    pub fn matvec_add(&self, x: &Array1<T>, y: &mut Array1<T>) {
296        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
297        assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
298
299        for i in 0..self.num_rows {
300            for idx in self.row_range(i) {
301                let j = self.col_indices[idx];
302                y[i] += self.values[idx] * x[j];
303            }
304        }
305    }
306
307    /// Transpose matrix-vector product: y = A^T * x
308    pub fn matvec_transpose(&self, x: &Array1<T>) -> Array1<T> {
309        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
310
311        let mut y = Array1::from_elem(self.num_cols, T::zero());
312
313        for i in 0..self.num_rows {
314            for idx in self.row_range(i) {
315                let j = self.col_indices[idx];
316                y[j] += self.values[idx] * x[i];
317            }
318        }
319
320        y
321    }
322
323    /// Hermitian (conjugate transpose) matrix-vector product: y = A^H * x
324    pub fn matvec_hermitian(&self, x: &Array1<T>) -> Array1<T> {
325        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
326
327        let mut y = Array1::from_elem(self.num_cols, T::zero());
328
329        for i in 0..self.num_rows {
330            for idx in self.row_range(i) {
331                let j = self.col_indices[idx];
332                y[j] += self.values[idx].conj() * x[i];
333            }
334        }
335
336        y
337    }
338
339    /// Get element at (i, j), returns 0 if not stored
340    pub fn get(&self, i: usize, j: usize) -> T {
341        for idx in self.row_range(i) {
342            if self.col_indices[idx] == j {
343                return self.values[idx];
344            }
345        }
346        T::zero()
347    }
348
349    /// Extract diagonal elements
350    pub fn diagonal(&self) -> Array1<T> {
351        let n = self.num_rows.min(self.num_cols);
352        let mut diag = Array1::from_elem(n, T::zero());
353
354        for i in 0..n {
355            for idx in self.row_range(i) {
356                if self.col_indices[idx] == i {
357                    diag[i] = self.values[idx];
358                    break;
359                }
360            }
361        }
362
363        diag
364    }
365
366    /// Scale all values by a scalar
367    pub fn scale(&mut self, scalar: T) {
368        for val in &mut self.values {
369            *val *= scalar;
370        }
371    }
372
373    /// Add a scalar to the diagonal
374    pub fn add_diagonal(&mut self, scalar: T) {
375        let n = self.num_rows.min(self.num_cols);
376
377        for i in 0..n {
378            for idx in self.row_range(i) {
379                if self.col_indices[idx] == i {
380                    self.values[idx] += scalar;
381                    break;
382                }
383            }
384        }
385    }
386
387    /// Create identity matrix in CSR format
388    pub fn identity(n: usize) -> Self {
389        Self {
390            num_rows: n,
391            num_cols: n,
392            values: vec![T::one(); n],
393            col_indices: (0..n).collect(),
394            row_ptrs: (0..=n).collect(),
395        }
396    }
397
398    /// Create diagonal matrix from vector
399    pub fn from_diagonal(diag: &Array1<T>) -> Self {
400        let n = diag.len();
401        Self {
402            num_rows: n,
403            num_cols: n,
404            values: diag.to_vec(),
405            col_indices: (0..n).collect(),
406            row_ptrs: (0..=n).collect(),
407        }
408    }
409
410    /// Convert to dense matrix (for debugging/small matrices)
411    pub fn to_dense(&self) -> Array2<T> {
412        let mut dense = Array2::from_elem((self.num_rows, self.num_cols), T::zero());
413
414        for i in 0..self.num_rows {
415            for idx in self.row_range(i) {
416                let j = self.col_indices[idx];
417                dense[[i, j]] = self.values[idx];
418            }
419        }
420
421        dense
422    }
423}
424
425impl<T: ComplexField> LinearOperator<T> for CsrMatrix<T> {
426    fn num_rows(&self) -> usize {
427        self.num_rows
428    }
429
430    fn num_cols(&self) -> usize {
431        self.num_cols
432    }
433
434    fn apply(&self, x: &Array1<T>) -> Array1<T> {
435        self.matvec(x)
436    }
437
438    fn apply_transpose(&self, x: &Array1<T>) -> Array1<T> {
439        self.matvec_transpose(x)
440    }
441
442    fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
443        self.matvec_hermitian(x)
444    }
445}
446
447/// Builder for constructing CSR matrices row by row
448pub struct CsrBuilder<T: ComplexField> {
449    num_rows: usize,
450    num_cols: usize,
451    values: Vec<T>,
452    col_indices: Vec<usize>,
453    row_ptrs: Vec<usize>,
454    current_row: usize,
455}
456
457impl<T: ComplexField> CsrBuilder<T> {
458    /// Create a new CSR builder
459    pub fn new(num_rows: usize, num_cols: usize) -> Self {
460        Self {
461            num_rows,
462            num_cols,
463            values: Vec::new(),
464            col_indices: Vec::new(),
465            row_ptrs: vec![0],
466            current_row: 0,
467        }
468    }
469
470    /// Create a new CSR builder with estimated non-zeros
471    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
472        Self {
473            num_rows,
474            num_cols,
475            values: Vec::with_capacity(nnz_estimate),
476            col_indices: Vec::with_capacity(nnz_estimate),
477            row_ptrs: Vec::with_capacity(num_rows + 1),
478            current_row: 0,
479        }
480    }
481
482    /// Add entries for the current row (must be added in column order)
483    pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, T)>) {
484        for (col, val) in entries {
485            if val.norm() > T::Real::zero() {
486                self.values.push(val);
487                self.col_indices.push(col);
488            }
489        }
490        self.row_ptrs.push(self.values.len());
491        self.current_row += 1;
492    }
493
494    /// Finish building and return the CSR matrix
495    pub fn finish(mut self) -> CsrMatrix<T> {
496        // Fill remaining rows if not all rows were added
497        while self.current_row < self.num_rows {
498            self.row_ptrs.push(self.values.len());
499            self.current_row += 1;
500        }
501
502        CsrMatrix {
503            num_rows: self.num_rows,
504            num_cols: self.num_cols,
505            values: self.values,
506            col_indices: self.col_indices,
507            row_ptrs: self.row_ptrs,
508        }
509    }
510}
511
512/// Blocked CSR format for hierarchical matrices
513///
514/// Stores the matrix as a collection of dense blocks at the leaf level
515/// of a hierarchical decomposition.
516#[derive(Debug, Clone)]
517pub struct BlockedCsr<T: ComplexField> {
518    /// Number of rows
519    pub num_rows: usize,
520    /// Number of columns
521    pub num_cols: usize,
522    /// Block size (rows and columns per block)
523    pub block_size: usize,
524    /// Number of block rows
525    pub num_block_rows: usize,
526    /// Number of block columns
527    pub num_block_cols: usize,
528    /// Dense blocks stored in CSR-like format
529    pub blocks: Vec<Array2<T>>,
530    /// Block column indices
531    pub block_col_indices: Vec<usize>,
532    /// Block row pointers
533    pub block_row_ptrs: Vec<usize>,
534}
535
536impl<T: ComplexField> BlockedCsr<T> {
537    /// Create a new blocked CSR matrix
538    pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
539        let num_block_rows = num_rows.div_ceil(block_size);
540        let num_block_cols = num_cols.div_ceil(block_size);
541
542        Self {
543            num_rows,
544            num_cols,
545            block_size,
546            num_block_rows,
547            num_block_cols,
548            blocks: Vec::new(),
549            block_col_indices: Vec::new(),
550            block_row_ptrs: vec![0; num_block_rows + 1],
551        }
552    }
553
554    /// Matrix-vector product using blocked structure
555    pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
556        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
557
558        let mut y = Array1::from_elem(self.num_rows, T::zero());
559
560        for block_i in 0..self.num_block_rows {
561            let row_start = block_i * self.block_size;
562            let row_end = (row_start + self.block_size).min(self.num_rows);
563            let local_rows = row_end - row_start;
564
565            for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
566                let block_j = self.block_col_indices[idx];
567                let block = &self.blocks[idx];
568
569                let col_start = block_j * self.block_size;
570                let col_end = (col_start + self.block_size).min(self.num_cols);
571                let local_cols = col_end - col_start;
572
573                // Extract local x
574                let x_local: Array1<T> = Array1::from_iter((col_start..col_end).map(|j| x[j]));
575
576                // Apply block
577                for i in 0..local_rows {
578                    let mut sum = T::zero();
579                    for j in 0..local_cols {
580                        sum += block[[i, j]] * x_local[j];
581                    }
582                    y[row_start + i] += sum;
583                }
584            }
585        }
586
587        y
588    }
589}
590
591/// Optimized sparse matrix-matrix multiplication: C = A * B
592///
593/// Uses a sorted accumulation approach instead of HashMap for better cache locality,
594/// providing 2-4x speedup for the AMG Galerkin product.
595///
596/// For CSR matrices A (m×k) and B (k×n), computes C (m×n).
597impl<T: ComplexField> CsrMatrix<T> {
598    /// Compute C = A * B using optimized approach (Gustavson's algorithm)
599    pub fn matmul(&self, other: &CsrMatrix<T>) -> CsrMatrix<T> {
600        assert_eq!(
601            self.num_cols, other.num_rows,
602            "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
603            self.num_cols, other.num_rows
604        );
605
606        let m = self.num_rows;
607        let n = other.num_cols;
608
609        if m == 0 || n == 0 || self.nnz() == 0 || other.nnz() == 0 {
610            return CsrMatrix::new(m, n);
611        }
612
613        let tol = T::Real::from_f64(1e-15).unwrap();
614
615        let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(self.nnz() * 2);
616
617        // Workspace for Gustavson's algorithm
618        let mut sparse_accumulator = vec![T::zero(); n];
619        let mut active_indices = Vec::with_capacity(n);
620        let mut occupied = vec![false; n];
621
622        for i in 0..m {
623            for (k, a_ik) in self.row_entries(i) {
624                for (j, b_kj) in other.row_entries(k) {
625                    if !occupied[j] {
626                        occupied[j] = true;
627                        active_indices.push(j);
628                    }
629                    sparse_accumulator[j] += a_ik * b_kj;
630                }
631            }
632
633            if active_indices.is_empty() {
634                continue;
635            }
636
637            // To keep C's rows sorted by column index, we sort active_indices
638            active_indices.sort_unstable();
639
640            for &j in &active_indices {
641                let val = sparse_accumulator[j];
642                if val.norm() > tol {
643                    triplets.push((i, j, val));
644                }
645                // Reset for next row
646                sparse_accumulator[j] = T::zero();
647                occupied[j] = false;
648            }
649            active_indices.clear();
650        }
651
652        CsrMatrix::from_triplets(m, n, triplets)
653    }
654}
655
656#[cfg(test)]
657mod tests {
658    use super::*;
659    use approx::assert_relative_eq;
660    use ndarray::array;
661    use num_complex::Complex64;
662
663    #[test]
664    fn test_csr_from_dense() {
665        let dense = array![
666            [
667                Complex64::new(1.0, 0.0),
668                Complex64::new(0.0, 0.0),
669                Complex64::new(2.0, 0.0)
670            ],
671            [
672                Complex64::new(0.0, 0.0),
673                Complex64::new(3.0, 0.0),
674                Complex64::new(0.0, 0.0)
675            ],
676            [
677                Complex64::new(4.0, 0.0),
678                Complex64::new(0.0, 0.0),
679                Complex64::new(5.0, 0.0)
680            ],
681        ];
682
683        let csr = CsrMatrix::from_dense(&dense, 1e-15);
684
685        assert_eq!(csr.num_rows, 3);
686        assert_eq!(csr.num_cols, 3);
687        assert_eq!(csr.nnz(), 5);
688
689        assert_relative_eq!(csr.get(0, 0).re, 1.0);
690        assert_relative_eq!(csr.get(0, 2).re, 2.0);
691        assert_relative_eq!(csr.get(1, 1).re, 3.0);
692        assert_relative_eq!(csr.get(2, 0).re, 4.0);
693        assert_relative_eq!(csr.get(2, 2).re, 5.0);
694    }
695
696    #[test]
697    fn test_csr_matvec() {
698        let dense = array![
699            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
700            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
701        ];
702
703        let csr = CsrMatrix::from_dense(&dense, 1e-15);
704        let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
705
706        let y = csr.matvec(&x);
707
708        // [1 2] * [1]   [5]
709        // [3 4]   [2] = [11]
710        assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
711        assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
712    }
713
714    #[test]
715    fn test_csr_from_triplets() {
716        let triplets = vec![
717            (0, 0, Complex64::new(1.0, 0.0)),
718            (0, 2, Complex64::new(2.0, 0.0)),
719            (1, 1, Complex64::new(3.0, 0.0)),
720            (2, 0, Complex64::new(4.0, 0.0)),
721            (2, 2, Complex64::new(5.0, 0.0)),
722        ];
723
724        let csr = CsrMatrix::from_triplets(3, 3, triplets);
725
726        assert_eq!(csr.nnz(), 5);
727        assert_relative_eq!(csr.get(0, 0).re, 1.0);
728        assert_relative_eq!(csr.get(1, 1).re, 3.0);
729    }
730
731    #[test]
732    fn test_csr_triplets_duplicate() {
733        let triplets = vec![
734            (0, 0, Complex64::new(1.0, 0.0)),
735            (0, 0, Complex64::new(2.0, 0.0)), // Duplicate!
736            (1, 1, Complex64::new(3.0, 0.0)),
737        ];
738
739        let csr = CsrMatrix::from_triplets(2, 2, triplets);
740
741        assert_relative_eq!(csr.get(0, 0).re, 3.0); // 1 + 2 = 3
742    }
743
744    #[test]
745    fn test_csr_identity() {
746        let id: CsrMatrix<Complex64> = CsrMatrix::identity(3);
747
748        assert_eq!(id.nnz(), 3);
749        assert_relative_eq!(id.get(0, 0).re, 1.0);
750        assert_relative_eq!(id.get(1, 1).re, 1.0);
751        assert_relative_eq!(id.get(2, 2).re, 1.0);
752        assert_relative_eq!(id.get(0, 1).norm(), 0.0);
753    }
754
755    #[test]
756    fn test_csr_builder() {
757        let mut builder: CsrBuilder<Complex64> = CsrBuilder::new(3, 3);
758
759        builder.add_row_entries(
760            [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))].into_iter(),
761        );
762        builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
763        builder.add_row_entries(
764            [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))].into_iter(),
765        );
766
767        let csr = builder.finish();
768
769        assert_eq!(csr.nnz(), 5);
770        assert_relative_eq!(csr.get(0, 0).re, 1.0);
771        assert_relative_eq!(csr.get(1, 1).re, 3.0);
772    }
773
774    #[test]
775    fn test_csr_to_dense_roundtrip() {
776        let original = array![
777            [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
778            [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
779        ];
780
781        let csr = CsrMatrix::from_dense(&original, 1e-15);
782        let recovered = csr.to_dense();
783
784        for i in 0..2 {
785            for j in 0..2 {
786                assert_relative_eq!(
787                    (original[[i, j]] - recovered[[i, j]]).norm(),
788                    0.0,
789                    epsilon = 1e-10
790                );
791            }
792        }
793    }
794
795    #[test]
796    fn test_linear_operator_impl() {
797        let dense = array![
798            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
799            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
800        ];
801
802        let csr = CsrMatrix::from_dense(&dense, 1e-15);
803        let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
804
805        // Test via LinearOperator trait
806        let y = csr.apply(&x);
807        assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
808        assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
809
810        assert!(csr.is_square());
811        assert_eq!(csr.num_rows(), 2);
812        assert_eq!(csr.num_cols(), 2);
813    }
814
815    #[test]
816    fn test_f64_csr() {
817        let dense = array![[1.0_f64, 2.0], [3.0, 4.0],];
818
819        let csr = CsrMatrix::from_dense(&dense, 1e-15);
820        let x = array![1.0_f64, 2.0];
821
822        let y = csr.matvec(&x);
823        assert_relative_eq!(y[0], 5.0, epsilon = 1e-10);
824        assert_relative_eq!(y[1], 11.0, epsilon = 1e-10);
825    }
826}