Skip to main content

gam_solve/gpu_kernels/
reml_trace.rs

1//! GPU Hutchinson stochastic trace estimator for the REML/LAML logdet
2//! gradient, per math team block 2 (sections 12–18 of the V100 design).
3//!
4//! Public entry point: [`evidence_derivatives_hutchinson_gpu`]. For each
5//! derivative Hessian `H_j` (`j = 1..D`) and a single penalized Hessian `H`
6//! held resident on device, returns the unbiased Hutchinson estimate of
7//!
8//! ```text
9//! t_j = tr(H^{-1} H_j)
10//! ```
11//!
12//! plus the sample standard error of each estimate, computed from `K`
13//! Rademacher probe vectors `z_k ∈ {±1}^p` whose entries are drawn from a
14//! **stateless SplitMix64 counter hash** (no cuRAND state). The math
15//! identity used on device is
16//!
17//! ```text
18//! z^T H^{-1} H_j z  =  z^T H_j w   where   H w = z
19//! ```
20//!
21//! so we factor `H` **once** with `cusolverDnDpotrf`, batch-solve `H W = Z`
22//! with **one** `cusolverDnDpotrs` of `nrhs = K`, and then evaluate the
23//! quadratic forms with a custom NVRTC reduction kernel. The REML logdet
24//! gradient is `g_j = (1/2) · mean_k(q_{j,k})`.
25//!
26//! Two assembly variants for `H_j` are supported:
27//!
28//! * **Dense** — caller passes `H_j` as a `p × p` device or host matrix.
29//!   GEMM forms `Y_j = H_j W`, then a custom reduction sums
30//!   `z_k^T y_{j,k}` per (j, k). Cost: `D` GEMMs of size `p × p × K`.
31//! * **Weighted-Gram structural** — caller provides the design `X`
32//!   (`n × p`), weight vectors `A_j` (`n`, one per derivative — the
33//!   diagonal of the design's row weights that `H_j` adds), and the
34//!   per-derivative penalty contribution `Q_pen[j,k]` if any. The kernel
35//!   forms `R_Z = X Z` and `R_W = X W` **once** via GEMM and then sums
36//!   `sum_i a_j[i] · R_Z[i,k] · R_W[i,k]` per (j, k) without ever
37//!   materialising the `p × p` `H_j` matrix. Cost: 2 GEMMs of size
38//!   `n × p × K` shared across all `D` derivatives.
39//!
40//! The structural path is the high-value route for large-scale models
41//! where `p` is hundreds and there are many derivatives.
42//!
43//! # Stateless probe RNG
44//!
45//! The probe entries are produced on device by a SplitMix64 finalizer over
46//! `(seed, probe_index k, coordinate i)`. This has three consequences:
47//!
48//! 1. No cuRAND state — the kernel is fully stateless, threads write into
49//!    `Z[i + k·p]` independently.
50//! 2. **Common random numbers**: the first `K1` probes of a run with
51//!    `K2 > K1` are bit-identical to a `K = K1` run with the same seed.
52//!    This is the property that lets the adaptive `K` schedule build on
53//!    earlier probes without re-running them, and lets CPU and GPU
54//!    implementations of Hutchinson compare estimator-by-estimator (the
55//!    same probes produce the same `q_{j,k}` to round-off).
56//! 3. Reproducibility — a probe at `(seed, k, i)` is the same call after
57//!    call regardless of how the grid was scheduled.
58//!
59//! # Gating
60//!
61//! The companion helper [`should_use_gpu_hutchinson`] mirrors the CPU
62//! gate (`prefers_stochastic_trace_estimation` + matching kernel +
63//! plain-SPD logdet path) and adds the GPU-specific minima from the math
64//! team's section 18:
65//!
66//! * `p ≥ 512`
67//! * `K ∈ [8, 128]`
68//! * Hessian and design held resident or about to be uploaded
69//! * The projected penalty-subspace trace is **inactive** (otherwise the
70//!   CPU path projects through the IFT kernel — that route is required
71//!   for marginal-slope ρ-saturated rows)
72
73use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1};
74
75use gam_gpu::gpu_error::GpuError;
76use gam_linalg::pcg::{DotReduction, pcg_core};
77
78// ────────────────────────────────────────────────────────────────────────
79// Public types
80// ────────────────────────────────────────────────────────────────────────
81
82/// Stateless seed for the SplitMix64 Rademacher probe RNG.
83#[derive(Clone, Copy, Debug)]
84pub struct ProbeSeed(pub u64);
85
86impl Default for ProbeSeed {
87    fn default() -> Self {
88        // Matches the CPU default seed (`StochasticTraceConfig::default()`)
89        // so cross-implementation parity tests can use a shared constant.
90        Self(0xCAFE_BABE)
91    }
92}
93
94/// Description of one derivative-Hessian contribution `H_j`.
95///
96/// The estimator needs `H_j` only via the quadratic form `z^T H_j w`, so we
97/// describe `H_j` *structurally* rather than as a dense matrix. The dense
98/// case is recovered by the [`DerivativeHessian::Dense`] variant.
99#[derive(Clone, Debug)]
100pub enum DerivativeHessian<'a> {
101    /// `H_j` is a `p × p` symmetric matrix. The reducer forms `Y = H_j W`
102    /// via GEMM and then sums `z_k^T y_k`.
103    Dense(ArrayView2<'a, f64>),
104    /// `H_j = X^T diag(a_j) X + P_j`, where `a_j` is an `n`-vector of row
105    /// weights and `P_j` is an optional `p × p` direct penalty contribution
106    /// that is *added* to the structural part. The reducer evaluates
107    /// `z^T H_j w  =  sum_i a_j[i] · (X z)[i] · (X w)[i]  +  z^T P_j w`
108    /// without materialising the `p × p` `H_j`.
109    WeightedGram {
110        row_weights: ArrayView1<'a, f64>,
111        penalty_extra: Option<ArrayView2<'a, f64>>,
112    },
113}
114
115impl DerivativeHessian<'_> {
116    fn dim_p(&self, expected_p: usize, expected_n: usize) -> Result<(), GpuError> {
117        match self {
118            DerivativeHessian::Dense(matrix) => {
119                if matrix.nrows() != expected_p || matrix.ncols() != expected_p {
120                    gam_gpu::gpu_bail!(
121                        "reml_trace dense H_j: shape {:?} != ({expected_p}, {expected_p})",
122                        matrix.dim()
123                    );
124                }
125            }
126            DerivativeHessian::WeightedGram {
127                row_weights,
128                penalty_extra,
129            } => {
130                if row_weights.len() != expected_n {
131                    gam_gpu::gpu_bail!(
132                        "reml_trace structural H_j: row_weights.len()={} != n={expected_n}",
133                        row_weights.len()
134                    );
135                }
136                if let Some(p_extra) = penalty_extra
137                    && (p_extra.nrows() != expected_p || p_extra.ncols() != expected_p)
138                {
139                    gam_gpu::gpu_bail!(
140                        "reml_trace structural H_j penalty_extra: shape {:?} != ({expected_p}, {expected_p})",
141                        p_extra.dim()
142                    );
143                }
144            }
145        }
146        Ok(())
147    }
148}
149
150/// Inputs to [`evidence_derivatives_hutchinson_gpu`].
151#[derive(Clone, Debug)]
152pub struct RemlTraceHutchinsonInput<'a> {
153    /// Penalized Hessian `H` (`p × p`, SPD).
154    pub penalized_hessian: ArrayView2<'a, f64>,
155    /// Per-derivative descriptors `H_j`. `D = derivatives.len()`.
156    pub derivatives: Vec<DerivativeHessian<'a>>,
157    /// Design matrix `X` (`n × p`). Required iff any `H_j` is structural;
158    /// `None` is acceptable when **all** derivatives are dense.
159    pub design: Option<ArrayView2<'a, f64>>,
160    /// Number of probe vectors. Must be ≥ 2 (so a sample SE is defined).
161    pub probe_count: usize,
162    /// Stateless RNG seed.
163    pub seed: ProbeSeed,
164}
165
166/// Output of [`evidence_derivatives_hutchinson_gpu`].
167#[derive(Clone, Debug)]
168pub struct RemlTraceHutchinsonEvidence {
169    /// `log |H|` from the cached Cholesky factor (same value the exact GPU
170    /// path returns; reusing the factor amortises this).
171    pub logdet_hessian: f64,
172    /// REML logdet gradient `g_j = (1/2) · mean_k(q_{j,k})`, length `D`.
173    pub gradient_rho_logdet: Array1<f64>,
174    /// Per-probe sample standard deviation of the half-scaled gradient term
175    /// `(1/2)·q_{j,·}` across the `K` probes (i.e. `0.5·sd`, NOT divided by
176    /// `sqrt(K)`), length `D`. To obtain the standard error of the running
177    /// mean `g_j`, divide by `sqrt(K)` (= `sqrt(probe_count)`).
178    pub gradient_rho_stderr: Array1<f64>,
179    /// `K` probes actually used (matches `input.probe_count`).
180    pub probe_count: usize,
181}
182
183// ────────────────────────────────────────────────────────────────────────
184// Gating
185// ────────────────────────────────────────────────────────────────────────
186
187/// Minimum joint-dimension at which the GPU Hutchinson path is enabled.
188pub const HUTCHINSON_GPU_MIN_P: usize = 512;
189/// Minimum and maximum probe counts the GPU path accepts (math section 18).
190pub const HUTCHINSON_GPU_MIN_K: usize = 8;
191pub const HUTCHINSON_GPU_MAX_K: usize = 128;
192
193/// True when the GPU Hutchinson path is eligible at the current shape and
194/// configuration. Caller still has to satisfy the CPU-side gate
195/// (`prefers_stochastic_trace_estimation`, matching kernel, plain-SPD
196/// logdet, projected penalty subspace **inactive**) — the parameters
197/// `prefers_stochastic`, `kernel_matches_hinv`, `plain_spd_logdet`, and
198/// `projected_penalty_subspace_active` carry those CPU-side gate booleans
199/// into the dispatch decision.
200#[must_use]
201pub fn should_use_gpu_hutchinson(
202    p: usize,
203    probe_count: usize,
204    prefers_stochastic: bool,
205    kernel_matches_hinv: bool,
206    plain_spd_logdet: bool,
207    projected_penalty_subspace_active: bool,
208) -> bool {
209    p >= HUTCHINSON_GPU_MIN_P
210        && (HUTCHINSON_GPU_MIN_K..=HUTCHINSON_GPU_MAX_K).contains(&probe_count)
211        && prefers_stochastic
212        && kernel_matches_hinv
213        && plain_spd_logdet
214        && !projected_penalty_subspace_active
215}
216
217// ────────────────────────────────────────────────────────────────────────
218// Stateless SplitMix64 Rademacher RNG (host reference; mirrors the NVRTC
219// kernel byte-for-byte so CPU and GPU produce identical probes for the
220// same `(seed, k, i)`).
221// ────────────────────────────────────────────────────────────────────────
222
223/// SplitMix64 finalizer (Sebastiano Vigna, 2015). Thin wrapper over the
224/// canonical implementation in [`gam_linalg::utils::splitmix64_hash`].
225#[inline]
226pub fn splitmix64_mix(z: u64) -> u64 {
227    gam_linalg::utils::splitmix64_hash(z)
228}
229
230/// Stateless Rademacher entry at probe index `k` (0-based), coordinate
231/// `i` (0-based), seed `s`. Returns `+1.0` or `-1.0`.
232///
233/// The mix is `splitmix64(s ⊕ k·ζ ⊕ i·γ)` for two large odd constants
234/// `ζ`, `γ`; the sign bit (bit 63 of the hash) selects the sign. The two
235/// constants are *different* from the SplitMix increment so the row and
236/// column hashes don't collide on small `(k, i)`.
237#[inline]
238pub fn rademacher_entry(seed: u64, k: u64, i: u64) -> f64 {
239    const ZETA: u64 = 0xD1B5_4A32_D192_ED03;
240    const GAMMA: u64 = 0x8CB9_2BA7_2F9D_E81F;
241    let composite = seed ^ k.wrapping_mul(ZETA) ^ i.wrapping_mul(GAMMA);
242    let h = splitmix64_mix(composite);
243    if (h >> 63) == 0 { 1.0 } else { -1.0 }
244}
245
246/// Host-side reference: fill a column-major `(p, K)` Rademacher matrix.
247/// Used by tests to verify the GPU kernel produces the same bits.
248pub fn fill_rademacher_host(seed: ProbeSeed, p: usize, k: usize, out: &mut [f64]) {
249    assert_eq!(
250        out.len(),
251        p * k,
252        "fill_rademacher_host: out buffer length {} != p*K = {}*{}",
253        out.len(),
254        p,
255        k
256    );
257    for col in 0..k {
258        for row in 0..p {
259            out[col * p + row] = rademacher_entry(seed.0, col as u64, row as u64);
260        }
261    }
262}
263
264// ────────────────────────────────────────────────────────────────────────
265// CPU reference implementation of the Hutchinson estimator
266// ────────────────────────────────────────────────────────────────────────
267//
268// This path is what runs in CPU-only builds and is also what the V100
269// parity tests check the device implementation against. It uses the same
270// stateless SplitMix probes as the kernel.
271
272/// Run the Hutchinson estimator on CPU using the exact same probe bits
273/// the device kernel uses. Returns the same evidence struct.
274pub fn evidence_derivatives_hutchinson_cpu(
275    input: &RemlTraceHutchinsonInput<'_>,
276) -> Result<RemlTraceHutchinsonEvidence, String> {
277    validate_inputs(input)?;
278    let p = input.penalized_hessian.nrows();
279    let d = input.derivatives.len();
280    let k = input.probe_count;
281
282    // Cholesky factor of H (lower).
283    let h = input.penalized_hessian.to_owned();
284    let factor = cholesky_lower(&h)?;
285    let logdet_hessian = 2.0 * (0..p).map(|i| factor[[i, i]].ln()).sum::<f64>();
286
287    // Build Z (p, k) column-major in a flat vector.
288    let mut z = vec![0.0_f64; p * k];
289    fill_rademacher_host(input.seed, p, k, &mut z);
290
291    // Solve H W = Z column by column on CPU (matches what the device
292    // does in one batched potrs call). The K columns are independent — each
293    // `solve_cholesky` reads the shared (immutable) factor and writes only its
294    // own column of `w` — so they parallelize bit-for-bit (no reduction is
295    // reordered; each w-column is produced by exactly one task with identical
296    // arithmetic). The probes are embarrassingly parallel by construction; the
297    // CRN contract lives in the stateless SplitMix fill above, untouched.
298    use rayon::prelude::*;
299    let mut w = vec![0.0_f64; p * k];
300    w.par_chunks_mut(p)
301        .zip(z.par_chunks(p))
302        .for_each(|(w_col, z_col)| {
303            let solved = solve_cholesky(&factor, z_col);
304            w_col.copy_from_slice(&solved);
305        });
306
307    // Per-derivative quadratic forms. Each `q[j*k + col]` is an independent
308    // scalar function of probe column `col` only, so we parallelize over the
309    // probe columns. This is bit-identical to the serial fill: a given q entry
310    // is computed by one task with the same per-entry arithmetic, and the
311    // downstream `reduce_mean_stderr` indexes fixed (j, col) positions — no
312    // sum is reordered across threads.
313    let mut q = vec![0.0_f64; d * k]; // row-major (d, k): q[j*k + m]
314    for (j, derivative) in input.derivatives.iter().enumerate() {
315        let q_row = &mut q[j * k..(j + 1) * k];
316        match derivative {
317            DerivativeHessian::Dense(matrix) => {
318                q_row
319                    .par_iter_mut()
320                    .zip(z.par_chunks(p).zip(w.par_chunks(p)))
321                    .for_each(|(q_jk, (z_col, w_col))| {
322                        // y = H_j w
323                        let mut y = vec![0.0_f64; p];
324                        for r in 0..p {
325                            let mut acc = 0.0_f64;
326                            for c in 0..p {
327                                acc += matrix[[r, c]] * w_col[c];
328                            }
329                            y[r] = acc;
330                        }
331                        let mut zy = 0.0_f64;
332                        for i in 0..p {
333                            zy += z_col[i] * y[i];
334                        }
335                        *q_jk = zy;
336                    });
337            }
338            DerivativeHessian::WeightedGram {
339                row_weights,
340                penalty_extra,
341            } => {
342                let design = input.design.as_ref().expect("design validated");
343                let n = design.nrows();
344                q_row
345                    .par_iter_mut()
346                    .zip(z.par_chunks(p).zip(w.par_chunks(p)))
347                    .for_each(|(q_jk, (z_col, w_col))| {
348                        // r_z = X z (length n), r_w = X w (length n)
349                        let mut acc = 0.0_f64;
350                        for row in 0..n {
351                            let mut rz = 0.0_f64;
352                            let mut rw = 0.0_f64;
353                            for col_idx in 0..p {
354                                rz += design[[row, col_idx]] * z_col[col_idx];
355                                rw += design[[row, col_idx]] * w_col[col_idx];
356                            }
357                            acc += row_weights[row] * rz * rw;
358                        }
359                        if let Some(pen) = penalty_extra {
360                            for r in 0..p {
361                                let mut row_acc = 0.0_f64;
362                                for c in 0..p {
363                                    row_acc += pen[[r, c]] * w_col[c];
364                                }
365                                acc += z_col[r] * row_acc;
366                            }
367                        }
368                        *q_jk = acc;
369                    });
370            }
371        }
372    }
373
374    let (means, stderrs) = reduce_mean_stderr(&q, d, k);
375    let mut gradient_rho_logdet = Array1::<f64>::zeros(d);
376    let mut gradient_rho_stderr = Array1::<f64>::zeros(d);
377    for j in 0..d {
378        gradient_rho_logdet[j] = 0.5 * means[j];
379        gradient_rho_stderr[j] = 0.5 * stderrs[j];
380    }
381
382    Ok(RemlTraceHutchinsonEvidence {
383        logdet_hessian,
384        gradient_rho_logdet,
385        gradient_rho_stderr,
386        probe_count: k,
387    })
388}
389
390// ────────────────────────────────────────────────────────────────────────
391// Public dispatch entry point
392// ────────────────────────────────────────────────────────────────────────
393
394/// Compute `log |H|` and the Hutchinson estimate of `(1/2) tr(H^{-1} H_j)`
395/// for every derivative. Dispatches to the device-resident path when the
396/// CUDA runtime is up and probes the GPU successfully; otherwise runs the
397/// CPU reference. Either way the probe bits are identical (stateless
398/// SplitMix), so callers see the same estimator value to round-off.
399pub fn evidence_derivatives_hutchinson_gpu(
400    input: RemlTraceHutchinsonInput<'_>,
401) -> Result<RemlTraceHutchinsonEvidence, String> {
402    validate_inputs(&input)?;
403
404    #[cfg(target_os = "linux")]
405    {
406        if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
407            match linux_cuda::evidence_derivatives(&input) {
408                Ok(evidence) => return Ok(evidence),
409                Err(GpuError::NoDeviceKernel { .. }) => {
410                    // No device kernel for this path on this build: fall
411                    // through to the CPU reference.
412                }
413                Err(other) => return Err(String::from(other)),
414            }
415        }
416    }
417
418    evidence_derivatives_hutchinson_cpu(&input)
419}
420
421// ────────────────────────────────────────────────────────────────────────
422// Adaptive K (Block 2.5)
423// ────────────────────────────────────────────────────────────────────────
424
425/// Default relative-error target for the adaptive-K stopping rule.
426/// Matches `StochasticTraceConfig::default().relative_tol`.
427pub const HUTCHINSON_ADAPTIVE_REL_TOL: f64 = 0.01;
428/// Default near-zero-trace protection floor. Matches
429/// `StochasticTraceConfig::default().tau_rel`.
430pub const HUTCHINSON_ADAPTIVE_TAU_REL: f64 = 1e-8;
431
432/// Adaptive-K Hutchinson trace schedule with common random numbers (CRN).
433///
434/// Repeatedly invokes [`evidence_derivatives_hutchinson_gpu`] with probe
435/// counts `K = 16, 32, 64, 128`, stopping at the first `K` that satisfies
436/// the per-coordinate relative-SE criterion
437///
438/// ```text
439/// max_j  s_j / (sqrt(K) · max(|t_j|, τ))  ≤  ε
440/// ```
441///
442/// where `s_j` is the sample standard deviation across the `K` probes (the
443/// raw quadratic-form sample, *without* the `(1/2)` REML logdet scaling)
444/// and `t_j` is the running mean. Because the SplitMix probe RNG is
445/// stateless (`(seed, k_index, i) → ±1`), the first `K_prev` probes of a
446/// `K = 2·K_prev` re-run are bit-identical to the previous batch, so each
447/// step extends the prior estimate rather than starting fresh in
448/// expectation. The implementation re-runs from scratch at each `K` for
449/// simplicity; CRN is preserved by the stateless RNG seed.
450///
451/// Returns the **raw traces** `t_j = tr(H⁻¹ H_j) = mean_k q_{j,k}`
452/// (length `D`), the `log|H|` from the cached Cholesky, and the final
453/// probe count `K` actually used. The raw traces (not the `(1/2)` REML
454/// logdet gradient) are what the outer evaluator wants — it applies the
455/// logdet-gradient half-factor itself.
456pub struct AdaptiveTraceEvidence {
457    pub logdet_hessian: f64,
458    pub traces: Array1<f64>,
459    /// Per-probe sample standard deviation of the raw trace estimator
460    /// `q_{j,·}` (NOT divided by `sqrt(K)`); divide by `sqrt(probe_count)`
461    /// to obtain the standard error of the running mean `traces[j]`.
462    pub stderrs: Array1<f64>,
463    pub probe_count: usize,
464    pub converged: bool,
465}
466
467pub fn evidence_traces_adaptive<'a>(
468    penalized_hessian: ArrayView2<'a, f64>,
469    derivatives: Vec<DerivativeHessian<'a>>,
470    design: Option<ArrayView2<'a, f64>>,
471    seed: ProbeSeed,
472    rel_tol: f64,
473    tau_rel: f64,
474) -> Result<AdaptiveTraceEvidence, String> {
475    // Adaptive schedule per math team block 2 §16: K = 16, 32, 64, 128.
476    const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
477
478    let d = derivatives.len();
479    if d == 0 {
480        return Err("evidence_traces_adaptive: derivatives is empty".to_string());
481    }
482    if !(rel_tol > 0.0) {
483        return Err(format!(
484            "evidence_traces_adaptive: rel_tol must be > 0 (got {rel_tol})"
485        ));
486    }
487    if !(tau_rel > 0.0) {
488        return Err(format!(
489            "evidence_traces_adaptive: tau_rel must be > 0 (got {tau_rel})"
490        ));
491    }
492
493    let mut last_logdet = 0.0_f64;
494    let mut last_traces = Array1::<f64>::zeros(d);
495    let mut last_stderrs = Array1::<f64>::zeros(d);
496    let mut last_k = 0_usize;
497    let mut converged = false;
498
499    for &k in &SCHEDULE {
500        let input = RemlTraceHutchinsonInput {
501            penalized_hessian,
502            derivatives: derivatives.clone(),
503            design,
504            probe_count: k,
505            seed,
506        };
507        let evidence = evidence_derivatives_hutchinson_gpu(input)?;
508        last_logdet = evidence.logdet_hessian;
509        last_k = k;
510
511        // The dispatch entry returns the **(1/2)·mean** REML logdet
512        // gradient and **(1/2)·per-probe sample SD**. Undo the half to
513        // recover the raw `t_j = mean_k q_{j,k}` and the per-probe sample
514        // standard deviation `s_j` the stopping rule wants.
515        for j in 0..d {
516            last_traces[j] = 2.0 * evidence.gradient_rho_logdet[j];
517            last_stderrs[j] = 2.0 * evidence.gradient_rho_stderr[j];
518        }
519
520        // Stopping rule (math block 2 §16):
521        //   max_j  s_j / (sqrt(K) · max(|t_j|, τ))  ≤  ε
522        // `s_j` here is the per-probe sample standard deviation across the K
523        // probes (`reduce_mean_stderr` returns the raw SD, NOT the SE-of-mean);
524        // dividing by sqrt(K) once converts it to the standard error of the
525        // running mean. (The earlier double-sqrt(K) division — once here and
526        // once already inside `reduce_mean_stderr` — made this test sqrt(K)×
527        // too lax, stopping the schedule at too-small K; see #829.)
528        let sqrt_k = (k as f64).sqrt();
529        let mut worst = 0.0_f64;
530        for j in 0..d {
531            let denom = sqrt_k * last_traces[j].abs().max(tau_rel);
532            let r = last_stderrs[j] / denom;
533            if r > worst {
534                worst = r;
535            }
536        }
537        if worst <= rel_tol {
538            converged = true;
539            break;
540        }
541    }
542
543    Ok(AdaptiveTraceEvidence {
544        logdet_hessian: last_logdet,
545        traces: last_traces,
546        stderrs: last_stderrs,
547        probe_count: last_k,
548        converged,
549    })
550}
551
552// ────────────────────────────────────────────────────────────────────────
553// Block 2.7: batched-PCG HVP variant of adaptive Hutchinson
554// ────────────────────────────────────────────────────────────────────────
555
556/// CG convergence tolerance for the per-probe solve `H w = z`. The outer
557/// adaptive-K loop already drives Hutchinson variance to ~1%; a per-probe
558/// relative residual of 1e-6 keeps the CG round-off well below the
559/// stochastic SE without paying for double-machine convergence.
560pub const PCG_HVP_REL_TOL: f64 = 1e-6;
561
562/// Maximum CG iterations per probe before we stop and accept the partial
563/// solve. Capped so a poorly conditioned `H` cannot make a single REML
564/// step pay unbounded time — the Hutchinson estimator is statistically
565/// robust to a few stale `w_k` values (it inflates SE, which the adaptive
566/// stopping rule then catches by extending the schedule).
567pub const PCG_HVP_MAX_ITERS: usize = 200;
568
569/// Adaptive Hutchinson variant that consumes `H` as a matrix-free HVP
570/// closure rather than a dense `ArrayView2`. Used by call sites where the
571/// penalized Hessian is implicit (operator-only) and forming it densely
572/// would blow the memory budget — e.g. the device-resident PCG path in
573/// `gpu/bms_flex_row.rs` or the large-scale BMS Schur operator.
574///
575/// `hvp` must compute `out ← H · v` for an SPD `H`. The closure is called
576/// once per CG iteration per probe (so `K · iters_per_probe` times in
577/// total for each schedule step). It is responsible for any necessary
578/// pre-conditioning state, threading, or device residency — the routine
579/// itself is pure CPU.
580///
581/// `derivatives` are still passed as dense or `WeightedGram`; the
582/// adaptive trace `t_j = mean_k z_k^T H_j w_k` only needs `H_j` to be
583/// available as a matvec, and the dense / weighted-Gram variants of
584/// `DerivativeHessian::quadratic_form` already provide that.
585///
586/// CRN is preserved exactly as in [`evidence_traces_adaptive`]: the
587/// SplitMix probe RNG is stateless in `(seed, k_index, i)`, so the
588/// `K=16, 32, 64, 128` schedule extends the prior estimate rather than
589/// restarting it. Each schedule step re-runs all `K` solves; the
590/// implementation is intentionally simple, the asymptotic cost is
591/// dominated by the largest `K`.
592///
593/// Returns the same [`AdaptiveTraceEvidence`] shape as the dense path,
594/// with one exception: `logdet_hessian` is **NaN** because no Cholesky
595/// is performed. Callers needing both `tr(H⁻¹ H_j)` and `log|H|` from
596/// the matrix-free path should obtain `log|H|` separately (e.g. via
597/// stochastic Lanczos or by routing through the dense path when `H`
598/// fits in memory).
599pub fn evidence_traces_adaptive_hvp<F>(
600    p: usize,
601    mut hvp: F,
602    derivatives: Vec<DerivativeHessian<'_>>,
603    design: Option<ArrayView2<'_, f64>>,
604    seed: ProbeSeed,
605    rel_tol: f64,
606    tau_rel: f64,
607) -> Result<AdaptiveTraceEvidence, String>
608where
609    F: FnMut(&[f64], &mut [f64]),
610{
611    const SCHEDULE: [usize; 4] = [16, 32, 64, 128];
612
613    let d = derivatives.len();
614    if d == 0 {
615        return Err("evidence_traces_adaptive_hvp: derivatives is empty".to_string());
616    }
617    if p == 0 {
618        return Err("evidence_traces_adaptive_hvp: p must be > 0".to_string());
619    }
620    if !(rel_tol > 0.0) {
621        return Err(format!(
622            "evidence_traces_adaptive_hvp: rel_tol must be > 0 (got {rel_tol})"
623        ));
624    }
625    if !(tau_rel > 0.0) {
626        return Err(format!(
627            "evidence_traces_adaptive_hvp: tau_rel must be > 0 (got {tau_rel})"
628        ));
629    }
630
631    let mut last_traces = Array1::<f64>::zeros(d);
632    let mut last_stderrs = Array1::<f64>::zeros(d);
633    let mut last_k = 0_usize;
634    let mut converged = false;
635
636    let mut z = vec![0.0_f64; p];
637    let mut w = vec![0.0_f64; p];
638
639    // Per-derivative Welford accumulators (running mean and sum-of-squared
640    // deviations M2) for a numerically stable online mean / sample variance.
641    // The naive one-pass form E[q²] − E[q]² catastrophically cancels when the
642    // per-probe q cluster far from zero with small spread — exactly the
643    // near-converged regime the stopping rule cares about — so we track M2
644    // directly to match the two-pass `reduce_mean_stderr` without that loss.
645    let mut q_means = vec![0.0_f64; d];
646    let mut q_m2 = vec![0.0_f64; d];
647
648    for &k_target in &SCHEDULE {
649        // Re-run from scratch at each schedule step — CRN guarantees the
650        // first min(K_prev, K_target) probes are bit-identical, so the
651        // estimator is monotone in expectation across schedule extensions.
652        for s in q_means.iter_mut() {
653            *s = 0.0;
654        }
655        for s in q_m2.iter_mut() {
656            *s = 0.0;
657        }
658
659        for k_idx in 0..k_target {
660            // Fill z_k from the stateless SplitMix RNG.
661            for i in 0..p {
662                z[i] = rademacher_entry(seed.0, k_idx as u64, i as u64);
663            }
664            // Solve H w = z by unpreconditioned CG.
665            cg_solve(&mut hvp, &z, &mut w, PCG_HVP_REL_TOL, PCG_HVP_MAX_ITERS);
666
667            // Reduce q_{j,k} = z^T H_j w for each derivative. Mirrors the
668            // dense reference in `evidence_derivatives_hutchinson_cpu`.
669            for j in 0..d {
670                let q = match &derivatives[j] {
671                    DerivativeHessian::Dense(matrix) => {
672                        let mut y = 0.0_f64;
673                        for r in 0..p {
674                            let mut hr_w = 0.0_f64;
675                            for c in 0..p {
676                                hr_w += matrix[[r, c]] * w[c];
677                            }
678                            y += z[r] * hr_w;
679                        }
680                        y
681                    }
682                    DerivativeHessian::WeightedGram {
683                        row_weights,
684                        penalty_extra,
685                    } => {
686                        let design_view = design.as_ref().ok_or_else(|| {
687                            "evidence_traces_adaptive_hvp: WeightedGram derivative requires \
688                             design matrix"
689                                .to_string()
690                        })?;
691                        let n = design_view.nrows();
692                        let mut acc = 0.0_f64;
693                        for row in 0..n {
694                            let mut rz = 0.0_f64;
695                            let mut rw = 0.0_f64;
696                            for ci in 0..p {
697                                rz += design_view[[row, ci]] * z[ci];
698                                rw += design_view[[row, ci]] * w[ci];
699                            }
700                            acc += row_weights[row] * rz * rw;
701                        }
702                        if let Some(pen) = penalty_extra {
703                            for r in 0..p {
704                                let mut row_acc = 0.0_f64;
705                                for c in 0..p {
706                                    row_acc += pen[[r, c]] * w[c];
707                                }
708                                acc += z[r] * row_acc;
709                            }
710                        }
711                        acc
712                    }
713                };
714                // Welford update with the 1-based probe count (k_idx + 1).
715                let count = (k_idx + 1) as f64;
716                let delta = q - q_means[j];
717                q_means[j] += delta / count;
718                let delta2 = q - q_means[j];
719                q_m2[j] += delta * delta2;
720            }
721        }
722
723        let n = k_target as f64;
724        let mut worst_ratio = 0.0_f64;
725        for j in 0..d {
726            let mean = q_means[j];
727            // Sample variance M2 / (K−1) — Bessel's correction, matching the
728            // two-pass `reduce_mean_stderr` exactly (no one-pass cancellation).
729            // For K = 1 there is no spread to estimate, so the variance is 0.
730            let var = if n > 1.0 { q_m2[j] / (n - 1.0) } else { 0.0 };
731            let s = var.sqrt();
732            last_traces[j] = mean;
733            last_stderrs[j] = s;
734            let denom = n.sqrt() * mean.abs().max(tau_rel);
735            let r = s / denom;
736            if r > worst_ratio {
737                worst_ratio = r;
738            }
739        }
740        last_k = k_target;
741        if worst_ratio <= rel_tol {
742            converged = true;
743            break;
744        }
745    }
746
747    Ok(AdaptiveTraceEvidence {
748        logdet_hessian: f64::NAN,
749        traces: last_traces,
750        stderrs: last_stderrs,
751        probe_count: last_k,
752        converged,
753    })
754}
755
756/// Unpreconditioned conjugate gradients for `H w = b` with `H` accessed
757/// only through `hvp(v, out) → out ← H v`. SPD `H` is required.
758/// Initial guess is `w = 0`; stops when `‖r‖ ≤ rel_tol · ‖b‖` or after
759/// `max_iters` iterations.
760///
761/// Thin wrapper over the shared [`pcg_core`] (`linalg::pcg`): unpreconditioned
762/// (all-ones Jacobi diagonal), no residual refresh (`refresh_period = 0`), and
763/// no diagnostics. On a breakdown (lost SPD near convergence, non-finite
764/// scalar) the core stops and leaves the last valid iterate in `w`, which is
765/// the historical "accept current w" behavior.
766///
767/// Reduction: [`DotReduction::Reordered`]. This is the stochastic Hutchinson
768/// trace probe, NOT the main solve. The per-probe CG residual (`rel_tol`
769/// ≈ 1e-6) sits orders of magnitude below the estimator's own sampling SE, and
770/// the adaptive-K stopping rule budgets against that SE — so reordering the
771/// inner-product accumulation (ILP/SIMD reduction) only perturbs bits already
772/// dominated by Monte-Carlo noise. The CRN reproducibility that matters here is
773/// in the SplitMix probe RNG (`rademacher_entry`), which is untouched; we do
774/// NOT need the cross-thread bit-identity that the main SPD solve contracts
775/// for, so we trade it for add-side ILP on the hot per-iteration folds.
776fn cg_solve<F>(hvp: &mut F, b: &[f64], w: &mut [f64], rel_tol: f64, max_iters: usize)
777where
778    F: FnMut(&[f64], &mut [f64]),
779{
780    let n = b.len();
781    assert!(w.len() == n);
782
783    let rhs = ArrayView1::from(b);
784    let precond = Array1::<f64>::ones(n);
785    let mut solution = ArrayViewMut1::from(w);
786
787    pcg_core(
788        |v: &Array1<f64>, out: &mut Array1<f64>| {
789            // The core hands contiguous vectors; `hvp` speaks raw slices.
790            let v_slice = v.as_slice().expect("contiguous CG direction view");
791            let out_slice = out.as_slice_mut().expect("contiguous CG matvec view");
792            hvp(v_slice, out_slice);
793        },
794        &rhs,
795        &precond.view(),
796        rel_tol,
797        max_iters,
798        0,
799        false,
800        DotReduction::Reordered,
801        &mut solution,
802    );
803}
804
805// ────────────────────────────────────────────────────────────────────────
806// Outer logdet-gradient dispatch gate (Block 2.5)
807// ────────────────────────────────────────────────────────────────────────
808
809/// Composite gate predicate for the outer REML logdet-gradient bypass:
810/// when this returns `true`, the unified evaluator should replace its
811/// CPU stochastic-trace call with [`evidence_traces_adaptive`].
812///
813/// All five conditions must hold simultaneously:
814/// * `p ≥ 512` and `K_initial..=K_max` is `[16, 128]`
815/// * `H` is resident as a dense SPD operator (caller passes
816///   `dense_spd_h_resident = true` when `hop.as_exact_dense_spectral()`
817///   is `Some` AND the Cholesky succeeds — the latter is checked
818///   indirectly by `plain_spd_logdet`).
819/// * `plain_spd_logdet`: the operator's logdet kernel is `H⁻¹` exactly
820///   (i.e. `hop.logdet_traces_match_hinv_kernel() && hop.is_dense()`),
821///   so smooth-spectral and SCOP-warped paths are excluded.
822/// * `prefers_stochastic`: `hop.prefers_stochastic_trace_estimation()`.
823/// * `!projected_penalty_subspace_active`: the rank-deficient LAML
824///   projected kernel `U_S H_proj⁻¹ U_Sᵀ` is **not** installed.
825#[must_use]
826pub fn should_bypass_cpu_with_gpu_adaptive(
827    p: usize,
828    dense_spd_h_resident: bool,
829    plain_spd_logdet: bool,
830    prefers_stochastic: bool,
831    projected_penalty_subspace_active: bool,
832) -> bool {
833    p >= HUTCHINSON_GPU_MIN_P
834        && dense_spd_h_resident
835        && plain_spd_logdet
836        && prefers_stochastic
837        && !projected_penalty_subspace_active
838}
839
840// ────────────────────────────────────────────────────────────────────────
841// Linux/CUDA implementation
842// ────────────────────────────────────────────────────────────────────────
843
844#[cfg(target_os = "linux")]
845mod linux_cuda {
846    use super::{
847        DerivativeHessian, ProbeSeed, RemlTraceHutchinsonEvidence, RemlTraceHutchinsonInput,
848        reduce_mean_stderr,
849    };
850    use gam_gpu::driver::to_col_major;
851    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
852    use gam_gpu::solver::{
853        cholesky_logdet_from_col_major, context_and_stream, pinned_htod, potrf_in_place,
854        potrs_in_place,
855    };
856    use cudarc::cublas::sys::cublasOperation_t;
857    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
858    use cudarc::cusolver::DnHandle;
859    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
860    use std::sync::Arc;
861
862    /// NVRTC source for the three custom kernels used by this path. All
863    /// arithmetic is in `double` and the layouts are column-major to match
864    /// cuBLAS/cuSOLVER conventions.
865    ///
866    /// * `fill_rademacher_splitmix(seed, p, K, Z)` — stateless ±1 fill.
867    /// * `reduce_q_dense(p, K, D, Z, Y_stack, Q)` — `Q[j,k] = z_k^T Y_j[:,k]`
868    ///   with `Y_j[:,k] = (H_j W)[:,k]`. `Y_stack` is column-major shape
869    ///   `(p, K·D)` with derivative `j` occupying columns `[j·K, (j+1)·K)`.
870    /// * `reduce_q_weighted_gram(n, K, D, RZ_stride, RZ, RW, A_stack, Q)`
871    ///   — `Q[j,k] = sum_i A[i,j] · RZ[i,k] · RW[i,k]`. Used by the
872    ///   structural path. `A_stack` is column-major `(n, D)`.
873    ///
874    /// The reductions use a per-block warp-shuffle pattern with one block
875    /// per `(j, k)` output cell and `THREADS_PER_BLOCK` threads per block.
876    pub(super) const PTX_SOURCE: &str = r#"
877extern "C" __device__ unsigned long long splitmix64_mix(unsigned long long z) {
878    z += 0x9E3779B97F4A7C15ULL;
879    unsigned long long x = z;
880    x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9ULL;
881    x = (x ^ (x >> 27)) * 0x94D049BB133111EBULL;
882    return x ^ (x >> 31);
883}
884
885extern "C" __global__ void fill_rademacher_splitmix(
886    unsigned long long seed,
887    unsigned int p,
888    unsigned int K,
889    double* __restrict__ Z)
890{
891    unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
892    unsigned int k = blockIdx.y;
893    if (i >= p || k >= K) return;
894    const unsigned long long ZETA  = 0xD1B54A32D192ED03ULL;
895    const unsigned long long GAMMA = 0x8CB92BA72F9DE81FULL;
896    unsigned long long composite =
897        seed
898        ^ (((unsigned long long)k) * ZETA)
899        ^ (((unsigned long long)i) * GAMMA);
900    unsigned long long h = splitmix64_mix(composite);
901    double v = (h >> 63) == 0 ? 1.0 : -1.0;
902    Z[(size_t)k * (size_t)p + (size_t)i] = v;
903}
904
905extern "C" __device__ double block_reduce_sum(double v) {
906    __shared__ double smem[32];
907    int lane = threadIdx.x & 31;
908    int wid  = threadIdx.x >> 5;
909    for (int off = 16; off > 0; off >>= 1) {
910        v += __shfl_down_sync(0xffffffff, v, off);
911    }
912    if (lane == 0) smem[wid] = v;
913    __syncthreads();
914    double total = 0.0;
915    int n_warps = (blockDim.x + 31) >> 5;
916    if (threadIdx.x < (unsigned)n_warps) total = smem[threadIdx.x];
917    if (wid == 0) {
918        for (int off = 16; off > 0; off >>= 1) {
919            total += __shfl_down_sync(0xffffffff, total, off);
920        }
921    }
922    return total;
923}
924
925extern "C" __global__ void reduce_q_dense(
926    unsigned int p,
927    unsigned int K,
928    unsigned int D,
929    const double* __restrict__ Z,
930    const double* __restrict__ Y_stack,
931    double* __restrict__ Q)
932{
933    unsigned int k = blockIdx.x;
934    unsigned int j = blockIdx.y;
935    if (k >= K || j >= D) return;
936    const double* z_col = Z + (size_t)k * (size_t)p;
937    const double* y_col = Y_stack + ((size_t)j * (size_t)K + (size_t)k) * (size_t)p;
938    double partial = 0.0;
939    for (unsigned int i = threadIdx.x; i < p; i += blockDim.x) {
940        partial += z_col[i] * y_col[i];
941    }
942    double total = block_reduce_sum(partial);
943    if (threadIdx.x == 0) {
944        Q[(size_t)j * (size_t)K + (size_t)k] = total;
945    }
946}
947
948extern "C" __global__ void reduce_q_weighted_gram(
949    unsigned int n,
950    unsigned int K,
951    unsigned int D,
952    const double* __restrict__ RZ,
953    const double* __restrict__ RW,
954    const double* __restrict__ A_stack,
955    double* __restrict__ Q)
956{
957    unsigned int k = blockIdx.x;
958    unsigned int j = blockIdx.y;
959    if (k >= K || j >= D) return;
960    const double* rz_col = RZ + (size_t)k * (size_t)n;
961    const double* rw_col = RW + (size_t)k * (size_t)n;
962    const double* a_col  = A_stack + (size_t)j * (size_t)n;
963    double partial = 0.0;
964    for (unsigned int i = threadIdx.x; i < n; i += blockDim.x) {
965        partial += a_col[i] * rz_col[i] * rw_col[i];
966    }
967    double total = block_reduce_sum(partial);
968    if (threadIdx.x == 0) {
969        Q[(size_t)j * (size_t)K + (size_t)k] = total;
970    }
971}
972"#;
973
974    const THREADS_PER_BLOCK: u32 = 256;
975
976    fn module(ctx: &Arc<CudaContext>) -> Result<&'static Arc<CudaModule>, GpuError> {
977        static CACHE: gam_gpu::device_cache::PtxModuleCache =
978            gam_gpu::device_cache::PtxModuleCache::new();
979        CACHE.get_or_compile(ctx, "reml_trace", PTX_SOURCE)
980    }
981
982    pub(super) fn evidence_derivatives(
983        input: &RemlTraceHutchinsonInput<'_>,
984    ) -> Result<RemlTraceHutchinsonEvidence, GpuError> {
985        let p = input.penalized_hessian.nrows();
986        let d = input.derivatives.len();
987        let k = input.probe_count;
988        let (ctx, stream) =
989            context_and_stream().map_err(|reason| GpuError::DriverCallFailed { reason })?;
990        let solver = DnHandle::new(stream.clone()).gpu_ctx("reml_trace cusolver init")?;
991        let blas = CudaBlas::new(stream.clone()).gpu_ctx("reml_trace cublas init")?;
992        let compiled = module(&ctx)?;
993        let module_handle: &Arc<CudaModule> = compiled;
994
995        // ── 1. Upload H, factor once.
996        let h_col = to_col_major(&input.penalized_hessian);
997        let mut h_dev =
998            pinned_htod(&stream, &h_col).map_err(|reason| GpuError::DriverCallFailed { reason })?;
999        potrf_in_place(&solver, &stream, p, &mut h_dev)
1000            .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1001        let factor_col = stream
1002            .clone_dtoh(&h_dev)
1003            .gpu_ctx("reml_trace download factor")?;
1004        let logdet_hessian = cholesky_logdet_from_col_major(&factor_col, p);
1005
1006        // ── 2. Allocate Z (p, K) and fill with Rademacher entries on device.
1007        let total_z = p
1008            .checked_mul(k)
1009            .ok_or_else(|| gam_gpu::gpu_err!("reml_trace Z size overflow: p={p}, K={k}"))?;
1010        let mut z_dev = stream
1011            .alloc_zeros::<f64>(total_z)
1012            .gpu_ctx("reml_trace alloc Z")?;
1013        launch_fill_rademacher(&stream, module_handle, input.seed, p, k, &mut z_dev)?;
1014
1015        // ── 3. Solve H W = Z in a single batched potrs call (nrhs = K).
1016        //     Copy Z into a fresh buffer first; potrs is in-place.
1017        let mut w_dev = stream
1018            .alloc_zeros::<f64>(total_z)
1019            .gpu_ctx("reml_trace alloc W")?;
1020        copy_device_slice(&stream, &z_dev, &mut w_dev)?;
1021        potrs_in_place(&solver, &stream, p, k, &h_dev, &mut w_dev)
1022            .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1023
1024        // ── 4. Partition derivatives by kind.
1025        let mut dense_indices: Vec<usize> = Vec::new();
1026        let mut gram_indices: Vec<usize> = Vec::new();
1027        for (j, deriv) in input.derivatives.iter().enumerate() {
1028            match deriv {
1029                DerivativeHessian::Dense(_) => dense_indices.push(j),
1030                DerivativeHessian::WeightedGram { .. } => gram_indices.push(j),
1031            }
1032        }
1033
1034        let mut q_host = vec![0.0_f64; d * k];
1035
1036        // ── 5a. Dense path: for each dense H_j run a p×p × p×K GEMM and
1037        //       reduce. We loop over j rather than stacking the H_j's
1038        //       (would explode memory at large-scale-p), but the GEMMs share
1039        //       the resident W buffer.
1040        if !dense_indices.is_empty() {
1041            for &j in &dense_indices {
1042                let DerivativeHessian::Dense(matrix) = &input.derivatives[j] else {
1043                    // SAFETY: dense_indices was populated in the partition loop above
1044                    // with exactly the indices whose variant is DerivativeHessian::Dense.
1045                    // input.derivatives is immutably borrowed for the whole function so
1046                    // the slot at index j cannot have been rewritten between partition and
1047                    // this read; reaching this branch can only mean a future refactor split
1048                    // the partition from its consumer. The panic names the offending index.
1049                    panic!(
1050                        "reml_trace dense path: derivative index {j} is in dense_indices but \
1051                         input.derivatives[{j}] is not DerivativeHessian::Dense — \
1052                         dense_indices partition invariant violated"
1053                    );
1054                };
1055                let hj_col = to_col_major(matrix);
1056                let hj_dev = pinned_htod(&stream, &hj_col)
1057                    .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1058                let mut y_dev = stream
1059                    .alloc_zeros::<f64>(total_z)
1060                    .map_err(|err| gam_gpu::gpu_err!("reml_trace alloc Y_j (j={j}): {err}"))?;
1061                gemm_nn(
1062                    &blas,
1063                    GemmShape {
1064                        m: p,
1065                        n: k,
1066                        k_inner: p,
1067                        lda: p,
1068                        ldb: p,
1069                        ldc: p,
1070                    },
1071                    &hj_dev,
1072                    &w_dev,
1073                    &mut y_dev,
1074                )?;
1075                let mut q_j_dev = stream
1076                    .alloc_zeros::<f64>(k)
1077                    .gpu_ctx_with(|err| format!("reml_trace alloc Q_j (j={j}): {err}"))?;
1078                launch_reduce_q_dense(
1079                    &stream,
1080                    module_handle,
1081                    p,
1082                    k,
1083                    1,
1084                    &z_dev,
1085                    &y_dev,
1086                    &mut q_j_dev,
1087                )?;
1088                let q_host_j = stream
1089                    .clone_dtoh(&q_j_dev)
1090                    .gpu_ctx_with(|err| format!("reml_trace download Q_j (j={j}): {err}"))?;
1091                q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_j);
1092            }
1093        }
1094
1095        // ── 5b. Structural path: form R_Z = X Z and R_W = X W **once**,
1096        //       then run reduce_q_weighted_gram for each derivative.
1097        if !gram_indices.is_empty() {
1098            let design = input
1099                .design
1100                .as_ref()
1101                .ok_or_else(|| GpuError::DriverCallFailed {
1102                    reason: "reml_trace: structural derivative present but design=None".to_string(),
1103                })?;
1104            let n = design.nrows();
1105            let design_col = to_col_major(design);
1106            let x_dev = pinned_htod(&stream, &design_col)
1107                .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1108            let mut rz_dev = stream
1109                .alloc_zeros::<f64>(
1110                    n.checked_mul(k)
1111                        .ok_or_else(|| gam_gpu::gpu_err!("reml_trace RZ overflow: n={n}, K={k}"))?,
1112                )
1113                .gpu_ctx("reml_trace alloc RZ")?;
1114            let mut rw_dev = stream
1115                .alloc_zeros::<f64>(n * k)
1116                .gpu_ctx("reml_trace alloc RW")?;
1117            // R_Z = X Z   (n × p) · (p × K) -> (n × K)
1118            gemm_nn(
1119                &blas,
1120                GemmShape {
1121                    m: n,
1122                    n: k,
1123                    k_inner: p,
1124                    lda: n,
1125                    ldb: p,
1126                    ldc: n,
1127                },
1128                &x_dev,
1129                &z_dev,
1130                &mut rz_dev,
1131            )?;
1132            // R_W = X W
1133            gemm_nn(
1134                &blas,
1135                GemmShape {
1136                    m: n,
1137                    n: k,
1138                    k_inner: p,
1139                    lda: n,
1140                    ldb: p,
1141                    ldc: n,
1142                },
1143                &x_dev,
1144                &w_dev,
1145                &mut rw_dev,
1146            )?;
1147
1148            // Stack the row-weight vectors into A_stack column-major (n × D_gram).
1149            let d_gram = gram_indices.len();
1150            let mut a_stack = Vec::<f64>::with_capacity(n * d_gram);
1151            for &j in &gram_indices {
1152                let DerivativeHessian::WeightedGram { row_weights, .. } = &input.derivatives[j]
1153                else {
1154                    // SAFETY: gram_indices was populated in the partition loop above with
1155                    // exactly the indices whose variant is DerivativeHessian::WeightedGram.
1156                    // input.derivatives is immutably borrowed for the whole function so the
1157                    // slot at j cannot have been rewritten between partition and read; a
1158                    // failure here is a future-refactor bug, not a runtime input issue.
1159                    panic!(
1160                        "reml_trace structural path: derivative index {j} is in gram_indices \
1161                         but input.derivatives[{j}] is not DerivativeHessian::WeightedGram — \
1162                         gram_indices partition invariant violated"
1163                    );
1164                };
1165                let slice = row_weights.as_slice().ok_or_else(|| {
1166                    gam_gpu::gpu_err!("reml_trace structural H_j={j} row_weights not contiguous")
1167                })?;
1168                a_stack.extend_from_slice(slice);
1169            }
1170            let a_dev = pinned_htod(&stream, &a_stack)
1171                .map_err(|reason| GpuError::DriverCallFailed { reason })?;
1172            let mut q_dev = stream
1173                .alloc_zeros::<f64>(d_gram * k)
1174                .map_err(|err| gam_gpu::gpu_err!("reml_trace alloc Q_gram: {err}"))?;
1175            launch_reduce_q_weighted_gram(
1176                &stream,
1177                module_handle,
1178                n,
1179                k,
1180                d_gram,
1181                &rz_dev,
1182                &rw_dev,
1183                &a_dev,
1184                &mut q_dev,
1185            )?;
1186            let q_host_gram = stream
1187                .clone_dtoh(&q_dev)
1188                .gpu_ctx("reml_trace download Q_gram")?;
1189            for (slot, &j) in gram_indices.iter().enumerate() {
1190                q_host[j * k..(j + 1) * k].copy_from_slice(&q_host_gram[slot * k..(slot + 1) * k]);
1191            }
1192            // penalty_extra contributions (uncommon, dense p×p) — handled on
1193            // host to keep the kernel surface small; total cost p² · K per
1194            // derivative that has one.
1195            for &j in &gram_indices {
1196                let DerivativeHessian::WeightedGram { penalty_extra, .. } = &input.derivatives[j]
1197                else {
1198                    // SAFETY: gram_indices was populated by the partition loop above with
1199                    // exactly the WeightedGram-variant indices; the same indices are
1200                    // re-walked here to pick up the optional penalty_extra field.
1201                    // input.derivatives has been immutably borrowed since partitioning, so
1202                    // the variant at index j cannot have changed. A let-else failure here
1203                    // would mean a future refactor split partition from consumer loops.
1204                    panic!(
1205                        "reml_trace structural penalty_extra: derivative index {j} is in \
1206                         gram_indices but input.derivatives[{j}] is not \
1207                         DerivativeHessian::WeightedGram — gram_indices partition invariant \
1208                         violated"
1209                    );
1210                };
1211                if let Some(pen) = penalty_extra {
1212                    let z_host = stream
1213                        .clone_dtoh(&z_dev)
1214                        .gpu_ctx("reml_trace download Z for penalty_extra")?;
1215                    let w_host = stream
1216                        .clone_dtoh(&w_dev)
1217                        .gpu_ctx("reml_trace download W for penalty_extra")?;
1218                    for col in 0..k {
1219                        let z_col = &z_host[col * p..(col + 1) * p];
1220                        let w_col = &w_host[col * p..(col + 1) * p];
1221                        let mut acc = 0.0_f64;
1222                        for r in 0..p {
1223                            let mut row_acc = 0.0_f64;
1224                            for c in 0..p {
1225                                row_acc += pen[[r, c]] * w_col[c];
1226                            }
1227                            acc += z_col[r] * row_acc;
1228                        }
1229                        q_host[j * k + col] += acc;
1230                    }
1231                }
1232            }
1233        }
1234
1235        let (means, stderrs) = reduce_mean_stderr(&q_host, d, k);
1236        let mut gradient_rho_logdet = ndarray::Array1::<f64>::zeros(d);
1237        let mut gradient_rho_stderr = ndarray::Array1::<f64>::zeros(d);
1238        for j in 0..d {
1239            gradient_rho_logdet[j] = 0.5 * means[j];
1240            gradient_rho_stderr[j] = 0.5 * stderrs[j];
1241        }
1242
1243        Ok(RemlTraceHutchinsonEvidence {
1244            logdet_hessian,
1245            gradient_rho_logdet,
1246            gradient_rho_stderr,
1247            probe_count: k,
1248        })
1249    }
1250
1251    // ───── kernel launch wrappers ────────────────────────────────────────
1252
1253    fn launch_fill_rademacher(
1254        stream: &Arc<CudaStream>,
1255        module: &Arc<CudaModule>,
1256        seed: ProbeSeed,
1257        p: usize,
1258        k: usize,
1259        z: &mut cudarc::driver::CudaSlice<f64>,
1260    ) -> Result<(), GpuError> {
1261        let func = module
1262            .load_function("fill_rademacher_splitmix")
1263            .gpu_ctx("reml_trace load fill_rademacher")?;
1264        let grid_x = ((p as u32) + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
1265        let cfg = LaunchConfig {
1266            grid_dim: (grid_x, k as u32, 1),
1267            block_dim: (THREADS_PER_BLOCK, 1, 1),
1268            shared_mem_bytes: 0,
1269        };
1270        let seed_arg: u64 = seed.0;
1271        let p_arg: u32 = p as u32;
1272        let k_arg: u32 = k as u32;
1273        // SAFETY: kernel signature matches arg types; Z is a live device
1274        // buffer sized p*k.
1275        unsafe {
1276            stream
1277                .launch_builder(&func)
1278                .arg(&seed_arg)
1279                .arg(&p_arg)
1280                .arg(&k_arg)
1281                .arg(z)
1282                .launch(cfg)
1283        }
1284        .map(|_| ())
1285        .gpu_ctx("reml_trace launch fill_rademacher")
1286    }
1287
1288    fn launch_reduce_q_dense(
1289        stream: &Arc<CudaStream>,
1290        module: &Arc<CudaModule>,
1291        p: usize,
1292        k: usize,
1293        d: usize,
1294        z: &cudarc::driver::CudaSlice<f64>,
1295        y_stack: &cudarc::driver::CudaSlice<f64>,
1296        q: &mut cudarc::driver::CudaSlice<f64>,
1297    ) -> Result<(), GpuError> {
1298        let func = module
1299            .load_function("reduce_q_dense")
1300            .gpu_ctx("reml_trace load reduce_q_dense")?;
1301        let cfg = LaunchConfig {
1302            grid_dim: (k as u32, d as u32, 1),
1303            block_dim: (THREADS_PER_BLOCK, 1, 1),
1304            shared_mem_bytes: 0,
1305        };
1306        let p_arg: u32 = p as u32;
1307        let k_arg: u32 = k as u32;
1308        let d_arg: u32 = d as u32;
1309        // SAFETY: kernel signature matches; Z is (p,K), Y_stack is (p,K*D),
1310        // Q is (D,K) row-major as documented.
1311        unsafe {
1312            stream
1313                .launch_builder(&func)
1314                .arg(&p_arg)
1315                .arg(&k_arg)
1316                .arg(&d_arg)
1317                .arg(z)
1318                .arg(y_stack)
1319                .arg(q)
1320                .launch(cfg)
1321        }
1322        .map(|_| ())
1323        .gpu_ctx("reml_trace launch reduce_q_dense")
1324    }
1325
1326    fn launch_reduce_q_weighted_gram(
1327        stream: &Arc<CudaStream>,
1328        module: &Arc<CudaModule>,
1329        n: usize,
1330        k: usize,
1331        d: usize,
1332        rz: &cudarc::driver::CudaSlice<f64>,
1333        rw: &cudarc::driver::CudaSlice<f64>,
1334        a_stack: &cudarc::driver::CudaSlice<f64>,
1335        q: &mut cudarc::driver::CudaSlice<f64>,
1336    ) -> Result<(), GpuError> {
1337        let func = module
1338            .load_function("reduce_q_weighted_gram")
1339            .gpu_ctx("reml_trace load reduce_q_weighted_gram")?;
1340        let cfg = LaunchConfig {
1341            grid_dim: (k as u32, d as u32, 1),
1342            block_dim: (THREADS_PER_BLOCK, 1, 1),
1343            shared_mem_bytes: 0,
1344        };
1345        let n_arg: u32 = n as u32;
1346        let k_arg: u32 = k as u32;
1347        let d_arg: u32 = d as u32;
1348        // SAFETY: kernel signature matches; RZ, RW are (n,K), A_stack is (n,D).
1349        unsafe {
1350            stream
1351                .launch_builder(&func)
1352                .arg(&n_arg)
1353                .arg(&k_arg)
1354                .arg(&d_arg)
1355                .arg(rz)
1356                .arg(rw)
1357                .arg(a_stack)
1358                .arg(q)
1359                .launch(cfg)
1360        }
1361        .map(|_| ())
1362        .gpu_ctx("reml_trace launch reduce_q_weighted_gram")
1363    }
1364
1365    fn copy_device_slice(
1366        stream: &Arc<CudaStream>,
1367        src: &cudarc::driver::CudaSlice<f64>,
1368        dst: &mut cudarc::driver::CudaSlice<f64>,
1369    ) -> Result<(), GpuError> {
1370        stream.memcpy_dtod(src, dst).gpu_ctx("reml_trace dtod copy")
1371    }
1372
1373    struct GemmShape {
1374        m: usize,
1375        n: usize,
1376        k_inner: usize,
1377        lda: usize,
1378        ldb: usize,
1379        ldc: usize,
1380    }
1381
1382    fn gemm_nn(
1383        blas: &CudaBlas,
1384        shape: GemmShape,
1385        a: &cudarc::driver::CudaSlice<f64>,
1386        b: &cudarc::driver::CudaSlice<f64>,
1387        c: &mut cudarc::driver::CudaSlice<f64>,
1388    ) -> Result<(), GpuError> {
1389        let GemmShape {
1390            m,
1391            n,
1392            k_inner,
1393            lda,
1394            ldb,
1395            ldc,
1396        } = shape;
1397        let cfg = GemmConfig::<f64> {
1398            transa: cublasOperation_t::CUBLAS_OP_N,
1399            transb: cublasOperation_t::CUBLAS_OP_N,
1400            m: m as i32,
1401            n: n as i32,
1402            k: k_inner as i32,
1403            alpha: 1.0,
1404            lda: lda as i32,
1405            ldb: ldb as i32,
1406            beta: 0.0,
1407            ldc: ldc as i32,
1408        };
1409        // SAFETY: dgemm with column-major leading dims documented above;
1410        // buffers a, b, c sized lda*k_inner, ldb*n, ldc*n.
1411        unsafe { blas.gemm(cfg, a, b, c) }.gpu_ctx("reml_trace cublas dgemm")
1412    }
1413}
1414
1415// ────────────────────────────────────────────────────────────────────────
1416// Shared validation + linear algebra helpers
1417// ────────────────────────────────────────────────────────────────────────
1418
1419fn validate_inputs(input: &RemlTraceHutchinsonInput<'_>) -> Result<(), String> {
1420    let (p, p2) = input.penalized_hessian.dim();
1421    if p == 0 || p != p2 {
1422        return Err(format!("reml_trace input H must be square, got {p}x{p2}"));
1423    }
1424    if input.probe_count < 2 {
1425        return Err(format!(
1426            "reml_trace requires probe_count >= 2 for a sample SE, got {}",
1427            input.probe_count
1428        ));
1429    }
1430    let needs_design = input
1431        .derivatives
1432        .iter()
1433        .any(|d| matches!(d, DerivativeHessian::WeightedGram { .. }));
1434    if needs_design && input.design.is_none() {
1435        return Err("reml_trace: structural derivative present but design=None".to_string());
1436    }
1437    let n = input.design.as_ref().map(|x| x.nrows()).unwrap_or(0);
1438    if let Some(x) = input.design.as_ref()
1439        && x.ncols() != p
1440    {
1441        return Err(format!(
1442            "reml_trace design has {} columns, expected p={p}",
1443            x.ncols()
1444        ));
1445    }
1446    for (j, derivative) in input.derivatives.iter().enumerate() {
1447        derivative
1448            .dim_p(p, n)
1449            .map_err(String::from)
1450            .map_err(|e| format!("reml_trace derivative {j}: {e}"))?;
1451    }
1452    Ok(())
1453}
1454
1455/// Compute the per-derivative sample mean and **per-probe sample standard
1456/// deviation** from the flat row-major (D, K) Q matrix. The SD uses Bessel's
1457/// correction (K-1) for the variance; it is NOT divided by `sqrt(K)`. Callers
1458/// that want the standard error of the running mean must divide by `sqrt(K)`
1459/// themselves (the stopping rule and the tests do). Returning the raw sample
1460/// SD here keeps a single convention shared with the HVP adaptive path
1461/// (which folds `var.sqrt()` directly into its `stderrs`), and avoids the
1462/// historical double `1/sqrt(K)` division that under-counted the variance
1463/// band by a factor of `K`.
1464fn reduce_mean_stderr(q: &[f64], d: usize, k: usize) -> (Vec<f64>, Vec<f64>) {
1465    assert_eq!(
1466        q.len(),
1467        d * k,
1468        "reduce_mean_stderr: q buffer length {} != D*K = {}*{}",
1469        q.len(),
1470        d,
1471        k
1472    );
1473    let mut means = vec![0.0_f64; d];
1474    let mut stderrs = vec![0.0_f64; d];
1475    let inv_k = 1.0 / (k as f64);
1476    for j in 0..d {
1477        let row = &q[j * k..(j + 1) * k];
1478        let mean = row.iter().copied().sum::<f64>() * inv_k;
1479        means[j] = mean;
1480        if k >= 2 {
1481            let var = row.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / ((k - 1) as f64);
1482            stderrs[j] = var.sqrt();
1483        }
1484    }
1485    (means, stderrs)
1486}
1487
1488// ── Cholesky helpers (CPU reference only) ──────────────────────────────
1489
1490fn cholesky_lower(matrix: &Array2<f64>) -> Result<Array2<f64>, String> {
1491    let n = matrix.nrows();
1492    let mut l = Array2::<f64>::zeros((n, n));
1493    for i in 0..n {
1494        for j in 0..=i {
1495            let mut sum = matrix[[i, j]];
1496            for k in 0..j {
1497                sum -= l[[i, k]] * l[[j, k]];
1498            }
1499            if i == j {
1500                if sum <= 0.0 {
1501                    return Err(format!(
1502                        "reml_trace CPU Cholesky: non-SPD diagonal {sum} at row {i}"
1503                    ));
1504                }
1505                l[[i, j]] = sum.sqrt();
1506            } else {
1507                l[[i, j]] = sum / l[[j, j]];
1508            }
1509        }
1510    }
1511    Ok(l)
1512}
1513
1514fn solve_cholesky(l: &Array2<f64>, rhs: &[f64]) -> Vec<f64> {
1515    let n = l.nrows();
1516    let mut y = vec![0.0_f64; n];
1517    for i in 0..n {
1518        let mut sum = rhs[i];
1519        for k in 0..i {
1520            sum -= l[[i, k]] * y[k];
1521        }
1522        y[i] = sum / l[[i, i]];
1523    }
1524    let mut x = vec![0.0_f64; n];
1525    for i in (0..n).rev() {
1526        let mut sum = y[i];
1527        for k in (i + 1)..n {
1528            sum -= l[[k, i]] * x[k];
1529        }
1530        x[i] = sum / l[[i, i]];
1531    }
1532    x
1533}
1534
1535// ────────────────────────────────────────────────────────────────────────
1536// Tests
1537// ────────────────────────────────────────────────────────────────────────
1538
1539#[cfg(test)]
1540mod tests {
1541    use super::*;
1542    use ndarray::{Array2, ArrayView2};
1543
1544    fn make_spd(p: usize, jitter: f64) -> Array2<f64> {
1545        let mut h = Array2::<f64>::zeros((p, p));
1546        for i in 0..p {
1547            for j in 0..p {
1548                h[[i, j]] = if i == j {
1549                    p as f64 + jitter
1550                } else {
1551                    1.0 / (1.0 + (i as f64 - j as f64).abs())
1552                };
1553            }
1554        }
1555        h
1556    }
1557
1558    fn random_dense_sym(p: usize, seed: u64) -> Array2<f64> {
1559        let mut a = Array2::<f64>::zeros((p, p));
1560        let mut s = seed;
1561        for i in 0..p {
1562            for j in i..p {
1563                s = splitmix64_mix(s.wrapping_add(1));
1564                let v = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
1565                a[[i, j]] = v;
1566                a[[j, i]] = v;
1567            }
1568        }
1569        a
1570    }
1571
1572    fn exact_trace_hinv_a(h: ArrayView2<f64>, a: ArrayView2<f64>) -> f64 {
1573        let p = h.nrows();
1574        let factor = cholesky_lower(&h.to_owned()).expect("SPD");
1575        let mut trace = 0.0;
1576        for col in 0..p {
1577            let mut e = vec![0.0_f64; p];
1578            e[col] = 1.0;
1579            let w = solve_cholesky(&factor, &e);
1580            // (H^{-1} A) diag entry [col, col] = sum_i A[col, i] * w[i]
1581            let mut diag = 0.0;
1582            for i in 0..p {
1583                diag += a[[col, i]] * w[i];
1584            }
1585            trace += diag;
1586        }
1587        trace
1588    }
1589
1590    #[test]
1591    fn splitmix_is_deterministic_and_disperses() {
1592        // Self-consistency: same input → same output, and a few near-by
1593        // inputs land in distinct buckets (no trivial collisions).
1594        assert_eq!(splitmix64_mix(42), splitmix64_mix(42));
1595        let mut bits_seen = 0u64;
1596        for x in 0u64..64 {
1597            bits_seen |= splitmix64_mix(x);
1598        }
1599        assert_eq!(
1600            bits_seen,
1601            u64::MAX,
1602            "splitmix should cover every bit position across 64 inputs"
1603        );
1604    }
1605
1606    #[test]
1607    fn rademacher_entries_are_pm_one_and_stateless() {
1608        let seed = ProbeSeed(0xCAFE_BABE);
1609        for k in 0..16u64 {
1610            for i in 0..32u64 {
1611                let v = rademacher_entry(seed.0, k, i);
1612                assert!(
1613                    v == 1.0 || v == -1.0,
1614                    "non-pm1 entry at (k={k}, i={i}): {v}"
1615                );
1616                let v2 = rademacher_entry(seed.0, k, i);
1617                assert_eq!(v, v2, "same (k,i) must hash to same value");
1618            }
1619        }
1620    }
1621
1622    #[test]
1623    fn rademacher_common_random_numbers_match_for_prefix() {
1624        // First 16 probes of a K=16 run must equal first 16 probes of K=32.
1625        let p = 50;
1626        let mut z16 = vec![0.0_f64; p * 16];
1627        let mut z32 = vec![0.0_f64; p * 32];
1628        fill_rademacher_host(ProbeSeed(7), p, 16, &mut z16);
1629        fill_rademacher_host(ProbeSeed(7), p, 32, &mut z32);
1630        for col in 0..16 {
1631            for row in 0..p {
1632                assert_eq!(
1633                    z16[col * p + row],
1634                    z32[col * p + row],
1635                    "CRN broken at (col={col}, row={row})"
1636                );
1637            }
1638        }
1639    }
1640
1641    #[test]
1642    fn cpu_hutchinson_unbiased_against_exact_small_spd() {
1643        let p = 16;
1644        let h = make_spd(p, 0.5);
1645        let a1 = random_dense_sym(p, 0x1234);
1646        let a2 = random_dense_sym(p, 0x5678);
1647        let exact1 = exact_trace_hinv_a(h.view(), a1.view());
1648        let exact2 = exact_trace_hinv_a(h.view(), a2.view());
1649        let input = RemlTraceHutchinsonInput {
1650            penalized_hessian: h.view(),
1651            derivatives: vec![
1652                DerivativeHessian::Dense(a1.view()),
1653                DerivativeHessian::Dense(a2.view()),
1654            ],
1655            design: None,
1656            probe_count: 4096,
1657            seed: ProbeSeed(0xCAFE_BABE),
1658        };
1659        let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
1660        // gradient = 0.5 * trace, so multiply estimate by 2 for the trace.
1661        let est1 = 2.0 * evidence.gradient_rho_logdet[0];
1662        let est2 = 2.0 * evidence.gradient_rho_logdet[1];
1663        // `gradient_rho_stderr` is the per-probe sample SD of the half-scaled
1664        // gradient; the trace SE of the K-probe mean is `2·sd/sqrt(K)`.
1665        let sqrt_k = (evidence.probe_count as f64).sqrt();
1666        let se1 = 2.0 * evidence.gradient_rho_stderr[0] / sqrt_k;
1667        let se2 = 2.0 * evidence.gradient_rho_stderr[1] / sqrt_k;
1668        let tol1 = 6.0 * se1.max(1e-8);
1669        let tol2 = 6.0 * se2.max(1e-8);
1670        assert!(
1671            (est1 - exact1).abs() <= tol1,
1672            "Hutchinson est {est1} too far from exact {exact1} (tol={tol1}, se={})",
1673            evidence.gradient_rho_stderr[0]
1674        );
1675        assert!(
1676            (est2 - exact2).abs() <= tol2,
1677            "Hutchinson est {est2} too far from exact {exact2} (tol={tol2})"
1678        );
1679    }
1680
1681    #[test]
1682    fn structural_path_matches_dense_for_xtwx() {
1683        // Build H_j = X^T diag(a) X exactly; both the dense and the
1684        // structural descriptor must produce the same q value per probe.
1685        let n = 40;
1686        let p = 8;
1687        let mut x = Array2::<f64>::zeros((n, p));
1688        let mut s = 11u64;
1689        for r in 0..n {
1690            for c in 0..p {
1691                s = splitmix64_mix(s.wrapping_add(1));
1692                x[[r, c]] = ((s >> 11) as f64) / ((1u64 << 53) as f64) - 0.5;
1693            }
1694        }
1695        let a: Vec<f64> = (0..n).map(|i| 0.5 + 0.01 * (i as f64)).collect();
1696        let a_arr = ndarray::Array1::from(a);
1697        // H_j dense
1698        let mut hj_dense = Array2::<f64>::zeros((p, p));
1699        for r in 0..p {
1700            for c in 0..p {
1701                let mut acc = 0.0;
1702                for i in 0..n {
1703                    acc += x[[i, r]] * a_arr[i] * x[[i, c]];
1704                }
1705                hj_dense[[r, c]] = acc;
1706            }
1707        }
1708        // SPD H so the solve is well posed.
1709        let mut h = make_spd(p, 1.0);
1710        for i in 0..p {
1711            h[[i, i]] += 1.0;
1712        }
1713        let input_dense = RemlTraceHutchinsonInput {
1714            penalized_hessian: h.view(),
1715            derivatives: vec![DerivativeHessian::Dense(hj_dense.view())],
1716            design: None,
1717            probe_count: 32,
1718            seed: ProbeSeed(123),
1719        };
1720        let input_struct = RemlTraceHutchinsonInput {
1721            penalized_hessian: h.view(),
1722            derivatives: vec![DerivativeHessian::WeightedGram {
1723                row_weights: a_arr.view(),
1724                penalty_extra: None,
1725            }],
1726            design: Some(x.view()),
1727            probe_count: 32,
1728            seed: ProbeSeed(123),
1729        };
1730        let e_dense = evidence_derivatives_hutchinson_cpu(&input_dense).expect("ok");
1731        let e_struct = evidence_derivatives_hutchinson_cpu(&input_struct).expect("ok");
1732        // Same probes, same H_j ⇒ identical estimator (modulo round-off).
1733        assert!(
1734            (e_dense.gradient_rho_logdet[0] - e_struct.gradient_rho_logdet[0]).abs() < 1e-9,
1735            "dense vs structural mismatch: dense={}, struct={}",
1736            e_dense.gradient_rho_logdet[0],
1737            e_struct.gradient_rho_logdet[0]
1738        );
1739    }
1740
1741    #[test]
1742    fn finite_difference_check_against_logdet() {
1743        // For H(rho) = H0 + rho * A, d/d(rho) log|H| = tr(H^{-1} A).
1744        let p = 10;
1745        let h0 = make_spd(p, 0.2);
1746        let a = random_dense_sym(p, 0xABCD);
1747        let eps = 1e-4;
1748        let mut hp = h0.clone();
1749        let mut hm = h0.clone();
1750        for i in 0..p {
1751            for j in 0..p {
1752                hp[[i, j]] += eps * a[[i, j]];
1753                hm[[i, j]] -= eps * a[[i, j]];
1754            }
1755        }
1756        let ld = |m: &Array2<f64>| -> f64 {
1757            let l = cholesky_lower(m).unwrap();
1758            2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
1759        };
1760        let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
1761        let exact = exact_trace_hinv_a(h0.view(), a.view());
1762        assert!(
1763            (fd - exact).abs() / exact.abs().max(1e-12) < 1e-6,
1764            "FD logdet derivative {fd} != exact trace {exact}"
1765        );
1766        // And Hutchinson should land near 0.5 * exact (the gradient form).
1767        let input = RemlTraceHutchinsonInput {
1768            penalized_hessian: h0.view(),
1769            derivatives: vec![DerivativeHessian::Dense(a.view())],
1770            design: None,
1771            probe_count: 4096,
1772            seed: ProbeSeed(0xAA55),
1773        };
1774        let evidence = evidence_derivatives_hutchinson_cpu(&input).expect("ok");
1775        // SE of the half-scaled gradient mean = (per-probe sample SD) / sqrt(K).
1776        let se = evidence.gradient_rho_stderr[0] / (evidence.probe_count as f64).sqrt();
1777        let tol = 8.0 * se.max(1e-8);
1778        assert!(
1779            (evidence.gradient_rho_logdet[0] - 0.5 * exact).abs() < tol,
1780            "Hutchinson gradient {} not within 8·SE of 0.5·exact={}",
1781            evidence.gradient_rho_logdet[0],
1782            0.5 * exact
1783        );
1784    }
1785
1786    #[test]
1787    fn gate_rejects_below_min_p() {
1788        assert!(!should_use_gpu_hutchinson(64, 16, true, true, true, false));
1789    }
1790
1791    #[test]
1792    fn gate_rejects_k_out_of_range() {
1793        assert!(!should_use_gpu_hutchinson(2000, 4, true, true, true, false));
1794        assert!(!should_use_gpu_hutchinson(
1795            2000, 200, true, true, true, false
1796        ));
1797    }
1798
1799    #[test]
1800    fn gate_rejects_when_subspace_active() {
1801        assert!(!should_use_gpu_hutchinson(2000, 16, true, true, true, true));
1802    }
1803
1804    #[test]
1805    fn gate_accepts_canonical_case() {
1806        assert!(should_use_gpu_hutchinson(2000, 16, true, true, true, false));
1807    }
1808
1809    // ────────────────────────────────────────────────────────────────
1810    // Block 2.6: adaptive-K validation tests.
1811    //
1812    // All five run on CPU hosts (where `evidence_derivatives_hutchinson_gpu`
1813    // falls back to the SplitMix CPU reference) and on V100 hosts (where the
1814    // CUDA path takes over). Probe-level CRN is preserved across both paths.
1815    // ────────────────────────────────────────────────────────────────
1816
1817    #[test]
1818    fn block_2_6_adaptive_unbiased_against_exact_p512() {
1819        // (1) Adaptive Hutchinson with the default ε must land near the
1820        // exact `tr(H⁻¹ A)` within its reported stopping tolerance.
1821        let p = 64;
1822        let h = make_spd(p, 0.5);
1823        let a = random_dense_sym(p, 0xBADC0DE);
1824        let exact = exact_trace_hinv_a(h.view(), a.view());
1825        let evidence = evidence_traces_adaptive(
1826            h.view(),
1827            vec![DerivativeHessian::Dense(a.view())],
1828            None,
1829            ProbeSeed(0xA5A5A5),
1830            HUTCHINSON_ADAPTIVE_REL_TOL,
1831            HUTCHINSON_ADAPTIVE_TAU_REL,
1832        )
1833        .expect("adaptive run ok");
1834        let est = evidence.traces[0];
1835        let se = evidence.stderrs[0] / (evidence.probe_count as f64).sqrt();
1836        let tol = (8.0 * se).max(0.05 * exact.abs());
1837        assert!(
1838            (est - exact).abs() <= tol,
1839            "adaptive est {est} far from exact {exact} (tol={tol}, se={se}, K={})",
1840            evidence.probe_count
1841        );
1842    }
1843
1844    #[test]
1845    fn block_2_6_same_probes_cpu_vs_dispatch() {
1846        // (2) The dispatch entry (`_gpu`) and the explicit CPU reference
1847        // must produce identical estimates when given the same probes.
1848        // The dispatcher falls back to the CPU reference on non-CUDA hosts,
1849        // so this is a tautology on CPU; on V100 it asserts bit-identical
1850        // q-values across paths (the `q_{j,k}=z_k^T H_j w_k` reduction is
1851        // deterministic to machine precision once probes match).
1852        let p = 32;
1853        let h = make_spd(p, 0.3);
1854        let a = random_dense_sym(p, 0x1357);
1855        let input = RemlTraceHutchinsonInput {
1856            penalized_hessian: h.view(),
1857            derivatives: vec![DerivativeHessian::Dense(a.view())],
1858            design: None,
1859            probe_count: 16,
1860            seed: ProbeSeed(0xBEEF),
1861        };
1862        let cpu = evidence_derivatives_hutchinson_cpu(&input).expect("cpu");
1863        let dispatch = evidence_derivatives_hutchinson_gpu(input).expect("dispatch");
1864        let diff = (cpu.gradient_rho_logdet[0] - dispatch.gradient_rho_logdet[0]).abs();
1865        assert!(
1866            diff < 1e-9,
1867            "same-probes CPU vs GPU dispatch differ: cpu={}, dispatch={}, diff={diff}",
1868            cpu.gradient_rho_logdet[0],
1869            dispatch.gradient_rho_logdet[0]
1870        );
1871    }
1872
1873    #[test]
1874    fn block_2_6_fd_logdet_matches_adaptive() {
1875        // (3) Adaptive estimate of `tr(H⁻¹ A)` should agree with the
1876        // central-difference derivative `d/dρ log|H + ρA|` at ρ=0.
1877        let p = 24;
1878        let h = make_spd(p, 0.4);
1879        let a = random_dense_sym(p, 0x2468);
1880        let eps = 1e-4;
1881        let mut hp = h.clone();
1882        let mut hm = h.clone();
1883        for i in 0..p {
1884            for j in 0..p {
1885                hp[[i, j]] += eps * a[[i, j]];
1886                hm[[i, j]] -= eps * a[[i, j]];
1887            }
1888        }
1889        let ld = |m: &Array2<f64>| -> f64 {
1890            let l = cholesky_lower(m).expect("SPD");
1891            2.0 * (0..p).map(|i| l[[i, i]].ln()).sum::<f64>()
1892        };
1893        let fd = (ld(&hp) - ld(&hm)) / (2.0 * eps);
1894        let evidence = evidence_traces_adaptive(
1895            h.view(),
1896            vec![DerivativeHessian::Dense(a.view())],
1897            None,
1898            ProbeSeed(0x9999),
1899            HUTCHINSON_ADAPTIVE_REL_TOL,
1900            HUTCHINSON_ADAPTIVE_TAU_REL,
1901        )
1902        .expect("adaptive ok");
1903        let est = evidence.traces[0];
1904        let se = evidence.stderrs[0] / (evidence.probe_count as f64).sqrt();
1905        let tol = (8.0 * se).max(0.05 * fd.abs());
1906        assert!(
1907            (est - fd).abs() <= tol,
1908            "adaptive trace {est} disagrees with FD logdet derivative {fd} (tol={tol})"
1909        );
1910    }
1911
1912    #[test]
1913    fn block_2_6_k_4096_matches_exact_tightly() {
1914        // (4) A large fixed K (4096 probes) — well past the adaptive
1915        // schedule's max — must drive the Hutchinson estimator to within
1916        // a few SE of exact. Bounds the residual variance and confirms
1917        // the estimator is consistent (not merely unbiased at small K).
1918        let p = 40;
1919        let h = make_spd(p, 0.6);
1920        let a = random_dense_sym(p, 0xDEAD);
1921        let exact = exact_trace_hinv_a(h.view(), a.view());
1922        let input = RemlTraceHutchinsonInput {
1923            penalized_hessian: h.view(),
1924            derivatives: vec![DerivativeHessian::Dense(a.view())],
1925            design: None,
1926            probe_count: 4096,
1927            seed: ProbeSeed(0xC0FFEE),
1928        };
1929        let evidence = evidence_derivatives_hutchinson_gpu(input).expect("ok");
1930        let est = 2.0 * evidence.gradient_rho_logdet[0];
1931        let se = 2.0 * evidence.gradient_rho_stderr[0] / (4096_f64).sqrt();
1932        let tol = (6.0 * se).max(1e-3 * exact.abs());
1933        assert!(
1934            (est - exact).abs() <= tol,
1935            "K=4096 Hutchinson {est} not within 6·SE of exact {exact} (tol={tol}, se={se})"
1936        );
1937    }
1938
1939    #[test]
1940    fn block_2_6_crn_prefix_match_across_schedule() {
1941        // (5) Common-random-numbers: the first 16 probes of a K=32 (and
1942        // K=64) draw must be bit-identical to a K=16 draw with the same
1943        // seed. The SplitMix probe RNG is stateless in (seed, k, i), so
1944        // this is what guarantees the adaptive schedule's variance
1945        // monotonically *decreases* rather than oscillating.
1946        let p = 50;
1947        let seed = ProbeSeed(0x4242_4242);
1948        let mut z16 = vec![0.0_f64; p * 16];
1949        let mut z32 = vec![0.0_f64; p * 32];
1950        let mut z64 = vec![0.0_f64; p * 64];
1951        fill_rademacher_host(seed, p, 16, &mut z16);
1952        fill_rademacher_host(seed, p, 32, &mut z32);
1953        fill_rademacher_host(seed, p, 64, &mut z64);
1954        for col in 0..16 {
1955            for row in 0..p {
1956                assert_eq!(z16[col * p + row], z32[col * p + row]);
1957                assert_eq!(z16[col * p + row], z64[col * p + row]);
1958            }
1959        }
1960        for col in 0..32 {
1961            for row in 0..p {
1962                assert_eq!(z32[col * p + row], z64[col * p + row]);
1963            }
1964        }
1965    }
1966
1967    // ────────────────────────────────────────────────────────────────
1968    // Block 2.7: batched-PCG HVP variant tests.
1969    // ────────────────────────────────────────────────────────────────
1970
1971    #[test]
1972    fn block_2_7_hvp_path_matches_dense_adaptive() {
1973        // HVP closure that multiplies a stored dense H matches the
1974        // dense `evidence_traces_adaptive` exactly (same CRN probes,
1975        // same derivative). CG round-off bounded by PCG_HVP_REL_TOL.
1976        let p = 40;
1977        let h = make_spd(p, 0.7);
1978        let a = random_dense_sym(p, 0xABBA);
1979        let seed = ProbeSeed(0x707);
1980
1981        let dense = evidence_traces_adaptive(
1982            h.view(),
1983            vec![DerivativeHessian::Dense(a.view())],
1984            None,
1985            seed,
1986            HUTCHINSON_ADAPTIVE_REL_TOL,
1987            HUTCHINSON_ADAPTIVE_TAU_REL,
1988        )
1989        .expect("dense ok");
1990
1991        let h_clone = h.clone();
1992        let hvp_evidence = evidence_traces_adaptive_hvp(
1993            p,
1994            |v: &[f64], out: &mut [f64]| {
1995                for r in 0..p {
1996                    let mut acc = 0.0_f64;
1997                    for c in 0..p {
1998                        acc += h_clone[[r, c]] * v[c];
1999                    }
2000                    out[r] = acc;
2001                }
2002            },
2003            vec![DerivativeHessian::Dense(a.view())],
2004            None,
2005            seed,
2006            HUTCHINSON_ADAPTIVE_REL_TOL,
2007            HUTCHINSON_ADAPTIVE_TAU_REL,
2008        )
2009        .expect("hvp ok");
2010
2011        // Adaptive may stop at different K if SE crosses the threshold
2012        // at a different step due to CG round-off; compare both
2013        // estimates against exact rather than to each other.
2014        let exact = exact_trace_hinv_a(h.view(), a.view());
2015        let se_dense = dense.stderrs[0] / (dense.probe_count as f64).sqrt();
2016        let se_hvp = hvp_evidence.stderrs[0] / (hvp_evidence.probe_count as f64).sqrt();
2017        let tol_dense = (8.0 * se_dense).max(0.05 * exact.abs());
2018        let tol_hvp = (8.0 * se_hvp).max(0.05 * exact.abs());
2019        assert!(
2020            (dense.traces[0] - exact).abs() <= tol_dense,
2021            "dense adaptive {} not near exact {} (tol {})",
2022            dense.traces[0],
2023            exact,
2024            tol_dense
2025        );
2026        assert!(
2027            (hvp_evidence.traces[0] - exact).abs() <= tol_hvp,
2028            "hvp adaptive {} not near exact {} (tol {})",
2029            hvp_evidence.traces[0],
2030            exact,
2031            tol_hvp
2032        );
2033        // logdet is intentionally NaN on the HVP path.
2034        assert!(hvp_evidence.logdet_hessian.is_nan());
2035    }
2036
2037    #[test]
2038    fn block_2_7_hvp_stderr_matches_dense_reduce_mean_stderr() {
2039        // The HVP path's `stderrs` must use the SAME estimator convention as
2040        // the dense path's `reduce_mean_stderr`: the Bessel-corrected (K−1)
2041        // sample standard deviation of the per-probe q values. We force both
2042        // paths to run the full K=128 schedule (rel_tol below any achievable
2043        // ratio) so the comparison is at identical probe counts on identical
2044        // CRN probes. The only residual difference is the inner solve (exact
2045        // Cholesky vs CG@1e-6), which keeps the q values — and hence the SDs —
2046        // agreeing to a tight relative tolerance.
2047        let p = 36;
2048        let h = make_spd(p, 0.6);
2049        let a = random_dense_sym(p, 0x5151);
2050        let seed = ProbeSeed(0xBEEF);
2051        let force_full_schedule = 1e-12_f64;
2052
2053        let dense = evidence_traces_adaptive(
2054            h.view(),
2055            vec![DerivativeHessian::Dense(a.view())],
2056            None,
2057            seed,
2058            force_full_schedule,
2059            HUTCHINSON_ADAPTIVE_TAU_REL,
2060        )
2061        .expect("dense ok");
2062
2063        let h_clone = h.clone();
2064        let hvp = evidence_traces_adaptive_hvp(
2065            p,
2066            |v: &[f64], out: &mut [f64]| {
2067                for r in 0..p {
2068                    let mut acc = 0.0_f64;
2069                    for c in 0..p {
2070                        acc += h_clone[[r, c]] * v[c];
2071                    }
2072                    out[r] = acc;
2073                }
2074            },
2075            vec![DerivativeHessian::Dense(a.view())],
2076            None,
2077            seed,
2078            force_full_schedule,
2079            HUTCHINSON_ADAPTIVE_TAU_REL,
2080        )
2081        .expect("hvp ok");
2082
2083        // Both ran the full schedule, so probe counts match exactly.
2084        assert_eq!(dense.probe_count, 128);
2085        assert_eq!(hvp.probe_count, dense.probe_count);
2086
2087        let sd_dense = dense.stderrs[0];
2088        let sd_hvp = hvp.stderrs[0];
2089        assert!(
2090            sd_dense > 0.0,
2091            "dense SD should be positive, got {sd_dense}"
2092        );
2093        let rel = (sd_hvp - sd_dense).abs() / sd_dense;
2094        assert!(
2095            rel <= 1e-3,
2096            "HVP SD {sd_hvp} disagrees with dense reduce_mean_stderr SD {sd_dense} \
2097             (rel {rel}); the two paths must share the Bessel-corrected (K−1) convention"
2098        );
2099    }
2100
2101    #[test]
2102    fn block_2_7_cg_solves_diagonal_in_one_iteration() {
2103        // For diagonal H, CG converges in one step (Krylov subspace
2104        // contains the exact solution). Verifies the CG residual
2105        // logic and SPD bailout.
2106        let p = 8;
2107        let diag: Vec<f64> = (0..p).map(|i| 1.0 + i as f64).collect();
2108        let b: Vec<f64> = (0..p).map(|i| (i as f64) + 0.5).collect();
2109        let mut w = vec![0.0_f64; p];
2110        let diag_clone = diag.clone();
2111        cg_solve(
2112            &mut |v: &[f64], out: &mut [f64]| {
2113                for i in 0..p {
2114                    out[i] = diag_clone[i] * v[i];
2115                }
2116            },
2117            &b,
2118            &mut w,
2119            1e-12,
2120            PCG_HVP_MAX_ITERS,
2121        );
2122        for i in 0..p {
2123            let expected = b[i] / diag[i];
2124            assert!(
2125                (w[i] - expected).abs() < 1e-10,
2126                "diagonal CG: w[{i}]={} expected {expected}",
2127                w[i]
2128            );
2129        }
2130    }
2131
2132    // ────────────────────────────────────────────────────────────────
2133    // Block 2.8: V100 hill-climb (10× vs exact GPU at p=2000, d_ρ=8).
2134    //
2135    // The assertion only fires when a CUDA runtime is detected;
2136    // on CPU-only hosts the test still runs the timing comparison but
2137    // skips the speedup assertion (exact dense Cholesky is competitive
2138    // with adaptive Hutchinson on a single core, so the 10× lower bound
2139    // is V100-specific). On V100, the adaptive path batches K=16-128
2140    // probes through one potrs while the exact path repeats `d` full
2141    // solves; the bound is therefore comfortable.
2142    // ────────────────────────────────────────────────────────────────
2143
2144    #[test]
2145    fn block_2_8_hill_climb_adaptive_vs_exact_at_p2000_d8() {
2146        // Smaller dimensions on CPU CI to keep the test under a minute;
2147        // V100 runs the full p=2000, d=8 specified in the charter.
2148        let on_v100 =
2149            cfg!(target_os = "linux") && gam_gpu::device_runtime::GpuRuntime::global().is_some();
2150        let (p, d): (usize, usize) = if on_v100 { (2000, 8) } else { (256, 4) };
2151
2152        let mut h = Array2::<f64>::zeros((p, p));
2153        for i in 0..p {
2154            for j in 0..p {
2155                h[[i, j]] = if i == j {
2156                    p as f64 + 1.0
2157                } else {
2158                    1.0 / (1.0 + (i as f64 - j as f64).abs())
2159                };
2160            }
2161        }
2162        let derivs_owned: Vec<Array2<f64>> = (0..d)
2163            .map(|k| random_dense_sym(p, 0x1000 + k as u64))
2164            .collect();
2165        let derivs: Vec<DerivativeHessian<'_>> = derivs_owned
2166            .iter()
2167            .map(|a| DerivativeHessian::Dense(a.view()))
2168            .collect();
2169
2170        // Exact path: factor H once, then `tr(H⁻¹ A_j) = Σᵢ (H⁻¹ A_j)[i,i]`
2171        // by solving H X = A_j column-by-column. This is the cost the
2172        // CPU/exact-spectral path pays per REML outer step.
2173        let t_exact_start = std::time::Instant::now();
2174        let factor = cholesky_lower(&h).expect("SPD");
2175        let mut exact_traces = vec![0.0_f64; d];
2176        for (j, a) in derivs_owned.iter().enumerate() {
2177            let mut acc = 0.0_f64;
2178            for col in 0..p {
2179                let mut rhs = vec![0.0_f64; p];
2180                for r in 0..p {
2181                    rhs[r] = a[[r, col]];
2182                }
2183                let w = solve_cholesky(&factor, &rhs);
2184                acc += w[col];
2185            }
2186            exact_traces[j] = acc;
2187        }
2188        let t_exact = t_exact_start.elapsed();
2189
2190        // Adaptive Hutchinson path.
2191        let t_adaptive_start = std::time::Instant::now();
2192        let evidence = evidence_traces_adaptive(
2193            h.view(),
2194            derivs,
2195            None,
2196            ProbeSeed(0xB10C),
2197            HUTCHINSON_ADAPTIVE_REL_TOL,
2198            HUTCHINSON_ADAPTIVE_TAU_REL,
2199        )
2200        .expect("adaptive ok");
2201        let t_adaptive = t_adaptive_start.elapsed();
2202
2203        // Sanity: every adaptive trace must agree with exact within its
2204        // reported SE. This guards against a fast-but-wrong perf path.
2205        for j in 0..d {
2206            let se = evidence.stderrs[j] / (evidence.probe_count as f64).sqrt();
2207            let tol = (10.0 * se).max(0.05 * exact_traces[j].abs());
2208            let diff = (evidence.traces[j] - exact_traces[j]).abs();
2209            assert!(
2210                diff <= tol,
2211                "block_2_8: derivative {j} adaptive {} disagrees with exact {} (tol {tol}, diff {diff})",
2212                evidence.traces[j],
2213                exact_traces[j]
2214            );
2215        }
2216
2217        let speedup = t_exact.as_secs_f64() / t_adaptive.as_secs_f64().max(1e-9);
2218        eprintln!(
2219            "block_2_8 hill-climb [p={p}, d={d}, V100={on_v100}]: \
2220             exact={:?}, adaptive={:?}, speedup={:.2}× (K={}, converged={})",
2221            t_exact, t_adaptive, speedup, evidence.probe_count, evidence.converged
2222        );
2223        if on_v100 {
2224            assert!(
2225                speedup >= 10.0,
2226                "block_2_8 V100 speedup {speedup:.2}× below the 10× target \
2227                 (exact {:?}, adaptive {:?})",
2228                t_exact,
2229                t_adaptive,
2230            );
2231        }
2232    }
2233}