1use crate::error::SolverError;
13
14#[derive(Debug, Clone)]
24pub struct EliminationTree {
25 parent: Vec<Option<usize>>,
27 children: Vec<Vec<usize>>,
29 postorder: Vec<usize>,
31 n: usize,
33}
34
35impl EliminationTree {
36 pub fn from_csr(row_offsets: &[usize], col_indices: &[usize], n: usize) -> Self {
40 let mut parent: Vec<Option<usize>> = vec![None; n];
41 let mut ancestor: Vec<usize> = (0..n).collect();
42
43 for i in 0..n {
44 let row_start = row_offsets.get(i).copied().unwrap_or(0);
45 let row_end = row_offsets.get(i + 1).copied().unwrap_or(row_start);
46
47 for idx in row_start..row_end {
48 let j = match col_indices.get(idx) {
49 Some(&c) if c < i => c,
50 _ => continue,
51 };
52
53 let mut r = j;
54 while ancestor[r] != r {
55 let next = ancestor[r];
56 ancestor[r] = i;
57 r = next;
58 }
59 if r != i && parent[r].is_none() {
60 parent[r] = Some(i);
61 ancestor[r] = i;
62 }
63 }
64 }
65
66 let mut children: Vec<Vec<usize>> = vec![Vec::new(); n];
67 for (node, par) in parent.iter().enumerate() {
68 if let Some(p) = par {
69 children[*p].push(node);
70 }
71 }
72
73 let postorder = Self::compute_postorder(&parent, &children, n);
74
75 Self {
76 parent,
77 children,
78 postorder,
79 n,
80 }
81 }
82
83 fn compute_postorder(
84 parent: &[Option<usize>],
85 children: &[Vec<usize>],
86 n: usize,
87 ) -> Vec<usize> {
88 let mut order = Vec::with_capacity(n);
89 let mut visited = vec![false; n];
90
91 let roots: Vec<usize> = (0..n).filter(|&i| parent[i].is_none()).collect();
92
93 for root in roots {
94 let mut stack: Vec<(usize, bool)> = vec![(root, false)];
95 while let Some((node, expanded)) = stack.pop() {
96 if expanded {
97 order.push(node);
98 visited[node] = true;
99 } else {
100 stack.push((node, true));
101 for &child in children[node].iter().rev() {
102 if !visited[child] {
103 stack.push((child, false));
104 }
105 }
106 }
107 }
108 }
109
110 order
111 }
112
113 pub fn postorder_traversal(&self) -> &[usize] {
115 &self.postorder
116 }
117
118 pub fn subtree_size(&self, node: usize) -> usize {
120 if node >= self.n {
121 return 0;
122 }
123 let mut size = 1usize;
124 let mut stack = vec![node];
125 while let Some(cur) = stack.pop() {
126 if let Some(kids) = self.children.get(cur) {
127 for &child in kids {
128 size += 1;
129 stack.push(child);
130 }
131 }
132 }
133 size
134 }
135
136 pub fn size(&self) -> usize {
138 self.n
139 }
140
141 pub fn parent_of(&self, node: usize) -> Option<usize> {
143 self.parent.get(node).copied().flatten()
144 }
145
146 pub fn children_of(&self, node: usize) -> &[usize] {
148 self.children.get(node).map_or(&[], |v| v.as_slice())
149 }
150}
151
152pub fn column_counts(
161 row_offsets: &[usize],
162 col_indices: &[usize],
163 etree: &EliminationTree,
164) -> Vec<usize> {
165 let n = etree.size();
166 let mut counts = vec![1usize; n];
167
168 let mut col_rows: Vec<Vec<usize>> = vec![Vec::new(); n];
170 for i in 0..n {
171 let rs = row_offsets.get(i).copied().unwrap_or(0);
172 let re = row_offsets.get(i + 1).copied().unwrap_or(rs);
173 for idx in rs..re {
174 if let Some(&j) = col_indices.get(idx) {
175 if j < i {
176 col_rows[j].push(i);
177 }
178 }
179 }
180 }
181
182 let mut l_rows: Vec<Vec<usize>> = vec![Vec::new(); n];
184 for &node in etree.postorder_traversal() {
185 let mut rows: Vec<usize> = col_rows[node].clone();
186
187 for &child in etree.children_of(node) {
188 for &r in &l_rows[child] {
189 if r > node {
190 rows.push(r);
191 }
192 }
193 }
194
195 rows.sort_unstable();
196 rows.dedup();
197 counts[node] = 1 + rows.len();
198 l_rows[node] = rows;
199 }
200
201 counts
202}
203
204#[derive(Debug, Clone)]
210pub struct Supernode {
211 pub start: usize,
213 pub end: usize,
215 pub columns: Vec<usize>,
217 pub dense_block: Vec<f64>,
219}
220
221impl Supernode {
222 pub fn width(&self) -> usize {
224 self.end - self.start
225 }
226
227 pub fn nrows(&self) -> usize {
229 self.columns.len()
230 }
231}
232
233#[derive(Debug, Clone)]
239pub struct SupernodalStructure {
240 pub supernodes: Vec<Supernode>,
242 pub membership: Vec<usize>,
244}
245
246impl SupernodalStructure {
247 pub fn from_etree(
249 etree: &EliminationTree,
250 row_offsets: &[usize],
251 col_indices: &[usize],
252 ) -> Self {
253 let n = etree.size();
254 let col_cnts = column_counts(row_offsets, col_indices, etree);
255
256 let mut is_start = vec![true; n];
258 for j in 0..n.saturating_sub(1) {
259 if etree.parent_of(j) == Some(j + 1)
260 && col_cnts[j + 1] + 1 == col_cnts[j]
261 && etree.children_of(j + 1).len() <= 1
262 {
263 is_start[j + 1] = false;
264 }
265 }
266
267 let mut col_rows: Vec<Vec<usize>> = vec![Vec::new(); n];
269 for i in 0..n {
270 let rs = row_offsets.get(i).copied().unwrap_or(0);
271 let re = row_offsets.get(i + 1).copied().unwrap_or(rs);
272 for idx in rs..re {
273 if let Some(&j) = col_indices.get(idx) {
274 if j < i {
275 col_rows[j].push(i);
276 }
277 }
278 }
279 }
280
281 let mut l_rows: Vec<Vec<usize>> = vec![Vec::new(); n];
283 for &node in etree.postorder_traversal() {
284 let mut rows: Vec<usize> = col_rows[node].clone();
285 for &child in etree.children_of(node) {
286 for &r in &l_rows[child] {
287 if r > node {
288 rows.push(r);
289 }
290 }
291 }
292 rows.sort_unstable();
293 rows.dedup();
294 l_rows[node] = rows;
295 }
296
297 let mut supernodes = Vec::new();
299 let mut membership = vec![0usize; n];
300
301 let mut i = 0;
302 while i < n {
303 let start = i;
304 let mut end = i + 1;
305 while end < n && !is_start[end] {
306 end += 1;
307 }
308
309 let mut rows: Vec<usize> = (start..end).collect();
311 for &r in &l_rows[start] {
314 if r >= end {
315 rows.push(r);
316 }
317 }
318 for l_row_set in l_rows.iter().take(end).skip(start + 1) {
320 for &r in l_row_set {
321 if r >= end && !rows.contains(&r) {
322 rows.push(r);
323 }
324 }
325 }
326 rows.sort_unstable();
327 rows.dedup();
328
329 let nrows = rows.len();
330 let ncols = end - start;
331
332 let sn_idx = supernodes.len();
333 for m in membership.iter_mut().take(end).skip(start) {
334 *m = sn_idx;
335 }
336
337 supernodes.push(Supernode {
338 start,
339 end,
340 columns: rows,
341 dense_block: vec![0.0; nrows * ncols],
342 });
343
344 i = end;
345 }
346
347 Self {
348 supernodes,
349 membership,
350 }
351 }
352}
353
354#[derive(Debug, Clone)]
361pub struct SymbolicFactorization {
362 pub etree: EliminationTree,
364 pub structure: SupernodalStructure,
366 pub nnz_l: usize,
368 pub nnz_u: usize,
370}
371
372impl SymbolicFactorization {
373 pub fn compute(
375 row_offsets: &[usize],
376 col_indices: &[usize],
377 n: usize,
378 ) -> Result<Self, SolverError> {
379 if row_offsets.len() != n + 1 {
380 return Err(SolverError::DimensionMismatch(format!(
381 "row_offsets length {} != n+1 = {}",
382 row_offsets.len(),
383 n + 1
384 )));
385 }
386
387 let etree = EliminationTree::from_csr(row_offsets, col_indices, n);
388 let structure = SupernodalStructure::from_etree(&etree, row_offsets, col_indices);
389
390 let nnz_l: usize = structure
391 .supernodes
392 .iter()
393 .map(|sn| sn.nrows() * sn.width())
394 .sum();
395
396 Ok(Self {
397 etree,
398 structure,
399 nnz_l,
400 nnz_u: nnz_l,
401 })
402 }
403}
404
405#[derive(Debug, Clone)]
414pub struct SupernodalCholeskySolver {
415 structure: SupernodalStructure,
417 factored: bool,
419 etree: EliminationTree,
421 n: usize,
423}
424
425impl SupernodalCholeskySolver {
426 pub fn symbolic(
428 row_offsets: &[usize],
429 col_indices: &[usize],
430 n: usize,
431 ) -> Result<Self, SolverError> {
432 if row_offsets.len() != n + 1 {
433 return Err(SolverError::DimensionMismatch(format!(
434 "row_offsets length {} != n+1 = {}",
435 row_offsets.len(),
436 n + 1
437 )));
438 }
439
440 let etree = EliminationTree::from_csr(row_offsets, col_indices, n);
441 let structure = SupernodalStructure::from_etree(&etree, row_offsets, col_indices);
442
443 Ok(Self {
444 structure,
445 factored: false,
446 etree,
447 n,
448 })
449 }
450
451 pub fn numeric(
456 &mut self,
457 row_offsets: &[usize],
458 col_indices: &[usize],
459 values: &[f64],
460 ) -> Result<(), SolverError> {
461 let n = self.n;
462
463 let mut dense = vec![0.0f64; n * n];
465 for i in 0..n {
466 let rs = row_offsets.get(i).copied().unwrap_or(0);
467 let re = row_offsets.get(i + 1).copied().unwrap_or(rs);
468 for idx in rs..re {
469 let j = match col_indices.get(idx) {
470 Some(&c) => c,
471 None => continue,
472 };
473 let val = values.get(idx).copied().unwrap_or(0.0);
474 if i < n && j < n {
475 dense[i + j * n] = val;
476 dense[j + i * n] = val; }
478 }
479 }
480
481 for sn in &mut self.structure.supernodes {
483 for v in &mut sn.dense_block {
484 *v = 0.0;
485 }
486 }
487
488 for sn in &mut self.structure.supernodes {
490 let ncols = sn.width();
491 let nrows = sn.nrows();
492 for lc in 0..ncols {
493 let gc = sn.start + lc;
494 for (lr, &gr) in sn.columns.iter().enumerate() {
495 if gr < n && gc < n {
496 sn.dense_block[lr + lc * nrows] = dense[gr + gc * n];
497 }
498 }
499 }
500 }
501
502 let postorder: Vec<usize> = self.etree.postorder_traversal().to_vec();
504 let num_supernodes = self.structure.supernodes.len();
505 let mut processed = vec![false; num_supernodes];
506
507 for &node in &postorder {
508 let sn_idx = match self.structure.membership.get(node) {
509 Some(&idx) if idx < num_supernodes => idx,
510 _ => continue,
511 };
512
513 if processed[sn_idx] {
514 continue;
515 }
516 processed[sn_idx] = true;
517
518 self.factor_supernode(sn_idx)?;
519 }
520
521 self.factored = true;
522 Ok(())
523 }
524
525 fn factor_supernode(&mut self, sn_idx: usize) -> Result<(), SolverError> {
526 let sn = match self.structure.supernodes.get(sn_idx) {
527 Some(s) => s,
528 None => {
529 return Err(SolverError::InternalError(
530 "invalid supernode index".to_string(),
531 ));
532 }
533 };
534
535 let ncols = sn.width();
536 let nrows = sn.nrows();
537
538 if ncols == 0 || nrows == 0 {
539 return Ok(());
540 }
541
542 let mut block = self.structure.supernodes[sn_idx].dense_block.clone();
543
544 for k in 0..ncols {
546 let diag_idx = k + k * nrows;
547 let diag_val = match block.get(diag_idx) {
548 Some(&v) => v,
549 None => {
550 return Err(SolverError::InternalError(
551 "dense block index out of bounds".to_string(),
552 ));
553 }
554 };
555
556 if diag_val <= 0.0 {
557 return Err(SolverError::NotPositiveDefinite);
558 }
559 let l_kk = diag_val.sqrt();
560 block[diag_idx] = l_kk;
561 let l_kk_inv = 1.0 / l_kk;
562
563 for i in (k + 1)..nrows {
565 block[i + k * nrows] *= l_kk_inv;
566 }
567
568 for j in (k + 1)..ncols {
570 let l_jk = block[j + k * nrows];
571 for i in j..nrows {
572 block[i + j * nrows] -= block[i + k * nrows] * l_jk;
573 }
574 }
575 }
576
577 if nrows > ncols {
579 let off_rows: Vec<usize> = self.structure.supernodes[sn_idx].columns[ncols..].to_vec();
580 let off_nrows = nrows - ncols;
581
582 let mut update = vec![0.0f64; off_nrows * off_nrows];
584 for k in 0..ncols {
585 for i in 0..off_nrows {
586 let l_ik = block[(ncols + i) + k * nrows];
587 for j in 0..=i {
588 let l_jk = block[(ncols + j) + k * nrows];
589 update[i + j * off_nrows] += l_ik * l_jk;
590 }
591 }
592 }
593
594 for i in 0..off_nrows {
596 for j in 0..=i {
597 let row_i = off_rows[i];
598 let row_j = off_rows[j];
599 let target_sn_idx = match self.structure.membership.get(row_j) {
600 Some(&idx) => idx,
601 None => continue,
602 };
603 let target = match self.structure.supernodes.get_mut(target_sn_idx) {
604 Some(s) => s,
605 None => continue,
606 };
607 let local_col = row_j - target.start;
608 if local_col >= target.width() {
609 continue;
610 }
611 if let Some(local_row) = target.columns.iter().position(|&r| r == row_i) {
612 let tnrows = target.nrows();
613 if let Some(entry) =
614 target.dense_block.get_mut(local_row + local_col * tnrows)
615 {
616 *entry -= update[i + j * off_nrows];
617 }
618 }
619 if i != j {
621 let target2 = match self
622 .structure
623 .supernodes
624 .get_mut(*self.structure.membership.get(row_i).unwrap_or(&0))
625 {
626 Some(s) => s,
627 None => continue,
628 };
629 let local_col2 = row_i - target2.start;
630 if local_col2 >= target2.width() {
631 continue;
632 }
633 if let Some(local_row2) = target2.columns.iter().position(|&r| r == row_j) {
634 let tnrows2 = target2.nrows();
635 if let Some(entry2) = target2
636 .dense_block
637 .get_mut(local_row2 + local_col2 * tnrows2)
638 {
639 *entry2 -= update[i + j * off_nrows];
640 }
641 }
642 }
643 }
644 }
645 }
646
647 self.structure.supernodes[sn_idx].dense_block = block;
648 Ok(())
649 }
650
651 pub fn solve(&self, rhs: &[f64]) -> Result<Vec<f64>, SolverError> {
655 if !self.factored {
656 return Err(SolverError::InternalError(
657 "numeric factorization not performed".to_string(),
658 ));
659 }
660 if rhs.len() != self.n {
661 return Err(SolverError::DimensionMismatch(format!(
662 "rhs length {} != n = {}",
663 rhs.len(),
664 self.n
665 )));
666 }
667
668 let mut x = rhs.to_vec();
669
670 for sn in &self.structure.supernodes {
672 let ncols = sn.width();
673 let nrows = sn.nrows();
674
675 for k in 0..ncols {
676 let l_kk = sn.dense_block[k + k * nrows];
677 if l_kk.abs() < 1e-300 {
678 return Err(SolverError::SingularMatrix);
679 }
680 let global_k = sn.columns[k];
681 x[global_k] /= l_kk;
682
683 let x_k = x[global_k];
684 for i in (k + 1)..nrows {
685 let global_i = sn.columns[i];
686 x[global_i] -= sn.dense_block[i + k * nrows] * x_k;
687 }
688 }
689 }
690
691 for sn in self.structure.supernodes.iter().rev() {
693 let ncols = sn.width();
694 let nrows = sn.nrows();
695
696 for k in (0..ncols).rev() {
697 let global_k = sn.columns[k];
698 for i in (k + 1)..nrows {
699 let global_i = sn.columns[i];
700 x[global_k] -= sn.dense_block[i + k * nrows] * x[global_i];
701 }
702
703 let l_kk = sn.dense_block[k + k * nrows];
704 if l_kk.abs() < 1e-300 {
705 return Err(SolverError::SingularMatrix);
706 }
707 x[global_k] /= l_kk;
708 }
709 }
710
711 Ok(x)
712 }
713
714 pub fn nnz_factor(&self) -> usize {
716 self.structure
717 .supernodes
718 .iter()
719 .map(|sn| {
720 let ncols = sn.width();
721 let nrows = sn.nrows();
722 let diag_nnz = ncols * (ncols + 1) / 2;
723 let offdiag_nnz = (nrows - ncols) * ncols;
724 diag_nnz + offdiag_nnz
725 })
726 .sum()
727 }
728}
729
730#[derive(Debug, Clone)]
739pub struct MultifrontalLUSolver {
740 l_factor: Vec<f64>,
742 u_factor: Vec<f64>,
744 perm: Vec<usize>,
746 factored: bool,
748 #[allow(dead_code)]
750 structure: SupernodalStructure,
751 #[allow(dead_code)]
753 etree: EliminationTree,
754 n: usize,
756}
757
758impl MultifrontalLUSolver {
759 pub fn symbolic(
761 row_offsets: &[usize],
762 col_indices: &[usize],
763 n: usize,
764 ) -> Result<Self, SolverError> {
765 if row_offsets.len() != n + 1 {
766 return Err(SolverError::DimensionMismatch(format!(
767 "row_offsets length {} != n+1 = {}",
768 row_offsets.len(),
769 n + 1
770 )));
771 }
772
773 let etree = EliminationTree::from_csr(row_offsets, col_indices, n);
774 let structure = SupernodalStructure::from_etree(&etree, row_offsets, col_indices);
775
776 Ok(Self {
777 l_factor: Vec::new(),
778 u_factor: Vec::new(),
779 perm: Vec::new(),
780 factored: false,
781 structure,
782 etree,
783 n,
784 })
785 }
786
787 pub fn numeric(
792 &mut self,
793 row_offsets: &[usize],
794 col_indices: &[usize],
795 values: &[f64],
796 ) -> Result<(), SolverError> {
797 let n = self.n;
798
799 let mut a = vec![0.0f64; n * n];
801 for i in 0..n {
802 let rs = row_offsets.get(i).copied().unwrap_or(0);
803 let re = row_offsets.get(i + 1).copied().unwrap_or(rs);
804 for idx in rs..re {
805 let j = match col_indices.get(idx) {
806 Some(&c) => c,
807 None => continue,
808 };
809 let val = values.get(idx).copied().unwrap_or(0.0);
810 if i < n && j < n {
811 a[i + j * n] = val;
812 }
813 }
814 }
815
816 let mut perm: Vec<usize> = (0..n).collect();
818
819 for k in 0..n {
820 let mut max_val = 0.0f64;
822 let mut max_row = k;
823 for i in k..n {
824 let val = a[i + k * n].abs();
825 if val > max_val {
826 max_val = val;
827 max_row = i;
828 }
829 }
830
831 if max_row != k {
833 perm.swap(k, max_row);
834 for j in 0..n {
835 a.swap(k + j * n, max_row + j * n);
836 }
837 }
838
839 let pivot = a[k + k * n];
840 if pivot.abs() < 1e-300 {
841 continue; }
843
844 for i in (k + 1)..n {
846 a[i + k * n] /= pivot;
847 }
848
849 for j in (k + 1)..n {
851 let u_kj = a[k + j * n];
852 for i in (k + 1)..n {
853 a[i + j * n] -= a[i + k * n] * u_kj;
854 }
855 }
856 }
857
858 let mut l = vec![0.0f64; n * n];
860 let mut u = vec![0.0f64; n * n];
861 for j in 0..n {
862 for i in 0..n {
863 if i > j {
864 l[i + j * n] = a[i + j * n];
865 } else if i == j {
866 l[i + j * n] = 1.0;
867 u[i + j * n] = a[i + j * n];
868 } else {
869 u[i + j * n] = a[i + j * n];
870 }
871 }
872 }
873
874 self.l_factor = l;
875 self.u_factor = u;
876 self.perm = perm;
877 self.factored = true;
878 Ok(())
879 }
880
881 pub fn solve(&self, rhs: &[f64]) -> Result<Vec<f64>, SolverError> {
885 if !self.factored {
886 return Err(SolverError::InternalError(
887 "numeric factorization not performed".to_string(),
888 ));
889 }
890 let n = self.n;
891 if rhs.len() != n {
892 return Err(SolverError::DimensionMismatch(format!(
893 "rhs length {} != n = {}",
894 rhs.len(),
895 n
896 )));
897 }
898
899 let mut pb = vec![0.0f64; n];
916 for k in 0..n {
917 pb[k] = rhs[self.perm[k]];
918 }
919
920 let mut x = pb;
922 for k in 0..n {
923 for i in (k + 1)..n {
924 x[i] -= self.l_factor[i + k * n] * x[k];
925 }
926 }
927
928 for k in (0..n).rev() {
930 let u_kk = self.u_factor[k + k * n];
931 if u_kk.abs() < 1e-300 {
932 return Err(SolverError::SingularMatrix);
933 }
934 x[k] /= u_kk;
935 for i in 0..k {
936 x[i] -= self.u_factor[i + k * n] * x[k];
937 }
938 }
939
940 Ok(x)
941 }
942}
943
944pub fn sparse_cholesky_solve(
953 row_offsets: &[usize],
954 col_indices: &[usize],
955 values: &[f64],
956 n: usize,
957 rhs: &[f64],
958) -> Result<Vec<f64>, SolverError> {
959 let mut solver = SupernodalCholeskySolver::symbolic(row_offsets, col_indices, n)?;
960 solver.numeric(row_offsets, col_indices, values)?;
961 solver.solve(rhs)
962}
963
964pub fn sparse_lu_solve(
969 row_offsets: &[usize],
970 col_indices: &[usize],
971 values: &[f64],
972 n: usize,
973 rhs: &[f64],
974) -> Result<Vec<f64>, SolverError> {
975 let mut solver = MultifrontalLUSolver::symbolic(row_offsets, col_indices, n)?;
976 solver.numeric(row_offsets, col_indices, values)?;
977 solver.solve(rhs)
978}
979
980#[cfg(test)]
985mod tests {
986 use super::*;
987
988 fn spd_3x3_lower() -> (Vec<usize>, Vec<usize>, Vec<f64>, usize) {
993 let row_offsets = vec![0, 1, 3, 5];
994 let col_indices = vec![0, 0, 1, 1, 2];
995 let values = vec![4.0, 1.0, 4.0, 1.0, 4.0];
996 (row_offsets, col_indices, values, 3)
997 }
998
999 fn spd_5x5_tridiag_lower() -> (Vec<usize>, Vec<usize>, Vec<f64>, usize) {
1001 let row_offsets = vec![0, 1, 3, 5, 7, 9];
1002 let col_indices = vec![0, 0, 1, 1, 2, 2, 3, 3, 4];
1003 let values = vec![4.0, 1.0, 4.0, 1.0, 4.0, 1.0, 4.0, 1.0, 4.0];
1004 (row_offsets, col_indices, values, 5)
1005 }
1006
1007 fn identity_lower(n: usize) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
1009 let row_offsets: Vec<usize> = (0..=n).collect();
1010 let col_indices: Vec<usize> = (0..n).collect();
1011 let values = vec![1.0; n];
1012 (row_offsets, col_indices, values)
1013 }
1014
1015 fn residual_norm_symmetric(
1017 row_offsets: &[usize],
1018 col_indices: &[usize],
1019 values: &[f64],
1020 n: usize,
1021 x: &[f64],
1022 b: &[f64],
1023 ) -> f64 {
1024 let mut ax = vec![0.0; n];
1025 for i in 0..n {
1026 let rs = row_offsets[i];
1027 let re = row_offsets[i + 1];
1028 for idx in rs..re {
1029 let j = col_indices[idx];
1030 let v = values[idx];
1031 ax[i] += v * x[j];
1032 if i != j {
1033 ax[j] += v * x[i];
1034 }
1035 }
1036 }
1037 let mut norm_sq = 0.0;
1038 for i in 0..n {
1039 let diff = ax[i] - b[i];
1040 norm_sq += diff * diff;
1041 }
1042 norm_sq.sqrt()
1043 }
1044
1045 #[test]
1046 fn test_elimination_tree_simple() {
1047 let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1048 let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1049
1050 assert_eq!(etree.size(), 3);
1051 assert_eq!(etree.parent_of(0), Some(1));
1052 assert_eq!(etree.parent_of(1), Some(2));
1053 assert_eq!(etree.parent_of(2), None);
1054 }
1055
1056 #[test]
1057 fn test_postorder_traversal() {
1058 let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1059 let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1060
1061 let postorder = etree.postorder_traversal();
1062 assert_eq!(postorder.len(), 3);
1063 assert_eq!(postorder, &[0, 1, 2]);
1064 }
1065
1066 #[test]
1067 fn test_subtree_size() {
1068 let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1069 let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1070
1071 assert_eq!(etree.subtree_size(2), 3);
1072 assert_eq!(etree.subtree_size(1), 2);
1073 assert_eq!(etree.subtree_size(0), 1);
1074 }
1075
1076 #[test]
1077 fn test_supernode_detection_diagonal() {
1078 let n = 4;
1079 let (row_offsets, col_indices, _) = identity_lower(n);
1080 let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1081 let structure = SupernodalStructure::from_etree(&etree, &row_offsets, &col_indices);
1082
1083 assert_eq!(structure.supernodes.len(), n);
1084 for sn in &structure.supernodes {
1085 assert_eq!(sn.width(), 1);
1086 }
1087 }
1088
1089 #[test]
1090 fn test_supernodal_cholesky_3x3() {
1091 let (row_offsets, col_indices, values, n) = spd_3x3_lower();
1092
1093 let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1094 .expect("symbolic should succeed");
1095 solver
1096 .numeric(&row_offsets, &col_indices, &values)
1097 .expect("numeric should succeed");
1098
1099 assert!(solver.factored);
1100 }
1101
1102 #[test]
1103 fn test_supernodal_cholesky_5x5_tridiag() {
1104 let (row_offsets, col_indices, values, n) = spd_5x5_tridiag_lower();
1105
1106 let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1107 .expect("symbolic should succeed");
1108 solver
1109 .numeric(&row_offsets, &col_indices, &values)
1110 .expect("numeric should succeed");
1111
1112 assert!(solver.factored);
1113 }
1114
1115 #[test]
1116 fn test_cholesky_solve_accuracy() {
1117 let (row_offsets, col_indices, values, n) = spd_3x3_lower();
1118 let rhs = vec![5.0, 6.0, 5.0]; let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1121 .expect("symbolic should succeed");
1122 solver
1123 .numeric(&row_offsets, &col_indices, &values)
1124 .expect("numeric should succeed");
1125 let x = solver.solve(&rhs).expect("solve should succeed");
1126
1127 let residual = residual_norm_symmetric(&row_offsets, &col_indices, &values, n, &x, &rhs);
1128 assert!(
1129 residual < 1e-10,
1130 "residual {residual:.3e} exceeds tolerance 1e-10"
1131 );
1132 }
1133
1134 #[test]
1135 fn test_lu_factorization_3x3() {
1136 let row_offsets = vec![0, 2, 5, 7];
1138 let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
1139 let values = vec![2.0, 1.0, 1.0, 3.0, 1.0, 1.0, 2.0];
1140 let n = 3;
1141
1142 let mut solver = MultifrontalLUSolver::symbolic(&row_offsets, &col_indices, n)
1143 .expect("symbolic should succeed");
1144 solver
1145 .numeric(&row_offsets, &col_indices, &values)
1146 .expect("numeric should succeed");
1147
1148 assert!(solver.factored);
1149 }
1150
1151 #[test]
1152 fn test_lu_solve_accuracy() {
1153 let row_offsets = vec![0, 2, 5, 7];
1158 let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
1159 let values = vec![2.0, 1.0, 1.0, 3.0, 1.0, 1.0, 2.0];
1160 let n = 3;
1161 let rhs = vec![3.0, 5.0, 3.0];
1162
1163 let mut solver = MultifrontalLUSolver::symbolic(&row_offsets, &col_indices, n)
1164 .expect("symbolic should succeed");
1165 solver
1166 .numeric(&row_offsets, &col_indices, &values)
1167 .expect("numeric should succeed");
1168 let x = solver.solve(&rhs).expect("solve should succeed");
1169
1170 let mut ax = vec![0.0; n];
1171 for i in 0..n {
1172 for idx in row_offsets[i]..row_offsets[i + 1] {
1173 ax[i] += values[idx] * x[col_indices[idx]];
1174 }
1175 }
1176 let residual: f64 = ax
1177 .iter()
1178 .zip(rhs.iter())
1179 .map(|(a, b)| (a - b).powi(2))
1180 .sum::<f64>()
1181 .sqrt();
1182 assert!(
1183 residual < 1e-10,
1184 "LU solve residual {residual:.3e} exceeds tolerance"
1185 );
1186 }
1187
1188 #[test]
1189 fn test_symbolic_factorization_reuse() {
1190 let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1191
1192 let sym = SymbolicFactorization::compute(&row_offsets, &col_indices, n)
1193 .expect("symbolic should succeed");
1194
1195 assert!(sym.nnz_l > 0);
1196 assert_eq!(sym.nnz_l, sym.nnz_u);
1197 assert_eq!(sym.etree.size(), n);
1198 assert!(!sym.structure.supernodes.is_empty());
1199 }
1200
1201 #[test]
1202 fn test_column_counts() {
1203 let (row_offsets, col_indices, _, n) = spd_3x3_lower();
1204 let etree = EliminationTree::from_csr(&row_offsets, &col_indices, n);
1205 let counts = column_counts(&row_offsets, &col_indices, &etree);
1206
1207 assert_eq!(counts.len(), 3);
1208 assert_eq!(counts[0], 2);
1210 assert_eq!(counts[1], 2);
1212 assert_eq!(counts[2], 1);
1214 }
1215
1216 #[test]
1217 fn test_sparse_cholesky_solve_convenience() {
1218 let (row_offsets, col_indices, values, n) = spd_3x3_lower();
1219 let rhs = vec![5.0, 6.0, 5.0];
1220
1221 let x = sparse_cholesky_solve(&row_offsets, &col_indices, &values, n, &rhs)
1222 .expect("convenience solve should succeed");
1223
1224 let residual = residual_norm_symmetric(&row_offsets, &col_indices, &values, n, &x, &rhs);
1225 assert!(
1226 residual < 1e-10,
1227 "convenience solve residual {residual:.3e} too large"
1228 );
1229 }
1230
1231 #[test]
1232 fn test_sparse_lu_solve_convenience() {
1233 let row_offsets = vec![0, 2, 5, 7];
1234 let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
1235 let values = vec![2.0, 1.0, 1.0, 3.0, 1.0, 1.0, 2.0];
1236 let n = 3;
1237 let rhs = vec![3.0, 5.0, 3.0];
1238
1239 let x = sparse_lu_solve(&row_offsets, &col_indices, &values, n, &rhs)
1240 .expect("LU convenience solve should succeed");
1241
1242 let mut ax = vec![0.0; n];
1243 for i in 0..n {
1244 for idx in row_offsets[i]..row_offsets[i + 1] {
1245 ax[i] += values[idx] * x[col_indices[idx]];
1246 }
1247 }
1248 let residual: f64 = ax
1249 .iter()
1250 .zip(rhs.iter())
1251 .map(|(a, b)| (a - b).powi(2))
1252 .sum::<f64>()
1253 .sqrt();
1254 assert!(
1255 residual < 1e-10,
1256 "LU convenience solve residual {residual:.3e} too large"
1257 );
1258 }
1259
1260 #[test]
1261 fn test_non_spd_cholesky_failure() {
1262 let row_offsets = vec![0, 1, 3, 5];
1263 let col_indices = vec![0, 0, 1, 1, 2];
1264 let values = vec![-4.0, 1.0, 4.0, 1.0, 4.0]; let n = 3;
1266
1267 let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1268 .expect("symbolic should succeed");
1269 let result = solver.numeric(&row_offsets, &col_indices, &values);
1270
1271 assert!(result.is_err());
1272 assert!(matches!(
1273 result.unwrap_err(),
1274 SolverError::NotPositiveDefinite
1275 ));
1276 }
1277
1278 #[test]
1279 fn test_singular_matrix_lu() {
1280 let row_offsets = vec![0, 2, 4, 5];
1282 let col_indices = vec![0, 1, 0, 1, 2];
1283 let values = vec![1.0, 2.0, 1.0, 2.0, 1.0];
1284 let n = 3;
1285 let rhs = vec![1.0, 1.0, 1.0];
1286
1287 let mut solver = MultifrontalLUSolver::symbolic(&row_offsets, &col_indices, n)
1288 .expect("symbolic should succeed");
1289 solver
1290 .numeric(&row_offsets, &col_indices, &values)
1291 .expect("numeric may succeed with zero pivot stored");
1292 let result = solver.solve(&rhs);
1293
1294 assert!(result.is_err());
1295 }
1296
1297 #[test]
1298 fn test_identity_factorization() {
1299 let n = 4;
1300 let (row_offsets, col_indices, values) = identity_lower(n);
1301
1302 let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1303 .expect("symbolic should succeed");
1304 solver
1305 .numeric(&row_offsets, &col_indices, &values)
1306 .expect("numeric should succeed on identity");
1307
1308 let rhs = vec![1.0, 2.0, 3.0, 4.0];
1309 let x = solver.solve(&rhs).expect("solve should succeed");
1310
1311 for i in 0..n {
1312 assert!(
1313 (x[i] - rhs[i]).abs() < 1e-14,
1314 "identity solve failed at index {i}: got {} expected {}",
1315 x[i],
1316 rhs[i]
1317 );
1318 }
1319 }
1320
1321 #[test]
1322 fn test_nnz_factor_count() {
1323 let (row_offsets, col_indices, values, n) = spd_3x3_lower();
1324
1325 let mut solver = SupernodalCholeskySolver::symbolic(&row_offsets, &col_indices, n)
1326 .expect("symbolic should succeed");
1327 solver
1328 .numeric(&row_offsets, &col_indices, &values)
1329 .expect("numeric should succeed");
1330
1331 let nnz = solver.nnz_factor();
1332 assert!(nnz >= 5, "nnz_factor = {nnz}, expected >= 5");
1335 }
1336}