1use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1};
74
75use gam_gpu::gpu_error::GpuError;
76use gam_linalg::pcg::{DotReduction, pcg_core};
77
78#[derive(Clone, Copy, Debug)]
84pub struct ProbeSeed(pub u64);
85
86impl Default for ProbeSeed {
87 fn default() -> Self {
88 Self(0xCAFE_BABE)
91 }
92}
93
94#[derive(Clone, Debug)]
100pub enum DerivativeHessian<'a> {
101 Dense(ArrayView2<'a, f64>),
104 WeightedGram {
110 row_weights: ArrayView1<'a, f64>,
111 penalty_extra: Option<ArrayView2<'a, f64>>,
112 },
113}
114
115impl DerivativeHessian<'_> {
116 fn dim_p(&self, expected_p: usize, expected_n: usize) -> Result<(), GpuError> {
117 match self {
118 DerivativeHessian::Dense(matrix) => {
119 if matrix.nrows() != expected_p || matrix.ncols() != expected_p {
120 gam_gpu::gpu_bail!(
121 "reml_trace dense H_j: shape {:?} != ({expected_p}, {expected_p})",
122 matrix.dim()
123 );
124 }
125 }
126 DerivativeHessian::WeightedGram {
127 row_weights,
128 penalty_extra,
129 } => {
130 if row_weights.len() != expected_n {
131 gam_gpu::gpu_bail!(
132 "reml_trace structural H_j: row_weights.len()={} != n={expected_n}",
133 row_weights.len()
134 );
135 }
136 if let Some(p_extra) = penalty_extra
137 && (p_extra.nrows() != expected_p || p_extra.ncols() != expected_p)
138 {
139 gam_gpu::gpu_bail!(
140 "reml_trace structural H_j penalty_extra: shape {:?} != ({expected_p}, {expected_p})",
141 p_extra.dim()
142 );
143 }
144 }
145 }
146 Ok(())
147 }
148}
149
150#[derive(Clone, Debug)]
152pub struct RemlTraceHutchinsonInput<'a> {
153 pub penalized_hessian: ArrayView2<'a, f64>,
155 pub derivatives: Vec<DerivativeHessian<'a>>,
157 pub design: Option<ArrayView2<'a, f64>>,
160 pub probe_count: usize,
162 pub seed: ProbeSeed,
164}
165
166#[derive(Clone, Debug)]
168pub struct RemlTraceHutchinsonEvidence {
169 pub logdet_hessian: f64,
172 pub gradient_rho_logdet: Array1<f64>,
174 pub gradient_rho_stderr: Array1<f64>,
179 pub probe_count: usize,
181}
182
183pub const HUTCHINSON_GPU_MIN_P: usize = 512;
189pub const HUTCHINSON_GPU_MIN_K: usize = 8;
191pub const HUTCHINSON_GPU_MAX_K: usize = 128;
192
193#[must_use]
201pub fn should_use_gpu_hutchinson(
202 p: usize,
203 probe_count: usize,
204 prefers_stochastic: bool,
205 kernel_matches_hinv: bool,
206 plain_spd_logdet: bool,
207 projected_penalty_subspace_active: bool,
208) -> bool {
209 p >= HUTCHINSON_GPU_MIN_P
210 && (HUTCHINSON_GPU_MIN_K..=HUTCHINSON_GPU_MAX_K).contains(&probe_count)
211 && prefers_stochastic
212 && kernel_matches_hinv
213 && plain_spd_logdet
214 && !projected_penalty_subspace_active
215}
216
217#[inline]
226pub fn splitmix64_mix(z: u64) -> u64 {
227 gam_linalg::utils::splitmix64_hash(z)
228}
229
230#[inline]
238pub fn rademacher_entry(seed: u64, k: u64, i: u64) -> f64 {
239 const ZETA: u64 = 0xD1B5_4A32_D192_ED03;
240 const GAMMA: u64 = 0x8CB9_2BA7_2F9D_E81F;
241 let composite = seed ^ k.wrapping_mul(ZETA) ^ i.wrapping_mul(GAMMA);
242 let h = splitmix64_mix(composite);
243 if (h >> 63) == 0 { 1.0 } else { -1.0 }
244}
245
246pub fn fill_rademacher_host(seed: ProbeSeed, p: usize, k: usize, out: &mut [f64]) {
249 assert_eq!(
250 out.len(),
251 p * k,
252 "fill_rademacher_host: out buffer length {} != p*K = {}*{}",
253 out.len(),
254 p,
255 k
256 );
257 for col in 0..k {
258 for row in 0..p {
259 out[col * p + row] = rademacher_entry(seed.0, col as u64, row as u64);
260 }
261 }
262}
263
264pub fn evidence_derivatives_hutchinson_cpu(
275 input: &RemlTraceHutchinsonInput<'_>,
276) -> Result<RemlTraceHutchinsonEvidence, String> {
277 validate_inputs(input)?;
278 let p = input.penalized_hessian.nrows();
279 let d = input.derivatives.len();
280 let k = input.probe_count;
281
282 let h = input.penalized_hessian.to_owned();
284 let factor = cholesky_lower(&h)?;
285 let logdet_hessian = 2.0 * (0..p).map(|i| factor[[i, i]].ln()).sum::<f64>();
286
287 let mut z = vec![0.0_f64; p * k];
289 fill_rademacher_host(input.seed, p, k, &mut z);
290
291 use rayon::prelude::*;
299 let mut w = vec![0.0_f64; p * k];
300 w.par_chunks_mut(p)
301 .zip(z.par_chunks(p))
302 .for_each(|(w_col, z_col)| {
303 let solved = solve_cholesky(&factor, z_col);
304 w_col.copy_from_slice(&solved);
305 });
306
307 let mut q = vec![0.0_f64; d * k]; for (j, derivative) in input.derivatives.iter().enumerate() {
315 let q_row = &mut q[j * k..(j + 1) * k];
316 match derivative {
317 DerivativeHessian::Dense(matrix) => {
318 q_row
319 .par_iter_mut()
320 .zip(z.par_chunks(p).zip(w.par_chunks(p)))
321 .for_each(|(q_jk, (z_col, w_col))| {
322 let mut y = vec![0.0_f64; p];
324 for r in 0..p {
325 let mut acc = 0.0_f64;
326 for c in 0..p {
327 acc += matrix[[r, c]] * w_col[c];
328 }
329 y[r] = acc;
330 }
331 let mut zy = 0.0_f64;
332 for i in 0..p {
333 zy += z_col[i] * y[i];
334 }
335 *q_jk = zy;
336 });
337 }
338 DerivativeHessian::WeightedGram {
339 row_weights,
340 penalty_extra,
341 } => {
342 let design = input.design.as_ref().expect("design validated");
343 let n = design.nrows();
344 q_row
345 .par_iter_mut()
346 .zip(z.par_chunks(p).zip(w.par_chunks(p)))
347 .for_each(|(q_jk, (z_col, w_col))| {
348 let mut acc = 0.0_f64;
350 for row in 0..n {
351 let mut rz = 0.0_f64;
352 let mut rw = 0.0_f64;
353 for col_idx in 0..p {
354 rz += design[[row, col_idx]] * z_col[col_idx];
355 rw += design[[row, col_idx]] * w_col[col_idx];
356 }
357 acc += row_weights[row] * rz * rw;
358 }
359 if let Some(pen) = penalty_extra {
360 for r in 0..p {
361 let mut row_acc = 0.0_f64;
362 for c in 0..p {
363 row_acc += pen[[r, c]] * w_col[c];
364 }
365 acc += z_col[r] * row_acc;
366 }
367 }
368 *q_jk = acc;
369 });
370 }
371 }
372 }
373
374 let (means, stderrs) = reduce_mean_stderr(&q, d, k);
375 let mut gradient_rho_logdet = Array1::<f64>::zeros(d);
376 let mut gradient_rho_stderr = Array1::<f64>::zeros(d);
377 for j in 0..d {
378 gradient_rho_logdet[j] = 0.5 * means[j];
379 gradient_rho_stderr[j] = 0.5 * stderrs[j];
380 }
381
382 Ok(RemlTraceHutchinsonEvidence {
383 logdet_hessian,
384 gradient_rho_logdet,
385 gradient_rho_stderr,
386 probe_count: k,
387 })
388}
389
390pub fn evidence_derivatives_hutchinson_gpu(
400 input: RemlTraceHutchinsonInput<'_>,
401) -> Result<RemlTraceHutchinsonEvidence, String> {
402 validate_inputs(&input)?;
403
404 #[cfg(target_os = "linux")]
405 {
406 if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
407 match linux_cuda::evidence_derivatives(&input) {
408 Ok(evidence) => return Ok(evidence),
409 Err(GpuError::NoDeviceKernel { .. }) => {
410 }
413 Err(other) => return Err(String::from(other)),
414 }
415 }
416 }
417
418 evidence_derivatives_hutchinson_cpu(&input)
419}
420
421pub const HUTCHINSON_ADAPTIVE_REL_TOL: f64 = 0.01;
428pub const HUTCHINSON_ADAPTIVE_TAU_REL: f64 = 1e-8;
431
432pub struct AdaptiveTraceEvidence {
456 pub logdet_hessian: f64,
457 pub traces: Array1<f64>,
458 pub stderrs: Array1<f64>,
461 pub probe_count: usize,
462 pub converged: bool,
463}
464
465pub fn evidence_traces_adaptive<'a>(
466 penalized_hessian: ArrayView2<'a, f64>,
467 derivatives: Vec<DerivativeHessian<'a>>,
468 design: Option<ArrayView2<'a, f64>>,
469 seed: ProbeSeed,
470 rel_tol: f64,
471 tau_rel: f64,
472) -> Result<AdaptiveTraceEvidence, String> {
473 const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
475
476 let d = derivatives.len();
477 if d == 0 {
478 return Err("evidence_traces_adaptive: derivatives is empty".to_string());
479 }
480 if !(rel_tol > 0.0) {
481 return Err(format!(
482 "evidence_traces_adaptive: rel_tol must be > 0 (got {rel_tol})"
483 ));
484 }
485 if !(tau_rel > 0.0) {
486 return Err(format!(
487 "evidence_traces_adaptive: tau_rel must be > 0 (got {tau_rel})"
488 ));
489 }
490
491 let mut last_logdet = 0.0_f64;
492 let mut last_traces = Array1::<f64>::zeros(d);
493 let mut last_stderrs = Array1::<f64>::zeros(d);
494 let mut last_k = 0_usize;
495 let mut converged = false;
496
497 for &k in &SCHEDULE {
498 let input = RemlTraceHutchinsonInput {
499 penalized_hessian,
500 derivatives: derivatives.clone(),
501 design,
502 probe_count: k,
503 seed,
504 };
505 let evidence = evidence_derivatives_hutchinson_gpu(input)?;
506 last_logdet = evidence.logdet_hessian;
507 last_k = k;
508
509 for j in 0..d {
513 last_traces[j] = 2.0 * evidence.gradient_rho_logdet[j];
514 last_stderrs[j] = 2.0 * evidence.gradient_rho_stderr[j];
515 }
516
517 let mut worst = 0.0_f64;
522 for j in 0..d {
523 let denom = last_traces[j].abs().max(tau_rel);
524 let r = last_stderrs[j] / denom;
525 if r > worst {
526 worst = r;
527 }
528 }
529 if worst <= rel_tol {
530 converged = true;
531 break;
532 }
533 }
534
535 Ok(AdaptiveTraceEvidence {
536 logdet_hessian: last_logdet,
537 traces: last_traces,
538 stderrs: last_stderrs,
539 probe_count: last_k,
540 converged,
541 })
542}
543
544pub const PCG_HVP_REL_TOL: f64 = 1e-6;
553
554pub const PCG_HVP_MAX_ITERS: usize = 200;
560
561pub fn evidence_traces_adaptive_hvp<F>(
592 p: usize,
593 mut hvp: F,
594 derivatives: Vec<DerivativeHessian<'_>>,
595 design: Option<ArrayView2<'_, f64>>,
596 seed: ProbeSeed,
597 rel_tol: f64,
598 tau_rel: f64,
599) -> Result<AdaptiveTraceEvidence, String>
600where
601 F: FnMut(&[f64], &mut [f64]),
602{
603 const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
604
605 let d = derivatives.len();
606 if d == 0 {
607 return Err("evidence_traces_adaptive_hvp: derivatives is empty".to_string());
608 }
609 if p == 0 {
610 return Err("evidence_traces_adaptive_hvp: p must be > 0".to_string());
611 }
612 if !(rel_tol > 0.0) {
613 return Err(format!(
614 "evidence_traces_adaptive_hvp: rel_tol must be > 0 (got {rel_tol})"
615 ));
616 }
617 if !(tau_rel > 0.0) {
618 return Err(format!(
619 "evidence_traces_adaptive_hvp: tau_rel must be > 0 (got {tau_rel})"
620 ));
621 }
622
623 let mut last_traces = Array1::<f64>::zeros(d);
624 let mut last_stderrs = Array1::<f64>::zeros(d);
625 let mut last_k = 0_usize;
626 let mut converged = false;
627
628 let mut z = vec![0.0_f64; p];
629 let mut w = vec![0.0_f64; p];
630
631 let mut q_means = vec![0.0_f64; d];
638 let mut q_m2 = vec![0.0_f64; d];
639
640 for &k_target in &SCHEDULE {
641 for s in q_means.iter_mut() {
645 *s = 0.0;
646 }
647 for s in q_m2.iter_mut() {
648 *s = 0.0;
649 }
650
651 for k_idx in 0..k_target {
652 for i in 0..p {
654 z[i] = rademacher_entry(seed.0, k_idx as u64, i as u64);
655 }
656 cg_solve(&mut hvp, &z, &mut w, PCG_HVP_REL_TOL, PCG_HVP_MAX_ITERS);
658
659 for j in 0..d {
662 let q = match &derivatives[j] {
663 DerivativeHessian::Dense(matrix) => {
664 let mut y = 0.0_f64;
665 for r in 0..p {
666 let mut hr_w = 0.0_f64;
667 for c in 0..p {
668 hr_w += matrix[[r, c]] * w[c];
669 }
670 y += z[r] * hr_w;
671 }
672 y
673 }
674 DerivativeHessian::WeightedGram {
675 row_weights,
676 penalty_extra,
677 } => {
678 let design_view = design.as_ref().ok_or_else(|| {
679 "evidence_traces_adaptive_hvp: WeightedGram derivative requires \
680 design matrix"
681 .to_string()
682 })?;
683 let n = design_view.nrows();
684 let mut acc = 0.0_f64;
685 for row in 0..n {
686 let mut rz = 0.0_f64;
687 let mut rw = 0.0_f64;
688 for ci in 0..p {
689 rz += design_view[[row, ci]] * z[ci];
690 rw += design_view[[row, ci]] * w[ci];
691 }
692 acc += row_weights[row] * rz * rw;
693 }
694 if let Some(pen) = penalty_extra {
695 for r in 0..p {
696 let mut row_acc = 0.0_f64;
697 for c in 0..p {
698 row_acc += pen[[r, c]] * w[c];
699 }
700 acc += z[r] * row_acc;
701 }
702 }
703 acc
704 }
705 };
706 let count = (k_idx + 1) as f64;
708 let delta = q - q_means[j];
709 q_means[j] += delta / count;
710 let delta2 = q - q_means[j];
711 q_m2[j] += delta * delta2;
712 }
713 }
714
715 let n = k_target as f64;
716 let mut worst_ratio = 0.0_f64;
717 for j in 0..d {
718 let mean = q_means[j];
719 let var = if n > 1.0 { q_m2[j] / (n - 1.0) } else { 0.0 };
723 let se = var.sqrt() / n.sqrt();
724 last_traces[j] = mean;
725 last_stderrs[j] = se;
726 let denom = mean.abs().max(tau_rel);
727 let r = se / denom;
728 if r > worst_ratio {
729 worst_ratio = r;
730 }
731 }
732 last_k = k_target;
733 if worst_ratio <= rel_tol {
734 converged = true;
735 break;
736 }
737 }
738
739 Ok(AdaptiveTraceEvidence {
740 logdet_hessian: f64::NAN,
741 traces: last_traces,
742 stderrs: last_stderrs,
743 probe_count: last_k,
744 converged,
745 })
746}
747
748fn cg_solve<F>(hvp: &mut F, b: &[f64], w: &mut [f64], rel_tol: f64, max_iters: usize)
769where
770 F: FnMut(&[f64], &mut [f64]),
771{
772 let n = b.len();
773 assert!(w.len() == n);
774
775 let rhs = ArrayView1::from(b);
776 let precond = Array1::<f64>::ones(n);
777 let mut solution = ArrayViewMut1::from(w);
778
779 pcg_core(
780 |v: &Array1<f64>, out: &mut Array1<f64>| {
781 let v_slice = v.as_slice().expect("contiguous CG direction view");
783 let out_slice = out.as_slice_mut().expect("contiguous CG matvec view");
784 hvp(v_slice, out_slice);
785 },
786 &rhs,
787 &precond.view(),
788 rel_tol,
789 max_iters,
790 0,
791 false,
792 DotReduction::Reordered,
793 &mut solution,
794 );
795}
796
797#[must_use]
818pub fn should_bypass_cpu_with_gpu_adaptive(
819 p: usize,
820 dense_spd_h_resident: bool,
821 plain_spd_logdet: bool,
822 prefers_stochastic: bool,
823 projected_penalty_subspace_active: bool,
824) -> bool {
825 p >= HUTCHINSON_GPU_MIN_P
826 && dense_spd_h_resident
827 && plain_spd_logdet
828 && prefers_stochastic
829 && !projected_penalty_subspace_active
830}
831
832#[cfg(target_os = "linux")]
837mod linux_cuda {
838 use super::{
839 DerivativeHessian, ProbeSeed, RemlTraceHutchinsonEvidence, RemlTraceHutchinsonInput,
840 reduce_mean_stderr,
841 };
842 use cudarc::cublas::sys::cublasOperation_t;
843 use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
844 use cudarc::cusolver::DnHandle;
845 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
846 use gam_gpu::driver::to_col_major;
847 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
848 use gam_gpu::solver::{
849 cholesky_logdet_from_col_major, context_and_stream, pinned_htod, potrf_in_place,
850 potrs_in_place,
851 };
852 use std::sync::Arc;
853
854 pub(super) const PTX_SOURCE: &str = r#"
869extern "C" __device__ unsigned long long splitmix64_mix(unsigned long long z) {
870 z += 0x9E3779B97F4A7C15ULL;
871 unsigned long long x = z;
872 x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
873 x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
874 return x ^ (x >> 31);
875}
876
877extern "C" __global__ void fill_rademacher_splitmix(
878 unsigned long long seed,
879 unsigned int p,
880 unsigned int K,
881 double* __restrict__ Z)
882{
883 unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
884 unsigned int k = blockIdx.y;
885 if (i >= p || k >= K) return;
886 const unsigned long long ZETA = 0xD1B54A32D192ED03ULL;
887 const unsigned long long GAMMA = 0x8CB92BA72F9DE81FULL;
888 unsigned long long composite =
889 seed
890 ^ (((unsigned long long)k) * ZETA)
891 ^ (((unsigned long long)i) * GAMMA);
892 unsigned long long h = splitmix64_mix(composite);
893 double v = (h >> 63) == 0 ? 1.0 : -1.0;
894 Z[(size_t)k * (size_t)p + (size_t)i] = v;
895}
896
897extern "C" __device__ double block_reduce_sum(double v) {
898 __shared__ double smem[32];
899 int lane = threadIdx.x & 31;
900 int wid = threadIdx.x >> 5;
901 for (int off = 16; off > 0; off >>= 1) {
902 v += __shfl_down_sync(0xffffffff, v, off);
903 }
904 if (lane == 0) smem[wid] = v;
905 __syncthreads();
906 double total = 0.0;
907 int n_warps = (blockDim.x + 31) >> 5;
908 if (threadIdx.x < (unsigned)n_warps) total = smem[threadIdx.x];
909 if (wid == 0) {
910 for (int off = 16; off > 0; off >>= 1) {
911 total += __shfl_down_sync(0xffffffff, total, off);
912 }
913 }
914 return total;
915}
916
917extern "C" __global__ void reduce_q_dense(
918 unsigned int p,
919 unsigned int K,
920 unsigned int D,
921 const double* __restrict__ Z,
922 const double* __restrict__ Y_stack,
923 double* __restrict__ Q)
924{
925 unsigned int k = blockIdx.x;
926 unsigned int j = blockIdx.y;
927 if (k >= K || j >= D) return;
928 const double* z_col = Z + (size_t)k * (size_t)p;
929 const double* y_col = Y_stack + ((size_t)j * (size_t)K + (size_t)k) * (size_t)p;
930 double partial = 0.0;
931 for (unsigned int i = threadIdx.x; i < p; i += blockDim.x) {
932 partial += z_col[i] * y_col[i];
933 }
934 double total = block_reduce_sum(partial);
935 if (threadIdx.x == 0) {
936 Q[(size_t)j * (size_t)K + (size_t)k] = total;
937 }
938}
939
940extern "C" __global__ void reduce_q_weighted_gram(
941 unsigned int n,
942 unsigned int K,
943 unsigned int D,
944 const double* __restrict__ RZ,
945 const double* __restrict__ RW,
946 const double* __restrict__ A_stack,
947 double* __restrict__ Q)
948{
949 unsigned int k = blockIdx.x;
950 unsigned int j = blockIdx.y;
951 if (k >= K || j >= D) return;
952 const double* rz_col = RZ + (size_t)k * (size_t)n;
953 const double* rw_col = RW + (size_t)k * (size_t)n;
954 const double* a_col = A_stack + (size_t)j * (size_t)n;
955 double partial = 0.0;
956 for (unsigned int i = threadIdx.x; i < n; i += blockDim.x) {
957 partial += a_col[i] * rz_col[i] * rw_col[i];
958 }
959 double total = block_reduce_sum(partial);
960 if (threadIdx.x == 0) {
961 Q[(size_t)j * (size_t)K + (size_t)k] = total;
962 }
963}
964"#;
965
966 const THREADS_PER_BLOCK: u32 = 256;
967
968 fn module(ctx: &Arc<CudaContext>) -> Result<&'static Arc<CudaModule>, GpuError> {
969 static CACHE: gam_gpu::device_cache::PtxModuleCache =
970 gam_gpu::device_cache::PtxModuleCache::new();
971 CACHE.get_or_compile(ctx, "reml_trace", PTX_SOURCE)
972 }
973
974 pub(super) fn evidence_derivatives(
975 input: &RemlTraceHutchinsonInput<'_>,
976 ) -> Result<RemlTraceHutchinsonEvidence, GpuError> {
977 let p = input.penalized_hessian.nrows();
978 let d = input.derivatives.len();
979 let k = input.probe_count;
980 let (ctx, stream) =
981 context_and_stream().map_err(|reason| GpuError::DriverCallFailed { reason })?;
982 let solver = DnHandle::new(stream.clone()).gpu_ctx("reml_trace cusolver init")?;
983 let blas = CudaBlas::new(stream.clone()).gpu_ctx("reml_trace cublas init")?;
984 let compiled = module(&ctx)?;
985 let module_handle: &Arc<CudaModule> = compiled;
986
987 let h_col = to_col_major(&input.penalized_hessian);
989 let mut h_dev =
990 pinned_htod(&stream, &h_col).map_err(|reason| GpuError::DriverCallFailed { reason })?;
991 potrf_in_place(&solver, &stream, p, &mut h_dev)
992 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
993 let factor_col = stream
994 .clone_dtoh(&h_dev)
995 .gpu_ctx("reml_trace download factor")?;
996 let logdet_hessian = cholesky_logdet_from_col_major(&factor_col, p);
997
998 let total_z = p
1000 .checked_mul(k)
1001 .ok_or_else(|| gam_gpu::gpu_err!("reml_trace Z size overflow: p={p}, K={k}"))?;
1002 let mut z_dev = stream
1003 .alloc_zeros::<f64>(total_z)
1004 .gpu_ctx("reml_trace alloc Z")?;
1005 launch_fill_rademacher(&stream, module_handle, input.seed, p, k, &mut z_dev)?;
1006
1007 let mut w_dev = stream
1010 .alloc_zeros::<f64>(total_z)
1011 .gpu_ctx("reml_trace alloc W")?;
1012 copy_device_slice(&stream, &z_dev, &mut w_dev)?;
1013 potrs_in_place(&solver, &stream, p, k, &h_dev, &mut w_dev)
1014 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1015
1016 let mut dense_indices: Vec<usize> = Vec::new();
1018 let mut gram_indices: Vec<usize> = Vec::new();
1019 for (j, deriv) in input.derivatives.iter().enumerate() {
1020 match deriv {
1021 DerivativeHessian::Dense(_) => dense_indices.push(j),
1022 DerivativeHessian::WeightedGram { .. } => gram_indices.push(j),
1023 }
1024 }
1025
1026 let mut q_host = vec![0.0_f64; d * k];
1027
1028 if !dense_indices.is_empty() {
1033 for &j in &dense_indices {
1034 let DerivativeHessian::Dense(matrix) = &input.derivatives[j] else {
1035 panic!(
1042 "reml_trace dense path: derivative index {j} is in dense_indices but \
1043 input.derivatives[{j}] is not DerivativeHessian::Dense — \
1044 dense_indices partition invariant violated"
1045 );
1046 };
1047 let hj_col = to_col_major(matrix);
1048 let hj_dev = pinned_htod(&stream, &hj_col)
1049 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1050 let mut y_dev = stream
1051 .alloc_zeros::<f64>(total_z)
1052 .map_err(|err| gam_gpu::gpu_err!("reml_trace alloc Y_j (j={j}): {err}"))?;
1053 gemm_nn(
1054 &blas,
1055 GemmShape {
1056 m: p,
1057 n: k,
1058 k_inner: p,
1059 lda: p,
1060 ldb: p,
1061 ldc: p,
1062 },
1063 &hj_dev,
1064 &w_dev,
1065 &mut y_dev,
1066 )?;
1067 let mut q_j_dev = stream
1068 .alloc_zeros::<f64>(k)
1069 .gpu_ctx_with(|err| format!("reml_trace alloc Q_j (j={j}): {err}"))?;
1070 launch_reduce_q_dense(
1071 &stream,
1072 module_handle,
1073 p,
1074 k,
1075 1,
1076 &z_dev,
1077 &y_dev,
1078 &mut q_j_dev,
1079 )?;
1080 let q_host_j = stream
1081 .clone_dtoh(&q_j_dev)
1082 .gpu_ctx_with(|err| format!("reml_trace download Q_j (j={j}): {err}"))?;
1083 q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_j);
1084 }
1085 }
1086
1087 if !gram_indices.is_empty() {
1090 let design = input
1091 .design
1092 .as_ref()
1093 .ok_or_else(|| GpuError::DriverCallFailed {
1094 reason: "reml_trace: structural derivative present but design=None".to_string(),
1095 })?;
1096 let n = design.nrows();
1097 let design_col = to_col_major(design);
1098 let x_dev = pinned_htod(&stream, &design_col)
1099 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1100 let mut rz_dev = stream
1101 .alloc_zeros::<f64>(
1102 n.checked_mul(k)
1103 .ok_or_else(|| gam_gpu::gpu_err!("reml_trace RZ overflow: n={n}, K={k}"))?,
1104 )
1105 .gpu_ctx("reml_trace alloc RZ")?;
1106 let mut rw_dev = stream
1107 .alloc_zeros::<f64>(n * k)
1108 .gpu_ctx("reml_trace alloc RW")?;
1109 gemm_nn(
1111 &blas,
1112 GemmShape {
1113 m: n,
1114 n: k,
1115 k_inner: p,
1116 lda: n,
1117 ldb: p,
1118 ldc: n,
1119 },
1120 &x_dev,
1121 &z_dev,
1122 &mut rz_dev,
1123 )?;
1124 gemm_nn(
1126 &blas,
1127 GemmShape {
1128 m: n,
1129 n: k,
1130 k_inner: p,
1131 lda: n,
1132 ldb: p,
1133 ldc: n,
1134 },
1135 &x_dev,
1136 &w_dev,
1137 &mut rw_dev,
1138 )?;
1139
1140 let d_gram = gram_indices.len();
1142 let mut a_stack = Vec::<f64>::with_capacity(n * d_gram);
1143 for &j in &gram_indices {
1144 let DerivativeHessian::WeightedGram { row_weights, .. } = &input.derivatives[j]
1145 else {
1146 panic!(
1152 "reml_trace structural path: derivative index {j} is in gram_indices \
1153 but input.derivatives[{j}] is not DerivativeHessian::WeightedGram — \
1154 gram_indices partition invariant violated"
1155 );
1156 };
1157 let slice = row_weights.as_slice().ok_or_else(|| {
1158 gam_gpu::gpu_err!("reml_trace structural H_j={j} row_weights not contiguous")
1159 })?;
1160 a_stack.extend_from_slice(slice);
1161 }
1162 let a_dev = pinned_htod(&stream, &a_stack)
1163 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1164 let mut q_dev = stream
1165 .alloc_zeros::<f64>(d_gram * k)
1166 .map_err(|err| gam_gpu::gpu_err!("reml_trace alloc Q_gram: {err}"))?;
1167 launch_reduce_q_weighted_gram(
1168 &stream,
1169 module_handle,
1170 n,
1171 k,
1172 d_gram,
1173 &rz_dev,
1174 &rw_dev,
1175 &a_dev,
1176 &mut q_dev,
1177 )?;
1178 let q_host_gram = stream
1179 .clone_dtoh(&q_dev)
1180 .gpu_ctx("reml_trace download Q_gram")?;
1181 for (slot, &j) in gram_indices.iter().enumerate() {
1182 q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_gram[slot * k..(slot + 1) * k]);
1183 }
1184 for &j in &gram_indices {
1188 let DerivativeHessian::WeightedGram { penalty_extra, .. } = &input.derivatives[j]
1189 else {
1190 panic!(
1197 "reml_trace structural penalty_extra: derivative index {j} is in \
1198 gram_indices but input.derivatives[{j}] is not \
1199 DerivativeHessian::WeightedGram — gram_indices partition invariant \
1200 violated"
1201 );
1202 };
1203 if let Some(pen) = penalty_extra {
1204 let z_host = stream
1205 .clone_dtoh(&z_dev)
1206 .gpu_ctx("reml_trace download Z for penalty_extra")?;
1207 let w_host = stream
1208 .clone_dtoh(&w_dev)
1209 .gpu_ctx("reml_trace download W for penalty_extra")?;
1210 for col in 0..k {
1211 let z_col = &z_host[col * p..(col + 1) * p];
1212 let w_col = &w_host[col * p..(col + 1) * p];
1213 let mut acc = 0.0_f64;
1214 for r in 0..p {
1215 let mut row_acc = 0.0_f64;
1216 for c in 0..p {
1217 row_acc += pen[[r, c]] * w_col[c];
1218 }
1219 acc += z_col[r] * row_acc;
1220 }
1221 q_host[j * k + col] += acc;
1222 }
1223 }
1224 }
1225 }
1226
1227 let (means, stderrs) = reduce_mean_stderr(&q_host, d, k);
1228 let mut gradient_rho_logdet = ndarray::Array1::<f64>::zeros(d);
1229 let mut gradient_rho_stderr = ndarray::Array1::<f64>::zeros(d);
1230 for j in 0..d {
1231 gradient_rho_logdet[j] = 0.5 * means[j];
1232 gradient_rho_stderr[j] = 0.5 * stderrs[j];
1233 }
1234
1235 Ok(RemlTraceHutchinsonEvidence {
1236 logdet_hessian,
1237 gradient_rho_logdet,
1238 gradient_rho_stderr,
1239 probe_count: k,
1240 })
1241 }
1242
1243 fn launch_fill_rademacher(
1246 stream: &Arc<CudaStream>,
1247 module: &Arc<CudaModule>,
1248 seed: ProbeSeed,
1249 p: usize,
1250 k: usize,
1251 z: &mut cudarc::driver::CudaSlice<f64>,
1252 ) -> Result<(), GpuError> {
1253 let func = module
1254 .load_function("fill_rademacher_splitmix")
1255 .gpu_ctx("reml_trace load fill_rademacher")?;
1256 let grid_x = ((p as u32) + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1257 let cfg = LaunchConfig {
1258 grid_dim: (grid_x, k as u32, 1),
1259 block_dim: (THREADS_PER_BLOCK, 1, 1),
1260 shared_mem_bytes: 0,
1261 };
1262 let seed_arg: u64 = seed.0;
1263 let p_arg: u32 = p as u32;
1264 let k_arg: u32 = k as u32;
1265 unsafe {
1268 stream
1269 .launch_builder(&func)
1270 .arg(&seed_arg)
1271 .arg(&p_arg)
1272 .arg(&k_arg)
1273 .arg(z)
1274 .launch(cfg)
1275 }
1276 .map(|_| ())
1277 .gpu_ctx("reml_trace launch fill_rademacher")
1278 }
1279
1280 fn launch_reduce_q_dense(
1281 stream: &Arc<CudaStream>,
1282 module: &Arc<CudaModule>,
1283 p: usize,
1284 k: usize,
1285 d: usize,
1286 z: &cudarc::driver::CudaSlice<f64>,
1287 y_stack: &cudarc::driver::CudaSlice<f64>,
1288 q: &mut cudarc::driver::CudaSlice<f64>,
1289 ) -> Result<(), GpuError> {
1290 let func = module
1291 .load_function("reduce_q_dense")
1292 .gpu_ctx("reml_trace load reduce_q_dense")?;
1293 let cfg = LaunchConfig {
1294 grid_dim: (k as u32, d as u32, 1),
1295 block_dim: (THREADS_PER_BLOCK, 1, 1),
1296 shared_mem_bytes: 0,
1297 };
1298 let p_arg: u32 = p as u32;
1299 let k_arg: u32 = k as u32;
1300 let d_arg: u32 = d as u32;
1301 unsafe {
1304 stream
1305 .launch_builder(&func)
1306 .arg(&p_arg)
1307 .arg(&k_arg)
1308 .arg(&d_arg)
1309 .arg(z)
1310 .arg(y_stack)
1311 .arg(q)
1312 .launch(cfg)
1313 }
1314 .map(|_| ())
1315 .gpu_ctx("reml_trace launch reduce_q_dense")
1316 }
1317
1318 fn launch_reduce_q_weighted_gram(
1319 stream: &Arc<CudaStream>,
1320 module: &Arc<CudaModule>,
1321 n: usize,
1322 k: usize,
1323 d: usize,
1324 rz: &cudarc::driver::CudaSlice<f64>,
1325 rw: &cudarc::driver::CudaSlice<f64>,
1326 a_stack: &cudarc::driver::CudaSlice<f64>,
1327 q: &mut cudarc::driver::CudaSlice<f64>,
1328 ) -> Result<(), GpuError> {
1329 let func = module
1330 .load_function("reduce_q_weighted_gram")
1331 .gpu_ctx("reml_trace load reduce_q_weighted_gram")?;
1332 let cfg = LaunchConfig {
1333 grid_dim: (k as u32, d as u32, 1),
1334 block_dim: (THREADS_PER_BLOCK, 1, 1),
1335 shared_mem_bytes: 0,
1336 };
1337 let n_arg: u32 = n as u32;
1338 let k_arg: u32 = k as u32;
1339 let d_arg: u32 = d as u32;
1340 unsafe {
1342 stream
1343 .launch_builder(&func)
1344 .arg(&n_arg)
1345 .arg(&k_arg)
1346 .arg(&d_arg)
1347 .arg(rz)
1348 .arg(rw)
1349 .arg(a_stack)
1350 .arg(q)
1351 .launch(cfg)
1352 }
1353 .map(|_| ())
1354 .gpu_ctx("reml_trace launch reduce_q_weighted_gram")
1355 }
1356
1357 fn copy_device_slice(
1358 stream: &Arc<CudaStream>,
1359 src: &cudarc::driver::CudaSlice<f64>,
1360 dst: &mut cudarc::driver::CudaSlice<f64>,
1361 ) -> Result<(), GpuError> {
1362 stream.memcpy_dtod(src, dst).gpu_ctx("reml_trace dtod copy")
1363 }
1364
1365 struct GemmShape {
1366 m: usize,
1367 n: usize,
1368 k_inner: usize,
1369 lda: usize,
1370 ldb: usize,
1371 ldc: usize,
1372 }
1373
1374 fn gemm_nn(
1375 blas: &CudaBlas,
1376 shape: GemmShape,
1377 a: &cudarc::driver::CudaSlice<f64>,
1378 b: &cudarc::driver::CudaSlice<f64>,
1379 c: &mut cudarc::driver::CudaSlice<f64>,
1380 ) -> Result<(), GpuError> {
1381 let GemmShape {
1382 m,
1383 n,
1384 k_inner,
1385 lda,
1386 ldb,
1387 ldc,
1388 } = shape;
1389 let cfg = GemmConfig::<f64> {
1390 transa: cublasOperation_t::CUBLAS_OP_N,
1391 transb: cublasOperation_t::CUBLAS_OP_N,
1392 m: m as i32,
1393 n: n as i32,
1394 k: k_inner as i32,
1395 alpha: 1.0,
1396 lda: lda as i32,
1397 ldb: ldb as i32,
1398 beta: 0.0,
1399 ldc: ldc as i32,
1400 };
1401 unsafe { blas.gemm(cfg, a, b, c) }.gpu_ctx("reml_trace cublas dgemm")
1404 }
1405}
1406
1407fn validate_inputs(input: &RemlTraceHutchinsonInput<'_>) -> Result<(), String> {
1412 let (p, p2) = input.penalized_hessian.dim();
1413 if p == 0 || p != p2 {
1414 return Err(format!("reml_trace input H must be square, got {p}x{p2}"));
1415 }
1416 if input.probe_count < 2 {
1417 return Err(format!(
1418 "reml_trace requires probe_count >= 2 for a sample SE, got {}",
1419 input.probe_count
1420 ));
1421 }
1422 let needs_design = input
1423 .derivatives
1424 .iter()
1425 .any(|d| matches!(d, DerivativeHessian::WeightedGram { .. }));
1426 if needs_design && input.design.is_none() {
1427 return Err("reml_trace: structural derivative present but design=None".to_string());
1428 }
1429 let n = input.design.as_ref().map(|x| x.nrows()).unwrap_or(0);
1430 if let Some(x) = input.design.as_ref()
1431 && x.ncols() != p
1432 {
1433 return Err(format!(
1434 "reml_trace design has {} columns, expected p={p}",
1435 x.ncols()
1436 ));
1437 }
1438 for (j, derivative) in input.derivatives.iter().enumerate() {
1439 derivative
1440 .dim_p(p, n)
1441 .map_err(String::from)
1442 .map_err(|e| format!("reml_trace derivative {j}: {e}"))?;
1443 }
1444 Ok(())
1445}
1446
1447fn reduce_mean_stderr(q: &[f64], d: usize, k: usize) -> (Vec<f64>, Vec<f64>) {
1452 assert_eq!(
1453 q.len(),
1454 d * k,
1455 "reduce_mean_stderr: q buffer length {} != D*K = {}*{}",
1456 q.len(),
1457 d,
1458 k
1459 );
1460 let mut means = vec![0.0_f64; d];
1461 let mut stderrs = vec![0.0_f64; d];
1462 let inv_k = 1.0 / (k as f64);
1463 for j in 0..d {
1464 let row = &q[j * k..(j + 1) * k];
1465 let mean = row.iter().copied().sum::<f64>() * inv_k;
1466 means[j] = mean;
1467 if k >= 2 {
1468 let var = row.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / ((k - 1) as f64);
1469 stderrs[j] = (var / (k as f64)).sqrt();
1470 }
1471 }
1472 (means, stderrs)
1473}
1474
1475fn cholesky_lower(matrix: &Array2<f64>) -> Result<Array2<f64>, String> {
1478 let n = matrix.nrows();
1479 let mut l = Array2::<f64>::zeros((n, n));
1480 for i in 0..n {
1481 for j in 0..=i {
1482 let mut sum = matrix[[i, j]];
1483 for k in 0..j {
1484 sum -= l[[i, k]] * l[[j, k]];
1485 }
1486 if i == j {
1487 if sum <= 0.0 {
1488 return Err(format!(
1489 "reml_trace CPU Cholesky: non-SPD diagonal {sum} at row {i}"
1490 ));
1491 }
1492 l[[i, j]] = sum.sqrt();
1493 } else {
1494 l[[i, j]] = sum / l[[j, j]];
1495 }
1496 }
1497 }
1498 Ok(l)
1499}
1500
1501fn solve_cholesky(l: &Array2<f64>, rhs: &[f64]) -> Vec<f64> {
1502 let n = l.nrows();
1503 let mut y = vec![0.0_f64; n];
1504 for i in 0..n {
1505 let mut sum = rhs[i];
1506 for k in 0..i {
1507 sum -= l[[i, k]] * y[k];
1508 }
1509 y[i] = sum / l[[i, i]];
1510 }
1511 let mut x = vec![0.0_f64; n];
1512 for i in (0..n).rev() {
1513 let mut sum = y[i];
1514 for k in (i + 1)..n {
1515 sum -= l[[k, i]] * x[k];
1516 }
1517 x[i] = sum / l[[i, i]];
1518 }
1519 x
1520}
1521
1522#[cfg(test)]
1527mod tests {
1528 use super::*;
1529 use ndarray::{Array2, ArrayView2};
1530
1531 fn make_spd(p: usize, jitter: f64) -> Array2<f64> {
1532 let mut h = Array2::<f64>::zeros((p, p));
1533 for i in 0..p {
1534 for j in 0..p {
1535 h[[i, j]] = if i == j {
1536 p as f64 + jitter
1537 } else {
1538 1.0 / (1.0 + (i as f64 - j as f64).abs())
1539 };
1540 }
1541 }
1542 h
1543 }
1544
1545 fn random_dense_sym(p: usize, seed: u64) -> Array2<f64> {
1546 let mut a = Array2::<f64>::zeros((p, p));
1547 let mut s = seed;
1548 for i in 0..p {
1549 for j in i..p {
1550 s = splitmix64_mix(s.wrapping_add(1));
1551 let v = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
1552 a[[i, j]] = v;
1553 a[[j, i]] = v;
1554 }
1555 }
1556 a
1557 }
1558
1559 fn exact_trace_hinv_a(h: ArrayView2<f64>, a: ArrayView2<f64>) -> f64 {
1560 let p = h.nrows();
1561 let factor = cholesky_lower(&h.to_owned()).expect("SPD");
1562 let mut trace = 0.0;
1563 for col in 0..p {
1564 let mut e = vec![0.0_f64; p];
1565 e[col] = 1.0;
1566 let w = solve_cholesky(&factor, &e);
1567 let mut diag = 0.0;
1569 for i in 0..p {
1570 diag += a[[col, i]] * w[i];
1571 }
1572 trace += diag;
1573 }
1574 trace
1575 }
1576
1577 #[test]
1578 fn splitmix_is_deterministic_and_disperses() {
1579 assert_eq!(splitmix64_mix(42), splitmix64_mix(42));
1582 let mut bits_seen = 0u64;
1583 for x in 0u64..64 {
1584 bits_seen |= splitmix64_mix(x);
1585 }
1586 assert_eq!(
1587 bits_seen,
1588 u64::MAX,
1589 "splitmix should cover every bit position across 64 inputs"
1590 );
1591 }
1592
1593 #[test]
1594 fn rademacher_entries_are_pm_one_and_stateless() {
1595 let seed = ProbeSeed(0xCAFE_BABE);
1596 for k in 0..16u64 {
1597 for i in 0..32u64 {
1598 let v = rademacher_entry(seed.0, k, i);
1599 assert!(
1600 v == 1.0 || v == -1.0,
1601 "non-pm1 entry at (k={k}, i={i}): {v}"
1602 );
1603 let v2 = rademacher_entry(seed.0, k, i);
1604 assert_eq!(v, v2, "same (k,i) must hash to same value");
1605 }
1606 }
1607 }
1608
1609 #[test]
1610 fn rademacher_common_random_numbers_match_for_prefix() {
1611 let p = 50;
1613 let mut z16 = vec![0.0_f64; p * 16];
1614 let mut z32 = vec![0.0_f64; p * 32];
1615 fill_rademacher_host(ProbeSeed(7), p, 16, &mut z16);
1616 fill_rademacher_host(ProbeSeed(7), p, 32, &mut z32);
1617 for col in 0..16 {
1618 for row in 0..p {
1619 assert_eq!(
1620 z16[col * p + row],
1621 z32[col * p + row],
1622 "CRN broken at (col={col}, row={row})"
1623 );
1624 }
1625 }
1626 }
1627
1628 #[test]
1629 fn cpu_hutchinson_unbiased_against_exact_small_spd() {
1630 let p = 16;
1631 let h = make_spd(p, 0.5);
1632 let a1 = random_dense_sym(p, 0x1234);
1633 let a2 = random_dense_sym(p, 0x5678);
1634 let exact1 = exact_trace_hinv_a(h.view(), a1.view());
1635 let exact2 = exact_trace_hinv_a(h.view(), a2.view());
1636 let input = RemlTraceHutchinsonInput {
1637 penalized_hessian: h.view(),
1638 derivatives: vec![
1639 DerivativeHessian::Dense(a1.view()),
1640 DerivativeHessian::Dense(a2.view()),
1641 ],
1642 design: None,
1643 probe_count: 4096,
1644 seed: ProbeSeed(0xCAFE_BABE),
1645 };
1646 let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
1647 let est1 = 2.0 * evidence.gradient_rho_logdet[0];
1649 let est2 = 2.0 * evidence.gradient_rho_logdet[1];
1650 let se1 = 2.0 * evidence.gradient_rho_stderr[0];
1653 let se2 = 2.0 * evidence.gradient_rho_stderr[1];
1654 let tol1 = 6.0 * se1.max(1e-8);
1655 let tol2 = 6.0 * se2.max(1e-8);
1656 assert!(
1657 (est1 - exact1).abs() <= tol1,
1658 "Hutchinson est {est1} too far from exact {exact1} (tol={tol1}, se={})",
1659 evidence.gradient_rho_stderr[0]
1660 );
1661 assert!(
1662 (est2 - exact2).abs() <= tol2,
1663 "Hutchinson est {est2} too far from exact {exact2} (tol={tol2})"
1664 );
1665 }
1666
1667 #[test]
1668 fn structural_path_matches_dense_for_xtwx() {
1669 let n = 40;
1672 let p = 8;
1673 let mut x = Array2::<f64>::zeros((n, p));
1674 let mut s = 11u64;
1675 for r in 0..n {
1676 for c in 0..p {
1677 s = splitmix64_mix(s.wrapping_add(1));
1678 x[[r, c]] = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
1679 }
1680 }
1681 let a: Vec<f64> = (0..n).map(|i| 0.5 + 0.01 * (i as f64)).collect();
1682 let a_arr = ndarray::Array1::from(a);
1683 let mut hj_dense = Array2::<f64>::zeros((p, p));
1685 for r in 0..p {
1686 for c in 0..p {
1687 let mut acc = 0.0;
1688 for i in 0..n {
1689 acc += x[[i, r]] * a_arr[i] * x[[i, c]];
1690 }
1691 hj_dense[[r, c]] = acc;
1692 }
1693 }
1694 let mut h = make_spd(p, 1.0);
1696 for i in 0..p {
1697 h[[i, i]] += 1.0;
1698 }
1699 let input_dense = RemlTraceHutchinsonInput {
1700 penalized_hessian: h.view(),
1701 derivatives: vec![DerivativeHessian::Dense(hj_dense.view())],
1702 design: None,
1703 probe_count: 32,
1704 seed: ProbeSeed(123),
1705 };
1706 let input_struct = RemlTraceHutchinsonInput {
1707 penalized_hessian: h.view(),
1708 derivatives: vec![DerivativeHessian::WeightedGram {
1709 row_weights: a_arr.view(),
1710 penalty_extra: None,
1711 }],
1712 design: Some(x.view()),
1713 probe_count: 32,
1714 seed: ProbeSeed(123),
1715 };
1716 let e_dense = evidence_derivatives_hutchinson_cpu(&input_dense).expect("ok");
1717 let e_struct = evidence_derivatives_hutchinson_cpu(&input_struct).expect("ok");
1718 assert!(
1720 (e_dense.gradient_rho_logdet[0] - e_struct.gradient_rho_logdet[0]).abs() < 1e-9,
1721 "dense vs structural mismatch: dense={}, struct={}",
1722 e_dense.gradient_rho_logdet[0],
1723 e_struct.gradient_rho_logdet[0]
1724 );
1725 }
1726
1727 #[test]
1728 fn finite_difference_check_against_logdet() {
1729 let p = 10;
1731 let h0 = make_spd(p, 0.2);
1732 let a = random_dense_sym(p, 0xABCD);
1733 let eps = 1e-4;
1734 let mut hp = h0.clone();
1735 let mut hm = h0.clone();
1736 for i in 0..p {
1737 for j in 0..p {
1738 hp[[i, j]] += eps * a[[i, j]];
1739 hm[[i, j]] -= eps * a[[i, j]];
1740 }
1741 }
1742 let ld = |m: &Array2<f64>| -> f64 {
1743 let l = cholesky_lower(m).unwrap();
1744 2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
1745 };
1746 let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
1747 let exact = exact_trace_hinv_a(h0.view(), a.view());
1748 assert!(
1749 (fd - exact).abs() / exact.abs().max(1e-12) < 1e-6,
1750 "FD logdet derivative {fd} != exact trace {exact}"
1751 );
1752 let input = RemlTraceHutchinsonInput {
1754 penalized_hessian: h0.view(),
1755 derivatives: vec![DerivativeHessian::Dense(a.view())],
1756 design: None,
1757 probe_count: 4096,
1758 seed: ProbeSeed(0xAA55),
1759 };
1760 let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
1761 let se = evidence.gradient_rho_stderr[0];
1763 let tol = 8.0 * se.max(1e-8);
1764 assert!(
1765 (evidence.gradient_rho_logdet[0] - 0.5 * exact).abs() < tol,
1766 "Hutchinson gradient {} not within 8·SE of 0.5·exact={}",
1767 evidence.gradient_rho_logdet[0],
1768 0.5 * exact
1769 );
1770 }
1771
1772 #[test]
1773 fn gate_rejects_below_min_p() {
1774 assert!(!should_use_gpu_hutchinson(64, 16, true, true, true, false));
1775 }
1776
1777 #[test]
1778 fn gate_rejects_k_out_of_range() {
1779 assert!(!should_use_gpu_hutchinson(2000, 4, true, true, true, false));
1780 assert!(!should_use_gpu_hutchinson(
1781 2000, 200, true, true, true, false
1782 ));
1783 }
1784
1785 #[test]
1786 fn gate_rejects_when_subspace_active() {
1787 assert!(!should_use_gpu_hutchinson(2000, 16, true, true, true, true));
1788 }
1789
1790 #[test]
1791 fn gate_accepts_canonical_case() {
1792 assert!(should_use_gpu_hutchinson(2000, 16, true, true, true, false));
1793 }
1794
1795 #[test]
1804 fn block_2_6_adaptive_unbiased_against_exact_p512() {
1805 let p = 64;
1808 let h = make_spd(p, 0.5);
1809 let a = random_dense_sym(p, 0xBADC0DE);
1810 let exact = exact_trace_hinv_a(h.view(), a.view());
1811 let evidence = evidence_traces_adaptive(
1812 h.view(),
1813 vec![DerivativeHessian::Dense(a.view())],
1814 None,
1815 ProbeSeed(0xA5A5A5),
1816 HUTCHINSON_ADAPTIVE_REL_TOL,
1817 HUTCHINSON_ADAPTIVE_TAU_REL,
1818 )
1819 .expect("adaptive run ok");
1820 let est = evidence.traces[0];
1821 let se = evidence.stderrs[0];
1822 let tol = (8.0 * se).max(0.05 * exact.abs());
1823 assert!(
1824 (est - exact).abs() <= tol,
1825 "adaptive est {est} far from exact {exact} (tol={tol}, se={se}, K={})",
1826 evidence.probe_count
1827 );
1828 }
1829
1830 #[test]
1831 fn block_2_6_same_probes_cpu_vs_dispatch() {
1832 let p = 32;
1839 let h = make_spd(p, 0.3);
1840 let a = random_dense_sym(p, 0x1357);
1841 let input = RemlTraceHutchinsonInput {
1842 penalized_hessian: h.view(),
1843 derivatives: vec![DerivativeHessian::Dense(a.view())],
1844 design: None,
1845 probe_count: 16,
1846 seed: ProbeSeed(0xBEEF),
1847 };
1848 let cpu = evidence_derivatives_hutchinson_cpu(&input).expect("cpu");
1849 let dispatch = evidence_derivatives_hutchinson_gpu(input).expect("dispatch");
1850 let diff = (cpu.gradient_rho_logdet[0] - dispatch.gradient_rho_logdet[0]).abs();
1851 assert!(
1852 diff < 1e-9,
1853 "same-probes CPU vs GPU dispatch differ: cpu={}, dispatch={}, diff={diff}",
1854 cpu.gradient_rho_logdet[0],
1855 dispatch.gradient_rho_logdet[0]
1856 );
1857 }
1858
1859 #[test]
1860 fn block_2_6_fd_logdet_matches_adaptive() {
1861 let p = 24;
1864 let h = make_spd(p, 0.4);
1865 let a = random_dense_sym(p, 0x2468);
1866 let eps = 1e-4;
1867 let mut hp = h.clone();
1868 let mut hm = h.clone();
1869 for i in 0..p {
1870 for j in 0..p {
1871 hp[[i, j]] += eps * a[[i, j]];
1872 hm[[i, j]] -= eps * a[[i, j]];
1873 }
1874 }
1875 let ld = |m: &Array2<f64>| -> f64 {
1876 let l = cholesky_lower(m).expect("SPD");
1877 2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
1878 };
1879 let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
1880 let evidence = evidence_traces_adaptive(
1881 h.view(),
1882 vec![DerivativeHessian::Dense(a.view())],
1883 None,
1884 ProbeSeed(0x9999),
1885 HUTCHINSON_ADAPTIVE_REL_TOL,
1886 HUTCHINSON_ADAPTIVE_TAU_REL,
1887 )
1888 .expect("adaptive ok");
1889 let est = evidence.traces[0];
1890 let se = evidence.stderrs[0];
1891 let tol = (8.0 * se).max(0.05 * fd.abs());
1892 assert!(
1893 (est - fd).abs() <= tol,
1894 "adaptive trace {est} disagrees with FD logdet derivative {fd} (tol={tol})"
1895 );
1896 }
1897
1898 #[test]
1899 fn block_2_6_k_4096_matches_exact_tightly() {
1900 let p = 40;
1905 let h = make_spd(p, 0.6);
1906 let a = random_dense_sym(p, 0xDEAD);
1907 let exact = exact_trace_hinv_a(h.view(), a.view());
1908 let input = RemlTraceHutchinsonInput {
1909 penalized_hessian: h.view(),
1910 derivatives: vec![DerivativeHessian::Dense(a.view())],
1911 design: None,
1912 probe_count: 4096,
1913 seed: ProbeSeed(0xC0FFEE),
1914 };
1915 let evidence = evidence_derivatives_hutchinson_gpu(input).expect("ok");
1916 let est = 2.0 * evidence.gradient_rho_logdet[0];
1917 let se = 2.0 * evidence.gradient_rho_stderr[0];
1918 let tol = (6.0 * se).max(1e-3 * exact.abs());
1919 assert!(
1920 (est - exact).abs() <= tol,
1921 "K=4096 Hutchinson {est} not within 6·SE of exact {exact} (tol={tol}, se={se})"
1922 );
1923 }
1924
1925 #[test]
1926 fn block_2_6_crn_prefix_match_across_schedule() {
1927 let p = 50;
1933 let seed = ProbeSeed(0x4242_4242);
1934 let mut z16 = vec![0.0_f64; p * 16];
1935 let mut z32 = vec![0.0_f64; p * 32];
1936 let mut z64 = vec![0.0_f64; p * 64];
1937 fill_rademacher_host(seed, p, 16, &mut z16);
1938 fill_rademacher_host(seed, p, 32, &mut z32);
1939 fill_rademacher_host(seed, p, 64, &mut z64);
1940 for col in 0..16 {
1941 for row in 0..p {
1942 assert_eq!(z16[col * p + row], z32[col * p + row]);
1943 assert_eq!(z16[col * p + row], z64[col * p + row]);
1944 }
1945 }
1946 for col in 0..32 {
1947 for row in 0..p {
1948 assert_eq!(z32[col * p + row], z64[col * p + row]);
1949 }
1950 }
1951 }
1952
1953 #[test]
1958 fn block_2_7_hvp_path_matches_dense_adaptive() {
1959 let p = 40;
1963 let h = make_spd(p, 0.7);
1964 let a = random_dense_sym(p, 0xABBA);
1965 let seed = ProbeSeed(0x707);
1966
1967 let dense = evidence_traces_adaptive(
1968 h.view(),
1969 vec![DerivativeHessian::Dense(a.view())],
1970 None,
1971 seed,
1972 HUTCHINSON_ADAPTIVE_REL_TOL,
1973 HUTCHINSON_ADAPTIVE_TAU_REL,
1974 )
1975 .expect("dense ok");
1976
1977 let h_clone = h.clone();
1978 let hvp_evidence = evidence_traces_adaptive_hvp(
1979 p,
1980 |v: &[f64], out: &mut [f64]| {
1981 for r in 0..p {
1982 let mut acc = 0.0_f64;
1983 for c in 0..p {
1984 acc += h_clone[[r, c]] * v[c];
1985 }
1986 out[r] = acc;
1987 }
1988 },
1989 vec![DerivativeHessian::Dense(a.view())],
1990 None,
1991 seed,
1992 HUTCHINSON_ADAPTIVE_REL_TOL,
1993 HUTCHINSON_ADAPTIVE_TAU_REL,
1994 )
1995 .expect("hvp ok");
1996
1997 let exact = exact_trace_hinv_a(h.view(), a.view());
2001 let se_dense = dense.stderrs[0];
2002 let se_hvp = hvp_evidence.stderrs[0];
2003 let tol_dense = (8.0 * se_dense).max(0.05 * exact.abs());
2004 let tol_hvp = (8.0 * se_hvp).max(0.05 * exact.abs());
2005 assert!(
2006 (dense.traces[0] - exact).abs() <= tol_dense,
2007 "dense adaptive {} not near exact {} (tol {})",
2008 dense.traces[0],
2009 exact,
2010 tol_dense
2011 );
2012 assert!(
2013 (hvp_evidence.traces[0] - exact).abs() <= tol_hvp,
2014 "hvp adaptive {} not near exact {} (tol {})",
2015 hvp_evidence.traces[0],
2016 exact,
2017 tol_hvp
2018 );
2019 assert!(hvp_evidence.logdet_hessian.is_nan());
2021 }
2022
2023 #[test]
2024 fn block_2_7_hvp_stderr_matches_dense_reduce_mean_stderr() {
2025 let p = 36;
2034 let h = make_spd(p, 0.6);
2035 let a = random_dense_sym(p, 0x5151);
2036 let seed = ProbeSeed(0xBEEF);
2037 let force_full_schedule = 1e-12_f64;
2038
2039 let dense = evidence_traces_adaptive(
2040 h.view(),
2041 vec![DerivativeHessian::Dense(a.view())],
2042 None,
2043 seed,
2044 force_full_schedule,
2045 HUTCHINSON_ADAPTIVE_TAU_REL,
2046 )
2047 .expect("dense ok");
2048
2049 let h_clone = h.clone();
2050 let hvp = evidence_traces_adaptive_hvp(
2051 p,
2052 |v: &[f64], out: &mut [f64]| {
2053 for r in 0..p {
2054 let mut acc = 0.0_f64;
2055 for c in 0..p {
2056 acc += h_clone[[r, c]] * v[c];
2057 }
2058 out[r] = acc;
2059 }
2060 },
2061 vec![DerivativeHessian::Dense(a.view())],
2062 None,
2063 seed,
2064 force_full_schedule,
2065 HUTCHINSON_ADAPTIVE_TAU_REL,
2066 )
2067 .expect("hvp ok");
2068
2069 assert_eq!(dense.probe_count, 128);
2071 assert_eq!(hvp.probe_count, dense.probe_count);
2072
2073 let sd_dense = dense.stderrs[0];
2074 let sd_hvp = hvp.stderrs[0];
2075 assert!(
2076 sd_dense > 0.0,
2077 "dense SE should be positive, got {sd_dense}"
2078 );
2079 let rel = (sd_hvp - sd_dense).abs() / sd_dense;
2080 assert!(
2081 rel <= 1e-3,
2082 "HVP SE {sd_hvp} disagrees with dense reduce_mean_stderr SE {sd_dense} \
2083 (rel {rel}); the two paths must share the Bessel-corrected (K−1) convention"
2084 );
2085 }
2086
2087 #[test]
2088 fn block_2_7_cg_solves_diagonal_in_one_iteration() {
2089 let p = 8;
2093 let diag: Vec<f64> = (0..p).map(|i| 1.0 + i as f64).collect();
2094 let b: Vec<f64> = (0..p).map(|i| (i as f64) + 0.5).collect();
2095 let mut w = vec![0.0_f64; p];
2096 let diag_clone = diag.clone();
2097 cg_solve(
2098 &mut |v: &[f64], out: &mut [f64]| {
2099 for i in 0..p {
2100 out[i] = diag_clone[i] * v[i];
2101 }
2102 },
2103 &b,
2104 &mut w,
2105 1e-12,
2106 PCG_HVP_MAX_ITERS,
2107 );
2108 for i in 0..p {
2109 let expected = b[i] / diag[i];
2110 assert!(
2111 (w[i] - expected).abs() < 1e-10,
2112 "diagonal CG: w[{i}]={} expected {expected}",
2113 w[i]
2114 );
2115 }
2116 }
2117
2118 #[test]
2131 fn block_2_8_hill_climb_adaptive_vs_exact_at_p2000_d8() {
2132 let on_v100 =
2135 cfg!(target_os = "linux") && gam_gpu::device_runtime::GpuRuntime::global().is_some();
2136 let (p, d): (usize, usize) = if on_v100 { (2000, 8) } else { (256, 4) };
2137
2138 let mut h = Array2::<f64>::zeros((p, p));
2139 for i in 0..p {
2140 for j in 0..p {
2141 h[[i, j]] = if i == j {
2142 p as f64 + 1.0
2143 } else {
2144 1.0 / (1.0 + (i as f64 - j as f64).abs())
2145 };
2146 }
2147 }
2148 let derivs_owned: Vec<Array2<f64>> = (0..d)
2149 .map(|k| random_dense_sym(p, 0x1000 + k as u64))
2150 .collect();
2151 let derivs: Vec<DerivativeHessian<'_>> = derivs_owned
2152 .iter()
2153 .map(|a| DerivativeHessian::Dense(a.view()))
2154 .collect();
2155
2156 let t_exact_start = std::time::Instant::now();
2160 let factor = cholesky_lower(&h).expect("SPD");
2161 let mut exact_traces = vec![0.0_f64; d];
2162 for (j, a) in derivs_owned.iter().enumerate() {
2163 let mut acc = 0.0_f64;
2164 for col in 0..p {
2165 let mut rhs = vec![0.0_f64; p];
2166 for r in 0..p {
2167 rhs[r] = a[[r, col]];
2168 }
2169 let w = solve_cholesky(&factor, &rhs);
2170 acc += w[col];
2171 }
2172 exact_traces[j] = acc;
2173 }
2174 let t_exact = t_exact_start.elapsed();
2175
2176 let t_adaptive_start = std::time::Instant::now();
2178 let evidence = evidence_traces_adaptive(
2179 h.view(),
2180 derivs,
2181 None,
2182 ProbeSeed(0xB10C),
2183 HUTCHINSON_ADAPTIVE_REL_TOL,
2184 HUTCHINSON_ADAPTIVE_TAU_REL,
2185 )
2186 .expect("adaptive ok");
2187 let t_adaptive = t_adaptive_start.elapsed();
2188
2189 for j in 0..d {
2192 let se = evidence.stderrs[j];
2193 let tol = (10.0 * se).max(0.05 * exact_traces[j].abs());
2194 let diff = (evidence.traces[j] - exact_traces[j]).abs();
2195 assert!(
2196 diff <= tol,
2197 "block_2_8: derivative {j} adaptive {} disagrees with exact {} (tol {tol}, diff {diff})",
2198 evidence.traces[j],
2199 exact_traces[j]
2200 );
2201 }
2202
2203 let speedup = t_exact.as_secs_f64() / t_adaptive.as_secs_f64().max(1e-9);
2204 eprintln!(
2205 "block_2_8 hill-climb [p={p}, d={d}, V100={on_v100}]: \
2206 exact={:?}, adaptive={:?}, speedup={:.2}× (K={}, converged={})",
2207 t_exact, t_adaptive, speedup, evidence.probe_count, evidence.converged
2208 );
2209 if on_v100 {
2210 assert!(
2211 speedup >= 10.0,
2212 "block_2_8 V100 speedup {speedup:.2}× below the 10× target \
2213 (exact {:?}, adaptive {:?})",
2214 t_exact,
2215 t_adaptive,
2216 );
2217 }
2218 }
2219}