1use crate::LinalgError;
2use crate::faer_ndarray::{FaerArrayView, FaerColView};
3use faer::Side;
4use faer::linalg::solvers::Solve;
5use faer::sparse::linalg::solvers::Llt as SparseLlt;
6use faer::sparse::{SparseColMat, SymbolicSparseColMat, Triplet};
7use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Data, Ix1, Ix2};
8use rayon::prelude::*;
9use std::collections::BTreeMap;
10use std::sync::{Arc, Mutex};
11
12const ZERO_TOL: f64 = 1e-12;
13const PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD: usize = 64;
14
15macro_rules! bail_invalid_linalg {
16 ($($arg:tt)*) => {
17 return Err(LinalgError::InvalidInput(format!($($arg)*)))
18 };
19}
20
21#[derive(Clone)]
22pub struct SparseExactFactor {
23 factor: SparseLlt<usize, f64>,
24 simplicial: Arc<SimplicialFactor>,
25 n: usize,
26 logdet: f64,
27}
28
29impl crate::matrix::FactorizedSystem for SparseExactFactor {
30 fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String> {
31 solve_sparse_spd(self, rhs).map_err(|e| e.to_string())
32 }
33
34 fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String> {
35 solve_sparse_spdmulti(self, rhs).map_err(|e| e.to_string())
36 }
37
38 fn logdet(&self) -> f64 {
39 self.logdet
40 }
41}
42
43pub fn dense_to_sparse(
44 matrix: &Array2<f64>,
45 tol: f64,
46) -> Result<SparseColMat<usize, f64>, LinalgError> {
47 let nrows = matrix.nrows();
48 let ncols = matrix.ncols();
49 let counts: Vec<usize> = (0..ncols)
56 .into_par_iter()
57 .map(|col| {
58 let mut count = 0usize;
59 for row in 0..nrows {
60 if matrix[[row, col]].abs() > tol {
61 count += 1;
62 }
63 }
64 count
65 })
66 .collect();
67 let col_ptr = prefix_sum_counts(&counts);
68 let nnz = col_ptr[ncols];
69 let mut row_idx = vec![0usize; nnz];
70 let mut values = vec![0.0; nnz];
71 fill_dense_to_sparse_columns(matrix, tol, 0, ncols, &col_ptr, &mut row_idx, &mut values);
72 let symbolic = SymbolicSparseColMat::<usize>::new_checked(nrows, ncols, col_ptr, None, row_idx);
73 Ok(SparseColMat::<usize, f64>::new(symbolic, values))
74}
75
76pub fn dense_to_sparse_symmetric_upper(
82 matrix: &Array2<f64>,
83 tol: f64,
84) -> Result<SparseColMat<usize, f64>, LinalgError> {
85 let nrows = matrix.nrows();
86 let ncols = matrix.ncols();
87 let row_limit = nrows.min(ncols);
93 let counts: Vec<usize> = (0..ncols)
94 .into_par_iter()
95 .map(|col| {
96 let mut count = 0usize;
97 let row_end = (col + 1).min(row_limit);
98 for row in 0..row_end {
99 if matrix[[row, col]].abs() > tol {
100 count += 1;
101 }
102 }
103 count
104 })
105 .collect();
106 let col_ptr = prefix_sum_counts(&counts);
107 let nnz = col_ptr[ncols];
108 let mut row_idx = vec![0usize; nnz];
109 let mut values = vec![0.0; nnz];
110 fill_dense_symmetric_upper_columns(
111 matrix,
112 tol,
113 row_limit,
114 0,
115 ncols,
116 &col_ptr,
117 &mut row_idx,
118 &mut values,
119 );
120 let symbolic = SymbolicSparseColMat::<usize>::new_checked(nrows, ncols, col_ptr, None, row_idx);
121 Ok(SparseColMat::<usize, f64>::new(symbolic, values))
122}
123
124fn prefix_sum_counts(counts: &[usize]) -> Vec<usize> {
125 let mut col_ptr = Vec::with_capacity(counts.len() + 1);
126 col_ptr.push(0);
127 let mut running = 0usize;
128 for &count in counts {
129 running += count;
130 col_ptr.push(running);
131 }
132 col_ptr
133}
134
135fn fill_dense_to_sparse_columns(
136 matrix: &Array2<f64>,
137 tol: f64,
138 col_start: usize,
139 col_end: usize,
140 col_ptr: &[usize],
141 row_idx: &mut [usize],
142 values: &mut [f64],
143) {
144 if col_end - col_start <= PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD {
145 let base = col_ptr[col_start];
146 for col in col_start..col_end {
147 let mut write = col_ptr[col] - base;
148 for row in 0..matrix.nrows() {
149 let value = matrix[[row, col]];
150 if value.abs() > tol {
151 row_idx[write] = row;
152 values[write] = value;
153 write += 1;
154 }
155 }
156 }
157 return;
158 }
159
160 let mid = col_start + (col_end - col_start) / 2;
161 let split = col_ptr[mid] - col_ptr[col_start];
162 let (left_rows, right_rows) = row_idx.split_at_mut(split);
163 let (left_values, right_values) = values.split_at_mut(split);
164 rayon::join(
165 || {
166 fill_dense_to_sparse_columns(
167 matrix,
168 tol,
169 col_start,
170 mid,
171 col_ptr,
172 left_rows,
173 left_values,
174 );
175 },
176 || {
177 fill_dense_to_sparse_columns(
178 matrix,
179 tol,
180 mid,
181 col_end,
182 col_ptr,
183 right_rows,
184 right_values,
185 );
186 },
187 );
188}
189
190fn fill_dense_symmetric_upper_columns(
191 matrix: &Array2<f64>,
192 tol: f64,
193 row_limit: usize,
194 col_start: usize,
195 col_end: usize,
196 col_ptr: &[usize],
197 row_idx: &mut [usize],
198 values: &mut [f64],
199) {
200 if col_end - col_start <= PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD {
201 let base = col_ptr[col_start];
202 for col in col_start..col_end {
203 let row_end = (col + 1).min(row_limit);
204 let mut write = col_ptr[col] - base;
205 for row in 0..row_end {
206 let value = matrix[[row, col]];
207 if value.abs() > tol {
208 row_idx[write] = row;
209 values[write] = value;
210 write += 1;
211 }
212 }
213 }
214 return;
215 }
216
217 let mid = col_start + (col_end - col_start) / 2;
218 let split = col_ptr[mid] - col_ptr[col_start];
219 let (left_rows, right_rows) = row_idx.split_at_mut(split);
220 let (left_values, right_values) = values.split_at_mut(split);
221 rayon::join(
222 || {
223 fill_dense_symmetric_upper_columns(
224 matrix,
225 tol,
226 row_limit,
227 col_start,
228 mid,
229 col_ptr,
230 left_rows,
231 left_values,
232 );
233 },
234 || {
235 fill_dense_symmetric_upper_columns(
236 matrix,
237 tol,
238 row_limit,
239 mid,
240 col_end,
241 col_ptr,
242 right_rows,
243 right_values,
244 );
245 },
246 );
247}
248
249pub fn sparse_symmetric_upper_matvec_public<S: Data<Elem = f64>>(
250 matrix: &SparseColMat<usize, f64>,
251 vector: &ArrayBase<S, Ix1>,
252) -> Array1<f64> {
253 let mut out = Array1::<f64>::zeros(matrix.nrows());
254 let (symbolic, values) = matrix.parts();
255 let col_ptr = symbolic.col_ptr();
256 let row_idx = symbolic.row_idx();
257 for col in 0..matrix.ncols() {
258 let x_col = vector[col];
259 for idx in col_ptr[col]..col_ptr[col + 1] {
260 let row = row_idx[idx];
261 let value = values[idx];
262 out[row] += value * x_col;
263 if row != col {
264 out[col] += value * vector[row];
265 }
266 }
267 }
268 out
269}
270
271pub fn factorize_sparse_spd(
272 h: &SparseColMat<usize, f64>,
273) -> Result<SparseExactFactor, LinalgError> {
274 let t_start = std::time::Instant::now();
284 let n_input = h.ncols();
285 let h_upper = canonicalize_sparse_symmetric_upper(h, ZERO_TOL)?;
286 let factor = h_upper.as_ref().sp_cholesky(Side::Upper).map_err(|_| {
287 LinalgError::ModelIsIllConditioned {
288 condition_number: f64::INFINITY,
289 }
290 })?;
291 let simplicial = factorize_simplicial_canonical_upper(&h_upper)?;
295 let logdet = simplicial.logdet;
296 let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
297 if elapsed_ms > 100.0 {
298 log::info!(
299 "[sparse-chol] factorize_sparse_spd | n={} | {:.1}ms",
300 n_input,
301 elapsed_ms
302 );
303 }
304 Ok(SparseExactFactor {
305 factor,
306 simplicial: Arc::new(simplicial),
307 n: h_upper.ncols(),
308 logdet,
309 })
310}
311
312fn canonicalize_sparse_symmetric_upper(
313 matrix: &SparseColMat<usize, f64>,
314 tol: f64,
315) -> Result<SparseColMat<usize, f64>, LinalgError> {
316 if matrix.nrows() != matrix.ncols() {
317 bail_invalid_linalg!(
318 "sparse SPD factorization requires square matrix, got {}x{}",
319 matrix.nrows(),
320 matrix.ncols()
321 );
322 }
323
324 #[derive(Default, Clone, Copy)]
325 struct PairAccum {
326 upper_sum: f64,
327 upper_count: usize,
328 lower_sum: f64,
329 lower_count: usize,
330 }
331
332 let mut accum: BTreeMap<(usize, usize), PairAccum> = BTreeMap::new();
333 let (symbolic, values) = matrix.parts();
334 let col_ptr = symbolic.col_ptr();
335 let row_idx = symbolic.row_idx();
336
337 for col in 0..matrix.ncols() {
338 let start = col_ptr[col];
339 let end = col_ptr[col + 1];
340 for idx in start..end {
341 let row = row_idx[idx];
342 let value = values[idx];
343 let (r, c, is_upper) = if row <= col {
344 (row, col, true)
345 } else {
346 (col, row, false)
347 };
348 let slot = accum.entry((r, c)).or_default();
349 if is_upper {
350 slot.upper_sum += value;
351 slot.upper_count += 1;
352 } else {
353 slot.lower_sum += value;
354 slot.lower_count += 1;
355 }
356 }
357 }
358
359 let mut triplets = Vec::<Triplet<usize, usize, f64>>::new();
360 for ((row, col), slot) in accum {
361 let value = if row == col {
362 let count = slot.upper_count + slot.lower_count;
363 if count == 0 {
364 0.0
365 } else {
366 (slot.upper_sum + slot.lower_sum) / (count as f64)
367 }
368 } else {
369 let upper_avg = if slot.upper_count > 0 {
370 Some(slot.upper_sum / (slot.upper_count as f64))
371 } else {
372 None
373 };
374 let lower_avg = if slot.lower_count > 0 {
375 Some(slot.lower_sum / (slot.lower_count as f64))
376 } else {
377 None
378 };
379 match (upper_avg, lower_avg) {
380 (Some(u), Some(l)) => 0.5 * (u + l),
381 (Some(u), None) => u,
382 (None, Some(l)) => l,
383 (None, None) => 0.0,
384 }
385 };
386
387 if value.abs() > tol {
388 triplets.push(Triplet::new(row, col, value));
389 }
390 }
391
392 SparseColMat::try_new_from_triplets(matrix.nrows(), matrix.ncols(), &triplets).map_err(|_| {
393 LinalgError::InvalidInput(
394 "failed to canonicalize sparse matrix to symmetric-upper CSC".to_string(),
395 )
396 })
397}
398
399fn solve_view<R, I, F>(
400 factor: &SparseExactFactor,
401 rhs: ArrayView2<'_, f64>,
402 indices: I,
403 mut result: R,
404 non_finite_message: &'static str,
405 mut consume: F,
406) -> Result<R, LinalgError>
407where
408 I: IntoIterator<Item = (usize, usize)>,
409 F: FnMut(&mut R, usize, usize, f64),
410{
411 let rhsview = FaerArrayView::new(&rhs);
412 let solved = factor.factor.solve(rhsview.as_ref());
413 for (row, col) in indices {
414 let value = solved[(row, col)];
415 if !value.is_finite() {
416 bail_invalid_linalg!("{}", non_finite_message.to_string());
417 }
418 consume(&mut result, row, col, value);
419 }
420 Ok(result)
421}
422
423pub fn solve_sparse_spd<S>(
424 factor: &SparseExactFactor,
425 rhs: &ArrayBase<S, Ix1>,
426) -> Result<Array1<f64>, LinalgError>
427where
428 S: Data<Elem = f64>,
429{
430 if rhs.len() != factor.n {
431 bail_invalid_linalg!(
432 "sparse SPD solve dimension mismatch: rhs has {}, factor has {}",
433 rhs.len(),
434 factor.n
435 );
436 }
437 let mut result = Array1::<f64>::zeros(rhs.len());
438 solve_sparse_spd_into(factor, rhs, &mut result)?;
439 Ok(result)
440}
441
442pub fn solve_sparse_spd_into<S>(
447 factor: &SparseExactFactor,
448 rhs: &ArrayBase<S, Ix1>,
449 out: &mut Array1<f64>,
450) -> Result<(), LinalgError>
451where
452 S: Data<Elem = f64>,
453{
454 if rhs.len() != factor.n {
455 bail_invalid_linalg!(
456 "sparse SPD solve dimension mismatch: rhs has {}, factor has {}",
457 rhs.len(),
458 factor.n
459 );
460 }
461 if out.len() != factor.n {
462 bail_invalid_linalg!(
463 "sparse SPD solve output dimension mismatch: out has {}, factor has {}",
464 out.len(),
465 factor.n
466 );
467 }
468 let rhsview = FaerColView::new(rhs);
469 let solved = factor.factor.solve(rhsview.as_ref());
470 for i in 0..factor.n {
471 let value = solved[(i, 0)];
472 if !value.is_finite() {
473 bail_invalid_linalg!("sparse SPD solve produced non-finite values");
474 }
475 out[i] = value;
476 }
477 Ok(())
478}
479
480pub fn solve_sparse_spdmulti<S>(
481 factor: &SparseExactFactor,
482 rhs: &ArrayBase<S, Ix2>,
483) -> Result<Array2<f64>, LinalgError>
484where
485 S: Data<Elem = f64>,
486{
487 if rhs.nrows() != factor.n {
488 bail_invalid_linalg!(
489 "sparse SPD multi-solve row mismatch: rhs has {}, factor has {}",
490 rhs.nrows(),
491 factor.n
492 );
493 }
494 let indices = (0..rhs.nrows()).flat_map(|i| (0..rhs.ncols()).map(move |j| (i, j)));
495 solve_view(
496 factor,
497 rhs.view(),
498 indices,
499 Array2::<f64>::zeros(rhs.raw_dim()),
500 "sparse SPD multi-solve produced non-finite values",
501 |result, row, col, value| {
502 result[[row, col]] = value;
503 },
504 )
505}
506
507pub fn solve_sparse_spdmulti_rows<S>(
508 factor: &SparseExactFactor,
509 rhs: &ArrayBase<S, Ix2>,
510 row_start: usize,
511 row_end: usize,
512) -> Result<Array2<f64>, LinalgError>
513where
514 S: Data<Elem = f64>,
515{
516 if rhs.nrows() != factor.n {
517 bail_invalid_linalg!(
518 "sparse SPD multi-solve row mismatch: rhs has {}, factor has {}",
519 rhs.nrows(),
520 factor.n
521 );
522 }
523 if row_start > row_end || row_end > factor.n {
524 bail_invalid_linalg!(
525 "sparse SPD selected rows out of bounds: row_start={}, row_end={}, factor={}",
526 row_start,
527 row_end,
528 factor.n
529 );
530 }
531 let indices = (row_start..row_end).flat_map(|i| (0..rhs.ncols()).map(move |j| (i, j)));
532 solve_view(
533 factor,
534 rhs.view(),
535 indices,
536 Array2::<f64>::zeros((row_end - row_start, rhs.ncols())),
537 "sparse SPD selected-row solve produced non-finite values",
538 |result, row, col, value| {
539 result[[row - row_start, col]] = value;
540 },
541 )
542}
543
544pub fn solve_sparse_spdmulti_diagonal_sum<S>(
545 factor: &SparseExactFactor,
546 rhs: &ArrayBase<S, Ix2>,
547 row_start: usize,
548) -> Result<f64, LinalgError>
549where
550 S: Data<Elem = f64>,
551{
552 if row_start.saturating_add(rhs.ncols()) > rhs.nrows() {
553 bail_invalid_linalg!(
554 "sparse SPD selected diagonal out of bounds: row_start={}, rows={}, cols={}",
555 row_start,
556 rhs.nrows(),
557 rhs.ncols()
558 );
559 }
560 let indices = (0..rhs.ncols()).map(|col| (row_start + col, col));
561 solve_view(
562 factor,
563 rhs.view(),
564 indices,
565 0.0,
566 "sparse SPD selected diagonal solve produced non-finite values",
567 |sum, _, _, value| {
568 *sum += value;
569 },
570 )
571}
572
573pub fn logdet_from_factor(factor: &SparseExactFactor) -> Result<f64, LinalgError> {
574 Ok(factor.logdet)
575}
576
577pub fn assemble_sparse_factor_h_dense(
578 factor: &SparseExactFactor,
579) -> Result<Array2<f64>, LinalgError> {
580 factor.simplicial.assemble_h_dense_original_order()
581}
582
583use faer::dyn_stack::{MemBuffer, MemStack, StackReq};
588use faer::linalg::cholesky::llt::factor::LltRegularization;
589use faer::sparse::linalg::amd;
590use faer::sparse::linalg::cholesky::simplicial;
591
592pub struct SimplicialFactor {
597 l_col_ptr: Vec<usize>,
599 l_row_idx: Vec<usize>,
601 l_values: Vec<f64>,
603 perm_inv: Vec<usize>,
606 n: usize,
608 pub logdet: f64,
610}
611
612pub fn factorize_simplicial(h: &SparseColMat<usize, f64>) -> Result<SimplicialFactor, LinalgError> {
618 let h_upper = canonicalize_sparse_symmetric_upper(h, ZERO_TOL)?;
619 factorize_simplicial_canonical_upper(&h_upper)
620}
621
622fn factorize_simplicial_canonical_upper(
623 h_upper: &SparseColMat<usize, f64>,
624) -> Result<SimplicialFactor, LinalgError> {
625 let n = h_upper.ncols();
626 if n == 0 {
627 return Ok(SimplicialFactor {
628 l_col_ptr: vec![0],
629 l_row_idx: Vec::new(),
630 l_values: Vec::new(),
631 perm_inv: Vec::new(),
632 n: 0,
633 logdet: 0.0,
634 });
635 }
636
637 let a_nnz = h_upper.compute_nnz();
638
639 let mut perm_fwd = vec![0usize; n];
641 let mut perm_inv = vec![0usize; n];
642 {
643 let mut mem = MemBuffer::new(amd::order_scratch::<usize>(n, a_nnz));
644 amd::order(
645 &mut perm_fwd,
646 &mut perm_inv,
647 h_upper.symbolic(),
648 amd::Control::default(),
649 MemStack::new(&mut mem),
650 )
651 .map_err(|_| LinalgError::ModelIsIllConditioned {
652 condition_number: f64::INFINITY,
653 })?;
654 }
655
656 let perm = unsafe { faer::perm::PermRef::new_unchecked(&perm_fwd, &perm_inv, n) };
662
663 let a_perm_upper = {
665 let mut col_ptrs = vec![0usize; n + 1];
666 let mut row_indices = vec![0usize; a_nnz];
667 let mut values = vec![0.0f64; a_nnz];
668 let mut mem = MemBuffer::new(faer::sparse::utils::permute_self_adjoint_scratch::<usize>(
669 n,
670 ));
671 faer::sparse::utils::permute_self_adjoint_to_unsorted(
672 &mut values,
673 &mut col_ptrs,
674 &mut row_indices,
675 h_upper.as_ref(),
676 perm,
677 Side::Upper,
678 Side::Upper,
679 MemStack::new(&mut mem),
680 );
681 SparseColMat::<usize, f64>::new(
682 unsafe { SymbolicSparseColMat::new_unchecked(n, n, col_ptrs, None, row_indices) },
691 values,
692 )
693 };
694
695 let symbolic = {
697 let mut mem = MemBuffer::new(StackReq::any_of(&[
698 simplicial::prefactorize_symbolic_cholesky_scratch::<usize>(n, a_nnz),
699 simplicial::factorize_simplicial_symbolic_cholesky_scratch::<usize>(n),
700 ]));
701 let stack = MemStack::new(&mut mem);
702 let mut etree = vec![0isize; n];
703 let mut col_counts = vec![0usize; n];
704 let etree_ref = simplicial::prefactorize_symbolic_cholesky(
705 &mut etree,
706 &mut col_counts,
707 a_perm_upper.symbolic(),
708 stack,
709 );
710 simplicial::factorize_simplicial_symbolic_cholesky(
711 a_perm_upper.symbolic(),
712 etree_ref,
713 &col_counts,
714 stack,
715 )
716 .map_err(|_| LinalgError::ModelIsIllConditioned {
717 condition_number: f64::INFINITY,
718 })?
719 };
720
721 let mut l_values = vec![0.0f64; symbolic.len_val()];
723 {
724 let mut mem = MemBuffer::new(simplicial::factorize_simplicial_numeric_llt_scratch::<
725 usize,
726 f64,
727 >(n));
728 simplicial::factorize_simplicial_numeric_llt::<usize, f64>(
729 &mut l_values,
730 a_perm_upper.as_ref(),
731 LltRegularization::default(),
732 &symbolic,
733 MemStack::new(&mut mem),
734 )
735 .map_err(|_| LinalgError::HessianNotPositiveDefinite {
736 min_eigenvalue: f64::NAN,
737 })?;
738 }
739
740 let l_col_ptr: Vec<usize> = symbolic.col_ptr().to_vec();
742 let l_row_idx: Vec<usize> = symbolic.row_idx().to_vec();
743
744 let mut logdet = 0.0f64;
746 for j in 0..n {
747 let diag = l_values[l_col_ptr[j]];
748 if diag <= 0.0 {
749 return Err(LinalgError::HessianNotPositiveDefinite {
750 min_eigenvalue: f64::NAN,
751 });
752 }
753 logdet += diag.ln();
754 }
755 logdet *= 2.0;
756
757 Ok(SimplicialFactor {
758 l_col_ptr,
759 l_row_idx,
760 l_values,
761 perm_inv,
762 n,
763 logdet,
764 })
765}
766
767impl SimplicialFactor {
768 fn assemble_h_dense_original_order(&self) -> Result<Array2<f64>, LinalgError> {
775 if self.perm_inv.len() != self.n {
776 bail_invalid_linalg!(
777 "simplicial factor permutation length {} does not match dimension {}",
778 self.perm_inv.len(),
779 self.n
780 );
781 }
782 let mut h_permuted = Array2::<f64>::zeros((self.n, self.n));
783 for col in 0..self.n {
784 let start = self.l_col_ptr[col];
785 let end = self.l_col_ptr[col + 1];
786 for left_idx in start..end {
787 let left_row = self.l_row_idx[left_idx];
788 let left_value = self.l_values[left_idx];
789 if !left_value.is_finite() {
790 bail_invalid_linalg!(
791 "simplicial factor has non-finite L entry at value index {left_idx}"
792 );
793 }
794 for right_idx in start..end {
795 let right_row = self.l_row_idx[right_idx];
796 let right_value = self.l_values[right_idx];
797 h_permuted[[left_row, right_row]] += left_value * right_value;
798 }
799 }
800 }
801
802 let mut h_original = Array2::<f64>::zeros((self.n, self.n));
803 for i in 0..self.n {
804 let pi = self.perm_inv[i];
805 if pi >= self.n {
806 bail_invalid_linalg!(
807 "simplicial factor permutation maps row {i} to out-of-bounds index {pi}"
808 );
809 }
810 for j in 0..self.n {
811 let pj = self.perm_inv[j];
812 if pj >= self.n {
813 bail_invalid_linalg!(
814 "simplicial factor permutation maps column {j} to out-of-bounds index {pj}"
815 );
816 }
817 let value = h_permuted[[pi, pj]];
818 if !value.is_finite() {
819 bail_invalid_linalg!(
820 "dense reconstruction from sparse Cholesky produced non-finite values"
821 );
822 }
823 h_original[[i, j]] = value;
824 }
825 }
826 Ok(h_original)
827 }
828}
829
830pub struct TakahashiInverse {
836 z_values: Vec<f64>,
838 col_ptr: Vec<usize>,
840 row_idx: Vec<usize>,
842 l_values: Vec<f64>,
844 rows_lower: Arc<Vec<Vec<(usize, f64)>>>,
846 exact_columns: Mutex<BTreeMap<usize, Arc<Vec<f64>>>>,
849 perm_inv: Vec<usize>,
851 n: usize,
853}
854
855impl TakahashiInverse {
856 fn find_entry(col_ptr: &[usize], row_idx: &[usize], row: usize, col: usize) -> Option<usize> {
859 let start = col_ptr[col];
860 let end = col_ptr[col + 1];
861 let slice = &row_idx[start..end];
862 slice.binary_search(&row).ok().map(|pos| start + pos)
863 }
864
865 fn solve_permuted_column_from_cholesky(
866 n: usize,
867 col_ptr: &[usize],
868 row_idx: &[usize],
869 l_values: &[f64],
870 rows_lower: &[Vec<(usize, f64)>],
871 rhs_col: usize,
872 ) -> Vec<f64> {
873 let mut rhs = vec![0.0f64; n];
874 rhs[rhs_col] = 1.0;
875 let mut forward = vec![0.0f64; n];
876 let mut solution = vec![0.0f64; n];
877
878 for row in 0..n {
879 let mut sum = rhs[row];
880 let mut diag = None;
881 for &(col, value) in &rows_lower[row] {
882 if col < row {
883 sum -= value * forward[col];
884 } else if col == row {
885 diag = Some(value);
886 }
887 }
888 let l_rr = diag.expect("simplicial factor row should contain its diagonal");
889 forward[row] = sum / l_rr;
890 }
891
892 for row in (0..n).rev() {
893 let col_start = col_ptr[row];
894 let col_end = col_ptr[row + 1];
895 let mut sum = forward[row];
896 let l_rr = l_values[col_start];
897 for idx in (col_start + 1)..col_end {
898 let lower_row = row_idx[idx];
899 sum -= l_values[idx] * solution[lower_row];
900 }
901 solution[row] = sum / l_rr;
902 }
903
904 solution
905 }
906
907 fn exact_permuted_column(&self, col: usize) -> Arc<Vec<f64>> {
908 {
909 let cache = self
910 .exact_columns
911 .lock()
912 .expect("exact Takahashi column cache mutex poisoned");
913 if let Some(solution) = cache.get(&col) {
914 return solution.clone();
915 }
916 }
917
918 let solution = Arc::new(Self::solve_permuted_column_from_cholesky(
919 self.n,
920 &self.col_ptr,
921 &self.row_idx,
922 &self.l_values,
923 self.rows_lower.as_ref(),
924 col,
925 ));
926
927 let mut cache = self
928 .exact_columns
929 .lock()
930 .expect("exact Takahashi column cache mutex poisoned");
931 cache.entry(col).or_insert_with(|| solution.clone()).clone()
932 }
933
934 fn selected_value(
935 z_values: &[f64],
936 col_ptr: &[usize],
937 row_idx: &[usize],
938 row: usize,
939 col: usize,
940 ) -> Result<f64, LinalgError> {
941 let (lower_row, lower_col) = if row >= col { (row, col) } else { (col, row) };
942 Self::find_entry(col_ptr, row_idx, lower_row, lower_col)
943 .map(|idx| z_values[idx])
944 .ok_or_else(|| {
945 LinalgError::InvalidInput(format!(
946 "simplicial selected-inverse pattern is missing entry ({lower_row},{lower_col})"
947 ))
948 })
949 }
950
951 pub fn compute(factor: &SimplicialFactor) -> Result<Self, LinalgError> {
957 let n = factor.n;
958 let col_ptr = factor.l_col_ptr.clone();
959 let row_idx = factor.l_row_idx.clone();
960 let nnz = factor.l_values.len();
961 let mut z_values = vec![0.0f64; nnz];
962
963 let mut rows_lower: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
965 for col in 0..n {
966 for idx in col_ptr[col]..col_ptr[col + 1] {
967 let row = row_idx[idx];
968 rows_lower[row].push((col, factor.l_values[idx]));
969 }
970 }
971
972 for j in (0..n).rev() {
973 let diag_idx = col_ptr[j];
974 let col_end = col_ptr[j + 1];
975 let diag = factor.l_values[diag_idx];
976 if !(diag.is_finite() && diag > 0.0) {
977 return Err(LinalgError::HessianNotPositiveDefinite {
978 min_eigenvalue: f64::NAN,
979 });
980 }
981 for idx in (diag_idx + 1)..col_end {
982 let i = row_idx[idx];
983 let mut correction = 0.0;
984 for off_idx in (diag_idx + 1)..col_end {
985 let k = row_idx[off_idx];
986 let l_kj = factor.l_values[off_idx];
987 let z_ik = Self::selected_value(&z_values, &col_ptr, &row_idx, i, k)?;
988 correction += l_kj * z_ik;
989 }
990 let value = -correction / diag;
991 if !value.is_finite() {
992 bail_invalid_linalg!(
993 "Takahashi selected inverse produced non-finite entry ({i},{j})"
994 );
995 }
996 z_values[idx] = value;
997 }
998 let mut correction = 0.0;
999 for off_idx in (diag_idx + 1)..col_end {
1000 correction += factor.l_values[off_idx] * z_values[off_idx];
1001 }
1002 let value = (1.0 / diag - correction) / diag;
1003 if !value.is_finite() {
1004 bail_invalid_linalg!(
1005 "Takahashi selected inverse produced non-finite diagonal entry ({j},{j})"
1006 );
1007 }
1008 z_values[diag_idx] = value;
1009 }
1010
1011 Ok(TakahashiInverse {
1012 z_values,
1013 col_ptr,
1014 row_idx,
1015 l_values: factor.l_values.clone(),
1016 rows_lower: Arc::new(rows_lower),
1017 exact_columns: Mutex::new(BTreeMap::new()),
1018 perm_inv: factor.perm_inv.clone(),
1019 n,
1020 })
1021 }
1022
1023 pub fn get(&self, i: usize, j: usize) -> f64 {
1025 let pi = self.perm_inv[i];
1026 let pj = self.perm_inv[j];
1027 self.get_permuted(pi, pj)
1028 }
1029
1030 fn get_permuted(&self, pi: usize, pj: usize) -> f64 {
1032 let (row, col) = if pi >= pj { (pi, pj) } else { (pj, pi) };
1035 if let Some(pos) = Self::find_entry(&self.col_ptr, &self.row_idx, row, col) {
1036 self.z_values[pos]
1037 } else {
1038 self.exact_permuted_column(col)[row]
1039 }
1040 }
1041
1042 pub fn diagonal(&self) -> Array1<f64> {
1044 Array1::from_iter((0..self.n).map(|i| self.get(i, i)))
1045 }
1046
1047 pub fn block(&self, start: usize, end: usize) -> Array2<f64> {
1049 let dim = end - start;
1050 let mut out = Array2::zeros((dim, dim));
1051 for j_local in 0..dim {
1052 let j = start + j_local;
1053 for i_local in 0..dim {
1054 let i = start + i_local;
1055 out[[i_local, j_local]] = self.get(i, j);
1056 }
1057 }
1058 out
1059 }
1060
1061 pub fn trace_product_sparse(&self, s: &SparseColMat<usize, f64>) -> f64 {
1079 let (symbolic, values) = s.parts();
1080 let s_col_ptr = symbolic.col_ptr();
1081 let s_row_idx = symbolic.row_idx();
1082 let mut trace = 0.0;
1083 for col in 0..s.ncols() {
1084 let col_start = s_col_ptr[col];
1085 let col_end = s_col_ptr[col + 1];
1086 for idx in col_start..col_end {
1087 let row = s_row_idx[idx];
1088 if row > col {
1089 continue; }
1091 let val = values[idx];
1092 let z_ij = self.get(row, col);
1093 if row == col {
1094 trace += z_ij * val;
1095 } else {
1096 trace += 2.0 * z_ij * val;
1097 }
1098 }
1099 }
1100 trace
1101 }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106 use super::*;
1107 use crate::faer_ndarray::FaerCholesky;
1108 use ndarray::{array, Array1, Array2};
1109
1110 fn approx_eq(a: f64, b: f64, tol: f64) {
1111 assert!(
1112 (a - b).abs() <= tol,
1113 "values differ: left={a:.12e}, right={b:.12e}, |diff|={:.12e}, tol={tol:.12e}",
1114 (a - b).abs()
1115 );
1116 }
1117
1118 #[test]
1121 fn dense_to_sparse_preserves_all_nonzero_entries() {
1122 let m = array![[1.0, 2.0, 3.0], [0.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
1124 let s = dense_to_sparse(&m, ZERO_TOL).unwrap();
1125 assert_eq!(s.nrows(), 3);
1126 assert_eq!(s.ncols(), 3);
1127 assert_eq!(s.compute_nnz(), 8);
1129 }
1130
1131 #[test]
1132 fn dense_to_sparse_round_trips_via_matvec_identity() {
1133 let m = array![[4.0, 1.0, 0.5], [1.0, 3.0, 2.0], [0.5, 2.0, 6.0]];
1135 let s = dense_to_sparse(&m, ZERO_TOL).unwrap();
1136 for j in 0..3 {
1137 let mut ej = Array1::<f64>::zeros(3);
1138 ej[j] = 1.0;
1139 let result = {
1141 let mut out = Array1::<f64>::zeros(3);
1142 let (sym, vals) = s.parts();
1143 let col_ptr = sym.col_ptr();
1144 let row_idx = sym.row_idx();
1145 for col in 0..3 {
1146 for idx in col_ptr[col]..col_ptr[col + 1] {
1147 let row = row_idx[idx];
1148 out[row] += vals[idx] * ej[col];
1149 }
1150 }
1151 out
1152 };
1153 for i in 0..3 {
1154 approx_eq(result[i], m[[i, j]], 1e-14);
1155 }
1156 }
1157 }
1158
1159 #[test]
1160 fn dense_to_sparse_filters_entries_below_tolerance() {
1161 let tol = 0.1;
1162 let m = array![[1.0, 0.05], [0.05, 2.0]];
1163 let s = dense_to_sparse(&m, tol).unwrap();
1164 assert_eq!(s.compute_nnz(), 2, "off-diagonal entries below tol must be dropped");
1166 }
1167
1168 #[test]
1171 fn dense_to_sparse_symmetric_upper_stores_only_upper_triangle() {
1172 let m = array![[4.0, 1.0, 2.0], [1.0, 5.0, 3.0], [2.0, 3.0, 6.0]];
1174 let s = dense_to_sparse_symmetric_upper(&m, ZERO_TOL).unwrap();
1175 assert_eq!(s.compute_nnz(), 6);
1177 }
1178
1179 #[test]
1182 fn sparse_symmetric_upper_matvec_matches_dense_matvec() {
1183 let a = array![[4.0, 2.0, 0.0], [2.0, 5.0, 3.0], [0.0, 3.0, 6.0]];
1186 let v = array![1.0, 2.0, 3.0];
1187 let expected = a.dot(&v); let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1189 let got = sparse_symmetric_upper_matvec_public(&a_sparse, &v);
1190 for i in 0..3 {
1191 approx_eq(got[i], expected[i], 1e-13);
1192 }
1193 }
1194
1195 #[test]
1196 fn sparse_symmetric_upper_matvec_diagonal_only() {
1197 let a = array![[3.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 7.0]];
1199 let v = array![2.0, 4.0, 6.0];
1200 let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1201 let got = sparse_symmetric_upper_matvec_public(&a_sparse, &v);
1202 approx_eq(got[0], 6.0, 1e-14);
1203 approx_eq(got[1], 20.0, 1e-14);
1204 approx_eq(got[2], 42.0, 1e-14);
1205 }
1206
1207 #[test]
1210 fn solve_sparse_spd_recovers_known_solution() {
1211 let a = array![[4.0, 2.0], [2.0, 5.0]];
1213 let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1214 let factor = factorize_sparse_spd(&a_sparse).unwrap();
1215 let rhs = array![6.0, 11.0];
1216 let x = solve_sparse_spd(&factor, &rhs).unwrap();
1217 approx_eq(x[0], 0.5, 1e-12);
1219 approx_eq(x[1], 2.0, 1e-12);
1220 }
1221
1222 #[test]
1223 fn solve_sparse_spd_3x3_round_trip() {
1224 let a: Array2<f64> = array![
1225 [9.0, 3.0, 1.0],
1226 [3.0, 8.0, 2.0],
1227 [1.0, 2.0, 7.0]
1228 ];
1229 let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1230 let factor = factorize_sparse_spd(&a_sparse).unwrap();
1231 for j in 0..3 {
1232 let mut ej = Array1::<f64>::zeros(3);
1233 ej[j] = 1.0;
1234 let col_j = solve_sparse_spd(&factor, &ej).unwrap();
1235 let ax = a.dot(&col_j);
1237 for i in 0..3 {
1238 approx_eq(ax[i], ej[i], 1e-12);
1239 }
1240 }
1241 }
1242
1243 #[test]
1244 fn logdet_from_factor_matches_dense_logdet_diagonal() {
1245 let a: Array2<f64> =
1247 array![[4.0, 0.0, 0.0], [0.0, 9.0, 0.0], [0.0, 0.0, 16.0]];
1248 let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1249 let factor = factorize_sparse_spd(&a_sparse).unwrap();
1250 let logdet = logdet_from_factor(&factor).unwrap();
1251 let expected = 4.0_f64.ln() + 9.0_f64.ln() + 16.0_f64.ln();
1252 approx_eq(logdet, expected, 1e-12);
1253 }
1254
1255 #[test]
1256 fn logdet_from_factor_matches_2x2_formula() {
1257 let a = array![[4.0, 2.0], [2.0, 5.0]];
1259 let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1260 let factor = factorize_sparse_spd(&a_sparse).unwrap();
1261 let logdet = logdet_from_factor(&factor).unwrap();
1262 approx_eq(logdet, 16.0_f64.ln(), 1e-12);
1263 }
1264
1265 #[test]
1266 fn solve_sparse_spd_dimension_mismatch_returns_error() {
1267 let a = array![[4.0, 2.0], [2.0, 5.0]];
1268 let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1269 let factor = factorize_sparse_spd(&a_sparse).unwrap();
1270 let rhs = array![1.0, 2.0, 3.0]; assert!(solve_sparse_spd(&factor, &rhs).is_err());
1272 }
1273
1274 #[test]
1275 fn takahashi_diagonal_matches_dense_inverse() {
1276 let h = array![
1278 [4.0, 0.2, 0.0, 0.0],
1279 [0.2, 3.0, 0.1, 0.0],
1280 [0.0, 0.1, 2.5, 0.3],
1281 [0.0, 0.0, 0.3, 2.0]
1282 ];
1283 let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
1284
1285 let chol = h.cholesky(Side::Lower).unwrap();
1287 let mut h_inv = Array2::<f64>::zeros((4, 4));
1288 for j in 0..4 {
1289 let mut rhs = Array1::<f64>::zeros(4);
1290 rhs[j] = 1.0;
1291 let col = chol.solvevec(&rhs);
1292 for i in 0..4 {
1293 h_inv[[i, j]] = col[i];
1294 }
1295 }
1296
1297 let sfactor = factorize_simplicial(&h_sparse).unwrap();
1298 let taka = TakahashiInverse::compute(&sfactor).unwrap();
1299 let diag = taka.diagonal();
1300
1301 for i in 0..4 {
1303 approx_eq(diag[i], h_inv[[i, i]], 1e-10);
1304 }
1305 }
1306
1307 #[test]
1308 fn takahashi_logdet_matches_dense() {
1309 let h = array![
1310 [4.0, 0.2, 0.0, 0.0],
1311 [0.2, 3.0, 0.1, 0.0],
1312 [0.0, 0.1, 2.5, 0.3],
1313 [0.0, 0.0, 0.3, 2.0]
1314 ];
1315 let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
1316
1317 let existing = factorize_sparse_spd(&h_sparse).unwrap();
1319 let logdet_dense = existing.logdet;
1320
1321 let sfactor = factorize_simplicial(&h_sparse).unwrap();
1322 approx_eq(sfactor.logdet, logdet_dense, 1e-10);
1323 }
1324
1325 #[test]
1326 fn takahashi_get_and_block_recover_off_pattern_inverse_entries() {
1327 let h = array![
1328 [4.0, 1.0, 0.0, 0.0],
1329 [1.0, 3.0, 1.0, 0.0],
1330 [0.0, 1.0, 2.5, 1.0],
1331 [0.0, 0.0, 1.0, 2.0]
1332 ];
1333 let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
1334
1335 let chol = h.cholesky(Side::Lower).unwrap();
1336 let mut h_inv = Array2::<f64>::zeros((4, 4));
1337 for j in 0..4 {
1338 let mut rhs = Array1::<f64>::zeros(4);
1339 rhs[j] = 1.0;
1340 let col = chol.solvevec(&rhs);
1341 for i in 0..4 {
1342 h_inv[[i, j]] = col[i];
1343 }
1344 }
1345
1346 let sfactor = factorize_simplicial(&h_sparse).unwrap();
1347 let taka = TakahashiInverse::compute(&sfactor).unwrap();
1348
1349 assert!(
1350 h_inv[[0, 2]].abs() > 1e-8,
1351 "reference off-pattern inverse entry should be nonzero"
1352 );
1353 approx_eq(taka.get(0, 2), h_inv[[0, 2]], 1e-10);
1354
1355 let block = taka.block(0, 3);
1356 approx_eq(block[[0, 2]], h_inv[[0, 2]], 1e-10);
1357 approx_eq(block[[2, 0]], h_inv[[2, 0]], 1e-10);
1358 }
1359}