use faer::Side;
use ndarray::{Array1, Array2, s};
use rayon::prelude::*;
use crate::faer_ndarray::FaerEigh;
fn are_penalties_block_factored(penalties: &[crate::construction::CanonicalPenalty]) -> bool {
for (i, a) in penalties.iter().enumerate() {
for b in &penalties[i + 1..] {
let overlaps =
a.col_range.start < b.col_range.end && b.col_range.start < a.col_range.end;
let same_range =
a.col_range.start == b.col_range.start && a.col_range.end == b.col_range.end;
if overlaps && !same_range {
return false;
}
}
}
true
}
fn infer_penalty_rank(penalty: &crate::construction::CanonicalPenalty) -> Result<usize, String> {
let block_dim = penalty.block_dim();
if penalty.positive_eigenvalues.len() + penalty.nullity == block_dim {
return Ok(penalty.positive_eigenvalues.len());
}
if block_dim == 0 {
return Ok(0);
}
let (evals, _) = penalty
.local
.eigh(Side::Lower)
.map_err(|e| format!("Penalty component eigendecomposition failed: {e}"))?;
let threshold = super::unified::positive_eigenvalue_threshold(evals.as_slice().unwrap());
Ok(evals.iter().filter(|&&e| e > threshold).count())
}
fn structural_nullity_from_penalties(
penalties: &[crate::construction::CanonicalPenalty],
p_total: usize,
) -> Result<Option<usize>, String> {
if penalties.is_empty() {
return Ok(None);
}
let mut component_matrices = Vec::with_capacity(penalties.len());
let mut component_nullities = Vec::with_capacity(penalties.len());
for penalty in penalties {
let rank = infer_penalty_rank(penalty)?;
let mut component = Array2::<f64>::zeros((p_total, p_total));
penalty.accumulate_weighted(&mut component, 1.0);
component_matrices.push(component);
component_nullities.push(p_total.saturating_sub(rank));
}
Ok(Some(super::unified::exact_intersection_nullity(
&component_matrices,
&component_nullities,
)))
}
#[derive(Clone, Debug)]
struct PenaltyBlockSpan {
start: usize,
end: usize,
rank_start: usize,
rank_end: usize,
}
#[derive(Clone)]
pub struct PenaltyPseudologdet {
w_factor: Array2<f64>,
u_null: Option<Array2<f64>>,
inv_evals_sq: Array1<f64>,
rank: usize,
value: f64,
block_spans: Vec<PenaltyBlockSpan>,
}
impl PenaltyPseudologdet {
#[inline]
fn trace_dense_product(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
let diag_len = a.nrows().min(b.ncols());
let inner_len = a.ncols().min(b.nrows());
let mut total = 0.0;
for i in 0..diag_len {
for k in 0..inner_len {
total += a[[i, k]] * b[[k, i]];
}
}
total
}
fn pseudo_inverse_dense(&self) -> Array2<f64> {
self.w_factor.dot(&self.w_factor.t())
}
pub fn from_penalties(
penalties: &[crate::construction::CanonicalPenalty],
lambdas: &[f64],
ridge: f64,
p_total: usize,
) -> Result<Self, String> {
if penalties.is_empty() {
return Ok(Self {
w_factor: Array2::zeros((0, 0)),
u_null: None,
inv_evals_sq: Array1::zeros(0),
rank: 0,
value: 0.0,
block_spans: Vec::new(),
});
}
let disjoint = are_penalties_block_factored(penalties);
if disjoint {
Self::from_penalties_block_factored(penalties, lambdas, ridge, p_total)
} else {
let structural_nullity = structural_nullity_from_penalties(penalties, p_total)?;
let mut s_total = Array2::<f64>::zeros((p_total, p_total));
for (k, cp) in penalties.iter().enumerate() {
if k < lambdas.len() {
cp.accumulate_weighted(&mut s_total, lambdas[k]);
}
}
if ridge > 0.0 {
for i in 0..p_total {
s_total[[i, i]] += ridge;
}
}
Self::from_assembled_with_nullity(s_total, structural_nullity)
}
}
fn from_penalties_block_factored(
penalties: &[crate::construction::CanonicalPenalty],
lambdas: &[f64],
ridge: f64,
p_total: usize,
) -> Result<Self, String> {
use ndarray::s;
struct BlockData {
start: usize,
end: usize,
local: Array2<f64>,
component_matrices: Vec<Array2<f64>>,
component_nullities: Vec<usize>,
}
let mut blocks: Vec<BlockData> = Vec::new();
for (k, cp) in penalties.iter().enumerate() {
let lambda = if k < lambdas.len() { lambdas[k] } else { 0.0 };
let r = &cp.col_range;
let local_rank = infer_penalty_rank(cp)?;
let local_nullity = cp.block_dim().saturating_sub(local_rank);
if let Some(bd) = blocks
.iter_mut()
.find(|bd| bd.start == r.start && bd.end == r.end)
{
bd.local.scaled_add(lambda, &cp.local);
bd.component_matrices.push(cp.local.clone());
bd.component_nullities.push(local_nullity);
} else {
let bd = cp.block_dim();
let mut local = Array2::<f64>::zeros((bd, bd));
local.scaled_add(lambda, &cp.local);
blocks.push(BlockData {
start: r.start,
end: r.end,
local,
component_matrices: vec![cp.local.clone()],
component_nullities: vec![local_nullity],
});
}
}
if ridge > 0.0 {
for bd in &mut blocks {
let bs = bd.end - bd.start;
for i in 0..bs {
bd.local[[i, i]] += ridge;
}
}
}
let mut covered = vec![false; p_total];
for bd in &blocks {
for i in bd.start..bd.end {
covered[i] = true;
}
}
struct BlockResult {
start: usize,
end: usize,
w_local: Array2<f64>,
u_null_local: Array2<f64>,
inv_evals_sq: Vec<f64>,
value: f64,
rank: usize,
nullity: usize,
}
let mut block_results: Vec<BlockResult> = if rayon::current_thread_index().is_some() {
blocks
.iter()
.map(|bd| {
let structural_nullity = Some(super::unified::exact_intersection_nullity(
&bd.component_matrices,
&bd.component_nullities,
));
let block_pld =
Self::from_assembled_with_nullity(bd.local.clone(), structural_nullity)?;
let nullity = block_pld.u_null.as_ref().map_or(0, Array2::ncols);
Ok(BlockResult {
start: bd.start,
end: bd.end,
w_local: block_pld.w_factor,
u_null_local: block_pld
.u_null
.unwrap_or_else(|| Array2::<f64>::zeros((bd.end - bd.start, 0))),
inv_evals_sq: block_pld.inv_evals_sq.to_vec(),
value: block_pld.value,
rank: block_pld.rank,
nullity,
})
})
.collect::<Result<Vec<_>, String>>()?
} else {
blocks
.par_iter()
.map(|bd| {
let structural_nullity = Some(super::unified::exact_intersection_nullity(
&bd.component_matrices,
&bd.component_nullities,
));
let block_pld =
Self::from_assembled_with_nullity(bd.local.clone(), structural_nullity)?;
let nullity = block_pld.u_null.as_ref().map_or(0, Array2::ncols);
Ok(BlockResult {
start: bd.start,
end: bd.end,
w_local: block_pld.w_factor,
u_null_local: block_pld
.u_null
.unwrap_or_else(|| Array2::<f64>::zeros((bd.end - bd.start, 0))),
inv_evals_sq: block_pld.inv_evals_sq.to_vec(),
value: block_pld.value,
rank: block_pld.rank,
nullity,
})
})
.collect::<Result<Vec<_>, String>>()?
};
if ridge > 0.0 {
let inv_ridge_sq = 1.0 / (ridge * ridge);
let scale = 1.0 / ridge.sqrt();
for (idx, &c) in covered.iter().enumerate() {
if !c {
let mut w_col = Array2::<f64>::zeros((1, 1));
w_col[[0, 0]] = scale;
block_results.push(BlockResult {
start: idx,
end: idx + 1,
w_local: w_col,
u_null_local: Array2::<f64>::zeros((1, 0)),
inv_evals_sq: vec![inv_ridge_sq],
value: ridge.ln(),
rank: 1,
nullity: 0,
});
}
}
}
let total_rank: usize = block_results.iter().map(|br| br.rank).sum();
let total_value: f64 = block_results.iter().map(|br| br.value).sum();
let mut w_factor_combined = Array2::<f64>::zeros((p_total, total_rank));
let mut inv_evals_sq_combined = Array1::<f64>::zeros(total_rank);
let mut block_spans = Vec::with_capacity(block_results.len());
let mut col_offset = 0;
for br in &block_results {
if br.rank > 0 {
w_factor_combined
.slice_mut(s![br.start..br.end, col_offset..col_offset + br.rank])
.assign(&br.w_local);
for (i, &v) in br.inv_evals_sq.iter().enumerate() {
inv_evals_sq_combined[col_offset + i] = v;
}
block_spans.push(PenaltyBlockSpan {
start: br.start,
end: br.end,
rank_start: col_offset,
rank_end: col_offset + br.rank,
});
col_offset += br.rank;
}
}
let block_nullity: usize = block_results.iter().map(|br| br.nullity).sum();
let uncovered_nullity = if ridge > 0.0 {
0
} else {
covered.iter().filter(|&&c| !c).count()
};
let total_nullity = block_nullity + uncovered_nullity;
let u_null = if total_nullity > 0 {
let mut u0 = Array2::<f64>::zeros((p_total, total_nullity));
let mut null_col = 0;
for br in &block_results {
if br.nullity > 0 {
u0.slice_mut(s![br.start..br.end, null_col..null_col + br.nullity])
.assign(&br.u_null_local);
null_col += br.nullity;
}
}
for (idx, &c) in covered.iter().enumerate() {
if !c && ridge <= 0.0 {
u0[[idx, null_col]] = 1.0;
null_col += 1;
}
}
debug_assert_eq!(
null_col, total_nullity,
"block-factored pseudo-logdet nullspace assembly mismatch"
);
Some(u0)
} else {
None
};
Ok(Self {
w_factor: w_factor_combined,
u_null,
inv_evals_sq: inv_evals_sq_combined,
rank: total_rank,
value: total_value,
block_spans,
})
}
pub fn from_components(
s_k_matrices: &[Array2<f64>],
lambdas: &[f64],
ridge: f64,
) -> Result<Self, String> {
if s_k_matrices.is_empty() {
return Ok(Self {
w_factor: Array2::zeros((0, 0)),
u_null: None,
inv_evals_sq: Array1::zeros(0),
rank: 0,
value: 0.0,
block_spans: Vec::new(),
});
}
let p_dim = s_k_matrices[0].nrows();
assert!(
s_k_matrices
.iter()
.all(|m| m.nrows() == p_dim && m.ncols() == p_dim)
);
let mut s_total = Array2::<f64>::zeros((p_dim, p_dim));
for (k, s_k) in s_k_matrices.iter().enumerate() {
s_total.scaled_add(lambdas[k], s_k);
}
if ridge > 0.0 {
for i in 0..p_dim {
s_total[[i, i]] += ridge;
}
}
Self::from_assembled(s_total)
}
pub fn from_components_with_nullity(
s_k_matrices: &[Array2<f64>],
lambdas: &[f64],
ridge: f64,
structural_nullity: Option<usize>,
) -> Result<Self, String> {
if s_k_matrices.is_empty() {
return Ok(Self {
w_factor: Array2::zeros((0, 0)),
u_null: None,
inv_evals_sq: Array1::zeros(0),
rank: 0,
value: 0.0,
block_spans: Vec::new(),
});
}
let p_dim = s_k_matrices[0].nrows();
assert!(
s_k_matrices
.iter()
.all(|m| m.nrows() == p_dim && m.ncols() == p_dim)
);
let mut s_total = Array2::<f64>::zeros((p_dim, p_dim));
for (k, s_k) in s_k_matrices.iter().enumerate() {
s_total.scaled_add(lambdas[k], s_k);
}
if ridge > 0.0 {
for i in 0..p_dim {
s_total[[i, i]] += ridge;
}
}
Self::from_assembled_with_nullity(s_total, structural_nullity)
}
pub fn from_assembled(s_total: Array2<f64>) -> Result<Self, String> {
Self::from_assembled_inner(s_total, None)
}
pub fn from_assembled_with_nullity(
s_total: Array2<f64>,
structural_nullity: Option<usize>,
) -> Result<Self, String> {
Self::from_assembled_inner(s_total, structural_nullity)
}
fn from_assembled_inner(
s_total: Array2<f64>,
structural_nullity: Option<usize>,
) -> Result<Self, String> {
let p_dim = s_total.nrows();
if p_dim == 0 {
return Ok(Self {
w_factor: Array2::zeros((0, 0)),
u_null: None,
inv_evals_sq: Array1::zeros(0),
rank: 0,
value: 0.0,
block_spans: Vec::new(),
});
}
let (evals, evecs) = s_total
.eigh(Side::Lower)
.map_err(|e| format!("PenaltyPseudologdet eigendecomposition failed: {e}"))?;
let (rank, nullity) = if let Some(m0) = structural_nullity {
let m0 = m0.min(p_dim);
(p_dim - m0, m0)
} else {
let threshold =
super::unified::positive_eigenvalue_threshold(evals.as_slice().unwrap());
let rank = evals.iter().filter(|&&e| e > threshold).count();
(rank, p_dim - rank)
};
let value: f64 = evals
.iter()
.rev()
.take(rank)
.map(|&e| e.max(1e-300).ln())
.sum();
let mut w_factor = Array2::<f64>::zeros((p_dim, rank));
let mut inv_evals_sq = Array1::<f64>::zeros(rank);
for col in 0..rank {
let idx = nullity + col;
let ev = evals[idx];
let scale = 1.0 / ev.sqrt();
inv_evals_sq[col] = 1.0 / (ev * ev);
for row in 0..p_dim {
w_factor[[row, col]] = evecs[[row, idx]] * scale;
}
}
let u_null = if nullity > 0 {
let mut u0 = Array2::<f64>::zeros((p_dim, nullity));
for col in 0..nullity {
for row in 0..p_dim {
u0[[row, col]] = evecs[[row, col]];
}
}
Some(u0)
} else {
None
};
Ok(Self {
w_factor,
u_null,
inv_evals_sq,
rank,
value,
block_spans: Vec::new(),
})
}
pub fn value(&self) -> f64 {
self.value
}
pub fn rank(&self) -> usize {
self.rank
}
fn reduced(&self, m: &Array2<f64>) -> Array2<f64> {
let wt_m = self.w_factor.t().dot(m);
wt_m.dot(&self.w_factor)
}
fn leakage(&self, m: &Array2<f64>) -> Option<Array2<f64>> {
let u_null = self.u_null.as_ref()?;
let wt_m = self.w_factor.t().dot(m);
Some(wt_m.dot(u_null))
}
fn moving_nullspace_correction(&self, wt_si_u0: &Array2<f64>, wt_sj_u0: &Array2<f64>) -> f64 {
let mut total = 0.0_f64;
for r in 0..self.rank {
let sigma_inv = self.inv_evals_sq[r].sqrt(); let mut row_dot = 0.0_f64;
let nullity = wt_si_u0.ncols();
for m in 0..nullity {
row_dot += wt_si_u0[[r, m]] * wt_sj_u0[[r, m]];
}
total += sigma_inv * row_dot;
}
2.0 * total
}
pub fn rho_derivatives(
&self,
s_k_matrices: &[Array2<f64>],
lambdas: &[f64],
) -> (Array1<f64>, Array2<f64>) {
let k = s_k_matrices.len();
if k == 0 || self.rank == 0 {
return (Array1::zeros(k), Array2::zeros((k, k)));
}
let y_k: Vec<Array2<f64>> = if rayon::current_thread_index().is_some() {
s_k_matrices.iter().map(|s| self.reduced(s)).collect()
} else {
s_k_matrices.par_iter().map(|s| self.reduced(s)).collect()
};
let first_vals: Vec<f64> = y_k
.iter()
.enumerate()
.map(|(idx, y)| lambdas[idx] * (0..self.rank).map(|i| y[[i, i]]).sum::<f64>())
.collect();
let mut det1 = Array1::<f64>::zeros(k);
for (idx, value) in first_vals.into_iter().enumerate() {
det1[idx] = value;
}
let pairs = (0..k).flat_map(|ki| (0..=ki).map(move |li| (ki, li)));
let pair_vals: Vec<(usize, usize, f64)> = if rayon::current_thread_index().is_some() {
pairs
.map(|(ki, li)| {
let tr_ab = Self::trace_dense_product(&y_k[ki], &y_k[li]);
let mut val = -lambdas[ki] * lambdas[li] * tr_ab;
if ki == li {
val += det1[ki];
}
(ki, li, val)
})
.collect()
} else {
pairs
.par_bridge()
.map(|(ki, li)| {
let tr_ab = Self::trace_dense_product(&y_k[ki], &y_k[li]);
let mut val = -lambdas[ki] * lambdas[li] * tr_ab;
if ki == li {
val += det1[ki];
}
(ki, li, val)
})
.collect()
};
let mut det2 = Array2::<f64>::zeros((k, k));
for (ki, li, val) in pair_vals {
det2[[ki, li]] = val;
det2[[li, ki]] = val;
}
(det1, det2)
}
pub fn rho_derivatives_from_penalties(
&self,
penalties: &[crate::construction::CanonicalPenalty],
lambdas: &[f64],
) -> (Array1<f64>, Array2<f64>) {
let k = penalties.len();
if k == 0 || self.rank == 0 {
return (Array1::zeros(k), Array2::zeros((k, k)));
}
struct ReducedPenalty {
span: Option<usize>,
y: Array2<f64>,
}
let project = |penalty: &crate::construction::CanonicalPenalty| {
let start = penalty.col_range.start;
let end = penalty.col_range.end;
if let Some((span_idx, span)) = self
.block_spans
.iter()
.enumerate()
.find(|(_, span)| span.start <= start && end <= span.end)
{
let local_start = start - span.start;
let local_end = local_start + (end - start);
let w_block = self
.w_factor
.slice(s![start..end, span.rank_start..span.rank_end]);
let local_w = penalty.local.dot(&w_block);
let y = self
.w_factor
.slice(s![start..end, span.rank_start..span.rank_end])
.t()
.dot(&local_w);
debug_assert_eq!(local_end - local_start, penalty.local.nrows());
ReducedPenalty {
span: Some(span_idx),
y,
}
} else {
let w_block = self.w_factor.slice(s![start..end, ..]);
let local_w = penalty.local.dot(&w_block);
ReducedPenalty {
span: None,
y: w_block.t().dot(&local_w),
}
}
};
let y_k: Vec<ReducedPenalty> = if rayon::current_thread_index().is_some() {
penalties.iter().map(project).collect()
} else {
penalties.par_iter().map(project).collect()
};
let mut det1 = Array1::<f64>::zeros(k);
for (idx, reduced) in y_k.iter().enumerate() {
let tr: f64 = (0..reduced.y.nrows()).map(|i| reduced.y[[i, i]]).sum();
det1[idx] = lambdas[idx] * tr;
}
let pairs = (0..k).flat_map(|ki| (0..=ki).map(move |li| (ki, li)));
let pair_vals: Vec<(usize, usize, f64)> = if rayon::current_thread_index().is_some() {
pairs
.map(|(ki, li)| {
let same_span = match (y_k[ki].span, y_k[li].span) {
(Some(a), Some(b)) => a == b,
_ => true,
};
let tr_ab = if same_span {
Self::trace_dense_product(&y_k[ki].y, &y_k[li].y)
} else {
0.0
};
let mut val = -lambdas[ki] * lambdas[li] * tr_ab;
if ki == li {
val += det1[ki];
}
(ki, li, val)
})
.collect()
} else {
pairs
.par_bridge()
.map(|(ki, li)| {
let same_span = match (y_k[ki].span, y_k[li].span) {
(Some(a), Some(b)) => a == b,
_ => true,
};
let tr_ab = if same_span {
Self::trace_dense_product(&y_k[ki].y, &y_k[li].y)
} else {
0.0
};
let mut val = -lambdas[ki] * lambdas[li] * tr_ab;
if ki == li {
val += det1[ki];
}
(ki, li, val)
})
.collect()
};
let mut det2 = Array2::<f64>::zeros((k, k));
for (ki, li, val) in pair_vals {
det2[[ki, li]] = val;
det2[[li, ki]] = val;
}
(det1, det2)
}
pub fn tau_gradient_component(&self, s_tau_i: &Array2<f64>) -> f64 {
if self.rank == 0 {
return 0.0;
}
let y = self.reduced(s_tau_i);
(0..self.rank).map(|i| y[[i, i]]).sum()
}
pub fn tau_hessian_component(
&self,
s_tau_i: &Array2<f64>,
s_tau_j: &Array2<f64>,
s_tau_ij: Option<&Array2<f64>>,
) -> f64 {
if self.rank == 0 {
return 0.0;
}
let s_pinv = self.pseudo_inverse_dense();
let linear = if let Some(s_ij) = s_tau_ij {
Self::trace_dense_product(&s_pinv, s_ij)
} else {
0.0
};
let quad = Self::trace_dense_product(&s_pinv.dot(s_tau_i).dot(&s_pinv), s_tau_j);
let nullspace_correction = if self.u_null.is_some() {
let li = self.leakage(s_tau_i);
let lj = self.leakage(s_tau_j);
match (li, lj) {
(Some(ref wt_i_u0), Some(ref wt_j_u0)) => {
self.moving_nullspace_correction(wt_i_u0, wt_j_u0)
}
_ => 0.0,
}
} else {
0.0
};
linear - quad + nullspace_correction
}
pub fn rho_tau_hessian_component(
&self,
s_k: &Array2<f64>,
lambda_k: f64,
s_tau_i: &Array2<f64>,
ds_k_dtau_i: Option<&Array2<f64>>,
) -> f64 {
if self.rank == 0 {
return 0.0;
}
let s_pinv = self.pseudo_inverse_dense();
let quad = Self::trace_dense_product(&s_pinv.dot(s_k).dot(&s_pinv), s_tau_i);
let linear = if let Some(dsk) = ds_k_dtau_i {
Self::trace_dense_product(&s_pinv, dsk)
} else {
0.0
};
lambda_k * (linear - quad)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_scalar_penalty_logdet() {
let rho = 1.5_f64;
let lambda = rho.exp();
let s_k = array![[1.0]]; let pld = PenaltyPseudologdet::from_components(&[s_k.clone()], &[lambda], 0.0).unwrap();
assert!((pld.value() - rho).abs() < 1e-12, "value should be ρ");
let (det1, det2) = pld.rho_derivatives(&[s_k], &[lambda]);
assert!(
(det1[0] - 1.0).abs() < 1e-12,
"det1 = {}, expected 1.0",
det1[0]
);
assert!(
det2[[0, 0]].abs() < 1e-12,
"det2 = {}, expected 0.0",
det2[[0, 0]]
);
}
#[test]
fn test_two_penalty_logdet() {
let rho = [1.0_f64, -0.5];
let lambdas: Vec<f64> = rho.iter().map(|&r| r.exp()).collect();
let s1 = array![[1.0, 0.0], [0.0, 0.0]];
let s2 = array![[0.0, 0.0], [0.0, 1.0]];
let pld =
PenaltyPseudologdet::from_components(&[s1.clone(), s2.clone()], &lambdas, 0.0).unwrap();
assert!(
(pld.value() - 0.5).abs() < 1e-12,
"value = {}, expected 0.5",
pld.value()
);
let (det1, det2) = pld.rho_derivatives(&[s1, s2], &lambdas);
assert!((det1[0] - 1.0).abs() < 1e-12);
assert!((det1[1] - 1.0).abs() < 1e-12);
assert!(det2[[0, 0]].abs() < 1e-12);
assert!(det2[[1, 1]].abs() < 1e-12);
assert!(det2[[0, 1]].abs() < 1e-12);
}
#[test]
fn test_tau_derivative_fd() {
let tau0 = 0.3_f64;
let det = 2.0 * tau0 + 1.75;
let s0 = array![[1.0 + tau0, 0.5], [0.5, 2.0]];
let s_tau = array![[1.0, 0.0], [0.0, 0.0]];
let s_tau_tau = Array2::<f64>::zeros((2, 2));
let pld = PenaltyPseudologdet::from_assembled(s0).unwrap();
let exact_grad = 2.0 / det;
let grad = pld.tau_gradient_component(&s_tau);
assert!(
(grad - exact_grad).abs() < 1e-12,
"τ gradient: analytic={grad}, exact={exact_grad}"
);
let exact_hess = -4.0 / (det * det);
let hess = pld.tau_hessian_component(&s_tau, &s_tau, Some(&s_tau_tau));
assert!(
(hess - exact_hess).abs() < 1e-12,
"τ hessian: analytic={hess}, exact={exact_hess}"
);
}
#[test]
fn test_no_nullspace_correction_full_rank() {
let s = array![[3.0, 1.0], [1.0, 2.0]];
let pld = PenaltyPseudologdet::from_assembled(s).unwrap();
assert_eq!(pld.rank(), 2);
assert!(pld.u_null.is_none());
}
#[test]
fn test_rank_deficient_value() {
let s = array![[4.0, 2.0], [2.0, 1.0]];
let pld = PenaltyPseudologdet::from_assembled(s).unwrap();
assert_eq!(pld.rank(), 1);
assert!((pld.value() - 5.0_f64.ln()).abs() < 1e-12);
}
#[test]
fn test_nullspace_rotation_gradient_zero() {
let s1 = 3.0_f64;
let s2 = 1.0_f64;
let psi = 0.5_f64;
let c = psi.cos();
let s = psi.sin();
let r = array![[c, 0.0, -s], [0.0, 1.0, 0.0], [s, 0.0, c]];
let d = array![[s1, 0.0, 0.0], [0.0, s2, 0.0], [0.0, 0.0, 0.0]];
let s_mat = r.dot(&d).dot(&r.t());
let r_psi = array![[-s, 0.0, -c], [0.0, 0.0, 0.0], [c, 0.0, -s]];
let s_psi = r_psi.dot(&d).dot(&r.t()) + r.dot(&d).dot(&r_psi.t());
let pld = PenaltyPseudologdet::from_assembled(s_mat).unwrap();
assert_eq!(pld.rank(), 2);
let grad = pld.tau_gradient_component(&s_psi);
assert!(
grad.abs() < 1e-10,
"nullspace-rotation gradient should be zero, got {grad}"
);
}
#[test]
fn test_block_factored_tau_hessian_preserves_internal_nullspace() {
let s1 = 3.0_f64;
let s2 = 1.0_f64;
let psi = 0.5_f64;
let c = psi.cos();
let s = psi.sin();
let r = array![[c, 0.0, -s], [0.0, 1.0, 0.0], [s, 0.0, c]];
let d = array![[s1, 0.0, 0.0], [0.0, s2, 0.0], [0.0, 0.0, 0.0]];
let s_mat = r.dot(&d).dot(&r.t());
let r_psi = array![[-s, 0.0, -c], [0.0, 0.0, 0.0], [c, 0.0, -s]];
let s_psi = r_psi.dot(&d).dot(&r.t()) + r.dot(&d).dot(&r_psi.t());
let r_psi_psi = array![[-c, 0.0, s], [0.0, 0.0, 0.0], [-s, 0.0, -c]];
let s_psi_psi = r_psi_psi.dot(&d).dot(&r.t())
+ 2.0 * r_psi.dot(&d).dot(&r_psi.t())
+ r.dot(&d).dot(&r_psi_psi.t());
let root = crate::estimate::reml::unified::penalty_matrix_root(&s_mat).unwrap();
let penalty = crate::construction::CanonicalPenalty::from_dense_root(root, 3);
let block_factored = PenaltyPseudologdet::from_penalties(&[penalty], &[1.0], 0.0, 3)
.expect("block-factored pseudo-logdet");
let assembled =
PenaltyPseudologdet::from_assembled(s_mat).expect("assembled pseudo-logdet");
let block_hess = block_factored.tau_hessian_component(&s_psi, &s_psi, Some(&s_psi_psi));
let assembled_hess = assembled.tau_hessian_component(&s_psi, &s_psi, Some(&s_psi_psi));
assert!(
assembled_hess.abs() < 1e-10,
"assembled reference should see zero curvature for a pure nullspace rotation, got {assembled_hess}"
);
assert!(
(block_hess - assembled_hess).abs() < 1e-10,
"block-factored tau hessian lost internal nullspace columns: block={block_hess}, assembled={assembled_hess}"
);
}
#[test]
fn test_block_factored_ridge_preserves_structural_nullspace_value() {
let s = array![[4.0, 2.0], [2.0, 1.0]];
let ridge = 1e-4_f64;
let root = crate::estimate::reml::unified::penalty_matrix_root(&s).unwrap();
let penalty = crate::construction::CanonicalPenalty::from_dense_root(root, 2);
let block_factored = PenaltyPseudologdet::from_penalties(&[penalty], &[1.0], ridge, 2)
.expect("block-factored pseudo-logdet");
let mut s_ridged = s.clone();
for i in 0..2 {
s_ridged[[i, i]] += ridge;
}
let assembled = PenaltyPseudologdet::from_assembled_with_nullity(s_ridged, Some(1))
.expect("assembled pseudo-logdet with structural nullity");
assert_eq!(block_factored.rank(), assembled.rank());
assert!(
(block_factored.value() - assembled.value()).abs() < 1e-12,
"block-factored ridge path leaked structural nullspace logdet: block={}, assembled={}",
block_factored.value(),
assembled.value()
);
}
#[test]
fn test_block_factored_rho_derivatives_match_dense_without_cross_block_work() {
let p_total = 6;
let lambdas = [1.7_f64, 0.4_f64, 2.3_f64];
let penalties = vec![
crate::construction::CanonicalPenalty {
root: array![[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]],
col_range: 0..3,
total_dim: p_total,
nullity: 1,
local: array![[1.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 0.0]],
positive_eigenvalues: vec![1.0, 4.0],
op: None,
},
crate::construction::CanonicalPenalty {
root: array![[0.0, 0.0, 3.0]],
col_range: 0..3,
total_dim: p_total,
nullity: 2,
local: array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 9.0]],
positive_eigenvalues: vec![9.0],
op: None,
},
crate::construction::CanonicalPenalty {
root: array![[1.5, 0.0, 0.0], [0.0, 0.0, 0.5]],
col_range: 3..6,
total_dim: p_total,
nullity: 1,
local: array![[2.25, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.25]],
positive_eigenvalues: vec![2.25, 0.25],
op: None,
},
];
let block_factored =
PenaltyPseudologdet::from_penalties(&penalties, &lambdas, 0.0, p_total).unwrap();
assert_eq!(block_factored.block_spans.len(), 2);
let mut dense_components = Vec::new();
for penalty in &penalties {
let mut full = Array2::<f64>::zeros((p_total, p_total));
penalty.accumulate_weighted(&mut full, 1.0);
dense_components.push(full);
}
let dense = PenaltyPseudologdet::from_components(&dense_components, &lambdas, 0.0).unwrap();
let (block_first, block_second) =
block_factored.rho_derivatives_from_penalties(&penalties, &lambdas);
let (dense_first, dense_second) = dense.rho_derivatives(&dense_components, &lambdas);
for k in 0..lambdas.len() {
assert!((block_first[k] - dense_first[k]).abs() < 1e-11);
for l in 0..lambdas.len() {
assert!((block_second[[k, l]] - dense_second[[k, l]]).abs() < 1e-10);
}
}
assert!(block_second[[0, 2]].abs() < 1e-12);
assert!(block_second[[1, 2]].abs() < 1e-12);
}
#[test]
fn test_overlapping_penalties_ridge_preserve_structural_nullspace_value() {
let ridge = 1e-4_f64;
let lambdas = [2.0_f64, 3.0_f64];
let penalties = [
crate::construction::CanonicalPenalty::from_dense_root(array![[1.0, 0.0, 0.0]], 3),
crate::construction::CanonicalPenalty::from_dense_root(array![[0.0, 1.0, 0.0]], 3),
];
let overlapping = PenaltyPseudologdet::from_penalties(&penalties, &lambdas, ridge, 3)
.expect("overlapping pseudo-logdet");
let s_ridged = array![
[lambdas[0] + ridge, 0.0, 0.0],
[0.0, lambdas[1] + ridge, 0.0],
[0.0, 0.0, ridge]
];
let assembled = PenaltyPseudologdet::from_assembled_with_nullity(s_ridged, Some(1))
.expect("assembled pseudo-logdet with structural nullity");
assert_eq!(overlapping.rank(), assembled.rank());
assert!(
(overlapping.value() - assembled.value()).abs() < 1e-12,
"assembled ridge path leaked structural nullspace logdet: overlap={}, assembled={}",
overlapping.value(),
assembled.value()
);
}
}