Skip to main content

oxicuda_solver/sparse/
direct_factorization.rs

1//! Sparse direct solvers: supernodal Cholesky and multifrontal LU.
2//!
3//! Provides direct factorization methods for sparse linear systems, complementing
4//! the iterative solvers (CG, GMRES, etc.) with exact methods for when iterative
5//! convergence is slow or reliability is paramount.
6//!
7//! - [`SupernodalCholeskySolver`] — supernodal Cholesky for symmetric positive definite systems
8//! - [`MultifrontalLUSolver`] — multifrontal LU with partial pivoting for general sparse systems
9//!
10//! (C) 2026 COOLJAPAN OU (Team KitaSan)
11
12use crate::error::SolverError;
13
14// ---------------------------------------------------------------------------
15// Elimination Tree
16// ---------------------------------------------------------------------------
17
18/// Elimination tree of a sparse matrix.
19///
20/// The elimination tree captures the parent-child relationships among columns
21/// during Cholesky or LU factorization. It is the foundation for supernodal
22/// and multifrontal methods.
23#[derive(Debug, Clone)]
24pub struct EliminationTree {
25    /// Parent of each node (None for roots).
26    parent: Vec<Option<usize>>,
27    /// Children of each node.
28    children: Vec<Vec<usize>>,
29    /// Nodes in postorder (leaves first).
30    postorder: Vec<usize>,
31    /// Matrix dimension.
32    n: usize,
33}
34
35impl EliminationTree {
36    /// Compute the elimination tree from CSR structure (lower triangle).
37    ///
38    /// Uses union-find with path compression for O(n * alpha(n)) complexity.
39    pub fn from_csr(row_offsets: &[usize], col_indices: &[usize], n: usize) -> Self {
40        let mut parent: Vec<Option<usize>> = vec![None; n];
41        let mut ancestor: Vec<usize> = (0..n).collect();
42
43        for i in 0..n {
44            let row_start = row_offsets.get(i).copied().unwrap_or(0);
45            let row_end = row_offsets.get(i + 1).copied().unwrap_or(row_start);
46
47            for idx in row_start..row_end {
48                let j = match col_indices.get(idx) {
49                    Some(&c) if c < i => c,
50                    _ => continue,
51                };
52
53                let mut r = j;
54                while ancestor[r] != r {
55                    let next = ancestor[r];
56                    ancestor[r] = i;
57                    r = next;
58                }
59                if r != i && parent[r].is_none() {
60                    parent[r] = Some(i);
61                    ancestor[r] = i;
62                }
63            }
64        }
65
66        let mut children: Vec<Vec<usize>> = vec![Vec::new(); n];
67        for (node, par) in parent.iter().enumerate() {
68            if let Some(p) = par {
69                children[*p].push(node);
70            }
71        }
72
73        let postorder = Self::compute_postorder(&parent, &children, n);
74
75        Self {
76            parent,
77            children,
78            postorder,
79            n,
80        }
81    }
82
83    fn compute_postorder(
84        parent: &[Option<usize>],
85        children: &[Vec<usize>],
86        n: usize,
87    ) -> Vec<usize> {
88        let mut order = Vec::with_capacity(n);
89        let mut visited = vec![false; n];
90
91        let roots: Vec<usize> = (0..n).filter(|&i| parent[i].is_none()).collect();
92
93        for root in roots {
94            let mut stack: Vec<(usize, bool)> = vec![(root, false)];
95            while let Some((node, expanded)) = stack.pop() {
96                if expanded {
97                    order.push(node);
98                    visited[node] = true;
99                } else {
100                    stack.push((node, true));
101                    for &child in children[node].iter().rev() {
102                        if !visited[child] {
103                            stack.push((child, false));
104                        }
105                    }
106                }
107            }
108        }
109
110        order
111    }
112
113    /// Returns nodes in postorder (leaves before parents).
114    pub fn postorder_traversal(&self) -> &[usize] {
115        &self.postorder
116    }
117
118    /// Size of the subtree rooted at `node` (including the node itself).
119    pub fn subtree_size(&self, node: usize) -> usize {
120        if node >= self.n {
121            return 0;
122        }
123        let mut size = 1usize;
124        let mut stack = vec![node];
125        while let Some(cur) = stack.pop() {
126            if let Some(kids) = self.children.get(cur) {
127                for &child in kids {
128                    size += 1;
129                    stack.push(child);
130                }
131            }
132        }
133        size
134    }
135
136    /// Number of nodes.
137    pub fn size(&self) -> usize {
138        self.n
139    }
140
141    /// Parent of a node.
142    pub fn parent_of(&self, node: usize) -> Option<usize> {
143        self.parent.get(node).copied().flatten()
144    }
145
146    /// Children of a node.
147    pub fn children_of(&self, node: usize) -> &[usize] {
148        self.children.get(node).map_or(&[], |v| v.as_slice())
149    }
150}
151
152// ---------------------------------------------------------------------------
153// Column counts
154// ---------------------------------------------------------------------------
155
156/// Compute the number of non-zeros in each column of L.
157///
158/// Uses the elimination tree to propagate fill-in counts from leaves to root.
159/// Input is CSR lower triangle.
160pub fn column_counts(
161    row_offsets: &[usize],
162    col_indices: &[usize],
163    etree: &EliminationTree,
164) -> Vec<usize> {
165    let n = etree.size();
166    let mut counts = vec![1usize; n];
167
168    // Build column-to-rows map from CSR lower triangle
169    let mut col_rows: Vec<Vec<usize>> = vec![Vec::new(); n];
170    for i in 0..n {
171        let rs = row_offsets.get(i).copied().unwrap_or(0);
172        let re = row_offsets.get(i + 1).copied().unwrap_or(rs);
173        for idx in rs..re {
174            if let Some(&j) = col_indices.get(idx) {
175                if j < i {
176                    col_rows[j].push(i);
177                }
178            }
179        }
180    }
181
182    // Compute row indices of L for each column via etree
183    let mut l_rows: Vec<Vec<usize>> = vec![Vec::new(); n];
184    for &node in etree.postorder_traversal() {
185        let mut rows: Vec<usize> = col_rows[node].clone();
186
187        for &child in etree.children_of(node) {
188            for &r in &l_rows[child] {
189                if r > node {
190                    rows.push(r);
191                }
192            }
193        }
194
195        rows.sort_unstable();
196        rows.dedup();
197        counts[node] = 1 + rows.len();
198        l_rows[node] = rows;
199    }
200
201    counts
202}
203
204// ---------------------------------------------------------------------------
205// Supernode
206// ---------------------------------------------------------------------------
207
208/// A supernode: a contiguous set of columns with identical sparsity below the diagonal.
209#[derive(Debug, Clone)]
210pub struct Supernode {
211    /// First column index.
212    pub start: usize,
213    /// One past the last column index.
214    pub end: usize,
215    /// Row indices of this supernode (including the diagonal rows).
216    pub columns: Vec<usize>,
217    /// Dense block stored column-major: nrow x ncol where ncol = end - start.
218    pub dense_block: Vec<f64>,
219}
220
221impl Supernode {
222    /// Number of columns in this supernode.
223    pub fn width(&self) -> usize {
224        self.end - self.start
225    }
226
227    /// Number of rows (including diagonal rows).
228    pub fn nrows(&self) -> usize {
229        self.columns.len()
230    }
231}
232
233// ---------------------------------------------------------------------------
234// Supernodal Structure
235// ---------------------------------------------------------------------------
236
237/// Supernodal partition of the matrix.
238#[derive(Debug, Clone)]
239pub struct SupernodalStructure {
240    /// The supernodes.
241    pub supernodes: Vec<Supernode>,
242    /// Maps each column to its supernode index.
243    pub membership: Vec<usize>,
244}
245
246impl SupernodalStructure {
247    /// Detect fundamental supernodes from the elimination tree and column counts.
248    pub fn from_etree(
249        etree: &EliminationTree,
250        row_offsets: &[usize],
251        col_indices: &[usize],
252    ) -> Self {
253        let n = etree.size();
254        let col_cnts = column_counts(row_offsets, col_indices, etree);
255
256        // Detect supernode boundaries
257        let mut is_start = vec![true; n];
258        for j in 0..n.saturating_sub(1) {
259            if etree.parent_of(j) == Some(j + 1)
260                && col_cnts[j + 1] + 1 == col_cnts[j]
261                && etree.children_of(j + 1).len() <= 1
262            {
263                is_start[j + 1] = false;
264            }
265        }
266
267        // Build column-to-rows map (L sparsity pattern) for row membership
268        let mut col_rows: Vec<Vec<usize>> = vec![Vec::new(); n];
269        for i in 0..n {
270            let rs = row_offsets.get(i).copied().unwrap_or(0);
271            let re = row_offsets.get(i + 1).copied().unwrap_or(rs);
272            for idx in rs..re {
273                if let Some(&j) = col_indices.get(idx) {
274                    if j < i {
275                        col_rows[j].push(i);
276                    }
277                }
278            }
279        }
280
281        // Compute L sparsity (row indices per column including fill)
282        let mut l_rows: Vec<Vec<usize>> = vec![Vec::new(); n];
283        for &node in etree.postorder_traversal() {
284            let mut rows: Vec<usize> = col_rows[node].clone();
285            for &child in etree.children_of(node) {
286                for &r in &l_rows[child] {
287                    if r > node {
288                        rows.push(r);
289                    }
290                }
291            }
292            rows.sort_unstable();
293            rows.dedup();
294            l_rows[node] = rows;
295        }
296
297        // Build supernodes
298        let mut supernodes = Vec::new();
299        let mut membership = vec![0usize; n];
300
301        let mut i = 0;
302        while i < n {
303            let start = i;
304            let mut end = i + 1;
305            while end < n && !is_start[end] {
306                end += 1;
307            }
308
309            // Row indices: diagonal rows [start..end) plus sub-diagonal from L pattern
310            let mut rows: Vec<usize> = (start..end).collect();
311            // Use L sparsity of the first column in the supernode
312            // (fundamental supernodes have nested sparsity)
313            for &r in &l_rows[start] {
314                if r >= end {
315                    rows.push(r);
316                }
317            }
318            // Also include rows from other columns in the supernode
319            for l_row_set in l_rows.iter().take(end).skip(start + 1) {
320                for &r in l_row_set {
321                    if r >= end && !rows.contains(&r) {
322                        rows.push(r);
323                    }
324                }
325            }
326            rows.sort_unstable();
327            rows.dedup();
328
329            let nrows = rows.len();
330            let ncols = end - start;
331
332            let sn_idx = supernodes.len();
333            for m in membership.iter_mut().take(end).skip(start) {
334                *m = sn_idx;
335            }
336
337            supernodes.push(Supernode {
338                start,
339                end,
340                columns: rows,
341                dense_block: vec![0.0; nrows * ncols],
342            });
343
344            i = end;
345        }
346
347        Self {
348            supernodes,
349            membership,
350        }
351    }
352}
353
354// ---------------------------------------------------------------------------
355// Symbolic Factorization (reusable)
356// ---------------------------------------------------------------------------
357
358/// Reusable symbolic factorization for repeated numeric factorizations
359/// with the same sparsity pattern.
360#[derive(Debug, Clone)]
361pub struct SymbolicFactorization {
362    /// The elimination tree.
363    pub etree: EliminationTree,
364    /// The supernodal structure.
365    pub structure: SupernodalStructure,
366    /// Estimated non-zeros in L.
367    pub nnz_l: usize,
368    /// Estimated non-zeros in U (same as nnz_l for Cholesky).
369    pub nnz_u: usize,
370}
371
372impl SymbolicFactorization {
373    /// Perform symbolic factorization from CSR structure (lower triangle).
374    pub fn compute(
375        row_offsets: &[usize],
376        col_indices: &[usize],
377        n: usize,
378    ) -> Result<Self, SolverError> {
379        if row_offsets.len() != n + 1 {
380            return Err(SolverError::DimensionMismatch(format!(
381                "row_offsets length {} != n+1 = {}",
382                row_offsets.len(),
383                n + 1
384            )));
385        }
386
387        let etree = EliminationTree::from_csr(row_offsets, col_indices, n);
388        let structure = SupernodalStructure::from_etree(&etree, row_offsets, col_indices);
389
390        let nnz_l: usize = structure
391            .supernodes
392            .iter()
393            .map(|sn| sn.nrows() * sn.width())
394            .sum();
395
396        Ok(Self {
397            etree,
398            structure,
399            nnz_l,
400            nnz_u: nnz_l,
401        })
402    }
403}
404
405// ---------------------------------------------------------------------------
406// Supernodal Cholesky Solver
407// ---------------------------------------------------------------------------
408
409/// Supernodal Cholesky solver for symmetric positive definite sparse systems.
410///
411/// Performs `A = L * L^T` factorization using supernodal dense operations
412/// within each supernode for BLAS-like efficiency.
413#[derive(Debug, Clone)]
414pub struct SupernodalCholeskySolver {
415    /// The supernodal structure (holds L factor in dense blocks).
416    structure: SupernodalStructure,
417    /// Whether numeric factorization has been performed.
418    factored: bool,
419    /// The elimination tree.
420    etree: EliminationTree,
421    /// Matrix dimension.
422    n: usize,
423}
424
425impl SupernodalCholeskySolver {
426    /// Symbolic factorization: compute elimination tree, supernodes, column counts.
427    pub fn symbolic(
428        row_offsets: &[usize],
429        col_indices: &[usize],
430        n: usize,
431    ) -> Result<Self, SolverError> {
432        if row_offsets.len() != n + 1 {
433            return Err(SolverError::DimensionMismatch(format!(
434                "row_offsets length {} != n+1 = {}",
435                row_offsets.len(),
436                n + 1
437            )));
438        }
439
440        let etree = EliminationTree::from_csr(row_offsets, col_indices, n);
441        let structure = SupernodalStructure::from_etree(&etree, row_offsets, col_indices);
442
443        Ok(Self {
444            structure,
445            factored: false,
446            etree,
447            n,
448        })
449    }
450
451    /// Numeric factorization: fill supernode dense blocks with L values.
452    ///
453    /// Builds the full symmetric matrix from CSR lower triangle, then
454    /// assembles and factors supernodes in postorder.
455    pub fn numeric(
456        &mut self,
457        row_offsets: &[usize],
458        col_indices: &[usize],
459        values: &[f64],
460    ) -> Result<(), SolverError> {
461        let n = self.n;
462
463        // Build full symmetric dense matrix from CSR lower triangle
464        let mut dense = vec![0.0f64; n * n];
465        for i in 0..n {
466            let rs = row_offsets.get(i).copied().unwrap_or(0);
467            let re = row_offsets.get(i + 1).copied().unwrap_or(rs);
468            for idx in rs..re {
469                let j = match col_indices.get(idx) {
470                    Some(&c) => c,
471                    None => continue,
472                };
473                let val = values.get(idx).copied().unwrap_or(0.0);
474                if i < n && j < n {
475                    dense[i + j * n] = val;
476                    dense[j + i * n] = val; // symmetrize
477                }
478            }
479        }
480
481        // Reset dense blocks
482        for sn in &mut self.structure.supernodes {
483            for v in &mut sn.dense_block {
484                *v = 0.0;
485            }
486        }
487
488        // Assemble entries into supernode dense blocks from the full matrix
489        for sn in &mut self.structure.supernodes {
490            let ncols = sn.width();
491            let nrows = sn.nrows();
492            for lc in 0..ncols {
493                let gc = sn.start + lc;
494                for (lr, &gr) in sn.columns.iter().enumerate() {
495                    if gr < n && gc < n {
496                        sn.dense_block[lr + lc * nrows] = dense[gr + gc * n];
497                    }
498                }
499            }
500        }
501
502        // Process supernodes in postorder
503        let postorder: Vec<usize> = self.etree.postorder_traversal().to_vec();
504        let num_supernodes = self.structure.supernodes.len();
505        let mut processed = vec![false; num_supernodes];
506
507        for &node in &postorder {
508            let sn_idx = match self.structure.membership.get(node) {
509                Some(&idx) if idx < num_supernodes => idx,
510                _ => continue,
511            };
512
513            if processed[sn_idx] {
514                continue;
515            }
516            processed[sn_idx] = true;
517
518            self.factor_supernode(sn_idx)?;
519        }
520
521        self.factored = true;
522        Ok(())
523    }
524
525    fn factor_supernode(&mut self, sn_idx: usize) -> Result<(), SolverError> {
526        let sn = match self.structure.supernodes.get(sn_idx) {
527            Some(s) => s,
528            None => {
529                return Err(SolverError::InternalError(
530                    "invalid supernode index".to_string(),
531                ));
532            }
533        };
534
535        let ncols = sn.width();
536        let nrows = sn.nrows();
537
538        if ncols == 0 || nrows == 0 {
539            return Ok(());
540        }
541
542        let mut block = self.structure.supernodes[sn_idx].dense_block.clone();
543
544        // Dense Cholesky on the diagonal block (ncols x ncols, top-left)
545        for k in 0..ncols {
546            let diag_idx = k + k * nrows;
547            let diag_val = match block.get(diag_idx) {
548                Some(&v) => v,
549                None => {
550                    return Err(SolverError::InternalError(
551                        "dense block index out of bounds".to_string(),
552                    ));
553                }
554            };
555
556            if diag_val <= 0.0 {
557                return Err(SolverError::NotPositiveDefinite);
558            }
559            let l_kk = diag_val.sqrt();
560            block[diag_idx] = l_kk;
561            let l_kk_inv = 1.0 / l_kk;
562
563            // Scale column k below diagonal
564            for i in (k + 1)..nrows {
565                block[i + k * nrows] *= l_kk_inv;
566            }
567
568            // Update trailing submatrix
569            for j in (k + 1)..ncols {
570                let l_jk = block[j + k * nrows];
571                for i in j..nrows {
572                    block[i + j * nrows] -= block[i + k * nrows] * l_jk;
573                }
574            }
575        }
576
577        // Update ancestor supernodes with off-diagonal contribution
578        if nrows > ncols {
579            let off_rows: Vec<usize> = self.structure.supernodes[sn_idx].columns[ncols..].to_vec();
580            let off_nrows = nrows - ncols;
581
582            // Compute update matrix: L_21 * L_21^T
583            let mut update = vec![0.0f64; off_nrows * off_nrows];
584            for k in 0..ncols {
585                for i in 0..off_nrows {
586                    let l_ik = block[(ncols + i) + k * nrows];
587                    for j in 0..=i {
588                        let l_jk = block[(ncols + j) + k * nrows];
589                        update[i + j * off_nrows] += l_ik * l_jk;
590                    }
591                }
592            }
593
594            // Scatter update into ancestor supernodes
595            for i in 0..off_nrows {
596                for j in 0..=i {
597                    let row_i = off_rows[i];
598                    let row_j = off_rows[j];
599                    let target_sn_idx = match self.structure.membership.get(row_j) {
600                        Some(&idx) => idx,
601                        None => continue,
602                    };
603                    let target = match self.structure.supernodes.get_mut(target_sn_idx) {
604                        Some(s) => s,
605                        None => continue,
606                    };
607                    let local_col = row_j - target.start;
608                    if local_col >= target.width() {
609                        continue;
610                    }
611                    if let Some(local_row) = target.columns.iter().position(|&r| r == row_i) {
612                        let tnrows = target.nrows();
613                        if let Some(entry) =
614                            target.dense_block.get_mut(local_row + local_col * tnrows)
615                        {
616                            *entry -= update[i + j * off_nrows];
617                        }
618                    }
619                    // Also scatter the symmetric entry if i != j
620                    if i != j {
621                        let target2 = match self
622                            .structure
623                            .supernodes
624                            .get_mut(*self.structure.membership.get(row_i).unwrap_or(&0))
625                        {
626                            Some(s) => s,
627                            None => continue,
628                        };
629                        let local_col2 = row_i - target2.start;
630                        if local_col2 >= target2.width() {
631                            continue;
632                        }
633                        if let Some(local_row2) = target2.columns.iter().position(|&r| r == row_j) {
634                            let tnrows2 = target2.nrows();
635                            if let Some(entry2) = target2
636                                .dense_block
637                                .get_mut(local_row2 + local_col2 * tnrows2)
638                            {
639                                *entry2 -= update[i + j * off_nrows];
640                            }
641                        }
642                    }
643                }
644            }
645        }
646
647        self.structure.supernodes[sn_idx].dense_block = block;
648        Ok(())
649    }
650
651    /// Solve `A * x = b` using the supernodal Cholesky factorization.
652    ///
653    /// Performs forward solve `L * y = b` then backward solve `L^T * x = y`.
654    pub fn solve(&self, rhs: &[f64]) -> Result<Vec<f64>, SolverError> {
655        if !self.factored {
656            return Err(SolverError::InternalError(
657                "numeric factorization not performed".to_string(),
658            ));
659        }
660        if rhs.len() != self.n {
661            return Err(SolverError::DimensionMismatch(format!(
662                "rhs length {} != n = {}",
663                rhs.len(),
664                self.n
665            )));
666        }
667
668        let mut x = rhs.to_vec();
669
670        // Forward solve: L * y = b
671        for sn in &self.structure.supernodes {
672            let ncols = sn.width();
673            let nrows = sn.nrows();
674
675            for k in 0..ncols {
676                let l_kk = sn.dense_block[k + k * nrows];
677                if l_kk.abs() < 1e-300 {
678                    return Err(SolverError::SingularMatrix);
679                }
680                let global_k = sn.columns[k];
681                x[global_k] /= l_kk;
682
683                let x_k = x[global_k];
684                for i in (k + 1)..nrows {
685                    let global_i = sn.columns[i];
686                    x[global_i] -= sn.dense_block[i + k * nrows] * x_k;
687                }
688            }
689        }
690
691        // Backward solve: L^T * x = y
692        for sn in self.structure.supernodes.iter().rev() {
693            let ncols = sn.width();
694            let nrows = sn.nrows();
695
696            for k in (0..ncols).rev() {
697                let global_k = sn.columns[k];
698                for i in (k + 1)..nrows {
699                    let global_i = sn.columns[i];
700                    x[global_k] -= sn.dense_block[i + k * nrows] * x[global_i];
701                }
702
703                let l_kk = sn.dense_block[k + k * nrows];
704                if l_kk.abs() < 1e-300 {
705                    return Err(SolverError::SingularMatrix);
706                }
707                x[global_k] /= l_kk;
708            }
709        }
710
711        Ok(x)
712    }
713
714    /// Number of non-zeros in the L factor.
715    pub fn nnz_factor(&self) -> usize {
716        self.structure
717            .supernodes
718            .iter()
719            .map(|sn| {
720                let ncols = sn.width();
721                let nrows = sn.nrows();
722                let diag_nnz = ncols * (ncols + 1) / 2;
723                let offdiag_nnz = (nrows - ncols) * ncols;
724                diag_nnz + offdiag_nnz
725            })
726            .sum()
727    }
728}
729
730// ---------------------------------------------------------------------------
731// Multifrontal LU Solver
732// ---------------------------------------------------------------------------
733
734/// Multifrontal LU solver for general (non-symmetric) sparse systems.
735///
736/// Performs `P * A = L * U` factorization with partial pivoting within each
737/// frontal matrix, using the supernodal structure for efficient dense operations.
738#[derive(Debug, Clone)]
739pub struct MultifrontalLUSolver {
740    /// Global L factor (column-major dense, n x n).
741    l_factor: Vec<f64>,
742    /// Global U factor (column-major dense, n x n).
743    u_factor: Vec<f64>,
744    /// Global pivot permutation.
745    perm: Vec<usize>,
746    /// Whether numeric factorization has been performed.
747    factored: bool,
748    /// The supernodal structure (for symbolic analysis).
749    #[allow(dead_code)]
750    structure: SupernodalStructure,
751    /// The elimination tree.
752    #[allow(dead_code)]
753    etree: EliminationTree,
754    /// Matrix dimension.
755    n: usize,
756}
757
758impl MultifrontalLUSolver {
759    /// Symbolic factorization.
760    pub fn symbolic(
761        row_offsets: &[usize],
762        col_indices: &[usize],
763        n: usize,
764    ) -> Result<Self, SolverError> {
765        if row_offsets.len() != n + 1 {
766            return Err(SolverError::DimensionMismatch(format!(
767                "row_offsets length {} != n+1 = {}",
768                row_offsets.len(),
769                n + 1
770            )));
771        }
772
773        let etree = EliminationTree::from_csr(row_offsets, col_indices, n);
774        let structure = SupernodalStructure::from_etree(&etree, row_offsets, col_indices);
775
776        Ok(Self {
777            l_factor: Vec::new(),
778            u_factor: Vec::new(),
779            perm: Vec::new(),
780            factored: false,
781            structure,
782            etree,
783            n,
784        })
785    }
786
787    /// Numeric factorization with partial pivoting.
788    ///
789    /// Assembles the full matrix and performs LU decomposition with
790    /// partial pivoting (GETRF-like).
791    pub fn numeric(
792        &mut self,
793        row_offsets: &[usize],
794        col_indices: &[usize],
795        values: &[f64],
796    ) -> Result<(), SolverError> {
797        let n = self.n;
798
799        // Build full dense matrix from CSR
800        let mut a = vec![0.0f64; n * n];
801        for i in 0..n {
802            let rs = row_offsets.get(i).copied().unwrap_or(0);
803            let re = row_offsets.get(i + 1).copied().unwrap_or(rs);
804            for idx in rs..re {
805                let j = match col_indices.get(idx) {
806                    Some(&c) => c,
807                    None => continue,
808                };
809                let val = values.get(idx).copied().unwrap_or(0.0);
810                if i < n && j < n {
811                    a[i + j * n] = val;
812                }
813            }
814        }
815
816        // LU factorization with partial pivoting (column-major)
817        let mut perm: Vec<usize> = (0..n).collect();
818
819        for k in 0..n {
820            // Find pivot
821            let mut max_val = 0.0f64;
822            let mut max_row = k;
823            for i in k..n {
824                let val = a[i + k * n].abs();
825                if val > max_val {
826                    max_val = val;
827                    max_row = i;
828                }
829            }
830
831            // Swap rows
832            if max_row != k {
833                perm.swap(k, max_row);
834                for j in 0..n {
835                    a.swap(k + j * n, max_row + j * n);
836                }
837            }
838
839            let pivot = a[k + k * n];
840            if pivot.abs() < 1e-300 {
841                continue; // zero pivot, skip
842            }
843
844            // Compute L column
845            for i in (k + 1)..n {
846                a[i + k * n] /= pivot;
847            }
848
849            // Update trailing submatrix
850            for j in (k + 1)..n {
851                let u_kj = a[k + j * n];
852                for i in (k + 1)..n {
853                    a[i + j * n] -= a[i + k * n] * u_kj;
854                }
855            }
856        }
857
858        // Extract L and U from the combined a matrix
859        let mut l = vec![0.0f64; n * n];
860        let mut u = vec![0.0f64; n * n];
861        for j in 0..n {
862            for i in 0..n {
863                if i > j {
864                    l[i + j * n] = a[i + j * n];
865                } else if i == j {
866                    l[i + j * n] = 1.0;
867                    u[i + j * n] = a[i + j * n];
868                } else {
869                    u[i + j * n] = a[i + j * n];
870                }
871            }
872        }
873
874        self.l_factor = l;
875        self.u_factor = u;
876        self.perm = perm;
877        self.factored = true;
878        Ok(())
879    }
880
881    /// Solve `A * x = b` using the LU factorization.
882    ///
883    /// Applies `P * b`, then forward solve `L * y = P * b`, then backward solve `U * x = y`.
884    pub fn solve(&self, rhs: &[f64]) -> Result<Vec<f64>, SolverError> {
885        if !self.factored {
886            return Err(SolverError::InternalError(
887                "numeric factorization not performed".to_string(),
888            ));
889        }
890        let n = self.n;
891        if rhs.len() != n {
892            return Err(SolverError::DimensionMismatch(format!(
893                "rhs length {} != n = {}",
894                rhs.len(),
895                n
896            )));
897        }
898
899        // Apply permutation: pb[k] = rhs[perm[k]]
900        // But perm records the row swaps applied during factorization.
901        // We need to apply the same swaps to rhs.
902        // Apply permutation swaps in the same order as factorization
903        // perm[k] tells us that row k was swapped with row perm[k]
904        // but we stored the cumulative result. We need to replay swaps.
905        // Actually, the perm array after the loop records the final
906        // position of each row. We need a different approach.
907        // Let's rebuild: the factorization swapped rows k and max_row
908        // at each step. perm tracks the original row indices.
909        // To apply P*b, we just need: pb[k] = b[original_row_of_k]
910        // But perm was computed as successive swaps, so we need to replay.
911
912        // Simpler: during factorization we did perm.swap(k, max_row).
913        // perm[k] after all swaps = the original row that ended up at row k.
914        // So P*b means: pb[k] = b[perm[k]]
915        let mut pb = vec![0.0f64; n];
916        for k in 0..n {
917            pb[k] = rhs[self.perm[k]];
918        }
919
920        // Forward solve: L * y = pb
921        let mut x = pb;
922        for k in 0..n {
923            for i in (k + 1)..n {
924                x[i] -= self.l_factor[i + k * n] * x[k];
925            }
926        }
927
928        // Backward solve: U * z = y
929        for k in (0..n).rev() {
930            let u_kk = self.u_factor[k + k * n];
931            if u_kk.abs() < 1e-300 {
932                return Err(SolverError::SingularMatrix);
933            }
934            x[k] /= u_kk;
935            for i in 0..k {
936                x[i] -= self.u_factor[i + k * n] * x[k];
937            }
938        }
939
940        Ok(x)
941    }
942}
943
944// ---------------------------------------------------------------------------
945// Convenience functions
946// ---------------------------------------------------------------------------
947
948/// Solve a symmetric positive definite sparse system `A * x = b` via supernodal Cholesky.
949///
950/// Takes CSR format (lower triangle only). Convenience wrapper that performs
951/// symbolic + numeric factorization + solve in one call.
952pub fn sparse_cholesky_solve(
953    row_offsets: &[usize],
954    col_indices: &[usize],
955    values: &[f64],
956    n: usize,
957    rhs: &[f64],
958) -> Result<Vec<f64>, SolverError> {
959    let mut solver = SupernodalCholeskySolver::symbolic(row_offsets, col_indices, n)?;
960    solver.numeric(row_offsets, col_indices, values)?;
961    solver.solve(rhs)
962}
963
964/// Solve a general sparse system `A * x = b` via multifrontal LU.
965///
966/// Takes CSR format (full matrix, both triangles). Convenience wrapper that
967/// performs symbolic + numeric factorization + solve in one call.
968pub fn sparse_lu_solve(
969    row_offsets: &[usize],
970    col_indices: &[usize],
971    values: &[f64],
972    n: usize,
973    rhs: &[f64],
974) -> Result<Vec<f64>, SolverError> {
975    let mut solver = MultifrontalLUSolver::symbolic(row_offsets, col_indices, n)?;
976    solver.numeric(row_offsets, col_indices, values)?;
977    solver.solve(rhs)
978}
979
980// ---------------------------------------------------------------------------
981// Tests
982// ---------------------------------------------------------------------------
983
984#[cfg(test)]
985mod tests {
986    use super::*;
987
988    // Helper: build lower triangle CSR for a 3x3 SPD matrix
989    // A = [[4, 1, 0],
990    //      [1, 4, 1],
991    //      [0, 1, 4]]
992    fn spd_3x3_lower() -> (Vec<usize>, Vec<usize>, Vec<f64>, usize) {
993        let row_offsets = vec![0, 1, 3, 5];
994        let col_indices = vec![0, 0, 1, 1, 2];
995        let values = vec![4.0, 1.0, 4.0, 1.0, 4.0];
996        (row_offsets, col_indices, values, 3)
997    }
998
999    // Helper: build lower triangle CSR for a 5x5 tridiagonal SPD matrix
1000    fn spd_5x5_tridiag_lower() -> (Vec<usize>, Vec<usize>, Vec<f64>, usize) {
1001        let row_offsets = vec![0, 1, 3, 5, 7, 9];
1002        let col_indices = vec![0, 0, 1, 1, 2, 2, 3, 3, 4];
1003        let values = vec![4.0, 1.0, 4.0, 1.0, 4.0, 1.0, 4.0, 1.0, 4.0];
1004        (row_offsets, col_indices, values, 5)
1005    }
1006
1007    // Helper: build lower triangle CSR for identity matrix
1008    fn identity_lower(n: usize) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
1009        let row_offsets: Vec<usize> = (0..=n).collect();
1010        let col_indices: Vec<usize> = (0..n).collect();
1011        let values = vec![1.0; n];
1012        (row_offsets, col_indices, values)
1013    }
1014
1015    // Helper: compute ||Ax - b|| for lower triangle CSR (symmetrize)
1016    fn residual_norm_symmetric(
1017        row_offsets: &[usize],
1018        col_indices: &[usize],
1019        values: &[f64],
1020        n: usize,
1021        x: &[f64],
1022        b: &[f64],
1023    ) -> f64 {
1024        let mut ax = vec![0.0; n];
1025        for i in 0..n {
1026            let rs = row_offsets[i];
1027            let re = row_offsets[i + 1];
1028            for idx in rs..re {
1029                let j = col_indices[idx];
1030                let v = values[idx];
1031                ax[i] += v * x[j];
1032                if i != j {
1033                    ax[j] += v * x[i];
1034                }
1035            }
1036        }
1037        let mut norm_sq = 0.0;
1038        for i in 0..n {
1039            let diff = ax[i] - b[i];
1040            norm_sq += diff * diff;
1041        }
1042        norm_sq.sqrt()
1043    }
1044
1045    #[test]
1046    fn test_elimination_tree_simple() {
1047        let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1048        let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1049
1050        assert_eq!(etree.size(), 3);
1051        assert_eq!(etree.parent_of(0), Some(1));
1052        assert_eq!(etree.parent_of(1), Some(2));
1053        assert_eq!(etree.parent_of(2), None);
1054    }
1055
1056    #[test]
1057    fn test_postorder_traversal() {
1058        let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1059        let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1060
1061        let postorder = etree.postorder_traversal();
1062        assert_eq!(postorder.len(), 3);
1063        assert_eq!(postorder, &[0, 1, 2]);
1064    }
1065
1066    #[test]
1067    fn test_subtree_size() {
1068        let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1069        let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1070
1071        assert_eq!(etree.subtree_size(2), 3);
1072        assert_eq!(etree.subtree_size(1), 2);
1073        assert_eq!(etree.subtree_size(0), 1);
1074    }
1075
1076    #[test]
1077    fn test_supernode_detection_diagonal() {
1078        let n = 4;
1079        let (row_offsets, col_indices, _) = identity_lower(n);
1080        let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1081        let structure = SupernodalStructure::from_etree(&etree, &row_offsets, &col_indices);
1082
1083        assert_eq!(structure.supernodes.len(), n);
1084        for sn in &structure.supernodes {
1085            assert_eq!(sn.width(), 1);
1086        }
1087    }
1088
1089    #[test]
1090    fn test_supernodal_cholesky_3x3() {
1091        let (row_offsets, col_indices, values, n) = spd_3x3_lower();
1092
1093        let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1094            .expect("symbolic should succeed");
1095        solver
1096            .numeric(&row_offsets, &col_indices, &values)
1097            .expect("numeric should succeed");
1098
1099        assert!(solver.factored);
1100    }
1101
1102    #[test]
1103    fn test_supernodal_cholesky_5x5_tridiag() {
1104        let (row_offsets, col_indices, values, n) = spd_5x5_tridiag_lower();
1105
1106        let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1107            .expect("symbolic should succeed");
1108        solver
1109            .numeric(&row_offsets, &col_indices, &values)
1110            .expect("numeric should succeed");
1111
1112        assert!(solver.factored);
1113    }
1114
1115    #[test]
1116    fn test_cholesky_solve_accuracy() {
1117        let (row_offsets, col_indices, values, n) = spd_3x3_lower();
1118        let rhs = vec![5.0, 6.0, 5.0]; // A * [1, 1, 1] = [5, 6, 5]
1119
1120        let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1121            .expect("symbolic should succeed");
1122        solver
1123            .numeric(&row_offsets, &col_indices, &values)
1124            .expect("numeric should succeed");
1125        let x = solver.solve(&rhs).expect("solve should succeed");
1126
1127        let residual = residual_norm_symmetric(&row_offsets, &col_indices, &values, n, &x, &rhs);
1128        assert!(
1129            residual < 1e-10,
1130            "residual {residual:.3e} exceeds tolerance 1e-10"
1131        );
1132    }
1133
1134    #[test]
1135    fn test_lu_factorization_3x3() {
1136        // Full CSR for a 3x3 SPD matrix (used as general matrix)
1137        let row_offsets = vec![0, 2, 5, 7];
1138        let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
1139        let values = vec![2.0, 1.0, 1.0, 3.0, 1.0, 1.0, 2.0];
1140        let n = 3;
1141
1142        let mut solver = MultifrontalLUSolver::symbolic(&row_offsets, &col_indices, n)
1143            .expect("symbolic should succeed");
1144        solver
1145            .numeric(&row_offsets, &col_indices, &values)
1146            .expect("numeric should succeed");
1147
1148        assert!(solver.factored);
1149    }
1150
1151    #[test]
1152    fn test_lu_solve_accuracy() {
1153        // A = [[2, 1, 0],
1154        //      [1, 3, 1],
1155        //      [0, 1, 2]]
1156        // b = A * [1, 1, 1] = [3, 5, 3]
1157        let row_offsets = vec![0, 2, 5, 7];
1158        let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
1159        let values = vec![2.0, 1.0, 1.0, 3.0, 1.0, 1.0, 2.0];
1160        let n = 3;
1161        let rhs = vec![3.0, 5.0, 3.0];
1162
1163        let mut solver = MultifrontalLUSolver::symbolic(&row_offsets, &col_indices, n)
1164            .expect("symbolic should succeed");
1165        solver
1166            .numeric(&row_offsets, &col_indices, &values)
1167            .expect("numeric should succeed");
1168        let x = solver.solve(&rhs).expect("solve should succeed");
1169
1170        let mut ax = vec![0.0; n];
1171        for i in 0..n {
1172            for idx in row_offsets[i]..row_offsets[i + 1] {
1173                ax[i] += values[idx] * x[col_indices[idx]];
1174            }
1175        }
1176        let residual: f64 = ax
1177            .iter()
1178            .zip(rhs.iter())
1179            .map(|(a, b)| (a - b).powi(2))
1180            .sum::<f64>()
1181            .sqrt();
1182        assert!(
1183            residual < 1e-10,
1184            "LU solve residual {residual:.3e} exceeds tolerance"
1185        );
1186    }
1187
1188    #[test]
1189    fn test_symbolic_factorization_reuse() {
1190        let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1191
1192        let sym = SymbolicFactorization::compute(&row_offsets, &col_indices, n)
1193            .expect("symbolic should succeed");
1194
1195        assert!(sym.nnz_l > 0);
1196        assert_eq!(sym.nnz_l, sym.nnz_u);
1197        assert_eq!(sym.etree.size(), n);
1198        assert!(!sym.structure.supernodes.is_empty());
1199    }
1200
1201    #[test]
1202    fn test_column_counts() {
1203        let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1204        let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1205        let counts = column_counts(&row_offsets, &col_indices, &etree);
1206
1207        assert_eq!(counts.len(), 3);
1208        // Column 0: diagonal + row 1 = 2 entries
1209        assert_eq!(counts[0], 2);
1210        // Column 1: diagonal + row 2 = 2 entries
1211        assert_eq!(counts[1], 2);
1212        // Column 2: just diagonal = 1 entry
1213        assert_eq!(counts[2], 1);
1214    }
1215
1216    #[test]
1217    fn test_sparse_cholesky_solve_convenience() {
1218        let (row_offsets, col_indices, values, n) = spd_3x3_lower();
1219        let rhs = vec![5.0, 6.0, 5.0];
1220
1221        let x = sparse_cholesky_solve(&row_offsets, &col_indices, &values, n, &rhs)
1222            .expect("convenience solve should succeed");
1223
1224        let residual = residual_norm_symmetric(&row_offsets, &col_indices, &values, n, &x, &rhs);
1225        assert!(
1226            residual < 1e-10,
1227            "convenience solve residual {residual:.3e} too large"
1228        );
1229    }
1230
1231    #[test]
1232    fn test_sparse_lu_solve_convenience() {
1233        let row_offsets = vec![0, 2, 5, 7];
1234        let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
1235        let values = vec![2.0, 1.0, 1.0, 3.0, 1.0, 1.0, 2.0];
1236        let n = 3;
1237        let rhs = vec![3.0, 5.0, 3.0];
1238
1239        let x = sparse_lu_solve(&row_offsets, &col_indices, &values, n, &rhs)
1240            .expect("LU convenience solve should succeed");
1241
1242        let mut ax = vec![0.0; n];
1243        for i in 0..n {
1244            for idx in row_offsets[i]..row_offsets[i + 1] {
1245                ax[i] += values[idx] * x[col_indices[idx]];
1246            }
1247        }
1248        let residual: f64 = ax
1249            .iter()
1250            .zip(rhs.iter())
1251            .map(|(a, b)| (a - b).powi(2))
1252            .sum::<f64>()
1253            .sqrt();
1254        assert!(
1255            residual < 1e-10,
1256            "LU convenience solve residual {residual:.3e} too large"
1257        );
1258    }
1259
1260    #[test]
1261    fn test_non_spd_cholesky_failure() {
1262        let row_offsets = vec![0, 1, 3, 5];
1263        let col_indices = vec![0, 0, 1, 1, 2];
1264        let values = vec![-4.0, 1.0, 4.0, 1.0, 4.0]; // A[0,0] = -4
1265        let n = 3;
1266
1267        let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1268            .expect("symbolic should succeed");
1269        let result = solver.numeric(&row_offsets, &col_indices, &values);
1270
1271        assert!(result.is_err());
1272        assert!(matches!(
1273            result.unwrap_err(),
1274            SolverError::NotPositiveDefinite
1275        ));
1276    }
1277
1278    #[test]
1279    fn test_singular_matrix_lu() {
1280        // Singular: row 1 = row 0
1281        let row_offsets = vec![0, 2, 4, 5];
1282        let col_indices = vec![0, 1, 0, 1, 2];
1283        let values = vec![1.0, 2.0, 1.0, 2.0, 1.0];
1284        let n = 3;
1285        let rhs = vec![1.0, 1.0, 1.0];
1286
1287        let mut solver = MultifrontalLUSolver::symbolic(&row_offsets, &col_indices, n)
1288            .expect("symbolic should succeed");
1289        solver
1290            .numeric(&row_offsets, &col_indices, &values)
1291            .expect("numeric may succeed with zero pivot stored");
1292        let result = solver.solve(&rhs);
1293
1294        assert!(result.is_err());
1295    }
1296
1297    #[test]
1298    fn test_identity_factorization() {
1299        let n = 4;
1300        let (row_offsets, col_indices, values) = identity_lower(n);
1301
1302        let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1303            .expect("symbolic should succeed");
1304        solver
1305            .numeric(&row_offsets, &col_indices, &values)
1306            .expect("numeric should succeed on identity");
1307
1308        let rhs = vec![1.0, 2.0, 3.0, 4.0];
1309        let x = solver.solve(&rhs).expect("solve should succeed");
1310
1311        for i in 0..n {
1312            assert!(
1313                (x[i] - rhs[i]).abs() < 1e-14,
1314                "identity solve failed at index {i}: got {} expected {}",
1315                x[i],
1316                rhs[i]
1317            );
1318        }
1319    }
1320
1321    #[test]
1322    fn test_nnz_factor_count() {
1323        let (row_offsets, col_indices, values, n) = spd_3x3_lower();
1324
1325        let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1326            .expect("symbolic should succeed");
1327        solver
1328            .numeric(&row_offsets, &col_indices, &values)
1329            .expect("numeric should succeed");
1330
1331        let nnz = solver.nnz_factor();
1332        // For 3x3 tridiagonal:
1333        // L has entries at (0,0), (1,0), (1,1), (2,1), (2,2) = 5 entries
1334        assert!(nnz >= 5, "nnz_factor = {nnz}, expected >= 5");
1335    }
1336}