bem/core/assembly/
sparse.rs

1//! Sparse matrix structures (CSR format)
2//!
3//! This module provides Compressed Sparse Row (CSR) format for efficient
4//! storage and matrix-vector operations with sparse matrices.
5//!
6//! CSR format stores:
7//! - `values`: Non-zero entries in row-major order
8//! - `col_indices`: Column index for each value
9//! - `row_ptrs`: Index into values/col_indices where each row starts
10//!
11//! For BEM, the near-field matrix is sparse (only nearby element interactions),
12//! while far-field is handled via FMM factorization.
13
14use ndarray::Array1;
15use num_complex::Complex64;
16use std::ops::Range;
17
18/// Compressed Sparse Row (CSR) matrix format
19///
20/// Memory-efficient storage for sparse matrices with O(nnz) space complexity.
21/// Matrix-vector products are O(nnz) instead of O(n²) for dense matrices.
22#[derive(Debug, Clone)]
23pub struct CsrMatrix {
24    /// Number of rows
25    pub num_rows: usize,
26    /// Number of columns
27    pub num_cols: usize,
28    /// Non-zero values in row-major order
29    pub values: Vec<Complex64>,
30    /// Column indices for each value
31    pub col_indices: Vec<usize>,
32    /// Row pointers: row_ptrs[i] is the start index in values/col_indices for row i
33    /// row_ptrs[num_rows] = nnz (total number of non-zeros)
34    pub row_ptrs: Vec<usize>,
35}
36
37impl CsrMatrix {
38    /// Create a new empty CSR matrix
39    pub fn new(num_rows: usize, num_cols: usize) -> Self {
40        Self {
41            num_rows,
42            num_cols,
43            values: Vec::new(),
44            col_indices: Vec::new(),
45            row_ptrs: vec![0; num_rows + 1],
46        }
47    }
48
49    /// Create a CSR matrix with pre-allocated capacity
50    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
51        Self {
52            num_rows,
53            num_cols,
54            values: Vec::with_capacity(nnz_estimate),
55            col_indices: Vec::with_capacity(nnz_estimate),
56            row_ptrs: vec![0; num_rows + 1],
57        }
58    }
59
60    /// Create a CSR matrix from a dense matrix
61    ///
62    /// Only stores entries with magnitude > threshold
63    pub fn from_dense(
64        dense: &ndarray::Array2<Complex64>,
65        threshold: f64,
66    ) -> Self {
67        let num_rows = dense.nrows();
68        let num_cols = dense.ncols();
69
70        let mut values = Vec::new();
71        let mut col_indices = Vec::new();
72        let mut row_ptrs = vec![0usize; num_rows + 1];
73
74        for i in 0..num_rows {
75            for j in 0..num_cols {
76                let val = dense[[i, j]];
77                if val.norm() > threshold {
78                    values.push(val);
79                    col_indices.push(j);
80                }
81            }
82            row_ptrs[i + 1] = values.len();
83        }
84
85        Self {
86            num_rows,
87            num_cols,
88            values,
89            col_indices,
90            row_ptrs,
91        }
92    }
93
94    /// Create a CSR matrix from COO (Coordinate) format triplets
95    ///
96    /// Triplets are (row, col, value). Duplicate entries are summed.
97    pub fn from_triplets(
98        num_rows: usize,
99        num_cols: usize,
100        mut triplets: Vec<(usize, usize, Complex64)>,
101    ) -> Self {
102        if triplets.is_empty() {
103            return Self::new(num_rows, num_cols);
104        }
105
106        // Sort by row, then by column
107        triplets.sort_by(|a, b| {
108            if a.0 != b.0 {
109                a.0.cmp(&b.0)
110            } else {
111                a.1.cmp(&b.1)
112            }
113        });
114
115        let mut values = Vec::with_capacity(triplets.len());
116        let mut col_indices = Vec::with_capacity(triplets.len());
117        let mut row_ptrs = vec![0usize; num_rows + 1];
118
119        let mut prev_row = usize::MAX;
120        let mut prev_col = usize::MAX;
121
122        for (row, col, val) in triplets {
123            if row == prev_row && col == prev_col {
124                // Same entry, accumulate
125                if let Some(last) = values.last_mut() {
126                    *last += val;
127                }
128            } else {
129                // New entry - push it
130                values.push(val);
131                col_indices.push(col);
132
133                // Update row pointers for any rows we skipped
134                if row != prev_row {
135                    let start = if prev_row == usize::MAX { 0 } else { prev_row + 1 };
136                    for r in start..=row {
137                        row_ptrs[r] = values.len() - 1;
138                    }
139                }
140
141                prev_row = row;
142                prev_col = col;
143            }
144        }
145
146        // Fill remaining row pointers
147        let last_row = if prev_row == usize::MAX { 0 } else { prev_row + 1 };
148        for r in last_row..=num_rows {
149            row_ptrs[r] = values.len();
150        }
151
152        Self {
153            num_rows,
154            num_cols,
155            values,
156            col_indices,
157            row_ptrs,
158        }
159    }
160
161    /// Number of non-zero entries
162    pub fn nnz(&self) -> usize {
163        self.values.len()
164    }
165
166    /// Sparsity ratio (fraction of non-zero entries)
167    pub fn sparsity(&self) -> f64 {
168        let total = self.num_rows * self.num_cols;
169        if total == 0 {
170            0.0
171        } else {
172            self.nnz() as f64 / total as f64
173        }
174    }
175
176    /// Get the range of indices in values/col_indices for a given row
177    pub fn row_range(&self, row: usize) -> Range<usize> {
178        self.row_ptrs[row]..self.row_ptrs[row + 1]
179    }
180
181    /// Get the (col, value) pairs for a row
182    pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, Complex64)> + '_ {
183        let range = self.row_range(row);
184        self.col_indices[range.clone()]
185            .iter()
186            .copied()
187            .zip(self.values[range].iter().copied())
188    }
189
190    /// Matrix-vector product: y = A * x
191    pub fn matvec(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
192        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
193
194        let mut y = Array1::zeros(self.num_rows);
195
196        for i in 0..self.num_rows {
197            let mut sum = Complex64::new(0.0, 0.0);
198            for idx in self.row_range(i) {
199                let j = self.col_indices[idx];
200                sum += self.values[idx] * x[j];
201            }
202            y[i] = sum;
203        }
204
205        y
206    }
207
208    /// Matrix-vector product with accumulation: y += A * x
209    pub fn matvec_add(&self, x: &Array1<Complex64>, y: &mut Array1<Complex64>) {
210        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
211        assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
212
213        for i in 0..self.num_rows {
214            for idx in self.row_range(i) {
215                let j = self.col_indices[idx];
216                y[i] += self.values[idx] * x[j];
217            }
218        }
219    }
220
221    /// Transpose matrix-vector product: y = A^T * x
222    pub fn matvec_transpose(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
223        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
224
225        let mut y = Array1::zeros(self.num_cols);
226
227        for i in 0..self.num_rows {
228            for idx in self.row_range(i) {
229                let j = self.col_indices[idx];
230                y[j] += self.values[idx] * x[i];
231            }
232        }
233
234        y
235    }
236
237    /// Hermitian (conjugate transpose) matrix-vector product: y = A^H * x
238    pub fn matvec_hermitian(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
239        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
240
241        let mut y = Array1::zeros(self.num_cols);
242
243        for i in 0..self.num_rows {
244            for idx in self.row_range(i) {
245                let j = self.col_indices[idx];
246                y[j] += self.values[idx].conj() * x[i];
247            }
248        }
249
250        y
251    }
252
253    /// Get element at (i, j), returns 0 if not stored
254    pub fn get(&self, i: usize, j: usize) -> Complex64 {
255        for idx in self.row_range(i) {
256            if self.col_indices[idx] == j {
257                return self.values[idx];
258            }
259        }
260        Complex64::new(0.0, 0.0)
261    }
262
263    /// Extract diagonal elements
264    pub fn diagonal(&self) -> Array1<Complex64> {
265        let n = self.num_rows.min(self.num_cols);
266        let mut diag = Array1::zeros(n);
267
268        for i in 0..n {
269            diag[i] = self.get(i, i);
270        }
271
272        diag
273    }
274
275    /// Scale all values by a scalar
276    pub fn scale(&mut self, scalar: Complex64) {
277        for val in &mut self.values {
278            *val *= scalar;
279        }
280    }
281
282    /// Add a scalar to the diagonal
283    pub fn add_diagonal(&mut self, scalar: Complex64) {
284        let n = self.num_rows.min(self.num_cols);
285
286        for i in 0..n {
287            for idx in self.row_range(i) {
288                if self.col_indices[idx] == i {
289                    self.values[idx] += scalar;
290                    break;
291                }
292            }
293        }
294    }
295
296    /// Create identity matrix in CSR format
297    pub fn identity(n: usize) -> Self {
298        Self {
299            num_rows: n,
300            num_cols: n,
301            values: vec![Complex64::new(1.0, 0.0); n],
302            col_indices: (0..n).collect(),
303            row_ptrs: (0..=n).collect(),
304        }
305    }
306
307    /// Create diagonal matrix from vector
308    pub fn from_diagonal(diag: &Array1<Complex64>) -> Self {
309        let n = diag.len();
310        Self {
311            num_rows: n,
312            num_cols: n,
313            values: diag.to_vec(),
314            col_indices: (0..n).collect(),
315            row_ptrs: (0..=n).collect(),
316        }
317    }
318
319    /// Convert to dense matrix (for debugging/small matrices)
320    pub fn to_dense(&self) -> ndarray::Array2<Complex64> {
321        let mut dense = ndarray::Array2::zeros((self.num_rows, self.num_cols));
322
323        for i in 0..self.num_rows {
324            for idx in self.row_range(i) {
325                let j = self.col_indices[idx];
326                dense[[i, j]] = self.values[idx];
327            }
328        }
329
330        dense
331    }
332}
333
334/// Builder for constructing CSR matrices row by row
335pub struct CsrBuilder {
336    num_rows: usize,
337    num_cols: usize,
338    values: Vec<Complex64>,
339    col_indices: Vec<usize>,
340    row_ptrs: Vec<usize>,
341    current_row: usize,
342}
343
344impl CsrBuilder {
345    /// Create a new CSR builder
346    pub fn new(num_rows: usize, num_cols: usize) -> Self {
347        Self {
348            num_rows,
349            num_cols,
350            values: Vec::new(),
351            col_indices: Vec::new(),
352            row_ptrs: vec![0],
353            current_row: 0,
354        }
355    }
356
357    /// Create a new CSR builder with estimated non-zeros
358    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
359        Self {
360            num_rows,
361            num_cols,
362            values: Vec::with_capacity(nnz_estimate),
363            col_indices: Vec::with_capacity(nnz_estimate),
364            row_ptrs: Vec::with_capacity(num_rows + 1),
365            current_row: 0,
366        }
367    }
368
369    /// Add entries for the current row (must be added in column order)
370    pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, Complex64)>) {
371        for (col, val) in entries {
372            if val.norm() > 0.0 {
373                self.values.push(val);
374                self.col_indices.push(col);
375            }
376        }
377        self.row_ptrs.push(self.values.len());
378        self.current_row += 1;
379    }
380
381    /// Finish building and return the CSR matrix
382    pub fn finish(mut self) -> CsrMatrix {
383        // Fill remaining rows if not all rows were added
384        while self.current_row < self.num_rows {
385            self.row_ptrs.push(self.values.len());
386            self.current_row += 1;
387        }
388
389        CsrMatrix {
390            num_rows: self.num_rows,
391            num_cols: self.num_cols,
392            values: self.values,
393            col_indices: self.col_indices,
394            row_ptrs: self.row_ptrs,
395        }
396    }
397}
398
399/// Blocked CSR format for hierarchical matrices
400///
401/// Stores the matrix as a collection of dense blocks at the leaf level
402/// of a hierarchical decomposition.
403#[derive(Debug, Clone)]
404pub struct BlockedCsr {
405    /// Number of rows
406    pub num_rows: usize,
407    /// Number of columns
408    pub num_cols: usize,
409    /// Block size (rows and columns per block)
410    pub block_size: usize,
411    /// Number of block rows
412    pub num_block_rows: usize,
413    /// Number of block columns
414    pub num_block_cols: usize,
415    /// Dense blocks stored in CSR-like format
416    /// Each block is a dense matrix
417    pub blocks: Vec<ndarray::Array2<Complex64>>,
418    /// Block column indices
419    pub block_col_indices: Vec<usize>,
420    /// Block row pointers
421    pub block_row_ptrs: Vec<usize>,
422}
423
424impl BlockedCsr {
425    /// Create a new blocked CSR matrix
426    pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
427        let num_block_rows = (num_rows + block_size - 1) / block_size;
428        let num_block_cols = (num_cols + block_size - 1) / block_size;
429
430        Self {
431            num_rows,
432            num_cols,
433            block_size,
434            num_block_rows,
435            num_block_cols,
436            blocks: Vec::new(),
437            block_col_indices: Vec::new(),
438            block_row_ptrs: vec![0; num_block_rows + 1],
439        }
440    }
441
442    /// Matrix-vector product using blocked structure
443    pub fn matvec(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
444        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
445
446        let mut y = Array1::zeros(self.num_rows);
447
448        for block_i in 0..self.num_block_rows {
449            let row_start = block_i * self.block_size;
450            let row_end = (row_start + self.block_size).min(self.num_rows);
451            let local_rows = row_end - row_start;
452
453            for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
454                let block_j = self.block_col_indices[idx];
455                let block = &self.blocks[idx];
456
457                let col_start = block_j * self.block_size;
458                let col_end = (col_start + self.block_size).min(self.num_cols);
459                let local_cols = col_end - col_start;
460
461                // Extract local x
462                let x_local: Array1<Complex64> =
463                    Array1::from_iter((col_start..col_end).map(|j| x[j]));
464
465                // Apply block
466                for i in 0..local_rows {
467                    let mut sum = Complex64::new(0.0, 0.0);
468                    for j in 0..local_cols {
469                        sum += block[[i, j]] * x_local[j];
470                    }
471                    y[row_start + i] += sum;
472                }
473            }
474        }
475
476        y
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use ndarray::array;
484
485    #[test]
486    fn test_csr_from_dense() {
487        let dense = array![
488            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0), Complex64::new(2.0, 0.0)],
489            [Complex64::new(0.0, 0.0), Complex64::new(3.0, 0.0), Complex64::new(0.0, 0.0)],
490            [Complex64::new(4.0, 0.0), Complex64::new(0.0, 0.0), Complex64::new(5.0, 0.0)],
491        ];
492
493        let csr = CsrMatrix::from_dense(&dense, 1e-15);
494
495        assert_eq!(csr.num_rows, 3);
496        assert_eq!(csr.num_cols, 3);
497        assert_eq!(csr.nnz(), 5);
498
499        // Check values
500        assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
501        assert_eq!(csr.get(0, 2), Complex64::new(2.0, 0.0));
502        assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
503        assert_eq!(csr.get(2, 0), Complex64::new(4.0, 0.0));
504        assert_eq!(csr.get(2, 2), Complex64::new(5.0, 0.0));
505
506        // Check zeros
507        assert_eq!(csr.get(0, 1), Complex64::new(0.0, 0.0));
508        assert_eq!(csr.get(1, 0), Complex64::new(0.0, 0.0));
509    }
510
511    #[test]
512    fn test_csr_matvec() {
513        let dense = array![
514            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
515            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
516        ];
517
518        let csr = CsrMatrix::from_dense(&dense, 1e-15);
519        let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
520
521        let y = csr.matvec(&x);
522
523        // [1 2] * [1]   [5]
524        // [3 4]   [2] = [11]
525        assert!((y[0] - Complex64::new(5.0, 0.0)).norm() < 1e-10);
526        assert!((y[1] - Complex64::new(11.0, 0.0)).norm() < 1e-10);
527    }
528
529    #[test]
530    fn test_csr_from_triplets() {
531        let triplets = vec![
532            (0, 0, Complex64::new(1.0, 0.0)),
533            (0, 2, Complex64::new(2.0, 0.0)),
534            (1, 1, Complex64::new(3.0, 0.0)),
535            (2, 0, Complex64::new(4.0, 0.0)),
536            (2, 2, Complex64::new(5.0, 0.0)),
537        ];
538
539        let csr = CsrMatrix::from_triplets(3, 3, triplets);
540
541        assert_eq!(csr.nnz(), 5);
542        assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
543        assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
544    }
545
546    #[test]
547    fn test_csr_triplets_duplicate() {
548        // Test that duplicate entries are summed
549        let triplets = vec![
550            (0, 0, Complex64::new(1.0, 0.0)),
551            (0, 0, Complex64::new(2.0, 0.0)), // Duplicate!
552            (1, 1, Complex64::new(3.0, 0.0)),
553        ];
554
555        let csr = CsrMatrix::from_triplets(2, 2, triplets);
556
557        assert_eq!(csr.get(0, 0), Complex64::new(3.0, 0.0)); // 1 + 2 = 3
558    }
559
560    #[test]
561    fn test_csr_identity() {
562        let id = CsrMatrix::identity(3);
563
564        assert_eq!(id.nnz(), 3);
565        assert_eq!(id.get(0, 0), Complex64::new(1.0, 0.0));
566        assert_eq!(id.get(1, 1), Complex64::new(1.0, 0.0));
567        assert_eq!(id.get(2, 2), Complex64::new(1.0, 0.0));
568        assert_eq!(id.get(0, 1), Complex64::new(0.0, 0.0));
569    }
570
571    #[test]
572    fn test_csr_builder() {
573        let mut builder = CsrBuilder::new(3, 3);
574
575        // Row 0: entries at columns 0 and 2
576        builder.add_row_entries(
577            [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))]
578                .into_iter(),
579        );
580
581        // Row 1: entry at column 1
582        builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
583
584        // Row 2: entries at columns 0 and 2
585        builder.add_row_entries(
586            [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))]
587                .into_iter(),
588        );
589
590        let csr = builder.finish();
591
592        assert_eq!(csr.nnz(), 5);
593        assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
594        assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
595    }
596
597    #[test]
598    fn test_csr_to_dense_roundtrip() {
599        let original = array![
600            [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
601            [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
602        ];
603
604        let csr = CsrMatrix::from_dense(&original, 1e-15);
605        let recovered = csr.to_dense();
606
607        for i in 0..2 {
608            for j in 0..2 {
609                assert!((original[[i, j]] - recovered[[i, j]]).norm() < 1e-10);
610            }
611        }
612    }
613
614    #[test]
615    fn test_csr_transpose_matvec() {
616        let dense = array![
617            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
618            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
619            [Complex64::new(5.0, 0.0), Complex64::new(6.0, 0.0)],
620        ];
621
622        let csr = CsrMatrix::from_dense(&dense, 1e-15);
623        let x = array![
624            Complex64::new(1.0, 0.0),
625            Complex64::new(2.0, 0.0),
626            Complex64::new(3.0, 0.0)
627        ];
628
629        let y = csr.matvec_transpose(&x);
630
631        // A^T * x = [1 3 5] * [1]   [22]
632        //           [2 4 6]   [2] = [28]
633        //                     [3]
634        assert!((y[0] - Complex64::new(22.0, 0.0)).norm() < 1e-10);
635        assert!((y[1] - Complex64::new(28.0, 0.0)).norm() < 1e-10);
636    }
637}