Skip to main content

gam_linalg/
sparse_exact.rs

1use crate::LinalgError;
2use crate::faer_ndarray::{FaerArrayView, FaerColView};
3use faer::Side;
4use faer::linalg::solvers::Solve;
5use faer::sparse::linalg::solvers::Llt as SparseLlt;
6use faer::sparse::{SparseColMat, SymbolicSparseColMat, Triplet};
7use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Data, Ix1, Ix2};
8use rayon::prelude::*;
9use std::collections::BTreeMap;
10use std::sync::{Arc, Mutex};
11
12const ZERO_TOL: f64 = 1e-12;
13const PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD: usize = 64;
14
15macro_rules! bail_invalid_linalg {
16    ($($arg:tt)*) => {
17        return Err(LinalgError::InvalidInput(format!($($arg)*)))
18    };
19}
20
21#[derive(Clone)]
22pub struct SparseExactFactor {
23    factor: SparseLlt<usize, f64>,
24    simplicial: Arc<SimplicialFactor>,
25    n: usize,
26    logdet: f64,
27}
28
29impl crate::matrix::FactorizedSystem for SparseExactFactor {
30    fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String> {
31        solve_sparse_spd(self, rhs).map_err(|e| e.to_string())
32    }
33
34    fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String> {
35        solve_sparse_spdmulti(self, rhs).map_err(|e| e.to_string())
36    }
37
38    fn logdet(&self) -> f64 {
39        self.logdet
40    }
41}
42
43pub fn dense_to_sparse(
44    matrix: &Array2<f64>,
45    tol: f64,
46) -> Result<SparseColMat<usize, f64>, LinalgError> {
47    let nrows = matrix.nrows();
48    let ncols = matrix.ncols();
49    // Direct column-major CSC construction.  Three-pass: count nnz per
50    // column in parallel, perform the prefix sum serially, then fill each
51    // deterministic column slice in parallel.  Columns are still traversed
52    // in order and rows are written in ascending order within each column,
53    // preserving the same canonical CSC ordering as the previous serial
54    // implementation without requiring a triplet sort/dedup pass.
55    let counts: Vec<usize> = (0..ncols)
56        .into_par_iter()
57        .map(|col| {
58            let mut count = 0usize;
59            for row in 0..nrows {
60                if matrix[[row, col]].abs() > tol {
61                    count += 1;
62                }
63            }
64            count
65        })
66        .collect();
67    let col_ptr = prefix_sum_counts(&counts);
68    let nnz = col_ptr[ncols];
69    let mut row_idx = vec![0usize; nnz];
70    let mut values = vec![0.0; nnz];
71    fill_dense_to_sparse_columns(matrix, tol, 0, ncols, &col_ptr, &mut row_idx, &mut values);
72    let symbolic = SymbolicSparseColMat::<usize>::new_checked(nrows, ncols, col_ptr, None, row_idx);
73    Ok(SparseColMat::<usize, f64>::new(symbolic, values))
74}
75
76/// Convert a dense symmetric matrix to sparse CSC storing only the upper triangle.
77///
78/// This encoding is required by sparse SPD routines in this module that interpret
79/// entries as symmetric-upper storage and mirror off-diagonals when reconstructing
80/// dense diagnostics.
81pub fn dense_to_sparse_symmetric_upper(
82    matrix: &Array2<f64>,
83    tol: f64,
84) -> Result<SparseColMat<usize, f64>, LinalgError> {
85    let nrows = matrix.nrows();
86    let ncols = matrix.ncols();
87    // Direct CSC build over the upper triangle.  Counts and fills are
88    // parallelized by column, with a serial prefix sum between them so every
89    // column writes to a deterministic, non-overlapping slice.  Iterating rows
90    // from low to high within each column keeps CSC row indices sorted exactly
91    // as in the previous serial implementation.
92    let row_limit = nrows.min(ncols);
93    let counts: Vec<usize> = (0..ncols)
94        .into_par_iter()
95        .map(|col| {
96            let mut count = 0usize;
97            let row_end = (col + 1).min(row_limit);
98            for row in 0..row_end {
99                if matrix[[row, col]].abs() > tol {
100                    count += 1;
101                }
102            }
103            count
104        })
105        .collect();
106    let col_ptr = prefix_sum_counts(&counts);
107    let nnz = col_ptr[ncols];
108    let mut row_idx = vec![0usize; nnz];
109    let mut values = vec![0.0; nnz];
110    fill_dense_symmetric_upper_columns(
111        matrix,
112        tol,
113        row_limit,
114        0,
115        ncols,
116        &col_ptr,
117        &mut row_idx,
118        &mut values,
119    );
120    let symbolic = SymbolicSparseColMat::<usize>::new_checked(nrows, ncols, col_ptr, None, row_idx);
121    Ok(SparseColMat::<usize, f64>::new(symbolic, values))
122}
123
124fn prefix_sum_counts(counts: &[usize]) -> Vec<usize> {
125    let mut col_ptr = Vec::with_capacity(counts.len() + 1);
126    col_ptr.push(0);
127    let mut running = 0usize;
128    for &count in counts {
129        running += count;
130        col_ptr.push(running);
131    }
132    col_ptr
133}
134
135fn fill_dense_to_sparse_columns(
136    matrix: &Array2<f64>,
137    tol: f64,
138    col_start: usize,
139    col_end: usize,
140    col_ptr: &[usize],
141    row_idx: &mut [usize],
142    values: &mut [f64],
143) {
144    if col_end - col_start <= PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD {
145        let base = col_ptr[col_start];
146        for col in col_start..col_end {
147            let mut write = col_ptr[col] - base;
148            for row in 0..matrix.nrows() {
149                let value = matrix[[row, col]];
150                if value.abs() > tol {
151                    row_idx[write] = row;
152                    values[write] = value;
153                    write += 1;
154                }
155            }
156        }
157        return;
158    }
159
160    let mid = col_start + (col_end - col_start) / 2;
161    let split = col_ptr[mid] - col_ptr[col_start];
162    let (left_rows, right_rows) = row_idx.split_at_mut(split);
163    let (left_values, right_values) = values.split_at_mut(split);
164    rayon::join(
165        || {
166            fill_dense_to_sparse_columns(
167                matrix,
168                tol,
169                col_start,
170                mid,
171                col_ptr,
172                left_rows,
173                left_values,
174            );
175        },
176        || {
177            fill_dense_to_sparse_columns(
178                matrix,
179                tol,
180                mid,
181                col_end,
182                col_ptr,
183                right_rows,
184                right_values,
185            );
186        },
187    );
188}
189
190fn fill_dense_symmetric_upper_columns(
191    matrix: &Array2<f64>,
192    tol: f64,
193    row_limit: usize,
194    col_start: usize,
195    col_end: usize,
196    col_ptr: &[usize],
197    row_idx: &mut [usize],
198    values: &mut [f64],
199) {
200    if col_end - col_start <= PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD {
201        let base = col_ptr[col_start];
202        for col in col_start..col_end {
203            let row_end = (col + 1).min(row_limit);
204            let mut write = col_ptr[col] - base;
205            for row in 0..row_end {
206                let value = matrix[[row, col]];
207                if value.abs() > tol {
208                    row_idx[write] = row;
209                    values[write] = value;
210                    write += 1;
211                }
212            }
213        }
214        return;
215    }
216
217    let mid = col_start + (col_end - col_start) / 2;
218    let split = col_ptr[mid] - col_ptr[col_start];
219    let (left_rows, right_rows) = row_idx.split_at_mut(split);
220    let (left_values, right_values) = values.split_at_mut(split);
221    rayon::join(
222        || {
223            fill_dense_symmetric_upper_columns(
224                matrix,
225                tol,
226                row_limit,
227                col_start,
228                mid,
229                col_ptr,
230                left_rows,
231                left_values,
232            );
233        },
234        || {
235            fill_dense_symmetric_upper_columns(
236                matrix,
237                tol,
238                row_limit,
239                mid,
240                col_end,
241                col_ptr,
242                right_rows,
243                right_values,
244            );
245        },
246    );
247}
248
249pub fn sparse_symmetric_upper_matvec_public<S: Data<Elem = f64>>(
250    matrix: &SparseColMat<usize, f64>,
251    vector: &ArrayBase<S, Ix1>,
252) -> Array1<f64> {
253    let mut out = Array1::<f64>::zeros(matrix.nrows());
254    let (symbolic, values) = matrix.parts();
255    let col_ptr = symbolic.col_ptr();
256    let row_idx = symbolic.row_idx();
257    for col in 0..matrix.ncols() {
258        let x_col = vector[col];
259        for idx in col_ptr[col]..col_ptr[col + 1] {
260            let row = row_idx[idx];
261            let value = values[idx];
262            out[row] += value * x_col;
263            if row != col {
264                out[col] += value * vector[row];
265            }
266        }
267    }
268    out
269}
270
271pub fn factorize_sparse_spd(
272    h: &SparseColMat<usize, f64>,
273) -> Result<SparseExactFactor, LinalgError> {
274    // Canonicalize to symmetric-upper storage before factorization.
275    //
276    // Math contract:
277    // - If callers pass upper-only storage, values are preserved.
278    // - If callers pass full symmetric storage, paired (i,j)/(j,i) entries are averaged.
279    // - If callers pass lower-only storage, it is mirrored into upper.
280    //
281    // This prevents off-diagonal double counting in paths that interpret input as
282    // symmetric-upper and makes the sparse factor path robust to caller encoding.
283    let t_start = std::time::Instant::now();
284    let n_input = h.ncols();
285    let h_upper = canonicalize_sparse_symmetric_upper(h, ZERO_TOL)?;
286    let factor = h_upper.as_ref().sp_cholesky(Side::Upper).map_err(|_| {
287        LinalgError::ModelIsIllConditioned {
288            condition_number: f64::INFINITY,
289        }
290    })?;
291    // Keep an explicit simplicial LLᵀ factor in addition to faer's solver
292    // object. The raw L is needed by callers that must reconstruct H in a
293    // changed basis, such as active-constraint tangent projection.
294    let simplicial = factorize_simplicial_canonical_upper(&h_upper)?;
295    let logdet = simplicial.logdet;
296    let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
297    if elapsed_ms > 100.0 {
298        log::info!(
299            "[sparse-chol] factorize_sparse_spd | n={} | {:.1}ms",
300            n_input,
301            elapsed_ms
302        );
303    }
304    Ok(SparseExactFactor {
305        factor,
306        simplicial: Arc::new(simplicial),
307        n: h_upper.ncols(),
308        logdet,
309    })
310}
311
312fn canonicalize_sparse_symmetric_upper(
313    matrix: &SparseColMat<usize, f64>,
314    tol: f64,
315) -> Result<SparseColMat<usize, f64>, LinalgError> {
316    if matrix.nrows() != matrix.ncols() {
317        bail_invalid_linalg!(
318            "sparse SPD factorization requires square matrix, got {}x{}",
319            matrix.nrows(),
320            matrix.ncols()
321        );
322    }
323
324    #[derive(Default, Clone, Copy)]
325    struct PairAccum {
326        upper_sum: f64,
327        upper_count: usize,
328        lower_sum: f64,
329        lower_count: usize,
330    }
331
332    let mut accum: BTreeMap<(usize, usize), PairAccum> = BTreeMap::new();
333    let (symbolic, values) = matrix.parts();
334    let col_ptr = symbolic.col_ptr();
335    let row_idx = symbolic.row_idx();
336
337    for col in 0..matrix.ncols() {
338        let start = col_ptr[col];
339        let end = col_ptr[col + 1];
340        for idx in start..end {
341            let row = row_idx[idx];
342            let value = values[idx];
343            let (r, c, is_upper) = if row <= col {
344                (row, col, true)
345            } else {
346                (col, row, false)
347            };
348            let slot = accum.entry((r, c)).or_default();
349            if is_upper {
350                slot.upper_sum += value;
351                slot.upper_count += 1;
352            } else {
353                slot.lower_sum += value;
354                slot.lower_count += 1;
355            }
356        }
357    }
358
359    let mut triplets = Vec::<Triplet<usize, usize, f64>>::new();
360    for ((row, col), slot) in accum {
361        let value = if row == col {
362            let count = slot.upper_count + slot.lower_count;
363            if count == 0 {
364                0.0
365            } else {
366                (slot.upper_sum + slot.lower_sum) / (count as f64)
367            }
368        } else {
369            let upper_avg = if slot.upper_count > 0 {
370                Some(slot.upper_sum / (slot.upper_count as f64))
371            } else {
372                None
373            };
374            let lower_avg = if slot.lower_count > 0 {
375                Some(slot.lower_sum / (slot.lower_count as f64))
376            } else {
377                None
378            };
379            match (upper_avg, lower_avg) {
380                (Some(u), Some(l)) => 0.5 * (u + l),
381                (Some(u), None) => u,
382                (None, Some(l)) => l,
383                (None, None) => 0.0,
384            }
385        };
386
387        if value.abs() > tol {
388            triplets.push(Triplet::new(row, col, value));
389        }
390    }
391
392    SparseColMat::try_new_from_triplets(matrix.nrows(), matrix.ncols(), &triplets).map_err(|_| {
393        LinalgError::InvalidInput(
394            "failed to canonicalize sparse matrix to symmetric-upper CSC".to_string(),
395        )
396    })
397}
398
399fn solve_view<R, I, F>(
400    factor: &SparseExactFactor,
401    rhs: ArrayView2<'_, f64>,
402    indices: I,
403    mut result: R,
404    non_finite_message: &'static str,
405    mut consume: F,
406) -> Result<R, LinalgError>
407where
408    I: IntoIterator<Item = (usize, usize)>,
409    F: FnMut(&mut R, usize, usize, f64),
410{
411    let rhsview = FaerArrayView::new(&rhs);
412    let solved = factor.factor.solve(rhsview.as_ref());
413    for (row, col) in indices {
414        let value = solved[(row, col)];
415        if !value.is_finite() {
416            bail_invalid_linalg!("{}", non_finite_message.to_string());
417        }
418        consume(&mut result, row, col, value);
419    }
420    Ok(result)
421}
422
423pub fn solve_sparse_spd<S>(
424    factor: &SparseExactFactor,
425    rhs: &ArrayBase<S, Ix1>,
426) -> Result<Array1<f64>, LinalgError>
427where
428    S: Data<Elem = f64>,
429{
430    if rhs.len() != factor.n {
431        bail_invalid_linalg!(
432            "sparse SPD solve dimension mismatch: rhs has {}, factor has {}",
433            rhs.len(),
434            factor.n
435        );
436    }
437    let mut result = Array1::<f64>::zeros(rhs.len());
438    solve_sparse_spd_into(factor, rhs, &mut result)?;
439    Ok(result)
440}
441
442/// In-place variant of [`solve_sparse_spd`]. Writes the solution directly into
443/// `out`, avoiding the intermediate `Array1` allocation on the hot PIRLS path.
444/// `out` must already be sized to match `factor.n` (typically the reused
445/// Newton-direction buffer).
446pub fn solve_sparse_spd_into<S>(
447    factor: &SparseExactFactor,
448    rhs: &ArrayBase<S, Ix1>,
449    out: &mut Array1<f64>,
450) -> Result<(), LinalgError>
451where
452    S: Data<Elem = f64>,
453{
454    if rhs.len() != factor.n {
455        bail_invalid_linalg!(
456            "sparse SPD solve dimension mismatch: rhs has {}, factor has {}",
457            rhs.len(),
458            factor.n
459        );
460    }
461    if out.len() != factor.n {
462        bail_invalid_linalg!(
463            "sparse SPD solve output dimension mismatch: out has {}, factor has {}",
464            out.len(),
465            factor.n
466        );
467    }
468    let rhsview = FaerColView::new(rhs);
469    let solved = factor.factor.solve(rhsview.as_ref());
470    for i in 0..factor.n {
471        let value = solved[(i, 0)];
472        if !value.is_finite() {
473            bail_invalid_linalg!("sparse SPD solve produced non-finite values");
474        }
475        out[i] = value;
476    }
477    Ok(())
478}
479
480pub fn solve_sparse_spdmulti<S>(
481    factor: &SparseExactFactor,
482    rhs: &ArrayBase<S, Ix2>,
483) -> Result<Array2<f64>, LinalgError>
484where
485    S: Data<Elem = f64>,
486{
487    if rhs.nrows() != factor.n {
488        bail_invalid_linalg!(
489            "sparse SPD multi-solve row mismatch: rhs has {}, factor has {}",
490            rhs.nrows(),
491            factor.n
492        );
493    }
494    let indices = (0..rhs.nrows()).flat_map(|i| (0..rhs.ncols()).map(move |j| (i, j)));
495    solve_view(
496        factor,
497        rhs.view(),
498        indices,
499        Array2::<f64>::zeros(rhs.raw_dim()),
500        "sparse SPD multi-solve produced non-finite values",
501        |result, row, col, value| {
502            result[[row, col]] = value;
503        },
504    )
505}
506
507pub fn solve_sparse_spdmulti_rows<S>(
508    factor: &SparseExactFactor,
509    rhs: &ArrayBase<S, Ix2>,
510    row_start: usize,
511    row_end: usize,
512) -> Result<Array2<f64>, LinalgError>
513where
514    S: Data<Elem = f64>,
515{
516    if rhs.nrows() != factor.n {
517        bail_invalid_linalg!(
518            "sparse SPD multi-solve row mismatch: rhs has {}, factor has {}",
519            rhs.nrows(),
520            factor.n
521        );
522    }
523    if row_start > row_end || row_end > factor.n {
524        bail_invalid_linalg!(
525            "sparse SPD selected rows out of bounds: row_start={}, row_end={}, factor={}",
526            row_start,
527            row_end,
528            factor.n
529        );
530    }
531    let indices = (row_start..row_end).flat_map(|i| (0..rhs.ncols()).map(move |j| (i, j)));
532    solve_view(
533        factor,
534        rhs.view(),
535        indices,
536        Array2::<f64>::zeros((row_end - row_start, rhs.ncols())),
537        "sparse SPD selected-row solve produced non-finite values",
538        |result, row, col, value| {
539            result[[row - row_start, col]] = value;
540        },
541    )
542}
543
544pub fn solve_sparse_spdmulti_diagonal_sum<S>(
545    factor: &SparseExactFactor,
546    rhs: &ArrayBase<S, Ix2>,
547    row_start: usize,
548) -> Result<f64, LinalgError>
549where
550    S: Data<Elem = f64>,
551{
552    if row_start.saturating_add(rhs.ncols()) > rhs.nrows() {
553        bail_invalid_linalg!(
554            "sparse SPD selected diagonal out of bounds: row_start={}, rows={}, cols={}",
555            row_start,
556            rhs.nrows(),
557            rhs.ncols()
558        );
559    }
560    let indices = (0..rhs.ncols()).map(|col| (row_start + col, col));
561    solve_view(
562        factor,
563        rhs.view(),
564        indices,
565        0.0,
566        "sparse SPD selected diagonal solve produced non-finite values",
567        |sum, _, _, value| {
568            *sum += value;
569        },
570    )
571}
572
573pub fn logdet_from_factor(factor: &SparseExactFactor) -> Result<f64, LinalgError> {
574    Ok(factor.logdet)
575}
576
577pub fn assemble_sparse_factor_h_dense(
578    factor: &SparseExactFactor,
579) -> Result<Array2<f64>, LinalgError> {
580    factor.simplicial.assemble_h_dense_original_order()
581}
582
583// ---------------------------------------------------------------------------
584// Takahashi selected inversion via simplicial Cholesky
585// ---------------------------------------------------------------------------
586
587use faer::dyn_stack::{MemBuffer, MemStack, StackReq};
588use faer::linalg::cholesky::llt::factor::LltRegularization;
589use faer::sparse::linalg::amd;
590use faer::sparse::linalg::cholesky::simplicial;
591
592/// A simplicial Cholesky factorization with raw access to L's CSC pattern and
593/// values, plus the AMD permutation.  Built using faer's low-level simplicial
594/// API so that L's sparse structure is directly available for Takahashi
595/// selected inversion.
596pub struct SimplicialFactor {
597    /// Column pointers of L (lower triangular, CSC), length n+1
598    l_col_ptr: Vec<usize>,
599    /// Row indices of L (lower triangular, CSC), length nnz(L)
600    l_row_idx: Vec<usize>,
601    /// Numeric values of L, length nnz(L)
602    l_values: Vec<f64>,
603    /// Inverse permutation returned by faer, used to map original coordinates
604    /// into the permuted simplicial factor basis.
605    perm_inv: Vec<usize>,
606    /// Dimension
607    n: usize,
608    /// log|H| = 2 * sum(log(L_ii))
609    pub logdet: f64,
610}
611
612/// Build a [`SimplicialFactor`] from a symmetric CSC matrix (upper, lower, or
613/// full storage – it is canonicalized to symmetric-upper internally).
614///
615/// The factorization uses AMD fill-reducing ordering and faer's simplicial
616/// LLᵀ numeric factorization.
617pub fn factorize_simplicial(h: &SparseColMat<usize, f64>) -> Result<SimplicialFactor, LinalgError> {
618    let h_upper = canonicalize_sparse_symmetric_upper(h, ZERO_TOL)?;
619    factorize_simplicial_canonical_upper(&h_upper)
620}
621
622fn factorize_simplicial_canonical_upper(
623    h_upper: &SparseColMat<usize, f64>,
624) -> Result<SimplicialFactor, LinalgError> {
625    let n = h_upper.ncols();
626    if n == 0 {
627        return Ok(SimplicialFactor {
628            l_col_ptr: vec![0],
629            l_row_idx: Vec::new(),
630            l_values: Vec::new(),
631            perm_inv: Vec::new(),
632            n: 0,
633            logdet: 0.0,
634        });
635    }
636
637    let a_nnz = h_upper.compute_nnz();
638
639    // 1. AMD ordering
640    let mut perm_fwd = vec![0usize; n];
641    let mut perm_inv = vec![0usize; n];
642    {
643        let mut mem = MemBuffer::new(amd::order_scratch::<usize>(n, a_nnz));
644        amd::order(
645            &mut perm_fwd,
646            &mut perm_inv,
647            h_upper.symbolic(),
648            amd::Control::default(),
649            MemStack::new(&mut mem),
650        )
651        .map_err(|_| LinalgError::ModelIsIllConditioned {
652            condition_number: f64::INFINITY,
653        })?;
654    }
655
656    // perm_fwd and perm_inv have length n and were just populated by
657    // amd::order above for a valid symmetric n×n CSC matrix. On success,
658    // amd::order writes a valid permutation of 0..n into perm_fwd and its
659    // exact inverse into perm_inv.
660    // SAFETY: those are exactly the invariants required by PermRef::new_unchecked.
661    let perm = unsafe { faer::perm::PermRef::new_unchecked(&perm_fwd, &perm_inv, n) };
662
663    // 2. Permute to P A Pᵀ (upper-triangular, unsorted)
664    let a_perm_upper = {
665        let mut col_ptrs = vec![0usize; n + 1];
666        let mut row_indices = vec![0usize; a_nnz];
667        let mut values = vec![0.0f64; a_nnz];
668        let mut mem = MemBuffer::new(faer::sparse::utils::permute_self_adjoint_scratch::<usize>(
669            n,
670        ));
671        faer::sparse::utils::permute_self_adjoint_to_unsorted(
672            &mut values,
673            &mut col_ptrs,
674            &mut row_indices,
675            h_upper.as_ref(),
676            perm,
677            Side::Upper,
678            Side::Upper,
679            MemStack::new(&mut mem),
680        );
681        SparseColMat::<usize, f64>::new(
682            // col_ptrs and row_indices were just produced into preallocated
683            // buffers by permute_self_adjoint_to_unsorted from a valid n×n
684            // symbolic CSC and a valid permutation. That routine writes an
685            // unsorted CSC with col_ptrs length n + 1, monotone column ranges
686            // within row_indices, and every row index in 0..n.
687            // SAFETY: those are the hard SymbolicSparseColMat invariants; the
688            // following faer symbolic Cholesky routines accept this unsorted
689            // self-adjoint permutation.
690            unsafe { SymbolicSparseColMat::new_unchecked(n, n, col_ptrs, None, row_indices) },
691            values,
692        )
693    };
694
695    // 3. Symbolic analysis
696    let symbolic = {
697        let mut mem = MemBuffer::new(StackReq::any_of(&[
698            simplicial::prefactorize_symbolic_cholesky_scratch::<usize>(n, a_nnz),
699            simplicial::factorize_simplicial_symbolic_cholesky_scratch::<usize>(n),
700        ]));
701        let stack = MemStack::new(&mut mem);
702        let mut etree = vec![0isize; n];
703        let mut col_counts = vec![0usize; n];
704        let etree_ref = simplicial::prefactorize_symbolic_cholesky(
705            &mut etree,
706            &mut col_counts,
707            a_perm_upper.symbolic(),
708            stack,
709        );
710        simplicial::factorize_simplicial_symbolic_cholesky(
711            a_perm_upper.symbolic(),
712            etree_ref,
713            &col_counts,
714            stack,
715        )
716        .map_err(|_| LinalgError::ModelIsIllConditioned {
717            condition_number: f64::INFINITY,
718        })?
719    };
720
721    // 4. Numeric LLᵀ factorization
722    let mut l_values = vec![0.0f64; symbolic.len_val()];
723    {
724        let mut mem = MemBuffer::new(simplicial::factorize_simplicial_numeric_llt_scratch::<
725            usize,
726            f64,
727        >(n));
728        simplicial::factorize_simplicial_numeric_llt::<usize, f64>(
729            &mut l_values,
730            a_perm_upper.as_ref(),
731            LltRegularization::default(),
732            &symbolic,
733            MemStack::new(&mut mem),
734        )
735        .map_err(|_| LinalgError::HessianNotPositiveDefinite {
736            min_eigenvalue: f64::NAN,
737        })?;
738    }
739
740    // 5. Extract col_ptr, row_idx from the symbolic structure
741    let l_col_ptr: Vec<usize> = symbolic.col_ptr().to_vec();
742    let l_row_idx: Vec<usize> = symbolic.row_idx().to_vec();
743
744    // 6. Compute logdet from L diagonal: L[j,j] = l_values[l_col_ptr[j]]
745    let mut logdet = 0.0f64;
746    for j in 0..n {
747        let diag = l_values[l_col_ptr[j]];
748        if diag <= 0.0 {
749            return Err(LinalgError::HessianNotPositiveDefinite {
750                min_eigenvalue: f64::NAN,
751            });
752        }
753        logdet += diag.ln();
754    }
755    logdet *= 2.0;
756
757    Ok(SimplicialFactor {
758        l_col_ptr,
759        l_row_idx,
760        l_values,
761        perm_inv,
762        n,
763        logdet,
764    })
765}
766
767impl SimplicialFactor {
768    /// Reconstruct the original-order dense SPD matrix represented by this
769    /// permuted sparse Cholesky factor.
770    ///
771    /// The simplicial factor stores `L` for `P H Pᵀ = L Lᵀ`, with
772    /// `perm_inv[original] = permuted`. We first assemble the dense permuted
773    /// product and then map rows/columns back to the caller's coordinate order.
774    fn assemble_h_dense_original_order(&self) -> Result<Array2<f64>, LinalgError> {
775        if self.perm_inv.len() != self.n {
776            bail_invalid_linalg!(
777                "simplicial factor permutation length {} does not match dimension {}",
778                self.perm_inv.len(),
779                self.n
780            );
781        }
782        let mut h_permuted = Array2::<f64>::zeros((self.n, self.n));
783        for col in 0..self.n {
784            let start = self.l_col_ptr[col];
785            let end = self.l_col_ptr[col + 1];
786            for left_idx in start..end {
787                let left_row = self.l_row_idx[left_idx];
788                let left_value = self.l_values[left_idx];
789                if !left_value.is_finite() {
790                    bail_invalid_linalg!(
791                        "simplicial factor has non-finite L entry at value index {left_idx}"
792                    );
793                }
794                for right_idx in start..end {
795                    let right_row = self.l_row_idx[right_idx];
796                    let right_value = self.l_values[right_idx];
797                    h_permuted[[left_row, right_row]] += left_value * right_value;
798                }
799            }
800        }
801
802        let mut h_original = Array2::<f64>::zeros((self.n, self.n));
803        for i in 0..self.n {
804            let pi = self.perm_inv[i];
805            if pi >= self.n {
806                bail_invalid_linalg!(
807                    "simplicial factor permutation maps row {i} to out-of-bounds index {pi}"
808                );
809            }
810            for j in 0..self.n {
811                let pj = self.perm_inv[j];
812                if pj >= self.n {
813                    bail_invalid_linalg!(
814                        "simplicial factor permutation maps column {j} to out-of-bounds index {pj}"
815                    );
816                }
817                let value = h_permuted[[pi, pj]];
818                if !value.is_finite() {
819                    bail_invalid_linalg!(
820                        "dense reconstruction from sparse Cholesky produced non-finite values"
821                    );
822                }
823                h_original[[i, j]] = value;
824            }
825        }
826        Ok(h_original)
827    }
828}
829
830/// Result of the Takahashi selected inversion.
831///
832/// Z stores entries of H⁻¹ at positions corresponding to the filled sparsity
833/// pattern of the Cholesky factor L. Off-pattern entries are recovered exactly
834/// on demand by cached column solves against the same simplicial factor.
835pub struct TakahashiInverse {
836    /// Z values stored in the same CSC pattern as L (lower triangular)
837    z_values: Vec<f64>,
838    /// Column pointers (owned copy from L)
839    col_ptr: Vec<usize>,
840    /// Row indices (owned copy from L)
841    row_idx: Vec<usize>,
842    /// Numeric values of the simplicial Cholesky factor L.
843    l_values: Vec<f64>,
844    /// Row-oriented access to L for forward solves in the permuted basis.
845    rows_lower: Arc<Vec<Vec<(usize, f64)>>>,
846    /// Exact inverse columns solved on demand for entries outside the selected
847    /// inverse pattern. Keys are permuted-basis column indices.
848    exact_columns: Mutex<BTreeMap<usize, Arc<Vec<f64>>>>,
849    /// Inverse permutation returned by faer.
850    perm_inv: Vec<usize>,
851    /// Dimension
852    n: usize,
853}
854
855impl TakahashiInverse {
856    /// Binary search for entry (row, col) in lower-triangular CSC.
857    /// Returns the value-array index if the entry exists.
858    fn find_entry(col_ptr: &[usize], row_idx: &[usize], row: usize, col: usize) -> Option<usize> {
859        let start = col_ptr[col];
860        let end = col_ptr[col + 1];
861        let slice = &row_idx[start..end];
862        slice.binary_search(&row).ok().map(|pos| start + pos)
863    }
864
865    fn solve_permuted_column_from_cholesky(
866        n: usize,
867        col_ptr: &[usize],
868        row_idx: &[usize],
869        l_values: &[f64],
870        rows_lower: &[Vec<(usize, f64)>],
871        rhs_col: usize,
872    ) -> Vec<f64> {
873        let mut rhs = vec![0.0f64; n];
874        rhs[rhs_col] = 1.0;
875        let mut forward = vec![0.0f64; n];
876        let mut solution = vec![0.0f64; n];
877
878        for row in 0..n {
879            let mut sum = rhs[row];
880            let mut diag = None;
881            for &(col, value) in &rows_lower[row] {
882                if col < row {
883                    sum -= value * forward[col];
884                } else if col == row {
885                    diag = Some(value);
886                }
887            }
888            let l_rr = diag.expect("simplicial factor row should contain its diagonal");
889            forward[row] = sum / l_rr;
890        }
891
892        for row in (0..n).rev() {
893            let col_start = col_ptr[row];
894            let col_end = col_ptr[row + 1];
895            let mut sum = forward[row];
896            let l_rr = l_values[col_start];
897            for idx in (col_start + 1)..col_end {
898                let lower_row = row_idx[idx];
899                sum -= l_values[idx] * solution[lower_row];
900            }
901            solution[row] = sum / l_rr;
902        }
903
904        solution
905    }
906
907    fn exact_permuted_column(&self, col: usize) -> Arc<Vec<f64>> {
908        {
909            let cache = self
910                .exact_columns
911                .lock()
912                .expect("exact Takahashi column cache mutex poisoned");
913            if let Some(solution) = cache.get(&col) {
914                return solution.clone();
915            }
916        }
917
918        let solution = Arc::new(Self::solve_permuted_column_from_cholesky(
919            self.n,
920            &self.col_ptr,
921            &self.row_idx,
922            &self.l_values,
923            self.rows_lower.as_ref(),
924            col,
925        ));
926
927        let mut cache = self
928            .exact_columns
929            .lock()
930            .expect("exact Takahashi column cache mutex poisoned");
931        cache.entry(col).or_insert_with(|| solution.clone()).clone()
932    }
933
934    fn selected_value(
935        z_values: &[f64],
936        col_ptr: &[usize],
937        row_idx: &[usize],
938        row: usize,
939        col: usize,
940    ) -> Result<f64, LinalgError> {
941        let (lower_row, lower_col) = if row >= col { (row, col) } else { (col, row) };
942        Self::find_entry(col_ptr, row_idx, lower_row, lower_col)
943            .map(|idx| z_values[idx])
944            .ok_or_else(|| {
945                LinalgError::InvalidInput(format!(
946                    "simplicial selected-inverse pattern is missing entry ({lower_row},{lower_col})"
947                ))
948            })
949    }
950
951    /// Compute the selected inverse from a simplicial Cholesky factor.
952    ///
953    /// Given H = LLᵀ in the permuted basis, this applies the Takahashi
954    /// recurrence on the filled Cholesky pattern. Off-pattern exact entries are
955    /// recovered later by cached column solves from the same simplicial factor.
956    pub fn compute(factor: &SimplicialFactor) -> Result<Self, LinalgError> {
957        let n = factor.n;
958        let col_ptr = factor.l_col_ptr.clone();
959        let row_idx = factor.l_row_idx.clone();
960        let nnz = factor.l_values.len();
961        let mut z_values = vec![0.0f64; nnz];
962
963        // Build row access for forward solves in the permuted basis.
964        let mut rows_lower: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
965        for col in 0..n {
966            for idx in col_ptr[col]..col_ptr[col + 1] {
967                let row = row_idx[idx];
968                rows_lower[row].push((col, factor.l_values[idx]));
969            }
970        }
971
972        for j in (0..n).rev() {
973            let diag_idx = col_ptr[j];
974            let col_end = col_ptr[j + 1];
975            let diag = factor.l_values[diag_idx];
976            if !(diag.is_finite() && diag > 0.0) {
977                return Err(LinalgError::HessianNotPositiveDefinite {
978                    min_eigenvalue: f64::NAN,
979                });
980            }
981            for idx in (diag_idx + 1)..col_end {
982                let i = row_idx[idx];
983                let mut correction = 0.0;
984                for off_idx in (diag_idx + 1)..col_end {
985                    let k = row_idx[off_idx];
986                    let l_kj = factor.l_values[off_idx];
987                    let z_ik = Self::selected_value(&z_values, &col_ptr, &row_idx, i, k)?;
988                    correction += l_kj * z_ik;
989                }
990                let value = -correction / diag;
991                if !value.is_finite() {
992                    bail_invalid_linalg!(
993                        "Takahashi selected inverse produced non-finite entry ({i},{j})"
994                    );
995                }
996                z_values[idx] = value;
997            }
998            let mut correction = 0.0;
999            for off_idx in (diag_idx + 1)..col_end {
1000                correction += factor.l_values[off_idx] * z_values[off_idx];
1001            }
1002            let value = (1.0 / diag - correction) / diag;
1003            if !value.is_finite() {
1004                bail_invalid_linalg!(
1005                    "Takahashi selected inverse produced non-finite diagonal entry ({j},{j})"
1006                );
1007            }
1008            z_values[diag_idx] = value;
1009        }
1010
1011        Ok(TakahashiInverse {
1012            z_values,
1013            col_ptr,
1014            row_idx,
1015            l_values: factor.l_values.clone(),
1016            rows_lower: Arc::new(rows_lower),
1017            exact_columns: Mutex::new(BTreeMap::new()),
1018            perm_inv: factor.perm_inv.clone(),
1019            n,
1020        })
1021    }
1022
1023    /// Get H⁻¹[i,j] in ORIGINAL (unpermuted) coordinates.
1024    pub fn get(&self, i: usize, j: usize) -> f64 {
1025        let pi = self.perm_inv[i];
1026        let pj = self.perm_inv[j];
1027        self.get_permuted(pi, pj)
1028    }
1029
1030    /// Get Z[pi,pj] in permuted coordinates.
1031    fn get_permuted(&self, pi: usize, pj: usize) -> f64 {
1032        // Z is symmetric and stored as lower-triangular CSC.
1033        // Ensure row >= col for lookup.
1034        let (row, col) = if pi >= pj { (pi, pj) } else { (pj, pi) };
1035        if let Some(pos) = Self::find_entry(&self.col_ptr, &self.row_idx, row, col) {
1036            self.z_values[pos]
1037        } else {
1038            self.exact_permuted_column(col)[row]
1039        }
1040    }
1041
1042    /// Diagonal of H⁻¹ in original ordering.
1043    pub fn diagonal(&self) -> Array1<f64> {
1044        Array1::from_iter((0..self.n).map(|i| self.get(i, i)))
1045    }
1046
1047    /// H⁻¹[start..end, start..end] block in original ordering.
1048    pub fn block(&self, start: usize, end: usize) -> Array2<f64> {
1049        let dim = end - start;
1050        let mut out = Array2::zeros((dim, dim));
1051        for j_local in 0..dim {
1052            let j = start + j_local;
1053            for i_local in 0..dim {
1054                let i = start + i_local;
1055                out[[i_local, j_local]] = self.get(i, j);
1056            }
1057        }
1058        out
1059    }
1060
1061    /// tr(H⁻¹ S) where S is given as sparse CSC, symmetric in either upper-
1062    /// triangle-only or full (both triangles stored) format.
1063    ///
1064    /// The algorithm iterates over the upper triangle of S (entries with
1065    /// row ≤ col), doubles off-diagonals, and skips lower-triangle entries.
1066    /// This is correct for both storage conventions:
1067    ///
1068    /// - **Upper-triangle-only** (for example, solver-owned sparse penalty blocks):
1069    ///   every off-diagonal pair has exactly one stored entry with row < col,
1070    ///   which we double.
1071    ///
1072    /// - **Full symmetric** (from `dense_to_sparse`): each off-diagonal pair
1073    ///   has entries at both (i,j) and (j,i).  We process only the row < col
1074    ///   entry and double it; the row > col mirror is skipped.  The diagonal
1075    ///   is stored once and counted once.
1076    ///
1077    /// In both cases: tr(Z S) = Σ_diag Z[i,i] S[i,i] + 2 Σ_{i<j} Z[i,j] S[i,j].
1078    pub fn trace_product_sparse(&self, s: &SparseColMat<usize, f64>) -> f64 {
1079        let (symbolic, values) = s.parts();
1080        let s_col_ptr = symbolic.col_ptr();
1081        let s_row_idx = symbolic.row_idx();
1082        let mut trace = 0.0;
1083        for col in 0..s.ncols() {
1084            let col_start = s_col_ptr[col];
1085            let col_end = s_col_ptr[col + 1];
1086            for idx in col_start..col_end {
1087                let row = s_row_idx[idx];
1088                if row > col {
1089                    continue; // skip lower triangle (handled via its mirror)
1090                }
1091                let val = values[idx];
1092                let z_ij = self.get(row, col);
1093                if row == col {
1094                    trace += z_ij * val;
1095                } else {
1096                    trace += 2.0 * z_ij * val;
1097                }
1098            }
1099        }
1100        trace
1101    }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106    use super::*;
1107    use crate::faer_ndarray::FaerCholesky;
1108    use ndarray::{array, Array1, Array2};
1109
1110    fn approx_eq(a: f64, b: f64, tol: f64) {
1111        assert!(
1112            (a - b).abs() <= tol,
1113            "values differ: left={a:.12e}, right={b:.12e}, |diff|={:.12e}, tol={tol:.12e}",
1114            (a - b).abs()
1115        );
1116    }
1117
1118    // ── dense_to_sparse ───────────────────────────────────────────────────
1119
1120    #[test]
1121    fn dense_to_sparse_preserves_all_nonzero_entries() {
1122        // 3x3 matrix with a zero at (1,0) and all others nonzero.
1123        let m = array![[1.0, 2.0, 3.0], [0.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
1124        let s = dense_to_sparse(&m, ZERO_TOL).unwrap();
1125        assert_eq!(s.nrows(), 3);
1126        assert_eq!(s.ncols(), 3);
1127        // 8 entries should be stored (one zero excluded).
1128        assert_eq!(s.compute_nnz(), 8);
1129    }
1130
1131    #[test]
1132    fn dense_to_sparse_round_trips_via_matvec_identity() {
1133        // Verify that (sparse A) * e_j == column j of A for each column.
1134        let m = array![[4.0, 1.0, 0.5], [1.0, 3.0, 2.0], [0.5, 2.0, 6.0]];
1135        let s = dense_to_sparse(&m, ZERO_TOL).unwrap();
1136        for j in 0..3 {
1137            let mut ej = Array1::<f64>::zeros(3);
1138            ej[j] = 1.0;
1139            // Multiply via the raw faer sparse multiply.
1140            let result = {
1141                let mut out = Array1::<f64>::zeros(3);
1142                let (sym, vals) = s.parts();
1143                let col_ptr = sym.col_ptr();
1144                let row_idx = sym.row_idx();
1145                for col in 0..3 {
1146                    for idx in col_ptr[col]..col_ptr[col + 1] {
1147                        let row = row_idx[idx];
1148                        out[row] += vals[idx] * ej[col];
1149                    }
1150                }
1151                out
1152            };
1153            for i in 0..3 {
1154                approx_eq(result[i], m[[i, j]], 1e-14);
1155            }
1156        }
1157    }
1158
1159    #[test]
1160    fn dense_to_sparse_filters_entries_below_tolerance() {
1161        let tol = 0.1;
1162        let m = array![[1.0, 0.05], [0.05, 2.0]];
1163        let s = dense_to_sparse(&m, tol).unwrap();
1164        // Only the two diagonal entries exceed tol.
1165        assert_eq!(s.compute_nnz(), 2, "off-diagonal entries below tol must be dropped");
1166    }
1167
1168    // ── dense_to_sparse_symmetric_upper ───────────────────────────────────
1169
1170    #[test]
1171    fn dense_to_sparse_symmetric_upper_stores_only_upper_triangle() {
1172        // Full symmetric 3x3 matrix — only upper triangle (i<=j) should be stored.
1173        let m = array![[4.0, 1.0, 2.0], [1.0, 5.0, 3.0], [2.0, 3.0, 6.0]];
1174        let s = dense_to_sparse_symmetric_upper(&m, ZERO_TOL).unwrap();
1175        // Upper triangle has 3 diagonal + 3 off-diagonal = 6 entries.
1176        assert_eq!(s.compute_nnz(), 6);
1177    }
1178
1179    // ── sparse_symmetric_upper_matvec_public ──────────────────────────────
1180
1181    #[test]
1182    fn sparse_symmetric_upper_matvec_matches_dense_matvec() {
1183        // Symmetric matrix A; upper-sparse encodes only the upper triangle.
1184        // A * v must equal the result of the symmetric matvec.
1185        let a = array![[4.0, 2.0, 0.0], [2.0, 5.0, 3.0], [0.0, 3.0, 6.0]];
1186        let v = array![1.0, 2.0, 3.0];
1187        let expected = a.dot(&v); // dense reference
1188        let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1189        let got = sparse_symmetric_upper_matvec_public(&a_sparse, &v);
1190        for i in 0..3 {
1191            approx_eq(got[i], expected[i], 1e-13);
1192        }
1193    }
1194
1195    #[test]
1196    fn sparse_symmetric_upper_matvec_diagonal_only() {
1197        // Pure diagonal matrix: matvec should scale each component.
1198        let a = array![[3.0, 0.0, 0.0], [0.0, 5.0, 0.0], [0.0, 0.0, 7.0]];
1199        let v = array![2.0, 4.0, 6.0];
1200        let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1201        let got = sparse_symmetric_upper_matvec_public(&a_sparse, &v);
1202        approx_eq(got[0], 6.0, 1e-14);
1203        approx_eq(got[1], 20.0, 1e-14);
1204        approx_eq(got[2], 42.0, 1e-14);
1205    }
1206
1207    // ── solve_sparse_spd / logdet_from_factor ─────────────────────────────
1208
1209    #[test]
1210    fn solve_sparse_spd_recovers_known_solution() {
1211        // A = [[4,2],[2,5]]; A^{-1} b = [0.5, 2.0] for b = [6, 11].
1212        let a = array![[4.0, 2.0], [2.0, 5.0]];
1213        let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1214        let factor = factorize_sparse_spd(&a_sparse).unwrap();
1215        let rhs = array![6.0, 11.0];
1216        let x = solve_sparse_spd(&factor, &rhs).unwrap();
1217        // A^-1 = (1/16)*[[5,-2],[-2,4]]; x = (1/16)*[5*6-2*11, -2*6+4*11] = [0.5, 2.0]
1218        approx_eq(x[0], 0.5, 1e-12);
1219        approx_eq(x[1], 2.0, 1e-12);
1220    }
1221
1222    #[test]
1223    fn solve_sparse_spd_3x3_round_trip() {
1224        let a: Array2<f64> = array![
1225            [9.0, 3.0, 1.0],
1226            [3.0, 8.0, 2.0],
1227            [1.0, 2.0, 7.0]
1228        ];
1229        let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1230        let factor = factorize_sparse_spd(&a_sparse).unwrap();
1231        for j in 0..3 {
1232            let mut ej = Array1::<f64>::zeros(3);
1233            ej[j] = 1.0;
1234            let col_j = solve_sparse_spd(&factor, &ej).unwrap();
1235            // A * x should equal ej.
1236            let ax = a.dot(&col_j);
1237            for i in 0..3 {
1238                approx_eq(ax[i], ej[i], 1e-12);
1239            }
1240        }
1241    }
1242
1243    #[test]
1244    fn logdet_from_factor_matches_dense_logdet_diagonal() {
1245        // Diagonal matrix diag(4,9,16): log-det = log(4)+log(9)+log(16)
1246        let a: Array2<f64> =
1247            array![[4.0, 0.0, 0.0], [0.0, 9.0, 0.0], [0.0, 0.0, 16.0]];
1248        let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1249        let factor = factorize_sparse_spd(&a_sparse).unwrap();
1250        let logdet = logdet_from_factor(&factor).unwrap();
1251        let expected = 4.0_f64.ln() + 9.0_f64.ln() + 16.0_f64.ln();
1252        approx_eq(logdet, expected, 1e-12);
1253    }
1254
1255    #[test]
1256    fn logdet_from_factor_matches_2x2_formula() {
1257        // A = [[4,2],[2,5]]; det(A) = 20-4 = 16; log-det = log(16)
1258        let a = array![[4.0, 2.0], [2.0, 5.0]];
1259        let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1260        let factor = factorize_sparse_spd(&a_sparse).unwrap();
1261        let logdet = logdet_from_factor(&factor).unwrap();
1262        approx_eq(logdet, 16.0_f64.ln(), 1e-12);
1263    }
1264
1265    #[test]
1266    fn solve_sparse_spd_dimension_mismatch_returns_error() {
1267        let a = array![[4.0, 2.0], [2.0, 5.0]];
1268        let a_sparse = dense_to_sparse_symmetric_upper(&a, ZERO_TOL).unwrap();
1269        let factor = factorize_sparse_spd(&a_sparse).unwrap();
1270        let rhs = array![1.0, 2.0, 3.0]; // wrong length
1271        assert!(solve_sparse_spd(&factor, &rhs).is_err());
1272    }
1273
1274    #[test]
1275    fn takahashi_diagonal_matches_dense_inverse() {
1276        // 4x4 SPD matrix
1277        let h = array![
1278            [4.0, 0.2, 0.0, 0.0],
1279            [0.2, 3.0, 0.1, 0.0],
1280            [0.0, 0.1, 2.5, 0.3],
1281            [0.0, 0.0, 0.3, 2.0]
1282        ];
1283        let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
1284
1285        // Dense inverse for reference via column solves
1286        let chol = h.cholesky(Side::Lower).unwrap();
1287        let mut h_inv = Array2::<f64>::zeros((4, 4));
1288        for j in 0..4 {
1289            let mut rhs = Array1::<f64>::zeros(4);
1290            rhs[j] = 1.0;
1291            let col = chol.solvevec(&rhs);
1292            for i in 0..4 {
1293                h_inv[[i, j]] = col[i];
1294            }
1295        }
1296
1297        let sfactor = factorize_simplicial(&h_sparse).unwrap();
1298        let taka = TakahashiInverse::compute(&sfactor).unwrap();
1299        let diag = taka.diagonal();
1300
1301        // Diagonal of selected inverse should match dense inverse diagonal
1302        for i in 0..4 {
1303            approx_eq(diag[i], h_inv[[i, i]], 1e-10);
1304        }
1305    }
1306
1307    #[test]
1308    fn takahashi_logdet_matches_dense() {
1309        let h = array![
1310            [4.0, 0.2, 0.0, 0.0],
1311            [0.2, 3.0, 0.1, 0.0],
1312            [0.0, 0.1, 2.5, 0.3],
1313            [0.0, 0.0, 0.3, 2.0]
1314        ];
1315        let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
1316
1317        // Dense logdet via existing factor
1318        let existing = factorize_sparse_spd(&h_sparse).unwrap();
1319        let logdet_dense = existing.logdet;
1320
1321        let sfactor = factorize_simplicial(&h_sparse).unwrap();
1322        approx_eq(sfactor.logdet, logdet_dense, 1e-10);
1323    }
1324
1325    #[test]
1326    fn takahashi_get_and_block_recover_off_pattern_inverse_entries() {
1327        let h = array![
1328            [4.0, 1.0, 0.0, 0.0],
1329            [1.0, 3.0, 1.0, 0.0],
1330            [0.0, 1.0, 2.5, 1.0],
1331            [0.0, 0.0, 1.0, 2.0]
1332        ];
1333        let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
1334
1335        let chol = h.cholesky(Side::Lower).unwrap();
1336        let mut h_inv = Array2::<f64>::zeros((4, 4));
1337        for j in 0..4 {
1338            let mut rhs = Array1::<f64>::zeros(4);
1339            rhs[j] = 1.0;
1340            let col = chol.solvevec(&rhs);
1341            for i in 0..4 {
1342                h_inv[[i, j]] = col[i];
1343            }
1344        }
1345
1346        let sfactor = factorize_simplicial(&h_sparse).unwrap();
1347        let taka = TakahashiInverse::compute(&sfactor).unwrap();
1348
1349        assert!(
1350            h_inv[[0, 2]].abs() > 1e-8,
1351            "reference off-pattern inverse entry should be nonzero"
1352        );
1353        approx_eq(taka.get(0, 2), h_inv[[0, 2]], 1e-10);
1354
1355        let block = taka.block(0, 3);
1356        approx_eq(block[[0, 2]], h_inv[[0, 2]], 1e-10);
1357        approx_eq(block[[2, 0]], h_inv[[2, 0]], 1e-10);
1358    }
1359}