Skip to main content

oxiphysics_fem/sparse/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use std::collections::HashMap;
6
7/// Incomplete Cholesky (ICC) preconditioner for symmetric positive definite
8/// sparse matrices.
9///
10/// Computes a lower-triangular factor `L` such that `L L^T ≈ A`, maintaining
11/// the sparsity pattern of the lower triangle of `A`.  This is the sparse
12/// counterpart of the dense incomplete Cholesky; suitable for use as a
13/// preconditioner in PCG.
14#[derive(Debug, Clone)]
15pub struct IccPreconditioner {
16    /// Lower-triangular Cholesky factor stored in CSR format.
17    pub(super) l_values: Vec<f64>,
18    pub(super) row_ptr: Vec<usize>,
19    pub(super) col_indices: Vec<usize>,
20    pub(super) n: usize,
21}
22impl IccPreconditioner {
23    /// Compute the ICC factorization of a symmetric CSR matrix `a`.
24    ///
25    /// Only the lower-triangular pattern (including diagonal) of `a` is used.
26    /// Fill-in is dropped (ICC(0) strategy).
27    pub fn new(a: &CsrMatrix) -> Self {
28        assert_eq!(a.nrows, a.ncols, "ICC requires a square matrix");
29        let n = a.nrows;
30        let mut triplets: Vec<(usize, usize, f64)> = Vec::new();
31        for row in 0..n {
32            let start = a.row_ptr[row];
33            let end = a.row_ptr[row + 1];
34            for idx in start..end {
35                let col = a.col_indices[idx];
36                if col <= row {
37                    triplets.push((row, col, a.values[idx]));
38                }
39            }
40        }
41        let lt = CsrMatrix::from_triplets(n, n, &triplets);
42        let mut l_values = lt.values.clone();
43        let row_ptr = lt.row_ptr.clone();
44        let col_indices = lt.col_indices.clone();
45        for j in 0..n {
46            let j_start = row_ptr[j];
47            let j_end = row_ptr[j + 1];
48            let diag_pos = col_indices[j_start..j_end]
49                .iter()
50                .position(|&c| c == j)
51                .map(|off| j_start + off);
52            if let Some(dp) = diag_pos {
53                let sum_sq: f64 = l_values[j_start..dp].iter().map(|&v| v * v).sum();
54                let diag_val = l_values[dp] - sum_sq;
55                if diag_val <= 0.0 {
56                    l_values[dp] = 1e-30_f64.sqrt();
57                } else {
58                    l_values[dp] = diag_val.sqrt();
59                }
60                let l_jj = l_values[dp];
61                for row_i in (j + 1)..n {
62                    let i_start = row_ptr[row_i];
63                    let i_end = row_ptr[row_i + 1];
64                    let pos_ij = col_indices[i_start..i_end]
65                        .iter()
66                        .position(|&c| c == j)
67                        .map(|off| i_start + off);
68                    if let Some(pij) = pos_ij {
69                        let mut dot = 0.0f64;
70                        let mut pi = i_start;
71                        let mut pj = j_start;
72                        while pi < pij && pj < dp {
73                            let ci = col_indices[pi];
74                            let cj = col_indices[pj];
75                            if ci == cj {
76                                dot += l_values[pi] * l_values[pj];
77                                pi += 1;
78                                pj += 1;
79                            } else if ci < cj {
80                                pi += 1;
81                            } else {
82                                pj += 1;
83                            }
84                        }
85                        if l_jj.abs() > 1e-60 {
86                            l_values[pij] = (l_values[pij] - dot) / l_jj;
87                        }
88                    }
89                }
90            }
91        }
92        IccPreconditioner {
93            l_values,
94            row_ptr,
95            col_indices,
96            n,
97        }
98    }
99    /// Apply the ICC preconditioner: solve `(L L^T) z = r`.
100    ///
101    /// Performs two triangular solves: `L y = r` (forward), then `L^T z = y`
102    /// (backward).
103    pub fn solve(&self, rhs: &[f64]) -> Vec<f64> {
104        let n = self.n;
105        assert_eq!(rhs.len(), n);
106        let mut y = rhs.to_vec();
107        for i in 0..n {
108            let start = self.row_ptr[i];
109            let end = self.row_ptr[i + 1];
110            let mut diag = 1.0;
111            for p in start..end {
112                let j = self.col_indices[p];
113                if j == i {
114                    diag = self.l_values[p];
115                } else if j < i {
116                    y[i] -= self.l_values[p] * y[j];
117                }
118            }
119            if diag.abs() > 1e-60 {
120                y[i] /= diag;
121            }
122        }
123        let mut z = y.clone();
124        for i in (0..n).rev() {
125            let start = self.row_ptr[i];
126            let end = self.row_ptr[i + 1];
127            let mut diag = 1.0;
128            for p in start..end {
129                let j = self.col_indices[p];
130                if j == i {
131                    diag = self.l_values[p];
132                }
133            }
134            if diag.abs() > 1e-60 {
135                z[i] /= diag;
136            }
137            for p in start..end {
138                let j = self.col_indices[p];
139                if j < i {
140                    z[j] -= self.l_values[p] * z[i];
141                }
142            }
143        }
144        z
145    }
146    /// Return the number of non-zero entries in the factor L.
147    pub fn nnz(&self) -> usize {
148        self.l_values.len()
149    }
150}
151/// A node in a 2-D quadtree used for adaptive mesh refinement.
152///
153/// Each node represents a rectangular cell in the mesh.  Leaf nodes
154/// correspond to actual mesh cells; internal nodes have been refined.
155#[derive(Debug, Clone)]
156pub struct QuadTreeNode {
157    /// Lower-left x coordinate.
158    pub x0: f64,
159    /// Lower-left y coordinate.
160    pub y0: f64,
161    /// Cell width.
162    pub width: f64,
163    /// Cell height.
164    pub height: f64,
165    /// Refinement level (0 = root).
166    pub level: u32,
167    /// Per-element error indicator (set by the caller before refinement).
168    pub error_indicator: f64,
169    /// Children (SW, SE, NW, NE) if this node has been refined.
170    pub children: Option<Box<[QuadTreeNode; 4]>>,
171    /// Unique cell index (assigned during leaf enumeration).
172    pub cell_id: usize,
173}
174impl QuadTreeNode {
175    /// Create a root node covering `[x0, x0+width] × [y0, y0+height]`.
176    pub fn new_root(x0: f64, y0: f64, width: f64, height: f64) -> Self {
177        Self {
178            x0,
179            y0,
180            width,
181            height,
182            level: 0,
183            error_indicator: 0.0,
184            children: None,
185            cell_id: 0,
186        }
187    }
188    /// Create a child node.
189    fn new_child(x0: f64, y0: f64, width: f64, height: f64, level: u32) -> Self {
190        Self {
191            x0,
192            y0,
193            width,
194            height,
195            level,
196            error_indicator: 0.0,
197            children: None,
198            cell_id: 0,
199        }
200    }
201    /// Return the cell width (same as `self.width`).
202    pub fn cell_width(&self) -> f64 {
203        self.width
204    }
205    /// Return the cell height (same as `self.height`).
206    pub fn cell_height(&self) -> f64 {
207        self.height
208    }
209    /// Return `true` if this node has no children (is a leaf).
210    pub fn is_leaf(&self) -> bool {
211        self.children.is_none()
212    }
213    /// Refine this node into four child cells (SW, SE, NW, NE).
214    ///
215    /// Does nothing if the node is already refined.
216    pub fn refine(&mut self) {
217        if self.children.is_some() {
218            return;
219        }
220        let hw = self.width * 0.5;
221        let hh = self.height * 0.5;
222        let lev = self.level + 1;
223        let sw = Self::new_child(self.x0, self.y0, hw, hh, lev);
224        let se = Self::new_child(self.x0 + hw, self.y0, hw, hh, lev);
225        let nw = Self::new_child(self.x0, self.y0 + hh, hw, hh, lev);
226        let ne = Self::new_child(self.x0 + hw, self.y0 + hh, hw, hh, lev);
227        self.children = Some(Box::new([sw, se, nw, ne]));
228    }
229    /// Count the total number of leaf nodes in this subtree.
230    pub fn leaf_count(&self) -> usize {
231        match &self.children {
232            None => 1,
233            Some(ch) => ch.iter().map(|c| c.leaf_count()).sum(),
234        }
235    }
236    /// Enumerate leaf nodes and assign sequential cell IDs.
237    ///
238    /// Returns the number of leaves assigned.
239    pub fn enumerate_leaves(&mut self, start_id: usize) -> usize {
240        match self.children.as_mut() {
241            None => {
242                self.cell_id = start_id;
243                start_id + 1
244            }
245            Some(ch) => {
246                let mut id = start_id;
247                for child in ch.iter_mut() {
248                    id = child.enumerate_leaves(id);
249                }
250                id
251            }
252        }
253    }
254    /// Collect all leaf nodes as immutable references.
255    pub fn collect_leaves<'a>(&'a self, leaves: &mut Vec<&'a QuadTreeNode>) {
256        match &self.children {
257            None => leaves.push(self),
258            Some(ch) => {
259                for child in ch.iter() {
260                    child.collect_leaves(leaves);
261                }
262            }
263        }
264    }
265    /// Centroid x-coordinate.
266    pub fn cx(&self) -> f64 {
267        self.x0 + 0.5 * self.width
268    }
269    /// Centroid y-coordinate.
270    pub fn cy(&self) -> f64 {
271        self.y0 + 0.5 * self.height
272    }
273}
274/// Incomplete LU factorization with level-of-fill `k` (ILU(k)).
275///
276/// Level k = 0 reproduces ILU(0) (no fill).  Level k = 1 allows one level
277/// of fill beyond the original pattern, and so on.  Higher levels produce
278/// better preconditioners at the cost of more memory.
279///
280/// The factorization is stored in CSR format with an extended sparsity pattern
281/// computed via symbolic analysis.
282#[derive(Debug, Clone)]
283pub struct IlukPreconditioner {
284    /// L and U factors in combined CSR (L strictly lower, U includes diagonal).
285    pub(super) lu_values: Vec<f64>,
286    pub(super) row_ptr: Vec<usize>,
287    pub(super) col_indices: Vec<usize>,
288    /// Level-of-fill for each non-zero.
289    pub(super) fill_levels: Vec<u32>,
290    pub(super) n: usize,
291    /// Fill level parameter k.
292    pub(super) k: u32,
293}
294impl IlukPreconditioner {
295    /// Compute the ILU(k) factorization of matrix `a` with fill level `k`.
296    ///
297    /// For k=0 this matches ILU(0).  For k≥1 additional fill-in entries
298    /// are allowed in the sparsity pattern.
299    pub fn new(a: &CsrMatrix, k: u32) -> Self {
300        assert_eq!(a.nrows, a.ncols, "matrix must be square");
301        let n = a.nrows;
302        let mut pattern: Vec<Vec<(usize, u32)>> = (0..n)
303            .map(|row| {
304                let start = a.row_ptr[row];
305                let end = a.row_ptr[row + 1];
306                let mut row_pat: Vec<(usize, u32)> =
307                    (start..end).map(|idx| (a.col_indices[idx], 0)).collect();
308                row_pat.sort_by_key(|&(c, _)| c);
309                row_pat
310            })
311            .collect();
312        for i in 1..n {
313            let mut j = 0;
314            while j < pattern[i].len() {
315                let (col, lev_ij) = pattern[i][j];
316                if col >= i {
317                    break;
318                }
319                let row_p_copy: Vec<(usize, u32)> = pattern[col].clone();
320                for &(q, lev_pq) in &row_p_copy {
321                    if q <= col {
322                        continue;
323                    }
324                    let new_lev = lev_ij + lev_pq + 1;
325                    if new_lev <= k {
326                        match pattern[i].binary_search_by_key(&q, |&(c, _)| c) {
327                            Ok(pos) => {
328                                if pattern[i][pos].1 > new_lev {
329                                    pattern[i][pos].1 = new_lev;
330                                }
331                            }
332                            Err(pos) => {
333                                pattern[i].insert(pos, (q, new_lev));
334                            }
335                        }
336                    }
337                }
338                j += 1;
339            }
340        }
341        let mut row_ptr = vec![0usize; n + 1];
342        for i in 0..n {
343            row_ptr[i + 1] = row_ptr[i] + pattern[i].len();
344        }
345        let nnz = row_ptr[n];
346        let mut col_indices = vec![0usize; nnz];
347        let mut fill_levels = vec![0u32; nnz];
348        let mut lu_values = vec![0.0f64; nnz];
349        for i in 0..n {
350            let start = row_ptr[i];
351            for (k_idx, &(col, lev)) in pattern[i].iter().enumerate() {
352                col_indices[start + k_idx] = col;
353                fill_levels[start + k_idx] = lev;
354            }
355        }
356        for row in 0..n {
357            let a_start = a.row_ptr[row];
358            let a_end = a.row_ptr[row + 1];
359            for a_idx in a_start..a_end {
360                let col = a.col_indices[a_idx];
361                let lu_start = row_ptr[row];
362                let lu_end = row_ptr[row + 1];
363                if let Ok(offset) = col_indices[lu_start..lu_end].binary_search(&col) {
364                    lu_values[lu_start + offset] = a.values[a_idx];
365                }
366            }
367        }
368        for i in 1..n {
369            let row_start = row_ptr[i];
370            let row_end = row_ptr[i + 1];
371            let lower_cols: Vec<usize> = (row_start..row_end)
372                .map(|p| col_indices[p])
373                .take_while(|&c| c < i)
374                .collect();
375            for &kk in &lower_cols {
376                let k_start = row_ptr[kk];
377                let k_end = row_ptr[kk + 1];
378                let diag_k = match col_indices[k_start..k_end].binary_search(&kk) {
379                    Ok(off) => lu_values[k_start + off],
380                    Err(_) => 0.0,
381                };
382                if diag_k.abs() < 1e-60 {
383                    continue;
384                }
385                let p_ik = col_indices[row_start..row_end]
386                    .binary_search(&kk)
387                    .map(|off| row_start + off)
388                    .unwrap_or(usize::MAX);
389                if p_ik == usize::MAX {
390                    continue;
391                }
392                lu_values[p_ik] /= diag_k;
393                let factor = lu_values[p_ik];
394                for k_idx in k_start..k_end {
395                    let j = col_indices[k_idx];
396                    if j <= kk {
397                        continue;
398                    }
399                    if let Ok(off) = col_indices[row_start..row_end].binary_search(&j) {
400                        lu_values[row_start + off] -= factor * lu_values[k_idx];
401                    }
402                }
403            }
404        }
405        IlukPreconditioner {
406            lu_values,
407            row_ptr,
408            col_indices,
409            fill_levels,
410            n,
411            k,
412        }
413    }
414    /// Apply the ILU(k) preconditioner: solve (LU) z = r.
415    pub fn solve(&self, rhs: &[f64]) -> Vec<f64> {
416        assert_eq!(rhs.len(), self.n);
417        let n = self.n;
418        let mut y = rhs.to_vec();
419        for i in 0..n {
420            let start = self.row_ptr[i];
421            let end = self.row_ptr[i + 1];
422            for p in start..end {
423                let j = self.col_indices[p];
424                if j >= i {
425                    break;
426                }
427                y[i] -= self.lu_values[p] * y[j];
428            }
429        }
430        for i in (0..n).rev() {
431            let start = self.row_ptr[i];
432            let end = self.row_ptr[i + 1];
433            let mut diag = 1.0;
434            for p in start..end {
435                let j = self.col_indices[p];
436                if j == i {
437                    diag = self.lu_values[p];
438                } else if j > i {
439                    y[i] -= self.lu_values[p] * y[j];
440                }
441            }
442            if diag.abs() > 1e-60 {
443                y[i] /= diag;
444            }
445        }
446        y
447    }
448    /// Return the fill level parameter used for this factorization.
449    pub fn fill_level(&self) -> u32 {
450        self.k
451    }
452    /// Return the number of non-zeros in the extended sparsity pattern.
453    pub fn nnz(&self) -> usize {
454        self.lu_values.len()
455    }
456    /// Return the fill-in count: entries added beyond the original pattern.
457    pub fn fill_in_count(&self) -> usize {
458        self.fill_levels.iter().filter(|&&lev| lev > 0).count()
459    }
460}
461/// Block sparse matrix where each scalar entry is replaced by a 3×3 dense
462/// block.  Suitable for 3-D elasticity problems where DOFs come in groups of
463/// three per node.
464///
465/// The block structure is stored in CSR format on the block level.  Each
466/// block `(i, j)` holds a flat 9-element row-major 3×3 sub-matrix.
467#[derive(Debug, Clone)]
468pub struct BlockCsrMatrix3 {
469    /// Row pointer array (block-level, length = n_block_rows + 1).
470    pub row_ptr: Vec<usize>,
471    /// Block column indices (block-level).
472    pub col_indices: Vec<usize>,
473    /// Block values: each entry is a 9-element \[f64; 9\] (row-major 3×3 block).
474    pub blocks: Vec<[f64; 9]>,
475    /// Number of block rows.
476    pub n_block_rows: usize,
477    /// Number of block columns.
478    pub n_block_cols: usize,
479}
480impl BlockCsrMatrix3 {
481    /// Create an empty block CSR matrix.
482    pub fn new(n_block_rows: usize, n_block_cols: usize) -> Self {
483        Self {
484            row_ptr: vec![0; n_block_rows + 1],
485            col_indices: Vec::new(),
486            blocks: Vec::new(),
487            n_block_rows,
488            n_block_cols,
489        }
490    }
491    /// Build from a list of block-level triplets `(block_row, block_col, block_3x3)`.
492    ///
493    /// Duplicate block entries at the same `(block_row, block_col)` are summed
494    /// element-wise.
495    pub fn from_block_triplets(
496        n_block_rows: usize,
497        n_block_cols: usize,
498        triplets: &[(usize, usize, [f64; 9])],
499    ) -> Self {
500        let mut map: std::collections::HashMap<(usize, usize), [f64; 9]> =
501            std::collections::HashMap::new();
502        for &(r, c, ref blk) in triplets {
503            let entry = map.entry((r, c)).or_insert([0.0; 9]);
504            for k in 0..9 {
505                entry[k] += blk[k];
506            }
507        }
508        let mut entries: Vec<((usize, usize), [f64; 9])> = map.into_iter().collect();
509        entries.sort_by_key(|&((r, c), _)| (r, c));
510        let mut row_ptr = vec![0usize; n_block_rows + 1];
511        let mut col_indices = Vec::new();
512        let mut blocks = Vec::new();
513        for &((r, c), ref blk) in &entries {
514            row_ptr[r + 1] += 1;
515            col_indices.push(c);
516            blocks.push(*blk);
517        }
518        for i in 1..=n_block_rows {
519            row_ptr[i] += row_ptr[i - 1];
520        }
521        Self {
522            row_ptr,
523            col_indices,
524            blocks,
525            n_block_rows,
526            n_block_cols,
527        }
528    }
529    /// Multiply by a dense vector of length `3 * n_block_cols`.
530    ///
531    /// Returns a dense vector of length `3 * n_block_rows`.
532    pub fn mul_vec(&self, x: &[f64]) -> Vec<f64> {
533        assert_eq!(x.len(), self.n_block_cols * 3);
534        let mut y = vec![0.0f64; self.n_block_rows * 3];
535        for br in 0..self.n_block_rows {
536            let start = self.row_ptr[br];
537            let end = self.row_ptr[br + 1];
538            for bidx in start..end {
539                let bc = self.col_indices[bidx];
540                let blk = &self.blocks[bidx];
541                for i in 0..3 {
542                    for j in 0..3 {
543                        y[br * 3 + i] += blk[i * 3 + j] * x[bc * 3 + j];
544                    }
545                }
546            }
547        }
548        y
549    }
550    /// Return the number of stored blocks.
551    pub fn n_blocks(&self) -> usize {
552        self.blocks.len()
553    }
554    /// Get a block at `(block_row, block_col)`.  Returns a zero block if not
555    /// stored.
556    pub fn get_block(&self, block_row: usize, block_col: usize) -> [f64; 9] {
557        let start = self.row_ptr[block_row];
558        let end = self.row_ptr[block_row + 1];
559        for idx in start..end {
560            if self.col_indices[idx] == block_col {
561                return self.blocks[idx];
562            }
563        }
564        [0.0; 9]
565    }
566    /// Frobenius norm of the entire block matrix.
567    pub fn frobenius_norm(&self) -> f64 {
568        self.blocks
569            .iter()
570            .flat_map(|b| b.iter())
571            .map(|v| v * v)
572            .sum::<f64>()
573            .sqrt()
574    }
575    /// Convert to a scalar CSR matrix (expand each 3×3 block into scalar entries).
576    pub fn to_scalar_csr(&self) -> CsrMatrix {
577        let n_scalar_rows = self.n_block_rows * 3;
578        let n_scalar_cols = self.n_block_cols * 3;
579        let mut triplets = Vec::with_capacity(self.blocks.len() * 9);
580        for br in 0..self.n_block_rows {
581            let start = self.row_ptr[br];
582            let end = self.row_ptr[br + 1];
583            for bidx in start..end {
584                let bc = self.col_indices[bidx];
585                let blk = &self.blocks[bidx];
586                for i in 0..3 {
587                    for j in 0..3 {
588                        let v = blk[i * 3 + j];
589                        if v.abs() > 1e-30 {
590                            triplets.push((br * 3 + i, bc * 3 + j, v));
591                        }
592                    }
593                }
594            }
595        }
596        CsrMatrix::from_triplets(n_scalar_rows, n_scalar_cols, &triplets)
597    }
598}
599/// Compressed Sparse Row (CSR) matrix.
600///
601/// Stores a sparse matrix in CSR format with row pointers, column indices,
602/// and values arrays. Efficient for matrix-vector multiplication and row access.
603#[derive(Debug, Clone)]
604pub struct CsrMatrix {
605    /// Row pointer array (length = nrows + 1).
606    pub row_ptr: Vec<usize>,
607    /// Column indices for non-zero entries.
608    pub col_indices: Vec<usize>,
609    /// Non-zero values.
610    pub values: Vec<f64>,
611    /// Number of rows.
612    pub nrows: usize,
613    /// Number of columns.
614    pub ncols: usize,
615}
616impl CsrMatrix {
617    /// Create a new empty CSR matrix with the given dimensions.
618    pub fn new(nrows: usize, ncols: usize) -> Self {
619        Self {
620            row_ptr: vec![0; nrows + 1],
621            col_indices: Vec::new(),
622            values: Vec::new(),
623            nrows,
624            ncols,
625        }
626    }
627    /// Build a CSR matrix from coordinate (triplet) format.
628    ///
629    /// Duplicate entries at the same `(row, col)` are summed together.
630    ///
631    /// # Panics
632    ///
633    /// Panics if any index is out of bounds.
634    pub fn from_triplets(nrows: usize, ncols: usize, triplets: &[(usize, usize, f64)]) -> Self {
635        let mut map: HashMap<(usize, usize), f64> = HashMap::new();
636        for &(r, c, v) in triplets {
637            assert!(r < nrows, "row index {r} out of bounds for {nrows} rows");
638            assert!(c < ncols, "col index {c} out of bounds for {ncols} cols");
639            *map.entry((r, c)).or_insert(0.0) += v;
640        }
641        let mut entries: Vec<((usize, usize), f64)> = map.into_iter().collect();
642        entries.sort_by_key(|&((r, c), _)| (r, c));
643        let mut row_ptr = vec![0usize; nrows + 1];
644        let mut col_indices = Vec::with_capacity(entries.len());
645        let mut values = Vec::with_capacity(entries.len());
646        for &((r, c), v) in &entries {
647            row_ptr[r + 1] += 1;
648            col_indices.push(c);
649            values.push(v);
650        }
651        for i in 1..=nrows {
652            row_ptr[i] += row_ptr[i - 1];
653        }
654        Self {
655            row_ptr,
656            col_indices,
657            values,
658            nrows,
659            ncols,
660        }
661    }
662    /// Get the value at `(row, col)`. Returns 0.0 if the entry is not stored.
663    ///
664    /// # Panics
665    ///
666    /// Panics if indices are out of bounds.
667    pub fn get(&self, row: usize, col: usize) -> f64 {
668        assert!(row < self.nrows);
669        assert!(col < self.ncols);
670        let start = self.row_ptr[row];
671        let end = self.row_ptr[row + 1];
672        for idx in start..end {
673            if self.col_indices[idx] == col {
674                return self.values[idx];
675            }
676        }
677        0.0
678    }
679    /// Set the value at `(row, col)`. If the entry exists, update it;
680    /// otherwise insert a new entry.
681    ///
682    /// **Note:** Inserting new entries is O(nnz) in the worst case because the
683    /// arrays must be shifted. Prefer [`from_triplets`](Self::from_triplets) for
684    /// bulk construction.
685    ///
686    /// # Panics
687    ///
688    /// Panics if indices are out of bounds.
689    pub fn set(&mut self, row: usize, col: usize, value: f64) {
690        assert!(row < self.nrows);
691        assert!(col < self.ncols);
692        let start = self.row_ptr[row];
693        let end = self.row_ptr[row + 1];
694        for idx in start..end {
695            if self.col_indices[idx] == col {
696                self.values[idx] = value;
697                return;
698            }
699        }
700        let mut insert_pos = start;
701        while insert_pos < end && self.col_indices[insert_pos] < col {
702            insert_pos += 1;
703        }
704        self.col_indices.insert(insert_pos, col);
705        self.values.insert(insert_pos, value);
706        for r in (row + 1)..=self.nrows {
707            self.row_ptr[r] += 1;
708        }
709    }
710    /// Add `value` to the entry at `(row, col)`. If the entry does not exist,
711    /// it is created with the given value. This is the primary method used
712    /// during stiffness matrix assembly.
713    ///
714    /// # Panics
715    ///
716    /// Panics if indices are out of bounds.
717    pub fn add_to(&mut self, row: usize, col: usize, value: f64) {
718        assert!(row < self.nrows);
719        assert!(col < self.ncols);
720        let start = self.row_ptr[row];
721        let end = self.row_ptr[row + 1];
722        for idx in start..end {
723            if self.col_indices[idx] == col {
724                self.values[idx] += value;
725                return;
726            }
727        }
728        let mut insert_pos = start;
729        while insert_pos < end && self.col_indices[insert_pos] < col {
730            insert_pos += 1;
731        }
732        self.col_indices.insert(insert_pos, col);
733        self.values.insert(insert_pos, value);
734        for r in (row + 1)..=self.nrows {
735            self.row_ptr[r] += 1;
736        }
737    }
738    /// Multiply this matrix by a dense vector: `y = A * x`.
739    ///
740    /// # Panics
741    ///
742    /// Panics if `x.len() != self.ncols`.
743    pub fn mul_vec(&self, x: &[f64]) -> Vec<f64> {
744        assert_eq!(x.len(), self.ncols, "vector length must equal ncols");
745        let mut y = vec![0.0; self.nrows];
746        for (row, y_row) in y.iter_mut().enumerate().take(self.nrows) {
747            let start = self.row_ptr[row];
748            let end = self.row_ptr[row + 1];
749            let mut sum = 0.0;
750            for idx in start..end {
751                sum += self.values[idx] * x[self.col_indices[idx]];
752            }
753            *y_row = sum;
754        }
755        y
756    }
757    /// Return the number of stored non-zero entries.
758    pub fn nnz(&self) -> usize {
759        self.values.len()
760    }
761    /// Get the diagonal element at `(i, i)`.
762    pub fn diagonal(&self, i: usize) -> f64 {
763        self.get(i, i)
764    }
765    /// Transpose this CSR matrix, returning a new CSR matrix.
766    pub fn transpose(&self) -> CsrMatrix {
767        let mut triplets = Vec::with_capacity(self.nnz());
768        for row in 0..self.nrows {
769            let start = self.row_ptr[row];
770            let end = self.row_ptr[row + 1];
771            for idx in start..end {
772                triplets.push((self.col_indices[idx], row, self.values[idx]));
773            }
774        }
775        CsrMatrix::from_triplets(self.ncols, self.nrows, &triplets)
776    }
777    /// Add two CSR matrices: C = A + B.
778    ///
779    /// Both matrices must have the same dimensions.
780    pub fn add(&self, other: &CsrMatrix) -> CsrMatrix {
781        assert_eq!(self.nrows, other.nrows, "row dimensions must match");
782        assert_eq!(self.ncols, other.ncols, "col dimensions must match");
783        let mut triplets = Vec::with_capacity(self.nnz() + other.nnz());
784        for row in 0..self.nrows {
785            let start = self.row_ptr[row];
786            let end = self.row_ptr[row + 1];
787            for idx in start..end {
788                triplets.push((row, self.col_indices[idx], self.values[idx]));
789            }
790        }
791        for row in 0..other.nrows {
792            let start = other.row_ptr[row];
793            let end = other.row_ptr[row + 1];
794            for idx in start..end {
795                triplets.push((row, other.col_indices[idx], other.values[idx]));
796            }
797        }
798        CsrMatrix::from_triplets(self.nrows, self.ncols, &triplets)
799    }
800    /// Scale all values by a scalar: A *= alpha.
801    pub fn scale(&mut self, alpha: f64) {
802        for v in self.values.iter_mut() {
803            *v *= alpha;
804        }
805    }
806    /// Return a scaled copy: B = alpha * A.
807    pub fn scaled(&self, alpha: f64) -> CsrMatrix {
808        let mut result = self.clone();
809        result.scale(alpha);
810        result
811    }
812    /// Optimized sparse matrix-vector multiply using row-based access.
813    ///
814    /// Same as `mul_vec` but with explicit prefetch-friendly ordering.
815    /// y = alpha * A * x + beta * y
816    pub fn mul_vec_axpby(&self, x: &[f64], y: &mut [f64], alpha: f64, beta: f64) {
817        assert_eq!(x.len(), self.ncols);
818        assert_eq!(y.len(), self.nrows);
819        for (row, y_row) in y.iter_mut().enumerate() {
820            let start = self.row_ptr[row];
821            let end = self.row_ptr[row + 1];
822            let mut sum = 0.0;
823            for idx in start..end {
824                sum += self.values[idx] * x[self.col_indices[idx]];
825            }
826            *y_row = alpha * sum + beta * *y_row;
827        }
828    }
829    /// Symmetric sparse matrix-vector multiply (upper triangle only).
830    ///
831    /// Assumes only the upper triangle is stored. Computes y = A * x
832    /// using symmetry: a_ij contributes to both y\[i\] and y\[j\].
833    pub fn symmetric_mul_vec(&self, x: &[f64]) -> Vec<f64> {
834        assert_eq!(x.len(), self.ncols);
835        assert_eq!(
836            self.nrows, self.ncols,
837            "matrix must be square for symmetric multiply"
838        );
839        let n = self.nrows;
840        let mut y = vec![0.0; n];
841        for row in 0..n {
842            let start = self.row_ptr[row];
843            let end = self.row_ptr[row + 1];
844            for idx in start..end {
845                let col = self.col_indices[idx];
846                let val = self.values[idx];
847                y[row] += val * x[col];
848                if col != row {
849                    y[col] += val * x[row];
850                }
851            }
852        }
853        y
854    }
855    /// Extract diagonal as a vector.
856    pub fn diagonal_vec(&self) -> Vec<f64> {
857        let n = self.nrows.min(self.ncols);
858        (0..n).map(|i| self.get(i, i)).collect()
859    }
860    /// Convert this CSR matrix to CSC format.
861    pub fn to_csc(&self) -> CscMatrix {
862        let mut triplets = Vec::with_capacity(self.nnz());
863        for row in 0..self.nrows {
864            let start = self.row_ptr[row];
865            let end = self.row_ptr[row + 1];
866            for idx in start..end {
867                triplets.push((row, self.col_indices[idx], self.values[idx]));
868            }
869        }
870        CscMatrix::from_triplets(self.nrows, self.ncols, &triplets)
871    }
872    /// Frobenius norm: ||A||_F = sqrt(sum of a_ij^2).
873    pub fn frobenius_norm(&self) -> f64 {
874        self.values.iter().map(|v| v * v).sum::<f64>().sqrt()
875    }
876}
877/// Thin wrapper around `Vec`f64` for sparse/dense vector operations.
878#[derive(Debug, Clone)]
879pub struct SparseVector {
880    /// The underlying dense data.
881    pub data: Vec<f64>,
882}
883impl SparseVector {
884    /// Create a new zero vector of the given length.
885    pub fn new(len: usize) -> Self {
886        Self {
887            data: vec![0.0; len],
888        }
889    }
890    /// Create a sparse vector from existing data.
891    pub fn from_vec(data: Vec<f64>) -> Self {
892        Self { data }
893    }
894    /// Return the length of the vector.
895    pub fn len(&self) -> usize {
896        self.data.len()
897    }
898    /// Check if the vector is empty.
899    pub fn is_empty(&self) -> bool {
900        self.data.is_empty()
901    }
902    /// Compute the dot product with another vector.
903    pub fn dot(&self, other: &Self) -> f64 {
904        self.data
905            .iter()
906            .zip(other.data.iter())
907            .map(|(a, b)| a * b)
908            .sum()
909    }
910    /// Compute the Euclidean norm.
911    pub fn norm(&self) -> f64 {
912        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
913    }
914    /// AXPY: self = alpha * self + beta * other
915    pub fn axpby(&mut self, alpha: f64, other: &SparseVector, beta: f64) {
916        assert_eq!(self.data.len(), other.data.len());
917        for (a, b) in self.data.iter_mut().zip(other.data.iter()) {
918            *a = alpha * *a + beta * *b;
919        }
920    }
921    /// Scale all entries by a scalar.
922    pub fn scale(&mut self, alpha: f64) {
923        for v in self.data.iter_mut() {
924            *v *= alpha;
925        }
926    }
927}
928/// Incomplete LU factorization with zero fill-in (ILU(0)).
929///
930/// The factorization maintains the same sparsity pattern as the original matrix.
931/// Used as a preconditioner for iterative solvers like CG or GMRES.
932#[derive(Debug, Clone)]
933pub struct Ilu0Preconditioner {
934    /// Combined L and U factors stored in CSR format.
935    /// L has unit diagonal (not stored), U diagonal is stored.
936    /// The storage reuses the sparsity pattern of A.
937    pub(super) lu_values: Vec<f64>,
938    pub(super) row_ptr: Vec<usize>,
939    pub(super) col_indices: Vec<usize>,
940    pub(super) n: usize,
941}
942impl Ilu0Preconditioner {
943    /// Compute the ILU(0) factorization of matrix A.
944    ///
945    /// The factorization uses the same sparsity pattern as A. Elements that
946    /// would be fill-in are dropped.
947    pub fn new(a: &CsrMatrix) -> Self {
948        assert_eq!(a.nrows, a.ncols, "matrix must be square");
949        let n = a.nrows;
950        let mut lu_values = a.values.clone();
951        let row_ptr = a.row_ptr.clone();
952        let col_indices = a.col_indices.clone();
953        for i in 1..n {
954            let row_start = row_ptr[i];
955            let row_end = row_ptr[i + 1];
956            for p in row_start..row_end {
957                let k = col_indices[p];
958                if k >= i {
959                    break;
960                }
961                let k_start = row_ptr[k];
962                let k_end = row_ptr[k + 1];
963                let mut diag_k = 0.0;
964                for q in k_start..k_end {
965                    if col_indices[q] == k {
966                        diag_k = lu_values[q];
967                        break;
968                    }
969                }
970                if diag_k.abs() < 1e-60 {
971                    continue;
972                }
973                lu_values[p] /= diag_k;
974                let factor = lu_values[p];
975                for q in k_start..k_end {
976                    let j = col_indices[q];
977                    if j <= k {
978                        continue;
979                    }
980                    for s in row_start..row_end {
981                        if col_indices[s] == j {
982                            lu_values[s] -= factor * lu_values[q];
983                            break;
984                        }
985                    }
986                }
987            }
988        }
989        Ilu0Preconditioner {
990            lu_values,
991            row_ptr,
992            col_indices,
993            n,
994        }
995    }
996    /// Apply the ILU(0) preconditioner: solve (L U) z = r.
997    ///
998    /// First forward-substitutes L y = r (L has unit diagonal),
999    /// then back-substitutes U z = y.
1000    pub fn solve(&self, rhs: &[f64]) -> Vec<f64> {
1001        assert_eq!(rhs.len(), self.n);
1002        let n = self.n;
1003        let mut y = rhs.to_vec();
1004        for i in 0..n {
1005            let start = self.row_ptr[i];
1006            let end = self.row_ptr[i + 1];
1007            for p in start..end {
1008                let j = self.col_indices[p];
1009                if j >= i {
1010                    break;
1011                }
1012                y[i] -= self.lu_values[p] * y[j];
1013            }
1014        }
1015        for i in (0..n).rev() {
1016            let start = self.row_ptr[i];
1017            let end = self.row_ptr[i + 1];
1018            let mut diag = 1.0;
1019            for p in start..end {
1020                let j = self.col_indices[p];
1021                if j == i {
1022                    diag = self.lu_values[p];
1023                } else if j > i {
1024                    y[i] -= self.lu_values[p] * y[j];
1025                }
1026            }
1027            if diag.abs() > 1e-60 {
1028                y[i] /= diag;
1029            }
1030        }
1031        y
1032    }
1033}
1034/// Compressed Sparse Column (CSC) matrix.
1035///
1036/// Stores a sparse matrix in CSC format with column pointers, row indices,
1037/// and values arrays. Efficient for column access and certain direct solvers.
1038#[derive(Debug, Clone)]
1039pub struct CscMatrix {
1040    /// Column pointer array (length = ncols + 1).
1041    pub col_ptr: Vec<usize>,
1042    /// Row indices for non-zero entries.
1043    pub row_indices: Vec<usize>,
1044    /// Non-zero values.
1045    pub values: Vec<f64>,
1046    /// Number of rows.
1047    pub nrows: usize,
1048    /// Number of columns.
1049    pub ncols: usize,
1050}
1051impl CscMatrix {
1052    /// Build a CSC matrix from coordinate (triplet) format.
1053    ///
1054    /// Duplicate entries are summed.
1055    pub fn from_triplets(nrows: usize, ncols: usize, triplets: &[(usize, usize, f64)]) -> Self {
1056        let mut map: HashMap<(usize, usize), f64> = HashMap::new();
1057        for &(r, c, v) in triplets {
1058            assert!(r < nrows, "row index {r} out of bounds for {nrows} rows");
1059            assert!(c < ncols, "col index {c} out of bounds for {ncols} cols");
1060            *map.entry((r, c)).or_insert(0.0) += v;
1061        }
1062        let mut entries: Vec<((usize, usize), f64)> = map.into_iter().collect();
1063        entries.sort_by_key(|&((r, c), _)| (c, r));
1064        let mut col_ptr = vec![0usize; ncols + 1];
1065        let mut row_indices = Vec::with_capacity(entries.len());
1066        let mut values = Vec::with_capacity(entries.len());
1067        for &((r, c), v) in &entries {
1068            col_ptr[c + 1] += 1;
1069            row_indices.push(r);
1070            values.push(v);
1071        }
1072        for i in 1..=ncols {
1073            col_ptr[i] += col_ptr[i - 1];
1074        }
1075        Self {
1076            col_ptr,
1077            row_indices,
1078            values,
1079            nrows,
1080            ncols,
1081        }
1082    }
1083    /// Get the value at (row, col).
1084    pub fn get(&self, row: usize, col: usize) -> f64 {
1085        assert!(row < self.nrows);
1086        assert!(col < self.ncols);
1087        let start = self.col_ptr[col];
1088        let end = self.col_ptr[col + 1];
1089        for idx in start..end {
1090            if self.row_indices[idx] == row {
1091                return self.values[idx];
1092            }
1093        }
1094        0.0
1095    }
1096    /// Number of stored non-zero entries.
1097    pub fn nnz(&self) -> usize {
1098        self.values.len()
1099    }
1100    /// Multiply by a dense vector: y = A * x.
1101    pub fn mul_vec(&self, x: &[f64]) -> Vec<f64> {
1102        assert_eq!(x.len(), self.ncols, "vector length must equal ncols");
1103        let mut y = vec![0.0; self.nrows];
1104        for (col, &x_col) in x.iter().enumerate().take(self.ncols) {
1105            let start = self.col_ptr[col];
1106            let end = self.col_ptr[col + 1];
1107            for idx in start..end {
1108                y[self.row_indices[idx]] += self.values[idx] * x_col;
1109            }
1110        }
1111        y
1112    }
1113    /// Transpose this CSC matrix, returning a new CSC matrix.
1114    pub fn transpose(&self) -> CscMatrix {
1115        let mut triplets = Vec::with_capacity(self.nnz());
1116        for col in 0..self.ncols {
1117            let start = self.col_ptr[col];
1118            let end = self.col_ptr[col + 1];
1119            for idx in start..end {
1120                triplets.push((col, self.row_indices[idx], self.values[idx]));
1121            }
1122        }
1123        CscMatrix::from_triplets(self.ncols, self.nrows, &triplets)
1124    }
1125    /// Convert to CSR format.
1126    pub fn to_csr(&self) -> CsrMatrix {
1127        let mut triplets = Vec::with_capacity(self.nnz());
1128        for col in 0..self.ncols {
1129            let start = self.col_ptr[col];
1130            let end = self.col_ptr[col + 1];
1131            for idx in start..end {
1132                triplets.push((self.row_indices[idx], col, self.values[idx]));
1133            }
1134        }
1135        CsrMatrix::from_triplets(self.nrows, self.ncols, &triplets)
1136    }
1137}