Skip to main content

gam_models/bms/gpu/
row.rs

1//! Stage 2 of the BMS FLEX row kernel — per-row math that turns per-cell
2//! derivative moments (built by Stage 1 in `src/gpu/cubic_cell/mod.rs`) into a
3//! row gradient and row-primary `r × r` Hessian.
4//!
5//! Math (mirrors the CPU reference
6//! `BernoulliMarginalSlope::compute_row_analytic_flex_from_parts_into` in
7//! `src/families/bernoulli_marginal_slope.rs`):
8//!
9//! For each row `i`, with per-cell cubic predictor coefficients
10//! `C_c = (C0, C1, C2, C3)` and derivative moments `m_0..m_9`, build
11//!
12//! ```text
13//!     κ        = 1 / (2π)
14//!     T_n      = κ · Σ_{e=0..3} C_e · m_{e+n}     (n = 0..6)
15//!     D(R)     = κ · Σ_{k=0..3} R_k · m_k
16//!     Q(R, S)  = Σ_{p,q=0..3} R_p · S_q · T_{p+q}
17//!     H(R, S, U) = D(U) − Q(R, S)
18//! ```
19//!
20//! Per cell `c`, accumulate into row scratch:
21//!
22//! ```text
23//!     F_a   += D(A_c)
24//!     F_aa  += H(A_c, A_c, AA_c)
25//!     F_u   += D(R_{c,u})                         u > 0
26//!     F_au  += H(A_c, R_{c,u}, AR_{c,u})          u > 0
27//!     F_uv  += H(R_{c,u}, R_{c,v}, S_{c,uv})      0 < u ≤ v
28//! ```
29//!
30//! After the cell sum, the `q`-row is overridden:
31//!
32//! ```text
33//!     F_q  = −mu_1
34//!     F_qq = −mu_2
35//!     F_qv = 0   (v > 0)
36//!     F_aq = 0
37//! ```
38//!
39//! Implicit function theorem (single `1/F_a`):
40//!
41//! ```text
42//!     inv_Fa = 1 / F_a
43//!     a_u    = −F_u · inv_Fa                       (q-row override: mu_1 · inv_Fa)
44//!     a_uv   = −(F_uv + F_au·a_v + F_av·a_u + F_aa·a_u·a_v) · inv_Fa
45//! ```
46//!
47//! Observed predictor at `z_obs` (host supplies pre-evaluated chi, xi, rho, tau,
48//! r_uv per row and coordinate):
49//!
50//! ```text
51//!     bar_e_u  = chi_obs · a_u + rho_u
52//!     bar_e_uv = chi_obs · a_uv + xi_obs · a_u · a_v + tau_u · a_v
53//!                + a_u · tau_v + r_uv
54//! ```
55//!
56//! Probit Mills (stable; uses `log_ndtr_and_mills` from `numerics_device::PROBIT_NUMERICS_CU`):
57//!
58//! ```text
59//!     s = 2y − 1 ;  m = s · e_obs
60//!     [log_cdf, λ] = log_ndtr_and_mills(m)
61//!     A = −w · s · λ
62//!     B =  w · λ · (m + λ)
63//! ```
64//!
65//! Final outputs:
66//!
67//! ```text
68//!     neglog   = −w · log_cdf
69//!     g_u      = A · bar_e_u
70//!     H_{uv}   = B · bar_e_u · bar_e_v + A · bar_e_uv     (symmetric)
71//! ```
72//!
73//! Implementation choice (Stage 2): **one CUDA block per row**, with
74//! `blockDim.x = 32` threads. The block's `F_u`, `F_au`, `F_uv`, `bar_e_u`,
75//! `bar_e_uv` live in shared memory; threads in the block parallelise the
76//! per-cell sums, then a single thread of the block (`threadIdx.x == 0`) does
77//! the IFT solve, the observed-point assembly, the Mills evaluation, and the
78//! final gradient + Hessian write-out. With the `r ≤ MAX_R` cap (32) the
79//! shared-memory footprint per block is `r + r + r*r + r + r*r` doubles
80//! = `2r² + 3r` ≤ 2 144 doubles ≈ 17 KB, well below the V100 48 KB per-block
81//! limit. This keeps the implementation simple and avoids per-thread global
82//! scratch (a per-thread `r*r` scratch arena would be ~2 GB at n=195k, r=20).
83
84#[cfg(target_os = "linux")]
85use std::sync::OnceLock;
86
87use gam_gpu::gpu_error::GpuError;
88
89#[cfg(target_os = "linux")]
90use std::sync::Arc;
91
92#[cfg(target_os = "linux")]
93use cudarc::driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
94
95/// Hard ceiling on `r` (= 2 + p_h + p_w). Matches the shared-memory budget
96/// argument in the module docstring: with `MAX_R = 32` the per-block shared
97/// footprint is at most `2·32² + 3·32 = 2 144` doubles = 17 KB.
98pub(crate) const MAX_R: usize = 32;
99
100/// `blockDim.x` for the row kernel. Threads of a row-block parallelise the
101/// per-cell loop; thread 0 of the block finalises the IFT solve. Linux-only
102/// because the kernel launcher that consumes it is Linux-only.
103#[cfg(target_os = "linux")]
104pub(crate) const ROW_KERNEL_THREADS: u32 = 32;
105
106/// Number of cubic predictor coefficients per cell (`C0..C3`) and the matching
107/// support length of `A_c`, `R_{c,u}`, `AA_c`, `AR_{c,u}`, `S_{c,uv}`.
108pub(crate) const COEFF4: usize = 4;
109
110/// Highest moment index touched per cell: `T_n` uses `m_{n+e}` for `e = 0..3`
111/// and `n = 0..6`, so the maximum index is `9`. `MOMENT_STRIDE = 10`.
112pub(crate) const MOMENT_STRIDE: usize = 10;
113
114/// Source of the per-cell derivative moments fed into the row kernel.
115/// Phase-4 wiring: the substrate at `src/gpu/cubic_cell/mod.rs` can produce
116/// these on the GPU; this enum lets the launcher consume them directly
117/// without a DtoH+HtoD round-trip.
118pub(crate) enum CellMomentsSource<'a> {
119    /// Host-resident `[total_cells, MOMENT_STRIDE = 10]` row-major buffer.
120    /// The launcher will HtoD-upload this on every launch.
121    Host(&'a [f64]),
122    /// Device-resident moments already living on the row-kernel backend's
123    /// default stream (which is the same `cuda_context_for(ordinal).default_stream()`
124    /// the cubic-cell substrate uses, so no cross-context copy is needed).
125    /// Length on the device must be `total_cells * MOMENT_STRIDE`. Linux-only.
126    #[cfg(target_os = "linux")]
127    Device(&'a CudaSlice<f64>),
128}
129
130impl<'a> CellMomentsSource<'a> {
131    /// Logical element count of the moments source, used by [`BmsFlexRowKernelInputs::validate`].
132    pub(crate) fn len(&self) -> usize {
133        match self {
134            CellMomentsSource::Host(slice) => slice.len(),
135            #[cfg(target_os = "linux")]
136            CellMomentsSource::Device(d) => d.len(),
137        }
138    }
139}
140
141/// Per-row input bundle for [`launch_bms_flex_row_kernel`].
142///
143/// Coordinate ordering convention: `u = 0` is `a` (the latent intercept and
144/// the variable IFT eliminates); `u = 1` is `b` (slope); `u = 2..2+p_h` is the
145/// score-warp `β_h` block; `u = 2+p_h..2+p_h+p_w` is the link-wiggle `β_w`
146/// block. So `r = 2 + p_h + p_w` and `u = 1` is the `b` (slope) index used by
147/// the sparse `S_{b·h}` / `S_{b·w}` payloads.
148macro_rules! define_bms_flex_row_kernel_input_types {
149    (
150        f64_fields: [$($f64_field:ident),+ $(,)?],
151        u32_fields: [$($u32_field:ident),+ $(,)?],
152        moments_field: $moments_field:ident $(,)?
153    ) => {
154        pub(crate) struct BmsFlexRowKernelInputs<'a> {
155            /// Number of observation rows.
156            pub n_rows: usize,
157            /// Total primary local dimension. `r = 2 + p_h + p_w`.
158            pub r: usize,
159            /// Number of score-warp basis coordinates.
160            pub p_h: usize,
161            /// Number of link-wiggle basis coordinates.
162            pub p_w: usize,
163            /// Probit frailty scale `S_f` (scalar shared across rows; matches
164            /// `BernoulliMarginalSlope::probit_frailty_scale`).
165            pub s_f: f64,
166            $(pub $f64_field: &'a [f64],)+
167            $(pub $u32_field: &'a [u32],)+
168            pub $moments_field: CellMomentsSource<'a>,
169        }
170
171        /// Owned twin of [`BmsFlexRowKernelInputs`] — every borrowed slice is
172        /// replaced by an owned `Vec`. The buffer fields are declared from the
173        /// same schema as the borrowed launch ABI and converted by
174        /// [`BmsFlexRowKernelInputsOwned::as_borrowed`].
175        pub(crate) struct BmsFlexRowKernelInputsOwned {
176            pub n_rows: usize,
177            pub r: usize,
178            pub p_h: usize,
179            pub p_w: usize,
180            pub s_f: f64,
181            $(pub $f64_field: Vec<f64>,)+
182            $(pub $u32_field: Vec<u32>,)+
183            pub $moments_field: Vec<f64>,
184            /// Phase-4 device-resident moments. When `Some(_)`, the launcher
185            /// skips the host upload and consumes the buffer directly.
186            /// Linux-only field.
187            #[cfg(target_os = "linux")]
188            pub cell_moments_device: Option<CudaSlice<f64>>,
189        }
190
191        impl BmsFlexRowKernelInputsOwned {
192            /// Borrowed view over `self` suitable for
193            /// [`launch_bms_flex_row_kernel`]. The returned struct holds
194            /// references into `self` so the owned bundle must outlive the
195            /// launch.
196            pub(crate) fn as_borrowed(&self) -> BmsFlexRowKernelInputs<'_> {
197                #[cfg(target_os = "linux")]
198                let cell_moments = match self.cell_moments_device.as_ref() {
199                    Some(d) => CellMomentsSource::Device(d),
200                    None => CellMomentsSource::Host(&self.cell_moments),
201                };
202                #[cfg(not(target_os = "linux"))]
203                let cell_moments = CellMomentsSource::Host(&self.cell_moments);
204                BmsFlexRowKernelInputs {
205                    n_rows: self.n_rows,
206                    r: self.r,
207                    p_h: self.p_h,
208                    p_w: self.p_w,
209                    s_f: self.s_f,
210                    $($f64_field: &self.$f64_field,)+
211                    $($u32_field: &self.$u32_field,)+
212                    $moments_field: cell_moments,
213                }
214            }
215        }
216    };
217}
218
219define_bms_flex_row_kernel_input_types! {
220    f64_fields: [
221        q,
222        b,
223        mu_1,
224        mu_2,
225        z_obs,
226        y,
227        w,
228        cell_c0,
229        cell_c1,
230        cell_c2,
231        cell_c3,
232        cell_a,
233        cell_aa,
234        cell_r,
235        cell_ar,
236        cell_sbb,
237        cell_sbh,
238        cell_sbw,
239        chi_obs,
240        xi_obs,
241        rho_u,
242        tau_u,
243        r_uv,
244    ],
245    u32_fields: [cell_offsets],
246    moments_field: cell_moments,
247}
248
249/// Per-row outputs produced by [`launch_bms_flex_row_kernel`].
250#[derive(Debug)]
251pub(crate) struct BmsFlexRowKernelOutputs {
252    /// Per-row negative log-likelihood. Length `n_rows`.
253    pub neglog: Vec<f64>,
254    /// Per-row gradient, row-major `[n_rows, r]`.
255    pub grad: Vec<f64>,
256    /// Per-row Hessian, row-major `[n_rows, r*r]`. The kernel writes the full
257    /// symmetric matrix.
258    pub hess: Vec<f64>,
259}
260
261impl<'a> BmsFlexRowKernelInputs<'a> {
262    /// Sanity-check every shape the kernel relies on. This is the only place
263    /// length errors are surfaced — the device kernel assumes valid layout.
264    pub(crate) fn validate(&self) -> Result<(), GpuError> {
265        if self.r == 0 {
266            return Err(GpuError::DriverCallFailed {
267                reason: "bms_flex_row inputs: r must be > 0".to_string(),
268            });
269        }
270        if self.r > MAX_R {
271            return Err(GpuError::DriverCallFailed {
272                reason: format!("bms_flex_row inputs: r={} exceeds MAX_R={MAX_R}", self.r),
273            });
274        }
275        if self.r != 2 + self.p_h + self.p_w {
276            return Err(GpuError::DriverCallFailed {
277                reason: format!(
278                    "bms_flex_row inputs: r={} must equal 2 + p_h({}) + p_w({}) = {}",
279                    self.r,
280                    self.p_h,
281                    self.p_w,
282                    2 + self.p_h + self.p_w
283                ),
284            });
285        }
286        let n = self.n_rows;
287        let check_len = |name: &str, have: usize, want: usize| -> Result<(), GpuError> {
288            if have != want {
289                return Err(GpuError::DriverCallFailed {
290                    reason: format!("bms_flex_row inputs: {name}.len()={have} != {want}"),
291                });
292            }
293            Ok(())
294        };
295        check_len("q", self.q.len(), n)?;
296        check_len("b", self.b.len(), n)?;
297        check_len("mu_1", self.mu_1.len(), n)?;
298        check_len("mu_2", self.mu_2.len(), n)?;
299        check_len("z_obs", self.z_obs.len(), n)?;
300        check_len("y", self.y.len(), n)?;
301        check_len("w", self.w.len(), n)?;
302        check_len("chi_obs", self.chi_obs.len(), n)?;
303        check_len("xi_obs", self.xi_obs.len(), n)?;
304        check_len("rho_u", self.rho_u.len(), n * self.r)?;
305        check_len("tau_u", self.tau_u.len(), n * self.r)?;
306        check_len("r_uv", self.r_uv.len(), n * self.r * self.r)?;
307        check_len("cell_offsets", self.cell_offsets.len(), n + 1)?;
308        let total_cells_u32 = self.cell_offsets[n];
309        let total_cells = total_cells_u32 as usize;
310        check_len("cell_c0", self.cell_c0.len(), total_cells)?;
311        check_len("cell_c1", self.cell_c1.len(), total_cells)?;
312        check_len("cell_c2", self.cell_c2.len(), total_cells)?;
313        check_len("cell_c3", self.cell_c3.len(), total_cells)?;
314        check_len("cell_a", self.cell_a.len(), total_cells * COEFF4)?;
315        check_len("cell_aa", self.cell_aa.len(), total_cells * COEFF4)?;
316        check_len(
317            "cell_r",
318            self.cell_r.len(),
319            total_cells * self.r.saturating_sub(1) * COEFF4,
320        )?;
321        check_len(
322            "cell_ar",
323            self.cell_ar.len(),
324            total_cells * self.r.saturating_sub(1) * COEFF4,
325        )?;
326        check_len("cell_sbb", self.cell_sbb.len(), total_cells * COEFF4)?;
327        check_len(
328            "cell_sbh",
329            self.cell_sbh.len(),
330            total_cells * self.p_h * COEFF4,
331        )?;
332        check_len(
333            "cell_sbw",
334            self.cell_sbw.len(),
335            total_cells * self.p_w * COEFF4,
336        )?;
337        check_len(
338            "cell_moments",
339            self.cell_moments.len(),
340            total_cells * MOMENT_STRIDE,
341        )?;
342        // Bonus: when the moments came from `CellMomentsSource::Device`, the
343        // launcher needs to know the source is from a device buffer; nothing
344        // to validate beyond length above. The Host variant length check is
345        // also already covered above.
346        // Monotone cell_offsets check.
347        for i in 0..n {
348            if self.cell_offsets[i] > self.cell_offsets[i + 1] {
349                return Err(GpuError::DriverCallFailed {
350                    reason: format!(
351                        "bms_flex_row inputs: cell_offsets must be monotone (offset[{}]={} > offset[{}]={})",
352                        i,
353                        self.cell_offsets[i],
354                        i + 1,
355                        self.cell_offsets[i + 1]
356                    ),
357                });
358            }
359        }
360        Ok(())
361    }
362}
363
364/// NVRTC kernel source body. One CUDA block per row; 32 threads per block
365/// parallise the per-cell sums into shared-memory scratch; thread 0 of the
366/// block finishes the IFT + observed-point + Mills + Hessian write-out.
367///
368/// Shared probit numerics (`erfcx_nonnegative`, `log_ndtr`,
369/// `log_ndtr_and_mills`) are provided by
370/// `numerics_device::PROBIT_NUMERICS_CU`, which is prepended before
371/// passing to `cudarc::nvrtc::compile_ptx`.
372///
373/// **CPU parity reference**: the body mirrors
374/// `compute_row_analytic_flex_from_parts_into` in
375/// `src/families/bernoulli_marginal_slope.rs`.
376#[cfg(target_os = "linux")]
377pub(crate) const ROW_KERNEL_BODY: &str = r#"
378// One block per row. blockDim.x = 32; threadIdx.x parallises per-cell sums.
379// CPU parity reference: src/families/bernoulli_marginal_slope.rs
380//                      ::compute_row_analytic_flex_from_parts_into.
381
382#define INV_TWO_PI     0.15915494309189535
383
384extern "C" __device__ __forceinline__ double atomic_add_f64(double *addr, double value) {
385    unsigned long long int *addr_as_ull = (unsigned long long int *)addr;
386    unsigned long long int old = *addr_as_ull;
387    unsigned long long int assumed;
388    do {
389        assumed = old;
390        double next = __longlong_as_double((long long int)assumed) + value;
391        old = atomicCAS(addr_as_ull, assumed, (unsigned long long int)__double_as_longlong(next));
392    } while (assumed != old);
393    return __longlong_as_double((long long int)old);
394}
395
396// `nan_fill_outputs`: thread-0-only path used when row inputs are degenerate
397// (`F_a` non-finite or non-positive). Writes NaNs to neglog/grad/hess so the
398// host falls back to CPU for that row.
399extern "C" __device__ __forceinline__ void
400nan_fill_outputs(int r,
401                 int row,
402                 double *out_neglog,
403                 double *out_grad,
404                 double *out_hess) {
405    double nan_value = __longlong_as_double(0x7ff8000000000000ULL);
406    out_neglog[row] = nan_value;
407    for (int u = 0; u < r; ++u) {
408        out_grad[row * r + u] = nan_value;
409    }
410    int rr = r * r;
411    for (int idx = 0; idx < rr; ++idx) {
412        out_hess[row * rr + idx] = nan_value;
413    }
414}
415
416extern "C" __global__ void bms_flex_row_kernel(
417    int                  n_rows,
418    int                  r,
419    int                  p_h,
420    int                  p_w,
421    double               s_f,                // currently unused on device:
422                                             // host has already baked S_f
423                                             // into the cubic coefficients.
424                                             // Kept for diagnostic parity.
425    const double * __restrict__ row_q,
426    const double * __restrict__ row_b,
427    const double * __restrict__ row_mu1,
428    const double * __restrict__ row_mu2,
429    const double * __restrict__ row_zobs,
430    const double * __restrict__ row_y,
431    const double * __restrict__ row_w,
432    const unsigned int * __restrict__ cell_offsets,
433    const double * __restrict__ cell_c0,
434    const double * __restrict__ cell_c1,
435    const double * __restrict__ cell_c2,
436    const double * __restrict__ cell_c3,
437    const double * __restrict__ cell_a,       // [n_cells, 4]
438    const double * __restrict__ cell_aa,      // [n_cells, 4]
439    const double * __restrict__ cell_r,       // [n_cells, r-1, 4]
440    const double * __restrict__ cell_ar,      // [n_cells, r-1, 4]
441    const double * __restrict__ cell_sbb,     // [n_cells, 4]
442    const double * __restrict__ cell_sbh,     // [n_cells, p_h, 4]
443    const double * __restrict__ cell_sbw,     // [n_cells, p_w, 4]
444    const double * __restrict__ cell_moments, // [n_cells, 10]
445    const double * __restrict__ row_chi,
446    const double * __restrict__ row_xi,
447    const double * __restrict__ row_rho,      // [n_rows, r]
448    const double * __restrict__ row_tau,      // [n_rows, r]
449    const double * __restrict__ row_ruv,      // [n_rows, r*r]
450    double       * __restrict__ out_neglog,
451    double       * __restrict__ out_grad,
452    double       * __restrict__ out_hess)
453{
454    int row = blockIdx.x;
455    if (row >= n_rows) return;
456    int tid = threadIdx.x;
457
458    // ── shared scratch (sized to MAX_R = 32) ──────────────────────────────
459    // Layout (doubles):
460    //   F_u      [r]
461    //   F_au     [r]
462    //   F_uv     [r*r]
463    //   bar_e_u  [r]
464    //   bar_e_uv [r*r]
465    //   reduce_a [blockDim.x]
466    //   reduce_b [blockDim.x]
467    // Sized for the worst case (r = MAX_R = 32).
468    __shared__ double F_u[32];
469    __shared__ double F_au[32];
470    __shared__ double F_uv[32 * 32];
471    __shared__ double bar_e_u[32];
472    __shared__ double bar_e_uv[32 * 32];
473    __shared__ double reduce_a[32];
474    __shared__ double reduce_b[32];
475    __shared__ double F_a_shared;
476    __shared__ double F_aa_shared;
477
478    // Zero scratch.
479    if (tid == 0) { F_a_shared = 0.0; F_aa_shared = 0.0; }
480    for (int u = tid; u < r; u += blockDim.x) {
481        F_u[u]  = 0.0;
482        F_au[u] = 0.0;
483    }
484    for (int uv = tid; uv < r * r; uv += blockDim.x) {
485        F_uv[uv] = 0.0;
486    }
487    __syncthreads();
488
489    // ── per-cell sweep ───────────────────────────────────────────────────
490    unsigned int cell_lo = cell_offsets[row];
491    unsigned int cell_hi = cell_offsets[row + 1];
492    int n_cells = (int)(cell_hi - cell_lo);
493
494    double local_Fa  = 0.0;
495    double local_Faa = 0.0;
496
497    for (int local_c = tid; local_c < n_cells; local_c += blockDim.x) {
498        unsigned int c = cell_lo + (unsigned int)local_c;
499
500        // Load cubic predictor coeffs C0..C3.
501        double C[4];
502        C[0] = cell_c0[c]; C[1] = cell_c1[c];
503        C[2] = cell_c2[c]; C[3] = cell_c3[c];
504
505        // Load m_0..m_9.
506        const double *m = cell_moments + (size_t)c * 10;
507
508        // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
509        // CPU parity: equivalent to the `eta_rs ⊗ moments` contraction in
510        //             `cell_second_derivative_from_moments` after folding the
511        //             cubic predictor.
512        double T[7];
513        #pragma unroll
514        for (int n = 0; n < 7; ++n) {
515            double acc = 0.0;
516            #pragma unroll
517            for (int e = 0; e < 4; ++e) {
518                acc = fma(C[e], m[e + n], acc);
519            }
520            T[n] = acc * INV_TWO_PI;
521        }
522
523        // D(R) = κ · Σ_k R_k · m_k.
524        // CPU parity: `cell_first_derivative_from_moments`.
525        #define D_OF(R) (INV_TWO_PI * (R[0]*m[0] + R[1]*m[1] + R[2]*m[2] + R[3]*m[3]))
526
527        // Q(R, S) = Σ_{p,q} R_p · S_q · T_{p+q}.
528        // CPU parity: the `eta_rs` folded dot in
529        // `cell_second_derivative_from_moments`.
530        #define Q_OF(R, S)                                                                 \
531            ((R[0]*S[0])*T[0] + (R[0]*S[1] + R[1]*S[0])*T[1]                               \
532             + (R[0]*S[2] + R[1]*S[1] + R[2]*S[0])*T[2]                                    \
533             + (R[0]*S[3] + R[1]*S[2] + R[2]*S[1] + R[3]*S[0])*T[3]                        \
534             + (R[1]*S[3] + R[2]*S[2] + R[3]*S[1])*T[4]                                    \
535             + (R[2]*S[3] + R[3]*S[2])*T[5]                                                \
536             + (R[3]*S[3])*T[6])
537
538        // F_a += D(A_c) ; F_aa += H(A_c, A_c, AA_c) = D(AA_c) − Q(A_c, A_c).
539        const double *A_c  = cell_a  + (size_t)c * 4;
540        const double *AA_c = cell_aa + (size_t)c * 4;
541        local_Fa  += D_OF(A_c);
542        local_Faa += D_OF(AA_c) - Q_OF(A_c, A_c);
543
544        // For each u > 0: F_u += D(R_{c,u}) ; F_au += H(A_c, R_{c,u}, AR_{c,u})
545        //                                   = D(AR_{c,u}) − Q(A_c, R_{c,u}).
546        for (int u = 1; u < r; ++u) {
547            const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
548            const double *AR_u = cell_ar + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
549            double d_R   = D_OF(R_u);
550            double d_AR  = D_OF(AR_u);
551            double q_AR  = Q_OF(A_c, R_u);
552            atomic_add_f64(&F_u[u], d_R);
553            atomic_add_f64(&F_au[u], d_AR - q_AR);
554        }
555
556        // F_uv: only b·b, b·h_j, b·w_ℓ have a material `S_{c,uv}`; every other
557        // (u, v) pair just contributes −Q(R_u, R_v).
558        // CPU parity: `SparsePrimaryCoeffJetView::pair_from_b_family` with
559        // `COEFF_SUPPORT_BHW` — every cross pair outside the b-row is zero.
560        for (int u = 1; u < r; ++u) {
561            const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
562            for (int v = u; v < r; ++v) {
563                const double *R_v = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(v - 1)) * 4;
564                double q_uv = Q_OF(R_u, R_v);
565                double d_s  = 0.0;
566                // S_{bb}: u == v == 1 (b coordinate).
567                if (u == 1 && v == 1) {
568                    const double *S_bb = cell_sbb + (size_t)c * 4;
569                    d_s = D_OF(S_bb);
570                }
571                // S_{b·h_j}: u == 1, v in score-warp block, or symmetric.
572                else if (u == 1 && v >= 2 && v < 2 + p_h) {
573                    int j = v - 2;
574                    const double *S_bh = cell_sbh + ((size_t)c * (size_t)p_h + (size_t)j) * 4;
575                    d_s = D_OF(S_bh);
576                }
577                // S_{b·w_ℓ}: u == 1, v in link-wiggle block, or symmetric.
578                else if (u == 1 && v >= 2 + p_h && v < r) {
579                    int l = v - (2 + p_h);
580                    const double *S_bw = cell_sbw + ((size_t)c * (size_t)p_w + (size_t)l) * 4;
581                    d_s = D_OF(S_bw);
582                }
583                // Symmetric mirror: u in (h or w) block, v == 1 cannot happen
584                // because we iterate v >= u; skip.
585                double val = d_s - q_uv;
586                atomic_add_f64(&F_uv[u * r + v], val);
587            }
588        }
589
590        #undef D_OF
591        #undef Q_OF
592    }
593
594    // Block reduction of local_Fa, local_Faa into shared.
595    reduce_a[tid] = local_Fa;
596    reduce_b[tid] = local_Faa;
597    __syncthreads();
598    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
599        if (tid < stride) {
600            reduce_a[tid] += reduce_a[tid + stride];
601            reduce_b[tid] += reduce_b[tid + stride];
602        }
603        __syncthreads();
604    }
605    if (tid == 0) {
606        F_a_shared  = reduce_a[0];
607        F_aa_shared = reduce_b[0];
608    }
609    __syncthreads();
610
611    // ── thread-0 finalisation: IFT + observed-point + Mills + writes ──────
612    if (tid != 0) return;
613
614    double F_a  = F_a_shared;
615    double F_aa = F_aa_shared;
616    double mu_1 = row_mu1[row];
617    double mu_2 = row_mu2[row];
618
619    // q-row overrides.
620    //   F_q  = -mu_1 ; F_qq = -mu_2 ; F_qv = 0 (v > 0) ; F_aq = 0.
621    F_u[0]  = -mu_1;
622    F_au[0] = 0.0;
623    // Zero the q-cross row/column of F_uv (u == 0 or v == 0), then plant -mu_2 at (0,0).
624    for (int v = 0; v < r; ++v) {
625        F_uv[0 * r + v] = 0.0;
626        F_uv[v * r + 0] = 0.0;
627    }
628    F_uv[0 * r + 0] = -mu_2;
629
630    // Guard: degenerate F_a ⇒ NaN-fill this row's outputs.
631    if (!isfinite(F_a) || F_a <= 0.0) {
632        nan_fill_outputs(r, row, out_neglog, out_grad, out_hess);
633        return;
634    }
635    double inv_Fa = 1.0 / F_a;
636
637    // IFT, first order.
638    //   a_u = -F_u · inv_Fa     (q-override: a_q = mu_1 · inv_Fa).
639    double a_u[32];
640    a_u[0] = mu_1 * inv_Fa;
641    for (int u = 1; u < r; ++u) {
642        a_u[u] = -F_u[u] * inv_Fa;
643    }
644
645    // IFT, second order.
646    //   a_uv = -(F_uv + F_au · a_v + F_av · a_u + F_aa · a_u · a_v) · inv_Fa.
647    // The q-row contributions (u==0 or v==0) collapse to a_uv = mu_2 · inv_Fa
648    // when both are 0 and to (F_au_v) · inv_Fa-style mixed shape otherwise.
649    // We compute it uniformly using the populated F_uv / F_au with the
650    // q-overrides above.
651    double a_uv[32 * 32];
652    for (int u = 0; u < r; ++u) {
653        for (int v = u; v < r; ++v) {
654            double term = F_uv[u * r + v]
655                        + F_au[v] * a_u[u]
656                        + F_au[u] * a_u[v]
657                        + F_aa * a_u[u] * a_u[v];
658            double val = -term * inv_Fa;
659            a_uv[u * r + v] = val;
660            a_uv[v * r + u] = val;
661        }
662    }
663
664    // Observed predictor jets at z_obs.
665    //   bar_e_u  = chi · a_u + rho_u.
666    //   bar_e_uv = chi · a_uv + xi · a_u · a_v + tau_u · a_v + a_u · tau_v + r_uv.
667    double chi = row_chi[row];
668    double xi  = row_xi[row];
669    const double *rho = row_rho + (size_t)row * r;
670    const double *tau = row_tau + (size_t)row * r;
671    const double *ruv = row_ruv + (size_t)row * r * r;
672
673    for (int u = 0; u < r; ++u) {
674        bar_e_u[u] = chi * a_u[u] + rho[u];
675    }
676    for (int u = 0; u < r; ++u) {
677        for (int v = u; v < r; ++v) {
678            double val = chi * a_uv[u * r + v]
679                       + xi  * a_u[u] * a_u[v]
680                       + tau[u] * a_u[v]
681                       + a_u[u] * tau[v]
682                       + ruv[u * r + v];
683            bar_e_uv[u * r + v] = val;
684            if (u != v) {
685                bar_e_uv[v * r + u] = val;
686            }
687        }
688    }
689
690    // Probit Mills.
691    double y    = row_y[row];
692    double w    = row_w[row];
693    double s    = 2.0 * y - 1.0;
694    // The "observed predictor" e_obs is the value (degree-0) term of the
695    // observed jet — same convention as the CPU path. CPU parity:
696    // `e_obs = chi · a_0 + rho_0`... well, no: `bar_e_u` is the *first*
697    // derivative jet, not the value. The observed predictor value comes
698    // from the host pre-evaluation as `rho_u[0]` of the value jet —
699    // pre-baked into the host's `m = s · e_obs` payload. For Stage 2 we
700    // expose it via the `bar_e_u[0]` slot which is `chi·a_0 + rho_0`; the
701    // host wiring lands in the dispatcher wave that bridges this kernel
702    // and the row evaluator in `bernoulli_marginal_slope.rs`.
703    double e_obs = bar_e_u[0];
704    double m_arg = s * e_obs;
705    double log_cdf, lambda;
706    log_ndtr_and_mills(m_arg, &log_cdf, &lambda);
707    double A_i = -w * s * lambda;
708    double B_i =  w * lambda * (m_arg + lambda);
709
710    out_neglog[row] = -w * log_cdf;
711    for (int u = 0; u < r; ++u) {
712        out_grad[row * r + u] = A_i * bar_e_u[u];
713    }
714    for (int u = 0; u < r; ++u) {
715        for (int v = u; v < r; ++v) {
716            double val = B_i * bar_e_u[u] * bar_e_u[v] + A_i * bar_e_uv[u * r + v];
717            out_hess[row * r * r + u * r + v] = val;
718            if (u != v) {
719                out_hess[row * r * r + v * r + u] = val;
720            }
721        }
722    }
723}
724"#;
725
726// Force `s_f` to be considered used at the Rust level even though Stage 2 of
727// the kernel doesn't consume it on-device (the host has already baked the
728// probit frailty scale into the per-cell cubic coefficients). The dispatcher
729// wave that ports the rigid-branch fallback may want to apply `s_f` device-side
730// for log diagnostics; leaving the field on the input struct + reading it here
731// avoids a `let _` silencer the build.rs scanner would reject.
732#[inline]
733pub(crate) fn s_f_diagnostic_finite(inputs: &BmsFlexRowKernelInputs<'_>) -> bool {
734    inputs.s_f.is_finite() && inputs.s_f > 0.0
735}
736
737#[cfg(target_os = "linux")]
738pub(crate) struct RowKernelBackend {
739    pub(crate) stream: Arc<CudaStream>,
740    pub(crate) module: Arc<CudaModule>,
741}
742
743#[cfg(target_os = "linux")]
744impl RowKernelBackend {
745    pub(crate) fn probe() -> Result<&'static Self, GpuError> {
746        static BACKEND: OnceLock<Result<RowKernelBackend, GpuError>> = OnceLock::new();
747        BACKEND
748            .get_or_init(|| {
749                gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row", |parts| {
750                    let row_kernel_source = [
751                        gam_gpu::numerics_device::PROBIT_NUMERICS_CU,
752                        ROW_KERNEL_BODY,
753                    ]
754                    .concat();
755                    // #1551: route through the project's single arch-aware NVRTC
756                    // entry point instead of bare `cudarc::nvrtc::compile_ptx`.
757                    // `compile_ptx_arch` pins `--gpu-architecture` to the selected
758                    // device's compute capability and supplies the standard CUDA
759                    // include paths; bare `compile_ptx` uses NVRTC's default
760                    // virtual arch with no includes. The row kernel's 64-bit
761                    // `atomic_add_f64` (atomicCAS emulation) compiles best against
762                    // the real device arch, and this keeps every BMS-flex compile
763                    // site consistent with the SAE arrow/Schur kernels that do
764                    // require the sm_60 pin for native `atomicAdd(double*,double)`.
765                    let ptx = gam_gpu::device_cache::compile_ptx_arch(&row_kernel_source)
766                        .map_err(|err| GpuError::DriverCallFailed {
767                            reason: format!("bms_flex_row NVRTC compile failed: {err}"),
768                        })?;
769                    let module =
770                        parts
771                            .ctx
772                            .load_module(ptx)
773                            .map_err(|err| GpuError::DriverCallFailed {
774                                reason: format!("bms_flex_row module load failed: {err}"),
775                            })?;
776                    Ok(RowKernelBackend {
777                        stream: parts.stream.clone(),
778                        module,
779                    })
780                })
781            })
782            .as_ref()
783            .map_err(GpuError::clone)
784    }
785}
786
787/// Launch Stage-2 BMS FLEX row kernel. On non-Linux returns
788/// [`GpuError::DriverLibraryUnavailable`]; on Linux NVRTC-compiles the kernel
789/// (cached for the process lifetime), uploads the per-row + per-cell buffers,
790/// and dispatches one block per row.
791pub(crate) fn launch_bms_flex_row_kernel(
792    inputs: BmsFlexRowKernelInputs<'_>,
793) -> Result<BmsFlexRowKernelOutputs, GpuError> {
794    inputs.validate()?;
795    if !s_f_diagnostic_finite(&inputs) {
796        return Err(GpuError::DriverCallFailed {
797            reason: format!(
798                "bms_flex_row inputs: s_f must be positive and finite, got {}",
799                inputs.s_f
800            ),
801        });
802    }
803
804    #[cfg(target_os = "linux")]
805    {
806        launch_linux(inputs)
807    }
808    #[cfg(not(target_os = "linux"))]
809    {
810        Err(GpuError::DriverLibraryUnavailable {
811            reason: "bms_flex_row GPU kernel is Linux-only".to_string(),
812        })
813    }
814}
815
816#[cfg(target_os = "linux")]
817pub(crate) fn launch_linux(
818    inputs: BmsFlexRowKernelInputs<'_>,
819) -> Result<BmsFlexRowKernelOutputs, GpuError> {
820    let backend = RowKernelBackend::probe()?;
821    let stream = &backend.stream;
822
823    let upload_f64 = |slice: &[f64], label: &str| {
824        stream
825            .clone_htod(slice)
826            .map_err(|err| GpuError::DriverCallFailed {
827                reason: format!("bms_flex_row upload {label}: {err}"),
828            })
829    };
830    let upload_u32 = |slice: &[u32], label: &str| {
831        stream
832            .clone_htod(slice)
833            .map_err(|err| GpuError::DriverCallFailed {
834                reason: format!("bms_flex_row upload {label}: {err}"),
835            })
836    };
837
838    let d_q = upload_f64(inputs.q, "q")?;
839    let d_b = upload_f64(inputs.b, "b")?;
840    let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
841    let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
842    let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
843    let d_y = upload_f64(inputs.y, "y")?;
844    let d_w = upload_f64(inputs.w, "w")?;
845    let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
846    let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
847    let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
848    let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
849    let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
850    let d_a = upload_f64(inputs.cell_a, "cell_a")?;
851    let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
852    let d_r = upload_f64(inputs.cell_r, "cell_r")?;
853    let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
854    let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
855    let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
856    let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
857    // Phase-4: optionally consume device-resident moments (no host upload).
858    // Both branches end up holding a `&CudaSlice<f64>` named `d_moments_ref`
859    // we can pass to the launch builder uniformly.
860    let owned_host_moments: CudaSlice<f64>;
861    let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
862        CellMomentsSource::Host(slice) => {
863            owned_host_moments = upload_f64(slice, "cell_moments")?;
864            &owned_host_moments
865        }
866        CellMomentsSource::Device(d) => *d,
867    };
868    let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
869    let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
870    let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
871    let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
872    let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
873
874    let n = inputs.n_rows;
875    let r = inputs.r;
876    let mut d_neglog = stream
877        .alloc_zeros::<f64>(n)
878        .map_err(|err| GpuError::DriverCallFailed {
879            reason: format!("bms_flex_row alloc neglog: {err}"),
880        })?;
881    let mut d_grad =
882        stream
883            .alloc_zeros::<f64>(n * r)
884            .map_err(|err| GpuError::DriverCallFailed {
885                reason: format!("bms_flex_row alloc grad: {err}"),
886            })?;
887    let mut d_hess =
888        stream
889            .alloc_zeros::<f64>(n * r * r)
890            .map_err(|err| GpuError::DriverCallFailed {
891                reason: format!("bms_flex_row alloc hess: {err}"),
892            })?;
893
894    let func = backend
895        .module
896        .load_function("bms_flex_row_kernel")
897        .map_err(|err| GpuError::DriverCallFailed {
898            reason: format!("bms_flex_row load_function: {err}"),
899        })?;
900
901    let cfg = LaunchConfig {
902        grid_dim: (n as u32, 1, 1),
903        block_dim: (ROW_KERNEL_THREADS, 1, 1),
904        shared_mem_bytes: 0,
905    };
906    let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
907        reason: format!("bms_flex_row: n_rows={n} exceeds i32 range"),
908    })?;
909    let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
910        reason: format!("bms_flex_row: r={r} exceeds i32 range"),
911    })?;
912    let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
913        reason: format!("bms_flex_row: p_h={} exceeds i32 range", inputs.p_h),
914    })?;
915    let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
916        reason: format!("bms_flex_row: p_w={} exceeds i32 range", inputs.p_w),
917    })?;
918    let s_f = inputs.s_f;
919
920    let mut builder = stream.launch_builder(&func);
921    builder
922        .arg(&n_i32)
923        .arg(&r_i32)
924        .arg(&p_h_i32)
925        .arg(&p_w_i32)
926        .arg(&s_f)
927        .arg(&d_q)
928        .arg(&d_b)
929        .arg(&d_mu1)
930        .arg(&d_mu2)
931        .arg(&d_zobs)
932        .arg(&d_y)
933        .arg(&d_w)
934        .arg(&d_offsets)
935        .arg(&d_c0)
936        .arg(&d_c1)
937        .arg(&d_c2)
938        .arg(&d_c3)
939        .arg(&d_a)
940        .arg(&d_aa)
941        .arg(&d_r)
942        .arg(&d_ar)
943        .arg(&d_sbb)
944        .arg(&d_sbh)
945        .arg(&d_sbw)
946        .arg(d_moments_ref)
947        .arg(&d_chi)
948        .arg(&d_xi)
949        .arg(&d_rho)
950        .arg(&d_tau)
951        .arg(&d_ruv)
952        .arg(&mut d_neglog)
953        .arg(&mut d_grad)
954        .arg(&mut d_hess);
955
956    // SAFETY: every kernel parameter above is either a primitive `i32` /
957    // `f64` (passed by value), a const device pointer to a buffer whose
958    // length the host validated against the input struct, or an output
959    // buffer pre-allocated to `n_rows`, `n_rows*r`, `n_rows*r*r`
960    // doubles. The kernel's shared-memory arrays are sized to MAX_R = 32
961    // and validate() rejects r > MAX_R.
962    unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
963        reason: format!("bms_flex_row launch: {err}"),
964    })?;
965    stream
966        .synchronize()
967        .map_err(|err| GpuError::DriverCallFailed {
968            reason: format!("bms_flex_row synchronize: {err}"),
969        })?;
970
971    let neglog = stream
972        .clone_dtoh(&d_neglog)
973        .map_err(|err| GpuError::DriverCallFailed {
974            reason: format!("bms_flex_row download neglog: {err}"),
975        })?;
976    let grad = stream
977        .clone_dtoh(&d_grad)
978        .map_err(|err| GpuError::DriverCallFailed {
979            reason: format!("bms_flex_row download grad: {err}"),
980        })?;
981    let hess = stream
982        .clone_dtoh(&d_hess)
983        .map_err(|err| GpuError::DriverCallFailed {
984            reason: format!("bms_flex_row download hess: {err}"),
985        })?;
986
987    Ok(BmsFlexRowKernelOutputs { neglog, grad, hess })
988}
989
990// ─────────────────────────────────────────────────────────────────────────────
991// Phase 3: device-resident row Hessian + HVP / diagonal kernels.
992//
993// Math (mirrors the CPU oracle in
994// `src/families/bernoulli_marginal_slope.rs::exact_newton_joint_hessian_*_from_cache`):
995//
996//   Block layout (joint β):
997//     marginal = [0..p_m), logslope = [p_m..p_m+p_g),
998//     h        = [h_start..h_end), w = [w_start..w_end), total = p_total.
999//
1000//   Primary layout (per-row r-vector):
1001//     q = 0, logslope = 1,
1002//     h = [h_primary_start..h_primary_end),
1003//     w = [w_primary_start..w_primary_end), total = r.
1004//
1005//   row_dir[u] for u in primary layout:
1006//     row_dir[0]   = Σ_j marginal_design[row, j] · v[j]
1007//     row_dir[1]   = Σ_j logslope_design[row, j] · v[p_m + j]
1008//     row_dir[h_k] = v[h_block_start + (h_k - h_primary_start)]
1009//     row_dir[w_k] = v[w_block_start + (w_k - w_primary_start)]
1010//
1011//   action[u]    = Σ_v row_hessians[row, u*r + v] · row_dir[v]
1012//
1013//   block_partial[marginal_j] += action[0] · marginal_design[row, j]
1014//   block_partial[logslope_j] += action[1] · logslope_design[row, j]
1015//   block_partial[h_block_start + (h_k - h_primary_start)] += action[h_k]
1016//   block_partial[w_block_start + (w_k - w_primary_start)] += action[w_k]
1017//
1018// Diagonal:
1019//   diag[marginal_j] += row_hess[row, 0*r + 0] · marginal_design[row, j]²
1020//   diag[logslope_j] += row_hess[row, 1*r + 1] · logslope_design[row, j]²
1021//   diag[h_block_start + k] += row_hess[row, ii*r + ii]   (ii = h_primary_start + k)
1022//   diag[w_block_start + k] += row_hess[row, ii*r + ii]   (ii = w_primary_start + k)
1023//
1024// Determinism: each CTA owns a contiguous slice of `[chunk_start..chunk_end)`
1025// rows and writes its full per-chunk `p_total` partial into a non-overlapping
1026// region of the global partial buffer. The reduce kernel then sums those
1027// partials in fixed chunk-major order. No atomics.
1028
1029/// Joint-β block layout shared with the host (mirrors `BlockSlices` in
1030/// `bernoulli_marginal_slope.rs`).
1031///
1032/// Gating: Linux-only. The lone production constructor lives in
1033/// `bernoulli_marginal_slope.rs:9189` behind `#[cfg(target_os = "linux")]`
1034/// — the device-resident row-Hessian path is the only producer (see
1035/// `launch_bms_flex_row_kernel_device_resident`), and the joint-β
1036/// consumers `launch_bms_flex_row_hvp` / `_diagonal` / `_dense_block`
1037/// are also Linux-only. Any non-Linux test referencing this type must
1038/// guard itself with `#[cfg(target_os = "linux")]` too — the build.rs
1039/// ban scanner explicitly rejects `#[cfg(any(..., test))]` on items as
1040/// a dead-code escape hatch.
1041#[cfg(target_os = "linux")]
1042#[derive(Clone, Debug)]
1043pub(crate) struct BmsFlexBlockLayout {
1044    pub p_m: usize,
1045    pub p_g: usize,
1046    pub h: Option<std::ops::Range<usize>>,
1047    pub w: Option<std::ops::Range<usize>>,
1048    pub p_total: usize,
1049}
1050
1051/// Primary-r layout shared with the host (mirrors `PrimarySlices`).
1052/// Gating rationale identical to [`BmsFlexBlockLayout`].
1053#[cfg(target_os = "linux")]
1054#[derive(Clone, Debug)]
1055pub(crate) struct BmsFlexPrimaryLayout {
1056    pub h: Option<std::ops::Range<usize>>,
1057    pub w: Option<std::ops::Range<usize>>,
1058    pub r: usize,
1059}
1060
1061// ── Linux-only: device-resident row-Hessian state + kernels ─────────────────
1062
1063/// Number of rows each HVP / diagonal CTA processes. Each CTA writes a single
1064/// `[1, p_total]` partial row into the global partial buffer (no atomics);
1065/// the reduce kernel then sums partials in chunk-major fixed order.
1066#[cfg(target_os = "linux")]
1067pub(crate) const HVP_ROWS_PER_CTA: u32 = 256;
1068
1069/// `blockDim.x` for the HVP / diagonal partial kernels.
1070#[cfg(target_os = "linux")]
1071pub(crate) const HVP_THREADS: u32 = 128;
1072
1073/// `blockDim.x` for the partial-sum reduction kernels (one element per thread,
1074/// grid-strided over the `p_total`/`rhs_elems` partial buffer). A full warp
1075/// multiple that keeps the reduce launch occupancy-bound rather than tail-bound
1076/// for the typical large-scale `p_total`.
1077#[cfg(target_os = "linux")]
1078pub(crate) const REDUCTION_THREADS: u32 = 256;
1079
1080/// Maximum RHS columns fused into one row-primary HVP launch. The matching
1081/// CUDA source uses fixed shared arrays sized as
1082/// `BMS_FLEX_ROW_HVP_MAX_RHS * MAX_R`; increasing this requires updating the
1083/// `MAX_MULTI_RHS` define in `HVP_KERNEL_SOURCE`.
1084#[cfg(target_os = "linux")]
1085pub(crate) const BMS_FLEX_ROW_HVP_MAX_RHS: usize = 8;
1086
1087/// Device-resident state produced by
1088/// [`launch_bms_flex_row_kernel_device_resident`] and consumed by
1089/// [`launch_bms_flex_row_hvp`] / [`launch_bms_flex_row_diagonal`].
1090///
1091/// Owns the row-Hessian + design slices on-device so the host can issue
1092/// many HVPs against the same β snapshot without round-tripping
1093/// 626 MB through host RAM. Drop releases the device memory back to
1094/// the CUDA runtime.
1095/// Per-row Hessian storage layout on the device. The build path is free to
1096/// emit either, and the Hv / diag kernels read whichever the storage says.
1097///
1098/// Charter (Block 9 Phase 4): packed-upper halves the DRAM footprint of the
1099/// `n × r²` cache (per-row `r*(r+1)/2` doubles instead of `r²`), at the cost
1100/// of a single per-entry index conversion in the kernel. The benchmark
1101/// decides whether the packed path becomes the default for large-scale
1102/// fits (`r = 20` → 210 vs 400 doubles per row, ~47.5% smaller). The
1103/// numerics are bit-equal because each `H_i` is symmetric by construction
1104/// (the row kernel emits a symmetric block by construction — see the
1105/// symmetric scratch-write loop in `bms_flex_row_kernel`'s shared-memory
1106/// finaliser).
1107#[cfg(target_os = "linux")]
1108pub struct DeviceResidentRowHess {
1109    /// Per-row dense `[n, r, r]` row-major Hessian. Element `(u, v)` of row
1110    /// `i` is `hess[i*r*r + u*r + v]`. This is the only on-device storage
1111    /// layout supported by the current HVP / diag kernels.
1112    pub(crate) hess: CudaSlice<f64>,
1113    pub(crate) marginal_design: CudaSlice<f64>,
1114    pub(crate) logslope_design: CudaSlice<f64>,
1115    pub(crate) n: usize,
1116    pub(crate) r: usize,
1117    pub(crate) block: BmsFlexBlockLayout,
1118    pub(crate) primary: BmsFlexPrimaryLayout,
1119    /// Estimated bytes resident on device (for accounting).
1120    pub(crate) bytes: u64,
1121}
1122
1123#[cfg(target_os = "linux")]
1124impl std::fmt::Debug for DeviceResidentRowHess {
1125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1126        f.debug_struct("DeviceResidentRowHess")
1127            .field("n", &self.n)
1128            .field("r", &self.r)
1129            .field("p_total", &self.block.p_total)
1130            .field("bytes", &self.bytes)
1131            .finish()
1132    }
1133}
1134
1135/// Sized-to-fit-once CTA mapping. Rows `[c * HVP_ROWS_PER_CTA, (c+1) * HVP_ROWS_PER_CTA)`
1136/// belong to chunk `c`.
1137#[cfg(target_os = "linux")]
1138pub(crate) fn num_hvp_chunks(n: usize) -> usize {
1139    n.div_ceil(HVP_ROWS_PER_CTA as usize)
1140}
1141
1142/// NVRTC source: HVP-partial kernel + HVP-reduce kernel + diag-partial +
1143/// diag-reduce. All kernels mirror the CPU oracle in this file.
1144#[cfg(target_os = "linux")]
1145pub(crate) const HVP_KERNEL_SOURCE: &str = r#"
1146// CPU parity reference: cpu_oracle_bms_flex_row_hvp / cpu_oracle_bms_flex_row_diagonal
1147// in this module.
1148
1149#define MAX_MULTI_RHS 8
1150
1151extern "C" __global__ void bms_flex_row_hvp_partial(
1152    int                  n_rows,
1153    int                  r,
1154    int                  p_m,
1155    int                  p_g,
1156    int                  p_total,
1157    int                  h_block_start,
1158    int                  h_block_len,
1159    int                  w_block_start,
1160    int                  w_block_len,
1161    int                  h_primary_start,
1162    int                  w_primary_start,
1163    int                  rows_per_cta,
1164    const double * __restrict__ row_hessians,    // [n, r*r]
1165    const double * __restrict__ marginal_design, // [n, p_m] row-major
1166    const double * __restrict__ logslope_design, // [n, p_g] row-major
1167    const double * __restrict__ v,               // [p_total]
1168    double       * __restrict__ partial)         // [num_chunks, p_total]
1169{
1170    int chunk = blockIdx.x;
1171    int tid   = threadIdx.x;
1172    int row_lo = chunk * rows_per_cta;
1173    int row_hi = row_lo + rows_per_cta;
1174    if (row_hi > n_rows) row_hi = n_rows;
1175
1176    // Zero this chunk's partial slice cooperatively.
1177    double *out = partial + (size_t)chunk * (size_t)p_total;
1178    for (int j = tid; j < p_total; j += blockDim.x) {
1179        out[j] = 0.0;
1180    }
1181    __syncthreads();
1182
1183    // Each thread serially processes a stride-of-blockDim set of rows so
1184    // every write to `out[..]` happens from one thread → no atomics within
1185    // the chunk. To keep writes race-free across threads of the same chunk,
1186    // we serialize the cross-row accumulation through a per-row barrier:
1187    // thread 0 of the block processes all rows in the chunk. The per-row
1188    // work is dominated by the dot/axpy over `p_m + p_g`, which is large.
1189    // For Stage 3 we ship the simple, correct path (thread 0 sequential
1190    // per row, blockDim.x threads parallel within a row's dot/axpy).
1191    __shared__ double row_dir[32];
1192    __shared__ double action[32];
1193    __shared__ double dot_reduce[128];
1194
1195    for (int row = row_lo; row < row_hi; ++row) {
1196        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1197        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1198        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1199
1200        // row_dir[0] = mrow · v[0..p_m]
1201        double local = 0.0;
1202        for (int j = tid; j < p_m; j += blockDim.x) {
1203            local += mrow[j] * v[j];
1204        }
1205        dot_reduce[tid] = local;
1206        __syncthreads();
1207        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1208            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1209            __syncthreads();
1210        }
1211        if (tid == 0) row_dir[0] = dot_reduce[0];
1212
1213        // row_dir[1] = grow · v[p_m..p_m+p_g]
1214        local = 0.0;
1215        for (int j = tid; j < p_g; j += blockDim.x) {
1216            local += grow[j] * v[p_m + j];
1217        }
1218        dot_reduce[tid] = local;
1219        __syncthreads();
1220        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1221            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1222            __syncthreads();
1223        }
1224        if (tid == 0) row_dir[1] = dot_reduce[0];
1225
1226        // h/w blocks: direct copy.
1227        if (tid == 0) {
1228            for (int k = 0; k < h_block_len; ++k) {
1229                row_dir[h_primary_start + k] = v[h_block_start + k];
1230            }
1231            for (int k = 0; k < w_block_len; ++k) {
1232                row_dir[w_primary_start + k] = v[w_block_start + k];
1233            }
1234        }
1235        __syncthreads();
1236
1237        // action[u] = Σ_v Hrow[u*r+v] · row_dir[v], computed by thread u (u < r).
1238        if (tid < r) {
1239            double acc = 0.0;
1240            for (int vv = 0; vv < r; ++vv) {
1241                acc += Hrow[tid * r + vv] * row_dir[vv];
1242            }
1243            action[tid] = acc;
1244        }
1245        __syncthreads();
1246
1247        // Pull back into joint β slot.
1248        //   marginal: out[j] += action[0] · mrow[j]   (parallel j)
1249        double a0 = action[0];
1250        for (int j = tid; j < p_m; j += blockDim.x) {
1251            out[j] += a0 * mrow[j];
1252        }
1253        double a1 = action[1];
1254        for (int j = tid; j < p_g; j += blockDim.x) {
1255            out[p_m + j] += a1 * grow[j];
1256        }
1257        if (tid == 0) {
1258            for (int k = 0; k < h_block_len; ++k) {
1259                out[h_block_start + k] += action[h_primary_start + k];
1260            }
1261            for (int k = 0; k < w_block_len; ++k) {
1262                out[w_block_start + k] += action[w_primary_start + k];
1263            }
1264        }
1265        __syncthreads();
1266    }
1267}
1268
1269extern "C" __global__ void bms_flex_row_hvp_reduce(
1270    int                  num_chunks,
1271    int                  p_total,
1272    const double * __restrict__ partial,   // [num_chunks, p_total]
1273    double       * __restrict__ out)        // [p_total]
1274{
1275    int j = blockIdx.x * blockDim.x + threadIdx.x;
1276    if (j >= p_total) return;
1277    double acc = 0.0;
1278    for (int c = 0; c < num_chunks; ++c) {
1279        acc += partial[(size_t)c * (size_t)p_total + (size_t)j];
1280    }
1281    out[j] = acc;
1282}
1283
1284extern "C" __global__ void bms_flex_row_hvp_multi_partial(
1285    int                  n_rows,
1286    int                  r,
1287    int                  p_m,
1288    int                  p_g,
1289    int                  p_total,
1290    int                  h_block_start,
1291    int                  h_block_len,
1292    int                  w_block_start,
1293    int                  w_block_len,
1294    int                  h_primary_start,
1295    int                  w_primary_start,
1296    int                  rows_per_cta,
1297    int                  rhs_count,
1298    const double * __restrict__ row_hessians,    // [n, r*r]
1299    const double * __restrict__ marginal_design, // [n, p_m]
1300    const double * __restrict__ logslope_design, // [n, p_g]
1301    const double * __restrict__ v_rhs,           // [rhs_count, p_total]
1302    double       * __restrict__ partial)         // [rhs_count, num_chunks, p_total]
1303{
1304    int chunk = blockIdx.x;
1305    int tid   = threadIdx.x;
1306    int row_lo = chunk * rows_per_cta;
1307    int row_hi = row_lo + rows_per_cta;
1308    if (row_hi > n_rows) row_hi = n_rows;
1309
1310    int num_chunks = (n_rows + rows_per_cta - 1) / rows_per_cta;
1311    for (int idx = tid; idx < rhs_count * p_total; idx += blockDim.x) {
1312        int rhs = idx / p_total;
1313        int j = idx - rhs * p_total;
1314        partial[((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total + (size_t)j] = 0.0;
1315    }
1316    __syncthreads();
1317
1318    __shared__ double row_dir[MAX_MULTI_RHS * 32];
1319    __shared__ double action[MAX_MULTI_RHS * 32];
1320    __shared__ double dot_reduce[128];
1321
1322    for (int row = row_lo; row < row_hi; ++row) {
1323        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1324        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1325        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1326
1327        for (int rhs = 0; rhs < rhs_count; ++rhs) {
1328            const double *v = v_rhs + (size_t)rhs * (size_t)p_total;
1329
1330            double local = 0.0;
1331            for (int j = tid; j < p_m; j += blockDim.x) {
1332                local += mrow[j] * v[j];
1333            }
1334            dot_reduce[tid] = local;
1335            __syncthreads();
1336            for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1337                if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1338                __syncthreads();
1339            }
1340            if (tid == 0) row_dir[rhs * 32 + 0] = dot_reduce[0];
1341
1342            local = 0.0;
1343            for (int j = tid; j < p_g; j += blockDim.x) {
1344                local += grow[j] * v[p_m + j];
1345            }
1346            dot_reduce[tid] = local;
1347            __syncthreads();
1348            for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1349                if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1350                __syncthreads();
1351            }
1352            if (tid == 0) {
1353                row_dir[rhs * 32 + 1] = dot_reduce[0];
1354                for (int k = 0; k < h_block_len; ++k) {
1355                    row_dir[rhs * 32 + h_primary_start + k] = v[h_block_start + k];
1356                }
1357                for (int k = 0; k < w_block_len; ++k) {
1358                    row_dir[rhs * 32 + w_primary_start + k] = v[w_block_start + k];
1359                }
1360            }
1361            __syncthreads();
1362        }
1363
1364        for (int idx = tid; idx < rhs_count * r; idx += blockDim.x) {
1365            int rhs = idx / r;
1366            int u = idx - rhs * r;
1367            double acc = 0.0;
1368            const double *dir = row_dir + rhs * 32;
1369            for (int vv = 0; vv < r; ++vv) {
1370                acc += Hrow[u * r + vv] * dir[vv];
1371            }
1372            action[rhs * 32 + u] = acc;
1373        }
1374        __syncthreads();
1375
1376        for (int rhs = 0; rhs < rhs_count; ++rhs) {
1377            double *out = partial + ((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total;
1378            double a0 = action[rhs * 32 + 0];
1379            for (int j = tid; j < p_m; j += blockDim.x) {
1380                out[j] += a0 * mrow[j];
1381            }
1382            double a1 = action[rhs * 32 + 1];
1383            for (int j = tid; j < p_g; j += blockDim.x) {
1384                out[p_m + j] += a1 * grow[j];
1385            }
1386            if (tid == 0) {
1387                for (int k = 0; k < h_block_len; ++k) {
1388                    out[h_block_start + k] += action[rhs * 32 + h_primary_start + k];
1389                }
1390                for (int k = 0; k < w_block_len; ++k) {
1391                    out[w_block_start + k] += action[rhs * 32 + w_primary_start + k];
1392                }
1393            }
1394            __syncthreads();
1395        }
1396    }
1397}
1398
1399extern "C" __global__ void bms_flex_row_hvp_multi_reduce(
1400    int                  num_chunks,
1401    int                  p_total,
1402    int                  rhs_count,
1403    const double * __restrict__ partial,   // [rhs_count, num_chunks, p_total]
1404    double       * __restrict__ out)        // [rhs_count, p_total]
1405{
1406    int idx = blockIdx.x * blockDim.x + threadIdx.x;
1407    int total = rhs_count * p_total;
1408    if (idx >= total) return;
1409    int rhs = idx / p_total;
1410    int j = idx - rhs * p_total;
1411    double acc = 0.0;
1412    for (int c = 0; c < num_chunks; ++c) {
1413        acc += partial[((size_t)rhs * (size_t)num_chunks + (size_t)c) * (size_t)p_total + (size_t)j];
1414    }
1415    out[(size_t)rhs * (size_t)p_total + (size_t)j] = acc;
1416}
1417
1418extern "C" __global__ void bms_flex_row_diag_partial(
1419    int                  n_rows,
1420    int                  r,
1421    int                  p_m,
1422    int                  p_g,
1423    int                  p_total,
1424    int                  h_block_start,
1425    int                  h_block_len,
1426    int                  w_block_start,
1427    int                  w_block_len,
1428    int                  h_primary_start,
1429    int                  w_primary_start,
1430    int                  rows_per_cta,
1431    const double * __restrict__ row_hessians,
1432    const double * __restrict__ marginal_design,
1433    const double * __restrict__ logslope_design,
1434    double       * __restrict__ partial)
1435{
1436    int chunk = blockIdx.x;
1437    int tid   = threadIdx.x;
1438    int row_lo = chunk * rows_per_cta;
1439    int row_hi = row_lo + rows_per_cta;
1440    if (row_hi > n_rows) row_hi = n_rows;
1441
1442    double *out = partial + (size_t)chunk * (size_t)p_total;
1443    for (int j = tid; j < p_total; j += blockDim.x) {
1444        out[j] = 0.0;
1445    }
1446    __syncthreads();
1447
1448    for (int row = row_lo; row < row_hi; ++row) {
1449        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1450        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1451        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1452        double h00 = Hrow[0];
1453        double h11 = Hrow[1 * r + 1];
1454        for (int j = tid; j < p_m; j += blockDim.x) {
1455            double v = mrow[j];
1456            out[j] += h00 * v * v;
1457        }
1458        for (int j = tid; j < p_g; j += blockDim.x) {
1459            double v = grow[j];
1460            out[p_m + j] += h11 * v * v;
1461        }
1462        if (tid == 0) {
1463            for (int k = 0; k < h_block_len; ++k) {
1464                int ii = h_primary_start + k;
1465                out[h_block_start + k] += Hrow[ii * r + ii];
1466            }
1467            for (int k = 0; k < w_block_len; ++k) {
1468                int ii = w_primary_start + k;
1469                out[w_block_start + k] += Hrow[ii * r + ii];
1470            }
1471        }
1472        __syncthreads();
1473    }
1474}
1475
1476// ────────────────────────────────────────────────────────────────────────
1477// Phase 4 — SymmetricPackedUpper variants. Per-row storage is
1478//   row_hessians_packed + (size_t)row * (size_t)(r*(r+1)/2)
1479// indexed as
1480//   packed[(u*(2*r - u - 1))/2 + (v - u)]   for u <= v
1481// with symmetric mirror for v < u.
1482// ────────────────────────────────────────────────────────────────────────
1483
1484// Helper: packed-upper index for (u, v) within a single row of r*(r+1)/2
1485// doubles. Caller must pre-swap so that u <= v.
1486__device__ __forceinline__ int bms_flex_packed_idx(int u, int v, int r) {
1487    // u*(2r - u - 1)/2 + (v - u)
1488    return (u * (2 * r - u - 1)) / 2 + (v - u);
1489}
1490
1491// Pack one row of the full row-major r×r Hessian into packed-upper layout.
1492// Launched as one CTA per row (gridDim.x = n_rows, blockDim.x configurable).
1493// Bit-equal copy: each upper-triangle entry is read once from the dense
1494// source and written once to the packed destination.
1495extern "C" __global__ void bms_flex_row_pack_upper(
1496    int                  n_rows,
1497    int                  r,
1498    const double * __restrict__ src_full,    // [n, r*r]
1499    double       * __restrict__ dst_packed)  // [n, r*(r+1)/2]
1500{
1501    int row = blockIdx.x;
1502    if (row >= n_rows) return;
1503    int tid = threadIdx.x;
1504    int per_row = r * (r + 1) / 2;
1505    const double *src = src_full + (size_t)row * (size_t)r * (size_t)r;
1506    double       *dst = dst_packed + (size_t)row * (size_t)per_row;
1507    // Linear scan over packed positions; map each back to (u, v).
1508    for (int pos = tid; pos < per_row; pos += blockDim.x) {
1509        // Invert: for u in [0, r), the range [u_start, u_start + (r - u))
1510        // contains positions for that u. u_start = u*(2r - u - 1)/2.
1511        // Solve smallest u with u*(2r - u - 1)/2 > pos to get u (then
1512        // back off by one); equivalent O(r) linear scan with r <= 32.
1513        int u = 0;
1514        int u_start = 0;
1515        while (u < r) {
1516            int next = u_start + (r - u);
1517            if (pos < next) break;
1518            u_start = next;
1519            ++u;
1520        }
1521        int v = u + (pos - u_start);
1522        dst[pos] = src[(size_t)u * (size_t)r + (size_t)v];
1523    }
1524}
1525
1526extern "C" __global__ void bms_flex_row_hvp_partial_packed(
1527    int                  n_rows,
1528    int                  r,
1529    int                  p_m,
1530    int                  p_g,
1531    int                  p_total,
1532    int                  h_block_start,
1533    int                  h_block_len,
1534    int                  w_block_start,
1535    int                  w_block_len,
1536    int                  h_primary_start,
1537    int                  w_primary_start,
1538    int                  rows_per_cta,
1539    const double * __restrict__ row_hessians_packed, // [n, r*(r+1)/2]
1540    const double * __restrict__ marginal_design,
1541    const double * __restrict__ logslope_design,
1542    const double * __restrict__ v,
1543    double       * __restrict__ partial)
1544{
1545    int chunk = blockIdx.x;
1546    int tid   = threadIdx.x;
1547    int row_lo = chunk * rows_per_cta;
1548    int row_hi = row_lo + rows_per_cta;
1549    if (row_hi > n_rows) row_hi = n_rows;
1550
1551    int per_row = r * (r + 1) / 2;
1552    double *out = partial + (size_t)chunk * (size_t)p_total;
1553    for (int j = tid; j < p_total; j += blockDim.x) {
1554        out[j] = 0.0;
1555    }
1556    __syncthreads();
1557
1558    __shared__ double row_dir[32];
1559    __shared__ double action[32];
1560    __shared__ double dot_reduce[128];
1561
1562    for (int row = row_lo; row < row_hi; ++row) {
1563        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1564        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1565        const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1566
1567        // row_dir[0] = mrow · v[0..p_m]
1568        double local = 0.0;
1569        for (int j = tid; j < p_m; j += blockDim.x) {
1570            local += mrow[j] * v[j];
1571        }
1572        dot_reduce[tid] = local;
1573        __syncthreads();
1574        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1575            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1576            __syncthreads();
1577        }
1578        if (tid == 0) row_dir[0] = dot_reduce[0];
1579
1580        // row_dir[1] = grow · v[p_m..p_m+p_g]
1581        local = 0.0;
1582        for (int j = tid; j < p_g; j += blockDim.x) {
1583            local += grow[j] * v[p_m + j];
1584        }
1585        dot_reduce[tid] = local;
1586        __syncthreads();
1587        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1588            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1589            __syncthreads();
1590        }
1591        if (tid == 0) row_dir[1] = dot_reduce[0];
1592
1593        if (tid == 0) {
1594            for (int k = 0; k < h_block_len; ++k) {
1595                row_dir[h_primary_start + k] = v[h_block_start + k];
1596            }
1597            for (int k = 0; k < w_block_len; ++k) {
1598                row_dir[w_primary_start + k] = v[w_block_start + k];
1599            }
1600        }
1601        __syncthreads();
1602
1603        // action[u] = Σ_w H[u, w] · row_dir[w], where H[u, w] reads from
1604        // packed-upper with (uu, vv) = (min(u, w), max(u, w)).
1605        if (tid < r) {
1606            double acc = 0.0;
1607            int u = tid;
1608            for (int w = 0; w < r; ++w) {
1609                int uu = u < w ? u : w;
1610                int vv = u < w ? w : u;
1611                acc += Hrow[bms_flex_packed_idx(uu, vv, r)] * row_dir[w];
1612            }
1613            action[tid] = acc;
1614        }
1615        __syncthreads();
1616
1617        double a0 = action[0];
1618        for (int j = tid; j < p_m; j += blockDim.x) {
1619            out[j] += a0 * mrow[j];
1620        }
1621        double a1 = action[1];
1622        for (int j = tid; j < p_g; j += blockDim.x) {
1623            out[p_m + j] += a1 * grow[j];
1624        }
1625        if (tid == 0) {
1626            for (int k = 0; k < h_block_len; ++k) {
1627                out[h_block_start + k] += action[h_primary_start + k];
1628            }
1629            for (int k = 0; k < w_block_len; ++k) {
1630                out[w_block_start + k] += action[w_primary_start + k];
1631            }
1632        }
1633        __syncthreads();
1634    }
1635}
1636
1637// ────────────────────────────────────────────────────────────────────────
1638// Phase 6 — dense joint-Hessian block kernel for the debug / exact-REML
1639// route. Materialises the full `[p_total, p_total]` row-major joint H
1640// from the per-row r×r Hessian via the P_i pullback. NOT the default
1641// Newton path: production Newton uses HVP (Phase 2/3); this kernel exists
1642// for exact-REML logdet / dense-H comparisons / diagnostic dumps where the
1643// caller genuinely needs the dense matrix on the device.
1644//
1645// Per-CTA partial: each CTA owns a contiguous chunk of rows
1646// `[chunk*rows_per_cta, (chunk+1)*rows_per_cta)`. Inside the CTA the
1647// per-row pullback computes `(P_i^T H_i P_i)[m, n]` and adds it to the
1648// CTA's shared-mem `[p_total, p_total]` partial. The reduce kernel sums
1649// chunk-major-fixed-order into a single `[p_total, p_total]` output.
1650//
1651// Math: for primary index u ∈ [0, r):
1652//   * u = 0:        phi_u = (X_i in slot 0..p_m, 0 elsewhere)
1653//   * u = 1:        phi_u = (0, G_i in slot p_m..p_m+p_g, 0 elsewhere)
1654//   * u = 2+j:      phi_u = e_{h_block_start + j}  (j ∈ 0..h_block_len)
1655//   * u = 2+h+l:    phi_u = e_{w_block_start + l}  (l ∈ 0..w_block_len)
1656// Then `H_full[m, n] += sum_{u,v} H_i[u,v] * phi_u[m] * phi_v[n]`.
1657//
1658// Shared-memory budget: at large-scale shape p_total = 44, a [44, 44] f64
1659// partial is 44*44*8 = 15.5 KiB — well below the V100 48 KiB/SM cap.
1660// At p_total ≤ 80 the kernel still fits (80*80*8 = 50 KiB → just over
1661// V100 cap; caller must enforce p_total ≤ DENSE_BLOCK_MAX_P). The
1662// launcher rejects oversize p_total cleanly.
1663
1664extern "C" __global__ void bms_flex_row_dense_block_partial(
1665    int                  n_rows,
1666    int                  r,
1667    int                  p_m,
1668    int                  p_g,
1669    int                  p_total,
1670    int                  h_block_start,
1671    int                  h_block_len,
1672    int                  w_block_start,
1673    int                  w_block_len,
1674    int                  h_primary_start,
1675    int                  w_primary_start,
1676    int                  rows_per_cta,
1677    const double * __restrict__ row_hessians,    // [n, r*r]
1678    const double * __restrict__ marginal_design, // [n, p_m]
1679    const double * __restrict__ logslope_design, // [n, p_g]
1680    double       * __restrict__ partial)         // [num_chunks, p_total, p_total]
1681{
1682    extern __shared__ double shmem[];
1683    int chunk = blockIdx.x;
1684    int tid   = threadIdx.x;
1685    int row_lo = chunk * rows_per_cta;
1686    int row_hi = row_lo + rows_per_cta;
1687    if (row_hi > n_rows) row_hi = n_rows;
1688
1689    int pp = p_total * p_total;
1690    double *acc = shmem; // CTA-private accumulator [p_total, p_total]
1691    for (int j = tid; j < pp; j += blockDim.x) acc[j] = 0.0;
1692    __syncthreads();
1693
1694    // Per-row work performed by thread 0 to avoid cross-thread RW
1695    // contention on `acc[]`. Per-row complexity is O(r * p_m + r * p_g
1696    // + r²): tractable because r ≤ 32 and p_m + p_g typically ≤ 64.
1697    // Tighter parallel implementations are possible (warp-stripe the
1698    // 4-way nested u-v-m-n loop) but Phase 6 is a debug-only path and
1699    // the simple version is easier to audit for correctness against
1700    // the host-side P_i pullback oracle.
1701    if (tid == 0) {
1702        for (int row = row_lo; row < row_hi; ++row) {
1703            const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1704            const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1705            const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1706            for (int u = 0; u < r; ++u) {
1707                for (int v = 0; v < r; ++v) {
1708                    double huv = Hrow[u * r + v];
1709                    if (huv == 0.0) continue;
1710                    // For each (u, v), iterate (m, n) over the non-zero
1711                    // outer-product support of phi_u and phi_v.
1712                    // Build a small (offset, len, src_ptr) descriptor for
1713                    // each operand block as we go.
1714                    int m_off, m_len; const double *m_src; bool m_indicator;
1715                    int n_off, n_len; const double *n_src; bool n_indicator;
1716                    if (u == 0)      { m_off = 0;   m_len = p_m; m_src = mrow; m_indicator = false; }
1717                    else if (u == 1) { m_off = p_m; m_len = p_g; m_src = grow; m_indicator = false; }
1718                    else if (u - 2 < h_block_len) {
1719                                       m_off = h_block_start + (u - 2);
1720                                       m_len = 1;   m_src = NULL; m_indicator = true;
1721                    } else {
1722                                       m_off = w_block_start + (u - 2 - h_block_len);
1723                                       m_len = 1;   m_src = NULL; m_indicator = true;
1724                    }
1725                    if (v == 0)      { n_off = 0;   n_len = p_m; n_src = mrow; n_indicator = false; }
1726                    else if (v == 1) { n_off = p_m; n_len = p_g; n_src = grow; n_indicator = false; }
1727                    else if (v - 2 < h_block_len) {
1728                                       n_off = h_block_start + (v - 2);
1729                                       n_len = 1;   n_src = NULL; n_indicator = true;
1730                    } else {
1731                                       n_off = w_block_start + (v - 2 - h_block_len);
1732                                       n_len = 1;   n_src = NULL; n_indicator = true;
1733                    }
1734                    // accumulate huv * phi_u[m] * phi_v[n] into acc[m, n]
1735                    for (int mi = 0; mi < m_len; ++mi) {
1736                        double pm = m_indicator ? 1.0 : m_src[mi];
1737                        if (pm == 0.0) continue;
1738                        double scaled = huv * pm;
1739                        int m_idx = m_off + mi;
1740                        for (int ni = 0; ni < n_len; ++ni) {
1741                            double pn = n_indicator ? 1.0 : n_src[ni];
1742                            int n_idx = n_off + ni;
1743                            acc[m_idx * p_total + n_idx] += scaled * pn;
1744                        }
1745                    }
1746                }
1747            }
1748        }
1749    }
1750    __syncthreads();
1751
1752    // Write CTA accumulator out to global memory at its chunk slot.
1753    double *out_chunk = partial + (size_t)chunk * (size_t)pp;
1754    for (int j = tid; j < pp; j += blockDim.x) {
1755        out_chunk[j] = acc[j];
1756    }
1757}
1758
1759extern "C" __global__ void bms_flex_row_dense_block_reduce(
1760    int                  num_chunks,
1761    int                  p_total,
1762    const double * __restrict__ partial,
1763    double       * __restrict__ out)
1764{
1765    int j = blockIdx.x * blockDim.x + threadIdx.x;
1766    int pp = p_total * p_total;
1767    if (j >= pp) return;
1768    double acc = 0.0;
1769    for (int c = 0; c < num_chunks; ++c) {
1770        acc += partial[(size_t)c * (size_t)pp + (size_t)j];
1771    }
1772    out[j] = acc;
1773}
1774
1775extern "C" __global__ void bms_flex_row_diag_partial_packed(
1776    int                  n_rows,
1777    int                  r,
1778    int                  p_m,
1779    int                  p_g,
1780    int                  p_total,
1781    int                  h_block_start,
1782    int                  h_block_len,
1783    int                  w_block_start,
1784    int                  w_block_len,
1785    int                  h_primary_start,
1786    int                  w_primary_start,
1787    int                  rows_per_cta,
1788    const double * __restrict__ row_hessians_packed,
1789    const double * __restrict__ marginal_design,
1790    const double * __restrict__ logslope_design,
1791    double       * __restrict__ partial)
1792{
1793    int chunk = blockIdx.x;
1794    int tid   = threadIdx.x;
1795    int row_lo = chunk * rows_per_cta;
1796    int row_hi = row_lo + rows_per_cta;
1797    if (row_hi > n_rows) row_hi = n_rows;
1798
1799    int per_row = r * (r + 1) / 2;
1800    double *out = partial + (size_t)chunk * (size_t)p_total;
1801    for (int j = tid; j < p_total; j += blockDim.x) {
1802        out[j] = 0.0;
1803    }
1804    __syncthreads();
1805
1806    for (int row = row_lo; row < row_hi; ++row) {
1807        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1808        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1809        const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1810        // Diagonal entry for (u, u) sits at packed_idx(u, u, r).
1811        double h00 = Hrow[bms_flex_packed_idx(0, 0, r)];
1812        double h11 = Hrow[bms_flex_packed_idx(1, 1, r)];
1813        for (int j = tid; j < p_m; j += blockDim.x) {
1814            double v = mrow[j];
1815            out[j] += h00 * v * v;
1816        }
1817        for (int j = tid; j < p_g; j += blockDim.x) {
1818            double v = grow[j];
1819            out[p_m + j] += h11 * v * v;
1820        }
1821        if (tid == 0) {
1822            for (int k = 0; k < h_block_len; ++k) {
1823                int ii = h_primary_start + k;
1824                out[h_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1825            }
1826            for (int k = 0; k < w_block_len; ++k) {
1827                int ii = w_primary_start + k;
1828                out[w_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1829            }
1830        }
1831        __syncthreads();
1832    }
1833}
1834"#;
1835
1836#[cfg(target_os = "linux")]
1837pub(crate) struct HvpKernelBackend {
1838    pub(crate) stream: Arc<CudaStream>,
1839    pub(crate) module: Arc<CudaModule>,
1840}
1841
1842#[cfg(target_os = "linux")]
1843impl HvpKernelBackend {
1844    pub(crate) fn probe() -> Result<&'static Self, GpuError> {
1845        static BACKEND: OnceLock<Result<HvpKernelBackend, GpuError>> = OnceLock::new();
1846        BACKEND
1847            .get_or_init(|| {
1848                gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row hvp", |parts| {
1849                    // #1551: arch-aware compile (see launch_bms_flex_row_kernel) —
1850                    // pin `--gpu-architecture` to the device capability and supply
1851                    // the standard include paths via the shared NVRTC entry point.
1852                    let ptx = gam_gpu::device_cache::compile_ptx_arch(HVP_KERNEL_SOURCE)
1853                        .map_err(|err| GpuError::DriverCallFailed {
1854                            reason: format!("bms_flex_row hvp NVRTC compile failed: {err}"),
1855                        })?;
1856                    let module =
1857                        parts
1858                            .ctx
1859                            .load_module(ptx)
1860                            .map_err(|err| GpuError::DriverCallFailed {
1861                                reason: format!("bms_flex_row hvp module load failed: {err}"),
1862                            })?;
1863                    Ok(HvpKernelBackend {
1864                        stream: parts.stream.clone(),
1865                        module,
1866                    })
1867                })
1868            })
1869            .as_ref()
1870            .map_err(GpuError::clone)
1871    }
1872}
1873
1874/// Build a device-resident row-Hessian cache by launching the row kernel and
1875/// keeping the resulting `n × r²` slice resident on the device. Also uploads
1876/// the dense marginal + logslope design matrices so subsequent HVPs do not
1877/// re-upload them at every direction.
1878///
1879/// `marginal_design_row_major` and `logslope_design_row_major` must be
1880/// row-major `[n, p_m]` and `[n, p_g]` contiguous slices.
1881///
1882/// #461 absorber (additive Stage-1 influence block): the orthogonalization is
1883/// realized as **A2 — the marginal design widened to `[M | Z̃_infl]`** (see
1884/// `src/families/bms/block_specs.rs::widen_marginal_dense_with_influence`), NOT
1885/// a dedicated 5th primary coordinate. The absorber `+Z̃_infl·γ` is plain
1886/// additive into the marginal index `α(x)`, so γ lives inside `β_m` of the
1887/// widened block and `p_m` already counts the `p₁` influence columns. The row
1888/// kernel reads the marginal index from `block_states[0].eta` (which carries
1889/// `Z̃_infl·γ`) and pulls back through this same widened `marginal_design`, so
1890/// the absorber rides the existing primary-coordinate `u = 0` chain with **no
1891/// kernel-source change**: η, gradient, and Hessian match the CPU kernel
1892/// bit-for-bit precisely because `marginal_design` and `β_m` are the matched
1893/// (design, coefficient) pair the CPU path uses. The validation below pins
1894/// `marginal_design.len() == n·p_m` (with `p_m` widened), so a stale narrow
1895/// design against a widened `block.p_m` is rejected cleanly (CPU fallback)
1896/// rather than silently computing the wrong η. The absorber is dropped at
1897/// predict, where the marginal design is rebuilt without the influence columns,
1898/// so the predict-time `p_m` is narrow and this path is correct there too.
1899#[cfg(target_os = "linux")]
1900pub(crate) fn launch_bms_flex_row_kernel_device_resident(
1901    inputs: BmsFlexRowKernelInputs<'_>,
1902    marginal_design_row_major: &[f64],
1903    logslope_design_row_major: &[f64],
1904    block: BmsFlexBlockLayout,
1905    primary: BmsFlexPrimaryLayout,
1906) -> Result<DeviceResidentRowHess, GpuError> {
1907    inputs.validate()?;
1908    if !s_f_diagnostic_finite(&inputs) {
1909        return Err(GpuError::DriverCallFailed {
1910            reason: format!(
1911                "bms_flex_row device-resident: s_f must be positive and finite, got {}",
1912                inputs.s_f
1913            ),
1914        });
1915    }
1916    let n = inputs.n_rows;
1917    let r = inputs.r;
1918    if marginal_design_row_major.len() != n * block.p_m {
1919        return Err(GpuError::DriverCallFailed {
1920            reason: format!(
1921                "bms_flex_row device-resident: marginal_design len={} != n*p_m={}",
1922                marginal_design_row_major.len(),
1923                n * block.p_m
1924            ),
1925        });
1926    }
1927    if logslope_design_row_major.len() != n * block.p_g {
1928        return Err(GpuError::DriverCallFailed {
1929            reason: format!(
1930                "bms_flex_row device-resident: logslope_design len={} != n*p_g={}",
1931                logslope_design_row_major.len(),
1932                n * block.p_g
1933            ),
1934        });
1935    }
1936    if primary.r != r {
1937        return Err(GpuError::DriverCallFailed {
1938            reason: format!(
1939                "bms_flex_row device-resident: primary.r={} != inputs.r={}",
1940                primary.r, r
1941            ),
1942        });
1943    }
1944
1945    // Ensure the row kernel backend is compiled & loaded (this also compiles
1946    // the HVP backend on first use so the caller surfaces failures here).
1947    let backend = RowKernelBackend::probe()?;
1948    HvpKernelBackend::probe()?;
1949    let stream = backend.stream.clone();
1950
1951    let upload_f64 = |slice: &[f64], label: &str| {
1952        stream
1953            .clone_htod(slice)
1954            .map_err(|err| GpuError::DriverCallFailed {
1955                reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1956            })
1957    };
1958    let upload_u32 = |slice: &[u32], label: &str| {
1959        stream
1960            .clone_htod(slice)
1961            .map_err(|err| GpuError::DriverCallFailed {
1962                reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1963            })
1964    };
1965
1966    let d_q = upload_f64(inputs.q, "q")?;
1967    let d_b = upload_f64(inputs.b, "b")?;
1968    let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
1969    let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
1970    let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
1971    let d_y = upload_f64(inputs.y, "y")?;
1972    let d_w = upload_f64(inputs.w, "w")?;
1973    let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
1974    let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
1975    let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
1976    let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
1977    let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
1978    let d_a = upload_f64(inputs.cell_a, "cell_a")?;
1979    let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
1980    let d_r = upload_f64(inputs.cell_r, "cell_r")?;
1981    let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
1982    let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
1983    let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
1984    let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
1985    // Phase-4: optionally consume device-resident moments (no host upload).
1986    let owned_host_moments: CudaSlice<f64>;
1987    let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
1988        CellMomentsSource::Host(slice) => {
1989            owned_host_moments = upload_f64(slice, "cell_moments")?;
1990            &owned_host_moments
1991        }
1992        CellMomentsSource::Device(d) => *d,
1993    };
1994    let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
1995    let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
1996    let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
1997    let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
1998    let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
1999
2000    let d_marginal = upload_f64(marginal_design_row_major, "marginal_design")?;
2001    let d_logslope = upload_f64(logslope_design_row_major, "logslope_design")?;
2002
2003    let mut d_neglog = stream
2004        .alloc_zeros::<f64>(n)
2005        .map_err(|err| GpuError::DriverCallFailed {
2006            reason: format!("bms_flex_row device-resident alloc neglog: {err}"),
2007        })?;
2008    let mut d_grad =
2009        stream
2010            .alloc_zeros::<f64>(n * r)
2011            .map_err(|err| GpuError::DriverCallFailed {
2012                reason: format!("bms_flex_row device-resident alloc grad: {err}"),
2013            })?;
2014    let mut d_hess =
2015        stream
2016            .alloc_zeros::<f64>(n * r * r)
2017            .map_err(|err| GpuError::DriverCallFailed {
2018                reason: format!("bms_flex_row device-resident alloc hess: {err}"),
2019            })?;
2020
2021    let func = backend
2022        .module
2023        .load_function("bms_flex_row_kernel")
2024        .map_err(|err| GpuError::DriverCallFailed {
2025            reason: format!("bms_flex_row device-resident load_function: {err}"),
2026        })?;
2027
2028    let cfg = LaunchConfig {
2029        grid_dim: (n as u32, 1, 1),
2030        block_dim: (ROW_KERNEL_THREADS, 1, 1),
2031        shared_mem_bytes: 0,
2032    };
2033    let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
2034        reason: format!("bms_flex_row device-resident: n_rows={n} exceeds i32 range"),
2035    })?;
2036    let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
2037        reason: format!("bms_flex_row device-resident: r={r} exceeds i32 range"),
2038    })?;
2039    let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
2040        reason: format!(
2041            "bms_flex_row device-resident: p_h={} exceeds i32 range",
2042            inputs.p_h
2043        ),
2044    })?;
2045    let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
2046        reason: format!(
2047            "bms_flex_row device-resident: p_w={} exceeds i32 range",
2048            inputs.p_w
2049        ),
2050    })?;
2051    let s_f_val = inputs.s_f;
2052
2053    let mut builder = stream.launch_builder(&func);
2054    builder
2055        .arg(&n_i32)
2056        .arg(&r_i32)
2057        .arg(&p_h_i32)
2058        .arg(&p_w_i32)
2059        .arg(&s_f_val)
2060        .arg(&d_q)
2061        .arg(&d_b)
2062        .arg(&d_mu1)
2063        .arg(&d_mu2)
2064        .arg(&d_zobs)
2065        .arg(&d_y)
2066        .arg(&d_w)
2067        .arg(&d_offsets)
2068        .arg(&d_c0)
2069        .arg(&d_c1)
2070        .arg(&d_c2)
2071        .arg(&d_c3)
2072        .arg(&d_a)
2073        .arg(&d_aa)
2074        .arg(&d_r)
2075        .arg(&d_ar)
2076        .arg(&d_sbb)
2077        .arg(&d_sbh)
2078        .arg(&d_sbw)
2079        .arg(d_moments_ref)
2080        .arg(&d_chi)
2081        .arg(&d_xi)
2082        .arg(&d_rho)
2083        .arg(&d_tau)
2084        .arg(&d_ruv)
2085        .arg(&mut d_neglog)
2086        .arg(&mut d_grad)
2087        .arg(&mut d_hess);
2088    // SAFETY: same shape contract as `launch_linux`: every kernel parameter is
2089    // either a primitive scalar by-value, a const device pointer whose
2090    // capacity was validated by `inputs.validate()`, or one of the three
2091    // output buffers we just allocated with the expected element count.
2092    unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
2093        reason: format!("bms_flex_row device-resident launch: {err}"),
2094    })?;
2095    stream
2096        .synchronize()
2097        .map_err(|err| GpuError::DriverCallFailed {
2098            reason: format!("bms_flex_row device-resident synchronize: {err}"),
2099        })?;
2100
2101    // The kernel writes neglog + grad alongside the row Hessian, but the
2102    // device-resident cache path keeps neither on the host: the fused
2103    // CPU gradient pass (the only consumer of host-side neglog/grad) is
2104    // dispatched only as a fallback when the GPU dense-block kernel does
2105    // not apply, and in that fallback the row kernel runs again locally.
2106    // Drop the device buffers so the allocation pool reclaims them
2107    // immediately rather than tying them to the cache's lifetime.
2108    drop(d_neglog);
2109    drop(d_grad);
2110    // Drop the per-cell uploads; keep d_hess + designs.
2111    drop(d_q);
2112    drop(d_b);
2113    drop(d_mu1);
2114    drop(d_mu2);
2115    drop(d_zobs);
2116    drop(d_y);
2117    drop(d_w);
2118    drop(d_offsets);
2119    drop(d_c0);
2120    drop(d_c1);
2121    drop(d_c2);
2122    drop(d_c3);
2123    drop(d_a);
2124    drop(d_aa);
2125    drop(d_r);
2126    drop(d_ar);
2127    drop(d_sbb);
2128    drop(d_sbh);
2129    drop(d_sbw);
2130    // `owned_host_moments` (if any) and the borrowed `d_moments_ref` both
2131    // go out of scope at the end of the function; the device-resident
2132    // moments owned by the caller stay alive.
2133    drop(d_chi);
2134    drop(d_xi);
2135    drop(d_rho);
2136    drop(d_tau);
2137    drop(d_ruv);
2138
2139    let bytes = ((n * r * r + marginal_design_row_major.len() + logslope_design_row_major.len())
2140        * std::mem::size_of::<f64>()) as u64;
2141    Ok(DeviceResidentRowHess {
2142        hess: d_hess,
2143        marginal_design: d_marginal,
2144        logslope_design: d_logslope,
2145        n,
2146        r,
2147        block,
2148        primary,
2149        bytes,
2150    })
2151}
2152
2153/// Which partial kernel the joint-β engine drives, whether it consumes a
2154/// direction vector `d_v`, and where the reduced `[1, p_total]` image lands.
2155/// All three points of variation are encoded here so the public entry points
2156/// stay thin wrappers over one launch helper.
2157#[cfg(target_os = "linux")]
2158#[derive(Clone, Copy)]
2159pub(crate) enum BmsFlexRowLaunchMode {
2160    /// `bms_flex_row_hvp_partial`, `H · v` per row, result left on-stream.
2161    HvpDeviceOut,
2162    /// `bms_flex_row_diag_partial`, `diag(H)` per row, downloaded to host.
2163    DiagonalHostOut,
2164}
2165
2166#[cfg(target_os = "linux")]
2167impl BmsFlexRowLaunchMode {
2168    /// Name of the partial kernel this mode loads from the HVP module.
2169    pub(crate) fn partial_kernel_name(self) -> &'static str {
2170        match self {
2171            BmsFlexRowLaunchMode::HvpDeviceOut => "bms_flex_row_hvp_partial",
2172            BmsFlexRowLaunchMode::DiagonalHostOut => "bms_flex_row_diag_partial",
2173        }
2174    }
2175}
2176
2177/// All scalar launch arguments for the joint-β partial kernel, derived once
2178/// from a [`DeviceResidentRowHess`]. The HVP and diagonal partial kernels take
2179/// the identical leading block-layout argument list (only the trailing
2180/// `d_v` / output pointers differ), so this captures the long, easy-to-
2181/// desynchronize prefix in a single place.
2182#[cfg(target_os = "linux")]
2183pub(crate) struct PreparedBmsFlexRowLaunchArgs {
2184    pub(crate) n_i32: i32,
2185    pub(crate) r_i32: i32,
2186    pub(crate) p_m_i32: i32,
2187    pub(crate) p_g_i32: i32,
2188    pub(crate) p_total_i32: i32,
2189    pub(crate) h_block_start: i32,
2190    pub(crate) h_block_len: i32,
2191    pub(crate) w_block_start: i32,
2192    pub(crate) w_block_len: i32,
2193    pub(crate) h_primary_start: i32,
2194    pub(crate) w_primary_start: i32,
2195    pub(crate) rows_per_cta: i32,
2196    pub(crate) num_chunks: usize,
2197}
2198
2199#[cfg(target_os = "linux")]
2200impl PreparedBmsFlexRowLaunchArgs {
2201    pub(crate) fn from_storage(storage: &DeviceResidentRowHess) -> Self {
2202        let p_total = storage.block.p_total;
2203        let num_chunks = num_hvp_chunks(storage.n);
2204        PreparedBmsFlexRowLaunchArgs {
2205            n_i32: storage.n as i32,
2206            r_i32: storage.r as i32,
2207            p_m_i32: storage.block.p_m as i32,
2208            p_g_i32: storage.block.p_g as i32,
2209            p_total_i32: p_total as i32,
2210            h_block_start: storage
2211                .block
2212                .h
2213                .as_ref()
2214                .map(|r| r.start as i32)
2215                .unwrap_or(0),
2216            h_block_len: storage
2217                .block
2218                .h
2219                .as_ref()
2220                .map(|r| r.len() as i32)
2221                .unwrap_or(0),
2222            w_block_start: storage
2223                .block
2224                .w
2225                .as_ref()
2226                .map(|r| r.start as i32)
2227                .unwrap_or(0),
2228            w_block_len: storage
2229                .block
2230                .w
2231                .as_ref()
2232                .map(|r| r.len() as i32)
2233                .unwrap_or(0),
2234            h_primary_start: storage
2235                .primary
2236                .h
2237                .as_ref()
2238                .map(|r| r.start as i32)
2239                .unwrap_or(0),
2240            w_primary_start: storage
2241                .primary
2242                .w
2243                .as_ref()
2244                .map(|r| r.start as i32)
2245                .unwrap_or(0),
2246            rows_per_cta: HVP_ROWS_PER_CTA as i32,
2247            num_chunks,
2248        }
2249    }
2250}
2251
2252/// Shared partial+reduce engine behind every joint-β launcher.
2253///
2254/// Allocates the `[num_chunks, p_total]` partial buffer, loads the mode's
2255/// partial kernel plus the common `bms_flex_row_hvp_reduce`, builds both
2256/// launch configs from a single [`PreparedBmsFlexRowLaunchArgs`], launches the
2257/// partial kernel (binding `d_v` only for the HVP modes), and launches the
2258/// reduction into caller-supplied `d_out`.
2259///
2260/// **No** `synchronize()` or DtoH is performed here — the surrounding helper
2261/// decides whether to keep the result on-stream (device-resident PCG hot path)
2262/// or sync + download it to the host. `ctx` is a short error-context tag woven
2263/// into every `DriverCallFailed` reason so failures stay attributable to the
2264/// originating entry point.
2265#[cfg(target_os = "linux")]
2266pub(crate) fn run_bms_flex_row_partial_reduce(
2267    storage: &DeviceResidentRowHess,
2268    mode: BmsFlexRowLaunchMode,
2269    d_v: Option<&CudaSlice<f64>>,
2270    d_out: &mut CudaSlice<f64>,
2271    ctx: &str,
2272) -> Result<(), GpuError> {
2273    let backend = HvpKernelBackend::probe()?;
2274    let stream = backend.stream.clone();
2275    let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2276    let p_total = storage.block.p_total;
2277
2278    let mut d_partial = stream
2279        .alloc_zeros::<f64>(args.num_chunks * p_total)
2280        .map_err(|err| GpuError::DriverCallFailed {
2281            reason: format!("bms_flex_row {ctx} alloc partial: {err}"),
2282        })?;
2283
2284    let partial_kernel_name = mode.partial_kernel_name();
2285    let part_func = backend
2286        .module
2287        .load_function(partial_kernel_name)
2288        .map_err(|err| GpuError::DriverCallFailed {
2289            reason: format!("bms_flex_row {ctx} load {partial_kernel_name}: {err}"),
2290        })?;
2291    let red_func = backend
2292        .module
2293        .load_function("bms_flex_row_hvp_reduce")
2294        .map_err(|err| GpuError::DriverCallFailed {
2295            reason: format!("bms_flex_row {ctx} load reduce: {err}"),
2296        })?;
2297
2298    let cfg_part = LaunchConfig {
2299        grid_dim: (args.num_chunks as u32, 1, 1),
2300        block_dim: (HVP_THREADS, 1, 1),
2301        shared_mem_bytes: 0,
2302    };
2303    let mut builder = stream.launch_builder(&part_func);
2304    builder
2305        .arg(&args.n_i32)
2306        .arg(&args.r_i32)
2307        .arg(&args.p_m_i32)
2308        .arg(&args.p_g_i32)
2309        .arg(&args.p_total_i32)
2310        .arg(&args.h_block_start)
2311        .arg(&args.h_block_len)
2312        .arg(&args.w_block_start)
2313        .arg(&args.w_block_len)
2314        .arg(&args.h_primary_start)
2315        .arg(&args.w_primary_start)
2316        .arg(&args.rows_per_cta)
2317        .arg(&storage.hess)
2318        .arg(&storage.marginal_design)
2319        .arg(&storage.logslope_design);
2320    if let Some(d_v) = d_v {
2321        builder.arg(d_v);
2322    }
2323    builder.arg(&mut d_partial);
2324    // SAFETY: every device pointer above either comes from `storage` (whose
2325    // capacities were established by
2326    // `launch_bms_flex_row_kernel_device_resident`) or was just allocated here
2327    // (`d_partial` = num_chunks * p_total). `d_v`, when bound, is length-checked
2328    // by the calling adapter against `p_total`. The diagonal partial kernel
2329    // takes no direction argument, matching `d_v == None`. Scalar args are i32
2330    // by-value.
2331    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2332        reason: format!("bms_flex_row {ctx} partial launch: {err}"),
2333    })?;
2334
2335    let red_threads: u32 = REDUCTION_THREADS;
2336    let red_blocks: u32 = ((p_total as u32) + red_threads - 1) / red_threads;
2337    let cfg_red = LaunchConfig {
2338        grid_dim: (red_blocks, 1, 1),
2339        block_dim: (red_threads, 1, 1),
2340        shared_mem_bytes: 0,
2341    };
2342    let num_chunks_i32 = args.num_chunks as i32;
2343    let mut builder = stream.launch_builder(&red_func);
2344    builder
2345        .arg(&num_chunks_i32)
2346        .arg(&args.p_total_i32)
2347        .arg(&d_partial)
2348        .arg(d_out);
2349    // SAFETY: `d_partial` was just populated by the partial kernel above;
2350    // `d_out` is `p_total` doubles (length-checked / allocated by the calling
2351    // adapter); both scalar args fit i32.
2352    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2353        reason: format!("bms_flex_row {ctx} reduce launch: {err}"),
2354    })?;
2355    // `d_partial` drops at end of fn; cudarc keeps the alloc alive until the
2356    // stream is done with it, so the reduce kernel completes safely.
2357    drop(d_partial);
2358    Ok(())
2359}
2360
2361/// Host-returning joint-β launcher shared by [`launch_bms_flex_row_hvp`]
2362/// ([`BmsFlexRowLaunchMode::HvpHostOut`]) and [`launch_bms_flex_row_diagonal`]
2363/// ([`BmsFlexRowLaunchMode::DiagonalHostOut`]).
2364///
2365/// Probes the backend, allocates the `p_total`-double output on its stream,
2366/// optionally uploads a host direction `v` (HVP modes; `None` for the diagonal
2367/// mode, which takes no direction), runs the shared partial+reduce engine, then
2368/// synchronizes and downloads the reduced image to a host `Vec<f64>`. `ctx`
2369/// tags every `DriverCallFailed` reason with the originating entry point.
2370#[cfg(target_os = "linux")]
2371pub(crate) fn launch_bms_flex_row_host(
2372    storage: &DeviceResidentRowHess,
2373    mode: BmsFlexRowLaunchMode,
2374    v: Option<&[f64]>,
2375    ctx: &str,
2376) -> Result<Vec<f64>, GpuError> {
2377    let p_total = storage.block.p_total;
2378    if let Some(v) = v {
2379        if v.len() != p_total {
2380            return Err(GpuError::DriverCallFailed {
2381                reason: format!(
2382                    "bms_flex_row {ctx}: v.len()={} != p_total={p_total}",
2383                    v.len()
2384                ),
2385            });
2386        }
2387    }
2388
2389    let backend = HvpKernelBackend::probe()?;
2390    let stream = backend.stream.clone();
2391
2392    let d_v = match v {
2393        Some(v) => Some(
2394            stream
2395                .clone_htod(v)
2396                .map_err(|err| GpuError::DriverCallFailed {
2397                    reason: format!("bms_flex_row {ctx} upload v: {err}"),
2398                })?,
2399        ),
2400        None => None,
2401    };
2402    let mut d_out =
2403        stream
2404            .alloc_zeros::<f64>(p_total)
2405            .map_err(|err| GpuError::DriverCallFailed {
2406                reason: format!("bms_flex_row {ctx} alloc out: {err}"),
2407            })?;
2408
2409    run_bms_flex_row_partial_reduce(storage, mode, d_v.as_ref(), &mut d_out, ctx)?;
2410
2411    stream
2412        .synchronize()
2413        .map_err(|err| GpuError::DriverCallFailed {
2414            reason: format!("bms_flex_row {ctx} synchronize: {err}"),
2415        })?;
2416    stream
2417        .clone_dtoh(&d_out)
2418        .map_err(|err| GpuError::DriverCallFailed {
2419            reason: format!("bms_flex_row {ctx} download out: {err}"),
2420        })
2421}
2422
2423#[cfg(target_os = "linux")]
2424pub(crate) fn validate_bms_flex_row_hvp_multi_shape(
2425    storage: &DeviceResidentRowHess,
2426    rhs_count: usize,
2427    v_rhs_len: usize,
2428    out_len: Option<usize>,
2429    ctx: &str,
2430) -> Result<usize, GpuError> {
2431    if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2432        return Err(GpuError::DriverCallFailed {
2433            reason: format!(
2434                "bms_flex_row {ctx}: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2435            ),
2436        });
2437    }
2438    let p_total = storage.block.p_total;
2439    let rhs_elems = rhs_count
2440        .checked_mul(p_total)
2441        .ok_or_else(|| GpuError::DriverCallFailed {
2442            reason: format!(
2443                "bms_flex_row {ctx}: rhs_count({rhs_count})*p_total({p_total}) overflow"
2444            ),
2445        })?;
2446    if v_rhs_len != rhs_elems {
2447        return Err(GpuError::DriverCallFailed {
2448            reason: format!(
2449                "bms_flex_row {ctx}: v_rhs.len()={v_rhs_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2450            ),
2451        });
2452    }
2453    if let Some(out_len) = out_len
2454        && out_len != rhs_elems
2455    {
2456        return Err(GpuError::DriverCallFailed {
2457            reason: format!(
2458                "bms_flex_row {ctx}: out.len()={out_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2459            ),
2460        });
2461    }
2462    Ok(rhs_elems)
2463}
2464
2465/// Transient device bytes for a multi-RHS HVP launch, excluding persistent
2466/// row-Hessian/design storage. Scratch scales with
2467/// `rhs_count * num_chunks * p_total`, not `rhs_count * n * r * r`.
2468#[cfg(target_os = "linux")]
2469pub fn bms_flex_row_hvp_multi_scratch_bytes_for_shape(
2470    n: usize,
2471    p_total: usize,
2472    rhs_count: usize,
2473) -> Result<u64, GpuError> {
2474    if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2475        return Err(GpuError::DriverCallFailed {
2476            reason: format!(
2477                "bms_flex_row hvp_multi_scratch_bytes: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2478            ),
2479        });
2480    }
2481    let num_chunks = num_hvp_chunks(n);
2482    let partial = rhs_count
2483        .checked_mul(num_chunks)
2484        .and_then(|v| v.checked_mul(p_total))
2485        .ok_or_else(|| GpuError::DriverCallFailed {
2486            reason: format!(
2487                "bms_flex_row hvp_multi_scratch_bytes: rhs_count({rhs_count})*num_chunks({num_chunks})*p_total({p_total}) overflow"
2488            ),
2489        })?;
2490    let rhs_vectors = rhs_count
2491        .checked_mul(p_total)
2492        .and_then(|v| v.checked_mul(2))
2493        .ok_or_else(|| GpuError::DriverCallFailed {
2494            reason: format!(
2495                "bms_flex_row hvp_multi_scratch_bytes: 2*rhs_count({rhs_count})*p_total({p_total}) overflow"
2496            ),
2497        })?;
2498    let elems = partial
2499        .checked_add(rhs_vectors)
2500        .ok_or_else(|| GpuError::DriverCallFailed {
2501            reason: "bms_flex_row hvp_multi_scratch_bytes: element count overflow".to_string(),
2502        })?;
2503    Ok((elems * std::mem::size_of::<f64>()) as u64)
2504}
2505
2506#[cfg(target_os = "linux")]
2507pub(crate) fn run_bms_flex_row_multi_partial_reduce(
2508    storage: &DeviceResidentRowHess,
2509    rhs_count: usize,
2510    d_v_rhs: &CudaSlice<f64>,
2511    d_out: &mut CudaSlice<f64>,
2512    ctx: &str,
2513) -> Result<(), GpuError> {
2514    let rhs_elems = validate_bms_flex_row_hvp_multi_shape(
2515        storage,
2516        rhs_count,
2517        d_v_rhs.len(),
2518        Some(d_out.len()),
2519        ctx,
2520    )?;
2521    let backend = HvpKernelBackend::probe()?;
2522    let stream = backend.stream.clone();
2523    let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2524    let p_total = storage.block.p_total;
2525    let partial_len = rhs_count
2526        .checked_mul(args.num_chunks)
2527        .and_then(|v| v.checked_mul(p_total))
2528        .ok_or_else(|| GpuError::DriverCallFailed {
2529            reason: format!(
2530                "bms_flex_row {ctx}: partial length overflow for rhs_count={rhs_count}, num_chunks={}, p_total={p_total}",
2531                args.num_chunks
2532            ),
2533        })?;
2534
2535    let mut d_partial =
2536        stream
2537            .alloc_zeros::<f64>(partial_len)
2538            .map_err(|err| GpuError::DriverCallFailed {
2539                reason: format!("bms_flex_row {ctx} alloc multi partial: {err}"),
2540            })?;
2541    let part_func = backend
2542        .module
2543        .load_function("bms_flex_row_hvp_multi_partial")
2544        .map_err(|err| GpuError::DriverCallFailed {
2545            reason: format!("bms_flex_row {ctx} load multi partial: {err}"),
2546        })?;
2547    let red_func = backend
2548        .module
2549        .load_function("bms_flex_row_hvp_multi_reduce")
2550        .map_err(|err| GpuError::DriverCallFailed {
2551            reason: format!("bms_flex_row {ctx} load multi reduce: {err}"),
2552        })?;
2553
2554    let rhs_count_i32 = i32::try_from(rhs_count).map_err(|_| GpuError::DriverCallFailed {
2555        reason: format!("bms_flex_row {ctx}: rhs_count={rhs_count} exceeds i32 range"),
2556    })?;
2557    let cfg_part = LaunchConfig {
2558        grid_dim: (args.num_chunks as u32, 1, 1),
2559        block_dim: (HVP_THREADS, 1, 1),
2560        shared_mem_bytes: 0,
2561    };
2562    let mut builder = stream.launch_builder(&part_func);
2563    builder
2564        .arg(&args.n_i32)
2565        .arg(&args.r_i32)
2566        .arg(&args.p_m_i32)
2567        .arg(&args.p_g_i32)
2568        .arg(&args.p_total_i32)
2569        .arg(&args.h_block_start)
2570        .arg(&args.h_block_len)
2571        .arg(&args.w_block_start)
2572        .arg(&args.w_block_len)
2573        .arg(&args.h_primary_start)
2574        .arg(&args.w_primary_start)
2575        .arg(&args.rows_per_cta)
2576        .arg(&rhs_count_i32)
2577        .arg(&storage.hess)
2578        .arg(&storage.marginal_design)
2579        .arg(&storage.logslope_design)
2580        .arg(d_v_rhs)
2581        .arg(&mut d_partial);
2582    // SAFETY: storage buffers were validated at construction; `d_v_rhs` and
2583    // `d_out` have rhs_count*p_total elements, `d_partial` has
2584    // rhs_count*num_chunks*p_total, and rhs_count is bounded by fixed shared
2585    // array sizes in the CUDA source.
2586    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2587        reason: format!("bms_flex_row {ctx} multi partial launch: {err}"),
2588    })?;
2589
2590    let red_threads: u32 = REDUCTION_THREADS;
2591    let red_blocks: u32 = ((rhs_elems as u32) + red_threads - 1) / red_threads;
2592    let cfg_red = LaunchConfig {
2593        grid_dim: (red_blocks, 1, 1),
2594        block_dim: (red_threads, 1, 1),
2595        shared_mem_bytes: 0,
2596    };
2597    let num_chunks_i32 = args.num_chunks as i32;
2598    let mut builder = stream.launch_builder(&red_func);
2599    builder
2600        .arg(&num_chunks_i32)
2601        .arg(&args.p_total_i32)
2602        .arg(&rhs_count_i32)
2603        .arg(&d_partial)
2604        .arg(d_out);
2605    // SAFETY: the reduce kernel reads the just-populated partial buffer and
2606    // writes exactly rhs_count*p_total output entries.
2607    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2608        reason: format!("bms_flex_row {ctx} multi reduce launch: {err}"),
2609    })?;
2610    drop(d_partial);
2611    Ok(())
2612}
2613
2614/// Device-resident multi-RHS HVP. `v_rhs` is row-major
2615/// `[rhs_count, p_total]`; the returned vector has the same layout.
2616#[cfg(target_os = "linux")]
2617pub(crate) fn launch_bms_flex_row_hvp_multi(
2618    storage: &DeviceResidentRowHess,
2619    v_rhs: &[f64],
2620    rhs_count: usize,
2621) -> Result<Vec<f64>, GpuError> {
2622    let rhs_elems =
2623        validate_bms_flex_row_hvp_multi_shape(storage, rhs_count, v_rhs.len(), None, "hvp_multi")?;
2624    let backend = HvpKernelBackend::probe()?;
2625    let stream = backend.stream.clone();
2626    let d_v_rhs = stream
2627        .clone_htod(v_rhs)
2628        .map_err(|err| GpuError::DriverCallFailed {
2629            reason: format!("bms_flex_row hvp_multi upload v_rhs: {err}"),
2630        })?;
2631    let mut d_out =
2632        stream
2633            .alloc_zeros::<f64>(rhs_elems)
2634            .map_err(|err| GpuError::DriverCallFailed {
2635                reason: format!("bms_flex_row hvp_multi alloc out: {err}"),
2636            })?;
2637    run_bms_flex_row_multi_partial_reduce(storage, rhs_count, &d_v_rhs, &mut d_out, "hvp_multi")?;
2638    stream
2639        .synchronize()
2640        .map_err(|err| GpuError::DriverCallFailed {
2641            reason: format!("bms_flex_row hvp_multi synchronize: {err}"),
2642        })?;
2643    stream
2644        .clone_dtoh(&d_out)
2645        .map_err(|err| GpuError::DriverCallFailed {
2646            reason: format!("bms_flex_row hvp_multi download out: {err}"),
2647        })
2648}
2649
2650/// Device-output HVP. Runs `bms_flex_row_hvp_partial(_packed)` +
2651/// `bms_flex_row_hvp_reduce` on the storage's stream against caller-supplied
2652/// device-resident `d_v` (length `p_total` doubles), writing the result into
2653/// caller-supplied `d_out` (also `p_total` doubles). **No** `synchronize()`
2654/// or DtoH is performed — the caller is responsible for stream ordering
2655/// against any consumer that reads `d_out`.
2656///
2657/// This is the device-resident PCG hot path (Block 9 Phase 5): keeping the
2658/// HVP output on the stream lets the outer PCG loop chain axpy / dot /
2659/// preconditioner kernels back-to-back without a per-iter device sync.
2660#[cfg(target_os = "linux")]
2661pub(crate) fn launch_bms_flex_row_hvp_into_device(
2662    storage: &DeviceResidentRowHess,
2663    d_v: &CudaSlice<f64>,
2664    d_out: &mut CudaSlice<f64>,
2665) -> Result<(), GpuError> {
2666    let p_total = storage.block.p_total;
2667    if d_v.len() != p_total {
2668        return Err(GpuError::DriverCallFailed {
2669            reason: format!(
2670                "bms_flex_row hvp_into_device: d_v.len()={} != p_total={}",
2671                d_v.len(),
2672                p_total
2673            ),
2674        });
2675    }
2676    if d_out.len() != p_total {
2677        return Err(GpuError::DriverCallFailed {
2678            reason: format!(
2679                "bms_flex_row hvp_into_device: d_out.len()={} != p_total={}",
2680                d_out.len(),
2681                p_total
2682            ),
2683        });
2684    }
2685    // On-stream output: the shared engine launches partial+reduce into the
2686    // caller's `d_out` and returns without sync/DtoH, so the outer PCG loop can
2687    // chain device kernels against the result.
2688    run_bms_flex_row_partial_reduce(
2689        storage,
2690        BmsFlexRowLaunchMode::HvpDeviceOut,
2691        Some(d_v),
2692        d_out,
2693        "hvp_into_device",
2694    )
2695}
2696
2697/// Launch the device-resident HVP kernel. Returns the host-side joint β image
2698/// of length `block.p_total`.
2699#[cfg(target_os = "linux")]
2700pub(crate) fn launch_bms_flex_row_hvp(
2701    storage: &DeviceResidentRowHess,
2702    v: &[f64],
2703) -> Result<Vec<f64>, GpuError> {
2704    launch_bms_flex_row_hvp_multi(storage, v, 1)
2705}
2706
2707/// Launch the device-resident diagonal kernel. Returns the host-side joint
2708/// β diagonal of length `block.p_total`.
2709#[cfg(target_os = "linux")]
2710pub(crate) fn launch_bms_flex_row_diagonal(
2711    storage: &DeviceResidentRowHess,
2712) -> Result<Vec<f64>, GpuError> {
2713    launch_bms_flex_row_host(storage, BmsFlexRowLaunchMode::DiagonalHostOut, None, "diag")
2714}
2715
2716/// Block 9 Phase 6 — hard cap on `p_total` for the dense joint-Hessian
2717/// device kernel. Per-CTA shared-memory accumulator is `p_total² * 8`
2718/// bytes. V100 default per-block shared cap is 48 KiB, so the largest
2719/// safe `p_total` here is `sqrt(48 KiB / 8) = 78`. We round down to a
2720/// power-of-two-ish multiple of 8 for predictable launch geometry.
2721#[cfg(target_os = "linux")]
2722pub(crate) const DENSE_BLOCK_MAX_P: usize = 72;
2723
2724/// Number of rows each dense-block CTA processes. Smaller than the HVP
2725/// `HVP_ROWS_PER_CTA = 256` because the per-row inner loop is `O(r² *
2726/// (p_m + p_g + h_block_len + w_block_len))` rather than `O(r²)` — fewer
2727/// rows per CTA keeps the per-CTA wall time short and lets us scale grid
2728/// occupancy with `num_chunks = ceil(n / DENSE_BLOCK_ROWS_PER_CTA)`.
2729#[cfg(target_os = "linux")]
2730pub(crate) const DENSE_BLOCK_ROWS_PER_CTA: u32 = 32;
2731
2732/// Launch the Phase-6 dense joint-Hessian block kernel. Returns the
2733/// host-side `[p_total, p_total]` row-major joint H as a `Vec<f64>`
2734/// (length `p_total²`).
2735///
2736/// **Not the default Newton path.** Production Newton uses HVP (Phase 2)
2737/// and never materialises the full dense Hessian. This entry exists for:
2738///   * exact-REML logdet (`log|H|`) when the unified evaluator wants to
2739///     factor H directly instead of going through the matrix-free path;
2740///   * diagnostic dumps that compare the GPU dense build against the CPU
2741///     `BernoulliMarginalSlopeFamily::fused_gradient_dense` reference;
2742///   * small-`p` debug routes where it is cheaper to factor + solve dense
2743///     than to run a PCG.
2744///
2745/// The kernel rejects `p_total > DENSE_BLOCK_MAX_P` cleanly because the
2746/// per-CTA shared-memory accumulator (`p_total² * 8` bytes) would exceed
2747/// the V100 48 KiB/block cap above that threshold.
2748#[cfg(target_os = "linux")]
2749pub fn launch_bms_flex_row_dense_block(
2750    storage: &DeviceResidentRowHess,
2751) -> Result<Vec<f64>, GpuError> {
2752    let p_total = storage.block.p_total;
2753    if p_total == 0 {
2754        return Err(GpuError::DriverCallFailed {
2755            reason: "bms_flex_row dense_block: p_total must be > 0".to_string(),
2756        });
2757    }
2758    if p_total > DENSE_BLOCK_MAX_P {
2759        return Err(GpuError::DriverCallFailed {
2760            reason: format!(
2761                "bms_flex_row dense_block: p_total={p_total} exceeds DENSE_BLOCK_MAX_P={DENSE_BLOCK_MAX_P} \
2762                 (per-CTA shmem accumulator p²*8 bytes would exceed V100's 48 KiB/block)"
2763            ),
2764        });
2765    }
2766    let backend = HvpKernelBackend::probe()?;
2767    let stream = backend.stream.clone();
2768    let n = storage.n;
2769    let r = storage.r;
2770    let rows_per_cta = DENSE_BLOCK_ROWS_PER_CTA as usize;
2771    let num_chunks = n.div_ceil(rows_per_cta);
2772    let pp = p_total * p_total;
2773
2774    let mut d_partial =
2775        stream
2776            .alloc_zeros::<f64>(num_chunks * pp)
2777            .map_err(|err| GpuError::DriverCallFailed {
2778                reason: format!("bms_flex_row dense_block alloc partial: {err}"),
2779            })?;
2780    let mut d_out = stream
2781        .alloc_zeros::<f64>(pp)
2782        .map_err(|err| GpuError::DriverCallFailed {
2783            reason: format!("bms_flex_row dense_block alloc out: {err}"),
2784        })?;
2785
2786    let part_func = backend
2787        .module
2788        .load_function("bms_flex_row_dense_block_partial")
2789        .map_err(|err| GpuError::DriverCallFailed {
2790            reason: format!("bms_flex_row dense_block load partial: {err}"),
2791        })?;
2792    let red_func = backend
2793        .module
2794        .load_function("bms_flex_row_dense_block_reduce")
2795        .map_err(|err| GpuError::DriverCallFailed {
2796            reason: format!("bms_flex_row dense_block load reduce: {err}"),
2797        })?;
2798
2799    let n_i32 = n as i32;
2800    let r_i32 = r as i32;
2801    let p_m_i32 = storage.block.p_m as i32;
2802    let p_g_i32 = storage.block.p_g as i32;
2803    let p_total_i32 = p_total as i32;
2804    let h_block_start = storage
2805        .block
2806        .h
2807        .as_ref()
2808        .map(|r| r.start as i32)
2809        .unwrap_or(0);
2810    let h_block_len = storage
2811        .block
2812        .h
2813        .as_ref()
2814        .map(|r| r.len() as i32)
2815        .unwrap_or(0);
2816    let w_block_start = storage
2817        .block
2818        .w
2819        .as_ref()
2820        .map(|r| r.start as i32)
2821        .unwrap_or(0);
2822    let w_block_len = storage
2823        .block
2824        .w
2825        .as_ref()
2826        .map(|r| r.len() as i32)
2827        .unwrap_or(0);
2828    let h_primary_start = storage
2829        .primary
2830        .h
2831        .as_ref()
2832        .map(|r| r.start as i32)
2833        .unwrap_or(0);
2834    let w_primary_start = storage
2835        .primary
2836        .w
2837        .as_ref()
2838        .map(|r| r.start as i32)
2839        .unwrap_or(0);
2840    let rows_per_cta_i32 = DENSE_BLOCK_ROWS_PER_CTA as i32;
2841    let num_chunks_u32 = num_chunks as u32;
2842
2843    // Per-CTA shmem accumulator: p_total² doubles.
2844    let shmem_bytes: u32 =
2845        u32::try_from(pp * std::mem::size_of::<f64>()).map_err(|_| GpuError::DriverCallFailed {
2846            reason: format!("dense_block shmem bytes overflow u32 for p_total={p_total}"),
2847        })?;
2848
2849    let cfg_part = LaunchConfig {
2850        grid_dim: (num_chunks_u32, 1, 1),
2851        block_dim: (HVP_THREADS, 1, 1),
2852        shared_mem_bytes: shmem_bytes,
2853    };
2854    let mut builder = stream.launch_builder(&part_func);
2855    builder
2856        .arg(&n_i32)
2857        .arg(&r_i32)
2858        .arg(&p_m_i32)
2859        .arg(&p_g_i32)
2860        .arg(&p_total_i32)
2861        .arg(&h_block_start)
2862        .arg(&h_block_len)
2863        .arg(&w_block_start)
2864        .arg(&w_block_len)
2865        .arg(&h_primary_start)
2866        .arg(&w_primary_start)
2867        .arg(&rows_per_cta_i32)
2868        .arg(&storage.hess)
2869        .arg(&storage.marginal_design)
2870        .arg(&storage.logslope_design)
2871        .arg(&mut d_partial);
2872    // SAFETY: storage pointers have validated capacities; d_partial sized
2873    // num_chunks * pp doubles; dynamic shmem matches the kernel's `extern
2874    // __shared__` accumulator length.
2875    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2876        reason: format!("bms_flex_row dense_block partial launch: {err}"),
2877    })?;
2878
2879    let red_threads: u32 = REDUCTION_THREADS;
2880    let red_blocks: u32 = ((pp as u32) + red_threads - 1) / red_threads;
2881    let cfg_red = LaunchConfig {
2882        grid_dim: (red_blocks, 1, 1),
2883        block_dim: (red_threads, 1, 1),
2884        shared_mem_bytes: 0,
2885    };
2886    let num_chunks_i32 = num_chunks as i32;
2887    let mut builder = stream.launch_builder(&red_func);
2888    builder
2889        .arg(&num_chunks_i32)
2890        .arg(&p_total_i32)
2891        .arg(&d_partial)
2892        .arg(&mut d_out);
2893    // SAFETY: d_partial just populated, d_out is pp doubles.
2894    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2895        reason: format!("bms_flex_row dense_block reduce launch: {err}"),
2896    })?;
2897    stream
2898        .synchronize()
2899        .map_err(|err| GpuError::DriverCallFailed {
2900            reason: format!("bms_flex_row dense_block sync: {err}"),
2901        })?;
2902    stream
2903        .clone_dtoh(&d_out)
2904        .map_err(|err| GpuError::DriverCallFailed {
2905            reason: format!("bms_flex_row dense_block download: {err}"),
2906        })
2907}
2908
2909// Block 9 / V100-build unblock (2026-05-27): every test below either
2910// constructs `BmsFlexBlockLayout` / `BmsFlexPrimaryLayout` (Linux-only
2911// types) or drives a CUDA-dependent fixture, so gate the whole module
2912// `#[cfg(all(test, target_os = "linux"))]`. On macOS the structs are
2913// absent and these tests do not compile — the build.rs ban scanner
2914// explicitly rejects `#[cfg(any(..., test))]` on the struct definitions
2915// themselves as a dead-code escape hatch.
2916#[cfg(all(test, target_os = "linux"))]
2917mod tests {
2918    use super::*;
2919
2920    pub(crate) fn minimal_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
2921        BmsFlexRowKernelInputs {
2922            n_rows: 1,
2923            r: 4,
2924            p_h: 1,
2925            p_w: 1,
2926            q: &buffers.q,
2927            b: &buffers.b,
2928            mu_1: &buffers.mu_1,
2929            mu_2: &buffers.mu_2,
2930            z_obs: &buffers.z_obs,
2931            y: &buffers.y,
2932            w: &buffers.w,
2933            s_f: 1.0,
2934            cell_offsets: &buffers.cell_offsets,
2935            cell_c0: &buffers.cell_c0,
2936            cell_c1: &buffers.cell_c1,
2937            cell_c2: &buffers.cell_c2,
2938            cell_c3: &buffers.cell_c3,
2939            cell_a: &buffers.cell_a,
2940            cell_aa: &buffers.cell_aa,
2941            cell_r: &buffers.cell_r,
2942            cell_ar: &buffers.cell_ar,
2943            cell_sbb: &buffers.cell_sbb,
2944            cell_sbh: &buffers.cell_sbh,
2945            cell_sbw: &buffers.cell_sbw,
2946            cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
2947            chi_obs: &buffers.chi_obs,
2948            xi_obs: &buffers.xi_obs,
2949            rho_u: &buffers.rho_u,
2950            tau_u: &buffers.tau_u,
2951            r_uv: &buffers.r_uv,
2952        }
2953    }
2954
2955    pub(crate) struct TestBuffers {
2956        pub(crate) q: Vec<f64>,
2957        pub(crate) b: Vec<f64>,
2958        pub(crate) mu_1: Vec<f64>,
2959        pub(crate) mu_2: Vec<f64>,
2960        pub(crate) z_obs: Vec<f64>,
2961        pub(crate) y: Vec<f64>,
2962        pub(crate) w: Vec<f64>,
2963        pub(crate) cell_offsets: Vec<u32>,
2964        pub(crate) cell_c0: Vec<f64>,
2965        pub(crate) cell_c1: Vec<f64>,
2966        pub(crate) cell_c2: Vec<f64>,
2967        pub(crate) cell_c3: Vec<f64>,
2968        pub(crate) cell_a: Vec<f64>,
2969        pub(crate) cell_aa: Vec<f64>,
2970        pub(crate) cell_r: Vec<f64>,
2971        pub(crate) cell_ar: Vec<f64>,
2972        pub(crate) cell_sbb: Vec<f64>,
2973        pub(crate) cell_sbh: Vec<f64>,
2974        pub(crate) cell_sbw: Vec<f64>,
2975        pub(crate) cell_moments: Vec<f64>,
2976        pub(crate) chi_obs: Vec<f64>,
2977        pub(crate) xi_obs: Vec<f64>,
2978        pub(crate) rho_u: Vec<f64>,
2979        pub(crate) tau_u: Vec<f64>,
2980        pub(crate) r_uv: Vec<f64>,
2981    }
2982
2983    pub(crate) fn make_buffers(n_cells: u32, r: usize, p_h: usize, p_w: usize) -> TestBuffers {
2984        let cells = n_cells as usize;
2985        TestBuffers {
2986            q: vec![0.1; 1],
2987            b: vec![0.5; 1],
2988            mu_1: vec![0.3; 1],
2989            mu_2: vec![0.07; 1],
2990            z_obs: vec![0.0; 1],
2991            y: vec![1.0; 1],
2992            w: vec![1.0; 1],
2993            cell_offsets: vec![0, n_cells],
2994            cell_c0: vec![0.2; cells],
2995            cell_c1: vec![-0.1; cells],
2996            cell_c2: vec![0.05; cells],
2997            cell_c3: vec![-0.02; cells],
2998            cell_a: vec![0.1; cells * 4],
2999            cell_aa: vec![0.0; cells * 4],
3000            cell_r: vec![0.05; cells * (r - 1) * 4],
3001            cell_ar: vec![0.0; cells * (r - 1) * 4],
3002            cell_sbb: vec![0.0; cells * 4],
3003            cell_sbh: vec![0.0; cells * p_h * 4],
3004            cell_sbw: vec![0.0; cells * p_w * 4],
3005            cell_moments: vec![1.0; cells * MOMENT_STRIDE],
3006            chi_obs: vec![1.0; 1],
3007            xi_obs: vec![0.0; 1],
3008            rho_u: vec![0.0; r],
3009            tau_u: vec![0.0; r],
3010            r_uv: vec![0.0; r * r],
3011        }
3012    }
3013
3014    #[test]
3015    pub(crate) fn validate_accepts_minimal_inputs() {
3016        let buffers = make_buffers(2, 4, 1, 1);
3017        let inputs = minimal_inputs(&buffers);
3018        assert!(inputs.validate().is_ok());
3019    }
3020
3021    #[test]
3022    pub(crate) fn validate_rejects_r_above_max() {
3023        let r = MAX_R + 1;
3024        let p_h = (r - 2) / 2;
3025        let p_w = (r - 2) - p_h;
3026        let buffers = make_buffers(1, r, p_h, p_w);
3027        let bad_inputs = BmsFlexRowKernelInputs {
3028            r,
3029            p_h,
3030            p_w,
3031            rho_u: &buffers.rho_u, // length matches `r` we wrote
3032            tau_u: &buffers.tau_u,
3033            r_uv: &buffers.r_uv,
3034            cell_r: &buffers.cell_r,
3035            cell_ar: &buffers.cell_ar,
3036            cell_sbh: &buffers.cell_sbh,
3037            cell_sbw: &buffers.cell_sbw,
3038            ..minimal_inputs(&buffers)
3039        };
3040        let err = bad_inputs.validate().expect_err("r > MAX_R must fail");
3041        let msg = err.to_string();
3042        assert!(msg.contains("MAX_R"), "expected MAX_R hint, got: {msg}");
3043    }
3044
3045    #[test]
3046    pub(crate) fn validate_rejects_mismatched_r_decomposition() {
3047        let buffers = make_buffers(1, 4, 1, 1);
3048        let bad_inputs = BmsFlexRowKernelInputs {
3049            r: 4,
3050            p_h: 1,
3051            p_w: 2, // inconsistent with r = 4
3052            ..minimal_inputs(&buffers)
3053        };
3054        let err = bad_inputs
3055            .validate()
3056            .expect_err("inconsistent r vs p_h+p_w must fail");
3057        let msg = err.to_string();
3058        assert!(msg.contains("p_h"), "got: {msg}");
3059        assert!(msg.contains("p_w"), "got: {msg}");
3060    }
3061
3062    #[test]
3063    pub(crate) fn validate_rejects_non_monotone_offsets() {
3064        // `minimal_inputs` hard-codes `n_rows = 1`, so the CSR-style row
3065        // pointer length is `n + 1 = 2`. Pin both `offsets[1] = total_cells`
3066        // and `cell_c0.len() = total_cells = 2` from `make_buffers(2, …)`,
3067        // then violate monotonicity by setting `offsets[0] > offsets[1]`;
3068        // every length / per-cell-count check is satisfied so the only
3069        // failure mode left is the monotonicity guard.
3070        let mut buffers = make_buffers(2, 4, 1, 1);
3071        buffers.cell_offsets = vec![5, 2];
3072        let inputs = minimal_inputs(&buffers);
3073        let err = inputs
3074            .validate()
3075            .expect_err("non-monotone offsets must fail");
3076        let msg = err.to_string();
3077        assert!(msg.contains("monotone"), "got: {msg}");
3078    }
3079
3080    #[test]
3081    pub(crate) fn validate_rejects_mismatched_cell_moments_length() {
3082        let mut buffers = make_buffers(2, 4, 1, 1);
3083        buffers.cell_moments.pop(); // length now 2*10 - 1
3084        let inputs = minimal_inputs(&buffers);
3085        let err = inputs.validate().expect_err("short cell_moments must fail");
3086        let msg = err.to_string();
3087        assert!(msg.contains("cell_moments"), "got: {msg}");
3088    }
3089
3090    #[test]
3091    pub(crate) fn launch_on_non_linux_reports_driver_library_unavailable() {
3092        // Mac/Windows builds must surface a typed `DriverLibraryUnavailable`
3093        // rather than panicking or returning Ok. On Linux this test is
3094        // skipped because the kernel actually launches.
3095        #[cfg(target_os = "linux")]
3096        {
3097            // Linux builds may or may not have a device; the dispatcher
3098            // contract is that without a runtime, probe() returns
3099            // DriverLibraryUnavailable. Either outcome (NoDeviceKernel,
3100            // DriverLibraryUnavailable, or DriverCallFailed) is acceptable
3101            // here; success would mean the kernel actually ran which is a
3102            // V100-only outcome we don't gate the unit test on.
3103            let buffers = make_buffers(1, 4, 1, 1);
3104            let inputs = minimal_inputs(&buffers);
3105            match launch_bms_flex_row_kernel(inputs) {
3106                Ok(_) => { /* V100 host: real launch */ }
3107                Err(GpuError::DriverLibraryUnavailable { .. })
3108                | Err(GpuError::DriverCallFailed { .. })
3109                | Err(GpuError::DriverSymbolMissing { .. })
3110                | Err(GpuError::NoDeviceKernel { .. }) => { /* expected on CPU-only */ }
3111                Err(other) => panic!("unexpected GpuError variant: {other:?}"),
3112            }
3113        }
3114        #[cfg(not(target_os = "linux"))]
3115        {
3116            let buffers = make_buffers(1, 4, 1, 1);
3117            let inputs = minimal_inputs(&buffers);
3118            match launch_bms_flex_row_kernel(inputs) {
3119                Err(GpuError::DriverLibraryUnavailable { reason }) => {
3120                    assert!(
3121                        reason.contains("Linux-only"),
3122                        "expected Linux-only hint, got: {reason}"
3123                    );
3124                }
3125                other => panic!("expected DriverLibraryUnavailable on non-Linux, got {other:?}"),
3126            }
3127        }
3128    }
3129
3130    #[test]
3131    pub(crate) fn s_f_must_be_positive_and_finite() {
3132        let buffers = make_buffers(1, 4, 1, 1);
3133        let mut inputs = minimal_inputs(&buffers);
3134        inputs.s_f = 0.0;
3135        match launch_bms_flex_row_kernel(inputs) {
3136            Err(GpuError::DriverCallFailed { reason }) => {
3137                assert!(reason.contains("s_f"), "got: {reason}");
3138            }
3139            other => panic!("expected DriverCallFailed for s_f=0, got {other:?}"),
3140        }
3141    }
3142
3143    // ── CPU oracle that mirrors ROW_KERNEL_BODY bit-for-bit ──────────────────
3144    //
3145    // `cpu_oracle_outputs` implements the same algebra as
3146    // `bms_flex_row_kernel` in ROW_KERNEL_BODY: per-cell `T_n` / `D` / `Q`
3147    // contractions, q-row override, IFT to `a_u` / `a_uv`, observed-point
3148    // assembly to `bar_e_u` / `bar_e_uv`, probit Mills, and the final
3149    // `out_grad` / `out_hess` writes. It takes the same
3150    // `BmsFlexRowKernelInputs` struct so a CUDA-equipped host can run both
3151    // paths off one bundle and check element-wise parity.
3152    //
3153    // Used by the GPU↔CPU parity test below; the test skips on non-Linux
3154    // hosts via cfg, but the oracle itself is platform-independent so the
3155    // macOS lib build can still type-check it.
3156
3157    pub(crate) const ORACLE_INV_TWO_PI: f64 = 1.0 / std::f64::consts::TAU;
3158    pub(crate) const ORACLE_SQRT_2: f64 = std::f64::consts::SQRT_2;
3159    pub(crate) const ORACLE_INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
3160
3161    pub(crate) fn oracle_erfcx_nonnegative(x: f64) -> f64 {
3162        if !x.is_finite() {
3163            return if x > 0.0 { 0.0 } else { f64::INFINITY };
3164        }
3165        if x <= 0.0 {
3166            return 1.0;
3167        }
3168        if x < 26.0 {
3169            let mut xx = x * x;
3170            if xx > 700.0 {
3171                xx = 700.0;
3172            }
3173            return xx.exp() * gam_gpu::numerics_host::erfc(x);
3174        }
3175        let inv = 1.0 / x;
3176        let inv2 = inv * inv;
3177        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
3178            + 6.5625 * inv2 * inv2 * inv2 * inv2;
3179        let inv_sqrt_pi: f64 = 0.564_189_583_547_756_3;
3180        inv * poly * inv_sqrt_pi
3181    }
3182
3183    pub(crate) fn oracle_log_ndtr_and_mills(x: f64) -> (f64, f64) {
3184        if x == f64::INFINITY {
3185            return (0.0, 0.0);
3186        }
3187        if x == f64::NEG_INFINITY {
3188            return (f64::NEG_INFINITY, f64::INFINITY);
3189        }
3190        if x.is_nan() {
3191            return (x, x);
3192        }
3193        // Single-algorithm region around and above 0. Both `log Φ(x)` and the
3194        // Mills ratio `φ/Φ` are computed from the SAME `erfc(-x/√2)` call, so
3195        // the oracle is C¹ across the x=0 seam (the prior split — erfcx-based
3196        // `-u²+ln(0.5·e^{u²}·erfc u)` for x<0 vs direct `ln(0.5·erfc(-x))` for
3197        // x≥0 — used two distinct float algorithms whose ~1e-7 disagreement at
3198        // x=0 corrupted a finite-difference reference straddling the seam, #838).
3199        //
3200        // The erfcx form is mathematically identical (e^{u²}·erfc(u)=erfcx(u),
3201        // so −u²+ln(0.5·erfcx u)=ln(0.5·erfc(−x/√2))); it is only needed deep in
3202        // the left tail (x ≲ −38), where `erfc(-x/√2)` underflows to 0 and the
3203        // exp/ln cancellation is the *only* way to keep `log Φ` finite. We move
3204        // the branch there, far from any region the kernel or its FD lock visits.
3205        const ORACLE_LEFT_TAIL_X: f64 = -37.0;
3206        if x >= ORACLE_LEFT_TAIL_X {
3207            let mut cdf = 0.5 * gam_gpu::numerics_host::erfc(-x / ORACLE_SQRT_2);
3208            if cdf < 1e-300 {
3209                cdf = 1e-300;
3210            }
3211            if cdf > 1.0 {
3212                cdf = 1.0;
3213            }
3214            let pdf = ORACLE_INV_SQRT_2PI * (-0.5 * x * x).exp();
3215            (cdf.ln(), pdf / cdf)
3216        } else {
3217            let u = -x / ORACLE_SQRT_2;
3218            let mut ex = oracle_erfcx_nonnegative(u);
3219            if ex < 1e-300 {
3220                ex = 1e-300;
3221            }
3222            let log_cdf = -u * u + (0.5 * ex).ln();
3223            let sqrt_2_over_pi: f64 = 0.797_884_560_802_865_4;
3224            (log_cdf, sqrt_2_over_pi / ex)
3225        }
3226    }
3227
3228    /// Same outputs the device kernel writes: `(neglog, grad, hess)` per row.
3229    /// `grad` is row-major `n × r`, `hess` is row-major `n × r × r`.
3230    /// Mirrors `bms_flex_row_kernel` line-for-line so kernel + oracle diverge
3231    /// only if one side breaks parity.
3232    pub(crate) fn cpu_oracle_outputs(
3233        inputs: &BmsFlexRowKernelInputs<'_>,
3234    ) -> BmsFlexRowKernelOutputs {
3235        let n = inputs.n_rows;
3236        let r = inputs.r;
3237        let p_h = inputs.p_h;
3238        let p_w = inputs.p_w;
3239        let mut neglog = vec![0.0_f64; n];
3240        let mut grad = vec![0.0_f64; n * r];
3241        let mut hess = vec![0.0_f64; n * r * r];
3242        let cell_moments_host = match &inputs.cell_moments {
3243            CellMomentsSource::Host(slice) => *slice,
3244            #[cfg(target_os = "linux")]
3245            CellMomentsSource::Device(_) => panic!(
3246                // SAFETY: this CPU oracle is a host-only sanity checker invoked
3247                // exclusively from `#[cfg(test)] mod tests`. The kernel-launch
3248                // path uses `CellMomentsSource::Device(...)`; the oracle must
3249                // never see that variant. Reaching this arm means a test
3250                // mis-wired its fixture — surface it loudly at the call site.
3251                "cpu_oracle_outputs: cell_moments is device-resident; oracle \
3252                 is a host-only sanity checker"
3253            ),
3254        };
3255
3256        for row in 0..n {
3257            // ── per-cell sweep: accumulate F_u, F_au, F_uv, F_a, F_aa.
3258            let mut f_u = vec![0.0_f64; r];
3259            let mut f_au = vec![0.0_f64; r];
3260            let mut f_uv = vec![0.0_f64; r * r];
3261            let mut f_a = 0.0_f64;
3262            let mut f_aa = 0.0_f64;
3263
3264            let cell_lo = inputs.cell_offsets[row] as usize;
3265            let cell_hi = inputs.cell_offsets[row + 1] as usize;
3266            for c in cell_lo..cell_hi {
3267                let c_arr = [
3268                    inputs.cell_c0[c],
3269                    inputs.cell_c1[c],
3270                    inputs.cell_c2[c],
3271                    inputs.cell_c3[c],
3272                ];
3273                let m = &cell_moments_host[c * MOMENT_STRIDE..(c + 1) * MOMENT_STRIDE];
3274
3275                // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
3276                let mut t = [0.0_f64; 7];
3277                for (n_idx, t_slot) in t.iter_mut().enumerate() {
3278                    let mut acc = 0.0_f64;
3279                    for (e, c_e) in c_arr.iter().enumerate() {
3280                        acc = c_e.mul_add(m[e + n_idx], acc);
3281                    }
3282                    *t_slot = acc * ORACLE_INV_TWO_PI;
3283                }
3284
3285                let d_of = |r_arr: &[f64]| -> f64 {
3286                    ORACLE_INV_TWO_PI
3287                        * (r_arr[0] * m[0] + r_arr[1] * m[1] + r_arr[2] * m[2] + r_arr[3] * m[3])
3288                };
3289                let q_of = |r_arr: &[f64], s_arr: &[f64]| -> f64 {
3290                    (r_arr[0] * s_arr[0]) * t[0]
3291                        + (r_arr[0] * s_arr[1] + r_arr[1] * s_arr[0]) * t[1]
3292                        + (r_arr[0] * s_arr[2] + r_arr[1] * s_arr[1] + r_arr[2] * s_arr[0]) * t[2]
3293                        + (r_arr[0] * s_arr[3]
3294                            + r_arr[1] * s_arr[2]
3295                            + r_arr[2] * s_arr[1]
3296                            + r_arr[3] * s_arr[0])
3297                            * t[3]
3298                        + (r_arr[1] * s_arr[3] + r_arr[2] * s_arr[2] + r_arr[3] * s_arr[1]) * t[4]
3299                        + (r_arr[2] * s_arr[3] + r_arr[3] * s_arr[2]) * t[5]
3300                        + (r_arr[3] * s_arr[3]) * t[6]
3301                };
3302
3303                let a_c = &inputs.cell_a[c * 4..(c + 1) * 4];
3304                let aa_c = &inputs.cell_aa[c * 4..(c + 1) * 4];
3305                f_a += d_of(a_c);
3306                f_aa += d_of(aa_c) - q_of(a_c, a_c);
3307
3308                for u in 1..r {
3309                    let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3310                    let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3311                    let ar_u = &inputs.cell_ar[r_u_off..r_u_off + 4];
3312                    f_u[u] += d_of(r_u);
3313                    f_au[u] += d_of(ar_u) - q_of(a_c, r_u);
3314                }
3315
3316                for u in 1..r {
3317                    let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3318                    let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3319                    for v in u..r {
3320                        let r_v_off = (c * (r - 1) + (v - 1)) * 4;
3321                        let r_v = &inputs.cell_r[r_v_off..r_v_off + 4];
3322                        let q_uv = q_of(r_u, r_v);
3323                        let d_s = if u == 1 && v == 1 {
3324                            let s_bb = &inputs.cell_sbb[c * 4..(c + 1) * 4];
3325                            d_of(s_bb)
3326                        } else if u == 1 && v >= 2 && v < 2 + p_h {
3327                            let j = v - 2;
3328                            let off = (c * p_h + j) * 4;
3329                            let s_bh = &inputs.cell_sbh[off..off + 4];
3330                            d_of(s_bh)
3331                        } else if u == 1 && v >= 2 + p_h && v < r {
3332                            let l = v - (2 + p_h);
3333                            let off = (c * p_w + l) * 4;
3334                            let s_bw = &inputs.cell_sbw[off..off + 4];
3335                            d_of(s_bw)
3336                        } else {
3337                            0.0
3338                        };
3339                        f_uv[u * r + v] += d_s - q_uv;
3340                    }
3341                }
3342            }
3343
3344            // q-row overrides (mirror kernel lines 691–700).
3345            let mu_1 = inputs.mu_1[row];
3346            let mu_2 = inputs.mu_2[row];
3347            f_u[0] = -mu_1;
3348            f_au[0] = 0.0;
3349            for v in 0..r {
3350                f_uv[v] = 0.0;
3351                f_uv[v * r] = 0.0;
3352            }
3353            f_uv[0] = -mu_2;
3354
3355            // Degenerate F_a ⇒ NaN-fill (mirror kernel lines 703–706).
3356            if !f_a.is_finite() || f_a <= 0.0 {
3357                neglog[row] = f64::NAN;
3358                for slot in grad[row * r..(row + 1) * r].iter_mut() {
3359                    *slot = f64::NAN;
3360                }
3361                for slot in hess[row * r * r..(row + 1) * r * r].iter_mut() {
3362                    *slot = f64::NAN;
3363                }
3364                continue;
3365            }
3366            let inv_fa = 1.0 / f_a;
3367
3368            // IFT first/second order.
3369            let mut a_u = vec![0.0_f64; r];
3370            a_u[0] = mu_1 * inv_fa;
3371            for u in 1..r {
3372                a_u[u] = -f_u[u] * inv_fa;
3373            }
3374            let mut a_uv = vec![0.0_f64; r * r];
3375            for u in 0..r {
3376                for v in u..r {
3377                    let term = f_uv[u * r + v]
3378                        + f_au[v] * a_u[u]
3379                        + f_au[u] * a_u[v]
3380                        + f_aa * a_u[u] * a_u[v];
3381                    let val = -term * inv_fa;
3382                    a_uv[u * r + v] = val;
3383                    a_uv[v * r + u] = val;
3384                }
3385            }
3386
3387            // Observed predictor jets.
3388            let chi = inputs.chi_obs[row];
3389            let xi = inputs.xi_obs[row];
3390            let rho = &inputs.rho_u[row * r..(row + 1) * r];
3391            let tau = &inputs.tau_u[row * r..(row + 1) * r];
3392            let ruv = &inputs.r_uv[row * r * r..(row + 1) * r * r];
3393            let mut bar_e_u = vec![0.0_f64; r];
3394            for u in 0..r {
3395                bar_e_u[u] = chi * a_u[u] + rho[u];
3396            }
3397            let mut bar_e_uv = vec![0.0_f64; r * r];
3398            for u in 0..r {
3399                for v in u..r {
3400                    let val = chi * a_uv[u * r + v]
3401                        + xi * a_u[u] * a_u[v]
3402                        + tau[u] * a_u[v]
3403                        + a_u[u] * tau[v]
3404                        + ruv[u * r + v];
3405                    bar_e_uv[u * r + v] = val;
3406                    if u != v {
3407                        bar_e_uv[v * r + u] = val;
3408                    }
3409                }
3410            }
3411
3412            // Probit Mills + final writes.
3413            let y = inputs.y[row];
3414            let w = inputs.w[row];
3415            let s = 2.0 * y - 1.0;
3416            let e_obs = bar_e_u[0];
3417            let m_arg = s * e_obs;
3418            let (log_cdf, lambda) = oracle_log_ndtr_and_mills(m_arg);
3419            let a_i = -w * s * lambda;
3420            let b_i = w * lambda * (m_arg + lambda);
3421            neglog[row] = -w * log_cdf;
3422            for u in 0..r {
3423                grad[row * r + u] = a_i * bar_e_u[u];
3424            }
3425            for u in 0..r {
3426                for v in u..r {
3427                    let val = b_i * bar_e_u[u] * bar_e_u[v] + a_i * bar_e_uv[u * r + v];
3428                    hess[row * r * r + u * r + v] = val;
3429                    if u != v {
3430                        hess[row * r * r + v * r + u] = val;
3431                    }
3432                }
3433            }
3434        }
3435
3436        BmsFlexRowKernelOutputs { neglog, grad, hess }
3437    }
3438
3439    /// Build a non-trivial fixture: `n = 4` rows, `r = 5` (p_h = 2, p_w = 1),
3440    /// 2–4 cells per row, distinct values so a structural bug in either path
3441    /// can't be masked by accidental cancellation.
3442    pub(crate) fn make_parity_buffers() -> TestBuffers {
3443        let n = 4_usize;
3444        let r = 5_usize;
3445        let p_h = 2_usize;
3446        let p_w = 1_usize;
3447        // Per-row cell counts: 2, 3, 4, 2 → total 11 cells.
3448        let row_cells: [u32; 4] = [2, 3, 4, 2];
3449        let mut cell_offsets = vec![0_u32; n + 1];
3450        for i in 0..n {
3451            cell_offsets[i + 1] = cell_offsets[i] + row_cells[i];
3452        }
3453        let total_cells = cell_offsets[n] as usize;
3454
3455        // Deterministic but varied generators (LCG-ish so each slot is distinct).
3456        let f = |seed: usize| -> f64 {
3457            let x = ((seed.wrapping_mul(2_654_435_761)) & 0xFFFF) as f64 / 65_536.0;
3458            0.1 + 0.4 * x
3459        };
3460
3461        let q = (0..n).map(|i| 0.05 + 0.1 * (i as f64)).collect::<Vec<_>>();
3462        let b = (0..n).map(|i| 0.6 + 0.05 * (i as f64)).collect::<Vec<_>>();
3463        let mu_1 = (0..n).map(|i| 0.7 + 0.02 * (i as f64)).collect::<Vec<_>>();
3464        let mu_2 = (0..n).map(|i| 0.15 + 0.01 * (i as f64)).collect::<Vec<_>>();
3465        let z_obs = (0..n).map(|i| -0.2 + 0.1 * (i as f64)).collect::<Vec<_>>();
3466        let y = [1.0, 0.0, 1.0, 0.0].to_vec();
3467        let w = vec![1.0; n];
3468
3469        let cell_c0 = (0..total_cells).map(|c| f(c + 1001)).collect::<Vec<_>>();
3470        let cell_c1 = (0..total_cells)
3471            .map(|c| -f(c + 2002) * 0.5)
3472            .collect::<Vec<_>>();
3473        let cell_c2 = (0..total_cells).map(|c| f(c + 3003) * 0.2).collect();
3474        let cell_c3 = (0..total_cells).map(|c| -f(c + 4004) * 0.1).collect();
3475
3476        let cell_a = (0..total_cells * 4)
3477            .map(|i| f(i + 5005) * 0.3)
3478            .collect::<Vec<_>>();
3479        let cell_aa = (0..total_cells * 4)
3480            .map(|i| f(i + 6006) * 0.1)
3481            .collect::<Vec<_>>();
3482        let cell_r = (0..total_cells * (r - 1) * 4)
3483            .map(|i| f(i + 7007) * 0.2)
3484            .collect::<Vec<_>>();
3485        let cell_ar = (0..total_cells * (r - 1) * 4)
3486            .map(|i| f(i + 8008) * 0.05)
3487            .collect::<Vec<_>>();
3488        let cell_sbb = (0..total_cells * 4)
3489            .map(|i| f(i + 9009) * 0.08)
3490            .collect::<Vec<_>>();
3491        let cell_sbh = (0..total_cells * p_h * 4)
3492            .map(|i| f(i + 10_010) * 0.07)
3493            .collect::<Vec<_>>();
3494        let cell_sbw = (0..total_cells * p_w * 4)
3495            .map(|i| f(i + 11_011) * 0.06)
3496            .collect::<Vec<_>>();
3497        let cell_moments = (0..total_cells * MOMENT_STRIDE)
3498            .map(|i| 0.4 + 0.1 * f(i + 12_012))
3499            .collect::<Vec<_>>();
3500
3501        let chi_obs = (0..n).map(|i| 0.9 + 0.01 * (i as f64)).collect::<Vec<_>>();
3502        let xi_obs = (0..n).map(|i| 0.2 + 0.01 * (i as f64)).collect::<Vec<_>>();
3503        let rho_u = (0..n * r).map(|i| 0.03 * f(i + 13_013)).collect::<Vec<_>>();
3504        let tau_u = (0..n * r).map(|i| 0.02 * f(i + 14_014)).collect::<Vec<_>>();
3505        let r_uv = (0..n * r * r)
3506            .map(|i| 0.04 * f(i + 15_015))
3507            .collect::<Vec<_>>();
3508
3509        TestBuffers {
3510            q,
3511            b,
3512            mu_1,
3513            mu_2,
3514            z_obs,
3515            y,
3516            w,
3517            cell_offsets,
3518            cell_c0,
3519            cell_c1,
3520            cell_c2,
3521            cell_c3,
3522            cell_a,
3523            cell_aa,
3524            cell_r,
3525            cell_ar,
3526            cell_sbb,
3527            cell_sbh,
3528            cell_sbw,
3529            cell_moments,
3530            chi_obs,
3531            xi_obs,
3532            rho_u,
3533            tau_u,
3534            r_uv,
3535        }
3536    }
3537
3538    pub(crate) fn parity_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3539        BmsFlexRowKernelInputs {
3540            n_rows: 4,
3541            r: 5,
3542            p_h: 2,
3543            p_w: 1,
3544            q: &buffers.q,
3545            b: &buffers.b,
3546            mu_1: &buffers.mu_1,
3547            mu_2: &buffers.mu_2,
3548            z_obs: &buffers.z_obs,
3549            y: &buffers.y,
3550            w: &buffers.w,
3551            s_f: 1.0,
3552            cell_offsets: &buffers.cell_offsets,
3553            cell_c0: &buffers.cell_c0,
3554            cell_c1: &buffers.cell_c1,
3555            cell_c2: &buffers.cell_c2,
3556            cell_c3: &buffers.cell_c3,
3557            cell_a: &buffers.cell_a,
3558            cell_aa: &buffers.cell_aa,
3559            cell_r: &buffers.cell_r,
3560            cell_ar: &buffers.cell_ar,
3561            cell_sbb: &buffers.cell_sbb,
3562            cell_sbh: &buffers.cell_sbh,
3563            cell_sbw: &buffers.cell_sbw,
3564            cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3565            chi_obs: &buffers.chi_obs,
3566            xi_obs: &buffers.xi_obs,
3567            rho_u: &buffers.rho_u,
3568            tau_u: &buffers.tau_u,
3569            r_uv: &buffers.r_uv,
3570        }
3571    }
3572
3573    /// Symmetry + finiteness of the CPU oracle. Runs on every host (Linux,
3574    /// macOS, CPU CI) since the oracle is platform-independent. Guarantees the
3575    /// reference path used by the GPU parity test is itself well-formed.
3576    #[test]
3577    pub(crate) fn cpu_oracle_produces_finite_symmetric_hessian() {
3578        let buffers = make_parity_buffers();
3579        let inputs = parity_inputs(&buffers);
3580        inputs
3581            .validate()
3582            .expect("parity fixture must satisfy validate()");
3583        let out = cpu_oracle_outputs(&inputs);
3584        let n = inputs.n_rows;
3585        let r = inputs.r;
3586        assert_eq!(out.neglog.len(), n);
3587        assert_eq!(out.grad.len(), n * r);
3588        assert_eq!(out.hess.len(), n * r * r);
3589        for row in 0..n {
3590            assert!(
3591                out.neglog[row].is_finite(),
3592                "row {row}: neglog must be finite, got {}",
3593                out.neglog[row]
3594            );
3595            for u in 0..r {
3596                let g = out.grad[row * r + u];
3597                assert!(g.is_finite(), "row {row}: grad[{u}] = {g}");
3598                for v in 0..r {
3599                    let huv = out.hess[row * r * r + u * r + v];
3600                    let hvu = out.hess[row * r * r + v * r + u];
3601                    assert!(huv.is_finite(), "row {row}: H[{u},{v}] = {huv}");
3602                    assert_eq!(
3603                        huv.to_bits(),
3604                        hvu.to_bits(),
3605                        "row {row}: H[{u},{v}] and H[{v},{u}] must be bit-identical"
3606                    );
3607                }
3608            }
3609        }
3610    }
3611
3612    /// Independent finite-difference correctness lock on the oracle's probit
3613    /// Mills layer — the most optimizer-sensitive, drift-prone term in the
3614    /// whole row kernel (issue #415: "third/fourth-order derivative
3615    /// contractions drift silently … formulas are complex and
3616    /// optimizer-sensitive"). The device kernel's `bms_flex_row_kernel` and
3617    /// the host oracle's `cpu_oracle_outputs` both close out with the same
3618    /// Mills algebra:
3619    ///
3620    /// ```text
3621    ///     m       = s · e_obs ;  s = 2y − 1
3622    ///     A       = −w · s · λ(m)
3623    ///     B       =  w · λ(m) · (m + λ(m))
3624    ///     neglog  = −w · log Φ(s · e_obs)
3625    ///     g_u     = A · bar_e_u
3626    ///     H_uv    = B · bar_e_u · bar_e_v + A · bar_e_uv
3627    /// ```
3628    ///
3629    /// Holding the observed jets `bar_e` fixed, the row neglog is a function
3630    /// of the single scalar `e := bar_e_u[0] = e_obs`, and by the assembled
3631    /// formula `∂neglog/∂e = A` and `∂²neglog/∂e² = B`. This test reconstructs
3632    /// `A`, `B`, and `neglog` exactly as the oracle does (same
3633    /// `oracle_log_ndtr_and_mills`, same sign convention), then verifies the
3634    /// analytic `A`/`B` against high-order central differences of
3635    /// `e ↦ −w · log Φ(s·e)`. A drift in the kernel's Mills derivatives —
3636    /// which the device-parity test cannot catch because it checks the kernel
3637    /// *against the same (possibly-wrong) oracle* — fails here on every host,
3638    /// CUDA or not. Bounds are the genuine fifth-order central-difference
3639    /// truncation floor; they are not weakened to pass.
3640    #[test]
3641    pub(crate) fn cpu_oracle_mills_layer_matches_finite_differences() {
3642        // Probit neglog as the oracle assembles it, as a function of the
3643        // observed scalar predictor `e` with weight `w` and label `y`.
3644        let neglog_of = |e: f64, y: f64, w: f64| -> f64 {
3645            let s = 2.0 * y - 1.0;
3646            let (log_cdf, _) = oracle_log_ndtr_and_mills(s * e);
3647            -w * log_cdf
3648        };
3649        // Analytic first/second derivatives wrt `e` — the exact `A`/`B` the
3650        // kernel writes into `grad`/`hess`.
3651        let ab_of = |e: f64, y: f64, w: f64| -> (f64, f64) {
3652            let s = 2.0 * y - 1.0;
3653            let m_arg = s * e;
3654            let (_, lambda) = oracle_log_ndtr_and_mills(m_arg);
3655            let a_i = -w * s * lambda;
3656            let b_i = w * lambda * (m_arg + lambda);
3657            (a_i, b_i)
3658        };
3659
3660        // Sweep both labels (s = ±1), both tails of the predictor, and a
3661        // non-unit weight so every sign/scale path of the Mills algebra is
3662        // exercised. Points stay clear of the deep-tail asymptote where a
3663        // central-difference reference loses its own accuracy.
3664        let cases: [(f64, f64, f64); 12] = [
3665            (-1.6, 1.0, 1.0),
3666            (-0.7, 1.0, 1.0),
3667            (0.0, 1.0, 1.0),
3668            (0.9, 1.0, 1.0),
3669            (1.8, 1.0, 1.0),
3670            (-1.4, 0.0, 1.0),
3671            (-0.3, 0.0, 1.0),
3672            (0.0, 0.0, 1.0),
3673            (0.6, 0.0, 1.0),
3674            (1.5, 0.0, 1.0),
3675            (0.4, 1.0, 0.75),
3676            (-0.8, 0.0, 1.3),
3677        ];
3678        // Fifth-order central stencils; `h` chosen near the f64 sweet spot for
3679        // first/second derivatives of a smooth O(1) function.
3680        let h = 1e-3_f64;
3681        for (e, y, w) in cases {
3682            let (a_ana, b_ana) = ab_of(e, y, w);
3683
3684            let fp2 = neglog_of(e + 2.0 * h, y, w);
3685            let fp1 = neglog_of(e + h, y, w);
3686            let f0 = neglog_of(e, y, w);
3687            let fm1 = neglog_of(e - h, y, w);
3688            let fm2 = neglog_of(e - 2.0 * h, y, w);
3689
3690            // 5-point central first derivative: O(h⁴).
3691            let d1_fd = (-fp2 + 8.0 * fp1 - 8.0 * fm1 + fm2) / (12.0 * h);
3692            // 5-point central second derivative: O(h⁴).
3693            let d2_fd = (-fp2 + 16.0 * fp1 - 30.0 * f0 + 16.0 * fm1 - fm2) / (12.0 * h * h);
3694
3695            let a_abs = (a_ana - d1_fd).abs();
3696            let a_rel = a_abs / a_ana.abs().max(1.0);
3697            assert!(
3698                a_abs <= 5e-8 || a_rel <= 5e-8,
3699                "Mills A (∂neglog/∂e) drift at e={e} y={y} w={w}: \
3700                 analytic={a_ana:.17e} fd={d1_fd:.17e} abs={a_abs:.3e} rel={a_rel:.3e}"
3701            );
3702
3703            let b_abs = (b_ana - d2_fd).abs();
3704            let b_rel = b_abs / b_ana.abs().max(1.0);
3705            assert!(
3706                b_abs <= 5e-6 || b_rel <= 5e-6,
3707                "Mills B (∂²neglog/∂e²) drift at e={e} y={y} w={w}: \
3708                 analytic={b_ana:.17e} fd={d2_fd:.17e} abs={b_abs:.3e} rel={b_rel:.3e}"
3709            );
3710        }
3711    }
3712
3713    /// CPU↔GPU parity. Only runs end-to-end on a Linux host with a CUDA
3714    /// runtime; skips with a clear `eprintln!` on every other host so the
3715    /// always-on test suite stays green on the macOS dev box and CPU CI.
3716    ///
3717    /// On a CUDA host: drives the kernel through `launch_bms_flex_row_kernel`
3718    /// and the same `BmsFlexRowKernelInputs` through `cpu_oracle_outputs`,
3719    /// then asserts every element of `neglog`, `grad`, and `hess` agrees
3720    /// within `|Δ| <= 1e-8 + 1e-8·|cpu|` (absolute-or-relative).
3721    #[test]
3722    pub(crate) fn bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available() {
3723        #[cfg(not(target_os = "linux"))]
3724        {
3725            eprintln!(
3726                "[bms_flex_row parity] non-Linux host — skipping CUDA parity \
3727                 (CPU oracle exercised by sibling test)"
3728            );
3729            return;
3730        }
3731        #[cfg(target_os = "linux")]
3732        {
3733            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
3734                eprintln!(
3735                    "[bms_flex_row parity] no CUDA runtime — skipping device \
3736                     parity (CPU oracle exercised by sibling test)"
3737                );
3738                return;
3739            };
3740            let buffers = make_parity_buffers();
3741            let inputs_cpu = parity_inputs(&buffers);
3742            inputs_cpu
3743                .validate()
3744                .expect("parity fixture must satisfy validate()");
3745            let cpu_out = cpu_oracle_outputs(&inputs_cpu);
3746
3747            // Launch the device kernel against the same inputs.
3748            let inputs_gpu = parity_inputs(&buffers);
3749            let gpu_out = match launch_bms_flex_row_kernel(inputs_gpu) {
3750                Ok(out) => out,
3751                Err(err) => panic!(
3752                    "[bms_flex_row parity] launch failed on CUDA-selected host; \
3753                     device/oracle parity must fail loudly on GPU CI: {err}"
3754                ),
3755            };
3756
3757            let n = inputs_cpu.n_rows;
3758            let r = inputs_cpu.r;
3759            let tol_abs = 1e-8_f64;
3760            let tol_rel = 1e-8_f64;
3761            let check_close = |label: &str, idx: usize, cpu: f64, gpu: f64| {
3762                if cpu.is_nan() || gpu.is_nan() {
3763                    assert!(
3764                        cpu.is_nan() && gpu.is_nan(),
3765                        "{label}[{idx}]: NaN parity broke — cpu={cpu}, gpu={gpu}"
3766                    );
3767                    return;
3768                }
3769                let diff = (cpu - gpu).abs();
3770                let tol = tol_abs + tol_rel * cpu.abs();
3771                assert!(
3772                    diff <= tol,
3773                    "{label}[{idx}]: |cpu − gpu| = {diff:.3e} > tol = {tol:.3e}; \
3774                     cpu={cpu:.17e}, gpu={gpu:.17e}"
3775                );
3776            };
3777            assert_eq!(cpu_out.neglog.len(), gpu_out.neglog.len());
3778            assert_eq!(cpu_out.grad.len(), gpu_out.grad.len());
3779            assert_eq!(cpu_out.hess.len(), gpu_out.hess.len());
3780            for (i, (&c, &g)) in cpu_out.neglog.iter().zip(gpu_out.neglog.iter()).enumerate() {
3781                check_close("neglog", i, c, g);
3782            }
3783            for (i, (&c, &g)) in cpu_out.grad.iter().zip(gpu_out.grad.iter()).enumerate() {
3784                check_close("grad", i, c, g);
3785            }
3786            for (i, (&c, &g)) in cpu_out.hess.iter().zip(gpu_out.hess.iter()).enumerate() {
3787                check_close("hess", i, c, g);
3788            }
3789            // Spot-check exact symmetry on the GPU Hessian too.
3790            for row in 0..n {
3791                for u in 0..r {
3792                    for v in 0..r {
3793                        let a = gpu_out.hess[row * r * r + u * r + v];
3794                        let bb = gpu_out.hess[row * r * r + v * r + u];
3795                        assert_eq!(
3796                            a.to_bits(),
3797                            bb.to_bits(),
3798                            "GPU row {row}: H[{u},{v}] ≠ H[{v},{u}] bit-for-bit"
3799                        );
3800                    }
3801                }
3802            }
3803        }
3804    }
3805
3806    #[test]
3807    pub(crate) fn kernel_source_mentions_cpu_parity_reference() {
3808        // Guarantee the maintainer-facing parity reference comment survives
3809        // refactors of the NVRTC kernel source — the dispatcher wave that
3810        // wires this to bms_flex.rs cross-checks parity against the CPU
3811        // function named here.
3812        #[cfg(target_os = "linux")]
3813        assert!(ROW_KERNEL_BODY.contains("compute_row_analytic_flex_from_parts_into"));
3814        #[cfg(target_os = "linux")]
3815        assert!(ROW_KERNEL_BODY.contains("cell_first_derivative_from_moments"));
3816    }
3817
3818    // ── Phase-3 HVP / diagonal CPU oracles + GPU parity tests ────────────────
3819
3820    /// CPU oracle for [`launch_bms_flex_row_hvp`]. Mirrors the device kernel
3821    /// element-for-element so the GPU parity test runs against the same algebra.
3822    pub(crate) fn cpu_oracle_bms_flex_row_hvp(
3823        row_hessians: &[f64],
3824        marginal_design: &[f64],
3825        logslope_design: &[f64],
3826        block: &BmsFlexBlockLayout,
3827        primary: &BmsFlexPrimaryLayout,
3828        n: usize,
3829        v: &[f64],
3830    ) -> Vec<f64> {
3831        let r = primary.r;
3832        let p_m = block.p_m;
3833        let p_g = block.p_g;
3834        assert_eq!(v.len(), block.p_total);
3835        assert_eq!(row_hessians.len(), n * r * r);
3836        assert_eq!(marginal_design.len(), n * p_m);
3837        assert_eq!(logslope_design.len(), n * p_g);
3838        let mut out = vec![0.0_f64; block.p_total];
3839        let mut row_dir = vec![0.0_f64; r];
3840        let mut action = vec![0.0_f64; r];
3841        for row in 0..n {
3842            let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
3843            let grow = &logslope_design[row * p_g..(row + 1) * p_g];
3844            let mut acc_q = 0.0_f64;
3845            for j in 0..p_m {
3846                acc_q += mrow[j] * v[j];
3847            }
3848            let mut acc_g = 0.0_f64;
3849            for j in 0..p_g {
3850                acc_g += grow[j] * v[p_m + j];
3851            }
3852            row_dir[0] = acc_q;
3853            row_dir[1] = acc_g;
3854            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3855                for (k, ii) in prange.clone().enumerate() {
3856                    row_dir[ii] = v[brange.start + k];
3857                }
3858            }
3859            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3860                for (k, ii) in prange.clone().enumerate() {
3861                    row_dir[ii] = v[brange.start + k];
3862                }
3863            }
3864            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3865            for u in 0..r {
3866                let mut acc = 0.0_f64;
3867                for v_idx in 0..r {
3868                    acc += h_slice[u * r + v_idx] * row_dir[v_idx];
3869                }
3870                action[u] = acc;
3871            }
3872            let a0 = action[0];
3873            for j in 0..p_m {
3874                out[j] += a0 * mrow[j];
3875            }
3876            let a1 = action[1];
3877            for j in 0..p_g {
3878                out[p_m + j] += a1 * grow[j];
3879            }
3880            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3881                for (k, ii) in prange.clone().enumerate() {
3882                    out[brange.start + k] += action[ii];
3883                }
3884            }
3885            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3886                for (k, ii) in prange.clone().enumerate() {
3887                    out[brange.start + k] += action[ii];
3888                }
3889            }
3890        }
3891        out
3892    }
3893
3894    pub(crate) fn cpu_oracle_bms_flex_row_diagonal(
3895        row_hessians: &[f64],
3896        marginal_design: &[f64],
3897        logslope_design: &[f64],
3898        block: &BmsFlexBlockLayout,
3899        primary: &BmsFlexPrimaryLayout,
3900        n: usize,
3901    ) -> Vec<f64> {
3902        let r = primary.r;
3903        let p_m = block.p_m;
3904        let p_g = block.p_g;
3905        let mut out = vec![0.0_f64; block.p_total];
3906        for row in 0..n {
3907            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3908            let h00 = h_slice[0];
3909            let h11 = h_slice[r + 1];
3910            let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
3911            let grow = &logslope_design[row * p_g..(row + 1) * p_g];
3912            for j in 0..p_m {
3913                out[j] += h00 * mrow[j] * mrow[j];
3914            }
3915            for j in 0..p_g {
3916                out[p_m + j] += h11 * grow[j] * grow[j];
3917            }
3918            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3919                for (k, ii) in prange.clone().enumerate() {
3920                    out[brange.start + k] += h_slice[ii * r + ii];
3921                }
3922            }
3923            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3924                for (k, ii) in prange.clone().enumerate() {
3925                    out[brange.start + k] += h_slice[ii * r + ii];
3926                }
3927            }
3928        }
3929        out
3930    }
3931
3932    /// Hand-construct a small symmetric per-row Hessian + small designs and
3933    /// verify the CPU oracle satisfies the expected algebra. Platform-
3934    /// independent (runs on macOS / Linux without CUDA).
3935    #[test]
3936    pub(crate) fn cpu_oracle_hvp_matches_hand_computation_no_hw() {
3937        let n = 4_usize;
3938        let r = 4_usize; // q, logslope, h(1), w(1)
3939        let p_m = 2_usize;
3940        let p_g = 2_usize;
3941        let p_h_dim = 1_usize;
3942        let p_w_dim = 1_usize;
3943        let p_total = p_m + p_g + p_h_dim + p_w_dim;
3944        let block = BmsFlexBlockLayout {
3945            p_m,
3946            p_g,
3947            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
3948            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
3949            p_total,
3950        };
3951        let primary = BmsFlexPrimaryLayout {
3952            h: Some(2..3),
3953            w: Some(3..4),
3954            r,
3955        };
3956        // Symmetric per-row Hessian: H_row[u,v] = (row + 1) * (1 + u + 2v) symmetrised.
3957        let mut row_hessians = vec![0.0_f64; n * r * r];
3958        for row in 0..n {
3959            for u in 0..r {
3960                for v in u..r {
3961                    let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
3962                    row_hessians[row * r * r + u * r + v] = val;
3963                    row_hessians[row * r * r + v * r + u] = val;
3964                }
3965            }
3966        }
3967        let mut marginal = vec![0.0_f64; n * p_m];
3968        for row in 0..n {
3969            for j in 0..p_m {
3970                marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
3971            }
3972        }
3973        let mut logslope = vec![0.0_f64; n * p_g];
3974        for row in 0..n {
3975            for j in 0..p_g {
3976                logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
3977            }
3978        }
3979        let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
3980        let out = cpu_oracle_bms_flex_row_hvp(
3981            &row_hessians,
3982            &marginal,
3983            &logslope,
3984            &block,
3985            &primary,
3986            n,
3987            &v,
3988        );
3989        // Hand check the first marginal slot: out[0] = Σ_row action[0]·mrow[0].
3990        let mut expect_out_0 = 0.0_f64;
3991        for row in 0..n {
3992            let mrow = &marginal[row * p_m..(row + 1) * p_m];
3993            let grow = &logslope[row * p_g..(row + 1) * p_g];
3994            let mut row_dir = vec![0.0_f64; r];
3995            row_dir[0] = mrow[0] * v[0] + mrow[1] * v[1];
3996            row_dir[1] = grow[0] * v[p_m] + grow[1] * v[p_m + 1];
3997            row_dir[2] = v[p_m + p_g];
3998            row_dir[3] = v[p_m + p_g + p_h_dim];
3999            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4000            let mut action0 = 0.0_f64;
4001            // h_slice is the row-major r×r Hessian for this row; we want
4002            // row 0, i.e. entries (0, vv) for vv in 0..r, which lives at
4003            // `vv` in the flat layout.
4004            for vv in 0..r {
4005                action0 += h_slice[vv] * row_dir[vv];
4006            }
4007            expect_out_0 += action0 * mrow[0];
4008        }
4009        assert!(
4010            (out[0] - expect_out_0).abs() < 1e-12,
4011            "cpu oracle HVP out[0] mismatch: {} vs hand-check {}",
4012            out[0],
4013            expect_out_0
4014        );
4015        assert!(out.iter().all(|x| x.is_finite()));
4016        assert_eq!(out.len(), p_total);
4017    }
4018
4019    /// Diagonal oracle equals the explicit per-row design² accumulator.
4020    #[test]
4021    pub(crate) fn cpu_oracle_diagonal_matches_hand_computation() {
4022        let n = 3_usize;
4023        let r = 4_usize;
4024        let p_m = 2_usize;
4025        let p_g = 2_usize;
4026        let p_h_dim = 1_usize;
4027        let p_w_dim = 1_usize;
4028        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4029        let block = BmsFlexBlockLayout {
4030            p_m,
4031            p_g,
4032            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4033            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4034            p_total,
4035        };
4036        let primary = BmsFlexPrimaryLayout {
4037            h: Some(2..3),
4038            w: Some(3..4),
4039            r,
4040        };
4041        let mut row_hessians = vec![0.0_f64; n * r * r];
4042        for row in 0..n {
4043            for u in 0..r {
4044                row_hessians[row * r * r + u * r + u] = 1.0 + (row as f64) + (u as f64) * 0.5;
4045            }
4046        }
4047        let mut marginal = vec![0.0_f64; n * p_m];
4048        let mut logslope = vec![0.0_f64; n * p_g];
4049        for row in 0..n {
4050            for j in 0..p_m {
4051                marginal[row * p_m + j] = 0.2 + (row as f64) * 0.3 + (j as f64) * 0.1;
4052            }
4053            for j in 0..p_g {
4054                logslope[row * p_g + j] = -0.4 + (row as f64) * 0.1 + (j as f64) * 0.2;
4055            }
4056        }
4057        let out = cpu_oracle_bms_flex_row_diagonal(
4058            &row_hessians,
4059            &marginal,
4060            &logslope,
4061            &block,
4062            &primary,
4063            n,
4064        );
4065        // Hand check: out[0] = Σ_row H[row,0,0] · marginal[row,0]^2.
4066        let mut expect = 0.0_f64;
4067        for row in 0..n {
4068            let h00 = row_hessians[row * r * r];
4069            expect += h00 * marginal[row * p_m].powi(2);
4070        }
4071        assert!(
4072            (out[0] - expect).abs() < 1e-12,
4073            "out[0] {} vs {}",
4074            out[0],
4075            expect
4076        );
4077        // h slot = sum of H[row, 2, 2] across rows.
4078        let mut expect_h = 0.0_f64;
4079        for row in 0..n {
4080            expect_h += row_hessians[row * r * r + 2 * r + 2];
4081        }
4082        let h_slot = p_m + p_g;
4083        assert!(
4084            (out[h_slot] - expect_h).abs() < 1e-12,
4085            "h slot {} vs {}",
4086            out[h_slot],
4087            expect_h
4088        );
4089    }
4090
4091    /// GPU↔CPU parity for the HVP and diagonal kernels. Skips on non-Linux /
4092    /// no-CUDA hosts. Hand-constructs a small `DeviceResidentRowHess` by
4093    /// allocating the device slices directly, uploading the same arrays the
4094    /// CPU oracle consumes, then dispatching the device kernels.
4095    #[test]
4096    pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_when_cuda_available() {
4097        #[cfg(not(target_os = "linux"))]
4098        {
4099            eprintln!(
4100                "[bms_flex_row hvp parity] non-Linux host — skipping CUDA parity \
4101                 (CPU oracle exercised by sibling tests)"
4102            );
4103        }
4104        #[cfg(target_os = "linux")]
4105        {
4106            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4107                eprintln!(
4108                    "[bms_flex_row hvp parity] no CUDA runtime — skipping device \
4109                     parity"
4110                );
4111                return;
4112            };
4113            let n = 4_usize;
4114            let r = 4_usize;
4115            let p_m = 2_usize;
4116            let p_g = 2_usize;
4117            let p_h_dim = 1_usize;
4118            let p_w_dim = 1_usize;
4119            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4120            let block = BmsFlexBlockLayout {
4121                p_m,
4122                p_g,
4123                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4124                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4125                p_total,
4126            };
4127            let primary = BmsFlexPrimaryLayout {
4128                h: Some(2..3),
4129                w: Some(3..4),
4130                r,
4131            };
4132            let mut row_hessians = vec![0.0_f64; n * r * r];
4133            for row in 0..n {
4134                for u in 0..r {
4135                    for v in u..r {
4136                        let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4137                        row_hessians[row * r * r + u * r + v] = val;
4138                        row_hessians[row * r * r + v * r + u] = val;
4139                    }
4140                }
4141            }
4142            let mut marginal = vec![0.0_f64; n * p_m];
4143            for row in 0..n {
4144                for j in 0..p_m {
4145                    marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4146                }
4147            }
4148            let mut logslope = vec![0.0_f64; n * p_g];
4149            for row in 0..n {
4150                for j in 0..p_g {
4151                    logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4152                }
4153            }
4154            let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4155            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4156                &row_hessians,
4157                &marginal,
4158                &logslope,
4159                &block,
4160                &primary,
4161                n,
4162                &v,
4163            );
4164            let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4165                &row_hessians,
4166                &marginal,
4167                &logslope,
4168                &block,
4169                &primary,
4170                n,
4171            );
4172
4173            // Allocate a DeviceResidentRowHess by hand using the HVP backend's
4174            // stream + module so we don't need to drive the full BMS row kernel.
4175            // Past the GpuRuntime::global() Some-gate above: a probe/upload failure
4176            // here is a real device fault on a CUDA host, not a no-CUDA skip. Fail
4177            // loud (the device-PCG skip-pass class, eee12f6b2) — the old arms
4178            // returned and the test passed while exercising nothing.
4179            let backend = HvpKernelBackend::probe()
4180                .expect("[bms_flex_row hvp parity] backend probe must succeed on CUDA host");
4181            let stream = backend.stream.clone();
4182            let d_h = stream
4183                .clone_htod(&row_hessians)
4184                .expect("[bms_flex_row hvp parity] upload h must succeed on CUDA host");
4185            let d_m = stream
4186                .clone_htod(&marginal)
4187                .expect("[bms_flex_row hvp parity] upload marg must succeed on CUDA host");
4188            let d_g = stream
4189                .clone_htod(&logslope)
4190                .expect("[bms_flex_row hvp parity] upload logslope must succeed on CUDA host");
4191            let storage = DeviceResidentRowHess {
4192                hess: d_h,
4193                marginal_design: d_m,
4194                logslope_design: d_g,
4195                n,
4196                r,
4197                block: block.clone(),
4198                primary: primary.clone(),
4199
4200                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4201            };
4202            let gpu_hvp =
4203                launch_bms_flex_row_hvp(&storage, &v).expect("HVP kernel must launch on CUDA host");
4204            let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4205                .expect("diagonal kernel must launch on CUDA host");
4206            assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4207            assert_eq!(gpu_diag.len(), cpu_diag.len());
4208            for i in 0..p_total {
4209                let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4210                assert!(
4211                    diff <= 1e-10,
4212                    "HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4213                    cpu_hvp[i],
4214                    gpu_hvp[i]
4215                );
4216                let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4217                assert!(
4218                    ddiff <= 1e-10,
4219                    "diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4220                    cpu_diag[i],
4221                    gpu_diag[i]
4222                );
4223            }
4224        }
4225    }
4226
4227    #[test]
4228    pub(crate) fn bms_flex_row_hvp_multi_scratch_is_bounded_at_large_scale_shape() {
4229        let n = 195_000_usize;
4230        let r = 20_usize;
4231        let p_total = 44_usize;
4232        let rhs_count = 4_usize;
4233        let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4234            .expect("large-scale multi-RHS scratch budget");
4235        let per_rhs_full_row_cache =
4236            (n * r * r * std::mem::size_of::<f64>()) as u64 * rhs_count as u64;
4237        assert!(
4238            scratch < per_rhs_full_row_cache / 100,
4239            "multi-RHS scratch must tile by row chunks instead of materializing \
4240             a row-Hessian copy per RHS: scratch={scratch} full_per_rhs={per_rhs_full_row_cache}"
4241        );
4242        assert!(
4243            bms_flex_row_hvp_multi_scratch_bytes_for_shape(
4244                n,
4245                p_total,
4246                BMS_FLEX_ROW_HVP_MAX_RHS + 1
4247            )
4248            .is_err(),
4249            "multi-RHS launch must reject unbounded RHS counts"
4250        );
4251    }
4252
4253    #[test]
4254    pub(crate) fn bms_flex_row_hvp_multi_kernel_matches_cpu_oracle_when_cuda_available() {
4255        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4256            eprintln!("[bms_flex_row hvp_multi parity] no CUDA runtime — skipping device parity");
4257            return;
4258        };
4259        let n = 5_usize;
4260        let r = 4_usize;
4261        let p_m = 2_usize;
4262        let p_g = 2_usize;
4263        let p_h_dim = 1_usize;
4264        let p_w_dim = 1_usize;
4265        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4266        let rhs_count = 3_usize;
4267        let block = BmsFlexBlockLayout {
4268            p_m,
4269            p_g,
4270            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4271            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4272            p_total,
4273        };
4274        let primary = BmsFlexPrimaryLayout {
4275            h: Some(2..3),
4276            w: Some(3..4),
4277            r,
4278        };
4279        let mut row_hessians = vec![0.0_f64; n * r * r];
4280        for row in 0..n {
4281            for u in 0..r {
4282                for v in u..r {
4283                    let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4284                    row_hessians[row * r * r + u * r + v] = val;
4285                    row_hessians[row * r * r + v * r + u] = val;
4286                }
4287            }
4288        }
4289        let mut marginal = vec![0.0_f64; n * p_m];
4290        let mut logslope = vec![0.0_f64; n * p_g];
4291        for row in 0..n {
4292            for j in 0..p_m {
4293                marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4294            }
4295            for j in 0..p_g {
4296                logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4297            }
4298        }
4299        let mut v_rhs = vec![0.0_f64; rhs_count * p_total];
4300        for rhs in 0..rhs_count {
4301            for j in 0..p_total {
4302                let seed = (rhs as f64) * 0.37 + (j as f64) * 0.19 + 0.4;
4303                v_rhs[rhs * p_total + j] = seed.sin() * 0.4 + seed.cos() * 0.2;
4304            }
4305        }
4306
4307        // Past the GpuRuntime::global() Some-gate: a probe/upload failure here is a
4308        // real device fault on a CUDA host, not a no-CUDA skip — fail loud
4309        // (device-PCG skip-pass class, eee12f6b2).
4310        let backend = HvpKernelBackend::probe()
4311            .expect("[bms_flex_row hvp_multi parity] backend probe must succeed on CUDA host");
4312        let stream = backend.stream.clone();
4313        let d_h = stream
4314            .clone_htod(&row_hessians)
4315            .expect("[bms_flex_row hvp_multi parity] upload h must succeed on CUDA host");
4316        let d_m = stream
4317            .clone_htod(&marginal)
4318            .expect("[bms_flex_row hvp_multi parity] upload marg must succeed on CUDA host");
4319        let d_g = stream
4320            .clone_htod(&logslope)
4321            .expect("[bms_flex_row hvp_multi parity] upload logslope must succeed on CUDA host");
4322        let storage = DeviceResidentRowHess {
4323            hess: d_h,
4324            marginal_design: d_m,
4325            logslope_design: d_g,
4326            n,
4327            r,
4328            block: block.clone(),
4329            primary: primary.clone(),
4330
4331            bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4332        };
4333        let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4334            .expect("storage scratch budget");
4335        assert!(
4336            scratch < storage.bytes,
4337            "multi-RHS scratch should stay below resident cache bytes"
4338        );
4339        let gpu = launch_bms_flex_row_hvp_multi(&storage, &v_rhs, rhs_count)
4340            .expect("multi-RHS HVP kernel must launch on CUDA host");
4341        assert_eq!(gpu.len(), rhs_count * p_total);
4342        for rhs in 0..rhs_count {
4343            let v = &v_rhs[rhs * p_total..(rhs + 1) * p_total];
4344            let cpu = cpu_oracle_bms_flex_row_hvp(
4345                &row_hessians,
4346                &marginal,
4347                &logslope,
4348                &block,
4349                &primary,
4350                n,
4351                v,
4352            );
4353            let single = launch_bms_flex_row_hvp(&storage, v)
4354                .expect("single-RHS HVP kernel must launch on CUDA host");
4355            for j in 0..p_total {
4356                let got = gpu[rhs * p_total + j];
4357                let diff = (cpu[j] - got).abs();
4358                assert!(
4359                    diff <= 1e-10,
4360                    "multi-RHS HVP rhs={rhs} j={j}: cpu={} gpu={} |diff|={diff:.3e}",
4361                    cpu[j],
4362                    got
4363                );
4364                assert_eq!(
4365                    got, single[j],
4366                    "multi-RHS and single-RHS host launch diverged at rhs={rhs} j={j}"
4367                );
4368            }
4369        }
4370    }
4371
4372    /// Parity for the third launch mode — device-output HVP
4373    /// ([`launch_bms_flex_row_hvp_into_device`]) — which the
4374    /// `run_bms_flex_row_partial_reduce` unification routes through the same
4375    /// partial+reduce engine as the host-returning `_hvp` / `_diagonal`
4376    /// adapters. Confirms that keeping the result on-stream (no internal sync /
4377    /// DtoH) reaches bit-identical output to both the CPU oracle and the
4378    /// host-out adapter, so the engine's mode/output split is faithful.
4379    ///
4380    /// Skips cleanly on non-Linux / no-CUDA hosts using the convention shared
4381    /// with the sibling parity tests.
4382    #[test]
4383    pub(crate) fn bms_flex_row_hvp_into_device_matches_cpu_oracle_and_host_out() {
4384        #[cfg(not(target_os = "linux"))]
4385        {
4386            eprintln!(
4387                "[bms_flex_row hvp_into_device parity] non-Linux host — skipping \
4388                 CUDA parity (CPU oracle exercised by sibling tests)"
4389            );
4390        }
4391        #[cfg(target_os = "linux")]
4392        {
4393            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4394                eprintln!(
4395                    "[bms_flex_row hvp_into_device parity] no CUDA runtime — \
4396                     skipping device parity"
4397                );
4398                return;
4399            };
4400            let n = 4_usize;
4401            let r = 4_usize;
4402            let p_m = 2_usize;
4403            let p_g = 2_usize;
4404            let p_h_dim = 1_usize;
4405            let p_w_dim = 1_usize;
4406            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4407            let block = BmsFlexBlockLayout {
4408                p_m,
4409                p_g,
4410                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4411                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4412                p_total,
4413            };
4414            let primary = BmsFlexPrimaryLayout {
4415                h: Some(2..3),
4416                w: Some(3..4),
4417                r,
4418            };
4419            let mut row_hessians = vec![0.0_f64; n * r * r];
4420            for row in 0..n {
4421                for u in 0..r {
4422                    for v in u..r {
4423                        let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4424                        row_hessians[row * r * r + u * r + v] = val;
4425                        row_hessians[row * r * r + v * r + u] = val;
4426                    }
4427                }
4428            }
4429            let mut marginal = vec![0.0_f64; n * p_m];
4430            for row in 0..n {
4431                for j in 0..p_m {
4432                    marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4433                }
4434            }
4435            let mut logslope = vec![0.0_f64; n * p_g];
4436            for row in 0..n {
4437                for j in 0..p_g {
4438                    logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4439                }
4440            }
4441            let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4442            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4443                &row_hessians,
4444                &marginal,
4445                &logslope,
4446                &block,
4447                &primary,
4448                n,
4449                &v,
4450            );
4451
4452            // Past the GpuRuntime::global() Some-gate: probe/upload failures are
4453            // real device faults on a CUDA host — fail loud (device-PCG class).
4454            let backend = HvpKernelBackend::probe().expect(
4455                "[bms_flex_row hvp_into_device parity] backend probe must succeed on CUDA host",
4456            );
4457            let stream = backend.stream.clone();
4458            let d_h = stream
4459                .clone_htod(&row_hessians)
4460                .expect("[bms_flex_row hvp_into_device parity] upload h must succeed on CUDA host");
4461            let d_m = stream.clone_htod(&marginal).expect(
4462                "[bms_flex_row hvp_into_device parity] upload marg must succeed on CUDA host",
4463            );
4464            let d_g = stream.clone_htod(&logslope).expect(
4465                "[bms_flex_row hvp_into_device parity] upload logslope must succeed on CUDA host",
4466            );
4467            let storage = DeviceResidentRowHess {
4468                hess: d_h,
4469                marginal_design: d_m,
4470                logslope_design: d_g,
4471                n,
4472                r,
4473                block: block.clone(),
4474                primary: primary.clone(),
4475
4476                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4477            };
4478
4479            // Host-out adapter (allocates its own d_out, syncs + downloads).
4480            let host_out_hvp = launch_bms_flex_row_hvp(&storage, &v)
4481                .expect("host-out HVP kernel must launch on CUDA host");
4482
4483            // Device-out adapter: caller owns d_v + d_out; the engine performs
4484            // no sync / DtoH, so we synchronize + download here.
4485            let d_v = stream
4486                .clone_htod(&v)
4487                .expect("upload direction for device-out HVP");
4488            let mut d_out = stream
4489                .alloc_zeros::<f64>(p_total)
4490                .expect("alloc device-out HVP output");
4491            launch_bms_flex_row_hvp_into_device(&storage, &d_v, &mut d_out)
4492                .expect("device-out HVP kernel must launch on CUDA host");
4493            stream
4494                .synchronize()
4495                .expect("synchronize after device-out HVP");
4496            let device_out_hvp = stream
4497                .clone_dtoh(&d_out)
4498                .expect("download device-out HVP output");
4499
4500            assert_eq!(device_out_hvp.len(), cpu_hvp.len());
4501            assert_eq!(device_out_hvp.len(), host_out_hvp.len());
4502            for i in 0..p_total {
4503                let diff = (cpu_hvp[i] - device_out_hvp[i]).abs();
4504                assert!(
4505                    diff <= 1e-10,
4506                    "device-out HVP[{i}] vs CPU: cpu={} gpu={} |Δ|={diff:.3e}",
4507                    cpu_hvp[i],
4508                    device_out_hvp[i]
4509                );
4510                // Both adapters share the engine; the only difference is the
4511                // copy-back path, so they must be bit-identical.
4512                let host_diff = (host_out_hvp[i] - device_out_hvp[i]).abs();
4513                assert!(
4514                    host_diff == 0.0,
4515                    "device-out vs host-out HVP[{i}]: host={} device={} |Δ|={host_diff:.3e}",
4516                    host_out_hvp[i],
4517                    device_out_hvp[i]
4518                );
4519            }
4520        }
4521    }
4522
4523    /// Block 9 Phase 2 parity gate at the shape specified by the
4524    /// charter task: `n = 64`, `r = 20`, `p_total = 44`. Splits
4525    /// `p_total` as `p_m = 14`, `p_g = 12`, `p_h = 10`, `p_w = 8` so
4526    /// `r = 2 + p_h + p_w = 20` and every primary block participates
4527    /// in both the device pullback and the reduce pass. Tolerance is
4528    /// `|Δ| ≤ 1e-8` per the task description (looser than the 1e-10
4529    /// hand-fixture parity, since accumulation order across HVP CTAs
4530    /// differs from the CPU oracle's row-major sum even with the
4531    /// deterministic reduction policy).
4532    ///
4533    /// Skips cleanly on non-Linux and no-CUDA hosts using the same
4534    /// convention as the hand-fixture parity above.
4535    #[test]
4536    pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_at_n64_r20_p44() {
4537        #[cfg(not(target_os = "linux"))]
4538        {
4539            eprintln!(
4540                "[bms_flex_row hvp parity n64_r20_p44] non-Linux host — \
4541                 skipping CUDA parity"
4542            );
4543        }
4544        #[cfg(target_os = "linux")]
4545        {
4546            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4547                eprintln!(
4548                    "[bms_flex_row hvp parity n64_r20_p44] no CUDA runtime — \
4549                     skipping device parity"
4550                );
4551                return;
4552            };
4553            let n = 64_usize;
4554            let p_m = 14_usize;
4555            let p_g = 12_usize;
4556            let p_h_dim = 10_usize;
4557            let p_w_dim = 8_usize;
4558            let r = 2 + p_h_dim + p_w_dim;
4559            assert_eq!(r, 20);
4560            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4561            assert_eq!(p_total, 44);
4562            let block = BmsFlexBlockLayout {
4563                p_m,
4564                p_g,
4565                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4566                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4567                p_total,
4568            };
4569            let primary = BmsFlexPrimaryLayout {
4570                h: Some(2..2 + p_h_dim),
4571                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4572                r,
4573            };
4574
4575            // Deterministic symmetric per-row Hessians + designs +
4576            // direction. Same scrambling family as
4577            // `row_hessian_ops::tests::make_fixture` so any regression
4578            // surfaces consistently across the host-pinned and
4579            // device-resident parity tests.
4580            let mut row_hessians = vec![0.0_f64; n * r * r];
4581            for row in 0..n {
4582                let base = row * r * r;
4583                for u in 0..r {
4584                    for v in 0..r {
4585                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
4586                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4587                        row_hessians[base + u * r + v] = a;
4588                    }
4589                }
4590                for u in 0..r {
4591                    for v in (u + 1)..r {
4592                        let upper = row_hessians[base + u * r + v];
4593                        let lower = row_hessians[base + v * r + u];
4594                        let sym = 0.5 * (upper + lower);
4595                        row_hessians[base + u * r + v] = sym;
4596                        row_hessians[base + v * r + u] = sym;
4597                    }
4598                    row_hessians[base + u * r + u] += r as f64;
4599                }
4600            }
4601            let mut marginal = vec![0.0_f64; n * p_m];
4602            for row in 0..n {
4603                for j in 0..p_m {
4604                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4605                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4606                }
4607            }
4608            let mut logslope = vec![0.0_f64; n * p_g];
4609            for row in 0..n {
4610                for j in 0..p_g {
4611                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4612                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4613                }
4614            }
4615            let v: Vec<f64> = (0..p_total)
4616                .map(|i| {
4617                    let seed = (i as f64) * 0.157 + 0.6;
4618                    seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4619                })
4620                .collect();
4621
4622            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4623                &row_hessians,
4624                &marginal,
4625                &logslope,
4626                &block,
4627                &primary,
4628                n,
4629                &v,
4630            );
4631            let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4632                &row_hessians,
4633                &marginal,
4634                &logslope,
4635                &block,
4636                &primary,
4637                n,
4638            );
4639
4640            let backend = match HvpKernelBackend::probe() {
4641                Ok(b) => b,
4642                Err(err) => {
4643                    eprintln!(
4644                        "[bms_flex_row hvp parity n64_r20_p44] backend probe \
4645                         failed: {err}"
4646                    );
4647                    return;
4648                }
4649            };
4650            let stream = backend.stream.clone();
4651            let d_h = match stream.clone_htod(&row_hessians) {
4652                Ok(s) => s,
4653                Err(err) => {
4654                    eprintln!(
4655                        "[bms_flex_row hvp parity n64_r20_p44] upload h \
4656                         failed: {err}"
4657                    );
4658                    return;
4659                }
4660            };
4661            let d_m = match stream.clone_htod(&marginal) {
4662                Ok(s) => s,
4663                Err(err) => {
4664                    eprintln!(
4665                        "[bms_flex_row hvp parity n64_r20_p44] upload marg \
4666                         failed: {err}"
4667                    );
4668                    return;
4669                }
4670            };
4671            let d_g = match stream.clone_htod(&logslope) {
4672                Ok(s) => s,
4673                Err(err) => {
4674                    eprintln!(
4675                        "[bms_flex_row hvp parity n64_r20_p44] upload logslope \
4676                         failed: {err}"
4677                    );
4678                    return;
4679                }
4680            };
4681            let storage = DeviceResidentRowHess {
4682                hess: d_h,
4683                marginal_design: d_m,
4684                logslope_design: d_g,
4685                n,
4686                r,
4687                block: block.clone(),
4688                primary: primary.clone(),
4689
4690                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4691            };
4692            let gpu_hvp = launch_bms_flex_row_hvp(&storage, &v)
4693                .expect("HVP kernel must launch on CUDA host at n64/r20/p44");
4694            let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4695                .expect("diagonal kernel must launch on CUDA host at n64/r20/p44");
4696            assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4697            assert_eq!(gpu_diag.len(), cpu_diag.len());
4698            for i in 0..p_total {
4699                let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4700                assert!(
4701                    diff <= 1e-8,
4702                    "n64_r20_p44 HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4703                    cpu_hvp[i],
4704                    gpu_hvp[i]
4705                );
4706                let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4707                assert!(
4708                    ddiff <= 1e-8,
4709                    "n64_r20_p44 diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4710                    cpu_diag[i],
4711                    gpu_diag[i]
4712                );
4713            }
4714        }
4715    }
4716
4717    /// Block 9 Phase 6 — small-fixture parity for the dense-block kernel
4718    /// against the host-side P_i pullback oracle.
4719    /// Verifies bit-equality (modulo reduction-order f.p. noise) between
4720    /// the device-resident dense build and the host accumulator over the
4721    /// same per-row Hessian + designs + P_i pullback.
4722    #[test]
4723    pub(crate) fn bms_flex_row_dense_block_kernel_matches_cpu_pullback() {
4724        #[cfg(not(target_os = "linux"))]
4725        {
4726            eprintln!("[bms_flex_row dense_block parity] non-Linux host — skipping CUDA parity");
4727        }
4728        #[cfg(target_os = "linux")]
4729        {
4730            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4731                eprintln!("[bms_flex_row dense_block parity] no CUDA runtime — skipping");
4732                return;
4733            };
4734            // Small fixture: n=24, r=8 (2 + 3 + 3), p_total=18 (4+4+3+3).
4735            // Keeps the CPU pullback fast while still exercising every
4736            // primary slot (q, g, h, w).
4737            let n = 24_usize;
4738            let p_m = 4_usize;
4739            let p_g = 4_usize;
4740            let p_h_dim = 3_usize;
4741            let p_w_dim = 3_usize;
4742            let r = 2 + p_h_dim + p_w_dim;
4743            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4744            let block = BmsFlexBlockLayout {
4745                p_m,
4746                p_g,
4747                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4748                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4749                p_total,
4750            };
4751            let primary = BmsFlexPrimaryLayout {
4752                h: Some(2..2 + p_h_dim),
4753                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4754                r,
4755            };
4756
4757            let mut row_hessians = vec![0.0_f64; n * r * r];
4758            for row in 0..n {
4759                let base = row * r * r;
4760                for u in 0..r {
4761                    for v in 0..r {
4762                        let seed = (row as f64) * 0.21 + (u as f64) * 1.13 + (v as f64) * 0.47;
4763                        let a = (seed.sin() * 1.4 + (seed * 0.6).cos() * 0.7) * 0.5;
4764                        row_hessians[base + u * r + v] = a;
4765                    }
4766                }
4767                for u in 0..r {
4768                    for v in (u + 1)..r {
4769                        let upper = row_hessians[base + u * r + v];
4770                        let lower = row_hessians[base + v * r + u];
4771                        let sym = 0.5 * (upper + lower);
4772                        row_hessians[base + u * r + v] = sym;
4773                        row_hessians[base + v * r + u] = sym;
4774                    }
4775                    row_hessians[base + u * r + u] += r as f64;
4776                }
4777            }
4778            let mut marginal = vec![0.0_f64; n * p_m];
4779            for row in 0..n {
4780                for j in 0..p_m {
4781                    let seed = (row as f64) * 0.083 + (j as f64) * 0.171 + 0.31;
4782                    marginal[row * p_m + j] = seed.sin() * 0.7 - (seed * 0.5).cos() * 0.25;
4783                }
4784            }
4785            let mut logslope = vec![0.0_f64; n * p_g];
4786            for row in 0..n {
4787                for j in 0..p_g {
4788                    let seed = (row as f64) * 0.097 + (j as f64) * 0.143 - 0.15;
4789                    logslope[row * p_g + j] = seed.cos() * 0.65 + (seed * 0.4).sin() * 0.2;
4790                }
4791            }
4792
4793            // CPU oracle — same pullback math the device kernel mirrors.
4794            let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
4795            let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
4796            let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
4797            let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
4798            let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
4799            let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
4800            let mut h_cpu = vec![0.0_f64; p_total * p_total];
4801            for row in 0..n {
4802                let mrow = &marginal[row * p_m..(row + 1) * p_m];
4803                let grow = &logslope[row * p_g..(row + 1) * p_g];
4804                let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
4805                // Build per-row phi (r length-p_total vectors).
4806                let mut phi = vec![vec![0.0_f64; p_total]; r];
4807                for k in 0..p_m {
4808                    phi[0][k] = mrow[k];
4809                }
4810                for k in 0..p_g {
4811                    phi[1][p_m + k] = grow[k];
4812                }
4813                for k in 0..h_block_len {
4814                    phi[h_primary_start + k][h_block_start + k] = 1.0;
4815                }
4816                for k in 0..w_block_len {
4817                    phi[w_primary_start + k][w_block_start + k] = 1.0;
4818                }
4819                for u in 0..r {
4820                    for v in 0..r {
4821                        let huv = hrow[u * r + v];
4822                        if huv == 0.0 {
4823                            continue;
4824                        }
4825                        for m in 0..p_total {
4826                            let pm = phi[u][m];
4827                            if pm == 0.0 {
4828                                continue;
4829                            }
4830                            let scaled = huv * pm;
4831                            for nn in 0..p_total {
4832                                h_cpu[m * p_total + nn] += scaled * phi[v][nn];
4833                            }
4834                        }
4835                    }
4836                }
4837            }
4838
4839            // Build a transient device-resident storage and launch the
4840            // dense-block kernel.
4841            // Past the GpuRuntime::global() Some-gate: probe/upload failures are
4842            // real device faults on a CUDA host — fail loud (device-PCG class).
4843            let backend = HvpKernelBackend::probe().expect(
4844                "[bms_flex_row dense_block parity] backend probe must succeed on CUDA host",
4845            );
4846            let stream = backend.stream.clone();
4847            let d_h = stream
4848                .clone_htod(&row_hessians)
4849                .expect("[bms_flex_row dense_block parity] upload h must succeed on CUDA host");
4850            let d_m = stream
4851                .clone_htod(&marginal)
4852                .expect("[bms_flex_row dense_block parity] upload marg must succeed on CUDA host");
4853            let d_g = stream.clone_htod(&logslope).expect(
4854                "[bms_flex_row dense_block parity] upload logslope must succeed on CUDA host",
4855            );
4856            let storage = DeviceResidentRowHess {
4857                hess: d_h,
4858                marginal_design: d_m,
4859                logslope_design: d_g,
4860                n,
4861                r,
4862                block: block.clone(),
4863                primary: primary.clone(),
4864
4865                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4866            };
4867            let h_gpu = launch_bms_flex_row_dense_block(&storage)
4868                .expect("dense_block kernel must launch on CUDA host");
4869            assert_eq!(h_gpu.len(), p_total * p_total);
4870
4871            // Compare entry-by-entry with a tolerance that absorbs
4872            // reduction-order f.p. noise from the CTA chunk sum.
4873            let mut max_abs = 0.0_f64;
4874            for i in 0..p_total {
4875                for j in 0..p_total {
4876                    let a = h_cpu[i * p_total + j];
4877                    let b = h_gpu[i * p_total + j];
4878                    let diff = (a - b).abs();
4879                    if diff > max_abs {
4880                        max_abs = diff;
4881                    }
4882                    assert!(
4883                        diff <= 1e-9 * a.abs().max(b.abs()).max(1.0),
4884                        "dense_block[{i},{j}]: cpu={a} gpu={b} |Δ|={diff:.3e}"
4885                    );
4886                }
4887            }
4888            eprintln!(
4889                "[bms_flex_row dense_block parity] n={n} r={r} p={p_total}: max|Δ|={max_abs:.3e}"
4890            );
4891        }
4892    }
4893
4894    /// Block 9 final hill-climb gate — GPU HVP must be at least 5× faster
4895    /// than a Rayon-parallel CPU HVP at large-scale shape (n=195_000, r=20,
4896    /// p_total=44). This is the charter pass/fail metric for whether the
4897    /// device-resident row-Hessian path is a real perf win for the
4898    /// production marginal-slope fit.
4899    ///
4900    /// Methodology:
4901    ///   * Build the same deterministic fixture as the parity tests.
4902    ///   * GPU: median of `iters` `launch_bms_flex_row_hvp` wall-times
4903    ///     after `warmup` warm-up launches (kernel compile + L2 prime).
4904    ///   * CPU: median of `iters` `cpu_oracle_bms_flex_row_hvp` wall-times,
4905    ///     parallelised over rows via Rayon — this mirrors the actual
4906    ///     production CPU path in
4907    ///     `exact_newton_joint_hessian_matvec_from_cache` (which uses
4908    ///     `ROW_CHUNK_SIZE` chunked `into_par_iter()` for the same
4909    ///     contraction).
4910    ///   * Ratio = cpu_median / gpu_median; assert ratio >= 5.
4911    ///
4912    /// Skips on non-Linux / no-CUDA hosts.
4913    #[test]
4914    pub(crate) fn bms_flex_row_hvp_v100_hill_climb_5x_vs_cpu_at_large_scale() {
4915        #[cfg(not(target_os = "linux"))]
4916        {
4917            eprintln!("[bms_flex_row hvp hill-climb] non-Linux host — skipping V100 perf gate");
4918        }
4919        #[cfg(target_os = "linux")]
4920        {
4921            use rayon::prelude::*;
4922
4923            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4924                eprintln!(
4925                    "[bms_flex_row hvp hill-climb] no CUDA runtime — skipping V100 perf gate"
4926                );
4927                return;
4928            };
4929            let n = 195_000_usize;
4930            let p_m = 14_usize;
4931            let p_g = 12_usize;
4932            let p_h_dim = 10_usize;
4933            let p_w_dim = 8_usize;
4934            let r = 2 + p_h_dim + p_w_dim;
4935            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4936            let block = BmsFlexBlockLayout {
4937                p_m,
4938                p_g,
4939                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4940                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4941                p_total,
4942            };
4943            let primary = BmsFlexPrimaryLayout {
4944                h: Some(2..2 + p_h_dim),
4945                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4946                r,
4947            };
4948
4949            // Same deterministic fixture as the Phase 4 large-scale benchmark.
4950            let mut row_hessians = vec![0.0_f64; n * r * r];
4951            for row in 0..n {
4952                let base = row * r * r;
4953                for u in 0..r {
4954                    for vv in 0..r {
4955                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
4956                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4957                        row_hessians[base + u * r + vv] = a;
4958                    }
4959                }
4960                for u in 0..r {
4961                    for vv in (u + 1)..r {
4962                        let upper = row_hessians[base + u * r + vv];
4963                        let lower = row_hessians[base + vv * r + u];
4964                        let sym = 0.5 * (upper + lower);
4965                        row_hessians[base + u * r + vv] = sym;
4966                        row_hessians[base + vv * r + u] = sym;
4967                    }
4968                    row_hessians[base + u * r + u] += r as f64;
4969                }
4970            }
4971            let mut marginal = vec![0.0_f64; n * p_m];
4972            for row in 0..n {
4973                for j in 0..p_m {
4974                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4975                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4976                }
4977            }
4978            let mut logslope = vec![0.0_f64; n * p_g];
4979            for row in 0..n {
4980                for j in 0..p_g {
4981                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4982                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4983                }
4984            }
4985            let v: Vec<f64> = (0..p_total)
4986                .map(|i| {
4987                    let seed = (i as f64) * 0.157 + 0.6;
4988                    seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4989                })
4990                .collect();
4991
4992            // ── GPU side: upload once, time HVP launches ─────────────
4993            let backend = match HvpKernelBackend::probe() {
4994                Ok(b) => b,
4995                Err(err) => {
4996                    eprintln!("[bms_flex_row hvp hill-climb] backend probe failed: {err}");
4997                    return;
4998                }
4999            };
5000            let stream = backend.stream.clone();
5001            let d_h = match stream.clone_htod(&row_hessians) {
5002                Ok(s) => s,
5003                Err(err) => {
5004                    eprintln!("[bms_flex_row hvp hill-climb] upload h failed (likely OOM): {err}");
5005                    return;
5006                }
5007            };
5008            let d_m = match stream.clone_htod(&marginal) {
5009                Ok(s) => s,
5010                Err(err) => {
5011                    eprintln!("[bms_flex_row hvp hill-climb] upload marg failed: {err}");
5012                    return;
5013                }
5014            };
5015            let d_g = match stream.clone_htod(&logslope) {
5016                Ok(s) => s,
5017                Err(err) => {
5018                    eprintln!("[bms_flex_row hvp hill-climb] upload logslope failed: {err}");
5019                    return;
5020                }
5021            };
5022            let storage = DeviceResidentRowHess {
5023                hess: d_h,
5024                marginal_design: d_m,
5025                logslope_design: d_g,
5026                n,
5027                r,
5028                block: block.clone(),
5029                primary: primary.clone(),
5030
5031                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5032            };
5033            let warmup: usize = 3;
5034            let iters: usize = 15;
5035            for _ in 0..warmup {
5036                let out =
5037                    launch_bms_flex_row_hvp(&storage, &v).expect("warmup GPU HVP must launch");
5038                assert_eq!(out.len(), p_total);
5039            }
5040            let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5041            for _ in 0..iters {
5042                let t0 = std::time::Instant::now();
5043                let out = launch_bms_flex_row_hvp(&storage, &v).expect("GPU HVP must launch");
5044                gpu_us.push(t0.elapsed().as_micros());
5045                assert_eq!(out.len(), p_total);
5046            }
5047            gpu_us.sort_unstable();
5048            let gpu_median = gpu_us[iters / 2];
5049
5050            // ── CPU side: chunked Rayon HVP over rows, mirroring the
5051            //    production `exact_newton_joint_hessian_matvec_from_cache`
5052            //    parallelisation pattern (ROW_CHUNK_SIZE-row chunks,
5053            //    try_fold + try_reduce). The per-chunk worker calls the
5054            //    single-threaded oracle on its row slice.
5055            const CHUNK_ROWS: usize = 4096;
5056            let cpu_hvp_parallel = || -> Vec<f64> {
5057                let nchunks = n.div_ceil(CHUNK_ROWS);
5058                (0..nchunks)
5059                    .into_par_iter()
5060                    .fold(
5061                        || vec![0.0_f64; p_total],
5062                        |mut acc, ci| {
5063                            let lo = ci * CHUNK_ROWS;
5064                            let hi = (lo + CHUNK_ROWS).min(n);
5065                            let m = hi - lo;
5066                            let partial = cpu_oracle_bms_flex_row_hvp(
5067                                &row_hessians[lo * r * r..hi * r * r],
5068                                &marginal[lo * p_m..hi * p_m],
5069                                &logslope[lo * p_g..hi * p_g],
5070                                &block,
5071                                &primary,
5072                                m,
5073                                &v,
5074                            );
5075                            for (a, &p) in acc.iter_mut().zip(partial.iter()) {
5076                                *a += p;
5077                            }
5078                            acc
5079                        },
5080                    )
5081                    .reduce(
5082                        || vec![0.0_f64; p_total],
5083                        |mut a, b| {
5084                            for (ax, bx) in a.iter_mut().zip(b.iter()) {
5085                                *ax += *bx;
5086                            }
5087                            a
5088                        },
5089                    )
5090            };
5091            // Warmup once to populate L3 / steady-state Rayon thread pool.
5092            let warm = cpu_hvp_parallel();
5093            assert_eq!(warm.len(), p_total);
5094            let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5095            for _ in 0..iters {
5096                let t0 = std::time::Instant::now();
5097                let out = cpu_hvp_parallel();
5098                cpu_us.push(t0.elapsed().as_micros());
5099                assert_eq!(out.len(), p_total);
5100            }
5101            cpu_us.sort_unstable();
5102            let cpu_median = cpu_us[iters / 2];
5103
5104            let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5105            eprintln!(
5106                "[bms_flex_row hvp hill-climb] large-scale n={n} r={r} p={p_total}: \
5107                 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5108                 speedup={speedup:.2}× (charter target ≥ 5×)"
5109            );
5110            assert!(
5111                speedup >= 5.0,
5112                "large-scale HVP perf gate: GPU only {speedup:.2}× faster than CPU; \
5113                 need ≥ 5× per Block 9 charter (cpu_median={cpu_median}us, \
5114                 gpu_median={gpu_median}us). Hill-climb the kernel until met or \
5115                 prove the kernel is at hardware roofline."
5116            );
5117        }
5118    }
5119
5120    /// Companion to the HVP hill-climb: GPU dense-block build must be at
5121    /// least 10× faster than a Rayon-parallel CPU dense build at large-scale
5122    /// shape. The dense build is `O(n * r² * p_total)` work for both
5123    /// paths so the ratio is well-defined.
5124    #[test]
5125    pub(crate) fn bms_flex_row_dense_block_v100_hill_climb_10x_vs_cpu_at_large_scale() {
5126        #[cfg(not(target_os = "linux"))]
5127        {
5128            eprintln!(
5129                "[bms_flex_row dense_block hill-climb] non-Linux host — skipping V100 perf gate"
5130            );
5131        }
5132        #[cfg(target_os = "linux")]
5133        {
5134            use rayon::prelude::*;
5135
5136            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5137                eprintln!(
5138                    "[bms_flex_row dense_block hill-climb] no CUDA runtime — skipping V100 perf gate"
5139                );
5140                return;
5141            };
5142            let n = 195_000_usize;
5143            let p_m = 14_usize;
5144            let p_g = 12_usize;
5145            let p_h_dim = 10_usize;
5146            let p_w_dim = 8_usize;
5147            let r = 2 + p_h_dim + p_w_dim;
5148            let p_total = p_m + p_g + p_h_dim + p_w_dim;
5149            let block = BmsFlexBlockLayout {
5150                p_m,
5151                p_g,
5152                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5153                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5154                p_total,
5155            };
5156            let primary = BmsFlexPrimaryLayout {
5157                h: Some(2..2 + p_h_dim),
5158                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5159                r,
5160            };
5161
5162            // Reuse the same large-scale fixture recipe.
5163            let mut row_hessians = vec![0.0_f64; n * r * r];
5164            for row in 0..n {
5165                let base = row * r * r;
5166                for u in 0..r {
5167                    for vv in 0..r {
5168                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5169                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5170                        row_hessians[base + u * r + vv] = a;
5171                    }
5172                }
5173                for u in 0..r {
5174                    for vv in (u + 1)..r {
5175                        let upper = row_hessians[base + u * r + vv];
5176                        let lower = row_hessians[base + vv * r + u];
5177                        let sym = 0.5 * (upper + lower);
5178                        row_hessians[base + u * r + vv] = sym;
5179                        row_hessians[base + vv * r + u] = sym;
5180                    }
5181                    row_hessians[base + u * r + u] += r as f64;
5182                }
5183            }
5184            let mut marginal = vec![0.0_f64; n * p_m];
5185            for row in 0..n {
5186                for j in 0..p_m {
5187                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5188                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5189                }
5190            }
5191            let mut logslope = vec![0.0_f64; n * p_g];
5192            for row in 0..n {
5193                for j in 0..p_g {
5194                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5195                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5196                }
5197            }
5198
5199            // GPU dense_block kernel rejects p_total > DENSE_BLOCK_MAX_P
5200            // (72 at V100 48 KiB/block). LargeScale's p_total = 44 fits.
5201            if p_total > DENSE_BLOCK_MAX_P {
5202                eprintln!(
5203                    "[bms_flex_row dense_block hill-climb] p_total={p_total} > MAX={DENSE_BLOCK_MAX_P}, skipping"
5204                );
5205                return;
5206            }
5207            let backend = match HvpKernelBackend::probe() {
5208                Ok(b) => b,
5209                Err(err) => {
5210                    eprintln!("[bms_flex_row dense_block hill-climb] backend probe failed: {err}");
5211                    return;
5212                }
5213            };
5214            let stream = backend.stream.clone();
5215            let d_h = match stream.clone_htod(&row_hessians) {
5216                Ok(s) => s,
5217                Err(err) => {
5218                    eprintln!("[bms_flex_row dense_block hill-climb] upload h failed: {err}");
5219                    return;
5220                }
5221            };
5222            let d_m = match stream.clone_htod(&marginal) {
5223                Ok(s) => s,
5224                Err(err) => {
5225                    eprintln!("[bms_flex_row dense_block hill-climb] upload marg failed: {err}");
5226                    return;
5227                }
5228            };
5229            let d_g = match stream.clone_htod(&logslope) {
5230                Ok(s) => s,
5231                Err(err) => {
5232                    eprintln!(
5233                        "[bms_flex_row dense_block hill-climb] upload logslope failed: {err}"
5234                    );
5235                    return;
5236                }
5237            };
5238            let storage = DeviceResidentRowHess {
5239                hess: d_h,
5240                marginal_design: d_m,
5241                logslope_design: d_g,
5242                n,
5243                r,
5244                block: block.clone(),
5245                primary: primary.clone(),
5246
5247                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5248            };
5249            // Warmup + 5-iter median (dense build is heavier than HVP).
5250            let warmup: usize = 2;
5251            let iters: usize = 5;
5252            for _ in 0..warmup {
5253                let out = launch_bms_flex_row_dense_block(&storage)
5254                    .expect("warmup GPU dense_block must launch");
5255                assert_eq!(out.len(), p_total * p_total);
5256            }
5257            let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5258            for _ in 0..iters {
5259                let t0 = std::time::Instant::now();
5260                let out =
5261                    launch_bms_flex_row_dense_block(&storage).expect("GPU dense_block must launch");
5262                gpu_us.push(t0.elapsed().as_micros());
5263                assert_eq!(out.len(), p_total * p_total);
5264            }
5265            gpu_us.sort_unstable();
5266            let gpu_median = gpu_us[iters / 2];
5267
5268            // CPU side: chunked Rayon dense build over rows. Each chunk
5269            // builds a `[p_total, p_total]` partial then we reduce-add.
5270            const CHUNK_ROWS: usize = 2048;
5271            let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5272            let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5273            let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5274            let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5275            let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5276            let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5277            let cpu_build_parallel = || -> Vec<f64> {
5278                let nchunks = n.div_ceil(CHUNK_ROWS);
5279                (0..nchunks)
5280                    .into_par_iter()
5281                    .fold(
5282                        || vec![0.0_f64; p_total * p_total],
5283                        |mut acc, ci| {
5284                            let lo = ci * CHUNK_ROWS;
5285                            let hi = (lo + CHUNK_ROWS).min(n);
5286                            let mut phi: Vec<Vec<f64>> = vec![vec![0.0_f64; p_total]; r];
5287                            for row in lo..hi {
5288                                for col in phi.iter_mut() {
5289                                    col.iter_mut().for_each(|v| *v = 0.0);
5290                                }
5291                                let mrow = &marginal[row * p_m..(row + 1) * p_m];
5292                                let grow = &logslope[row * p_g..(row + 1) * p_g];
5293                                for k in 0..p_m {
5294                                    phi[0][k] = mrow[k];
5295                                }
5296                                for k in 0..p_g {
5297                                    phi[1][p_m + k] = grow[k];
5298                                }
5299                                for k in 0..h_block_len {
5300                                    phi[h_primary_start + k][h_block_start + k] = 1.0;
5301                                }
5302                                for k in 0..w_block_len {
5303                                    phi[w_primary_start + k][w_block_start + k] = 1.0;
5304                                }
5305                                let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5306                                for u in 0..r {
5307                                    for v_idx in 0..r {
5308                                        let huv = hrow[u * r + v_idx];
5309                                        if huv == 0.0 {
5310                                            continue;
5311                                        }
5312                                        for m in 0..p_total {
5313                                            let pm = phi[u][m];
5314                                            if pm == 0.0 {
5315                                                continue;
5316                                            }
5317                                            let scaled = huv * pm;
5318                                            for nn in 0..p_total {
5319                                                acc[m * p_total + nn] += scaled * phi[v_idx][nn];
5320                                            }
5321                                        }
5322                                    }
5323                                }
5324                            }
5325                            acc
5326                        },
5327                    )
5328                    .reduce(
5329                        || vec![0.0_f64; p_total * p_total],
5330                        |mut a, b| {
5331                            for (ax, bx) in a.iter_mut().zip(b.iter()) {
5332                                *ax += *bx;
5333                            }
5334                            a
5335                        },
5336                    )
5337            };
5338            let warm_cpu = cpu_build_parallel();
5339            assert_eq!(warm_cpu.len(), p_total * p_total);
5340            let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5341            for _ in 0..iters {
5342                let t0 = std::time::Instant::now();
5343                let out = cpu_build_parallel();
5344                cpu_us.push(t0.elapsed().as_micros());
5345                assert_eq!(out.len(), p_total * p_total);
5346            }
5347            cpu_us.sort_unstable();
5348            let cpu_median = cpu_us[iters / 2];
5349
5350            let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5351            eprintln!(
5352                "[bms_flex_row dense_block hill-climb] large-scale n={n} r={r} p={p_total}: \
5353                 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5354                 speedup={speedup:.2}× (charter target ≥ 10×)"
5355            );
5356            assert!(
5357                speedup >= 10.0,
5358                "large-scale dense-H perf gate: GPU only {speedup:.2}× faster than CPU; \
5359                 need ≥ 10× per Block 9 charter (cpu_median={cpu_median}us, \
5360                 gpu_median={gpu_median}us). Hill-climb the dense_block kernel \
5361                 (warp-stripe the u-v-m-n loop, vectorise loads, etc.) until met \
5362                 or prove the kernel is at hardware roofline."
5363            );
5364        }
5365    }
5366}