use crate::linalg::faer_ndarray::{FaerCholesky, FaerEigh};
use faer::Side;
use ndarray::{Array1, Array2, ArrayView2};
const PENALTY_NULL_RELATIVE_TOL: f64 = 1e-9;
#[derive(Debug, Clone)]
pub struct JeffreysSubspace {
pub columns: Array2<f64>,
}
impl JeffreysSubspace {
#[inline]
pub fn span_dim(&self) -> usize {
self.columns.ncols()
}
}
pub fn jeffreys_subspace_from_penalty(
aggregate_penalty: ArrayView2<'_, f64>,
structural_nullity: Option<usize>,
) -> Result<JeffreysSubspace, String> {
let p = aggregate_penalty.nrows();
if aggregate_penalty.ncols() != p {
return Err(format!(
"jeffreys_subspace: aggregate penalty must be square, got {}x{}",
aggregate_penalty.nrows(),
aggregate_penalty.ncols()
));
}
if p == 0 {
return Ok(JeffreysSubspace {
columns: Array2::zeros((0, 0)),
});
}
let frobenius = aggregate_penalty.iter().map(|v| v * v).sum::<f64>().sqrt();
if frobenius == 0.0 {
return Ok(JeffreysSubspace {
columns: Array2::eye(p),
});
}
let owned = aggregate_penalty.to_owned();
let (evals, evecs) = owned
.eigh(Side::Lower)
.map_err(|e| format!("jeffreys_subspace: penalty eigendecomposition failed: {e}"))?;
let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max);
let m0 = match structural_nullity {
Some(m0) => m0.min(p),
None => {
let cutoff = PENALTY_NULL_RELATIVE_TOL * lambda_max.max(f64::MIN_POSITIVE);
evals.iter().filter(|&&e| e <= cutoff).count()
}
};
if m0 == 0 {
return Ok(JeffreysSubspace {
columns: Array2::zeros((p, 0)),
});
}
let columns = evecs.slice(ndarray::s![.., 0..m0]).to_owned();
Ok(JeffreysSubspace { columns })
}
pub fn joint_jeffreys_term<DirFn>(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
mut hessian_dir: DirFn,
) -> Result<(f64, Array1<f64>, Array2<f64>), String>
where
DirFn: FnMut(&Array1<f64>) -> Result<Option<Array2<f64>>, String>,
{
let p = h_joint.nrows();
if h_joint.ncols() != p {
return Err(format!(
"joint_jeffreys_term: H must be square, got {}x{}",
h_joint.nrows(),
h_joint.ncols()
));
}
if z_j.nrows() != p {
return Err(format!(
"joint_jeffreys_term: Z_J has {} rows, expected {} to match H",
z_j.nrows(),
p
));
}
let m = z_j.ncols();
if m == 0 {
return Ok((0.0, Array1::zeros(p), Array2::zeros((p, p))));
}
let hz = h_joint.dot(&z_j);
let h_id = z_j.t().dot(&hz);
let mut h_id_sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
h_id_sym[[i, j]] = 0.5 * (h_id[[i, j]] + h_id[[j, i]]);
}
}
let chol = h_id_sym.cholesky(Side::Lower).map_err(|e| {
format!("joint_jeffreys_term: reduced Fisher information not SPD on under-identified span ({e}); orthogonalization should remove any structural confound before Jeffreys is applied")
})?;
let phi = chol.diag().iter().map(|d| d.abs().ln()).sum::<f64>();
let eye = Array2::<f64>::eye(m);
let h_id_inv = chol.solve_mat(&eye);
let mut grad = Array1::<f64>::zeros(p);
let mut sensitivity = Array2::<f64>::zeros((p, m * m));
let mut axis = Array1::<f64>::zeros(p);
for k in 0..p {
axis.fill(0.0);
axis[k] = 1.0;
let hdot = match hessian_dir(&axis)? {
Some(hdot) => hdot,
None => {
return Ok((phi, Array1::zeros(p), Array2::zeros((p, p))));
}
};
if hdot.nrows() != p || hdot.ncols() != p {
return Err(format!(
"joint_jeffreys_term: Hdot shape {}x{} != {p}x{p}",
hdot.nrows(),
hdot.ncols()
));
}
let hdz = hdot.dot(&z_j);
let d_k = z_j.t().dot(&hdz);
let m_k = h_id_inv.dot(&d_k);
let mut trace = 0.0;
for i in 0..m {
trace += m_k[[i, i]];
}
grad[k] = 0.5 * trace;
let mut col = 0usize;
for i in 0..m {
for j in 0..m {
sensitivity[[k, col]] = m_k[[i, j]];
col += 1;
}
}
}
let mut hphi = Array2::<f64>::zeros((p, p));
for a in 0..p {
for b in a..p {
let mut acc = 0.0;
for col in 0..(m * m) {
acc += sensitivity[[a, col]] * sensitivity[[b, col]];
}
let value = 0.5 * acc;
hphi[[a, b]] = value;
hphi[[b, a]] = value;
}
}
Ok((phi, grad, hphi))
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn parametric_block_under_identified_span_is_identity() {
let s = Array2::<f64>::zeros((3, 3));
let z = jeffreys_subspace_from_penalty(s.view(), None).unwrap();
assert_eq!(z.span_dim(), 3);
assert_eq!(z.columns.nrows(), 3);
let gram = z.columns.t().dot(&z.columns);
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((gram[[i, j]] - expected).abs() < 1e-12);
}
}
}
#[test]
fn rank_deficient_penalty_selects_null_space_only() {
let mut s = Array2::<f64>::zeros((3, 3));
s[[2, 2]] = 5.0;
let z = jeffreys_subspace_from_penalty(s.view(), Some(2)).unwrap();
assert_eq!(z.span_dim(), 2);
let e3 = array![0.0, 0.0, 1.0];
let proj = z.columns.t().dot(&e3);
for v in proj.iter() {
assert!(v.abs() < 1e-10, "Z_J should be orthogonal to penalized dir");
}
}
#[test]
fn full_rank_penalty_has_empty_under_identified_span() {
let s = Array2::<f64>::eye(4) * 2.0;
let z = jeffreys_subspace_from_penalty(s.view(), Some(0)).unwrap();
assert_eq!(z.span_dim(), 0);
assert_eq!(z.columns.ncols(), 0);
}
#[test]
fn joint_jeffreys_term_matches_finite_difference_gradient() {
let p = 2usize;
let z = Array2::<f64>::eye(p);
let h_at = |b: &Array1<f64>| -> Array2<f64> {
let mut h = Array2::<f64>::zeros((p, p));
h[[0, 0]] = b[0].exp();
h[[1, 1]] = 1.0 + b[1] * b[1];
h
};
let beta: Array1<f64> = array![0.3, -0.4];
let hdir = |d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
let mut hd = Array2::<f64>::zeros((p, p));
hd[[0, 0]] = beta[0].exp() * d[0];
hd[[1, 1]] = 2.0 * beta[1] * d[1];
Ok(Some(hd))
};
let h = h_at(&beta);
let (phi, grad, hphi) =
joint_jeffreys_term(h.view(), z.view(), hdir).unwrap();
let expected_phi = 0.5 * (beta[0].exp() * (1.0 + beta[1] * beta[1])).ln();
assert!((phi - expected_phi).abs() < 1e-10, "phi {phi} vs {expected_phi}");
let eps = 1e-6;
for k in 0..p {
let mut bp = beta.clone();
let mut bm = beta.clone();
bp[k] += eps;
bm[k] -= eps;
let hp = h_at(&bp);
let hm = h_at(&bm);
let phi_p = 0.5 * (hp[[0, 0]] * hp[[1, 1]]).ln();
let phi_m = 0.5 * (hm[[0, 0]] * hm[[1, 1]]).ln();
let fd = (phi_p - phi_m) / (2.0 * eps);
assert!(
(grad[k] - fd).abs() < 1e-5,
"grad[{k}] {} vs fd {fd}",
grad[k]
);
}
for a in 0..p {
for b in 0..p {
assert!((hphi[[a, b]] - hphi[[b, a]]).abs() < 1e-12);
}
}
let (evals, _) = hphi.eigh(Side::Lower).unwrap();
for e in evals.iter() {
assert!(*e >= -1e-10, "H_Phi must be PSD, got eigenvalue {e}");
}
}
#[test]
fn empty_span_yields_zero_term() {
let h = Array2::<f64>::eye(3);
let z = Array2::<f64>::zeros((3, 0));
let hdir = |_d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
Ok(Some(Array2::<f64>::zeros((3, 3))))
};
let (phi, grad, hphi) = joint_jeffreys_term(h.view(), z.view(), hdir).unwrap();
assert_eq!(phi, 0.0);
assert!(grad.iter().all(|v| *v == 0.0));
assert!(hphi.iter().all(|v| *v == 0.0));
}
}