Skip to main content

Module reml_trace

Module reml_trace 

Source
Expand description

GPU Hutchinson stochastic trace estimator for the REML/LAML logdet gradient, per math team block 2 (sections 12–18 of the V100 design).

Public entry point: evidence_derivatives_hutchinson_gpu. For each derivative Hessian H_j (j = 1..D) and a single penalized Hessian H held resident on device, returns the unbiased Hutchinson estimate of

t_j = tr(H^{-1} H_j)

plus the sample standard error of each estimate, computed from K Rademacher probe vectors z_k ∈ {±1}^p whose entries are drawn from a stateless SplitMix64 counter hash (no cuRAND state). The math identity used on device is

z^T H^{-1} H_j z  =  z^T H_j w   where   H w = z

so we factor H once with cusolverDnDpotrf, batch-solve H W = Z with one cusolverDnDpotrs of nrhs = K, and then evaluate the quadratic forms with a custom NVRTC reduction kernel. The REML logdet gradient is g_j = (1/2) · mean_k(q_{j,k}).

Two assembly variants for H_j are supported:

  • Dense — caller passes H_j as a p × p device or host matrix. GEMM forms Y_j = H_j W, then a custom reduction sums z_k^T y_{j,k} per (j, k). Cost: D GEMMs of size p × p × K.
  • Weighted-Gram structural — caller provides the design X (n × p), weight vectors A_j (n, one per derivative — the diagonal of the design’s row weights that H_j adds), and the per-derivative penalty contribution Q_pen[j,k] if any. The kernel forms R_Z = X Z and R_W = X W once via GEMM and then sums sum_i a_j[i] · R_Z[i,k] · R_W[i,k] per (j, k) without ever materialising the p × p H_j matrix. Cost: 2 GEMMs of size n × p × K shared across all D derivatives.

The structural path is the high-value route for large-scale models where p is hundreds and there are many derivatives.

§Stateless probe RNG

The probe entries are produced on device by a SplitMix64 finalizer over (seed, probe_index k, coordinate i). This has three consequences:

  1. No cuRAND state — the kernel is fully stateless, threads write into Z[i + k·p] independently.
  2. Common random numbers: the first K1 probes of a run with K2 > K1 are bit-identical to a K = K1 run with the same seed. This is the property that lets the adaptive K schedule build on earlier probes without re-running them, and lets CPU and GPU implementations of Hutchinson compare estimator-by-estimator (the same probes produce the same q_{j,k} to round-off).
  3. Reproducibility — a probe at (seed, k, i) is the same call after call regardless of how the grid was scheduled.

§Gating

The companion helper should_use_gpu_hutchinson mirrors the CPU gate (prefers_stochastic_trace_estimation + matching kernel + plain-SPD logdet path) and adds the GPU-specific minima from the math team’s section 18:

  • p ≥ 512
  • K ∈ [8, 128]
  • Hessian and design held resident or about to be uploaded
  • The projected penalty-subspace trace is inactive (otherwise the CPU path projects through the IFT kernel — that route is required for marginal-slope ρ-saturated rows)

Structs§

AdaptiveTraceEvidence
Adaptive-K Hutchinson trace schedule with common random numbers (CRN).
ProbeSeed
Stateless seed for the SplitMix64 Rademacher probe RNG.
RemlTraceHutchinsonEvidence
Output of evidence_derivatives_hutchinson_gpu.
RemlTraceHutchinsonInput
Inputs to evidence_derivatives_hutchinson_gpu.

Enums§

DerivativeHessian
Description of one derivative-Hessian contribution H_j.

Constants§

HUTCHINSON_ADAPTIVE_REL_TOL
Default relative-error target for the adaptive-K stopping rule. Matches StochasticTraceConfig::default().relative_tol.
HUTCHINSON_ADAPTIVE_TAU_REL
Default near-zero-trace protection floor. Matches StochasticTraceConfig::default().tau_rel.
HUTCHINSON_GPU_MAX_K
HUTCHINSON_GPU_MIN_K
Minimum and maximum probe counts the GPU path accepts (math section 18).
HUTCHINSON_GPU_MIN_P
Minimum joint-dimension at which the GPU Hutchinson path is enabled.
PCG_HVP_MAX_ITERS
Maximum CG iterations per probe before we stop and accept the partial solve. Capped so a poorly conditioned H cannot make a single REML step pay unbounded time — the Hutchinson estimator is statistically robust to a few stale w_k values (it inflates SE, which the adaptive stopping rule then catches by extending the schedule).
PCG_HVP_REL_TOL
CG convergence tolerance for the per-probe solve H w = z. The outer adaptive-K loop already drives Hutchinson variance to ~1%; a per-probe relative residual of 1e-6 keeps the CG round-off well below the stochastic SE without paying for double-machine convergence.

Functions§

evidence_derivatives_hutchinson_cpu
Run the Hutchinson estimator on CPU using the exact same probe bits the device kernel uses. Returns the same evidence struct.
evidence_derivatives_hutchinson_gpu
Compute log |H| and the Hutchinson estimate of (1/2) tr(H^{-1} H_j) for every derivative. Dispatches to the device-resident path when the CUDA runtime is up and probes the GPU successfully; otherwise runs the CPU reference. Either way the probe bits are identical (stateless SplitMix), so callers see the same estimator value to round-off.
evidence_traces_adaptive
evidence_traces_adaptive_hvp
Adaptive Hutchinson variant that consumes H as a matrix-free HVP closure rather than a dense ArrayView2. Used by call sites where the penalized Hessian is implicit (operator-only) and forming it densely would blow the memory budget — e.g. the device-resident PCG path in gpu/bms_flex_row.rs or the large-scale BMS Schur operator.
fill_rademacher_host
Host-side reference: fill a column-major (p, K) Rademacher matrix. Used by tests to verify the GPU kernel produces the same bits.
rademacher_entry
Stateless Rademacher entry at probe index k (0-based), coordinate i (0-based), seed s. Returns +1.0 or -1.0.
should_bypass_cpu_with_gpu_adaptive
Composite gate predicate for the outer REML logdet-gradient bypass: when this returns true, the unified evaluator should replace its CPU stochastic-trace call with evidence_traces_adaptive.
should_use_gpu_hutchinson
True when the GPU Hutchinson path is eligible at the current shape and configuration. Caller still has to satisfy the CPU-side gate (prefers_stochastic_trace_estimation, matching kernel, plain-SPD logdet, projected penalty subspace inactive) — the parameters prefers_stochastic, kernel_matches_hinv, plain_spd_logdet, and projected_penalty_subspace_active carry those CPU-side gate booleans into the dispatch decision.
splitmix64_mix
SplitMix64 finalizer (Sebastiano Vigna, 2015). Thin wrapper over the canonical implementation in gam_linalg::utils::splitmix64_hash.