use crate::linalg::faer_ndarray::FaerEigh;
use crate::linalg::lanczos::{SymmetricLanczosOptions, symmetric_lanczos_eigenpairs};
use faer::Side;
use ndarray::{Array1, Array2, ArrayView2};
use std::collections::HashMap;
use std::sync::Arc;
#[inline]
pub(crate) fn norm2_slice(a: &[f64]) -> f64 {
a.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub(crate) const REDUCED_INFO_RELATIVE_FLOOR: f64 = 1e-10;
pub(crate) const REDUCED_INFO_ABSOLUTE_FLOOR: f64 = 1e-12;
#[inline]
pub(crate) fn jeffreys_cap(floor: f64) -> f64 {
CONDITIONING_GATE_ABSOLUTE_CLEAR.max(floor)
}
#[inline]
pub(crate) fn floored_inverse(lam: f64, floor: f64) -> f64 {
let cap = jeffreys_cap(floor);
if lam >= cap {
cap / (lam * lam)
} else if lam >= floor {
1.0 / lam
} else if lam >= 0.0 {
1.0 / floor
} else {
let denom = floor - lam;
floor / (denom * denom)
}
}
#[inline]
pub(crate) fn jeffreys_antiderivative(lam: f64, floor: f64) -> f64 {
let cap = jeffreys_cap(floor);
if lam >= cap {
cap.ln() + 1.0 - cap / lam
} else if lam >= floor {
lam.ln()
} else if lam >= 0.0 {
lam / floor + floor.ln() - 1.0
} else {
floor.ln() - 1.0 + lam / (floor - lam)
}
}
#[inline]
pub(crate) fn jeffreys_antiderivative_floor_sensitivity(lam: f64, floor: f64) -> f64 {
let cap = jeffreys_cap(floor);
if lam >= cap {
if cap > CONDITIONING_GATE_ABSOLUTE_CLEAR {
1.0 / floor - 1.0 / lam
} else {
0.0
}
} else if lam >= floor {
0.0
} else if lam >= 0.0 {
1.0 / floor - lam / (floor * floor)
} else {
let denom = floor - lam;
1.0 / floor - lam / (denom * denom)
}
}
#[inline]
pub(crate) fn floored_inverse_prime(lam: f64, floor: f64) -> f64 {
let cap = jeffreys_cap(floor);
if lam >= cap {
-2.0 * cap / (lam * lam * lam)
} else if lam >= floor {
-1.0 / (lam * lam)
} else if lam >= 0.0 {
0.0
} else {
let denom = floor - lam;
2.0 * floor / (denom * denom * denom)
}
}
#[inline]
pub(crate) fn floored_inverse_second(lam: f64, floor: f64) -> f64 {
let cap = jeffreys_cap(floor);
if lam >= cap {
6.0 * cap / (lam * lam * lam * lam)
} else if lam >= floor {
2.0 / (lam * lam * lam)
} else if lam >= 0.0 {
0.0
} else {
let denom = floor - lam;
6.0 * floor / (denom * denom * denom * denom)
}
}
#[inline]
pub(crate) fn floored_inverse_floor_sensitivity(lam: f64, floor: f64) -> f64 {
let cap = jeffreys_cap(floor);
if lam >= cap {
if cap > CONDITIONING_GATE_ABSOLUTE_CLEAR {
1.0 / (lam * lam)
} else {
0.0
}
} else if lam >= floor {
0.0
} else if lam >= 0.0 {
-1.0 / (floor * floor)
} else {
let denom = floor - lam;
-(floor + lam) / (denom * denom * denom)
}
}
#[inline]
pub(crate) fn floored_inverse_prime_floor_sensitivity(lam: f64, floor: f64) -> f64 {
let cap = jeffreys_cap(floor);
if lam >= cap {
if cap > CONDITIONING_GATE_ABSOLUTE_CLEAR {
-2.0 / (lam * lam * lam)
} else {
0.0
}
} else if lam >= 0.0 {
0.0
} else {
let denom = floor - lam;
-2.0 * (2.0 * floor + lam) / (denom * denom * denom * denom)
}
}
pub(crate) fn floored_inverse_divided_differences(evals: &Array1<f64>, floor: f64) -> Array2<f64> {
let m = evals.len();
let mut psi = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
let denom = evals[i] - evals[j];
psi[[i, j]] = if denom.abs() <= REDUCED_INFO_ABSOLUTE_FLOOR {
floored_inverse_prime(evals[i], floor)
} else {
(floored_inverse(evals[i], floor) - floored_inverse(evals[j], floor)) / denom
};
}
}
psi
}
pub(crate) const CONDITIONING_GATE_RELATIVE: f64 = 1e-8;
pub(crate) const CONDITIONING_GATE_ABSOLUTE: f64 = 1.0;
pub(crate) const CONDITIONING_GATE_ABSOLUTE_CLEAR: f64 = 16.0;
pub(crate) const CONDITIONING_GATE_RELATIVE_CLEAR: f64 = 1e-6;
#[inline]
pub(crate) fn conditioning_gate_weight(lambda_min: f64, lambda_max: f64) -> f64 {
if lambda_max <= 0.0 {
return 1.0;
}
if !lambda_min.is_finite() {
return 1.0;
}
#[inline]
fn ramp_down(x: f64, under: f64, clear: f64) -> f64 {
if x <= under {
return 1.0;
}
if x >= clear {
return 0.0;
}
let t = (x - under) / (clear - under);
1.0 - t * t * (3.0 - 2.0 * t)
}
let w_abs = ramp_down(
lambda_min,
CONDITIONING_GATE_ABSOLUTE,
CONDITIONING_GATE_ABSOLUTE_CLEAR,
);
let ratio = (lambda_min / lambda_max).max(f64::MIN_POSITIVE);
let w_rel = ramp_down(
ratio.log10(),
CONDITIONING_GATE_RELATIVE.log10(),
CONDITIONING_GATE_RELATIVE_CLEAR.log10(),
);
w_abs.max(w_rel)
}
pub(crate) fn conditioning_gate_weight_grad(lambda_min: f64, lambda_max: f64) -> (f64, f64) {
if lambda_max <= 0.0 || !lambda_min.is_finite() {
return (0.0, 0.0);
}
#[inline]
fn ramp_down_value_and_deriv(x: f64, under: f64, clear: f64) -> (f64, f64) {
if x <= under {
return (1.0, 0.0);
}
if x >= clear {
return (0.0, 0.0);
}
let span = clear - under;
let t = (x - under) / span;
let value = 1.0 - t * t * (3.0 - 2.0 * t);
let deriv = -6.0 * t * (1.0 - t) / span;
(value, deriv)
}
let (w_abs, dw_abs_dlmin) = ramp_down_value_and_deriv(
lambda_min,
CONDITIONING_GATE_ABSOLUTE,
CONDITIONING_GATE_ABSOLUTE_CLEAR,
);
let ratio = (lambda_min / lambda_max).max(f64::MIN_POSITIVE);
let (w_rel, dw_rel_dlogratio) = ramp_down_value_and_deriv(
ratio.log10(),
CONDITIONING_GATE_RELATIVE.log10(),
CONDITIONING_GATE_RELATIVE_CLEAR.log10(),
);
if w_abs >= w_rel {
(dw_abs_dlmin, 0.0)
} else {
let ln10 = std::f64::consts::LN_10;
(
dw_rel_dlogratio / (lambda_min * ln10),
-dw_rel_dlogratio / (lambda_max * ln10),
)
}
}
pub const CHEAP_CONDITIONING_PRECHECK_MIN_DIM: usize = 128;
pub(crate) const CHEAP_PRECHECK_SAFETY_MARGIN: f64 = 8.0;
pub(crate) const CHEAP_PRECHECK_LANCZOS_STEPS: usize = 12;
pub(crate) const CHEAP_PRECHECK_RITZ_REL_TOL: f64 = 1e-3;
pub(crate) fn cheap_conditioning_bounds<HvFn>(
mut hv: HvFn,
p: usize,
) -> Result<Option<(f64, f64)>, String>
where
HvFn: FnMut(&Array1<f64>) -> Result<Array1<f64>, String>,
{
if p == 0 {
return Ok(None);
}
let steps = CHEAP_PRECHECK_LANCZOS_STEPS.min(p);
let mut q0 = vec![0.0_f64; p];
let golden = 0.618_033_988_749_894_8_f64; for (i, qi) in q0.iter_mut().enumerate() {
let frac = ((i as f64 + 1.0) * golden).fract();
*qi = frac - 0.5;
}
let q_norm = norm2_slice(&q0);
if !(q_norm.is_finite() && q_norm > 0.0) {
return Ok(None);
}
let mut hv_failed: Option<String> = None;
let eigen = match symmetric_lanczos_eigenpairs(
p,
&q0,
SymmetricLanczosOptions {
max_steps: steps,
residual_tol: f64::EPSILON,
local_reorthogonalize: false,
full_reorthogonalize: true,
},
|q, out| {
let qv = Array1::from_vec(q.to_vec());
let w = match hv(&qv) {
Ok(w) => w,
Err(e) => {
hv_failed = Some(e);
return Err("cheap_conditioning_bounds: HVP failed".to_string());
}
};
if w.len() != p || w.iter().any(|x| !x.is_finite()) {
return Err(
"cheap_conditioning_bounds: HVP produced non-finite/ill-sized output"
.to_string(),
);
}
out.copy_from_slice(w.as_slice().ok_or_else(|| {
"cheap_conditioning_bounds: HVP output not contiguous".to_string()
})?);
Ok(())
},
) {
Ok(eigen) => eigen,
Err(_) => {
if let Some(e) = hv_failed {
return Err(e);
}
return Ok(None);
}
};
let ritz = eigen.eigenvalues;
let ritz_vecs = eigen.eigenvectors;
let last_residual_norm = eigen.residual_norm;
let k = ritz.len();
if k == 0 {
return Ok(None);
}
let mut idx_min = 0usize;
let mut idx_max = 0usize;
for i in 1..k {
if ritz[i] < ritz[idx_min] {
idx_min = i;
}
if ritz[i] > ritz[idx_max] {
idx_max = i;
}
}
let theta_min = ritz[idx_min];
let theta_max = ritz[idx_max];
if !theta_min.is_finite() || !theta_max.is_finite() {
return Ok(None);
}
let last_row = k - 1;
let res_min = last_residual_norm * ritz_vecs[[last_row, idx_min]].abs();
let res_max = last_residual_norm * ritz_vecs[[last_row, idx_max]].abs();
let scale = theta_max.abs().max(1.0);
let converged_tol = CHEAP_PRECHECK_RITZ_REL_TOL * scale;
if res_min > converged_tol || res_max > converged_tol {
return Ok(None);
}
let lambda_min_lb = theta_min - res_min;
let lambda_max_ub = theta_max + res_max;
Ok(Some((lambda_min_lb, lambda_max_ub)))
}
pub fn jeffreys_term_skippable_via_matvec<HvFn>(hv: HvFn, p: usize) -> Result<bool, String>
where
HvFn: FnMut(&Array1<f64>) -> Result<Array1<f64>, String>,
{
if p < CHEAP_CONDITIONING_PRECHECK_MIN_DIM {
return Ok(false);
}
let (lambda_min_lb, lambda_max_ub) = match cheap_conditioning_bounds(hv, p)? {
Some(bounds) => bounds,
None => return Ok(false),
};
if !(lambda_min_lb.is_finite() && lambda_max_ub.is_finite()) {
return Ok(false);
}
if lambda_min_lb <= 0.0 || lambda_max_ub <= 0.0 {
return Ok(false);
}
let absolute_clears =
lambda_min_lb >= CHEAP_PRECHECK_SAFETY_MARGIN * CONDITIONING_GATE_ABSOLUTE_CLEAR;
let relative_clears = lambda_min_lb / lambda_max_ub
>= CHEAP_PRECHECK_SAFETY_MARGIN * CONDITIONING_GATE_RELATIVE_CLEAR;
Ok(absolute_clears && relative_clears)
}
#[derive(Debug, Clone)]
pub struct JeffreysSubspace {
pub columns: Array2<f64>,
}
impl JeffreysSubspace {
#[inline]
pub fn span_dim(&self) -> usize {
self.columns.ncols()
}
}
pub fn jeffreys_subspace_from_penalty(
aggregate_penalty: ArrayView2<'_, f64>,
) -> Result<JeffreysSubspace, String> {
let p = aggregate_penalty.nrows();
if aggregate_penalty.ncols() != p {
return Err(format!(
"jeffreys_subspace: aggregate penalty must be square, got {}x{}",
aggregate_penalty.nrows(),
aggregate_penalty.ncols()
));
}
if p == 0 {
return Ok(JeffreysSubspace {
columns: Array2::zeros((0, 0)),
});
}
Ok(JeffreysSubspace {
columns: Array2::eye(p),
})
}
pub fn joint_jeffreys_term<DirFn>(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
hessian_dir: DirFn,
) -> Result<(f64, Array1<f64>, Array2<f64>), String>
where
DirFn: Fn(&Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
{
let p = h_joint.nrows();
if h_joint.ncols() != p {
return Err(format!(
"joint_jeffreys_term: H must be square, got {}x{}",
h_joint.nrows(),
h_joint.ncols()
));
}
if z_j.nrows() != p {
return Err(format!(
"joint_jeffreys_term: Z_J has {} rows, expected {} to match H",
z_j.nrows(),
p
));
}
let m = z_j.ncols();
if m == 0 {
return Ok((0.0, Array1::zeros(p), Array2::zeros((p, p))));
}
let hz = h_joint.dot(&z_j);
let h_id = z_j.t().dot(&hz);
let mut h_id_sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
h_id_sym[[i, j]] = 0.5 * (h_id[[i, j]] + h_id[[j, i]]);
}
}
let (evals, evecs) = h_id_sym.eigh(Side::Lower).map_err(|e| {
format!("joint_jeffreys_term: reduced-information eigendecomposition failed: {e}")
})?;
let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max);
let gate_weight = {
let lambda_min = evals.iter().cloned().fold(f64::INFINITY, f64::min);
conditioning_gate_weight(lambda_min, lambda_max)
};
if gate_weight == 0.0 {
return Ok((0.0, Array1::zeros(p), Array2::zeros((p, p))));
}
let floor = (REDUCED_INFO_RELATIVE_FLOOR * lambda_max).max(REDUCED_INFO_ABSOLUTE_FLOOR);
let floor_in_relative_regime =
lambda_max > 0.0 && REDUCED_INFO_RELATIVE_FLOOR * lambda_max >= REDUCED_INFO_ABSOLUTE_FLOOR;
let lambda_max_idx: Option<usize> = if floor_in_relative_regime {
let mut idx_max = 0usize;
for i in 1..m {
if evals[i] > evals[idx_max] {
idx_max = i;
}
}
Some(idx_max)
} else {
None
};
let value_atom = super::atoms::JeffreysLogdetAtom {
eigvals: evals.clone(),
floor,
gate_weight,
reduced_drift: HashMap::new(),
floor_drift: HashMap::new(),
stratum: super::atoms::StratumFingerprint {
kept_rank: m,
min_relative_eigengap: 0.0,
},
};
let phi = super::atoms::CriterionAtom::value(&value_atom);
let mut grad = Array1::<f64>::zeros(p);
let hdots: Vec<Array2<f64>> = {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let results: Vec<Result<Option<Array2<f64>>, String>> = (0..p)
.into_par_iter()
.map(|k| {
let mut axis = Array1::<f64>::zeros(p);
axis[k] = 1.0;
crate::linalg::faer_ndarray::with_nested_parallel(|| hessian_dir(&axis))
})
.collect();
let mut hdots = Vec::with_capacity(p);
for hdot in results {
let hdot = match hdot? {
Some(hdot) => hdot,
None => {
return Ok((phi, Array1::zeros(p), Array2::zeros((p, p))));
}
};
if hdot.nrows() != p || hdot.ncols() != p {
return Err(format!(
"joint_jeffreys_term: Hdot shape {}x{} != {p}x{p}",
hdot.nrows(),
hdot.ncols()
));
}
hdots.push(hdot);
}
hdots
};
let mut reduced_drift: HashMap<usize, Arc<Array2<f64>>> = HashMap::with_capacity(p);
let mut floor_drift: HashMap<usize, f64> = HashMap::new();
for (k, hdot) in hdots.into_iter().enumerate() {
let hdz = hdot.dot(&z_j);
let d_k = z_j.t().dot(&hdz);
let a_k = evecs.t().dot(&d_k).dot(&evecs);
if let Some(idx_max) = lambda_max_idx {
let dlambda_max = a_k[[idx_max, idx_max]]; floor_drift.insert(k, REDUCED_INFO_RELATIVE_FLOOR * dlambda_max);
}
reduced_drift.insert(k, Arc::new(a_k));
}
let gradient_atom = super::atoms::JeffreysLogdetAtom {
eigvals: evals.clone(),
floor,
gate_weight,
reduced_drift,
floor_drift,
stratum: super::atoms::StratumFingerprint {
kept_rank: m,
min_relative_eigengap: 0.0,
},
};
for k in 0..p {
let dir = super::atoms::ThetaDirection {
index: Some(k),
beta_dot: None,
h_dot_total: None,
};
grad[k] = super::atoms::CriterionAtom::frozen_d1(&gradient_atom, &dir);
}
let hphi = gradient_atom.second_order_curvature(p)?;
Ok((phi, grad, hphi))
}
pub fn joint_jeffreys_second_order_completion<Dir2Fn>(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
hessian_second_dir: Dir2Fn,
) -> Result<Option<Array2<f64>>, String>
where
Dir2Fn: Fn(&Array1<f64>, &Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
{
let p = h_joint.nrows();
if h_joint.ncols() != p {
return Err(format!(
"joint_jeffreys_second_order_completion: H must be square, got {}x{}",
h_joint.nrows(),
h_joint.ncols()
));
}
if z_j.nrows() != p {
return Err(format!(
"joint_jeffreys_second_order_completion: Z_J has {} rows, expected {p}",
z_j.nrows()
));
}
let m = z_j.ncols();
if m == 0 {
return Ok(Some(Array2::zeros((p, p))));
}
let hz = h_joint.dot(&z_j);
let h_id = z_j.t().dot(&hz);
let mut h_id_sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
h_id_sym[[i, j]] = 0.5 * (h_id[[i, j]] + h_id[[j, i]]);
}
}
let (evals, evecs) = h_id_sym.eigh(Side::Lower).map_err(|e| {
format!("joint_jeffreys_second_order_completion: reduced-information eigendecomposition failed: {e}")
})?;
let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max);
let gate_weight = {
let lambda_min = evals.iter().cloned().fold(f64::INFINITY, f64::min);
conditioning_gate_weight(lambda_min, lambda_max)
};
if gate_weight == 0.0 {
return Ok(Some(Array2::zeros((p, p))));
}
let floor = (REDUCED_INFO_RELATIVE_FLOOR * lambda_max).max(REDUCED_INFO_ABSOLUTE_FLOOR);
let mut inv_diag = Array1::<f64>::zeros(m);
for (i, &lam) in evals.iter().enumerate() {
inv_diag[i] = floored_inverse(lam, floor);
}
let mut k_reduced = Array2::<f64>::zeros((m, m));
for eig in 0..m {
let weight = inv_diag[eig];
if weight == 0.0 {
continue;
}
for row in 0..m {
let wr = weight * evecs[[row, eig]];
for col in 0..m {
k_reduced[[row, col]] += wr * evecs[[col, eig]];
}
}
}
let mut out = Array2::<f64>::zeros((p, p));
let pairs: Vec<(usize, usize)> = (0..p).flat_map(|a| (a..p).map(move |b| (a, b))).collect();
let h2s: Vec<Array2<f64>> = {
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
let results: Vec<Result<Option<Array2<f64>>, String>> = pairs
.par_iter()
.map(|&(a, b)| {
let mut axis_a = Array1::<f64>::zeros(p);
axis_a[a] = 1.0;
let mut axis_b = Array1::<f64>::zeros(p);
axis_b[b] = 1.0;
crate::linalg::faer_ndarray::with_nested_parallel(|| {
hessian_second_dir(&axis_a, &axis_b)
})
})
.collect();
let mut h2s = Vec::with_capacity(pairs.len());
for (&(a, b), result) in pairs.iter().zip(results.into_iter()) {
let h2 = match result? {
Some(h2) => h2,
None => return Ok(None),
};
if h2.dim() != (p, p) {
return Err(format!(
"joint_jeffreys_second_order_completion: H''[{a},{b}] shape {:?} != ({p}, {p})",
h2.dim()
));
}
h2s.push(h2);
}
h2s
};
for (&(a, b), h2) in pairs.iter().zip(h2s.into_iter()) {
let h2z = h2.dot(&z_j);
let d_ab = z_j.t().dot(&h2z);
let mut trace = 0.0_f64;
for i in 0..m {
for j in 0..m {
trace += k_reduced[[i, j]] * d_ab[[j, i]];
}
}
let value = -0.5 * gate_weight * trace;
out[[a, b]] = value;
out[[b, a]] = value;
}
Ok(Some(out))
}
pub fn joint_jeffreys_hphi_explicit_param_derivative<BaseFn, PertFn>(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
pert_h: &Array2<f64>,
base_hessian_dir: BaseFn,
pert_hessian_dir: PertFn,
) -> Result<Array2<f64>, String>
where
BaseFn: Fn(&Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
PertFn: Fn(&Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
{
joint_jeffreys_hphi_perturbation_derivative(
h_joint,
z_j,
base_hessian_dir,
pert_h,
pert_hessian_dir,
)
}
pub(crate) struct JeffreysHphiDriftBase {
p: usize,
m: usize,
z_j: Array2<f64>,
evals: Array1<f64>,
evecs: Array2<f64>,
floor: f64,
gate_weight: f64,
psi: Array2<f64>,
floor_in_relative_regime: bool,
idx_min: usize,
idx_max: usize,
a_rows: Array2<f64>,
aw_rows: Array2<f64>,
}
impl JeffreysHphiDriftBase {
pub(crate) fn prepare<BaseFn>(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
base_hessian_dir: BaseFn,
) -> Result<Option<JeffreysHphiDriftBase>, String>
where
BaseFn: Fn(&Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
{
let p = h_joint.nrows();
if h_joint.ncols() != p {
return Err(format!(
"JeffreysHphiDriftBase::prepare: H must be square, got {}x{}",
h_joint.nrows(),
h_joint.ncols()
));
}
if z_j.nrows() != p {
return Err(format!(
"JeffreysHphiDriftBase::prepare: Z_J has {} rows, expected {p}",
z_j.nrows()
));
}
let m = z_j.ncols();
if m == 0 || p == 0 {
return Ok(None);
}
let hdots: Vec<Array2<f64>> = {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let results: Vec<Result<Option<Array2<f64>>, String>> = (0..p)
.into_par_iter()
.map(|a| {
let mut axis = Array1::<f64>::zeros(p);
axis[a] = 1.0;
crate::linalg::faer_ndarray::with_nested_parallel(|| base_hessian_dir(&axis))
})
.collect();
let mut hdots = Vec::with_capacity(p);
for result in results {
let hdot = match result? {
Some(hd) => hd,
None => return Ok(None),
};
hdots.push(hdot);
}
hdots
};
Self::from_axis_derivatives(h_joint, z_j, hdots)
}
pub(crate) fn prepare_with_axes(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
hdots: Vec<Array2<f64>>,
) -> Result<Option<JeffreysHphiDriftBase>, String> {
let p = h_joint.nrows();
if h_joint.ncols() != p {
return Err(format!(
"JeffreysHphiDriftBase::prepare_with_axes: H must be square, got {}x{}",
h_joint.nrows(),
h_joint.ncols()
));
}
if z_j.nrows() != p {
return Err(format!(
"JeffreysHphiDriftBase::prepare_with_axes: Z_J has {} rows, expected {p}",
z_j.nrows()
));
}
if hdots.len() != p {
return Err(format!(
"JeffreysHphiDriftBase::prepare_with_axes: got {} axis derivatives, expected {p}",
hdots.len()
));
}
let m = z_j.ncols();
if m == 0 || p == 0 {
return Ok(None);
}
Self::from_axis_derivatives(h_joint, z_j, hdots)
}
fn from_axis_derivatives(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
hdots: Vec<Array2<f64>>,
) -> Result<Option<JeffreysHphiDriftBase>, String> {
let p = h_joint.nrows();
let m = z_j.ncols();
for hdot in &hdots {
if hdot.nrows() != p || hdot.ncols() != p {
return Err(format!(
"JeffreysHphiDriftBase: Hdot[e_a] shape {}x{} != {p}x{p}",
hdot.nrows(),
hdot.ncols()
));
}
}
let hz0 = h_joint.dot(&z_j);
let h_id = z_j.t().dot(&hz0);
let mut h_id_sym = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
h_id_sym[[i, j]] = 0.5 * (h_id[[i, j]] + h_id[[j, i]]);
}
}
let (evals, evecs) = h_id_sym.eigh(Side::Lower).map_err(|e| {
format!("JeffreysHphiDriftBase::prepare: eigendecomposition failed: {e}")
})?;
let lambda_max = evals.iter().cloned().fold(0.0_f64, f64::max);
let lambda_min = evals.iter().cloned().fold(f64::INFINITY, f64::min);
let gate_weight = conditioning_gate_weight(lambda_min, lambda_max);
if gate_weight == 0.0 {
return Ok(None);
}
let floor = (REDUCED_INFO_RELATIVE_FLOOR * lambda_max).max(REDUCED_INFO_ABSOLUTE_FLOOR);
let psi = floored_inverse_divided_differences(&evals, floor);
let floor_in_relative_regime = lambda_max > 0.0
&& REDUCED_INFO_RELATIVE_FLOOR * lambda_max >= REDUCED_INFO_ABSOLUTE_FLOOR;
let mut idx_max = 0usize;
let mut idx_min = 0usize;
for i in 1..m {
if evals[i] > evals[idx_max] {
idx_max = i;
}
if evals[i] < evals[idx_min] {
idx_min = i;
}
}
let z_owned = z_j.to_owned();
let mut a_rows = Array2::<f64>::zeros((p, m * m));
let mut aw_rows = Array2::<f64>::zeros((p, m * m));
{
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
let hdots: Vec<Array2<f64>> = hdots;
a_rows
.axis_iter_mut(ndarray::Axis(0))
.into_par_iter()
.zip(aw_rows.axis_iter_mut(ndarray::Axis(0)).into_par_iter())
.zip(hdots.into_par_iter())
.for_each(|((mut a_row, mut aw_row), hdot_a)| {
crate::linalg::faer_ndarray::with_nested_parallel(|| {
let d_a_raw = z_j.t().dot(&hdot_a.dot(&z_j));
let mut d_a = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
d_a[[i, j]] = 0.5 * (d_a_raw[[i, j]] + d_a_raw[[j, i]]);
}
}
let a_a = evecs.t().dot(&d_a).dot(&evecs);
let mut col = 0usize;
for i in 0..m {
for j in 0..m {
a_row[col] = a_a[[i, j]];
aw_row[col] = psi[[i, j]] * a_a[[i, j]];
col += 1;
}
}
});
});
}
Ok(Some(JeffreysHphiDriftBase {
p,
m,
z_j: z_owned,
evals,
evecs,
floor,
gate_weight,
psi,
floor_in_relative_regime,
idx_min,
idx_max,
a_rows,
aw_rows,
}))
}
pub(crate) fn perturbation_derivative<PertFn>(
&self,
pert_h: &Array2<f64>,
pert_hessian_dir: PertFn,
) -> Result<Array2<f64>, String>
where
PertFn: Fn(&Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
{
let p = self.p;
let pert_hdots: Vec<Array2<f64>> = {
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let results: Vec<Result<Option<Array2<f64>>, String>> = (0..p)
.into_par_iter()
.map(|a| {
let mut axis = Array1::<f64>::zeros(p);
axis[a] = 1.0;
crate::linalg::faer_ndarray::with_nested_parallel(|| pert_hessian_dir(&axis))
})
.collect();
let mut pert_hdots = Vec::with_capacity(p);
for result in results {
match result? {
Some(h2) => pert_hdots.push(h2),
None => return Ok(Array2::zeros((p, p))),
}
}
pert_hdots
};
self.perturbation_derivative_from_axis_matrices(pert_h, pert_hdots)
}
pub(crate) fn perturbation_derivative_batched_axes(
&self,
pert_h: &Array2<f64>,
pert_axis_matrices: Option<Vec<Array2<f64>>>,
) -> Result<Array2<f64>, String> {
let p = self.p;
let Some(pert_hdots) = pert_axis_matrices else {
return Ok(Array2::zeros((p, p)));
};
if pert_hdots.len() != p {
return Err(format!(
"JeffreysHphiDriftBase::perturbation_derivative_batched_axes: got {} axis \
matrices, expected {p}",
pert_hdots.len()
));
}
self.perturbation_derivative_from_axis_matrices(pert_h, pert_hdots)
}
fn perturbation_derivative_from_axis_matrices(
&self,
pert_h: &Array2<f64>,
pert_hdots: Vec<Array2<f64>>,
) -> Result<Array2<f64>, String> {
let p = self.p;
if pert_h.nrows() != p || pert_h.ncols() != p {
return Err(format!(
"JeffreysHphiDriftBase::perturbation_derivative: pert_h shape {}x{} != {p}x{p}",
pert_h.nrows(),
pert_h.ncols()
));
}
let m = self.m;
let z_j = self.z_j.view();
let evals = &self.evals;
let evecs = &self.evecs;
let floor = self.floor;
let gate_weight = self.gate_weight;
let psi = &self.psi;
let floor_in_relative_regime = self.floor_in_relative_regime;
let idx_min = self.idx_min;
let idx_max = self.idx_max;
let lambda_min = evals[idx_min];
let lambda_max = evals[idx_max];
let dbar_raw = z_j.t().dot(&pert_h.dot(&z_j)); let mut dbar = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
dbar[[i, j]] = 0.5 * (dbar_raw[[i, j]] + dbar_raw[[j, i]]);
}
}
let dbar_red = evecs.t().dot(&dbar).dot(evecs); let dfloor = if floor_in_relative_regime {
REDUCED_INFO_RELATIVE_FLOOR * dbar_red[[idx_max, idx_max]]
} else {
0.0
};
let mut rotation = Array2::<f64>::zeros((m, m));
let mut dpsi = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
let denom = evals[j] - evals[i];
if denom.abs() > REDUCED_INFO_ABSOLUTE_FLOOR {
rotation[[i, j]] = dbar_red[[i, j]] / denom;
}
let gap = evals[i] - evals[j];
if gap.abs() > REDUCED_INFO_ABSOLUTE_FLOOR {
let dp_i = floored_inverse_prime(evals[i], floor);
let dp_j = floored_inverse_prime(evals[j], floor);
let lam_dot_i = dbar_red[[i, i]];
let lam_dot_j = dbar_red[[j, j]];
dpsi[[i, j]] =
((dp_i - psi[[i, j]]) * lam_dot_i + (psi[[i, j]] - dp_j) * lam_dot_j) / gap;
if dfloor != 0.0 {
dpsi[[i, j]] += (floored_inverse_floor_sensitivity(evals[i], floor)
- floored_inverse_floor_sensitivity(evals[j], floor))
/ gap
* dfloor;
}
} else {
dpsi[[i, j]] = floored_inverse_second(evals[i], floor)
* 0.5
* (dbar_red[[i, i]] + dbar_red[[j, j]]);
if dfloor != 0.0 {
dpsi[[i, j]] +=
floored_inverse_prime_floor_sensitivity(evals[i], floor) * dfloor;
}
}
}
}
let a_rows = &self.a_rows; let aw_rows = &self.aw_rows; let mut da_rows = Array2::<f64>::zeros((p, m * m)); let mut dw_rows = Array2::<f64>::zeros((p, m * m));
for (a, pert_hdot_a) in pert_hdots.iter().enumerate() {
if pert_hdot_a.nrows() != p || pert_hdot_a.ncols() != p {
return Err(format!(
"JeffreysHphiDriftBase::perturbation_derivative: ∂Hdot[e_{a}] shape {}x{} != {p}x{p}",
pert_hdot_a.nrows(),
pert_hdot_a.ncols()
));
}
}
{
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
da_rows
.axis_iter_mut(ndarray::Axis(0))
.into_par_iter()
.zip(dw_rows.axis_iter_mut(ndarray::Axis(0)).into_par_iter())
.zip(a_rows.axis_iter(ndarray::Axis(0)).into_par_iter())
.zip(pert_hdots.into_par_iter())
.for_each(|(((mut da_row, mut dw_row), a_flat), pert_hdot_a)| {
crate::linalg::faer_ndarray::with_nested_parallel(|| {
let mut a_a = Array2::<f64>::zeros((m, m));
{
let mut col = 0usize;
for i in 0..m {
for j in 0..m {
a_a[[i, j]] = a_flat[col];
col += 1;
}
}
}
let d_a_pert_raw = z_j.t().dot(&pert_hdot_a.dot(&z_j)); let mut d_a_pert = Array2::<f64>::zeros((m, m));
for i in 0..m {
for j in 0..m {
d_a_pert[[i, j]] =
0.5 * (d_a_pert_raw[[i, j]] + d_a_pert_raw[[j, i]]);
}
}
let da_a = evecs.t().dot(&d_a_pert).dot(evecs) + &a_a.dot(&rotation)
- &rotation.dot(&a_a);
let mut col = 0usize;
for i in 0..m {
for j in 0..m {
da_row[col] = da_a[[i, j]];
dw_row[col] =
dpsi[[i, j]] * a_a[[i, j]] + psi[[i, j]] * da_a[[i, j]];
col += 1;
}
}
});
});
}
let mut out = Array2::<f64>::zeros((p, p));
for a in 0..p {
for b in a..p {
let mut acc = 0.0;
for col in 0..(m * m) {
acc += dw_rows[[a, col]] * a_rows[[b, col]]
+ aw_rows[[a, col]] * da_rows[[b, col]];
}
let value = -0.5 * acc;
out[[a, b]] = value;
out[[b, a]] = value;
}
}
let mut result = out * gate_weight;
let (g_dlmin, g_dlmax) = conditioning_gate_weight_grad(lambda_min, lambda_max);
if g_dlmin != 0.0 || g_dlmax != 0.0 {
let extreme_perturbation = |idx: usize| -> f64 {
let v = evecs.column(idx);
v.dot(&dbar.dot(&v))
};
let d_gate =
g_dlmin * extreme_perturbation(idx_min) + g_dlmax * extreme_perturbation(idx_max);
if d_gate != 0.0 {
let hphi_raw = aw_rows.dot(&a_rows.t()).mapv(|x| -0.5 * x);
result.scaled_add(d_gate, &hphi_raw);
}
}
Ok(result)
}
}
pub(crate) fn joint_jeffreys_hphi_perturbation_derivative<BaseFn, PertFn>(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
base_hessian_dir: BaseFn,
pert_h: &Array2<f64>,
pert_hessian_dir: PertFn,
) -> Result<Array2<f64>, String>
where
BaseFn: Fn(&Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
PertFn: Fn(&Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
{
let p = h_joint.nrows();
match JeffreysHphiDriftBase::prepare(h_joint, z_j, base_hessian_dir)? {
None => Ok(Array2::zeros((p, p))),
Some(base) => base.perturbation_derivative(pert_h, pert_hessian_dir),
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn joint_jeffreys_hphi_directional_derivative<DirFn, Dir2Fn>(
h_joint: ArrayView2<'_, f64>,
z_j: ArrayView2<'_, f64>,
delta: &Array1<f64>,
hessian_dir: DirFn,
hessian_second_dir: Dir2Fn,
) -> Result<Array2<f64>, String>
where
DirFn: Fn(&Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
Dir2Fn: Fn(&Array1<f64>, &Array1<f64>) -> Result<Option<Array2<f64>>, String> + Sync,
{
let p = h_joint.nrows();
if delta.len() != p {
return Err(format!(
"joint_jeffreys_hphi_directional_derivative: delta has {} entries, expected {p}",
delta.len()
));
}
let pert_h = match hessian_dir(delta)? {
Some(hd) => hd,
None => return Ok(Array2::zeros((p, p))),
};
if pert_h.nrows() != p || pert_h.ncols() != p {
return Err(format!(
"joint_jeffreys_hphi_directional_derivative: Hdot[δ] shape {}x{} != {p}x{p}",
pert_h.nrows(),
pert_h.ncols()
));
}
joint_jeffreys_hphi_perturbation_derivative(
h_joint,
z_j,
|axis| hessian_dir(axis),
&pert_h,
|axis| hessian_second_dir(delta, axis),
)
}
#[test]
pub(crate) fn explicit_param_derivative_matches_finite_difference() {
let p = 4usize;
let z = Array2::<f64>::eye(p);
let h0 = array![
[30.0, 1.0, 0.5, 0.2],
[1.0, 12.0, 0.3, 0.1],
[0.5, 0.3, 5.0, 0.4],
[0.2, 0.1, 0.4, 1.5],
];
let pmat = array![
[2.0, 0.3, 0.1, 0.05],
[0.3, 1.5, 0.2, 0.1],
[0.1, 0.2, 1.0, 0.15],
[0.05, 0.1, 0.15, 0.7],
];
let make_sym = |seed: f64| -> Array2<f64> {
let mut a = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
a[[i, j]] = (seed + 0.37 * (i as f64) - 0.19 * (j as f64)).sin()
+ 0.5 * ((i + j) as f64 * seed).cos();
}
}
let at = a.t().to_owned();
(&a + &at).mapv(|v| 0.5 * v)
};
let g: Vec<Array2<f64>> = (0..p).map(|a| make_sym(1.0 + a as f64)).collect();
let q: Vec<Array2<f64>> = (0..p).map(|a| make_sym(7.0 + 2.0 * a as f64)).collect();
let axis_index = |axis: &Array1<f64>| -> usize {
axis.iter().position(|&x| x != 0.0).expect("one-hot axis")
};
let hphi_at = |s: f64| -> Array2<f64> {
let h = &h0 + &pmat.mapv(|v| s * v);
joint_jeffreys_term(h.view(), z.view(), |axis: &Array1<f64>| {
let a = axis_index(axis);
Ok(Some(&g[a] + &q[a].mapv(|v| s * v)))
})
.expect("value-path H_Φ")
.2
};
let s0 = 0.0_f64;
let hh = 1e-5;
let fd = (&hphi_at(s0 + hh) - &hphi_at(s0 - hh)).mapv(|v| v / (2.0 * hh));
let h_s0 = &h0 + &pmat.mapv(|v| s0 * v);
let analytic = joint_jeffreys_hphi_explicit_param_derivative(
h_s0.view(),
z.view(),
&pmat,
|axis: &Array1<f64>| {
let a = axis_index(axis);
Ok(Some(&g[a] + &q[a].mapv(|v| s0 * v)))
},
|axis: &Array1<f64>| {
let a = axis_index(axis);
Ok(Some(q[a].clone()))
},
)
.expect("explicit ∂_s H_Φ");
let mut max_err = 0.0_f64;
for i in 0..p {
for j in 0..p {
max_err = max_err.max((analytic[[i, j]] - fd[[i, j]]).abs());
}
}
assert!(
max_err < 1e-5,
"explicit ∂_s H_Φ mismatch vs FD: max_err={max_err}"
);
}
#[test]
pub(crate) fn perturbation_derivative_matches_finite_difference_below_floor() {
let p = 3usize;
let z = Array2::<f64>::eye(p);
let h0 = array![
[5.0e8, 2.0e3, 1.0e2],
[2.0e3, 4.0e8, 5.0e1],
[1.0e2, 5.0e1, 1.0e-4],
];
let make_sym = |seed: f64| -> Array2<f64> {
let mut a = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
a[[i, j]] = (seed + 0.41 * (i as f64) - 0.23 * (j as f64)).sin()
+ 0.6 * ((i + j) as f64 * seed).cos();
}
}
let at = a.t().to_owned();
(&a + &at).mapv(|v| 0.5 * v)
};
let a_mats: Vec<Array2<f64>> = (0..p).map(|a| make_sym(2.3 + 1.7 * a as f64)).collect();
let axis_index = |axis: &Array1<f64>| -> usize {
axis.iter().position(|&x| x != 0.0).expect("one-hot axis")
};
let hphi_at = |t: f64| -> Array2<f64> {
let h = &h0 + &a_mats[0].mapv(|v| t * v);
joint_jeffreys_term(h.view(), z.view(), |axis: &Array1<f64>| {
Ok(Some(a_mats[axis_index(axis)].clone()))
})
.expect("value-path H_Φ")
.2
};
let hh = 1e-5;
let fd = (&hphi_at(hh) - &hphi_at(-hh)).mapv(|v| v / (2.0 * hh));
let mut delta = Array1::<f64>::zeros(p);
delta[0] = 1.0;
let analytic = joint_jeffreys_hphi_directional_derivative(
h0.view(),
z.view(),
&delta,
|d: &Array1<f64>| {
let mut acc = Array2::<f64>::zeros((p, p));
for a in 0..p {
if d[a] != 0.0 {
acc.scaled_add(d[a], &a_mats[a]);
}
}
Ok(Some(acc))
},
|_u: &Array1<f64>, _v: &Array1<f64>| Ok(Some(Array2::<f64>::zeros((p, p)))),
)
.expect("mode-response drift D_β H_Φ[δ]");
let mut max_rel = 0.0_f64;
for i in 0..p {
for j in 0..p {
let scale = fd[[i, j]].abs().max(analytic[[i, j]].abs()).max(1.0);
max_rel = max_rel.max((analytic[[i, j]] - fd[[i, j]]).abs() / scale);
}
}
assert!(
max_rel < 1e-4,
"mode-response drift D_β H_Φ[δ] mismatch vs FD (below-floor): max_rel={max_rel}"
);
}
pub(crate) fn conditioning_gate_skips(lambda_min: f64, lambda_max: f64) -> bool {
conditioning_gate_weight(lambda_min, lambda_max) == 0.0
}
#[test]
pub(crate) fn full_span_is_identity_regardless_of_penalty() {
for s in [
Array2::<f64>::zeros((3, 3)), {
let mut s = Array2::<f64>::zeros((3, 3));
s[[2, 2]] = 5.0; s
},
Array2::<f64>::eye(4) * 2.0, ] {
let p = s.nrows();
let z = jeffreys_subspace_from_penalty(s.view()).unwrap();
assert_eq!(z.span_dim(), p, "full span must equal the block dimension");
assert_eq!(z.columns.nrows(), p);
let gram = z.columns.t().dot(&z.columns);
for i in 0..p {
for j in 0..p {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((gram[[i, j]] - expected).abs() < 1e-12);
}
}
}
}
#[test]
pub(crate) fn empty_block_yields_empty_span() {
let s = Array2::<f64>::zeros((0, 0));
let z = jeffreys_subspace_from_penalty(s.view()).unwrap();
assert_eq!(z.span_dim(), 0);
}
#[test]
pub(crate) fn joint_jeffreys_term_matches_finite_difference_gradient() {
let p = 2usize;
let ill = 1e-9_f64;
let z = Array2::<f64>::eye(p);
let h_at = |b: &Array1<f64>| -> Array2<f64> {
let mut h = Array2::<f64>::zeros((p, p));
h[[0, 0]] = b[0].exp();
h[[1, 1]] = ill * (1.0 + b[1] * b[1]);
h
};
let beta: Array1<f64> = array![0.3, -0.4];
let hdir = |d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
let mut hd = Array2::<f64>::zeros((p, p));
hd[[0, 0]] = beta[0].exp() * d[0];
hd[[1, 1]] = ill * 2.0 * beta[1] * d[1];
Ok(Some(hd))
};
let h = h_at(&beta);
let (phi, grad, hphi) = joint_jeffreys_term(h.view(), z.view(), hdir).unwrap();
let expected_phi = 0.5 * (beta[0].exp() * ill * (1.0 + beta[1] * beta[1])).ln();
assert!(
(phi - expected_phi).abs() < 1e-6,
"phi {phi} vs {expected_phi}"
);
let eps = 1e-6;
for k in 0..p {
let mut bp = beta.clone();
let mut bm = beta.clone();
bp[k] += eps;
bm[k] -= eps;
let hp = h_at(&bp);
let hm = h_at(&bm);
let phi_p = 0.5 * (hp[[0, 0]] * hp[[1, 1]]).ln();
let phi_m = 0.5 * (hm[[0, 0]] * hm[[1, 1]]).ln();
let fd = (phi_p - phi_m) / (2.0 * eps);
assert!(
(grad[k] - fd).abs() < 1e-5,
"grad[{k}] {} vs fd {fd}",
grad[k]
);
}
for a in 0..p {
for b in 0..p {
assert!((hphi[[a, b]] - hphi[[b, a]]).abs() < 1e-12);
}
}
let (evals, _) = hphi.eigh(Side::Lower).unwrap();
for e in evals.iter() {
assert!(*e >= -1e-10, "H_Phi must be PSD, got eigenvalue {e}");
}
}
#[test]
pub(crate) fn joint_jeffreys_term_value_gradient_consistent_below_floor() {
let p = 2usize;
let ill = 1e-12_f64;
let z = Array2::<f64>::eye(p);
let h_at = |b: &Array1<f64>| -> Array2<f64> {
let mut h = Array2::<f64>::zeros((p, p));
h[[0, 0]] = b[0].exp();
h[[1, 1]] = ill * (1.0 + b[1] * b[1]);
h
};
let beta: Array1<f64> = array![0.3, -0.4];
let hdir = |d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
let mut hd = Array2::<f64>::zeros((p, p));
hd[[0, 0]] = beta[0].exp() * d[0];
hd[[1, 1]] = ill * 2.0 * beta[1] * d[1];
Ok(Some(hd))
};
let h = h_at(&beta);
let (_phi, grad, _hphi) = joint_jeffreys_term(h.view(), z.view(), hdir).unwrap();
let value_at = |b: &Array1<f64>| -> f64 {
let hh = h_at(b);
let lam0 = hh[[0, 0]];
let lam1 = hh[[1, 1]];
let lambda_max = lam0.max(lam1);
let floor = (1e-10_f64 * lambda_max).max(1e-12_f64);
let g = |lam: f64| -> f64 {
if lam >= floor {
lam.ln()
} else {
lam / floor + floor.ln() - 1.0
}
};
0.5 * (g(lam0) + g(lam1))
};
let eps = 1e-7;
for k in 0..p {
let mut bp = beta.clone();
let mut bm = beta.clone();
bp[k] += eps;
bm[k] -= eps;
let fd = (value_at(&bp) - value_at(&bm)) / (2.0 * eps);
assert!(
(grad[k] - fd).abs() <= 1e-5 * (1.0 + fd.abs()),
"below-floor grad[{k}] {} vs fd {fd}; the Jeffreys value must be the \
exact antiderivative of the floored-inverse gradient",
grad[k]
);
}
}
#[test]
pub(crate) fn joint_jeffreys_term_indefinite_value_gradient_consistent() {
let p = 2usize;
let z = Array2::<f64>::eye(p);
let h_at = |b: &Array1<f64>| -> Array2<f64> {
let mut h = Array2::<f64>::zeros((p, p));
h[[0, 0]] = b[0].exp();
h[[1, 1]] = -(1.0 + b[1] * b[1]);
h
};
let beta: Array1<f64> = array![0.3, -0.4];
let hdir = |d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
let mut hd = Array2::<f64>::zeros((p, p));
hd[[0, 0]] = beta[0].exp() * d[0];
hd[[1, 1]] = -2.0 * beta[1] * d[1];
Ok(Some(hd))
};
let h = h_at(&beta);
let (phi, grad, hphi) = joint_jeffreys_term(h.view(), z.view(), hdir).unwrap();
assert!(
grad.iter().all(|g| g.abs() < 1e3),
"indefinite direction must carry no phantom Firth score; grad={grad:?}"
);
assert!(
grad[1].abs() < 1e-6,
"saturating branch must be flat on a moderate negative eigenvalue; grad[1]={}",
grad[1]
);
let lam0 = beta[0].exp();
let lam1 = -(1.0 + beta[1] * beta[1]);
let floor = 1e-10_f64 * lam0;
let g_sat = |lam: f64, floor: f64| -> f64 {
if lam >= floor {
lam.ln()
} else if lam >= 0.0 {
lam / floor + floor.ln() - 1.0
} else {
floor.ln() - 1.0 + lam / (floor - lam)
}
};
let expected_phi = 0.5 * (lam0.ln() + g_sat(lam1, floor));
assert!(
(phi - expected_phi).abs() < 1e-9,
"phi {phi} vs {expected_phi}"
);
let value_at = |b: &Array1<f64>| -> f64 {
let hh = h_at(b);
let lam_max = hh[[0, 0]].max(0.0);
let fl = (1e-10 * lam_max).max(1e-12);
0.5 * (g_sat(hh[[0, 0]], fl) + g_sat(hh[[1, 1]], fl))
};
let eps = 1e-7;
for k in 0..p {
let mut bp = beta.clone();
let mut bm = beta.clone();
bp[k] += eps;
bm[k] -= eps;
let fd = (value_at(&bp) - value_at(&bm)) / (2.0 * eps);
assert!(
(grad[k] - fd).abs() <= 1e-5 * (1.0 + fd.abs()),
"indefinite grad[{k}] {} vs fd {fd}; value/gradient must share the saturating g",
grad[k]
);
}
for a in 0..p {
for b in 0..p {
assert!((hphi[[a, b]] - hphi[[b, a]]).abs() < 1e-12);
}
}
}
#[test]
pub(crate) fn conditioning_gate_skips_well_conditioned_information() {
let p = 2usize;
let z = Array2::<f64>::eye(p);
let mut h = Array2::<f64>::zeros((p, p));
h[[0, 0]] = 200.0;
h[[1, 1]] = 100.0; let hdir = |d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
let mut hd = Array2::<f64>::zeros((p, p));
hd[[0, 0]] = 3.0 * d[0];
hd[[1, 1]] = 5.0 * d[1];
Ok(Some(hd))
};
let (phi, grad, hphi) = joint_jeffreys_term(h.view(), z.view(), hdir).unwrap();
assert_eq!(phi, 0.0, "well-conditioned ⇒ no Jeffreys value");
assert!(
grad.iter().all(|v| *v == 0.0),
"well-conditioned ⇒ zero grad"
);
assert!(
hphi.iter().all(|v| *v == 0.0),
"well-conditioned ⇒ zero curvature"
);
}
#[test]
pub(crate) fn conditioning_gate_fires_only_below_threshold() {
let p = 2usize;
let z = Array2::<f64>::eye(p);
let hdir = |d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
let mut hd = Array2::<f64>::zeros((p, p));
hd[[0, 0]] = d[0];
hd[[1, 1]] = d[1];
Ok(Some(hd))
};
let mk = |lmin: f64| {
let mut h = Array2::<f64>::zeros((p, p));
h[[0, 0]] = 1.0;
h[[1, 1]] = lmin;
h
};
let mut above = mk(50.0);
above[[0, 0]] = 100.0;
let (phi_a, grad_a, _) = joint_jeffreys_term(above.view(), z.view(), hdir).unwrap();
assert_eq!(phi_a, 0.0);
assert!(grad_a.iter().all(|v| *v == 0.0));
let below_rel = mk(CONDITIONING_GATE_RELATIVE * 0.1);
let (phi_r, _g, hphi_r) = joint_jeffreys_term(below_rel.view(), z.view(), hdir).unwrap();
assert!(phi_r != 0.0, "relatively near-separating must fire");
assert!(hphi_r.iter().any(|v| v.abs() > 0.0));
let below_abs = mk(0.05);
let (phi_b, _grad_b, hphi_b) =
joint_jeffreys_term(below_abs.view(), z.view(), hdir).unwrap();
assert!(
phi_b != 0.0,
"absolutely near-separating (small-n) must fire even though the relative ratio clears the gate",
);
assert!(
hphi_b.iter().any(|v| v.abs() > 0.0),
"absolute-gate firing must produce nonzero bounding curvature",
);
}
#[test]
pub(crate) fn conditioning_gate_predicate_relative_and_absolute() {
assert!(conditioning_gate_skips(50.0, 100.0));
assert!(!conditioning_gate_skips(
CONDITIONING_GATE_RELATIVE * 0.1,
1.0
));
assert!(!conditioning_gate_skips(0.05, 1.0));
assert!(!conditioning_gate_skips(
CONDITIONING_GATE_ABSOLUTE,
CONDITIONING_GATE_ABSOLUTE
));
assert!(!conditioning_gate_skips(4.0, 100.0));
assert!(conditioning_gate_skips(
CONDITIONING_GATE_ABSOLUTE_CLEAR,
CONDITIONING_GATE_ABSOLUTE_CLEAR
));
assert!(!conditioning_gate_skips(0.0, 0.0));
assert!(!conditioning_gate_skips(f64::NAN, 100.0));
}
#[test]
pub(crate) fn conditioning_gate_weight_is_continuous_and_monotone() {
let lambda_max = 1.0e6; let w = |lmin: f64| conditioning_gate_weight(lmin, lambda_max);
assert_eq!(w(CONDITIONING_GATE_ABSOLUTE), 1.0);
assert_eq!(w(0.1), 1.0);
assert_eq!(w(CONDITIONING_GATE_ABSOLUTE_CLEAR), 0.0);
assert_eq!(w(100.0), 0.0);
let mut prev = 1.0;
let n = 200usize;
for k in 0..=n {
let lmin = CONDITIONING_GATE_ABSOLUTE
+ (CONDITIONING_GATE_ABSOLUTE_CLEAR - CONDITIONING_GATE_ABSOLUTE)
* (k as f64 / n as f64);
let cur = w(lmin);
assert!((0.0..=1.0).contains(&cur));
assert!(cur <= prev + 1e-12, "weight must be non-increasing");
assert!(
(prev - cur).abs() < 0.1,
"no large jumps across the smooth band (continuity)"
);
prev = cur;
}
}
#[test]
pub(crate) fn conditioning_gate_weight_grad_matches_finite_difference() {
let configs: [(f64, f64); 6] = [
(8.0, 1.0e9), (4.0, 1.0e9), (12.0, 1.0e9), (100.0, 100.0 / 1.0e-7), (0.05, 1.0e9), (1.0e3, 1.0e3 / 1.0e-9), ];
for &(lmin, lmax) in &configs {
let (g_dlmin, g_dlmax) = conditioning_gate_weight_grad(lmin, lmax);
let hmin = 1e-7 * lmin.abs().max(1e-3);
let fd_dlmin = (conditioning_gate_weight(lmin + hmin, lmax)
- conditioning_gate_weight(lmin - hmin, lmax))
/ (2.0 * hmin);
assert!(
(fd_dlmin - g_dlmin).abs() <= 1e-4 * g_dlmin.abs().max(1.0),
"∂G/∂λ_min desync at (λ_min={lmin}, λ_max={lmax}): fd={fd_dlmin} analytic={g_dlmin}"
);
let hmax = 1e-7 * lmax.abs().max(1e-3);
let fd_dlmax = (conditioning_gate_weight(lmin, lmax + hmax)
- conditioning_gate_weight(lmin, lmax - hmax))
/ (2.0 * hmax);
assert!(
(fd_dlmax - g_dlmax).abs() <= 1e-4 * g_dlmax.abs().max(1.0),
"∂G/∂λ_max desync at (λ_min={lmin}, λ_max={lmax}): fd={fd_dlmax} analytic={g_dlmax}"
);
}
}
#[test]
pub(crate) fn empty_span_yields_zero_term() {
let h = Array2::<f64>::eye(3);
let z = Array2::<f64>::zeros((3, 0));
let hdir = |_d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
Ok(Some(Array2::<f64>::zeros((3, 3))))
};
let (phi, grad, hphi) = joint_jeffreys_term(h.view(), z.view(), hdir).unwrap();
assert_eq!(phi, 0.0);
assert!(grad.iter().all(|v| *v == 0.0));
assert!(hphi.iter().all(|v| *v == 0.0));
}
pub(crate) fn diag_hv(
diag: Vec<f64>,
) -> impl FnMut(&Array1<f64>) -> Result<Array1<f64>, String> {
move |v: &Array1<f64>| {
let mut out = Array1::<f64>::zeros(v.len());
for (i, &d) in diag.iter().enumerate() {
out[i] = d * v[i];
}
Ok(out)
}
}
#[test]
pub(crate) fn cheap_precheck_skips_clearly_well_conditioned_large_p() {
let p = 200usize;
let mut diag = vec![220.0; p];
diag[0] = 200.0; diag[1] = 250.0; let skippable = jeffreys_term_skippable_via_matvec(diag_hv(diag), p).unwrap();
assert!(
skippable,
"clearly well-conditioned wide fit must be skippable"
);
}
#[test]
pub(crate) fn cheap_precheck_does_not_skip_near_separating() {
let p = 200usize;
let mut diag = vec![50.0; p];
diag[7] = 1e-3; let skippable = jeffreys_term_skippable_via_matvec(diag_hv(diag), p).unwrap();
assert!(
!skippable,
"a near-separating direction must NOT be skipped (term is needed)"
);
}
#[test]
pub(crate) fn cheap_precheck_does_not_skip_below_size_threshold() {
let p = CHEAP_CONDITIONING_PRECHECK_MIN_DIM - 1;
let diag = vec![100.0; p];
let skippable = jeffreys_term_skippable_via_matvec(diag_hv(diag), p).unwrap();
assert!(
!skippable,
"below the size threshold the pre-check never skips"
);
}
#[test]
pub(crate) fn cheap_precheck_does_not_skip_marginal_absolute() {
let p = 200usize;
let mut diag = vec![50.0; p];
diag[3] = 2.0;
let skippable = jeffreys_term_skippable_via_matvec(diag_hv(diag), p).unwrap();
assert!(
!skippable,
"λ_min within the 8× absolute margin must conservatively fall through"
);
}
#[test]
pub(crate) fn cheap_precheck_skip_implies_exact_gate_skips() {
let p = 150usize;
let z = Array2::<f64>::eye(p);
let hdir = |_d: &Array1<f64>| -> Result<Option<Array2<f64>>, String> {
Ok(Some(Array2::<f64>::zeros((p, p))))
};
for &lmin in &[10.0_f64, 25.0, 80.0, 200.0] {
let mut diag = vec![lmin * 4.0; p];
diag[0] = lmin;
let cheap_skip = jeffreys_term_skippable_via_matvec(diag_hv(diag.clone()), p).unwrap();
if cheap_skip {
let mut h = Array2::<f64>::zeros((p, p));
for (i, &d) in diag.iter().enumerate() {
h[[i, i]] = d;
}
let (phi, grad, hphi) = joint_jeffreys_term(h.view(), z.view(), hdir).unwrap();
assert_eq!(
phi, 0.0,
"cheap-skip ⇒ exact phi must be zero (byte-identical)"
);
assert!(grad.iter().all(|v| *v == 0.0));
assert!(hphi.iter().all(|v| *v == 0.0));
}
}
}
#[test]
pub(crate) fn cheap_precheck_bails_on_nonfinite_matvec() {
let p = 200usize;
let hv = |v: &Array1<f64>| -> Result<Array1<f64>, String> {
Ok(Array1::from_elem(v.len(), f64::NAN))
};
assert!(!jeffreys_term_skippable_via_matvec(hv, p).unwrap());
}
#[test]
pub(crate) fn jeffreys_antiderivative_is_consistent_value_slope_floor_triple() {
let floor = 1e-3_f64;
let cap = jeffreys_cap(floor);
for &lam in &[cap * 4.0, (floor + cap) * 0.5, floor * 0.5, -0.7_f64] {
let hl = 1e-7 * lam.abs().max(1e-3);
let fd_lam = (jeffreys_antiderivative(lam + hl, floor)
- jeffreys_antiderivative(lam - hl, floor))
/ (2.0 * hl);
let dl = floored_inverse(lam, floor);
assert!(
(fd_lam - dl).abs() <= 1e-4 * dl.abs().max(1.0),
"∂g/∂λ desync at λ={lam}: fd={fd_lam} analytic={dl}"
);
let hf = 1e-7 * floor;
let fd_floor = (jeffreys_antiderivative(lam, floor + hf)
- jeffreys_antiderivative(lam, floor - hf))
/ (2.0 * hf);
let df = jeffreys_antiderivative_floor_sensitivity(lam, floor);
assert!(
(fd_floor - df).abs() <= 1e-4 * df.abs().max(1.0),
"∂g/∂floor desync at λ={lam}: fd={fd_floor} analytic={df}"
);
}
let big_floor = CONDITIONING_GATE_ABSOLUTE_CLEAR * 10.0;
assert!((jeffreys_cap(big_floor) - big_floor).abs() < 1e-12);
let lam_top = big_floor * 3.0;
let hf = 1e-7 * big_floor;
let fd_floor = (jeffreys_antiderivative(lam_top, big_floor + hf)
- jeffreys_antiderivative(lam_top, big_floor - hf))
/ (2.0 * hf);
let df = jeffreys_antiderivative_floor_sensitivity(lam_top, big_floor);
assert!(
df != 0.0 && (fd_floor - df).abs() <= 1e-4 * df.abs().max(1.0),
"floor-bound-cap ∂g/∂floor desync: fd={fd_floor} analytic={df}"
);
}
}