use super::*;
use super::tests::{fixed_state_logdet, gamma_fd_tiny_fixture};
#[test]
pub(crate) fn sae_logdet_theta_adjoint_matches_dense_fd_on_tiny_fixture() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
rho.log_lambda_sparse = 0.5;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 200, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(3usize, 1usize, SaeLocalRowVar::Coord { atom: 0, axis: 0 }),
];
for (row, local_pos, var) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
match var {
SaeLocalRowVar::Logit { atom } => {
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
}
SaeLocalRowVar::Coord { atom, axis } => {
let mut flat_p = plus.assignment.coords[atom].as_flat().clone();
let mut flat_m = minus.assignment.coords[atom].as_flat().clone();
let idx = row * plus.assignment.coords[atom].latent_dim() + axis;
flat_p[idx] += h;
flat_m[idx] -= h;
plus.assignment.coords[atom].set_flat(flat_p.view());
minus.assignment.coords[atom].set_flat(flat_m.view());
}
}
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 2.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"Gamma row={row} local_pos={local_pos}: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
pub(crate) fn sae_logdet_theta_adjoint_matches_dense_fd_ibp_map() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, false);
rho.log_lambda_sparse = -1.0;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(4usize, 1usize, SaeLocalRowVar::Logit { atom: 1 }),
(7usize, 0usize, SaeLocalRowVar::Logit { atom: 0 }),
(1usize, 2usize, SaeLocalRowVar::Coord { atom: 0, axis: 0 }),
(6usize, 3usize, SaeLocalRowVar::Coord { atom: 1, axis: 0 }),
];
for (row, local_pos, var) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
match var {
SaeLocalRowVar::Logit { atom } => {
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
}
SaeLocalRowVar::Coord { atom, axis } => {
let mut flat_p = plus.assignment.coords[atom].as_flat().clone();
let mut flat_m = minus.assignment.coords[atom].as_flat().clone();
let idx = row * plus.assignment.coords[atom].latent_dim() + axis;
flat_p[idx] += h;
flat_m[idx] -= h;
plus.assignment.coords[atom].set_flat(flat_p.view());
minus.assignment.coords[atom].set_flat(flat_m.view());
}
}
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 3.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"IBP Gamma row={row} local_pos={local_pos}: fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
pub(crate) fn ibp_rho_sparse_logdet_trace_matches_dense_fd_1416() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, false);
rho.log_lambda_sparse = -0.8;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 200, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let analytic = term
.assignment_log_strength_hessian_trace(&rho, &cache, &solver)
.expect("rho_sparse logdet trace");
let h = 1.0e-5;
let mut rho_plus = rho.clone();
let mut rho_minus = rho.clone();
rho_plus.log_lambda_sparse += h;
rho_minus.log_lambda_sparse -= h;
let fd_half = 0.5
* (fixed_state_logdet(term.clone(), &target, &rho_plus)
- fixed_state_logdet(term.clone(), &target, &rho_minus))
/ (2.0 * h);
let tol = 3.0e-3 * (1.0 + fd_half.abs().max(analytic.abs()));
assert!(
(fd_half - analytic).abs() <= tol,
"IBP ρ_sparse logdet trace: fd(½∂log|H|/∂ρ)={fd_half:.8e}, \
analytic={analytic:.8e}"
);
}
#[test]
pub(crate) fn learnable_ibp_alpha_logdet_trace_matches_dense_fd_1417() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, true);
rho.log_lambda_sparse = 0.5;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 5, 0.4, 1.0e-6, 1.0e-6)
.expect("converged cache");
let solver = DeflatedArrowSolver::plain(&cache);
let prior_trace = term
.assignment_log_strength_hessian_trace(&rho, &cache, &solver)
.expect("prior-Hessian alpha trace");
let data_trace = term
.learnable_ibp_data_logdet_alpha_trace(&rho, &cache, &solver)
.expect("data-Hessian alpha trace");
let analytic = prior_trace + data_trace;
let h = 1.0e-5;
let mut rho_plus = rho.clone();
let mut rho_minus = rho.clone();
rho_plus.log_lambda_sparse += h;
rho_minus.log_lambda_sparse -= h;
let fd_half = 0.5
* (fixed_state_logdet(term.clone(), &target, &rho_plus)
- fixed_state_logdet(term.clone(), &target, &rho_minus))
/ (2.0 * h);
let tol = 3.0e-3 * (1.0 + fd_half.abs().max(analytic.abs()));
assert!(
(fd_half - analytic).abs() <= tol,
"learnable-α logdet trace: fd(½∂log|H|/∂logα)={fd_half:.8e}, \
analytic(prior+data)={analytic:.8e} (prior={prior_trace:.6e}, \
data={data_trace:.6e})"
);
assert!(
data_trace.abs() > 1.0e-9,
"the #1417 data-Hessian alpha trace must be a live nonzero term; got \
{data_trace:.3e}"
);
}
#[test]
pub(crate) fn sae_logdet_theta_adjoint_matches_dense_fd_ibp_map_learnable_alpha_1625() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, true);
rho.log_lambda_sparse = 0.6;
let (_value, _loss, cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 200, 0.4, 1.0e-8, 1.0e-8)
.expect("converged learnable-α cache");
let solver = DeflatedArrowSolver::plain(&cache);
let gamma = term
.logdet_theta_adjoint(&rho, &cache, &solver)
.expect("Gamma");
let h = 1.0e-5;
let probes = [
(0usize, 0usize, 0usize),
(4usize, 1usize, 1usize),
(7usize, 0usize, 0usize),
];
for (row, local_pos, atom) in probes {
let mut plus = term.clone();
let mut minus = term.clone();
plus.assignment.logits[[row, atom]] += h;
minus.assignment.logits[[row, atom]] -= h;
let fd = (fixed_state_logdet(plus, &target, &rho)
- fixed_state_logdet(minus, &target, &rho))
/ (2.0 * h);
let analytic = gamma.t[cache.row_offsets[row] + local_pos];
let tol = 3.0e-3 * (1.0 + fd.abs().max(analytic.abs()));
assert!(
(fd - analytic).abs() <= tol,
"learnable-α IBP Gamma row={row} local_pos={local_pos}: \
fd={fd:.8e}, analytic={analytic:.8e}"
);
}
}
#[test]
pub(crate) fn ibp_rho_sparse_logdet_trace_compact_layout_matches_dense_1416() {
let (mut term, target, mut rho) = gamma_fd_tiny_fixture();
term.assignment.mode = AssignmentMode::ibp_map(0.7, 0.9, false);
rho.log_lambda_sparse = -0.8;
let (_value, _loss, dense_cache) = term
.reml_criterion_with_cache(target.view(), &rho, None, 200, 0.4, 1.0e-6, 1.0e-6)
.expect("dense converged cache");
let dense_solver = DeflatedArrowSolver::plain(&dense_cache);
let analytic_dense = term
.assignment_log_strength_hessian_trace(&rho, &dense_cache, &dense_solver)
.expect("dense rho_sparse trace");
let n = target.nrows();
let coord_dims = vec![1usize, 1usize];
let coord_offsets = term.assignment.coord_offsets();
let full_active: Vec<Vec<usize>> = (0..n).map(|_| vec![0usize, 1usize]).collect();
let layout = SaeRowLayout::from_active_atoms(full_active, coord_dims, coord_offsets);
let probe = SAE_DENSE_BETA_PENALTY_PROBE_MAX_DIM;
let sys = term
.assemble_arrow_schur_inner(target.view(), &rho, None, 1.0, probe, Some(Some(layout)))
.expect("full-support compact assembly");
let options = ArrowSolveOptions::direct().with_ill_conditioning_tolerated();
let (_dt, _db, compact_cache) =
solve_arrow_newton_step_with_options(&sys, 0.0, 0.0, &options).expect("compact factor");
let compact_solver = DeflatedArrowSolver::plain(&compact_cache);
let analytic_compact = term
.assignment_log_strength_hessian_trace(&rho, &compact_cache, &compact_solver)
.expect("compact rho_sparse trace");
let struct_tol = 1.0e-7 * (1.0 + analytic_dense.abs());
assert!(
(analytic_dense - analytic_compact).abs() <= struct_tol,
"compact-layout IBP ρ_sparse logdet trace must equal the dense trace on \
full support: dense={analytic_dense:.10e}, compact={analytic_compact:.10e}"
);
let h = 1.0e-5;
let mut rho_plus = rho.clone();
let mut rho_minus = rho.clone();
rho_plus.log_lambda_sparse += h;
rho_minus.log_lambda_sparse -= h;
let fd_half = 0.5
* (fixed_state_logdet(term.clone(), &target, &rho_plus)
- fixed_state_logdet(term.clone(), &target, &rho_minus))
/ (2.0 * h);
let fd_tol = 3.0e-3 * (1.0 + fd_half.abs().max(analytic_compact.abs()));
assert!(
(fd_half - analytic_compact).abs() <= fd_tol,
"compact-layout IBP ρ_sparse logdet trace vs dense FD: \
fd(½∂log|H|/∂ρ)={fd_half:.8e}, compact analytic={analytic_compact:.8e}"
);
}