Skip to main content

gam_models/bms/gpu/
row.rs

1//! Stage 2 of the BMS FLEX row kernel — per-row math that turns per-cell
2//! derivative moments (built by Stage 1 in `src/gpu/cubic_cell/mod.rs`) into a
3//! row gradient and row-primary `r × r` Hessian.
4//!
5//! Math (mirrors the CPU reference
6//! `BernoulliMarginalSlope::compute_row_analytic_flex_from_parts_into` in
7//! `src/families/bernoulli_marginal_slope.rs`):
8//!
9//! For each row `i`, with per-cell cubic predictor coefficients
10//! `C_c = (C0, C1, C2, C3)` and derivative moments `m_0..m_9`, build
11//!
12//! ```text
13//!     κ        = 1 / (2π)
14//!     T_n      = κ · Σ_{e=0..3} C_e · m_{e+n}     (n = 0..6)
15//!     D(R)     = κ · Σ_{k=0..3} R_k · m_k
16//!     Q(R, S)  = Σ_{p,q=0..3} R_p · S_q · T_{p+q}
17//!     H(R, S, U) = D(U) − Q(R, S)
18//! ```
19//!
20//! Per cell `c`, accumulate into row scratch:
21//!
22//! ```text
23//!     F_a   += D(A_c)
24//!     F_aa  += H(A_c, A_c, AA_c)
25//!     F_u   += D(R_{c,u})                         u > 0
26//!     F_au  += H(A_c, R_{c,u}, AR_{c,u})          u > 0
27//!     F_uv  += H(R_{c,u}, R_{c,v}, S_{c,uv})      0 < u ≤ v
28//! ```
29//!
30//! After the cell sum, the `q`-row is overridden:
31//!
32//! ```text
33//!     F_q  = −mu_1
34//!     F_qq = −mu_2
35//!     F_qv = 0   (v > 0)
36//!     F_aq = 0
37//! ```
38//!
39//! Implicit function theorem (single `1/F_a`):
40//!
41//! ```text
42//!     inv_Fa = 1 / F_a
43//!     a_u    = −F_u · inv_Fa                       (q-row override: mu_1 · inv_Fa)
44//!     a_uv   = −(F_uv + F_au·a_v + F_av·a_u + F_aa·a_u·a_v) · inv_Fa
45//! ```
46//!
47//! Observed predictor at `z_obs` (host supplies pre-evaluated chi, xi, rho, tau,
48//! r_uv per row and coordinate):
49//!
50//! ```text
51//!     bar_e_u  = chi_obs · a_u + rho_u
52//!     bar_e_uv = chi_obs · a_uv + xi_obs · a_u · a_v + tau_u · a_v
53//!                + a_u · tau_v + r_uv
54//! ```
55//!
56//! Probit Mills (stable; uses `log_ndtr_and_mills` from `numerics_device::PROBIT_NUMERICS_CU`):
57//!
58//! ```text
59//!     s = 2y − 1 ;  m = s · e_obs
60//!     [log_cdf, λ] = log_ndtr_and_mills(m)
61//!     A = −w · s · λ
62//!     B =  w · λ · (m + λ)
63//! ```
64//!
65//! Final outputs:
66//!
67//! ```text
68//!     neglog   = −w · log_cdf
69//!     g_u      = A · bar_e_u
70//!     H_{uv}   = B · bar_e_u · bar_e_v + A · bar_e_uv     (symmetric)
71//! ```
72//!
73//! Implementation choice (Stage 2): **one CUDA block per row**, with
74//! `blockDim.x = 32` threads. The block's `F_u`, `F_au`, `F_uv`, `bar_e_u`,
75//! `bar_e_uv` live in shared memory; threads in the block parallelise the
76//! per-cell sums, then a single thread of the block (`threadIdx.x == 0`) does
77//! the IFT solve, the observed-point assembly, the Mills evaluation, and the
78//! final gradient + Hessian write-out. With the `r ≤ MAX_R` cap (32) the
79//! shared-memory footprint per block is `r + r + r*r + r + r*r` doubles
80//! = `2r² + 3r` ≤ 2 144 doubles ≈ 17 KB, well below the V100 48 KB per-block
81//! limit. This keeps the implementation simple and avoids per-thread global
82//! scratch (a per-thread `r*r` scratch arena would be ~2 GB at n=195k, r=20).
83
84#[cfg(target_os = "linux")]
85use std::sync::OnceLock;
86
87use gam_gpu::gpu_error::GpuError;
88
89#[cfg(target_os = "linux")]
90use std::sync::Arc;
91
92#[cfg(target_os = "linux")]
93use cudarc::driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
94
95/// Hard ceiling on `r` (= 2 + p_h + p_w). Matches the shared-memory budget
96/// argument in the module docstring: with `MAX_R = 32` the per-block shared
97/// footprint is at most `2·32² + 3·32 = 2 144` doubles = 17 KB.
98pub(crate) const MAX_R: usize = 32;
99
100/// `blockDim.x` for the row kernel. Threads of a row-block parallelise the
101/// per-cell loop; thread 0 of the block finalises the IFT solve. Linux-only
102/// because the kernel launcher that consumes it is Linux-only.
103#[cfg(target_os = "linux")]
104pub(crate) const ROW_KERNEL_THREADS: u32 = 32;
105
106/// Number of cubic predictor coefficients per cell (`C0..C3`) and the matching
107/// support length of `A_c`, `R_{c,u}`, `AA_c`, `AR_{c,u}`, `S_{c,uv}`.
108pub(crate) const COEFF4: usize = 4;
109
110/// Highest moment index touched per cell: `T_n` uses `m_{n+e}` for `e = 0..3`
111/// and `n = 0..6`, so the maximum index is `9`. `MOMENT_STRIDE = 10`.
112pub(crate) const MOMENT_STRIDE: usize = 10;
113
114/// Source of the per-cell derivative moments fed into the row kernel.
115/// Phase-4 wiring: the substrate at `src/gpu/cubic_cell/mod.rs` can produce
116/// these on the GPU; this enum lets the launcher consume them directly
117/// without a DtoH+HtoD round-trip.
118pub(crate) enum CellMomentsSource<'a> {
119    /// Host-resident `[total_cells, MOMENT_STRIDE = 10]` row-major buffer.
120    /// The launcher will HtoD-upload this on every launch.
121    Host(&'a [f64]),
122    /// Device-resident moments already living on the row-kernel backend's
123    /// default stream (which is the same `cuda_context_for(ordinal).default_stream()`
124    /// the cubic-cell substrate uses, so no cross-context copy is needed).
125    /// Length on the device must be `total_cells * MOMENT_STRIDE`. Linux-only.
126    #[cfg(target_os = "linux")]
127    Device(&'a CudaSlice<f64>),
128}
129
130impl<'a> CellMomentsSource<'a> {
131    /// Logical element count of the moments source, used by [`BmsFlexRowKernelInputs::validate`].
132    pub(crate) fn len(&self) -> usize {
133        match self {
134            CellMomentsSource::Host(slice) => slice.len(),
135            #[cfg(target_os = "linux")]
136            CellMomentsSource::Device(d) => d.len(),
137        }
138    }
139}
140
141/// Per-row input bundle for [`launch_bms_flex_row_kernel`].
142///
143/// Coordinate ordering convention: `u = 0` is `a` (the latent intercept and
144/// the variable IFT eliminates); `u = 1` is `b` (slope); `u = 2..2+p_h` is the
145/// score-warp `β_h` block; `u = 2+p_h..2+p_h+p_w` is the link-wiggle `β_w`
146/// block. So `r = 2 + p_h + p_w` and `u = 1` is the `b` (slope) index used by
147/// the sparse `S_{b·h}` / `S_{b·w}` payloads.
148macro_rules! define_bms_flex_row_kernel_input_types {
149    (
150        f64_fields: [$($f64_field:ident),+ $(,)?],
151        u32_fields: [$($u32_field:ident),+ $(,)?],
152        moments_field: $moments_field:ident $(,)?
153    ) => {
154        pub(crate) struct BmsFlexRowKernelInputs<'a> {
155            /// Number of observation rows.
156            pub n_rows: usize,
157            /// Total primary local dimension. `r = 2 + p_h + p_w`.
158            pub r: usize,
159            /// Number of score-warp basis coordinates.
160            pub p_h: usize,
161            /// Number of link-wiggle basis coordinates.
162            pub p_w: usize,
163            /// Probit frailty scale `S_f` (scalar shared across rows; matches
164            /// `BernoulliMarginalSlope::probit_frailty_scale`).
165            pub s_f: f64,
166            $(pub $f64_field: &'a [f64],)+
167            $(pub $u32_field: &'a [u32],)+
168            pub $moments_field: CellMomentsSource<'a>,
169        }
170
171        /// Owned twin of [`BmsFlexRowKernelInputs`] — every borrowed slice is
172        /// replaced by an owned `Vec`. The buffer fields are declared from the
173        /// same schema as the borrowed launch ABI and converted by
174        /// [`BmsFlexRowKernelInputsOwned::as_borrowed`].
175        pub(crate) struct BmsFlexRowKernelInputsOwned {
176            pub n_rows: usize,
177            pub r: usize,
178            pub p_h: usize,
179            pub p_w: usize,
180            pub s_f: f64,
181            $(pub $f64_field: Vec<f64>,)+
182            $(pub $u32_field: Vec<u32>,)+
183            pub $moments_field: Vec<f64>,
184            /// Phase-4 device-resident moments. When `Some(_)`, the launcher
185            /// skips the host upload and consumes the buffer directly.
186            /// Linux-only field.
187            #[cfg(target_os = "linux")]
188            pub cell_moments_device: Option<CudaSlice<f64>>,
189        }
190
191        impl BmsFlexRowKernelInputsOwned {
192            /// Borrowed view over `self` suitable for
193            /// [`launch_bms_flex_row_kernel`]. The returned struct holds
194            /// references into `self` so the owned bundle must outlive the
195            /// launch.
196            pub(crate) fn as_borrowed(&self) -> BmsFlexRowKernelInputs<'_> {
197                #[cfg(target_os = "linux")]
198                let cell_moments = match self.cell_moments_device.as_ref() {
199                    Some(d) => CellMomentsSource::Device(d),
200                    None => CellMomentsSource::Host(&self.cell_moments),
201                };
202                #[cfg(not(target_os = "linux"))]
203                let cell_moments = CellMomentsSource::Host(&self.cell_moments);
204                BmsFlexRowKernelInputs {
205                    n_rows: self.n_rows,
206                    r: self.r,
207                    p_h: self.p_h,
208                    p_w: self.p_w,
209                    s_f: self.s_f,
210                    $($f64_field: &self.$f64_field,)+
211                    $($u32_field: &self.$u32_field,)+
212                    $moments_field: cell_moments,
213                }
214            }
215        }
216    };
217}
218
219define_bms_flex_row_kernel_input_types! {
220    f64_fields: [
221        q,
222        b,
223        mu_1,
224        mu_2,
225        z_obs,
226        y,
227        w,
228        e_obs,
229        cell_c0,
230        cell_c1,
231        cell_c2,
232        cell_c3,
233        cell_a,
234        cell_aa,
235        cell_r,
236        cell_ar,
237        cell_sbb,
238        cell_sbh,
239        cell_sbw,
240        chi_obs,
241        xi_obs,
242        rho_u,
243        tau_u,
244        r_uv,
245    ],
246    u32_fields: [cell_offsets],
247    moments_field: cell_moments,
248}
249
250/// Per-row outputs produced by [`launch_bms_flex_row_kernel`].
251#[derive(Debug)]
252pub(crate) struct BmsFlexRowKernelOutputs {
253    /// Per-row negative log-likelihood. Length `n_rows`.
254    pub neglog: Vec<f64>,
255    /// Per-row gradient, row-major `[n_rows, r]`.
256    pub grad: Vec<f64>,
257    /// Per-row Hessian, row-major `[n_rows, r*r]`. The kernel writes the full
258    /// symmetric matrix.
259    pub hess: Vec<f64>,
260}
261
262impl<'a> BmsFlexRowKernelInputs<'a> {
263    /// Sanity-check every shape the kernel relies on. This is the only place
264    /// length errors are surfaced — the device kernel assumes valid layout.
265    pub(crate) fn validate(&self) -> Result<(), GpuError> {
266        if self.r == 0 {
267            return Err(GpuError::DriverCallFailed {
268                reason: "bms_flex_row inputs: r must be > 0".to_string(),
269            });
270        }
271        if self.r > MAX_R {
272            return Err(GpuError::DriverCallFailed {
273                reason: format!("bms_flex_row inputs: r={} exceeds MAX_R={MAX_R}", self.r),
274            });
275        }
276        if self.r != 2 + self.p_h + self.p_w {
277            return Err(GpuError::DriverCallFailed {
278                reason: format!(
279                    "bms_flex_row inputs: r={} must equal 2 + p_h({}) + p_w({}) = {}",
280                    self.r,
281                    self.p_h,
282                    self.p_w,
283                    2 + self.p_h + self.p_w
284                ),
285            });
286        }
287        let n = self.n_rows;
288        let check_len = |name: &str, have: usize, want: usize| -> Result<(), GpuError> {
289            if have != want {
290                return Err(GpuError::DriverCallFailed {
291                    reason: format!("bms_flex_row inputs: {name}.len()={have} != {want}"),
292                });
293            }
294            Ok(())
295        };
296        check_len("q", self.q.len(), n)?;
297        check_len("b", self.b.len(), n)?;
298        check_len("mu_1", self.mu_1.len(), n)?;
299        check_len("mu_2", self.mu_2.len(), n)?;
300        check_len("z_obs", self.z_obs.len(), n)?;
301        check_len("y", self.y.len(), n)?;
302        check_len("w", self.w.len(), n)?;
303        check_len("e_obs", self.e_obs.len(), n)?;
304        check_len("chi_obs", self.chi_obs.len(), n)?;
305        check_len("xi_obs", self.xi_obs.len(), n)?;
306        check_len("rho_u", self.rho_u.len(), n * self.r)?;
307        check_len("tau_u", self.tau_u.len(), n * self.r)?;
308        check_len("r_uv", self.r_uv.len(), n * self.r * self.r)?;
309        check_len("cell_offsets", self.cell_offsets.len(), n + 1)?;
310        let total_cells_u32 = self.cell_offsets[n];
311        let total_cells = total_cells_u32 as usize;
312        check_len("cell_c0", self.cell_c0.len(), total_cells)?;
313        check_len("cell_c1", self.cell_c1.len(), total_cells)?;
314        check_len("cell_c2", self.cell_c2.len(), total_cells)?;
315        check_len("cell_c3", self.cell_c3.len(), total_cells)?;
316        check_len("cell_a", self.cell_a.len(), total_cells * COEFF4)?;
317        check_len("cell_aa", self.cell_aa.len(), total_cells * COEFF4)?;
318        check_len(
319            "cell_r",
320            self.cell_r.len(),
321            total_cells * self.r.saturating_sub(1) * COEFF4,
322        )?;
323        check_len(
324            "cell_ar",
325            self.cell_ar.len(),
326            total_cells * self.r.saturating_sub(1) * COEFF4,
327        )?;
328        check_len("cell_sbb", self.cell_sbb.len(), total_cells * COEFF4)?;
329        check_len(
330            "cell_sbh",
331            self.cell_sbh.len(),
332            total_cells * self.p_h * COEFF4,
333        )?;
334        check_len(
335            "cell_sbw",
336            self.cell_sbw.len(),
337            total_cells * self.p_w * COEFF4,
338        )?;
339        check_len(
340            "cell_moments",
341            self.cell_moments.len(),
342            total_cells * MOMENT_STRIDE,
343        )?;
344        // Bonus: when the moments came from `CellMomentsSource::Device`, the
345        // launcher needs to know the source is from a device buffer; nothing
346        // to validate beyond length above. The Host variant length check is
347        // also already covered above.
348        // Monotone cell_offsets check.
349        for i in 0..n {
350            if self.cell_offsets[i] > self.cell_offsets[i + 1] {
351                return Err(GpuError::DriverCallFailed {
352                    reason: format!(
353                        "bms_flex_row inputs: cell_offsets must be monotone (offset[{}]={} > offset[{}]={})",
354                        i,
355                        self.cell_offsets[i],
356                        i + 1,
357                        self.cell_offsets[i + 1]
358                    ),
359                });
360            }
361        }
362        Ok(())
363    }
364}
365
366/// NVRTC kernel source body. One CUDA block per row; 32 threads per block
367/// parallise the per-cell sums into shared-memory scratch; thread 0 of the
368/// block finishes the IFT + observed-point + Mills + Hessian write-out.
369///
370/// Shared probit numerics (`erfcx_nonnegative`, `log_ndtr`,
371/// `log_ndtr_and_mills`) are provided by
372/// `numerics_device::PROBIT_NUMERICS_CU`, which is prepended before
373/// passing to `cudarc::nvrtc::compile_ptx`.
374///
375/// **CPU parity reference**: the body mirrors
376/// `compute_row_analytic_flex_from_parts_into` in
377/// `src/families/bernoulli_marginal_slope.rs`.
378#[cfg(target_os = "linux")]
379pub(crate) const ROW_KERNEL_BODY: &str = r#"
380// One block per row. blockDim.x = 32; threadIdx.x parallises per-cell sums.
381// CPU parity reference: src/families/bernoulli_marginal_slope.rs
382//                      ::compute_row_analytic_flex_from_parts_into.
383
384#define INV_TWO_PI     0.15915494309189535
385
386extern "C" __device__ __forceinline__ double atomic_add_f64(double *addr, double value) {
387    unsigned long long int *addr_as_ull = (unsigned long long int *)addr;
388    unsigned long long int old = *addr_as_ull;
389    unsigned long long int assumed;
390    do {
391        assumed = old;
392        double next = __longlong_as_double((long long int)assumed) + value;
393        old = atomicCAS(addr_as_ull, assumed, (unsigned long long int)__double_as_longlong(next));
394    } while (assumed != old);
395    return __longlong_as_double((long long int)old);
396}
397
398// `nan_fill_outputs`: thread-0-only path used when row inputs are degenerate
399// (`F_a` non-finite or non-positive). Writes NaNs to neglog/grad/hess so the
400// host falls back to CPU for that row.
401extern "C" __device__ __forceinline__ void
402nan_fill_outputs(int r,
403                 int row,
404                 double *out_neglog,
405                 double *out_grad,
406                 double *out_hess) {
407    double nan_value = __longlong_as_double(0x7ff8000000000000ULL);
408    out_neglog[row] = nan_value;
409    for (int u = 0; u < r; ++u) {
410        out_grad[row * r + u] = nan_value;
411    }
412    int rr = r * r;
413    for (int idx = 0; idx < rr; ++idx) {
414        out_hess[row * rr + idx] = nan_value;
415    }
416}
417
418extern "C" __global__ void bms_flex_row_kernel(
419    int                  n_rows,
420    int                  r,
421    int                  p_h,
422    int                  p_w,
423    double               s_f,                // currently unused on device:
424                                             // host has already baked S_f
425                                             // into the cubic coefficients.
426                                             // Kept for diagnostic parity.
427    const double * __restrict__ row_q,
428    const double * __restrict__ row_b,
429    const double * __restrict__ row_mu1,
430    const double * __restrict__ row_mu2,
431    const double * __restrict__ row_zobs,
432    const double * __restrict__ row_y,
433    const double * __restrict__ row_w,
434    const unsigned int * __restrict__ cell_offsets,
435    const double * __restrict__ cell_c0,
436    const double * __restrict__ cell_c1,
437    const double * __restrict__ cell_c2,
438    const double * __restrict__ cell_c3,
439    const double * __restrict__ cell_a,       // [n_cells, 4]
440    const double * __restrict__ cell_aa,      // [n_cells, 4]
441    const double * __restrict__ cell_r,       // [n_cells, r-1, 4]
442    const double * __restrict__ cell_ar,      // [n_cells, r-1, 4]
443    const double * __restrict__ cell_sbb,     // [n_cells, 4]
444    const double * __restrict__ cell_sbh,     // [n_cells, p_h, 4]
445    const double * __restrict__ cell_sbw,     // [n_cells, p_w, 4]
446    const double * __restrict__ cell_moments, // [n_cells, 10]
447    const double * __restrict__ row_chi,
448    const double * __restrict__ row_xi,
449    const double * __restrict__ row_rho,      // [n_rows, r]
450    const double * __restrict__ row_tau,      // [n_rows, r]
451    const double * __restrict__ row_ruv,      // [n_rows, r*r]
452    const double * __restrict__ row_e_obs,    // [n_rows] observed predictor VALUE
453    double       * __restrict__ out_neglog,
454    double       * __restrict__ out_grad,
455    double       * __restrict__ out_hess)
456{
457    int row = blockIdx.x;
458    if (row >= n_rows) return;
459    int tid = threadIdx.x;
460
461    // ── shared scratch (sized to MAX_R = 32) ──────────────────────────────
462    // Layout (doubles):
463    //   F_u      [r]
464    //   F_au     [r]
465    //   F_uv     [r*r]
466    //   bar_e_u  [r]
467    //   bar_e_uv [r*r]
468    //   reduce_a [blockDim.x]
469    //   reduce_b [blockDim.x]
470    // Sized for the worst case (r = MAX_R = 32).
471    __shared__ double F_u[32];
472    __shared__ double F_au[32];
473    __shared__ double F_uv[32 * 32];
474    __shared__ double bar_e_u[32];
475    __shared__ double bar_e_uv[32 * 32];
476    __shared__ double reduce_a[32];
477    __shared__ double reduce_b[32];
478    __shared__ double F_a_shared;
479    __shared__ double F_aa_shared;
480
481    // Zero scratch.
482    if (tid == 0) { F_a_shared = 0.0; F_aa_shared = 0.0; }
483    for (int u = tid; u < r; u += blockDim.x) {
484        F_u[u]  = 0.0;
485        F_au[u] = 0.0;
486    }
487    for (int uv = tid; uv < r * r; uv += blockDim.x) {
488        F_uv[uv] = 0.0;
489    }
490    __syncthreads();
491
492    // ── per-cell sweep ───────────────────────────────────────────────────
493    unsigned int cell_lo = cell_offsets[row];
494    unsigned int cell_hi = cell_offsets[row + 1];
495    int n_cells = (int)(cell_hi - cell_lo);
496
497    double local_Fa  = 0.0;
498    double local_Faa = 0.0;
499
500    for (int local_c = tid; local_c < n_cells; local_c += blockDim.x) {
501        unsigned int c = cell_lo + (unsigned int)local_c;
502
503        // Load cubic predictor coeffs C0..C3.
504        double C[4];
505        C[0] = cell_c0[c]; C[1] = cell_c1[c];
506        C[2] = cell_c2[c]; C[3] = cell_c3[c];
507
508        // Load m_0..m_9.
509        const double *m = cell_moments + (size_t)c * 10;
510
511        // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
512        // CPU parity: equivalent to the `eta_rs ⊗ moments` contraction in
513        //             `cell_second_derivative_from_moments` after folding the
514        //             cubic predictor.
515        double T[7];
516        #pragma unroll
517        for (int n = 0; n < 7; ++n) {
518            double acc = 0.0;
519            #pragma unroll
520            for (int e = 0; e < 4; ++e) {
521                acc = fma(C[e], m[e + n], acc);
522            }
523            T[n] = acc * INV_TWO_PI;
524        }
525
526        // D(R) = κ · Σ_k R_k · m_k.
527        // CPU parity: `cell_first_derivative_from_moments`.
528        #define D_OF(R) (INV_TWO_PI * (R[0]*m[0] + R[1]*m[1] + R[2]*m[2] + R[3]*m[3]))
529
530        // Q(R, S) = Σ_{p,q} R_p · S_q · T_{p+q}.
531        // CPU parity: the `eta_rs` folded dot in
532        // `cell_second_derivative_from_moments`.
533        #define Q_OF(R, S)                                                                 \
534            ((R[0]*S[0])*T[0] + (R[0]*S[1] + R[1]*S[0])*T[1]                               \
535             + (R[0]*S[2] + R[1]*S[1] + R[2]*S[0])*T[2]                                    \
536             + (R[0]*S[3] + R[1]*S[2] + R[2]*S[1] + R[3]*S[0])*T[3]                        \
537             + (R[1]*S[3] + R[2]*S[2] + R[3]*S[1])*T[4]                                    \
538             + (R[2]*S[3] + R[3]*S[2])*T[5]                                                \
539             + (R[3]*S[3])*T[6])
540
541        // F_a += D(A_c) ; F_aa += H(A_c, A_c, AA_c) = D(AA_c) − Q(A_c, A_c).
542        const double *A_c  = cell_a  + (size_t)c * 4;
543        const double *AA_c = cell_aa + (size_t)c * 4;
544        local_Fa  += D_OF(A_c);
545        local_Faa += D_OF(AA_c) - Q_OF(A_c, A_c);
546
547        // For each u > 0: F_u += D(R_{c,u}) ; F_au += H(A_c, R_{c,u}, AR_{c,u})
548        //                                   = D(AR_{c,u}) − Q(A_c, R_{c,u}).
549        for (int u = 1; u < r; ++u) {
550            const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
551            const double *AR_u = cell_ar + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
552            double d_R   = D_OF(R_u);
553            double d_AR  = D_OF(AR_u);
554            double q_AR  = Q_OF(A_c, R_u);
555            atomic_add_f64(&F_u[u], d_R);
556            atomic_add_f64(&F_au[u], d_AR - q_AR);
557        }
558
559        // F_uv: only b·b, b·h_j, b·w_ℓ have a material `S_{c,uv}`; every other
560        // (u, v) pair just contributes −Q(R_u, R_v).
561        // CPU parity: `SparsePrimaryCoeffJetView::pair_from_b_family` with
562        // `COEFF_SUPPORT_BHW` — every cross pair outside the b-row is zero.
563        for (int u = 1; u < r; ++u) {
564            const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
565            for (int v = u; v < r; ++v) {
566                const double *R_v = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(v - 1)) * 4;
567                double q_uv = Q_OF(R_u, R_v);
568                double d_s  = 0.0;
569                // S_{bb}: u == v == 1 (b coordinate).
570                if (u == 1 && v == 1) {
571                    const double *S_bb = cell_sbb + (size_t)c * 4;
572                    d_s = D_OF(S_bb);
573                }
574                // S_{b·h_j}: u == 1, v in score-warp block, or symmetric.
575                else if (u == 1 && v >= 2 && v < 2 + p_h) {
576                    int j = v - 2;
577                    const double *S_bh = cell_sbh + ((size_t)c * (size_t)p_h + (size_t)j) * 4;
578                    d_s = D_OF(S_bh);
579                }
580                // S_{b·w_ℓ}: u == 1, v in link-wiggle block, or symmetric.
581                else if (u == 1 && v >= 2 + p_h && v < r) {
582                    int l = v - (2 + p_h);
583                    const double *S_bw = cell_sbw + ((size_t)c * (size_t)p_w + (size_t)l) * 4;
584                    d_s = D_OF(S_bw);
585                }
586                // Symmetric mirror: u in (h or w) block, v == 1 cannot happen
587                // because we iterate v >= u; skip.
588                double val = d_s - q_uv;
589                atomic_add_f64(&F_uv[u * r + v], val);
590            }
591        }
592
593        #undef D_OF
594        #undef Q_OF
595    }
596
597    // Block reduction of local_Fa, local_Faa into shared.
598    reduce_a[tid] = local_Fa;
599    reduce_b[tid] = local_Faa;
600    __syncthreads();
601    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
602        if (tid < stride) {
603            reduce_a[tid] += reduce_a[tid + stride];
604            reduce_b[tid] += reduce_b[tid + stride];
605        }
606        __syncthreads();
607    }
608    if (tid == 0) {
609        F_a_shared  = reduce_a[0];
610        F_aa_shared = reduce_b[0];
611    }
612    __syncthreads();
613
614    // ── thread-0 finalisation: IFT + observed-point + Mills + writes ──────
615    if (tid != 0) return;
616
617    double F_a  = F_a_shared;
618    double F_aa = F_aa_shared;
619    double mu_1 = row_mu1[row];
620    double mu_2 = row_mu2[row];
621
622    // q-row overrides.
623    //   F_q  = -mu_1 ; F_qq = -mu_2 ; F_qv = 0 (v > 0) ; F_aq = 0.
624    F_u[0]  = -mu_1;
625    F_au[0] = 0.0;
626    // Zero the q-cross row/column of F_uv (u == 0 or v == 0), then plant -mu_2 at (0,0).
627    for (int v = 0; v < r; ++v) {
628        F_uv[0 * r + v] = 0.0;
629        F_uv[v * r + 0] = 0.0;
630    }
631    F_uv[0 * r + 0] = -mu_2;
632
633    // Guard: degenerate F_a ⇒ NaN-fill this row's outputs.
634    if (!isfinite(F_a) || F_a <= 0.0) {
635        nan_fill_outputs(r, row, out_neglog, out_grad, out_hess);
636        return;
637    }
638    double inv_Fa = 1.0 / F_a;
639
640    // IFT, first order.
641    //   a_u = -F_u · inv_Fa     (q-override: a_q = mu_1 · inv_Fa).
642    double a_u[32];
643    a_u[0] = mu_1 * inv_Fa;
644    for (int u = 1; u < r; ++u) {
645        a_u[u] = -F_u[u] * inv_Fa;
646    }
647
648    // IFT, second order.
649    //   a_uv = -(F_uv + F_au · a_v + F_av · a_u + F_aa · a_u · a_v) · inv_Fa.
650    // The q-row contributions (u==0 or v==0) collapse to a_uv = mu_2 · inv_Fa
651    // when both are 0 and to (F_au_v) · inv_Fa-style mixed shape otherwise.
652    // We compute it uniformly using the populated F_uv / F_au with the
653    // q-overrides above.
654    double a_uv[32 * 32];
655    for (int u = 0; u < r; ++u) {
656        for (int v = u; v < r; ++v) {
657            double term = F_uv[u * r + v]
658                        + F_au[v] * a_u[u]
659                        + F_au[u] * a_u[v]
660                        + F_aa * a_u[u] * a_u[v];
661            double val = -term * inv_Fa;
662            a_uv[u * r + v] = val;
663            a_uv[v * r + u] = val;
664        }
665    }
666
667    // Observed predictor jets at z_obs.
668    //   bar_e_u  = chi · a_u + rho_u.
669    //   bar_e_uv = chi · a_uv + xi · a_u · a_v + tau_u · a_v + a_u · tau_v + r_uv.
670    double chi = row_chi[row];
671    double xi  = row_xi[row];
672    const double *rho = row_rho + (size_t)row * r;
673    const double *tau = row_tau + (size_t)row * r;
674    const double *ruv = row_ruv + (size_t)row * r * r;
675
676    for (int u = 0; u < r; ++u) {
677        bar_e_u[u] = chi * a_u[u] + rho[u];
678    }
679    for (int u = 0; u < r; ++u) {
680        for (int v = u; v < r; ++v) {
681            double val = chi * a_uv[u * r + v]
682                       + xi  * a_u[u] * a_u[v]
683                       + tau[u] * a_u[v]
684                       + a_u[u] * tau[v]
685                       + ruv[u * r + v];
686            bar_e_uv[u * r + v] = val;
687            if (u != v) {
688                bar_e_uv[v * r + u] = val;
689            }
690        }
691    }
692
693    // Probit Mills.
694    double y    = row_y[row];
695    double w    = row_w[row];
696    double s    = 2.0 * y - 1.0;
697    // The "observed predictor" e_obs is the VALUE (degree-0 term) of the
698    // observed jet η(a(θ), θ; z_obs) — NOT `bar_e_u[0]`, which is the u=0
699    // FIRST-derivative jet (`chi·a_0 + rho_0 = dη_obs/dq`). The host packs
700    // the observed value directly in `row_e_obs[row]` (see
701    // `pack_bms_flex_row_kernel_inputs`, `eta_val = eval_coeff4_at(obs.coeff,
702    // z_obs)`), matching the CPU family `compute_row_analytic_flex_from_parts_into`
703    // which forms `signed_margin = s_y · eta_val`. #415 parity lock.
704    double e_obs = row_e_obs[row];
705    double m_arg = s * e_obs;
706    double log_cdf, lambda;
707    log_ndtr_and_mills(m_arg, &log_cdf, &lambda);
708    double A_i = -w * s * lambda;
709    double B_i =  w * lambda * (m_arg + lambda);
710
711    out_neglog[row] = -w * log_cdf;
712    for (int u = 0; u < r; ++u) {
713        out_grad[row * r + u] = A_i * bar_e_u[u];
714    }
715    for (int u = 0; u < r; ++u) {
716        for (int v = u; v < r; ++v) {
717            double val = B_i * bar_e_u[u] * bar_e_u[v] + A_i * bar_e_uv[u * r + v];
718            out_hess[row * r * r + u * r + v] = val;
719            if (u != v) {
720                out_hess[row * r * r + v * r + u] = val;
721            }
722        }
723    }
724}
725"#;
726
727// Force `s_f` to be considered used at the Rust level even though Stage 2 of
728// the kernel doesn't consume it on-device (the host has already baked the
729// probit frailty scale into the per-cell cubic coefficients). The dispatcher
730// wave that ports the rigid-branch fallback may want to apply `s_f` device-side
731// for log diagnostics; leaving the field on the input struct + reading it here
732// avoids a `let _` silencer the build.rs scanner would reject.
733#[inline]
734pub(crate) fn s_f_diagnostic_finite(inputs: &BmsFlexRowKernelInputs<'_>) -> bool {
735    inputs.s_f.is_finite() && inputs.s_f > 0.0
736}
737
738#[cfg(target_os = "linux")]
739pub(crate) struct RowKernelBackend {
740    pub(crate) stream: Arc<CudaStream>,
741    pub(crate) module: Arc<CudaModule>,
742}
743
744#[cfg(target_os = "linux")]
745impl RowKernelBackend {
746    pub(crate) fn probe() -> Result<&'static Self, GpuError> {
747        static BACKEND: OnceLock<Result<RowKernelBackend, GpuError>> = OnceLock::new();
748        BACKEND
749            .get_or_init(|| {
750                gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row", |parts| {
751                    let row_kernel_source = [
752                        gam_gpu::numerics_device::PROBIT_NUMERICS_CU,
753                        ROW_KERNEL_BODY,
754                    ]
755                    .concat();
756                    // #1551: route through the project's single arch-aware NVRTC
757                    // entry point instead of bare `cudarc::nvrtc::compile_ptx`.
758                    // `compile_ptx_arch` pins `--gpu-architecture` to the selected
759                    // device's compute capability and supplies the standard CUDA
760                    // include paths; bare `compile_ptx` uses NVRTC's default
761                    // virtual arch with no includes. The row kernel's 64-bit
762                    // `atomic_add_f64` (atomicCAS emulation) compiles best against
763                    // the real device arch, and this keeps every BMS-flex compile
764                    // site consistent with the SAE arrow/Schur kernels that do
765                    // require the sm_60 pin for native `atomicAdd(double*,double)`.
766                    let ptx = gam_gpu::device_cache::compile_ptx_arch(&row_kernel_source).map_err(
767                        |err| GpuError::DriverCallFailed {
768                            reason: format!("bms_flex_row NVRTC compile failed: {err}"),
769                        },
770                    )?;
771                    let module =
772                        parts
773                            .ctx
774                            .load_module(ptx)
775                            .map_err(|err| GpuError::DriverCallFailed {
776                                reason: format!("bms_flex_row module load failed: {err}"),
777                            })?;
778                    Ok(RowKernelBackend {
779                        stream: parts.stream.clone(),
780                        module,
781                    })
782                })
783            })
784            .as_ref()
785            .map_err(GpuError::clone)
786    }
787}
788
789/// Launch Stage-2 BMS FLEX row kernel. On non-Linux returns
790/// [`GpuError::DriverLibraryUnavailable`]; on Linux NVRTC-compiles the kernel
791/// (cached for the process lifetime), uploads the per-row + per-cell buffers,
792/// and dispatches one block per row.
793pub(crate) fn launch_bms_flex_row_kernel(
794    inputs: BmsFlexRowKernelInputs<'_>,
795) -> Result<BmsFlexRowKernelOutputs, GpuError> {
796    inputs.validate()?;
797    if !s_f_diagnostic_finite(&inputs) {
798        return Err(GpuError::DriverCallFailed {
799            reason: format!(
800                "bms_flex_row inputs: s_f must be positive and finite, got {}",
801                inputs.s_f
802            ),
803        });
804    }
805
806    #[cfg(target_os = "linux")]
807    {
808        launch_linux(inputs)
809    }
810    #[cfg(not(target_os = "linux"))]
811    {
812        Err(GpuError::DriverLibraryUnavailable {
813            reason: "bms_flex_row GPU kernel is Linux-only".to_string(),
814        })
815    }
816}
817
818#[cfg(target_os = "linux")]
819pub(crate) fn launch_linux(
820    inputs: BmsFlexRowKernelInputs<'_>,
821) -> Result<BmsFlexRowKernelOutputs, GpuError> {
822    let backend = RowKernelBackend::probe()?;
823    let stream = &backend.stream;
824
825    let upload_f64 = |slice: &[f64], label: &str| {
826        stream
827            .clone_htod(slice)
828            .map_err(|err| GpuError::DriverCallFailed {
829                reason: format!("bms_flex_row upload {label}: {err}"),
830            })
831    };
832    let upload_u32 = |slice: &[u32], label: &str| {
833        stream
834            .clone_htod(slice)
835            .map_err(|err| GpuError::DriverCallFailed {
836                reason: format!("bms_flex_row upload {label}: {err}"),
837            })
838    };
839
840    let d_q = upload_f64(inputs.q, "q")?;
841    let d_b = upload_f64(inputs.b, "b")?;
842    let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
843    let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
844    let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
845    let d_y = upload_f64(inputs.y, "y")?;
846    let d_w = upload_f64(inputs.w, "w")?;
847    let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
848    let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
849    let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
850    let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
851    let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
852    let d_a = upload_f64(inputs.cell_a, "cell_a")?;
853    let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
854    let d_r = upload_f64(inputs.cell_r, "cell_r")?;
855    let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
856    let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
857    let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
858    let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
859    // Phase-4: optionally consume device-resident moments (no host upload).
860    // Both branches end up holding a `&CudaSlice<f64>` named `d_moments_ref`
861    // we can pass to the launch builder uniformly.
862    let owned_host_moments: CudaSlice<f64>;
863    let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
864        CellMomentsSource::Host(slice) => {
865            owned_host_moments = upload_f64(slice, "cell_moments")?;
866            &owned_host_moments
867        }
868        CellMomentsSource::Device(d) => *d,
869    };
870    let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
871    let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
872    let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
873    let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
874    let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
875    let d_e_obs = upload_f64(inputs.e_obs, "e_obs")?;
876
877    let n = inputs.n_rows;
878    let r = inputs.r;
879    let mut d_neglog = stream
880        .alloc_zeros::<f64>(n)
881        .map_err(|err| GpuError::DriverCallFailed {
882            reason: format!("bms_flex_row alloc neglog: {err}"),
883        })?;
884    let mut d_grad =
885        stream
886            .alloc_zeros::<f64>(n * r)
887            .map_err(|err| GpuError::DriverCallFailed {
888                reason: format!("bms_flex_row alloc grad: {err}"),
889            })?;
890    let mut d_hess =
891        stream
892            .alloc_zeros::<f64>(n * r * r)
893            .map_err(|err| GpuError::DriverCallFailed {
894                reason: format!("bms_flex_row alloc hess: {err}"),
895            })?;
896
897    let func = backend
898        .module
899        .load_function("bms_flex_row_kernel")
900        .map_err(|err| GpuError::DriverCallFailed {
901            reason: format!("bms_flex_row load_function: {err}"),
902        })?;
903
904    let cfg = LaunchConfig {
905        grid_dim: (n as u32, 1, 1),
906        block_dim: (ROW_KERNEL_THREADS, 1, 1),
907        shared_mem_bytes: 0,
908    };
909    let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
910        reason: format!("bms_flex_row: n_rows={n} exceeds i32 range"),
911    })?;
912    let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
913        reason: format!("bms_flex_row: r={r} exceeds i32 range"),
914    })?;
915    let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
916        reason: format!("bms_flex_row: p_h={} exceeds i32 range", inputs.p_h),
917    })?;
918    let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
919        reason: format!("bms_flex_row: p_w={} exceeds i32 range", inputs.p_w),
920    })?;
921    let s_f = inputs.s_f;
922
923    let mut builder = stream.launch_builder(&func);
924    builder
925        .arg(&n_i32)
926        .arg(&r_i32)
927        .arg(&p_h_i32)
928        .arg(&p_w_i32)
929        .arg(&s_f)
930        .arg(&d_q)
931        .arg(&d_b)
932        .arg(&d_mu1)
933        .arg(&d_mu2)
934        .arg(&d_zobs)
935        .arg(&d_y)
936        .arg(&d_w)
937        .arg(&d_offsets)
938        .arg(&d_c0)
939        .arg(&d_c1)
940        .arg(&d_c2)
941        .arg(&d_c3)
942        .arg(&d_a)
943        .arg(&d_aa)
944        .arg(&d_r)
945        .arg(&d_ar)
946        .arg(&d_sbb)
947        .arg(&d_sbh)
948        .arg(&d_sbw)
949        .arg(d_moments_ref)
950        .arg(&d_chi)
951        .arg(&d_xi)
952        .arg(&d_rho)
953        .arg(&d_tau)
954        .arg(&d_ruv)
955        .arg(&d_e_obs)
956        .arg(&mut d_neglog)
957        .arg(&mut d_grad)
958        .arg(&mut d_hess);
959
960    // SAFETY: every kernel parameter above is either a primitive `i32` /
961    // `f64` (passed by value), a const device pointer to a buffer whose
962    // length the host validated against the input struct, or an output
963    // buffer pre-allocated to `n_rows`, `n_rows*r`, `n_rows*r*r`
964    // doubles. The kernel's shared-memory arrays are sized to MAX_R = 32
965    // and validate() rejects r > MAX_R.
966    unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
967        reason: format!("bms_flex_row launch: {err}"),
968    })?;
969    stream
970        .synchronize()
971        .map_err(|err| GpuError::DriverCallFailed {
972            reason: format!("bms_flex_row synchronize: {err}"),
973        })?;
974
975    let neglog = stream
976        .clone_dtoh(&d_neglog)
977        .map_err(|err| GpuError::DriverCallFailed {
978            reason: format!("bms_flex_row download neglog: {err}"),
979        })?;
980    let grad = stream
981        .clone_dtoh(&d_grad)
982        .map_err(|err| GpuError::DriverCallFailed {
983            reason: format!("bms_flex_row download grad: {err}"),
984        })?;
985    let hess = stream
986        .clone_dtoh(&d_hess)
987        .map_err(|err| GpuError::DriverCallFailed {
988            reason: format!("bms_flex_row download hess: {err}"),
989        })?;
990
991    Ok(BmsFlexRowKernelOutputs { neglog, grad, hess })
992}
993
994// ─────────────────────────────────────────────────────────────────────────────
995// Phase 3: device-resident row Hessian + HVP / diagonal kernels.
996//
997// Math (mirrors the CPU oracle in
998// `src/families/bernoulli_marginal_slope.rs::exact_newton_joint_hessian_*_from_cache`):
999//
1000//   Block layout (joint β):
1001//     marginal = [0..p_m), logslope = [p_m..p_m+p_g),
1002//     h        = [h_start..h_end), w = [w_start..w_end), total = p_total.
1003//
1004//   Primary layout (per-row r-vector):
1005//     q = 0, logslope = 1,
1006//     h = [h_primary_start..h_primary_end),
1007//     w = [w_primary_start..w_primary_end), total = r.
1008//
1009//   row_dir[u] for u in primary layout:
1010//     row_dir[0]   = Σ_j marginal_design[row, j] · v[j]
1011//     row_dir[1]   = Σ_j logslope_design[row, j] · v[p_m + j]
1012//     row_dir[h_k] = v[h_block_start + (h_k - h_primary_start)]
1013//     row_dir[w_k] = v[w_block_start + (w_k - w_primary_start)]
1014//
1015//   action[u]    = Σ_v row_hessians[row, u*r + v] · row_dir[v]
1016//
1017//   block_partial[marginal_j] += action[0] · marginal_design[row, j]
1018//   block_partial[logslope_j] += action[1] · logslope_design[row, j]
1019//   block_partial[h_block_start + (h_k - h_primary_start)] += action[h_k]
1020//   block_partial[w_block_start + (w_k - w_primary_start)] += action[w_k]
1021//
1022// Diagonal:
1023//   diag[marginal_j] += row_hess[row, 0*r + 0] · marginal_design[row, j]²
1024//   diag[logslope_j] += row_hess[row, 1*r + 1] · logslope_design[row, j]²
1025//   diag[h_block_start + k] += row_hess[row, ii*r + ii]   (ii = h_primary_start + k)
1026//   diag[w_block_start + k] += row_hess[row, ii*r + ii]   (ii = w_primary_start + k)
1027//
1028// Determinism: each CTA owns a contiguous slice of `[chunk_start..chunk_end)`
1029// rows and writes its full per-chunk `p_total` partial into a non-overlapping
1030// region of the global partial buffer. The reduce kernel then sums those
1031// partials in fixed chunk-major order. No atomics.
1032
1033/// Joint-β block layout shared with the host (mirrors `BlockSlices` in
1034/// `bernoulli_marginal_slope.rs`).
1035///
1036/// Gating: Linux-only. The lone production constructor lives in
1037/// `bernoulli_marginal_slope.rs:9189` behind `#[cfg(target_os = "linux")]`
1038/// — the device-resident row-Hessian path is the only producer (see
1039/// `launch_bms_flex_row_kernel_device_resident`), and the joint-β
1040/// consumers `launch_bms_flex_row_hvp` / `_diagonal` / `_dense_block`
1041/// are also Linux-only. Any non-Linux test referencing this type must
1042/// guard itself with `#[cfg(target_os = "linux")]` too — the build.rs
1043/// ban scanner explicitly rejects `#[cfg(any(..., test))]` on items as
1044/// a dead-code escape hatch.
1045#[cfg(target_os = "linux")]
1046#[derive(Clone, Debug)]
1047pub(crate) struct BmsFlexBlockLayout {
1048    pub p_m: usize,
1049    pub p_g: usize,
1050    pub h: Option<std::ops::Range<usize>>,
1051    pub w: Option<std::ops::Range<usize>>,
1052    pub p_total: usize,
1053}
1054
1055/// Primary-r layout shared with the host (mirrors `PrimarySlices`).
1056/// Gating rationale identical to [`BmsFlexBlockLayout`].
1057#[cfg(target_os = "linux")]
1058#[derive(Clone, Debug)]
1059pub(crate) struct BmsFlexPrimaryLayout {
1060    pub h: Option<std::ops::Range<usize>>,
1061    pub w: Option<std::ops::Range<usize>>,
1062    pub r: usize,
1063}
1064
1065// ── Linux-only: device-resident row-Hessian state + kernels ─────────────────
1066
1067/// Number of rows each HVP / diagonal CTA processes. Each CTA writes a single
1068/// `[1, p_total]` partial row into the global partial buffer (no atomics);
1069/// the reduce kernel then sums partials in chunk-major fixed order.
1070#[cfg(target_os = "linux")]
1071pub(crate) const HVP_ROWS_PER_CTA: u32 = 256;
1072
1073/// `blockDim.x` for the HVP / diagonal partial kernels.
1074#[cfg(target_os = "linux")]
1075pub(crate) const HVP_THREADS: u32 = 128;
1076
1077/// `blockDim.x` for the partial-sum reduction kernels (one element per thread,
1078/// grid-strided over the `p_total`/`rhs_elems` partial buffer). A full warp
1079/// multiple that keeps the reduce launch occupancy-bound rather than tail-bound
1080/// for the typical large-scale `p_total`.
1081#[cfg(target_os = "linux")]
1082pub(crate) const REDUCTION_THREADS: u32 = 256;
1083
1084/// Maximum RHS columns fused into one row-primary HVP launch. The matching
1085/// CUDA source uses fixed shared arrays sized as
1086/// `BMS_FLEX_ROW_HVP_MAX_RHS * MAX_R`; increasing this requires updating the
1087/// `MAX_MULTI_RHS` define in `HVP_KERNEL_SOURCE`.
1088#[cfg(target_os = "linux")]
1089pub(crate) const BMS_FLEX_ROW_HVP_MAX_RHS: usize = 8;
1090
1091/// Device-resident state produced by
1092/// [`launch_bms_flex_row_kernel_device_resident`] and consumed by
1093/// [`launch_bms_flex_row_hvp`] / [`launch_bms_flex_row_diagonal`].
1094///
1095/// Owns the row-Hessian + design slices on-device so the host can issue
1096/// many HVPs against the same β snapshot without round-tripping
1097/// 626 MB through host RAM. Drop releases the device memory back to
1098/// the CUDA runtime.
1099/// Per-row Hessian storage layout on the device. The build path is free to
1100/// emit either, and the Hv / diag kernels read whichever the storage says.
1101///
1102/// Charter (Block 9 Phase 4): packed-upper halves the DRAM footprint of the
1103/// `n × r²` cache (per-row `r*(r+1)/2` doubles instead of `r²`), at the cost
1104/// of a single per-entry index conversion in the kernel. The benchmark
1105/// decides whether the packed path becomes the default for large-scale
1106/// fits (`r = 20` → 210 vs 400 doubles per row, ~47.5% smaller). The
1107/// numerics are bit-equal because each `H_i` is symmetric by construction
1108/// (the row kernel emits a symmetric block by construction — see the
1109/// symmetric scratch-write loop in `bms_flex_row_kernel`'s shared-memory
1110/// finaliser).
1111#[cfg(target_os = "linux")]
1112pub struct DeviceResidentRowHess {
1113    /// Per-row dense `[n, r, r]` row-major Hessian. Element `(u, v)` of row
1114    /// `i` is `hess[i*r*r + u*r + v]`. This is the only on-device storage
1115    /// layout supported by the current HVP / diag kernels.
1116    pub(crate) hess: CudaSlice<f64>,
1117    pub(crate) marginal_design: CudaSlice<f64>,
1118    pub(crate) logslope_design: CudaSlice<f64>,
1119    pub(crate) n: usize,
1120    pub(crate) r: usize,
1121    pub(crate) block: BmsFlexBlockLayout,
1122    pub(crate) primary: BmsFlexPrimaryLayout,
1123    /// Estimated bytes resident on device (for accounting).
1124    pub(crate) bytes: u64,
1125}
1126
1127#[cfg(target_os = "linux")]
1128impl std::fmt::Debug for DeviceResidentRowHess {
1129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1130        f.debug_struct("DeviceResidentRowHess")
1131            .field("n", &self.n)
1132            .field("r", &self.r)
1133            .field("p_total", &self.block.p_total)
1134            .field("bytes", &self.bytes)
1135            .finish()
1136    }
1137}
1138
1139/// Sized-to-fit-once CTA mapping. Rows `[c * HVP_ROWS_PER_CTA, (c+1) * HVP_ROWS_PER_CTA)`
1140/// belong to chunk `c`.
1141#[cfg(target_os = "linux")]
1142pub(crate) fn num_hvp_chunks(n: usize) -> usize {
1143    n.div_ceil(HVP_ROWS_PER_CTA as usize)
1144}
1145
1146/// NVRTC source: HVP-partial kernel + HVP-reduce kernel + diag-partial +
1147/// diag-reduce. All kernels mirror the CPU oracle in this file.
1148#[cfg(target_os = "linux")]
1149pub(crate) const HVP_KERNEL_SOURCE: &str = r#"
1150// CPU parity reference: cpu_oracle_bms_flex_row_hvp / cpu_oracle_bms_flex_row_diagonal
1151// in this module.
1152
1153#define MAX_MULTI_RHS 8
1154
1155extern "C" __global__ void bms_flex_row_hvp_partial(
1156    int                  n_rows,
1157    int                  r,
1158    int                  p_m,
1159    int                  p_g,
1160    int                  p_total,
1161    int                  h_block_start,
1162    int                  h_block_len,
1163    int                  w_block_start,
1164    int                  w_block_len,
1165    int                  h_primary_start,
1166    int                  w_primary_start,
1167    int                  rows_per_cta,
1168    const double * __restrict__ row_hessians,    // [n, r*r]
1169    const double * __restrict__ marginal_design, // [n, p_m] row-major
1170    const double * __restrict__ logslope_design, // [n, p_g] row-major
1171    const double * __restrict__ v,               // [p_total]
1172    double       * __restrict__ partial)         // [num_chunks, p_total]
1173{
1174    int chunk = blockIdx.x;
1175    int tid   = threadIdx.x;
1176    int row_lo = chunk * rows_per_cta;
1177    int row_hi = row_lo + rows_per_cta;
1178    if (row_hi > n_rows) row_hi = n_rows;
1179
1180    // Zero this chunk's partial slice cooperatively.
1181    double *out = partial + (size_t)chunk * (size_t)p_total;
1182    for (int j = tid; j < p_total; j += blockDim.x) {
1183        out[j] = 0.0;
1184    }
1185    __syncthreads();
1186
1187    // Each thread serially processes a stride-of-blockDim set of rows so
1188    // every write to `out[..]` happens from one thread → no atomics within
1189    // the chunk. To keep writes race-free across threads of the same chunk,
1190    // we serialize the cross-row accumulation through a per-row barrier:
1191    // thread 0 of the block processes all rows in the chunk. The per-row
1192    // work is dominated by the dot/axpy over `p_m + p_g`, which is large.
1193    // For Stage 3 we ship the simple, correct path (thread 0 sequential
1194    // per row, blockDim.x threads parallel within a row's dot/axpy).
1195    __shared__ double row_dir[32];
1196    __shared__ double action[32];
1197    __shared__ double dot_reduce[128];
1198
1199    for (int row = row_lo; row < row_hi; ++row) {
1200        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1201        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1202        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1203
1204        // row_dir[0] = mrow · v[0..p_m]
1205        double local = 0.0;
1206        for (int j = tid; j < p_m; j += blockDim.x) {
1207            local += mrow[j] * v[j];
1208        }
1209        dot_reduce[tid] = local;
1210        __syncthreads();
1211        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1212            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1213            __syncthreads();
1214        }
1215        if (tid == 0) row_dir[0] = dot_reduce[0];
1216
1217        // row_dir[1] = grow · v[p_m..p_m+p_g]
1218        local = 0.0;
1219        for (int j = tid; j < p_g; j += blockDim.x) {
1220            local += grow[j] * v[p_m + j];
1221        }
1222        dot_reduce[tid] = local;
1223        __syncthreads();
1224        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1225            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1226            __syncthreads();
1227        }
1228        if (tid == 0) row_dir[1] = dot_reduce[0];
1229
1230        // h/w blocks: direct copy.
1231        if (tid == 0) {
1232            for (int k = 0; k < h_block_len; ++k) {
1233                row_dir[h_primary_start + k] = v[h_block_start + k];
1234            }
1235            for (int k = 0; k < w_block_len; ++k) {
1236                row_dir[w_primary_start + k] = v[w_block_start + k];
1237            }
1238        }
1239        __syncthreads();
1240
1241        // action[u] = Σ_v Hrow[u*r+v] · row_dir[v], computed by thread u (u < r).
1242        if (tid < r) {
1243            double acc = 0.0;
1244            for (int vv = 0; vv < r; ++vv) {
1245                acc += Hrow[tid * r + vv] * row_dir[vv];
1246            }
1247            action[tid] = acc;
1248        }
1249        __syncthreads();
1250
1251        // Pull back into joint β slot.
1252        //   marginal: out[j] += action[0] · mrow[j]   (parallel j)
1253        double a0 = action[0];
1254        for (int j = tid; j < p_m; j += blockDim.x) {
1255            out[j] += a0 * mrow[j];
1256        }
1257        double a1 = action[1];
1258        for (int j = tid; j < p_g; j += blockDim.x) {
1259            out[p_m + j] += a1 * grow[j];
1260        }
1261        if (tid == 0) {
1262            for (int k = 0; k < h_block_len; ++k) {
1263                out[h_block_start + k] += action[h_primary_start + k];
1264            }
1265            for (int k = 0; k < w_block_len; ++k) {
1266                out[w_block_start + k] += action[w_primary_start + k];
1267            }
1268        }
1269        __syncthreads();
1270    }
1271}
1272
1273extern "C" __global__ void bms_flex_row_hvp_reduce(
1274    int                  num_chunks,
1275    int                  p_total,
1276    const double * __restrict__ partial,   // [num_chunks, p_total]
1277    double       * __restrict__ out)        // [p_total]
1278{
1279    int j = blockIdx.x * blockDim.x + threadIdx.x;
1280    if (j >= p_total) return;
1281    double acc = 0.0;
1282    for (int c = 0; c < num_chunks; ++c) {
1283        acc += partial[(size_t)c * (size_t)p_total + (size_t)j];
1284    }
1285    out[j] = acc;
1286}
1287
1288extern "C" __global__ void bms_flex_row_hvp_multi_partial(
1289    int                  n_rows,
1290    int                  r,
1291    int                  p_m,
1292    int                  p_g,
1293    int                  p_total,
1294    int                  h_block_start,
1295    int                  h_block_len,
1296    int                  w_block_start,
1297    int                  w_block_len,
1298    int                  h_primary_start,
1299    int                  w_primary_start,
1300    int                  rows_per_cta,
1301    int                  rhs_count,
1302    const double * __restrict__ row_hessians,    // [n, r*r]
1303    const double * __restrict__ marginal_design, // [n, p_m]
1304    const double * __restrict__ logslope_design, // [n, p_g]
1305    const double * __restrict__ v_rhs,           // [rhs_count, p_total]
1306    double       * __restrict__ partial)         // [rhs_count, num_chunks, p_total]
1307{
1308    int chunk = blockIdx.x;
1309    int tid   = threadIdx.x;
1310    int row_lo = chunk * rows_per_cta;
1311    int row_hi = row_lo + rows_per_cta;
1312    if (row_hi > n_rows) row_hi = n_rows;
1313
1314    int num_chunks = (n_rows + rows_per_cta - 1) / rows_per_cta;
1315    for (int idx = tid; idx < rhs_count * p_total; idx += blockDim.x) {
1316        int rhs = idx / p_total;
1317        int j = idx - rhs * p_total;
1318        partial[((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total + (size_t)j] = 0.0;
1319    }
1320    __syncthreads();
1321
1322    __shared__ double row_dir[MAX_MULTI_RHS * 32];
1323    __shared__ double action[MAX_MULTI_RHS * 32];
1324    __shared__ double dot_reduce[128];
1325
1326    for (int row = row_lo; row < row_hi; ++row) {
1327        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1328        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1329        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1330
1331        for (int rhs = 0; rhs < rhs_count; ++rhs) {
1332            const double *v = v_rhs + (size_t)rhs * (size_t)p_total;
1333
1334            double local = 0.0;
1335            for (int j = tid; j < p_m; j += blockDim.x) {
1336                local += mrow[j] * v[j];
1337            }
1338            dot_reduce[tid] = local;
1339            __syncthreads();
1340            for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1341                if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1342                __syncthreads();
1343            }
1344            if (tid == 0) row_dir[rhs * 32 + 0] = dot_reduce[0];
1345
1346            local = 0.0;
1347            for (int j = tid; j < p_g; j += blockDim.x) {
1348                local += grow[j] * v[p_m + j];
1349            }
1350            dot_reduce[tid] = local;
1351            __syncthreads();
1352            for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1353                if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1354                __syncthreads();
1355            }
1356            if (tid == 0) {
1357                row_dir[rhs * 32 + 1] = dot_reduce[0];
1358                for (int k = 0; k < h_block_len; ++k) {
1359                    row_dir[rhs * 32 + h_primary_start + k] = v[h_block_start + k];
1360                }
1361                for (int k = 0; k < w_block_len; ++k) {
1362                    row_dir[rhs * 32 + w_primary_start + k] = v[w_block_start + k];
1363                }
1364            }
1365            __syncthreads();
1366        }
1367
1368        for (int idx = tid; idx < rhs_count * r; idx += blockDim.x) {
1369            int rhs = idx / r;
1370            int u = idx - rhs * r;
1371            double acc = 0.0;
1372            const double *dir = row_dir + rhs * 32;
1373            for (int vv = 0; vv < r; ++vv) {
1374                acc += Hrow[u * r + vv] * dir[vv];
1375            }
1376            action[rhs * 32 + u] = acc;
1377        }
1378        __syncthreads();
1379
1380        for (int rhs = 0; rhs < rhs_count; ++rhs) {
1381            double *out = partial + ((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total;
1382            double a0 = action[rhs * 32 + 0];
1383            for (int j = tid; j < p_m; j += blockDim.x) {
1384                out[j] += a0 * mrow[j];
1385            }
1386            double a1 = action[rhs * 32 + 1];
1387            for (int j = tid; j < p_g; j += blockDim.x) {
1388                out[p_m + j] += a1 * grow[j];
1389            }
1390            if (tid == 0) {
1391                for (int k = 0; k < h_block_len; ++k) {
1392                    out[h_block_start + k] += action[rhs * 32 + h_primary_start + k];
1393                }
1394                for (int k = 0; k < w_block_len; ++k) {
1395                    out[w_block_start + k] += action[rhs * 32 + w_primary_start + k];
1396                }
1397            }
1398            __syncthreads();
1399        }
1400    }
1401}
1402
1403extern "C" __global__ void bms_flex_row_hvp_multi_reduce(
1404    int                  num_chunks,
1405    int                  p_total,
1406    int                  rhs_count,
1407    const double * __restrict__ partial,   // [rhs_count, num_chunks, p_total]
1408    double       * __restrict__ out)        // [rhs_count, p_total]
1409{
1410    int idx = blockIdx.x * blockDim.x + threadIdx.x;
1411    int total = rhs_count * p_total;
1412    if (idx >= total) return;
1413    int rhs = idx / p_total;
1414    int j = idx - rhs * p_total;
1415    double acc = 0.0;
1416    for (int c = 0; c < num_chunks; ++c) {
1417        acc += partial[((size_t)rhs * (size_t)num_chunks + (size_t)c) * (size_t)p_total + (size_t)j];
1418    }
1419    out[(size_t)rhs * (size_t)p_total + (size_t)j] = acc;
1420}
1421
1422extern "C" __global__ void bms_flex_row_diag_partial(
1423    int                  n_rows,
1424    int                  r,
1425    int                  p_m,
1426    int                  p_g,
1427    int                  p_total,
1428    int                  h_block_start,
1429    int                  h_block_len,
1430    int                  w_block_start,
1431    int                  w_block_len,
1432    int                  h_primary_start,
1433    int                  w_primary_start,
1434    int                  rows_per_cta,
1435    const double * __restrict__ row_hessians,
1436    const double * __restrict__ marginal_design,
1437    const double * __restrict__ logslope_design,
1438    double       * __restrict__ partial)
1439{
1440    int chunk = blockIdx.x;
1441    int tid   = threadIdx.x;
1442    int row_lo = chunk * rows_per_cta;
1443    int row_hi = row_lo + rows_per_cta;
1444    if (row_hi > n_rows) row_hi = n_rows;
1445
1446    double *out = partial + (size_t)chunk * (size_t)p_total;
1447    for (int j = tid; j < p_total; j += blockDim.x) {
1448        out[j] = 0.0;
1449    }
1450    __syncthreads();
1451
1452    for (int row = row_lo; row < row_hi; ++row) {
1453        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1454        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1455        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1456        double h00 = Hrow[0];
1457        double h11 = Hrow[1 * r + 1];
1458        for (int j = tid; j < p_m; j += blockDim.x) {
1459            double v = mrow[j];
1460            out[j] += h00 * v * v;
1461        }
1462        for (int j = tid; j < p_g; j += blockDim.x) {
1463            double v = grow[j];
1464            out[p_m + j] += h11 * v * v;
1465        }
1466        if (tid == 0) {
1467            for (int k = 0; k < h_block_len; ++k) {
1468                int ii = h_primary_start + k;
1469                out[h_block_start + k] += Hrow[ii * r + ii];
1470            }
1471            for (int k = 0; k < w_block_len; ++k) {
1472                int ii = w_primary_start + k;
1473                out[w_block_start + k] += Hrow[ii * r + ii];
1474            }
1475        }
1476        __syncthreads();
1477    }
1478}
1479
1480// ────────────────────────────────────────────────────────────────────────
1481// Phase 4 — SymmetricPackedUpper variants. Per-row storage is
1482//   row_hessians_packed + (size_t)row * (size_t)(r*(r+1)/2)
1483// indexed as
1484//   packed[(u*(2*r - u - 1))/2 + (v - u)]   for u <= v
1485// with symmetric mirror for v < u.
1486// ────────────────────────────────────────────────────────────────────────
1487
1488// Helper: packed-upper index for (u, v) within a single row of r*(r+1)/2
1489// doubles. Caller must pre-swap so that u <= v.
1490__device__ __forceinline__ int bms_flex_packed_idx(int u, int v, int r) {
1491    // u*(2r - u - 1)/2 + (v - u)
1492    return (u * (2 * r - u - 1)) / 2 + (v - u);
1493}
1494
1495// Pack one row of the full row-major r×r Hessian into packed-upper layout.
1496// Launched as one CTA per row (gridDim.x = n_rows, blockDim.x configurable).
1497// Bit-equal copy: each upper-triangle entry is read once from the dense
1498// source and written once to the packed destination.
1499extern "C" __global__ void bms_flex_row_pack_upper(
1500    int                  n_rows,
1501    int                  r,
1502    const double * __restrict__ src_full,    // [n, r*r]
1503    double       * __restrict__ dst_packed)  // [n, r*(r+1)/2]
1504{
1505    int row = blockIdx.x;
1506    if (row >= n_rows) return;
1507    int tid = threadIdx.x;
1508    int per_row = r * (r + 1) / 2;
1509    const double *src = src_full + (size_t)row * (size_t)r * (size_t)r;
1510    double       *dst = dst_packed + (size_t)row * (size_t)per_row;
1511    // Linear scan over packed positions; map each back to (u, v).
1512    for (int pos = tid; pos < per_row; pos += blockDim.x) {
1513        // Invert: for u in [0, r), the range [u_start, u_start + (r - u))
1514        // contains positions for that u. u_start = u*(2r - u - 1)/2.
1515        // Solve smallest u with u*(2r - u - 1)/2 > pos to get u (then
1516        // back off by one); equivalent O(r) linear scan with r <= 32.
1517        int u = 0;
1518        int u_start = 0;
1519        while (u < r) {
1520            int next = u_start + (r - u);
1521            if (pos < next) break;
1522            u_start = next;
1523            ++u;
1524        }
1525        int v = u + (pos - u_start);
1526        dst[pos] = src[(size_t)u * (size_t)r + (size_t)v];
1527    }
1528}
1529
1530extern "C" __global__ void bms_flex_row_hvp_partial_packed(
1531    int                  n_rows,
1532    int                  r,
1533    int                  p_m,
1534    int                  p_g,
1535    int                  p_total,
1536    int                  h_block_start,
1537    int                  h_block_len,
1538    int                  w_block_start,
1539    int                  w_block_len,
1540    int                  h_primary_start,
1541    int                  w_primary_start,
1542    int                  rows_per_cta,
1543    const double * __restrict__ row_hessians_packed, // [n, r*(r+1)/2]
1544    const double * __restrict__ marginal_design,
1545    const double * __restrict__ logslope_design,
1546    const double * __restrict__ v,
1547    double       * __restrict__ partial)
1548{
1549    int chunk = blockIdx.x;
1550    int tid   = threadIdx.x;
1551    int row_lo = chunk * rows_per_cta;
1552    int row_hi = row_lo + rows_per_cta;
1553    if (row_hi > n_rows) row_hi = n_rows;
1554
1555    int per_row = r * (r + 1) / 2;
1556    double *out = partial + (size_t)chunk * (size_t)p_total;
1557    for (int j = tid; j < p_total; j += blockDim.x) {
1558        out[j] = 0.0;
1559    }
1560    __syncthreads();
1561
1562    __shared__ double row_dir[32];
1563    __shared__ double action[32];
1564    __shared__ double dot_reduce[128];
1565
1566    for (int row = row_lo; row < row_hi; ++row) {
1567        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1568        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1569        const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1570
1571        // row_dir[0] = mrow · v[0..p_m]
1572        double local = 0.0;
1573        for (int j = tid; j < p_m; j += blockDim.x) {
1574            local += mrow[j] * v[j];
1575        }
1576        dot_reduce[tid] = local;
1577        __syncthreads();
1578        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1579            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1580            __syncthreads();
1581        }
1582        if (tid == 0) row_dir[0] = dot_reduce[0];
1583
1584        // row_dir[1] = grow · v[p_m..p_m+p_g]
1585        local = 0.0;
1586        for (int j = tid; j < p_g; j += blockDim.x) {
1587            local += grow[j] * v[p_m + j];
1588        }
1589        dot_reduce[tid] = local;
1590        __syncthreads();
1591        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1592            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1593            __syncthreads();
1594        }
1595        if (tid == 0) row_dir[1] = dot_reduce[0];
1596
1597        if (tid == 0) {
1598            for (int k = 0; k < h_block_len; ++k) {
1599                row_dir[h_primary_start + k] = v[h_block_start + k];
1600            }
1601            for (int k = 0; k < w_block_len; ++k) {
1602                row_dir[w_primary_start + k] = v[w_block_start + k];
1603            }
1604        }
1605        __syncthreads();
1606
1607        // action[u] = Σ_w H[u, w] · row_dir[w], where H[u, w] reads from
1608        // packed-upper with (uu, vv) = (min(u, w), max(u, w)).
1609        if (tid < r) {
1610            double acc = 0.0;
1611            int u = tid;
1612            for (int w = 0; w < r; ++w) {
1613                int uu = u < w ? u : w;
1614                int vv = u < w ? w : u;
1615                acc += Hrow[bms_flex_packed_idx(uu, vv, r)] * row_dir[w];
1616            }
1617            action[tid] = acc;
1618        }
1619        __syncthreads();
1620
1621        double a0 = action[0];
1622        for (int j = tid; j < p_m; j += blockDim.x) {
1623            out[j] += a0 * mrow[j];
1624        }
1625        double a1 = action[1];
1626        for (int j = tid; j < p_g; j += blockDim.x) {
1627            out[p_m + j] += a1 * grow[j];
1628        }
1629        if (tid == 0) {
1630            for (int k = 0; k < h_block_len; ++k) {
1631                out[h_block_start + k] += action[h_primary_start + k];
1632            }
1633            for (int k = 0; k < w_block_len; ++k) {
1634                out[w_block_start + k] += action[w_primary_start + k];
1635            }
1636        }
1637        __syncthreads();
1638    }
1639}
1640
1641// ────────────────────────────────────────────────────────────────────────
1642// Phase 6 — dense joint-Hessian block kernel for the debug / exact-REML
1643// route. Materialises the full `[p_total, p_total]` row-major joint H
1644// from the per-row r×r Hessian via the P_i pullback. NOT the default
1645// Newton path: production Newton uses HVP (Phase 2/3); this kernel exists
1646// for exact-REML logdet / dense-H comparisons / diagnostic dumps where the
1647// caller genuinely needs the dense matrix on the device.
1648//
1649// Per-CTA partial: each CTA owns a contiguous chunk of rows
1650// `[chunk*rows_per_cta, (chunk+1)*rows_per_cta)`. Inside the CTA the
1651// per-row pullback computes `(P_i^T H_i P_i)[m, n]` and adds it to the
1652// CTA's shared-mem `[p_total, p_total]` partial. The reduce kernel sums
1653// chunk-major-fixed-order into a single `[p_total, p_total]` output.
1654//
1655// Math: for primary index u ∈ [0, r):
1656//   * u = 0:        phi_u = (X_i in slot 0..p_m, 0 elsewhere)
1657//   * u = 1:        phi_u = (0, G_i in slot p_m..p_m+p_g, 0 elsewhere)
1658//   * u = 2+j:      phi_u = e_{h_block_start + j}  (j ∈ 0..h_block_len)
1659//   * u = 2+h+l:    phi_u = e_{w_block_start + l}  (l ∈ 0..w_block_len)
1660// Then `H_full[m, n] += sum_{u,v} H_i[u,v] * phi_u[m] * phi_v[n]`.
1661//
1662// Shared-memory budget: at large-scale shape p_total = 44, a [44, 44] f64
1663// partial is 44*44*8 = 15.5 KiB — well below the V100 48 KiB/SM cap.
1664// At p_total ≤ 80 the kernel still fits (80*80*8 = 50 KiB → just over
1665// V100 cap; caller must enforce p_total ≤ DENSE_BLOCK_MAX_P). The
1666// launcher rejects oversize p_total cleanly.
1667
1668extern "C" __global__ void bms_flex_row_dense_block_partial(
1669    int                  n_rows,
1670    int                  r,
1671    int                  p_m,
1672    int                  p_g,
1673    int                  p_total,
1674    int                  h_block_start,
1675    int                  h_block_len,
1676    int                  w_block_start,
1677    int                  w_block_len,
1678    int                  h_primary_start,
1679    int                  w_primary_start,
1680    int                  rows_per_cta,
1681    const double * __restrict__ row_hessians,    // [n, r*r]
1682    const double * __restrict__ marginal_design, // [n, p_m]
1683    const double * __restrict__ logslope_design, // [n, p_g]
1684    double       * __restrict__ partial)         // [num_chunks, p_total, p_total]
1685{
1686    extern __shared__ double shmem[];
1687    int chunk = blockIdx.x;
1688    int tid   = threadIdx.x;
1689    int row_lo = chunk * rows_per_cta;
1690    int row_hi = row_lo + rows_per_cta;
1691    if (row_hi > n_rows) row_hi = n_rows;
1692
1693    int pp = p_total * p_total;
1694    double *acc = shmem; // CTA-private accumulator [p_total, p_total]
1695    for (int j = tid; j < pp; j += blockDim.x) acc[j] = 0.0;
1696    __syncthreads();
1697
1698    // Per-row work performed by thread 0 to avoid cross-thread RW
1699    // contention on `acc[]`. Per-row complexity is O(r * p_m + r * p_g
1700    // + r²): tractable because r ≤ 32 and p_m + p_g typically ≤ 64.
1701    // Tighter parallel implementations are possible (warp-stripe the
1702    // 4-way nested u-v-m-n loop) but Phase 6 is a debug-only path and
1703    // the simple version is easier to audit for correctness against
1704    // the host-side P_i pullback oracle.
1705    if (tid == 0) {
1706        for (int row = row_lo; row < row_hi; ++row) {
1707            const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1708            const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1709            const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1710            for (int u = 0; u < r; ++u) {
1711                for (int v = 0; v < r; ++v) {
1712                    double huv = Hrow[u * r + v];
1713                    if (huv == 0.0) continue;
1714                    // For each (u, v), iterate (m, n) over the non-zero
1715                    // outer-product support of phi_u and phi_v.
1716                    // Build a small (offset, len, src_ptr) descriptor for
1717                    // each operand block as we go.
1718                    int m_off, m_len; const double *m_src; bool m_indicator;
1719                    int n_off, n_len; const double *n_src; bool n_indicator;
1720                    if (u == 0)      { m_off = 0;   m_len = p_m; m_src = mrow; m_indicator = false; }
1721                    else if (u == 1) { m_off = p_m; m_len = p_g; m_src = grow; m_indicator = false; }
1722                    else if (u - 2 < h_block_len) {
1723                                       m_off = h_block_start + (u - 2);
1724                                       m_len = 1;   m_src = NULL; m_indicator = true;
1725                    } else {
1726                                       m_off = w_block_start + (u - 2 - h_block_len);
1727                                       m_len = 1;   m_src = NULL; m_indicator = true;
1728                    }
1729                    if (v == 0)      { n_off = 0;   n_len = p_m; n_src = mrow; n_indicator = false; }
1730                    else if (v == 1) { n_off = p_m; n_len = p_g; n_src = grow; n_indicator = false; }
1731                    else if (v - 2 < h_block_len) {
1732                                       n_off = h_block_start + (v - 2);
1733                                       n_len = 1;   n_src = NULL; n_indicator = true;
1734                    } else {
1735                                       n_off = w_block_start + (v - 2 - h_block_len);
1736                                       n_len = 1;   n_src = NULL; n_indicator = true;
1737                    }
1738                    // accumulate huv * phi_u[m] * phi_v[n] into acc[m, n]
1739                    for (int mi = 0; mi < m_len; ++mi) {
1740                        double pm = m_indicator ? 1.0 : m_src[mi];
1741                        if (pm == 0.0) continue;
1742                        double scaled = huv * pm;
1743                        int m_idx = m_off + mi;
1744                        for (int ni = 0; ni < n_len; ++ni) {
1745                            double pn = n_indicator ? 1.0 : n_src[ni];
1746                            int n_idx = n_off + ni;
1747                            acc[m_idx * p_total + n_idx] += scaled * pn;
1748                        }
1749                    }
1750                }
1751            }
1752        }
1753    }
1754    __syncthreads();
1755
1756    // Write CTA accumulator out to global memory at its chunk slot.
1757    double *out_chunk = partial + (size_t)chunk * (size_t)pp;
1758    for (int j = tid; j < pp; j += blockDim.x) {
1759        out_chunk[j] = acc[j];
1760    }
1761}
1762
1763extern "C" __global__ void bms_flex_row_dense_block_reduce(
1764    int                  num_chunks,
1765    int                  p_total,
1766    const double * __restrict__ partial,
1767    double       * __restrict__ out)
1768{
1769    int j = blockIdx.x * blockDim.x + threadIdx.x;
1770    int pp = p_total * p_total;
1771    if (j >= pp) return;
1772    double acc = 0.0;
1773    for (int c = 0; c < num_chunks; ++c) {
1774        acc += partial[(size_t)c * (size_t)pp + (size_t)j];
1775    }
1776    out[j] = acc;
1777}
1778
1779extern "C" __global__ void bms_flex_row_diag_partial_packed(
1780    int                  n_rows,
1781    int                  r,
1782    int                  p_m,
1783    int                  p_g,
1784    int                  p_total,
1785    int                  h_block_start,
1786    int                  h_block_len,
1787    int                  w_block_start,
1788    int                  w_block_len,
1789    int                  h_primary_start,
1790    int                  w_primary_start,
1791    int                  rows_per_cta,
1792    const double * __restrict__ row_hessians_packed,
1793    const double * __restrict__ marginal_design,
1794    const double * __restrict__ logslope_design,
1795    double       * __restrict__ partial)
1796{
1797    int chunk = blockIdx.x;
1798    int tid   = threadIdx.x;
1799    int row_lo = chunk * rows_per_cta;
1800    int row_hi = row_lo + rows_per_cta;
1801    if (row_hi > n_rows) row_hi = n_rows;
1802
1803    int per_row = r * (r + 1) / 2;
1804    double *out = partial + (size_t)chunk * (size_t)p_total;
1805    for (int j = tid; j < p_total; j += blockDim.x) {
1806        out[j] = 0.0;
1807    }
1808    __syncthreads();
1809
1810    for (int row = row_lo; row < row_hi; ++row) {
1811        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1812        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1813        const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1814        // Diagonal entry for (u, u) sits at packed_idx(u, u, r).
1815        double h00 = Hrow[bms_flex_packed_idx(0, 0, r)];
1816        double h11 = Hrow[bms_flex_packed_idx(1, 1, r)];
1817        for (int j = tid; j < p_m; j += blockDim.x) {
1818            double v = mrow[j];
1819            out[j] += h00 * v * v;
1820        }
1821        for (int j = tid; j < p_g; j += blockDim.x) {
1822            double v = grow[j];
1823            out[p_m + j] += h11 * v * v;
1824        }
1825        if (tid == 0) {
1826            for (int k = 0; k < h_block_len; ++k) {
1827                int ii = h_primary_start + k;
1828                out[h_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1829            }
1830            for (int k = 0; k < w_block_len; ++k) {
1831                int ii = w_primary_start + k;
1832                out[w_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1833            }
1834        }
1835        __syncthreads();
1836    }
1837}
1838"#;
1839
1840#[cfg(target_os = "linux")]
1841pub(crate) struct HvpKernelBackend {
1842    pub(crate) stream: Arc<CudaStream>,
1843    pub(crate) module: Arc<CudaModule>,
1844}
1845
1846#[cfg(target_os = "linux")]
1847impl HvpKernelBackend {
1848    pub(crate) fn probe() -> Result<&'static Self, GpuError> {
1849        static BACKEND: OnceLock<Result<HvpKernelBackend, GpuError>> = OnceLock::new();
1850        BACKEND
1851            .get_or_init(|| {
1852                gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row hvp", |parts| {
1853                    // #1551: arch-aware compile (see launch_bms_flex_row_kernel) —
1854                    // pin `--gpu-architecture` to the device capability and supply
1855                    // the standard include paths via the shared NVRTC entry point.
1856                    let ptx = gam_gpu::device_cache::compile_ptx_arch(HVP_KERNEL_SOURCE).map_err(
1857                        |err| GpuError::DriverCallFailed {
1858                            reason: format!("bms_flex_row hvp NVRTC compile failed: {err}"),
1859                        },
1860                    )?;
1861                    let module =
1862                        parts
1863                            .ctx
1864                            .load_module(ptx)
1865                            .map_err(|err| GpuError::DriverCallFailed {
1866                                reason: format!("bms_flex_row hvp module load failed: {err}"),
1867                            })?;
1868                    Ok(HvpKernelBackend {
1869                        stream: parts.stream.clone(),
1870                        module,
1871                    })
1872                })
1873            })
1874            .as_ref()
1875            .map_err(GpuError::clone)
1876    }
1877}
1878
1879/// Build a device-resident row-Hessian cache by launching the row kernel and
1880/// keeping the resulting `n × r²` slice resident on the device. Also uploads
1881/// the dense marginal + logslope design matrices so subsequent HVPs do not
1882/// re-upload them at every direction.
1883///
1884/// `marginal_design_row_major` and `logslope_design_row_major` must be
1885/// row-major `[n, p_m]` and `[n, p_g]` contiguous slices.
1886///
1887/// #461 absorber (additive Stage-1 influence block): the orthogonalization is
1888/// realized as **A2 — the marginal design widened to `[M | Z̃_infl]`** (see
1889/// `src/families/bms/block_specs.rs::widen_marginal_dense_with_influence`), NOT
1890/// a dedicated 5th primary coordinate. The absorber `+Z̃_infl·γ` is plain
1891/// additive into the marginal index `α(x)`, so γ lives inside `β_m` of the
1892/// widened block and `p_m` already counts the `p₁` influence columns. The row
1893/// kernel reads the marginal index from `block_states[0].eta` (which carries
1894/// `Z̃_infl·γ`) and pulls back through this same widened `marginal_design`, so
1895/// the absorber rides the existing primary-coordinate `u = 0` chain with **no
1896/// kernel-source change**: η, gradient, and Hessian match the CPU kernel
1897/// bit-for-bit precisely because `marginal_design` and `β_m` are the matched
1898/// (design, coefficient) pair the CPU path uses. The validation below pins
1899/// `marginal_design.len() == n·p_m` (with `p_m` widened), so a stale narrow
1900/// design against a widened `block.p_m` is rejected cleanly (CPU fallback)
1901/// rather than silently computing the wrong η. The absorber is dropped at
1902/// predict, where the marginal design is rebuilt without the influence columns,
1903/// so the predict-time `p_m` is narrow and this path is correct there too.
1904#[cfg(target_os = "linux")]
1905pub(crate) fn launch_bms_flex_row_kernel_device_resident(
1906    inputs: BmsFlexRowKernelInputs<'_>,
1907    marginal_design_row_major: &[f64],
1908    logslope_design_row_major: &[f64],
1909    block: BmsFlexBlockLayout,
1910    primary: BmsFlexPrimaryLayout,
1911) -> Result<DeviceResidentRowHess, GpuError> {
1912    inputs.validate()?;
1913    if !s_f_diagnostic_finite(&inputs) {
1914        return Err(GpuError::DriverCallFailed {
1915            reason: format!(
1916                "bms_flex_row device-resident: s_f must be positive and finite, got {}",
1917                inputs.s_f
1918            ),
1919        });
1920    }
1921    let n = inputs.n_rows;
1922    let r = inputs.r;
1923    if marginal_design_row_major.len() != n * block.p_m {
1924        return Err(GpuError::DriverCallFailed {
1925            reason: format!(
1926                "bms_flex_row device-resident: marginal_design len={} != n*p_m={}",
1927                marginal_design_row_major.len(),
1928                n * block.p_m
1929            ),
1930        });
1931    }
1932    if logslope_design_row_major.len() != n * block.p_g {
1933        return Err(GpuError::DriverCallFailed {
1934            reason: format!(
1935                "bms_flex_row device-resident: logslope_design len={} != n*p_g={}",
1936                logslope_design_row_major.len(),
1937                n * block.p_g
1938            ),
1939        });
1940    }
1941    if primary.r != r {
1942        return Err(GpuError::DriverCallFailed {
1943            reason: format!(
1944                "bms_flex_row device-resident: primary.r={} != inputs.r={}",
1945                primary.r, r
1946            ),
1947        });
1948    }
1949
1950    // Ensure the row kernel backend is compiled & loaded (this also compiles
1951    // the HVP backend on first use so the caller surfaces failures here).
1952    let backend = RowKernelBackend::probe()?;
1953    HvpKernelBackend::probe()?;
1954    let stream = backend.stream.clone();
1955
1956    let upload_f64 = |slice: &[f64], label: &str| {
1957        stream
1958            .clone_htod(slice)
1959            .map_err(|err| GpuError::DriverCallFailed {
1960                reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1961            })
1962    };
1963    let upload_u32 = |slice: &[u32], label: &str| {
1964        stream
1965            .clone_htod(slice)
1966            .map_err(|err| GpuError::DriverCallFailed {
1967                reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1968            })
1969    };
1970
1971    let d_q = upload_f64(inputs.q, "q")?;
1972    let d_b = upload_f64(inputs.b, "b")?;
1973    let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
1974    let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
1975    let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
1976    let d_y = upload_f64(inputs.y, "y")?;
1977    let d_w = upload_f64(inputs.w, "w")?;
1978    let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
1979    let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
1980    let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
1981    let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
1982    let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
1983    let d_a = upload_f64(inputs.cell_a, "cell_a")?;
1984    let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
1985    let d_r = upload_f64(inputs.cell_r, "cell_r")?;
1986    let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
1987    let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
1988    let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
1989    let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
1990    // Phase-4: optionally consume device-resident moments (no host upload).
1991    let owned_host_moments: CudaSlice<f64>;
1992    let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
1993        CellMomentsSource::Host(slice) => {
1994            owned_host_moments = upload_f64(slice, "cell_moments")?;
1995            &owned_host_moments
1996        }
1997        CellMomentsSource::Device(d) => *d,
1998    };
1999    let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
2000    let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
2001    let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
2002    let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
2003    let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
2004    let d_e_obs = upload_f64(inputs.e_obs, "e_obs")?;
2005
2006    let d_marginal = upload_f64(marginal_design_row_major, "marginal_design")?;
2007    let d_logslope = upload_f64(logslope_design_row_major, "logslope_design")?;
2008
2009    let mut d_neglog = stream
2010        .alloc_zeros::<f64>(n)
2011        .map_err(|err| GpuError::DriverCallFailed {
2012            reason: format!("bms_flex_row device-resident alloc neglog: {err}"),
2013        })?;
2014    let mut d_grad =
2015        stream
2016            .alloc_zeros::<f64>(n * r)
2017            .map_err(|err| GpuError::DriverCallFailed {
2018                reason: format!("bms_flex_row device-resident alloc grad: {err}"),
2019            })?;
2020    let mut d_hess =
2021        stream
2022            .alloc_zeros::<f64>(n * r * r)
2023            .map_err(|err| GpuError::DriverCallFailed {
2024                reason: format!("bms_flex_row device-resident alloc hess: {err}"),
2025            })?;
2026
2027    let func = backend
2028        .module
2029        .load_function("bms_flex_row_kernel")
2030        .map_err(|err| GpuError::DriverCallFailed {
2031            reason: format!("bms_flex_row device-resident load_function: {err}"),
2032        })?;
2033
2034    let cfg = LaunchConfig {
2035        grid_dim: (n as u32, 1, 1),
2036        block_dim: (ROW_KERNEL_THREADS, 1, 1),
2037        shared_mem_bytes: 0,
2038    };
2039    let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
2040        reason: format!("bms_flex_row device-resident: n_rows={n} exceeds i32 range"),
2041    })?;
2042    let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
2043        reason: format!("bms_flex_row device-resident: r={r} exceeds i32 range"),
2044    })?;
2045    let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
2046        reason: format!(
2047            "bms_flex_row device-resident: p_h={} exceeds i32 range",
2048            inputs.p_h
2049        ),
2050    })?;
2051    let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
2052        reason: format!(
2053            "bms_flex_row device-resident: p_w={} exceeds i32 range",
2054            inputs.p_w
2055        ),
2056    })?;
2057    let s_f_val = inputs.s_f;
2058
2059    let mut builder = stream.launch_builder(&func);
2060    builder
2061        .arg(&n_i32)
2062        .arg(&r_i32)
2063        .arg(&p_h_i32)
2064        .arg(&p_w_i32)
2065        .arg(&s_f_val)
2066        .arg(&d_q)
2067        .arg(&d_b)
2068        .arg(&d_mu1)
2069        .arg(&d_mu2)
2070        .arg(&d_zobs)
2071        .arg(&d_y)
2072        .arg(&d_w)
2073        .arg(&d_offsets)
2074        .arg(&d_c0)
2075        .arg(&d_c1)
2076        .arg(&d_c2)
2077        .arg(&d_c3)
2078        .arg(&d_a)
2079        .arg(&d_aa)
2080        .arg(&d_r)
2081        .arg(&d_ar)
2082        .arg(&d_sbb)
2083        .arg(&d_sbh)
2084        .arg(&d_sbw)
2085        .arg(d_moments_ref)
2086        .arg(&d_chi)
2087        .arg(&d_xi)
2088        .arg(&d_rho)
2089        .arg(&d_tau)
2090        .arg(&d_ruv)
2091        .arg(&d_e_obs)
2092        .arg(&mut d_neglog)
2093        .arg(&mut d_grad)
2094        .arg(&mut d_hess);
2095    // SAFETY: same shape contract as `launch_linux`: every kernel parameter is
2096    // either a primitive scalar by-value, a const device pointer whose
2097    // capacity was validated by `inputs.validate()`, or one of the three
2098    // output buffers we just allocated with the expected element count.
2099    unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
2100        reason: format!("bms_flex_row device-resident launch: {err}"),
2101    })?;
2102    stream
2103        .synchronize()
2104        .map_err(|err| GpuError::DriverCallFailed {
2105            reason: format!("bms_flex_row device-resident synchronize: {err}"),
2106        })?;
2107
2108    // The kernel writes neglog + grad alongside the row Hessian, but the
2109    // device-resident cache path keeps neither on the host: the fused
2110    // CPU gradient pass (the only consumer of host-side neglog/grad) is
2111    // dispatched only as a fallback when the GPU dense-block kernel does
2112    // not apply, and in that fallback the row kernel runs again locally.
2113    // Drop the device buffers so the allocation pool reclaims them
2114    // immediately rather than tying them to the cache's lifetime.
2115    drop(d_neglog);
2116    drop(d_grad);
2117    // Drop the per-cell uploads; keep d_hess + designs.
2118    drop(d_q);
2119    drop(d_b);
2120    drop(d_mu1);
2121    drop(d_mu2);
2122    drop(d_zobs);
2123    drop(d_y);
2124    drop(d_w);
2125    drop(d_offsets);
2126    drop(d_c0);
2127    drop(d_c1);
2128    drop(d_c2);
2129    drop(d_c3);
2130    drop(d_a);
2131    drop(d_aa);
2132    drop(d_r);
2133    drop(d_ar);
2134    drop(d_sbb);
2135    drop(d_sbh);
2136    drop(d_sbw);
2137    // `owned_host_moments` (if any) and the borrowed `d_moments_ref` both
2138    // go out of scope at the end of the function; the device-resident
2139    // moments owned by the caller stay alive.
2140    drop(d_chi);
2141    drop(d_xi);
2142    drop(d_rho);
2143    drop(d_tau);
2144    drop(d_ruv);
2145
2146    let bytes = ((n * r * r + marginal_design_row_major.len() + logslope_design_row_major.len())
2147        * std::mem::size_of::<f64>()) as u64;
2148    Ok(DeviceResidentRowHess {
2149        hess: d_hess,
2150        marginal_design: d_marginal,
2151        logslope_design: d_logslope,
2152        n,
2153        r,
2154        block,
2155        primary,
2156        bytes,
2157    })
2158}
2159
2160/// Which partial kernel the joint-β engine drives, whether it consumes a
2161/// direction vector `d_v`, and where the reduced `[1, p_total]` image lands.
2162/// All three points of variation are encoded here so the public entry points
2163/// stay thin wrappers over one launch helper.
2164#[cfg(target_os = "linux")]
2165#[derive(Clone, Copy)]
2166pub(crate) enum BmsFlexRowLaunchMode {
2167    /// `bms_flex_row_hvp_partial`, `H · v` per row, result left on-stream.
2168    HvpDeviceOut,
2169    /// `bms_flex_row_diag_partial`, `diag(H)` per row, downloaded to host.
2170    DiagonalHostOut,
2171}
2172
2173#[cfg(target_os = "linux")]
2174impl BmsFlexRowLaunchMode {
2175    /// Name of the partial kernel this mode loads from the HVP module.
2176    pub(crate) fn partial_kernel_name(self) -> &'static str {
2177        match self {
2178            BmsFlexRowLaunchMode::HvpDeviceOut => "bms_flex_row_hvp_partial",
2179            BmsFlexRowLaunchMode::DiagonalHostOut => "bms_flex_row_diag_partial",
2180        }
2181    }
2182}
2183
2184/// All scalar launch arguments for the joint-β partial kernel, derived once
2185/// from a [`DeviceResidentRowHess`]. The HVP and diagonal partial kernels take
2186/// the identical leading block-layout argument list (only the trailing
2187/// `d_v` / output pointers differ), so this captures the long, easy-to-
2188/// desynchronize prefix in a single place.
2189#[cfg(target_os = "linux")]
2190pub(crate) struct PreparedBmsFlexRowLaunchArgs {
2191    pub(crate) n_i32: i32,
2192    pub(crate) r_i32: i32,
2193    pub(crate) p_m_i32: i32,
2194    pub(crate) p_g_i32: i32,
2195    pub(crate) p_total_i32: i32,
2196    pub(crate) h_block_start: i32,
2197    pub(crate) h_block_len: i32,
2198    pub(crate) w_block_start: i32,
2199    pub(crate) w_block_len: i32,
2200    pub(crate) h_primary_start: i32,
2201    pub(crate) w_primary_start: i32,
2202    pub(crate) rows_per_cta: i32,
2203    pub(crate) num_chunks: usize,
2204}
2205
2206#[cfg(target_os = "linux")]
2207impl PreparedBmsFlexRowLaunchArgs {
2208    pub(crate) fn from_storage(storage: &DeviceResidentRowHess) -> Self {
2209        let p_total = storage.block.p_total;
2210        let num_chunks = num_hvp_chunks(storage.n);
2211        PreparedBmsFlexRowLaunchArgs {
2212            n_i32: storage.n as i32,
2213            r_i32: storage.r as i32,
2214            p_m_i32: storage.block.p_m as i32,
2215            p_g_i32: storage.block.p_g as i32,
2216            p_total_i32: p_total as i32,
2217            h_block_start: storage
2218                .block
2219                .h
2220                .as_ref()
2221                .map(|r| r.start as i32)
2222                .unwrap_or(0),
2223            h_block_len: storage
2224                .block
2225                .h
2226                .as_ref()
2227                .map(|r| r.len() as i32)
2228                .unwrap_or(0),
2229            w_block_start: storage
2230                .block
2231                .w
2232                .as_ref()
2233                .map(|r| r.start as i32)
2234                .unwrap_or(0),
2235            w_block_len: storage
2236                .block
2237                .w
2238                .as_ref()
2239                .map(|r| r.len() as i32)
2240                .unwrap_or(0),
2241            h_primary_start: storage
2242                .primary
2243                .h
2244                .as_ref()
2245                .map(|r| r.start as i32)
2246                .unwrap_or(0),
2247            w_primary_start: storage
2248                .primary
2249                .w
2250                .as_ref()
2251                .map(|r| r.start as i32)
2252                .unwrap_or(0),
2253            rows_per_cta: HVP_ROWS_PER_CTA as i32,
2254            num_chunks,
2255        }
2256    }
2257}
2258
2259/// Shared partial+reduce engine behind every joint-β launcher.
2260///
2261/// Allocates the `[num_chunks, p_total]` partial buffer, loads the mode's
2262/// partial kernel plus the common `bms_flex_row_hvp_reduce`, builds both
2263/// launch configs from a single [`PreparedBmsFlexRowLaunchArgs`], launches the
2264/// partial kernel (binding `d_v` only for the HVP modes), and launches the
2265/// reduction into caller-supplied `d_out`.
2266///
2267/// **No** `synchronize()` or DtoH is performed here — the surrounding helper
2268/// decides whether to keep the result on-stream (device-resident PCG hot path)
2269/// or sync + download it to the host. `ctx` is a short error-context tag woven
2270/// into every `DriverCallFailed` reason so failures stay attributable to the
2271/// originating entry point.
2272#[cfg(target_os = "linux")]
2273pub(crate) fn run_bms_flex_row_partial_reduce(
2274    storage: &DeviceResidentRowHess,
2275    mode: BmsFlexRowLaunchMode,
2276    d_v: Option<&CudaSlice<f64>>,
2277    d_out: &mut CudaSlice<f64>,
2278    ctx: &str,
2279) -> Result<(), GpuError> {
2280    let backend = HvpKernelBackend::probe()?;
2281    let stream = backend.stream.clone();
2282    let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2283    let p_total = storage.block.p_total;
2284
2285    let mut d_partial = stream
2286        .alloc_zeros::<f64>(args.num_chunks * p_total)
2287        .map_err(|err| GpuError::DriverCallFailed {
2288            reason: format!("bms_flex_row {ctx} alloc partial: {err}"),
2289        })?;
2290
2291    let partial_kernel_name = mode.partial_kernel_name();
2292    let part_func = backend
2293        .module
2294        .load_function(partial_kernel_name)
2295        .map_err(|err| GpuError::DriverCallFailed {
2296            reason: format!("bms_flex_row {ctx} load {partial_kernel_name}: {err}"),
2297        })?;
2298    let red_func = backend
2299        .module
2300        .load_function("bms_flex_row_hvp_reduce")
2301        .map_err(|err| GpuError::DriverCallFailed {
2302            reason: format!("bms_flex_row {ctx} load reduce: {err}"),
2303        })?;
2304
2305    let cfg_part = LaunchConfig {
2306        grid_dim: (args.num_chunks as u32, 1, 1),
2307        block_dim: (HVP_THREADS, 1, 1),
2308        shared_mem_bytes: 0,
2309    };
2310    let mut builder = stream.launch_builder(&part_func);
2311    builder
2312        .arg(&args.n_i32)
2313        .arg(&args.r_i32)
2314        .arg(&args.p_m_i32)
2315        .arg(&args.p_g_i32)
2316        .arg(&args.p_total_i32)
2317        .arg(&args.h_block_start)
2318        .arg(&args.h_block_len)
2319        .arg(&args.w_block_start)
2320        .arg(&args.w_block_len)
2321        .arg(&args.h_primary_start)
2322        .arg(&args.w_primary_start)
2323        .arg(&args.rows_per_cta)
2324        .arg(&storage.hess)
2325        .arg(&storage.marginal_design)
2326        .arg(&storage.logslope_design);
2327    if let Some(d_v) = d_v {
2328        builder.arg(d_v);
2329    }
2330    builder.arg(&mut d_partial);
2331    // SAFETY: every device pointer above either comes from `storage` (whose
2332    // capacities were established by
2333    // `launch_bms_flex_row_kernel_device_resident`) or was just allocated here
2334    // (`d_partial` = num_chunks * p_total). `d_v`, when bound, is length-checked
2335    // by the calling adapter against `p_total`. The diagonal partial kernel
2336    // takes no direction argument, matching `d_v == None`. Scalar args are i32
2337    // by-value.
2338    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2339        reason: format!("bms_flex_row {ctx} partial launch: {err}"),
2340    })?;
2341
2342    let red_threads: u32 = REDUCTION_THREADS;
2343    let red_blocks: u32 = ((p_total as u32) + red_threads - 1) / red_threads;
2344    let cfg_red = LaunchConfig {
2345        grid_dim: (red_blocks, 1, 1),
2346        block_dim: (red_threads, 1, 1),
2347        shared_mem_bytes: 0,
2348    };
2349    let num_chunks_i32 = args.num_chunks as i32;
2350    let mut builder = stream.launch_builder(&red_func);
2351    builder
2352        .arg(&num_chunks_i32)
2353        .arg(&args.p_total_i32)
2354        .arg(&d_partial)
2355        .arg(d_out);
2356    // SAFETY: `d_partial` was just populated by the partial kernel above;
2357    // `d_out` is `p_total` doubles (length-checked / allocated by the calling
2358    // adapter); both scalar args fit i32.
2359    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2360        reason: format!("bms_flex_row {ctx} reduce launch: {err}"),
2361    })?;
2362    // `d_partial` drops at end of fn; cudarc keeps the alloc alive until the
2363    // stream is done with it, so the reduce kernel completes safely.
2364    drop(d_partial);
2365    Ok(())
2366}
2367
2368/// Host-returning joint-β launcher shared by [`launch_bms_flex_row_hvp`]
2369/// ([`BmsFlexRowLaunchMode::HvpHostOut`]) and [`launch_bms_flex_row_diagonal`]
2370/// ([`BmsFlexRowLaunchMode::DiagonalHostOut`]).
2371///
2372/// Probes the backend, allocates the `p_total`-double output on its stream,
2373/// optionally uploads a host direction `v` (HVP modes; `None` for the diagonal
2374/// mode, which takes no direction), runs the shared partial+reduce engine, then
2375/// synchronizes and downloads the reduced image to a host `Vec<f64>`. `ctx`
2376/// tags every `DriverCallFailed` reason with the originating entry point.
2377#[cfg(target_os = "linux")]
2378pub(crate) fn launch_bms_flex_row_host(
2379    storage: &DeviceResidentRowHess,
2380    mode: BmsFlexRowLaunchMode,
2381    v: Option<&[f64]>,
2382    ctx: &str,
2383) -> Result<Vec<f64>, GpuError> {
2384    let p_total = storage.block.p_total;
2385    if let Some(v) = v {
2386        if v.len() != p_total {
2387            return Err(GpuError::DriverCallFailed {
2388                reason: format!(
2389                    "bms_flex_row {ctx}: v.len()={} != p_total={p_total}",
2390                    v.len()
2391                ),
2392            });
2393        }
2394    }
2395
2396    let backend = HvpKernelBackend::probe()?;
2397    let stream = backend.stream.clone();
2398
2399    let d_v = match v {
2400        Some(v) => Some(
2401            stream
2402                .clone_htod(v)
2403                .map_err(|err| GpuError::DriverCallFailed {
2404                    reason: format!("bms_flex_row {ctx} upload v: {err}"),
2405                })?,
2406        ),
2407        None => None,
2408    };
2409    let mut d_out =
2410        stream
2411            .alloc_zeros::<f64>(p_total)
2412            .map_err(|err| GpuError::DriverCallFailed {
2413                reason: format!("bms_flex_row {ctx} alloc out: {err}"),
2414            })?;
2415
2416    run_bms_flex_row_partial_reduce(storage, mode, d_v.as_ref(), &mut d_out, ctx)?;
2417
2418    stream
2419        .synchronize()
2420        .map_err(|err| GpuError::DriverCallFailed {
2421            reason: format!("bms_flex_row {ctx} synchronize: {err}"),
2422        })?;
2423    stream
2424        .clone_dtoh(&d_out)
2425        .map_err(|err| GpuError::DriverCallFailed {
2426            reason: format!("bms_flex_row {ctx} download out: {err}"),
2427        })
2428}
2429
2430#[cfg(target_os = "linux")]
2431pub(crate) fn validate_bms_flex_row_hvp_multi_shape(
2432    storage: &DeviceResidentRowHess,
2433    rhs_count: usize,
2434    v_rhs_len: usize,
2435    out_len: Option<usize>,
2436    ctx: &str,
2437) -> Result<usize, GpuError> {
2438    if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2439        return Err(GpuError::DriverCallFailed {
2440            reason: format!(
2441                "bms_flex_row {ctx}: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2442            ),
2443        });
2444    }
2445    let p_total = storage.block.p_total;
2446    let rhs_elems = rhs_count
2447        .checked_mul(p_total)
2448        .ok_or_else(|| GpuError::DriverCallFailed {
2449            reason: format!(
2450                "bms_flex_row {ctx}: rhs_count({rhs_count})*p_total({p_total}) overflow"
2451            ),
2452        })?;
2453    if v_rhs_len != rhs_elems {
2454        return Err(GpuError::DriverCallFailed {
2455            reason: format!(
2456                "bms_flex_row {ctx}: v_rhs.len()={v_rhs_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2457            ),
2458        });
2459    }
2460    if let Some(out_len) = out_len
2461        && out_len != rhs_elems
2462    {
2463        return Err(GpuError::DriverCallFailed {
2464            reason: format!(
2465                "bms_flex_row {ctx}: out.len()={out_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2466            ),
2467        });
2468    }
2469    Ok(rhs_elems)
2470}
2471
2472/// Transient device bytes for a multi-RHS HVP launch, excluding persistent
2473/// row-Hessian/design storage. Scratch scales with
2474/// `rhs_count * num_chunks * p_total`, not `rhs_count * n * r * r`.
2475#[cfg(target_os = "linux")]
2476pub fn bms_flex_row_hvp_multi_scratch_bytes_for_shape(
2477    n: usize,
2478    p_total: usize,
2479    rhs_count: usize,
2480) -> Result<u64, GpuError> {
2481    if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2482        return Err(GpuError::DriverCallFailed {
2483            reason: format!(
2484                "bms_flex_row hvp_multi_scratch_bytes: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2485            ),
2486        });
2487    }
2488    let num_chunks = num_hvp_chunks(n);
2489    let partial = rhs_count
2490        .checked_mul(num_chunks)
2491        .and_then(|v| v.checked_mul(p_total))
2492        .ok_or_else(|| GpuError::DriverCallFailed {
2493            reason: format!(
2494                "bms_flex_row hvp_multi_scratch_bytes: rhs_count({rhs_count})*num_chunks({num_chunks})*p_total({p_total}) overflow"
2495            ),
2496        })?;
2497    let rhs_vectors = rhs_count
2498        .checked_mul(p_total)
2499        .and_then(|v| v.checked_mul(2))
2500        .ok_or_else(|| GpuError::DriverCallFailed {
2501            reason: format!(
2502                "bms_flex_row hvp_multi_scratch_bytes: 2*rhs_count({rhs_count})*p_total({p_total}) overflow"
2503            ),
2504        })?;
2505    let elems = partial
2506        .checked_add(rhs_vectors)
2507        .ok_or_else(|| GpuError::DriverCallFailed {
2508            reason: "bms_flex_row hvp_multi_scratch_bytes: element count overflow".to_string(),
2509        })?;
2510    Ok((elems * std::mem::size_of::<f64>()) as u64)
2511}
2512
2513#[cfg(target_os = "linux")]
2514pub(crate) fn run_bms_flex_row_multi_partial_reduce(
2515    storage: &DeviceResidentRowHess,
2516    rhs_count: usize,
2517    d_v_rhs: &CudaSlice<f64>,
2518    d_out: &mut CudaSlice<f64>,
2519    ctx: &str,
2520) -> Result<(), GpuError> {
2521    let rhs_elems = validate_bms_flex_row_hvp_multi_shape(
2522        storage,
2523        rhs_count,
2524        d_v_rhs.len(),
2525        Some(d_out.len()),
2526        ctx,
2527    )?;
2528    let backend = HvpKernelBackend::probe()?;
2529    let stream = backend.stream.clone();
2530    let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2531    let p_total = storage.block.p_total;
2532    let partial_len = rhs_count
2533        .checked_mul(args.num_chunks)
2534        .and_then(|v| v.checked_mul(p_total))
2535        .ok_or_else(|| GpuError::DriverCallFailed {
2536            reason: format!(
2537                "bms_flex_row {ctx}: partial length overflow for rhs_count={rhs_count}, num_chunks={}, p_total={p_total}",
2538                args.num_chunks
2539            ),
2540        })?;
2541
2542    let mut d_partial =
2543        stream
2544            .alloc_zeros::<f64>(partial_len)
2545            .map_err(|err| GpuError::DriverCallFailed {
2546                reason: format!("bms_flex_row {ctx} alloc multi partial: {err}"),
2547            })?;
2548    let part_func = backend
2549        .module
2550        .load_function("bms_flex_row_hvp_multi_partial")
2551        .map_err(|err| GpuError::DriverCallFailed {
2552            reason: format!("bms_flex_row {ctx} load multi partial: {err}"),
2553        })?;
2554    let red_func = backend
2555        .module
2556        .load_function("bms_flex_row_hvp_multi_reduce")
2557        .map_err(|err| GpuError::DriverCallFailed {
2558            reason: format!("bms_flex_row {ctx} load multi reduce: {err}"),
2559        })?;
2560
2561    let rhs_count_i32 = i32::try_from(rhs_count).map_err(|_| GpuError::DriverCallFailed {
2562        reason: format!("bms_flex_row {ctx}: rhs_count={rhs_count} exceeds i32 range"),
2563    })?;
2564    let cfg_part = LaunchConfig {
2565        grid_dim: (args.num_chunks as u32, 1, 1),
2566        block_dim: (HVP_THREADS, 1, 1),
2567        shared_mem_bytes: 0,
2568    };
2569    let mut builder = stream.launch_builder(&part_func);
2570    builder
2571        .arg(&args.n_i32)
2572        .arg(&args.r_i32)
2573        .arg(&args.p_m_i32)
2574        .arg(&args.p_g_i32)
2575        .arg(&args.p_total_i32)
2576        .arg(&args.h_block_start)
2577        .arg(&args.h_block_len)
2578        .arg(&args.w_block_start)
2579        .arg(&args.w_block_len)
2580        .arg(&args.h_primary_start)
2581        .arg(&args.w_primary_start)
2582        .arg(&args.rows_per_cta)
2583        .arg(&rhs_count_i32)
2584        .arg(&storage.hess)
2585        .arg(&storage.marginal_design)
2586        .arg(&storage.logslope_design)
2587        .arg(d_v_rhs)
2588        .arg(&mut d_partial);
2589    // SAFETY: storage buffers were validated at construction; `d_v_rhs` and
2590    // `d_out` have rhs_count*p_total elements, `d_partial` has
2591    // rhs_count*num_chunks*p_total, and rhs_count is bounded by fixed shared
2592    // array sizes in the CUDA source.
2593    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2594        reason: format!("bms_flex_row {ctx} multi partial launch: {err}"),
2595    })?;
2596
2597    let red_threads: u32 = REDUCTION_THREADS;
2598    let red_blocks: u32 = ((rhs_elems as u32) + red_threads - 1) / red_threads;
2599    let cfg_red = LaunchConfig {
2600        grid_dim: (red_blocks, 1, 1),
2601        block_dim: (red_threads, 1, 1),
2602        shared_mem_bytes: 0,
2603    };
2604    let num_chunks_i32 = args.num_chunks as i32;
2605    let mut builder = stream.launch_builder(&red_func);
2606    builder
2607        .arg(&num_chunks_i32)
2608        .arg(&args.p_total_i32)
2609        .arg(&rhs_count_i32)
2610        .arg(&d_partial)
2611        .arg(d_out);
2612    // SAFETY: the reduce kernel reads the just-populated partial buffer and
2613    // writes exactly rhs_count*p_total output entries.
2614    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2615        reason: format!("bms_flex_row {ctx} multi reduce launch: {err}"),
2616    })?;
2617    drop(d_partial);
2618    Ok(())
2619}
2620
2621/// Device-resident multi-RHS HVP. `v_rhs` is row-major
2622/// `[rhs_count, p_total]`; the returned vector has the same layout.
2623#[cfg(target_os = "linux")]
2624pub(crate) fn launch_bms_flex_row_hvp_multi(
2625    storage: &DeviceResidentRowHess,
2626    v_rhs: &[f64],
2627    rhs_count: usize,
2628) -> Result<Vec<f64>, GpuError> {
2629    let rhs_elems =
2630        validate_bms_flex_row_hvp_multi_shape(storage, rhs_count, v_rhs.len(), None, "hvp_multi")?;
2631    let backend = HvpKernelBackend::probe()?;
2632    let stream = backend.stream.clone();
2633    let d_v_rhs = stream
2634        .clone_htod(v_rhs)
2635        .map_err(|err| GpuError::DriverCallFailed {
2636            reason: format!("bms_flex_row hvp_multi upload v_rhs: {err}"),
2637        })?;
2638    let mut d_out =
2639        stream
2640            .alloc_zeros::<f64>(rhs_elems)
2641            .map_err(|err| GpuError::DriverCallFailed {
2642                reason: format!("bms_flex_row hvp_multi alloc out: {err}"),
2643            })?;
2644    run_bms_flex_row_multi_partial_reduce(storage, rhs_count, &d_v_rhs, &mut d_out, "hvp_multi")?;
2645    stream
2646        .synchronize()
2647        .map_err(|err| GpuError::DriverCallFailed {
2648            reason: format!("bms_flex_row hvp_multi synchronize: {err}"),
2649        })?;
2650    stream
2651        .clone_dtoh(&d_out)
2652        .map_err(|err| GpuError::DriverCallFailed {
2653            reason: format!("bms_flex_row hvp_multi download out: {err}"),
2654        })
2655}
2656
2657/// Device-output HVP. Runs `bms_flex_row_hvp_partial(_packed)` +
2658/// `bms_flex_row_hvp_reduce` on the storage's stream against caller-supplied
2659/// device-resident `d_v` (length `p_total` doubles), writing the result into
2660/// caller-supplied `d_out` (also `p_total` doubles). **No** `synchronize()`
2661/// or DtoH is performed — the caller is responsible for stream ordering
2662/// against any consumer that reads `d_out`.
2663///
2664/// This is the device-resident PCG hot path (Block 9 Phase 5): keeping the
2665/// HVP output on the stream lets the outer PCG loop chain axpy / dot /
2666/// preconditioner kernels back-to-back without a per-iter device sync.
2667#[cfg(target_os = "linux")]
2668pub(crate) fn launch_bms_flex_row_hvp_into_device(
2669    storage: &DeviceResidentRowHess,
2670    d_v: &CudaSlice<f64>,
2671    d_out: &mut CudaSlice<f64>,
2672) -> Result<(), GpuError> {
2673    let p_total = storage.block.p_total;
2674    if d_v.len() != p_total {
2675        return Err(GpuError::DriverCallFailed {
2676            reason: format!(
2677                "bms_flex_row hvp_into_device: d_v.len()={} != p_total={}",
2678                d_v.len(),
2679                p_total
2680            ),
2681        });
2682    }
2683    if d_out.len() != p_total {
2684        return Err(GpuError::DriverCallFailed {
2685            reason: format!(
2686                "bms_flex_row hvp_into_device: d_out.len()={} != p_total={}",
2687                d_out.len(),
2688                p_total
2689            ),
2690        });
2691    }
2692    // On-stream output: the shared engine launches partial+reduce into the
2693    // caller's `d_out` and returns without sync/DtoH, so the outer PCG loop can
2694    // chain device kernels against the result.
2695    run_bms_flex_row_partial_reduce(
2696        storage,
2697        BmsFlexRowLaunchMode::HvpDeviceOut,
2698        Some(d_v),
2699        d_out,
2700        "hvp_into_device",
2701    )
2702}
2703
2704/// Launch the device-resident HVP kernel. Returns the host-side joint β image
2705/// of length `block.p_total`.
2706#[cfg(target_os = "linux")]
2707pub(crate) fn launch_bms_flex_row_hvp(
2708    storage: &DeviceResidentRowHess,
2709    v: &[f64],
2710) -> Result<Vec<f64>, GpuError> {
2711    launch_bms_flex_row_hvp_multi(storage, v, 1)
2712}
2713
2714/// Launch the device-resident diagonal kernel. Returns the host-side joint
2715/// β diagonal of length `block.p_total`.
2716#[cfg(target_os = "linux")]
2717pub(crate) fn launch_bms_flex_row_diagonal(
2718    storage: &DeviceResidentRowHess,
2719) -> Result<Vec<f64>, GpuError> {
2720    launch_bms_flex_row_host(storage, BmsFlexRowLaunchMode::DiagonalHostOut, None, "diag")
2721}
2722
2723/// Block 9 Phase 6 — hard cap on `p_total` for the dense joint-Hessian
2724/// device kernel. Per-CTA shared-memory accumulator is `p_total² * 8`
2725/// bytes. V100 default per-block shared cap is 48 KiB, so the largest
2726/// safe `p_total` here is `sqrt(48 KiB / 8) = 78`. We round down to a
2727/// power-of-two-ish multiple of 8 for predictable launch geometry.
2728#[cfg(target_os = "linux")]
2729pub(crate) const DENSE_BLOCK_MAX_P: usize = 72;
2730
2731/// Number of rows each dense-block CTA processes. Smaller than the HVP
2732/// `HVP_ROWS_PER_CTA = 256` because the per-row inner loop is `O(r² *
2733/// (p_m + p_g + h_block_len + w_block_len))` rather than `O(r²)` — fewer
2734/// rows per CTA keeps the per-CTA wall time short and lets us scale grid
2735/// occupancy with `num_chunks = ceil(n / DENSE_BLOCK_ROWS_PER_CTA)`.
2736#[cfg(target_os = "linux")]
2737pub(crate) const DENSE_BLOCK_ROWS_PER_CTA: u32 = 32;
2738
2739/// Launch the Phase-6 dense joint-Hessian block kernel. Returns the
2740/// host-side `[p_total, p_total]` row-major joint H as a `Vec<f64>`
2741/// (length `p_total²`).
2742///
2743/// **Not the default Newton path.** Production Newton uses HVP (Phase 2)
2744/// and never materialises the full dense Hessian. This entry exists for:
2745///   * exact-REML logdet (`log|H|`) when the unified evaluator wants to
2746///     factor H directly instead of going through the matrix-free path;
2747///   * diagnostic dumps that compare the GPU dense build against the CPU
2748///     `BernoulliMarginalSlopeFamily::fused_gradient_dense` reference;
2749///   * small-`p` debug routes where it is cheaper to factor + solve dense
2750///     than to run a PCG.
2751///
2752/// The kernel rejects `p_total > DENSE_BLOCK_MAX_P` cleanly because the
2753/// per-CTA shared-memory accumulator (`p_total² * 8` bytes) would exceed
2754/// the V100 48 KiB/block cap above that threshold.
2755#[cfg(target_os = "linux")]
2756pub fn launch_bms_flex_row_dense_block(
2757    storage: &DeviceResidentRowHess,
2758) -> Result<Vec<f64>, GpuError> {
2759    let p_total = storage.block.p_total;
2760    if p_total == 0 {
2761        return Err(GpuError::DriverCallFailed {
2762            reason: "bms_flex_row dense_block: p_total must be > 0".to_string(),
2763        });
2764    }
2765    if p_total > DENSE_BLOCK_MAX_P {
2766        return Err(GpuError::DriverCallFailed {
2767            reason: format!(
2768                "bms_flex_row dense_block: p_total={p_total} exceeds DENSE_BLOCK_MAX_P={DENSE_BLOCK_MAX_P} \
2769                 (per-CTA shmem accumulator p²*8 bytes would exceed V100's 48 KiB/block)"
2770            ),
2771        });
2772    }
2773    let backend = HvpKernelBackend::probe()?;
2774    let stream = backend.stream.clone();
2775    let n = storage.n;
2776    let r = storage.r;
2777    let rows_per_cta = DENSE_BLOCK_ROWS_PER_CTA as usize;
2778    let num_chunks = n.div_ceil(rows_per_cta);
2779    let pp = p_total * p_total;
2780
2781    let mut d_partial =
2782        stream
2783            .alloc_zeros::<f64>(num_chunks * pp)
2784            .map_err(|err| GpuError::DriverCallFailed {
2785                reason: format!("bms_flex_row dense_block alloc partial: {err}"),
2786            })?;
2787    let mut d_out = stream
2788        .alloc_zeros::<f64>(pp)
2789        .map_err(|err| GpuError::DriverCallFailed {
2790            reason: format!("bms_flex_row dense_block alloc out: {err}"),
2791        })?;
2792
2793    let part_func = backend
2794        .module
2795        .load_function("bms_flex_row_dense_block_partial")
2796        .map_err(|err| GpuError::DriverCallFailed {
2797            reason: format!("bms_flex_row dense_block load partial: {err}"),
2798        })?;
2799    let red_func = backend
2800        .module
2801        .load_function("bms_flex_row_dense_block_reduce")
2802        .map_err(|err| GpuError::DriverCallFailed {
2803            reason: format!("bms_flex_row dense_block load reduce: {err}"),
2804        })?;
2805
2806    let n_i32 = n as i32;
2807    let r_i32 = r as i32;
2808    let p_m_i32 = storage.block.p_m as i32;
2809    let p_g_i32 = storage.block.p_g as i32;
2810    let p_total_i32 = p_total as i32;
2811    let h_block_start = storage
2812        .block
2813        .h
2814        .as_ref()
2815        .map(|r| r.start as i32)
2816        .unwrap_or(0);
2817    let h_block_len = storage
2818        .block
2819        .h
2820        .as_ref()
2821        .map(|r| r.len() as i32)
2822        .unwrap_or(0);
2823    let w_block_start = storage
2824        .block
2825        .w
2826        .as_ref()
2827        .map(|r| r.start as i32)
2828        .unwrap_or(0);
2829    let w_block_len = storage
2830        .block
2831        .w
2832        .as_ref()
2833        .map(|r| r.len() as i32)
2834        .unwrap_or(0);
2835    let h_primary_start = storage
2836        .primary
2837        .h
2838        .as_ref()
2839        .map(|r| r.start as i32)
2840        .unwrap_or(0);
2841    let w_primary_start = storage
2842        .primary
2843        .w
2844        .as_ref()
2845        .map(|r| r.start as i32)
2846        .unwrap_or(0);
2847    let rows_per_cta_i32 = DENSE_BLOCK_ROWS_PER_CTA as i32;
2848    let num_chunks_u32 = num_chunks as u32;
2849
2850    // Per-CTA shmem accumulator: p_total² doubles.
2851    let shmem_bytes: u32 =
2852        u32::try_from(pp * std::mem::size_of::<f64>()).map_err(|_| GpuError::DriverCallFailed {
2853            reason: format!("dense_block shmem bytes overflow u32 for p_total={p_total}"),
2854        })?;
2855
2856    let cfg_part = LaunchConfig {
2857        grid_dim: (num_chunks_u32, 1, 1),
2858        block_dim: (HVP_THREADS, 1, 1),
2859        shared_mem_bytes: shmem_bytes,
2860    };
2861    let mut builder = stream.launch_builder(&part_func);
2862    builder
2863        .arg(&n_i32)
2864        .arg(&r_i32)
2865        .arg(&p_m_i32)
2866        .arg(&p_g_i32)
2867        .arg(&p_total_i32)
2868        .arg(&h_block_start)
2869        .arg(&h_block_len)
2870        .arg(&w_block_start)
2871        .arg(&w_block_len)
2872        .arg(&h_primary_start)
2873        .arg(&w_primary_start)
2874        .arg(&rows_per_cta_i32)
2875        .arg(&storage.hess)
2876        .arg(&storage.marginal_design)
2877        .arg(&storage.logslope_design)
2878        .arg(&mut d_partial);
2879    // SAFETY: storage pointers have validated capacities; d_partial sized
2880    // num_chunks * pp doubles; dynamic shmem matches the kernel's `extern
2881    // __shared__` accumulator length.
2882    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2883        reason: format!("bms_flex_row dense_block partial launch: {err}"),
2884    })?;
2885
2886    let red_threads: u32 = REDUCTION_THREADS;
2887    let red_blocks: u32 = ((pp as u32) + red_threads - 1) / red_threads;
2888    let cfg_red = LaunchConfig {
2889        grid_dim: (red_blocks, 1, 1),
2890        block_dim: (red_threads, 1, 1),
2891        shared_mem_bytes: 0,
2892    };
2893    let num_chunks_i32 = num_chunks as i32;
2894    let mut builder = stream.launch_builder(&red_func);
2895    builder
2896        .arg(&num_chunks_i32)
2897        .arg(&p_total_i32)
2898        .arg(&d_partial)
2899        .arg(&mut d_out);
2900    // SAFETY: d_partial just populated, d_out is pp doubles.
2901    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2902        reason: format!("bms_flex_row dense_block reduce launch: {err}"),
2903    })?;
2904    stream
2905        .synchronize()
2906        .map_err(|err| GpuError::DriverCallFailed {
2907            reason: format!("bms_flex_row dense_block sync: {err}"),
2908        })?;
2909    stream
2910        .clone_dtoh(&d_out)
2911        .map_err(|err| GpuError::DriverCallFailed {
2912            reason: format!("bms_flex_row dense_block download: {err}"),
2913        })
2914}
2915
2916// Block 9 / V100-build unblock (2026-05-27): every test below either
2917// constructs `BmsFlexBlockLayout` / `BmsFlexPrimaryLayout` (Linux-only
2918// types) or drives a CUDA-dependent fixture, so gate the whole module
2919// `#[cfg(all(test, target_os = "linux"))]`. On macOS the structs are
2920// absent and these tests do not compile — the build.rs ban scanner
2921// explicitly rejects `#[cfg(any(..., test))]` on the struct definitions
2922// themselves as a dead-code escape hatch.
2923// #415 parity lock: the CPU host oracle for the BMS-FLEX row kernel lives in a
2924// non-linux-gated `#[cfg(test)]` module so it can be exercised on the macOS dev
2925// box + CPU CI (the sibling `mod tests` is linux-gated because it also builds
2926// CUDA-only fixture types). `cpu_oracle_outputs` itself is platform-independent.
2927#[cfg(test)]
2928mod oracle_parity_tests {
2929    use super::*;
2930
2931    // ── CPU oracle that mirrors ROW_KERNEL_BODY bit-for-bit ──────────────────
2932    //
2933    // `cpu_oracle_outputs` implements the same algebra as
2934    // `bms_flex_row_kernel` in ROW_KERNEL_BODY: per-cell `T_n` / `D` / `Q`
2935    // contractions, q-row override, IFT to `a_u` / `a_uv`, observed-point
2936    // assembly to `bar_e_u` / `bar_e_uv`, probit Mills, and the final
2937    // `out_grad` / `out_hess` writes. It takes the same
2938    // `BmsFlexRowKernelInputs` struct so a CUDA-equipped host can run both
2939    // paths off one bundle and check element-wise parity.
2940    //
2941    // Used by the GPU↔CPU parity test below; the test skips on non-Linux
2942    // hosts via cfg, but the oracle itself is platform-independent so the
2943    // macOS lib build can still type-check it.
2944
2945    pub(crate) const ORACLE_INV_TWO_PI: f64 = 1.0 / std::f64::consts::TAU;
2946    pub(crate) const ORACLE_SQRT_2: f64 = std::f64::consts::SQRT_2;
2947    pub(crate) const ORACLE_INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
2948
2949    pub(crate) fn oracle_erfcx_nonnegative(x: f64) -> f64 {
2950        if !x.is_finite() {
2951            return if x > 0.0 { 0.0 } else { f64::INFINITY };
2952        }
2953        if x <= 0.0 {
2954            return 1.0;
2955        }
2956        if x < 26.0 {
2957            let mut xx = x * x;
2958            if xx > 700.0 {
2959                xx = 700.0;
2960            }
2961            return xx.exp() * gam_gpu::numerics_host::erfc(x);
2962        }
2963        let inv = 1.0 / x;
2964        let inv2 = inv * inv;
2965        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
2966            + 6.5625 * inv2 * inv2 * inv2 * inv2;
2967        let inv_sqrt_pi: f64 = 0.564_189_583_547_756_3;
2968        inv * poly * inv_sqrt_pi
2969    }
2970
2971    pub(crate) fn oracle_log_ndtr_and_mills(x: f64) -> (f64, f64) {
2972        if x == f64::INFINITY {
2973            return (0.0, 0.0);
2974        }
2975        if x == f64::NEG_INFINITY {
2976            return (f64::NEG_INFINITY, f64::INFINITY);
2977        }
2978        if x.is_nan() {
2979            return (x, x);
2980        }
2981        // Single-algorithm region around and above 0. Both `log Φ(x)` and the
2982        // Mills ratio `φ/Φ` are computed from the SAME `erfc(-x/√2)` call, so
2983        // the oracle is C¹ across the x=0 seam (the prior split — erfcx-based
2984        // `-u²+ln(0.5·e^{u²}·erfc u)` for x<0 vs direct `ln(0.5·erfc(-x))` for
2985        // x≥0 — used two distinct float algorithms whose ~1e-7 disagreement at
2986        // x=0 corrupted a finite-difference reference straddling the seam, #838).
2987        //
2988        // The erfcx form is mathematically identical (e^{u²}·erfc(u)=erfcx(u),
2989        // so −u²+ln(0.5·erfcx u)=ln(0.5·erfc(−x/√2))); it is only needed deep in
2990        // the left tail (x ≲ −38), where `erfc(-x/√2)` underflows to 0 and the
2991        // exp/ln cancellation is the *only* way to keep `log Φ` finite. We move
2992        // the branch there, far from any region the kernel or its FD lock visits.
2993        const ORACLE_LEFT_TAIL_X: f64 = -37.0;
2994        if x >= ORACLE_LEFT_TAIL_X {
2995            let mut cdf = 0.5 * gam_gpu::numerics_host::erfc(-x / ORACLE_SQRT_2);
2996            if cdf < 1e-300 {
2997                cdf = 1e-300;
2998            }
2999            if cdf > 1.0 {
3000                cdf = 1.0;
3001            }
3002            let pdf = ORACLE_INV_SQRT_2PI * (-0.5 * x * x).exp();
3003            (cdf.ln(), pdf / cdf)
3004        } else {
3005            let u = -x / ORACLE_SQRT_2;
3006            let mut ex = oracle_erfcx_nonnegative(u);
3007            if ex < 1e-300 {
3008                ex = 1e-300;
3009            }
3010            let log_cdf = -u * u + (0.5 * ex).ln();
3011            let sqrt_2_over_pi: f64 = 0.797_884_560_802_865_4;
3012            (log_cdf, sqrt_2_over_pi / ex)
3013        }
3014    }
3015
3016    /// Same outputs the device kernel writes: `(neglog, grad, hess)` per row.
3017    /// `grad` is row-major `n × r`, `hess` is row-major `n × r × r`.
3018    /// Mirrors `bms_flex_row_kernel` line-for-line so kernel + oracle diverge
3019    /// only if one side breaks parity.
3020    pub(crate) fn cpu_oracle_outputs(
3021        inputs: &BmsFlexRowKernelInputs<'_>,
3022    ) -> BmsFlexRowKernelOutputs {
3023        let n = inputs.n_rows;
3024        let r = inputs.r;
3025        let p_h = inputs.p_h;
3026        let p_w = inputs.p_w;
3027        let mut neglog = vec![0.0_f64; n];
3028        let mut grad = vec![0.0_f64; n * r];
3029        let mut hess = vec![0.0_f64; n * r * r];
3030        let cell_moments_host = match &inputs.cell_moments {
3031            CellMomentsSource::Host(slice) => *slice,
3032            #[cfg(target_os = "linux")]
3033            CellMomentsSource::Device(_) => panic!(
3034                // SAFETY: this CPU oracle is a host-only sanity checker invoked
3035                // exclusively from `#[cfg(test)] mod tests`. The kernel-launch
3036                // path uses `CellMomentsSource::Device(...)`; the oracle must
3037                // never see that variant. Reaching this arm means a test
3038                // mis-wired its fixture — surface it loudly at the call site.
3039                "cpu_oracle_outputs: cell_moments is device-resident; oracle \
3040                 is a host-only sanity checker"
3041            ),
3042        };
3043
3044        for row in 0..n {
3045            // ── per-cell sweep: accumulate F_u, F_au, F_uv, F_a, F_aa.
3046            let mut f_u = vec![0.0_f64; r];
3047            let mut f_au = vec![0.0_f64; r];
3048            let mut f_uv = vec![0.0_f64; r * r];
3049            let mut f_a = 0.0_f64;
3050            let mut f_aa = 0.0_f64;
3051
3052            let cell_lo = inputs.cell_offsets[row] as usize;
3053            let cell_hi = inputs.cell_offsets[row + 1] as usize;
3054            for c in cell_lo..cell_hi {
3055                let c_arr = [
3056                    inputs.cell_c0[c],
3057                    inputs.cell_c1[c],
3058                    inputs.cell_c2[c],
3059                    inputs.cell_c3[c],
3060                ];
3061                let m = &cell_moments_host[c * MOMENT_STRIDE..(c + 1) * MOMENT_STRIDE];
3062
3063                // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
3064                let mut t = [0.0_f64; 7];
3065                for (n_idx, t_slot) in t.iter_mut().enumerate() {
3066                    let mut acc = 0.0_f64;
3067                    for (e, c_e) in c_arr.iter().enumerate() {
3068                        acc = c_e.mul_add(m[e + n_idx], acc);
3069                    }
3070                    *t_slot = acc * ORACLE_INV_TWO_PI;
3071                }
3072
3073                let d_of = |r_arr: &[f64]| -> f64 {
3074                    ORACLE_INV_TWO_PI
3075                        * (r_arr[0] * m[0] + r_arr[1] * m[1] + r_arr[2] * m[2] + r_arr[3] * m[3])
3076                };
3077                let q_of = |r_arr: &[f64], s_arr: &[f64]| -> f64 {
3078                    (r_arr[0] * s_arr[0]) * t[0]
3079                        + (r_arr[0] * s_arr[1] + r_arr[1] * s_arr[0]) * t[1]
3080                        + (r_arr[0] * s_arr[2] + r_arr[1] * s_arr[1] + r_arr[2] * s_arr[0]) * t[2]
3081                        + (r_arr[0] * s_arr[3]
3082                            + r_arr[1] * s_arr[2]
3083                            + r_arr[2] * s_arr[1]
3084                            + r_arr[3] * s_arr[0])
3085                            * t[3]
3086                        + (r_arr[1] * s_arr[3] + r_arr[2] * s_arr[2] + r_arr[3] * s_arr[1]) * t[4]
3087                        + (r_arr[2] * s_arr[3] + r_arr[3] * s_arr[2]) * t[5]
3088                        + (r_arr[3] * s_arr[3]) * t[6]
3089                };
3090
3091                let a_c = &inputs.cell_a[c * 4..(c + 1) * 4];
3092                let aa_c = &inputs.cell_aa[c * 4..(c + 1) * 4];
3093                f_a += d_of(a_c);
3094                f_aa += d_of(aa_c) - q_of(a_c, a_c);
3095
3096                for u in 1..r {
3097                    let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3098                    let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3099                    let ar_u = &inputs.cell_ar[r_u_off..r_u_off + 4];
3100                    f_u[u] += d_of(r_u);
3101                    f_au[u] += d_of(ar_u) - q_of(a_c, r_u);
3102                }
3103
3104                for u in 1..r {
3105                    let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3106                    let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3107                    for v in u..r {
3108                        let r_v_off = (c * (r - 1) + (v - 1)) * 4;
3109                        let r_v = &inputs.cell_r[r_v_off..r_v_off + 4];
3110                        let q_uv = q_of(r_u, r_v);
3111                        let d_s = if u == 1 && v == 1 {
3112                            let s_bb = &inputs.cell_sbb[c * 4..(c + 1) * 4];
3113                            d_of(s_bb)
3114                        } else if u == 1 && v >= 2 && v < 2 + p_h {
3115                            let j = v - 2;
3116                            let off = (c * p_h + j) * 4;
3117                            let s_bh = &inputs.cell_sbh[off..off + 4];
3118                            d_of(s_bh)
3119                        } else if u == 1 && v >= 2 + p_h && v < r {
3120                            let l = v - (2 + p_h);
3121                            let off = (c * p_w + l) * 4;
3122                            let s_bw = &inputs.cell_sbw[off..off + 4];
3123                            d_of(s_bw)
3124                        } else {
3125                            0.0
3126                        };
3127                        f_uv[u * r + v] += d_s - q_uv;
3128                    }
3129                }
3130            }
3131
3132            // q-row overrides (mirror kernel lines 691–700).
3133            let mu_1 = inputs.mu_1[row];
3134            let mu_2 = inputs.mu_2[row];
3135            f_u[0] = -mu_1;
3136            f_au[0] = 0.0;
3137            for v in 0..r {
3138                f_uv[v] = 0.0;
3139                f_uv[v * r] = 0.0;
3140            }
3141            f_uv[0] = -mu_2;
3142
3143            // Degenerate F_a ⇒ NaN-fill (mirror kernel lines 703–706).
3144            if !f_a.is_finite() || f_a <= 0.0 {
3145                neglog[row] = f64::NAN;
3146                for slot in grad[row * r..(row + 1) * r].iter_mut() {
3147                    *slot = f64::NAN;
3148                }
3149                for slot in hess[row * r * r..(row + 1) * r * r].iter_mut() {
3150                    *slot = f64::NAN;
3151                }
3152                continue;
3153            }
3154            let inv_fa = 1.0 / f_a;
3155
3156            // IFT first/second order.
3157            let mut a_u = vec![0.0_f64; r];
3158            a_u[0] = mu_1 * inv_fa;
3159            for u in 1..r {
3160                a_u[u] = -f_u[u] * inv_fa;
3161            }
3162            let mut a_uv = vec![0.0_f64; r * r];
3163            for u in 0..r {
3164                for v in u..r {
3165                    let term = f_uv[u * r + v]
3166                        + f_au[v] * a_u[u]
3167                        + f_au[u] * a_u[v]
3168                        + f_aa * a_u[u] * a_u[v];
3169                    let val = -term * inv_fa;
3170                    a_uv[u * r + v] = val;
3171                    a_uv[v * r + u] = val;
3172                }
3173            }
3174
3175            // Observed predictor jets.
3176            let chi = inputs.chi_obs[row];
3177            let xi = inputs.xi_obs[row];
3178            let rho = &inputs.rho_u[row * r..(row + 1) * r];
3179            let tau = &inputs.tau_u[row * r..(row + 1) * r];
3180            let ruv = &inputs.r_uv[row * r * r..(row + 1) * r * r];
3181            let mut bar_e_u = vec![0.0_f64; r];
3182            for u in 0..r {
3183                bar_e_u[u] = chi * a_u[u] + rho[u];
3184            }
3185            let mut bar_e_uv = vec![0.0_f64; r * r];
3186            for u in 0..r {
3187                for v in u..r {
3188                    let val = chi * a_uv[u * r + v]
3189                        + xi * a_u[u] * a_u[v]
3190                        + tau[u] * a_u[v]
3191                        + a_u[u] * tau[v]
3192                        + ruv[u * r + v];
3193                    bar_e_uv[u * r + v] = val;
3194                    if u != v {
3195                        bar_e_uv[v * r + u] = val;
3196                    }
3197                }
3198            }
3199
3200            // Probit Mills + final writes.
3201            let y = inputs.y[row];
3202            let w = inputs.w[row];
3203            let s = 2.0 * y - 1.0;
3204            // #415 parity: the observed predictor VALUE is packed directly
3205            // (`inputs.e_obs`), matching the CPU family's `signed_margin =
3206            // s_y * eta_val`. `bar_e_u[0]` is the u=0 first-derivative jet and
3207            // is used only for the gradient/Hessian, never as the Mills margin.
3208            let e_obs = inputs.e_obs[row];
3209            let m_arg = s * e_obs;
3210            let (log_cdf, lambda) = oracle_log_ndtr_and_mills(m_arg);
3211            let a_i = -w * s * lambda;
3212            let b_i = w * lambda * (m_arg + lambda);
3213            neglog[row] = -w * log_cdf;
3214            for u in 0..r {
3215                grad[row * r + u] = a_i * bar_e_u[u];
3216            }
3217            for u in 0..r {
3218                for v in u..r {
3219                    let val = b_i * bar_e_u[u] * bar_e_u[v] + a_i * bar_e_uv[u * r + v];
3220                    hess[row * r * r + u * r + v] = val;
3221                    if u != v {
3222                        hess[row * r * r + v * r + u] = val;
3223                    }
3224                }
3225            }
3226        }
3227
3228        BmsFlexRowKernelOutputs { neglog, grad, hess }
3229    }
3230
3231    // #415 parity lock. This test lives HERE (a descendant of `bms::gpu::row`)
3232    // rather than in `bms::row_primary_hessian` because the host oracle
3233    // `cpu_oracle_outputs` must live in a PRIVATE `#[cfg(test)]` mod (the
3234    // build.rs ban-scanner forbids `#[cfg(test)]` on a non-private mod), so a
3235    // sibling module cannot reach it. Nested here, the test sees the private
3236    // oracle directly while the packer/CPU-family methods it drives are
3237    // `pub(in crate::bms)` and stay reachable. The nested module carries no
3238    // `#[cfg(test)]` attribute of its own (it inherits the parent's), so it is
3239    // ban-scanner-clean.
3240    mod parity_415 {
3241        //! #415 parity lock: the GPU-host oracle `cpu_oracle_outputs` (which
3242        //! GATES the device row kernel via
3243        //! `bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available`) must
3244        //! reproduce the CPU family reference
3245        //! `compute_row_analytic_flex_from_parts_into` element-for-element, from
3246        //! ONE fitted `(family, block_states, cache)`.
3247        //!
3248        //! Before this test the only ties between the two were (1) an FD lock on
3249        //! the outer scalar Mills layer and (2) a string-contains comment guard —
3250        //! neither pins the cell-contraction algebra (`F_a`, `F_aa`, `F_au`,
3251        //! `F_uv` → value/grad/Hessian). This closes that gap: the SAME fitted
3252        //! state is packed into `BmsFlexRowKernelInputs` and run through
3253        //! `cpu_oracle_outputs` for all rows, and independently run through the
3254        //! CPU family per row; every row value, full gradient, and full r×r
3255        //! Hessian must agree to ~1e-10.
3256
3257        use super::cpu_oracle_outputs;
3258        use crate::bms::family::*;
3259        use crate::bms::hessian_paths::*;
3260        use crate::bms::{DeviationBlockConfig, LatentMeasureKind, exact_kernel};
3261        use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix};
3262        use gam_problem::{InverseLink, ParameterBlockState, StandardLink};
3263        use ndarray::{Array1, Array2};
3264        use std::sync::{Arc, Mutex};
3265
3266        /// Build a small but REAL flex BMS family in the `StandardNormal`
3267        /// latent-measure branch with BOTH a score-warp (`p_h > 0`) and a
3268        /// link-deviation (`p_w > 0`) block active, plus mixed labels y ∈ {0,1}.
3269        /// Ported from the `gradient_paths` flex oracle fixture so the cache is
3270        /// populated by the production cell-moment assembly (never hand-faked).
3271        fn make_flex_parity_family(
3272            n: usize,
3273        ) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
3274            let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
3275            let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
3276            let cfg = DeviationBlockConfig {
3277                num_internal_knots: 3,
3278                ..DeviationBlockConfig::default()
3279            };
3280            let score_prepared = build_score_warp_deviation_block_from_seed(&score_seed, &cfg)
3281                .expect("build score warp block");
3282            let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
3283                &link_seed, &link_seed, &cfg,
3284            )
3285            .expect("build link deviation block");
3286
3287            // Mixed labels y ∈ {0,1} so both s_y = ±1 Mills branches are exercised.
3288            let y: Array1<f64> =
3289                Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
3290            let weights: Array1<f64> =
3291                Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
3292            let z: Array1<f64> =
3293                Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
3294            let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
3295                if j == 0 {
3296                    1.0
3297                } else {
3298                    -0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
3299                }
3300            });
3301            let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
3302                if j == 0 {
3303                    1.0
3304                } else {
3305                    0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
3306                }
3307            });
3308
3309            let family = BernoulliMarginalSlopeFamily {
3310                y: Arc::new(y),
3311                weights: Arc::new(weights),
3312                z: Arc::new(z.clone()),
3313                latent_measure: LatentMeasureKind::StandardNormal,
3314                gaussian_frailty_sd: Some(0.15),
3315                base_link: InverseLink::Standard(StandardLink::Probit),
3316                marginal_design: DesignMatrix::Dense(DenseDesignMatrix::from(marginal_x.clone())),
3317                logslope_design: DesignMatrix::Dense(DenseDesignMatrix::from(logslope_x.clone())),
3318                score_warp: Some(score_prepared.runtime.clone()),
3319                link_dev: Some(link_prepared.runtime.clone()),
3320                policy: gam_runtime::resource::ResourcePolicy::default_library(),
3321                cell_moment_lru: Arc::new(exact_kernel::CellMomentLruCache::new(1024)),
3322                cell_moment_cache_stats: Arc::new(exact_kernel::CellMomentCacheStats::default()),
3323                intercept_warm_starts: None,
3324                auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
3325                auto_subsample_last_rho: Arc::new(Mutex::new(None)),
3326            };
3327
3328            let beta_m = Array1::from_vec(vec![0.12, -0.04]);
3329            let beta_g = Array1::from_vec(vec![0.35, 0.03]);
3330            let beta_h = Array1::from_iter(
3331                (0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
3332            );
3333            let beta_w = Array1::from_iter(
3334                (0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
3335            );
3336            let states = vec![
3337                ParameterBlockState {
3338                    eta: marginal_x.dot(&beta_m),
3339                    beta: beta_m,
3340                },
3341                ParameterBlockState {
3342                    eta: logslope_x.dot(&beta_g),
3343                    beta: beta_g,
3344                },
3345                ParameterBlockState {
3346                    beta: beta_h,
3347                    eta: Array1::zeros(z.len()),
3348                },
3349                ParameterBlockState {
3350                    beta: beta_w,
3351                    eta: Array1::zeros(z.len()),
3352                },
3353            ];
3354            (family, states)
3355        }
3356
3357        /// The non-vacuous #415 lock: pack once, run the host oracle for all
3358        /// rows, run the CPU family per row, and assert value/gradient/Hessian
3359        /// parity.
3360        #[test]
3361        fn cpu_oracle_matches_cpu_family_row_analytic_flex_415() {
3362            let n = 12usize;
3363            let (family, states) = make_flex_parity_family(n);
3364            let cache = family
3365                .build_exact_eval_cache(&states)
3366                .expect("flex exact eval cache");
3367
3368            // Preconditions that make the lock non-vacuous: the row-cell-moments
3369            // bundle (which the oracle consumes) must actually be materialised,
3370            // and the deviation blocks must both be present so p_h > 0 AND p_w > 0.
3371            assert!(
3372                cache.row_cell_moments.is_some(),
3373                "#415 fixture must materialise the row-cell-moments bundle; the pack \
3374                 and both compared paths read it"
3375            );
3376            let primary = &cache.primary;
3377            let r = primary.total;
3378            let p_h = primary.h.as_ref().map(|range| range.len()).unwrap_or(0);
3379            let p_w = primary.w.as_ref().map(|range| range.len()).unwrap_or(0);
3380            assert!(
3381                p_h > 0 && p_w > 0,
3382                "#415 fixture must be full-flex: p_h={p_h} p_w={p_w}"
3383            );
3384            assert_eq!(r, 2 + p_h + p_w, "#415 fixture primary layout");
3385
3386            // Pack the SAME fitted state the CPU family will consume, then run the
3387            // GPU host oracle over every row.
3388            let owned = family
3389                .pack_bms_flex_row_kernel_inputs(&states, &cache)
3390                .expect("pack must not error")
3391                .expect("pack must succeed for the StandardNormal full-flex fixture");
3392            let inputs = owned.as_borrowed();
3393            let oracle = cpu_oracle_outputs(&inputs);
3394            assert_eq!(oracle.neglog.len(), n);
3395            assert_eq!(oracle.grad.len(), n * r);
3396            assert_eq!(oracle.hess.len(), n * r * r);
3397
3398            // Non-vacuity guard for the original #415/BMS-FLEX failure mode:
3399            // the Mills margin must be the packed observed predictor VALUE
3400            // `e_obs`, not the q-axis first derivative `bar_e_u[0]`. The oracle
3401            // does not expose `bar_e_u`, but its gradient obeys
3402            // `grad[0] = A(e_obs) · bar_e_u[0]`; recover that derivative from
3403            // the written output and prove this fixture separates it from
3404            // `e_obs` on at least one row. Otherwise a kernel/oracle that
3405            // accidentally substituted `bar_e_u[0]` for `e_obs` could pass a
3406            // vacuous fixture where both scalars coincide.
3407            let mut separates_observed_value_from_q_derivative = false;
3408            for row in 0..n {
3409                let y = inputs.y[row];
3410                let w = inputs.w[row];
3411                let s = 2.0 * y - 1.0;
3412                let e_obs = inputs.e_obs[row];
3413                let (_, lambda) = super::oracle_log_ndtr_and_mills(s * e_obs);
3414                let a_i = -w * s * lambda;
3415                if a_i.abs() > 1e-12 {
3416                    let recovered_bar_e_q = oracle.grad[row * r] / a_i;
3417                    if (recovered_bar_e_q - e_obs).abs() > 1e-8 {
3418                        separates_observed_value_from_q_derivative = true;
3419                        break;
3420                    }
3421                }
3422            }
3423            assert!(
3424                separates_observed_value_from_q_derivative,
3425                "#415 fixture must distinguish e_obs from bar_e_u[0]; otherwise \
3426                 the observed-value Mills-margin regression is not exercised"
3427            );
3428
3429            // Both sides are exact f64 CPU math over the SAME cached moments, so
3430            // the only slack is FP summation ordering. Anything looser than this
3431            // would hide a real algebraic drift.
3432            let tol_abs = 1e-9_f64;
3433            let tol_rel = 1e-10_f64;
3434
3435            let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(r);
3436            let mut max_rel = 0.0_f64;
3437            let mut checked_labels = [false, false];
3438
3439            for row in 0..n {
3440                let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
3441                let row_moments = cache
3442                    .row_cell_moments
3443                    .as_ref()
3444                    .and_then(|bundle| bundle.row(row, 9));
3445                assert!(
3446                    row_moments.is_some(),
3447                    "row {row} must carry degree-9 cell moments (the oracle reads them)"
3448                );
3449                let label = family.y[row] as usize;
3450                if label < 2 {
3451                    checked_labels[label] = true;
3452                }
3453
3454                let value = family
3455                    .compute_row_analytic_flex_into_with_moments(
3456                        row,
3457                        &states,
3458                        primary,
3459                        row_ctx,
3460                        row_moments,
3461                        cache.cell_family_forest.as_ref(),
3462                        true,
3463                        &mut scratch,
3464                    )
3465                    .expect("cpu family row analytic flex");
3466
3467                // ── value ────────────────────────────────────────────────────
3468                let o_val = oracle.neglog[row];
3469                if o_val.is_nan() || value.is_nan() {
3470                    assert!(
3471                        o_val.is_nan() && value.is_nan(),
3472                        "row {row}: NaN parity broke — oracle={o_val} family={value}"
3473                    );
3474                    continue;
3475                }
3476                let vd = (o_val - value).abs();
3477                let vtol = tol_abs + tol_rel * o_val.abs();
3478                max_rel = max_rel.max(vd / o_val.abs().max(1.0));
3479                assert!(
3480                    vd <= vtol,
3481                    "row {row} value drift: oracle={o_val:.17e} family={value:.17e} \
3482                     |Δ|={vd:.3e} > tol={vtol:.3e}"
3483                );
3484
3485                // ── gradient ─────────────────────────────────────────────────
3486                for u in 0..r {
3487                    let o_g = oracle.grad[row * r + u];
3488                    let f_g = scratch.grad[u];
3489                    let gd = (o_g - f_g).abs();
3490                    let gtol = tol_abs + tol_rel * o_g.abs();
3491                    max_rel = max_rel.max(gd / o_g.abs().max(1.0));
3492                    assert!(
3493                        gd <= gtol,
3494                        "row {row} grad[{u}] drift: oracle={o_g:.17e} family={f_g:.17e} \
3495                         |Δ|={gd:.3e} > tol={gtol:.3e}"
3496                    );
3497                }
3498
3499                // ── full r×r Hessian ─────────────────────────────────────────
3500                for u in 0..r {
3501                    for v in 0..r {
3502                        let o_h = oracle.hess[row * r * r + u * r + v];
3503                        let f_h = scratch.hess[[u, v]];
3504                        let hd = (o_h - f_h).abs();
3505                        let htol = tol_abs + tol_rel * o_h.abs();
3506                        max_rel = max_rel.max(hd / o_h.abs().max(1.0));
3507                        assert!(
3508                            hd <= htol,
3509                            "row {row} hess[{u},{v}] drift: oracle={o_h:.17e} \
3510                             family={f_h:.17e} |Δ|={hd:.3e} > tol={htol:.3e}"
3511                        );
3512                    }
3513                }
3514            }
3515
3516            // Edge coverage: both label branches must have been exercised (the
3517            // q-row overrides F_q=-mu_1 / F_qq=-mu_2 and both Mills sign branches).
3518            assert!(
3519                checked_labels[0] && checked_labels[1],
3520                "#415 fixture must exercise both y=0 and y=1 rows: {checked_labels:?}"
3521            );
3522            eprintln!(
3523                "#415 parity lock: n={n} r={r} p_h={p_h} p_w={p_w} max_rel(oracle−family)={max_rel:.3e}"
3524            );
3525        }
3526    }
3527}
3528
3529#[cfg(all(test, target_os = "linux"))]
3530mod tests {
3531    use super::oracle_parity_tests::*;
3532    use super::*;
3533
3534    pub(crate) fn minimal_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3535        BmsFlexRowKernelInputs {
3536            n_rows: 1,
3537            r: 4,
3538            p_h: 1,
3539            p_w: 1,
3540            q: &buffers.q,
3541            b: &buffers.b,
3542            mu_1: &buffers.mu_1,
3543            mu_2: &buffers.mu_2,
3544            z_obs: &buffers.z_obs,
3545            y: &buffers.y,
3546            w: &buffers.w,
3547            e_obs: &buffers.e_obs,
3548            s_f: 1.0,
3549            cell_offsets: &buffers.cell_offsets,
3550            cell_c0: &buffers.cell_c0,
3551            cell_c1: &buffers.cell_c1,
3552            cell_c2: &buffers.cell_c2,
3553            cell_c3: &buffers.cell_c3,
3554            cell_a: &buffers.cell_a,
3555            cell_aa: &buffers.cell_aa,
3556            cell_r: &buffers.cell_r,
3557            cell_ar: &buffers.cell_ar,
3558            cell_sbb: &buffers.cell_sbb,
3559            cell_sbh: &buffers.cell_sbh,
3560            cell_sbw: &buffers.cell_sbw,
3561            cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3562            chi_obs: &buffers.chi_obs,
3563            xi_obs: &buffers.xi_obs,
3564            rho_u: &buffers.rho_u,
3565            tau_u: &buffers.tau_u,
3566            r_uv: &buffers.r_uv,
3567        }
3568    }
3569
3570    pub(crate) struct TestBuffers {
3571        pub(crate) q: Vec<f64>,
3572        pub(crate) b: Vec<f64>,
3573        pub(crate) mu_1: Vec<f64>,
3574        pub(crate) mu_2: Vec<f64>,
3575        pub(crate) z_obs: Vec<f64>,
3576        pub(crate) y: Vec<f64>,
3577        pub(crate) w: Vec<f64>,
3578        pub(crate) e_obs: Vec<f64>,
3579        pub(crate) cell_offsets: Vec<u32>,
3580        pub(crate) cell_c0: Vec<f64>,
3581        pub(crate) cell_c1: Vec<f64>,
3582        pub(crate) cell_c2: Vec<f64>,
3583        pub(crate) cell_c3: Vec<f64>,
3584        pub(crate) cell_a: Vec<f64>,
3585        pub(crate) cell_aa: Vec<f64>,
3586        pub(crate) cell_r: Vec<f64>,
3587        pub(crate) cell_ar: Vec<f64>,
3588        pub(crate) cell_sbb: Vec<f64>,
3589        pub(crate) cell_sbh: Vec<f64>,
3590        pub(crate) cell_sbw: Vec<f64>,
3591        pub(crate) cell_moments: Vec<f64>,
3592        pub(crate) chi_obs: Vec<f64>,
3593        pub(crate) xi_obs: Vec<f64>,
3594        pub(crate) rho_u: Vec<f64>,
3595        pub(crate) tau_u: Vec<f64>,
3596        pub(crate) r_uv: Vec<f64>,
3597    }
3598
3599    pub(crate) fn make_buffers(n_cells: u32, r: usize, p_h: usize, p_w: usize) -> TestBuffers {
3600        let cells = n_cells as usize;
3601        TestBuffers {
3602            q: vec![0.1; 1],
3603            b: vec![0.5; 1],
3604            mu_1: vec![0.3; 1],
3605            mu_2: vec![0.07; 1],
3606            z_obs: vec![0.0; 1],
3607            y: vec![1.0; 1],
3608            w: vec![1.0; 1],
3609            e_obs: vec![0.15; 1],
3610            cell_offsets: vec![0, n_cells],
3611            cell_c0: vec![0.2; cells],
3612            cell_c1: vec![-0.1; cells],
3613            cell_c2: vec![0.05; cells],
3614            cell_c3: vec![-0.02; cells],
3615            cell_a: vec![0.1; cells * 4],
3616            cell_aa: vec![0.0; cells * 4],
3617            cell_r: vec![0.05; cells * (r - 1) * 4],
3618            cell_ar: vec![0.0; cells * (r - 1) * 4],
3619            cell_sbb: vec![0.0; cells * 4],
3620            cell_sbh: vec![0.0; cells * p_h * 4],
3621            cell_sbw: vec![0.0; cells * p_w * 4],
3622            cell_moments: vec![1.0; cells * MOMENT_STRIDE],
3623            chi_obs: vec![1.0; 1],
3624            xi_obs: vec![0.0; 1],
3625            rho_u: vec![0.0; r],
3626            tau_u: vec![0.0; r],
3627            r_uv: vec![0.0; r * r],
3628        }
3629    }
3630
3631    #[test]
3632    pub(crate) fn validate_accepts_minimal_inputs() {
3633        let buffers = make_buffers(2, 4, 1, 1);
3634        let inputs = minimal_inputs(&buffers);
3635        assert!(inputs.validate().is_ok());
3636    }
3637
3638    #[test]
3639    pub(crate) fn validate_rejects_r_above_max() {
3640        let r = MAX_R + 1;
3641        let p_h = (r - 2) / 2;
3642        let p_w = (r - 2) - p_h;
3643        let buffers = make_buffers(1, r, p_h, p_w);
3644        let bad_inputs = BmsFlexRowKernelInputs {
3645            r,
3646            p_h,
3647            p_w,
3648            rho_u: &buffers.rho_u, // length matches `r` we wrote
3649            tau_u: &buffers.tau_u,
3650            r_uv: &buffers.r_uv,
3651            cell_r: &buffers.cell_r,
3652            cell_ar: &buffers.cell_ar,
3653            cell_sbh: &buffers.cell_sbh,
3654            cell_sbw: &buffers.cell_sbw,
3655            ..minimal_inputs(&buffers)
3656        };
3657        let err = bad_inputs.validate().expect_err("r > MAX_R must fail");
3658        let msg = err.to_string();
3659        assert!(msg.contains("MAX_R"), "expected MAX_R hint, got: {msg}");
3660    }
3661
3662    #[test]
3663    pub(crate) fn validate_rejects_mismatched_r_decomposition() {
3664        let buffers = make_buffers(1, 4, 1, 1);
3665        let bad_inputs = BmsFlexRowKernelInputs {
3666            r: 4,
3667            p_h: 1,
3668            p_w: 2, // inconsistent with r = 4
3669            ..minimal_inputs(&buffers)
3670        };
3671        let err = bad_inputs
3672            .validate()
3673            .expect_err("inconsistent r vs p_h+p_w must fail");
3674        let msg = err.to_string();
3675        assert!(msg.contains("p_h"), "got: {msg}");
3676        assert!(msg.contains("p_w"), "got: {msg}");
3677    }
3678
3679    #[test]
3680    pub(crate) fn validate_rejects_non_monotone_offsets() {
3681        // `minimal_inputs` hard-codes `n_rows = 1`, so the CSR-style row
3682        // pointer length is `n + 1 = 2`. Pin both `offsets[1] = total_cells`
3683        // and `cell_c0.len() = total_cells = 2` from `make_buffers(2, …)`,
3684        // then violate monotonicity by setting `offsets[0] > offsets[1]`;
3685        // every length / per-cell-count check is satisfied so the only
3686        // failure mode left is the monotonicity guard.
3687        let mut buffers = make_buffers(2, 4, 1, 1);
3688        buffers.cell_offsets = vec![5, 2];
3689        let inputs = minimal_inputs(&buffers);
3690        let err = inputs
3691            .validate()
3692            .expect_err("non-monotone offsets must fail");
3693        let msg = err.to_string();
3694        assert!(msg.contains("monotone"), "got: {msg}");
3695    }
3696
3697    #[test]
3698    pub(crate) fn validate_rejects_mismatched_cell_moments_length() {
3699        let mut buffers = make_buffers(2, 4, 1, 1);
3700        buffers.cell_moments.pop(); // length now 2*10 - 1
3701        let inputs = minimal_inputs(&buffers);
3702        let err = inputs.validate().expect_err("short cell_moments must fail");
3703        let msg = err.to_string();
3704        assert!(msg.contains("cell_moments"), "got: {msg}");
3705    }
3706
3707    #[test]
3708    pub(crate) fn launch_on_non_linux_reports_driver_library_unavailable() {
3709        // Mac/Windows builds must surface a typed `DriverLibraryUnavailable`
3710        // rather than panicking or returning Ok. On Linux this test is
3711        // skipped because the kernel actually launches.
3712        #[cfg(target_os = "linux")]
3713        {
3714            // Linux builds may or may not have a device; the dispatcher
3715            // contract is that without a runtime, probe() returns
3716            // DriverLibraryUnavailable. Either outcome (NoDeviceKernel,
3717            // DriverLibraryUnavailable, or DriverCallFailed) is acceptable
3718            // here; success would mean the kernel actually ran which is a
3719            // V100-only outcome we don't gate the unit test on.
3720            let buffers = make_buffers(1, 4, 1, 1);
3721            let inputs = minimal_inputs(&buffers);
3722            match launch_bms_flex_row_kernel(inputs) {
3723                Ok(_) => { /* V100 host: real launch */ }
3724                Err(GpuError::DriverLibraryUnavailable { .. })
3725                | Err(GpuError::DriverCallFailed { .. })
3726                | Err(GpuError::DriverSymbolMissing { .. })
3727                | Err(GpuError::NoDeviceKernel { .. }) => { /* expected on CPU-only */ }
3728                Err(other) => panic!("unexpected GpuError variant: {other:?}"),
3729            }
3730        }
3731        #[cfg(not(target_os = "linux"))]
3732        {
3733            let buffers = make_buffers(1, 4, 1, 1);
3734            let inputs = minimal_inputs(&buffers);
3735            match launch_bms_flex_row_kernel(inputs) {
3736                Err(GpuError::DriverLibraryUnavailable { reason }) => {
3737                    assert!(
3738                        reason.contains("Linux-only"),
3739                        "expected Linux-only hint, got: {reason}"
3740                    );
3741                }
3742                other => panic!("expected DriverLibraryUnavailable on non-Linux, got {other:?}"),
3743            }
3744        }
3745    }
3746
3747    #[test]
3748    pub(crate) fn s_f_must_be_positive_and_finite() {
3749        let buffers = make_buffers(1, 4, 1, 1);
3750        let mut inputs = minimal_inputs(&buffers);
3751        inputs.s_f = 0.0;
3752        match launch_bms_flex_row_kernel(inputs) {
3753            Err(GpuError::DriverCallFailed { reason }) => {
3754                assert!(reason.contains("s_f"), "got: {reason}");
3755            }
3756            other => panic!("expected DriverCallFailed for s_f=0, got {other:?}"),
3757        }
3758    }
3759
3760    /// Build a non-trivial fixture: `n = 4` rows, `r = 5` (p_h = 2, p_w = 1),
3761    /// 2–4 cells per row, distinct values so a structural bug in either path
3762    /// can't be masked by accidental cancellation.
3763    pub(crate) fn make_parity_buffers() -> TestBuffers {
3764        let n = 4_usize;
3765        let r = 5_usize;
3766        let p_h = 2_usize;
3767        let p_w = 1_usize;
3768        // Per-row cell counts: 2, 3, 4, 2 → total 11 cells.
3769        let row_cells: [u32; 4] = [2, 3, 4, 2];
3770        let mut cell_offsets = vec![0_u32; n + 1];
3771        for i in 0..n {
3772            cell_offsets[i + 1] = cell_offsets[i] + row_cells[i];
3773        }
3774        let total_cells = cell_offsets[n] as usize;
3775
3776        // Deterministic but varied generators (LCG-ish so each slot is distinct).
3777        let f = |seed: usize| -> f64 {
3778            let x = ((seed.wrapping_mul(2_654_435_761)) & 0xFFFF) as f64 / 65_536.0;
3779            0.1 + 0.4 * x
3780        };
3781
3782        let q = (0..n).map(|i| 0.05 + 0.1 * (i as f64)).collect::<Vec<_>>();
3783        let b = (0..n).map(|i| 0.6 + 0.05 * (i as f64)).collect::<Vec<_>>();
3784        let mu_1 = (0..n).map(|i| 0.7 + 0.02 * (i as f64)).collect::<Vec<_>>();
3785        let mu_2 = (0..n).map(|i| 0.15 + 0.01 * (i as f64)).collect::<Vec<_>>();
3786        let z_obs = (0..n).map(|i| -0.2 + 0.1 * (i as f64)).collect::<Vec<_>>();
3787        let y = [1.0, 0.0, 1.0, 0.0].to_vec();
3788        let w = vec![1.0; n];
3789        let e_obs = (0..n).map(|i| -0.3 + 0.2 * (i as f64)).collect::<Vec<_>>();
3790
3791        let cell_c0 = (0..total_cells).map(|c| f(c + 1001)).collect::<Vec<_>>();
3792        let cell_c1 = (0..total_cells)
3793            .map(|c| -f(c + 2002) * 0.5)
3794            .collect::<Vec<_>>();
3795        let cell_c2 = (0..total_cells).map(|c| f(c + 3003) * 0.2).collect();
3796        let cell_c3 = (0..total_cells).map(|c| -f(c + 4004) * 0.1).collect();
3797
3798        let cell_a = (0..total_cells * 4)
3799            .map(|i| f(i + 5005) * 0.3)
3800            .collect::<Vec<_>>();
3801        let cell_aa = (0..total_cells * 4)
3802            .map(|i| f(i + 6006) * 0.1)
3803            .collect::<Vec<_>>();
3804        let cell_r = (0..total_cells * (r - 1) * 4)
3805            .map(|i| f(i + 7007) * 0.2)
3806            .collect::<Vec<_>>();
3807        let cell_ar = (0..total_cells * (r - 1) * 4)
3808            .map(|i| f(i + 8008) * 0.05)
3809            .collect::<Vec<_>>();
3810        let cell_sbb = (0..total_cells * 4)
3811            .map(|i| f(i + 9009) * 0.08)
3812            .collect::<Vec<_>>();
3813        let cell_sbh = (0..total_cells * p_h * 4)
3814            .map(|i| f(i + 10_010) * 0.07)
3815            .collect::<Vec<_>>();
3816        let cell_sbw = (0..total_cells * p_w * 4)
3817            .map(|i| f(i + 11_011) * 0.06)
3818            .collect::<Vec<_>>();
3819        let cell_moments = (0..total_cells * MOMENT_STRIDE)
3820            .map(|i| 0.4 + 0.1 * f(i + 12_012))
3821            .collect::<Vec<_>>();
3822
3823        let chi_obs = (0..n).map(|i| 0.9 + 0.01 * (i as f64)).collect::<Vec<_>>();
3824        let xi_obs = (0..n).map(|i| 0.2 + 0.01 * (i as f64)).collect::<Vec<_>>();
3825        let rho_u = (0..n * r).map(|i| 0.03 * f(i + 13_013)).collect::<Vec<_>>();
3826        let tau_u = (0..n * r).map(|i| 0.02 * f(i + 14_014)).collect::<Vec<_>>();
3827        let r_uv = (0..n * r * r)
3828            .map(|i| 0.04 * f(i + 15_015))
3829            .collect::<Vec<_>>();
3830
3831        TestBuffers {
3832            q,
3833            b,
3834            mu_1,
3835            mu_2,
3836            z_obs,
3837            y,
3838            w,
3839            e_obs,
3840            cell_offsets,
3841            cell_c0,
3842            cell_c1,
3843            cell_c2,
3844            cell_c3,
3845            cell_a,
3846            cell_aa,
3847            cell_r,
3848            cell_ar,
3849            cell_sbb,
3850            cell_sbh,
3851            cell_sbw,
3852            cell_moments,
3853            chi_obs,
3854            xi_obs,
3855            rho_u,
3856            tau_u,
3857            r_uv,
3858        }
3859    }
3860
3861    pub(crate) fn parity_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3862        BmsFlexRowKernelInputs {
3863            n_rows: 4,
3864            r: 5,
3865            p_h: 2,
3866            p_w: 1,
3867            q: &buffers.q,
3868            b: &buffers.b,
3869            mu_1: &buffers.mu_1,
3870            mu_2: &buffers.mu_2,
3871            z_obs: &buffers.z_obs,
3872            y: &buffers.y,
3873            w: &buffers.w,
3874            e_obs: &buffers.e_obs,
3875            s_f: 1.0,
3876            cell_offsets: &buffers.cell_offsets,
3877            cell_c0: &buffers.cell_c0,
3878            cell_c1: &buffers.cell_c1,
3879            cell_c2: &buffers.cell_c2,
3880            cell_c3: &buffers.cell_c3,
3881            cell_a: &buffers.cell_a,
3882            cell_aa: &buffers.cell_aa,
3883            cell_r: &buffers.cell_r,
3884            cell_ar: &buffers.cell_ar,
3885            cell_sbb: &buffers.cell_sbb,
3886            cell_sbh: &buffers.cell_sbh,
3887            cell_sbw: &buffers.cell_sbw,
3888            cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3889            chi_obs: &buffers.chi_obs,
3890            xi_obs: &buffers.xi_obs,
3891            rho_u: &buffers.rho_u,
3892            tau_u: &buffers.tau_u,
3893            r_uv: &buffers.r_uv,
3894        }
3895    }
3896
3897    /// Symmetry + finiteness of the CPU oracle. Runs on every host (Linux,
3898    /// macOS, CPU CI) since the oracle is platform-independent. Guarantees the
3899    /// reference path used by the GPU parity test is itself well-formed.
3900    #[test]
3901    pub(crate) fn cpu_oracle_produces_finite_symmetric_hessian() {
3902        let buffers = make_parity_buffers();
3903        let inputs = parity_inputs(&buffers);
3904        inputs
3905            .validate()
3906            .expect("parity fixture must satisfy validate()");
3907        let out = cpu_oracle_outputs(&inputs);
3908        let n = inputs.n_rows;
3909        let r = inputs.r;
3910        assert_eq!(out.neglog.len(), n);
3911        assert_eq!(out.grad.len(), n * r);
3912        assert_eq!(out.hess.len(), n * r * r);
3913        for row in 0..n {
3914            assert!(
3915                out.neglog[row].is_finite(),
3916                "row {row}: neglog must be finite, got {}",
3917                out.neglog[row]
3918            );
3919            for u in 0..r {
3920                let g = out.grad[row * r + u];
3921                assert!(g.is_finite(), "row {row}: grad[{u}] = {g}");
3922                for v in 0..r {
3923                    let huv = out.hess[row * r * r + u * r + v];
3924                    let hvu = out.hess[row * r * r + v * r + u];
3925                    assert!(huv.is_finite(), "row {row}: H[{u},{v}] = {huv}");
3926                    assert_eq!(
3927                        huv.to_bits(),
3928                        hvu.to_bits(),
3929                        "row {row}: H[{u},{v}] and H[{v},{u}] must be bit-identical"
3930                    );
3931                }
3932            }
3933        }
3934    }
3935
3936    /// Independent finite-difference correctness lock on the oracle's probit
3937    /// Mills layer — the most optimizer-sensitive, drift-prone term in the
3938    /// whole row kernel (issue #415: "third/fourth-order derivative
3939    /// contractions drift silently … formulas are complex and
3940    /// optimizer-sensitive"). The device kernel's `bms_flex_row_kernel` and
3941    /// the host oracle's `cpu_oracle_outputs` both close out with the same
3942    /// Mills algebra:
3943    ///
3944    /// ```text
3945    ///     m       = s · e_obs ;  s = 2y − 1
3946    ///     A       = −w · s · λ(m)
3947    ///     B       =  w · λ(m) · (m + λ(m))
3948    ///     neglog  = −w · log Φ(s · e_obs)
3949    ///     g_u     = A · bar_e_u
3950    ///     H_uv    = B · bar_e_u · bar_e_v + A · bar_e_uv
3951    /// ```
3952    ///
3953    /// Holding the observed derivative jets `bar_e_u`/`bar_e_uv` fixed, the
3954    /// row neglog is a function of the observed predictor VALUE `e := e_obs`,
3955    /// not of the q-axis first derivative `bar_e_u[0]`; by the assembled
3956    /// formula `∂neglog/∂e = A` and `∂²neglog/∂e² = B`. This test reconstructs
3957    /// `A`, `B`, and `neglog` exactly as the oracle does (same
3958    /// `oracle_log_ndtr_and_mills`, same sign convention), then verifies the
3959    /// analytic `A`/`B` against high-order central differences of
3960    /// `e ↦ −w · log Φ(s·e)`. A drift in the kernel's Mills derivatives —
3961    /// which the device-parity test cannot catch because it checks the kernel
3962    /// *against the same (possibly-wrong) oracle* — fails here on every host,
3963    /// CUDA or not. Bounds are the genuine fifth-order central-difference
3964    /// truncation floor; they are not weakened to pass.
3965    #[test]
3966    pub(crate) fn cpu_oracle_mills_layer_matches_finite_differences() {
3967        // Probit neglog as the oracle assembles it, as a function of the
3968        // observed scalar predictor `e` with weight `w` and label `y`.
3969        let neglog_of = |e: f64, y: f64, w: f64| -> f64 {
3970            let s = 2.0 * y - 1.0;
3971            let (log_cdf, _) = oracle_log_ndtr_and_mills(s * e);
3972            -w * log_cdf
3973        };
3974        // Analytic first/second derivatives wrt `e` — the exact `A`/`B` the
3975        // kernel writes into `grad`/`hess`.
3976        let ab_of = |e: f64, y: f64, w: f64| -> (f64, f64) {
3977            let s = 2.0 * y - 1.0;
3978            let m_arg = s * e;
3979            let (_, lambda) = oracle_log_ndtr_and_mills(m_arg);
3980            let a_i = -w * s * lambda;
3981            let b_i = w * lambda * (m_arg + lambda);
3982            (a_i, b_i)
3983        };
3984
3985        // Sweep both labels (s = ±1), both tails of the predictor, and a
3986        // non-unit weight so every sign/scale path of the Mills algebra is
3987        // exercised. Points stay clear of the deep-tail asymptote where a
3988        // central-difference reference loses its own accuracy.
3989        let cases: [(f64, f64, f64); 12] = [
3990            (-1.6, 1.0, 1.0),
3991            (-0.7, 1.0, 1.0),
3992            (0.0, 1.0, 1.0),
3993            (0.9, 1.0, 1.0),
3994            (1.8, 1.0, 1.0),
3995            (-1.4, 0.0, 1.0),
3996            (-0.3, 0.0, 1.0),
3997            (0.0, 0.0, 1.0),
3998            (0.6, 0.0, 1.0),
3999            (1.5, 0.0, 1.0),
4000            (0.4, 1.0, 0.75),
4001            (-0.8, 0.0, 1.3),
4002        ];
4003        // Fifth-order central stencils; `h` chosen near the f64 sweet spot for
4004        // first/second derivatives of a smooth O(1) function.
4005        let h = 1e-3_f64;
4006        for (e, y, w) in cases {
4007            let (a_ana, b_ana) = ab_of(e, y, w);
4008
4009            let fp2 = neglog_of(e + 2.0 * h, y, w);
4010            let fp1 = neglog_of(e + h, y, w);
4011            let f0 = neglog_of(e, y, w);
4012            let fm1 = neglog_of(e - h, y, w);
4013            let fm2 = neglog_of(e - 2.0 * h, y, w);
4014
4015            // 5-point central first derivative: O(h⁴).
4016            let d1_fd = (-fp2 + 8.0 * fp1 - 8.0 * fm1 + fm2) / (12.0 * h);
4017            // 5-point central second derivative: O(h⁴).
4018            let d2_fd = (-fp2 + 16.0 * fp1 - 30.0 * f0 + 16.0 * fm1 - fm2) / (12.0 * h * h);
4019
4020            let a_abs = (a_ana - d1_fd).abs();
4021            let a_rel = a_abs / a_ana.abs().max(1.0);
4022            assert!(
4023                a_abs <= 5e-8 || a_rel <= 5e-8,
4024                "Mills A (∂neglog/∂e) drift at e={e} y={y} w={w}: \
4025                 analytic={a_ana:.17e} fd={d1_fd:.17e} abs={a_abs:.3e} rel={a_rel:.3e}"
4026            );
4027
4028            let b_abs = (b_ana - d2_fd).abs();
4029            let b_rel = b_abs / b_ana.abs().max(1.0);
4030            assert!(
4031                b_abs <= 5e-6 || b_rel <= 5e-6,
4032                "Mills B (∂²neglog/∂e²) drift at e={e} y={y} w={w}: \
4033                 analytic={b_ana:.17e} fd={d2_fd:.17e} abs={b_abs:.3e} rel={b_rel:.3e}"
4034            );
4035        }
4036    }
4037
4038    /// CPU↔GPU parity. Only runs end-to-end on a Linux host with a CUDA
4039    /// runtime; skips with a clear `eprintln!` on every other host so the
4040    /// always-on test suite stays green on the macOS dev box and CPU CI.
4041    ///
4042    /// On a CUDA host: drives the kernel through `launch_bms_flex_row_kernel`
4043    /// and the same `BmsFlexRowKernelInputs` through `cpu_oracle_outputs`,
4044    /// then asserts every element of `neglog`, `grad`, and `hess` agrees
4045    /// within `|Δ| <= 1e-8 + 1e-8·|cpu|` (absolute-or-relative).
4046    #[test]
4047    pub(crate) fn bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available() {
4048        #[cfg(not(target_os = "linux"))]
4049        {
4050            eprintln!(
4051                "[bms_flex_row parity] non-Linux host — skipping CUDA parity \
4052                 (CPU oracle exercised by sibling test)"
4053            );
4054            return;
4055        }
4056        #[cfg(target_os = "linux")]
4057        {
4058            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4059                eprintln!(
4060                    "[bms_flex_row parity] no CUDA runtime — skipping device \
4061                     parity (CPU oracle exercised by sibling test)"
4062                );
4063                return;
4064            };
4065            let buffers = make_parity_buffers();
4066            let inputs_cpu = parity_inputs(&buffers);
4067            inputs_cpu
4068                .validate()
4069                .expect("parity fixture must satisfy validate()");
4070            let cpu_out = cpu_oracle_outputs(&inputs_cpu);
4071
4072            // Launch the device kernel against the same inputs.
4073            let inputs_gpu = parity_inputs(&buffers);
4074            let gpu_out = match launch_bms_flex_row_kernel(inputs_gpu) {
4075                Ok(out) => out,
4076                Err(err) => panic!(
4077                    "[bms_flex_row parity] launch failed on CUDA-selected host; \
4078                     device/oracle parity must fail loudly on GPU CI: {err}"
4079                ),
4080            };
4081
4082            let n = inputs_cpu.n_rows;
4083            let r = inputs_cpu.r;
4084            let tol_abs = 1e-8_f64;
4085            let tol_rel = 1e-8_f64;
4086            let check_close = |label: &str, idx: usize, cpu: f64, gpu: f64| {
4087                if cpu.is_nan() || gpu.is_nan() {
4088                    assert!(
4089                        cpu.is_nan() && gpu.is_nan(),
4090                        "{label}[{idx}]: NaN parity broke — cpu={cpu}, gpu={gpu}"
4091                    );
4092                    return;
4093                }
4094                let diff = (cpu - gpu).abs();
4095                let tol = tol_abs + tol_rel * cpu.abs();
4096                assert!(
4097                    diff <= tol,
4098                    "{label}[{idx}]: |cpu − gpu| = {diff:.3e} > tol = {tol:.3e}; \
4099                     cpu={cpu:.17e}, gpu={gpu:.17e}"
4100                );
4101            };
4102            assert_eq!(cpu_out.neglog.len(), gpu_out.neglog.len());
4103            assert_eq!(cpu_out.grad.len(), gpu_out.grad.len());
4104            assert_eq!(cpu_out.hess.len(), gpu_out.hess.len());
4105            for (i, (&c, &g)) in cpu_out.neglog.iter().zip(gpu_out.neglog.iter()).enumerate() {
4106                check_close("neglog", i, c, g);
4107            }
4108            for (i, (&c, &g)) in cpu_out.grad.iter().zip(gpu_out.grad.iter()).enumerate() {
4109                check_close("grad", i, c, g);
4110            }
4111            for (i, (&c, &g)) in cpu_out.hess.iter().zip(gpu_out.hess.iter()).enumerate() {
4112                check_close("hess", i, c, g);
4113            }
4114            // Spot-check exact symmetry on the GPU Hessian too.
4115            for row in 0..n {
4116                for u in 0..r {
4117                    for v in 0..r {
4118                        let a = gpu_out.hess[row * r * r + u * r + v];
4119                        let bb = gpu_out.hess[row * r * r + v * r + u];
4120                        assert_eq!(
4121                            a.to_bits(),
4122                            bb.to_bits(),
4123                            "GPU row {row}: H[{u},{v}] ≠ H[{v},{u}] bit-for-bit"
4124                        );
4125                    }
4126                }
4127            }
4128        }
4129    }
4130
4131    #[test]
4132    pub(crate) fn kernel_source_mentions_cpu_parity_reference() {
4133        // Guarantee the maintainer-facing parity reference comment survives
4134        // refactors of the NVRTC kernel source — the dispatcher wave that
4135        // wires this to bms_flex.rs cross-checks parity against the CPU
4136        // function named here.
4137        #[cfg(target_os = "linux")]
4138        assert!(ROW_KERNEL_BODY.contains("compute_row_analytic_flex_from_parts_into"));
4139        #[cfg(target_os = "linux")]
4140        assert!(ROW_KERNEL_BODY.contains("cell_first_derivative_from_moments"));
4141    }
4142
4143    // ── Phase-3 HVP / diagonal CPU oracles + GPU parity tests ────────────────
4144
4145    /// CPU oracle for [`launch_bms_flex_row_hvp`]. Mirrors the device kernel
4146    /// element-for-element so the GPU parity test runs against the same algebra.
4147    pub(crate) fn cpu_oracle_bms_flex_row_hvp(
4148        row_hessians: &[f64],
4149        marginal_design: &[f64],
4150        logslope_design: &[f64],
4151        block: &BmsFlexBlockLayout,
4152        primary: &BmsFlexPrimaryLayout,
4153        n: usize,
4154        v: &[f64],
4155    ) -> Vec<f64> {
4156        let r = primary.r;
4157        let p_m = block.p_m;
4158        let p_g = block.p_g;
4159        assert_eq!(v.len(), block.p_total);
4160        assert_eq!(row_hessians.len(), n * r * r);
4161        assert_eq!(marginal_design.len(), n * p_m);
4162        assert_eq!(logslope_design.len(), n * p_g);
4163        let mut out = vec![0.0_f64; block.p_total];
4164        let mut row_dir = vec![0.0_f64; r];
4165        let mut action = vec![0.0_f64; r];
4166        for row in 0..n {
4167            let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
4168            let grow = &logslope_design[row * p_g..(row + 1) * p_g];
4169            let mut acc_q = 0.0_f64;
4170            for j in 0..p_m {
4171                acc_q += mrow[j] * v[j];
4172            }
4173            let mut acc_g = 0.0_f64;
4174            for j in 0..p_g {
4175                acc_g += grow[j] * v[p_m + j];
4176            }
4177            row_dir[0] = acc_q;
4178            row_dir[1] = acc_g;
4179            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4180                for (k, ii) in prange.clone().enumerate() {
4181                    row_dir[ii] = v[brange.start + k];
4182                }
4183            }
4184            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4185                for (k, ii) in prange.clone().enumerate() {
4186                    row_dir[ii] = v[brange.start + k];
4187                }
4188            }
4189            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4190            for u in 0..r {
4191                let mut acc = 0.0_f64;
4192                for v_idx in 0..r {
4193                    acc += h_slice[u * r + v_idx] * row_dir[v_idx];
4194                }
4195                action[u] = acc;
4196            }
4197            let a0 = action[0];
4198            for j in 0..p_m {
4199                out[j] += a0 * mrow[j];
4200            }
4201            let a1 = action[1];
4202            for j in 0..p_g {
4203                out[p_m + j] += a1 * grow[j];
4204            }
4205            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4206                for (k, ii) in prange.clone().enumerate() {
4207                    out[brange.start + k] += action[ii];
4208                }
4209            }
4210            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4211                for (k, ii) in prange.clone().enumerate() {
4212                    out[brange.start + k] += action[ii];
4213                }
4214            }
4215        }
4216        out
4217    }
4218
4219    pub(crate) fn cpu_oracle_bms_flex_row_diagonal(
4220        row_hessians: &[f64],
4221        marginal_design: &[f64],
4222        logslope_design: &[f64],
4223        block: &BmsFlexBlockLayout,
4224        primary: &BmsFlexPrimaryLayout,
4225        n: usize,
4226    ) -> Vec<f64> {
4227        let r = primary.r;
4228        let p_m = block.p_m;
4229        let p_g = block.p_g;
4230        let mut out = vec![0.0_f64; block.p_total];
4231        for row in 0..n {
4232            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4233            let h00 = h_slice[0];
4234            let h11 = h_slice[r + 1];
4235            let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
4236            let grow = &logslope_design[row * p_g..(row + 1) * p_g];
4237            for j in 0..p_m {
4238                out[j] += h00 * mrow[j] * mrow[j];
4239            }
4240            for j in 0..p_g {
4241                out[p_m + j] += h11 * grow[j] * grow[j];
4242            }
4243            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4244                for (k, ii) in prange.clone().enumerate() {
4245                    out[brange.start + k] += h_slice[ii * r + ii];
4246                }
4247            }
4248            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4249                for (k, ii) in prange.clone().enumerate() {
4250                    out[brange.start + k] += h_slice[ii * r + ii];
4251                }
4252            }
4253        }
4254        out
4255    }
4256
4257    /// Hand-construct a small symmetric per-row Hessian + small designs and
4258    /// verify the CPU oracle satisfies the expected algebra. Platform-
4259    /// independent (runs on macOS / Linux without CUDA).
4260    #[test]
4261    pub(crate) fn cpu_oracle_hvp_matches_hand_computation_no_hw() {
4262        let n = 4_usize;
4263        let r = 4_usize; // q, logslope, h(1), w(1)
4264        let p_m = 2_usize;
4265        let p_g = 2_usize;
4266        let p_h_dim = 1_usize;
4267        let p_w_dim = 1_usize;
4268        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4269        let block = BmsFlexBlockLayout {
4270            p_m,
4271            p_g,
4272            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4273            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4274            p_total,
4275        };
4276        let primary = BmsFlexPrimaryLayout {
4277            h: Some(2..3),
4278            w: Some(3..4),
4279            r,
4280        };
4281        // Symmetric per-row Hessian: H_row[u,v] = (row + 1) * (1 + u + 2v) symmetrised.
4282        let mut row_hessians = vec![0.0_f64; n * r * r];
4283        for row in 0..n {
4284            for u in 0..r {
4285                for v in u..r {
4286                    let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4287                    row_hessians[row * r * r + u * r + v] = val;
4288                    row_hessians[row * r * r + v * r + u] = val;
4289                }
4290            }
4291        }
4292        let mut marginal = vec![0.0_f64; n * p_m];
4293        for row in 0..n {
4294            for j in 0..p_m {
4295                marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4296            }
4297        }
4298        let mut logslope = vec![0.0_f64; n * p_g];
4299        for row in 0..n {
4300            for j in 0..p_g {
4301                logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4302            }
4303        }
4304        let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4305        let out = cpu_oracle_bms_flex_row_hvp(
4306            &row_hessians,
4307            &marginal,
4308            &logslope,
4309            &block,
4310            &primary,
4311            n,
4312            &v,
4313        );
4314        // Hand check the first marginal slot: out[0] = Σ_row action[0]·mrow[0].
4315        let mut expect_out_0 = 0.0_f64;
4316        for row in 0..n {
4317            let mrow = &marginal[row * p_m..(row + 1) * p_m];
4318            let grow = &logslope[row * p_g..(row + 1) * p_g];
4319            let mut row_dir = vec![0.0_f64; r];
4320            row_dir[0] = mrow[0] * v[0] + mrow[1] * v[1];
4321            row_dir[1] = grow[0] * v[p_m] + grow[1] * v[p_m + 1];
4322            row_dir[2] = v[p_m + p_g];
4323            row_dir[3] = v[p_m + p_g + p_h_dim];
4324            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4325            let mut action0 = 0.0_f64;
4326            // h_slice is the row-major r×r Hessian for this row; we want
4327            // row 0, i.e. entries (0, vv) for vv in 0..r, which lives at
4328            // `vv` in the flat layout.
4329            for vv in 0..r {
4330                action0 += h_slice[vv] * row_dir[vv];
4331            }
4332            expect_out_0 += action0 * mrow[0];
4333        }
4334        assert!(
4335            (out[0] - expect_out_0).abs() < 1e-12,
4336            "cpu oracle HVP out[0] mismatch: {} vs hand-check {}",
4337            out[0],
4338            expect_out_0
4339        );
4340        assert!(out.iter().all(|x| x.is_finite()));
4341        assert_eq!(out.len(), p_total);
4342    }
4343
4344    /// Diagonal oracle equals the explicit per-row design² accumulator.
4345    #[test]
4346    pub(crate) fn cpu_oracle_diagonal_matches_hand_computation() {
4347        let n = 3_usize;
4348        let r = 4_usize;
4349        let p_m = 2_usize;
4350        let p_g = 2_usize;
4351        let p_h_dim = 1_usize;
4352        let p_w_dim = 1_usize;
4353        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4354        let block = BmsFlexBlockLayout {
4355            p_m,
4356            p_g,
4357            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4358            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4359            p_total,
4360        };
4361        let primary = BmsFlexPrimaryLayout {
4362            h: Some(2..3),
4363            w: Some(3..4),
4364            r,
4365        };
4366        let mut row_hessians = vec![0.0_f64; n * r * r];
4367        for row in 0..n {
4368            for u in 0..r {
4369                row_hessians[row * r * r + u * r + u] = 1.0 + (row as f64) + (u as f64) * 0.5;
4370            }
4371        }
4372        let mut marginal = vec![0.0_f64; n * p_m];
4373        let mut logslope = vec![0.0_f64; n * p_g];
4374        for row in 0..n {
4375            for j in 0..p_m {
4376                marginal[row * p_m + j] = 0.2 + (row as f64) * 0.3 + (j as f64) * 0.1;
4377            }
4378            for j in 0..p_g {
4379                logslope[row * p_g + j] = -0.4 + (row as f64) * 0.1 + (j as f64) * 0.2;
4380            }
4381        }
4382        let out = cpu_oracle_bms_flex_row_diagonal(
4383            &row_hessians,
4384            &marginal,
4385            &logslope,
4386            &block,
4387            &primary,
4388            n,
4389        );
4390        // Hand check: out[0] = Σ_row H[row,0,0] · marginal[row,0]^2.
4391        let mut expect = 0.0_f64;
4392        for row in 0..n {
4393            let h00 = row_hessians[row * r * r];
4394            expect += h00 * marginal[row * p_m].powi(2);
4395        }
4396        assert!(
4397            (out[0] - expect).abs() < 1e-12,
4398            "out[0] {} vs {}",
4399            out[0],
4400            expect
4401        );
4402        // h slot = sum of H[row, 2, 2] across rows.
4403        let mut expect_h = 0.0_f64;
4404        for row in 0..n {
4405            expect_h += row_hessians[row * r * r + 2 * r + 2];
4406        }
4407        let h_slot = p_m + p_g;
4408        assert!(
4409            (out[h_slot] - expect_h).abs() < 1e-12,
4410            "h slot {} vs {}",
4411            out[h_slot],
4412            expect_h
4413        );
4414    }
4415
4416    /// GPU↔CPU parity for the HVP and diagonal kernels. Skips on non-Linux /
4417    /// no-CUDA hosts. Hand-constructs a small `DeviceResidentRowHess` by
4418    /// allocating the device slices directly, uploading the same arrays the
4419    /// CPU oracle consumes, then dispatching the device kernels.
4420    #[test]
4421    pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_when_cuda_available() {
4422        #[cfg(not(target_os = "linux"))]
4423        {
4424            eprintln!(
4425                "[bms_flex_row hvp parity] non-Linux host — skipping CUDA parity \
4426                 (CPU oracle exercised by sibling tests)"
4427            );
4428        }
4429        #[cfg(target_os = "linux")]
4430        {
4431            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4432                eprintln!(
4433                    "[bms_flex_row hvp parity] no CUDA runtime — skipping device \
4434                     parity"
4435                );
4436                return;
4437            };
4438            let n = 4_usize;
4439            let r = 4_usize;
4440            let p_m = 2_usize;
4441            let p_g = 2_usize;
4442            let p_h_dim = 1_usize;
4443            let p_w_dim = 1_usize;
4444            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4445            let block = BmsFlexBlockLayout {
4446                p_m,
4447                p_g,
4448                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4449                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4450                p_total,
4451            };
4452            let primary = BmsFlexPrimaryLayout {
4453                h: Some(2..3),
4454                w: Some(3..4),
4455                r,
4456            };
4457            let mut row_hessians = vec![0.0_f64; n * r * r];
4458            for row in 0..n {
4459                for u in 0..r {
4460                    for v in u..r {
4461                        let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4462                        row_hessians[row * r * r + u * r + v] = val;
4463                        row_hessians[row * r * r + v * r + u] = val;
4464                    }
4465                }
4466            }
4467            let mut marginal = vec![0.0_f64; n * p_m];
4468            for row in 0..n {
4469                for j in 0..p_m {
4470                    marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4471                }
4472            }
4473            let mut logslope = vec![0.0_f64; n * p_g];
4474            for row in 0..n {
4475                for j in 0..p_g {
4476                    logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4477                }
4478            }
4479            let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4480            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4481                &row_hessians,
4482                &marginal,
4483                &logslope,
4484                &block,
4485                &primary,
4486                n,
4487                &v,
4488            );
4489            let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4490                &row_hessians,
4491                &marginal,
4492                &logslope,
4493                &block,
4494                &primary,
4495                n,
4496            );
4497
4498            // Allocate a DeviceResidentRowHess by hand using the HVP backend's
4499            // stream + module so we don't need to drive the full BMS row kernel.
4500            // Past the GpuRuntime::global() Some-gate above: a probe/upload failure
4501            // here is a real device fault on a CUDA host, not a no-CUDA skip. Fail
4502            // loud (the device-PCG skip-pass class, eee12f6b2) — the old arms
4503            // returned and the test passed while exercising nothing.
4504            let backend = HvpKernelBackend::probe()
4505                .expect("[bms_flex_row hvp parity] backend probe must succeed on CUDA host");
4506            let stream = backend.stream.clone();
4507            let d_h = stream
4508                .clone_htod(&row_hessians)
4509                .expect("[bms_flex_row hvp parity] upload h must succeed on CUDA host");
4510            let d_m = stream
4511                .clone_htod(&marginal)
4512                .expect("[bms_flex_row hvp parity] upload marg must succeed on CUDA host");
4513            let d_g = stream
4514                .clone_htod(&logslope)
4515                .expect("[bms_flex_row hvp parity] upload logslope must succeed on CUDA host");
4516            let storage = DeviceResidentRowHess {
4517                hess: d_h,
4518                marginal_design: d_m,
4519                logslope_design: d_g,
4520                n,
4521                r,
4522                block: block.clone(),
4523                primary: primary.clone(),
4524
4525                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4526            };
4527            let gpu_hvp =
4528                launch_bms_flex_row_hvp(&storage, &v).expect("HVP kernel must launch on CUDA host");
4529            let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4530                .expect("diagonal kernel must launch on CUDA host");
4531            assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4532            assert_eq!(gpu_diag.len(), cpu_diag.len());
4533            for i in 0..p_total {
4534                let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4535                assert!(
4536                    diff <= 1e-10,
4537                    "HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4538                    cpu_hvp[i],
4539                    gpu_hvp[i]
4540                );
4541                let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4542                assert!(
4543                    ddiff <= 1e-10,
4544                    "diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4545                    cpu_diag[i],
4546                    gpu_diag[i]
4547                );
4548            }
4549        }
4550    }
4551
4552    #[test]
4553    pub(crate) fn bms_flex_row_hvp_multi_scratch_is_bounded_at_large_scale_shape() {
4554        let n = 195_000_usize;
4555        let r = 20_usize;
4556        let p_total = 44_usize;
4557        let rhs_count = 4_usize;
4558        let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4559            .expect("large-scale multi-RHS scratch budget");
4560        let per_rhs_full_row_cache =
4561            (n * r * r * std::mem::size_of::<f64>()) as u64 * rhs_count as u64;
4562        assert!(
4563            scratch < per_rhs_full_row_cache / 100,
4564            "multi-RHS scratch must tile by row chunks instead of materializing \
4565             a row-Hessian copy per RHS: scratch={scratch} full_per_rhs={per_rhs_full_row_cache}"
4566        );
4567        assert!(
4568            bms_flex_row_hvp_multi_scratch_bytes_for_shape(
4569                n,
4570                p_total,
4571                BMS_FLEX_ROW_HVP_MAX_RHS + 1
4572            )
4573            .is_err(),
4574            "multi-RHS launch must reject unbounded RHS counts"
4575        );
4576    }
4577
4578    #[test]
4579    pub(crate) fn bms_flex_row_hvp_multi_kernel_matches_cpu_oracle_when_cuda_available() {
4580        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4581            eprintln!("[bms_flex_row hvp_multi parity] no CUDA runtime — skipping device parity");
4582            return;
4583        };
4584        let n = 5_usize;
4585        let r = 4_usize;
4586        let p_m = 2_usize;
4587        let p_g = 2_usize;
4588        let p_h_dim = 1_usize;
4589        let p_w_dim = 1_usize;
4590        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4591        let rhs_count = 3_usize;
4592        let block = BmsFlexBlockLayout {
4593            p_m,
4594            p_g,
4595            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4596            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4597            p_total,
4598        };
4599        let primary = BmsFlexPrimaryLayout {
4600            h: Some(2..3),
4601            w: Some(3..4),
4602            r,
4603        };
4604        let mut row_hessians = vec![0.0_f64; n * r * r];
4605        for row in 0..n {
4606            for u in 0..r {
4607                for v in u..r {
4608                    let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4609                    row_hessians[row * r * r + u * r + v] = val;
4610                    row_hessians[row * r * r + v * r + u] = val;
4611                }
4612            }
4613        }
4614        let mut marginal = vec![0.0_f64; n * p_m];
4615        let mut logslope = vec![0.0_f64; n * p_g];
4616        for row in 0..n {
4617            for j in 0..p_m {
4618                marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4619            }
4620            for j in 0..p_g {
4621                logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4622            }
4623        }
4624        let mut v_rhs = vec![0.0_f64; rhs_count * p_total];
4625        for rhs in 0..rhs_count {
4626            for j in 0..p_total {
4627                let seed = (rhs as f64) * 0.37 + (j as f64) * 0.19 + 0.4;
4628                v_rhs[rhs * p_total + j] = seed.sin() * 0.4 + seed.cos() * 0.2;
4629            }
4630        }
4631
4632        // Past the GpuRuntime::global() Some-gate: a probe/upload failure here is a
4633        // real device fault on a CUDA host, not a no-CUDA skip — fail loud
4634        // (device-PCG skip-pass class, eee12f6b2).
4635        let backend = HvpKernelBackend::probe()
4636            .expect("[bms_flex_row hvp_multi parity] backend probe must succeed on CUDA host");
4637        let stream = backend.stream.clone();
4638        let d_h = stream
4639            .clone_htod(&row_hessians)
4640            .expect("[bms_flex_row hvp_multi parity] upload h must succeed on CUDA host");
4641        let d_m = stream
4642            .clone_htod(&marginal)
4643            .expect("[bms_flex_row hvp_multi parity] upload marg must succeed on CUDA host");
4644        let d_g = stream
4645            .clone_htod(&logslope)
4646            .expect("[bms_flex_row hvp_multi parity] upload logslope must succeed on CUDA host");
4647        let storage = DeviceResidentRowHess {
4648            hess: d_h,
4649            marginal_design: d_m,
4650            logslope_design: d_g,
4651            n,
4652            r,
4653            block: block.clone(),
4654            primary: primary.clone(),
4655
4656            bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4657        };
4658        let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4659            .expect("storage scratch budget");
4660        assert!(
4661            scratch < storage.bytes,
4662            "multi-RHS scratch should stay below resident cache bytes"
4663        );
4664        let gpu = launch_bms_flex_row_hvp_multi(&storage, &v_rhs, rhs_count)
4665            .expect("multi-RHS HVP kernel must launch on CUDA host");
4666        assert_eq!(gpu.len(), rhs_count * p_total);
4667        for rhs in 0..rhs_count {
4668            let v = &v_rhs[rhs * p_total..(rhs + 1) * p_total];
4669            let cpu = cpu_oracle_bms_flex_row_hvp(
4670                &row_hessians,
4671                &marginal,
4672                &logslope,
4673                &block,
4674                &primary,
4675                n,
4676                v,
4677            );
4678            let single = launch_bms_flex_row_hvp(&storage, v)
4679                .expect("single-RHS HVP kernel must launch on CUDA host");
4680            for j in 0..p_total {
4681                let got = gpu[rhs * p_total + j];
4682                let diff = (cpu[j] - got).abs();
4683                assert!(
4684                    diff <= 1e-10,
4685                    "multi-RHS HVP rhs={rhs} j={j}: cpu={} gpu={} |diff|={diff:.3e}",
4686                    cpu[j],
4687                    got
4688                );
4689                assert_eq!(
4690                    got, single[j],
4691                    "multi-RHS and single-RHS host launch diverged at rhs={rhs} j={j}"
4692                );
4693            }
4694        }
4695    }
4696
4697    /// Parity for the third launch mode — device-output HVP
4698    /// ([`launch_bms_flex_row_hvp_into_device`]) — which the
4699    /// `run_bms_flex_row_partial_reduce` unification routes through the same
4700    /// partial+reduce engine as the host-returning `_hvp` / `_diagonal`
4701    /// adapters. Confirms that keeping the result on-stream (no internal sync /
4702    /// DtoH) reaches bit-identical output to both the CPU oracle and the
4703    /// host-out adapter, so the engine's mode/output split is faithful.
4704    ///
4705    /// Skips cleanly on non-Linux / no-CUDA hosts using the convention shared
4706    /// with the sibling parity tests.
4707    #[test]
4708    pub(crate) fn bms_flex_row_hvp_into_device_matches_cpu_oracle_and_host_out() {
4709        #[cfg(not(target_os = "linux"))]
4710        {
4711            eprintln!(
4712                "[bms_flex_row hvp_into_device parity] non-Linux host — skipping \
4713                 CUDA parity (CPU oracle exercised by sibling tests)"
4714            );
4715        }
4716        #[cfg(target_os = "linux")]
4717        {
4718            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4719                eprintln!(
4720                    "[bms_flex_row hvp_into_device parity] no CUDA runtime — \
4721                     skipping device parity"
4722                );
4723                return;
4724            };
4725            let n = 4_usize;
4726            let r = 4_usize;
4727            let p_m = 2_usize;
4728            let p_g = 2_usize;
4729            let p_h_dim = 1_usize;
4730            let p_w_dim = 1_usize;
4731            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4732            let block = BmsFlexBlockLayout {
4733                p_m,
4734                p_g,
4735                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4736                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4737                p_total,
4738            };
4739            let primary = BmsFlexPrimaryLayout {
4740                h: Some(2..3),
4741                w: Some(3..4),
4742                r,
4743            };
4744            let mut row_hessians = vec![0.0_f64; n * r * r];
4745            for row in 0..n {
4746                for u in 0..r {
4747                    for v in u..r {
4748                        let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4749                        row_hessians[row * r * r + u * r + v] = val;
4750                        row_hessians[row * r * r + v * r + u] = val;
4751                    }
4752                }
4753            }
4754            let mut marginal = vec![0.0_f64; n * p_m];
4755            for row in 0..n {
4756                for j in 0..p_m {
4757                    marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4758                }
4759            }
4760            let mut logslope = vec![0.0_f64; n * p_g];
4761            for row in 0..n {
4762                for j in 0..p_g {
4763                    logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4764                }
4765            }
4766            let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4767            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4768                &row_hessians,
4769                &marginal,
4770                &logslope,
4771                &block,
4772                &primary,
4773                n,
4774                &v,
4775            );
4776
4777            // Past the GpuRuntime::global() Some-gate: probe/upload failures are
4778            // real device faults on a CUDA host — fail loud (device-PCG class).
4779            let backend = HvpKernelBackend::probe().expect(
4780                "[bms_flex_row hvp_into_device parity] backend probe must succeed on CUDA host",
4781            );
4782            let stream = backend.stream.clone();
4783            let d_h = stream
4784                .clone_htod(&row_hessians)
4785                .expect("[bms_flex_row hvp_into_device parity] upload h must succeed on CUDA host");
4786            let d_m = stream.clone_htod(&marginal).expect(
4787                "[bms_flex_row hvp_into_device parity] upload marg must succeed on CUDA host",
4788            );
4789            let d_g = stream.clone_htod(&logslope).expect(
4790                "[bms_flex_row hvp_into_device parity] upload logslope must succeed on CUDA host",
4791            );
4792            let storage = DeviceResidentRowHess {
4793                hess: d_h,
4794                marginal_design: d_m,
4795                logslope_design: d_g,
4796                n,
4797                r,
4798                block: block.clone(),
4799                primary: primary.clone(),
4800
4801                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4802            };
4803
4804            // Host-out adapter (allocates its own d_out, syncs + downloads).
4805            let host_out_hvp = launch_bms_flex_row_hvp(&storage, &v)
4806                .expect("host-out HVP kernel must launch on CUDA host");
4807
4808            // Device-out adapter: caller owns d_v + d_out; the engine performs
4809            // no sync / DtoH, so we synchronize + download here.
4810            let d_v = stream
4811                .clone_htod(&v)
4812                .expect("upload direction for device-out HVP");
4813            let mut d_out = stream
4814                .alloc_zeros::<f64>(p_total)
4815                .expect("alloc device-out HVP output");
4816            launch_bms_flex_row_hvp_into_device(&storage, &d_v, &mut d_out)
4817                .expect("device-out HVP kernel must launch on CUDA host");
4818            stream
4819                .synchronize()
4820                .expect("synchronize after device-out HVP");
4821            let device_out_hvp = stream
4822                .clone_dtoh(&d_out)
4823                .expect("download device-out HVP output");
4824
4825            assert_eq!(device_out_hvp.len(), cpu_hvp.len());
4826            assert_eq!(device_out_hvp.len(), host_out_hvp.len());
4827            for i in 0..p_total {
4828                let diff = (cpu_hvp[i] - device_out_hvp[i]).abs();
4829                assert!(
4830                    diff <= 1e-10,
4831                    "device-out HVP[{i}] vs CPU: cpu={} gpu={} |Δ|={diff:.3e}",
4832                    cpu_hvp[i],
4833                    device_out_hvp[i]
4834                );
4835                // Both adapters share the engine; the only difference is the
4836                // copy-back path, so they must be bit-identical.
4837                let host_diff = (host_out_hvp[i] - device_out_hvp[i]).abs();
4838                assert!(
4839                    host_diff == 0.0,
4840                    "device-out vs host-out HVP[{i}]: host={} device={} |Δ|={host_diff:.3e}",
4841                    host_out_hvp[i],
4842                    device_out_hvp[i]
4843                );
4844            }
4845        }
4846    }
4847
4848    /// Block 9 Phase 2 parity gate at the shape specified by the
4849    /// charter task: `n = 64`, `r = 20`, `p_total = 44`. Splits
4850    /// `p_total` as `p_m = 14`, `p_g = 12`, `p_h = 10`, `p_w = 8` so
4851    /// `r = 2 + p_h + p_w = 20` and every primary block participates
4852    /// in both the device pullback and the reduce pass. Tolerance is
4853    /// `|Δ| ≤ 1e-8` per the task description (looser than the 1e-10
4854    /// hand-fixture parity, since accumulation order across HVP CTAs
4855    /// differs from the CPU oracle's row-major sum even with the
4856    /// deterministic reduction policy).
4857    ///
4858    /// Skips cleanly on non-Linux and no-CUDA hosts using the same
4859    /// convention as the hand-fixture parity above.
4860    #[test]
4861    pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_at_n64_r20_p44() {
4862        #[cfg(not(target_os = "linux"))]
4863        {
4864            eprintln!(
4865                "[bms_flex_row hvp parity n64_r20_p44] non-Linux host — \
4866                 skipping CUDA parity"
4867            );
4868        }
4869        #[cfg(target_os = "linux")]
4870        {
4871            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4872                eprintln!(
4873                    "[bms_flex_row hvp parity n64_r20_p44] no CUDA runtime — \
4874                     skipping device parity"
4875                );
4876                return;
4877            };
4878            let n = 64_usize;
4879            let p_m = 14_usize;
4880            let p_g = 12_usize;
4881            let p_h_dim = 10_usize;
4882            let p_w_dim = 8_usize;
4883            let r = 2 + p_h_dim + p_w_dim;
4884            assert_eq!(r, 20);
4885            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4886            assert_eq!(p_total, 44);
4887            let block = BmsFlexBlockLayout {
4888                p_m,
4889                p_g,
4890                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4891                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4892                p_total,
4893            };
4894            let primary = BmsFlexPrimaryLayout {
4895                h: Some(2..2 + p_h_dim),
4896                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4897                r,
4898            };
4899
4900            // Deterministic symmetric per-row Hessians + designs +
4901            // direction. Same scrambling family as
4902            // `row_hessian_ops::tests::make_fixture` so any regression
4903            // surfaces consistently across the host-pinned and
4904            // device-resident parity tests.
4905            let mut row_hessians = vec![0.0_f64; n * r * r];
4906            for row in 0..n {
4907                let base = row * r * r;
4908                for u in 0..r {
4909                    for v in 0..r {
4910                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
4911                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4912                        row_hessians[base + u * r + v] = a;
4913                    }
4914                }
4915                for u in 0..r {
4916                    for v in (u + 1)..r {
4917                        let upper = row_hessians[base + u * r + v];
4918                        let lower = row_hessians[base + v * r + u];
4919                        let sym = 0.5 * (upper + lower);
4920                        row_hessians[base + u * r + v] = sym;
4921                        row_hessians[base + v * r + u] = sym;
4922                    }
4923                    row_hessians[base + u * r + u] += r as f64;
4924                }
4925            }
4926            let mut marginal = vec![0.0_f64; n * p_m];
4927            for row in 0..n {
4928                for j in 0..p_m {
4929                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4930                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4931                }
4932            }
4933            let mut logslope = vec![0.0_f64; n * p_g];
4934            for row in 0..n {
4935                for j in 0..p_g {
4936                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4937                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4938                }
4939            }
4940            let v: Vec<f64> = (0..p_total)
4941                .map(|i| {
4942                    let seed = (i as f64) * 0.157 + 0.6;
4943                    seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4944                })
4945                .collect();
4946
4947            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4948                &row_hessians,
4949                &marginal,
4950                &logslope,
4951                &block,
4952                &primary,
4953                n,
4954                &v,
4955            );
4956            let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4957                &row_hessians,
4958                &marginal,
4959                &logslope,
4960                &block,
4961                &primary,
4962                n,
4963            );
4964
4965            let backend = match HvpKernelBackend::probe() {
4966                Ok(b) => b,
4967                Err(err) => {
4968                    eprintln!(
4969                        "[bms_flex_row hvp parity n64_r20_p44] backend probe \
4970                         failed: {err}"
4971                    );
4972                    return;
4973                }
4974            };
4975            let stream = backend.stream.clone();
4976            let d_h = match stream.clone_htod(&row_hessians) {
4977                Ok(s) => s,
4978                Err(err) => {
4979                    eprintln!(
4980                        "[bms_flex_row hvp parity n64_r20_p44] upload h \
4981                         failed: {err}"
4982                    );
4983                    return;
4984                }
4985            };
4986            let d_m = match stream.clone_htod(&marginal) {
4987                Ok(s) => s,
4988                Err(err) => {
4989                    eprintln!(
4990                        "[bms_flex_row hvp parity n64_r20_p44] upload marg \
4991                         failed: {err}"
4992                    );
4993                    return;
4994                }
4995            };
4996            let d_g = match stream.clone_htod(&logslope) {
4997                Ok(s) => s,
4998                Err(err) => {
4999                    eprintln!(
5000                        "[bms_flex_row hvp parity n64_r20_p44] upload logslope \
5001                         failed: {err}"
5002                    );
5003                    return;
5004                }
5005            };
5006            let storage = DeviceResidentRowHess {
5007                hess: d_h,
5008                marginal_design: d_m,
5009                logslope_design: d_g,
5010                n,
5011                r,
5012                block: block.clone(),
5013                primary: primary.clone(),
5014
5015                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5016            };
5017            let gpu_hvp = launch_bms_flex_row_hvp(&storage, &v)
5018                .expect("HVP kernel must launch on CUDA host at n64/r20/p44");
5019            let gpu_diag = launch_bms_flex_row_diagonal(&storage)
5020                .expect("diagonal kernel must launch on CUDA host at n64/r20/p44");
5021            assert_eq!(gpu_hvp.len(), cpu_hvp.len());
5022            assert_eq!(gpu_diag.len(), cpu_diag.len());
5023            for i in 0..p_total {
5024                let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
5025                assert!(
5026                    diff <= 1e-8,
5027                    "n64_r20_p44 HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
5028                    cpu_hvp[i],
5029                    gpu_hvp[i]
5030                );
5031                let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
5032                assert!(
5033                    ddiff <= 1e-8,
5034                    "n64_r20_p44 diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
5035                    cpu_diag[i],
5036                    gpu_diag[i]
5037                );
5038            }
5039        }
5040    }
5041
5042    /// Block 9 Phase 6 — small-fixture parity for the dense-block kernel
5043    /// against the host-side P_i pullback oracle.
5044    /// Verifies bit-equality (modulo reduction-order f.p. noise) between
5045    /// the device-resident dense build and the host accumulator over the
5046    /// same per-row Hessian + designs + P_i pullback.
5047    #[test]
5048    pub(crate) fn bms_flex_row_dense_block_kernel_matches_cpu_pullback() {
5049        #[cfg(not(target_os = "linux"))]
5050        {
5051            eprintln!("[bms_flex_row dense_block parity] non-Linux host — skipping CUDA parity");
5052        }
5053        #[cfg(target_os = "linux")]
5054        {
5055            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5056                eprintln!("[bms_flex_row dense_block parity] no CUDA runtime — skipping");
5057                return;
5058            };
5059            // Small fixture: n=24, r=8 (2 + 3 + 3), p_total=18 (4+4+3+3).
5060            // Keeps the CPU pullback fast while still exercising every
5061            // primary slot (q, g, h, w).
5062            let n = 24_usize;
5063            let p_m = 4_usize;
5064            let p_g = 4_usize;
5065            let p_h_dim = 3_usize;
5066            let p_w_dim = 3_usize;
5067            let r = 2 + p_h_dim + p_w_dim;
5068            let p_total = p_m + p_g + p_h_dim + p_w_dim;
5069            let block = BmsFlexBlockLayout {
5070                p_m,
5071                p_g,
5072                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5073                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5074                p_total,
5075            };
5076            let primary = BmsFlexPrimaryLayout {
5077                h: Some(2..2 + p_h_dim),
5078                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5079                r,
5080            };
5081
5082            let mut row_hessians = vec![0.0_f64; n * r * r];
5083            for row in 0..n {
5084                let base = row * r * r;
5085                for u in 0..r {
5086                    for v in 0..r {
5087                        let seed = (row as f64) * 0.21 + (u as f64) * 1.13 + (v as f64) * 0.47;
5088                        let a = (seed.sin() * 1.4 + (seed * 0.6).cos() * 0.7) * 0.5;
5089                        row_hessians[base + u * r + v] = a;
5090                    }
5091                }
5092                for u in 0..r {
5093                    for v in (u + 1)..r {
5094                        let upper = row_hessians[base + u * r + v];
5095                        let lower = row_hessians[base + v * r + u];
5096                        let sym = 0.5 * (upper + lower);
5097                        row_hessians[base + u * r + v] = sym;
5098                        row_hessians[base + v * r + u] = sym;
5099                    }
5100                    row_hessians[base + u * r + u] += r as f64;
5101                }
5102            }
5103            let mut marginal = vec![0.0_f64; n * p_m];
5104            for row in 0..n {
5105                for j in 0..p_m {
5106                    let seed = (row as f64) * 0.083 + (j as f64) * 0.171 + 0.31;
5107                    marginal[row * p_m + j] = seed.sin() * 0.7 - (seed * 0.5).cos() * 0.25;
5108                }
5109            }
5110            let mut logslope = vec![0.0_f64; n * p_g];
5111            for row in 0..n {
5112                for j in 0..p_g {
5113                    let seed = (row as f64) * 0.097 + (j as f64) * 0.143 - 0.15;
5114                    logslope[row * p_g + j] = seed.cos() * 0.65 + (seed * 0.4).sin() * 0.2;
5115                }
5116            }
5117
5118            // CPU oracle — same pullback math the device kernel mirrors.
5119            let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5120            let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5121            let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5122            let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5123            let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5124            let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5125            let mut h_cpu = vec![0.0_f64; p_total * p_total];
5126            for row in 0..n {
5127                let mrow = &marginal[row * p_m..(row + 1) * p_m];
5128                let grow = &logslope[row * p_g..(row + 1) * p_g];
5129                let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5130                // Build per-row phi (r length-p_total vectors).
5131                let mut phi = vec![vec![0.0_f64; p_total]; r];
5132                for k in 0..p_m {
5133                    phi[0][k] = mrow[k];
5134                }
5135                for k in 0..p_g {
5136                    phi[1][p_m + k] = grow[k];
5137                }
5138                for k in 0..h_block_len {
5139                    phi[h_primary_start + k][h_block_start + k] = 1.0;
5140                }
5141                for k in 0..w_block_len {
5142                    phi[w_primary_start + k][w_block_start + k] = 1.0;
5143                }
5144                for u in 0..r {
5145                    for v in 0..r {
5146                        let huv = hrow[u * r + v];
5147                        if huv == 0.0 {
5148                            continue;
5149                        }
5150                        for m in 0..p_total {
5151                            let pm = phi[u][m];
5152                            if pm == 0.0 {
5153                                continue;
5154                            }
5155                            let scaled = huv * pm;
5156                            for nn in 0..p_total {
5157                                h_cpu[m * p_total + nn] += scaled * phi[v][nn];
5158                            }
5159                        }
5160                    }
5161                }
5162            }
5163
5164            // Build a transient device-resident storage and launch the
5165            // dense-block kernel.
5166            // Past the GpuRuntime::global() Some-gate: probe/upload failures are
5167            // real device faults on a CUDA host — fail loud (device-PCG class).
5168            let backend = HvpKernelBackend::probe().expect(
5169                "[bms_flex_row dense_block parity] backend probe must succeed on CUDA host",
5170            );
5171            let stream = backend.stream.clone();
5172            let d_h = stream
5173                .clone_htod(&row_hessians)
5174                .expect("[bms_flex_row dense_block parity] upload h must succeed on CUDA host");
5175            let d_m = stream
5176                .clone_htod(&marginal)
5177                .expect("[bms_flex_row dense_block parity] upload marg must succeed on CUDA host");
5178            let d_g = stream.clone_htod(&logslope).expect(
5179                "[bms_flex_row dense_block parity] upload logslope must succeed on CUDA host",
5180            );
5181            let storage = DeviceResidentRowHess {
5182                hess: d_h,
5183                marginal_design: d_m,
5184                logslope_design: d_g,
5185                n,
5186                r,
5187                block: block.clone(),
5188                primary: primary.clone(),
5189
5190                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5191            };
5192            let h_gpu = launch_bms_flex_row_dense_block(&storage)
5193                .expect("dense_block kernel must launch on CUDA host");
5194            assert_eq!(h_gpu.len(), p_total * p_total);
5195
5196            // Compare entry-by-entry with a tolerance that absorbs
5197            // reduction-order f.p. noise from the CTA chunk sum.
5198            let mut max_abs = 0.0_f64;
5199            for i in 0..p_total {
5200                for j in 0..p_total {
5201                    let a = h_cpu[i * p_total + j];
5202                    let b = h_gpu[i * p_total + j];
5203                    let diff = (a - b).abs();
5204                    if diff > max_abs {
5205                        max_abs = diff;
5206                    }
5207                    assert!(
5208                        diff <= 1e-9 * a.abs().max(b.abs()).max(1.0),
5209                        "dense_block[{i},{j}]: cpu={a} gpu={b} |Δ|={diff:.3e}"
5210                    );
5211                }
5212            }
5213            eprintln!(
5214                "[bms_flex_row dense_block parity] n={n} r={r} p={p_total}: max|Δ|={max_abs:.3e}"
5215            );
5216        }
5217    }
5218
5219    /// Block 9 final hill-climb gate — GPU HVP must be at least 5× faster
5220    /// than a Rayon-parallel CPU HVP at large-scale shape (n=195_000, r=20,
5221    /// p_total=44). This is the charter pass/fail metric for whether the
5222    /// device-resident row-Hessian path is a real perf win for the
5223    /// production marginal-slope fit.
5224    ///
5225    /// Methodology:
5226    ///   * Build the same deterministic fixture as the parity tests.
5227    ///   * GPU: median of `iters` `launch_bms_flex_row_hvp` wall-times
5228    ///     after `warmup` warm-up launches (kernel compile + L2 prime).
5229    ///   * CPU: median of `iters` `cpu_oracle_bms_flex_row_hvp` wall-times,
5230    ///     parallelised over rows via Rayon — this mirrors the actual
5231    ///     production CPU path in
5232    ///     `exact_newton_joint_hessian_matvec_from_cache` (which uses
5233    ///     `ROW_CHUNK_SIZE` chunked `into_par_iter()` for the same
5234    ///     contraction).
5235    ///   * Ratio = cpu_median / gpu_median; assert ratio >= 5.
5236    ///
5237    /// Skips on non-Linux / no-CUDA hosts.
5238    #[test]
5239    pub(crate) fn bms_flex_row_hvp_v100_hill_climb_5x_vs_cpu_at_large_scale() {
5240        #[cfg(not(target_os = "linux"))]
5241        {
5242            eprintln!("[bms_flex_row hvp hill-climb] non-Linux host — skipping V100 perf gate");
5243        }
5244        #[cfg(target_os = "linux")]
5245        {
5246            use rayon::prelude::*;
5247
5248            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5249                eprintln!(
5250                    "[bms_flex_row hvp hill-climb] no CUDA runtime — skipping V100 perf gate"
5251                );
5252                return;
5253            };
5254            let n = 195_000_usize;
5255            let p_m = 14_usize;
5256            let p_g = 12_usize;
5257            let p_h_dim = 10_usize;
5258            let p_w_dim = 8_usize;
5259            let r = 2 + p_h_dim + p_w_dim;
5260            let p_total = p_m + p_g + p_h_dim + p_w_dim;
5261            let block = BmsFlexBlockLayout {
5262                p_m,
5263                p_g,
5264                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5265                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5266                p_total,
5267            };
5268            let primary = BmsFlexPrimaryLayout {
5269                h: Some(2..2 + p_h_dim),
5270                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5271                r,
5272            };
5273
5274            // Same deterministic fixture as the Phase 4 large-scale benchmark.
5275            let mut row_hessians = vec![0.0_f64; n * r * r];
5276            for row in 0..n {
5277                let base = row * r * r;
5278                for u in 0..r {
5279                    for vv in 0..r {
5280                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5281                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5282                        row_hessians[base + u * r + vv] = a;
5283                    }
5284                }
5285                for u in 0..r {
5286                    for vv in (u + 1)..r {
5287                        let upper = row_hessians[base + u * r + vv];
5288                        let lower = row_hessians[base + vv * r + u];
5289                        let sym = 0.5 * (upper + lower);
5290                        row_hessians[base + u * r + vv] = sym;
5291                        row_hessians[base + vv * r + u] = sym;
5292                    }
5293                    row_hessians[base + u * r + u] += r as f64;
5294                }
5295            }
5296            let mut marginal = vec![0.0_f64; n * p_m];
5297            for row in 0..n {
5298                for j in 0..p_m {
5299                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5300                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5301                }
5302            }
5303            let mut logslope = vec![0.0_f64; n * p_g];
5304            for row in 0..n {
5305                for j in 0..p_g {
5306                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5307                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5308                }
5309            }
5310            let v: Vec<f64> = (0..p_total)
5311                .map(|i| {
5312                    let seed = (i as f64) * 0.157 + 0.6;
5313                    seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
5314                })
5315                .collect();
5316
5317            // ── GPU side: upload once, time HVP launches ─────────────
5318            let backend = match HvpKernelBackend::probe() {
5319                Ok(b) => b,
5320                Err(err) => {
5321                    eprintln!("[bms_flex_row hvp hill-climb] backend probe failed: {err}");
5322                    return;
5323                }
5324            };
5325            let stream = backend.stream.clone();
5326            let d_h = match stream.clone_htod(&row_hessians) {
5327                Ok(s) => s,
5328                Err(err) => {
5329                    eprintln!("[bms_flex_row hvp hill-climb] upload h failed (likely OOM): {err}");
5330                    return;
5331                }
5332            };
5333            let d_m = match stream.clone_htod(&marginal) {
5334                Ok(s) => s,
5335                Err(err) => {
5336                    eprintln!("[bms_flex_row hvp hill-climb] upload marg failed: {err}");
5337                    return;
5338                }
5339            };
5340            let d_g = match stream.clone_htod(&logslope) {
5341                Ok(s) => s,
5342                Err(err) => {
5343                    eprintln!("[bms_flex_row hvp hill-climb] upload logslope failed: {err}");
5344                    return;
5345                }
5346            };
5347            let storage = DeviceResidentRowHess {
5348                hess: d_h,
5349                marginal_design: d_m,
5350                logslope_design: d_g,
5351                n,
5352                r,
5353                block: block.clone(),
5354                primary: primary.clone(),
5355
5356                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5357            };
5358            let warmup: usize = 3;
5359            let iters: usize = 15;
5360            for _ in 0..warmup {
5361                let out =
5362                    launch_bms_flex_row_hvp(&storage, &v).expect("warmup GPU HVP must launch");
5363                assert_eq!(out.len(), p_total);
5364            }
5365            let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5366            for _ in 0..iters {
5367                let t0 = std::time::Instant::now();
5368                let out = launch_bms_flex_row_hvp(&storage, &v).expect("GPU HVP must launch");
5369                gpu_us.push(t0.elapsed().as_micros());
5370                assert_eq!(out.len(), p_total);
5371            }
5372            gpu_us.sort_unstable();
5373            let gpu_median = gpu_us[iters / 2];
5374
5375            // ── CPU side: chunked Rayon HVP over rows, mirroring the
5376            //    production `exact_newton_joint_hessian_matvec_from_cache`
5377            //    parallelisation pattern (ROW_CHUNK_SIZE-row chunks,
5378            //    try_fold + try_reduce). The per-chunk worker calls the
5379            //    single-threaded oracle on its row slice.
5380            const CHUNK_ROWS: usize = 4096;
5381            let cpu_hvp_parallel = || -> Vec<f64> {
5382                let nchunks = n.div_ceil(CHUNK_ROWS);
5383                (0..nchunks)
5384                    .into_par_iter()
5385                    .fold(
5386                        || vec![0.0_f64; p_total],
5387                        |mut acc, ci| {
5388                            let lo = ci * CHUNK_ROWS;
5389                            let hi = (lo + CHUNK_ROWS).min(n);
5390                            let m = hi - lo;
5391                            let partial = cpu_oracle_bms_flex_row_hvp(
5392                                &row_hessians[lo * r * r..hi * r * r],
5393                                &marginal[lo * p_m..hi * p_m],
5394                                &logslope[lo * p_g..hi * p_g],
5395                                &block,
5396                                &primary,
5397                                m,
5398                                &v,
5399                            );
5400                            for (a, &p) in acc.iter_mut().zip(partial.iter()) {
5401                                *a += p;
5402                            }
5403                            acc
5404                        },
5405                    )
5406                    .reduce(
5407                        || vec![0.0_f64; p_total],
5408                        |mut a, b| {
5409                            for (ax, bx) in a.iter_mut().zip(b.iter()) {
5410                                *ax += *bx;
5411                            }
5412                            a
5413                        },
5414                    )
5415            };
5416            // Warmup once to populate L3 / steady-state Rayon thread pool.
5417            let warm = cpu_hvp_parallel();
5418            assert_eq!(warm.len(), p_total);
5419            let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5420            for _ in 0..iters {
5421                let t0 = std::time::Instant::now();
5422                let out = cpu_hvp_parallel();
5423                cpu_us.push(t0.elapsed().as_micros());
5424                assert_eq!(out.len(), p_total);
5425            }
5426            cpu_us.sort_unstable();
5427            let cpu_median = cpu_us[iters / 2];
5428
5429            let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5430            eprintln!(
5431                "[bms_flex_row hvp hill-climb] large-scale n={n} r={r} p={p_total}: \
5432                 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5433                 speedup={speedup:.2}× (charter target ≥ 5×)"
5434            );
5435            assert!(
5436                speedup >= 5.0,
5437                "large-scale HVP perf gate: GPU only {speedup:.2}× faster than CPU; \
5438                 need ≥ 5× per Block 9 charter (cpu_median={cpu_median}us, \
5439                 gpu_median={gpu_median}us). Hill-climb the kernel until met or \
5440                 prove the kernel is at hardware roofline."
5441            );
5442        }
5443    }
5444
5445    /// Companion to the HVP hill-climb: GPU dense-block build must be at
5446    /// least 10× faster than a Rayon-parallel CPU dense build at large-scale
5447    /// shape. The dense build is `O(n * r² * p_total)` work for both
5448    /// paths so the ratio is well-defined.
5449    #[test]
5450    pub(crate) fn bms_flex_row_dense_block_v100_hill_climb_10x_vs_cpu_at_large_scale() {
5451        #[cfg(not(target_os = "linux"))]
5452        {
5453            eprintln!(
5454                "[bms_flex_row dense_block hill-climb] non-Linux host — skipping V100 perf gate"
5455            );
5456        }
5457        #[cfg(target_os = "linux")]
5458        {
5459            use rayon::prelude::*;
5460
5461            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5462                eprintln!(
5463                    "[bms_flex_row dense_block hill-climb] no CUDA runtime — skipping V100 perf gate"
5464                );
5465                return;
5466            };
5467            let n = 195_000_usize;
5468            let p_m = 14_usize;
5469            let p_g = 12_usize;
5470            let p_h_dim = 10_usize;
5471            let p_w_dim = 8_usize;
5472            let r = 2 + p_h_dim + p_w_dim;
5473            let p_total = p_m + p_g + p_h_dim + p_w_dim;
5474            let block = BmsFlexBlockLayout {
5475                p_m,
5476                p_g,
5477                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5478                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5479                p_total,
5480            };
5481            let primary = BmsFlexPrimaryLayout {
5482                h: Some(2..2 + p_h_dim),
5483                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5484                r,
5485            };
5486
5487            // Reuse the same large-scale fixture recipe.
5488            let mut row_hessians = vec![0.0_f64; n * r * r];
5489            for row in 0..n {
5490                let base = row * r * r;
5491                for u in 0..r {
5492                    for vv in 0..r {
5493                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5494                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5495                        row_hessians[base + u * r + vv] = a;
5496                    }
5497                }
5498                for u in 0..r {
5499                    for vv in (u + 1)..r {
5500                        let upper = row_hessians[base + u * r + vv];
5501                        let lower = row_hessians[base + vv * r + u];
5502                        let sym = 0.5 * (upper + lower);
5503                        row_hessians[base + u * r + vv] = sym;
5504                        row_hessians[base + vv * r + u] = sym;
5505                    }
5506                    row_hessians[base + u * r + u] += r as f64;
5507                }
5508            }
5509            let mut marginal = vec![0.0_f64; n * p_m];
5510            for row in 0..n {
5511                for j in 0..p_m {
5512                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5513                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5514                }
5515            }
5516            let mut logslope = vec![0.0_f64; n * p_g];
5517            for row in 0..n {
5518                for j in 0..p_g {
5519                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5520                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5521                }
5522            }
5523
5524            // GPU dense_block kernel rejects p_total > DENSE_BLOCK_MAX_P
5525            // (72 at V100 48 KiB/block). LargeScale's p_total = 44 fits.
5526            if p_total > DENSE_BLOCK_MAX_P {
5527                eprintln!(
5528                    "[bms_flex_row dense_block hill-climb] p_total={p_total} > MAX={DENSE_BLOCK_MAX_P}, skipping"
5529                );
5530                return;
5531            }
5532            let backend = match HvpKernelBackend::probe() {
5533                Ok(b) => b,
5534                Err(err) => {
5535                    eprintln!("[bms_flex_row dense_block hill-climb] backend probe failed: {err}");
5536                    return;
5537                }
5538            };
5539            let stream = backend.stream.clone();
5540            let d_h = match stream.clone_htod(&row_hessians) {
5541                Ok(s) => s,
5542                Err(err) => {
5543                    eprintln!("[bms_flex_row dense_block hill-climb] upload h failed: {err}");
5544                    return;
5545                }
5546            };
5547            let d_m = match stream.clone_htod(&marginal) {
5548                Ok(s) => s,
5549                Err(err) => {
5550                    eprintln!("[bms_flex_row dense_block hill-climb] upload marg failed: {err}");
5551                    return;
5552                }
5553            };
5554            let d_g = match stream.clone_htod(&logslope) {
5555                Ok(s) => s,
5556                Err(err) => {
5557                    eprintln!(
5558                        "[bms_flex_row dense_block hill-climb] upload logslope failed: {err}"
5559                    );
5560                    return;
5561                }
5562            };
5563            let storage = DeviceResidentRowHess {
5564                hess: d_h,
5565                marginal_design: d_m,
5566                logslope_design: d_g,
5567                n,
5568                r,
5569                block: block.clone(),
5570                primary: primary.clone(),
5571
5572                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5573            };
5574            // Warmup + 5-iter median (dense build is heavier than HVP).
5575            let warmup: usize = 2;
5576            let iters: usize = 5;
5577            for _ in 0..warmup {
5578                let out = launch_bms_flex_row_dense_block(&storage)
5579                    .expect("warmup GPU dense_block must launch");
5580                assert_eq!(out.len(), p_total * p_total);
5581            }
5582            let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5583            for _ in 0..iters {
5584                let t0 = std::time::Instant::now();
5585                let out =
5586                    launch_bms_flex_row_dense_block(&storage).expect("GPU dense_block must launch");
5587                gpu_us.push(t0.elapsed().as_micros());
5588                assert_eq!(out.len(), p_total * p_total);
5589            }
5590            gpu_us.sort_unstable();
5591            let gpu_median = gpu_us[iters / 2];
5592
5593            // CPU side: chunked Rayon dense build over rows. Each chunk
5594            // builds a `[p_total, p_total]` partial then we reduce-add.
5595            const CHUNK_ROWS: usize = 2048;
5596            let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5597            let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5598            let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5599            let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5600            let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5601            let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5602            let cpu_build_parallel = || -> Vec<f64> {
5603                let nchunks = n.div_ceil(CHUNK_ROWS);
5604                (0..nchunks)
5605                    .into_par_iter()
5606                    .fold(
5607                        || vec![0.0_f64; p_total * p_total],
5608                        |mut acc, ci| {
5609                            let lo = ci * CHUNK_ROWS;
5610                            let hi = (lo + CHUNK_ROWS).min(n);
5611                            let mut phi: Vec<Vec<f64>> = vec![vec![0.0_f64; p_total]; r];
5612                            for row in lo..hi {
5613                                for col in phi.iter_mut() {
5614                                    col.iter_mut().for_each(|v| *v = 0.0);
5615                                }
5616                                let mrow = &marginal[row * p_m..(row + 1) * p_m];
5617                                let grow = &logslope[row * p_g..(row + 1) * p_g];
5618                                for k in 0..p_m {
5619                                    phi[0][k] = mrow[k];
5620                                }
5621                                for k in 0..p_g {
5622                                    phi[1][p_m + k] = grow[k];
5623                                }
5624                                for k in 0..h_block_len {
5625                                    phi[h_primary_start + k][h_block_start + k] = 1.0;
5626                                }
5627                                for k in 0..w_block_len {
5628                                    phi[w_primary_start + k][w_block_start + k] = 1.0;
5629                                }
5630                                let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5631                                for u in 0..r {
5632                                    for v_idx in 0..r {
5633                                        let huv = hrow[u * r + v_idx];
5634                                        if huv == 0.0 {
5635                                            continue;
5636                                        }
5637                                        for m in 0..p_total {
5638                                            let pm = phi[u][m];
5639                                            if pm == 0.0 {
5640                                                continue;
5641                                            }
5642                                            let scaled = huv * pm;
5643                                            for nn in 0..p_total {
5644                                                acc[m * p_total + nn] += scaled * phi[v_idx][nn];
5645                                            }
5646                                        }
5647                                    }
5648                                }
5649                            }
5650                            acc
5651                        },
5652                    )
5653                    .reduce(
5654                        || vec![0.0_f64; p_total * p_total],
5655                        |mut a, b| {
5656                            for (ax, bx) in a.iter_mut().zip(b.iter()) {
5657                                *ax += *bx;
5658                            }
5659                            a
5660                        },
5661                    )
5662            };
5663            let warm_cpu = cpu_build_parallel();
5664            assert_eq!(warm_cpu.len(), p_total * p_total);
5665            let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5666            for _ in 0..iters {
5667                let t0 = std::time::Instant::now();
5668                let out = cpu_build_parallel();
5669                cpu_us.push(t0.elapsed().as_micros());
5670                assert_eq!(out.len(), p_total * p_total);
5671            }
5672            cpu_us.sort_unstable();
5673            let cpu_median = cpu_us[iters / 2];
5674
5675            let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5676            eprintln!(
5677                "[bms_flex_row dense_block hill-climb] large-scale n={n} r={r} p={p_total}: \
5678                 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5679                 speedup={speedup:.2}× (charter target ≥ 10×)"
5680            );
5681            assert!(
5682                speedup >= 10.0,
5683                "large-scale dense-H perf gate: GPU only {speedup:.2}× faster than CPU; \
5684                 need ≥ 10× per Block 9 charter (cpu_median={cpu_median}us, \
5685                 gpu_median={gpu_median}us). Hill-climb the dense_block kernel \
5686                 (warp-stripe the u-v-m-n loop, vectorise loads, etc.) until met \
5687                 or prove the kernel is at hardware roofline."
5688            );
5689        }
5690    }
5691}