solvers/sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format
2//!
3//! CSR format stores:
4//! - `values`: Non-zero entries in row-major order
5//! - `col_indices`: Column index for each value
6//! - `row_ptrs`: Index into values/col_indices where each row starts
7
8use crate::traits::{ComplexField, LinearOperator};
9use ndarray::{Array1, Array2};
10use num_traits::Zero;
11use std::ops::Range;
12
13#[cfg(feature = "rayon")]
14use rayon::prelude::*;
15
16/// Compressed Sparse Row (CSR) matrix format
17///
18/// Memory-efficient storage for sparse matrices with O(nnz) space complexity.
19/// Matrix-vector products are O(nnz) instead of O(n²) for dense matrices.
20#[derive(Debug, Clone)]
21pub struct CsrMatrix<T: ComplexField> {
22    /// Number of rows
23    pub num_rows: usize,
24    /// Number of columns
25    pub num_cols: usize,
26    /// Non-zero values in row-major order
27    pub values: Vec<T>,
28    /// Column indices for each value
29    pub col_indices: Vec<usize>,
30    /// Row pointers: row_ptrs[i] is the start index in values/col_indices for row i
31    /// row_ptrs[num_rows] = nnz (total number of non-zeros)
32    pub row_ptrs: Vec<usize>,
33}
34
35impl<T: ComplexField> CsrMatrix<T> {
36    /// Create a new empty CSR matrix
37    pub fn new(num_rows: usize, num_cols: usize) -> Self {
38        Self {
39            num_rows,
40            num_cols,
41            values: Vec::new(),
42            col_indices: Vec::new(),
43            row_ptrs: vec![0; num_rows + 1],
44        }
45    }
46
47    /// Create a CSR matrix with pre-allocated capacity
48    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
49        Self {
50            num_rows,
51            num_cols,
52            values: Vec::with_capacity(nnz_estimate),
53            col_indices: Vec::with_capacity(nnz_estimate),
54            row_ptrs: vec![0; num_rows + 1],
55        }
56    }
57
58    /// Create a CSR matrix from raw components
59    ///
60    /// This is useful for converting between different CSR matrix representations
61    /// that share the same internal structure.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the input arrays are inconsistent:
66    /// - `row_ptrs` must have length `num_rows + 1`
67    /// - `col_indices` and `values` must have the same length
68    /// - `row_ptrs[num_rows]` must equal `values.len()`
69    pub fn from_raw_parts(
70        num_rows: usize,
71        num_cols: usize,
72        row_ptrs: Vec<usize>,
73        col_indices: Vec<usize>,
74        values: Vec<T>,
75    ) -> Self {
76        assert_eq!(
77            row_ptrs.len(),
78            num_rows + 1,
79            "row_ptrs must have num_rows + 1 elements"
80        );
81        assert_eq!(
82            col_indices.len(),
83            values.len(),
84            "col_indices and values must have the same length"
85        );
86        assert_eq!(
87            row_ptrs[num_rows],
88            values.len(),
89            "row_ptrs[num_rows] must equal nnz"
90        );
91
92        Self {
93            num_rows,
94            num_cols,
95            row_ptrs,
96            col_indices,
97            values,
98        }
99    }
100
101    /// Create a CSR matrix from a dense matrix
102    ///
103    /// Only stores entries with magnitude > threshold
104    pub fn from_dense(dense: &Array2<T>, threshold: T::Real) -> Self {
105        let num_rows = dense.nrows();
106        let num_cols = dense.ncols();
107
108        let mut values = Vec::new();
109        let mut col_indices = Vec::new();
110        let mut row_ptrs = vec![0usize; num_rows + 1];
111
112        for i in 0..num_rows {
113            for j in 0..num_cols {
114                let val = dense[[i, j]];
115                if val.norm() > threshold {
116                    values.push(val);
117                    col_indices.push(j);
118                }
119            }
120            row_ptrs[i + 1] = values.len();
121        }
122
123        Self {
124            num_rows,
125            num_cols,
126            values,
127            col_indices,
128            row_ptrs,
129        }
130    }
131
132    /// Create a CSR matrix from COO (Coordinate) format triplets
133    ///
134    /// Triplets are (row, col, value). Duplicate entries are summed.
135    pub fn from_triplets(
136        num_rows: usize,
137        num_cols: usize,
138        mut triplets: Vec<(usize, usize, T)>,
139    ) -> Self {
140        if triplets.is_empty() {
141            return Self::new(num_rows, num_cols);
142        }
143
144        // Sort by row, then by column
145        triplets.sort_by(|a, b| {
146            if a.0 != b.0 {
147                a.0.cmp(&b.0)
148            } else {
149                a.1.cmp(&b.1)
150            }
151        });
152
153        let mut values = Vec::with_capacity(triplets.len());
154        let mut col_indices = Vec::with_capacity(triplets.len());
155        let mut row_ptrs = vec![0usize; num_rows + 1];
156
157        let mut prev_row = usize::MAX;
158        let mut prev_col = usize::MAX;
159
160        for (row, col, val) in triplets {
161            if row == prev_row && col == prev_col {
162                // Same entry, accumulate
163                if let Some(last) = values.last_mut() {
164                    *last += val;
165                }
166            } else {
167                // New entry - push it
168                values.push(val);
169                col_indices.push(col);
170
171                // Update row pointers for any rows we skipped
172                if row != prev_row {
173                    let start = if prev_row == usize::MAX {
174                        0
175                    } else {
176                        prev_row + 1
177                    };
178                    for item in row_ptrs.iter_mut().take(row + 1).skip(start) {
179                        *item = values.len() - 1;
180                    }
181                }
182
183                prev_row = row;
184                prev_col = col;
185            }
186        }
187
188        // Fill remaining row pointers
189        let last_row = if prev_row == usize::MAX {
190            0
191        } else {
192            prev_row + 1
193        };
194        for item in row_ptrs.iter_mut().take(num_rows + 1).skip(last_row) {
195            *item = values.len();
196        }
197
198        Self {
199            num_rows,
200            num_cols,
201            values,
202            col_indices,
203            row_ptrs,
204        }
205    }
206
207    /// Number of non-zero entries
208    pub fn nnz(&self) -> usize {
209        self.values.len()
210    }
211
212    /// Sparsity ratio (fraction of non-zero entries)
213    pub fn sparsity(&self) -> f64 {
214        let total = self.num_rows * self.num_cols;
215        if total == 0 {
216            0.0
217        } else {
218            self.nnz() as f64 / total as f64
219        }
220    }
221
222    /// Get the range of indices in values/col_indices for a given row
223    pub fn row_range(&self, row: usize) -> Range<usize> {
224        self.row_ptrs[row]..self.row_ptrs[row + 1]
225    }
226
227    /// Get the (col, value) pairs for a row
228    pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, T)> + '_ {
229        let range = self.row_range(row);
230        self.col_indices[range.clone()]
231            .iter()
232            .copied()
233            .zip(self.values[range].iter().copied())
234    }
235
236    /// Matrix-vector product: y = A * x
237    ///
238    /// Uses parallel processing when the `rayon` feature is enabled and the
239    /// matrix is large enough to benefit from parallelization.
240    pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
241        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
242
243        // Use parallel version for large matrices when rayon is available
244        #[cfg(feature = "rayon")]
245        {
246            // Only parallelize if we have enough rows to benefit
247            if self.num_rows >= 246 {
248                return self.matvec_parallel(x);
249            }
250        }
251
252        self.matvec_sequential(x)
253    }
254
255    /// Sequential matrix-vector product
256    fn matvec_sequential(&self, x: &Array1<T>) -> Array1<T> {
257        let mut y = Array1::from_elem(self.num_rows, T::zero());
258
259        for i in 0..self.num_rows {
260            let mut sum = T::zero();
261            for idx in self.row_range(i) {
262                let j = self.col_indices[idx];
263                sum += self.values[idx] * x[j];
264            }
265            y[i] = sum;
266        }
267
268        y
269    }
270
271    /// Parallel matrix-vector product using rayon
272    #[cfg(feature = "rayon")]
273    fn matvec_parallel(&self, x: &Array1<T>) -> Array1<T>
274    where
275        T: Send + Sync,
276    {
277        let x_slice = x.as_slice().expect("Array should be contiguous");
278
279        let results: Vec<T> = (0..self.num_rows)
280            .into_par_iter()
281            .map(|i| {
282                let mut sum = T::zero();
283                for idx in self.row_range(i) {
284                    let j = self.col_indices[idx];
285                    sum += self.values[idx] * x_slice[j];
286                }
287                sum
288            })
289            .collect();
290
291        Array1::from_vec(results)
292    }
293
294    /// Matrix-vector product with accumulation: y += A * x
295    pub fn matvec_add(&self, x: &Array1<T>, y: &mut Array1<T>) {
296        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
297        assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
298
299        for i in 0..self.num_rows {
300            for idx in self.row_range(i) {
301                let j = self.col_indices[idx];
302                y[i] += self.values[idx] * x[j];
303            }
304        }
305    }
306
307    /// Transpose matrix-vector product: y = A^T * x
308    pub fn matvec_transpose(&self, x: &Array1<T>) -> Array1<T> {
309        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
310
311        let mut y = Array1::from_elem(self.num_cols, T::zero());
312
313        for i in 0..self.num_rows {
314            for idx in self.row_range(i) {
315                let j = self.col_indices[idx];
316                y[j] += self.values[idx] * x[i];
317            }
318        }
319
320        y
321    }
322
323    /// Hermitian (conjugate transpose) matrix-vector product: y = A^H * x
324    pub fn matvec_hermitian(&self, x: &Array1<T>) -> Array1<T> {
325        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
326
327        let mut y = Array1::from_elem(self.num_cols, T::zero());
328
329        for i in 0..self.num_rows {
330            for idx in self.row_range(i) {
331                let j = self.col_indices[idx];
332                y[j] += self.values[idx].conj() * x[i];
333            }
334        }
335
336        y
337    }
338
339    /// Get element at (i, j), returns 0 if not stored
340    pub fn get(&self, i: usize, j: usize) -> T {
341        for idx in self.row_range(i) {
342            if self.col_indices[idx] == j {
343                return self.values[idx];
344            }
345        }
346        T::zero()
347    }
348
349    /// Extract diagonal elements
350    pub fn diagonal(&self) -> Array1<T> {
351        let n = self.num_rows.min(self.num_cols);
352        let mut diag = Array1::from_elem(n, T::zero());
353
354        for i in 0..n {
355            diag[i] = self.get(i, i);
356        }
357
358        diag
359    }
360
361    /// Scale all values by a scalar
362    pub fn scale(&mut self, scalar: T) {
363        for val in &mut self.values {
364            *val *= scalar;
365        }
366    }
367
368    /// Add a scalar to the diagonal
369    pub fn add_diagonal(&mut self, scalar: T) {
370        let n = self.num_rows.min(self.num_cols);
371
372        for i in 0..n {
373            for idx in self.row_range(i) {
374                if self.col_indices[idx] == i {
375                    self.values[idx] += scalar;
376                    break;
377                }
378            }
379        }
380    }
381
382    /// Create identity matrix in CSR format
383    pub fn identity(n: usize) -> Self {
384        Self {
385            num_rows: n,
386            num_cols: n,
387            values: vec![T::one(); n],
388            col_indices: (0..n).collect(),
389            row_ptrs: (0..=n).collect(),
390        }
391    }
392
393    /// Create diagonal matrix from vector
394    pub fn from_diagonal(diag: &Array1<T>) -> Self {
395        let n = diag.len();
396        Self {
397            num_rows: n,
398            num_cols: n,
399            values: diag.to_vec(),
400            col_indices: (0..n).collect(),
401            row_ptrs: (0..=n).collect(),
402        }
403    }
404
405    /// Convert to dense matrix (for debugging/small matrices)
406    pub fn to_dense(&self) -> Array2<T> {
407        let mut dense = Array2::from_elem((self.num_rows, self.num_cols), T::zero());
408
409        for i in 0..self.num_rows {
410            for idx in self.row_range(i) {
411                let j = self.col_indices[idx];
412                dense[[i, j]] = self.values[idx];
413            }
414        }
415
416        dense
417    }
418}
419
420impl<T: ComplexField> LinearOperator<T> for CsrMatrix<T> {
421    fn num_rows(&self) -> usize {
422        self.num_rows
423    }
424
425    fn num_cols(&self) -> usize {
426        self.num_cols
427    }
428
429    fn apply(&self, x: &Array1<T>) -> Array1<T> {
430        self.matvec(x)
431    }
432
433    fn apply_transpose(&self, x: &Array1<T>) -> Array1<T> {
434        self.matvec_transpose(x)
435    }
436
437    fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
438        self.matvec_hermitian(x)
439    }
440}
441
442/// Builder for constructing CSR matrices row by row
443pub struct CsrBuilder<T: ComplexField> {
444    num_rows: usize,
445    num_cols: usize,
446    values: Vec<T>,
447    col_indices: Vec<usize>,
448    row_ptrs: Vec<usize>,
449    current_row: usize,
450}
451
452impl<T: ComplexField> CsrBuilder<T> {
453    /// Create a new CSR builder
454    pub fn new(num_rows: usize, num_cols: usize) -> Self {
455        Self {
456            num_rows,
457            num_cols,
458            values: Vec::new(),
459            col_indices: Vec::new(),
460            row_ptrs: vec![0],
461            current_row: 0,
462        }
463    }
464
465    /// Create a new CSR builder with estimated non-zeros
466    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
467        Self {
468            num_rows,
469            num_cols,
470            values: Vec::with_capacity(nnz_estimate),
471            col_indices: Vec::with_capacity(nnz_estimate),
472            row_ptrs: Vec::with_capacity(num_rows + 1),
473            current_row: 0,
474        }
475    }
476
477    /// Add entries for the current row (must be added in column order)
478    pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, T)>) {
479        for (col, val) in entries {
480            if val.norm() > T::Real::zero() {
481                self.values.push(val);
482                self.col_indices.push(col);
483            }
484        }
485        self.row_ptrs.push(self.values.len());
486        self.current_row += 1;
487    }
488
489    /// Finish building and return the CSR matrix
490    pub fn finish(mut self) -> CsrMatrix<T> {
491        // Fill remaining rows if not all rows were added
492        while self.current_row < self.num_rows {
493            self.row_ptrs.push(self.values.len());
494            self.current_row += 1;
495        }
496
497        CsrMatrix {
498            num_rows: self.num_rows,
499            num_cols: self.num_cols,
500            values: self.values,
501            col_indices: self.col_indices,
502            row_ptrs: self.row_ptrs,
503        }
504    }
505}
506
507/// Blocked CSR format for hierarchical matrices
508///
509/// Stores the matrix as a collection of dense blocks at the leaf level
510/// of a hierarchical decomposition.
511#[derive(Debug, Clone)]
512pub struct BlockedCsr<T: ComplexField> {
513    /// Number of rows
514    pub num_rows: usize,
515    /// Number of columns
516    pub num_cols: usize,
517    /// Block size (rows and columns per block)
518    pub block_size: usize,
519    /// Number of block rows
520    pub num_block_rows: usize,
521    /// Number of block columns
522    pub num_block_cols: usize,
523    /// Dense blocks stored in CSR-like format
524    pub blocks: Vec<Array2<T>>,
525    /// Block column indices
526    pub block_col_indices: Vec<usize>,
527    /// Block row pointers
528    pub block_row_ptrs: Vec<usize>,
529}
530
531impl<T: ComplexField> BlockedCsr<T> {
532    /// Create a new blocked CSR matrix
533    pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
534        let num_block_rows = num_rows.div_ceil(block_size);
535        let num_block_cols = num_cols.div_ceil(block_size);
536
537        Self {
538            num_rows,
539            num_cols,
540            block_size,
541            num_block_rows,
542            num_block_cols,
543            blocks: Vec::new(),
544            block_col_indices: Vec::new(),
545            block_row_ptrs: vec![0; num_block_rows + 1],
546        }
547    }
548
549    /// Matrix-vector product using blocked structure
550    pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
551        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
552
553        let mut y = Array1::from_elem(self.num_rows, T::zero());
554
555        for block_i in 0..self.num_block_rows {
556            let row_start = block_i * self.block_size;
557            let row_end = (row_start + self.block_size).min(self.num_rows);
558            let local_rows = row_end - row_start;
559
560            for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
561                let block_j = self.block_col_indices[idx];
562                let block = &self.blocks[idx];
563
564                let col_start = block_j * self.block_size;
565                let col_end = (col_start + self.block_size).min(self.num_cols);
566                let local_cols = col_end - col_start;
567
568                // Extract local x
569                let x_local: Array1<T> = Array1::from_iter((col_start..col_end).map(|j| x[j]));
570
571                // Apply block
572                for i in 0..local_rows {
573                    let mut sum = T::zero();
574                    for j in 0..local_cols {
575                        sum += block[[i, j]] * x_local[j];
576                    }
577                    y[row_start + i] += sum;
578                }
579            }
580        }
581
582        y
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use approx::assert_relative_eq;
590    use ndarray::array;
591    use num_complex::Complex64;
592
593    #[test]
594    fn test_csr_from_dense() {
595        let dense = array![
596            [
597                Complex64::new(1.0, 0.0),
598                Complex64::new(0.0, 0.0),
599                Complex64::new(2.0, 0.0)
600            ],
601            [
602                Complex64::new(0.0, 0.0),
603                Complex64::new(3.0, 0.0),
604                Complex64::new(0.0, 0.0)
605            ],
606            [
607                Complex64::new(4.0, 0.0),
608                Complex64::new(0.0, 0.0),
609                Complex64::new(5.0, 0.0)
610            ],
611        ];
612
613        let csr = CsrMatrix::from_dense(&dense, 1e-15);
614
615        assert_eq!(csr.num_rows, 3);
616        assert_eq!(csr.num_cols, 3);
617        assert_eq!(csr.nnz(), 5);
618
619        assert_relative_eq!(csr.get(0, 0).re, 1.0);
620        assert_relative_eq!(csr.get(0, 2).re, 2.0);
621        assert_relative_eq!(csr.get(1, 1).re, 3.0);
622        assert_relative_eq!(csr.get(2, 0).re, 4.0);
623        assert_relative_eq!(csr.get(2, 2).re, 5.0);
624    }
625
626    #[test]
627    fn test_csr_matvec() {
628        let dense = array![
629            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
630            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
631        ];
632
633        let csr = CsrMatrix::from_dense(&dense, 1e-15);
634        let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
635
636        let y = csr.matvec(&x);
637
638        // [1 2] * [1]   [5]
639        // [3 4]   [2] = [11]
640        assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
641        assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
642    }
643
644    #[test]
645    fn test_csr_from_triplets() {
646        let triplets = vec![
647            (0, 0, Complex64::new(1.0, 0.0)),
648            (0, 2, Complex64::new(2.0, 0.0)),
649            (1, 1, Complex64::new(3.0, 0.0)),
650            (2, 0, Complex64::new(4.0, 0.0)),
651            (2, 2, Complex64::new(5.0, 0.0)),
652        ];
653
654        let csr = CsrMatrix::from_triplets(3, 3, triplets);
655
656        assert_eq!(csr.nnz(), 5);
657        assert_relative_eq!(csr.get(0, 0).re, 1.0);
658        assert_relative_eq!(csr.get(1, 1).re, 3.0);
659    }
660
661    #[test]
662    fn test_csr_triplets_duplicate() {
663        let triplets = vec![
664            (0, 0, Complex64::new(1.0, 0.0)),
665            (0, 0, Complex64::new(2.0, 0.0)), // Duplicate!
666            (1, 1, Complex64::new(3.0, 0.0)),
667        ];
668
669        let csr = CsrMatrix::from_triplets(2, 2, triplets);
670
671        assert_relative_eq!(csr.get(0, 0).re, 3.0); // 1 + 2 = 3
672    }
673
674    #[test]
675    fn test_csr_identity() {
676        let id: CsrMatrix<Complex64> = CsrMatrix::identity(3);
677
678        assert_eq!(id.nnz(), 3);
679        assert_relative_eq!(id.get(0, 0).re, 1.0);
680        assert_relative_eq!(id.get(1, 1).re, 1.0);
681        assert_relative_eq!(id.get(2, 2).re, 1.0);
682        assert_relative_eq!(id.get(0, 1).norm(), 0.0);
683    }
684
685    #[test]
686    fn test_csr_builder() {
687        let mut builder: CsrBuilder<Complex64> = CsrBuilder::new(3, 3);
688
689        builder.add_row_entries(
690            [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))].into_iter(),
691        );
692        builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
693        builder.add_row_entries(
694            [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))].into_iter(),
695        );
696
697        let csr = builder.finish();
698
699        assert_eq!(csr.nnz(), 5);
700        assert_relative_eq!(csr.get(0, 0).re, 1.0);
701        assert_relative_eq!(csr.get(1, 1).re, 3.0);
702    }
703
704    #[test]
705    fn test_csr_to_dense_roundtrip() {
706        let original = array![
707            [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
708            [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
709        ];
710
711        let csr = CsrMatrix::from_dense(&original, 1e-15);
712        let recovered = csr.to_dense();
713
714        for i in 0..2 {
715            for j in 0..2 {
716                assert_relative_eq!(
717                    (original[[i, j]] - recovered[[i, j]]).norm(),
718                    0.0,
719                    epsilon = 1e-10
720                );
721            }
722        }
723    }
724
725    #[test]
726    fn test_linear_operator_impl() {
727        let dense = array![
728            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
729            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
730        ];
731
732        let csr = CsrMatrix::from_dense(&dense, 1e-15);
733        let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
734
735        // Test via LinearOperator trait
736        let y = csr.apply(&x);
737        assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
738        assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
739
740        assert!(csr.is_square());
741        assert_eq!(csr.num_rows(), 2);
742        assert_eq!(csr.num_cols(), 2);
743    }
744
745    #[test]
746    fn test_f64_csr() {
747        let dense = array![[1.0_f64, 2.0], [3.0, 4.0],];
748
749        let csr = CsrMatrix::from_dense(&dense, 1e-15);
750        let x = array![1.0_f64, 2.0];
751
752        let y = csr.matvec(&x);
753        assert_relative_eq!(y[0], 5.0, epsilon = 1e-10);
754        assert_relative_eq!(y[1], 11.0, epsilon = 1e-10);
755    }
756}