1use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1};
74
75use gam_gpu::gpu_error::GpuError;
76use gam_linalg::pcg::{DotReduction, pcg_core};
77
78#[derive(Clone, Copy, Debug)]
84pub struct ProbeSeed(pub u64);
85
86impl Default for ProbeSeed {
87 fn default() -> Self {
88 Self(0xCAFE_BABE)
91 }
92}
93
94#[derive(Clone, Debug)]
100pub enum DerivativeHessian<'a> {
101 Dense(ArrayView2<'a, f64>),
104 WeightedGram {
110 row_weights: ArrayView1<'a, f64>,
111 penalty_extra: Option<ArrayView2<'a, f64>>,
112 },
113}
114
115impl DerivativeHessian<'_> {
116 fn dim_p(&self, expected_p: usize, expected_n: usize) -> Result<(), GpuError> {
117 match self {
118 DerivativeHessian::Dense(matrix) => {
119 if matrix.nrows() != expected_p || matrix.ncols() != expected_p {
120 gam_gpu::gpu_bail!(
121 "reml_trace dense H_j: shape {:?} != ({expected_p}, {expected_p})",
122 matrix.dim()
123 );
124 }
125 }
126 DerivativeHessian::WeightedGram {
127 row_weights,
128 penalty_extra,
129 } => {
130 if row_weights.len() != expected_n {
131 gam_gpu::gpu_bail!(
132 "reml_trace structural H_j: row_weights.len()={} != n={expected_n}",
133 row_weights.len()
134 );
135 }
136 if let Some(p_extra) = penalty_extra
137 && (p_extra.nrows() != expected_p || p_extra.ncols() != expected_p)
138 {
139 gam_gpu::gpu_bail!(
140 "reml_trace structural H_j penalty_extra: shape {:?} != ({expected_p}, {expected_p})",
141 p_extra.dim()
142 );
143 }
144 }
145 }
146 Ok(())
147 }
148}
149
150#[derive(Clone, Debug)]
152pub struct RemlTraceHutchinsonInput<'a> {
153 pub penalized_hessian: ArrayView2<'a, f64>,
155 pub derivatives: Vec<DerivativeHessian<'a>>,
157 pub design: Option<ArrayView2<'a, f64>>,
160 pub probe_count: usize,
162 pub seed: ProbeSeed,
164}
165
166#[derive(Clone, Debug)]
168pub struct RemlTraceHutchinsonEvidence {
169 pub logdet_hessian: f64,
172 pub gradient_rho_logdet: Array1<f64>,
174 pub gradient_rho_stderr: Array1<f64>,
179 pub probe_count: usize,
181}
182
183pub const HUTCHINSON_GPU_MIN_P: usize = 512;
189pub const HUTCHINSON_GPU_MIN_K: usize = 8;
191pub const HUTCHINSON_GPU_MAX_K: usize = 128;
192
193#[must_use]
201pub fn should_use_gpu_hutchinson(
202 p: usize,
203 probe_count: usize,
204 prefers_stochastic: bool,
205 kernel_matches_hinv: bool,
206 plain_spd_logdet: bool,
207 projected_penalty_subspace_active: bool,
208) -> bool {
209 p >= HUTCHINSON_GPU_MIN_P
210 && (HUTCHINSON_GPU_MIN_K..=HUTCHINSON_GPU_MAX_K).contains(&probe_count)
211 && prefers_stochastic
212 && kernel_matches_hinv
213 && plain_spd_logdet
214 && !projected_penalty_subspace_active
215}
216
217#[inline]
226pub fn splitmix64_mix(z: u64) -> u64 {
227 gam_linalg::utils::splitmix64_hash(z)
228}
229
230#[inline]
238pub fn rademacher_entry(seed: u64, k: u64, i: u64) -> f64 {
239 const ZETA: u64 = 0xD1B5_4A32_D192_ED03;
240 const GAMMA: u64 = 0x8CB9_2BA7_2F9D_E81F;
241 let composite = seed ^ k.wrapping_mul(ZETA) ^ i.wrapping_mul(GAMMA);
242 let h = splitmix64_mix(composite);
243 if (h >> 63) == 0 { 1.0 } else { -1.0 }
244}
245
246pub fn fill_rademacher_host(seed: ProbeSeed, p: usize, k: usize, out: &mut [f64]) {
249 assert_eq!(
250 out.len(),
251 p * k,
252 "fill_rademacher_host: out buffer length {} != p*K = {}*{}",
253 out.len(),
254 p,
255 k
256 );
257 for col in 0..k {
258 for row in 0..p {
259 out[col * p + row] = rademacher_entry(seed.0, col as u64, row as u64);
260 }
261 }
262}
263
264pub fn evidence_derivatives_hutchinson_cpu(
275 input: &RemlTraceHutchinsonInput<'_>,
276) -> Result<RemlTraceHutchinsonEvidence, String> {
277 validate_inputs(input)?;
278 let p = input.penalized_hessian.nrows();
279 let d = input.derivatives.len();
280 let k = input.probe_count;
281
282 let h = input.penalized_hessian.to_owned();
284 let factor = cholesky_lower(&h)?;
285 let logdet_hessian = 2.0 * (0..p).map(|i| factor[[i, i]].ln()).sum::<f64>();
286
287 let mut z = vec![0.0_f64; p * k];
289 fill_rademacher_host(input.seed, p, k, &mut z);
290
291 use rayon::prelude::*;
299 let mut w = vec![0.0_f64; p * k];
300 w.par_chunks_mut(p)
301 .zip(z.par_chunks(p))
302 .for_each(|(w_col, z_col)| {
303 let solved = solve_cholesky(&factor, z_col);
304 w_col.copy_from_slice(&solved);
305 });
306
307 let mut q = vec![0.0_f64; d * k]; for (j, derivative) in input.derivatives.iter().enumerate() {
315 let q_row = &mut q[j * k..(j + 1) * k];
316 match derivative {
317 DerivativeHessian::Dense(matrix) => {
318 q_row
319 .par_iter_mut()
320 .zip(z.par_chunks(p).zip(w.par_chunks(p)))
321 .for_each(|(q_jk, (z_col, w_col))| {
322 let mut y = vec![0.0_f64; p];
324 for r in 0..p {
325 let mut acc = 0.0_f64;
326 for c in 0..p {
327 acc += matrix[[r, c]] * w_col[c];
328 }
329 y[r] = acc;
330 }
331 let mut zy = 0.0_f64;
332 for i in 0..p {
333 zy += z_col[i] * y[i];
334 }
335 *q_jk = zy;
336 });
337 }
338 DerivativeHessian::WeightedGram {
339 row_weights,
340 penalty_extra,
341 } => {
342 let design = input.design.as_ref().expect("design validated");
343 let n = design.nrows();
344 q_row
345 .par_iter_mut()
346 .zip(z.par_chunks(p).zip(w.par_chunks(p)))
347 .for_each(|(q_jk, (z_col, w_col))| {
348 let mut acc = 0.0_f64;
350 for row in 0..n {
351 let mut rz = 0.0_f64;
352 let mut rw = 0.0_f64;
353 for col_idx in 0..p {
354 rz += design[[row, col_idx]] * z_col[col_idx];
355 rw += design[[row, col_idx]] * w_col[col_idx];
356 }
357 acc += row_weights[row] * rz * rw;
358 }
359 if let Some(pen) = penalty_extra {
360 for r in 0..p {
361 let mut row_acc = 0.0_f64;
362 for c in 0..p {
363 row_acc += pen[[r, c]] * w_col[c];
364 }
365 acc += z_col[r] * row_acc;
366 }
367 }
368 *q_jk = acc;
369 });
370 }
371 }
372 }
373
374 let (means, stderrs) = reduce_mean_stderr(&q, d, k);
375 let mut gradient_rho_logdet = Array1::<f64>::zeros(d);
376 let mut gradient_rho_stderr = Array1::<f64>::zeros(d);
377 for j in 0..d {
378 gradient_rho_logdet[j] = 0.5 * means[j];
379 gradient_rho_stderr[j] = 0.5 * stderrs[j];
380 }
381
382 Ok(RemlTraceHutchinsonEvidence {
383 logdet_hessian,
384 gradient_rho_logdet,
385 gradient_rho_stderr,
386 probe_count: k,
387 })
388}
389
390pub fn evidence_derivatives_hutchinson_gpu(
400 input: RemlTraceHutchinsonInput<'_>,
401) -> Result<RemlTraceHutchinsonEvidence, String> {
402 validate_inputs(&input)?;
403
404 #[cfg(target_os = "linux")]
405 {
406 if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
407 match linux_cuda::evidence_derivatives(&input) {
408 Ok(evidence) => return Ok(evidence),
409 Err(GpuError::NoDeviceKernel { .. }) => {
410 }
413 Err(other) => return Err(String::from(other)),
414 }
415 }
416 }
417
418 evidence_derivatives_hutchinson_cpu(&input)
419}
420
421pub const HUTCHINSON_ADAPTIVE_REL_TOL: f64 = 0.01;
428pub const HUTCHINSON_ADAPTIVE_TAU_REL: f64 = 1e-8;
431
432pub struct AdaptiveTraceEvidence {
457 pub logdet_hessian: f64,
458 pub traces: Array1<f64>,
459 pub stderrs: Array1<f64>,
463 pub probe_count: usize,
464 pub converged: bool,
465}
466
467pub fn evidence_traces_adaptive<'a>(
468 penalized_hessian: ArrayView2<'a, f64>,
469 derivatives: Vec<DerivativeHessian<'a>>,
470 design: Option<ArrayView2<'a, f64>>,
471 seed: ProbeSeed,
472 rel_tol: f64,
473 tau_rel: f64,
474) -> Result<AdaptiveTraceEvidence, String> {
475 const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
477
478 let d = derivatives.len();
479 if d == 0 {
480 return Err("evidence_traces_adaptive: derivatives is empty".to_string());
481 }
482 if !(rel_tol > 0.0) {
483 return Err(format!(
484 "evidence_traces_adaptive: rel_tol must be > 0 (got {rel_tol})"
485 ));
486 }
487 if !(tau_rel > 0.0) {
488 return Err(format!(
489 "evidence_traces_adaptive: tau_rel must be > 0 (got {tau_rel})"
490 ));
491 }
492
493 let mut last_logdet = 0.0_f64;
494 let mut last_traces = Array1::<f64>::zeros(d);
495 let mut last_stderrs = Array1::<f64>::zeros(d);
496 let mut last_k = 0_usize;
497 let mut converged = false;
498
499 for &k in &SCHEDULE {
500 let input = RemlTraceHutchinsonInput {
501 penalized_hessian,
502 derivatives: derivatives.clone(),
503 design,
504 probe_count: k,
505 seed,
506 };
507 let evidence = evidence_derivatives_hutchinson_gpu(input)?;
508 last_logdet = evidence.logdet_hessian;
509 last_k = k;
510
511 for j in 0..d {
516 last_traces[j] = 2.0 * evidence.gradient_rho_logdet[j];
517 last_stderrs[j] = 2.0 * evidence.gradient_rho_stderr[j];
518 }
519
520 let sqrt_k = (k as f64).sqrt();
529 let mut worst = 0.0_f64;
530 for j in 0..d {
531 let denom = sqrt_k * last_traces[j].abs().max(tau_rel);
532 let r = last_stderrs[j] / denom;
533 if r > worst {
534 worst = r;
535 }
536 }
537 if worst <= rel_tol {
538 converged = true;
539 break;
540 }
541 }
542
543 Ok(AdaptiveTraceEvidence {
544 logdet_hessian: last_logdet,
545 traces: last_traces,
546 stderrs: last_stderrs,
547 probe_count: last_k,
548 converged,
549 })
550}
551
552pub const PCG_HVP_REL_TOL: f64 = 1e-6;
561
562pub const PCG_HVP_MAX_ITERS: usize = 200;
568
569pub fn evidence_traces_adaptive_hvp<F>(
600 p: usize,
601 mut hvp: F,
602 derivatives: Vec<DerivativeHessian<'_>>,
603 design: Option<ArrayView2<'_, f64>>,
604 seed: ProbeSeed,
605 rel_tol: f64,
606 tau_rel: f64,
607) -> Result<AdaptiveTraceEvidence, String>
608where
609 F: FnMut(&[f64], &mut [f64]),
610{
611 const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
612
613 let d = derivatives.len();
614 if d == 0 {
615 return Err("evidence_traces_adaptive_hvp: derivatives is empty".to_string());
616 }
617 if p == 0 {
618 return Err("evidence_traces_adaptive_hvp: p must be > 0".to_string());
619 }
620 if !(rel_tol > 0.0) {
621 return Err(format!(
622 "evidence_traces_adaptive_hvp: rel_tol must be > 0 (got {rel_tol})"
623 ));
624 }
625 if !(tau_rel > 0.0) {
626 return Err(format!(
627 "evidence_traces_adaptive_hvp: tau_rel must be > 0 (got {tau_rel})"
628 ));
629 }
630
631 let mut last_traces = Array1::<f64>::zeros(d);
632 let mut last_stderrs = Array1::<f64>::zeros(d);
633 let mut last_k = 0_usize;
634 let mut converged = false;
635
636 let mut z = vec![0.0_f64; p];
637 let mut w = vec![0.0_f64; p];
638
639 let mut q_means = vec![0.0_f64; d];
646 let mut q_m2 = vec![0.0_f64; d];
647
648 for &k_target in &SCHEDULE {
649 for s in q_means.iter_mut() {
653 *s = 0.0;
654 }
655 for s in q_m2.iter_mut() {
656 *s = 0.0;
657 }
658
659 for k_idx in 0..k_target {
660 for i in 0..p {
662 z[i] = rademacher_entry(seed.0, k_idx as u64, i as u64);
663 }
664 cg_solve(&mut hvp, &z, &mut w, PCG_HVP_REL_TOL, PCG_HVP_MAX_ITERS);
666
667 for j in 0..d {
670 let q = match &derivatives[j] {
671 DerivativeHessian::Dense(matrix) => {
672 let mut y = 0.0_f64;
673 for r in 0..p {
674 let mut hr_w = 0.0_f64;
675 for c in 0..p {
676 hr_w += matrix[[r, c]] * w[c];
677 }
678 y += z[r] * hr_w;
679 }
680 y
681 }
682 DerivativeHessian::WeightedGram {
683 row_weights,
684 penalty_extra,
685 } => {
686 let design_view = design.as_ref().ok_or_else(|| {
687 "evidence_traces_adaptive_hvp: WeightedGram derivative requires \
688 design matrix"
689 .to_string()
690 })?;
691 let n = design_view.nrows();
692 let mut acc = 0.0_f64;
693 for row in 0..n {
694 let mut rz = 0.0_f64;
695 let mut rw = 0.0_f64;
696 for ci in 0..p {
697 rz += design_view[[row, ci]] * z[ci];
698 rw += design_view[[row, ci]] * w[ci];
699 }
700 acc += row_weights[row] * rz * rw;
701 }
702 if let Some(pen) = penalty_extra {
703 for r in 0..p {
704 let mut row_acc = 0.0_f64;
705 for c in 0..p {
706 row_acc += pen[[r, c]] * w[c];
707 }
708 acc += z[r] * row_acc;
709 }
710 }
711 acc
712 }
713 };
714 let count = (k_idx + 1) as f64;
716 let delta = q - q_means[j];
717 q_means[j] += delta / count;
718 let delta2 = q - q_means[j];
719 q_m2[j] += delta * delta2;
720 }
721 }
722
723 let n = k_target as f64;
724 let mut worst_ratio = 0.0_f64;
725 for j in 0..d {
726 let mean = q_means[j];
727 let var = if n > 1.0 { q_m2[j] / (n - 1.0) } else { 0.0 };
731 let s = var.sqrt();
732 last_traces[j] = mean;
733 last_stderrs[j] = s;
734 let denom = n.sqrt() * mean.abs().max(tau_rel);
735 let r = s / denom;
736 if r > worst_ratio {
737 worst_ratio = r;
738 }
739 }
740 last_k = k_target;
741 if worst_ratio <= rel_tol {
742 converged = true;
743 break;
744 }
745 }
746
747 Ok(AdaptiveTraceEvidence {
748 logdet_hessian: f64::NAN,
749 traces: last_traces,
750 stderrs: last_stderrs,
751 probe_count: last_k,
752 converged,
753 })
754}
755
756fn cg_solve<F>(hvp: &mut F, b: &[f64], w: &mut [f64], rel_tol: f64, max_iters: usize)
777where
778 F: FnMut(&[f64], &mut [f64]),
779{
780 let n = b.len();
781 assert!(w.len() == n);
782
783 let rhs = ArrayView1::from(b);
784 let precond = Array1::<f64>::ones(n);
785 let mut solution = ArrayViewMut1::from(w);
786
787 pcg_core(
788 |v: &Array1<f64>, out: &mut Array1<f64>| {
789 let v_slice = v.as_slice().expect("contiguous CG direction view");
791 let out_slice = out.as_slice_mut().expect("contiguous CG matvec view");
792 hvp(v_slice, out_slice);
793 },
794 &rhs,
795 &precond.view(),
796 rel_tol,
797 max_iters,
798 0,
799 false,
800 DotReduction::Reordered,
801 &mut solution,
802 );
803}
804
805#[must_use]
826pub fn should_bypass_cpu_with_gpu_adaptive(
827 p: usize,
828 dense_spd_h_resident: bool,
829 plain_spd_logdet: bool,
830 prefers_stochastic: bool,
831 projected_penalty_subspace_active: bool,
832) -> bool {
833 p >= HUTCHINSON_GPU_MIN_P
834 && dense_spd_h_resident
835 && plain_spd_logdet
836 && prefers_stochastic
837 && !projected_penalty_subspace_active
838}
839
840#[cfg(target_os = "linux")]
845mod linux_cuda {
846 use super::{
847 DerivativeHessian, ProbeSeed, RemlTraceHutchinsonEvidence, RemlTraceHutchinsonInput,
848 reduce_mean_stderr,
849 };
850 use gam_gpu::driver::to_col_major;
851 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
852 use gam_gpu::solver::{
853 cholesky_logdet_from_col_major, context_and_stream, pinned_htod, potrf_in_place,
854 potrs_in_place,
855 };
856 use cudarc::cublas::sys::cublasOperation_t;
857 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
858 use cudarc::cusolver::DnHandle;
859 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
860 use std::sync::Arc;
861
862 pub(super) const PTX_SOURCE: &str = r#"
877extern "C" __device__ unsigned long long splitmix64_mix(unsigned long long z) {
878 z += 0x9E3779B97F4A7C15ULL;
879 unsigned long long x = z;
880 x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
881 x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
882 return x ^ (x >> 31);
883}
884
885extern "C" __global__ void fill_rademacher_splitmix(
886 unsigned long long seed,
887 unsigned int p,
888 unsigned int K,
889 double* __restrict__ Z)
890{
891 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
892 unsigned int k = blockIdx.y;
893 if (i >= p || k >= K) return;
894 const unsigned long long ZETA = 0xD1B54A32D192ED03ULL;
895 const unsigned long long GAMMA = 0x8CB92BA72F9DE81FULL;
896 unsigned long long composite =
897 seed
898 ^ (((unsigned long long)k) * ZETA)
899 ^ (((unsigned long long)i) * GAMMA);
900 unsigned long long h = splitmix64_mix(composite);
901 double v = (h >> 63) == 0 ? 1.0 : -1.0;
902 Z[(size_t)k * (size_t)p + (size_t)i] = v;
903}
904
905extern "C" __device__ double block_reduce_sum(double v) {
906 __shared__ double smem[32];
907 int lane = threadIdx.x & 31;
908 int wid = threadIdx.x >> 5;
909 for (int off = 16; off > 0; off >>= 1) {
910 v += __shfl_down_sync(0xffffffff, v, off);
911 }
912 if (lane == 0) smem[wid] = v;
913 __syncthreads();
914 double total = 0.0;
915 int n_warps = (blockDim.x + 31) >> 5;
916 if (threadIdx.x < (unsigned)n_warps) total = smem[threadIdx.x];
917 if (wid == 0) {
918 for (int off = 16; off > 0; off >>= 1) {
919 total += __shfl_down_sync(0xffffffff, total, off);
920 }
921 }
922 return total;
923}
924
925extern "C" __global__ void reduce_q_dense(
926 unsigned int p,
927 unsigned int K,
928 unsigned int D,
929 const double* __restrict__ Z,
930 const double* __restrict__ Y_stack,
931 double* __restrict__ Q)
932{
933 unsigned int k = blockIdx.x;
934 unsigned int j = blockIdx.y;
935 if (k >= K || j >= D) return;
936 const double* z_col = Z + (size_t)k * (size_t)p;
937 const double* y_col = Y_stack + ((size_t)j * (size_t)K + (size_t)k) * (size_t)p;
938 double partial = 0.0;
939 for (unsigned int i = threadIdx.x; i < p; i += blockDim.x) {
940 partial += z_col[i] * y_col[i];
941 }
942 double total = block_reduce_sum(partial);
943 if (threadIdx.x == 0) {
944 Q[(size_t)j * (size_t)K + (size_t)k] = total;
945 }
946}
947
948extern "C" __global__ void reduce_q_weighted_gram(
949 unsigned int n,
950 unsigned int K,
951 unsigned int D,
952 const double* __restrict__ RZ,
953 const double* __restrict__ RW,
954 const double* __restrict__ A_stack,
955 double* __restrict__ Q)
956{
957 unsigned int k = blockIdx.x;
958 unsigned int j = blockIdx.y;
959 if (k >= K || j >= D) return;
960 const double* rz_col = RZ + (size_t)k * (size_t)n;
961 const double* rw_col = RW + (size_t)k * (size_t)n;
962 const double* a_col = A_stack + (size_t)j * (size_t)n;
963 double partial = 0.0;
964 for (unsigned int i = threadIdx.x; i < n; i += blockDim.x) {
965 partial += a_col[i] * rz_col[i] * rw_col[i];
966 }
967 double total = block_reduce_sum(partial);
968 if (threadIdx.x == 0) {
969 Q[(size_t)j * (size_t)K + (size_t)k] = total;
970 }
971}
972"#;
973
974 const THREADS_PER_BLOCK: u32 = 256;
975
976 fn module(ctx: &Arc<CudaContext>) -> Result<&'static Arc<CudaModule>, GpuError> {
977 static CACHE: gam_gpu::device_cache::PtxModuleCache =
978 gam_gpu::device_cache::PtxModuleCache::new();
979 CACHE.get_or_compile(ctx, "reml_trace", PTX_SOURCE)
980 }
981
982 pub(super) fn evidence_derivatives(
983 input: &RemlTraceHutchinsonInput<'_>,
984 ) -> Result<RemlTraceHutchinsonEvidence, GpuError> {
985 let p = input.penalized_hessian.nrows();
986 let d = input.derivatives.len();
987 let k = input.probe_count;
988 let (ctx, stream) =
989 context_and_stream().map_err(|reason| GpuError::DriverCallFailed { reason })?;
990 let solver = DnHandle::new(stream.clone()).gpu_ctx("reml_trace cusolver init")?;
991 let blas = CudaBlas::new(stream.clone()).gpu_ctx("reml_trace cublas init")?;
992 let compiled = module(&ctx)?;
993 let module_handle: &Arc<CudaModule> = compiled;
994
995 let h_col = to_col_major(&input.penalized_hessian);
997 let mut h_dev =
998 pinned_htod(&stream, &h_col).map_err(|reason| GpuError::DriverCallFailed { reason })?;
999 potrf_in_place(&solver, &stream, p, &mut h_dev)
1000 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1001 let factor_col = stream
1002 .clone_dtoh(&h_dev)
1003 .gpu_ctx("reml_trace download factor")?;
1004 let logdet_hessian = cholesky_logdet_from_col_major(&factor_col, p);
1005
1006 let total_z = p
1008 .checked_mul(k)
1009 .ok_or_else(|| gam_gpu::gpu_err!("reml_trace Z size overflow: p={p}, K={k}"))?;
1010 let mut z_dev = stream
1011 .alloc_zeros::<f64>(total_z)
1012 .gpu_ctx("reml_trace alloc Z")?;
1013 launch_fill_rademacher(&stream, module_handle, input.seed, p, k, &mut z_dev)?;
1014
1015 let mut w_dev = stream
1018 .alloc_zeros::<f64>(total_z)
1019 .gpu_ctx("reml_trace alloc W")?;
1020 copy_device_slice(&stream, &z_dev, &mut w_dev)?;
1021 potrs_in_place(&solver, &stream, p, k, &h_dev, &mut w_dev)
1022 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1023
1024 let mut dense_indices: Vec<usize> = Vec::new();
1026 let mut gram_indices: Vec<usize> = Vec::new();
1027 for (j, deriv) in input.derivatives.iter().enumerate() {
1028 match deriv {
1029 DerivativeHessian::Dense(_) => dense_indices.push(j),
1030 DerivativeHessian::WeightedGram { .. } => gram_indices.push(j),
1031 }
1032 }
1033
1034 let mut q_host = vec![0.0_f64; d * k];
1035
1036 if !dense_indices.is_empty() {
1041 for &j in &dense_indices {
1042 let DerivativeHessian::Dense(matrix) = &input.derivatives[j] else {
1043 panic!(
1050 "reml_trace dense path: derivative index {j} is in dense_indices but \
1051 input.derivatives[{j}] is not DerivativeHessian::Dense — \
1052 dense_indices partition invariant violated"
1053 );
1054 };
1055 let hj_col = to_col_major(matrix);
1056 let hj_dev = pinned_htod(&stream, &hj_col)
1057 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1058 let mut y_dev = stream
1059 .alloc_zeros::<f64>(total_z)
1060 .map_err(|err| gam_gpu::gpu_err!("reml_trace alloc Y_j (j={j}): {err}"))?;
1061 gemm_nn(
1062 &blas,
1063 GemmShape {
1064 m: p,
1065 n: k,
1066 k_inner: p,
1067 lda: p,
1068 ldb: p,
1069 ldc: p,
1070 },
1071 &hj_dev,
1072 &w_dev,
1073 &mut y_dev,
1074 )?;
1075 let mut q_j_dev = stream
1076 .alloc_zeros::<f64>(k)
1077 .gpu_ctx_with(|err| format!("reml_trace alloc Q_j (j={j}): {err}"))?;
1078 launch_reduce_q_dense(
1079 &stream,
1080 module_handle,
1081 p,
1082 k,
1083 1,
1084 &z_dev,
1085 &y_dev,
1086 &mut q_j_dev,
1087 )?;
1088 let q_host_j = stream
1089 .clone_dtoh(&q_j_dev)
1090 .gpu_ctx_with(|err| format!("reml_trace download Q_j (j={j}): {err}"))?;
1091 q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_j);
1092 }
1093 }
1094
1095 if !gram_indices.is_empty() {
1098 let design = input
1099 .design
1100 .as_ref()
1101 .ok_or_else(|| GpuError::DriverCallFailed {
1102 reason: "reml_trace: structural derivative present but design=None".to_string(),
1103 })?;
1104 let n = design.nrows();
1105 let design_col = to_col_major(design);
1106 let x_dev = pinned_htod(&stream, &design_col)
1107 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1108 let mut rz_dev = stream
1109 .alloc_zeros::<f64>(
1110 n.checked_mul(k)
1111 .ok_or_else(|| gam_gpu::gpu_err!("reml_trace RZ overflow: n={n}, K={k}"))?,
1112 )
1113 .gpu_ctx("reml_trace alloc RZ")?;
1114 let mut rw_dev = stream
1115 .alloc_zeros::<f64>(n * k)
1116 .gpu_ctx("reml_trace alloc RW")?;
1117 gemm_nn(
1119 &blas,
1120 GemmShape {
1121 m: n,
1122 n: k,
1123 k_inner: p,
1124 lda: n,
1125 ldb: p,
1126 ldc: n,
1127 },
1128 &x_dev,
1129 &z_dev,
1130 &mut rz_dev,
1131 )?;
1132 gemm_nn(
1134 &blas,
1135 GemmShape {
1136 m: n,
1137 n: k,
1138 k_inner: p,
1139 lda: n,
1140 ldb: p,
1141 ldc: n,
1142 },
1143 &x_dev,
1144 &w_dev,
1145 &mut rw_dev,
1146 )?;
1147
1148 let d_gram = gram_indices.len();
1150 let mut a_stack = Vec::<f64>::with_capacity(n * d_gram);
1151 for &j in &gram_indices {
1152 let DerivativeHessian::WeightedGram { row_weights, .. } = &input.derivatives[j]
1153 else {
1154 panic!(
1160 "reml_trace structural path: derivative index {j} is in gram_indices \
1161 but input.derivatives[{j}] is not DerivativeHessian::WeightedGram — \
1162 gram_indices partition invariant violated"
1163 );
1164 };
1165 let slice = row_weights.as_slice().ok_or_else(|| {
1166 gam_gpu::gpu_err!("reml_trace structural H_j={j} row_weights not contiguous")
1167 })?;
1168 a_stack.extend_from_slice(slice);
1169 }
1170 let a_dev = pinned_htod(&stream, &a_stack)
1171 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1172 let mut q_dev = stream
1173 .alloc_zeros::<f64>(d_gram * k)
1174 .map_err(|err| gam_gpu::gpu_err!("reml_trace alloc Q_gram: {err}"))?;
1175 launch_reduce_q_weighted_gram(
1176 &stream,
1177 module_handle,
1178 n,
1179 k,
1180 d_gram,
1181 &rz_dev,
1182 &rw_dev,
1183 &a_dev,
1184 &mut q_dev,
1185 )?;
1186 let q_host_gram = stream
1187 .clone_dtoh(&q_dev)
1188 .gpu_ctx("reml_trace download Q_gram")?;
1189 for (slot, &j) in gram_indices.iter().enumerate() {
1190 q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_gram[slot * k..(slot + 1) * k]);
1191 }
1192 for &j in &gram_indices {
1196 let DerivativeHessian::WeightedGram { penalty_extra, .. } = &input.derivatives[j]
1197 else {
1198 panic!(
1205 "reml_trace structural penalty_extra: derivative index {j} is in \
1206 gram_indices but input.derivatives[{j}] is not \
1207 DerivativeHessian::WeightedGram — gram_indices partition invariant \
1208 violated"
1209 );
1210 };
1211 if let Some(pen) = penalty_extra {
1212 let z_host = stream
1213 .clone_dtoh(&z_dev)
1214 .gpu_ctx("reml_trace download Z for penalty_extra")?;
1215 let w_host = stream
1216 .clone_dtoh(&w_dev)
1217 .gpu_ctx("reml_trace download W for penalty_extra")?;
1218 for col in 0..k {
1219 let z_col = &z_host[col * p..(col + 1) * p];
1220 let w_col = &w_host[col * p..(col + 1) * p];
1221 let mut acc = 0.0_f64;
1222 for r in 0..p {
1223 let mut row_acc = 0.0_f64;
1224 for c in 0..p {
1225 row_acc += pen[[r, c]] * w_col[c];
1226 }
1227 acc += z_col[r] * row_acc;
1228 }
1229 q_host[j * k + col] += acc;
1230 }
1231 }
1232 }
1233 }
1234
1235 let (means, stderrs) = reduce_mean_stderr(&q_host, d, k);
1236 let mut gradient_rho_logdet = ndarray::Array1::<f64>::zeros(d);
1237 let mut gradient_rho_stderr = ndarray::Array1::<f64>::zeros(d);
1238 for j in 0..d {
1239 gradient_rho_logdet[j] = 0.5 * means[j];
1240 gradient_rho_stderr[j] = 0.5 * stderrs[j];
1241 }
1242
1243 Ok(RemlTraceHutchinsonEvidence {
1244 logdet_hessian,
1245 gradient_rho_logdet,
1246 gradient_rho_stderr,
1247 probe_count: k,
1248 })
1249 }
1250
1251 fn launch_fill_rademacher(
1254 stream: &Arc<CudaStream>,
1255 module: &Arc<CudaModule>,
1256 seed: ProbeSeed,
1257 p: usize,
1258 k: usize,
1259 z: &mut cudarc::driver::CudaSlice<f64>,
1260 ) -> Result<(), GpuError> {
1261 let func = module
1262 .load_function("fill_rademacher_splitmix")
1263 .gpu_ctx("reml_trace load fill_rademacher")?;
1264 let grid_x = ((p as u32) + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1265 let cfg = LaunchConfig {
1266 grid_dim: (grid_x, k as u32, 1),
1267 block_dim: (THREADS_PER_BLOCK, 1, 1),
1268 shared_mem_bytes: 0,
1269 };
1270 let seed_arg: u64 = seed.0;
1271 let p_arg: u32 = p as u32;
1272 let k_arg: u32 = k as u32;
1273 unsafe {
1276 stream
1277 .launch_builder(&func)
1278 .arg(&seed_arg)
1279 .arg(&p_arg)
1280 .arg(&k_arg)
1281 .arg(z)
1282 .launch(cfg)
1283 }
1284 .map(|_| ())
1285 .gpu_ctx("reml_trace launch fill_rademacher")
1286 }
1287
1288 fn launch_reduce_q_dense(
1289 stream: &Arc<CudaStream>,
1290 module: &Arc<CudaModule>,
1291 p: usize,
1292 k: usize,
1293 d: usize,
1294 z: &cudarc::driver::CudaSlice<f64>,
1295 y_stack: &cudarc::driver::CudaSlice<f64>,
1296 q: &mut cudarc::driver::CudaSlice<f64>,
1297 ) -> Result<(), GpuError> {
1298 let func = module
1299 .load_function("reduce_q_dense")
1300 .gpu_ctx("reml_trace load reduce_q_dense")?;
1301 let cfg = LaunchConfig {
1302 grid_dim: (k as u32, d as u32, 1),
1303 block_dim: (THREADS_PER_BLOCK, 1, 1),
1304 shared_mem_bytes: 0,
1305 };
1306 let p_arg: u32 = p as u32;
1307 let k_arg: u32 = k as u32;
1308 let d_arg: u32 = d as u32;
1309 unsafe {
1312 stream
1313 .launch_builder(&func)
1314 .arg(&p_arg)
1315 .arg(&k_arg)
1316 .arg(&d_arg)
1317 .arg(z)
1318 .arg(y_stack)
1319 .arg(q)
1320 .launch(cfg)
1321 }
1322 .map(|_| ())
1323 .gpu_ctx("reml_trace launch reduce_q_dense")
1324 }
1325
1326 fn launch_reduce_q_weighted_gram(
1327 stream: &Arc<CudaStream>,
1328 module: &Arc<CudaModule>,
1329 n: usize,
1330 k: usize,
1331 d: usize,
1332 rz: &cudarc::driver::CudaSlice<f64>,
1333 rw: &cudarc::driver::CudaSlice<f64>,
1334 a_stack: &cudarc::driver::CudaSlice<f64>,
1335 q: &mut cudarc::driver::CudaSlice<f64>,
1336 ) -> Result<(), GpuError> {
1337 let func = module
1338 .load_function("reduce_q_weighted_gram")
1339 .gpu_ctx("reml_trace load reduce_q_weighted_gram")?;
1340 let cfg = LaunchConfig {
1341 grid_dim: (k as u32, d as u32, 1),
1342 block_dim: (THREADS_PER_BLOCK, 1, 1),
1343 shared_mem_bytes: 0,
1344 };
1345 let n_arg: u32 = n as u32;
1346 let k_arg: u32 = k as u32;
1347 let d_arg: u32 = d as u32;
1348 unsafe {
1350 stream
1351 .launch_builder(&func)
1352 .arg(&n_arg)
1353 .arg(&k_arg)
1354 .arg(&d_arg)
1355 .arg(rz)
1356 .arg(rw)
1357 .arg(a_stack)
1358 .arg(q)
1359 .launch(cfg)
1360 }
1361 .map(|_| ())
1362 .gpu_ctx("reml_trace launch reduce_q_weighted_gram")
1363 }
1364
1365 fn copy_device_slice(
1366 stream: &Arc<CudaStream>,
1367 src: &cudarc::driver::CudaSlice<f64>,
1368 dst: &mut cudarc::driver::CudaSlice<f64>,
1369 ) -> Result<(), GpuError> {
1370 stream.memcpy_dtod(src, dst).gpu_ctx("reml_trace dtod copy")
1371 }
1372
1373 struct GemmShape {
1374 m: usize,
1375 n: usize,
1376 k_inner: usize,
1377 lda: usize,
1378 ldb: usize,
1379 ldc: usize,
1380 }
1381
1382 fn gemm_nn(
1383 blas: &CudaBlas,
1384 shape: GemmShape,
1385 a: &cudarc::driver::CudaSlice<f64>,
1386 b: &cudarc::driver::CudaSlice<f64>,
1387 c: &mut cudarc::driver::CudaSlice<f64>,
1388 ) -> Result<(), GpuError> {
1389 let GemmShape {
1390 m,
1391 n,
1392 k_inner,
1393 lda,
1394 ldb,
1395 ldc,
1396 } = shape;
1397 let cfg = GemmConfig::<f64> {
1398 transa: cublasOperation_t::CUBLAS_OP_N,
1399 transb: cublasOperation_t::CUBLAS_OP_N,
1400 m: m as i32,
1401 n: n as i32,
1402 k: k_inner as i32,
1403 alpha: 1.0,
1404 lda: lda as i32,
1405 ldb: ldb as i32,
1406 beta: 0.0,
1407 ldc: ldc as i32,
1408 };
1409 unsafe { blas.gemm(cfg, a, b, c) }.gpu_ctx("reml_trace cublas dgemm")
1412 }
1413}
1414
1415fn validate_inputs(input: &RemlTraceHutchinsonInput<'_>) -> Result<(), String> {
1420 let (p, p2) = input.penalized_hessian.dim();
1421 if p == 0 || p != p2 {
1422 return Err(format!("reml_trace input H must be square, got {p}x{p2}"));
1423 }
1424 if input.probe_count < 2 {
1425 return Err(format!(
1426 "reml_trace requires probe_count >= 2 for a sample SE, got {}",
1427 input.probe_count
1428 ));
1429 }
1430 let needs_design = input
1431 .derivatives
1432 .iter()
1433 .any(|d| matches!(d, DerivativeHessian::WeightedGram { .. }));
1434 if needs_design && input.design.is_none() {
1435 return Err("reml_trace: structural derivative present but design=None".to_string());
1436 }
1437 let n = input.design.as_ref().map(|x| x.nrows()).unwrap_or(0);
1438 if let Some(x) = input.design.as_ref()
1439 && x.ncols() != p
1440 {
1441 return Err(format!(
1442 "reml_trace design has {} columns, expected p={p}",
1443 x.ncols()
1444 ));
1445 }
1446 for (j, derivative) in input.derivatives.iter().enumerate() {
1447 derivative
1448 .dim_p(p, n)
1449 .map_err(String::from)
1450 .map_err(|e| format!("reml_trace derivative {j}: {e}"))?;
1451 }
1452 Ok(())
1453}
1454
1455fn reduce_mean_stderr(q: &[f64], d: usize, k: usize) -> (Vec<f64>, Vec<f64>) {
1465 assert_eq!(
1466 q.len(),
1467 d * k,
1468 "reduce_mean_stderr: q buffer length {} != D*K = {}*{}",
1469 q.len(),
1470 d,
1471 k
1472 );
1473 let mut means = vec![0.0_f64; d];
1474 let mut stderrs = vec![0.0_f64; d];
1475 let inv_k = 1.0 / (k as f64);
1476 for j in 0..d {
1477 let row = &q[j * k..(j + 1) * k];
1478 let mean = row.iter().copied().sum::<f64>() * inv_k;
1479 means[j] = mean;
1480 if k >= 2 {
1481 let var = row.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / ((k - 1) as f64);
1482 stderrs[j] = var.sqrt();
1483 }
1484 }
1485 (means, stderrs)
1486}
1487
1488fn cholesky_lower(matrix: &Array2<f64>) -> Result<Array2<f64>, String> {
1491 let n = matrix.nrows();
1492 let mut l = Array2::<f64>::zeros((n, n));
1493 for i in 0..n {
1494 for j in 0..=i {
1495 let mut sum = matrix[[i, j]];
1496 for k in 0..j {
1497 sum -= l[[i, k]] * l[[j, k]];
1498 }
1499 if i == j {
1500 if sum <= 0.0 {
1501 return Err(format!(
1502 "reml_trace CPU Cholesky: non-SPD diagonal {sum} at row {i}"
1503 ));
1504 }
1505 l[[i, j]] = sum.sqrt();
1506 } else {
1507 l[[i, j]] = sum / l[[j, j]];
1508 }
1509 }
1510 }
1511 Ok(l)
1512}
1513
1514fn solve_cholesky(l: &Array2<f64>, rhs: &[f64]) -> Vec<f64> {
1515 let n = l.nrows();
1516 let mut y = vec![0.0_f64; n];
1517 for i in 0..n {
1518 let mut sum = rhs[i];
1519 for k in 0..i {
1520 sum -= l[[i, k]] * y[k];
1521 }
1522 y[i] = sum / l[[i, i]];
1523 }
1524 let mut x = vec![0.0_f64; n];
1525 for i in (0..n).rev() {
1526 let mut sum = y[i];
1527 for k in (i + 1)..n {
1528 sum -= l[[k, i]] * x[k];
1529 }
1530 x[i] = sum / l[[i, i]];
1531 }
1532 x
1533}
1534
1535#[cfg(test)]
1540mod tests {
1541 use super::*;
1542 use ndarray::{Array2, ArrayView2};
1543
1544 fn make_spd(p: usize, jitter: f64) -> Array2<f64> {
1545 let mut h = Array2::<f64>::zeros((p, p));
1546 for i in 0..p {
1547 for j in 0..p {
1548 h[[i, j]] = if i == j {
1549 p as f64 + jitter
1550 } else {
1551 1.0 / (1.0 + (i as f64 - j as f64).abs())
1552 };
1553 }
1554 }
1555 h
1556 }
1557
1558 fn random_dense_sym(p: usize, seed: u64) -> Array2<f64> {
1559 let mut a = Array2::<f64>::zeros((p, p));
1560 let mut s = seed;
1561 for i in 0..p {
1562 for j in i..p {
1563 s = splitmix64_mix(s.wrapping_add(1));
1564 let v = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
1565 a[[i, j]] = v;
1566 a[[j, i]] = v;
1567 }
1568 }
1569 a
1570 }
1571
1572 fn exact_trace_hinv_a(h: ArrayView2<f64>, a: ArrayView2<f64>) -> f64 {
1573 let p = h.nrows();
1574 let factor = cholesky_lower(&h.to_owned()).expect("SPD");
1575 let mut trace = 0.0;
1576 for col in 0..p {
1577 let mut e = vec![0.0_f64; p];
1578 e[col] = 1.0;
1579 let w = solve_cholesky(&factor, &e);
1580 let mut diag = 0.0;
1582 for i in 0..p {
1583 diag += a[[col, i]] * w[i];
1584 }
1585 trace += diag;
1586 }
1587 trace
1588 }
1589
1590 #[test]
1591 fn splitmix_is_deterministic_and_disperses() {
1592 assert_eq!(splitmix64_mix(42), splitmix64_mix(42));
1595 let mut bits_seen = 0u64;
1596 for x in 0u64..64 {
1597 bits_seen |= splitmix64_mix(x);
1598 }
1599 assert_eq!(
1600 bits_seen,
1601 u64::MAX,
1602 "splitmix should cover every bit position across 64 inputs"
1603 );
1604 }
1605
1606 #[test]
1607 fn rademacher_entries_are_pm_one_and_stateless() {
1608 let seed = ProbeSeed(0xCAFE_BABE);
1609 for k in 0..16u64 {
1610 for i in 0..32u64 {
1611 let v = rademacher_entry(seed.0, k, i);
1612 assert!(
1613 v == 1.0 || v == -1.0,
1614 "non-pm1 entry at (k={k}, i={i}): {v}"
1615 );
1616 let v2 = rademacher_entry(seed.0, k, i);
1617 assert_eq!(v, v2, "same (k,i) must hash to same value");
1618 }
1619 }
1620 }
1621
1622 #[test]
1623 fn rademacher_common_random_numbers_match_for_prefix() {
1624 let p = 50;
1626 let mut z16 = vec![0.0_f64; p * 16];
1627 let mut z32 = vec![0.0_f64; p * 32];
1628 fill_rademacher_host(ProbeSeed(7), p, 16, &mut z16);
1629 fill_rademacher_host(ProbeSeed(7), p, 32, &mut z32);
1630 for col in 0..16 {
1631 for row in 0..p {
1632 assert_eq!(
1633 z16[col * p + row],
1634 z32[col * p + row],
1635 "CRN broken at (col={col}, row={row})"
1636 );
1637 }
1638 }
1639 }
1640
1641 #[test]
1642 fn cpu_hutchinson_unbiased_against_exact_small_spd() {
1643 let p = 16;
1644 let h = make_spd(p, 0.5);
1645 let a1 = random_dense_sym(p, 0x1234);
1646 let a2 = random_dense_sym(p, 0x5678);
1647 let exact1 = exact_trace_hinv_a(h.view(), a1.view());
1648 let exact2 = exact_trace_hinv_a(h.view(), a2.view());
1649 let input = RemlTraceHutchinsonInput {
1650 penalized_hessian: h.view(),
1651 derivatives: vec![
1652 DerivativeHessian::Dense(a1.view()),
1653 DerivativeHessian::Dense(a2.view()),
1654 ],
1655 design: None,
1656 probe_count: 4096,
1657 seed: ProbeSeed(0xCAFE_BABE),
1658 };
1659 let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
1660 let est1 = 2.0 * evidence.gradient_rho_logdet[0];
1662 let est2 = 2.0 * evidence.gradient_rho_logdet[1];
1663 let sqrt_k = (evidence.probe_count as f64).sqrt();
1666 let se1 = 2.0 * evidence.gradient_rho_stderr[0] / sqrt_k;
1667 let se2 = 2.0 * evidence.gradient_rho_stderr[1] / sqrt_k;
1668 let tol1 = 6.0 * se1.max(1e-8);
1669 let tol2 = 6.0 * se2.max(1e-8);
1670 assert!(
1671 (est1 - exact1).abs() <= tol1,
1672 "Hutchinson est {est1} too far from exact {exact1} (tol={tol1}, se={})",
1673 evidence.gradient_rho_stderr[0]
1674 );
1675 assert!(
1676 (est2 - exact2).abs() <= tol2,
1677 "Hutchinson est {est2} too far from exact {exact2} (tol={tol2})"
1678 );
1679 }
1680
1681 #[test]
1682 fn structural_path_matches_dense_for_xtwx() {
1683 let n = 40;
1686 let p = 8;
1687 let mut x = Array2::<f64>::zeros((n, p));
1688 let mut s = 11u64;
1689 for r in 0..n {
1690 for c in 0..p {
1691 s = splitmix64_mix(s.wrapping_add(1));
1692 x[[r, c]] = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
1693 }
1694 }
1695 let a: Vec<f64> = (0..n).map(|i| 0.5 + 0.01 * (i as f64)).collect();
1696 let a_arr = ndarray::Array1::from(a);
1697 let mut hj_dense = Array2::<f64>::zeros((p, p));
1699 for r in 0..p {
1700 for c in 0..p {
1701 let mut acc = 0.0;
1702 for i in 0..n {
1703 acc += x[[i, r]] * a_arr[i] * x[[i, c]];
1704 }
1705 hj_dense[[r, c]] = acc;
1706 }
1707 }
1708 let mut h = make_spd(p, 1.0);
1710 for i in 0..p {
1711 h[[i, i]] += 1.0;
1712 }
1713 let input_dense = RemlTraceHutchinsonInput {
1714 penalized_hessian: h.view(),
1715 derivatives: vec![DerivativeHessian::Dense(hj_dense.view())],
1716 design: None,
1717 probe_count: 32,
1718 seed: ProbeSeed(123),
1719 };
1720 let input_struct = RemlTraceHutchinsonInput {
1721 penalized_hessian: h.view(),
1722 derivatives: vec![DerivativeHessian::WeightedGram {
1723 row_weights: a_arr.view(),
1724 penalty_extra: None,
1725 }],
1726 design: Some(x.view()),
1727 probe_count: 32,
1728 seed: ProbeSeed(123),
1729 };
1730 let e_dense = evidence_derivatives_hutchinson_cpu(&input_dense).expect("ok");
1731 let e_struct = evidence_derivatives_hutchinson_cpu(&input_struct).expect("ok");
1732 assert!(
1734 (e_dense.gradient_rho_logdet[0] - e_struct.gradient_rho_logdet[0]).abs() < 1e-9,
1735 "dense vs structural mismatch: dense={}, struct={}",
1736 e_dense.gradient_rho_logdet[0],
1737 e_struct.gradient_rho_logdet[0]
1738 );
1739 }
1740
1741 #[test]
1742 fn finite_difference_check_against_logdet() {
1743 let p = 10;
1745 let h0 = make_spd(p, 0.2);
1746 let a = random_dense_sym(p, 0xABCD);
1747 let eps = 1e-4;
1748 let mut hp = h0.clone();
1749 let mut hm = h0.clone();
1750 for i in 0..p {
1751 for j in 0..p {
1752 hp[[i, j]] += eps * a[[i, j]];
1753 hm[[i, j]] -= eps * a[[i, j]];
1754 }
1755 }
1756 let ld = |m: &Array2<f64>| -> f64 {
1757 let l = cholesky_lower(m).unwrap();
1758 2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
1759 };
1760 let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
1761 let exact = exact_trace_hinv_a(h0.view(), a.view());
1762 assert!(
1763 (fd - exact).abs() / exact.abs().max(1e-12) < 1e-6,
1764 "FD logdet derivative {fd} != exact trace {exact}"
1765 );
1766 let input = RemlTraceHutchinsonInput {
1768 penalized_hessian: h0.view(),
1769 derivatives: vec![DerivativeHessian::Dense(a.view())],
1770 design: None,
1771 probe_count: 4096,
1772 seed: ProbeSeed(0xAA55),
1773 };
1774 let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
1775 let se = evidence.gradient_rho_stderr[0] / (evidence.probe_count as f64).sqrt();
1777 let tol = 8.0 * se.max(1e-8);
1778 assert!(
1779 (evidence.gradient_rho_logdet[0] - 0.5 * exact).abs() < tol,
1780 "Hutchinson gradient {} not within 8·SE of 0.5·exact={}",
1781 evidence.gradient_rho_logdet[0],
1782 0.5 * exact
1783 );
1784 }
1785
1786 #[test]
1787 fn gate_rejects_below_min_p() {
1788 assert!(!should_use_gpu_hutchinson(64, 16, true, true, true, false));
1789 }
1790
1791 #[test]
1792 fn gate_rejects_k_out_of_range() {
1793 assert!(!should_use_gpu_hutchinson(2000, 4, true, true, true, false));
1794 assert!(!should_use_gpu_hutchinson(
1795 2000, 200, true, true, true, false
1796 ));
1797 }
1798
1799 #[test]
1800 fn gate_rejects_when_subspace_active() {
1801 assert!(!should_use_gpu_hutchinson(2000, 16, true, true, true, true));
1802 }
1803
1804 #[test]
1805 fn gate_accepts_canonical_case() {
1806 assert!(should_use_gpu_hutchinson(2000, 16, true, true, true, false));
1807 }
1808
1809 #[test]
1818 fn block_2_6_adaptive_unbiased_against_exact_p512() {
1819 let p = 64;
1822 let h = make_spd(p, 0.5);
1823 let a = random_dense_sym(p, 0xBADC0DE);
1824 let exact = exact_trace_hinv_a(h.view(), a.view());
1825 let evidence = evidence_traces_adaptive(
1826 h.view(),
1827 vec![DerivativeHessian::Dense(a.view())],
1828 None,
1829 ProbeSeed(0xA5A5A5),
1830 HUTCHINSON_ADAPTIVE_REL_TOL,
1831 HUTCHINSON_ADAPTIVE_TAU_REL,
1832 )
1833 .expect("adaptive run ok");
1834 let est = evidence.traces[0];
1835 let se = evidence.stderrs[0] / (evidence.probe_count as f64).sqrt();
1836 let tol = (8.0 * se).max(0.05 * exact.abs());
1837 assert!(
1838 (est - exact).abs() <= tol,
1839 "adaptive est {est} far from exact {exact} (tol={tol}, se={se}, K={})",
1840 evidence.probe_count
1841 );
1842 }
1843
1844 #[test]
1845 fn block_2_6_same_probes_cpu_vs_dispatch() {
1846 let p = 32;
1853 let h = make_spd(p, 0.3);
1854 let a = random_dense_sym(p, 0x1357);
1855 let input = RemlTraceHutchinsonInput {
1856 penalized_hessian: h.view(),
1857 derivatives: vec![DerivativeHessian::Dense(a.view())],
1858 design: None,
1859 probe_count: 16,
1860 seed: ProbeSeed(0xBEEF),
1861 };
1862 let cpu = evidence_derivatives_hutchinson_cpu(&input).expect("cpu");
1863 let dispatch = evidence_derivatives_hutchinson_gpu(input).expect("dispatch");
1864 let diff = (cpu.gradient_rho_logdet[0] - dispatch.gradient_rho_logdet[0]).abs();
1865 assert!(
1866 diff < 1e-9,
1867 "same-probes CPU vs GPU dispatch differ: cpu={}, dispatch={}, diff={diff}",
1868 cpu.gradient_rho_logdet[0],
1869 dispatch.gradient_rho_logdet[0]
1870 );
1871 }
1872
1873 #[test]
1874 fn block_2_6_fd_logdet_matches_adaptive() {
1875 let p = 24;
1878 let h = make_spd(p, 0.4);
1879 let a = random_dense_sym(p, 0x2468);
1880 let eps = 1e-4;
1881 let mut hp = h.clone();
1882 let mut hm = h.clone();
1883 for i in 0..p {
1884 for j in 0..p {
1885 hp[[i, j]] += eps * a[[i, j]];
1886 hm[[i, j]] -= eps * a[[i, j]];
1887 }
1888 }
1889 let ld = |m: &Array2<f64>| -> f64 {
1890 let l = cholesky_lower(m).expect("SPD");
1891 2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
1892 };
1893 let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
1894 let evidence = evidence_traces_adaptive(
1895 h.view(),
1896 vec![DerivativeHessian::Dense(a.view())],
1897 None,
1898 ProbeSeed(0x9999),
1899 HUTCHINSON_ADAPTIVE_REL_TOL,
1900 HUTCHINSON_ADAPTIVE_TAU_REL,
1901 )
1902 .expect("adaptive ok");
1903 let est = evidence.traces[0];
1904 let se = evidence.stderrs[0] / (evidence.probe_count as f64).sqrt();
1905 let tol = (8.0 * se).max(0.05 * fd.abs());
1906 assert!(
1907 (est - fd).abs() <= tol,
1908 "adaptive trace {est} disagrees with FD logdet derivative {fd} (tol={tol})"
1909 );
1910 }
1911
1912 #[test]
1913 fn block_2_6_k_4096_matches_exact_tightly() {
1914 let p = 40;
1919 let h = make_spd(p, 0.6);
1920 let a = random_dense_sym(p, 0xDEAD);
1921 let exact = exact_trace_hinv_a(h.view(), a.view());
1922 let input = RemlTraceHutchinsonInput {
1923 penalized_hessian: h.view(),
1924 derivatives: vec![DerivativeHessian::Dense(a.view())],
1925 design: None,
1926 probe_count: 4096,
1927 seed: ProbeSeed(0xC0FFEE),
1928 };
1929 let evidence = evidence_derivatives_hutchinson_gpu(input).expect("ok");
1930 let est = 2.0 * evidence.gradient_rho_logdet[0];
1931 let se = 2.0 * evidence.gradient_rho_stderr[0] / (4096_f64).sqrt();
1932 let tol = (6.0 * se).max(1e-3 * exact.abs());
1933 assert!(
1934 (est - exact).abs() <= tol,
1935 "K=4096 Hutchinson {est} not within 6·SE of exact {exact} (tol={tol}, se={se})"
1936 );
1937 }
1938
1939 #[test]
1940 fn block_2_6_crn_prefix_match_across_schedule() {
1941 let p = 50;
1947 let seed = ProbeSeed(0x4242_4242);
1948 let mut z16 = vec![0.0_f64; p * 16];
1949 let mut z32 = vec![0.0_f64; p * 32];
1950 let mut z64 = vec![0.0_f64; p * 64];
1951 fill_rademacher_host(seed, p, 16, &mut z16);
1952 fill_rademacher_host(seed, p, 32, &mut z32);
1953 fill_rademacher_host(seed, p, 64, &mut z64);
1954 for col in 0..16 {
1955 for row in 0..p {
1956 assert_eq!(z16[col * p + row], z32[col * p + row]);
1957 assert_eq!(z16[col * p + row], z64[col * p + row]);
1958 }
1959 }
1960 for col in 0..32 {
1961 for row in 0..p {
1962 assert_eq!(z32[col * p + row], z64[col * p + row]);
1963 }
1964 }
1965 }
1966
1967 #[test]
1972 fn block_2_7_hvp_path_matches_dense_adaptive() {
1973 let p = 40;
1977 let h = make_spd(p, 0.7);
1978 let a = random_dense_sym(p, 0xABBA);
1979 let seed = ProbeSeed(0x707);
1980
1981 let dense = evidence_traces_adaptive(
1982 h.view(),
1983 vec![DerivativeHessian::Dense(a.view())],
1984 None,
1985 seed,
1986 HUTCHINSON_ADAPTIVE_REL_TOL,
1987 HUTCHINSON_ADAPTIVE_TAU_REL,
1988 )
1989 .expect("dense ok");
1990
1991 let h_clone = h.clone();
1992 let hvp_evidence = evidence_traces_adaptive_hvp(
1993 p,
1994 |v: &[f64], out: &mut [f64]| {
1995 for r in 0..p {
1996 let mut acc = 0.0_f64;
1997 for c in 0..p {
1998 acc += h_clone[[r, c]] * v[c];
1999 }
2000 out[r] = acc;
2001 }
2002 },
2003 vec![DerivativeHessian::Dense(a.view())],
2004 None,
2005 seed,
2006 HUTCHINSON_ADAPTIVE_REL_TOL,
2007 HUTCHINSON_ADAPTIVE_TAU_REL,
2008 )
2009 .expect("hvp ok");
2010
2011 let exact = exact_trace_hinv_a(h.view(), a.view());
2015 let se_dense = dense.stderrs[0] / (dense.probe_count as f64).sqrt();
2016 let se_hvp = hvp_evidence.stderrs[0] / (hvp_evidence.probe_count as f64).sqrt();
2017 let tol_dense = (8.0 * se_dense).max(0.05 * exact.abs());
2018 let tol_hvp = (8.0 * se_hvp).max(0.05 * exact.abs());
2019 assert!(
2020 (dense.traces[0] - exact).abs() <= tol_dense,
2021 "dense adaptive {} not near exact {} (tol {})",
2022 dense.traces[0],
2023 exact,
2024 tol_dense
2025 );
2026 assert!(
2027 (hvp_evidence.traces[0] - exact).abs() <= tol_hvp,
2028 "hvp adaptive {} not near exact {} (tol {})",
2029 hvp_evidence.traces[0],
2030 exact,
2031 tol_hvp
2032 );
2033 assert!(hvp_evidence.logdet_hessian.is_nan());
2035 }
2036
2037 #[test]
2038 fn block_2_7_hvp_stderr_matches_dense_reduce_mean_stderr() {
2039 let p = 36;
2048 let h = make_spd(p, 0.6);
2049 let a = random_dense_sym(p, 0x5151);
2050 let seed = ProbeSeed(0xBEEF);
2051 let force_full_schedule = 1e-12_f64;
2052
2053 let dense = evidence_traces_adaptive(
2054 h.view(),
2055 vec![DerivativeHessian::Dense(a.view())],
2056 None,
2057 seed,
2058 force_full_schedule,
2059 HUTCHINSON_ADAPTIVE_TAU_REL,
2060 )
2061 .expect("dense ok");
2062
2063 let h_clone = h.clone();
2064 let hvp = evidence_traces_adaptive_hvp(
2065 p,
2066 |v: &[f64], out: &mut [f64]| {
2067 for r in 0..p {
2068 let mut acc = 0.0_f64;
2069 for c in 0..p {
2070 acc += h_clone[[r, c]] * v[c];
2071 }
2072 out[r] = acc;
2073 }
2074 },
2075 vec![DerivativeHessian::Dense(a.view())],
2076 None,
2077 seed,
2078 force_full_schedule,
2079 HUTCHINSON_ADAPTIVE_TAU_REL,
2080 )
2081 .expect("hvp ok");
2082
2083 assert_eq!(dense.probe_count, 128);
2085 assert_eq!(hvp.probe_count, dense.probe_count);
2086
2087 let sd_dense = dense.stderrs[0];
2088 let sd_hvp = hvp.stderrs[0];
2089 assert!(
2090 sd_dense > 0.0,
2091 "dense SD should be positive, got {sd_dense}"
2092 );
2093 let rel = (sd_hvp - sd_dense).abs() / sd_dense;
2094 assert!(
2095 rel <= 1e-3,
2096 "HVP SD {sd_hvp} disagrees with dense reduce_mean_stderr SD {sd_dense} \
2097 (rel {rel}); the two paths must share the Bessel-corrected (K−1) convention"
2098 );
2099 }
2100
2101 #[test]
2102 fn block_2_7_cg_solves_diagonal_in_one_iteration() {
2103 let p = 8;
2107 let diag: Vec<f64> = (0..p).map(|i| 1.0 + i as f64).collect();
2108 let b: Vec<f64> = (0..p).map(|i| (i as f64) + 0.5).collect();
2109 let mut w = vec![0.0_f64; p];
2110 let diag_clone = diag.clone();
2111 cg_solve(
2112 &mut |v: &[f64], out: &mut [f64]| {
2113 for i in 0..p {
2114 out[i] = diag_clone[i] * v[i];
2115 }
2116 },
2117 &b,
2118 &mut w,
2119 1e-12,
2120 PCG_HVP_MAX_ITERS,
2121 );
2122 for i in 0..p {
2123 let expected = b[i] / diag[i];
2124 assert!(
2125 (w[i] - expected).abs() < 1e-10,
2126 "diagonal CG: w[{i}]={} expected {expected}",
2127 w[i]
2128 );
2129 }
2130 }
2131
2132 #[test]
2145 fn block_2_8_hill_climb_adaptive_vs_exact_at_p2000_d8() {
2146 let on_v100 =
2149 cfg!(target_os = "linux") && gam_gpu::device_runtime::GpuRuntime::global().is_some();
2150 let (p, d): (usize, usize) = if on_v100 { (2000, 8) } else { (256, 4) };
2151
2152 let mut h = Array2::<f64>::zeros((p, p));
2153 for i in 0..p {
2154 for j in 0..p {
2155 h[[i, j]] = if i == j {
2156 p as f64 + 1.0
2157 } else {
2158 1.0 / (1.0 + (i as f64 - j as f64).abs())
2159 };
2160 }
2161 }
2162 let derivs_owned: Vec<Array2<f64>> = (0..d)
2163 .map(|k| random_dense_sym(p, 0x1000 + k as u64))
2164 .collect();
2165 let derivs: Vec<DerivativeHessian<'_>> = derivs_owned
2166 .iter()
2167 .map(|a| DerivativeHessian::Dense(a.view()))
2168 .collect();
2169
2170 let t_exact_start = std::time::Instant::now();
2174 let factor = cholesky_lower(&h).expect("SPD");
2175 let mut exact_traces = vec![0.0_f64; d];
2176 for (j, a) in derivs_owned.iter().enumerate() {
2177 let mut acc = 0.0_f64;
2178 for col in 0..p {
2179 let mut rhs = vec![0.0_f64; p];
2180 for r in 0..p {
2181 rhs[r] = a[[r, col]];
2182 }
2183 let w = solve_cholesky(&factor, &rhs);
2184 acc += w[col];
2185 }
2186 exact_traces[j] = acc;
2187 }
2188 let t_exact = t_exact_start.elapsed();
2189
2190 let t_adaptive_start = std::time::Instant::now();
2192 let evidence = evidence_traces_adaptive(
2193 h.view(),
2194 derivs,
2195 None,
2196 ProbeSeed(0xB10C),
2197 HUTCHINSON_ADAPTIVE_REL_TOL,
2198 HUTCHINSON_ADAPTIVE_TAU_REL,
2199 )
2200 .expect("adaptive ok");
2201 let t_adaptive = t_adaptive_start.elapsed();
2202
2203 for j in 0..d {
2206 let se = evidence.stderrs[j] / (evidence.probe_count as f64).sqrt();
2207 let tol = (10.0 * se).max(0.05 * exact_traces[j].abs());
2208 let diff = (evidence.traces[j] - exact_traces[j]).abs();
2209 assert!(
2210 diff <= tol,
2211 "block_2_8: derivative {j} adaptive {} disagrees with exact {} (tol {tol}, diff {diff})",
2212 evidence.traces[j],
2213 exact_traces[j]
2214 );
2215 }
2216
2217 let speedup = t_exact.as_secs_f64() / t_adaptive.as_secs_f64().max(1e-9);
2218 eprintln!(
2219 "block_2_8 hill-climb [p={p}, d={d}, V100={on_v100}]: \
2220 exact={:?}, adaptive={:?}, speedup={:.2}× (K={}, converged={})",
2221 t_exact, t_adaptive, speedup, evidence.probe_count, evidence.converged
2222 );
2223 if on_v100 {
2224 assert!(
2225 speedup >= 10.0,
2226 "block_2_8 V100 speedup {speedup:.2}× below the 10× target \
2227 (exact {:?}, adaptive {:?})",
2228 t_exact,
2229 t_adaptive,
2230 );
2231 }
2232 }
2233}