math_audio_bem/core/assembly/
sparse.rs

1//! Sparse matrix structures (CSR format)
2//!
3//! This module provides Compressed Sparse Row (CSR) format for efficient
4//! storage and matrix-vector operations with sparse matrices.
5//!
6//! CSR format stores:
7//! - `values`: Non-zero entries in row-major order
8//! - `col_indices`: Column index for each value
9//! - `row_ptrs`: Index into values/col_indices where each row starts
10//!
11//! For BEM, the near-field matrix is sparse (only nearby element interactions),
12//! while far-field is handled via FMM factorization.
13
14use ndarray::Array1;
15use num_complex::Complex64;
16use std::ops::Range;
17
18/// Compressed Sparse Row (CSR) matrix format
19///
20/// Memory-efficient storage for sparse matrices with O(nnz) space complexity.
21/// Matrix-vector products are O(nnz) instead of O(n²) for dense matrices.
22#[derive(Debug, Clone)]
23pub struct CsrMatrix {
24    /// Number of rows
25    pub num_rows: usize,
26    /// Number of columns
27    pub num_cols: usize,
28    /// Non-zero values in row-major order
29    pub values: Vec<Complex64>,
30    /// Column indices for each value
31    pub col_indices: Vec<usize>,
32    /// Row pointers: row_ptrs[i] is the start index in values/col_indices for row i
33    /// row_ptrs[num_rows] = nnz (total number of non-zeros)
34    pub row_ptrs: Vec<usize>,
35}
36
37impl CsrMatrix {
38    /// Create a new empty CSR matrix
39    pub fn new(num_rows: usize, num_cols: usize) -> Self {
40        Self {
41            num_rows,
42            num_cols,
43            values: Vec::new(),
44            col_indices: Vec::new(),
45            row_ptrs: vec![0; num_rows + 1],
46        }
47    }
48
49    /// Create a CSR matrix with pre-allocated capacity
50    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
51        Self {
52            num_rows,
53            num_cols,
54            values: Vec::with_capacity(nnz_estimate),
55            col_indices: Vec::with_capacity(nnz_estimate),
56            row_ptrs: vec![0; num_rows + 1],
57        }
58    }
59
60    /// Create a CSR matrix from a dense matrix
61    ///
62    /// Only stores entries with magnitude > threshold
63    pub fn from_dense(dense: &ndarray::Array2<Complex64>, threshold: f64) -> Self {
64        let num_rows = dense.nrows();
65        let num_cols = dense.ncols();
66
67        let mut values = Vec::new();
68        let mut col_indices = Vec::new();
69        let mut row_ptrs = vec![0usize; num_rows + 1];
70
71        for i in 0..num_rows {
72            for j in 0..num_cols {
73                let val = dense[[i, j]];
74                if val.norm() > threshold {
75                    values.push(val);
76                    col_indices.push(j);
77                }
78            }
79            row_ptrs[i + 1] = values.len();
80        }
81
82        Self {
83            num_rows,
84            num_cols,
85            values,
86            col_indices,
87            row_ptrs,
88        }
89    }
90
91    /// Create a CSR matrix from COO (Coordinate) format triplets
92    ///
93    /// Triplets are (row, col, value). Duplicate entries are summed.
94    pub fn from_triplets(
95        num_rows: usize,
96        num_cols: usize,
97        mut triplets: Vec<(usize, usize, Complex64)>,
98    ) -> Self {
99        if triplets.is_empty() {
100            return Self::new(num_rows, num_cols);
101        }
102
103        // Sort by row, then by column
104        triplets.sort_by(|a, b| {
105            if a.0 != b.0 {
106                a.0.cmp(&b.0)
107            } else {
108                a.1.cmp(&b.1)
109            }
110        });
111
112        let mut values = Vec::with_capacity(triplets.len());
113        let mut col_indices = Vec::with_capacity(triplets.len());
114        let mut row_ptrs = vec![0usize; num_rows + 1];
115
116        let mut prev_row = usize::MAX;
117        let mut prev_col = usize::MAX;
118
119        for (row, col, val) in triplets {
120            if row == prev_row && col == prev_col {
121                // Same entry, accumulate
122                if let Some(last) = values.last_mut() {
123                    *last += val;
124                }
125            } else {
126                // New entry - push it
127                values.push(val);
128                col_indices.push(col);
129
130                // Update row pointers for any rows we skipped
131                if row != prev_row {
132                    let start = if prev_row == usize::MAX {
133                        0
134                    } else {
135                        prev_row + 1
136                    };
137                    for item in row_ptrs.iter_mut().take(row + 1).skip(start) {
138                        *item = values.len() - 1;
139                    }
140                }
141
142                prev_row = row;
143                prev_col = col;
144            }
145        }
146
147        // Fill remaining row pointers
148        let last_row = if prev_row == usize::MAX {
149            0
150        } else {
151            prev_row + 1
152        };
153        for item in row_ptrs.iter_mut().take(num_rows + 1).skip(last_row) {
154            *item = values.len();
155        }
156
157        Self {
158            num_rows,
159            num_cols,
160            values,
161            col_indices,
162            row_ptrs,
163        }
164    }
165
166    /// Number of non-zero entries
167    pub fn nnz(&self) -> usize {
168        self.values.len()
169    }
170
171    /// Sparsity ratio (fraction of non-zero entries)
172    pub fn sparsity(&self) -> f64 {
173        let total = self.num_rows * self.num_cols;
174        if total == 0 {
175            0.0
176        } else {
177            self.nnz() as f64 / total as f64
178        }
179    }
180
181    /// Get the range of indices in values/col_indices for a given row
182    pub fn row_range(&self, row: usize) -> Range<usize> {
183        self.row_ptrs[row]..self.row_ptrs[row + 1]
184    }
185
186    /// Get the (col, value) pairs for a row
187    pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, Complex64)> + '_ {
188        let range = self.row_range(row);
189        self.col_indices[range.clone()]
190            .iter()
191            .copied()
192            .zip(self.values[range].iter().copied())
193    }
194
195    /// Matrix-vector product: y = A * x
196    pub fn matvec(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
197        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
198
199        let mut y = Array1::zeros(self.num_rows);
200
201        for i in 0..self.num_rows {
202            let mut sum = Complex64::new(0.0, 0.0);
203            for idx in self.row_range(i) {
204                let j = self.col_indices[idx];
205                sum += self.values[idx] * x[j];
206            }
207            y[i] = sum;
208        }
209
210        y
211    }
212
213    /// Matrix-vector product with accumulation: y += A * x
214    pub fn matvec_add(&self, x: &Array1<Complex64>, y: &mut Array1<Complex64>) {
215        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
216        assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
217
218        for i in 0..self.num_rows {
219            for idx in self.row_range(i) {
220                let j = self.col_indices[idx];
221                y[i] += self.values[idx] * x[j];
222            }
223        }
224    }
225
226    /// Transpose matrix-vector product: y = A^T * x
227    pub fn matvec_transpose(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
228        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
229
230        let mut y = Array1::zeros(self.num_cols);
231
232        for i in 0..self.num_rows {
233            for idx in self.row_range(i) {
234                let j = self.col_indices[idx];
235                y[j] += self.values[idx] * x[i];
236            }
237        }
238
239        y
240    }
241
242    /// Hermitian (conjugate transpose) matrix-vector product: y = A^H * x
243    pub fn matvec_hermitian(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
244        assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
245
246        let mut y = Array1::zeros(self.num_cols);
247
248        for i in 0..self.num_rows {
249            for idx in self.row_range(i) {
250                let j = self.col_indices[idx];
251                y[j] += self.values[idx].conj() * x[i];
252            }
253        }
254
255        y
256    }
257
258    /// Get element at (i, j), returns 0 if not stored
259    pub fn get(&self, i: usize, j: usize) -> Complex64 {
260        for idx in self.row_range(i) {
261            if self.col_indices[idx] == j {
262                return self.values[idx];
263            }
264        }
265        Complex64::new(0.0, 0.0)
266    }
267
268    /// Extract diagonal elements
269    pub fn diagonal(&self) -> Array1<Complex64> {
270        let n = self.num_rows.min(self.num_cols);
271        let mut diag = Array1::zeros(n);
272
273        for i in 0..n {
274            diag[i] = self.get(i, i);
275        }
276
277        diag
278    }
279
280    /// Scale all values by a scalar
281    pub fn scale(&mut self, scalar: Complex64) {
282        for val in &mut self.values {
283            *val *= scalar;
284        }
285    }
286
287    /// Add a scalar to the diagonal
288    pub fn add_diagonal(&mut self, scalar: Complex64) {
289        let n = self.num_rows.min(self.num_cols);
290
291        for i in 0..n {
292            for idx in self.row_range(i) {
293                if self.col_indices[idx] == i {
294                    self.values[idx] += scalar;
295                    break;
296                }
297            }
298        }
299    }
300
301    /// Create identity matrix in CSR format
302    pub fn identity(n: usize) -> Self {
303        Self {
304            num_rows: n,
305            num_cols: n,
306            values: vec![Complex64::new(1.0, 0.0); n],
307            col_indices: (0..n).collect(),
308            row_ptrs: (0..=n).collect(),
309        }
310    }
311
312    /// Create diagonal matrix from vector
313    pub fn from_diagonal(diag: &Array1<Complex64>) -> Self {
314        let n = diag.len();
315        Self {
316            num_rows: n,
317            num_cols: n,
318            values: diag.to_vec(),
319            col_indices: (0..n).collect(),
320            row_ptrs: (0..=n).collect(),
321        }
322    }
323
324    /// Convert to dense matrix (for debugging/small matrices)
325    pub fn to_dense(&self) -> ndarray::Array2<Complex64> {
326        let mut dense = ndarray::Array2::zeros((self.num_rows, self.num_cols));
327
328        for i in 0..self.num_rows {
329            for idx in self.row_range(i) {
330                let j = self.col_indices[idx];
331                dense[[i, j]] = self.values[idx];
332            }
333        }
334
335        dense
336    }
337}
338
339/// Builder for constructing CSR matrices row by row
340pub struct CsrBuilder {
341    num_rows: usize,
342    num_cols: usize,
343    values: Vec<Complex64>,
344    col_indices: Vec<usize>,
345    row_ptrs: Vec<usize>,
346    current_row: usize,
347}
348
349impl CsrBuilder {
350    /// Create a new CSR builder
351    pub fn new(num_rows: usize, num_cols: usize) -> Self {
352        Self {
353            num_rows,
354            num_cols,
355            values: Vec::new(),
356            col_indices: Vec::new(),
357            row_ptrs: vec![0],
358            current_row: 0,
359        }
360    }
361
362    /// Create a new CSR builder with estimated non-zeros
363    pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
364        Self {
365            num_rows,
366            num_cols,
367            values: Vec::with_capacity(nnz_estimate),
368            col_indices: Vec::with_capacity(nnz_estimate),
369            row_ptrs: Vec::with_capacity(num_rows + 1),
370            current_row: 0,
371        }
372    }
373
374    /// Add entries for the current row (must be added in column order)
375    pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, Complex64)>) {
376        for (col, val) in entries {
377            if val.norm() > 0.0 {
378                self.values.push(val);
379                self.col_indices.push(col);
380            }
381        }
382        self.row_ptrs.push(self.values.len());
383        self.current_row += 1;
384    }
385
386    /// Finish building and return the CSR matrix
387    pub fn finish(mut self) -> CsrMatrix {
388        // Fill remaining rows if not all rows were added
389        while self.current_row < self.num_rows {
390            self.row_ptrs.push(self.values.len());
391            self.current_row += 1;
392        }
393
394        CsrMatrix {
395            num_rows: self.num_rows,
396            num_cols: self.num_cols,
397            values: self.values,
398            col_indices: self.col_indices,
399            row_ptrs: self.row_ptrs,
400        }
401    }
402}
403
404/// Blocked CSR format for hierarchical matrices
405///
406/// Stores the matrix as a collection of dense blocks at the leaf level
407/// of a hierarchical decomposition.
408#[derive(Debug, Clone)]
409pub struct BlockedCsr {
410    /// Number of rows
411    pub num_rows: usize,
412    /// Number of columns
413    pub num_cols: usize,
414    /// Block size (rows and columns per block)
415    pub block_size: usize,
416    /// Number of block rows
417    pub num_block_rows: usize,
418    /// Number of block columns
419    pub num_block_cols: usize,
420    /// Dense blocks stored in CSR-like format
421    /// Each block is a dense matrix
422    pub blocks: Vec<ndarray::Array2<Complex64>>,
423    /// Block column indices
424    pub block_col_indices: Vec<usize>,
425    /// Block row pointers
426    pub block_row_ptrs: Vec<usize>,
427}
428
429impl BlockedCsr {
430    /// Create a new blocked CSR matrix
431    pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
432        let num_block_rows = num_rows.div_ceil(block_size);
433        let num_block_cols = num_cols.div_ceil(block_size);
434
435        Self {
436            num_rows,
437            num_cols,
438            block_size,
439            num_block_rows,
440            num_block_cols,
441            blocks: Vec::new(),
442            block_col_indices: Vec::new(),
443            block_row_ptrs: vec![0; num_block_rows + 1],
444        }
445    }
446
447    /// Matrix-vector product using blocked structure
448    pub fn matvec(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
449        assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
450
451        let mut y = Array1::zeros(self.num_rows);
452
453        for block_i in 0..self.num_block_rows {
454            let row_start = block_i * self.block_size;
455            let row_end = (row_start + self.block_size).min(self.num_rows);
456            let local_rows = row_end - row_start;
457
458            for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
459                let block_j = self.block_col_indices[idx];
460                let block = &self.blocks[idx];
461
462                let col_start = block_j * self.block_size;
463                let col_end = (col_start + self.block_size).min(self.num_cols);
464                let local_cols = col_end - col_start;
465
466                // Extract local x
467                let x_local: Array1<Complex64> =
468                    Array1::from_iter((col_start..col_end).map(|j| x[j]));
469
470                // Apply block
471                for i in 0..local_rows {
472                    let mut sum = Complex64::new(0.0, 0.0);
473                    for j in 0..local_cols {
474                        sum += block[[i, j]] * x_local[j];
475                    }
476                    y[row_start + i] += sum;
477                }
478            }
479        }
480
481        y
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use ndarray::array;
489
490    #[test]
491    fn test_csr_from_dense() {
492        let dense = array![
493            [
494                Complex64::new(1.0, 0.0),
495                Complex64::new(0.0, 0.0),
496                Complex64::new(2.0, 0.0)
497            ],
498            [
499                Complex64::new(0.0, 0.0),
500                Complex64::new(3.0, 0.0),
501                Complex64::new(0.0, 0.0)
502            ],
503            [
504                Complex64::new(4.0, 0.0),
505                Complex64::new(0.0, 0.0),
506                Complex64::new(5.0, 0.0)
507            ],
508        ];
509
510        let csr = CsrMatrix::from_dense(&dense, 1e-15);
511
512        assert_eq!(csr.num_rows, 3);
513        assert_eq!(csr.num_cols, 3);
514        assert_eq!(csr.nnz(), 5);
515
516        // Check values
517        assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
518        assert_eq!(csr.get(0, 2), Complex64::new(2.0, 0.0));
519        assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
520        assert_eq!(csr.get(2, 0), Complex64::new(4.0, 0.0));
521        assert_eq!(csr.get(2, 2), Complex64::new(5.0, 0.0));
522
523        // Check zeros
524        assert_eq!(csr.get(0, 1), Complex64::new(0.0, 0.0));
525        assert_eq!(csr.get(1, 0), Complex64::new(0.0, 0.0));
526    }
527
528    #[test]
529    fn test_csr_matvec() {
530        let dense = array![
531            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
532            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
533        ];
534
535        let csr = CsrMatrix::from_dense(&dense, 1e-15);
536        let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
537
538        let y = csr.matvec(&x);
539
540        // [1 2] * [1]   [5]
541        // [3 4]   [2] = [11]
542        assert!((y[0] - Complex64::new(5.0, 0.0)).norm() < 1e-10);
543        assert!((y[1] - Complex64::new(11.0, 0.0)).norm() < 1e-10);
544    }
545
546    #[test]
547    fn test_csr_from_triplets() {
548        let triplets = vec![
549            (0, 0, Complex64::new(1.0, 0.0)),
550            (0, 2, Complex64::new(2.0, 0.0)),
551            (1, 1, Complex64::new(3.0, 0.0)),
552            (2, 0, Complex64::new(4.0, 0.0)),
553            (2, 2, Complex64::new(5.0, 0.0)),
554        ];
555
556        let csr = CsrMatrix::from_triplets(3, 3, triplets);
557
558        assert_eq!(csr.nnz(), 5);
559        assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
560        assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
561    }
562
563    #[test]
564    fn test_csr_triplets_duplicate() {
565        // Test that duplicate entries are summed
566        let triplets = vec![
567            (0, 0, Complex64::new(1.0, 0.0)),
568            (0, 0, Complex64::new(2.0, 0.0)), // Duplicate!
569            (1, 1, Complex64::new(3.0, 0.0)),
570        ];
571
572        let csr = CsrMatrix::from_triplets(2, 2, triplets);
573
574        assert_eq!(csr.get(0, 0), Complex64::new(3.0, 0.0)); // 1 + 2 = 3
575    }
576
577    #[test]
578    fn test_csr_identity() {
579        let id = CsrMatrix::identity(3);
580
581        assert_eq!(id.nnz(), 3);
582        assert_eq!(id.get(0, 0), Complex64::new(1.0, 0.0));
583        assert_eq!(id.get(1, 1), Complex64::new(1.0, 0.0));
584        assert_eq!(id.get(2, 2), Complex64::new(1.0, 0.0));
585        assert_eq!(id.get(0, 1), Complex64::new(0.0, 0.0));
586    }
587
588    #[test]
589    fn test_csr_builder() {
590        let mut builder = CsrBuilder::new(3, 3);
591
592        // Row 0: entries at columns 0 and 2
593        builder.add_row_entries(
594            [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))].into_iter(),
595        );
596
597        // Row 1: entry at column 1
598        builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
599
600        // Row 2: entries at columns 0 and 2
601        builder.add_row_entries(
602            [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))].into_iter(),
603        );
604
605        let csr = builder.finish();
606
607        assert_eq!(csr.nnz(), 5);
608        assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
609        assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
610    }
611
612    #[test]
613    fn test_csr_to_dense_roundtrip() {
614        let original = array![
615            [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
616            [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
617        ];
618
619        let csr = CsrMatrix::from_dense(&original, 1e-15);
620        let recovered = csr.to_dense();
621
622        for i in 0..2 {
623            for j in 0..2 {
624                assert!((original[[i, j]] - recovered[[i, j]]).norm() < 1e-10);
625            }
626        }
627    }
628
629    #[test]
630    fn test_csr_transpose_matvec() {
631        let dense = array![
632            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
633            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
634            [Complex64::new(5.0, 0.0), Complex64::new(6.0, 0.0)],
635        ];
636
637        let csr = CsrMatrix::from_dense(&dense, 1e-15);
638        let x = array![
639            Complex64::new(1.0, 0.0),
640            Complex64::new(2.0, 0.0),
641            Complex64::new(3.0, 0.0)
642        ];
643
644        let y = csr.matvec_transpose(&x);
645
646        // A^T * x = [1 3 5] * [1]   [22]
647        //           [2 4 6]   [2] = [28]
648        //                     [3]
649        assert!((y[0] - Complex64::new(22.0, 0.0)).norm() < 1e-10);
650        assert!((y[1] - Complex64::new(28.0, 0.0)).norm() < 1e-10);
651    }
652}