use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use super::error::GpuError;
#[derive(Clone, Copy, Debug)]
pub struct ProbeSeed(pub u64);
impl Default for ProbeSeed {
fn default() -> Self {
Self(0xCAFE_BABE)
}
}
#[derive(Clone, Debug)]
pub enum DerivativeHessian<'a> {
Dense(ArrayView2<'a, f64>),
WeightedGram {
row_weights: ArrayView1<'a, f64>,
penalty_extra: Option<ArrayView2<'a, f64>>,
},
}
impl DerivativeHessian<'_> {
fn dim_p(&self, expected_p: usize, expected_n: usize) -> Result<(), GpuError> {
match self {
DerivativeHessian::Dense(matrix) => {
if matrix.nrows() != expected_p || matrix.ncols() != expected_p {
crate::gpu_bail!(
"reml_trace dense H_j: shape {:?} != ({expected_p}, {expected_p})",
matrix.dim()
);
}
}
DerivativeHessian::WeightedGram {
row_weights,
penalty_extra,
} => {
if row_weights.len() != expected_n {
crate::gpu_bail!(
"reml_trace structural H_j: row_weights.len()={} != n={expected_n}",
row_weights.len()
);
}
if let Some(p_extra) = penalty_extra
&& (p_extra.nrows() != expected_p || p_extra.ncols() != expected_p)
{
crate::gpu_bail!(
"reml_trace structural H_j penalty_extra: shape {:?} != ({expected_p}, {expected_p})",
p_extra.dim()
);
}
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct RemlTraceHutchinsonInput<'a> {
pub penalized_hessian: ArrayView2<'a, f64>,
pub derivatives: Vec<DerivativeHessian<'a>>,
pub design: Option<ArrayView2<'a, f64>>,
pub probe_count: usize,
pub seed: ProbeSeed,
}
#[derive(Clone, Debug)]
pub struct RemlTraceHutchinsonEvidence {
pub logdet_hessian: f64,
pub gradient_rho_logdet: Array1<f64>,
pub gradient_rho_stderr: Array1<f64>,
pub probe_count: usize,
}
pub const HUTCHINSON_GPU_MIN_P: usize = 512;
pub const HUTCHINSON_GPU_MIN_K: usize = 8;
pub const HUTCHINSON_GPU_MAX_K: usize = 128;
pub const HUTCHINSON_GPU_K_INITIAL: usize = 16;
pub const HUTCHINSON_GPU_K_STEP: usize = 8;
#[must_use]
pub fn should_use_gpu_hutchinson(
p: usize,
probe_count: usize,
prefers_stochastic: bool,
kernel_matches_hinv: bool,
plain_spd_logdet: bool,
projected_penalty_subspace_active: bool,
) -> bool {
p >= HUTCHINSON_GPU_MIN_P
&& (HUTCHINSON_GPU_MIN_K..=HUTCHINSON_GPU_MAX_K).contains(&probe_count)
&& prefers_stochastic
&& kernel_matches_hinv
&& plain_spd_logdet
&& !projected_penalty_subspace_active
}
#[inline]
pub fn splitmix64_mix(mut z: u64) -> u64 {
z = z.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut x = z;
x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
x ^ (x >> 31)
}
#[inline]
pub fn rademacher_entry(seed: u64, k: u64, i: u64) -> f64 {
const ZETA: u64 = 0xD1B5_4A32_D192_ED03;
const GAMMA: u64 = 0x8CB9_2BA7_2F9D_E81F;
let composite = seed ^ k.wrapping_mul(ZETA) ^ i.wrapping_mul(GAMMA);
let h = splitmix64_mix(composite);
if (h >> 63) == 0 { 1.0 } else { -1.0 }
}
pub fn fill_rademacher_host(seed: ProbeSeed, p: usize, k: usize, out: &mut [f64]) {
assert_eq!(
out.len(),
p * k,
"fill_rademacher_host: out buffer length {} != p*K = {}*{}",
out.len(),
p,
k
);
for col in 0..k {
for row in 0..p {
out[col * p + row] = rademacher_entry(seed.0, col as u64, row as u64);
}
}
}
pub fn evidence_derivatives_hutchinson_cpu(
input: &RemlTraceHutchinsonInput<'_>,
) -> Result<RemlTraceHutchinsonEvidence, String> {
validate_inputs(input)?;
let p = input.penalized_hessian.nrows();
let d = input.derivatives.len();
let k = input.probe_count;
let h = input.penalized_hessian.to_owned();
let factor = cholesky_lower(&h)?;
let logdet_hessian = 2.0 * (0..p).map(|i| factor[[i, i]].ln()).sum::<f64>();
let mut z = vec![0.0_f64; p * k];
fill_rademacher_host(input.seed, p, k, &mut z);
let mut w = vec![0.0_f64; p * k];
for col in 0..k {
let mut rhs = vec![0.0_f64; p];
rhs.copy_from_slice(&z[col * p..(col + 1) * p]);
let solved = solve_cholesky(&factor, &rhs);
w[col * p..(col + 1) * p].copy_from_slice(&solved);
}
let mut q = vec![0.0_f64; d * k]; for (j, derivative) in input.derivatives.iter().enumerate() {
match derivative {
DerivativeHessian::Dense(matrix) => {
for col in 0..k {
let z_col = &z[col * p..(col + 1) * p];
let w_col = &w[col * p..(col + 1) * p];
let mut y = vec![0.0_f64; p];
for r in 0..p {
let mut acc = 0.0_f64;
for c in 0..p {
acc += matrix[[r, c]] * w_col[c];
}
y[r] = acc;
}
let mut zy = 0.0_f64;
for i in 0..p {
zy += z_col[i] * y[i];
}
q[j * k + col] = zy;
}
}
DerivativeHessian::WeightedGram {
row_weights,
penalty_extra,
} => {
let design = input.design.as_ref().expect("design validated");
let n = design.nrows();
for col in 0..k {
let z_col = &z[col * p..(col + 1) * p];
let w_col = &w[col * p..(col + 1) * p];
let mut acc = 0.0_f64;
for row in 0..n {
let mut rz = 0.0_f64;
let mut rw = 0.0_f64;
for col_idx in 0..p {
rz += design[[row, col_idx]] * z_col[col_idx];
rw += design[[row, col_idx]] * w_col[col_idx];
}
acc += row_weights[row] * rz * rw;
}
if let Some(pen) = penalty_extra {
for r in 0..p {
let mut row_acc = 0.0_f64;
for c in 0..p {
row_acc += pen[[r, c]] * w_col[c];
}
acc += z_col[r] * row_acc;
}
}
q[j * k + col] = acc;
}
}
}
}
let (means, stderrs) = reduce_mean_stderr(&q, d, k);
let mut gradient_rho_logdet = Array1::<f64>::zeros(d);
let mut gradient_rho_stderr = Array1::<f64>::zeros(d);
for j in 0..d {
gradient_rho_logdet[j] = 0.5 * means[j];
gradient_rho_stderr[j] = 0.5 * stderrs[j];
}
Ok(RemlTraceHutchinsonEvidence {
logdet_hessian,
gradient_rho_logdet,
gradient_rho_stderr,
probe_count: k,
})
}
pub fn evidence_derivatives_hutchinson_gpu(
input: RemlTraceHutchinsonInput<'_>,
) -> Result<RemlTraceHutchinsonEvidence, String> {
validate_inputs(&input)?;
#[cfg(target_os = "linux")]
{
if super::runtime::GpuRuntime::global().is_some() {
match linux_cuda::evidence_derivatives(&input) {
Ok(evidence) => return Ok(evidence),
Err(GpuError::NotYetImplemented { .. }) => {
}
Err(other) => return Err(String::from(other)),
}
}
}
evidence_derivatives_hutchinson_cpu(&input)
}
pub const HUTCHINSON_ADAPTIVE_REL_TOL: f64 = 0.01;
pub const HUTCHINSON_ADAPTIVE_TAU_REL: f64 = 1e-8;
pub struct AdaptiveTraceEvidence {
pub logdet_hessian: f64,
pub traces: Array1<f64>,
pub stderrs: Array1<f64>,
pub probe_count: usize,
pub converged: bool,
}
pub fn evidence_traces_adaptive<'a>(
penalized_hessian: ArrayView2<'a, f64>,
derivatives: Vec<DerivativeHessian<'a>>,
design: Option<ArrayView2<'a, f64>>,
seed: ProbeSeed,
rel_tol: f64,
tau_rel: f64,
) -> Result<AdaptiveTraceEvidence, String> {
const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
let d = derivatives.len();
if d == 0 {
return Err("evidence_traces_adaptive: derivatives is empty".to_string());
}
if !(rel_tol > 0.0) {
return Err(format!(
"evidence_traces_adaptive: rel_tol must be > 0 (got {rel_tol})"
));
}
if !(tau_rel > 0.0) {
return Err(format!(
"evidence_traces_adaptive: tau_rel must be > 0 (got {tau_rel})"
));
}
let mut last_logdet = 0.0_f64;
let mut last_traces = Array1::<f64>::zeros(d);
let mut last_stderrs = Array1::<f64>::zeros(d);
let mut last_k = 0_usize;
let mut converged = false;
for &k in &SCHEDULE {
let input = RemlTraceHutchinsonInput {
penalized_hessian,
derivatives: derivatives.clone(),
design,
probe_count: k,
seed,
};
let evidence = evidence_derivatives_hutchinson_gpu(input)?;
last_logdet = evidence.logdet_hessian;
last_k = k;
for j in 0..d {
last_traces[j] = 2.0 * evidence.gradient_rho_logdet[j];
last_stderrs[j] = 2.0 * evidence.gradient_rho_stderr[j];
}
let sqrt_k = (k as f64).sqrt();
let mut worst = 0.0_f64;
for j in 0..d {
let denom = sqrt_k * last_traces[j].abs().max(tau_rel);
let r = last_stderrs[j] / denom;
if r > worst {
worst = r;
}
}
if worst <= rel_tol {
converged = true;
break;
}
}
Ok(AdaptiveTraceEvidence {
logdet_hessian: last_logdet,
traces: last_traces,
stderrs: last_stderrs,
probe_count: last_k,
converged,
})
}
pub const PCG_HVP_REL_TOL: f64 = 1e-6;
pub const PCG_HVP_MAX_ITERS: usize = 200;
pub fn evidence_traces_adaptive_hvp<F>(
p: usize,
mut hvp: F,
derivatives: Vec<DerivativeHessian<'_>>,
design: Option<ArrayView2<'_, f64>>,
seed: ProbeSeed,
rel_tol: f64,
tau_rel: f64,
) -> Result<AdaptiveTraceEvidence, String>
where
F: FnMut(&[f64], &mut [f64]),
{
const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
let d = derivatives.len();
if d == 0 {
return Err("evidence_traces_adaptive_hvp: derivatives is empty".to_string());
}
if p == 0 {
return Err("evidence_traces_adaptive_hvp: p must be > 0".to_string());
}
if !(rel_tol > 0.0) {
return Err(format!(
"evidence_traces_adaptive_hvp: rel_tol must be > 0 (got {rel_tol})"
));
}
if !(tau_rel > 0.0) {
return Err(format!(
"evidence_traces_adaptive_hvp: tau_rel must be > 0 (got {tau_rel})"
));
}
let mut last_traces = Array1::<f64>::zeros(d);
let mut last_stderrs = Array1::<f64>::zeros(d);
let mut last_k = 0_usize;
let mut converged = false;
let mut z = vec![0.0_f64; p];
let mut w = vec![0.0_f64; p];
let mut q_sums = vec![0.0_f64; d];
let mut q_sq_sums = vec![0.0_f64; d];
for &k_target in &SCHEDULE {
for s in q_sums.iter_mut() {
*s = 0.0;
}
for s in q_sq_sums.iter_mut() {
*s = 0.0;
}
for k_idx in 0..k_target {
for i in 0..p {
z[i] = rademacher_entry(seed.0, k_idx as u64, i as u64);
}
cg_solve(&mut hvp, &z, &mut w, rel_tol, PCG_HVP_MAX_ITERS);
for j in 0..d {
let q = match &derivatives[j] {
DerivativeHessian::Dense(matrix) => {
let mut y = 0.0_f64;
for r in 0..p {
let mut hr_w = 0.0_f64;
for c in 0..p {
hr_w += matrix[[r, c]] * w[c];
}
y += z[r] * hr_w;
}
y
}
DerivativeHessian::WeightedGram {
row_weights,
penalty_extra,
} => {
let design_view = design.as_ref().ok_or_else(|| {
"evidence_traces_adaptive_hvp: WeightedGram derivative requires \
design matrix"
.to_string()
})?;
let n = design_view.nrows();
let mut acc = 0.0_f64;
for row in 0..n {
let mut rz = 0.0_f64;
let mut rw = 0.0_f64;
for ci in 0..p {
rz += design_view[[row, ci]] * z[ci];
rw += design_view[[row, ci]] * w[ci];
}
acc += row_weights[row] * rz * rw;
}
if let Some(pen) = penalty_extra {
for r in 0..p {
let mut row_acc = 0.0_f64;
for c in 0..p {
row_acc += pen[[r, c]] * w[c];
}
acc += z[r] * row_acc;
}
}
acc
}
};
q_sums[j] += q;
q_sq_sums[j] += q * q;
}
}
let n = k_target as f64;
let mut worst_ratio = 0.0_f64;
for j in 0..d {
let mean = q_sums[j] / n;
let var = (q_sq_sums[j] / n - mean * mean).max(0.0);
let s = var.sqrt();
last_traces[j] = mean;
last_stderrs[j] = s;
let denom = n.sqrt() * mean.abs().max(tau_rel);
let r = s / denom;
if r > worst_ratio {
worst_ratio = r;
}
}
last_k = k_target;
if worst_ratio <= rel_tol {
converged = true;
break;
}
}
Ok(AdaptiveTraceEvidence {
logdet_hessian: f64::NAN,
traces: last_traces,
stderrs: last_stderrs,
probe_count: last_k,
converged,
})
}
fn cg_solve<F>(hvp: &mut F, b: &[f64], w: &mut [f64], rel_tol: f64, max_iters: usize)
where
F: FnMut(&[f64], &mut [f64]),
{
let n = b.len();
assert!(w.len() == n);
for v in w.iter_mut() {
*v = 0.0;
}
let mut r = b.to_vec();
let mut p = r.clone();
let mut hp = vec![0.0_f64; n];
let b_norm_sq: f64 = b.iter().map(|x| x * x).sum();
if b_norm_sq == 0.0 {
return;
}
let mut r_norm_sq: f64 = r.iter().map(|x| x * x).sum();
let tol_sq = rel_tol * rel_tol * b_norm_sq;
for _ in 0..max_iters {
if r_norm_sq <= tol_sq {
break;
}
hvp(p.as_slice(), hp.as_mut_slice());
let p_hp: f64 = p.iter().zip(hp.iter()).map(|(a, b)| a * b).sum();
if !(p_hp > 0.0) {
break;
}
let alpha = r_norm_sq / p_hp;
for i in 0..n {
w[i] += alpha * p[i];
r[i] -= alpha * hp[i];
}
let new_r_norm_sq: f64 = r.iter().map(|x| x * x).sum();
let beta = new_r_norm_sq / r_norm_sq;
for i in 0..n {
p[i] = r[i] + beta * p[i];
}
r_norm_sq = new_r_norm_sq;
}
}
#[must_use]
pub fn should_bypass_cpu_with_gpu_adaptive(
p: usize,
dense_spd_h_resident: bool,
plain_spd_logdet: bool,
prefers_stochastic: bool,
projected_penalty_subspace_active: bool,
) -> bool {
p >= HUTCHINSON_GPU_MIN_P
&& dense_spd_h_resident
&& plain_spd_logdet
&& prefers_stochastic
&& !projected_penalty_subspace_active
}
#[cfg(target_os = "linux")]
mod linux_cuda {
use super::{
DerivativeHessian, ProbeSeed, RemlTraceHutchinsonEvidence, RemlTraceHutchinsonInput,
reduce_mean_stderr,
};
use crate::gpu::driver::to_col_major;
use crate::gpu::error::{GpuError, GpuResultExt};
use crate::gpu::solver::{
cholesky_logdet_from_col_major, context_and_stream, pinned_htod, potrf_in_place,
potrs_in_place,
};
use cudarc::cublas::sys::cublasOperation_t;
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
use cudarc::cusolver::DnHandle;
use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
use std::sync::Arc;
pub(super) const PTX_SOURCE: &str = r#"
extern "C" __device__ unsigned long long splitmix64_mix(unsigned long long z) {
z += 0x9E3779B97F4A7C15ULL;
unsigned long long x = z;
x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
return x ^ (x >> 31);
}
extern "C" __global__ void fill_rademacher_splitmix(
unsigned long long seed,
unsigned int p,
unsigned int K,
double* __restrict__ Z)
{
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int k = blockIdx.y;
if (i >= p || k >= K) return;
const unsigned long long ZETA = 0xD1B54A32D192ED03ULL;
const unsigned long long GAMMA = 0x8CB92BA72F9DE81FULL;
unsigned long long composite =
seed
^ (((unsigned long long)k) * ZETA)
^ (((unsigned long long)i) * GAMMA);
unsigned long long h = splitmix64_mix(composite);
double v = (h >> 63) == 0 ? 1.0 : -1.0;
Z[(size_t)k * (size_t)p + (size_t)i] = v;
}
extern "C" __device__ double block_reduce_sum(double v) {
__shared__ double smem[32];
int lane = threadIdx.x & 31;
int wid = threadIdx.x >> 5;
for (int off = 16; off > 0; off >>= 1) {
v += __shfl_down_sync(0xffffffff, v, off);
}
if (lane == 0) smem[wid] = v;
__syncthreads();
double total = 0.0;
int n_warps = (blockDim.x + 31) >> 5;
if (threadIdx.x < (unsigned)n_warps) total = smem[threadIdx.x];
if (wid == 0) {
for (int off = 16; off > 0; off >>= 1) {
total += __shfl_down_sync(0xffffffff, total, off);
}
}
return total;
}
extern "C" __global__ void reduce_q_dense(
unsigned int p,
unsigned int K,
unsigned int D,
const double* __restrict__ Z,
const double* __restrict__ Y_stack,
double* __restrict__ Q)
{
unsigned int k = blockIdx.x;
unsigned int j = blockIdx.y;
if (k >= K || j >= D) return;
const double* z_col = Z + (size_t)k * (size_t)p;
const double* y_col = Y_stack + ((size_t)j * (size_t)K + (size_t)k) * (size_t)p;
double partial = 0.0;
for (unsigned int i = threadIdx.x; i < p; i += blockDim.x) {
partial += z_col[i] * y_col[i];
}
double total = block_reduce_sum(partial);
if (threadIdx.x == 0) {
Q[(size_t)j * (size_t)K + (size_t)k] = total;
}
}
extern "C" __global__ void reduce_q_weighted_gram(
unsigned int n,
unsigned int K,
unsigned int D,
const double* __restrict__ RZ,
const double* __restrict__ RW,
const double* __restrict__ A_stack,
double* __restrict__ Q)
{
unsigned int k = blockIdx.x;
unsigned int j = blockIdx.y;
if (k >= K || j >= D) return;
const double* rz_col = RZ + (size_t)k * (size_t)n;
const double* rw_col = RW + (size_t)k * (size_t)n;
const double* a_col = A_stack + (size_t)j * (size_t)n;
double partial = 0.0;
for (unsigned int i = threadIdx.x; i < n; i += blockDim.x) {
partial += a_col[i] * rz_col[i] * rw_col[i];
}
double total = block_reduce_sum(partial);
if (threadIdx.x == 0) {
Q[(size_t)j * (size_t)K + (size_t)k] = total;
}
}
"#;
const THREADS_PER_BLOCK: u32 = 256;
fn module(ctx: &Arc<CudaContext>) -> Result<&'static Arc<CudaModule>, GpuError> {
static CACHE: crate::gpu::common::PtxModuleCache =
crate::gpu::common::PtxModuleCache::new();
CACHE.get_or_compile(ctx, "reml_trace", PTX_SOURCE)
}
pub(super) fn evidence_derivatives(
input: &RemlTraceHutchinsonInput<'_>,
) -> Result<RemlTraceHutchinsonEvidence, GpuError> {
let p = input.penalized_hessian.nrows();
let d = input.derivatives.len();
let k = input.probe_count;
let (ctx, stream) =
context_and_stream().map_err(|reason| GpuError::DriverCallFailed { reason })?;
let solver = DnHandle::new(stream.clone()).gpu_ctx("reml_trace cusolver init")?;
let blas = CudaBlas::new(stream.clone()).gpu_ctx("reml_trace cublas init")?;
let compiled = module(&ctx)?;
let module_handle: &Arc<CudaModule> = compiled;
let h_col = to_col_major(&input.penalized_hessian);
let mut h_dev =
pinned_htod(&stream, &h_col).map_err(|reason| GpuError::DriverCallFailed { reason })?;
potrf_in_place(&solver, &stream, p, &mut h_dev)
.map_err(|reason| GpuError::DriverCallFailed { reason })?;
let factor_col = stream
.clone_dtoh(&h_dev)
.gpu_ctx("reml_trace download factor")?;
let logdet_hessian = cholesky_logdet_from_col_major(&factor_col, p);
let total_z = p
.checked_mul(k)
.ok_or_else(|| gpu_err!("reml_trace Z size overflow: p={p}, K={k}"))?;
let mut z_dev = stream
.alloc_zeros::<f64>(total_z)
.gpu_ctx("reml_trace alloc Z")?;
launch_fill_rademacher(&stream, module_handle, input.seed, p, k, &mut z_dev)?;
let mut w_dev = stream
.alloc_zeros::<f64>(total_z)
.gpu_ctx("reml_trace alloc W")?;
copy_device_slice(&stream, &z_dev, &mut w_dev)?;
potrs_in_place(&solver, &stream, p, k, &h_dev, &mut w_dev)
.map_err(|reason| GpuError::DriverCallFailed { reason })?;
let mut dense_indices: Vec<usize> = Vec::new();
let mut gram_indices: Vec<usize> = Vec::new();
for (j, deriv) in input.derivatives.iter().enumerate() {
match deriv {
DerivativeHessian::Dense(_) => dense_indices.push(j),
DerivativeHessian::WeightedGram { .. } => gram_indices.push(j),
}
}
let mut q_host = vec![0.0_f64; d * k];
if !dense_indices.is_empty() {
for &j in &dense_indices {
let DerivativeHessian::Dense(matrix) = &input.derivatives[j] else {
panic!(
"reml_trace dense path: derivative index {j} is in dense_indices but \
input.derivatives[{j}] is not DerivativeHessian::Dense — \
dense_indices partition invariant violated"
);
};
let hj_col = to_col_major(matrix);
let hj_dev = pinned_htod(&stream, &hj_col)
.map_err(|reason| GpuError::DriverCallFailed { reason })?;
let mut y_dev = stream
.alloc_zeros::<f64>(total_z)
.map_err(|err| gpu_err!("reml_trace alloc Y_j (j={j}): {err}"))?;
gemm_nn(
&blas,
GemmShape {
m: p,
n: k,
k_inner: p,
lda: p,
ldb: p,
ldc: p,
},
&hj_dev,
&w_dev,
&mut y_dev,
)?;
let mut q_j_dev = stream
.alloc_zeros::<f64>(k)
.gpu_ctx_with(|err| format!("reml_trace alloc Q_j (j={j}): {err}"))?;
launch_reduce_q_dense(
&stream,
module_handle,
p,
k,
1,
&z_dev,
&y_dev,
&mut q_j_dev,
)?;
let q_host_j = stream
.clone_dtoh(&q_j_dev)
.gpu_ctx_with(|err| format!("reml_trace download Q_j (j={j}): {err}"))?;
q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_j);
}
}
if !gram_indices.is_empty() {
let design = input
.design
.as_ref()
.ok_or_else(|| GpuError::DriverCallFailed {
reason: "reml_trace: structural derivative present but design=None".to_string(),
})?;
let n = design.nrows();
let design_col = to_col_major(design);
let x_dev = pinned_htod(&stream, &design_col)
.map_err(|reason| GpuError::DriverCallFailed { reason })?;
let mut rz_dev = stream
.alloc_zeros::<f64>(
n.checked_mul(k)
.ok_or_else(|| gpu_err!("reml_trace RZ overflow: n={n}, K={k}"))?,
)
.gpu_ctx("reml_trace alloc RZ")?;
let mut rw_dev = stream
.alloc_zeros::<f64>(n * k)
.gpu_ctx("reml_trace alloc RW")?;
gemm_nn(
&blas,
GemmShape {
m: n,
n: k,
k_inner: p,
lda: n,
ldb: p,
ldc: n,
},
&x_dev,
&z_dev,
&mut rz_dev,
)?;
gemm_nn(
&blas,
GemmShape {
m: n,
n: k,
k_inner: p,
lda: n,
ldb: p,
ldc: n,
},
&x_dev,
&w_dev,
&mut rw_dev,
)?;
let d_gram = gram_indices.len();
let mut a_stack = Vec::<f64>::with_capacity(n * d_gram);
for &j in &gram_indices {
let DerivativeHessian::WeightedGram { row_weights, .. } = &input.derivatives[j]
else {
panic!(
"reml_trace structural path: derivative index {j} is in gram_indices \
but input.derivatives[{j}] is not DerivativeHessian::WeightedGram — \
gram_indices partition invariant violated"
);
};
let slice = row_weights.as_slice().ok_or_else(|| {
gpu_err!("reml_trace structural H_j={j} row_weights not contiguous")
})?;
a_stack.extend_from_slice(slice);
}
let a_dev = pinned_htod(&stream, &a_stack)
.map_err(|reason| GpuError::DriverCallFailed { reason })?;
let mut q_dev = stream
.alloc_zeros::<f64>(d_gram * k)
.map_err(|err| gpu_err!("reml_trace alloc Q_gram: {err}"))?;
launch_reduce_q_weighted_gram(
&stream,
module_handle,
n,
k,
d_gram,
&rz_dev,
&rw_dev,
&a_dev,
&mut q_dev,
)?;
let q_host_gram = stream
.clone_dtoh(&q_dev)
.gpu_ctx("reml_trace download Q_gram")?;
for (slot, &j) in gram_indices.iter().enumerate() {
q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_gram[slot * k..(slot + 1) * k]);
}
for &j in &gram_indices {
let DerivativeHessian::WeightedGram { penalty_extra, .. } = &input.derivatives[j]
else {
panic!(
"reml_trace structural penalty_extra: derivative index {j} is in \
gram_indices but input.derivatives[{j}] is not \
DerivativeHessian::WeightedGram — gram_indices partition invariant \
violated"
);
};
if let Some(pen) = penalty_extra {
let z_host = stream
.clone_dtoh(&z_dev)
.gpu_ctx("reml_trace download Z for penalty_extra")?;
let w_host = stream
.clone_dtoh(&w_dev)
.gpu_ctx("reml_trace download W for penalty_extra")?;
for col in 0..k {
let z_col = &z_host[col * p..(col + 1) * p];
let w_col = &w_host[col * p..(col + 1) * p];
let mut acc = 0.0_f64;
for r in 0..p {
let mut row_acc = 0.0_f64;
for c in 0..p {
row_acc += pen[[r, c]] * w_col[c];
}
acc += z_col[r] * row_acc;
}
q_host[j * k + col] += acc;
}
}
}
}
let (means, stderrs) = reduce_mean_stderr(&q_host, d, k);
let mut gradient_rho_logdet = ndarray::Array1::<f64>::zeros(d);
let mut gradient_rho_stderr = ndarray::Array1::<f64>::zeros(d);
for j in 0..d {
gradient_rho_logdet[j] = 0.5 * means[j];
gradient_rho_stderr[j] = 0.5 * stderrs[j];
}
Ok(RemlTraceHutchinsonEvidence {
logdet_hessian,
gradient_rho_logdet,
gradient_rho_stderr,
probe_count: k,
})
}
fn launch_fill_rademacher(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
seed: ProbeSeed,
p: usize,
k: usize,
z: &mut cudarc::driver::CudaSlice<f64>,
) -> Result<(), GpuError> {
let func = module
.load_function("fill_rademacher_splitmix")
.gpu_ctx("reml_trace load fill_rademacher")?;
let grid_x = ((p as u32) + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let cfg = LaunchConfig {
grid_dim: (grid_x, k as u32, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let seed_arg: u64 = seed.0;
let p_arg: u32 = p as u32;
let k_arg: u32 = k as u32;
unsafe {
stream
.launch_builder(&func)
.arg(&seed_arg)
.arg(&p_arg)
.arg(&k_arg)
.arg(z)
.launch(cfg)
}
.map(|_| ())
.gpu_ctx("reml_trace launch fill_rademacher")
}
fn launch_reduce_q_dense(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
p: usize,
k: usize,
d: usize,
z: &cudarc::driver::CudaSlice<f64>,
y_stack: &cudarc::driver::CudaSlice<f64>,
q: &mut cudarc::driver::CudaSlice<f64>,
) -> Result<(), GpuError> {
let func = module
.load_function("reduce_q_dense")
.gpu_ctx("reml_trace load reduce_q_dense")?;
let cfg = LaunchConfig {
grid_dim: (k as u32, d as u32, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let p_arg: u32 = p as u32;
let k_arg: u32 = k as u32;
let d_arg: u32 = d as u32;
unsafe {
stream
.launch_builder(&func)
.arg(&p_arg)
.arg(&k_arg)
.arg(&d_arg)
.arg(z)
.arg(y_stack)
.arg(q)
.launch(cfg)
}
.map(|_| ())
.gpu_ctx("reml_trace launch reduce_q_dense")
}
fn launch_reduce_q_weighted_gram(
stream: &Arc<CudaStream>,
module: &Arc<CudaModule>,
n: usize,
k: usize,
d: usize,
rz: &cudarc::driver::CudaSlice<f64>,
rw: &cudarc::driver::CudaSlice<f64>,
a_stack: &cudarc::driver::CudaSlice<f64>,
q: &mut cudarc::driver::CudaSlice<f64>,
) -> Result<(), GpuError> {
let func = module
.load_function("reduce_q_weighted_gram")
.gpu_ctx("reml_trace load reduce_q_weighted_gram")?;
let cfg = LaunchConfig {
grid_dim: (k as u32, d as u32, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
let n_arg: u32 = n as u32;
let k_arg: u32 = k as u32;
let d_arg: u32 = d as u32;
unsafe {
stream
.launch_builder(&func)
.arg(&n_arg)
.arg(&k_arg)
.arg(&d_arg)
.arg(rz)
.arg(rw)
.arg(a_stack)
.arg(q)
.launch(cfg)
}
.map(|_| ())
.gpu_ctx("reml_trace launch reduce_q_weighted_gram")
}
fn copy_device_slice(
stream: &Arc<CudaStream>,
src: &cudarc::driver::CudaSlice<f64>,
dst: &mut cudarc::driver::CudaSlice<f64>,
) -> Result<(), GpuError> {
stream.memcpy_dtod(src, dst).gpu_ctx("reml_trace dtod copy")
}
struct GemmShape {
m: usize,
n: usize,
k_inner: usize,
lda: usize,
ldb: usize,
ldc: usize,
}
fn gemm_nn(
blas: &CudaBlas,
shape: GemmShape,
a: &cudarc::driver::CudaSlice<f64>,
b: &cudarc::driver::CudaSlice<f64>,
c: &mut cudarc::driver::CudaSlice<f64>,
) -> Result<(), GpuError> {
let GemmShape {
m,
n,
k_inner,
lda,
ldb,
ldc,
} = shape;
let cfg = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m: m as i32,
n: n as i32,
k: k_inner as i32,
alpha: 1.0,
lda: lda as i32,
ldb: ldb as i32,
beta: 0.0,
ldc: ldc as i32,
};
unsafe { blas.gemm(cfg, a, b, c) }.gpu_ctx("reml_trace cublas dgemm")
}
}
fn validate_inputs(input: &RemlTraceHutchinsonInput<'_>) -> Result<(), String> {
let (p, p2) = input.penalized_hessian.dim();
if p == 0 || p != p2 {
return Err(format!("reml_trace input H must be square, got {p}x{p2}"));
}
if input.probe_count < 2 {
return Err(format!(
"reml_trace requires probe_count >= 2 for a sample SE, got {}",
input.probe_count
));
}
let needs_design = input
.derivatives
.iter()
.any(|d| matches!(d, DerivativeHessian::WeightedGram { .. }));
if needs_design && input.design.is_none() {
return Err("reml_trace: structural derivative present but design=None".to_string());
}
let n = input.design.as_ref().map(|x| x.nrows()).unwrap_or(0);
if let Some(x) = input.design.as_ref()
&& x.ncols() != p
{
return Err(format!(
"reml_trace design has {} columns, expected p={p}",
x.ncols()
));
}
for (j, derivative) in input.derivatives.iter().enumerate() {
derivative
.dim_p(p, n)
.map_err(String::from)
.map_err(|e| format!("reml_trace derivative {j}: {e}"))?;
}
Ok(())
}
fn reduce_mean_stderr(q: &[f64], d: usize, k: usize) -> (Vec<f64>, Vec<f64>) {
assert_eq!(
q.len(),
d * k,
"reduce_mean_stderr: q buffer length {} != D*K = {}*{}",
q.len(),
d,
k
);
let mut means = vec![0.0_f64; d];
let mut stderrs = vec![0.0_f64; d];
let inv_k = 1.0 / (k as f64);
for j in 0..d {
let row = &q[j * k..(j + 1) * k];
let mean = row.iter().copied().sum::<f64>() * inv_k;
means[j] = mean;
if k >= 2 {
let var = row.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / ((k - 1) as f64);
stderrs[j] = (var / (k as f64)).sqrt();
}
}
(means, stderrs)
}
fn cholesky_lower(matrix: &Array2<f64>) -> Result<Array2<f64>, String> {
let n = matrix.nrows();
let mut l = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = matrix[[i, j]];
for k in 0..j {
sum -= l[[i, k]] * l[[j, k]];
}
if i == j {
if sum <= 0.0 {
return Err(format!(
"reml_trace CPU Cholesky: non-SPD diagonal {sum} at row {i}"
));
}
l[[i, j]] = sum.sqrt();
} else {
l[[i, j]] = sum / l[[j, j]];
}
}
}
Ok(l)
}
fn solve_cholesky(l: &Array2<f64>, rhs: &[f64]) -> Vec<f64> {
let n = l.nrows();
let mut y = vec![0.0_f64; n];
for i in 0..n {
let mut sum = rhs[i];
for k in 0..i {
sum -= l[[i, k]] * y[k];
}
y[i] = sum / l[[i, i]];
}
let mut x = vec![0.0_f64; n];
for i in (0..n).rev() {
let mut sum = y[i];
for k in (i + 1)..n {
sum -= l[[k, i]] * x[k];
}
x[i] = sum / l[[i, i]];
}
x
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array2, ArrayView2};
fn make_spd(p: usize, jitter: f64) -> Array2<f64> {
let mut h = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
h[[i, j]] = if i == j {
p as f64 + jitter
} else {
1.0 / (1.0 + (i as f64 - j as f64).abs())
};
}
}
h
}
fn random_dense_sym(p: usize, seed: u64) -> Array2<f64> {
let mut a = Array2::<f64>::zeros((p, p));
let mut s = seed;
for i in 0..p {
for j in i..p {
s = splitmix64_mix(s.wrapping_add(1));
let v = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
a[[i, j]] = v;
a[[j, i]] = v;
}
}
a
}
fn exact_trace_hinv_a(h: ArrayView2<f64>, a: ArrayView2<f64>) -> f64 {
let p = h.nrows();
let factor = cholesky_lower(&h.to_owned()).expect("SPD");
let mut trace = 0.0;
for col in 0..p {
let mut e = vec![0.0_f64; p];
e[col] = 1.0;
let w = solve_cholesky(&factor, &e);
let mut diag = 0.0;
for i in 0..p {
diag += a[[col, i]] * w[i];
}
trace += diag;
}
trace
}
#[test]
fn splitmix_is_deterministic_and_disperses() {
assert_eq!(splitmix64_mix(42), splitmix64_mix(42));
let mut bits_seen = 0u64;
for x in 0u64..64 {
bits_seen |= splitmix64_mix(x);
}
assert_eq!(
bits_seen,
u64::MAX,
"splitmix should cover every bit position across 64 inputs"
);
}
#[test]
fn rademacher_entries_are_pm_one_and_stateless() {
let seed = ProbeSeed(0xCAFE_BABE);
for k in 0..16u64 {
for i in 0..32u64 {
let v = rademacher_entry(seed.0, k, i);
assert!(
v == 1.0 || v == -1.0,
"non-pm1 entry at (k={k}, i={i}): {v}"
);
let v2 = rademacher_entry(seed.0, k, i);
assert_eq!(v, v2, "same (k,i) must hash to same value");
}
}
}
#[test]
fn rademacher_common_random_numbers_match_for_prefix() {
let p = 50;
let mut z16 = vec![0.0_f64; p * 16];
let mut z32 = vec![0.0_f64; p * 32];
fill_rademacher_host(ProbeSeed(7), p, 16, &mut z16);
fill_rademacher_host(ProbeSeed(7), p, 32, &mut z32);
for col in 0..16 {
for row in 0..p {
assert_eq!(
z16[col * p + row],
z32[col * p + row],
"CRN broken at (col={col}, row={row})"
);
}
}
}
#[test]
fn cpu_hutchinson_unbiased_against_exact_small_spd() {
let p = 16;
let h = make_spd(p, 0.5);
let a1 = random_dense_sym(p, 0x1234);
let a2 = random_dense_sym(p, 0x5678);
let exact1 = exact_trace_hinv_a(h.view(), a1.view());
let exact2 = exact_trace_hinv_a(h.view(), a2.view());
let input = RemlTraceHutchinsonInput {
penalized_hessian: h.view(),
derivatives: vec![
DerivativeHessian::Dense(a1.view()),
DerivativeHessian::Dense(a2.view()),
],
design: None,
probe_count: 4096,
seed: ProbeSeed(0xCAFE_BABE),
};
let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
let est1 = 2.0 * evidence.gradient_rho_logdet[0];
let est2 = 2.0 * evidence.gradient_rho_logdet[1];
let tol1 = 6.0 * evidence.gradient_rho_stderr[0].max(1e-8) * 2.0;
let tol2 = 6.0 * evidence.gradient_rho_stderr[1].max(1e-8) * 2.0;
assert!(
(est1 - exact1).abs() <= tol1,
"Hutchinson est {est1} too far from exact {exact1} (tol={tol1}, se={})",
evidence.gradient_rho_stderr[0]
);
assert!(
(est2 - exact2).abs() <= tol2,
"Hutchinson est {est2} too far from exact {exact2} (tol={tol2})"
);
}
#[test]
fn structural_path_matches_dense_for_xtwx() {
let n = 40;
let p = 8;
let mut x = Array2::<f64>::zeros((n, p));
let mut s = 11u64;
for r in 0..n {
for c in 0..p {
s = splitmix64_mix(s.wrapping_add(1));
x[[r, c]] = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
}
}
let a: Vec<f64> = (0..n).map(|i| 0.5 + 0.01 * (i as f64)).collect();
let a_arr = ndarray::Array1::from(a);
let mut hj_dense = Array2::<f64>::zeros((p, p));
for r in 0..p {
for c in 0..p {
let mut acc = 0.0;
for i in 0..n {
acc += x[[i, r]] * a_arr[i] * x[[i, c]];
}
hj_dense[[r, c]] = acc;
}
}
let mut h = make_spd(p, 1.0);
for i in 0..p {
h[[i, i]] += 1.0;
}
let input_dense = RemlTraceHutchinsonInput {
penalized_hessian: h.view(),
derivatives: vec![DerivativeHessian::Dense(hj_dense.view())],
design: None,
probe_count: 32,
seed: ProbeSeed(123),
};
let input_struct = RemlTraceHutchinsonInput {
penalized_hessian: h.view(),
derivatives: vec![DerivativeHessian::WeightedGram {
row_weights: a_arr.view(),
penalty_extra: None,
}],
design: Some(x.view()),
probe_count: 32,
seed: ProbeSeed(123),
};
let e_dense = evidence_derivatives_hutchinson_cpu(&input_dense).expect("ok");
let e_struct = evidence_derivatives_hutchinson_cpu(&input_struct).expect("ok");
assert!(
(e_dense.gradient_rho_logdet[0] - e_struct.gradient_rho_logdet[0]).abs() < 1e-9,
"dense vs structural mismatch: dense={}, struct={}",
e_dense.gradient_rho_logdet[0],
e_struct.gradient_rho_logdet[0]
);
}
#[test]
fn finite_difference_check_against_logdet() {
let p = 10;
let h0 = make_spd(p, 0.2);
let a = random_dense_sym(p, 0xABCD);
let eps = 1e-4;
let mut hp = h0.clone();
let mut hm = h0.clone();
for i in 0..p {
for j in 0..p {
hp[[i, j]] += eps * a[[i, j]];
hm[[i, j]] -= eps * a[[i, j]];
}
}
let ld = |m: &Array2<f64>| -> f64 {
let l = cholesky_lower(m).unwrap();
2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
};
let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
let exact = exact_trace_hinv_a(h0.view(), a.view());
assert!(
(fd - exact).abs() / exact.abs().max(1e-12) < 1e-6,
"FD logdet derivative {fd} != exact trace {exact}"
);
let input = RemlTraceHutchinsonInput {
penalized_hessian: h0.view(),
derivatives: vec![DerivativeHessian::Dense(a.view())],
design: None,
probe_count: 4096,
seed: ProbeSeed(0xAA55),
};
let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
let tol = 8.0 * evidence.gradient_rho_stderr[0].max(1e-8);
assert!(
(evidence.gradient_rho_logdet[0] - 0.5 * exact).abs() < tol,
"Hutchinson gradient {} not within 8·SE of 0.5·exact={}",
evidence.gradient_rho_logdet[0],
0.5 * exact
);
}
#[test]
fn gate_rejects_below_min_p() {
assert!(!should_use_gpu_hutchinson(64, 16, true, true, true, false));
}
#[test]
fn gate_rejects_k_out_of_range() {
assert!(!should_use_gpu_hutchinson(2000, 4, true, true, true, false));
assert!(!should_use_gpu_hutchinson(
2000, 200, true, true, true, false
));
}
#[test]
fn gate_rejects_when_subspace_active() {
assert!(!should_use_gpu_hutchinson(2000, 16, true, true, true, true));
}
#[test]
fn gate_accepts_canonical_case() {
assert!(should_use_gpu_hutchinson(2000, 16, true, true, true, false));
}
#[test]
fn block_2_6_adaptive_unbiased_against_exact_p512() {
let p = 64;
let h = make_spd(p, 0.5);
let a = random_dense_sym(p, 0xBADC0DE);
let exact = exact_trace_hinv_a(h.view(), a.view());
let evidence = evidence_traces_adaptive(
h.view(),
vec![DerivativeHessian::Dense(a.view())],
None,
ProbeSeed(0xA5A5A5),
HUTCHINSON_ADAPTIVE_REL_TOL,
HUTCHINSON_ADAPTIVE_TAU_REL,
)
.expect("adaptive run ok");
let est = evidence.traces[0];
let se = evidence.stderrs[0] / (evidence.probe_count as f64).sqrt();
let tol = (8.0 * se).max(0.05 * exact.abs());
assert!(
(est - exact).abs() <= tol,
"adaptive est {est} far from exact {exact} (tol={tol}, se={se}, K={})",
evidence.probe_count
);
}
#[test]
fn block_2_6_same_probes_cpu_vs_dispatch() {
let p = 32;
let h = make_spd(p, 0.3);
let a = random_dense_sym(p, 0x1357);
let input = RemlTraceHutchinsonInput {
penalized_hessian: h.view(),
derivatives: vec![DerivativeHessian::Dense(a.view())],
design: None,
probe_count: 16,
seed: ProbeSeed(0xBEEF),
};
let cpu = evidence_derivatives_hutchinson_cpu(&input).expect("cpu");
let dispatch = evidence_derivatives_hutchinson_gpu(input).expect("dispatch");
let diff = (cpu.gradient_rho_logdet[0] - dispatch.gradient_rho_logdet[0]).abs();
assert!(
diff < 1e-9,
"same-probes CPU vs GPU dispatch differ: cpu={}, dispatch={}, diff={diff}",
cpu.gradient_rho_logdet[0],
dispatch.gradient_rho_logdet[0]
);
}
#[test]
fn block_2_6_fd_logdet_matches_adaptive() {
let p = 24;
let h = make_spd(p, 0.4);
let a = random_dense_sym(p, 0x2468);
let eps = 1e-4;
let mut hp = h.clone();
let mut hm = h.clone();
for i in 0..p {
for j in 0..p {
hp[[i, j]] += eps * a[[i, j]];
hm[[i, j]] -= eps * a[[i, j]];
}
}
let ld = |m: &Array2<f64>| -> f64 {
let l = cholesky_lower(m).expect("SPD");
2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
};
let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
let evidence = evidence_traces_adaptive(
h.view(),
vec![DerivativeHessian::Dense(a.view())],
None,
ProbeSeed(0x9999),
HUTCHINSON_ADAPTIVE_REL_TOL,
HUTCHINSON_ADAPTIVE_TAU_REL,
)
.expect("adaptive ok");
let est = evidence.traces[0];
let se = evidence.stderrs[0] / (evidence.probe_count as f64).sqrt();
let tol = (8.0 * se).max(0.05 * fd.abs());
assert!(
(est - fd).abs() <= tol,
"adaptive trace {est} disagrees with FD logdet derivative {fd} (tol={tol})"
);
}
#[test]
fn block_2_6_k_4096_matches_exact_tightly() {
let p = 40;
let h = make_spd(p, 0.6);
let a = random_dense_sym(p, 0xDEAD);
let exact = exact_trace_hinv_a(h.view(), a.view());
let input = RemlTraceHutchinsonInput {
penalized_hessian: h.view(),
derivatives: vec![DerivativeHessian::Dense(a.view())],
design: None,
probe_count: 4096,
seed: ProbeSeed(0xC0FFEE),
};
let evidence = evidence_derivatives_hutchinson_gpu(input).expect("ok");
let est = 2.0 * evidence.gradient_rho_logdet[0];
let se = 2.0 * evidence.gradient_rho_stderr[0] / (4096_f64).sqrt();
let tol = (6.0 * se).max(1e-3 * exact.abs());
assert!(
(est - exact).abs() <= tol,
"K=4096 Hutchinson {est} not within 6·SE of exact {exact} (tol={tol}, se={se})"
);
}
#[test]
fn block_2_6_crn_prefix_match_across_schedule() {
let p = 50;
let seed = ProbeSeed(0x4242_4242);
let mut z16 = vec![0.0_f64; p * 16];
let mut z32 = vec![0.0_f64; p * 32];
let mut z64 = vec![0.0_f64; p * 64];
fill_rademacher_host(seed, p, 16, &mut z16);
fill_rademacher_host(seed, p, 32, &mut z32);
fill_rademacher_host(seed, p, 64, &mut z64);
for col in 0..16 {
for row in 0..p {
assert_eq!(z16[col * p + row], z32[col * p + row]);
assert_eq!(z16[col * p + row], z64[col * p + row]);
}
}
for col in 0..32 {
for row in 0..p {
assert_eq!(z32[col * p + row], z64[col * p + row]);
}
}
}
#[test]
fn block_2_7_hvp_path_matches_dense_adaptive() {
let p = 40;
let h = make_spd(p, 0.7);
let a = random_dense_sym(p, 0xABBA);
let seed = ProbeSeed(0x707);
let dense = evidence_traces_adaptive(
h.view(),
vec![DerivativeHessian::Dense(a.view())],
None,
seed,
HUTCHINSON_ADAPTIVE_REL_TOL,
HUTCHINSON_ADAPTIVE_TAU_REL,
)
.expect("dense ok");
let h_clone = h.clone();
let hvp_evidence = evidence_traces_adaptive_hvp(
p,
|v: &[f64], out: &mut [f64]| {
for r in 0..p {
let mut acc = 0.0_f64;
for c in 0..p {
acc += h_clone[[r, c]] * v[c];
}
out[r] = acc;
}
},
vec![DerivativeHessian::Dense(a.view())],
None,
seed,
HUTCHINSON_ADAPTIVE_REL_TOL,
HUTCHINSON_ADAPTIVE_TAU_REL,
)
.expect("hvp ok");
let exact = exact_trace_hinv_a(h.view(), a.view());
let se_dense = dense.stderrs[0] / (dense.probe_count as f64).sqrt();
let se_hvp = hvp_evidence.stderrs[0] / (hvp_evidence.probe_count as f64).sqrt();
let tol_dense = (8.0 * se_dense).max(0.05 * exact.abs());
let tol_hvp = (8.0 * se_hvp).max(0.05 * exact.abs());
assert!(
(dense.traces[0] - exact).abs() <= tol_dense,
"dense adaptive {} not near exact {} (tol {})",
dense.traces[0],
exact,
tol_dense
);
assert!(
(hvp_evidence.traces[0] - exact).abs() <= tol_hvp,
"hvp adaptive {} not near exact {} (tol {})",
hvp_evidence.traces[0],
exact,
tol_hvp
);
assert!(hvp_evidence.logdet_hessian.is_nan());
}
#[test]
fn block_2_7_cg_solves_diagonal_in_one_iteration() {
let p = 8;
let diag: Vec<f64> = (0..p).map(|i| 1.0 + i as f64).collect();
let b: Vec<f64> = (0..p).map(|i| (i as f64) + 0.5).collect();
let mut w = vec![0.0_f64; p];
let diag_clone = diag.clone();
cg_solve(
&mut |v: &[f64], out: &mut [f64]| {
for i in 0..p {
out[i] = diag_clone[i] * v[i];
}
},
&b,
&mut w,
1e-12,
PCG_HVP_MAX_ITERS,
);
for i in 0..p {
let expected = b[i] / diag[i];
assert!(
(w[i] - expected).abs() < 1e-10,
"diagonal CG: w[{i}]={} expected {expected}",
w[i]
);
}
}
#[test]
fn block_2_8_hill_climb_adaptive_vs_exact_at_p2000_d8() {
let on_v100 =
cfg!(target_os = "linux") && super::super::runtime::GpuRuntime::global().is_some();
let (p, d): (usize, usize) = if on_v100 { (2000, 8) } else { (256, 4) };
let mut h = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
h[[i, j]] = if i == j {
p as f64 + 1.0
} else {
1.0 / (1.0 + (i as f64 - j as f64).abs())
};
}
}
let derivs_owned: Vec<Array2<f64>> = (0..d)
.map(|k| random_dense_sym(p, 0x1000 + k as u64))
.collect();
let derivs: Vec<DerivativeHessian<'_>> = derivs_owned
.iter()
.map(|a| DerivativeHessian::Dense(a.view()))
.collect();
let t_exact_start = std::time::Instant::now();
let factor = cholesky_lower(&h).expect("SPD");
let mut exact_traces = vec![0.0_f64; d];
for (j, a) in derivs_owned.iter().enumerate() {
let mut acc = 0.0_f64;
for col in 0..p {
let mut rhs = vec![0.0_f64; p];
for r in 0..p {
rhs[r] = a[[r, col]];
}
let w = solve_cholesky(&factor, &rhs);
acc += w[col];
}
exact_traces[j] = acc;
}
let t_exact = t_exact_start.elapsed();
let t_adaptive_start = std::time::Instant::now();
let evidence = evidence_traces_adaptive(
h.view(),
derivs,
None,
ProbeSeed(0xB10C),
HUTCHINSON_ADAPTIVE_REL_TOL,
HUTCHINSON_ADAPTIVE_TAU_REL,
)
.expect("adaptive ok");
let t_adaptive = t_adaptive_start.elapsed();
for j in 0..d {
let se = evidence.stderrs[j] / (evidence.probe_count as f64).sqrt();
let tol = (10.0 * se).max(0.05 * exact_traces[j].abs());
let diff = (evidence.traces[j] - exact_traces[j]).abs();
assert!(
diff <= tol,
"block_2_8: derivative {j} adaptive {} disagrees with exact {} (tol {tol}, diff {diff})",
evidence.traces[j],
exact_traces[j]
);
}
let speedup = t_exact.as_secs_f64() / t_adaptive.as_secs_f64().max(1e-9);
eprintln!(
"block_2_8 hill-climb [p={p}, d={d}, V100={on_v100}]: \
exact={:?}, adaptive={:?}, speedup={:.2}× (K={}, converged={})",
t_exact, t_adaptive, speedup, evidence.probe_count, evidence.converged
);
if on_v100 {
assert!(
speedup >= 10.0,
"block_2_8 V100 speedup {speedup:.2}× below the 10× target \
(exact {:?}, adaptive {:?})",
t_exact,
t_adaptive,
);
}
}
}