use super::*;
use super::tests::{fixed_state_logdet, gamma_fd_tiny_fixture};
#[test]
pub(crate) fn learnable_ibp_alpha_logdet_trace_matches_dense_fd_pd_region_deflation() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, true);
rho.log_lambda_sparse = 0.5;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
assert!(
cache.gauge_deflated_directions > 0,
"the PD-region deflation regression requires a deflated direction; got \
{} (fixture no longer deflates — re-pick ρ)",
cache.gauge_deflated_directions
);
assert!(
cache
.deflated_row_directions
.iter()
.any(|dirs| !dirs.is_empty()),
"deflated directions were not surfaced into the cache"
);
let solver = DeflatedArrowSolver::plain(&cache);
let prior_trace = term
.assignment_log_strength_hessian_trace(&rho, &cache, &solver)
.expect("prior-Hessian alpha trace");
let data_trace = term
.learnable_ibp_data_logdet_alpha_trace(&rho, &cache, &solver)
.expect("data-Hessian alpha trace");
let analytic = prior_trace + data_trace;
let h = 1.0e-5;
let mut rho_plus = rho.clone();
let mut rho_minus = rho.clone();
rho_plus.log_lambda_sparse += h;
rho_minus.log_lambda_sparse -= h;
let fd_half = 0.5
* (fixed_state_logdet(term.clone(), &target, &rho_plus)
- fixed_state_logdet(term.clone(), &target, &rho_minus))
/ (2.0 * h);
let tol = 1.0e-6 * (1.0 + fd_half.abs().max(analytic.abs()));
assert!(
(fd_half - analytic).abs() <= tol,
"PD-region deflation logdet trace: fd(½∂log|H|/∂logα)={fd_half:.8e}, \
analytic(prior+data)={analytic:.8e} (prior={prior_trace:.6e}, \
data={data_trace:.6e}), gap={:.6e} > tol={tol:.6e}",
(fd_half - analytic).abs()
);
}
#[test]
pub(crate) fn ard_log_precision_trace_matches_dense_fd_pd_region_deflation() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, true);
rho.log_lambda_sparse = 0.5;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
assert!(
cache.gauge_deflated_directions > 0,
"ARD deflation regression requires a deflated direction; got {}",
cache.gauge_deflated_directions
);
let solver = DeflatedArrowSolver::plain(&cache);
let analytic = term
.ard_log_precision_hessian_trace(&rho, &cache, &solver)
.expect("ARD log-precision trace");
let h = 1.0e-5;
let mut checked = 0usize;
for atom in 0..rho.log_ard.len() {
for axis in 0..rho.log_ard[atom].len() {
let mut rho_plus = rho.clone();
let mut rho_minus = rho.clone();
rho_plus.log_ard[atom][axis] += h;
rho_minus.log_ard[atom][axis] -= h;
let fd_half = 0.5
* (fixed_state_logdet(term.clone(), &target, &rho_plus)
- fixed_state_logdet(term.clone(), &target, &rho_minus))
/ (2.0 * h);
let a = analytic[atom][axis];
let tol = 5.0e-3 * (1.0 + fd_half.abs().max(a.abs()));
assert!(
(fd_half - a).abs() <= tol,
"ARD trace atom={atom} axis={axis}: fd={fd_half:.8e} analytic={a:.8e} \
gap={:.6e} tol={tol:.6e}",
(fd_half - a).abs()
);
checked += 1;
}
}
assert!(checked > 0, "no ARD axes were checked");
}
#[test]
pub(crate) fn learnable_ibp_data_logdet_trace_zeroes_ungated_atom_1026() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, true);
term.assignment = term
.assignment
.clone()
.with_ungated(vec![false, true])
.unwrap();
{
let mut found = None;
for &r in &[1.0_f64, 1.5, 2.0, 2.5, 3.0, 0.5, 0.0, -0.5] {
let mut probe = term.clone();
let mut rr = rho.clone();
rr.log_lambda_sparse = r;
if probe
.reml_criterion_with_cache(target.view(), &rr, None, 5, 0.4, 1.0e-6, 1.0e-6)
.is_ok()
{
found = Some(r);
break;
}
}
rho.log_lambda_sparse =
found.expect("no PD-region ρ found for the ungated learnable-α fixture");
}
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache at the PD ρ");
let solver = DeflatedArrowSolver::plain(&cache);
let prior_trace = term
.assignment_log_strength_hessian_trace(&rho, &cache, &solver)
.expect("prior-Hessian alpha trace");
let data_trace = term
.learnable_ibp_data_logdet_alpha_trace(&rho, &cache, &solver)
.expect("data-Hessian alpha trace");
let analytic = prior_trace + data_trace;
assert!(
prior_trace.is_finite() && data_trace.is_finite() && analytic.is_finite(),
"ungated learnable-α traces must be finite: prior={prior_trace}, \
data={data_trace}, analytic={analytic}"
);
}