use super::*;
use super::tests::{fixed_state_logdet, gamma_fd_tiny_fixture};
fn ibp_1416_oracle_term() -> (SaeManifoldTerm, SaeManifoldRho) {
let n = 2usize;
let p = 1usize;
let m = 3usize;
let coords = Array2::from_shape_vec((n, 1), vec![0.15_f64, 0.65_f64]).unwrap();
let evaluator = std::sync::Arc::new(PeriodicHarmonicEvaluator::new(m).unwrap());
let atom = SaeManifoldAtom::new(
"ibp1416",
SaeAtomBasisKind::Periodic,
1,
Array2::<f64>::zeros((n, m)),
Array3::<f64>::zeros((n, m, 1)),
Array2::<f64>::zeros((m, p)),
Array2::<f64>::eye(m),
)
.unwrap()
.with_basis_second_jet(evaluator);
let logits = Array2::from_shape_vec((n, 1), vec![0.2_f64, -0.4_f64]).unwrap();
let assignment = SaeAssignment::from_blocks_with_mode_and_manifolds(
logits,
vec![coords],
vec![LatentManifold::Circle { period: 1.0 }],
AssignmentMode::ibp_map(0.8, 1.8, false),
)
.unwrap();
let term = SaeManifoldTerm::new(vec![atom], assignment).unwrap();
let rho = SaeManifoldRho::new(0.0, 0.0, vec![Array1::from_vec(vec![0.0])]);
(term, rho)
}
fn ibp_1416_oracle_cache(term: &SaeManifoldTerm, rho: &SaeManifoldRho) -> ArrowFactorCache {
let n = term.n_obs();
let channels = ibp_assignment_third_channels(&term.assignment, rho)
.expect("channels")
.expect("IBP mode must yield cross-row channels");
let hdiag = assignment_prior_log_strength_hdiag(&term.assignment, rho).expect("hdiag");
let data_curv = 1.2_f64;
let mut sys = ArrowSchurSystem::new(n, 1, 0);
for row in 0..n {
sys.rows[row].htt[[0, 0]] = data_curv + hdiag[row];
}
let entries: Vec<(usize, usize, f64)> = (0..n).map(|i| (i, 0usize, channels.z_jac[i])).collect();
let source = IbpCrossRowSource {
r: 1,
d: channels.cross_row_d.clone(),
entries,
};
sys.set_ibp_cross_row_source(source);
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let (_dt, _db, cache) =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options).expect("factor H");
cache
}
fn ibp_1416_oracle_cache_with_coord(
term: &SaeManifoldTerm,
rho: &SaeManifoldRho,
) -> ArrowFactorCache {
let n = term.n_obs();
let channels = ibp_assignment_third_channels(&term.assignment, rho)
.expect("channels")
.expect("IBP mode must yield cross-row channels");
let hdiag = assignment_prior_log_strength_hdiag(&term.assignment, rho).expect("hdiag");
let data_curv = 1.2_f64;
let border_dim = term.factored_border_dim();
let mut sys = ArrowSchurSystem::new(n, 2, border_dim);
for c in 0..border_dim {
sys.hbb[[c, c]] = 1.0;
}
for row in 0..n {
sys.rows[row].htt[[0, 0]] = data_curv + hdiag[row]; sys.rows[row].htt[[1, 1]] = 1.0; }
let entries: Vec<(usize, usize, f64)> = (0..n)
.map(|i| (2 * i, 0usize, channels.z_jac[i]))
.collect();
let source = IbpCrossRowSource {
r: 1,
d: channels.cross_row_d.clone(),
entries,
};
sys.set_ibp_cross_row_source(source);
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let (_dt, _db, cache) =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options).expect("factor H");
cache
}
#[test]
pub(crate) fn ibp_rho_trace_matches_exact_numerical_oracle_1416() {
let (term, rho) = ibp_1416_oracle_term();
let cache = ibp_1416_oracle_cache(&term, &rho);
let solver = DeflatedArrowSolver::plain(&cache);
let analytic = term
.assignment_log_strength_hessian_trace(&rho, &cache, &solver)
.expect("rho-trace");
const ORACLE: f64 = -0.1609707929;
assert!(
(analytic - ORACLE).abs() <= 1.0e-7,
"IBP ρ-trace exact oracle: analytic={analytic:.10e}, oracle={ORACLE:.10e} \
(diagonal-only pre-#1416 bug returns -0.1436656628)"
);
}
#[test]
pub(crate) fn ibp_logit_adjoint_matches_exact_numerical_oracle_1416() {
let (term, rho) = ibp_1416_oracle_term();
let cache = ibp_1416_oracle_cache_with_coord(&term, &rho);
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("theta-adjoint");
let analytic = gamma.t[cache.row_offsets[1]];
const ORACLE: f64 = -0.0498935387;
assert!(
(analytic - ORACLE).abs() <= 1.0e-7,
"IBP logit adjoint exact oracle ∂/∂ℓ_2 log|H|: analytic={analytic:.10e}, \
oracle={ORACLE:.10e} (diagonal-only pre-#1416 bug returns -0.0355527958)"
);
let fd_logdet = |dl: f64| -> f64 {
let mut t = term.clone();
t.assignment.logits[[1, 0]] += dl;
let c = ibp_1416_oracle_cache_with_coord(&t, &rho);
let (tt, beta) = c.arrow_log_det();
tt + beta.unwrap_or(0.0)
};
let h = 1.0e-6;
let fd = (fd_logdet(h) - fd_logdet(-h)) / (2.0 * h);
assert!(
(fd - analytic).abs() <= 1.0e-5,
"IBP logit adjoint vs FD of log|H|: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
#[test]
pub(crate) fn sae_logdet_theta_adjoint_matches_dense_fd_on_tiny_fixture() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
rho.log_lambda_sparse = 0.5;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 200, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(3usize, 1usize, SaeLocalRowVar::Coord { atom: 0, axis: 0 }),
];
for (row, local_pos, var) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
match var {
SaeLocalRowVar::Logit { atom } => {
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
}
SaeLocalRowVar::Coord { atom, axis } => {
let mut flat_p = plus.assignment.coords[atom].as_flat().clone();
let mut flat_m = minus.assignment.coords[atom].as_flat().clone();
let idx = row * plus.assignment.coords[atom].latent_dim() + axis;
flat_p[idx] += h;
flat_m[idx] -= h;
plus.assignment.coords[atom].set_flat(flat_p.view());
minus.assignment.coords[atom].set_flat(flat_m.view());
}
}
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 2.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"Gamma row={row} local_pos={local_pos}: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
pub(crate) fn sae_logdet_theta_adjoint_matches_dense_fd_ibp_map() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, false);
rho.log_lambda_sparse = -1.0;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(4usize, 1usize, SaeLocalRowVar::Logit { atom: 1 }),
(7usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(1usize, 2usize, SaeLocalRowVar::Coord { atom: 0, axis: 0 }),
(6usize, 3usize, SaeLocalRowVar::Coord { atom: 1, axis: 0 }),
];
for (row, local_pos, var) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
match var {
SaeLocalRowVar::Logit { atom } => {
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
}
SaeLocalRowVar::Coord { atom, axis } => {
let mut flat_p = plus.assignment.coords[atom].as_flat().clone();
let mut flat_m = minus.assignment.coords[atom].as_flat().clone();
let idx = row * plus.assignment.coords[atom].latent_dim() + axis;
flat_p[idx] += h;
flat_m[idx] -= h;
plus.assignment.coords[atom].set_flat(flat_p.view());
minus.assignment.coords[atom].set_flat(flat_m.view());
}
}
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 3.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"IBP Gamma row={row} local_pos={local_pos}: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
pub(crate) fn ibp_rho_sparse_logdet_trace_matches_dense_fd_1416() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, false);
rho.log_lambda_sparse = -0.8;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 200, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let analytic = term
.assignment_log_strength_hessian_trace(&rho, &cache, &solver)
.expect("rho_sparse logdet trace");
let h = 1.0e-5;
let mut rho_plus = rho.clone();
let mut rho_minus = rho.clone();
rho_plus.log_lambda_sparse += h;
rho_minus.log_lambda_sparse -= h;
let fd_half = 0.5
* (fixed_state_logdet(term.clone(), &target, &rho_plus)
- fixed_state_logdet(term.clone(), &target, &rho_minus))
/ (2.0 * h);
let tol = 3.0e-3 * (1.0 + fd_half.abs().max(analytic.abs()));
assert!(
(fd_half - analytic).abs() <= tol,
"IBP ρ_sparse logdet trace: fd(½∂log|H|/∂ρ)={fd_half:.8e}, \
analytic={analytic:.8e}"
);
}
#[test]
pub(crate) fn learnable_ibp_alpha_logdet_trace_matches_dense_fd_1417() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, true);
rho.log_lambda_sparse = 0.5;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let prior_trace = term
.assignment_log_strength_hessian_trace(&rho, &cache, &solver)
.expect("prior-Hessian alpha trace");
let data_trace = term
.learnable_ibp_data_logdet_alpha_trace(&rho, &cache, &solver)
.expect("data-Hessian alpha trace");
let analytic = prior_trace + data_trace;
let h = 1.0e-5;
let mut rho_plus = rho.clone();
let mut rho_minus = rho.clone();
rho_plus.log_lambda_sparse += h;
rho_minus.log_lambda_sparse -= h;
let fd_half = 0.5
* (fixed_state_logdet(term.clone(), &target, &rho_plus)
- fixed_state_logdet(term.clone(), &target, &rho_minus))
/ (2.0 * h);
let tol = 3.0e-3 * (1.0 + fd_half.abs().max(analytic.abs()));
assert!(
(fd_half - analytic).abs() <= tol,
"learnable-α logdet trace: fd(½∂log|H|/∂logα)={fd_half:.8e}, \
analytic(prior+data)={analytic:.8e} (prior={prior_trace:.6e}, \
data={data_trace:.6e})"
);
assert!(
data_trace.abs() > 1.0e-9,
"the #1417 data-Hessian alpha trace must be a live nonzero term; got \
{data_trace:.3e}"
);
}
#[test]
pub(crate) fn sae_logdet_theta_adjoint_matches_dense_fd_ibp_map_learnable_alpha_1625() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, true);
rho.log_lambda_sparse = 0.6;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 200, 0.4, 1.0e-8, 1.0e-8)
.expect("converged learnable-α cache");
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, 0usize),
(4usize, 1usize, 1usize),
(7usize, 0usize, 0usize),
];
for (row, local_pos, atom) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 3.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"learnable-α IBP Gamma row={row} local_pos={local_pos}: \
fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
pub(crate) fn ibp_rho_sparse_logdet_trace_compact_layout_matches_dense_1416() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, false);
rho.log_lambda_sparse = -0.8;
let (_value, _loss, dense_cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 200, 0.4, 1.0e-6, 1.0e-6)
.expect("dense converged cache");
let dense_solver = DeflatedArrowSolver::plain(&dense_cache);
let analytic_dense = term
.assignment_log_strength_hessian_trace(&rho, &dense_cache, &dense_solver)
.expect("dense rho_sparse trace");
let n = target.nrows();
let coord_dims = vec![1usize, 1usize];
let coord_offsets = term.assignment.coord_offsets();
let full_active: Vec<Vec<usize>> = (0..n).map(|_| vec![0usize, 1usize]).collect();
let layout = SaeRowLayout::from_active_atoms(full_active, coord_dims, coord_offsets);
let probe = SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM;
let sys = term
.assemble_arrow_schur_inner(target.view(), &rho, None, 1.0, probe, Some(Some(layout)))
.expect("full-support compact assembly");
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let (_dt, _db, compact_cache) =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options).expect("compact factor");
let compact_solver = DeflatedArrowSolver::plain(&compact_cache);
let analytic_compact = term
.assignment_log_strength_hessian_trace(&rho, &compact_cache, &compact_solver)
.expect("compact rho_sparse trace");
let struct_tol = 1.0e-7 * (1.0 + analytic_dense.abs());
assert!(
(analytic_dense - analytic_compact).abs() <= struct_tol,
"compact-layout IBP ρ_sparse logdet trace must equal the dense trace on \
full support: dense={analytic_dense:.10e}, compact={analytic_compact:.10e}"
);
let h = 1.0e-5;
let mut rho_plus = rho.clone();
let mut rho_minus = rho.clone();
rho_plus.log_lambda_sparse += h;
rho_minus.log_lambda_sparse -= h;
let fd_half = 0.5
* (fixed_state_logdet(term.clone(), &target, &rho_plus)
- fixed_state_logdet(term.clone(), &target, &rho_minus))
/ (2.0 * h);
let fd_tol = 3.0e-3 * (1.0 + fd_half.abs().max(analytic_compact.abs()));
assert!(
(fd_half - analytic_compact).abs() <= fd_tol,
"compact-layout IBP ρ_sparse logdet trace vs dense FD: \
fd(½∂log|H|/∂ρ)={fd_half:.8e}, compact analytic={analytic_compact:.8e}"
);
}