use crate::estimate::EstimationError;
use crate::faer_ndarray::{FaerArrayView, FaerColView};
use crate::solver::pirls::{PirlsWorkspace, sparse_reml_penalized_hessian};
use faer::Side;
use faer::linalg::solvers::Solve;
use faer::sparse::linalg::solvers::Llt as SparseLlt;
use faer::sparse::{SparseColMat, SymbolicSparseColMat, Triplet};
use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2};
use rayon::prelude::*;
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
const ZERO_TOL: f64 = 1e-12;
const PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD: usize = 64;
#[derive(Clone)]
pub struct SparseExactFactor {
factor: SparseLlt<usize, f64>,
n: usize,
logdet: f64,
}
impl crate::matrix::FactorizedSystem for SparseExactFactor {
fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String> {
solve_sparse_spd(self, rhs).map_err(|e| e.to_string())
}
fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String> {
solve_sparse_spdmulti(self, rhs).map_err(|e| e.to_string())
}
fn logdet(&self) -> f64 {
self.logdet
}
}
#[derive(Clone)]
pub struct SparsePenaltyBlock {
pub term_index: usize,
pub p_start: usize,
pub p_end: usize,
pub positive_eigenvalues: Arc<Vec<f64>>,
pub block_support_strict: bool,
pub s_k_sparse: SparseColMat<usize, f64>,
pub s_k_block_dense: Arc<Array2<f64>>,
pub s_k_block_upper_entries: Arc<Vec<(usize, usize, f64)>>,
}
#[derive(Clone)]
pub struct SparsePenalizedSystem {
pub h_sparse: SparseColMat<usize, f64>,
pub factor: SparseExactFactor,
pub logdet_h: f64,
}
pub fn dense_to_sparse(
matrix: &Array2<f64>,
tol: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
let nrows = matrix.nrows();
let ncols = matrix.ncols();
let counts: Vec<usize> = (0..ncols)
.into_par_iter()
.map(|col| {
let mut count = 0usize;
for row in 0..nrows {
if matrix[[row, col]].abs() > tol {
count += 1;
}
}
count
})
.collect();
let col_ptr = prefix_sum_counts(&counts);
let nnz = col_ptr[ncols];
let mut row_idx = vec![0usize; nnz];
let mut values = vec![0.0; nnz];
fill_dense_to_sparse_columns(matrix, tol, 0, ncols, &col_ptr, &mut row_idx, &mut values);
let symbolic = SymbolicSparseColMat::<usize>::new_checked(nrows, ncols, col_ptr, None, row_idx);
Ok(SparseColMat::<usize, f64>::new(symbolic, values))
}
fn embed_dense_block_to_sparse_symmetric_upper(
local: &Array2<f64>,
offset: usize,
total_dim: usize,
tol: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
let block_n = local.nrows();
if local.ncols() != block_n {
return Err(EstimationError::InvalidInput(
"embed_dense_block_to_sparse_symmetric_upper requires a square block".to_string(),
));
}
if offset + block_n > total_dim {
return Err(EstimationError::InvalidInput(
"embed_dense_block_to_sparse_symmetric_upper offset+block exceeds total_dim"
.to_string(),
));
}
let counts: Vec<usize> = (0..block_n)
.into_par_iter()
.map(|c| {
let mut count = 0usize;
for r in 0..=c {
if local[[r, c]].abs() > tol {
count += 1;
}
}
count
})
.collect();
let block_col_ptr = prefix_sum_counts(&counts);
let mut col_ptr = vec![0usize; total_dim + 1];
for c in 0..block_n {
col_ptr[offset + c + 1] = block_col_ptr[c + 1];
}
let nnz_in_block_end = block_col_ptr[block_n];
for c in (offset + block_n)..total_dim {
col_ptr[c + 1] = nnz_in_block_end;
}
let nnz = col_ptr[total_dim];
let mut row_idx = vec![0usize; nnz];
let mut values = vec![0.0; nnz];
fill_embedded_symmetric_upper_columns(
local,
offset,
tol,
0,
block_n,
&block_col_ptr,
&mut row_idx,
&mut values,
);
let symbolic =
SymbolicSparseColMat::<usize>::new_checked(total_dim, total_dim, col_ptr, None, row_idx);
Ok(SparseColMat::<usize, f64>::new(symbolic, values))
}
pub fn dense_to_sparse_symmetric_upper(
matrix: &Array2<f64>,
tol: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
let nrows = matrix.nrows();
let ncols = matrix.ncols();
let row_limit = nrows.min(ncols);
let counts: Vec<usize> = (0..ncols)
.into_par_iter()
.map(|col| {
let mut count = 0usize;
let row_end = (col + 1).min(row_limit);
for row in 0..row_end {
if matrix[[row, col]].abs() > tol {
count += 1;
}
}
count
})
.collect();
let col_ptr = prefix_sum_counts(&counts);
let nnz = col_ptr[ncols];
let mut row_idx = vec![0usize; nnz];
let mut values = vec![0.0; nnz];
fill_dense_symmetric_upper_columns(
matrix,
tol,
row_limit,
0,
ncols,
&col_ptr,
&mut row_idx,
&mut values,
);
let symbolic = SymbolicSparseColMat::<usize>::new_checked(nrows, ncols, col_ptr, None, row_idx);
Ok(SparseColMat::<usize, f64>::new(symbolic, values))
}
fn prefix_sum_counts(counts: &[usize]) -> Vec<usize> {
let mut col_ptr = Vec::with_capacity(counts.len() + 1);
col_ptr.push(0);
let mut running = 0usize;
for &count in counts {
running += count;
col_ptr.push(running);
}
col_ptr
}
fn fill_dense_to_sparse_columns(
matrix: &Array2<f64>,
tol: f64,
col_start: usize,
col_end: usize,
col_ptr: &[usize],
row_idx: &mut [usize],
values: &mut [f64],
) {
if col_end - col_start <= PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD {
let base = col_ptr[col_start];
for col in col_start..col_end {
let mut write = col_ptr[col] - base;
for row in 0..matrix.nrows() {
let value = matrix[[row, col]];
if value.abs() > tol {
row_idx[write] = row;
values[write] = value;
write += 1;
}
}
}
return;
}
let mid = col_start + (col_end - col_start) / 2;
let split = col_ptr[mid] - col_ptr[col_start];
let (left_rows, right_rows) = row_idx.split_at_mut(split);
let (left_values, right_values) = values.split_at_mut(split);
rayon::join(
|| {
fill_dense_to_sparse_columns(
matrix,
tol,
col_start,
mid,
col_ptr,
left_rows,
left_values,
);
},
|| {
fill_dense_to_sparse_columns(
matrix,
tol,
mid,
col_end,
col_ptr,
right_rows,
right_values,
);
},
);
}
fn fill_dense_symmetric_upper_columns(
matrix: &Array2<f64>,
tol: f64,
row_limit: usize,
col_start: usize,
col_end: usize,
col_ptr: &[usize],
row_idx: &mut [usize],
values: &mut [f64],
) {
if col_end - col_start <= PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD {
let base = col_ptr[col_start];
for col in col_start..col_end {
let row_end = (col + 1).min(row_limit);
let mut write = col_ptr[col] - base;
for row in 0..row_end {
let value = matrix[[row, col]];
if value.abs() > tol {
row_idx[write] = row;
values[write] = value;
write += 1;
}
}
}
return;
}
let mid = col_start + (col_end - col_start) / 2;
let split = col_ptr[mid] - col_ptr[col_start];
let (left_rows, right_rows) = row_idx.split_at_mut(split);
let (left_values, right_values) = values.split_at_mut(split);
rayon::join(
|| {
fill_dense_symmetric_upper_columns(
matrix,
tol,
row_limit,
col_start,
mid,
col_ptr,
left_rows,
left_values,
);
},
|| {
fill_dense_symmetric_upper_columns(
matrix,
tol,
row_limit,
mid,
col_end,
col_ptr,
right_rows,
right_values,
);
},
);
}
fn fill_embedded_symmetric_upper_columns(
local: &Array2<f64>,
offset: usize,
tol: f64,
col_start: usize,
col_end: usize,
col_ptr: &[usize],
row_idx: &mut [usize],
values: &mut [f64],
) {
if col_end - col_start <= PARALLEL_SPARSE_FILL_COLUMN_THRESHOLD {
let base = col_ptr[col_start];
for col in col_start..col_end {
let mut write = col_ptr[col] - base;
for row in 0..=col {
let value = local[[row, col]];
if value.abs() > tol {
row_idx[write] = offset + row;
values[write] = value;
write += 1;
}
}
}
return;
}
let mid = col_start + (col_end - col_start) / 2;
let split = col_ptr[mid] - col_ptr[col_start];
let (left_rows, right_rows) = row_idx.split_at_mut(split);
let (left_values, right_values) = values.split_at_mut(split);
rayon::join(
|| {
fill_embedded_symmetric_upper_columns(
local,
offset,
tol,
col_start,
mid,
col_ptr,
left_rows,
left_values,
);
},
|| {
fill_embedded_symmetric_upper_columns(
local,
offset,
tol,
mid,
col_end,
col_ptr,
right_rows,
right_values,
);
},
);
}
pub fn sparse_to_dense_symmetric_upper_public(matrix: &SparseColMat<usize, f64>) -> Array2<f64> {
let mut dense = Array2::<f64>::zeros((matrix.nrows(), matrix.ncols()));
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
dense[[row, col]] += value;
if row != col {
dense[[col, row]] += value;
}
}
}
dense
}
pub fn sparse_symmetric_upper_matvec_public<S: Data<Elem = f64>>(
matrix: &SparseColMat<usize, f64>,
vector: &ArrayBase<S, Ix1>,
) -> Array1<f64> {
let mut out = Array1::<f64>::zeros(matrix.nrows());
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
let x_col = vector[col];
for idx in col_ptr[col]..col_ptr[col + 1] {
let row = row_idx[idx];
let value = values[idx];
out[row] += value * x_col;
if row != col {
out[col] += value * vector[row];
}
}
}
out
}
pub fn factorize_sparse_spd(
h: &SparseColMat<usize, f64>,
) -> Result<SparseExactFactor, EstimationError> {
let t_start = std::time::Instant::now();
let n_input = h.ncols();
let h_upper = canonicalize_sparse_symmetric_upper(h, ZERO_TOL)?;
let factor = h_upper.as_ref().sp_cholesky(Side::Upper).map_err(|_| {
EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
}
})?;
let logdet = sparse_spd_logdet_via_simplicial(&h_upper)?;
let elapsed_ms = t_start.elapsed().as_secs_f64() * 1000.0;
if elapsed_ms > 100.0 {
log::info!(
"[sparse-chol] factorize_sparse_spd | n={} | {:.1}ms",
n_input,
elapsed_ms
);
}
Ok(SparseExactFactor {
factor,
n: h_upper.ncols(),
logdet,
})
}
fn canonicalize_sparse_symmetric_upper(
matrix: &SparseColMat<usize, f64>,
tol: f64,
) -> Result<SparseColMat<usize, f64>, EstimationError> {
if matrix.nrows() != matrix.ncols() {
return Err(EstimationError::InvalidInput(format!(
"sparse SPD factorization requires square matrix, got {}x{}",
matrix.nrows(),
matrix.ncols()
)));
}
#[derive(Default, Clone, Copy)]
struct PairAccum {
upper_sum: f64,
upper_count: usize,
lower_sum: f64,
lower_count: usize,
}
let mut accum: BTreeMap<(usize, usize), PairAccum> = BTreeMap::new();
let (symbolic, values) = matrix.parts();
let col_ptr = symbolic.col_ptr();
let row_idx = symbolic.row_idx();
for col in 0..matrix.ncols() {
let start = col_ptr[col];
let end = col_ptr[col + 1];
for idx in start..end {
let row = row_idx[idx];
let value = values[idx];
let (r, c, is_upper) = if row <= col {
(row, col, true)
} else {
(col, row, false)
};
let slot = accum.entry((r, c)).or_default();
if is_upper {
slot.upper_sum += value;
slot.upper_count += 1;
} else {
slot.lower_sum += value;
slot.lower_count += 1;
}
}
}
let mut triplets = Vec::<Triplet<usize, usize, f64>>::new();
for ((row, col), slot) in accum {
let value = if row == col {
let count = slot.upper_count + slot.lower_count;
if count == 0 {
0.0
} else {
(slot.upper_sum + slot.lower_sum) / (count as f64)
}
} else {
let upper_avg = if slot.upper_count > 0 {
Some(slot.upper_sum / (slot.upper_count as f64))
} else {
None
};
let lower_avg = if slot.lower_count > 0 {
Some(slot.lower_sum / (slot.lower_count as f64))
} else {
None
};
match (upper_avg, lower_avg) {
(Some(u), Some(l)) => 0.5 * (u + l),
(Some(u), None) => u,
(None, Some(l)) => l,
(None, None) => 0.0,
}
};
if value.abs() > tol {
triplets.push(Triplet::new(row, col, value));
}
}
SparseColMat::try_new_from_triplets(matrix.nrows(), matrix.ncols(), &triplets).map_err(|_| {
EstimationError::InvalidInput(
"failed to canonicalize sparse matrix to symmetric-upper CSC".to_string(),
)
})
}
pub fn solve_sparse_spd<S>(
factor: &SparseExactFactor,
rhs: &ArrayBase<S, Ix1>,
) -> Result<Array1<f64>, EstimationError>
where
S: Data<Elem = f64>,
{
if rhs.len() != factor.n {
return Err(EstimationError::InvalidInput(format!(
"sparse SPD solve dimension mismatch: rhs has {}, factor has {}",
rhs.len(),
factor.n
)));
}
let rhsview = FaerColView::new(rhs);
let out = factor.factor.solve(rhsview.as_ref());
let mut result = Array1::<f64>::zeros(rhs.len());
for i in 0..rhs.len() {
let value = out[(i, 0)];
if !value.is_finite() {
return Err(EstimationError::InvalidInput(
"sparse SPD solve produced non-finite values".to_string(),
));
}
result[i] = value;
}
Ok(result)
}
pub fn solve_sparse_spdmulti<S>(
factor: &SparseExactFactor,
rhs: &ArrayBase<S, Ix2>,
) -> Result<Array2<f64>, EstimationError>
where
S: Data<Elem = f64>,
{
if rhs.nrows() != factor.n {
return Err(EstimationError::InvalidInput(format!(
"sparse SPD multi-solve row mismatch: rhs has {}, factor has {}",
rhs.nrows(),
factor.n
)));
}
let rhsview = FaerArrayView::new(rhs);
let out = factor.factor.solve(rhsview.as_ref());
let mut result = Array2::<f64>::zeros(rhs.raw_dim());
for i in 0..rhs.nrows() {
for j in 0..rhs.ncols() {
let value = out[(i, j)];
if !value.is_finite() {
return Err(EstimationError::InvalidInput(
"sparse SPD multi-solve produced non-finite values".to_string(),
));
}
result[[i, j]] = value;
}
}
Ok(result)
}
pub fn solve_sparse_spdmulti_rows<S>(
factor: &SparseExactFactor,
rhs: &ArrayBase<S, Ix2>,
row_start: usize,
row_end: usize,
) -> Result<Array2<f64>, EstimationError>
where
S: Data<Elem = f64>,
{
if rhs.nrows() != factor.n {
return Err(EstimationError::InvalidInput(format!(
"sparse SPD multi-solve row mismatch: rhs has {}, factor has {}",
rhs.nrows(),
factor.n
)));
}
if row_start > row_end || row_end > factor.n {
return Err(EstimationError::InvalidInput(format!(
"sparse SPD selected rows out of bounds: row_start={}, row_end={}, factor={}",
row_start, row_end, factor.n
)));
}
let rhsview = FaerArrayView::new(rhs);
let out = factor.factor.solve(rhsview.as_ref());
let mut result = Array2::<f64>::zeros((row_end - row_start, rhs.ncols()));
for i in row_start..row_end {
for j in 0..rhs.ncols() {
let value = out[(i, j)];
if !value.is_finite() {
return Err(EstimationError::InvalidInput(
"sparse SPD selected-row solve produced non-finite values".to_string(),
));
}
result[[i - row_start, j]] = value;
}
}
Ok(result)
}
pub fn solve_sparse_spdmulti_diagonal_sum<S>(
factor: &SparseExactFactor,
rhs: &ArrayBase<S, Ix2>,
row_start: usize,
) -> Result<f64, EstimationError>
where
S: Data<Elem = f64>,
{
if row_start.saturating_add(rhs.ncols()) > rhs.nrows() {
return Err(EstimationError::InvalidInput(format!(
"sparse SPD selected diagonal out of bounds: row_start={}, rows={}, cols={}",
row_start,
rhs.nrows(),
rhs.ncols()
)));
}
let rhsview = FaerArrayView::new(rhs);
let out = factor.factor.solve(rhsview.as_ref());
let mut sum = 0.0;
for col in 0..rhs.ncols() {
let value = out[(row_start + col, col)];
if !value.is_finite() {
return Err(EstimationError::InvalidInput(
"sparse SPD selected diagonal solve produced non-finite values".to_string(),
));
}
sum += value;
}
Ok(sum)
}
pub fn logdet_from_factor(factor: &SparseExactFactor) -> Result<f64, EstimationError> {
Ok(factor.logdet)
}
pub fn assemble_and_factor_sparse_penalized_system(
workspace: &mut PirlsWorkspace,
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
s_lambda: &Array2<f64>,
ridge: f64,
) -> Result<SparsePenalizedSystem, EstimationError> {
let logdet_h_start = std::time::Instant::now();
let h_sparse = sparse_reml_penalized_hessian(workspace, x, weights, s_lambda, ridge)?;
let factor = factorize_sparse_spd(&h_sparse)?;
let logdet_h = logdet_from_factor(&factor)?;
log::info!(
"[STAGE] logdet H (sparse Cholesky) p={} elapsed={:.3}s",
h_sparse.nrows(),
logdet_h_start.elapsed().as_secs_f64(),
);
Ok(SparsePenalizedSystem {
h_sparse,
factor,
logdet_h,
})
}
#[cfg(test)]
fn build_sparse_penalty_blocks(
s_list: &[Array2<f64>],
) -> Result<Option<Vec<SparsePenaltyBlock>>, EstimationError> {
use crate::faer_ndarray::FaerEigh;
let mut ranges = Vec::with_capacity(s_list.len());
for (term_index, s_k) in s_list.iter().enumerate() {
let mut min_idx = usize::MAX;
let mut max_idx = 0usize;
for row in 0..s_k.nrows() {
for col in 0..s_k.ncols() {
if s_k[[row, col]].abs() > ZERO_TOL {
min_idx = min_idx.min(row.min(col));
max_idx = max_idx.max(row.max(col));
}
}
}
if min_idx == usize::MAX {
ranges.push((term_index, 0usize, 0usize));
} else {
ranges.push((term_index, min_idx, max_idx + 1));
}
}
let mut sorted = ranges.clone();
sorted.sort_by_key(|(_, start, _)| *start);
for pair in sorted.windows(2) {
let (_, _, end_left) = pair[0];
let (_, start_right, _) = pair[1];
if end_left > start_right {
return Ok(None);
}
}
let mut blocks = Vec::with_capacity(s_list.len());
for (term_index, p_start, p_end) in ranges {
let s_k = &s_list[term_index];
let block_support_strict = if p_end > p_start {
let mut ok = true;
for row in 0..s_k.nrows() {
for col in 0..s_k.ncols() {
if s_k[[row, col]].abs() > ZERO_TOL
&& (row < p_start || row >= p_end || col < p_start || col >= p_end)
{
ok = false;
break;
}
}
if !ok {
break;
}
}
ok
} else {
true
};
let s_k_block_dense = if p_end > p_start {
s_k.slice(ndarray::s![p_start..p_end, p_start..p_end])
.to_owned()
} else {
Array2::<f64>::zeros((0, 0))
};
let mut s_k_block_upper_entries = Vec::<(usize, usize, f64)>::new();
for col in 0..s_k_block_dense.ncols() {
for row in 0..=col {
let value = s_k_block_dense[[row, col]];
if value.abs() > ZERO_TOL {
s_k_block_upper_entries.push((row, col, value));
}
}
}
let s_k_sparse = dense_to_sparse(s_k, ZERO_TOL)?;
let block_dense = if p_end > p_start {
s_k.slice(ndarray::s![p_start..p_end, p_start..p_end])
.to_owned()
} else {
Array2::<f64>::zeros((0, 0))
};
let positive_eigenvalues = if block_dense.nrows() == 0 {
Vec::new()
} else {
let (evals, _) = block_dense
.eigh(Side::Lower)
.map_err(EstimationError::EigendecompositionFailed)?;
evals.iter().copied().filter(|v| *v > ZERO_TOL).collect()
};
blocks.push(SparsePenaltyBlock {
term_index,
p_start,
p_end,
positive_eigenvalues: Arc::new(positive_eigenvalues),
block_support_strict,
s_k_sparse,
s_k_block_dense: Arc::new(s_k_block_dense),
s_k_block_upper_entries: Arc::new(s_k_block_upper_entries),
});
}
Ok(Some(blocks))
}
pub fn build_sparse_penalty_blocks_from_canonical(
penalties: &[crate::construction::CanonicalPenalty],
p: usize,
) -> Result<Option<Vec<SparsePenaltyBlock>>, EstimationError> {
if penalties.is_empty() {
return Ok(Some(Vec::new()));
}
let mut sorted_ranges: Vec<(usize, usize, usize)> = penalties
.iter()
.enumerate()
.map(|(i, cp)| (i, cp.col_range.start, cp.col_range.end))
.collect();
sorted_ranges.sort_by_key(|&(_, start, _)| start);
for pair in sorted_ranges.windows(2) {
let (_, _, end_left) = pair[0];
let (_, start_right, _) = pair[1];
if end_left > start_right {
return Ok(None);
}
}
use rayon::prelude::*;
let block_results: Vec<Result<SparsePenaltyBlock, EstimationError>> = penalties
.par_iter()
.enumerate()
.map(|(term_index, cp)| {
let p_start = cp.col_range.start;
let p_end = cp.col_range.end;
let s_k_block_dense = cp.local_penalty();
let s_k_sparse = embed_dense_block_to_sparse_symmetric_upper(
&s_k_block_dense,
p_start,
p,
ZERO_TOL,
)?;
let mut s_k_block_upper_entries = Vec::<(usize, usize, f64)>::new();
for col in 0..s_k_block_dense.ncols() {
for row in 0..=col {
let value = s_k_block_dense[[row, col]];
if value.abs() > ZERO_TOL {
s_k_block_upper_entries.push((row, col, value));
}
}
}
Ok(SparsePenaltyBlock {
term_index,
p_start,
p_end,
positive_eigenvalues: Arc::new(cp.positive_eigenvalues.clone()),
block_support_strict: true,
s_k_sparse,
s_k_block_dense: Arc::new(s_k_block_dense),
s_k_block_upper_entries: Arc::new(s_k_block_upper_entries),
})
})
.collect();
let blocks = block_results.into_iter().collect::<Result<Vec<_>, _>>()?;
Ok(Some(blocks))
}
use faer::dyn_stack::{MemBuffer, MemStack, StackReq};
use faer::linalg::cholesky::llt::factor::LltRegularization;
use faer::sparse::linalg::amd;
use faer::sparse::linalg::cholesky::simplicial;
pub struct SimplicialFactor {
l_col_ptr: Vec<usize>,
l_row_idx: Vec<usize>,
l_values: Vec<f64>,
perm_inv: Vec<usize>,
n: usize,
pub logdet: f64,
}
fn sparse_spd_logdet_via_simplicial(
h_upper: &SparseColMat<usize, f64>,
) -> Result<f64, EstimationError> {
let n = h_upper.ncols();
if n == 0 {
return Ok(0.0);
}
let a_nnz = h_upper.compute_nnz();
let mut perm_fwd = vec![0usize; n];
let mut perm_inv = vec![0usize; n];
{
let mut mem = MemBuffer::new(amd::order_scratch::<usize>(n, a_nnz));
amd::order(
&mut perm_fwd,
&mut perm_inv,
h_upper.symbolic(),
amd::Control::default(),
MemStack::new(&mut mem),
)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
}
let perm = unsafe { faer::perm::PermRef::new_unchecked(&perm_fwd, &perm_inv, n) };
let a_perm_upper = {
let mut col_ptrs = vec![0usize; n + 1];
let mut row_indices = vec![0usize; a_nnz];
let mut values = vec![0.0f64; a_nnz];
let mut mem = MemBuffer::new(faer::sparse::utils::permute_self_adjoint_scratch::<usize>(
n,
));
faer::sparse::utils::permute_self_adjoint_to_unsorted(
&mut values,
&mut col_ptrs,
&mut row_indices,
h_upper.as_ref(),
perm,
Side::Upper,
Side::Upper,
MemStack::new(&mut mem),
);
SparseColMat::<usize, f64>::new(
unsafe { SymbolicSparseColMat::new_unchecked(n, n, col_ptrs, None, row_indices) },
values,
)
};
let symbolic = {
let mut mem = MemBuffer::new(StackReq::any_of(&[
simplicial::prefactorize_symbolic_cholesky_scratch::<usize>(n, a_nnz),
simplicial::factorize_simplicial_symbolic_cholesky_scratch::<usize>(n),
]));
let stack = MemStack::new(&mut mem);
let mut etree = vec![0isize; n];
let mut col_counts = vec![0usize; n];
simplicial::prefactorize_symbolic_cholesky(
&mut etree,
&mut col_counts,
a_perm_upper.symbolic(),
stack,
);
simplicial::factorize_simplicial_symbolic_cholesky(
a_perm_upper.symbolic(),
unsafe { simplicial::EliminationTreeRef::from_inner(&etree) },
&col_counts,
stack,
)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?
};
let mut l_values = vec![0.0f64; symbolic.len_val()];
{
let mut mem = MemBuffer::new(simplicial::factorize_simplicial_numeric_llt_scratch::<
usize,
f64,
>(n));
simplicial::factorize_simplicial_numeric_llt::<usize, f64>(
&mut l_values,
a_perm_upper.as_ref(),
LltRegularization::default(),
&symbolic,
MemStack::new(&mut mem),
)
.map_err(|_| EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: f64::NAN,
})?;
}
let l_col_ptr = symbolic.col_ptr();
let mut logdet = 0.0f64;
for j in 0..n {
let diag = l_values[l_col_ptr[j]];
if diag <= 0.0 {
return Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: f64::NAN,
});
}
logdet += diag.ln();
}
Ok(2.0 * logdet)
}
pub fn factorize_simplicial(
h: &SparseColMat<usize, f64>,
) -> Result<SimplicialFactor, EstimationError> {
let h_upper = canonicalize_sparse_symmetric_upper(h, ZERO_TOL)?;
let n = h_upper.ncols();
if n == 0 {
return Ok(SimplicialFactor {
l_col_ptr: vec![0],
l_row_idx: Vec::new(),
l_values: Vec::new(),
perm_inv: Vec::new(),
n: 0,
logdet: 0.0,
});
}
let a_nnz = h_upper.compute_nnz();
let mut perm_fwd = vec![0usize; n];
let mut perm_inv = vec![0usize; n];
{
let mut mem = MemBuffer::new(amd::order_scratch::<usize>(n, a_nnz));
amd::order(
&mut perm_fwd,
&mut perm_inv,
h_upper.symbolic(),
amd::Control::default(),
MemStack::new(&mut mem),
)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?;
}
let perm = unsafe { faer::perm::PermRef::new_unchecked(&perm_fwd, &perm_inv, n) };
let a_perm_upper = {
let mut col_ptrs = vec![0usize; n + 1];
let mut row_indices = vec![0usize; a_nnz];
let mut values = vec![0.0f64; a_nnz];
let mut mem = MemBuffer::new(faer::sparse::utils::permute_self_adjoint_scratch::<usize>(
n,
));
faer::sparse::utils::permute_self_adjoint_to_unsorted(
&mut values,
&mut col_ptrs,
&mut row_indices,
h_upper.as_ref(),
perm,
Side::Upper,
Side::Upper,
MemStack::new(&mut mem),
);
SparseColMat::<usize, f64>::new(
unsafe { SymbolicSparseColMat::new_unchecked(n, n, col_ptrs, None, row_indices) },
values,
)
};
let symbolic = {
let mut mem = MemBuffer::new(StackReq::any_of(&[
simplicial::prefactorize_symbolic_cholesky_scratch::<usize>(n, a_nnz),
simplicial::factorize_simplicial_symbolic_cholesky_scratch::<usize>(n),
]));
let stack = MemStack::new(&mut mem);
let mut etree = vec![0isize; n];
let mut col_counts = vec![0usize; n];
simplicial::prefactorize_symbolic_cholesky(
&mut etree,
&mut col_counts,
a_perm_upper.symbolic(),
stack,
);
simplicial::factorize_simplicial_symbolic_cholesky(
a_perm_upper.symbolic(),
unsafe { simplicial::EliminationTreeRef::from_inner(&etree) },
&col_counts,
stack,
)
.map_err(|_| EstimationError::ModelIsIllConditioned {
condition_number: f64::INFINITY,
})?
};
let mut l_values = vec![0.0f64; symbolic.len_val()];
{
let mut mem = MemBuffer::new(simplicial::factorize_simplicial_numeric_llt_scratch::<
usize,
f64,
>(n));
simplicial::factorize_simplicial_numeric_llt::<usize, f64>(
&mut l_values,
a_perm_upper.as_ref(),
LltRegularization::default(),
&symbolic,
MemStack::new(&mut mem),
)
.map_err(|_| EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: f64::NAN,
})?;
}
let l_col_ptr: Vec<usize> = symbolic.col_ptr().to_vec();
let l_row_idx: Vec<usize> = symbolic.row_idx().to_vec();
let mut logdet = 0.0f64;
for j in 0..n {
let diag = l_values[l_col_ptr[j]];
if diag <= 0.0 {
return Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: f64::NAN,
});
}
logdet += diag.ln();
}
logdet *= 2.0;
Ok(SimplicialFactor {
l_col_ptr,
l_row_idx,
l_values,
perm_inv,
n,
logdet,
})
}
pub struct TakahashiInverse {
z_values: Vec<f64>,
col_ptr: Vec<usize>,
row_idx: Vec<usize>,
l_values: Vec<f64>,
rows_lower: Arc<Vec<Vec<(usize, f64)>>>,
exact_columns: Mutex<BTreeMap<usize, Arc<Vec<f64>>>>,
perm_inv: Vec<usize>,
n: usize,
}
impl TakahashiInverse {
fn find_entry(col_ptr: &[usize], row_idx: &[usize], row: usize, col: usize) -> Option<usize> {
let start = col_ptr[col];
let end = col_ptr[col + 1];
let slice = &row_idx[start..end];
slice.binary_search(&row).ok().map(|pos| start + pos)
}
fn solve_permuted_column_from_cholesky(
n: usize,
col_ptr: &[usize],
row_idx: &[usize],
l_values: &[f64],
rows_lower: &[Vec<(usize, f64)>],
rhs_col: usize,
) -> Vec<f64> {
let mut rhs = vec![0.0f64; n];
rhs[rhs_col] = 1.0;
let mut forward = vec![0.0f64; n];
let mut solution = vec![0.0f64; n];
for row in 0..n {
let mut sum = rhs[row];
let mut diag = None;
for &(col, value) in &rows_lower[row] {
if col < row {
sum -= value * forward[col];
} else if col == row {
diag = Some(value);
}
}
let l_rr = diag.expect("simplicial factor row should contain its diagonal");
forward[row] = sum / l_rr;
}
for row in (0..n).rev() {
let col_start = col_ptr[row];
let col_end = col_ptr[row + 1];
let mut sum = forward[row];
let l_rr = l_values[col_start];
for idx in (col_start + 1)..col_end {
let lower_row = row_idx[idx];
sum -= l_values[idx] * solution[lower_row];
}
solution[row] = sum / l_rr;
}
solution
}
fn exact_permuted_column(&self, col: usize) -> Arc<Vec<f64>> {
{
let cache = self
.exact_columns
.lock()
.expect("exact Takahashi column cache mutex poisoned");
if let Some(solution) = cache.get(&col) {
return solution.clone();
}
}
let solution = Arc::new(Self::solve_permuted_column_from_cholesky(
self.n,
&self.col_ptr,
&self.row_idx,
&self.l_values,
self.rows_lower.as_ref(),
col,
));
let mut cache = self
.exact_columns
.lock()
.expect("exact Takahashi column cache mutex poisoned");
cache.entry(col).or_insert_with(|| solution.clone()).clone()
}
fn selected_value(
z_values: &[f64],
col_ptr: &[usize],
row_idx: &[usize],
row: usize,
col: usize,
) -> Result<f64, EstimationError> {
let (lower_row, lower_col) = if row >= col { (row, col) } else { (col, row) };
Self::find_entry(col_ptr, row_idx, lower_row, lower_col)
.map(|idx| z_values[idx])
.ok_or_else(|| {
EstimationError::InvalidInput(format!(
"simplicial selected-inverse pattern is missing entry ({lower_row},{lower_col})"
))
})
}
pub fn compute(factor: &SimplicialFactor) -> Result<Self, EstimationError> {
let n = factor.n;
let col_ptr = factor.l_col_ptr.clone();
let row_idx = factor.l_row_idx.clone();
let nnz = factor.l_values.len();
let mut z_values = vec![0.0f64; nnz];
let mut rows_lower: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
for col in 0..n {
for idx in col_ptr[col]..col_ptr[col + 1] {
let row = row_idx[idx];
rows_lower[row].push((col, factor.l_values[idx]));
}
}
for j in (0..n).rev() {
let diag_idx = col_ptr[j];
let col_end = col_ptr[j + 1];
let diag = factor.l_values[diag_idx];
if !(diag.is_finite() && diag > 0.0) {
return Err(EstimationError::HessianNotPositiveDefinite {
min_eigenvalue: f64::NAN,
});
}
for idx in (diag_idx + 1)..col_end {
let i = row_idx[idx];
let mut correction = 0.0;
for off_idx in (diag_idx + 1)..col_end {
let k = row_idx[off_idx];
let l_kj = factor.l_values[off_idx];
let z_ik = Self::selected_value(&z_values, &col_ptr, &row_idx, i, k)?;
correction += l_kj * z_ik;
}
let value = -correction / diag;
if !value.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"Takahashi selected inverse produced non-finite entry ({i},{j})"
)));
}
z_values[idx] = value;
}
let mut correction = 0.0;
for off_idx in (diag_idx + 1)..col_end {
correction += factor.l_values[off_idx] * z_values[off_idx];
}
let value = (1.0 / diag - correction) / diag;
if !value.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"Takahashi selected inverse produced non-finite diagonal entry ({j},{j})"
)));
}
z_values[diag_idx] = value;
}
Ok(TakahashiInverse {
z_values,
col_ptr,
row_idx,
l_values: factor.l_values.clone(),
rows_lower: Arc::new(rows_lower),
exact_columns: Mutex::new(BTreeMap::new()),
perm_inv: factor.perm_inv.clone(),
n,
})
}
pub fn get(&self, i: usize, j: usize) -> f64 {
let pi = self.perm_inv[i];
let pj = self.perm_inv[j];
self.get_permuted(pi, pj)
}
fn get_permuted(&self, pi: usize, pj: usize) -> f64 {
let (row, col) = if pi >= pj { (pi, pj) } else { (pj, pi) };
if let Some(pos) = Self::find_entry(&self.col_ptr, &self.row_idx, row, col) {
self.z_values[pos]
} else {
self.exact_permuted_column(col)[row]
}
}
pub fn diagonal(&self) -> Array1<f64> {
let mut diag = Array1::zeros(self.n);
for i in 0..self.n {
diag[i] = self.get(i, i);
}
diag
}
pub fn block(&self, start: usize, end: usize) -> Array2<f64> {
let dim = end - start;
let mut out = Array2::zeros((dim, dim));
for j_local in 0..dim {
let j = start + j_local;
for i_local in 0..dim {
let i = start + i_local;
out[[i_local, j_local]] = self.get(i, j);
}
}
out
}
pub fn trace_product_sparse(&self, s: &SparseColMat<usize, f64>) -> f64 {
let (symbolic, values) = s.parts();
let s_col_ptr = symbolic.col_ptr();
let s_row_idx = symbolic.row_idx();
let mut trace = 0.0;
for col in 0..s.ncols() {
let col_start = s_col_ptr[col];
let col_end = s_col_ptr[col + 1];
for idx in col_start..col_end {
let row = s_row_idx[idx];
if row > col {
continue; }
let val = values[idx];
let z_ij = self.get(row, col);
if row == col {
trace += z_ij * val;
} else {
trace += 2.0 * z_ij * val;
}
}
}
trace
}
}
pub fn trace_hinv_sk_takahashi(taka: &TakahashiInverse, penalty: &SparsePenaltyBlock) -> f64 {
if penalty.block_support_strict {
let mut trace = 0.0;
for &(row, col, val) in penalty.s_k_block_upper_entries.iter() {
let z_val = taka.get(penalty.p_start + row, penalty.p_start + col);
if row == col {
trace += z_val * val;
} else {
trace += 2.0 * z_val * val;
}
}
trace
} else {
taka.trace_product_sparse(&penalty.s_k_sparse)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::faer_ndarray::FaerCholesky;
use ndarray::array;
fn approx_eq(a: f64, b: f64, tol: f64) {
assert!(
(a - b).abs() <= tol,
"values differ: left={a:.12e}, right={b:.12e}, |diff|={:.12e}, tol={tol:.12e}",
(a - b).abs()
);
}
#[test]
fn canonical_sparse_penalty_blocks_preserve_input_order() {
fn canonical_penalty(
col_range: std::ops::Range<usize>,
local: Array2<f64>,
positive_eigenvalues: Vec<f64>,
total_dim: usize,
) -> crate::construction::CanonicalPenalty {
crate::construction::CanonicalPenalty {
root: Array2::<f64>::zeros((0, col_range.len())),
col_range,
total_dim,
nullity: 0,
local,
positive_eigenvalues,
op: None,
}
}
let penalties = vec![
canonical_penalty(2..4, array![[2.0, 0.5], [0.5, 3.0]], vec![2.0, 3.0], 5),
canonical_penalty(0..1, array![[7.0]], vec![7.0], 5),
canonical_penalty(4..5, array![[11.0]], vec![11.0], 5),
];
let blocks = build_sparse_penalty_blocks_from_canonical(&penalties, 5)
.unwrap()
.expect("non-overlapping canonical blocks should be sparse-block compatible");
let observed: Vec<(usize, usize, usize)> = blocks
.iter()
.map(|block| (block.term_index, block.p_start, block.p_end))
.collect();
assert_eq!(observed, vec![(0, 2, 4), (1, 0, 1), (2, 4, 5)]);
assert_eq!(&*blocks[0].positive_eigenvalues, &vec![2.0, 3.0]);
assert_eq!(&*blocks[1].positive_eigenvalues, &vec![7.0]);
assert_eq!(&*blocks[2].positive_eigenvalues, &vec![11.0]);
}
#[test]
fn takahashi_diagonal_matches_dense_inverse() {
let h = array![
[4.0, 0.2, 0.0, 0.0],
[0.2, 3.0, 0.1, 0.0],
[0.0, 0.1, 2.5, 0.3],
[0.0, 0.0, 0.3, 2.0]
];
let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
let chol = h.cholesky(Side::Lower).unwrap();
let mut h_inv = Array2::<f64>::zeros((4, 4));
for j in 0..4 {
let mut rhs = Array1::<f64>::zeros(4);
rhs[j] = 1.0;
let col = chol.solvevec(&rhs);
for i in 0..4 {
h_inv[[i, j]] = col[i];
}
}
let sfactor = factorize_simplicial(&h_sparse).unwrap();
let taka = TakahashiInverse::compute(&sfactor).unwrap();
let diag = taka.diagonal();
for i in 0..4 {
approx_eq(diag[i], h_inv[[i, i]], 1e-10);
}
}
#[test]
fn takahashi_logdet_matches_dense() {
let h = array![
[4.0, 0.2, 0.0, 0.0],
[0.2, 3.0, 0.1, 0.0],
[0.0, 0.1, 2.5, 0.3],
[0.0, 0.0, 0.3, 2.0]
];
let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
let existing = factorize_sparse_spd(&h_sparse).unwrap();
let logdet_dense = existing.logdet;
let sfactor = factorize_simplicial(&h_sparse).unwrap();
approx_eq(sfactor.logdet, logdet_dense, 1e-10);
}
#[test]
fn takahashi_trace_hinv_sk_matches_column_solve() {
let h = array![
[4.0, 0.2, 0.0, 0.0],
[0.2, 3.0, 0.1, 0.0],
[0.0, 0.1, 2.5, 0.3],
[0.0, 0.0, 0.3, 2.0]
];
let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
let mut s = Array2::<f64>::zeros((4, 4));
s[[1, 1]] = 2.0;
s[[2, 2]] = 3.0;
let blocks = build_sparse_penalty_blocks(&[s])
.unwrap()
.expect("single local block expected");
let chol = h.cholesky(Side::Lower).unwrap();
let mut h_inv = Array2::<f64>::zeros((4, 4));
for j in 0..4 {
let mut rhs = Array1::<f64>::zeros(4);
rhs[j] = 1.0;
let col = chol.solvevec(&rhs);
for i in 0..4 {
h_inv[[i, j]] = col[i];
}
}
let reference = h_inv[[1, 1]] * 2.0 + h_inv[[2, 2]] * 3.0;
let sfactor = factorize_simplicial(&h_sparse).unwrap();
let taka = TakahashiInverse::compute(&sfactor).unwrap();
let taka_result = trace_hinv_sk_takahashi(&taka, &blocks[0]);
approx_eq(taka_result, reference, 1e-10);
}
#[test]
fn takahashi_get_and_block_recover_off_pattern_inverse_entries() {
let h = array![
[4.0, 1.0, 0.0, 0.0],
[1.0, 3.0, 1.0, 0.0],
[0.0, 1.0, 2.5, 1.0],
[0.0, 0.0, 1.0, 2.0]
];
let h_sparse = dense_to_sparse_symmetric_upper(&h, ZERO_TOL).unwrap();
let chol = h.cholesky(Side::Lower).unwrap();
let mut h_inv = Array2::<f64>::zeros((4, 4));
for j in 0..4 {
let mut rhs = Array1::<f64>::zeros(4);
rhs[j] = 1.0;
let col = chol.solvevec(&rhs);
for i in 0..4 {
h_inv[[i, j]] = col[i];
}
}
let sfactor = factorize_simplicial(&h_sparse).unwrap();
let taka = TakahashiInverse::compute(&sfactor).unwrap();
assert!(
h_inv[[0, 2]].abs() > 1e-8,
"reference off-pattern inverse entry should be nonzero"
);
approx_eq(taka.get(0, 2), h_inv[[0, 2]], 1e-10);
let block = taka.block(0, 3);
approx_eq(block[[0, 2]], h_inv[[0, 2]], 1e-10);
approx_eq(block[[2, 0]], h_inv[[2, 0]], 1e-10);
}
}