Skip to main content

gam_solve/gpu_kernels/
reml_trace.rs

1//! GPU Hutchinson stochastic trace estimator for the REML/LAML logdet
2//! gradient, per math team block 2 (sections 12–18 of the V100 design).
3//!
4//! Public entry point: [`evidence_derivatives_hutchinson_gpu`]. For each
5//! derivative Hessian `H_j` (`j = 1..D`) and a single penalized Hessian `H`
6//! held resident on device, returns the unbiased Hutchinson estimate of
7//!
8//! ```text
9//! t_j = tr(H^{-1} H_j)
10//! ```
11//!
12//! plus the sample standard error of each estimate, computed from `K`
13//! Rademacher probe vectors `z_k ∈ {±1}^p` whose entries are drawn from a
14//! **stateless SplitMix64 counter hash** (no cuRAND state). The math
15//! identity used on device is
16//!
17//! ```text
18//! z^T H^{-1} H_j z  =  z^T H_j w   where   H w = z
19//! ```
20//!
21//! so we factor `H` **once** with `cusolverDnDpotrf`, batch-solve `H W = Z`
22//! with **one** `cusolverDnDpotrs` of `nrhs = K`, and then evaluate the
23//! quadratic forms with a custom NVRTC reduction kernel. The REML logdet
24//! gradient is `g_j = (1/2) · mean_k(q_{j,k})`.
25//!
26//! Two assembly variants for `H_j` are supported:
27//!
28//! * **Dense** — caller passes `H_j` as a `p × p` device or host matrix.
29//!   GEMM forms `Y_j = H_j W`, then a custom reduction sums
30//!   `z_k^T y_{j,k}` per (j, k). Cost: `D` GEMMs of size `p × p × K`.
31//! * **Weighted-Gram structural** — caller provides the design `X`
32//!   (`n × p`), weight vectors `A_j` (`n`, one per derivative — the
33//!   diagonal of the design's row weights that `H_j` adds), and the
34//!   per-derivative penalty contribution `Q_pen[j,k]` if any. The kernel
35//!   forms `R_Z = X Z` and `R_W = X W` **once** via GEMM and then sums
36//!   `sum_i a_j[i] · R_Z[i,k] · R_W[i,k]` per (j, k) without ever
37//!   materialising the `p × p` `H_j` matrix. Cost: 2 GEMMs of size
38//!   `n × p × K` shared across all `D` derivatives.
39//!
40//! The structural path is the high-value route for large-scale models
41//! where `p` is hundreds and there are many derivatives.
42//!
43//! # Stateless probe RNG
44//!
45//! The probe entries are produced on device by a SplitMix64 finalizer over
46//! `(seed, probe_index k, coordinate i)`. This has three consequences:
47//!
48//! 1. No cuRAND state — the kernel is fully stateless, threads write into
49//!    `Z[i + k·p]` independently.
50//! 2. **Common random numbers**: the first `K1` probes of a run with
51//!    `K2 > K1` are bit-identical to a `K = K1` run with the same seed.
52//!    This is the property that lets the adaptive `K` schedule build on
53//!    earlier probes without re-running them, and lets CPU and GPU
54//!    implementations of Hutchinson compare estimator-by-estimator (the
55//!    same probes produce the same `q_{j,k}` to round-off).
56//! 3. Reproducibility — a probe at `(seed, k, i)` is the same call after
57//!    call regardless of how the grid was scheduled.
58//!
59//! # Gating
60//!
61//! The companion helper [`should_use_gpu_hutchinson`] mirrors the CPU
62//! gate (`prefers_stochastic_trace_estimation` + matching kernel +
63//! plain-SPD logdet path) and adds the GPU-specific minima from the math
64//! team's section 18:
65//!
66//! * `p ≥ 512`
67//! * `K ∈ [8, 128]`
68//! * Hessian and design held resident or about to be uploaded
69//! * The projected penalty-subspace trace is **inactive** (otherwise the
70//!   CPU path projects through the IFT kernel — that route is required
71//!   for marginal-slope ρ-saturated rows)
72
73use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1};
74
75use gam_gpu::gpu_error::GpuError;
76use gam_linalg::pcg::{DotReduction, pcg_core};
77
78// ────────────────────────────────────────────────────────────────────────
79// Public types
80// ────────────────────────────────────────────────────────────────────────
81
82/// Stateless seed for the SplitMix64 Rademacher probe RNG.
83#[derive(Clone, Copy, Debug)]
84pub struct ProbeSeed(pub u64);
85
86impl Default for ProbeSeed {
87    fn default() -> Self {
88        // Matches the CPU default seed (`StochasticTraceConfig::default()`)
89        // so cross-implementation parity tests can use a shared constant.
90        Self(0xCAFE_BABE)
91    }
92}
93
94/// Description of one derivative-Hessian contribution `H_j`.
95///
96/// The estimator needs `H_j` only via the quadratic form `z^T H_j w`, so we
97/// describe `H_j` *structurally* rather than as a dense matrix. The dense
98/// case is recovered by the [`DerivativeHessian::Dense`] variant.
99#[derive(Clone, Debug)]
100pub enum DerivativeHessian<'a> {
101    /// `H_j` is a `p × p` symmetric matrix. The reducer forms `Y = H_j W`
102    /// via GEMM and then sums `z_k^T y_k`.
103    Dense(ArrayView2<'a, f64>),
104    /// `H_j = X^T diag(a_j) X + P_j`, where `a_j` is an `n`-vector of row
105    /// weights and `P_j` is an optional `p × p` direct penalty contribution
106    /// that is *added* to the structural part. The reducer evaluates
107    /// `z^T H_j w  =  sum_i a_j[i] · (X z)[i] · (X w)[i]  +  z^T P_j w`
108    /// without materialising the `p × p` `H_j`.
109    WeightedGram {
110        row_weights: ArrayView1<'a, f64>,
111        penalty_extra: Option<ArrayView2<'a, f64>>,
112    },
113}
114
115impl DerivativeHessian<'_> {
116    fn dim_p(&self, expected_p: usize, expected_n: usize) -> Result<(), GpuError> {
117        match self {
118            DerivativeHessian::Dense(matrix) => {
119                if matrix.nrows() != expected_p || matrix.ncols() != expected_p {
120                    gam_gpu::gpu_bail!(
121                        "reml_trace dense H_j: shape {:?} != ({expected_p}, {expected_p})",
122                        matrix.dim()
123                    );
124                }
125            }
126            DerivativeHessian::WeightedGram {
127                row_weights,
128                penalty_extra,
129            } => {
130                if row_weights.len() != expected_n {
131                    gam_gpu::gpu_bail!(
132                        "reml_trace structural H_j: row_weights.len()={} != n={expected_n}",
133                        row_weights.len()
134                    );
135                }
136                if let Some(p_extra) = penalty_extra
137                    && (p_extra.nrows() != expected_p || p_extra.ncols() != expected_p)
138                {
139                    gam_gpu::gpu_bail!(
140                        "reml_trace structural H_j penalty_extra: shape {:?} != ({expected_p}, {expected_p})",
141                        p_extra.dim()
142                    );
143                }
144            }
145        }
146        Ok(())
147    }
148}
149
150/// Inputs to [`evidence_derivatives_hutchinson_gpu`].
151#[derive(Clone, Debug)]
152pub struct RemlTraceHutchinsonInput<'a> {
153    /// Penalized Hessian `H` (`p × p`, SPD).
154    pub penalized_hessian: ArrayView2<'a, f64>,
155    /// Per-derivative descriptors `H_j`. `D = derivatives.len()`.
156    pub derivatives: Vec<DerivativeHessian<'a>>,
157    /// Design matrix `X` (`n × p`). Required iff any `H_j` is structural;
158    /// `None` is acceptable when **all** derivatives are dense.
159    pub design: Option<ArrayView2<'a, f64>>,
160    /// Number of probe vectors. Must be ≥ 2 (so a sample SE is defined).
161    pub probe_count: usize,
162    /// Stateless RNG seed.
163    pub seed: ProbeSeed,
164}
165
166/// Output of [`evidence_derivatives_hutchinson_gpu`].
167#[derive(Clone, Debug)]
168pub struct RemlTraceHutchinsonEvidence {
169    /// `log |H|` from the cached Cholesky factor (same value the exact GPU
170    /// path returns; reusing the factor amortises this).
171    pub logdet_hessian: f64,
172    /// REML logdet gradient `g_j = (1/2) · mean_k(q_{j,k})`, length `D`.
173    pub gradient_rho_logdet: Array1<f64>,
174    /// Standard error of the half-scaled gradient estimator
175    /// `(1/2)·mean_k(q_{j,k})`, length `D`. This is the Bessel-corrected
176    /// sample standard deviation across probes divided by `sqrt(K)`, with the
177    /// same `(1/2)` REML logdet scaling as [`Self::gradient_rho_logdet`].
178    pub gradient_rho_stderr: Array1<f64>,
179    /// `K` probes actually used (matches `input.probe_count`).
180    pub probe_count: usize,
181}
182
183// ────────────────────────────────────────────────────────────────────────
184// Gating
185// ────────────────────────────────────────────────────────────────────────
186
187/// Minimum joint-dimension at which the GPU Hutchinson path is enabled.
188pub const HUTCHINSON_GPU_MIN_P: usize = 512;
189/// Minimum and maximum probe counts the GPU path accepts (math section 18).
190pub const HUTCHINSON_GPU_MIN_K: usize = 8;
191pub const HUTCHINSON_GPU_MAX_K: usize = 128;
192
193/// True when the GPU Hutchinson path is eligible at the current shape and
194/// configuration. Caller still has to satisfy the CPU-side gate
195/// (`prefers_stochastic_trace_estimation`, matching kernel, plain-SPD
196/// logdet, projected penalty subspace **inactive**) — the parameters
197/// `prefers_stochastic`, `kernel_matches_hinv`, `plain_spd_logdet`, and
198/// `projected_penalty_subspace_active` carry those CPU-side gate booleans
199/// into the dispatch decision.
200#[must_use]
201pub fn should_use_gpu_hutchinson(
202    p: usize,
203    probe_count: usize,
204    prefers_stochastic: bool,
205    kernel_matches_hinv: bool,
206    plain_spd_logdet: bool,
207    projected_penalty_subspace_active: bool,
208) -> bool {
209    p >= HUTCHINSON_GPU_MIN_P
210        && (HUTCHINSON_GPU_MIN_K..=HUTCHINSON_GPU_MAX_K).contains(&probe_count)
211        && prefers_stochastic
212        && kernel_matches_hinv
213        && plain_spd_logdet
214        && !projected_penalty_subspace_active
215}
216
217// ────────────────────────────────────────────────────────────────────────
218// Stateless SplitMix64 Rademacher RNG (host reference; mirrors the NVRTC
219// kernel byte-for-byte so CPU and GPU produce identical probes for the
220// same `(seed, k, i)`).
221// ────────────────────────────────────────────────────────────────────────
222
223/// SplitMix64 finalizer (Sebastiano Vigna, 2015). Thin wrapper over the
224/// canonical implementation in [`gam_linalg::utils::splitmix64_hash`].
225#[inline]
226pub fn splitmix64_mix(z: u64) -> u64 {
227    gam_linalg::utils::splitmix64_hash(z)
228}
229
230/// Stateless Rademacher entry at probe index `k` (0-based), coordinate
231/// `i` (0-based), seed `s`. Returns `+1.0` or `-1.0`.
232///
233/// The mix is `splitmix64(s ⊕ k·ζ ⊕ i·γ)` for two large odd constants
234/// `ζ`, `γ`; the sign bit (bit 63 of the hash) selects the sign. The two
235/// constants are *different* from the SplitMix increment so the row and
236/// column hashes don't collide on small `(k, i)`.
237#[inline]
238pub fn rademacher_entry(seed: u64, k: u64, i: u64) -> f64 {
239    const ZETA: u64 = 0xD1B5_4A32_D192_ED03;
240    const GAMMA: u64 = 0x8CB9_2BA7_2F9D_E81F;
241    let composite = seed ^ k.wrapping_mul(ZETA) ^ i.wrapping_mul(GAMMA);
242    let h = splitmix64_mix(composite);
243    if (h >> 63) == 0 { 1.0 } else { -1.0 }
244}
245
246/// Host-side reference: fill a column-major `(p, K)` Rademacher matrix.
247/// Used by tests to verify the GPU kernel produces the same bits.
248pub fn fill_rademacher_host(seed: ProbeSeed, p: usize, k: usize, out: &mut [f64]) {
249    assert_eq!(
250        out.len(),
251        p * k,
252        "fill_rademacher_host: out buffer length {} != p*K = {}*{}",
253        out.len(),
254        p,
255        k
256    );
257    for col in 0..k {
258        for row in 0..p {
259            out[col * p + row] = rademacher_entry(seed.0, col as u64, row as u64);
260        }
261    }
262}
263
264// ────────────────────────────────────────────────────────────────────────
265// CPU reference implementation of the Hutchinson estimator
266// ────────────────────────────────────────────────────────────────────────
267//
268// This path is what runs in CPU-only builds and is also what the V100
269// parity tests check the device implementation against. It uses the same
270// stateless SplitMix probes as the kernel.
271
272/// Run the Hutchinson estimator on CPU using the exact same probe bits
273/// the device kernel uses. Returns the same evidence struct.
274pub fn evidence_derivatives_hutchinson_cpu(
275    input: &RemlTraceHutchinsonInput<'_>,
276) -> Result<RemlTraceHutchinsonEvidence, String> {
277    validate_inputs(input)?;
278    let p = input.penalized_hessian.nrows();
279    let d = input.derivatives.len();
280    let k = input.probe_count;
281
282    // Cholesky factor of H (lower).
283    let h = input.penalized_hessian.to_owned();
284    let factor = cholesky_lower(&h)?;
285    let logdet_hessian = 2.0 * (0..p).map(|i| factor[[i, i]].ln()).sum::<f64>();
286
287    // Build Z (p, k) column-major in a flat vector.
288    let mut z = vec![0.0_f64; p * k];
289    fill_rademacher_host(input.seed, p, k, &mut z);
290
291    // Solve H W = Z column by column on CPU (matches what the device
292    // does in one batched potrs call). The K columns are independent — each
293    // `solve_cholesky` reads the shared (immutable) factor and writes only its
294    // own column of `w` — so they parallelize bit-for-bit (no reduction is
295    // reordered; each w-column is produced by exactly one task with identical
296    // arithmetic). The probes are embarrassingly parallel by construction; the
297    // CRN contract lives in the stateless SplitMix fill above, untouched.
298    use rayon::prelude::*;
299    let mut w = vec![0.0_f64; p * k];
300    w.par_chunks_mut(p)
301        .zip(z.par_chunks(p))
302        .for_each(|(w_col, z_col)| {
303            let solved = solve_cholesky(&factor, z_col);
304            w_col.copy_from_slice(&solved);
305        });
306
307    // Per-derivative quadratic forms. Each `q[j*k + col]` is an independent
308    // scalar function of probe column `col` only, so we parallelize over the
309    // probe columns. This is bit-identical to the serial fill: a given q entry
310    // is computed by one task with the same per-entry arithmetic, and the
311    // downstream `reduce_mean_stderr` indexes fixed (j, col) positions — no
312    // sum is reordered across threads.
313    let mut q = vec![0.0_f64; d * k]; // row-major (d, k): q[j*k + m]
314    for (j, derivative) in input.derivatives.iter().enumerate() {
315        let q_row = &mut q[j * k..(j + 1) * k];
316        match derivative {
317            DerivativeHessian::Dense(matrix) => {
318                q_row
319                    .par_iter_mut()
320                    .zip(z.par_chunks(p).zip(w.par_chunks(p)))
321                    .for_each(|(q_jk, (z_col, w_col))| {
322                        // y = H_j w
323                        let mut y = vec![0.0_f64; p];
324                        for r in 0..p {
325                            let mut acc = 0.0_f64;
326                            for c in 0..p {
327                                acc += matrix[[r, c]] * w_col[c];
328                            }
329                            y[r] = acc;
330                        }
331                        let mut zy = 0.0_f64;
332                        for i in 0..p {
333                            zy += z_col[i] * y[i];
334                        }
335                        *q_jk = zy;
336                    });
337            }
338            DerivativeHessian::WeightedGram {
339                row_weights,
340                penalty_extra,
341            } => {
342                let design = input.design.as_ref().expect("design validated");
343                let n = design.nrows();
344                q_row
345                    .par_iter_mut()
346                    .zip(z.par_chunks(p).zip(w.par_chunks(p)))
347                    .for_each(|(q_jk, (z_col, w_col))| {
348                        // r_z = X z (length n), r_w = X w (length n)
349                        let mut acc = 0.0_f64;
350                        for row in 0..n {
351                            let mut rz = 0.0_f64;
352                            let mut rw = 0.0_f64;
353                            for col_idx in 0..p {
354                                rz += design[[row, col_idx]] * z_col[col_idx];
355                                rw += design[[row, col_idx]] * w_col[col_idx];
356                            }
357                            acc += row_weights[row] * rz * rw;
358                        }
359                        if let Some(pen) = penalty_extra {
360                            for r in 0..p {
361                                let mut row_acc = 0.0_f64;
362                                for c in 0..p {
363                                    row_acc += pen[[r, c]] * w_col[c];
364                                }
365                                acc += z_col[r] * row_acc;
366                            }
367                        }
368                        *q_jk = acc;
369                    });
370            }
371        }
372    }
373
374    let (means, stderrs) = reduce_mean_stderr(&q, d, k);
375    let mut gradient_rho_logdet = Array1::<f64>::zeros(d);
376    let mut gradient_rho_stderr = Array1::<f64>::zeros(d);
377    for j in 0..d {
378        gradient_rho_logdet[j] = 0.5 * means[j];
379        gradient_rho_stderr[j] = 0.5 * stderrs[j];
380    }
381
382    Ok(RemlTraceHutchinsonEvidence {
383        logdet_hessian,
384        gradient_rho_logdet,
385        gradient_rho_stderr,
386        probe_count: k,
387    })
388}
389
390// ────────────────────────────────────────────────────────────────────────
391// Public dispatch entry point
392// ────────────────────────────────────────────────────────────────────────
393
394/// Compute `log |H|` and the Hutchinson estimate of `(1/2) tr(H^{-1} H_j)`
395/// for every derivative. Dispatches to the device-resident path when the
396/// CUDA runtime is up and probes the GPU successfully; otherwise runs the
397/// CPU reference. Either way the probe bits are identical (stateless
398/// SplitMix), so callers see the same estimator value to round-off.
399pub fn evidence_derivatives_hutchinson_gpu(
400    input: RemlTraceHutchinsonInput<'_>,
401) -> Result<RemlTraceHutchinsonEvidence, String> {
402    validate_inputs(&input)?;
403
404    #[cfg(target_os = "linux")]
405    {
406        if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
407            match linux_cuda::evidence_derivatives(&input) {
408                Ok(evidence) => return Ok(evidence),
409                Err(GpuError::NoDeviceKernel { .. }) => {
410                    // No device kernel for this path on this build: fall
411                    // through to the CPU reference.
412                }
413                Err(other) => return Err(String::from(other)),
414            }
415        }
416    }
417
418    evidence_derivatives_hutchinson_cpu(&input)
419}
420
421// ────────────────────────────────────────────────────────────────────────
422// Adaptive K (Block 2.5)
423// ────────────────────────────────────────────────────────────────────────
424
425/// Default relative-error target for the adaptive-K stopping rule.
426/// Matches `StochasticTraceConfig::default().relative_tol`.
427pub const HUTCHINSON_ADAPTIVE_REL_TOL: f64 = 0.01;
428/// Default near-zero-trace protection floor. Matches
429/// `StochasticTraceConfig::default().tau_rel`.
430pub const HUTCHINSON_ADAPTIVE_TAU_REL: f64 = 1e-8;
431
432/// Adaptive-K Hutchinson trace schedule with common random numbers (CRN).
433///
434/// Repeatedly invokes [`evidence_derivatives_hutchinson_gpu`] with probe
435/// counts `K = 16, 32, 64, 128`, stopping at the first `K` that satisfies
436/// the per-coordinate relative-SE criterion
437///
438/// ```text
439/// max_j  SE(t_j) / max(|t_j|, τ)  ≤  ε
440/// ```
441///
442/// where `SE(t_j)` is the standard error of the raw quadratic-form running
443/// mean (without the `(1/2)` REML logdet scaling) and `t_j` is the running mean. Because the SplitMix probe RNG is
444/// stateless (`(seed, k_index, i) → ±1`), the first `K_prev` probes of a
445/// `K = 2·K_prev` re-run are bit-identical to the previous batch, so each
446/// step extends the prior estimate rather than starting fresh in
447/// expectation. The implementation re-runs from scratch at each `K` for
448/// simplicity; CRN is preserved by the stateless RNG seed.
449///
450/// Returns the **raw traces** `t_j = tr(H⁻¹ H_j) = mean_k q_{j,k}`
451/// (length `D`), the `log|H|` from the cached Cholesky, and the final
452/// probe count `K` actually used. The raw traces (not the `(1/2)` REML
453/// logdet gradient) are what the outer evaluator wants — it applies the
454/// logdet-gradient half-factor itself.
455pub struct AdaptiveTraceEvidence {
456    pub logdet_hessian: f64,
457    pub traces: Array1<f64>,
458    /// Standard error of the raw trace estimator `mean_k(q_{j,k})`, i.e. the
459    /// Bessel-corrected sample standard deviation divided by `sqrt(K)`.
460    pub stderrs: Array1<f64>,
461    pub probe_count: usize,
462    pub converged: bool,
463}
464
465pub fn evidence_traces_adaptive<'a>(
466    penalized_hessian: ArrayView2<'a, f64>,
467    derivatives: Vec<DerivativeHessian<'a>>,
468    design: Option<ArrayView2<'a, f64>>,
469    seed: ProbeSeed,
470    rel_tol: f64,
471    tau_rel: f64,
472) -> Result<AdaptiveTraceEvidence, String> {
473    // Adaptive schedule per math team block 2 §16: K = 16, 32, 64, 128.
474    const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
475
476    let d = derivatives.len();
477    if d == 0 {
478        return Err("evidence_traces_adaptive: derivatives is empty".to_string());
479    }
480    if !(rel_tol > 0.0) {
481        return Err(format!(
482            "evidence_traces_adaptive: rel_tol must be > 0 (got {rel_tol})"
483        ));
484    }
485    if !(tau_rel > 0.0) {
486        return Err(format!(
487            "evidence_traces_adaptive: tau_rel must be > 0 (got {tau_rel})"
488        ));
489    }
490
491    let mut last_logdet = 0.0_f64;
492    let mut last_traces = Array1::<f64>::zeros(d);
493    let mut last_stderrs = Array1::<f64>::zeros(d);
494    let mut last_k = 0_usize;
495    let mut converged = false;
496
497    for &k in &SCHEDULE {
498        let input = RemlTraceHutchinsonInput {
499            penalized_hessian,
500            derivatives: derivatives.clone(),
501            design,
502            probe_count: k,
503            seed,
504        };
505        let evidence = evidence_derivatives_hutchinson_gpu(input)?;
506        last_logdet = evidence.logdet_hessian;
507        last_k = k;
508
509        // The dispatch entry returns the **(1/2)·mean** REML logdet
510        // gradient and **(1/2)·SE**. Undo the half to recover the raw
511        // `t_j = mean_k q_{j,k}` and the standard error of the raw mean.
512        for j in 0..d {
513            last_traces[j] = 2.0 * evidence.gradient_rho_logdet[j];
514            last_stderrs[j] = 2.0 * evidence.gradient_rho_stderr[j];
515        }
516
517        // Stopping rule (math block 2 §16):
518        //   max_j  SE(t_j) / max(|t_j|, τ)  ≤  ε
519        // where `last_stderrs[j]` is already the standard error of the
520        // running mean.
521        let mut worst = 0.0_f64;
522        for j in 0..d {
523            let denom = last_traces[j].abs().max(tau_rel);
524            let r = last_stderrs[j] / denom;
525            if r > worst {
526                worst = r;
527            }
528        }
529        if worst <= rel_tol {
530            converged = true;
531            break;
532        }
533    }
534
535    Ok(AdaptiveTraceEvidence {
536        logdet_hessian: last_logdet,
537        traces: last_traces,
538        stderrs: last_stderrs,
539        probe_count: last_k,
540        converged,
541    })
542}
543
544// ────────────────────────────────────────────────────────────────────────
545// Block 2.7: batched-PCG HVP variant of adaptive Hutchinson
546// ────────────────────────────────────────────────────────────────────────
547
548/// CG convergence tolerance for the per-probe solve `H w = z`. The outer
549/// adaptive-K loop already drives Hutchinson variance to ~1%; a per-probe
550/// relative residual of 1e-6 keeps the CG round-off well below the
551/// stochastic SE without paying for double-machine convergence.
552pub const PCG_HVP_REL_TOL: f64 = 1e-6;
553
554/// Maximum CG iterations per probe before we stop and accept the partial
555/// solve. Capped so a poorly conditioned `H` cannot make a single REML
556/// step pay unbounded time — the Hutchinson estimator is statistically
557/// robust to a few stale `w_k` values (it inflates SE, which the adaptive
558/// stopping rule then catches by extending the schedule).
559pub const PCG_HVP_MAX_ITERS: usize = 200;
560
561/// Adaptive Hutchinson variant that consumes `H` as a matrix-free HVP
562/// closure rather than a dense `ArrayView2`. Used by call sites where the
563/// penalized Hessian is implicit (operator-only) and forming it densely
564/// would blow the memory budget — e.g. the device-resident PCG path in
565/// `gpu/bms_flex_row.rs` or the large-scale BMS Schur operator.
566///
567/// `hvp` must compute `out ← H · v` for an SPD `H`. The closure is called
568/// once per CG iteration per probe (so `K · iters_per_probe` times in
569/// total for each schedule step). It is responsible for any necessary
570/// pre-conditioning state, threading, or device residency — the routine
571/// itself is pure CPU.
572///
573/// `derivatives` are still passed as dense or `WeightedGram`; the
574/// adaptive trace `t_j = mean_k z_k^T H_j w_k` only needs `H_j` to be
575/// available as a matvec, and the dense / weighted-Gram variants of
576/// `DerivativeHessian::quadratic_form` already provide that.
577///
578/// CRN is preserved exactly as in [`evidence_traces_adaptive`]: the
579/// SplitMix probe RNG is stateless in `(seed, k_index, i)`, so the
580/// `K=16, 32, 64, 128` schedule extends the prior estimate rather than
581/// restarting it. Each schedule step re-runs all `K` solves; the
582/// implementation is intentionally simple, the asymptotic cost is
583/// dominated by the largest `K`.
584///
585/// Returns the same [`AdaptiveTraceEvidence`] shape as the dense path,
586/// with one exception: `logdet_hessian` is **NaN** because no Cholesky
587/// is performed. Callers needing both `tr(H⁻¹ H_j)` and `log|H|` from
588/// the matrix-free path should obtain `log|H|` separately (e.g. via
589/// stochastic Lanczos or by routing through the dense path when `H`
590/// fits in memory).
591pub fn evidence_traces_adaptive_hvp<F>(
592    p: usize,
593    mut hvp: F,
594    derivatives: Vec<DerivativeHessian<'_>>,
595    design: Option<ArrayView2<'_, f64>>,
596    seed: ProbeSeed,
597    rel_tol: f64,
598    tau_rel: f64,
599) -> Result<AdaptiveTraceEvidence, String>
600where
601    F: FnMut(&[f64], &mut [f64]),
602{
603    const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
604
605    let d = derivatives.len();
606    if d == 0 {
607        return Err("evidence_traces_adaptive_hvp: derivatives is empty".to_string());
608    }
609    if p == 0 {
610        return Err("evidence_traces_adaptive_hvp: p must be > 0".to_string());
611    }
612    if !(rel_tol > 0.0) {
613        return Err(format!(
614            "evidence_traces_adaptive_hvp: rel_tol must be > 0 (got {rel_tol})"
615        ));
616    }
617    if !(tau_rel > 0.0) {
618        return Err(format!(
619            "evidence_traces_adaptive_hvp: tau_rel must be > 0 (got {tau_rel})"
620        ));
621    }
622
623    let mut last_traces = Array1::<f64>::zeros(d);
624    let mut last_stderrs = Array1::<f64>::zeros(d);
625    let mut last_k = 0_usize;
626    let mut converged = false;
627
628    let mut z = vec![0.0_f64; p];
629    let mut w = vec![0.0_f64; p];
630
631    // Per-derivative Welford accumulators (running mean and sum-of-squared
632    // deviations M2) for a numerically stable online mean / sample variance.
633    // The naive one-pass form E[q²] − E[q]² catastrophically cancels when the
634    // per-probe q cluster far from zero with small spread — exactly the
635    // near-converged regime the stopping rule cares about — so we track M2
636    // directly to match the two-pass `reduce_mean_stderr` without that loss.
637    let mut q_means = vec![0.0_f64; d];
638    let mut q_m2 = vec![0.0_f64; d];
639
640    for &k_target in &SCHEDULE {
641        // Re-run from scratch at each schedule step — CRN guarantees the
642        // first min(K_prev, K_target) probes are bit-identical, so the
643        // estimator is monotone in expectation across schedule extensions.
644        for s in q_means.iter_mut() {
645            *s = 0.0;
646        }
647        for s in q_m2.iter_mut() {
648            *s = 0.0;
649        }
650
651        for k_idx in 0..k_target {
652            // Fill z_k from the stateless SplitMix RNG.
653            for i in 0..p {
654                z[i] = rademacher_entry(seed.0, k_idx as u64, i as u64);
655            }
656            // Solve H w = z by unpreconditioned CG.
657            cg_solve(&mut hvp, &z, &mut w, PCG_HVP_REL_TOL, PCG_HVP_MAX_ITERS);
658
659            // Reduce q_{j,k} = z^T H_j w for each derivative. Mirrors the
660            // dense reference in `evidence_derivatives_hutchinson_cpu`.
661            for j in 0..d {
662                let q = match &derivatives[j] {
663                    DerivativeHessian::Dense(matrix) => {
664                        let mut y = 0.0_f64;
665                        for r in 0..p {
666                            let mut hr_w = 0.0_f64;
667                            for c in 0..p {
668                                hr_w += matrix[[r, c]] * w[c];
669                            }
670                            y += z[r] * hr_w;
671                        }
672                        y
673                    }
674                    DerivativeHessian::WeightedGram {
675                        row_weights,
676                        penalty_extra,
677                    } => {
678                        let design_view = design.as_ref().ok_or_else(|| {
679                            "evidence_traces_adaptive_hvp: WeightedGram derivative requires \
680                             design matrix"
681                                .to_string()
682                        })?;
683                        let n = design_view.nrows();
684                        let mut acc = 0.0_f64;
685                        for row in 0..n {
686                            let mut rz = 0.0_f64;
687                            let mut rw = 0.0_f64;
688                            for ci in 0..p {
689                                rz += design_view[[row, ci]] * z[ci];
690                                rw += design_view[[row, ci]] * w[ci];
691                            }
692                            acc += row_weights[row] * rz * rw;
693                        }
694                        if let Some(pen) = penalty_extra {
695                            for r in 0..p {
696                                let mut row_acc = 0.0_f64;
697                                for c in 0..p {
698                                    row_acc += pen[[r, c]] * w[c];
699                                }
700                                acc += z[r] * row_acc;
701                            }
702                        }
703                        acc
704                    }
705                };
706                // Welford update with the 1-based probe count (k_idx + 1).
707                let count = (k_idx + 1) as f64;
708                let delta = q - q_means[j];
709                q_means[j] += delta / count;
710                let delta2 = q - q_means[j];
711                q_m2[j] += delta * delta2;
712            }
713        }
714
715        let n = k_target as f64;
716        let mut worst_ratio = 0.0_f64;
717        for j in 0..d {
718            let mean = q_means[j];
719            // Sample variance M2 / (K−1) — Bessel's correction, matching the
720            // two-pass `reduce_mean_stderr` exactly (no one-pass cancellation).
721            // For K = 1 there is no spread to estimate, so the variance is 0.
722            let var = if n > 1.0 { q_m2[j] / (n - 1.0) } else { 0.0 };
723            let se = var.sqrt() / n.sqrt();
724            last_traces[j] = mean;
725            last_stderrs[j] = se;
726            let denom = mean.abs().max(tau_rel);
727            let r = se / denom;
728            if r > worst_ratio {
729                worst_ratio = r;
730            }
731        }
732        last_k = k_target;
733        if worst_ratio <= rel_tol {
734            converged = true;
735            break;
736        }
737    }
738
739    Ok(AdaptiveTraceEvidence {
740        logdet_hessian: f64::NAN,
741        traces: last_traces,
742        stderrs: last_stderrs,
743        probe_count: last_k,
744        converged,
745    })
746}
747
748/// Unpreconditioned conjugate gradients for `H w = b` with `H` accessed
749/// only through `hvp(v, out) → out ← H v`. SPD `H` is required.
750/// Initial guess is `w = 0`; stops when `‖r‖ ≤ rel_tol · ‖b‖` or after
751/// `max_iters` iterations.
752///
753/// Thin wrapper over the shared [`pcg_core`] (`linalg::pcg`): unpreconditioned
754/// (all-ones Jacobi diagonal), no residual refresh (`refresh_period = 0`), and
755/// no diagnostics. On a breakdown (lost SPD near convergence, non-finite
756/// scalar) the core stops and leaves the last valid iterate in `w`, which is
757/// the historical "accept current w" behavior.
758///
759/// Reduction: [`DotReduction::Reordered`]. This is the stochastic Hutchinson
760/// trace probe, NOT the main solve. The per-probe CG residual (`rel_tol`
761/// ≈ 1e-6) sits orders of magnitude below the estimator's own sampling SE, and
762/// the adaptive-K stopping rule budgets against that SE — so reordering the
763/// inner-product accumulation (ILP/SIMD reduction) only perturbs bits already
764/// dominated by Monte-Carlo noise. The CRN reproducibility that matters here is
765/// in the SplitMix probe RNG (`rademacher_entry`), which is untouched; we do
766/// NOT need the cross-thread bit-identity that the main SPD solve contracts
767/// for, so we trade it for add-side ILP on the hot per-iteration folds.
768fn cg_solve<F>(hvp: &mut F, b: &[f64], w: &mut [f64], rel_tol: f64, max_iters: usize)
769where
770    F: FnMut(&[f64], &mut [f64]),
771{
772    let n = b.len();
773    assert!(w.len() == n);
774
775    let rhs = ArrayView1::from(b);
776    let precond = Array1::<f64>::ones(n);
777    let mut solution = ArrayViewMut1::from(w);
778
779    pcg_core(
780        |v: &Array1<f64>, out: &mut Array1<f64>| {
781            // The core hands contiguous vectors; `hvp` speaks raw slices.
782            let v_slice = v.as_slice().expect("contiguous CG direction view");
783            let out_slice = out.as_slice_mut().expect("contiguous CG matvec view");
784            hvp(v_slice, out_slice);
785        },
786        &rhs,
787        &precond.view(),
788        rel_tol,
789        max_iters,
790        0,
791        false,
792        DotReduction::Reordered,
793        &mut solution,
794    );
795}
796
797// ────────────────────────────────────────────────────────────────────────
798// Outer logdet-gradient dispatch gate (Block 2.5)
799// ────────────────────────────────────────────────────────────────────────
800
801/// Composite gate predicate for the outer REML logdet-gradient bypass:
802/// when this returns `true`, the unified evaluator should replace its
803/// CPU stochastic-trace call with [`evidence_traces_adaptive`].
804///
805/// All five conditions must hold simultaneously:
806/// * `p ≥ 512` and `K_initial..=K_max` is `[16, 128]`
807/// * `H` is resident as a dense SPD operator (caller passes
808///   `dense_spd_h_resident = true` when `hop.as_exact_dense_spectral()`
809///   is `Some` AND the Cholesky succeeds — the latter is checked
810///   indirectly by `plain_spd_logdet`).
811/// * `plain_spd_logdet`: the operator's logdet kernel is `H⁻¹` exactly
812///   (i.e. `hop.logdet_traces_match_hinv_kernel() && hop.is_dense()`),
813///   so smooth-spectral and SCOP-warped paths are excluded.
814/// * `prefers_stochastic`: `hop.prefers_stochastic_trace_estimation()`.
815/// * `!projected_penalty_subspace_active`: the rank-deficient LAML
816///   projected kernel `U_S H_proj⁻¹ U_Sᵀ` is **not** installed.
817#[must_use]
818pub fn should_bypass_cpu_with_gpu_adaptive(
819    p: usize,
820    dense_spd_h_resident: bool,
821    plain_spd_logdet: bool,
822    prefers_stochastic: bool,
823    projected_penalty_subspace_active: bool,
824) -> bool {
825    p >= HUTCHINSON_GPU_MIN_P
826        && dense_spd_h_resident
827        && plain_spd_logdet
828        && prefers_stochastic
829        && !projected_penalty_subspace_active
830}
831
832// ────────────────────────────────────────────────────────────────────────
833// Linux/CUDA implementation
834// ────────────────────────────────────────────────────────────────────────
835
836#[cfg(target_os = "linux")]
837mod linux_cuda {
838    use super::{
839        DerivativeHessian, ProbeSeed, RemlTraceHutchinsonEvidence, RemlTraceHutchinsonInput,
840        reduce_mean_stderr,
841    };
842    use cudarc::cublas::sys::cublasOperation_t;
843    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
844    use cudarc::cusolver::DnHandle;
845    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
846    use gam_gpu::driver::to_col_major;
847    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
848    use gam_gpu::solver::{
849        cholesky_logdet_from_col_major, context_and_stream, pinned_htod, potrf_in_place,
850        potrs_in_place,
851    };
852    use std::sync::Arc;
853
854    /// NVRTC source for the three custom kernels used by this path. All
855    /// arithmetic is in `double` and the layouts are column-major to match
856    /// cuBLAS/cuSOLVER conventions.
857    ///
858    /// * `fill_rademacher_splitmix(seed, p, K, Z)` — stateless ±1 fill.
859    /// * `reduce_q_dense(p, K, D, Z, Y_stack, Q)` — `Q[j,k] = z_k^T Y_j[:,k]`
860    ///   with `Y_j[:,k] = (H_j W)[:,k]`. `Y_stack` is column-major shape
861    ///   `(p, K·D)` with derivative `j` occupying columns `[j·K, (j+1)·K)`.
862    /// * `reduce_q_weighted_gram(n, K, D, RZ_stride, RZ, RW, A_stack, Q)`
863    ///   — `Q[j,k] = sum_i A[i,j] · RZ[i,k] · RW[i,k]`. Used by the
864    ///   structural path. `A_stack` is column-major `(n, D)`.
865    ///
866    /// The reductions use a per-block warp-shuffle pattern with one block
867    /// per `(j, k)` output cell and `THREADS_PER_BLOCK` threads per block.
868    pub(super) const PTX_SOURCE: &str = r#"
869extern "C" __device__ unsigned long long splitmix64_mix(unsigned long long z) {
870    z += 0x9E3779B97F4A7C15ULL;
871    unsigned long long x = z;
872    x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
873    x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
874    return x ^ (x >> 31);
875}
876
877extern "C" __global__ void fill_rademacher_splitmix(
878    unsigned long long seed,
879    unsigned int p,
880    unsigned int K,
881    double* __restrict__ Z)
882{
883    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
884    unsigned int k = blockIdx.y;
885    if (i >= p || k >= K) return;
886    const unsigned long long ZETA  = 0xD1B54A32D192ED03ULL;
887    const unsigned long long GAMMA = 0x8CB92BA72F9DE81FULL;
888    unsigned long long composite =
889        seed
890        ^ (((unsigned long long)k) * ZETA)
891        ^ (((unsigned long long)i) * GAMMA);
892    unsigned long long h = splitmix64_mix(composite);
893    double v = (h >> 63) == 0 ? 1.0 : -1.0;
894    Z[(size_t)k * (size_t)p + (size_t)i] = v;
895}
896
897extern "C" __device__ double block_reduce_sum(double v) {
898    __shared__ double smem[32];
899    int lane = threadIdx.x & 31;
900    int wid  = threadIdx.x >> 5;
901    for (int off = 16; off > 0; off >>= 1) {
902        v += __shfl_down_sync(0xffffffff, v, off);
903    }
904    if (lane == 0) smem[wid] = v;
905    __syncthreads();
906    double total = 0.0;
907    int n_warps = (blockDim.x + 31) >> 5;
908    if (threadIdx.x < (unsigned)n_warps) total = smem[threadIdx.x];
909    if (wid == 0) {
910        for (int off = 16; off > 0; off >>= 1) {
911            total += __shfl_down_sync(0xffffffff, total, off);
912        }
913    }
914    return total;
915}
916
917extern "C" __global__ void reduce_q_dense(
918    unsigned int p,
919    unsigned int K,
920    unsigned int D,
921    const double* __restrict__ Z,
922    const double* __restrict__ Y_stack,
923    double* __restrict__ Q)
924{
925    unsigned int k = blockIdx.x;
926    unsigned int j = blockIdx.y;
927    if (k >= K || j >= D) return;
928    const double* z_col = Z + (size_t)k * (size_t)p;
929    const double* y_col = Y_stack + ((size_t)j * (size_t)K + (size_t)k) * (size_t)p;
930    double partial = 0.0;
931    for (unsigned int i = threadIdx.x; i < p; i += blockDim.x) {
932        partial += z_col[i] * y_col[i];
933    }
934    double total = block_reduce_sum(partial);
935    if (threadIdx.x == 0) {
936        Q[(size_t)j * (size_t)K + (size_t)k] = total;
937    }
938}
939
940extern "C" __global__ void reduce_q_weighted_gram(
941    unsigned int n,
942    unsigned int K,
943    unsigned int D,
944    const double* __restrict__ RZ,
945    const double* __restrict__ RW,
946    const double* __restrict__ A_stack,
947    double* __restrict__ Q)
948{
949    unsigned int k = blockIdx.x;
950    unsigned int j = blockIdx.y;
951    if (k >= K || j >= D) return;
952    const double* rz_col = RZ + (size_t)k * (size_t)n;
953    const double* rw_col = RW + (size_t)k * (size_t)n;
954    const double* a_col  = A_stack + (size_t)j * (size_t)n;
955    double partial = 0.0;
956    for (unsigned int i = threadIdx.x; i < n; i += blockDim.x) {
957        partial += a_col[i] * rz_col[i] * rw_col[i];
958    }
959    double total = block_reduce_sum(partial);
960    if (threadIdx.x == 0) {
961        Q[(size_t)j * (size_t)K + (size_t)k] = total;
962    }
963}
964"#;
965
966    const THREADS_PER_BLOCK: u32 = 256;
967
968    fn module(ctx: &Arc<CudaContext>) -> Result<&'static Arc<CudaModule>, GpuError> {
969        static CACHE: gam_gpu::device_cache::PtxModuleCache =
970            gam_gpu::device_cache::PtxModuleCache::new();
971        CACHE.get_or_compile(ctx, "reml_trace", PTX_SOURCE)
972    }
973
974    pub(super) fn evidence_derivatives(
975        input: &RemlTraceHutchinsonInput<'_>,
976    ) -> Result<RemlTraceHutchinsonEvidence, GpuError> {
977        let p = input.penalized_hessian.nrows();
978        let d = input.derivatives.len();
979        let k = input.probe_count;
980        let (ctx, stream) =
981            context_and_stream().map_err(|reason| GpuError::DriverCallFailed { reason })?;
982        let solver = DnHandle::new(stream.clone()).gpu_ctx("reml_trace cusolver init")?;
983        let blas = CudaBlas::new(stream.clone()).gpu_ctx("reml_trace cublas init")?;
984        let compiled = module(&ctx)?;
985        let module_handle: &Arc<CudaModule> = compiled;
986
987        // ── 1. Upload H, factor once.
988        let h_col = to_col_major(&input.penalized_hessian);
989        let mut h_dev =
990            pinned_htod(&stream, &h_col).map_err(|reason| GpuError::DriverCallFailed { reason })?;
991        potrf_in_place(&solver, &stream, p, &mut h_dev)
992            .map_err(|reason| GpuError::DriverCallFailed { reason })?;
993        let factor_col = stream
994            .clone_dtoh(&h_dev)
995            .gpu_ctx("reml_trace download factor")?;
996        let logdet_hessian = cholesky_logdet_from_col_major(&factor_col, p);
997
998        // ── 2. Allocate Z (p, K) and fill with Rademacher entries on device.
999        let total_z = p
1000            .checked_mul(k)
1001            .ok_or_else(|| gam_gpu::gpu_err!("reml_trace Z size overflow: p={p}, K={k}"))?;
1002        let mut z_dev = stream
1003            .alloc_zeros::<f64>(total_z)
1004            .gpu_ctx("reml_trace alloc Z")?;
1005        launch_fill_rademacher(&stream, module_handle, input.seed, p, k, &mut z_dev)?;
1006
1007        // ── 3. Solve H W = Z in a single batched potrs call (nrhs = K).
1008        //     Copy Z into a fresh buffer first; potrs is in-place.
1009        let mut w_dev = stream
1010            .alloc_zeros::<f64>(total_z)
1011            .gpu_ctx("reml_trace alloc W")?;
1012        copy_device_slice(&stream, &z_dev, &mut w_dev)?;
1013        potrs_in_place(&solver, &stream, p, k, &h_dev, &mut w_dev)
1014            .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1015
1016        // ── 4. Partition derivatives by kind.
1017        let mut dense_indices: Vec<usize> = Vec::new();
1018        let mut gram_indices: Vec<usize> = Vec::new();
1019        for (j, deriv) in input.derivatives.iter().enumerate() {
1020            match deriv {
1021                DerivativeHessian::Dense(_) => dense_indices.push(j),
1022                DerivativeHessian::WeightedGram { .. } => gram_indices.push(j),
1023            }
1024        }
1025
1026        let mut q_host = vec![0.0_f64; d * k];
1027
1028        // ── 5a. Dense path: for each dense H_j run a p×p × p×K GEMM and
1029        //       reduce. We loop over j rather than stacking the H_j's
1030        //       (would explode memory at large-scale-p), but the GEMMs share
1031        //       the resident W buffer.
1032        if !dense_indices.is_empty() {
1033            for &j in &dense_indices {
1034                let DerivativeHessian::Dense(matrix) = &input.derivatives[j] else {
1035                    // SAFETY: dense_indices was populated in the partition loop above
1036                    // with exactly the indices whose variant is DerivativeHessian::Dense.
1037                    // input.derivatives is immutably borrowed for the whole function so
1038                    // the slot at index j cannot have been rewritten between partition and
1039                    // this read; reaching this branch can only mean a future refactor split
1040                    // the partition from its consumer. The panic names the offending index.
1041                    panic!(
1042                        "reml_trace dense path: derivative index {j} is in dense_indices but \
1043                         input.derivatives[{j}] is not DerivativeHessian::Dense — \
1044                         dense_indices partition invariant violated"
1045                    );
1046                };
1047                let hj_col = to_col_major(matrix);
1048                let hj_dev = pinned_htod(&stream, &hj_col)
1049                    .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1050                let mut y_dev = stream
1051                    .alloc_zeros::<f64>(total_z)
1052                    .map_err(|err| gam_gpu::gpu_err!("reml_trace alloc Y_j (j={j}): {err}"))?;
1053                gemm_nn(
1054                    &blas,
1055                    GemmShape {
1056                        m: p,
1057                        n: k,
1058                        k_inner: p,
1059                        lda: p,
1060                        ldb: p,
1061                        ldc: p,
1062                    },
1063                    &hj_dev,
1064                    &w_dev,
1065                    &mut y_dev,
1066                )?;
1067                let mut q_j_dev = stream
1068                    .alloc_zeros::<f64>(k)
1069                    .gpu_ctx_with(|err| format!("reml_trace alloc Q_j (j={j}): {err}"))?;
1070                launch_reduce_q_dense(
1071                    &stream,
1072                    module_handle,
1073                    p,
1074                    k,
1075                    1,
1076                    &z_dev,
1077                    &y_dev,
1078                    &mut q_j_dev,
1079                )?;
1080                let q_host_j = stream
1081                    .clone_dtoh(&q_j_dev)
1082                    .gpu_ctx_with(|err| format!("reml_trace download Q_j (j={j}): {err}"))?;
1083                q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_j);
1084            }
1085        }
1086
1087        // ── 5b. Structural path: form R_Z = X Z and R_W = X W **once**,
1088        //       then run reduce_q_weighted_gram for each derivative.
1089        if !gram_indices.is_empty() {
1090            let design = input
1091                .design
1092                .as_ref()
1093                .ok_or_else(|| GpuError::DriverCallFailed {
1094                    reason: "reml_trace: structural derivative present but design=None".to_string(),
1095                })?;
1096            let n = design.nrows();
1097            let design_col = to_col_major(design);
1098            let x_dev = pinned_htod(&stream, &design_col)
1099                .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1100            let mut rz_dev = stream
1101                .alloc_zeros::<f64>(
1102                    n.checked_mul(k)
1103                        .ok_or_else(|| gam_gpu::gpu_err!("reml_trace RZ overflow: n={n}, K={k}"))?,
1104                )
1105                .gpu_ctx("reml_trace alloc RZ")?;
1106            let mut rw_dev = stream
1107                .alloc_zeros::<f64>(n * k)
1108                .gpu_ctx("reml_trace alloc RW")?;
1109            // R_Z = X Z   (n × p) · (p × K) -> (n × K)
1110            gemm_nn(
1111                &blas,
1112                GemmShape {
1113                    m: n,
1114                    n: k,
1115                    k_inner: p,
1116                    lda: n,
1117                    ldb: p,
1118                    ldc: n,
1119                },
1120                &x_dev,
1121                &z_dev,
1122                &mut rz_dev,
1123            )?;
1124            // R_W = X W
1125            gemm_nn(
1126                &blas,
1127                GemmShape {
1128                    m: n,
1129                    n: k,
1130                    k_inner: p,
1131                    lda: n,
1132                    ldb: p,
1133                    ldc: n,
1134                },
1135                &x_dev,
1136                &w_dev,
1137                &mut rw_dev,
1138            )?;
1139
1140            // Stack the row-weight vectors into A_stack column-major (n × D_gram).
1141            let d_gram = gram_indices.len();
1142            let mut a_stack = Vec::<f64>::with_capacity(n * d_gram);
1143            for &j in &gram_indices {
1144                let DerivativeHessian::WeightedGram { row_weights, .. } = &input.derivatives[j]
1145                else {
1146                    // SAFETY: gram_indices was populated in the partition loop above with
1147                    // exactly the indices whose variant is DerivativeHessian::WeightedGram.
1148                    // input.derivatives is immutably borrowed for the whole function so the
1149                    // slot at j cannot have been rewritten between partition and read; a
1150                    // failure here is a future-refactor bug, not a runtime input issue.
1151                    panic!(
1152                        "reml_trace structural path: derivative index {j} is in gram_indices \
1153                         but input.derivatives[{j}] is not DerivativeHessian::WeightedGram — \
1154                         gram_indices partition invariant violated"
1155                    );
1156                };
1157                let slice = row_weights.as_slice().ok_or_else(|| {
1158                    gam_gpu::gpu_err!("reml_trace structural H_j={j} row_weights not contiguous")
1159                })?;
1160                a_stack.extend_from_slice(slice);
1161            }
1162            let a_dev = pinned_htod(&stream, &a_stack)
1163                .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1164            let mut q_dev = stream
1165                .alloc_zeros::<f64>(d_gram * k)
1166                .map_err(|err| gam_gpu::gpu_err!("reml_trace alloc Q_gram: {err}"))?;
1167            launch_reduce_q_weighted_gram(
1168                &stream,
1169                module_handle,
1170                n,
1171                k,
1172                d_gram,
1173                &rz_dev,
1174                &rw_dev,
1175                &a_dev,
1176                &mut q_dev,
1177            )?;
1178            let q_host_gram = stream
1179                .clone_dtoh(&q_dev)
1180                .gpu_ctx("reml_trace download Q_gram")?;
1181            for (slot, &j) in gram_indices.iter().enumerate() {
1182                q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_gram[slot * k..(slot + 1) * k]);
1183            }
1184            // penalty_extra contributions (uncommon, dense p×p) — handled on
1185            // host to keep the kernel surface small; total cost p² · K per
1186            // derivative that has one.
1187            for &j in &gram_indices {
1188                let DerivativeHessian::WeightedGram { penalty_extra, .. } = &input.derivatives[j]
1189                else {
1190                    // SAFETY: gram_indices was populated by the partition loop above with
1191                    // exactly the WeightedGram-variant indices; the same indices are
1192                    // re-walked here to pick up the optional penalty_extra field.
1193                    // input.derivatives has been immutably borrowed since partitioning, so
1194                    // the variant at index j cannot have changed. A let-else failure here
1195                    // would mean a future refactor split partition from consumer loops.
1196                    panic!(
1197                        "reml_trace structural penalty_extra: derivative index {j} is in \
1198                         gram_indices but input.derivatives[{j}] is not \
1199                         DerivativeHessian::WeightedGram — gram_indices partition invariant \
1200                         violated"
1201                    );
1202                };
1203                if let Some(pen) = penalty_extra {
1204                    let z_host = stream
1205                        .clone_dtoh(&z_dev)
1206                        .gpu_ctx("reml_trace download Z for penalty_extra")?;
1207                    let w_host = stream
1208                        .clone_dtoh(&w_dev)
1209                        .gpu_ctx("reml_trace download W for penalty_extra")?;
1210                    for col in 0..k {
1211                        let z_col = &z_host[col * p..(col + 1) * p];
1212                        let w_col = &w_host[col * p..(col + 1) * p];
1213                        let mut acc = 0.0_f64;
1214                        for r in 0..p {
1215                            let mut row_acc = 0.0_f64;
1216                            for c in 0..p {
1217                                row_acc += pen[[r, c]] * w_col[c];
1218                            }
1219                            acc += z_col[r] * row_acc;
1220                        }
1221                        q_host[j * k + col] += acc;
1222                    }
1223                }
1224            }
1225        }
1226
1227        let (means, stderrs) = reduce_mean_stderr(&q_host, d, k);
1228        let mut gradient_rho_logdet = ndarray::Array1::<f64>::zeros(d);
1229        let mut gradient_rho_stderr = ndarray::Array1::<f64>::zeros(d);
1230        for j in 0..d {
1231            gradient_rho_logdet[j] = 0.5 * means[j];
1232            gradient_rho_stderr[j] = 0.5 * stderrs[j];
1233        }
1234
1235        Ok(RemlTraceHutchinsonEvidence {
1236            logdet_hessian,
1237            gradient_rho_logdet,
1238            gradient_rho_stderr,
1239            probe_count: k,
1240        })
1241    }
1242
1243    // ───── kernel launch wrappers ────────────────────────────────────────
1244
1245    fn launch_fill_rademacher(
1246        stream: &Arc<CudaStream>,
1247        module: &Arc<CudaModule>,
1248        seed: ProbeSeed,
1249        p: usize,
1250        k: usize,
1251        z: &mut cudarc::driver::CudaSlice<f64>,
1252    ) -> Result<(), GpuError> {
1253        let func = module
1254            .load_function("fill_rademacher_splitmix")
1255            .gpu_ctx("reml_trace load fill_rademacher")?;
1256        let grid_x = ((p as u32) + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1257        let cfg = LaunchConfig {
1258            grid_dim: (grid_x, k as u32, 1),
1259            block_dim: (THREADS_PER_BLOCK, 1, 1),
1260            shared_mem_bytes: 0,
1261        };
1262        let seed_arg: u64 = seed.0;
1263        let p_arg: u32 = p as u32;
1264        let k_arg: u32 = k as u32;
1265        // SAFETY: kernel signature matches arg types; Z is a live device
1266        // buffer sized p*k.
1267        unsafe {
1268            stream
1269                .launch_builder(&func)
1270                .arg(&seed_arg)
1271                .arg(&p_arg)
1272                .arg(&k_arg)
1273                .arg(z)
1274                .launch(cfg)
1275        }
1276        .map(|_| ())
1277        .gpu_ctx("reml_trace launch fill_rademacher")
1278    }
1279
1280    fn launch_reduce_q_dense(
1281        stream: &Arc<CudaStream>,
1282        module: &Arc<CudaModule>,
1283        p: usize,
1284        k: usize,
1285        d: usize,
1286        z: &cudarc::driver::CudaSlice<f64>,
1287        y_stack: &cudarc::driver::CudaSlice<f64>,
1288        q: &mut cudarc::driver::CudaSlice<f64>,
1289    ) -> Result<(), GpuError> {
1290        let func = module
1291            .load_function("reduce_q_dense")
1292            .gpu_ctx("reml_trace load reduce_q_dense")?;
1293        let cfg = LaunchConfig {
1294            grid_dim: (k as u32, d as u32, 1),
1295            block_dim: (THREADS_PER_BLOCK, 1, 1),
1296            shared_mem_bytes: 0,
1297        };
1298        let p_arg: u32 = p as u32;
1299        let k_arg: u32 = k as u32;
1300        let d_arg: u32 = d as u32;
1301        // SAFETY: kernel signature matches; Z is (p,K), Y_stack is (p,K*D),
1302        // Q is (D,K) row-major as documented.
1303        unsafe {
1304            stream
1305                .launch_builder(&func)
1306                .arg(&p_arg)
1307                .arg(&k_arg)
1308                .arg(&d_arg)
1309                .arg(z)
1310                .arg(y_stack)
1311                .arg(q)
1312                .launch(cfg)
1313        }
1314        .map(|_| ())
1315        .gpu_ctx("reml_trace launch reduce_q_dense")
1316    }
1317
1318    fn launch_reduce_q_weighted_gram(
1319        stream: &Arc<CudaStream>,
1320        module: &Arc<CudaModule>,
1321        n: usize,
1322        k: usize,
1323        d: usize,
1324        rz: &cudarc::driver::CudaSlice<f64>,
1325        rw: &cudarc::driver::CudaSlice<f64>,
1326        a_stack: &cudarc::driver::CudaSlice<f64>,
1327        q: &mut cudarc::driver::CudaSlice<f64>,
1328    ) -> Result<(), GpuError> {
1329        let func = module
1330            .load_function("reduce_q_weighted_gram")
1331            .gpu_ctx("reml_trace load reduce_q_weighted_gram")?;
1332        let cfg = LaunchConfig {
1333            grid_dim: (k as u32, d as u32, 1),
1334            block_dim: (THREADS_PER_BLOCK, 1, 1),
1335            shared_mem_bytes: 0,
1336        };
1337        let n_arg: u32 = n as u32;
1338        let k_arg: u32 = k as u32;
1339        let d_arg: u32 = d as u32;
1340        // SAFETY: kernel signature matches; RZ, RW are (n,K), A_stack is (n,D).
1341        unsafe {
1342            stream
1343                .launch_builder(&func)
1344                .arg(&n_arg)
1345                .arg(&k_arg)
1346                .arg(&d_arg)
1347                .arg(rz)
1348                .arg(rw)
1349                .arg(a_stack)
1350                .arg(q)
1351                .launch(cfg)
1352        }
1353        .map(|_| ())
1354        .gpu_ctx("reml_trace launch reduce_q_weighted_gram")
1355    }
1356
1357    fn copy_device_slice(
1358        stream: &Arc<CudaStream>,
1359        src: &cudarc::driver::CudaSlice<f64>,
1360        dst: &mut cudarc::driver::CudaSlice<f64>,
1361    ) -> Result<(), GpuError> {
1362        stream.memcpy_dtod(src, dst).gpu_ctx("reml_trace dtod copy")
1363    }
1364
1365    struct GemmShape {
1366        m: usize,
1367        n: usize,
1368        k_inner: usize,
1369        lda: usize,
1370        ldb: usize,
1371        ldc: usize,
1372    }
1373
1374    fn gemm_nn(
1375        blas: &CudaBlas,
1376        shape: GemmShape,
1377        a: &cudarc::driver::CudaSlice<f64>,
1378        b: &cudarc::driver::CudaSlice<f64>,
1379        c: &mut cudarc::driver::CudaSlice<f64>,
1380    ) -> Result<(), GpuError> {
1381        let GemmShape {
1382            m,
1383            n,
1384            k_inner,
1385            lda,
1386            ldb,
1387            ldc,
1388        } = shape;
1389        let cfg = GemmConfig::<f64> {
1390            transa: cublasOperation_t::CUBLAS_OP_N,
1391            transb: cublasOperation_t::CUBLAS_OP_N,
1392            m: m as i32,
1393            n: n as i32,
1394            k: k_inner as i32,
1395            alpha: 1.0,
1396            lda: lda as i32,
1397            ldb: ldb as i32,
1398            beta: 0.0,
1399            ldc: ldc as i32,
1400        };
1401        // SAFETY: dgemm with column-major leading dims documented above;
1402        // buffers a, b, c sized lda*k_inner, ldb*n, ldc*n.
1403        unsafe { blas.gemm(cfg, a, b, c) }.gpu_ctx("reml_trace cublas dgemm")
1404    }
1405}
1406
1407// ────────────────────────────────────────────────────────────────────────
1408// Shared validation + linear algebra helpers
1409// ────────────────────────────────────────────────────────────────────────
1410
1411fn validate_inputs(input: &RemlTraceHutchinsonInput<'_>) -> Result<(), String> {
1412    let (p, p2) = input.penalized_hessian.dim();
1413    if p == 0 || p != p2 {
1414        return Err(format!("reml_trace input H must be square, got {p}x{p2}"));
1415    }
1416    if input.probe_count < 2 {
1417        return Err(format!(
1418            "reml_trace requires probe_count >= 2 for a sample SE, got {}",
1419            input.probe_count
1420        ));
1421    }
1422    let needs_design = input
1423        .derivatives
1424        .iter()
1425        .any(|d| matches!(d, DerivativeHessian::WeightedGram { .. }));
1426    if needs_design && input.design.is_none() {
1427        return Err("reml_trace: structural derivative present but design=None".to_string());
1428    }
1429    let n = input.design.as_ref().map(|x| x.nrows()).unwrap_or(0);
1430    if let Some(x) = input.design.as_ref()
1431        && x.ncols() != p
1432    {
1433        return Err(format!(
1434            "reml_trace design has {} columns, expected p={p}",
1435            x.ncols()
1436        ));
1437    }
1438    for (j, derivative) in input.derivatives.iter().enumerate() {
1439        derivative
1440            .dim_p(p, n)
1441            .map_err(String::from)
1442            .map_err(|e| format!("reml_trace derivative {j}: {e}"))?;
1443    }
1444    Ok(())
1445}
1446
1447/// Compute the per-derivative sample mean and **standard error of that mean**
1448/// from the flat row-major (D, K) Q matrix. The variance uses Bessel's
1449/// correction (K-1), then divides by `K` to report the uncertainty of
1450/// `mean_k(q_{j,k})` rather than the per-probe spread.
1451fn reduce_mean_stderr(q: &[f64], d: usize, k: usize) -> (Vec<f64>, Vec<f64>) {
1452    assert_eq!(
1453        q.len(),
1454        d * k,
1455        "reduce_mean_stderr: q buffer length {} != D*K = {}*{}",
1456        q.len(),
1457        d,
1458        k
1459    );
1460    let mut means = vec![0.0_f64; d];
1461    let mut stderrs = vec![0.0_f64; d];
1462    let inv_k = 1.0 / (k as f64);
1463    for j in 0..d {
1464        let row = &q[j * k..(j + 1) * k];
1465        let mean = row.iter().copied().sum::<f64>() * inv_k;
1466        means[j] = mean;
1467        if k >= 2 {
1468            let var = row.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / ((k - 1) as f64);
1469            stderrs[j] = (var / (k as f64)).sqrt();
1470        }
1471    }
1472    (means, stderrs)
1473}
1474
1475// ── Cholesky helpers (CPU reference only) ──────────────────────────────
1476
1477fn cholesky_lower(matrix: &Array2<f64>) -> Result<Array2<f64>, String> {
1478    let n = matrix.nrows();
1479    let mut l = Array2::<f64>::zeros((n, n));
1480    for i in 0..n {
1481        for j in 0..=i {
1482            let mut sum = matrix[[i, j]];
1483            for k in 0..j {
1484                sum -= l[[i, k]] * l[[j, k]];
1485            }
1486            if i == j {
1487                if sum <= 0.0 {
1488                    return Err(format!(
1489                        "reml_trace CPU Cholesky: non-SPD diagonal {sum} at row {i}"
1490                    ));
1491                }
1492                l[[i, j]] = sum.sqrt();
1493            } else {
1494                l[[i, j]] = sum / l[[j, j]];
1495            }
1496        }
1497    }
1498    Ok(l)
1499}
1500
1501fn solve_cholesky(l: &Array2<f64>, rhs: &[f64]) -> Vec<f64> {
1502    let n = l.nrows();
1503    let mut y = vec![0.0_f64; n];
1504    for i in 0..n {
1505        let mut sum = rhs[i];
1506        for k in 0..i {
1507            sum -= l[[i, k]] * y[k];
1508        }
1509        y[i] = sum / l[[i, i]];
1510    }
1511    let mut x = vec![0.0_f64; n];
1512    for i in (0..n).rev() {
1513        let mut sum = y[i];
1514        for k in (i + 1)..n {
1515            sum -= l[[k, i]] * x[k];
1516        }
1517        x[i] = sum / l[[i, i]];
1518    }
1519    x
1520}
1521
1522// ────────────────────────────────────────────────────────────────────────
1523// Tests
1524// ────────────────────────────────────────────────────────────────────────
1525
1526#[cfg(test)]
1527mod tests {
1528    use super::*;
1529    use ndarray::{Array2, ArrayView2};
1530
1531    fn make_spd(p: usize, jitter: f64) -> Array2<f64> {
1532        let mut h = Array2::<f64>::zeros((p, p));
1533        for i in 0..p {
1534            for j in 0..p {
1535                h[[i, j]] = if i == j {
1536                    p as f64 + jitter
1537                } else {
1538                    1.0 / (1.0 + (i as f64 - j as f64).abs())
1539                };
1540            }
1541        }
1542        h
1543    }
1544
1545    fn random_dense_sym(p: usize, seed: u64) -> Array2<f64> {
1546        let mut a = Array2::<f64>::zeros((p, p));
1547        let mut s = seed;
1548        for i in 0..p {
1549            for j in i..p {
1550                s = splitmix64_mix(s.wrapping_add(1));
1551                let v = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
1552                a[[i, j]] = v;
1553                a[[j, i]] = v;
1554            }
1555        }
1556        a
1557    }
1558
1559    fn exact_trace_hinv_a(h: ArrayView2<f64>, a: ArrayView2<f64>) -> f64 {
1560        let p = h.nrows();
1561        let factor = cholesky_lower(&h.to_owned()).expect("SPD");
1562        let mut trace = 0.0;
1563        for col in 0..p {
1564            let mut e = vec![0.0_f64; p];
1565            e[col] = 1.0;
1566            let w = solve_cholesky(&factor, &e);
1567            // (H^{-1} A) diag entry [col, col] = sum_i A[col, i] * w[i]
1568            let mut diag = 0.0;
1569            for i in 0..p {
1570                diag += a[[col, i]] * w[i];
1571            }
1572            trace += diag;
1573        }
1574        trace
1575    }
1576
1577    #[test]
1578    fn splitmix_is_deterministic_and_disperses() {
1579        // Self-consistency: same input → same output, and a few near-by
1580        // inputs land in distinct buckets (no trivial collisions).
1581        assert_eq!(splitmix64_mix(42), splitmix64_mix(42));
1582        let mut bits_seen = 0u64;
1583        for x in 0u64..64 {
1584            bits_seen |= splitmix64_mix(x);
1585        }
1586        assert_eq!(
1587            bits_seen,
1588            u64::MAX,
1589            "splitmix should cover every bit position across 64 inputs"
1590        );
1591    }
1592
1593    #[test]
1594    fn rademacher_entries_are_pm_one_and_stateless() {
1595        let seed = ProbeSeed(0xCAFE_BABE);
1596        for k in 0..16u64 {
1597            for i in 0..32u64 {
1598                let v = rademacher_entry(seed.0, k, i);
1599                assert!(
1600                    v == 1.0 || v == -1.0,
1601                    "non-pm1 entry at (k={k}, i={i}): {v}"
1602                );
1603                let v2 = rademacher_entry(seed.0, k, i);
1604                assert_eq!(v, v2, "same (k,i) must hash to same value");
1605            }
1606        }
1607    }
1608
1609    #[test]
1610    fn rademacher_common_random_numbers_match_for_prefix() {
1611        // First 16 probes of a K=16 run must equal first 16 probes of K=32.
1612        let p = 50;
1613        let mut z16 = vec![0.0_f64; p * 16];
1614        let mut z32 = vec![0.0_f64; p * 32];
1615        fill_rademacher_host(ProbeSeed(7), p, 16, &mut z16);
1616        fill_rademacher_host(ProbeSeed(7), p, 32, &mut z32);
1617        for col in 0..16 {
1618            for row in 0..p {
1619                assert_eq!(
1620                    z16[col * p + row],
1621                    z32[col * p + row],
1622                    "CRN broken at (col={col}, row={row})"
1623                );
1624            }
1625        }
1626    }
1627
1628    #[test]
1629    fn cpu_hutchinson_unbiased_against_exact_small_spd() {
1630        let p = 16;
1631        let h = make_spd(p, 0.5);
1632        let a1 = random_dense_sym(p, 0x1234);
1633        let a2 = random_dense_sym(p, 0x5678);
1634        let exact1 = exact_trace_hinv_a(h.view(), a1.view());
1635        let exact2 = exact_trace_hinv_a(h.view(), a2.view());
1636        let input = RemlTraceHutchinsonInput {
1637            penalized_hessian: h.view(),
1638            derivatives: vec![
1639                DerivativeHessian::Dense(a1.view()),
1640                DerivativeHessian::Dense(a2.view()),
1641            ],
1642            design: None,
1643            probe_count: 4096,
1644            seed: ProbeSeed(0xCAFE_BABE),
1645        };
1646        let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
1647        // gradient = 0.5 * trace, so multiply estimate by 2 for the trace.
1648        let est1 = 2.0 * evidence.gradient_rho_logdet[0];
1649        let est2 = 2.0 * evidence.gradient_rho_logdet[1];
1650        // `gradient_rho_stderr` is already the SE of the half-scaled
1651        // gradient; multiply by 2 for the raw trace SE.
1652        let se1 = 2.0 * evidence.gradient_rho_stderr[0];
1653        let se2 = 2.0 * evidence.gradient_rho_stderr[1];
1654        let tol1 = 6.0 * se1.max(1e-8);
1655        let tol2 = 6.0 * se2.max(1e-8);
1656        assert!(
1657            (est1 - exact1).abs() <= tol1,
1658            "Hutchinson est {est1} too far from exact {exact1} (tol={tol1}, se={})",
1659            evidence.gradient_rho_stderr[0]
1660        );
1661        assert!(
1662            (est2 - exact2).abs() <= tol2,
1663            "Hutchinson est {est2} too far from exact {exact2} (tol={tol2})"
1664        );
1665    }
1666
1667    #[test]
1668    fn structural_path_matches_dense_for_xtwx() {
1669        // Build H_j = X^T diag(a) X exactly; both the dense and the
1670        // structural descriptor must produce the same q value per probe.
1671        let n = 40;
1672        let p = 8;
1673        let mut x = Array2::<f64>::zeros((n, p));
1674        let mut s = 11u64;
1675        for r in 0..n {
1676            for c in 0..p {
1677                s = splitmix64_mix(s.wrapping_add(1));
1678                x[[r, c]] = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
1679            }
1680        }
1681        let a: Vec<f64> = (0..n).map(|i| 0.5 + 0.01 * (i as f64)).collect();
1682        let a_arr = ndarray::Array1::from(a);
1683        // H_j dense
1684        let mut hj_dense = Array2::<f64>::zeros((p, p));
1685        for r in 0..p {
1686            for c in 0..p {
1687                let mut acc = 0.0;
1688                for i in 0..n {
1689                    acc += x[[i, r]] * a_arr[i] * x[[i, c]];
1690                }
1691                hj_dense[[r, c]] = acc;
1692            }
1693        }
1694        // SPD H so the solve is well posed.
1695        let mut h = make_spd(p, 1.0);
1696        for i in 0..p {
1697            h[[i, i]] += 1.0;
1698        }
1699        let input_dense = RemlTraceHutchinsonInput {
1700            penalized_hessian: h.view(),
1701            derivatives: vec![DerivativeHessian::Dense(hj_dense.view())],
1702            design: None,
1703            probe_count: 32,
1704            seed: ProbeSeed(123),
1705        };
1706        let input_struct = RemlTraceHutchinsonInput {
1707            penalized_hessian: h.view(),
1708            derivatives: vec![DerivativeHessian::WeightedGram {
1709                row_weights: a_arr.view(),
1710                penalty_extra: None,
1711            }],
1712            design: Some(x.view()),
1713            probe_count: 32,
1714            seed: ProbeSeed(123),
1715        };
1716        let e_dense = evidence_derivatives_hutchinson_cpu(&input_dense).expect("ok");
1717        let e_struct = evidence_derivatives_hutchinson_cpu(&input_struct).expect("ok");
1718        // Same probes, same H_j ⇒ identical estimator (modulo round-off).
1719        assert!(
1720            (e_dense.gradient_rho_logdet[0] - e_struct.gradient_rho_logdet[0]).abs() < 1e-9,
1721            "dense vs structural mismatch: dense={}, struct={}",
1722            e_dense.gradient_rho_logdet[0],
1723            e_struct.gradient_rho_logdet[0]
1724        );
1725    }
1726
1727    #[test]
1728    fn finite_difference_check_against_logdet() {
1729        // For H(rho) = H0 + rho * A, d/d(rho) log|H| = tr(H^{-1} A).
1730        let p = 10;
1731        let h0 = make_spd(p, 0.2);
1732        let a = random_dense_sym(p, 0xABCD);
1733        let eps = 1e-4;
1734        let mut hp = h0.clone();
1735        let mut hm = h0.clone();
1736        for i in 0..p {
1737            for j in 0..p {
1738                hp[[i, j]] += eps * a[[i, j]];
1739                hm[[i, j]] -= eps * a[[i, j]];
1740            }
1741        }
1742        let ld = |m: &Array2<f64>| -> f64 {
1743            let l = cholesky_lower(m).unwrap();
1744            2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
1745        };
1746        let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
1747        let exact = exact_trace_hinv_a(h0.view(), a.view());
1748        assert!(
1749            (fd - exact).abs() / exact.abs().max(1e-12) < 1e-6,
1750            "FD logdet derivative {fd} != exact trace {exact}"
1751        );
1752        // And Hutchinson should land near 0.5 * exact (the gradient form).
1753        let input = RemlTraceHutchinsonInput {
1754            penalized_hessian: h0.view(),
1755            derivatives: vec![DerivativeHessian::Dense(a.view())],
1756            design: None,
1757            probe_count: 4096,
1758            seed: ProbeSeed(0xAA55),
1759        };
1760        let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
1761        // SE of the half-scaled gradient mean.
1762        let se = evidence.gradient_rho_stderr[0];
1763        let tol = 8.0 * se.max(1e-8);
1764        assert!(
1765            (evidence.gradient_rho_logdet[0] - 0.5 * exact).abs() < tol,
1766            "Hutchinson gradient {} not within 8·SE of 0.5·exact={}",
1767            evidence.gradient_rho_logdet[0],
1768            0.5 * exact
1769        );
1770    }
1771
1772    #[test]
1773    fn gate_rejects_below_min_p() {
1774        assert!(!should_use_gpu_hutchinson(64, 16, true, true, true, false));
1775    }
1776
1777    #[test]
1778    fn gate_rejects_k_out_of_range() {
1779        assert!(!should_use_gpu_hutchinson(2000, 4, true, true, true, false));
1780        assert!(!should_use_gpu_hutchinson(
1781            2000, 200, true, true, true, false
1782        ));
1783    }
1784
1785    #[test]
1786    fn gate_rejects_when_subspace_active() {
1787        assert!(!should_use_gpu_hutchinson(2000, 16, true, true, true, true));
1788    }
1789
1790    #[test]
1791    fn gate_accepts_canonical_case() {
1792        assert!(should_use_gpu_hutchinson(2000, 16, true, true, true, false));
1793    }
1794
1795    // ────────────────────────────────────────────────────────────────
1796    // Block 2.6: adaptive-K validation tests.
1797    //
1798    // All five run on CPU hosts (where `evidence_derivatives_hutchinson_gpu`
1799    // falls back to the SplitMix CPU reference) and on V100 hosts (where the
1800    // CUDA path takes over). Probe-level CRN is preserved across both paths.
1801    // ────────────────────────────────────────────────────────────────
1802
1803    #[test]
1804    fn block_2_6_adaptive_unbiased_against_exact_p512() {
1805        // (1) Adaptive Hutchinson with the default ε must land near the
1806        // exact `tr(H⁻¹ A)` within its reported stopping tolerance.
1807        let p = 64;
1808        let h = make_spd(p, 0.5);
1809        let a = random_dense_sym(p, 0xBADC0DE);
1810        let exact = exact_trace_hinv_a(h.view(), a.view());
1811        let evidence = evidence_traces_adaptive(
1812            h.view(),
1813            vec![DerivativeHessian::Dense(a.view())],
1814            None,
1815            ProbeSeed(0xA5A5A5),
1816            HUTCHINSON_ADAPTIVE_REL_TOL,
1817            HUTCHINSON_ADAPTIVE_TAU_REL,
1818        )
1819        .expect("adaptive run ok");
1820        let est = evidence.traces[0];
1821        let se = evidence.stderrs[0];
1822        let tol = (8.0 * se).max(0.05 * exact.abs());
1823        assert!(
1824            (est - exact).abs() <= tol,
1825            "adaptive est {est} far from exact {exact} (tol={tol}, se={se}, K={})",
1826            evidence.probe_count
1827        );
1828    }
1829
1830    #[test]
1831    fn block_2_6_same_probes_cpu_vs_dispatch() {
1832        // (2) The dispatch entry (`_gpu`) and the explicit CPU reference
1833        // must produce identical estimates when given the same probes.
1834        // The dispatcher falls back to the CPU reference on non-CUDA hosts,
1835        // so this is a tautology on CPU; on V100 it asserts bit-identical
1836        // q-values across paths (the `q_{j,k}=z_k^T H_j w_k` reduction is
1837        // deterministic to machine precision once probes match).
1838        let p = 32;
1839        let h = make_spd(p, 0.3);
1840        let a = random_dense_sym(p, 0x1357);
1841        let input = RemlTraceHutchinsonInput {
1842            penalized_hessian: h.view(),
1843            derivatives: vec![DerivativeHessian::Dense(a.view())],
1844            design: None,
1845            probe_count: 16,
1846            seed: ProbeSeed(0xBEEF),
1847        };
1848        let cpu = evidence_derivatives_hutchinson_cpu(&input).expect("cpu");
1849        let dispatch = evidence_derivatives_hutchinson_gpu(input).expect("dispatch");
1850        let diff = (cpu.gradient_rho_logdet[0] - dispatch.gradient_rho_logdet[0]).abs();
1851        assert!(
1852            diff < 1e-9,
1853            "same-probes CPU vs GPU dispatch differ: cpu={}, dispatch={}, diff={diff}",
1854            cpu.gradient_rho_logdet[0],
1855            dispatch.gradient_rho_logdet[0]
1856        );
1857    }
1858
1859    #[test]
1860    fn block_2_6_fd_logdet_matches_adaptive() {
1861        // (3) Adaptive estimate of `tr(H⁻¹ A)` should agree with the
1862        // central-difference derivative `d/dρ log|H + ρA|` at ρ=0.
1863        let p = 24;
1864        let h = make_spd(p, 0.4);
1865        let a = random_dense_sym(p, 0x2468);
1866        let eps = 1e-4;
1867        let mut hp = h.clone();
1868        let mut hm = h.clone();
1869        for i in 0..p {
1870            for j in 0..p {
1871                hp[[i, j]] += eps * a[[i, j]];
1872                hm[[i, j]] -= eps * a[[i, j]];
1873            }
1874        }
1875        let ld = |m: &Array2<f64>| -> f64 {
1876            let l = cholesky_lower(m).expect("SPD");
1877            2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
1878        };
1879        let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
1880        let evidence = evidence_traces_adaptive(
1881            h.view(),
1882            vec![DerivativeHessian::Dense(a.view())],
1883            None,
1884            ProbeSeed(0x9999),
1885            HUTCHINSON_ADAPTIVE_REL_TOL,
1886            HUTCHINSON_ADAPTIVE_TAU_REL,
1887        )
1888        .expect("adaptive ok");
1889        let est = evidence.traces[0];
1890        let se = evidence.stderrs[0];
1891        let tol = (8.0 * se).max(0.05 * fd.abs());
1892        assert!(
1893            (est - fd).abs() <= tol,
1894            "adaptive trace {est} disagrees with FD logdet derivative {fd} (tol={tol})"
1895        );
1896    }
1897
1898    #[test]
1899    fn block_2_6_k_4096_matches_exact_tightly() {
1900        // (4) A large fixed K (4096 probes) — well past the adaptive
1901        // schedule's max — must drive the Hutchinson estimator to within
1902        // a few SE of exact. Bounds the residual variance and confirms
1903        // the estimator is consistent (not merely unbiased at small K).
1904        let p = 40;
1905        let h = make_spd(p, 0.6);
1906        let a = random_dense_sym(p, 0xDEAD);
1907        let exact = exact_trace_hinv_a(h.view(), a.view());
1908        let input = RemlTraceHutchinsonInput {
1909            penalized_hessian: h.view(),
1910            derivatives: vec![DerivativeHessian::Dense(a.view())],
1911            design: None,
1912            probe_count: 4096,
1913            seed: ProbeSeed(0xC0FFEE),
1914        };
1915        let evidence = evidence_derivatives_hutchinson_gpu(input).expect("ok");
1916        let est = 2.0 * evidence.gradient_rho_logdet[0];
1917        let se = 2.0 * evidence.gradient_rho_stderr[0];
1918        let tol = (6.0 * se).max(1e-3 * exact.abs());
1919        assert!(
1920            (est - exact).abs() <= tol,
1921            "K=4096 Hutchinson {est} not within 6·SE of exact {exact} (tol={tol}, se={se})"
1922        );
1923    }
1924
1925    #[test]
1926    fn block_2_6_crn_prefix_match_across_schedule() {
1927        // (5) Common-random-numbers: the first 16 probes of a K=32 (and
1928        // K=64) draw must be bit-identical to a K=16 draw with the same
1929        // seed. The SplitMix probe RNG is stateless in (seed, k, i), so
1930        // this is what guarantees the adaptive schedule's variance
1931        // monotonically *decreases* rather than oscillating.
1932        let p = 50;
1933        let seed = ProbeSeed(0x4242_4242);
1934        let mut z16 = vec![0.0_f64; p * 16];
1935        let mut z32 = vec![0.0_f64; p * 32];
1936        let mut z64 = vec![0.0_f64; p * 64];
1937        fill_rademacher_host(seed, p, 16, &mut z16);
1938        fill_rademacher_host(seed, p, 32, &mut z32);
1939        fill_rademacher_host(seed, p, 64, &mut z64);
1940        for col in 0..16 {
1941            for row in 0..p {
1942                assert_eq!(z16[col * p + row], z32[col * p + row]);
1943                assert_eq!(z16[col * p + row], z64[col * p + row]);
1944            }
1945        }
1946        for col in 0..32 {
1947            for row in 0..p {
1948                assert_eq!(z32[col * p + row], z64[col * p + row]);
1949            }
1950        }
1951    }
1952
1953    // ────────────────────────────────────────────────────────────────
1954    // Block 2.7: batched-PCG HVP variant tests.
1955    // ────────────────────────────────────────────────────────────────
1956
1957    #[test]
1958    fn block_2_7_hvp_path_matches_dense_adaptive() {
1959        // HVP closure that multiplies a stored dense H matches the
1960        // dense `evidence_traces_adaptive` exactly (same CRN probes,
1961        // same derivative). CG round-off bounded by PCG_HVP_REL_TOL.
1962        let p = 40;
1963        let h = make_spd(p, 0.7);
1964        let a = random_dense_sym(p, 0xABBA);
1965        let seed = ProbeSeed(0x707);
1966
1967        let dense = evidence_traces_adaptive(
1968            h.view(),
1969            vec![DerivativeHessian::Dense(a.view())],
1970            None,
1971            seed,
1972            HUTCHINSON_ADAPTIVE_REL_TOL,
1973            HUTCHINSON_ADAPTIVE_TAU_REL,
1974        )
1975        .expect("dense ok");
1976
1977        let h_clone = h.clone();
1978        let hvp_evidence = evidence_traces_adaptive_hvp(
1979            p,
1980            |v: &[f64], out: &mut [f64]| {
1981                for r in 0..p {
1982                    let mut acc = 0.0_f64;
1983                    for c in 0..p {
1984                        acc += h_clone[[r, c]] * v[c];
1985                    }
1986                    out[r] = acc;
1987                }
1988            },
1989            vec![DerivativeHessian::Dense(a.view())],
1990            None,
1991            seed,
1992            HUTCHINSON_ADAPTIVE_REL_TOL,
1993            HUTCHINSON_ADAPTIVE_TAU_REL,
1994        )
1995        .expect("hvp ok");
1996
1997        // Adaptive may stop at different K if SE crosses the threshold
1998        // at a different step due to CG round-off; compare both
1999        // estimates against exact rather than to each other.
2000        let exact = exact_trace_hinv_a(h.view(), a.view());
2001        let se_dense = dense.stderrs[0];
2002        let se_hvp = hvp_evidence.stderrs[0];
2003        let tol_dense = (8.0 * se_dense).max(0.05 * exact.abs());
2004        let tol_hvp = (8.0 * se_hvp).max(0.05 * exact.abs());
2005        assert!(
2006            (dense.traces[0] - exact).abs() <= tol_dense,
2007            "dense adaptive {} not near exact {} (tol {})",
2008            dense.traces[0],
2009            exact,
2010            tol_dense
2011        );
2012        assert!(
2013            (hvp_evidence.traces[0] - exact).abs() <= tol_hvp,
2014            "hvp adaptive {} not near exact {} (tol {})",
2015            hvp_evidence.traces[0],
2016            exact,
2017            tol_hvp
2018        );
2019        // logdet is intentionally NaN on the HVP path.
2020        assert!(hvp_evidence.logdet_hessian.is_nan());
2021    }
2022
2023    #[test]
2024    fn block_2_7_hvp_stderr_matches_dense_reduce_mean_stderr() {
2025        // The HVP path's `stderrs` must use the SAME estimator convention as
2026        // the dense path's `reduce_mean_stderr`: the Bessel-corrected (K−1)
2027        // standard error of the per-probe q running mean. We force both
2028        // paths to run the full K=128 schedule (rel_tol below any achievable
2029        // ratio) so the comparison is at identical probe counts on identical
2030        // CRN probes. The only residual difference is the inner solve (exact
2031        // Cholesky vs CG@1e-6), which keeps the q values — and hence the SEs —
2032        // agreeing to a tight relative tolerance.
2033        let p = 36;
2034        let h = make_spd(p, 0.6);
2035        let a = random_dense_sym(p, 0x5151);
2036        let seed = ProbeSeed(0xBEEF);
2037        let force_full_schedule = 1e-12_f64;
2038
2039        let dense = evidence_traces_adaptive(
2040            h.view(),
2041            vec![DerivativeHessian::Dense(a.view())],
2042            None,
2043            seed,
2044            force_full_schedule,
2045            HUTCHINSON_ADAPTIVE_TAU_REL,
2046        )
2047        .expect("dense ok");
2048
2049        let h_clone = h.clone();
2050        let hvp = evidence_traces_adaptive_hvp(
2051            p,
2052            |v: &[f64], out: &mut [f64]| {
2053                for r in 0..p {
2054                    let mut acc = 0.0_f64;
2055                    for c in 0..p {
2056                        acc += h_clone[[r, c]] * v[c];
2057                    }
2058                    out[r] = acc;
2059                }
2060            },
2061            vec![DerivativeHessian::Dense(a.view())],
2062            None,
2063            seed,
2064            force_full_schedule,
2065            HUTCHINSON_ADAPTIVE_TAU_REL,
2066        )
2067        .expect("hvp ok");
2068
2069        // Both ran the full schedule, so probe counts match exactly.
2070        assert_eq!(dense.probe_count, 128);
2071        assert_eq!(hvp.probe_count, dense.probe_count);
2072
2073        let sd_dense = dense.stderrs[0];
2074        let sd_hvp = hvp.stderrs[0];
2075        assert!(
2076            sd_dense > 0.0,
2077            "dense SE should be positive, got {sd_dense}"
2078        );
2079        let rel = (sd_hvp - sd_dense).abs() / sd_dense;
2080        assert!(
2081            rel <= 1e-3,
2082            "HVP SE {sd_hvp} disagrees with dense reduce_mean_stderr SE {sd_dense} \
2083             (rel {rel}); the two paths must share the Bessel-corrected (K−1) convention"
2084        );
2085    }
2086
2087    #[test]
2088    fn block_2_7_cg_solves_diagonal_in_one_iteration() {
2089        // For diagonal H, CG converges in one step (Krylov subspace
2090        // contains the exact solution). Verifies the CG residual
2091        // logic and SPD bailout.
2092        let p = 8;
2093        let diag: Vec<f64> = (0..p).map(|i| 1.0 + i as f64).collect();
2094        let b: Vec<f64> = (0..p).map(|i| (i as f64) + 0.5).collect();
2095        let mut w = vec![0.0_f64; p];
2096        let diag_clone = diag.clone();
2097        cg_solve(
2098            &mut |v: &[f64], out: &mut [f64]| {
2099                for i in 0..p {
2100                    out[i] = diag_clone[i] * v[i];
2101                }
2102            },
2103            &b,
2104            &mut w,
2105            1e-12,
2106            PCG_HVP_MAX_ITERS,
2107        );
2108        for i in 0..p {
2109            let expected = b[i] / diag[i];
2110            assert!(
2111                (w[i] - expected).abs() < 1e-10,
2112                "diagonal CG: w[{i}]={} expected {expected}",
2113                w[i]
2114            );
2115        }
2116    }
2117
2118    // ────────────────────────────────────────────────────────────────
2119    // Block 2.8: V100 hill-climb (10× vs exact GPU at p=2000, d_ρ=8).
2120    //
2121    // The assertion only fires when a CUDA runtime is detected;
2122    // on CPU-only hosts the test still runs the timing comparison but
2123    // skips the speedup assertion (exact dense Cholesky is competitive
2124    // with adaptive Hutchinson on a single core, so the 10× lower bound
2125    // is V100-specific). On V100, the adaptive path batches K=16-128
2126    // probes through one potrs while the exact path repeats `d` full
2127    // solves; the bound is therefore comfortable.
2128    // ────────────────────────────────────────────────────────────────
2129
2130    #[test]
2131    fn block_2_8_hill_climb_adaptive_vs_exact_at_p2000_d8() {
2132        // Smaller dimensions on CPU CI to keep the test under a minute;
2133        // V100 runs the full p=2000, d=8 specified in the charter.
2134        let on_v100 =
2135            cfg!(target_os = "linux") && gam_gpu::device_runtime::GpuRuntime::global().is_some();
2136        let (p, d): (usize, usize) = if on_v100 { (2000, 8) } else { (256, 4) };
2137
2138        let mut h = Array2::<f64>::zeros((p, p));
2139        for i in 0..p {
2140            for j in 0..p {
2141                h[[i, j]] = if i == j {
2142                    p as f64 + 1.0
2143                } else {
2144                    1.0 / (1.0 + (i as f64 - j as f64).abs())
2145                };
2146            }
2147        }
2148        let derivs_owned: Vec<Array2<f64>> = (0..d)
2149            .map(|k| random_dense_sym(p, 0x1000 + k as u64))
2150            .collect();
2151        let derivs: Vec<DerivativeHessian<'_>> = derivs_owned
2152            .iter()
2153            .map(|a| DerivativeHessian::Dense(a.view()))
2154            .collect();
2155
2156        // Exact path: factor H once, then `tr(H⁻¹ A_j) = Σᵢ (H⁻¹ A_j)[i,i]`
2157        // by solving H X = A_j column-by-column. This is the cost the
2158        // CPU/exact-spectral path pays per REML outer step.
2159        let t_exact_start = std::time::Instant::now();
2160        let factor = cholesky_lower(&h).expect("SPD");
2161        let mut exact_traces = vec![0.0_f64; d];
2162        for (j, a) in derivs_owned.iter().enumerate() {
2163            let mut acc = 0.0_f64;
2164            for col in 0..p {
2165                let mut rhs = vec![0.0_f64; p];
2166                for r in 0..p {
2167                    rhs[r] = a[[r, col]];
2168                }
2169                let w = solve_cholesky(&factor, &rhs);
2170                acc += w[col];
2171            }
2172            exact_traces[j] = acc;
2173        }
2174        let t_exact = t_exact_start.elapsed();
2175
2176        // Adaptive Hutchinson path.
2177        let t_adaptive_start = std::time::Instant::now();
2178        let evidence = evidence_traces_adaptive(
2179            h.view(),
2180            derivs,
2181            None,
2182            ProbeSeed(0xB10C),
2183            HUTCHINSON_ADAPTIVE_REL_TOL,
2184            HUTCHINSON_ADAPTIVE_TAU_REL,
2185        )
2186        .expect("adaptive ok");
2187        let t_adaptive = t_adaptive_start.elapsed();
2188
2189        // Sanity: every adaptive trace must agree with exact within its
2190        // reported SE. This guards against a fast-but-wrong perf path.
2191        for j in 0..d {
2192            let se = evidence.stderrs[j];
2193            let tol = (10.0 * se).max(0.05 * exact_traces[j].abs());
2194            let diff = (evidence.traces[j] - exact_traces[j]).abs();
2195            assert!(
2196                diff <= tol,
2197                "block_2_8: derivative {j} adaptive {} disagrees with exact {} (tol {tol}, diff {diff})",
2198                evidence.traces[j],
2199                exact_traces[j]
2200            );
2201        }
2202
2203        let speedup = t_exact.as_secs_f64() / t_adaptive.as_secs_f64().max(1e-9);
2204        eprintln!(
2205            "block_2_8 hill-climb [p={p}, d={d}, V100={on_v100}]: \
2206             exact={:?}, adaptive={:?}, speedup={:.2}× (K={}, converged={})",
2207            t_exact, t_adaptive, speedup, evidence.probe_count, evidence.converged
2208        );
2209        if on_v100 {
2210            assert!(
2211                speedup >= 10.0,
2212                "block_2_8 V100 speedup {speedup:.2}× below the 10× target \
2213                 (exact {:?}, adaptive {:?})",
2214                t_exact,
2215                t_adaptive,
2216            );
2217        }
2218    }
2219}