math_audio_solvers/sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format
2//!
3//! CSR format stores:
4//! - `values`: Non-zero entries in row-major order
5//! - `col_indices`: Column index for each value
6//! - `row_ptrs`: Index into values/col_indices where each row starts
7
8use crate::traits::{ComplexField, LinearOperator};
9use ndarray::{Array1, Array2};
10use num_traits::{FromPrimitive, Zero};
11use std::ops::Range;
12
13#[cfg(feature = "rayon")]
14use rayon::prelude::*;
15
16/// Compressed Sparse Row (CSR) matrix format
17///
18/// Memory-efficient storage for sparse matrices with O(nnz) space complexity.
19/// Matrix-vector products are O(nnz) instead of O(n²) for dense matrices.
20#[derive(Debug, Clone)]
21pub struct CsrMatrix<T: ComplexField> {
22    /// Number of rows
23    pub num_rows: usize,
24    /// Number of columns
25    pub num_cols: usize,
26    /// Non-zero values in row-major order
27    pub values: Vec<T>,
28    /// Column indices for each value
29    pub col_indices: Vec<usize>,
30    /// Row pointers: row_ptrs[i] is the start index in values/col_indices for row i
31    /// row_ptrs[num_rows] = nnz (total number of non-zeros)
32    pub row_ptrs: Vec<usize>,
33}
34
35impl<T: ComplexField> CsrMatrix<T> {
36    /// Create a new empty CSR matrix
37    pub fn new(num_rows: usize, num_cols: usize) -> Self {
38        Self {
39            num_rows,
40            num_cols,
41            values: Vec::new(),
42            col_indices: Vec::new(),
43            row_ptrs: vec![0; num_rows + 1],
44        }
45    }
46
47    /// Create a CSR matrix with pre-allocated capacity
48    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
49        Self {
50            num_rows,
51            num_cols,
52            values: Vec::with_capacity(nnz_estimate),
53            col_indices: Vec::with_capacity(nnz_estimate),
54            row_ptrs: vec![0; num_rows + 1],
55        }
56    }
57
58    /// Create a CSR matrix from raw components
59    ///
60    /// This is useful for converting between different CSR matrix representations
61    /// that share the same internal structure.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the input arrays are inconsistent:
66    /// - `row_ptrs` must have length `num_rows + 1`
67    /// - `col_indices` and `values` must have the same length
68    /// - `row_ptrs[num_rows]` must equal `values.len()`
69    pub fn from_raw_parts(
70        num_rows: usize,
71        num_cols: usize,
72        row_ptrs: Vec<usize>,
73        col_indices: Vec<usize>,
74        values: Vec<T>,
75    ) -> Self {
76        assert_eq!(
77            row_ptrs.len(),
78            num_rows + 1,
79            "row_ptrs must have num_rows + 1 elements"
80        );
81        assert_eq!(
82            col_indices.len(),
83            values.len(),
84            "col_indices and values must have the same length"
85        );
86        assert_eq!(
87            row_ptrs[num_rows],
88            values.len(),
89            "row_ptrs[num_rows] must equal nnz"
90        );
91
92        Self {
93            num_rows,
94            num_cols,
95            row_ptrs,
96            col_indices,
97            values,
98        }
99    }
100
101    /// Create a CSR matrix from a dense matrix
102    ///
103    /// Only stores entries with magnitude > threshold
104    pub fn from_dense(dense: &Array2<T>, threshold: T::Real) -> Self {
105        let num_rows = dense.nrows();
106        let num_cols = dense.ncols();
107
108        let mut values = Vec::new();
109        let mut col_indices = Vec::new();
110        let mut row_ptrs = vec![0usize; num_rows + 1];
111
112        for i in 0..num_rows {
113            for j in 0..num_cols {
114                let val = dense[[i, j]];
115                if val.norm() > threshold {
116                    values.push(val);
117                    col_indices.push(j);
118                }
119            }
120            row_ptrs[i + 1] = values.len();
121        }
122
123        Self {
124            num_rows,
125            num_cols,
126            values,
127            col_indices,
128            row_ptrs,
129        }
130    }
131
132    /// Create a CSR matrix from COO (Coordinate) format triplets
133    ///
134    /// Triplets are (row, col, value). Duplicate entries are summed.
135    pub fn from_triplets(
136        num_rows: usize,
137        num_cols: usize,
138        mut triplets: Vec<(usize, usize, T)>,
139    ) -> Self {
140        if triplets.is_empty() {
141            return Self::new(num_rows, num_cols);
142        }
143
144        // Sort by row, then by column
145        triplets.sort_by(|a, b| {
146            if a.0 != b.0 {
147                a.0.cmp(&b.0)
148            } else {
149                a.1.cmp(&b.1)
150            }
151        });
152
153        let mut values = Vec::with_capacity(triplets.len());
154        let mut col_indices = Vec::with_capacity(triplets.len());
155        let mut row_ptrs = vec![0usize; num_rows + 1];
156
157        let mut prev_row = usize::MAX;
158        let mut prev_col = usize::MAX;
159
160        for (row, col, val) in triplets {
161            if row == prev_row && col == prev_col {
162                // Same entry, accumulate
163                if let Some(last) = values.last_mut() {
164                    *last += val;
165                }
166            } else {
167                // New entry - push it
168                values.push(val);
169                col_indices.push(col);
170
171                // Update row pointers for any rows we skipped
172                if row != prev_row {
173                    let start = if prev_row == usize::MAX {
174                        0
175                    } else {
176                        prev_row + 1
177                    };
178                    for item in row_ptrs.iter_mut().take(row + 1).skip(start) {
179                        *item = values.len() - 1;
180                    }
181                }
182
183                prev_row = row;
184                prev_col = col;
185            }
186        }
187
188        // Fill remaining row pointers
189        let last_row = if prev_row == usize::MAX {
190            0
191        } else {
192            prev_row + 1
193        };
194        for item in row_ptrs.iter_mut().take(num_rows + 1).skip(last_row) {
195            *item = values.len();
196        }
197
198        Self {
199            num_rows,
200            num_cols,
201            values,
202            col_indices,
203            row_ptrs,
204        }
205    }
206
207    /// Number of non-zero entries
208    pub fn nnz(&self) -> usize {
209        self.values.len()
210    }
211
212    /// Sparsity ratio (fraction of non-zero entries)
213    pub fn sparsity(&self) -> f64 {
214        let total = self.num_rows * self.num_cols;
215        if total == 0 {
216            0.0
217        } else {
218            self.nnz() as f64 / total as f64
219        }
220    }
221
222    /// Get the range of indices in values/col_indices for a given row
223    pub fn row_range(&self, row: usize) -> Range<usize> {
224        self.row_ptrs[row]..self.row_ptrs[row + 1]
225    }
226
227    /// Get the (col, value) pairs for a row
228    pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, T)> + '_ {
229        let range = self.row_range(row);
230        self.col_indices[range.clone()]
231            .iter()
232            .copied()
233            .zip(self.values[range].iter().copied())
234    }
235
236    /// Matrix-vector product: y = A * x
237    ///
238    /// Uses parallel processing when the `rayon` feature is enabled and the
239    /// matrix is large enough to benefit from parallelization.
240    pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
241        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
242
243        // Use parallel version for large matrices when rayon is available
244        #[cfg(feature = "rayon")]
245        {
246            // Only parallelize if we have enough rows to benefit
247            if self.num_rows >= 246 {
248                return self.matvec_parallel(x);
249            }
250        }
251
252        self.matvec_sequential(x)
253    }
254
255    /// Sequential matrix-vector product
256    fn matvec_sequential(&self, x: &Array1<T>) -> Array1<T> {
257        let mut y = Array1::from_elem(self.num_rows, T::zero());
258
259        for i in 0..self.num_rows {
260            let mut sum = T::zero();
261            for idx in self.row_range(i) {
262                let j = self.col_indices[idx];
263                sum += self.values[idx] * x[j];
264            }
265            y[i] = sum;
266        }
267
268        y
269    }
270
271    /// Parallel matrix-vector product using rayon
272    #[cfg(feature = "rayon")]
273    fn matvec_parallel(&self, x: &Array1<T>) -> Array1<T>
274    where
275        T: Send + Sync,
276    {
277        let x_slice = x.as_slice().expect("Array should be contiguous");
278
279        let results: Vec<T> = (0..self.num_rows)
280            .into_par_iter()
281            .map(|i| {
282                let mut sum = T::zero();
283                for idx in self.row_range(i) {
284                    let j = self.col_indices[idx];
285                    sum += self.values[idx] * x_slice[j];
286                }
287                sum
288            })
289            .collect();
290
291        Array1::from_vec(results)
292    }
293
294    /// Matrix-vector product with accumulation: y += A * x
295    pub fn matvec_add(&self, x: &Array1<T>, y: &mut Array1<T>) {
296        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
297        assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
298
299        for i in 0..self.num_rows {
300            for idx in self.row_range(i) {
301                let j = self.col_indices[idx];
302                y[i] += self.values[idx] * x[j];
303            }
304        }
305    }
306
307    /// Transpose matrix-vector product: y = A^T * x
308    pub fn matvec_transpose(&self, x: &Array1<T>) -> Array1<T> {
309        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
310
311        let mut y = Array1::from_elem(self.num_cols, T::zero());
312
313        for i in 0..self.num_rows {
314            for idx in self.row_range(i) {
315                let j = self.col_indices[idx];
316                y[j] += self.values[idx] * x[i];
317            }
318        }
319
320        y
321    }
322
323    /// Hermitian (conjugate transpose) matrix-vector product: y = A^H * x
324    pub fn matvec_hermitian(&self, x: &Array1<T>) -> Array1<T> {
325        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
326
327        let mut y = Array1::from_elem(self.num_cols, T::zero());
328
329        for i in 0..self.num_rows {
330            for idx in self.row_range(i) {
331                let j = self.col_indices[idx];
332                y[j] += self.values[idx].conj() * x[i];
333            }
334        }
335
336        y
337    }
338
339    /// Get element at (i, j), returns 0 if not stored
340    pub fn get(&self, i: usize, j: usize) -> T {
341        for idx in self.row_range(i) {
342            if self.col_indices[idx] == j {
343                return self.values[idx];
344            }
345        }
346        T::zero()
347    }
348
349    /// Extract diagonal elements
350    pub fn diagonal(&self) -> Array1<T> {
351        let n = self.num_rows.min(self.num_cols);
352        let mut diag = Array1::from_elem(n, T::zero());
353
354        for i in 0..n {
355            diag[i] = self.get(i, i);
356        }
357
358        diag
359    }
360
361    /// Scale all values by a scalar
362    pub fn scale(&mut self, scalar: T) {
363        for val in &mut self.values {
364            *val *= scalar;
365        }
366    }
367
368    /// Add a scalar to the diagonal
369    pub fn add_diagonal(&mut self, scalar: T) {
370        let n = self.num_rows.min(self.num_cols);
371
372        for i in 0..n {
373            for idx in self.row_range(i) {
374                if self.col_indices[idx] == i {
375                    self.values[idx] += scalar;
376                    break;
377                }
378            }
379        }
380    }
381
382    /// Create identity matrix in CSR format
383    pub fn identity(n: usize) -> Self {
384        Self {
385            num_rows: n,
386            num_cols: n,
387            values: vec![T::one(); n],
388            col_indices: (0..n).collect(),
389            row_ptrs: (0..=n).collect(),
390        }
391    }
392
393    /// Create diagonal matrix from vector
394    pub fn from_diagonal(diag: &Array1<T>) -> Self {
395        let n = diag.len();
396        Self {
397            num_rows: n,
398            num_cols: n,
399            values: diag.to_vec(),
400            col_indices: (0..n).collect(),
401            row_ptrs: (0..=n).collect(),
402        }
403    }
404
405    /// Convert to dense matrix (for debugging/small matrices)
406    pub fn to_dense(&self) -> Array2<T> {
407        let mut dense = Array2::from_elem((self.num_rows, self.num_cols), T::zero());
408
409        for i in 0..self.num_rows {
410            for idx in self.row_range(i) {
411                let j = self.col_indices[idx];
412                dense[[i, j]] = self.values[idx];
413            }
414        }
415
416        dense
417    }
418}
419
420impl<T: ComplexField> LinearOperator<T> for CsrMatrix<T> {
421    fn num_rows(&self) -> usize {
422        self.num_rows
423    }
424
425    fn num_cols(&self) -> usize {
426        self.num_cols
427    }
428
429    fn apply(&self, x: &Array1<T>) -> Array1<T> {
430        self.matvec(x)
431    }
432
433    fn apply_transpose(&self, x: &Array1<T>) -> Array1<T> {
434        self.matvec_transpose(x)
435    }
436
437    fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
438        self.matvec_hermitian(x)
439    }
440}
441
442/// Builder for constructing CSR matrices row by row
443pub struct CsrBuilder<T: ComplexField> {
444    num_rows: usize,
445    num_cols: usize,
446    values: Vec<T>,
447    col_indices: Vec<usize>,
448    row_ptrs: Vec<usize>,
449    current_row: usize,
450}
451
452impl<T: ComplexField> CsrBuilder<T> {
453    /// Create a new CSR builder
454    pub fn new(num_rows: usize, num_cols: usize) -> Self {
455        Self {
456            num_rows,
457            num_cols,
458            values: Vec::new(),
459            col_indices: Vec::new(),
460            row_ptrs: vec![0],
461            current_row: 0,
462        }
463    }
464
465    /// Create a new CSR builder with estimated non-zeros
466    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
467        Self {
468            num_rows,
469            num_cols,
470            values: Vec::with_capacity(nnz_estimate),
471            col_indices: Vec::with_capacity(nnz_estimate),
472            row_ptrs: Vec::with_capacity(num_rows + 1),
473            current_row: 0,
474        }
475    }
476
477    /// Add entries for the current row (must be added in column order)
478    pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, T)>) {
479        for (col, val) in entries {
480            if val.norm() > T::Real::zero() {
481                self.values.push(val);
482                self.col_indices.push(col);
483            }
484        }
485        self.row_ptrs.push(self.values.len());
486        self.current_row += 1;
487    }
488
489    /// Finish building and return the CSR matrix
490    pub fn finish(mut self) -> CsrMatrix<T> {
491        // Fill remaining rows if not all rows were added
492        while self.current_row < self.num_rows {
493            self.row_ptrs.push(self.values.len());
494            self.current_row += 1;
495        }
496
497        CsrMatrix {
498            num_rows: self.num_rows,
499            num_cols: self.num_cols,
500            values: self.values,
501            col_indices: self.col_indices,
502            row_ptrs: self.row_ptrs,
503        }
504    }
505}
506
507/// Blocked CSR format for hierarchical matrices
508///
509/// Stores the matrix as a collection of dense blocks at the leaf level
510/// of a hierarchical decomposition.
511#[derive(Debug, Clone)]
512pub struct BlockedCsr<T: ComplexField> {
513    /// Number of rows
514    pub num_rows: usize,
515    /// Number of columns
516    pub num_cols: usize,
517    /// Block size (rows and columns per block)
518    pub block_size: usize,
519    /// Number of block rows
520    pub num_block_rows: usize,
521    /// Number of block columns
522    pub num_block_cols: usize,
523    /// Dense blocks stored in CSR-like format
524    pub blocks: Vec<Array2<T>>,
525    /// Block column indices
526    pub block_col_indices: Vec<usize>,
527    /// Block row pointers
528    pub block_row_ptrs: Vec<usize>,
529}
530
531impl<T: ComplexField> BlockedCsr<T> {
532    /// Create a new blocked CSR matrix
533    pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
534        let num_block_rows = num_rows.div_ceil(block_size);
535        let num_block_cols = num_cols.div_ceil(block_size);
536
537        Self {
538            num_rows,
539            num_cols,
540            block_size,
541            num_block_rows,
542            num_block_cols,
543            blocks: Vec::new(),
544            block_col_indices: Vec::new(),
545            block_row_ptrs: vec![0; num_block_rows + 1],
546        }
547    }
548
549    /// Matrix-vector product using blocked structure
550    pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
551        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
552
553        let mut y = Array1::from_elem(self.num_rows, T::zero());
554
555        for block_i in 0..self.num_block_rows {
556            let row_start = block_i * self.block_size;
557            let row_end = (row_start + self.block_size).min(self.num_rows);
558            let local_rows = row_end - row_start;
559
560            for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
561                let block_j = self.block_col_indices[idx];
562                let block = &self.blocks[idx];
563
564                let col_start = block_j * self.block_size;
565                let col_end = (col_start + self.block_size).min(self.num_cols);
566                let local_cols = col_end - col_start;
567
568                // Extract local x
569                let x_local: Array1<T> = Array1::from_iter((col_start..col_end).map(|j| x[j]));
570
571                // Apply block
572                for i in 0..local_rows {
573                    let mut sum = T::zero();
574                    for j in 0..local_cols {
575                        sum += block[[i, j]] * x_local[j];
576                    }
577                    y[row_start + i] += sum;
578                }
579            }
580        }
581
582        y
583    }
584}
585
586/// Optimized sparse matrix-matrix multiplication: C = A * B
587///
588/// Uses a sorted accumulation approach instead of HashMap for better cache locality,
589/// providing 2-4x speedup for the AMG Galerkin product.
590///
591/// For CSR matrices A (m×k) and B (k×n), computes C (m×n).
592impl<T: ComplexField> CsrMatrix<T> {
593    /// Compute C = A * B using optimized approach
594    pub fn matmul(&self, other: &CsrMatrix<T>) -> CsrMatrix<T> {
595        assert_eq!(
596            self.num_cols, other.num_rows,
597            "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
598            self.num_cols, other.num_rows
599        );
600
601        let m = self.num_rows;
602        let n = other.num_cols;
603
604        if m == 0 || n == 0 || self.nnz() == 0 || other.nnz() == 0 {
605            return CsrMatrix::new(m, n);
606        }
607
608        let tol = T::Real::from_f64(1e-15).unwrap();
609
610        let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(self.nnz() * 4);
611
612        for i in 0..m {
613            let mut row_data: Vec<(usize, T)> = Vec::new();
614
615            for (k, a_ik) in self.row_entries(i) {
616                for (j, b_kj) in other.row_entries(k) {
617                    row_data.push((j, a_ik * b_kj));
618                }
619            }
620
621            if row_data.is_empty() {
622                continue;
623            }
624
625            row_data.sort_by_key(|&(j, _)| j);
626
627            let mut current_j = row_data[0].0;
628            let mut current_val = row_data[0].1;
629
630            for &(j, val) in &row_data[1..] {
631                if j == current_j {
632                    current_val += val;
633                } else {
634                    if current_val.norm() > tol {
635                        triplets.push((i, current_j, current_val));
636                    }
637                    current_j = j;
638                    current_val = val;
639                }
640            }
641
642            if current_val.norm() > tol {
643                triplets.push((i, current_j, current_val));
644            }
645        }
646
647        CsrMatrix::from_triplets(m, n, triplets)
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654    use approx::assert_relative_eq;
655    use ndarray::array;
656    use num_complex::Complex64;
657
658    #[test]
659    fn test_csr_from_dense() {
660        let dense = array![
661            [
662                Complex64::new(1.0, 0.0),
663                Complex64::new(0.0, 0.0),
664                Complex64::new(2.0, 0.0)
665            ],
666            [
667                Complex64::new(0.0, 0.0),
668                Complex64::new(3.0, 0.0),
669                Complex64::new(0.0, 0.0)
670            ],
671            [
672                Complex64::new(4.0, 0.0),
673                Complex64::new(0.0, 0.0),
674                Complex64::new(5.0, 0.0)
675            ],
676        ];
677
678        let csr = CsrMatrix::from_dense(&dense, 1e-15);
679
680        assert_eq!(csr.num_rows, 3);
681        assert_eq!(csr.num_cols, 3);
682        assert_eq!(csr.nnz(), 5);
683
684        assert_relative_eq!(csr.get(0, 0).re, 1.0);
685        assert_relative_eq!(csr.get(0, 2).re, 2.0);
686        assert_relative_eq!(csr.get(1, 1).re, 3.0);
687        assert_relative_eq!(csr.get(2, 0).re, 4.0);
688        assert_relative_eq!(csr.get(2, 2).re, 5.0);
689    }
690
691    #[test]
692    fn test_csr_matvec() {
693        let dense = array![
694            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
695            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
696        ];
697
698        let csr = CsrMatrix::from_dense(&dense, 1e-15);
699        let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
700
701        let y = csr.matvec(&x);
702
703        // [1 2] * [1]   [5]
704        // [3 4]   [2] = [11]
705        assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
706        assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
707    }
708
709    #[test]
710    fn test_csr_from_triplets() {
711        let triplets = vec![
712            (0, 0, Complex64::new(1.0, 0.0)),
713            (0, 2, Complex64::new(2.0, 0.0)),
714            (1, 1, Complex64::new(3.0, 0.0)),
715            (2, 0, Complex64::new(4.0, 0.0)),
716            (2, 2, Complex64::new(5.0, 0.0)),
717        ];
718
719        let csr = CsrMatrix::from_triplets(3, 3, triplets);
720
721        assert_eq!(csr.nnz(), 5);
722        assert_relative_eq!(csr.get(0, 0).re, 1.0);
723        assert_relative_eq!(csr.get(1, 1).re, 3.0);
724    }
725
726    #[test]
727    fn test_csr_triplets_duplicate() {
728        let triplets = vec![
729            (0, 0, Complex64::new(1.0, 0.0)),
730            (0, 0, Complex64::new(2.0, 0.0)), // Duplicate!
731            (1, 1, Complex64::new(3.0, 0.0)),
732        ];
733
734        let csr = CsrMatrix::from_triplets(2, 2, triplets);
735
736        assert_relative_eq!(csr.get(0, 0).re, 3.0); // 1 + 2 = 3
737    }
738
739    #[test]
740    fn test_csr_identity() {
741        let id: CsrMatrix<Complex64> = CsrMatrix::identity(3);
742
743        assert_eq!(id.nnz(), 3);
744        assert_relative_eq!(id.get(0, 0).re, 1.0);
745        assert_relative_eq!(id.get(1, 1).re, 1.0);
746        assert_relative_eq!(id.get(2, 2).re, 1.0);
747        assert_relative_eq!(id.get(0, 1).norm(), 0.0);
748    }
749
750    #[test]
751    fn test_csr_builder() {
752        let mut builder: CsrBuilder<Complex64> = CsrBuilder::new(3, 3);
753
754        builder.add_row_entries(
755            [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))].into_iter(),
756        );
757        builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
758        builder.add_row_entries(
759            [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))].into_iter(),
760        );
761
762        let csr = builder.finish();
763
764        assert_eq!(csr.nnz(), 5);
765        assert_relative_eq!(csr.get(0, 0).re, 1.0);
766        assert_relative_eq!(csr.get(1, 1).re, 3.0);
767    }
768
769    #[test]
770    fn test_csr_to_dense_roundtrip() {
771        let original = array![
772            [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
773            [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
774        ];
775
776        let csr = CsrMatrix::from_dense(&original, 1e-15);
777        let recovered = csr.to_dense();
778
779        for i in 0..2 {
780            for j in 0..2 {
781                assert_relative_eq!(
782                    (original[[i, j]] - recovered[[i, j]]).norm(),
783                    0.0,
784                    epsilon = 1e-10
785                );
786            }
787        }
788    }
789
790    #[test]
791    fn test_linear_operator_impl() {
792        let dense = array![
793            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
794            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
795        ];
796
797        let csr = CsrMatrix::from_dense(&dense, 1e-15);
798        let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
799
800        // Test via LinearOperator trait
801        let y = csr.apply(&x);
802        assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
803        assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
804
805        assert!(csr.is_square());
806        assert_eq!(csr.num_rows(), 2);
807        assert_eq!(csr.num_cols(), 2);
808    }
809
810    #[test]
811    fn test_f64_csr() {
812        let dense = array![[1.0_f64, 2.0], [3.0, 4.0],];
813
814        let csr = CsrMatrix::from_dense(&dense, 1e-15);
815        let x = array![1.0_f64, 2.0];
816
817        let y = csr.matvec(&x);
818        assert_relative_eq!(y[0], 5.0, epsilon = 1e-10);
819        assert_relative_eq!(y[1], 11.0, epsilon = 1e-10);
820    }
821}