Skip to main content

gam_models/bms/gpu/
row.rs

1//! Stage 2 of the BMS FLEX row kernel — per-row math that turns per-cell
2//! derivative moments (built by Stage 1 in `src/gpu/cubic_cell/mod.rs`) into a
3//! row gradient and row-primary `r × r` Hessian.
4//!
5//! Math (mirrors the CPU reference
6//! `BernoulliMarginalSlope::compute_row_analytic_flex_from_parts_into` in
7//! `src/families/bernoulli_marginal_slope.rs`):
8//!
9//! For each row `i`, with per-cell cubic predictor coefficients
10//! `C_c = (C0, C1, C2, C3)` and derivative moments `m_0..m_9`, build
11//!
12//! ```text
13//!     κ        = 1 / (2π)
14//!     T_n      = κ · Σ_{e=0..3} C_e · m_{e+n}     (n = 0..6)
15//!     D(R)     = κ · Σ_{k=0..3} R_k · m_k
16//!     Q(R, S)  = Σ_{p,q=0..3} R_p · S_q · T_{p+q}
17//!     H(R, S, U) = D(U) − Q(R, S)
18//! ```
19//!
20//! Per cell `c`, accumulate into row scratch:
21//!
22//! ```text
23//!     F_a   += D(A_c)
24//!     F_aa  += H(A_c, A_c, AA_c)
25//!     F_u   += D(R_{c,u})                         u > 0
26//!     F_au  += H(A_c, R_{c,u}, AR_{c,u})          u > 0
27//!     F_uv  += H(R_{c,u}, R_{c,v}, S_{c,uv})      0 < u ≤ v
28//! ```
29//!
30//! After the cell sum, the `q`-row is overridden:
31//!
32//! ```text
33//!     F_q  = −mu_1
34//!     F_qq = −mu_2
35//!     F_qv = 0   (v > 0)
36//!     F_aq = 0
37//! ```
38//!
39//! Implicit function theorem (single `1/F_a`):
40//!
41//! ```text
42//!     inv_Fa = 1 / F_a
43//!     a_u    = −F_u · inv_Fa                       (q-row override: mu_1 · inv_Fa)
44//!     a_uv   = −(F_uv + F_au·a_v + F_av·a_u + F_aa·a_u·a_v) · inv_Fa
45//! ```
46//!
47//! Observed predictor at `z_obs` (host supplies pre-evaluated chi, xi, rho, tau,
48//! r_uv per row and coordinate):
49//!
50//! ```text
51//!     bar_e_u  = chi_obs · a_u + rho_u
52//!     bar_e_uv = chi_obs · a_uv + xi_obs · a_u · a_v + tau_u · a_v
53//!                + a_u · tau_v + r_uv
54//! ```
55//!
56//! Probit Mills (stable; uses `log_ndtr_and_mills` from `numerics_device::PROBIT_NUMERICS_CU`):
57//!
58//! ```text
59//!     s = 2y − 1 ;  m = s · e_obs
60//!     [log_cdf, λ] = log_ndtr_and_mills(m)
61//!     A = −w · s · λ
62//!     B =  w · λ · (m + λ)
63//! ```
64//!
65//! Final outputs:
66//!
67//! ```text
68//!     neglog   = −w · log_cdf
69//!     g_u      = A · bar_e_u
70//!     H_{uv}   = B · bar_e_u · bar_e_v + A · bar_e_uv     (symmetric)
71//! ```
72//!
73//! Implementation choice (Stage 2): **one CUDA block per row**, with
74//! `blockDim.x = 32` threads. The block's `F_u`, `F_au`, `F_uv`, `bar_e_u`,
75//! `bar_e_uv` live in shared memory; threads in the block parallelise the
76//! per-cell sums, then a single thread of the block (`threadIdx.x == 0`) does
77//! the IFT solve, the observed-point assembly, the Mills evaluation, and the
78//! final gradient + Hessian write-out. With the `r ≤ MAX_R` cap (32) the
79//! shared-memory footprint per block is `r + r + r*r + r + r*r` doubles
80//! = `2r² + 3r` ≤ 2 144 doubles ≈ 17 KB, well below the V100 48 KB per-block
81//! limit. This keeps the implementation simple and avoids per-thread global
82//! scratch (a per-thread `r*r` scratch arena would be ~2 GB at n=195k, r=20).
83
84#[cfg(target_os = "linux")]
85use std::sync::OnceLock;
86
87use gam_gpu::gpu_error::GpuError;
88
89#[cfg(target_os = "linux")]
90use std::sync::Arc;
91
92#[cfg(target_os = "linux")]
93use cudarc::driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
94
95/// Hard ceiling on `r` (= 2 + p_h + p_w). Matches the shared-memory budget
96/// argument in the module docstring: with `MAX_R = 32` the per-block shared
97/// footprint is at most `2·32² + 3·32 = 2 144` doubles = 17 KB.
98pub(crate) const MAX_R: usize = 32;
99
100/// `blockDim.x` for the row kernel. Threads of a row-block parallelise the
101/// per-cell loop; thread 0 of the block finalises the IFT solve. Linux-only
102/// because the kernel launcher that consumes it is Linux-only.
103#[cfg(target_os = "linux")]
104pub(crate) const ROW_KERNEL_THREADS: u32 = 32;
105
106/// Number of cubic predictor coefficients per cell (`C0..C3`) and the matching
107/// support length of `A_c`, `R_{c,u}`, `AA_c`, `AR_{c,u}`, `S_{c,uv}`.
108pub(crate) const COEFF4: usize = 4;
109
110/// Highest moment index touched per cell: `T_n` uses `m_{n+e}` for `e = 0..3`
111/// and `n = 0..6`, so the maximum index is `9`. `MOMENT_STRIDE = 10`.
112pub(crate) const MOMENT_STRIDE: usize = 10;
113
114/// Source of the per-cell derivative moments fed into the row kernel.
115/// Phase-4 wiring: the substrate at `src/gpu/cubic_cell/mod.rs` can produce
116/// these on the GPU; this enum lets the launcher consume them directly
117/// without a DtoH+HtoD round-trip.
118pub(crate) enum CellMomentsSource<'a> {
119    /// Host-resident `[total_cells, MOMENT_STRIDE = 10]` row-major buffer.
120    /// The launcher will HtoD-upload this on every launch.
121    Host(&'a [f64]),
122    /// Device-resident moments already living on the row-kernel backend's
123    /// default stream (which is the same `cuda_context_for(ordinal).default_stream()`
124    /// the cubic-cell substrate uses, so no cross-context copy is needed).
125    /// Length on the device must be `total_cells * MOMENT_STRIDE`. Linux-only.
126    #[cfg(target_os = "linux")]
127    Device(&'a CudaSlice<f64>),
128}
129
130impl<'a> CellMomentsSource<'a> {
131    /// Logical element count of the moments source, used by [`BmsFlexRowKernelInputs::validate`].
132    pub(crate) fn len(&self) -> usize {
133        match self {
134            CellMomentsSource::Host(slice) => slice.len(),
135            #[cfg(target_os = "linux")]
136            CellMomentsSource::Device(d) => d.len(),
137        }
138    }
139}
140
141/// Per-row input bundle for [`launch_bms_flex_row_kernel`].
142///
143/// Coordinate ordering convention: `u = 0` is `a` (the latent intercept and
144/// the variable IFT eliminates); `u = 1` is `b` (slope); `u = 2..2+p_h` is the
145/// score-warp `β_h` block; `u = 2+p_h..2+p_h+p_w` is the link-wiggle `β_w`
146/// block. So `r = 2 + p_h + p_w` and `u = 1` is the `b` (slope) index used by
147/// the sparse `S_{b·h}` / `S_{b·w}` payloads.
148macro_rules! define_bms_flex_row_kernel_input_types {
149    (
150        f64_fields: [$($f64_field:ident),+ $(,)?],
151        u32_fields: [$($u32_field:ident),+ $(,)?],
152        moments_field: $moments_field:ident $(,)?
153    ) => {
154        pub(crate) struct BmsFlexRowKernelInputs<'a> {
155            /// Number of observation rows.
156            pub n_rows: usize,
157            /// Total primary local dimension. `r = 2 + p_h + p_w`.
158            pub r: usize,
159            /// Number of score-warp basis coordinates.
160            pub p_h: usize,
161            /// Number of link-wiggle basis coordinates.
162            pub p_w: usize,
163            /// Probit frailty scale `S_f` (scalar shared across rows; matches
164            /// `BernoulliMarginalSlope::probit_frailty_scale`).
165            pub s_f: f64,
166            $(pub $f64_field: &'a [f64],)+
167            $(pub $u32_field: &'a [u32],)+
168            pub $moments_field: CellMomentsSource<'a>,
169        }
170
171        /// Owned twin of [`BmsFlexRowKernelInputs`] — every borrowed slice is
172        /// replaced by an owned `Vec`. The buffer fields are declared from the
173        /// same schema as the borrowed launch ABI and converted by
174        /// [`BmsFlexRowKernelInputsOwned::as_borrowed`].
175        pub(crate) struct BmsFlexRowKernelInputsOwned {
176            pub n_rows: usize,
177            pub r: usize,
178            pub p_h: usize,
179            pub p_w: usize,
180            pub s_f: f64,
181            $(pub $f64_field: Vec<f64>,)+
182            $(pub $u32_field: Vec<u32>,)+
183            pub $moments_field: Vec<f64>,
184            /// Phase-4 device-resident moments. When `Some(_)`, the launcher
185            /// skips the host upload and consumes the buffer directly.
186            /// Linux-only field.
187            #[cfg(target_os = "linux")]
188            pub cell_moments_device: Option<CudaSlice<f64>>,
189        }
190
191        impl BmsFlexRowKernelInputsOwned {
192            /// Borrowed view over `self` suitable for
193            /// [`launch_bms_flex_row_kernel`]. The returned struct holds
194            /// references into `self` so the owned bundle must outlive the
195            /// launch.
196            pub(crate) fn as_borrowed(&self) -> BmsFlexRowKernelInputs<'_> {
197                #[cfg(target_os = "linux")]
198                let cell_moments = match self.cell_moments_device.as_ref() {
199                    Some(d) => CellMomentsSource::Device(d),
200                    None => CellMomentsSource::Host(&self.cell_moments),
201                };
202                #[cfg(not(target_os = "linux"))]
203                let cell_moments = CellMomentsSource::Host(&self.cell_moments);
204                BmsFlexRowKernelInputs {
205                    n_rows: self.n_rows,
206                    r: self.r,
207                    p_h: self.p_h,
208                    p_w: self.p_w,
209                    s_f: self.s_f,
210                    $($f64_field: &self.$f64_field,)+
211                    $($u32_field: &self.$u32_field,)+
212                    $moments_field: cell_moments,
213                }
214            }
215        }
216    };
217}
218
219define_bms_flex_row_kernel_input_types! {
220    f64_fields: [
221        q,
222        b,
223        mu_1,
224        mu_2,
225        z_obs,
226        y,
227        w,
228        e_obs,
229        cell_c0,
230        cell_c1,
231        cell_c2,
232        cell_c3,
233        cell_a,
234        cell_aa,
235        cell_r,
236        cell_ar,
237        cell_sbb,
238        cell_sbh,
239        cell_sbw,
240        chi_obs,
241        xi_obs,
242        rho_u,
243        tau_u,
244        r_uv,
245    ],
246    u32_fields: [cell_offsets],
247    moments_field: cell_moments,
248}
249
250/// Per-row outputs produced by [`launch_bms_flex_row_kernel`].
251#[derive(Debug)]
252pub(crate) struct BmsFlexRowKernelOutputs {
253    /// Per-row negative log-likelihood. Length `n_rows`.
254    pub neglog: Vec<f64>,
255    /// Per-row gradient, row-major `[n_rows, r]`.
256    pub grad: Vec<f64>,
257    /// Per-row Hessian, row-major `[n_rows, r*r]`. The kernel writes the full
258    /// symmetric matrix.
259    pub hess: Vec<f64>,
260}
261
262impl<'a> BmsFlexRowKernelInputs<'a> {
263    /// Sanity-check every shape the kernel relies on. This is the only place
264    /// length errors are surfaced — the device kernel assumes valid layout.
265    pub(crate) fn validate(&self) -> Result<(), GpuError> {
266        if self.r == 0 {
267            return Err(GpuError::DriverCallFailed {
268                reason: "bms_flex_row inputs: r must be > 0".to_string(),
269            });
270        }
271        if self.r > MAX_R {
272            return Err(GpuError::DriverCallFailed {
273                reason: format!("bms_flex_row inputs: r={} exceeds MAX_R={MAX_R}", self.r),
274            });
275        }
276        if self.r != 2 + self.p_h + self.p_w {
277            return Err(GpuError::DriverCallFailed {
278                reason: format!(
279                    "bms_flex_row inputs: r={} must equal 2 + p_h({}) + p_w({}) = {}",
280                    self.r,
281                    self.p_h,
282                    self.p_w,
283                    2 + self.p_h + self.p_w
284                ),
285            });
286        }
287        let n = self.n_rows;
288        let check_len = |name: &str, have: usize, want: usize| -> Result<(), GpuError> {
289            if have != want {
290                return Err(GpuError::DriverCallFailed {
291                    reason: format!("bms_flex_row inputs: {name}.len()={have} != {want}"),
292                });
293            }
294            Ok(())
295        };
296        check_len("q", self.q.len(), n)?;
297        check_len("b", self.b.len(), n)?;
298        check_len("mu_1", self.mu_1.len(), n)?;
299        check_len("mu_2", self.mu_2.len(), n)?;
300        check_len("z_obs", self.z_obs.len(), n)?;
301        check_len("y", self.y.len(), n)?;
302        check_len("w", self.w.len(), n)?;
303        check_len("e_obs", self.e_obs.len(), n)?;
304        check_len("chi_obs", self.chi_obs.len(), n)?;
305        check_len("xi_obs", self.xi_obs.len(), n)?;
306        check_len("rho_u", self.rho_u.len(), n * self.r)?;
307        check_len("tau_u", self.tau_u.len(), n * self.r)?;
308        check_len("r_uv", self.r_uv.len(), n * self.r * self.r)?;
309        check_len("cell_offsets", self.cell_offsets.len(), n + 1)?;
310        let total_cells_u32 = self.cell_offsets[n];
311        let total_cells = total_cells_u32 as usize;
312        check_len("cell_c0", self.cell_c0.len(), total_cells)?;
313        check_len("cell_c1", self.cell_c1.len(), total_cells)?;
314        check_len("cell_c2", self.cell_c2.len(), total_cells)?;
315        check_len("cell_c3", self.cell_c3.len(), total_cells)?;
316        check_len("cell_a", self.cell_a.len(), total_cells * COEFF4)?;
317        check_len("cell_aa", self.cell_aa.len(), total_cells * COEFF4)?;
318        check_len(
319            "cell_r",
320            self.cell_r.len(),
321            total_cells * self.r.saturating_sub(1) * COEFF4,
322        )?;
323        check_len(
324            "cell_ar",
325            self.cell_ar.len(),
326            total_cells * self.r.saturating_sub(1) * COEFF4,
327        )?;
328        check_len("cell_sbb", self.cell_sbb.len(), total_cells * COEFF4)?;
329        check_len(
330            "cell_sbh",
331            self.cell_sbh.len(),
332            total_cells * self.p_h * COEFF4,
333        )?;
334        check_len(
335            "cell_sbw",
336            self.cell_sbw.len(),
337            total_cells * self.p_w * COEFF4,
338        )?;
339        check_len(
340            "cell_moments",
341            self.cell_moments.len(),
342            total_cells * MOMENT_STRIDE,
343        )?;
344        // Bonus: when the moments came from `CellMomentsSource::Device`, the
345        // launcher needs to know the source is from a device buffer; nothing
346        // to validate beyond length above. The Host variant length check is
347        // also already covered above.
348        // Monotone cell_offsets check.
349        for i in 0..n {
350            if self.cell_offsets[i] > self.cell_offsets[i + 1] {
351                return Err(GpuError::DriverCallFailed {
352                    reason: format!(
353                        "bms_flex_row inputs: cell_offsets must be monotone (offset[{}]={} > offset[{}]={})",
354                        i,
355                        self.cell_offsets[i],
356                        i + 1,
357                        self.cell_offsets[i + 1]
358                    ),
359                });
360            }
361        }
362        Ok(())
363    }
364}
365
366/// NVRTC kernel source body. One CUDA block per row; 32 threads per block
367/// parallise the per-cell sums into shared-memory scratch; thread 0 of the
368/// block finishes the IFT + observed-point + Mills + Hessian write-out.
369///
370/// Shared probit numerics (`erfcx_nonnegative`, `log_ndtr`,
371/// `log_ndtr_and_mills`) are provided by
372/// `numerics_device::PROBIT_NUMERICS_CU`, which is prepended before
373/// passing to `cudarc::nvrtc::compile_ptx`.
374///
375/// **CPU parity reference**: the body mirrors
376/// `compute_row_analytic_flex_from_parts_into` in
377/// `src/families/bernoulli_marginal_slope.rs`.
378#[cfg(target_os = "linux")]
379pub(crate) const ROW_KERNEL_BODY: &str = r#"
380// One block per row. blockDim.x = 32; threadIdx.x parallises per-cell sums.
381// CPU parity reference: src/families/bernoulli_marginal_slope.rs
382//                      ::compute_row_analytic_flex_from_parts_into.
383
384#define INV_TWO_PI     0.15915494309189535
385
386extern "C" __device__ __forceinline__ double atomic_add_f64(double *addr, double value) {
387    unsigned long long int *addr_as_ull = (unsigned long long int *)addr;
388    unsigned long long int old = *addr_as_ull;
389    unsigned long long int assumed;
390    do {
391        assumed = old;
392        double next = __longlong_as_double((long long int)assumed) + value;
393        old = atomicCAS(addr_as_ull, assumed, (unsigned long long int)__double_as_longlong(next));
394    } while (assumed != old);
395    return __longlong_as_double((long long int)old);
396}
397
398// `nan_fill_outputs`: thread-0-only path used when row inputs are degenerate
399// (`F_a` non-finite or non-positive). Writes NaNs to neglog/grad/hess so the
400// host falls back to CPU for that row.
401extern "C" __device__ __forceinline__ void
402nan_fill_outputs(int r,
403                 int row,
404                 double *out_neglog,
405                 double *out_grad,
406                 double *out_hess) {
407    double nan_value = __longlong_as_double(0x7ff8000000000000ULL);
408    out_neglog[row] = nan_value;
409    for (int u = 0; u < r; ++u) {
410        out_grad[row * r + u] = nan_value;
411    }
412    int rr = r * r;
413    for (int idx = 0; idx < rr; ++idx) {
414        out_hess[row * rr + idx] = nan_value;
415    }
416}
417
418extern "C" __global__ void bms_flex_row_kernel(
419    int                  n_rows,
420    int                  r,
421    int                  p_h,
422    int                  p_w,
423    double               s_f,                // currently unused on device:
424                                             // host has already baked S_f
425                                             // into the cubic coefficients.
426                                             // Kept for diagnostic parity.
427    const double * __restrict__ row_q,
428    const double * __restrict__ row_b,
429    const double * __restrict__ row_mu1,
430    const double * __restrict__ row_mu2,
431    const double * __restrict__ row_zobs,
432    const double * __restrict__ row_y,
433    const double * __restrict__ row_w,
434    const unsigned int * __restrict__ cell_offsets,
435    const double * __restrict__ cell_c0,
436    const double * __restrict__ cell_c1,
437    const double * __restrict__ cell_c2,
438    const double * __restrict__ cell_c3,
439    const double * __restrict__ cell_a,       // [n_cells, 4]
440    const double * __restrict__ cell_aa,      // [n_cells, 4]
441    const double * __restrict__ cell_r,       // [n_cells, r-1, 4]
442    const double * __restrict__ cell_ar,      // [n_cells, r-1, 4]
443    const double * __restrict__ cell_sbb,     // [n_cells, 4]
444    const double * __restrict__ cell_sbh,     // [n_cells, p_h, 4]
445    const double * __restrict__ cell_sbw,     // [n_cells, p_w, 4]
446    const double * __restrict__ cell_moments, // [n_cells, 10]
447    const double * __restrict__ row_chi,
448    const double * __restrict__ row_xi,
449    const double * __restrict__ row_rho,      // [n_rows, r]
450    const double * __restrict__ row_tau,      // [n_rows, r]
451    const double * __restrict__ row_ruv,      // [n_rows, r*r]
452    const double * __restrict__ row_e_obs,    // [n_rows] observed predictor VALUE
453    double       * __restrict__ out_neglog,
454    double       * __restrict__ out_grad,
455    double       * __restrict__ out_hess)
456{
457    int row = blockIdx.x;
458    if (row >= n_rows) return;
459    int tid = threadIdx.x;
460
461    // ── shared scratch (sized to MAX_R = 32) ──────────────────────────────
462    // Layout (doubles):
463    //   F_u      [r]
464    //   F_au     [r]
465    //   F_uv     [r*r]
466    //   bar_e_u  [r]
467    //   bar_e_uv [r*r]
468    //   reduce_a [blockDim.x]
469    //   reduce_b [blockDim.x]
470    // Sized for the worst case (r = MAX_R = 32).
471    __shared__ double F_u[32];
472    __shared__ double F_au[32];
473    __shared__ double F_uv[32 * 32];
474    __shared__ double bar_e_u[32];
475    __shared__ double bar_e_uv[32 * 32];
476    __shared__ double reduce_a[32];
477    __shared__ double reduce_b[32];
478    __shared__ double F_a_shared;
479    __shared__ double F_aa_shared;
480
481    // Zero scratch.
482    if (tid == 0) { F_a_shared = 0.0; F_aa_shared = 0.0; }
483    for (int u = tid; u < r; u += blockDim.x) {
484        F_u[u]  = 0.0;
485        F_au[u] = 0.0;
486    }
487    for (int uv = tid; uv < r * r; uv += blockDim.x) {
488        F_uv[uv] = 0.0;
489    }
490    __syncthreads();
491
492    // ── per-cell sweep ───────────────────────────────────────────────────
493    unsigned int cell_lo = cell_offsets[row];
494    unsigned int cell_hi = cell_offsets[row + 1];
495    int n_cells = (int)(cell_hi - cell_lo);
496
497    double local_Fa  = 0.0;
498    double local_Faa = 0.0;
499
500    for (int local_c = tid; local_c < n_cells; local_c += blockDim.x) {
501        unsigned int c = cell_lo + (unsigned int)local_c;
502
503        // Load cubic predictor coeffs C0..C3.
504        double C[4];
505        C[0] = cell_c0[c]; C[1] = cell_c1[c];
506        C[2] = cell_c2[c]; C[3] = cell_c3[c];
507
508        // Load m_0..m_9.
509        const double *m = cell_moments + (size_t)c * 10;
510
511        // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
512        // CPU parity: equivalent to the `eta_rs ⊗ moments` contraction in
513        //             `cell_second_derivative_from_moments` after folding the
514        //             cubic predictor.
515        double T[7];
516        #pragma unroll
517        for (int n = 0; n < 7; ++n) {
518            double acc = 0.0;
519            #pragma unroll
520            for (int e = 0; e < 4; ++e) {
521                acc = fma(C[e], m[e + n], acc);
522            }
523            T[n] = acc * INV_TWO_PI;
524        }
525
526        // D(R) = κ · Σ_k R_k · m_k.
527        // CPU parity: `cell_first_derivative_from_moments`.
528        #define D_OF(R) (INV_TWO_PI * (R[0]*m[0] + R[1]*m[1] + R[2]*m[2] + R[3]*m[3]))
529
530        // Q(R, S) = Σ_{p,q} R_p · S_q · T_{p+q}.
531        // CPU parity: the `eta_rs` folded dot in
532        // `cell_second_derivative_from_moments`.
533        #define Q_OF(R, S)                                                                 \
534            ((R[0]*S[0])*T[0] + (R[0]*S[1] + R[1]*S[0])*T[1]                               \
535             + (R[0]*S[2] + R[1]*S[1] + R[2]*S[0])*T[2]                                    \
536             + (R[0]*S[3] + R[1]*S[2] + R[2]*S[1] + R[3]*S[0])*T[3]                        \
537             + (R[1]*S[3] + R[2]*S[2] + R[3]*S[1])*T[4]                                    \
538             + (R[2]*S[3] + R[3]*S[2])*T[5]                                                \
539             + (R[3]*S[3])*T[6])
540
541        // F_a += D(A_c) ; F_aa += H(A_c, A_c, AA_c) = D(AA_c) − Q(A_c, A_c).
542        const double *A_c  = cell_a  + (size_t)c * 4;
543        const double *AA_c = cell_aa + (size_t)c * 4;
544        local_Fa  += D_OF(A_c);
545        local_Faa += D_OF(AA_c) - Q_OF(A_c, A_c);
546
547        // For each u > 0: F_u += D(R_{c,u}) ; F_au += H(A_c, R_{c,u}, AR_{c,u})
548        //                                   = D(AR_{c,u}) − Q(A_c, R_{c,u}).
549        for (int u = 1; u < r; ++u) {
550            const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
551            const double *AR_u = cell_ar + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
552            double d_R   = D_OF(R_u);
553            double d_AR  = D_OF(AR_u);
554            double q_AR  = Q_OF(A_c, R_u);
555            atomic_add_f64(&F_u[u], d_R);
556            atomic_add_f64(&F_au[u], d_AR - q_AR);
557        }
558
559        // F_uv: only b·b, b·h_j, b·w_ℓ have a material `S_{c,uv}`; every other
560        // (u, v) pair just contributes −Q(R_u, R_v).
561        // CPU parity: `SparsePrimaryCoeffJetView::pair_from_b_family` with
562        // `COEFF_SUPPORT_BHW` — every cross pair outside the b-row is zero.
563        for (int u = 1; u < r; ++u) {
564            const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
565            for (int v = u; v < r; ++v) {
566                const double *R_v = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(v - 1)) * 4;
567                double q_uv = Q_OF(R_u, R_v);
568                double d_s  = 0.0;
569                // S_{bb}: u == v == 1 (b coordinate).
570                if (u == 1 && v == 1) {
571                    const double *S_bb = cell_sbb + (size_t)c * 4;
572                    d_s = D_OF(S_bb);
573                }
574                // S_{b·h_j}: u == 1, v in score-warp block, or symmetric.
575                else if (u == 1 && v >= 2 && v < 2 + p_h) {
576                    int j = v - 2;
577                    const double *S_bh = cell_sbh + ((size_t)c * (size_t)p_h + (size_t)j) * 4;
578                    d_s = D_OF(S_bh);
579                }
580                // S_{b·w_ℓ}: u == 1, v in link-wiggle block, or symmetric.
581                else if (u == 1 && v >= 2 + p_h && v < r) {
582                    int l = v - (2 + p_h);
583                    const double *S_bw = cell_sbw + ((size_t)c * (size_t)p_w + (size_t)l) * 4;
584                    d_s = D_OF(S_bw);
585                }
586                // Symmetric mirror: u in (h or w) block, v == 1 cannot happen
587                // because we iterate v >= u; skip.
588                double val = d_s - q_uv;
589                atomic_add_f64(&F_uv[u * r + v], val);
590            }
591        }
592
593        #undef D_OF
594        #undef Q_OF
595    }
596
597    // Block reduction of local_Fa, local_Faa into shared.
598    reduce_a[tid] = local_Fa;
599    reduce_b[tid] = local_Faa;
600    __syncthreads();
601    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
602        if (tid < stride) {
603            reduce_a[tid] += reduce_a[tid + stride];
604            reduce_b[tid] += reduce_b[tid + stride];
605        }
606        __syncthreads();
607    }
608    if (tid == 0) {
609        F_a_shared  = reduce_a[0];
610        F_aa_shared = reduce_b[0];
611    }
612    __syncthreads();
613
614    // ── thread-0 finalisation: IFT + observed-point + Mills + writes ──────
615    if (tid != 0) return;
616
617    double F_a  = F_a_shared;
618    double F_aa = F_aa_shared;
619    double mu_1 = row_mu1[row];
620    double mu_2 = row_mu2[row];
621
622    // q-row overrides.
623    //   F_q  = -mu_1 ; F_qq = -mu_2 ; F_qv = 0 (v > 0) ; F_aq = 0.
624    F_u[0]  = -mu_1;
625    F_au[0] = 0.0;
626    // Zero the q-cross row/column of F_uv (u == 0 or v == 0), then plant -mu_2 at (0,0).
627    for (int v = 0; v < r; ++v) {
628        F_uv[0 * r + v] = 0.0;
629        F_uv[v * r + 0] = 0.0;
630    }
631    F_uv[0 * r + 0] = -mu_2;
632
633    // Guard: degenerate F_a ⇒ NaN-fill this row's outputs.
634    if (!isfinite(F_a) || F_a <= 0.0) {
635        nan_fill_outputs(r, row, out_neglog, out_grad, out_hess);
636        return;
637    }
638    double inv_Fa = 1.0 / F_a;
639
640    // IFT, first order.
641    //   a_u = -F_u · inv_Fa     (q-override: a_q = mu_1 · inv_Fa).
642    double a_u[32];
643    a_u[0] = mu_1 * inv_Fa;
644    for (int u = 1; u < r; ++u) {
645        a_u[u] = -F_u[u] * inv_Fa;
646    }
647
648    // IFT, second order.
649    //   a_uv = -(F_uv + F_au · a_v + F_av · a_u + F_aa · a_u · a_v) · inv_Fa.
650    // The q-row contributions (u==0 or v==0) collapse to a_uv = mu_2 · inv_Fa
651    // when both are 0 and to (F_au_v) · inv_Fa-style mixed shape otherwise.
652    // We compute it uniformly using the populated F_uv / F_au with the
653    // q-overrides above.
654    double a_uv[32 * 32];
655    for (int u = 0; u < r; ++u) {
656        for (int v = u; v < r; ++v) {
657            double term = F_uv[u * r + v]
658                        + F_au[v] * a_u[u]
659                        + F_au[u] * a_u[v]
660                        + F_aa * a_u[u] * a_u[v];
661            double val = -term * inv_Fa;
662            a_uv[u * r + v] = val;
663            a_uv[v * r + u] = val;
664        }
665    }
666
667    // Observed predictor jets at z_obs.
668    //   bar_e_u  = chi · a_u + rho_u.
669    //   bar_e_uv = chi · a_uv + xi · a_u · a_v + tau_u · a_v + a_u · tau_v + r_uv.
670    double chi = row_chi[row];
671    double xi  = row_xi[row];
672    const double *rho = row_rho + (size_t)row * r;
673    const double *tau = row_tau + (size_t)row * r;
674    const double *ruv = row_ruv + (size_t)row * r * r;
675
676    for (int u = 0; u < r; ++u) {
677        bar_e_u[u] = chi * a_u[u] + rho[u];
678    }
679    for (int u = 0; u < r; ++u) {
680        for (int v = u; v < r; ++v) {
681            double val = chi * a_uv[u * r + v]
682                       + xi  * a_u[u] * a_u[v]
683                       + tau[u] * a_u[v]
684                       + a_u[u] * tau[v]
685                       + ruv[u * r + v];
686            bar_e_uv[u * r + v] = val;
687            if (u != v) {
688                bar_e_uv[v * r + u] = val;
689            }
690        }
691    }
692
693    // Probit Mills.
694    double y    = row_y[row];
695    double w    = row_w[row];
696    double s    = 2.0 * y - 1.0;
697    // The "observed predictor" e_obs is the VALUE (degree-0 term) of the
698    // observed jet η(a(θ), θ; z_obs) — NOT `bar_e_u[0]`, which is the u=0
699    // FIRST-derivative jet (`chi·a_0 + rho_0 = dη_obs/dq`). The host packs
700    // the observed value directly in `row_e_obs[row]` (see
701    // `pack_bms_flex_row_kernel_inputs`, `eta_val = eval_coeff4_at(obs.coeff,
702    // z_obs)`), matching the CPU family `compute_row_analytic_flex_from_parts_into`
703    // which forms `signed_margin = s_y · eta_val`. #415 parity lock.
704    double e_obs = row_e_obs[row];
705    double m_arg = s * e_obs;
706    double log_cdf, lambda;
707    log_ndtr_and_mills(m_arg, &log_cdf, &lambda);
708    double A_i = -w * s * lambda;
709    double B_i =  w * lambda * (m_arg + lambda);
710
711    out_neglog[row] = -w * log_cdf;
712    for (int u = 0; u < r; ++u) {
713        out_grad[row * r + u] = A_i * bar_e_u[u];
714    }
715    for (int u = 0; u < r; ++u) {
716        for (int v = u; v < r; ++v) {
717            double val = B_i * bar_e_u[u] * bar_e_u[v] + A_i * bar_e_uv[u * r + v];
718            out_hess[row * r * r + u * r + v] = val;
719            if (u != v) {
720                out_hess[row * r * r + v * r + u] = val;
721            }
722        }
723    }
724}
725"#;
726
727// Force `s_f` to be considered used at the Rust level even though Stage 2 of
728// the kernel doesn't consume it on-device (the host has already baked the
729// probit frailty scale into the per-cell cubic coefficients). The dispatcher
730// wave that ports the rigid-branch fallback may want to apply `s_f` device-side
731// for log diagnostics; leaving the field on the input struct + reading it here
732// avoids a `let _` silencer the build.rs scanner would reject.
733#[inline]
734pub(crate) fn s_f_diagnostic_finite(inputs: &BmsFlexRowKernelInputs<'_>) -> bool {
735    inputs.s_f.is_finite() && inputs.s_f > 0.0
736}
737
738#[cfg(target_os = "linux")]
739pub(crate) struct RowKernelBackend {
740    pub(crate) stream: Arc<CudaStream>,
741    pub(crate) module: Arc<CudaModule>,
742}
743
744#[cfg(target_os = "linux")]
745impl RowKernelBackend {
746    pub(crate) fn probe() -> Result<&'static Self, GpuError> {
747        static BACKEND: OnceLock<Result<RowKernelBackend, GpuError>> = OnceLock::new();
748        BACKEND
749            .get_or_init(|| {
750                gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row", |parts| {
751                    let row_kernel_source = [
752                        gam_gpu::numerics_device::PROBIT_NUMERICS_CU,
753                        ROW_KERNEL_BODY,
754                    ]
755                    .concat();
756                    // #1551: route through the project's single arch-aware NVRTC
757                    // entry point instead of bare `cudarc::nvrtc::compile_ptx`.
758                    // `compile_ptx_arch` pins `--gpu-architecture` to the selected
759                    // device's compute capability and supplies the standard CUDA
760                    // include paths; bare `compile_ptx` uses NVRTC's default
761                    // virtual arch with no includes. The row kernel's 64-bit
762                    // `atomic_add_f64` (atomicCAS emulation) compiles best against
763                    // the real device arch, and this keeps every BMS-flex compile
764                    // site consistent with the SAE arrow/Schur kernels that do
765                    // require the sm_60 pin for native `atomicAdd(double*,double)`.
766                    let ptx = gam_gpu::device_cache::compile_ptx_arch(&row_kernel_source)
767                        .map_err(|err| GpuError::DriverCallFailed {
768                            reason: format!("bms_flex_row NVRTC compile failed: {err}"),
769                        })?;
770                    let module =
771                        parts
772                            .ctx
773                            .load_module(ptx)
774                            .map_err(|err| GpuError::DriverCallFailed {
775                                reason: format!("bms_flex_row module load failed: {err}"),
776                            })?;
777                    Ok(RowKernelBackend {
778                        stream: parts.stream.clone(),
779                        module,
780                    })
781                })
782            })
783            .as_ref()
784            .map_err(GpuError::clone)
785    }
786}
787
788/// Launch Stage-2 BMS FLEX row kernel. On non-Linux returns
789/// [`GpuError::DriverLibraryUnavailable`]; on Linux NVRTC-compiles the kernel
790/// (cached for the process lifetime), uploads the per-row + per-cell buffers,
791/// and dispatches one block per row.
792pub(crate) fn launch_bms_flex_row_kernel(
793    inputs: BmsFlexRowKernelInputs<'_>,
794) -> Result<BmsFlexRowKernelOutputs, GpuError> {
795    inputs.validate()?;
796    if !s_f_diagnostic_finite(&inputs) {
797        return Err(GpuError::DriverCallFailed {
798            reason: format!(
799                "bms_flex_row inputs: s_f must be positive and finite, got {}",
800                inputs.s_f
801            ),
802        });
803    }
804
805    #[cfg(target_os = "linux")]
806    {
807        launch_linux(inputs)
808    }
809    #[cfg(not(target_os = "linux"))]
810    {
811        Err(GpuError::DriverLibraryUnavailable {
812            reason: "bms_flex_row GPU kernel is Linux-only".to_string(),
813        })
814    }
815}
816
817#[cfg(target_os = "linux")]
818pub(crate) fn launch_linux(
819    inputs: BmsFlexRowKernelInputs<'_>,
820) -> Result<BmsFlexRowKernelOutputs, GpuError> {
821    let backend = RowKernelBackend::probe()?;
822    let stream = &backend.stream;
823
824    let upload_f64 = |slice: &[f64], label: &str| {
825        stream
826            .clone_htod(slice)
827            .map_err(|err| GpuError::DriverCallFailed {
828                reason: format!("bms_flex_row upload {label}: {err}"),
829            })
830    };
831    let upload_u32 = |slice: &[u32], label: &str| {
832        stream
833            .clone_htod(slice)
834            .map_err(|err| GpuError::DriverCallFailed {
835                reason: format!("bms_flex_row upload {label}: {err}"),
836            })
837    };
838
839    let d_q = upload_f64(inputs.q, "q")?;
840    let d_b = upload_f64(inputs.b, "b")?;
841    let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
842    let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
843    let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
844    let d_y = upload_f64(inputs.y, "y")?;
845    let d_w = upload_f64(inputs.w, "w")?;
846    let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
847    let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
848    let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
849    let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
850    let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
851    let d_a = upload_f64(inputs.cell_a, "cell_a")?;
852    let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
853    let d_r = upload_f64(inputs.cell_r, "cell_r")?;
854    let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
855    let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
856    let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
857    let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
858    // Phase-4: optionally consume device-resident moments (no host upload).
859    // Both branches end up holding a `&CudaSlice<f64>` named `d_moments_ref`
860    // we can pass to the launch builder uniformly.
861    let owned_host_moments: CudaSlice<f64>;
862    let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
863        CellMomentsSource::Host(slice) => {
864            owned_host_moments = upload_f64(slice, "cell_moments")?;
865            &owned_host_moments
866        }
867        CellMomentsSource::Device(d) => *d,
868    };
869    let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
870    let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
871    let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
872    let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
873    let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
874    let d_e_obs = upload_f64(inputs.e_obs, "e_obs")?;
875
876    let n = inputs.n_rows;
877    let r = inputs.r;
878    let mut d_neglog = stream
879        .alloc_zeros::<f64>(n)
880        .map_err(|err| GpuError::DriverCallFailed {
881            reason: format!("bms_flex_row alloc neglog: {err}"),
882        })?;
883    let mut d_grad =
884        stream
885            .alloc_zeros::<f64>(n * r)
886            .map_err(|err| GpuError::DriverCallFailed {
887                reason: format!("bms_flex_row alloc grad: {err}"),
888            })?;
889    let mut d_hess =
890        stream
891            .alloc_zeros::<f64>(n * r * r)
892            .map_err(|err| GpuError::DriverCallFailed {
893                reason: format!("bms_flex_row alloc hess: {err}"),
894            })?;
895
896    let func = backend
897        .module
898        .load_function("bms_flex_row_kernel")
899        .map_err(|err| GpuError::DriverCallFailed {
900            reason: format!("bms_flex_row load_function: {err}"),
901        })?;
902
903    let cfg = LaunchConfig {
904        grid_dim: (n as u32, 1, 1),
905        block_dim: (ROW_KERNEL_THREADS, 1, 1),
906        shared_mem_bytes: 0,
907    };
908    let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
909        reason: format!("bms_flex_row: n_rows={n} exceeds i32 range"),
910    })?;
911    let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
912        reason: format!("bms_flex_row: r={r} exceeds i32 range"),
913    })?;
914    let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
915        reason: format!("bms_flex_row: p_h={} exceeds i32 range", inputs.p_h),
916    })?;
917    let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
918        reason: format!("bms_flex_row: p_w={} exceeds i32 range", inputs.p_w),
919    })?;
920    let s_f = inputs.s_f;
921
922    let mut builder = stream.launch_builder(&func);
923    builder
924        .arg(&n_i32)
925        .arg(&r_i32)
926        .arg(&p_h_i32)
927        .arg(&p_w_i32)
928        .arg(&s_f)
929        .arg(&d_q)
930        .arg(&d_b)
931        .arg(&d_mu1)
932        .arg(&d_mu2)
933        .arg(&d_zobs)
934        .arg(&d_y)
935        .arg(&d_w)
936        .arg(&d_offsets)
937        .arg(&d_c0)
938        .arg(&d_c1)
939        .arg(&d_c2)
940        .arg(&d_c3)
941        .arg(&d_a)
942        .arg(&d_aa)
943        .arg(&d_r)
944        .arg(&d_ar)
945        .arg(&d_sbb)
946        .arg(&d_sbh)
947        .arg(&d_sbw)
948        .arg(d_moments_ref)
949        .arg(&d_chi)
950        .arg(&d_xi)
951        .arg(&d_rho)
952        .arg(&d_tau)
953        .arg(&d_ruv)
954        .arg(&d_e_obs)
955        .arg(&mut d_neglog)
956        .arg(&mut d_grad)
957        .arg(&mut d_hess);
958
959    // SAFETY: every kernel parameter above is either a primitive `i32` /
960    // `f64` (passed by value), a const device pointer to a buffer whose
961    // length the host validated against the input struct, or an output
962    // buffer pre-allocated to `n_rows`, `n_rows*r`, `n_rows*r*r`
963    // doubles. The kernel's shared-memory arrays are sized to MAX_R = 32
964    // and validate() rejects r > MAX_R.
965    unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
966        reason: format!("bms_flex_row launch: {err}"),
967    })?;
968    stream
969        .synchronize()
970        .map_err(|err| GpuError::DriverCallFailed {
971            reason: format!("bms_flex_row synchronize: {err}"),
972        })?;
973
974    let neglog = stream
975        .clone_dtoh(&d_neglog)
976        .map_err(|err| GpuError::DriverCallFailed {
977            reason: format!("bms_flex_row download neglog: {err}"),
978        })?;
979    let grad = stream
980        .clone_dtoh(&d_grad)
981        .map_err(|err| GpuError::DriverCallFailed {
982            reason: format!("bms_flex_row download grad: {err}"),
983        })?;
984    let hess = stream
985        .clone_dtoh(&d_hess)
986        .map_err(|err| GpuError::DriverCallFailed {
987            reason: format!("bms_flex_row download hess: {err}"),
988        })?;
989
990    Ok(BmsFlexRowKernelOutputs { neglog, grad, hess })
991}
992
993// ─────────────────────────────────────────────────────────────────────────────
994// Phase 3: device-resident row Hessian + HVP / diagonal kernels.
995//
996// Math (mirrors the CPU oracle in
997// `src/families/bernoulli_marginal_slope.rs::exact_newton_joint_hessian_*_from_cache`):
998//
999//   Block layout (joint β):
1000//     marginal = [0..p_m), logslope = [p_m..p_m+p_g),
1001//     h        = [h_start..h_end), w = [w_start..w_end), total = p_total.
1002//
1003//   Primary layout (per-row r-vector):
1004//     q = 0, logslope = 1,
1005//     h = [h_primary_start..h_primary_end),
1006//     w = [w_primary_start..w_primary_end), total = r.
1007//
1008//   row_dir[u] for u in primary layout:
1009//     row_dir[0]   = Σ_j marginal_design[row, j] · v[j]
1010//     row_dir[1]   = Σ_j logslope_design[row, j] · v[p_m + j]
1011//     row_dir[h_k] = v[h_block_start + (h_k - h_primary_start)]
1012//     row_dir[w_k] = v[w_block_start + (w_k - w_primary_start)]
1013//
1014//   action[u]    = Σ_v row_hessians[row, u*r + v] · row_dir[v]
1015//
1016//   block_partial[marginal_j] += action[0] · marginal_design[row, j]
1017//   block_partial[logslope_j] += action[1] · logslope_design[row, j]
1018//   block_partial[h_block_start + (h_k - h_primary_start)] += action[h_k]
1019//   block_partial[w_block_start + (w_k - w_primary_start)] += action[w_k]
1020//
1021// Diagonal:
1022//   diag[marginal_j] += row_hess[row, 0*r + 0] · marginal_design[row, j]²
1023//   diag[logslope_j] += row_hess[row, 1*r + 1] · logslope_design[row, j]²
1024//   diag[h_block_start + k] += row_hess[row, ii*r + ii]   (ii = h_primary_start + k)
1025//   diag[w_block_start + k] += row_hess[row, ii*r + ii]   (ii = w_primary_start + k)
1026//
1027// Determinism: each CTA owns a contiguous slice of `[chunk_start..chunk_end)`
1028// rows and writes its full per-chunk `p_total` partial into a non-overlapping
1029// region of the global partial buffer. The reduce kernel then sums those
1030// partials in fixed chunk-major order. No atomics.
1031
1032/// Joint-β block layout shared with the host (mirrors `BlockSlices` in
1033/// `bernoulli_marginal_slope.rs`).
1034///
1035/// Gating: Linux-only. The lone production constructor lives in
1036/// `bernoulli_marginal_slope.rs:9189` behind `#[cfg(target_os = "linux")]`
1037/// — the device-resident row-Hessian path is the only producer (see
1038/// `launch_bms_flex_row_kernel_device_resident`), and the joint-β
1039/// consumers `launch_bms_flex_row_hvp` / `_diagonal` / `_dense_block`
1040/// are also Linux-only. Any non-Linux test referencing this type must
1041/// guard itself with `#[cfg(target_os = "linux")]` too — the build.rs
1042/// ban scanner explicitly rejects `#[cfg(any(..., test))]` on items as
1043/// a dead-code escape hatch.
1044#[cfg(target_os = "linux")]
1045#[derive(Clone, Debug)]
1046pub(crate) struct BmsFlexBlockLayout {
1047    pub p_m: usize,
1048    pub p_g: usize,
1049    pub h: Option<std::ops::Range<usize>>,
1050    pub w: Option<std::ops::Range<usize>>,
1051    pub p_total: usize,
1052}
1053
1054/// Primary-r layout shared with the host (mirrors `PrimarySlices`).
1055/// Gating rationale identical to [`BmsFlexBlockLayout`].
1056#[cfg(target_os = "linux")]
1057#[derive(Clone, Debug)]
1058pub(crate) struct BmsFlexPrimaryLayout {
1059    pub h: Option<std::ops::Range<usize>>,
1060    pub w: Option<std::ops::Range<usize>>,
1061    pub r: usize,
1062}
1063
1064// ── Linux-only: device-resident row-Hessian state + kernels ─────────────────
1065
1066/// Number of rows each HVP / diagonal CTA processes. Each CTA writes a single
1067/// `[1, p_total]` partial row into the global partial buffer (no atomics);
1068/// the reduce kernel then sums partials in chunk-major fixed order.
1069#[cfg(target_os = "linux")]
1070pub(crate) const HVP_ROWS_PER_CTA: u32 = 256;
1071
1072/// `blockDim.x` for the HVP / diagonal partial kernels.
1073#[cfg(target_os = "linux")]
1074pub(crate) const HVP_THREADS: u32 = 128;
1075
1076/// `blockDim.x` for the partial-sum reduction kernels (one element per thread,
1077/// grid-strided over the `p_total`/`rhs_elems` partial buffer). A full warp
1078/// multiple that keeps the reduce launch occupancy-bound rather than tail-bound
1079/// for the typical large-scale `p_total`.
1080#[cfg(target_os = "linux")]
1081pub(crate) const REDUCTION_THREADS: u32 = 256;
1082
1083/// Maximum RHS columns fused into one row-primary HVP launch. The matching
1084/// CUDA source uses fixed shared arrays sized as
1085/// `BMS_FLEX_ROW_HVP_MAX_RHS * MAX_R`; increasing this requires updating the
1086/// `MAX_MULTI_RHS` define in `HVP_KERNEL_SOURCE`.
1087#[cfg(target_os = "linux")]
1088pub(crate) const BMS_FLEX_ROW_HVP_MAX_RHS: usize = 8;
1089
1090/// Device-resident state produced by
1091/// [`launch_bms_flex_row_kernel_device_resident`] and consumed by
1092/// [`launch_bms_flex_row_hvp`] / [`launch_bms_flex_row_diagonal`].
1093///
1094/// Owns the row-Hessian + design slices on-device so the host can issue
1095/// many HVPs against the same β snapshot without round-tripping
1096/// 626 MB through host RAM. Drop releases the device memory back to
1097/// the CUDA runtime.
1098/// Per-row Hessian storage layout on the device. The build path is free to
1099/// emit either, and the Hv / diag kernels read whichever the storage says.
1100///
1101/// Charter (Block 9 Phase 4): packed-upper halves the DRAM footprint of the
1102/// `n × r²` cache (per-row `r*(r+1)/2` doubles instead of `r²`), at the cost
1103/// of a single per-entry index conversion in the kernel. The benchmark
1104/// decides whether the packed path becomes the default for large-scale
1105/// fits (`r = 20` → 210 vs 400 doubles per row, ~47.5% smaller). The
1106/// numerics are bit-equal because each `H_i` is symmetric by construction
1107/// (the row kernel emits a symmetric block by construction — see the
1108/// symmetric scratch-write loop in `bms_flex_row_kernel`'s shared-memory
1109/// finaliser).
1110#[cfg(target_os = "linux")]
1111pub struct DeviceResidentRowHess {
1112    /// Per-row dense `[n, r, r]` row-major Hessian. Element `(u, v)` of row
1113    /// `i` is `hess[i*r*r + u*r + v]`. This is the only on-device storage
1114    /// layout supported by the current HVP / diag kernels.
1115    pub(crate) hess: CudaSlice<f64>,
1116    pub(crate) marginal_design: CudaSlice<f64>,
1117    pub(crate) logslope_design: CudaSlice<f64>,
1118    pub(crate) n: usize,
1119    pub(crate) r: usize,
1120    pub(crate) block: BmsFlexBlockLayout,
1121    pub(crate) primary: BmsFlexPrimaryLayout,
1122    /// Estimated bytes resident on device (for accounting).
1123    pub(crate) bytes: u64,
1124}
1125
1126#[cfg(target_os = "linux")]
1127impl std::fmt::Debug for DeviceResidentRowHess {
1128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1129        f.debug_struct("DeviceResidentRowHess")
1130            .field("n", &self.n)
1131            .field("r", &self.r)
1132            .field("p_total", &self.block.p_total)
1133            .field("bytes", &self.bytes)
1134            .finish()
1135    }
1136}
1137
1138/// Sized-to-fit-once CTA mapping. Rows `[c * HVP_ROWS_PER_CTA, (c+1) * HVP_ROWS_PER_CTA)`
1139/// belong to chunk `c`.
1140#[cfg(target_os = "linux")]
1141pub(crate) fn num_hvp_chunks(n: usize) -> usize {
1142    n.div_ceil(HVP_ROWS_PER_CTA as usize)
1143}
1144
1145/// NVRTC source: HVP-partial kernel + HVP-reduce kernel + diag-partial +
1146/// diag-reduce. All kernels mirror the CPU oracle in this file.
1147#[cfg(target_os = "linux")]
1148pub(crate) const HVP_KERNEL_SOURCE: &str = r#"
1149// CPU parity reference: cpu_oracle_bms_flex_row_hvp / cpu_oracle_bms_flex_row_diagonal
1150// in this module.
1151
1152#define MAX_MULTI_RHS 8
1153
1154extern "C" __global__ void bms_flex_row_hvp_partial(
1155    int                  n_rows,
1156    int                  r,
1157    int                  p_m,
1158    int                  p_g,
1159    int                  p_total,
1160    int                  h_block_start,
1161    int                  h_block_len,
1162    int                  w_block_start,
1163    int                  w_block_len,
1164    int                  h_primary_start,
1165    int                  w_primary_start,
1166    int                  rows_per_cta,
1167    const double * __restrict__ row_hessians,    // [n, r*r]
1168    const double * __restrict__ marginal_design, // [n, p_m] row-major
1169    const double * __restrict__ logslope_design, // [n, p_g] row-major
1170    const double * __restrict__ v,               // [p_total]
1171    double       * __restrict__ partial)         // [num_chunks, p_total]
1172{
1173    int chunk = blockIdx.x;
1174    int tid   = threadIdx.x;
1175    int row_lo = chunk * rows_per_cta;
1176    int row_hi = row_lo + rows_per_cta;
1177    if (row_hi > n_rows) row_hi = n_rows;
1178
1179    // Zero this chunk's partial slice cooperatively.
1180    double *out = partial + (size_t)chunk * (size_t)p_total;
1181    for (int j = tid; j < p_total; j += blockDim.x) {
1182        out[j] = 0.0;
1183    }
1184    __syncthreads();
1185
1186    // Each thread serially processes a stride-of-blockDim set of rows so
1187    // every write to `out[..]` happens from one thread → no atomics within
1188    // the chunk. To keep writes race-free across threads of the same chunk,
1189    // we serialize the cross-row accumulation through a per-row barrier:
1190    // thread 0 of the block processes all rows in the chunk. The per-row
1191    // work is dominated by the dot/axpy over `p_m + p_g`, which is large.
1192    // For Stage 3 we ship the simple, correct path (thread 0 sequential
1193    // per row, blockDim.x threads parallel within a row's dot/axpy).
1194    __shared__ double row_dir[32];
1195    __shared__ double action[32];
1196    __shared__ double dot_reduce[128];
1197
1198    for (int row = row_lo; row < row_hi; ++row) {
1199        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1200        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1201        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1202
1203        // row_dir[0] = mrow · v[0..p_m]
1204        double local = 0.0;
1205        for (int j = tid; j < p_m; j += blockDim.x) {
1206            local += mrow[j] * v[j];
1207        }
1208        dot_reduce[tid] = local;
1209        __syncthreads();
1210        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1211            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1212            __syncthreads();
1213        }
1214        if (tid == 0) row_dir[0] = dot_reduce[0];
1215
1216        // row_dir[1] = grow · v[p_m..p_m+p_g]
1217        local = 0.0;
1218        for (int j = tid; j < p_g; j += blockDim.x) {
1219            local += grow[j] * v[p_m + j];
1220        }
1221        dot_reduce[tid] = local;
1222        __syncthreads();
1223        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1224            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1225            __syncthreads();
1226        }
1227        if (tid == 0) row_dir[1] = dot_reduce[0];
1228
1229        // h/w blocks: direct copy.
1230        if (tid == 0) {
1231            for (int k = 0; k < h_block_len; ++k) {
1232                row_dir[h_primary_start + k] = v[h_block_start + k];
1233            }
1234            for (int k = 0; k < w_block_len; ++k) {
1235                row_dir[w_primary_start + k] = v[w_block_start + k];
1236            }
1237        }
1238        __syncthreads();
1239
1240        // action[u] = Σ_v Hrow[u*r+v] · row_dir[v], computed by thread u (u < r).
1241        if (tid < r) {
1242            double acc = 0.0;
1243            for (int vv = 0; vv < r; ++vv) {
1244                acc += Hrow[tid * r + vv] * row_dir[vv];
1245            }
1246            action[tid] = acc;
1247        }
1248        __syncthreads();
1249
1250        // Pull back into joint β slot.
1251        //   marginal: out[j] += action[0] · mrow[j]   (parallel j)
1252        double a0 = action[0];
1253        for (int j = tid; j < p_m; j += blockDim.x) {
1254            out[j] += a0 * mrow[j];
1255        }
1256        double a1 = action[1];
1257        for (int j = tid; j < p_g; j += blockDim.x) {
1258            out[p_m + j] += a1 * grow[j];
1259        }
1260        if (tid == 0) {
1261            for (int k = 0; k < h_block_len; ++k) {
1262                out[h_block_start + k] += action[h_primary_start + k];
1263            }
1264            for (int k = 0; k < w_block_len; ++k) {
1265                out[w_block_start + k] += action[w_primary_start + k];
1266            }
1267        }
1268        __syncthreads();
1269    }
1270}
1271
1272extern "C" __global__ void bms_flex_row_hvp_reduce(
1273    int                  num_chunks,
1274    int                  p_total,
1275    const double * __restrict__ partial,   // [num_chunks, p_total]
1276    double       * __restrict__ out)        // [p_total]
1277{
1278    int j = blockIdx.x * blockDim.x + threadIdx.x;
1279    if (j >= p_total) return;
1280    double acc = 0.0;
1281    for (int c = 0; c < num_chunks; ++c) {
1282        acc += partial[(size_t)c * (size_t)p_total + (size_t)j];
1283    }
1284    out[j] = acc;
1285}
1286
1287extern "C" __global__ void bms_flex_row_hvp_multi_partial(
1288    int                  n_rows,
1289    int                  r,
1290    int                  p_m,
1291    int                  p_g,
1292    int                  p_total,
1293    int                  h_block_start,
1294    int                  h_block_len,
1295    int                  w_block_start,
1296    int                  w_block_len,
1297    int                  h_primary_start,
1298    int                  w_primary_start,
1299    int                  rows_per_cta,
1300    int                  rhs_count,
1301    const double * __restrict__ row_hessians,    // [n, r*r]
1302    const double * __restrict__ marginal_design, // [n, p_m]
1303    const double * __restrict__ logslope_design, // [n, p_g]
1304    const double * __restrict__ v_rhs,           // [rhs_count, p_total]
1305    double       * __restrict__ partial)         // [rhs_count, num_chunks, p_total]
1306{
1307    int chunk = blockIdx.x;
1308    int tid   = threadIdx.x;
1309    int row_lo = chunk * rows_per_cta;
1310    int row_hi = row_lo + rows_per_cta;
1311    if (row_hi > n_rows) row_hi = n_rows;
1312
1313    int num_chunks = (n_rows + rows_per_cta - 1) / rows_per_cta;
1314    for (int idx = tid; idx < rhs_count * p_total; idx += blockDim.x) {
1315        int rhs = idx / p_total;
1316        int j = idx - rhs * p_total;
1317        partial[((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total + (size_t)j] = 0.0;
1318    }
1319    __syncthreads();
1320
1321    __shared__ double row_dir[MAX_MULTI_RHS * 32];
1322    __shared__ double action[MAX_MULTI_RHS * 32];
1323    __shared__ double dot_reduce[128];
1324
1325    for (int row = row_lo; row < row_hi; ++row) {
1326        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1327        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1328        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1329
1330        for (int rhs = 0; rhs < rhs_count; ++rhs) {
1331            const double *v = v_rhs + (size_t)rhs * (size_t)p_total;
1332
1333            double local = 0.0;
1334            for (int j = tid; j < p_m; j += blockDim.x) {
1335                local += mrow[j] * v[j];
1336            }
1337            dot_reduce[tid] = local;
1338            __syncthreads();
1339            for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1340                if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1341                __syncthreads();
1342            }
1343            if (tid == 0) row_dir[rhs * 32 + 0] = dot_reduce[0];
1344
1345            local = 0.0;
1346            for (int j = tid; j < p_g; j += blockDim.x) {
1347                local += grow[j] * v[p_m + j];
1348            }
1349            dot_reduce[tid] = local;
1350            __syncthreads();
1351            for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1352                if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1353                __syncthreads();
1354            }
1355            if (tid == 0) {
1356                row_dir[rhs * 32 + 1] = dot_reduce[0];
1357                for (int k = 0; k < h_block_len; ++k) {
1358                    row_dir[rhs * 32 + h_primary_start + k] = v[h_block_start + k];
1359                }
1360                for (int k = 0; k < w_block_len; ++k) {
1361                    row_dir[rhs * 32 + w_primary_start + k] = v[w_block_start + k];
1362                }
1363            }
1364            __syncthreads();
1365        }
1366
1367        for (int idx = tid; idx < rhs_count * r; idx += blockDim.x) {
1368            int rhs = idx / r;
1369            int u = idx - rhs * r;
1370            double acc = 0.0;
1371            const double *dir = row_dir + rhs * 32;
1372            for (int vv = 0; vv < r; ++vv) {
1373                acc += Hrow[u * r + vv] * dir[vv];
1374            }
1375            action[rhs * 32 + u] = acc;
1376        }
1377        __syncthreads();
1378
1379        for (int rhs = 0; rhs < rhs_count; ++rhs) {
1380            double *out = partial + ((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total;
1381            double a0 = action[rhs * 32 + 0];
1382            for (int j = tid; j < p_m; j += blockDim.x) {
1383                out[j] += a0 * mrow[j];
1384            }
1385            double a1 = action[rhs * 32 + 1];
1386            for (int j = tid; j < p_g; j += blockDim.x) {
1387                out[p_m + j] += a1 * grow[j];
1388            }
1389            if (tid == 0) {
1390                for (int k = 0; k < h_block_len; ++k) {
1391                    out[h_block_start + k] += action[rhs * 32 + h_primary_start + k];
1392                }
1393                for (int k = 0; k < w_block_len; ++k) {
1394                    out[w_block_start + k] += action[rhs * 32 + w_primary_start + k];
1395                }
1396            }
1397            __syncthreads();
1398        }
1399    }
1400}
1401
1402extern "C" __global__ void bms_flex_row_hvp_multi_reduce(
1403    int                  num_chunks,
1404    int                  p_total,
1405    int                  rhs_count,
1406    const double * __restrict__ partial,   // [rhs_count, num_chunks, p_total]
1407    double       * __restrict__ out)        // [rhs_count, p_total]
1408{
1409    int idx = blockIdx.x * blockDim.x + threadIdx.x;
1410    int total = rhs_count * p_total;
1411    if (idx >= total) return;
1412    int rhs = idx / p_total;
1413    int j = idx - rhs * p_total;
1414    double acc = 0.0;
1415    for (int c = 0; c < num_chunks; ++c) {
1416        acc += partial[((size_t)rhs * (size_t)num_chunks + (size_t)c) * (size_t)p_total + (size_t)j];
1417    }
1418    out[(size_t)rhs * (size_t)p_total + (size_t)j] = acc;
1419}
1420
1421extern "C" __global__ void bms_flex_row_diag_partial(
1422    int                  n_rows,
1423    int                  r,
1424    int                  p_m,
1425    int                  p_g,
1426    int                  p_total,
1427    int                  h_block_start,
1428    int                  h_block_len,
1429    int                  w_block_start,
1430    int                  w_block_len,
1431    int                  h_primary_start,
1432    int                  w_primary_start,
1433    int                  rows_per_cta,
1434    const double * __restrict__ row_hessians,
1435    const double * __restrict__ marginal_design,
1436    const double * __restrict__ logslope_design,
1437    double       * __restrict__ partial)
1438{
1439    int chunk = blockIdx.x;
1440    int tid   = threadIdx.x;
1441    int row_lo = chunk * rows_per_cta;
1442    int row_hi = row_lo + rows_per_cta;
1443    if (row_hi > n_rows) row_hi = n_rows;
1444
1445    double *out = partial + (size_t)chunk * (size_t)p_total;
1446    for (int j = tid; j < p_total; j += blockDim.x) {
1447        out[j] = 0.0;
1448    }
1449    __syncthreads();
1450
1451    for (int row = row_lo; row < row_hi; ++row) {
1452        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1453        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1454        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1455        double h00 = Hrow[0];
1456        double h11 = Hrow[1 * r + 1];
1457        for (int j = tid; j < p_m; j += blockDim.x) {
1458            double v = mrow[j];
1459            out[j] += h00 * v * v;
1460        }
1461        for (int j = tid; j < p_g; j += blockDim.x) {
1462            double v = grow[j];
1463            out[p_m + j] += h11 * v * v;
1464        }
1465        if (tid == 0) {
1466            for (int k = 0; k < h_block_len; ++k) {
1467                int ii = h_primary_start + k;
1468                out[h_block_start + k] += Hrow[ii * r + ii];
1469            }
1470            for (int k = 0; k < w_block_len; ++k) {
1471                int ii = w_primary_start + k;
1472                out[w_block_start + k] += Hrow[ii * r + ii];
1473            }
1474        }
1475        __syncthreads();
1476    }
1477}
1478
1479// ────────────────────────────────────────────────────────────────────────
1480// Phase 4 — SymmetricPackedUpper variants. Per-row storage is
1481//   row_hessians_packed + (size_t)row * (size_t)(r*(r+1)/2)
1482// indexed as
1483//   packed[(u*(2*r - u - 1))/2 + (v - u)]   for u <= v
1484// with symmetric mirror for v < u.
1485// ────────────────────────────────────────────────────────────────────────
1486
1487// Helper: packed-upper index for (u, v) within a single row of r*(r+1)/2
1488// doubles. Caller must pre-swap so that u <= v.
1489__device__ __forceinline__ int bms_flex_packed_idx(int u, int v, int r) {
1490    // u*(2r - u - 1)/2 + (v - u)
1491    return (u * (2 * r - u - 1)) / 2 + (v - u);
1492}
1493
1494// Pack one row of the full row-major r×r Hessian into packed-upper layout.
1495// Launched as one CTA per row (gridDim.x = n_rows, blockDim.x configurable).
1496// Bit-equal copy: each upper-triangle entry is read once from the dense
1497// source and written once to the packed destination.
1498extern "C" __global__ void bms_flex_row_pack_upper(
1499    int                  n_rows,
1500    int                  r,
1501    const double * __restrict__ src_full,    // [n, r*r]
1502    double       * __restrict__ dst_packed)  // [n, r*(r+1)/2]
1503{
1504    int row = blockIdx.x;
1505    if (row >= n_rows) return;
1506    int tid = threadIdx.x;
1507    int per_row = r * (r + 1) / 2;
1508    const double *src = src_full + (size_t)row * (size_t)r * (size_t)r;
1509    double       *dst = dst_packed + (size_t)row * (size_t)per_row;
1510    // Linear scan over packed positions; map each back to (u, v).
1511    for (int pos = tid; pos < per_row; pos += blockDim.x) {
1512        // Invert: for u in [0, r), the range [u_start, u_start + (r - u))
1513        // contains positions for that u. u_start = u*(2r - u - 1)/2.
1514        // Solve smallest u with u*(2r - u - 1)/2 > pos to get u (then
1515        // back off by one); equivalent O(r) linear scan with r <= 32.
1516        int u = 0;
1517        int u_start = 0;
1518        while (u < r) {
1519            int next = u_start + (r - u);
1520            if (pos < next) break;
1521            u_start = next;
1522            ++u;
1523        }
1524        int v = u + (pos - u_start);
1525        dst[pos] = src[(size_t)u * (size_t)r + (size_t)v];
1526    }
1527}
1528
1529extern "C" __global__ void bms_flex_row_hvp_partial_packed(
1530    int                  n_rows,
1531    int                  r,
1532    int                  p_m,
1533    int                  p_g,
1534    int                  p_total,
1535    int                  h_block_start,
1536    int                  h_block_len,
1537    int                  w_block_start,
1538    int                  w_block_len,
1539    int                  h_primary_start,
1540    int                  w_primary_start,
1541    int                  rows_per_cta,
1542    const double * __restrict__ row_hessians_packed, // [n, r*(r+1)/2]
1543    const double * __restrict__ marginal_design,
1544    const double * __restrict__ logslope_design,
1545    const double * __restrict__ v,
1546    double       * __restrict__ partial)
1547{
1548    int chunk = blockIdx.x;
1549    int tid   = threadIdx.x;
1550    int row_lo = chunk * rows_per_cta;
1551    int row_hi = row_lo + rows_per_cta;
1552    if (row_hi > n_rows) row_hi = n_rows;
1553
1554    int per_row = r * (r + 1) / 2;
1555    double *out = partial + (size_t)chunk * (size_t)p_total;
1556    for (int j = tid; j < p_total; j += blockDim.x) {
1557        out[j] = 0.0;
1558    }
1559    __syncthreads();
1560
1561    __shared__ double row_dir[32];
1562    __shared__ double action[32];
1563    __shared__ double dot_reduce[128];
1564
1565    for (int row = row_lo; row < row_hi; ++row) {
1566        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1567        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1568        const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1569
1570        // row_dir[0] = mrow · v[0..p_m]
1571        double local = 0.0;
1572        for (int j = tid; j < p_m; j += blockDim.x) {
1573            local += mrow[j] * v[j];
1574        }
1575        dot_reduce[tid] = local;
1576        __syncthreads();
1577        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1578            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1579            __syncthreads();
1580        }
1581        if (tid == 0) row_dir[0] = dot_reduce[0];
1582
1583        // row_dir[1] = grow · v[p_m..p_m+p_g]
1584        local = 0.0;
1585        for (int j = tid; j < p_g; j += blockDim.x) {
1586            local += grow[j] * v[p_m + j];
1587        }
1588        dot_reduce[tid] = local;
1589        __syncthreads();
1590        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1591            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1592            __syncthreads();
1593        }
1594        if (tid == 0) row_dir[1] = dot_reduce[0];
1595
1596        if (tid == 0) {
1597            for (int k = 0; k < h_block_len; ++k) {
1598                row_dir[h_primary_start + k] = v[h_block_start + k];
1599            }
1600            for (int k = 0; k < w_block_len; ++k) {
1601                row_dir[w_primary_start + k] = v[w_block_start + k];
1602            }
1603        }
1604        __syncthreads();
1605
1606        // action[u] = Σ_w H[u, w] · row_dir[w], where H[u, w] reads from
1607        // packed-upper with (uu, vv) = (min(u, w), max(u, w)).
1608        if (tid < r) {
1609            double acc = 0.0;
1610            int u = tid;
1611            for (int w = 0; w < r; ++w) {
1612                int uu = u < w ? u : w;
1613                int vv = u < w ? w : u;
1614                acc += Hrow[bms_flex_packed_idx(uu, vv, r)] * row_dir[w];
1615            }
1616            action[tid] = acc;
1617        }
1618        __syncthreads();
1619
1620        double a0 = action[0];
1621        for (int j = tid; j < p_m; j += blockDim.x) {
1622            out[j] += a0 * mrow[j];
1623        }
1624        double a1 = action[1];
1625        for (int j = tid; j < p_g; j += blockDim.x) {
1626            out[p_m + j] += a1 * grow[j];
1627        }
1628        if (tid == 0) {
1629            for (int k = 0; k < h_block_len; ++k) {
1630                out[h_block_start + k] += action[h_primary_start + k];
1631            }
1632            for (int k = 0; k < w_block_len; ++k) {
1633                out[w_block_start + k] += action[w_primary_start + k];
1634            }
1635        }
1636        __syncthreads();
1637    }
1638}
1639
1640// ────────────────────────────────────────────────────────────────────────
1641// Phase 6 — dense joint-Hessian block kernel for the debug / exact-REML
1642// route. Materialises the full `[p_total, p_total]` row-major joint H
1643// from the per-row r×r Hessian via the P_i pullback. NOT the default
1644// Newton path: production Newton uses HVP (Phase 2/3); this kernel exists
1645// for exact-REML logdet / dense-H comparisons / diagnostic dumps where the
1646// caller genuinely needs the dense matrix on the device.
1647//
1648// Per-CTA partial: each CTA owns a contiguous chunk of rows
1649// `[chunk*rows_per_cta, (chunk+1)*rows_per_cta)`. Inside the CTA the
1650// per-row pullback computes `(P_i^T H_i P_i)[m, n]` and adds it to the
1651// CTA's shared-mem `[p_total, p_total]` partial. The reduce kernel sums
1652// chunk-major-fixed-order into a single `[p_total, p_total]` output.
1653//
1654// Math: for primary index u ∈ [0, r):
1655//   * u = 0:        phi_u = (X_i in slot 0..p_m, 0 elsewhere)
1656//   * u = 1:        phi_u = (0, G_i in slot p_m..p_m+p_g, 0 elsewhere)
1657//   * u = 2+j:      phi_u = e_{h_block_start + j}  (j ∈ 0..h_block_len)
1658//   * u = 2+h+l:    phi_u = e_{w_block_start + l}  (l ∈ 0..w_block_len)
1659// Then `H_full[m, n] += sum_{u,v} H_i[u,v] * phi_u[m] * phi_v[n]`.
1660//
1661// Shared-memory budget: at large-scale shape p_total = 44, a [44, 44] f64
1662// partial is 44*44*8 = 15.5 KiB — well below the V100 48 KiB/SM cap.
1663// At p_total ≤ 80 the kernel still fits (80*80*8 = 50 KiB → just over
1664// V100 cap; caller must enforce p_total ≤ DENSE_BLOCK_MAX_P). The
1665// launcher rejects oversize p_total cleanly.
1666
1667extern "C" __global__ void bms_flex_row_dense_block_partial(
1668    int                  n_rows,
1669    int                  r,
1670    int                  p_m,
1671    int                  p_g,
1672    int                  p_total,
1673    int                  h_block_start,
1674    int                  h_block_len,
1675    int                  w_block_start,
1676    int                  w_block_len,
1677    int                  h_primary_start,
1678    int                  w_primary_start,
1679    int                  rows_per_cta,
1680    const double * __restrict__ row_hessians,    // [n, r*r]
1681    const double * __restrict__ marginal_design, // [n, p_m]
1682    const double * __restrict__ logslope_design, // [n, p_g]
1683    double       * __restrict__ partial)         // [num_chunks, p_total, p_total]
1684{
1685    extern __shared__ double shmem[];
1686    int chunk = blockIdx.x;
1687    int tid   = threadIdx.x;
1688    int row_lo = chunk * rows_per_cta;
1689    int row_hi = row_lo + rows_per_cta;
1690    if (row_hi > n_rows) row_hi = n_rows;
1691
1692    int pp = p_total * p_total;
1693    double *acc = shmem; // CTA-private accumulator [p_total, p_total]
1694    for (int j = tid; j < pp; j += blockDim.x) acc[j] = 0.0;
1695    __syncthreads();
1696
1697    // Per-row work performed by thread 0 to avoid cross-thread RW
1698    // contention on `acc[]`. Per-row complexity is O(r * p_m + r * p_g
1699    // + r²): tractable because r ≤ 32 and p_m + p_g typically ≤ 64.
1700    // Tighter parallel implementations are possible (warp-stripe the
1701    // 4-way nested u-v-m-n loop) but Phase 6 is a debug-only path and
1702    // the simple version is easier to audit for correctness against
1703    // the host-side P_i pullback oracle.
1704    if (tid == 0) {
1705        for (int row = row_lo; row < row_hi; ++row) {
1706            const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1707            const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1708            const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1709            for (int u = 0; u < r; ++u) {
1710                for (int v = 0; v < r; ++v) {
1711                    double huv = Hrow[u * r + v];
1712                    if (huv == 0.0) continue;
1713                    // For each (u, v), iterate (m, n) over the non-zero
1714                    // outer-product support of phi_u and phi_v.
1715                    // Build a small (offset, len, src_ptr) descriptor for
1716                    // each operand block as we go.
1717                    int m_off, m_len; const double *m_src; bool m_indicator;
1718                    int n_off, n_len; const double *n_src; bool n_indicator;
1719                    if (u == 0)      { m_off = 0;   m_len = p_m; m_src = mrow; m_indicator = false; }
1720                    else if (u == 1) { m_off = p_m; m_len = p_g; m_src = grow; m_indicator = false; }
1721                    else if (u - 2 < h_block_len) {
1722                                       m_off = h_block_start + (u - 2);
1723                                       m_len = 1;   m_src = NULL; m_indicator = true;
1724                    } else {
1725                                       m_off = w_block_start + (u - 2 - h_block_len);
1726                                       m_len = 1;   m_src = NULL; m_indicator = true;
1727                    }
1728                    if (v == 0)      { n_off = 0;   n_len = p_m; n_src = mrow; n_indicator = false; }
1729                    else if (v == 1) { n_off = p_m; n_len = p_g; n_src = grow; n_indicator = false; }
1730                    else if (v - 2 < h_block_len) {
1731                                       n_off = h_block_start + (v - 2);
1732                                       n_len = 1;   n_src = NULL; n_indicator = true;
1733                    } else {
1734                                       n_off = w_block_start + (v - 2 - h_block_len);
1735                                       n_len = 1;   n_src = NULL; n_indicator = true;
1736                    }
1737                    // accumulate huv * phi_u[m] * phi_v[n] into acc[m, n]
1738                    for (int mi = 0; mi < m_len; ++mi) {
1739                        double pm = m_indicator ? 1.0 : m_src[mi];
1740                        if (pm == 0.0) continue;
1741                        double scaled = huv * pm;
1742                        int m_idx = m_off + mi;
1743                        for (int ni = 0; ni < n_len; ++ni) {
1744                            double pn = n_indicator ? 1.0 : n_src[ni];
1745                            int n_idx = n_off + ni;
1746                            acc[m_idx * p_total + n_idx] += scaled * pn;
1747                        }
1748                    }
1749                }
1750            }
1751        }
1752    }
1753    __syncthreads();
1754
1755    // Write CTA accumulator out to global memory at its chunk slot.
1756    double *out_chunk = partial + (size_t)chunk * (size_t)pp;
1757    for (int j = tid; j < pp; j += blockDim.x) {
1758        out_chunk[j] = acc[j];
1759    }
1760}
1761
1762extern "C" __global__ void bms_flex_row_dense_block_reduce(
1763    int                  num_chunks,
1764    int                  p_total,
1765    const double * __restrict__ partial,
1766    double       * __restrict__ out)
1767{
1768    int j = blockIdx.x * blockDim.x + threadIdx.x;
1769    int pp = p_total * p_total;
1770    if (j >= pp) return;
1771    double acc = 0.0;
1772    for (int c = 0; c < num_chunks; ++c) {
1773        acc += partial[(size_t)c * (size_t)pp + (size_t)j];
1774    }
1775    out[j] = acc;
1776}
1777
1778extern "C" __global__ void bms_flex_row_diag_partial_packed(
1779    int                  n_rows,
1780    int                  r,
1781    int                  p_m,
1782    int                  p_g,
1783    int                  p_total,
1784    int                  h_block_start,
1785    int                  h_block_len,
1786    int                  w_block_start,
1787    int                  w_block_len,
1788    int                  h_primary_start,
1789    int                  w_primary_start,
1790    int                  rows_per_cta,
1791    const double * __restrict__ row_hessians_packed,
1792    const double * __restrict__ marginal_design,
1793    const double * __restrict__ logslope_design,
1794    double       * __restrict__ partial)
1795{
1796    int chunk = blockIdx.x;
1797    int tid   = threadIdx.x;
1798    int row_lo = chunk * rows_per_cta;
1799    int row_hi = row_lo + rows_per_cta;
1800    if (row_hi > n_rows) row_hi = n_rows;
1801
1802    int per_row = r * (r + 1) / 2;
1803    double *out = partial + (size_t)chunk * (size_t)p_total;
1804    for (int j = tid; j < p_total; j += blockDim.x) {
1805        out[j] = 0.0;
1806    }
1807    __syncthreads();
1808
1809    for (int row = row_lo; row < row_hi; ++row) {
1810        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1811        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1812        const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1813        // Diagonal entry for (u, u) sits at packed_idx(u, u, r).
1814        double h00 = Hrow[bms_flex_packed_idx(0, 0, r)];
1815        double h11 = Hrow[bms_flex_packed_idx(1, 1, r)];
1816        for (int j = tid; j < p_m; j += blockDim.x) {
1817            double v = mrow[j];
1818            out[j] += h00 * v * v;
1819        }
1820        for (int j = tid; j < p_g; j += blockDim.x) {
1821            double v = grow[j];
1822            out[p_m + j] += h11 * v * v;
1823        }
1824        if (tid == 0) {
1825            for (int k = 0; k < h_block_len; ++k) {
1826                int ii = h_primary_start + k;
1827                out[h_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1828            }
1829            for (int k = 0; k < w_block_len; ++k) {
1830                int ii = w_primary_start + k;
1831                out[w_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1832            }
1833        }
1834        __syncthreads();
1835    }
1836}
1837"#;
1838
1839#[cfg(target_os = "linux")]
1840pub(crate) struct HvpKernelBackend {
1841    pub(crate) stream: Arc<CudaStream>,
1842    pub(crate) module: Arc<CudaModule>,
1843}
1844
1845#[cfg(target_os = "linux")]
1846impl HvpKernelBackend {
1847    pub(crate) fn probe() -> Result<&'static Self, GpuError> {
1848        static BACKEND: OnceLock<Result<HvpKernelBackend, GpuError>> = OnceLock::new();
1849        BACKEND
1850            .get_or_init(|| {
1851                gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row hvp", |parts| {
1852                    // #1551: arch-aware compile (see launch_bms_flex_row_kernel) —
1853                    // pin `--gpu-architecture` to the device capability and supply
1854                    // the standard include paths via the shared NVRTC entry point.
1855                    let ptx = gam_gpu::device_cache::compile_ptx_arch(HVP_KERNEL_SOURCE)
1856                        .map_err(|err| GpuError::DriverCallFailed {
1857                            reason: format!("bms_flex_row hvp NVRTC compile failed: {err}"),
1858                        })?;
1859                    let module =
1860                        parts
1861                            .ctx
1862                            .load_module(ptx)
1863                            .map_err(|err| GpuError::DriverCallFailed {
1864                                reason: format!("bms_flex_row hvp module load failed: {err}"),
1865                            })?;
1866                    Ok(HvpKernelBackend {
1867                        stream: parts.stream.clone(),
1868                        module,
1869                    })
1870                })
1871            })
1872            .as_ref()
1873            .map_err(GpuError::clone)
1874    }
1875}
1876
1877/// Build a device-resident row-Hessian cache by launching the row kernel and
1878/// keeping the resulting `n × r²` slice resident on the device. Also uploads
1879/// the dense marginal + logslope design matrices so subsequent HVPs do not
1880/// re-upload them at every direction.
1881///
1882/// `marginal_design_row_major` and `logslope_design_row_major` must be
1883/// row-major `[n, p_m]` and `[n, p_g]` contiguous slices.
1884///
1885/// #461 absorber (additive Stage-1 influence block): the orthogonalization is
1886/// realized as **A2 — the marginal design widened to `[M | Z̃_infl]`** (see
1887/// `src/families/bms/block_specs.rs::widen_marginal_dense_with_influence`), NOT
1888/// a dedicated 5th primary coordinate. The absorber `+Z̃_infl·γ` is plain
1889/// additive into the marginal index `α(x)`, so γ lives inside `β_m` of the
1890/// widened block and `p_m` already counts the `p₁` influence columns. The row
1891/// kernel reads the marginal index from `block_states[0].eta` (which carries
1892/// `Z̃_infl·γ`) and pulls back through this same widened `marginal_design`, so
1893/// the absorber rides the existing primary-coordinate `u = 0` chain with **no
1894/// kernel-source change**: η, gradient, and Hessian match the CPU kernel
1895/// bit-for-bit precisely because `marginal_design` and `β_m` are the matched
1896/// (design, coefficient) pair the CPU path uses. The validation below pins
1897/// `marginal_design.len() == n·p_m` (with `p_m` widened), so a stale narrow
1898/// design against a widened `block.p_m` is rejected cleanly (CPU fallback)
1899/// rather than silently computing the wrong η. The absorber is dropped at
1900/// predict, where the marginal design is rebuilt without the influence columns,
1901/// so the predict-time `p_m` is narrow and this path is correct there too.
1902#[cfg(target_os = "linux")]
1903pub(crate) fn launch_bms_flex_row_kernel_device_resident(
1904    inputs: BmsFlexRowKernelInputs<'_>,
1905    marginal_design_row_major: &[f64],
1906    logslope_design_row_major: &[f64],
1907    block: BmsFlexBlockLayout,
1908    primary: BmsFlexPrimaryLayout,
1909) -> Result<DeviceResidentRowHess, GpuError> {
1910    inputs.validate()?;
1911    if !s_f_diagnostic_finite(&inputs) {
1912        return Err(GpuError::DriverCallFailed {
1913            reason: format!(
1914                "bms_flex_row device-resident: s_f must be positive and finite, got {}",
1915                inputs.s_f
1916            ),
1917        });
1918    }
1919    let n = inputs.n_rows;
1920    let r = inputs.r;
1921    if marginal_design_row_major.len() != n * block.p_m {
1922        return Err(GpuError::DriverCallFailed {
1923            reason: format!(
1924                "bms_flex_row device-resident: marginal_design len={} != n*p_m={}",
1925                marginal_design_row_major.len(),
1926                n * block.p_m
1927            ),
1928        });
1929    }
1930    if logslope_design_row_major.len() != n * block.p_g {
1931        return Err(GpuError::DriverCallFailed {
1932            reason: format!(
1933                "bms_flex_row device-resident: logslope_design len={} != n*p_g={}",
1934                logslope_design_row_major.len(),
1935                n * block.p_g
1936            ),
1937        });
1938    }
1939    if primary.r != r {
1940        return Err(GpuError::DriverCallFailed {
1941            reason: format!(
1942                "bms_flex_row device-resident: primary.r={} != inputs.r={}",
1943                primary.r, r
1944            ),
1945        });
1946    }
1947
1948    // Ensure the row kernel backend is compiled & loaded (this also compiles
1949    // the HVP backend on first use so the caller surfaces failures here).
1950    let backend = RowKernelBackend::probe()?;
1951    HvpKernelBackend::probe()?;
1952    let stream = backend.stream.clone();
1953
1954    let upload_f64 = |slice: &[f64], label: &str| {
1955        stream
1956            .clone_htod(slice)
1957            .map_err(|err| GpuError::DriverCallFailed {
1958                reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1959            })
1960    };
1961    let upload_u32 = |slice: &[u32], label: &str| {
1962        stream
1963            .clone_htod(slice)
1964            .map_err(|err| GpuError::DriverCallFailed {
1965                reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1966            })
1967    };
1968
1969    let d_q = upload_f64(inputs.q, "q")?;
1970    let d_b = upload_f64(inputs.b, "b")?;
1971    let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
1972    let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
1973    let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
1974    let d_y = upload_f64(inputs.y, "y")?;
1975    let d_w = upload_f64(inputs.w, "w")?;
1976    let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
1977    let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
1978    let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
1979    let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
1980    let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
1981    let d_a = upload_f64(inputs.cell_a, "cell_a")?;
1982    let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
1983    let d_r = upload_f64(inputs.cell_r, "cell_r")?;
1984    let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
1985    let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
1986    let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
1987    let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
1988    // Phase-4: optionally consume device-resident moments (no host upload).
1989    let owned_host_moments: CudaSlice<f64>;
1990    let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
1991        CellMomentsSource::Host(slice) => {
1992            owned_host_moments = upload_f64(slice, "cell_moments")?;
1993            &owned_host_moments
1994        }
1995        CellMomentsSource::Device(d) => *d,
1996    };
1997    let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
1998    let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
1999    let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
2000    let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
2001    let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
2002    let d_e_obs = upload_f64(inputs.e_obs, "e_obs")?;
2003
2004    let d_marginal = upload_f64(marginal_design_row_major, "marginal_design")?;
2005    let d_logslope = upload_f64(logslope_design_row_major, "logslope_design")?;
2006
2007    let mut d_neglog = stream
2008        .alloc_zeros::<f64>(n)
2009        .map_err(|err| GpuError::DriverCallFailed {
2010            reason: format!("bms_flex_row device-resident alloc neglog: {err}"),
2011        })?;
2012    let mut d_grad =
2013        stream
2014            .alloc_zeros::<f64>(n * r)
2015            .map_err(|err| GpuError::DriverCallFailed {
2016                reason: format!("bms_flex_row device-resident alloc grad: {err}"),
2017            })?;
2018    let mut d_hess =
2019        stream
2020            .alloc_zeros::<f64>(n * r * r)
2021            .map_err(|err| GpuError::DriverCallFailed {
2022                reason: format!("bms_flex_row device-resident alloc hess: {err}"),
2023            })?;
2024
2025    let func = backend
2026        .module
2027        .load_function("bms_flex_row_kernel")
2028        .map_err(|err| GpuError::DriverCallFailed {
2029            reason: format!("bms_flex_row device-resident load_function: {err}"),
2030        })?;
2031
2032    let cfg = LaunchConfig {
2033        grid_dim: (n as u32, 1, 1),
2034        block_dim: (ROW_KERNEL_THREADS, 1, 1),
2035        shared_mem_bytes: 0,
2036    };
2037    let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
2038        reason: format!("bms_flex_row device-resident: n_rows={n} exceeds i32 range"),
2039    })?;
2040    let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
2041        reason: format!("bms_flex_row device-resident: r={r} exceeds i32 range"),
2042    })?;
2043    let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
2044        reason: format!(
2045            "bms_flex_row device-resident: p_h={} exceeds i32 range",
2046            inputs.p_h
2047        ),
2048    })?;
2049    let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
2050        reason: format!(
2051            "bms_flex_row device-resident: p_w={} exceeds i32 range",
2052            inputs.p_w
2053        ),
2054    })?;
2055    let s_f_val = inputs.s_f;
2056
2057    let mut builder = stream.launch_builder(&func);
2058    builder
2059        .arg(&n_i32)
2060        .arg(&r_i32)
2061        .arg(&p_h_i32)
2062        .arg(&p_w_i32)
2063        .arg(&s_f_val)
2064        .arg(&d_q)
2065        .arg(&d_b)
2066        .arg(&d_mu1)
2067        .arg(&d_mu2)
2068        .arg(&d_zobs)
2069        .arg(&d_y)
2070        .arg(&d_w)
2071        .arg(&d_offsets)
2072        .arg(&d_c0)
2073        .arg(&d_c1)
2074        .arg(&d_c2)
2075        .arg(&d_c3)
2076        .arg(&d_a)
2077        .arg(&d_aa)
2078        .arg(&d_r)
2079        .arg(&d_ar)
2080        .arg(&d_sbb)
2081        .arg(&d_sbh)
2082        .arg(&d_sbw)
2083        .arg(d_moments_ref)
2084        .arg(&d_chi)
2085        .arg(&d_xi)
2086        .arg(&d_rho)
2087        .arg(&d_tau)
2088        .arg(&d_ruv)
2089        .arg(&d_e_obs)
2090        .arg(&mut d_neglog)
2091        .arg(&mut d_grad)
2092        .arg(&mut d_hess);
2093    // SAFETY: same shape contract as `launch_linux`: every kernel parameter is
2094    // either a primitive scalar by-value, a const device pointer whose
2095    // capacity was validated by `inputs.validate()`, or one of the three
2096    // output buffers we just allocated with the expected element count.
2097    unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
2098        reason: format!("bms_flex_row device-resident launch: {err}"),
2099    })?;
2100    stream
2101        .synchronize()
2102        .map_err(|err| GpuError::DriverCallFailed {
2103            reason: format!("bms_flex_row device-resident synchronize: {err}"),
2104        })?;
2105
2106    // The kernel writes neglog + grad alongside the row Hessian, but the
2107    // device-resident cache path keeps neither on the host: the fused
2108    // CPU gradient pass (the only consumer of host-side neglog/grad) is
2109    // dispatched only as a fallback when the GPU dense-block kernel does
2110    // not apply, and in that fallback the row kernel runs again locally.
2111    // Drop the device buffers so the allocation pool reclaims them
2112    // immediately rather than tying them to the cache's lifetime.
2113    drop(d_neglog);
2114    drop(d_grad);
2115    // Drop the per-cell uploads; keep d_hess + designs.
2116    drop(d_q);
2117    drop(d_b);
2118    drop(d_mu1);
2119    drop(d_mu2);
2120    drop(d_zobs);
2121    drop(d_y);
2122    drop(d_w);
2123    drop(d_offsets);
2124    drop(d_c0);
2125    drop(d_c1);
2126    drop(d_c2);
2127    drop(d_c3);
2128    drop(d_a);
2129    drop(d_aa);
2130    drop(d_r);
2131    drop(d_ar);
2132    drop(d_sbb);
2133    drop(d_sbh);
2134    drop(d_sbw);
2135    // `owned_host_moments` (if any) and the borrowed `d_moments_ref` both
2136    // go out of scope at the end of the function; the device-resident
2137    // moments owned by the caller stay alive.
2138    drop(d_chi);
2139    drop(d_xi);
2140    drop(d_rho);
2141    drop(d_tau);
2142    drop(d_ruv);
2143
2144    let bytes = ((n * r * r + marginal_design_row_major.len() + logslope_design_row_major.len())
2145        * std::mem::size_of::<f64>()) as u64;
2146    Ok(DeviceResidentRowHess {
2147        hess: d_hess,
2148        marginal_design: d_marginal,
2149        logslope_design: d_logslope,
2150        n,
2151        r,
2152        block,
2153        primary,
2154        bytes,
2155    })
2156}
2157
2158/// Which partial kernel the joint-β engine drives, whether it consumes a
2159/// direction vector `d_v`, and where the reduced `[1, p_total]` image lands.
2160/// All three points of variation are encoded here so the public entry points
2161/// stay thin wrappers over one launch helper.
2162#[cfg(target_os = "linux")]
2163#[derive(Clone, Copy)]
2164pub(crate) enum BmsFlexRowLaunchMode {
2165    /// `bms_flex_row_hvp_partial`, `H · v` per row, result left on-stream.
2166    HvpDeviceOut,
2167    /// `bms_flex_row_diag_partial`, `diag(H)` per row, downloaded to host.
2168    DiagonalHostOut,
2169}
2170
2171#[cfg(target_os = "linux")]
2172impl BmsFlexRowLaunchMode {
2173    /// Name of the partial kernel this mode loads from the HVP module.
2174    pub(crate) fn partial_kernel_name(self) -> &'static str {
2175        match self {
2176            BmsFlexRowLaunchMode::HvpDeviceOut => "bms_flex_row_hvp_partial",
2177            BmsFlexRowLaunchMode::DiagonalHostOut => "bms_flex_row_diag_partial",
2178        }
2179    }
2180}
2181
2182/// All scalar launch arguments for the joint-β partial kernel, derived once
2183/// from a [`DeviceResidentRowHess`]. The HVP and diagonal partial kernels take
2184/// the identical leading block-layout argument list (only the trailing
2185/// `d_v` / output pointers differ), so this captures the long, easy-to-
2186/// desynchronize prefix in a single place.
2187#[cfg(target_os = "linux")]
2188pub(crate) struct PreparedBmsFlexRowLaunchArgs {
2189    pub(crate) n_i32: i32,
2190    pub(crate) r_i32: i32,
2191    pub(crate) p_m_i32: i32,
2192    pub(crate) p_g_i32: i32,
2193    pub(crate) p_total_i32: i32,
2194    pub(crate) h_block_start: i32,
2195    pub(crate) h_block_len: i32,
2196    pub(crate) w_block_start: i32,
2197    pub(crate) w_block_len: i32,
2198    pub(crate) h_primary_start: i32,
2199    pub(crate) w_primary_start: i32,
2200    pub(crate) rows_per_cta: i32,
2201    pub(crate) num_chunks: usize,
2202}
2203
2204#[cfg(target_os = "linux")]
2205impl PreparedBmsFlexRowLaunchArgs {
2206    pub(crate) fn from_storage(storage: &DeviceResidentRowHess) -> Self {
2207        let p_total = storage.block.p_total;
2208        let num_chunks = num_hvp_chunks(storage.n);
2209        PreparedBmsFlexRowLaunchArgs {
2210            n_i32: storage.n as i32,
2211            r_i32: storage.r as i32,
2212            p_m_i32: storage.block.p_m as i32,
2213            p_g_i32: storage.block.p_g as i32,
2214            p_total_i32: p_total as i32,
2215            h_block_start: storage
2216                .block
2217                .h
2218                .as_ref()
2219                .map(|r| r.start as i32)
2220                .unwrap_or(0),
2221            h_block_len: storage
2222                .block
2223                .h
2224                .as_ref()
2225                .map(|r| r.len() as i32)
2226                .unwrap_or(0),
2227            w_block_start: storage
2228                .block
2229                .w
2230                .as_ref()
2231                .map(|r| r.start as i32)
2232                .unwrap_or(0),
2233            w_block_len: storage
2234                .block
2235                .w
2236                .as_ref()
2237                .map(|r| r.len() as i32)
2238                .unwrap_or(0),
2239            h_primary_start: storage
2240                .primary
2241                .h
2242                .as_ref()
2243                .map(|r| r.start as i32)
2244                .unwrap_or(0),
2245            w_primary_start: storage
2246                .primary
2247                .w
2248                .as_ref()
2249                .map(|r| r.start as i32)
2250                .unwrap_or(0),
2251            rows_per_cta: HVP_ROWS_PER_CTA as i32,
2252            num_chunks,
2253        }
2254    }
2255}
2256
2257/// Shared partial+reduce engine behind every joint-β launcher.
2258///
2259/// Allocates the `[num_chunks, p_total]` partial buffer, loads the mode's
2260/// partial kernel plus the common `bms_flex_row_hvp_reduce`, builds both
2261/// launch configs from a single [`PreparedBmsFlexRowLaunchArgs`], launches the
2262/// partial kernel (binding `d_v` only for the HVP modes), and launches the
2263/// reduction into caller-supplied `d_out`.
2264///
2265/// **No** `synchronize()` or DtoH is performed here — the surrounding helper
2266/// decides whether to keep the result on-stream (device-resident PCG hot path)
2267/// or sync + download it to the host. `ctx` is a short error-context tag woven
2268/// into every `DriverCallFailed` reason so failures stay attributable to the
2269/// originating entry point.
2270#[cfg(target_os = "linux")]
2271pub(crate) fn run_bms_flex_row_partial_reduce(
2272    storage: &DeviceResidentRowHess,
2273    mode: BmsFlexRowLaunchMode,
2274    d_v: Option<&CudaSlice<f64>>,
2275    d_out: &mut CudaSlice<f64>,
2276    ctx: &str,
2277) -> Result<(), GpuError> {
2278    let backend = HvpKernelBackend::probe()?;
2279    let stream = backend.stream.clone();
2280    let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2281    let p_total = storage.block.p_total;
2282
2283    let mut d_partial = stream
2284        .alloc_zeros::<f64>(args.num_chunks * p_total)
2285        .map_err(|err| GpuError::DriverCallFailed {
2286            reason: format!("bms_flex_row {ctx} alloc partial: {err}"),
2287        })?;
2288
2289    let partial_kernel_name = mode.partial_kernel_name();
2290    let part_func = backend
2291        .module
2292        .load_function(partial_kernel_name)
2293        .map_err(|err| GpuError::DriverCallFailed {
2294            reason: format!("bms_flex_row {ctx} load {partial_kernel_name}: {err}"),
2295        })?;
2296    let red_func = backend
2297        .module
2298        .load_function("bms_flex_row_hvp_reduce")
2299        .map_err(|err| GpuError::DriverCallFailed {
2300            reason: format!("bms_flex_row {ctx} load reduce: {err}"),
2301        })?;
2302
2303    let cfg_part = LaunchConfig {
2304        grid_dim: (args.num_chunks as u32, 1, 1),
2305        block_dim: (HVP_THREADS, 1, 1),
2306        shared_mem_bytes: 0,
2307    };
2308    let mut builder = stream.launch_builder(&part_func);
2309    builder
2310        .arg(&args.n_i32)
2311        .arg(&args.r_i32)
2312        .arg(&args.p_m_i32)
2313        .arg(&args.p_g_i32)
2314        .arg(&args.p_total_i32)
2315        .arg(&args.h_block_start)
2316        .arg(&args.h_block_len)
2317        .arg(&args.w_block_start)
2318        .arg(&args.w_block_len)
2319        .arg(&args.h_primary_start)
2320        .arg(&args.w_primary_start)
2321        .arg(&args.rows_per_cta)
2322        .arg(&storage.hess)
2323        .arg(&storage.marginal_design)
2324        .arg(&storage.logslope_design);
2325    if let Some(d_v) = d_v {
2326        builder.arg(d_v);
2327    }
2328    builder.arg(&mut d_partial);
2329    // SAFETY: every device pointer above either comes from `storage` (whose
2330    // capacities were established by
2331    // `launch_bms_flex_row_kernel_device_resident`) or was just allocated here
2332    // (`d_partial` = num_chunks * p_total). `d_v`, when bound, is length-checked
2333    // by the calling adapter against `p_total`. The diagonal partial kernel
2334    // takes no direction argument, matching `d_v == None`. Scalar args are i32
2335    // by-value.
2336    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2337        reason: format!("bms_flex_row {ctx} partial launch: {err}"),
2338    })?;
2339
2340    let red_threads: u32 = REDUCTION_THREADS;
2341    let red_blocks: u32 = ((p_total as u32) + red_threads - 1) / red_threads;
2342    let cfg_red = LaunchConfig {
2343        grid_dim: (red_blocks, 1, 1),
2344        block_dim: (red_threads, 1, 1),
2345        shared_mem_bytes: 0,
2346    };
2347    let num_chunks_i32 = args.num_chunks as i32;
2348    let mut builder = stream.launch_builder(&red_func);
2349    builder
2350        .arg(&num_chunks_i32)
2351        .arg(&args.p_total_i32)
2352        .arg(&d_partial)
2353        .arg(d_out);
2354    // SAFETY: `d_partial` was just populated by the partial kernel above;
2355    // `d_out` is `p_total` doubles (length-checked / allocated by the calling
2356    // adapter); both scalar args fit i32.
2357    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2358        reason: format!("bms_flex_row {ctx} reduce launch: {err}"),
2359    })?;
2360    // `d_partial` drops at end of fn; cudarc keeps the alloc alive until the
2361    // stream is done with it, so the reduce kernel completes safely.
2362    drop(d_partial);
2363    Ok(())
2364}
2365
2366/// Host-returning joint-β launcher shared by [`launch_bms_flex_row_hvp`]
2367/// ([`BmsFlexRowLaunchMode::HvpHostOut`]) and [`launch_bms_flex_row_diagonal`]
2368/// ([`BmsFlexRowLaunchMode::DiagonalHostOut`]).
2369///
2370/// Probes the backend, allocates the `p_total`-double output on its stream,
2371/// optionally uploads a host direction `v` (HVP modes; `None` for the diagonal
2372/// mode, which takes no direction), runs the shared partial+reduce engine, then
2373/// synchronizes and downloads the reduced image to a host `Vec<f64>`. `ctx`
2374/// tags every `DriverCallFailed` reason with the originating entry point.
2375#[cfg(target_os = "linux")]
2376pub(crate) fn launch_bms_flex_row_host(
2377    storage: &DeviceResidentRowHess,
2378    mode: BmsFlexRowLaunchMode,
2379    v: Option<&[f64]>,
2380    ctx: &str,
2381) -> Result<Vec<f64>, GpuError> {
2382    let p_total = storage.block.p_total;
2383    if let Some(v) = v {
2384        if v.len() != p_total {
2385            return Err(GpuError::DriverCallFailed {
2386                reason: format!(
2387                    "bms_flex_row {ctx}: v.len()={} != p_total={p_total}",
2388                    v.len()
2389                ),
2390            });
2391        }
2392    }
2393
2394    let backend = HvpKernelBackend::probe()?;
2395    let stream = backend.stream.clone();
2396
2397    let d_v = match v {
2398        Some(v) => Some(
2399            stream
2400                .clone_htod(v)
2401                .map_err(|err| GpuError::DriverCallFailed {
2402                    reason: format!("bms_flex_row {ctx} upload v: {err}"),
2403                })?,
2404        ),
2405        None => None,
2406    };
2407    let mut d_out =
2408        stream
2409            .alloc_zeros::<f64>(p_total)
2410            .map_err(|err| GpuError::DriverCallFailed {
2411                reason: format!("bms_flex_row {ctx} alloc out: {err}"),
2412            })?;
2413
2414    run_bms_flex_row_partial_reduce(storage, mode, d_v.as_ref(), &mut d_out, ctx)?;
2415
2416    stream
2417        .synchronize()
2418        .map_err(|err| GpuError::DriverCallFailed {
2419            reason: format!("bms_flex_row {ctx} synchronize: {err}"),
2420        })?;
2421    stream
2422        .clone_dtoh(&d_out)
2423        .map_err(|err| GpuError::DriverCallFailed {
2424            reason: format!("bms_flex_row {ctx} download out: {err}"),
2425        })
2426}
2427
2428#[cfg(target_os = "linux")]
2429pub(crate) fn validate_bms_flex_row_hvp_multi_shape(
2430    storage: &DeviceResidentRowHess,
2431    rhs_count: usize,
2432    v_rhs_len: usize,
2433    out_len: Option<usize>,
2434    ctx: &str,
2435) -> Result<usize, GpuError> {
2436    if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2437        return Err(GpuError::DriverCallFailed {
2438            reason: format!(
2439                "bms_flex_row {ctx}: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2440            ),
2441        });
2442    }
2443    let p_total = storage.block.p_total;
2444    let rhs_elems = rhs_count
2445        .checked_mul(p_total)
2446        .ok_or_else(|| GpuError::DriverCallFailed {
2447            reason: format!(
2448                "bms_flex_row {ctx}: rhs_count({rhs_count})*p_total({p_total}) overflow"
2449            ),
2450        })?;
2451    if v_rhs_len != rhs_elems {
2452        return Err(GpuError::DriverCallFailed {
2453            reason: format!(
2454                "bms_flex_row {ctx}: v_rhs.len()={v_rhs_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2455            ),
2456        });
2457    }
2458    if let Some(out_len) = out_len
2459        && out_len != rhs_elems
2460    {
2461        return Err(GpuError::DriverCallFailed {
2462            reason: format!(
2463                "bms_flex_row {ctx}: out.len()={out_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2464            ),
2465        });
2466    }
2467    Ok(rhs_elems)
2468}
2469
2470/// Transient device bytes for a multi-RHS HVP launch, excluding persistent
2471/// row-Hessian/design storage. Scratch scales with
2472/// `rhs_count * num_chunks * p_total`, not `rhs_count * n * r * r`.
2473#[cfg(target_os = "linux")]
2474pub fn bms_flex_row_hvp_multi_scratch_bytes_for_shape(
2475    n: usize,
2476    p_total: usize,
2477    rhs_count: usize,
2478) -> Result<u64, GpuError> {
2479    if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2480        return Err(GpuError::DriverCallFailed {
2481            reason: format!(
2482                "bms_flex_row hvp_multi_scratch_bytes: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2483            ),
2484        });
2485    }
2486    let num_chunks = num_hvp_chunks(n);
2487    let partial = rhs_count
2488        .checked_mul(num_chunks)
2489        .and_then(|v| v.checked_mul(p_total))
2490        .ok_or_else(|| GpuError::DriverCallFailed {
2491            reason: format!(
2492                "bms_flex_row hvp_multi_scratch_bytes: rhs_count({rhs_count})*num_chunks({num_chunks})*p_total({p_total}) overflow"
2493            ),
2494        })?;
2495    let rhs_vectors = rhs_count
2496        .checked_mul(p_total)
2497        .and_then(|v| v.checked_mul(2))
2498        .ok_or_else(|| GpuError::DriverCallFailed {
2499            reason: format!(
2500                "bms_flex_row hvp_multi_scratch_bytes: 2*rhs_count({rhs_count})*p_total({p_total}) overflow"
2501            ),
2502        })?;
2503    let elems = partial
2504        .checked_add(rhs_vectors)
2505        .ok_or_else(|| GpuError::DriverCallFailed {
2506            reason: "bms_flex_row hvp_multi_scratch_bytes: element count overflow".to_string(),
2507        })?;
2508    Ok((elems * std::mem::size_of::<f64>()) as u64)
2509}
2510
2511#[cfg(target_os = "linux")]
2512pub(crate) fn run_bms_flex_row_multi_partial_reduce(
2513    storage: &DeviceResidentRowHess,
2514    rhs_count: usize,
2515    d_v_rhs: &CudaSlice<f64>,
2516    d_out: &mut CudaSlice<f64>,
2517    ctx: &str,
2518) -> Result<(), GpuError> {
2519    let rhs_elems = validate_bms_flex_row_hvp_multi_shape(
2520        storage,
2521        rhs_count,
2522        d_v_rhs.len(),
2523        Some(d_out.len()),
2524        ctx,
2525    )?;
2526    let backend = HvpKernelBackend::probe()?;
2527    let stream = backend.stream.clone();
2528    let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2529    let p_total = storage.block.p_total;
2530    let partial_len = rhs_count
2531        .checked_mul(args.num_chunks)
2532        .and_then(|v| v.checked_mul(p_total))
2533        .ok_or_else(|| GpuError::DriverCallFailed {
2534            reason: format!(
2535                "bms_flex_row {ctx}: partial length overflow for rhs_count={rhs_count}, num_chunks={}, p_total={p_total}",
2536                args.num_chunks
2537            ),
2538        })?;
2539
2540    let mut d_partial =
2541        stream
2542            .alloc_zeros::<f64>(partial_len)
2543            .map_err(|err| GpuError::DriverCallFailed {
2544                reason: format!("bms_flex_row {ctx} alloc multi partial: {err}"),
2545            })?;
2546    let part_func = backend
2547        .module
2548        .load_function("bms_flex_row_hvp_multi_partial")
2549        .map_err(|err| GpuError::DriverCallFailed {
2550            reason: format!("bms_flex_row {ctx} load multi partial: {err}"),
2551        })?;
2552    let red_func = backend
2553        .module
2554        .load_function("bms_flex_row_hvp_multi_reduce")
2555        .map_err(|err| GpuError::DriverCallFailed {
2556            reason: format!("bms_flex_row {ctx} load multi reduce: {err}"),
2557        })?;
2558
2559    let rhs_count_i32 = i32::try_from(rhs_count).map_err(|_| GpuError::DriverCallFailed {
2560        reason: format!("bms_flex_row {ctx}: rhs_count={rhs_count} exceeds i32 range"),
2561    })?;
2562    let cfg_part = LaunchConfig {
2563        grid_dim: (args.num_chunks as u32, 1, 1),
2564        block_dim: (HVP_THREADS, 1, 1),
2565        shared_mem_bytes: 0,
2566    };
2567    let mut builder = stream.launch_builder(&part_func);
2568    builder
2569        .arg(&args.n_i32)
2570        .arg(&args.r_i32)
2571        .arg(&args.p_m_i32)
2572        .arg(&args.p_g_i32)
2573        .arg(&args.p_total_i32)
2574        .arg(&args.h_block_start)
2575        .arg(&args.h_block_len)
2576        .arg(&args.w_block_start)
2577        .arg(&args.w_block_len)
2578        .arg(&args.h_primary_start)
2579        .arg(&args.w_primary_start)
2580        .arg(&args.rows_per_cta)
2581        .arg(&rhs_count_i32)
2582        .arg(&storage.hess)
2583        .arg(&storage.marginal_design)
2584        .arg(&storage.logslope_design)
2585        .arg(d_v_rhs)
2586        .arg(&mut d_partial);
2587    // SAFETY: storage buffers were validated at construction; `d_v_rhs` and
2588    // `d_out` have rhs_count*p_total elements, `d_partial` has
2589    // rhs_count*num_chunks*p_total, and rhs_count is bounded by fixed shared
2590    // array sizes in the CUDA source.
2591    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2592        reason: format!("bms_flex_row {ctx} multi partial launch: {err}"),
2593    })?;
2594
2595    let red_threads: u32 = REDUCTION_THREADS;
2596    let red_blocks: u32 = ((rhs_elems as u32) + red_threads - 1) / red_threads;
2597    let cfg_red = LaunchConfig {
2598        grid_dim: (red_blocks, 1, 1),
2599        block_dim: (red_threads, 1, 1),
2600        shared_mem_bytes: 0,
2601    };
2602    let num_chunks_i32 = args.num_chunks as i32;
2603    let mut builder = stream.launch_builder(&red_func);
2604    builder
2605        .arg(&num_chunks_i32)
2606        .arg(&args.p_total_i32)
2607        .arg(&rhs_count_i32)
2608        .arg(&d_partial)
2609        .arg(d_out);
2610    // SAFETY: the reduce kernel reads the just-populated partial buffer and
2611    // writes exactly rhs_count*p_total output entries.
2612    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2613        reason: format!("bms_flex_row {ctx} multi reduce launch: {err}"),
2614    })?;
2615    drop(d_partial);
2616    Ok(())
2617}
2618
2619/// Device-resident multi-RHS HVP. `v_rhs` is row-major
2620/// `[rhs_count, p_total]`; the returned vector has the same layout.
2621#[cfg(target_os = "linux")]
2622pub(crate) fn launch_bms_flex_row_hvp_multi(
2623    storage: &DeviceResidentRowHess,
2624    v_rhs: &[f64],
2625    rhs_count: usize,
2626) -> Result<Vec<f64>, GpuError> {
2627    let rhs_elems =
2628        validate_bms_flex_row_hvp_multi_shape(storage, rhs_count, v_rhs.len(), None, "hvp_multi")?;
2629    let backend = HvpKernelBackend::probe()?;
2630    let stream = backend.stream.clone();
2631    let d_v_rhs = stream
2632        .clone_htod(v_rhs)
2633        .map_err(|err| GpuError::DriverCallFailed {
2634            reason: format!("bms_flex_row hvp_multi upload v_rhs: {err}"),
2635        })?;
2636    let mut d_out =
2637        stream
2638            .alloc_zeros::<f64>(rhs_elems)
2639            .map_err(|err| GpuError::DriverCallFailed {
2640                reason: format!("bms_flex_row hvp_multi alloc out: {err}"),
2641            })?;
2642    run_bms_flex_row_multi_partial_reduce(storage, rhs_count, &d_v_rhs, &mut d_out, "hvp_multi")?;
2643    stream
2644        .synchronize()
2645        .map_err(|err| GpuError::DriverCallFailed {
2646            reason: format!("bms_flex_row hvp_multi synchronize: {err}"),
2647        })?;
2648    stream
2649        .clone_dtoh(&d_out)
2650        .map_err(|err| GpuError::DriverCallFailed {
2651            reason: format!("bms_flex_row hvp_multi download out: {err}"),
2652        })
2653}
2654
2655/// Device-output HVP. Runs `bms_flex_row_hvp_partial(_packed)` +
2656/// `bms_flex_row_hvp_reduce` on the storage's stream against caller-supplied
2657/// device-resident `d_v` (length `p_total` doubles), writing the result into
2658/// caller-supplied `d_out` (also `p_total` doubles). **No** `synchronize()`
2659/// or DtoH is performed — the caller is responsible for stream ordering
2660/// against any consumer that reads `d_out`.
2661///
2662/// This is the device-resident PCG hot path (Block 9 Phase 5): keeping the
2663/// HVP output on the stream lets the outer PCG loop chain axpy / dot /
2664/// preconditioner kernels back-to-back without a per-iter device sync.
2665#[cfg(target_os = "linux")]
2666pub(crate) fn launch_bms_flex_row_hvp_into_device(
2667    storage: &DeviceResidentRowHess,
2668    d_v: &CudaSlice<f64>,
2669    d_out: &mut CudaSlice<f64>,
2670) -> Result<(), GpuError> {
2671    let p_total = storage.block.p_total;
2672    if d_v.len() != p_total {
2673        return Err(GpuError::DriverCallFailed {
2674            reason: format!(
2675                "bms_flex_row hvp_into_device: d_v.len()={} != p_total={}",
2676                d_v.len(),
2677                p_total
2678            ),
2679        });
2680    }
2681    if d_out.len() != p_total {
2682        return Err(GpuError::DriverCallFailed {
2683            reason: format!(
2684                "bms_flex_row hvp_into_device: d_out.len()={} != p_total={}",
2685                d_out.len(),
2686                p_total
2687            ),
2688        });
2689    }
2690    // On-stream output: the shared engine launches partial+reduce into the
2691    // caller's `d_out` and returns without sync/DtoH, so the outer PCG loop can
2692    // chain device kernels against the result.
2693    run_bms_flex_row_partial_reduce(
2694        storage,
2695        BmsFlexRowLaunchMode::HvpDeviceOut,
2696        Some(d_v),
2697        d_out,
2698        "hvp_into_device",
2699    )
2700}
2701
2702/// Launch the device-resident HVP kernel. Returns the host-side joint β image
2703/// of length `block.p_total`.
2704#[cfg(target_os = "linux")]
2705pub(crate) fn launch_bms_flex_row_hvp(
2706    storage: &DeviceResidentRowHess,
2707    v: &[f64],
2708) -> Result<Vec<f64>, GpuError> {
2709    launch_bms_flex_row_hvp_multi(storage, v, 1)
2710}
2711
2712/// Launch the device-resident diagonal kernel. Returns the host-side joint
2713/// β diagonal of length `block.p_total`.
2714#[cfg(target_os = "linux")]
2715pub(crate) fn launch_bms_flex_row_diagonal(
2716    storage: &DeviceResidentRowHess,
2717) -> Result<Vec<f64>, GpuError> {
2718    launch_bms_flex_row_host(storage, BmsFlexRowLaunchMode::DiagonalHostOut, None, "diag")
2719}
2720
2721/// Block 9 Phase 6 — hard cap on `p_total` for the dense joint-Hessian
2722/// device kernel. Per-CTA shared-memory accumulator is `p_total² * 8`
2723/// bytes. V100 default per-block shared cap is 48 KiB, so the largest
2724/// safe `p_total` here is `sqrt(48 KiB / 8) = 78`. We round down to a
2725/// power-of-two-ish multiple of 8 for predictable launch geometry.
2726#[cfg(target_os = "linux")]
2727pub(crate) const DENSE_BLOCK_MAX_P: usize = 72;
2728
2729/// Number of rows each dense-block CTA processes. Smaller than the HVP
2730/// `HVP_ROWS_PER_CTA = 256` because the per-row inner loop is `O(r² *
2731/// (p_m + p_g + h_block_len + w_block_len))` rather than `O(r²)` — fewer
2732/// rows per CTA keeps the per-CTA wall time short and lets us scale grid
2733/// occupancy with `num_chunks = ceil(n / DENSE_BLOCK_ROWS_PER_CTA)`.
2734#[cfg(target_os = "linux")]
2735pub(crate) const DENSE_BLOCK_ROWS_PER_CTA: u32 = 32;
2736
2737/// Launch the Phase-6 dense joint-Hessian block kernel. Returns the
2738/// host-side `[p_total, p_total]` row-major joint H as a `Vec<f64>`
2739/// (length `p_total²`).
2740///
2741/// **Not the default Newton path.** Production Newton uses HVP (Phase 2)
2742/// and never materialises the full dense Hessian. This entry exists for:
2743///   * exact-REML logdet (`log|H|`) when the unified evaluator wants to
2744///     factor H directly instead of going through the matrix-free path;
2745///   * diagnostic dumps that compare the GPU dense build against the CPU
2746///     `BernoulliMarginalSlopeFamily::fused_gradient_dense` reference;
2747///   * small-`p` debug routes where it is cheaper to factor + solve dense
2748///     than to run a PCG.
2749///
2750/// The kernel rejects `p_total > DENSE_BLOCK_MAX_P` cleanly because the
2751/// per-CTA shared-memory accumulator (`p_total² * 8` bytes) would exceed
2752/// the V100 48 KiB/block cap above that threshold.
2753#[cfg(target_os = "linux")]
2754pub fn launch_bms_flex_row_dense_block(
2755    storage: &DeviceResidentRowHess,
2756) -> Result<Vec<f64>, GpuError> {
2757    let p_total = storage.block.p_total;
2758    if p_total == 0 {
2759        return Err(GpuError::DriverCallFailed {
2760            reason: "bms_flex_row dense_block: p_total must be > 0".to_string(),
2761        });
2762    }
2763    if p_total > DENSE_BLOCK_MAX_P {
2764        return Err(GpuError::DriverCallFailed {
2765            reason: format!(
2766                "bms_flex_row dense_block: p_total={p_total} exceeds DENSE_BLOCK_MAX_P={DENSE_BLOCK_MAX_P} \
2767                 (per-CTA shmem accumulator p²*8 bytes would exceed V100's 48 KiB/block)"
2768            ),
2769        });
2770    }
2771    let backend = HvpKernelBackend::probe()?;
2772    let stream = backend.stream.clone();
2773    let n = storage.n;
2774    let r = storage.r;
2775    let rows_per_cta = DENSE_BLOCK_ROWS_PER_CTA as usize;
2776    let num_chunks = n.div_ceil(rows_per_cta);
2777    let pp = p_total * p_total;
2778
2779    let mut d_partial =
2780        stream
2781            .alloc_zeros::<f64>(num_chunks * pp)
2782            .map_err(|err| GpuError::DriverCallFailed {
2783                reason: format!("bms_flex_row dense_block alloc partial: {err}"),
2784            })?;
2785    let mut d_out = stream
2786        .alloc_zeros::<f64>(pp)
2787        .map_err(|err| GpuError::DriverCallFailed {
2788            reason: format!("bms_flex_row dense_block alloc out: {err}"),
2789        })?;
2790
2791    let part_func = backend
2792        .module
2793        .load_function("bms_flex_row_dense_block_partial")
2794        .map_err(|err| GpuError::DriverCallFailed {
2795            reason: format!("bms_flex_row dense_block load partial: {err}"),
2796        })?;
2797    let red_func = backend
2798        .module
2799        .load_function("bms_flex_row_dense_block_reduce")
2800        .map_err(|err| GpuError::DriverCallFailed {
2801            reason: format!("bms_flex_row dense_block load reduce: {err}"),
2802        })?;
2803
2804    let n_i32 = n as i32;
2805    let r_i32 = r as i32;
2806    let p_m_i32 = storage.block.p_m as i32;
2807    let p_g_i32 = storage.block.p_g as i32;
2808    let p_total_i32 = p_total as i32;
2809    let h_block_start = storage
2810        .block
2811        .h
2812        .as_ref()
2813        .map(|r| r.start as i32)
2814        .unwrap_or(0);
2815    let h_block_len = storage
2816        .block
2817        .h
2818        .as_ref()
2819        .map(|r| r.len() as i32)
2820        .unwrap_or(0);
2821    let w_block_start = storage
2822        .block
2823        .w
2824        .as_ref()
2825        .map(|r| r.start as i32)
2826        .unwrap_or(0);
2827    let w_block_len = storage
2828        .block
2829        .w
2830        .as_ref()
2831        .map(|r| r.len() as i32)
2832        .unwrap_or(0);
2833    let h_primary_start = storage
2834        .primary
2835        .h
2836        .as_ref()
2837        .map(|r| r.start as i32)
2838        .unwrap_or(0);
2839    let w_primary_start = storage
2840        .primary
2841        .w
2842        .as_ref()
2843        .map(|r| r.start as i32)
2844        .unwrap_or(0);
2845    let rows_per_cta_i32 = DENSE_BLOCK_ROWS_PER_CTA as i32;
2846    let num_chunks_u32 = num_chunks as u32;
2847
2848    // Per-CTA shmem accumulator: p_total² doubles.
2849    let shmem_bytes: u32 =
2850        u32::try_from(pp * std::mem::size_of::<f64>()).map_err(|_| GpuError::DriverCallFailed {
2851            reason: format!("dense_block shmem bytes overflow u32 for p_total={p_total}"),
2852        })?;
2853
2854    let cfg_part = LaunchConfig {
2855        grid_dim: (num_chunks_u32, 1, 1),
2856        block_dim: (HVP_THREADS, 1, 1),
2857        shared_mem_bytes: shmem_bytes,
2858    };
2859    let mut builder = stream.launch_builder(&part_func);
2860    builder
2861        .arg(&n_i32)
2862        .arg(&r_i32)
2863        .arg(&p_m_i32)
2864        .arg(&p_g_i32)
2865        .arg(&p_total_i32)
2866        .arg(&h_block_start)
2867        .arg(&h_block_len)
2868        .arg(&w_block_start)
2869        .arg(&w_block_len)
2870        .arg(&h_primary_start)
2871        .arg(&w_primary_start)
2872        .arg(&rows_per_cta_i32)
2873        .arg(&storage.hess)
2874        .arg(&storage.marginal_design)
2875        .arg(&storage.logslope_design)
2876        .arg(&mut d_partial);
2877    // SAFETY: storage pointers have validated capacities; d_partial sized
2878    // num_chunks * pp doubles; dynamic shmem matches the kernel's `extern
2879    // __shared__` accumulator length.
2880    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2881        reason: format!("bms_flex_row dense_block partial launch: {err}"),
2882    })?;
2883
2884    let red_threads: u32 = REDUCTION_THREADS;
2885    let red_blocks: u32 = ((pp as u32) + red_threads - 1) / red_threads;
2886    let cfg_red = LaunchConfig {
2887        grid_dim: (red_blocks, 1, 1),
2888        block_dim: (red_threads, 1, 1),
2889        shared_mem_bytes: 0,
2890    };
2891    let num_chunks_i32 = num_chunks as i32;
2892    let mut builder = stream.launch_builder(&red_func);
2893    builder
2894        .arg(&num_chunks_i32)
2895        .arg(&p_total_i32)
2896        .arg(&d_partial)
2897        .arg(&mut d_out);
2898    // SAFETY: d_partial just populated, d_out is pp doubles.
2899    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2900        reason: format!("bms_flex_row dense_block reduce launch: {err}"),
2901    })?;
2902    stream
2903        .synchronize()
2904        .map_err(|err| GpuError::DriverCallFailed {
2905            reason: format!("bms_flex_row dense_block sync: {err}"),
2906        })?;
2907    stream
2908        .clone_dtoh(&d_out)
2909        .map_err(|err| GpuError::DriverCallFailed {
2910            reason: format!("bms_flex_row dense_block download: {err}"),
2911        })
2912}
2913
2914// Block 9 / V100-build unblock (2026-05-27): every test below either
2915// constructs `BmsFlexBlockLayout` / `BmsFlexPrimaryLayout` (Linux-only
2916// types) or drives a CUDA-dependent fixture, so gate the whole module
2917// `#[cfg(all(test, target_os = "linux"))]`. On macOS the structs are
2918// absent and these tests do not compile — the build.rs ban scanner
2919// explicitly rejects `#[cfg(any(..., test))]` on the struct definitions
2920// themselves as a dead-code escape hatch.
2921// #415 parity lock: the CPU host oracle for the BMS-FLEX row kernel lives in a
2922// non-linux-gated `#[cfg(test)]` module so it can be exercised on the macOS dev
2923// box + CPU CI (the sibling `mod tests` is linux-gated because it also builds
2924// CUDA-only fixture types). `cpu_oracle_outputs` itself is platform-independent.
2925#[cfg(test)]
2926mod oracle_parity_tests {
2927    use super::*;
2928
2929    // ── CPU oracle that mirrors ROW_KERNEL_BODY bit-for-bit ──────────────────
2930    //
2931    // `cpu_oracle_outputs` implements the same algebra as
2932    // `bms_flex_row_kernel` in ROW_KERNEL_BODY: per-cell `T_n` / `D` / `Q`
2933    // contractions, q-row override, IFT to `a_u` / `a_uv`, observed-point
2934    // assembly to `bar_e_u` / `bar_e_uv`, probit Mills, and the final
2935    // `out_grad` / `out_hess` writes. It takes the same
2936    // `BmsFlexRowKernelInputs` struct so a CUDA-equipped host can run both
2937    // paths off one bundle and check element-wise parity.
2938    //
2939    // Used by the GPU↔CPU parity test below; the test skips on non-Linux
2940    // hosts via cfg, but the oracle itself is platform-independent so the
2941    // macOS lib build can still type-check it.
2942
2943    pub(crate) const ORACLE_INV_TWO_PI: f64 = 1.0 / std::f64::consts::TAU;
2944    pub(crate) const ORACLE_SQRT_2: f64 = std::f64::consts::SQRT_2;
2945    pub(crate) const ORACLE_INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
2946
2947    pub(crate) fn oracle_erfcx_nonnegative(x: f64) -> f64 {
2948        if !x.is_finite() {
2949            return if x > 0.0 { 0.0 } else { f64::INFINITY };
2950        }
2951        if x <= 0.0 {
2952            return 1.0;
2953        }
2954        if x < 26.0 {
2955            let mut xx = x * x;
2956            if xx > 700.0 {
2957                xx = 700.0;
2958            }
2959            return xx.exp() * gam_gpu::numerics_host::erfc(x);
2960        }
2961        let inv = 1.0 / x;
2962        let inv2 = inv * inv;
2963        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
2964            + 6.5625 * inv2 * inv2 * inv2 * inv2;
2965        let inv_sqrt_pi: f64 = 0.564_189_583_547_756_3;
2966        inv * poly * inv_sqrt_pi
2967    }
2968
2969    pub(crate) fn oracle_log_ndtr_and_mills(x: f64) -> (f64, f64) {
2970        if x == f64::INFINITY {
2971            return (0.0, 0.0);
2972        }
2973        if x == f64::NEG_INFINITY {
2974            return (f64::NEG_INFINITY, f64::INFINITY);
2975        }
2976        if x.is_nan() {
2977            return (x, x);
2978        }
2979        // Single-algorithm region around and above 0. Both `log Φ(x)` and the
2980        // Mills ratio `φ/Φ` are computed from the SAME `erfc(-x/√2)` call, so
2981        // the oracle is C¹ across the x=0 seam (the prior split — erfcx-based
2982        // `-u²+ln(0.5·e^{u²}·erfc u)` for x<0 vs direct `ln(0.5·erfc(-x))` for
2983        // x≥0 — used two distinct float algorithms whose ~1e-7 disagreement at
2984        // x=0 corrupted a finite-difference reference straddling the seam, #838).
2985        //
2986        // The erfcx form is mathematically identical (e^{u²}·erfc(u)=erfcx(u),
2987        // so −u²+ln(0.5·erfcx u)=ln(0.5·erfc(−x/√2))); it is only needed deep in
2988        // the left tail (x ≲ −38), where `erfc(-x/√2)` underflows to 0 and the
2989        // exp/ln cancellation is the *only* way to keep `log Φ` finite. We move
2990        // the branch there, far from any region the kernel or its FD lock visits.
2991        const ORACLE_LEFT_TAIL_X: f64 = -37.0;
2992        if x >= ORACLE_LEFT_TAIL_X {
2993            let mut cdf = 0.5 * gam_gpu::numerics_host::erfc(-x / ORACLE_SQRT_2);
2994            if cdf < 1e-300 {
2995                cdf = 1e-300;
2996            }
2997            if cdf > 1.0 {
2998                cdf = 1.0;
2999            }
3000            let pdf = ORACLE_INV_SQRT_2PI * (-0.5 * x * x).exp();
3001            (cdf.ln(), pdf / cdf)
3002        } else {
3003            let u = -x / ORACLE_SQRT_2;
3004            let mut ex = oracle_erfcx_nonnegative(u);
3005            if ex < 1e-300 {
3006                ex = 1e-300;
3007            }
3008            let log_cdf = -u * u + (0.5 * ex).ln();
3009            let sqrt_2_over_pi: f64 = 0.797_884_560_802_865_4;
3010            (log_cdf, sqrt_2_over_pi / ex)
3011        }
3012    }
3013
3014    /// Same outputs the device kernel writes: `(neglog, grad, hess)` per row.
3015    /// `grad` is row-major `n × r`, `hess` is row-major `n × r × r`.
3016    /// Mirrors `bms_flex_row_kernel` line-for-line so kernel + oracle diverge
3017    /// only if one side breaks parity.
3018    pub(crate) fn cpu_oracle_outputs(
3019        inputs: &BmsFlexRowKernelInputs<'_>,
3020    ) -> BmsFlexRowKernelOutputs {
3021        let n = inputs.n_rows;
3022        let r = inputs.r;
3023        let p_h = inputs.p_h;
3024        let p_w = inputs.p_w;
3025        let mut neglog = vec![0.0_f64; n];
3026        let mut grad = vec![0.0_f64; n * r];
3027        let mut hess = vec![0.0_f64; n * r * r];
3028        let cell_moments_host = match &inputs.cell_moments {
3029            CellMomentsSource::Host(slice) => *slice,
3030            #[cfg(target_os = "linux")]
3031            CellMomentsSource::Device(_) => panic!(
3032                // SAFETY: this CPU oracle is a host-only sanity checker invoked
3033                // exclusively from `#[cfg(test)] mod tests`. The kernel-launch
3034                // path uses `CellMomentsSource::Device(...)`; the oracle must
3035                // never see that variant. Reaching this arm means a test
3036                // mis-wired its fixture — surface it loudly at the call site.
3037                "cpu_oracle_outputs: cell_moments is device-resident; oracle \
3038                 is a host-only sanity checker"
3039            ),
3040        };
3041
3042        for row in 0..n {
3043            // ── per-cell sweep: accumulate F_u, F_au, F_uv, F_a, F_aa.
3044            let mut f_u = vec![0.0_f64; r];
3045            let mut f_au = vec![0.0_f64; r];
3046            let mut f_uv = vec![0.0_f64; r * r];
3047            let mut f_a = 0.0_f64;
3048            let mut f_aa = 0.0_f64;
3049
3050            let cell_lo = inputs.cell_offsets[row] as usize;
3051            let cell_hi = inputs.cell_offsets[row + 1] as usize;
3052            for c in cell_lo..cell_hi {
3053                let c_arr = [
3054                    inputs.cell_c0[c],
3055                    inputs.cell_c1[c],
3056                    inputs.cell_c2[c],
3057                    inputs.cell_c3[c],
3058                ];
3059                let m = &cell_moments_host[c * MOMENT_STRIDE..(c + 1) * MOMENT_STRIDE];
3060
3061                // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
3062                let mut t = [0.0_f64; 7];
3063                for (n_idx, t_slot) in t.iter_mut().enumerate() {
3064                    let mut acc = 0.0_f64;
3065                    for (e, c_e) in c_arr.iter().enumerate() {
3066                        acc = c_e.mul_add(m[e + n_idx], acc);
3067                    }
3068                    *t_slot = acc * ORACLE_INV_TWO_PI;
3069                }
3070
3071                let d_of = |r_arr: &[f64]| -> f64 {
3072                    ORACLE_INV_TWO_PI
3073                        * (r_arr[0] * m[0] + r_arr[1] * m[1] + r_arr[2] * m[2] + r_arr[3] * m[3])
3074                };
3075                let q_of = |r_arr: &[f64], s_arr: &[f64]| -> f64 {
3076                    (r_arr[0] * s_arr[0]) * t[0]
3077                        + (r_arr[0] * s_arr[1] + r_arr[1] * s_arr[0]) * t[1]
3078                        + (r_arr[0] * s_arr[2] + r_arr[1] * s_arr[1] + r_arr[2] * s_arr[0]) * t[2]
3079                        + (r_arr[0] * s_arr[3]
3080                            + r_arr[1] * s_arr[2]
3081                            + r_arr[2] * s_arr[1]
3082                            + r_arr[3] * s_arr[0])
3083                            * t[3]
3084                        + (r_arr[1] * s_arr[3] + r_arr[2] * s_arr[2] + r_arr[3] * s_arr[1]) * t[4]
3085                        + (r_arr[2] * s_arr[3] + r_arr[3] * s_arr[2]) * t[5]
3086                        + (r_arr[3] * s_arr[3]) * t[6]
3087                };
3088
3089                let a_c = &inputs.cell_a[c * 4..(c + 1) * 4];
3090                let aa_c = &inputs.cell_aa[c * 4..(c + 1) * 4];
3091                f_a += d_of(a_c);
3092                f_aa += d_of(aa_c) - q_of(a_c, a_c);
3093
3094                for u in 1..r {
3095                    let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3096                    let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3097                    let ar_u = &inputs.cell_ar[r_u_off..r_u_off + 4];
3098                    f_u[u] += d_of(r_u);
3099                    f_au[u] += d_of(ar_u) - q_of(a_c, r_u);
3100                }
3101
3102                for u in 1..r {
3103                    let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3104                    let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3105                    for v in u..r {
3106                        let r_v_off = (c * (r - 1) + (v - 1)) * 4;
3107                        let r_v = &inputs.cell_r[r_v_off..r_v_off + 4];
3108                        let q_uv = q_of(r_u, r_v);
3109                        let d_s = if u == 1 && v == 1 {
3110                            let s_bb = &inputs.cell_sbb[c * 4..(c + 1) * 4];
3111                            d_of(s_bb)
3112                        } else if u == 1 && v >= 2 && v < 2 + p_h {
3113                            let j = v - 2;
3114                            let off = (c * p_h + j) * 4;
3115                            let s_bh = &inputs.cell_sbh[off..off + 4];
3116                            d_of(s_bh)
3117                        } else if u == 1 && v >= 2 + p_h && v < r {
3118                            let l = v - (2 + p_h);
3119                            let off = (c * p_w + l) * 4;
3120                            let s_bw = &inputs.cell_sbw[off..off + 4];
3121                            d_of(s_bw)
3122                        } else {
3123                            0.0
3124                        };
3125                        f_uv[u * r + v] += d_s - q_uv;
3126                    }
3127                }
3128            }
3129
3130            // q-row overrides (mirror kernel lines 691–700).
3131            let mu_1 = inputs.mu_1[row];
3132            let mu_2 = inputs.mu_2[row];
3133            f_u[0] = -mu_1;
3134            f_au[0] = 0.0;
3135            for v in 0..r {
3136                f_uv[v] = 0.0;
3137                f_uv[v * r] = 0.0;
3138            }
3139            f_uv[0] = -mu_2;
3140
3141            // Degenerate F_a ⇒ NaN-fill (mirror kernel lines 703–706).
3142            if !f_a.is_finite() || f_a <= 0.0 {
3143                neglog[row] = f64::NAN;
3144                for slot in grad[row * r..(row + 1) * r].iter_mut() {
3145                    *slot = f64::NAN;
3146                }
3147                for slot in hess[row * r * r..(row + 1) * r * r].iter_mut() {
3148                    *slot = f64::NAN;
3149                }
3150                continue;
3151            }
3152            let inv_fa = 1.0 / f_a;
3153
3154            // IFT first/second order.
3155            let mut a_u = vec![0.0_f64; r];
3156            a_u[0] = mu_1 * inv_fa;
3157            for u in 1..r {
3158                a_u[u] = -f_u[u] * inv_fa;
3159            }
3160            let mut a_uv = vec![0.0_f64; r * r];
3161            for u in 0..r {
3162                for v in u..r {
3163                    let term = f_uv[u * r + v]
3164                        + f_au[v] * a_u[u]
3165                        + f_au[u] * a_u[v]
3166                        + f_aa * a_u[u] * a_u[v];
3167                    let val = -term * inv_fa;
3168                    a_uv[u * r + v] = val;
3169                    a_uv[v * r + u] = val;
3170                }
3171            }
3172
3173            // Observed predictor jets.
3174            let chi = inputs.chi_obs[row];
3175            let xi = inputs.xi_obs[row];
3176            let rho = &inputs.rho_u[row * r..(row + 1) * r];
3177            let tau = &inputs.tau_u[row * r..(row + 1) * r];
3178            let ruv = &inputs.r_uv[row * r * r..(row + 1) * r * r];
3179            let mut bar_e_u = vec![0.0_f64; r];
3180            for u in 0..r {
3181                bar_e_u[u] = chi * a_u[u] + rho[u];
3182            }
3183            let mut bar_e_uv = vec![0.0_f64; r * r];
3184            for u in 0..r {
3185                for v in u..r {
3186                    let val = chi * a_uv[u * r + v]
3187                        + xi * a_u[u] * a_u[v]
3188                        + tau[u] * a_u[v]
3189                        + a_u[u] * tau[v]
3190                        + ruv[u * r + v];
3191                    bar_e_uv[u * r + v] = val;
3192                    if u != v {
3193                        bar_e_uv[v * r + u] = val;
3194                    }
3195                }
3196            }
3197
3198            // Probit Mills + final writes.
3199            let y = inputs.y[row];
3200            let w = inputs.w[row];
3201            let s = 2.0 * y - 1.0;
3202            // #415 parity: the observed predictor VALUE is packed directly
3203            // (`inputs.e_obs`), matching the CPU family's `signed_margin =
3204            // s_y * eta_val`. `bar_e_u[0]` is the u=0 first-derivative jet and
3205            // is used only for the gradient/Hessian, never as the Mills margin.
3206            let e_obs = inputs.e_obs[row];
3207            let m_arg = s * e_obs;
3208            let (log_cdf, lambda) = oracle_log_ndtr_and_mills(m_arg);
3209            let a_i = -w * s * lambda;
3210            let b_i = w * lambda * (m_arg + lambda);
3211            neglog[row] = -w * log_cdf;
3212            for u in 0..r {
3213                grad[row * r + u] = a_i * bar_e_u[u];
3214            }
3215            for u in 0..r {
3216                for v in u..r {
3217                    let val = b_i * bar_e_u[u] * bar_e_u[v] + a_i * bar_e_uv[u * r + v];
3218                    hess[row * r * r + u * r + v] = val;
3219                    if u != v {
3220                        hess[row * r * r + v * r + u] = val;
3221                    }
3222                }
3223            }
3224        }
3225
3226        BmsFlexRowKernelOutputs { neglog, grad, hess }
3227    }
3228
3229    // #415 parity lock. This test lives HERE (a descendant of `bms::gpu::row`)
3230    // rather than in `bms::row_primary_hessian` because the host oracle
3231    // `cpu_oracle_outputs` must live in a PRIVATE `#[cfg(test)]` mod (the
3232    // build.rs ban-scanner forbids `#[cfg(test)]` on a non-private mod), so a
3233    // sibling module cannot reach it. Nested here, the test sees the private
3234    // oracle directly while the packer/CPU-family methods it drives are
3235    // `pub(in crate::bms)` and stay reachable. The nested module carries no
3236    // `#[cfg(test)]` attribute of its own (it inherits the parent's), so it is
3237    // ban-scanner-clean.
3238    mod parity_415 {
3239        //! #415 parity lock: the GPU-host oracle `cpu_oracle_outputs` (which
3240        //! GATES the device row kernel via
3241        //! `bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available`) must
3242        //! reproduce the CPU family reference
3243        //! `compute_row_analytic_flex_from_parts_into` element-for-element, from
3244        //! ONE fitted `(family, block_states, cache)`.
3245        //!
3246        //! Before this test the only ties between the two were (1) an FD lock on
3247        //! the outer scalar Mills layer and (2) a string-contains comment guard —
3248        //! neither pins the cell-contraction algebra (`F_a`, `F_aa`, `F_au`,
3249        //! `F_uv` → value/grad/Hessian). This closes that gap: the SAME fitted
3250        //! state is packed into `BmsFlexRowKernelInputs` and run through
3251        //! `cpu_oracle_outputs` for all rows, and independently run through the
3252        //! CPU family per row; every row value, full gradient, and full r×r
3253        //! Hessian must agree to ~1e-10.
3254
3255        use super::cpu_oracle_outputs;
3256        use crate::bms::family::*;
3257        use crate::bms::hessian_paths::*;
3258        use crate::bms::{exact_kernel, DeviationBlockConfig, LatentMeasureKind};
3259        use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix};
3260        use gam_problem::{InverseLink, ParameterBlockState, StandardLink};
3261        use ndarray::{Array1, Array2};
3262        use std::sync::{Arc, Mutex};
3263
3264        /// Build a small but REAL flex BMS family in the `StandardNormal`
3265        /// latent-measure branch with BOTH a score-warp (`p_h > 0`) and a
3266        /// link-deviation (`p_w > 0`) block active, plus mixed labels y ∈ {0,1}.
3267        /// Ported from the `gradient_paths` flex oracle fixture so the cache is
3268        /// populated by the production cell-moment assembly (never hand-faked).
3269        fn make_flex_parity_family(
3270            n: usize,
3271        ) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
3272            let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
3273            let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
3274            let cfg = DeviationBlockConfig {
3275                num_internal_knots: 3,
3276                ..DeviationBlockConfig::default()
3277            };
3278            let score_prepared = build_score_warp_deviation_block_from_seed(&score_seed, &cfg)
3279                .expect("build score warp block");
3280            let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
3281                &link_seed, &link_seed, &cfg,
3282            )
3283            .expect("build link deviation block");
3284
3285            // Mixed labels y ∈ {0,1} so both s_y = ±1 Mills branches are exercised.
3286            let y: Array1<f64> =
3287                Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
3288            let weights: Array1<f64> =
3289                Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
3290            let z: Array1<f64> =
3291                Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
3292            let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
3293                if j == 0 {
3294                    1.0
3295                } else {
3296                    -0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
3297                }
3298            });
3299            let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
3300                if j == 0 {
3301                    1.0
3302                } else {
3303                    0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
3304                }
3305            });
3306
3307            let family = BernoulliMarginalSlopeFamily {
3308                y: Arc::new(y),
3309                weights: Arc::new(weights),
3310                z: Arc::new(z.clone()),
3311                latent_measure: LatentMeasureKind::StandardNormal,
3312                gaussian_frailty_sd: Some(0.15),
3313                base_link: InverseLink::Standard(StandardLink::Probit),
3314                marginal_design: DesignMatrix::Dense(DenseDesignMatrix::from(marginal_x.clone())),
3315                logslope_design: DesignMatrix::Dense(DenseDesignMatrix::from(logslope_x.clone())),
3316                score_warp: Some(score_prepared.runtime.clone()),
3317                link_dev: Some(link_prepared.runtime.clone()),
3318                policy: gam_runtime::resource::ResourcePolicy::default_library(),
3319                cell_moment_lru: Arc::new(exact_kernel::CellMomentLruCache::new(1024)),
3320                cell_moment_cache_stats: Arc::new(exact_kernel::CellMomentCacheStats::default()),
3321                intercept_warm_starts: None,
3322                auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
3323                auto_subsample_last_rho: Arc::new(Mutex::new(None)),
3324            };
3325
3326            let beta_m = Array1::from_vec(vec![0.12, -0.04]);
3327            let beta_g = Array1::from_vec(vec![0.35, 0.03]);
3328            let beta_h = Array1::from_iter(
3329                (0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
3330            );
3331            let beta_w = Array1::from_iter(
3332                (0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
3333            );
3334            let states = vec![
3335                ParameterBlockState {
3336                    eta: marginal_x.dot(&beta_m),
3337                    beta: beta_m,
3338                },
3339                ParameterBlockState {
3340                    eta: logslope_x.dot(&beta_g),
3341                    beta: beta_g,
3342                },
3343                ParameterBlockState {
3344                    beta: beta_h,
3345                    eta: Array1::zeros(z.len()),
3346                },
3347                ParameterBlockState {
3348                    beta: beta_w,
3349                    eta: Array1::zeros(z.len()),
3350                },
3351            ];
3352            (family, states)
3353        }
3354
3355        /// The non-vacuous #415 lock: pack once, run the host oracle for all
3356        /// rows, run the CPU family per row, and assert value/gradient/Hessian
3357        /// parity.
3358        #[test]
3359        fn cpu_oracle_matches_cpu_family_row_analytic_flex_415() {
3360            let n = 12usize;
3361            let (family, states) = make_flex_parity_family(n);
3362            let cache = family
3363                .build_exact_eval_cache(&states)
3364                .expect("flex exact eval cache");
3365
3366            // Preconditions that make the lock non-vacuous: the row-cell-moments
3367            // bundle (which the oracle consumes) must actually be materialised,
3368            // and the deviation blocks must both be present so p_h > 0 AND p_w > 0.
3369            assert!(
3370                cache.row_cell_moments.is_some(),
3371                "#415 fixture must materialise the row-cell-moments bundle; the pack \
3372                 and both compared paths read it"
3373            );
3374            let primary = &cache.primary;
3375            let r = primary.total;
3376            let p_h = primary.h.as_ref().map(|range| range.len()).unwrap_or(0);
3377            let p_w = primary.w.as_ref().map(|range| range.len()).unwrap_or(0);
3378            assert!(p_h > 0 && p_w > 0, "#415 fixture must be full-flex: p_h={p_h} p_w={p_w}");
3379            assert_eq!(r, 2 + p_h + p_w, "#415 fixture primary layout");
3380
3381            // Pack the SAME fitted state the CPU family will consume, then run the
3382            // GPU host oracle over every row.
3383            let owned = family
3384                .pack_bms_flex_row_kernel_inputs(&states, &cache)
3385                .expect("pack must not error")
3386                .expect("pack must succeed for the StandardNormal full-flex fixture");
3387            let inputs = owned.as_borrowed();
3388            let oracle = cpu_oracle_outputs(&inputs);
3389            assert_eq!(oracle.neglog.len(), n);
3390            assert_eq!(oracle.grad.len(), n * r);
3391            assert_eq!(oracle.hess.len(), n * r * r);
3392
3393            // Both sides are exact f64 CPU math over the SAME cached moments, so
3394            // the only slack is FP summation ordering. Anything looser than this
3395            // would hide a real algebraic drift.
3396            let tol_abs = 1e-9_f64;
3397            let tol_rel = 1e-10_f64;
3398
3399            let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(r);
3400            let mut max_rel = 0.0_f64;
3401            let mut checked_labels = [false, false];
3402
3403            for row in 0..n {
3404                let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
3405                let row_moments = cache
3406                    .row_cell_moments
3407                    .as_ref()
3408                    .and_then(|bundle| bundle.row(row, 9));
3409                assert!(
3410                    row_moments.is_some(),
3411                    "row {row} must carry degree-9 cell moments (the oracle reads them)"
3412                );
3413                let label = family.y[row] as usize;
3414                if label < 2 {
3415                    checked_labels[label] = true;
3416                }
3417
3418                let value = family
3419                    .compute_row_analytic_flex_into_with_moments(
3420                        row,
3421                        &states,
3422                        primary,
3423                        row_ctx,
3424                        row_moments,
3425                        cache.cell_family_forest.as_ref(),
3426                        true,
3427                        &mut scratch,
3428                    )
3429                    .expect("cpu family row analytic flex");
3430
3431                // ── value ────────────────────────────────────────────────────
3432                let o_val = oracle.neglog[row];
3433                if o_val.is_nan() || value.is_nan() {
3434                    assert!(
3435                        o_val.is_nan() && value.is_nan(),
3436                        "row {row}: NaN parity broke — oracle={o_val} family={value}"
3437                    );
3438                    continue;
3439                }
3440                let vd = (o_val - value).abs();
3441                let vtol = tol_abs + tol_rel * o_val.abs();
3442                max_rel = max_rel.max(vd / o_val.abs().max(1.0));
3443                assert!(
3444                    vd <= vtol,
3445                    "row {row} value drift: oracle={o_val:.17e} family={value:.17e} \
3446                     |Δ|={vd:.3e} > tol={vtol:.3e}"
3447                );
3448
3449                // ── gradient ─────────────────────────────────────────────────
3450                for u in 0..r {
3451                    let o_g = oracle.grad[row * r + u];
3452                    let f_g = scratch.grad[u];
3453                    let gd = (o_g - f_g).abs();
3454                    let gtol = tol_abs + tol_rel * o_g.abs();
3455                    max_rel = max_rel.max(gd / o_g.abs().max(1.0));
3456                    assert!(
3457                        gd <= gtol,
3458                        "row {row} grad[{u}] drift: oracle={o_g:.17e} family={f_g:.17e} \
3459                         |Δ|={gd:.3e} > tol={gtol:.3e}"
3460                    );
3461                }
3462
3463                // ── full r×r Hessian ─────────────────────────────────────────
3464                for u in 0..r {
3465                    for v in 0..r {
3466                        let o_h = oracle.hess[row * r * r + u * r + v];
3467                        let f_h = scratch.hess[[u, v]];
3468                        let hd = (o_h - f_h).abs();
3469                        let htol = tol_abs + tol_rel * o_h.abs();
3470                        max_rel = max_rel.max(hd / o_h.abs().max(1.0));
3471                        assert!(
3472                            hd <= htol,
3473                            "row {row} hess[{u},{v}] drift: oracle={o_h:.17e} \
3474                             family={f_h:.17e} |Δ|={hd:.3e} > tol={htol:.3e}"
3475                        );
3476                    }
3477                }
3478            }
3479
3480            // Edge coverage: both label branches must have been exercised (the
3481            // q-row overrides F_q=-mu_1 / F_qq=-mu_2 and both Mills sign branches).
3482            assert!(
3483                checked_labels[0] && checked_labels[1],
3484                "#415 fixture must exercise both y=0 and y=1 rows: {checked_labels:?}"
3485            );
3486            eprintln!(
3487                "#415 parity lock: n={n} r={r} p_h={p_h} p_w={p_w} max_rel(oracle−family)={max_rel:.3e}"
3488            );
3489        }
3490    }
3491}
3492
3493#[cfg(all(test, target_os = "linux"))]
3494mod tests {
3495    use super::*;
3496    use super::oracle_parity_tests::*;
3497
3498    pub(crate) fn minimal_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3499        BmsFlexRowKernelInputs {
3500            n_rows: 1,
3501            r: 4,
3502            p_h: 1,
3503            p_w: 1,
3504            q: &buffers.q,
3505            b: &buffers.b,
3506            mu_1: &buffers.mu_1,
3507            mu_2: &buffers.mu_2,
3508            z_obs: &buffers.z_obs,
3509            y: &buffers.y,
3510            w: &buffers.w,
3511            e_obs: &buffers.e_obs,
3512            s_f: 1.0,
3513            cell_offsets: &buffers.cell_offsets,
3514            cell_c0: &buffers.cell_c0,
3515            cell_c1: &buffers.cell_c1,
3516            cell_c2: &buffers.cell_c2,
3517            cell_c3: &buffers.cell_c3,
3518            cell_a: &buffers.cell_a,
3519            cell_aa: &buffers.cell_aa,
3520            cell_r: &buffers.cell_r,
3521            cell_ar: &buffers.cell_ar,
3522            cell_sbb: &buffers.cell_sbb,
3523            cell_sbh: &buffers.cell_sbh,
3524            cell_sbw: &buffers.cell_sbw,
3525            cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3526            chi_obs: &buffers.chi_obs,
3527            xi_obs: &buffers.xi_obs,
3528            rho_u: &buffers.rho_u,
3529            tau_u: &buffers.tau_u,
3530            r_uv: &buffers.r_uv,
3531        }
3532    }
3533
3534    pub(crate) struct TestBuffers {
3535        pub(crate) q: Vec<f64>,
3536        pub(crate) b: Vec<f64>,
3537        pub(crate) mu_1: Vec<f64>,
3538        pub(crate) mu_2: Vec<f64>,
3539        pub(crate) z_obs: Vec<f64>,
3540        pub(crate) y: Vec<f64>,
3541        pub(crate) w: Vec<f64>,
3542        pub(crate) e_obs: Vec<f64>,
3543        pub(crate) cell_offsets: Vec<u32>,
3544        pub(crate) cell_c0: Vec<f64>,
3545        pub(crate) cell_c1: Vec<f64>,
3546        pub(crate) cell_c2: Vec<f64>,
3547        pub(crate) cell_c3: Vec<f64>,
3548        pub(crate) cell_a: Vec<f64>,
3549        pub(crate) cell_aa: Vec<f64>,
3550        pub(crate) cell_r: Vec<f64>,
3551        pub(crate) cell_ar: Vec<f64>,
3552        pub(crate) cell_sbb: Vec<f64>,
3553        pub(crate) cell_sbh: Vec<f64>,
3554        pub(crate) cell_sbw: Vec<f64>,
3555        pub(crate) cell_moments: Vec<f64>,
3556        pub(crate) chi_obs: Vec<f64>,
3557        pub(crate) xi_obs: Vec<f64>,
3558        pub(crate) rho_u: Vec<f64>,
3559        pub(crate) tau_u: Vec<f64>,
3560        pub(crate) r_uv: Vec<f64>,
3561    }
3562
3563    pub(crate) fn make_buffers(n_cells: u32, r: usize, p_h: usize, p_w: usize) -> TestBuffers {
3564        let cells = n_cells as usize;
3565        TestBuffers {
3566            q: vec![0.1; 1],
3567            b: vec![0.5; 1],
3568            mu_1: vec![0.3; 1],
3569            mu_2: vec![0.07; 1],
3570            z_obs: vec![0.0; 1],
3571            y: vec![1.0; 1],
3572            w: vec![1.0; 1],
3573            e_obs: vec![0.15; 1],
3574            cell_offsets: vec![0, n_cells],
3575            cell_c0: vec![0.2; cells],
3576            cell_c1: vec![-0.1; cells],
3577            cell_c2: vec![0.05; cells],
3578            cell_c3: vec![-0.02; cells],
3579            cell_a: vec![0.1; cells * 4],
3580            cell_aa: vec![0.0; cells * 4],
3581            cell_r: vec![0.05; cells * (r - 1) * 4],
3582            cell_ar: vec![0.0; cells * (r - 1) * 4],
3583            cell_sbb: vec![0.0; cells * 4],
3584            cell_sbh: vec![0.0; cells * p_h * 4],
3585            cell_sbw: vec![0.0; cells * p_w * 4],
3586            cell_moments: vec![1.0; cells * MOMENT_STRIDE],
3587            chi_obs: vec![1.0; 1],
3588            xi_obs: vec![0.0; 1],
3589            rho_u: vec![0.0; r],
3590            tau_u: vec![0.0; r],
3591            r_uv: vec![0.0; r * r],
3592        }
3593    }
3594
3595    #[test]
3596    pub(crate) fn validate_accepts_minimal_inputs() {
3597        let buffers = make_buffers(2, 4, 1, 1);
3598        let inputs = minimal_inputs(&buffers);
3599        assert!(inputs.validate().is_ok());
3600    }
3601
3602    #[test]
3603    pub(crate) fn validate_rejects_r_above_max() {
3604        let r = MAX_R + 1;
3605        let p_h = (r - 2) / 2;
3606        let p_w = (r - 2) - p_h;
3607        let buffers = make_buffers(1, r, p_h, p_w);
3608        let bad_inputs = BmsFlexRowKernelInputs {
3609            r,
3610            p_h,
3611            p_w,
3612            rho_u: &buffers.rho_u, // length matches `r` we wrote
3613            tau_u: &buffers.tau_u,
3614            r_uv: &buffers.r_uv,
3615            cell_r: &buffers.cell_r,
3616            cell_ar: &buffers.cell_ar,
3617            cell_sbh: &buffers.cell_sbh,
3618            cell_sbw: &buffers.cell_sbw,
3619            ..minimal_inputs(&buffers)
3620        };
3621        let err = bad_inputs.validate().expect_err("r > MAX_R must fail");
3622        let msg = err.to_string();
3623        assert!(msg.contains("MAX_R"), "expected MAX_R hint, got: {msg}");
3624    }
3625
3626    #[test]
3627    pub(crate) fn validate_rejects_mismatched_r_decomposition() {
3628        let buffers = make_buffers(1, 4, 1, 1);
3629        let bad_inputs = BmsFlexRowKernelInputs {
3630            r: 4,
3631            p_h: 1,
3632            p_w: 2, // inconsistent with r = 4
3633            ..minimal_inputs(&buffers)
3634        };
3635        let err = bad_inputs
3636            .validate()
3637            .expect_err("inconsistent r vs p_h+p_w must fail");
3638        let msg = err.to_string();
3639        assert!(msg.contains("p_h"), "got: {msg}");
3640        assert!(msg.contains("p_w"), "got: {msg}");
3641    }
3642
3643    #[test]
3644    pub(crate) fn validate_rejects_non_monotone_offsets() {
3645        // `minimal_inputs` hard-codes `n_rows = 1`, so the CSR-style row
3646        // pointer length is `n + 1 = 2`. Pin both `offsets[1] = total_cells`
3647        // and `cell_c0.len() = total_cells = 2` from `make_buffers(2, …)`,
3648        // then violate monotonicity by setting `offsets[0] > offsets[1]`;
3649        // every length / per-cell-count check is satisfied so the only
3650        // failure mode left is the monotonicity guard.
3651        let mut buffers = make_buffers(2, 4, 1, 1);
3652        buffers.cell_offsets = vec![5, 2];
3653        let inputs = minimal_inputs(&buffers);
3654        let err = inputs
3655            .validate()
3656            .expect_err("non-monotone offsets must fail");
3657        let msg = err.to_string();
3658        assert!(msg.contains("monotone"), "got: {msg}");
3659    }
3660
3661    #[test]
3662    pub(crate) fn validate_rejects_mismatched_cell_moments_length() {
3663        let mut buffers = make_buffers(2, 4, 1, 1);
3664        buffers.cell_moments.pop(); // length now 2*10 - 1
3665        let inputs = minimal_inputs(&buffers);
3666        let err = inputs.validate().expect_err("short cell_moments must fail");
3667        let msg = err.to_string();
3668        assert!(msg.contains("cell_moments"), "got: {msg}");
3669    }
3670
3671    #[test]
3672    pub(crate) fn launch_on_non_linux_reports_driver_library_unavailable() {
3673        // Mac/Windows builds must surface a typed `DriverLibraryUnavailable`
3674        // rather than panicking or returning Ok. On Linux this test is
3675        // skipped because the kernel actually launches.
3676        #[cfg(target_os = "linux")]
3677        {
3678            // Linux builds may or may not have a device; the dispatcher
3679            // contract is that without a runtime, probe() returns
3680            // DriverLibraryUnavailable. Either outcome (NoDeviceKernel,
3681            // DriverLibraryUnavailable, or DriverCallFailed) is acceptable
3682            // here; success would mean the kernel actually ran which is a
3683            // V100-only outcome we don't gate the unit test on.
3684            let buffers = make_buffers(1, 4, 1, 1);
3685            let inputs = minimal_inputs(&buffers);
3686            match launch_bms_flex_row_kernel(inputs) {
3687                Ok(_) => { /* V100 host: real launch */ }
3688                Err(GpuError::DriverLibraryUnavailable { .. })
3689                | Err(GpuError::DriverCallFailed { .. })
3690                | Err(GpuError::DriverSymbolMissing { .. })
3691                | Err(GpuError::NoDeviceKernel { .. }) => { /* expected on CPU-only */ }
3692                Err(other) => panic!("unexpected GpuError variant: {other:?}"),
3693            }
3694        }
3695        #[cfg(not(target_os = "linux"))]
3696        {
3697            let buffers = make_buffers(1, 4, 1, 1);
3698            let inputs = minimal_inputs(&buffers);
3699            match launch_bms_flex_row_kernel(inputs) {
3700                Err(GpuError::DriverLibraryUnavailable { reason }) => {
3701                    assert!(
3702                        reason.contains("Linux-only"),
3703                        "expected Linux-only hint, got: {reason}"
3704                    );
3705                }
3706                other => panic!("expected DriverLibraryUnavailable on non-Linux, got {other:?}"),
3707            }
3708        }
3709    }
3710
3711    #[test]
3712    pub(crate) fn s_f_must_be_positive_and_finite() {
3713        let buffers = make_buffers(1, 4, 1, 1);
3714        let mut inputs = minimal_inputs(&buffers);
3715        inputs.s_f = 0.0;
3716        match launch_bms_flex_row_kernel(inputs) {
3717            Err(GpuError::DriverCallFailed { reason }) => {
3718                assert!(reason.contains("s_f"), "got: {reason}");
3719            }
3720            other => panic!("expected DriverCallFailed for s_f=0, got {other:?}"),
3721        }
3722    }
3723
3724
3725    /// Build a non-trivial fixture: `n = 4` rows, `r = 5` (p_h = 2, p_w = 1),
3726    /// 2–4 cells per row, distinct values so a structural bug in either path
3727    /// can't be masked by accidental cancellation.
3728    pub(crate) fn make_parity_buffers() -> TestBuffers {
3729        let n = 4_usize;
3730        let r = 5_usize;
3731        let p_h = 2_usize;
3732        let p_w = 1_usize;
3733        // Per-row cell counts: 2, 3, 4, 2 → total 11 cells.
3734        let row_cells: [u32; 4] = [2, 3, 4, 2];
3735        let mut cell_offsets = vec![0_u32; n + 1];
3736        for i in 0..n {
3737            cell_offsets[i + 1] = cell_offsets[i] + row_cells[i];
3738        }
3739        let total_cells = cell_offsets[n] as usize;
3740
3741        // Deterministic but varied generators (LCG-ish so each slot is distinct).
3742        let f = |seed: usize| -> f64 {
3743            let x = ((seed.wrapping_mul(2_654_435_761)) & 0xFFFF) as f64 / 65_536.0;
3744            0.1 + 0.4 * x
3745        };
3746
3747        let q = (0..n).map(|i| 0.05 + 0.1 * (i as f64)).collect::<Vec<_>>();
3748        let b = (0..n).map(|i| 0.6 + 0.05 * (i as f64)).collect::<Vec<_>>();
3749        let mu_1 = (0..n).map(|i| 0.7 + 0.02 * (i as f64)).collect::<Vec<_>>();
3750        let mu_2 = (0..n).map(|i| 0.15 + 0.01 * (i as f64)).collect::<Vec<_>>();
3751        let z_obs = (0..n).map(|i| -0.2 + 0.1 * (i as f64)).collect::<Vec<_>>();
3752        let y = [1.0, 0.0, 1.0, 0.0].to_vec();
3753        let w = vec![1.0; n];
3754        let e_obs = (0..n).map(|i| -0.3 + 0.2 * (i as f64)).collect::<Vec<_>>();
3755
3756        let cell_c0 = (0..total_cells).map(|c| f(c + 1001)).collect::<Vec<_>>();
3757        let cell_c1 = (0..total_cells)
3758            .map(|c| -f(c + 2002) * 0.5)
3759            .collect::<Vec<_>>();
3760        let cell_c2 = (0..total_cells).map(|c| f(c + 3003) * 0.2).collect();
3761        let cell_c3 = (0..total_cells).map(|c| -f(c + 4004) * 0.1).collect();
3762
3763        let cell_a = (0..total_cells * 4)
3764            .map(|i| f(i + 5005) * 0.3)
3765            .collect::<Vec<_>>();
3766        let cell_aa = (0..total_cells * 4)
3767            .map(|i| f(i + 6006) * 0.1)
3768            .collect::<Vec<_>>();
3769        let cell_r = (0..total_cells * (r - 1) * 4)
3770            .map(|i| f(i + 7007) * 0.2)
3771            .collect::<Vec<_>>();
3772        let cell_ar = (0..total_cells * (r - 1) * 4)
3773            .map(|i| f(i + 8008) * 0.05)
3774            .collect::<Vec<_>>();
3775        let cell_sbb = (0..total_cells * 4)
3776            .map(|i| f(i + 9009) * 0.08)
3777            .collect::<Vec<_>>();
3778        let cell_sbh = (0..total_cells * p_h * 4)
3779            .map(|i| f(i + 10_010) * 0.07)
3780            .collect::<Vec<_>>();
3781        let cell_sbw = (0..total_cells * p_w * 4)
3782            .map(|i| f(i + 11_011) * 0.06)
3783            .collect::<Vec<_>>();
3784        let cell_moments = (0..total_cells * MOMENT_STRIDE)
3785            .map(|i| 0.4 + 0.1 * f(i + 12_012))
3786            .collect::<Vec<_>>();
3787
3788        let chi_obs = (0..n).map(|i| 0.9 + 0.01 * (i as f64)).collect::<Vec<_>>();
3789        let xi_obs = (0..n).map(|i| 0.2 + 0.01 * (i as f64)).collect::<Vec<_>>();
3790        let rho_u = (0..n * r).map(|i| 0.03 * f(i + 13_013)).collect::<Vec<_>>();
3791        let tau_u = (0..n * r).map(|i| 0.02 * f(i + 14_014)).collect::<Vec<_>>();
3792        let r_uv = (0..n * r * r)
3793            .map(|i| 0.04 * f(i + 15_015))
3794            .collect::<Vec<_>>();
3795
3796        TestBuffers {
3797            q,
3798            b,
3799            mu_1,
3800            mu_2,
3801            z_obs,
3802            y,
3803            w,
3804            e_obs,
3805            cell_offsets,
3806            cell_c0,
3807            cell_c1,
3808            cell_c2,
3809            cell_c3,
3810            cell_a,
3811            cell_aa,
3812            cell_r,
3813            cell_ar,
3814            cell_sbb,
3815            cell_sbh,
3816            cell_sbw,
3817            cell_moments,
3818            chi_obs,
3819            xi_obs,
3820            rho_u,
3821            tau_u,
3822            r_uv,
3823        }
3824    }
3825
3826    pub(crate) fn parity_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3827        BmsFlexRowKernelInputs {
3828            n_rows: 4,
3829            r: 5,
3830            p_h: 2,
3831            p_w: 1,
3832            q: &buffers.q,
3833            b: &buffers.b,
3834            mu_1: &buffers.mu_1,
3835            mu_2: &buffers.mu_2,
3836            z_obs: &buffers.z_obs,
3837            y: &buffers.y,
3838            w: &buffers.w,
3839            e_obs: &buffers.e_obs,
3840            s_f: 1.0,
3841            cell_offsets: &buffers.cell_offsets,
3842            cell_c0: &buffers.cell_c0,
3843            cell_c1: &buffers.cell_c1,
3844            cell_c2: &buffers.cell_c2,
3845            cell_c3: &buffers.cell_c3,
3846            cell_a: &buffers.cell_a,
3847            cell_aa: &buffers.cell_aa,
3848            cell_r: &buffers.cell_r,
3849            cell_ar: &buffers.cell_ar,
3850            cell_sbb: &buffers.cell_sbb,
3851            cell_sbh: &buffers.cell_sbh,
3852            cell_sbw: &buffers.cell_sbw,
3853            cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3854            chi_obs: &buffers.chi_obs,
3855            xi_obs: &buffers.xi_obs,
3856            rho_u: &buffers.rho_u,
3857            tau_u: &buffers.tau_u,
3858            r_uv: &buffers.r_uv,
3859        }
3860    }
3861
3862    /// Symmetry + finiteness of the CPU oracle. Runs on every host (Linux,
3863    /// macOS, CPU CI) since the oracle is platform-independent. Guarantees the
3864    /// reference path used by the GPU parity test is itself well-formed.
3865    #[test]
3866    pub(crate) fn cpu_oracle_produces_finite_symmetric_hessian() {
3867        let buffers = make_parity_buffers();
3868        let inputs = parity_inputs(&buffers);
3869        inputs
3870            .validate()
3871            .expect("parity fixture must satisfy validate()");
3872        let out = cpu_oracle_outputs(&inputs);
3873        let n = inputs.n_rows;
3874        let r = inputs.r;
3875        assert_eq!(out.neglog.len(), n);
3876        assert_eq!(out.grad.len(), n * r);
3877        assert_eq!(out.hess.len(), n * r * r);
3878        for row in 0..n {
3879            assert!(
3880                out.neglog[row].is_finite(),
3881                "row {row}: neglog must be finite, got {}",
3882                out.neglog[row]
3883            );
3884            for u in 0..r {
3885                let g = out.grad[row * r + u];
3886                assert!(g.is_finite(), "row {row}: grad[{u}] = {g}");
3887                for v in 0..r {
3888                    let huv = out.hess[row * r * r + u * r + v];
3889                    let hvu = out.hess[row * r * r + v * r + u];
3890                    assert!(huv.is_finite(), "row {row}: H[{u},{v}] = {huv}");
3891                    assert_eq!(
3892                        huv.to_bits(),
3893                        hvu.to_bits(),
3894                        "row {row}: H[{u},{v}] and H[{v},{u}] must be bit-identical"
3895                    );
3896                }
3897            }
3898        }
3899    }
3900
3901    /// Independent finite-difference correctness lock on the oracle's probit
3902    /// Mills layer — the most optimizer-sensitive, drift-prone term in the
3903    /// whole row kernel (issue #415: "third/fourth-order derivative
3904    /// contractions drift silently … formulas are complex and
3905    /// optimizer-sensitive"). The device kernel's `bms_flex_row_kernel` and
3906    /// the host oracle's `cpu_oracle_outputs` both close out with the same
3907    /// Mills algebra:
3908    ///
3909    /// ```text
3910    ///     m       = s · e_obs ;  s = 2y − 1
3911    ///     A       = −w · s · λ(m)
3912    ///     B       =  w · λ(m) · (m + λ(m))
3913    ///     neglog  = −w · log Φ(s · e_obs)
3914    ///     g_u     = A · bar_e_u
3915    ///     H_uv    = B · bar_e_u · bar_e_v + A · bar_e_uv
3916    /// ```
3917    ///
3918    /// Holding the observed jets `bar_e` fixed, the row neglog is a function
3919    /// of the single scalar `e := bar_e_u[0] = e_obs`, and by the assembled
3920    /// formula `∂neglog/∂e = A` and `∂²neglog/∂e² = B`. This test reconstructs
3921    /// `A`, `B`, and `neglog` exactly as the oracle does (same
3922    /// `oracle_log_ndtr_and_mills`, same sign convention), then verifies the
3923    /// analytic `A`/`B` against high-order central differences of
3924    /// `e ↦ −w · log Φ(s·e)`. A drift in the kernel's Mills derivatives —
3925    /// which the device-parity test cannot catch because it checks the kernel
3926    /// *against the same (possibly-wrong) oracle* — fails here on every host,
3927    /// CUDA or not. Bounds are the genuine fifth-order central-difference
3928    /// truncation floor; they are not weakened to pass.
3929    #[test]
3930    pub(crate) fn cpu_oracle_mills_layer_matches_finite_differences() {
3931        // Probit neglog as the oracle assembles it, as a function of the
3932        // observed scalar predictor `e` with weight `w` and label `y`.
3933        let neglog_of = |e: f64, y: f64, w: f64| -> f64 {
3934            let s = 2.0 * y - 1.0;
3935            let (log_cdf, _) = oracle_log_ndtr_and_mills(s * e);
3936            -w * log_cdf
3937        };
3938        // Analytic first/second derivatives wrt `e` — the exact `A`/`B` the
3939        // kernel writes into `grad`/`hess`.
3940        let ab_of = |e: f64, y: f64, w: f64| -> (f64, f64) {
3941            let s = 2.0 * y - 1.0;
3942            let m_arg = s * e;
3943            let (_, lambda) = oracle_log_ndtr_and_mills(m_arg);
3944            let a_i = -w * s * lambda;
3945            let b_i = w * lambda * (m_arg + lambda);
3946            (a_i, b_i)
3947        };
3948
3949        // Sweep both labels (s = ±1), both tails of the predictor, and a
3950        // non-unit weight so every sign/scale path of the Mills algebra is
3951        // exercised. Points stay clear of the deep-tail asymptote where a
3952        // central-difference reference loses its own accuracy.
3953        let cases: [(f64, f64, f64); 12] = [
3954            (-1.6, 1.0, 1.0),
3955            (-0.7, 1.0, 1.0),
3956            (0.0, 1.0, 1.0),
3957            (0.9, 1.0, 1.0),
3958            (1.8, 1.0, 1.0),
3959            (-1.4, 0.0, 1.0),
3960            (-0.3, 0.0, 1.0),
3961            (0.0, 0.0, 1.0),
3962            (0.6, 0.0, 1.0),
3963            (1.5, 0.0, 1.0),
3964            (0.4, 1.0, 0.75),
3965            (-0.8, 0.0, 1.3),
3966        ];
3967        // Fifth-order central stencils; `h` chosen near the f64 sweet spot for
3968        // first/second derivatives of a smooth O(1) function.
3969        let h = 1e-3_f64;
3970        for (e, y, w) in cases {
3971            let (a_ana, b_ana) = ab_of(e, y, w);
3972
3973            let fp2 = neglog_of(e + 2.0 * h, y, w);
3974            let fp1 = neglog_of(e + h, y, w);
3975            let f0 = neglog_of(e, y, w);
3976            let fm1 = neglog_of(e - h, y, w);
3977            let fm2 = neglog_of(e - 2.0 * h, y, w);
3978
3979            // 5-point central first derivative: O(h⁴).
3980            let d1_fd = (-fp2 + 8.0 * fp1 - 8.0 * fm1 + fm2) / (12.0 * h);
3981            // 5-point central second derivative: O(h⁴).
3982            let d2_fd = (-fp2 + 16.0 * fp1 - 30.0 * f0 + 16.0 * fm1 - fm2) / (12.0 * h * h);
3983
3984            let a_abs = (a_ana - d1_fd).abs();
3985            let a_rel = a_abs / a_ana.abs().max(1.0);
3986            assert!(
3987                a_abs <= 5e-8 || a_rel <= 5e-8,
3988                "Mills A (∂neglog/∂e) drift at e={e} y={y} w={w}: \
3989                 analytic={a_ana:.17e} fd={d1_fd:.17e} abs={a_abs:.3e} rel={a_rel:.3e}"
3990            );
3991
3992            let b_abs = (b_ana - d2_fd).abs();
3993            let b_rel = b_abs / b_ana.abs().max(1.0);
3994            assert!(
3995                b_abs <= 5e-6 || b_rel <= 5e-6,
3996                "Mills B (∂²neglog/∂e²) drift at e={e} y={y} w={w}: \
3997                 analytic={b_ana:.17e} fd={d2_fd:.17e} abs={b_abs:.3e} rel={b_rel:.3e}"
3998            );
3999        }
4000    }
4001
4002    /// CPU↔GPU parity. Only runs end-to-end on a Linux host with a CUDA
4003    /// runtime; skips with a clear `eprintln!` on every other host so the
4004    /// always-on test suite stays green on the macOS dev box and CPU CI.
4005    ///
4006    /// On a CUDA host: drives the kernel through `launch_bms_flex_row_kernel`
4007    /// and the same `BmsFlexRowKernelInputs` through `cpu_oracle_outputs`,
4008    /// then asserts every element of `neglog`, `grad`, and `hess` agrees
4009    /// within `|Δ| <= 1e-8 + 1e-8·|cpu|` (absolute-or-relative).
4010    #[test]
4011    pub(crate) fn bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available() {
4012        #[cfg(not(target_os = "linux"))]
4013        {
4014            eprintln!(
4015                "[bms_flex_row parity] non-Linux host — skipping CUDA parity \
4016                 (CPU oracle exercised by sibling test)"
4017            );
4018            return;
4019        }
4020        #[cfg(target_os = "linux")]
4021        {
4022            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4023                eprintln!(
4024                    "[bms_flex_row parity] no CUDA runtime — skipping device \
4025                     parity (CPU oracle exercised by sibling test)"
4026                );
4027                return;
4028            };
4029            let buffers = make_parity_buffers();
4030            let inputs_cpu = parity_inputs(&buffers);
4031            inputs_cpu
4032                .validate()
4033                .expect("parity fixture must satisfy validate()");
4034            let cpu_out = cpu_oracle_outputs(&inputs_cpu);
4035
4036            // Launch the device kernel against the same inputs.
4037            let inputs_gpu = parity_inputs(&buffers);
4038            let gpu_out = match launch_bms_flex_row_kernel(inputs_gpu) {
4039                Ok(out) => out,
4040                Err(err) => panic!(
4041                    "[bms_flex_row parity] launch failed on CUDA-selected host; \
4042                     device/oracle parity must fail loudly on GPU CI: {err}"
4043                ),
4044            };
4045
4046            let n = inputs_cpu.n_rows;
4047            let r = inputs_cpu.r;
4048            let tol_abs = 1e-8_f64;
4049            let tol_rel = 1e-8_f64;
4050            let check_close = |label: &str, idx: usize, cpu: f64, gpu: f64| {
4051                if cpu.is_nan() || gpu.is_nan() {
4052                    assert!(
4053                        cpu.is_nan() && gpu.is_nan(),
4054                        "{label}[{idx}]: NaN parity broke — cpu={cpu}, gpu={gpu}"
4055                    );
4056                    return;
4057                }
4058                let diff = (cpu - gpu).abs();
4059                let tol = tol_abs + tol_rel * cpu.abs();
4060                assert!(
4061                    diff <= tol,
4062                    "{label}[{idx}]: |cpu − gpu| = {diff:.3e} > tol = {tol:.3e}; \
4063                     cpu={cpu:.17e}, gpu={gpu:.17e}"
4064                );
4065            };
4066            assert_eq!(cpu_out.neglog.len(), gpu_out.neglog.len());
4067            assert_eq!(cpu_out.grad.len(), gpu_out.grad.len());
4068            assert_eq!(cpu_out.hess.len(), gpu_out.hess.len());
4069            for (i, (&c, &g)) in cpu_out.neglog.iter().zip(gpu_out.neglog.iter()).enumerate() {
4070                check_close("neglog", i, c, g);
4071            }
4072            for (i, (&c, &g)) in cpu_out.grad.iter().zip(gpu_out.grad.iter()).enumerate() {
4073                check_close("grad", i, c, g);
4074            }
4075            for (i, (&c, &g)) in cpu_out.hess.iter().zip(gpu_out.hess.iter()).enumerate() {
4076                check_close("hess", i, c, g);
4077            }
4078            // Spot-check exact symmetry on the GPU Hessian too.
4079            for row in 0..n {
4080                for u in 0..r {
4081                    for v in 0..r {
4082                        let a = gpu_out.hess[row * r * r + u * r + v];
4083                        let bb = gpu_out.hess[row * r * r + v * r + u];
4084                        assert_eq!(
4085                            a.to_bits(),
4086                            bb.to_bits(),
4087                            "GPU row {row}: H[{u},{v}] ≠ H[{v},{u}] bit-for-bit"
4088                        );
4089                    }
4090                }
4091            }
4092        }
4093    }
4094
4095    #[test]
4096    pub(crate) fn kernel_source_mentions_cpu_parity_reference() {
4097        // Guarantee the maintainer-facing parity reference comment survives
4098        // refactors of the NVRTC kernel source — the dispatcher wave that
4099        // wires this to bms_flex.rs cross-checks parity against the CPU
4100        // function named here.
4101        #[cfg(target_os = "linux")]
4102        assert!(ROW_KERNEL_BODY.contains("compute_row_analytic_flex_from_parts_into"));
4103        #[cfg(target_os = "linux")]
4104        assert!(ROW_KERNEL_BODY.contains("cell_first_derivative_from_moments"));
4105    }
4106
4107    // ── Phase-3 HVP / diagonal CPU oracles + GPU parity tests ────────────────
4108
4109    /// CPU oracle for [`launch_bms_flex_row_hvp`]. Mirrors the device kernel
4110    /// element-for-element so the GPU parity test runs against the same algebra.
4111    pub(crate) fn cpu_oracle_bms_flex_row_hvp(
4112        row_hessians: &[f64],
4113        marginal_design: &[f64],
4114        logslope_design: &[f64],
4115        block: &BmsFlexBlockLayout,
4116        primary: &BmsFlexPrimaryLayout,
4117        n: usize,
4118        v: &[f64],
4119    ) -> Vec<f64> {
4120        let r = primary.r;
4121        let p_m = block.p_m;
4122        let p_g = block.p_g;
4123        assert_eq!(v.len(), block.p_total);
4124        assert_eq!(row_hessians.len(), n * r * r);
4125        assert_eq!(marginal_design.len(), n * p_m);
4126        assert_eq!(logslope_design.len(), n * p_g);
4127        let mut out = vec![0.0_f64; block.p_total];
4128        let mut row_dir = vec![0.0_f64; r];
4129        let mut action = vec![0.0_f64; r];
4130        for row in 0..n {
4131            let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
4132            let grow = &logslope_design[row * p_g..(row + 1) * p_g];
4133            let mut acc_q = 0.0_f64;
4134            for j in 0..p_m {
4135                acc_q += mrow[j] * v[j];
4136            }
4137            let mut acc_g = 0.0_f64;
4138            for j in 0..p_g {
4139                acc_g += grow[j] * v[p_m + j];
4140            }
4141            row_dir[0] = acc_q;
4142            row_dir[1] = acc_g;
4143            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4144                for (k, ii) in prange.clone().enumerate() {
4145                    row_dir[ii] = v[brange.start + k];
4146                }
4147            }
4148            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4149                for (k, ii) in prange.clone().enumerate() {
4150                    row_dir[ii] = v[brange.start + k];
4151                }
4152            }
4153            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4154            for u in 0..r {
4155                let mut acc = 0.0_f64;
4156                for v_idx in 0..r {
4157                    acc += h_slice[u * r + v_idx] * row_dir[v_idx];
4158                }
4159                action[u] = acc;
4160            }
4161            let a0 = action[0];
4162            for j in 0..p_m {
4163                out[j] += a0 * mrow[j];
4164            }
4165            let a1 = action[1];
4166            for j in 0..p_g {
4167                out[p_m + j] += a1 * grow[j];
4168            }
4169            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4170                for (k, ii) in prange.clone().enumerate() {
4171                    out[brange.start + k] += action[ii];
4172                }
4173            }
4174            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4175                for (k, ii) in prange.clone().enumerate() {
4176                    out[brange.start + k] += action[ii];
4177                }
4178            }
4179        }
4180        out
4181    }
4182
4183    pub(crate) fn cpu_oracle_bms_flex_row_diagonal(
4184        row_hessians: &[f64],
4185        marginal_design: &[f64],
4186        logslope_design: &[f64],
4187        block: &BmsFlexBlockLayout,
4188        primary: &BmsFlexPrimaryLayout,
4189        n: usize,
4190    ) -> Vec<f64> {
4191        let r = primary.r;
4192        let p_m = block.p_m;
4193        let p_g = block.p_g;
4194        let mut out = vec![0.0_f64; block.p_total];
4195        for row in 0..n {
4196            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4197            let h00 = h_slice[0];
4198            let h11 = h_slice[r + 1];
4199            let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
4200            let grow = &logslope_design[row * p_g..(row + 1) * p_g];
4201            for j in 0..p_m {
4202                out[j] += h00 * mrow[j] * mrow[j];
4203            }
4204            for j in 0..p_g {
4205                out[p_m + j] += h11 * grow[j] * grow[j];
4206            }
4207            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4208                for (k, ii) in prange.clone().enumerate() {
4209                    out[brange.start + k] += h_slice[ii * r + ii];
4210                }
4211            }
4212            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4213                for (k, ii) in prange.clone().enumerate() {
4214                    out[brange.start + k] += h_slice[ii * r + ii];
4215                }
4216            }
4217        }
4218        out
4219    }
4220
4221    /// Hand-construct a small symmetric per-row Hessian + small designs and
4222    /// verify the CPU oracle satisfies the expected algebra. Platform-
4223    /// independent (runs on macOS / Linux without CUDA).
4224    #[test]
4225    pub(crate) fn cpu_oracle_hvp_matches_hand_computation_no_hw() {
4226        let n = 4_usize;
4227        let r = 4_usize; // q, logslope, h(1), w(1)
4228        let p_m = 2_usize;
4229        let p_g = 2_usize;
4230        let p_h_dim = 1_usize;
4231        let p_w_dim = 1_usize;
4232        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4233        let block = BmsFlexBlockLayout {
4234            p_m,
4235            p_g,
4236            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4237            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4238            p_total,
4239        };
4240        let primary = BmsFlexPrimaryLayout {
4241            h: Some(2..3),
4242            w: Some(3..4),
4243            r,
4244        };
4245        // Symmetric per-row Hessian: H_row[u,v] = (row + 1) * (1 + u + 2v) symmetrised.
4246        let mut row_hessians = vec![0.0_f64; n * r * r];
4247        for row in 0..n {
4248            for u in 0..r {
4249                for v in u..r {
4250                    let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4251                    row_hessians[row * r * r + u * r + v] = val;
4252                    row_hessians[row * r * r + v * r + u] = val;
4253                }
4254            }
4255        }
4256        let mut marginal = vec![0.0_f64; n * p_m];
4257        for row in 0..n {
4258            for j in 0..p_m {
4259                marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4260            }
4261        }
4262        let mut logslope = vec![0.0_f64; n * p_g];
4263        for row in 0..n {
4264            for j in 0..p_g {
4265                logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4266            }
4267        }
4268        let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4269        let out = cpu_oracle_bms_flex_row_hvp(
4270            &row_hessians,
4271            &marginal,
4272            &logslope,
4273            &block,
4274            &primary,
4275            n,
4276            &v,
4277        );
4278        // Hand check the first marginal slot: out[0] = Σ_row action[0]·mrow[0].
4279        let mut expect_out_0 = 0.0_f64;
4280        for row in 0..n {
4281            let mrow = &marginal[row * p_m..(row + 1) * p_m];
4282            let grow = &logslope[row * p_g..(row + 1) * p_g];
4283            let mut row_dir = vec![0.0_f64; r];
4284            row_dir[0] = mrow[0] * v[0] + mrow[1] * v[1];
4285            row_dir[1] = grow[0] * v[p_m] + grow[1] * v[p_m + 1];
4286            row_dir[2] = v[p_m + p_g];
4287            row_dir[3] = v[p_m + p_g + p_h_dim];
4288            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4289            let mut action0 = 0.0_f64;
4290            // h_slice is the row-major r×r Hessian for this row; we want
4291            // row 0, i.e. entries (0, vv) for vv in 0..r, which lives at
4292            // `vv` in the flat layout.
4293            for vv in 0..r {
4294                action0 += h_slice[vv] * row_dir[vv];
4295            }
4296            expect_out_0 += action0 * mrow[0];
4297        }
4298        assert!(
4299            (out[0] - expect_out_0).abs() < 1e-12,
4300            "cpu oracle HVP out[0] mismatch: {} vs hand-check {}",
4301            out[0],
4302            expect_out_0
4303        );
4304        assert!(out.iter().all(|x| x.is_finite()));
4305        assert_eq!(out.len(), p_total);
4306    }
4307
4308    /// Diagonal oracle equals the explicit per-row design² accumulator.
4309    #[test]
4310    pub(crate) fn cpu_oracle_diagonal_matches_hand_computation() {
4311        let n = 3_usize;
4312        let r = 4_usize;
4313        let p_m = 2_usize;
4314        let p_g = 2_usize;
4315        let p_h_dim = 1_usize;
4316        let p_w_dim = 1_usize;
4317        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4318        let block = BmsFlexBlockLayout {
4319            p_m,
4320            p_g,
4321            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4322            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4323            p_total,
4324        };
4325        let primary = BmsFlexPrimaryLayout {
4326            h: Some(2..3),
4327            w: Some(3..4),
4328            r,
4329        };
4330        let mut row_hessians = vec![0.0_f64; n * r * r];
4331        for row in 0..n {
4332            for u in 0..r {
4333                row_hessians[row * r * r + u * r + u] = 1.0 + (row as f64) + (u as f64) * 0.5;
4334            }
4335        }
4336        let mut marginal = vec![0.0_f64; n * p_m];
4337        let mut logslope = vec![0.0_f64; n * p_g];
4338        for row in 0..n {
4339            for j in 0..p_m {
4340                marginal[row * p_m + j] = 0.2 + (row as f64) * 0.3 + (j as f64) * 0.1;
4341            }
4342            for j in 0..p_g {
4343                logslope[row * p_g + j] = -0.4 + (row as f64) * 0.1 + (j as f64) * 0.2;
4344            }
4345        }
4346        let out = cpu_oracle_bms_flex_row_diagonal(
4347            &row_hessians,
4348            &marginal,
4349            &logslope,
4350            &block,
4351            &primary,
4352            n,
4353        );
4354        // Hand check: out[0] = Σ_row H[row,0,0] · marginal[row,0]^2.
4355        let mut expect = 0.0_f64;
4356        for row in 0..n {
4357            let h00 = row_hessians[row * r * r];
4358            expect += h00 * marginal[row * p_m].powi(2);
4359        }
4360        assert!(
4361            (out[0] - expect).abs() < 1e-12,
4362            "out[0] {} vs {}",
4363            out[0],
4364            expect
4365        );
4366        // h slot = sum of H[row, 2, 2] across rows.
4367        let mut expect_h = 0.0_f64;
4368        for row in 0..n {
4369            expect_h += row_hessians[row * r * r + 2 * r + 2];
4370        }
4371        let h_slot = p_m + p_g;
4372        assert!(
4373            (out[h_slot] - expect_h).abs() < 1e-12,
4374            "h slot {} vs {}",
4375            out[h_slot],
4376            expect_h
4377        );
4378    }
4379
4380    /// GPU↔CPU parity for the HVP and diagonal kernels. Skips on non-Linux /
4381    /// no-CUDA hosts. Hand-constructs a small `DeviceResidentRowHess` by
4382    /// allocating the device slices directly, uploading the same arrays the
4383    /// CPU oracle consumes, then dispatching the device kernels.
4384    #[test]
4385    pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_when_cuda_available() {
4386        #[cfg(not(target_os = "linux"))]
4387        {
4388            eprintln!(
4389                "[bms_flex_row hvp parity] non-Linux host — skipping CUDA parity \
4390                 (CPU oracle exercised by sibling tests)"
4391            );
4392        }
4393        #[cfg(target_os = "linux")]
4394        {
4395            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4396                eprintln!(
4397                    "[bms_flex_row hvp parity] no CUDA runtime — skipping device \
4398                     parity"
4399                );
4400                return;
4401            };
4402            let n = 4_usize;
4403            let r = 4_usize;
4404            let p_m = 2_usize;
4405            let p_g = 2_usize;
4406            let p_h_dim = 1_usize;
4407            let p_w_dim = 1_usize;
4408            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4409            let block = BmsFlexBlockLayout {
4410                p_m,
4411                p_g,
4412                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4413                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4414                p_total,
4415            };
4416            let primary = BmsFlexPrimaryLayout {
4417                h: Some(2..3),
4418                w: Some(3..4),
4419                r,
4420            };
4421            let mut row_hessians = vec![0.0_f64; n * r * r];
4422            for row in 0..n {
4423                for u in 0..r {
4424                    for v in u..r {
4425                        let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4426                        row_hessians[row * r * r + u * r + v] = val;
4427                        row_hessians[row * r * r + v * r + u] = val;
4428                    }
4429                }
4430            }
4431            let mut marginal = vec![0.0_f64; n * p_m];
4432            for row in 0..n {
4433                for j in 0..p_m {
4434                    marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4435                }
4436            }
4437            let mut logslope = vec![0.0_f64; n * p_g];
4438            for row in 0..n {
4439                for j in 0..p_g {
4440                    logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4441                }
4442            }
4443            let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4444            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4445                &row_hessians,
4446                &marginal,
4447                &logslope,
4448                &block,
4449                &primary,
4450                n,
4451                &v,
4452            );
4453            let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4454                &row_hessians,
4455                &marginal,
4456                &logslope,
4457                &block,
4458                &primary,
4459                n,
4460            );
4461
4462            // Allocate a DeviceResidentRowHess by hand using the HVP backend's
4463            // stream + module so we don't need to drive the full BMS row kernel.
4464            // Past the GpuRuntime::global() Some-gate above: a probe/upload failure
4465            // here is a real device fault on a CUDA host, not a no-CUDA skip. Fail
4466            // loud (the device-PCG skip-pass class, eee12f6b2) — the old arms
4467            // returned and the test passed while exercising nothing.
4468            let backend = HvpKernelBackend::probe()
4469                .expect("[bms_flex_row hvp parity] backend probe must succeed on CUDA host");
4470            let stream = backend.stream.clone();
4471            let d_h = stream
4472                .clone_htod(&row_hessians)
4473                .expect("[bms_flex_row hvp parity] upload h must succeed on CUDA host");
4474            let d_m = stream
4475                .clone_htod(&marginal)
4476                .expect("[bms_flex_row hvp parity] upload marg must succeed on CUDA host");
4477            let d_g = stream
4478                .clone_htod(&logslope)
4479                .expect("[bms_flex_row hvp parity] upload logslope must succeed on CUDA host");
4480            let storage = DeviceResidentRowHess {
4481                hess: d_h,
4482                marginal_design: d_m,
4483                logslope_design: d_g,
4484                n,
4485                r,
4486                block: block.clone(),
4487                primary: primary.clone(),
4488
4489                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4490            };
4491            let gpu_hvp =
4492                launch_bms_flex_row_hvp(&storage, &v).expect("HVP kernel must launch on CUDA host");
4493            let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4494                .expect("diagonal kernel must launch on CUDA host");
4495            assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4496            assert_eq!(gpu_diag.len(), cpu_diag.len());
4497            for i in 0..p_total {
4498                let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4499                assert!(
4500                    diff <= 1e-10,
4501                    "HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4502                    cpu_hvp[i],
4503                    gpu_hvp[i]
4504                );
4505                let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4506                assert!(
4507                    ddiff <= 1e-10,
4508                    "diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4509                    cpu_diag[i],
4510                    gpu_diag[i]
4511                );
4512            }
4513        }
4514    }
4515
4516    #[test]
4517    pub(crate) fn bms_flex_row_hvp_multi_scratch_is_bounded_at_large_scale_shape() {
4518        let n = 195_000_usize;
4519        let r = 20_usize;
4520        let p_total = 44_usize;
4521        let rhs_count = 4_usize;
4522        let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4523            .expect("large-scale multi-RHS scratch budget");
4524        let per_rhs_full_row_cache =
4525            (n * r * r * std::mem::size_of::<f64>()) as u64 * rhs_count as u64;
4526        assert!(
4527            scratch < per_rhs_full_row_cache / 100,
4528            "multi-RHS scratch must tile by row chunks instead of materializing \
4529             a row-Hessian copy per RHS: scratch={scratch} full_per_rhs={per_rhs_full_row_cache}"
4530        );
4531        assert!(
4532            bms_flex_row_hvp_multi_scratch_bytes_for_shape(
4533                n,
4534                p_total,
4535                BMS_FLEX_ROW_HVP_MAX_RHS + 1
4536            )
4537            .is_err(),
4538            "multi-RHS launch must reject unbounded RHS counts"
4539        );
4540    }
4541
4542    #[test]
4543    pub(crate) fn bms_flex_row_hvp_multi_kernel_matches_cpu_oracle_when_cuda_available() {
4544        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4545            eprintln!("[bms_flex_row hvp_multi parity] no CUDA runtime — skipping device parity");
4546            return;
4547        };
4548        let n = 5_usize;
4549        let r = 4_usize;
4550        let p_m = 2_usize;
4551        let p_g = 2_usize;
4552        let p_h_dim = 1_usize;
4553        let p_w_dim = 1_usize;
4554        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4555        let rhs_count = 3_usize;
4556        let block = BmsFlexBlockLayout {
4557            p_m,
4558            p_g,
4559            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4560            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4561            p_total,
4562        };
4563        let primary = BmsFlexPrimaryLayout {
4564            h: Some(2..3),
4565            w: Some(3..4),
4566            r,
4567        };
4568        let mut row_hessians = vec![0.0_f64; n * r * r];
4569        for row in 0..n {
4570            for u in 0..r {
4571                for v in u..r {
4572                    let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4573                    row_hessians[row * r * r + u * r + v] = val;
4574                    row_hessians[row * r * r + v * r + u] = val;
4575                }
4576            }
4577        }
4578        let mut marginal = vec![0.0_f64; n * p_m];
4579        let mut logslope = vec![0.0_f64; n * p_g];
4580        for row in 0..n {
4581            for j in 0..p_m {
4582                marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4583            }
4584            for j in 0..p_g {
4585                logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4586            }
4587        }
4588        let mut v_rhs = vec![0.0_f64; rhs_count * p_total];
4589        for rhs in 0..rhs_count {
4590            for j in 0..p_total {
4591                let seed = (rhs as f64) * 0.37 + (j as f64) * 0.19 + 0.4;
4592                v_rhs[rhs * p_total + j] = seed.sin() * 0.4 + seed.cos() * 0.2;
4593            }
4594        }
4595
4596        // Past the GpuRuntime::global() Some-gate: a probe/upload failure here is a
4597        // real device fault on a CUDA host, not a no-CUDA skip — fail loud
4598        // (device-PCG skip-pass class, eee12f6b2).
4599        let backend = HvpKernelBackend::probe()
4600            .expect("[bms_flex_row hvp_multi parity] backend probe must succeed on CUDA host");
4601        let stream = backend.stream.clone();
4602        let d_h = stream
4603            .clone_htod(&row_hessians)
4604            .expect("[bms_flex_row hvp_multi parity] upload h must succeed on CUDA host");
4605        let d_m = stream
4606            .clone_htod(&marginal)
4607            .expect("[bms_flex_row hvp_multi parity] upload marg must succeed on CUDA host");
4608        let d_g = stream
4609            .clone_htod(&logslope)
4610            .expect("[bms_flex_row hvp_multi parity] upload logslope must succeed on CUDA host");
4611        let storage = DeviceResidentRowHess {
4612            hess: d_h,
4613            marginal_design: d_m,
4614            logslope_design: d_g,
4615            n,
4616            r,
4617            block: block.clone(),
4618            primary: primary.clone(),
4619
4620            bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4621        };
4622        let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4623            .expect("storage scratch budget");
4624        assert!(
4625            scratch < storage.bytes,
4626            "multi-RHS scratch should stay below resident cache bytes"
4627        );
4628        let gpu = launch_bms_flex_row_hvp_multi(&storage, &v_rhs, rhs_count)
4629            .expect("multi-RHS HVP kernel must launch on CUDA host");
4630        assert_eq!(gpu.len(), rhs_count * p_total);
4631        for rhs in 0..rhs_count {
4632            let v = &v_rhs[rhs * p_total..(rhs + 1) * p_total];
4633            let cpu = cpu_oracle_bms_flex_row_hvp(
4634                &row_hessians,
4635                &marginal,
4636                &logslope,
4637                &block,
4638                &primary,
4639                n,
4640                v,
4641            );
4642            let single = launch_bms_flex_row_hvp(&storage, v)
4643                .expect("single-RHS HVP kernel must launch on CUDA host");
4644            for j in 0..p_total {
4645                let got = gpu[rhs * p_total + j];
4646                let diff = (cpu[j] - got).abs();
4647                assert!(
4648                    diff <= 1e-10,
4649                    "multi-RHS HVP rhs={rhs} j={j}: cpu={} gpu={} |diff|={diff:.3e}",
4650                    cpu[j],
4651                    got
4652                );
4653                assert_eq!(
4654                    got, single[j],
4655                    "multi-RHS and single-RHS host launch diverged at rhs={rhs} j={j}"
4656                );
4657            }
4658        }
4659    }
4660
4661    /// Parity for the third launch mode — device-output HVP
4662    /// ([`launch_bms_flex_row_hvp_into_device`]) — which the
4663    /// `run_bms_flex_row_partial_reduce` unification routes through the same
4664    /// partial+reduce engine as the host-returning `_hvp` / `_diagonal`
4665    /// adapters. Confirms that keeping the result on-stream (no internal sync /
4666    /// DtoH) reaches bit-identical output to both the CPU oracle and the
4667    /// host-out adapter, so the engine's mode/output split is faithful.
4668    ///
4669    /// Skips cleanly on non-Linux / no-CUDA hosts using the convention shared
4670    /// with the sibling parity tests.
4671    #[test]
4672    pub(crate) fn bms_flex_row_hvp_into_device_matches_cpu_oracle_and_host_out() {
4673        #[cfg(not(target_os = "linux"))]
4674        {
4675            eprintln!(
4676                "[bms_flex_row hvp_into_device parity] non-Linux host — skipping \
4677                 CUDA parity (CPU oracle exercised by sibling tests)"
4678            );
4679        }
4680        #[cfg(target_os = "linux")]
4681        {
4682            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4683                eprintln!(
4684                    "[bms_flex_row hvp_into_device parity] no CUDA runtime — \
4685                     skipping device parity"
4686                );
4687                return;
4688            };
4689            let n = 4_usize;
4690            let r = 4_usize;
4691            let p_m = 2_usize;
4692            let p_g = 2_usize;
4693            let p_h_dim = 1_usize;
4694            let p_w_dim = 1_usize;
4695            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4696            let block = BmsFlexBlockLayout {
4697                p_m,
4698                p_g,
4699                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4700                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4701                p_total,
4702            };
4703            let primary = BmsFlexPrimaryLayout {
4704                h: Some(2..3),
4705                w: Some(3..4),
4706                r,
4707            };
4708            let mut row_hessians = vec![0.0_f64; n * r * r];
4709            for row in 0..n {
4710                for u in 0..r {
4711                    for v in u..r {
4712                        let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4713                        row_hessians[row * r * r + u * r + v] = val;
4714                        row_hessians[row * r * r + v * r + u] = val;
4715                    }
4716                }
4717            }
4718            let mut marginal = vec![0.0_f64; n * p_m];
4719            for row in 0..n {
4720                for j in 0..p_m {
4721                    marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4722                }
4723            }
4724            let mut logslope = vec![0.0_f64; n * p_g];
4725            for row in 0..n {
4726                for j in 0..p_g {
4727                    logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4728                }
4729            }
4730            let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4731            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4732                &row_hessians,
4733                &marginal,
4734                &logslope,
4735                &block,
4736                &primary,
4737                n,
4738                &v,
4739            );
4740
4741            // Past the GpuRuntime::global() Some-gate: probe/upload failures are
4742            // real device faults on a CUDA host — fail loud (device-PCG class).
4743            let backend = HvpKernelBackend::probe().expect(
4744                "[bms_flex_row hvp_into_device parity] backend probe must succeed on CUDA host",
4745            );
4746            let stream = backend.stream.clone();
4747            let d_h = stream
4748                .clone_htod(&row_hessians)
4749                .expect("[bms_flex_row hvp_into_device parity] upload h must succeed on CUDA host");
4750            let d_m = stream.clone_htod(&marginal).expect(
4751                "[bms_flex_row hvp_into_device parity] upload marg must succeed on CUDA host",
4752            );
4753            let d_g = stream.clone_htod(&logslope).expect(
4754                "[bms_flex_row hvp_into_device parity] upload logslope must succeed on CUDA host",
4755            );
4756            let storage = DeviceResidentRowHess {
4757                hess: d_h,
4758                marginal_design: d_m,
4759                logslope_design: d_g,
4760                n,
4761                r,
4762                block: block.clone(),
4763                primary: primary.clone(),
4764
4765                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4766            };
4767
4768            // Host-out adapter (allocates its own d_out, syncs + downloads).
4769            let host_out_hvp = launch_bms_flex_row_hvp(&storage, &v)
4770                .expect("host-out HVP kernel must launch on CUDA host");
4771
4772            // Device-out adapter: caller owns d_v + d_out; the engine performs
4773            // no sync / DtoH, so we synchronize + download here.
4774            let d_v = stream
4775                .clone_htod(&v)
4776                .expect("upload direction for device-out HVP");
4777            let mut d_out = stream
4778                .alloc_zeros::<f64>(p_total)
4779                .expect("alloc device-out HVP output");
4780            launch_bms_flex_row_hvp_into_device(&storage, &d_v, &mut d_out)
4781                .expect("device-out HVP kernel must launch on CUDA host");
4782            stream
4783                .synchronize()
4784                .expect("synchronize after device-out HVP");
4785            let device_out_hvp = stream
4786                .clone_dtoh(&d_out)
4787                .expect("download device-out HVP output");
4788
4789            assert_eq!(device_out_hvp.len(), cpu_hvp.len());
4790            assert_eq!(device_out_hvp.len(), host_out_hvp.len());
4791            for i in 0..p_total {
4792                let diff = (cpu_hvp[i] - device_out_hvp[i]).abs();
4793                assert!(
4794                    diff <= 1e-10,
4795                    "device-out HVP[{i}] vs CPU: cpu={} gpu={} |Δ|={diff:.3e}",
4796                    cpu_hvp[i],
4797                    device_out_hvp[i]
4798                );
4799                // Both adapters share the engine; the only difference is the
4800                // copy-back path, so they must be bit-identical.
4801                let host_diff = (host_out_hvp[i] - device_out_hvp[i]).abs();
4802                assert!(
4803                    host_diff == 0.0,
4804                    "device-out vs host-out HVP[{i}]: host={} device={} |Δ|={host_diff:.3e}",
4805                    host_out_hvp[i],
4806                    device_out_hvp[i]
4807                );
4808            }
4809        }
4810    }
4811
4812    /// Block 9 Phase 2 parity gate at the shape specified by the
4813    /// charter task: `n = 64`, `r = 20`, `p_total = 44`. Splits
4814    /// `p_total` as `p_m = 14`, `p_g = 12`, `p_h = 10`, `p_w = 8` so
4815    /// `r = 2 + p_h + p_w = 20` and every primary block participates
4816    /// in both the device pullback and the reduce pass. Tolerance is
4817    /// `|Δ| ≤ 1e-8` per the task description (looser than the 1e-10
4818    /// hand-fixture parity, since accumulation order across HVP CTAs
4819    /// differs from the CPU oracle's row-major sum even with the
4820    /// deterministic reduction policy).
4821    ///
4822    /// Skips cleanly on non-Linux and no-CUDA hosts using the same
4823    /// convention as the hand-fixture parity above.
4824    #[test]
4825    pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_at_n64_r20_p44() {
4826        #[cfg(not(target_os = "linux"))]
4827        {
4828            eprintln!(
4829                "[bms_flex_row hvp parity n64_r20_p44] non-Linux host — \
4830                 skipping CUDA parity"
4831            );
4832        }
4833        #[cfg(target_os = "linux")]
4834        {
4835            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4836                eprintln!(
4837                    "[bms_flex_row hvp parity n64_r20_p44] no CUDA runtime — \
4838                     skipping device parity"
4839                );
4840                return;
4841            };
4842            let n = 64_usize;
4843            let p_m = 14_usize;
4844            let p_g = 12_usize;
4845            let p_h_dim = 10_usize;
4846            let p_w_dim = 8_usize;
4847            let r = 2 + p_h_dim + p_w_dim;
4848            assert_eq!(r, 20);
4849            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4850            assert_eq!(p_total, 44);
4851            let block = BmsFlexBlockLayout {
4852                p_m,
4853                p_g,
4854                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4855                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4856                p_total,
4857            };
4858            let primary = BmsFlexPrimaryLayout {
4859                h: Some(2..2 + p_h_dim),
4860                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4861                r,
4862            };
4863
4864            // Deterministic symmetric per-row Hessians + designs +
4865            // direction. Same scrambling family as
4866            // `row_hessian_ops::tests::make_fixture` so any regression
4867            // surfaces consistently across the host-pinned and
4868            // device-resident parity tests.
4869            let mut row_hessians = vec![0.0_f64; n * r * r];
4870            for row in 0..n {
4871                let base = row * r * r;
4872                for u in 0..r {
4873                    for v in 0..r {
4874                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
4875                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4876                        row_hessians[base + u * r + v] = a;
4877                    }
4878                }
4879                for u in 0..r {
4880                    for v in (u + 1)..r {
4881                        let upper = row_hessians[base + u * r + v];
4882                        let lower = row_hessians[base + v * r + u];
4883                        let sym = 0.5 * (upper + lower);
4884                        row_hessians[base + u * r + v] = sym;
4885                        row_hessians[base + v * r + u] = sym;
4886                    }
4887                    row_hessians[base + u * r + u] += r as f64;
4888                }
4889            }
4890            let mut marginal = vec![0.0_f64; n * p_m];
4891            for row in 0..n {
4892                for j in 0..p_m {
4893                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4894                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4895                }
4896            }
4897            let mut logslope = vec![0.0_f64; n * p_g];
4898            for row in 0..n {
4899                for j in 0..p_g {
4900                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4901                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4902                }
4903            }
4904            let v: Vec<f64> = (0..p_total)
4905                .map(|i| {
4906                    let seed = (i as f64) * 0.157 + 0.6;
4907                    seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4908                })
4909                .collect();
4910
4911            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4912                &row_hessians,
4913                &marginal,
4914                &logslope,
4915                &block,
4916                &primary,
4917                n,
4918                &v,
4919            );
4920            let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4921                &row_hessians,
4922                &marginal,
4923                &logslope,
4924                &block,
4925                &primary,
4926                n,
4927            );
4928
4929            let backend = match HvpKernelBackend::probe() {
4930                Ok(b) => b,
4931                Err(err) => {
4932                    eprintln!(
4933                        "[bms_flex_row hvp parity n64_r20_p44] backend probe \
4934                         failed: {err}"
4935                    );
4936                    return;
4937                }
4938            };
4939            let stream = backend.stream.clone();
4940            let d_h = match stream.clone_htod(&row_hessians) {
4941                Ok(s) => s,
4942                Err(err) => {
4943                    eprintln!(
4944                        "[bms_flex_row hvp parity n64_r20_p44] upload h \
4945                         failed: {err}"
4946                    );
4947                    return;
4948                }
4949            };
4950            let d_m = match stream.clone_htod(&marginal) {
4951                Ok(s) => s,
4952                Err(err) => {
4953                    eprintln!(
4954                        "[bms_flex_row hvp parity n64_r20_p44] upload marg \
4955                         failed: {err}"
4956                    );
4957                    return;
4958                }
4959            };
4960            let d_g = match stream.clone_htod(&logslope) {
4961                Ok(s) => s,
4962                Err(err) => {
4963                    eprintln!(
4964                        "[bms_flex_row hvp parity n64_r20_p44] upload logslope \
4965                         failed: {err}"
4966                    );
4967                    return;
4968                }
4969            };
4970            let storage = DeviceResidentRowHess {
4971                hess: d_h,
4972                marginal_design: d_m,
4973                logslope_design: d_g,
4974                n,
4975                r,
4976                block: block.clone(),
4977                primary: primary.clone(),
4978
4979                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4980            };
4981            let gpu_hvp = launch_bms_flex_row_hvp(&storage, &v)
4982                .expect("HVP kernel must launch on CUDA host at n64/r20/p44");
4983            let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4984                .expect("diagonal kernel must launch on CUDA host at n64/r20/p44");
4985            assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4986            assert_eq!(gpu_diag.len(), cpu_diag.len());
4987            for i in 0..p_total {
4988                let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4989                assert!(
4990                    diff <= 1e-8,
4991                    "n64_r20_p44 HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4992                    cpu_hvp[i],
4993                    gpu_hvp[i]
4994                );
4995                let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4996                assert!(
4997                    ddiff <= 1e-8,
4998                    "n64_r20_p44 diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4999                    cpu_diag[i],
5000                    gpu_diag[i]
5001                );
5002            }
5003        }
5004    }
5005
5006    /// Block 9 Phase 6 — small-fixture parity for the dense-block kernel
5007    /// against the host-side P_i pullback oracle.
5008    /// Verifies bit-equality (modulo reduction-order f.p. noise) between
5009    /// the device-resident dense build and the host accumulator over the
5010    /// same per-row Hessian + designs + P_i pullback.
5011    #[test]
5012    pub(crate) fn bms_flex_row_dense_block_kernel_matches_cpu_pullback() {
5013        #[cfg(not(target_os = "linux"))]
5014        {
5015            eprintln!("[bms_flex_row dense_block parity] non-Linux host — skipping CUDA parity");
5016        }
5017        #[cfg(target_os = "linux")]
5018        {
5019            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5020                eprintln!("[bms_flex_row dense_block parity] no CUDA runtime — skipping");
5021                return;
5022            };
5023            // Small fixture: n=24, r=8 (2 + 3 + 3), p_total=18 (4+4+3+3).
5024            // Keeps the CPU pullback fast while still exercising every
5025            // primary slot (q, g, h, w).
5026            let n = 24_usize;
5027            let p_m = 4_usize;
5028            let p_g = 4_usize;
5029            let p_h_dim = 3_usize;
5030            let p_w_dim = 3_usize;
5031            let r = 2 + p_h_dim + p_w_dim;
5032            let p_total = p_m + p_g + p_h_dim + p_w_dim;
5033            let block = BmsFlexBlockLayout {
5034                p_m,
5035                p_g,
5036                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5037                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5038                p_total,
5039            };
5040            let primary = BmsFlexPrimaryLayout {
5041                h: Some(2..2 + p_h_dim),
5042                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5043                r,
5044            };
5045
5046            let mut row_hessians = vec![0.0_f64; n * r * r];
5047            for row in 0..n {
5048                let base = row * r * r;
5049                for u in 0..r {
5050                    for v in 0..r {
5051                        let seed = (row as f64) * 0.21 + (u as f64) * 1.13 + (v as f64) * 0.47;
5052                        let a = (seed.sin() * 1.4 + (seed * 0.6).cos() * 0.7) * 0.5;
5053                        row_hessians[base + u * r + v] = a;
5054                    }
5055                }
5056                for u in 0..r {
5057                    for v in (u + 1)..r {
5058                        let upper = row_hessians[base + u * r + v];
5059                        let lower = row_hessians[base + v * r + u];
5060                        let sym = 0.5 * (upper + lower);
5061                        row_hessians[base + u * r + v] = sym;
5062                        row_hessians[base + v * r + u] = sym;
5063                    }
5064                    row_hessians[base + u * r + u] += r as f64;
5065                }
5066            }
5067            let mut marginal = vec![0.0_f64; n * p_m];
5068            for row in 0..n {
5069                for j in 0..p_m {
5070                    let seed = (row as f64) * 0.083 + (j as f64) * 0.171 + 0.31;
5071                    marginal[row * p_m + j] = seed.sin() * 0.7 - (seed * 0.5).cos() * 0.25;
5072                }
5073            }
5074            let mut logslope = vec![0.0_f64; n * p_g];
5075            for row in 0..n {
5076                for j in 0..p_g {
5077                    let seed = (row as f64) * 0.097 + (j as f64) * 0.143 - 0.15;
5078                    logslope[row * p_g + j] = seed.cos() * 0.65 + (seed * 0.4).sin() * 0.2;
5079                }
5080            }
5081
5082            // CPU oracle — same pullback math the device kernel mirrors.
5083            let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5084            let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5085            let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5086            let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5087            let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5088            let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5089            let mut h_cpu = vec![0.0_f64; p_total * p_total];
5090            for row in 0..n {
5091                let mrow = &marginal[row * p_m..(row + 1) * p_m];
5092                let grow = &logslope[row * p_g..(row + 1) * p_g];
5093                let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5094                // Build per-row phi (r length-p_total vectors).
5095                let mut phi = vec![vec![0.0_f64; p_total]; r];
5096                for k in 0..p_m {
5097                    phi[0][k] = mrow[k];
5098                }
5099                for k in 0..p_g {
5100                    phi[1][p_m + k] = grow[k];
5101                }
5102                for k in 0..h_block_len {
5103                    phi[h_primary_start + k][h_block_start + k] = 1.0;
5104                }
5105                for k in 0..w_block_len {
5106                    phi[w_primary_start + k][w_block_start + k] = 1.0;
5107                }
5108                for u in 0..r {
5109                    for v in 0..r {
5110                        let huv = hrow[u * r + v];
5111                        if huv == 0.0 {
5112                            continue;
5113                        }
5114                        for m in 0..p_total {
5115                            let pm = phi[u][m];
5116                            if pm == 0.0 {
5117                                continue;
5118                            }
5119                            let scaled = huv * pm;
5120                            for nn in 0..p_total {
5121                                h_cpu[m * p_total + nn] += scaled * phi[v][nn];
5122                            }
5123                        }
5124                    }
5125                }
5126            }
5127
5128            // Build a transient device-resident storage and launch the
5129            // dense-block kernel.
5130            // Past the GpuRuntime::global() Some-gate: probe/upload failures are
5131            // real device faults on a CUDA host — fail loud (device-PCG class).
5132            let backend = HvpKernelBackend::probe().expect(
5133                "[bms_flex_row dense_block parity] backend probe must succeed on CUDA host",
5134            );
5135            let stream = backend.stream.clone();
5136            let d_h = stream
5137                .clone_htod(&row_hessians)
5138                .expect("[bms_flex_row dense_block parity] upload h must succeed on CUDA host");
5139            let d_m = stream
5140                .clone_htod(&marginal)
5141                .expect("[bms_flex_row dense_block parity] upload marg must succeed on CUDA host");
5142            let d_g = stream.clone_htod(&logslope).expect(
5143                "[bms_flex_row dense_block parity] upload logslope must succeed on CUDA host",
5144            );
5145            let storage = DeviceResidentRowHess {
5146                hess: d_h,
5147                marginal_design: d_m,
5148                logslope_design: d_g,
5149                n,
5150                r,
5151                block: block.clone(),
5152                primary: primary.clone(),
5153
5154                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5155            };
5156            let h_gpu = launch_bms_flex_row_dense_block(&storage)
5157                .expect("dense_block kernel must launch on CUDA host");
5158            assert_eq!(h_gpu.len(), p_total * p_total);
5159
5160            // Compare entry-by-entry with a tolerance that absorbs
5161            // reduction-order f.p. noise from the CTA chunk sum.
5162            let mut max_abs = 0.0_f64;
5163            for i in 0..p_total {
5164                for j in 0..p_total {
5165                    let a = h_cpu[i * p_total + j];
5166                    let b = h_gpu[i * p_total + j];
5167                    let diff = (a - b).abs();
5168                    if diff > max_abs {
5169                        max_abs = diff;
5170                    }
5171                    assert!(
5172                        diff <= 1e-9 * a.abs().max(b.abs()).max(1.0),
5173                        "dense_block[{i},{j}]: cpu={a} gpu={b} |Δ|={diff:.3e}"
5174                    );
5175                }
5176            }
5177            eprintln!(
5178                "[bms_flex_row dense_block parity] n={n} r={r} p={p_total}: max|Δ|={max_abs:.3e}"
5179            );
5180        }
5181    }
5182
5183    /// Block 9 final hill-climb gate — GPU HVP must be at least 5× faster
5184    /// than a Rayon-parallel CPU HVP at large-scale shape (n=195_000, r=20,
5185    /// p_total=44). This is the charter pass/fail metric for whether the
5186    /// device-resident row-Hessian path is a real perf win for the
5187    /// production marginal-slope fit.
5188    ///
5189    /// Methodology:
5190    ///   * Build the same deterministic fixture as the parity tests.
5191    ///   * GPU: median of `iters` `launch_bms_flex_row_hvp` wall-times
5192    ///     after `warmup` warm-up launches (kernel compile + L2 prime).
5193    ///   * CPU: median of `iters` `cpu_oracle_bms_flex_row_hvp` wall-times,
5194    ///     parallelised over rows via Rayon — this mirrors the actual
5195    ///     production CPU path in
5196    ///     `exact_newton_joint_hessian_matvec_from_cache` (which uses
5197    ///     `ROW_CHUNK_SIZE` chunked `into_par_iter()` for the same
5198    ///     contraction).
5199    ///   * Ratio = cpu_median / gpu_median; assert ratio >= 5.
5200    ///
5201    /// Skips on non-Linux / no-CUDA hosts.
5202    #[test]
5203    pub(crate) fn bms_flex_row_hvp_v100_hill_climb_5x_vs_cpu_at_large_scale() {
5204        #[cfg(not(target_os = "linux"))]
5205        {
5206            eprintln!("[bms_flex_row hvp hill-climb] non-Linux host — skipping V100 perf gate");
5207        }
5208        #[cfg(target_os = "linux")]
5209        {
5210            use rayon::prelude::*;
5211
5212            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5213                eprintln!(
5214                    "[bms_flex_row hvp hill-climb] no CUDA runtime — skipping V100 perf gate"
5215                );
5216                return;
5217            };
5218            let n = 195_000_usize;
5219            let p_m = 14_usize;
5220            let p_g = 12_usize;
5221            let p_h_dim = 10_usize;
5222            let p_w_dim = 8_usize;
5223            let r = 2 + p_h_dim + p_w_dim;
5224            let p_total = p_m + p_g + p_h_dim + p_w_dim;
5225            let block = BmsFlexBlockLayout {
5226                p_m,
5227                p_g,
5228                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5229                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5230                p_total,
5231            };
5232            let primary = BmsFlexPrimaryLayout {
5233                h: Some(2..2 + p_h_dim),
5234                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5235                r,
5236            };
5237
5238            // Same deterministic fixture as the Phase 4 large-scale benchmark.
5239            let mut row_hessians = vec![0.0_f64; n * r * r];
5240            for row in 0..n {
5241                let base = row * r * r;
5242                for u in 0..r {
5243                    for vv in 0..r {
5244                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5245                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5246                        row_hessians[base + u * r + vv] = a;
5247                    }
5248                }
5249                for u in 0..r {
5250                    for vv in (u + 1)..r {
5251                        let upper = row_hessians[base + u * r + vv];
5252                        let lower = row_hessians[base + vv * r + u];
5253                        let sym = 0.5 * (upper + lower);
5254                        row_hessians[base + u * r + vv] = sym;
5255                        row_hessians[base + vv * r + u] = sym;
5256                    }
5257                    row_hessians[base + u * r + u] += r as f64;
5258                }
5259            }
5260            let mut marginal = vec![0.0_f64; n * p_m];
5261            for row in 0..n {
5262                for j in 0..p_m {
5263                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5264                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5265                }
5266            }
5267            let mut logslope = vec![0.0_f64; n * p_g];
5268            for row in 0..n {
5269                for j in 0..p_g {
5270                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5271                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5272                }
5273            }
5274            let v: Vec<f64> = (0..p_total)
5275                .map(|i| {
5276                    let seed = (i as f64) * 0.157 + 0.6;
5277                    seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
5278                })
5279                .collect();
5280
5281            // ── GPU side: upload once, time HVP launches ─────────────
5282            let backend = match HvpKernelBackend::probe() {
5283                Ok(b) => b,
5284                Err(err) => {
5285                    eprintln!("[bms_flex_row hvp hill-climb] backend probe failed: {err}");
5286                    return;
5287                }
5288            };
5289            let stream = backend.stream.clone();
5290            let d_h = match stream.clone_htod(&row_hessians) {
5291                Ok(s) => s,
5292                Err(err) => {
5293                    eprintln!("[bms_flex_row hvp hill-climb] upload h failed (likely OOM): {err}");
5294                    return;
5295                }
5296            };
5297            let d_m = match stream.clone_htod(&marginal) {
5298                Ok(s) => s,
5299                Err(err) => {
5300                    eprintln!("[bms_flex_row hvp hill-climb] upload marg failed: {err}");
5301                    return;
5302                }
5303            };
5304            let d_g = match stream.clone_htod(&logslope) {
5305                Ok(s) => s,
5306                Err(err) => {
5307                    eprintln!("[bms_flex_row hvp hill-climb] upload logslope failed: {err}");
5308                    return;
5309                }
5310            };
5311            let storage = DeviceResidentRowHess {
5312                hess: d_h,
5313                marginal_design: d_m,
5314                logslope_design: d_g,
5315                n,
5316                r,
5317                block: block.clone(),
5318                primary: primary.clone(),
5319
5320                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5321            };
5322            let warmup: usize = 3;
5323            let iters: usize = 15;
5324            for _ in 0..warmup {
5325                let out =
5326                    launch_bms_flex_row_hvp(&storage, &v).expect("warmup GPU HVP must launch");
5327                assert_eq!(out.len(), p_total);
5328            }
5329            let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5330            for _ in 0..iters {
5331                let t0 = std::time::Instant::now();
5332                let out = launch_bms_flex_row_hvp(&storage, &v).expect("GPU HVP must launch");
5333                gpu_us.push(t0.elapsed().as_micros());
5334                assert_eq!(out.len(), p_total);
5335            }
5336            gpu_us.sort_unstable();
5337            let gpu_median = gpu_us[iters / 2];
5338
5339            // ── CPU side: chunked Rayon HVP over rows, mirroring the
5340            //    production `exact_newton_joint_hessian_matvec_from_cache`
5341            //    parallelisation pattern (ROW_CHUNK_SIZE-row chunks,
5342            //    try_fold + try_reduce). The per-chunk worker calls the
5343            //    single-threaded oracle on its row slice.
5344            const CHUNK_ROWS: usize = 4096;
5345            let cpu_hvp_parallel = || -> Vec<f64> {
5346                let nchunks = n.div_ceil(CHUNK_ROWS);
5347                (0..nchunks)
5348                    .into_par_iter()
5349                    .fold(
5350                        || vec![0.0_f64; p_total],
5351                        |mut acc, ci| {
5352                            let lo = ci * CHUNK_ROWS;
5353                            let hi = (lo + CHUNK_ROWS).min(n);
5354                            let m = hi - lo;
5355                            let partial = cpu_oracle_bms_flex_row_hvp(
5356                                &row_hessians[lo * r * r..hi * r * r],
5357                                &marginal[lo * p_m..hi * p_m],
5358                                &logslope[lo * p_g..hi * p_g],
5359                                &block,
5360                                &primary,
5361                                m,
5362                                &v,
5363                            );
5364                            for (a, &p) in acc.iter_mut().zip(partial.iter()) {
5365                                *a += p;
5366                            }
5367                            acc
5368                        },
5369                    )
5370                    .reduce(
5371                        || vec![0.0_f64; p_total],
5372                        |mut a, b| {
5373                            for (ax, bx) in a.iter_mut().zip(b.iter()) {
5374                                *ax += *bx;
5375                            }
5376                            a
5377                        },
5378                    )
5379            };
5380            // Warmup once to populate L3 / steady-state Rayon thread pool.
5381            let warm = cpu_hvp_parallel();
5382            assert_eq!(warm.len(), p_total);
5383            let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5384            for _ in 0..iters {
5385                let t0 = std::time::Instant::now();
5386                let out = cpu_hvp_parallel();
5387                cpu_us.push(t0.elapsed().as_micros());
5388                assert_eq!(out.len(), p_total);
5389            }
5390            cpu_us.sort_unstable();
5391            let cpu_median = cpu_us[iters / 2];
5392
5393            let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5394            eprintln!(
5395                "[bms_flex_row hvp hill-climb] large-scale n={n} r={r} p={p_total}: \
5396                 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5397                 speedup={speedup:.2}× (charter target ≥ 5×)"
5398            );
5399            assert!(
5400                speedup >= 5.0,
5401                "large-scale HVP perf gate: GPU only {speedup:.2}× faster than CPU; \
5402                 need ≥ 5× per Block 9 charter (cpu_median={cpu_median}us, \
5403                 gpu_median={gpu_median}us). Hill-climb the kernel until met or \
5404                 prove the kernel is at hardware roofline."
5405            );
5406        }
5407    }
5408
5409    /// Companion to the HVP hill-climb: GPU dense-block build must be at
5410    /// least 10× faster than a Rayon-parallel CPU dense build at large-scale
5411    /// shape. The dense build is `O(n * r² * p_total)` work for both
5412    /// paths so the ratio is well-defined.
5413    #[test]
5414    pub(crate) fn bms_flex_row_dense_block_v100_hill_climb_10x_vs_cpu_at_large_scale() {
5415        #[cfg(not(target_os = "linux"))]
5416        {
5417            eprintln!(
5418                "[bms_flex_row dense_block hill-climb] non-Linux host — skipping V100 perf gate"
5419            );
5420        }
5421        #[cfg(target_os = "linux")]
5422        {
5423            use rayon::prelude::*;
5424
5425            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5426                eprintln!(
5427                    "[bms_flex_row dense_block hill-climb] no CUDA runtime — skipping V100 perf gate"
5428                );
5429                return;
5430            };
5431            let n = 195_000_usize;
5432            let p_m = 14_usize;
5433            let p_g = 12_usize;
5434            let p_h_dim = 10_usize;
5435            let p_w_dim = 8_usize;
5436            let r = 2 + p_h_dim + p_w_dim;
5437            let p_total = p_m + p_g + p_h_dim + p_w_dim;
5438            let block = BmsFlexBlockLayout {
5439                p_m,
5440                p_g,
5441                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5442                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5443                p_total,
5444            };
5445            let primary = BmsFlexPrimaryLayout {
5446                h: Some(2..2 + p_h_dim),
5447                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5448                r,
5449            };
5450
5451            // Reuse the same large-scale fixture recipe.
5452            let mut row_hessians = vec![0.0_f64; n * r * r];
5453            for row in 0..n {
5454                let base = row * r * r;
5455                for u in 0..r {
5456                    for vv in 0..r {
5457                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5458                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5459                        row_hessians[base + u * r + vv] = a;
5460                    }
5461                }
5462                for u in 0..r {
5463                    for vv in (u + 1)..r {
5464                        let upper = row_hessians[base + u * r + vv];
5465                        let lower = row_hessians[base + vv * r + u];
5466                        let sym = 0.5 * (upper + lower);
5467                        row_hessians[base + u * r + vv] = sym;
5468                        row_hessians[base + vv * r + u] = sym;
5469                    }
5470                    row_hessians[base + u * r + u] += r as f64;
5471                }
5472            }
5473            let mut marginal = vec![0.0_f64; n * p_m];
5474            for row in 0..n {
5475                for j in 0..p_m {
5476                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5477                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5478                }
5479            }
5480            let mut logslope = vec![0.0_f64; n * p_g];
5481            for row in 0..n {
5482                for j in 0..p_g {
5483                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5484                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5485                }
5486            }
5487
5488            // GPU dense_block kernel rejects p_total > DENSE_BLOCK_MAX_P
5489            // (72 at V100 48 KiB/block). LargeScale's p_total = 44 fits.
5490            if p_total > DENSE_BLOCK_MAX_P {
5491                eprintln!(
5492                    "[bms_flex_row dense_block hill-climb] p_total={p_total} > MAX={DENSE_BLOCK_MAX_P}, skipping"
5493                );
5494                return;
5495            }
5496            let backend = match HvpKernelBackend::probe() {
5497                Ok(b) => b,
5498                Err(err) => {
5499                    eprintln!("[bms_flex_row dense_block hill-climb] backend probe failed: {err}");
5500                    return;
5501                }
5502            };
5503            let stream = backend.stream.clone();
5504            let d_h = match stream.clone_htod(&row_hessians) {
5505                Ok(s) => s,
5506                Err(err) => {
5507                    eprintln!("[bms_flex_row dense_block hill-climb] upload h failed: {err}");
5508                    return;
5509                }
5510            };
5511            let d_m = match stream.clone_htod(&marginal) {
5512                Ok(s) => s,
5513                Err(err) => {
5514                    eprintln!("[bms_flex_row dense_block hill-climb] upload marg failed: {err}");
5515                    return;
5516                }
5517            };
5518            let d_g = match stream.clone_htod(&logslope) {
5519                Ok(s) => s,
5520                Err(err) => {
5521                    eprintln!(
5522                        "[bms_flex_row dense_block hill-climb] upload logslope failed: {err}"
5523                    );
5524                    return;
5525                }
5526            };
5527            let storage = DeviceResidentRowHess {
5528                hess: d_h,
5529                marginal_design: d_m,
5530                logslope_design: d_g,
5531                n,
5532                r,
5533                block: block.clone(),
5534                primary: primary.clone(),
5535
5536                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5537            };
5538            // Warmup + 5-iter median (dense build is heavier than HVP).
5539            let warmup: usize = 2;
5540            let iters: usize = 5;
5541            for _ in 0..warmup {
5542                let out = launch_bms_flex_row_dense_block(&storage)
5543                    .expect("warmup GPU dense_block must launch");
5544                assert_eq!(out.len(), p_total * p_total);
5545            }
5546            let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5547            for _ in 0..iters {
5548                let t0 = std::time::Instant::now();
5549                let out =
5550                    launch_bms_flex_row_dense_block(&storage).expect("GPU dense_block must launch");
5551                gpu_us.push(t0.elapsed().as_micros());
5552                assert_eq!(out.len(), p_total * p_total);
5553            }
5554            gpu_us.sort_unstable();
5555            let gpu_median = gpu_us[iters / 2];
5556
5557            // CPU side: chunked Rayon dense build over rows. Each chunk
5558            // builds a `[p_total, p_total]` partial then we reduce-add.
5559            const CHUNK_ROWS: usize = 2048;
5560            let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5561            let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5562            let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5563            let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5564            let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5565            let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5566            let cpu_build_parallel = || -> Vec<f64> {
5567                let nchunks = n.div_ceil(CHUNK_ROWS);
5568                (0..nchunks)
5569                    .into_par_iter()
5570                    .fold(
5571                        || vec![0.0_f64; p_total * p_total],
5572                        |mut acc, ci| {
5573                            let lo = ci * CHUNK_ROWS;
5574                            let hi = (lo + CHUNK_ROWS).min(n);
5575                            let mut phi: Vec<Vec<f64>> = vec![vec![0.0_f64; p_total]; r];
5576                            for row in lo..hi {
5577                                for col in phi.iter_mut() {
5578                                    col.iter_mut().for_each(|v| *v = 0.0);
5579                                }
5580                                let mrow = &marginal[row * p_m..(row + 1) * p_m];
5581                                let grow = &logslope[row * p_g..(row + 1) * p_g];
5582                                for k in 0..p_m {
5583                                    phi[0][k] = mrow[k];
5584                                }
5585                                for k in 0..p_g {
5586                                    phi[1][p_m + k] = grow[k];
5587                                }
5588                                for k in 0..h_block_len {
5589                                    phi[h_primary_start + k][h_block_start + k] = 1.0;
5590                                }
5591                                for k in 0..w_block_len {
5592                                    phi[w_primary_start + k][w_block_start + k] = 1.0;
5593                                }
5594                                let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5595                                for u in 0..r {
5596                                    for v_idx in 0..r {
5597                                        let huv = hrow[u * r + v_idx];
5598                                        if huv == 0.0 {
5599                                            continue;
5600                                        }
5601                                        for m in 0..p_total {
5602                                            let pm = phi[u][m];
5603                                            if pm == 0.0 {
5604                                                continue;
5605                                            }
5606                                            let scaled = huv * pm;
5607                                            for nn in 0..p_total {
5608                                                acc[m * p_total + nn] += scaled * phi[v_idx][nn];
5609                                            }
5610                                        }
5611                                    }
5612                                }
5613                            }
5614                            acc
5615                        },
5616                    )
5617                    .reduce(
5618                        || vec![0.0_f64; p_total * p_total],
5619                        |mut a, b| {
5620                            for (ax, bx) in a.iter_mut().zip(b.iter()) {
5621                                *ax += *bx;
5622                            }
5623                            a
5624                        },
5625                    )
5626            };
5627            let warm_cpu = cpu_build_parallel();
5628            assert_eq!(warm_cpu.len(), p_total * p_total);
5629            let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5630            for _ in 0..iters {
5631                let t0 = std::time::Instant::now();
5632                let out = cpu_build_parallel();
5633                cpu_us.push(t0.elapsed().as_micros());
5634                assert_eq!(out.len(), p_total * p_total);
5635            }
5636            cpu_us.sort_unstable();
5637            let cpu_median = cpu_us[iters / 2];
5638
5639            let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5640            eprintln!(
5641                "[bms_flex_row dense_block hill-climb] large-scale n={n} r={r} p={p_total}: \
5642                 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5643                 speedup={speedup:.2}× (charter target ≥ 10×)"
5644            );
5645            assert!(
5646                speedup >= 10.0,
5647                "large-scale dense-H perf gate: GPU only {speedup:.2}× faster than CPU; \
5648                 need ≥ 10× per Block 9 charter (cpu_median={cpu_median}us, \
5649                 gpu_median={gpu_median}us). Hill-climb the dense_block kernel \
5650                 (warp-stripe the u-v-m-n loop, vectorise loads, etc.) until met \
5651                 or prove the kernel is at hardware roofline."
5652            );
5653        }
5654    }
5655}