Skip to main content

gam_models/bms/gpu/
row.rs

1//! Stage 2 of the BMS FLEX row kernel — per-row math that turns per-cell
2//! derivative moments (built by Stage 1 in `src/gpu/cubic_cell/mod.rs`) into a
3//! row gradient and row-primary `r × r` Hessian.
4//!
5//! Math (mirrors the CPU reference
6//! `BernoulliMarginalSlope::compute_row_analytic_flex_from_parts_into` in
7//! `src/families/bernoulli_marginal_slope.rs`):
8//!
9//! For each row `i`, with per-cell cubic predictor coefficients
10//! `C_c = (C0, C1, C2, C3)` and derivative moments `m_0..m_9`, build
11//!
12//! ```text
13//!     κ        = 1 / (2π)
14//!     T_n      = κ · Σ_{e=0..3} C_e · m_{e+n}     (n = 0..6)
15//!     D(R)     = κ · Σ_{k=0..3} R_k · m_k
16//!     Q(R, S)  = Σ_{p,q=0..3} R_p · S_q · T_{p+q}
17//!     H(R, S, U) = D(U) − Q(R, S)
18//! ```
19//!
20//! Per cell `c`, accumulate into row scratch:
21//!
22//! ```text
23//!     F_a   += D(A_c)
24//!     F_aa  += H(A_c, A_c, AA_c)
25//!     F_u   += D(R_{c,u})                         u > 0
26//!     F_au  += H(A_c, R_{c,u}, AR_{c,u})          u > 0
27//!     F_uv  += H(R_{c,u}, R_{c,v}, S_{c,uv})      0 < u ≤ v
28//! ```
29//!
30//! After the cell sum, the `q`-row is overridden:
31//!
32//! ```text
33//!     F_q  = −mu_1
34//!     F_qq = −mu_2
35//!     F_qv = 0   (v > 0)
36//!     F_aq = 0
37//! ```
38//!
39//! Implicit function theorem (single `1/F_a`):
40//!
41//! ```text
42//!     inv_Fa = 1 / F_a
43//!     a_u    = −F_u · inv_Fa                       (q-row override: mu_1 · inv_Fa)
44//!     a_uv   = −(F_uv + F_au·a_v + F_av·a_u + F_aa·a_u·a_v) · inv_Fa
45//! ```
46//!
47//! Observed predictor at `z_obs` (host supplies pre-evaluated chi, xi, rho, tau,
48//! r_uv per row and coordinate):
49//!
50//! ```text
51//!     bar_e_u  = chi_obs · a_u + rho_u
52//!     bar_e_uv = chi_obs · a_uv + xi_obs · a_u · a_v + tau_u · a_v
53//!                + a_u · tau_v + r_uv
54//! ```
55//!
56//! Probit Mills (stable; uses `log_ndtr_and_mills` from `numerics_device::PROBIT_NUMERICS_CU`):
57//!
58//! ```text
59//!     s = 2y − 1 ;  m = s · e_obs
60//!     [log_cdf, λ] = log_ndtr_and_mills(m)
61//!     A = −w · s · λ
62//!     B =  w · λ · (m + λ)
63//! ```
64//!
65//! Final outputs:
66//!
67//! ```text
68//!     neglog   = −w · log_cdf
69//!     g_u      = A · bar_e_u
70//!     H_{uv}   = B · bar_e_u · bar_e_v + A · bar_e_uv     (symmetric)
71//! ```
72//!
73//! Implementation choice (Stage 2): **one CUDA block per row**, with
74//! `blockDim.x = 32` threads. The block's `F_u`, `F_au`, `F_uv`, `bar_e_u`,
75//! `bar_e_uv` live in shared memory; threads in the block parallelise the
76//! per-cell sums, then a single thread of the block (`threadIdx.x == 0`) does
77//! the IFT solve, the observed-point assembly, the Mills evaluation, and the
78//! final gradient + Hessian write-out. With the `r ≤ MAX_R` cap (32) the
79//! shared-memory footprint per block is `r + r + r*r + r + r*r` doubles
80//! = `2r² + 3r` ≤ 2 144 doubles ≈ 17 KB, well below the V100 48 KB per-block
81//! limit. This keeps the implementation simple and avoids per-thread global
82//! scratch (a per-thread `r*r` scratch arena would be ~2 GB at n=195k, r=20).
83
84#[cfg(target_os = "linux")]
85use std::sync::OnceLock;
86
87use gam_gpu::gpu_error::GpuError;
88
89#[cfg(target_os = "linux")]
90use std::sync::Arc;
91
92#[cfg(target_os = "linux")]
93use cudarc::driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
94
95/// Hard ceiling on `r` (= 2 + p_h + p_w). Matches the shared-memory budget
96/// argument in the module docstring: with `MAX_R = 32` the per-block shared
97/// footprint is at most `2·32² + 3·32 = 2 144` doubles = 17 KB.
98pub(crate) const MAX_R: usize = 32;
99
100/// `blockDim.x` for the row kernel. Threads of a row-block parallelise the
101/// per-cell loop; thread 0 of the block finalises the IFT solve. Linux-only
102/// because the kernel launcher that consumes it is Linux-only.
103#[cfg(target_os = "linux")]
104pub(crate) const ROW_KERNEL_THREADS: u32 = 32;
105
106/// Number of cubic predictor coefficients per cell (`C0..C3`) and the matching
107/// support length of `A_c`, `R_{c,u}`, `AA_c`, `AR_{c,u}`, `S_{c,uv}`.
108pub(crate) const COEFF4: usize = 4;
109
110/// Highest moment index touched per cell: `T_n` uses `m_{n+e}` for `e = 0..3`
111/// and `n = 0..6`, so the maximum index is `9`. `MOMENT_STRIDE = 10`.
112pub(crate) const MOMENT_STRIDE: usize = 10;
113
114/// Source of the per-cell derivative moments fed into the row kernel.
115/// Phase-4 wiring: the substrate at `src/gpu/cubic_cell/mod.rs` can produce
116/// these on the GPU; this enum lets the launcher consume them directly
117/// without a DtoH+HtoD round-trip.
118pub(crate) enum CellMomentsSource<'a> {
119    /// Host-resident `[total_cells, MOMENT_STRIDE = 10]` row-major buffer.
120    /// The launcher will HtoD-upload this on every launch.
121    Host(&'a [f64]),
122    /// Device-resident moments already living on the row-kernel backend's
123    /// default stream (which is the same `cuda_context_for(ordinal).default_stream()`
124    /// the cubic-cell substrate uses, so no cross-context copy is needed).
125    /// Length on the device must be `total_cells * MOMENT_STRIDE`. Linux-only.
126    #[cfg(target_os = "linux")]
127    Device(&'a CudaSlice<f64>),
128}
129
130impl<'a> CellMomentsSource<'a> {
131    /// Logical element count of the moments source, used by [`BmsFlexRowKernelInputs::validate`].
132    pub(crate) fn len(&self) -> usize {
133        match self {
134            CellMomentsSource::Host(slice) => slice.len(),
135            #[cfg(target_os = "linux")]
136            CellMomentsSource::Device(d) => d.len(),
137        }
138    }
139}
140
141/// Per-row input bundle for [`launch_bms_flex_row_kernel`].
142///
143/// Coordinate ordering convention: `u = 0` is `a` (the latent intercept and
144/// the variable IFT eliminates); `u = 1` is `b` (slope); `u = 2..2+p_h` is the
145/// score-warp `β_h` block; `u = 2+p_h..2+p_h+p_w` is the link-wiggle `β_w`
146/// block. So `r = 2 + p_h + p_w` and `u = 1` is the `b` (slope) index used by
147/// the sparse `S_{b·h}` / `S_{b·w}` payloads.
148macro_rules! define_bms_flex_row_kernel_input_types {
149    (
150        f64_fields: [$($f64_field:ident),+ $(,)?],
151        u32_fields: [$($u32_field:ident),+ $(,)?],
152        moments_field: $moments_field:ident $(,)?
153    ) => {
154        pub(crate) struct BmsFlexRowKernelInputs<'a> {
155            /// Number of observation rows.
156            pub n_rows: usize,
157            /// Total primary local dimension. `r = 2 + p_h + p_w`.
158            pub r: usize,
159            /// Number of score-warp basis coordinates.
160            pub p_h: usize,
161            /// Number of link-wiggle basis coordinates.
162            pub p_w: usize,
163            /// Probit frailty scale `S_f` (scalar shared across rows; matches
164            /// `BernoulliMarginalSlope::probit_frailty_scale`).
165            pub s_f: f64,
166            $(pub $f64_field: &'a [f64],)+
167            $(pub $u32_field: &'a [u32],)+
168            pub $moments_field: CellMomentsSource<'a>,
169        }
170
171        /// Owned twin of [`BmsFlexRowKernelInputs`] — every borrowed slice is
172        /// replaced by an owned `Vec`. The buffer fields are declared from the
173        /// same schema as the borrowed launch ABI and converted by
174        /// [`BmsFlexRowKernelInputsOwned::as_borrowed`].
175        pub(crate) struct BmsFlexRowKernelInputsOwned {
176            pub n_rows: usize,
177            pub r: usize,
178            pub p_h: usize,
179            pub p_w: usize,
180            pub s_f: f64,
181            $(pub $f64_field: Vec<f64>,)+
182            $(pub $u32_field: Vec<u32>,)+
183            pub $moments_field: Vec<f64>,
184            /// Phase-4 device-resident moments. When `Some(_)`, the launcher
185            /// skips the host upload and consumes the buffer directly.
186            /// Linux-only field.
187            #[cfg(target_os = "linux")]
188            pub cell_moments_device: Option<CudaSlice<f64>>,
189        }
190
191        impl BmsFlexRowKernelInputsOwned {
192            /// Borrowed view over `self` suitable for
193            /// [`launch_bms_flex_row_kernel`]. The returned struct holds
194            /// references into `self` so the owned bundle must outlive the
195            /// launch.
196            pub(crate) fn as_borrowed(&self) -> BmsFlexRowKernelInputs<'_> {
197                #[cfg(target_os = "linux")]
198                let cell_moments = match self.cell_moments_device.as_ref() {
199                    Some(d) => CellMomentsSource::Device(d),
200                    None => CellMomentsSource::Host(&self.cell_moments),
201                };
202                #[cfg(not(target_os = "linux"))]
203                let cell_moments = CellMomentsSource::Host(&self.cell_moments);
204                BmsFlexRowKernelInputs {
205                    n_rows: self.n_rows,
206                    r: self.r,
207                    p_h: self.p_h,
208                    p_w: self.p_w,
209                    s_f: self.s_f,
210                    $($f64_field: &self.$f64_field,)+
211                    $($u32_field: &self.$u32_field,)+
212                    $moments_field: cell_moments,
213                }
214            }
215        }
216    };
217}
218
219define_bms_flex_row_kernel_input_types! {
220    f64_fields: [
221        q,
222        b,
223        mu_1,
224        mu_2,
225        z_obs,
226        y,
227        w,
228        cell_c0,
229        cell_c1,
230        cell_c2,
231        cell_c3,
232        cell_a,
233        cell_aa,
234        cell_r,
235        cell_ar,
236        cell_sbb,
237        cell_sbh,
238        cell_sbw,
239        chi_obs,
240        xi_obs,
241        rho_u,
242        tau_u,
243        r_uv,
244    ],
245    u32_fields: [cell_offsets],
246    moments_field: cell_moments,
247}
248
249/// Per-row outputs produced by [`launch_bms_flex_row_kernel`].
250#[derive(Debug)]
251pub(crate) struct BmsFlexRowKernelOutputs {
252    /// Per-row negative log-likelihood. Length `n_rows`.
253    pub neglog: Vec<f64>,
254    /// Per-row gradient, row-major `[n_rows, r]`.
255    pub grad: Vec<f64>,
256    /// Per-row Hessian, row-major `[n_rows, r*r]`. The kernel writes the full
257    /// symmetric matrix.
258    pub hess: Vec<f64>,
259}
260
261impl<'a> BmsFlexRowKernelInputs<'a> {
262    /// Sanity-check every shape the kernel relies on. This is the only place
263    /// length errors are surfaced — the device kernel assumes valid layout.
264    pub(crate) fn validate(&self) -> Result<(), GpuError> {
265        if self.r == 0 {
266            return Err(GpuError::DriverCallFailed {
267                reason: "bms_flex_row inputs: r must be > 0".to_string(),
268            });
269        }
270        if self.r > MAX_R {
271            return Err(GpuError::DriverCallFailed {
272                reason: format!("bms_flex_row inputs: r={} exceeds MAX_R={MAX_R}", self.r),
273            });
274        }
275        if self.r != 2 + self.p_h + self.p_w {
276            return Err(GpuError::DriverCallFailed {
277                reason: format!(
278                    "bms_flex_row inputs: r={} must equal 2 + p_h({}) + p_w({}) = {}",
279                    self.r,
280                    self.p_h,
281                    self.p_w,
282                    2 + self.p_h + self.p_w
283                ),
284            });
285        }
286        let n = self.n_rows;
287        let check_len = |name: &str, have: usize, want: usize| -> Result<(), GpuError> {
288            if have != want {
289                return Err(GpuError::DriverCallFailed {
290                    reason: format!("bms_flex_row inputs: {name}.len()={have} != {want}"),
291                });
292            }
293            Ok(())
294        };
295        check_len("q", self.q.len(), n)?;
296        check_len("b", self.b.len(), n)?;
297        check_len("mu_1", self.mu_1.len(), n)?;
298        check_len("mu_2", self.mu_2.len(), n)?;
299        check_len("z_obs", self.z_obs.len(), n)?;
300        check_len("y", self.y.len(), n)?;
301        check_len("w", self.w.len(), n)?;
302        check_len("chi_obs", self.chi_obs.len(), n)?;
303        check_len("xi_obs", self.xi_obs.len(), n)?;
304        check_len("rho_u", self.rho_u.len(), n * self.r)?;
305        check_len("tau_u", self.tau_u.len(), n * self.r)?;
306        check_len("r_uv", self.r_uv.len(), n * self.r * self.r)?;
307        check_len("cell_offsets", self.cell_offsets.len(), n + 1)?;
308        let total_cells_u32 = self.cell_offsets[n];
309        let total_cells = total_cells_u32 as usize;
310        check_len("cell_c0", self.cell_c0.len(), total_cells)?;
311        check_len("cell_c1", self.cell_c1.len(), total_cells)?;
312        check_len("cell_c2", self.cell_c2.len(), total_cells)?;
313        check_len("cell_c3", self.cell_c3.len(), total_cells)?;
314        check_len("cell_a", self.cell_a.len(), total_cells * COEFF4)?;
315        check_len("cell_aa", self.cell_aa.len(), total_cells * COEFF4)?;
316        check_len(
317            "cell_r",
318            self.cell_r.len(),
319            total_cells * self.r.saturating_sub(1) * COEFF4,
320        )?;
321        check_len(
322            "cell_ar",
323            self.cell_ar.len(),
324            total_cells * self.r.saturating_sub(1) * COEFF4,
325        )?;
326        check_len("cell_sbb", self.cell_sbb.len(), total_cells * COEFF4)?;
327        check_len(
328            "cell_sbh",
329            self.cell_sbh.len(),
330            total_cells * self.p_h * COEFF4,
331        )?;
332        check_len(
333            "cell_sbw",
334            self.cell_sbw.len(),
335            total_cells * self.p_w * COEFF4,
336        )?;
337        check_len(
338            "cell_moments",
339            self.cell_moments.len(),
340            total_cells * MOMENT_STRIDE,
341        )?;
342        // Bonus: when the moments came from `CellMomentsSource::Device`, the
343        // launcher needs to know the source is from a device buffer; nothing
344        // to validate beyond length above. The Host variant length check is
345        // also already covered above.
346        // Monotone cell_offsets check.
347        for i in 0..n {
348            if self.cell_offsets[i] > self.cell_offsets[i + 1] {
349                return Err(GpuError::DriverCallFailed {
350                    reason: format!(
351                        "bms_flex_row inputs: cell_offsets must be monotone (offset[{}]={} > offset[{}]={})",
352                        i,
353                        self.cell_offsets[i],
354                        i + 1,
355                        self.cell_offsets[i + 1]
356                    ),
357                });
358            }
359        }
360        Ok(())
361    }
362}
363
364/// NVRTC kernel source body. One CUDA block per row; 32 threads per block
365/// parallise the per-cell sums into shared-memory scratch; thread 0 of the
366/// block finishes the IFT + observed-point + Mills + Hessian write-out.
367///
368/// Shared probit numerics (`erfcx_nonnegative`, `log_ndtr`,
369/// `log_ndtr_and_mills`) are provided by
370/// `numerics_device::PROBIT_NUMERICS_CU`, which is prepended before
371/// passing to `cudarc::nvrtc::compile_ptx`.
372///
373/// **CPU parity reference**: the body mirrors
374/// `compute_row_analytic_flex_from_parts_into` in
375/// `src/families/bernoulli_marginal_slope.rs`.
376#[cfg(target_os = "linux")]
377pub(crate) const ROW_KERNEL_BODY: &str = r#"
378// One block per row. blockDim.x = 32; threadIdx.x parallises per-cell sums.
379// CPU parity reference: src/families/bernoulli_marginal_slope.rs
380//                      ::compute_row_analytic_flex_from_parts_into.
381
382#define INV_TWO_PI     0.15915494309189535
383
384extern "C" __device__ __forceinline__ double atomic_add_f64(double *addr, double value) {
385    unsigned long long int *addr_as_ull = (unsigned long long int *)addr;
386    unsigned long long int old = *addr_as_ull;
387    unsigned long long int assumed;
388    do {
389        assumed = old;
390        double next = __longlong_as_double((long long int)assumed) + value;
391        old = atomicCAS(addr_as_ull, assumed, (unsigned long long int)__double_as_longlong(next));
392    } while (assumed != old);
393    return __longlong_as_double((long long int)old);
394}
395
396// `nan_fill_outputs`: thread-0-only path used when row inputs are degenerate
397// (`F_a` non-finite or non-positive). Writes NaNs to neglog/grad/hess so the
398// host falls back to CPU for that row.
399extern "C" __device__ __forceinline__ void
400nan_fill_outputs(int r,
401                 int row,
402                 double *out_neglog,
403                 double *out_grad,
404                 double *out_hess) {
405    double nan_value = __longlong_as_double(0x7ff8000000000000ULL);
406    out_neglog[row] = nan_value;
407    for (int u = 0; u < r; ++u) {
408        out_grad[row * r + u] = nan_value;
409    }
410    int rr = r * r;
411    for (int idx = 0; idx < rr; ++idx) {
412        out_hess[row * rr + idx] = nan_value;
413    }
414}
415
416extern "C" __global__ void bms_flex_row_kernel(
417    int                  n_rows,
418    int                  r,
419    int                  p_h,
420    int                  p_w,
421    double               s_f,                // currently unused on device:
422                                             // host has already baked S_f
423                                             // into the cubic coefficients.
424                                             // Kept for diagnostic parity.
425    const double * __restrict__ row_q,
426    const double * __restrict__ row_b,
427    const double * __restrict__ row_mu1,
428    const double * __restrict__ row_mu2,
429    const double * __restrict__ row_zobs,
430    const double * __restrict__ row_y,
431    const double * __restrict__ row_w,
432    const unsigned int * __restrict__ cell_offsets,
433    const double * __restrict__ cell_c0,
434    const double * __restrict__ cell_c1,
435    const double * __restrict__ cell_c2,
436    const double * __restrict__ cell_c3,
437    const double * __restrict__ cell_a,       // [n_cells, 4]
438    const double * __restrict__ cell_aa,      // [n_cells, 4]
439    const double * __restrict__ cell_r,       // [n_cells, r-1, 4]
440    const double * __restrict__ cell_ar,      // [n_cells, r-1, 4]
441    const double * __restrict__ cell_sbb,     // [n_cells, 4]
442    const double * __restrict__ cell_sbh,     // [n_cells, p_h, 4]
443    const double * __restrict__ cell_sbw,     // [n_cells, p_w, 4]
444    const double * __restrict__ cell_moments, // [n_cells, 10]
445    const double * __restrict__ row_chi,
446    const double * __restrict__ row_xi,
447    const double * __restrict__ row_rho,      // [n_rows, r]
448    const double * __restrict__ row_tau,      // [n_rows, r]
449    const double * __restrict__ row_ruv,      // [n_rows, r*r]
450    double       * __restrict__ out_neglog,
451    double       * __restrict__ out_grad,
452    double       * __restrict__ out_hess)
453{
454    int row = blockIdx.x;
455    if (row >= n_rows) return;
456    int tid = threadIdx.x;
457
458    // ── shared scratch (sized to MAX_R = 32) ──────────────────────────────
459    // Layout (doubles):
460    //   F_u      [r]
461    //   F_au     [r]
462    //   F_uv     [r*r]
463    //   bar_e_u  [r]
464    //   bar_e_uv [r*r]
465    //   reduce_a [blockDim.x]
466    //   reduce_b [blockDim.x]
467    // Sized for the worst case (r = MAX_R = 32).
468    __shared__ double F_u[32];
469    __shared__ double F_au[32];
470    __shared__ double F_uv[32 * 32];
471    __shared__ double bar_e_u[32];
472    __shared__ double bar_e_uv[32 * 32];
473    __shared__ double reduce_a[32];
474    __shared__ double reduce_b[32];
475    __shared__ double F_a_shared;
476    __shared__ double F_aa_shared;
477
478    // Zero scratch.
479    if (tid == 0) { F_a_shared = 0.0; F_aa_shared = 0.0; }
480    for (int u = tid; u < r; u += blockDim.x) {
481        F_u[u]  = 0.0;
482        F_au[u] = 0.0;
483    }
484    for (int uv = tid; uv < r * r; uv += blockDim.x) {
485        F_uv[uv] = 0.0;
486    }
487    __syncthreads();
488
489    // ── per-cell sweep ───────────────────────────────────────────────────
490    unsigned int cell_lo = cell_offsets[row];
491    unsigned int cell_hi = cell_offsets[row + 1];
492    int n_cells = (int)(cell_hi - cell_lo);
493
494    double local_Fa  = 0.0;
495    double local_Faa = 0.0;
496
497    for (int local_c = tid; local_c < n_cells; local_c += blockDim.x) {
498        unsigned int c = cell_lo + (unsigned int)local_c;
499
500        // Load cubic predictor coeffs C0..C3.
501        double C[4];
502        C[0] = cell_c0[c]; C[1] = cell_c1[c];
503        C[2] = cell_c2[c]; C[3] = cell_c3[c];
504
505        // Load m_0..m_9.
506        const double *m = cell_moments + (size_t)c * 10;
507
508        // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
509        // CPU parity: equivalent to the `eta_rs ⊗ moments` contraction in
510        //             `cell_second_derivative_from_moments` after folding the
511        //             cubic predictor.
512        double T[7];
513        #pragma unroll
514        for (int n = 0; n < 7; ++n) {
515            double acc = 0.0;
516            #pragma unroll
517            for (int e = 0; e < 4; ++e) {
518                acc = fma(C[e], m[e + n], acc);
519            }
520            T[n] = acc * INV_TWO_PI;
521        }
522
523        // D(R) = κ · Σ_k R_k · m_k.
524        // CPU parity: `cell_first_derivative_from_moments`.
525        #define D_OF(R) (INV_TWO_PI * (R[0]*m[0] + R[1]*m[1] + R[2]*m[2] + R[3]*m[3]))
526
527        // Q(R, S) = Σ_{p,q} R_p · S_q · T_{p+q}.
528        // CPU parity: the `eta_rs` folded dot in
529        // `cell_second_derivative_from_moments`.
530        #define Q_OF(R, S)                                                                 \
531            ((R[0]*S[0])*T[0] + (R[0]*S[1] + R[1]*S[0])*T[1]                               \
532             + (R[0]*S[2] + R[1]*S[1] + R[2]*S[0])*T[2]                                    \
533             + (R[0]*S[3] + R[1]*S[2] + R[2]*S[1] + R[3]*S[0])*T[3]                        \
534             + (R[1]*S[3] + R[2]*S[2] + R[3]*S[1])*T[4]                                    \
535             + (R[2]*S[3] + R[3]*S[2])*T[5]                                                \
536             + (R[3]*S[3])*T[6])
537
538        // F_a += D(A_c) ; F_aa += H(A_c, A_c, AA_c) = D(AA_c) − Q(A_c, A_c).
539        const double *A_c  = cell_a  + (size_t)c * 4;
540        const double *AA_c = cell_aa + (size_t)c * 4;
541        local_Fa  += D_OF(A_c);
542        local_Faa += D_OF(AA_c) - Q_OF(A_c, A_c);
543
544        // For each u > 0: F_u += D(R_{c,u}) ; F_au += H(A_c, R_{c,u}, AR_{c,u})
545        //                                   = D(AR_{c,u}) − Q(A_c, R_{c,u}).
546        for (int u = 1; u < r; ++u) {
547            const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
548            const double *AR_u = cell_ar + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
549            double d_R   = D_OF(R_u);
550            double d_AR  = D_OF(AR_u);
551            double q_AR  = Q_OF(A_c, R_u);
552            atomic_add_f64(&F_u[u], d_R);
553            atomic_add_f64(&F_au[u], d_AR - q_AR);
554        }
555
556        // F_uv: only b·b, b·h_j, b·w_ℓ have a material `S_{c,uv}`; every other
557        // (u, v) pair just contributes −Q(R_u, R_v).
558        // CPU parity: `SparsePrimaryCoeffJetView::pair_from_b_family` with
559        // `COEFF_SUPPORT_BHW` — every cross pair outside the b-row is zero.
560        for (int u = 1; u < r; ++u) {
561            const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
562            for (int v = u; v < r; ++v) {
563                const double *R_v = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(v - 1)) * 4;
564                double q_uv = Q_OF(R_u, R_v);
565                double d_s  = 0.0;
566                // S_{bb}: u == v == 1 (b coordinate).
567                if (u == 1 && v == 1) {
568                    const double *S_bb = cell_sbb + (size_t)c * 4;
569                    d_s = D_OF(S_bb);
570                }
571                // S_{b·h_j}: u == 1, v in score-warp block, or symmetric.
572                else if (u == 1 && v >= 2 && v < 2 + p_h) {
573                    int j = v - 2;
574                    const double *S_bh = cell_sbh + ((size_t)c * (size_t)p_h + (size_t)j) * 4;
575                    d_s = D_OF(S_bh);
576                }
577                // S_{b·w_ℓ}: u == 1, v in link-wiggle block, or symmetric.
578                else if (u == 1 && v >= 2 + p_h && v < r) {
579                    int l = v - (2 + p_h);
580                    const double *S_bw = cell_sbw + ((size_t)c * (size_t)p_w + (size_t)l) * 4;
581                    d_s = D_OF(S_bw);
582                }
583                // Symmetric mirror: u in (h or w) block, v == 1 cannot happen
584                // because we iterate v >= u; skip.
585                double val = d_s - q_uv;
586                atomic_add_f64(&F_uv[u * r + v], val);
587            }
588        }
589
590        #undef D_OF
591        #undef Q_OF
592    }
593
594    // Block reduction of local_Fa, local_Faa into shared.
595    reduce_a[tid] = local_Fa;
596    reduce_b[tid] = local_Faa;
597    __syncthreads();
598    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
599        if (tid < stride) {
600            reduce_a[tid] += reduce_a[tid + stride];
601            reduce_b[tid] += reduce_b[tid + stride];
602        }
603        __syncthreads();
604    }
605    if (tid == 0) {
606        F_a_shared  = reduce_a[0];
607        F_aa_shared = reduce_b[0];
608    }
609    __syncthreads();
610
611    // ── thread-0 finalisation: IFT + observed-point + Mills + writes ──────
612    if (tid != 0) return;
613
614    double F_a  = F_a_shared;
615    double F_aa = F_aa_shared;
616    double mu_1 = row_mu1[row];
617    double mu_2 = row_mu2[row];
618
619    // q-row overrides.
620    //   F_q  = -mu_1 ; F_qq = -mu_2 ; F_qv = 0 (v > 0) ; F_aq = 0.
621    F_u[0]  = -mu_1;
622    F_au[0] = 0.0;
623    // Zero the q-cross row/column of F_uv (u == 0 or v == 0), then plant -mu_2 at (0,0).
624    for (int v = 0; v < r; ++v) {
625        F_uv[0 * r + v] = 0.0;
626        F_uv[v * r + 0] = 0.0;
627    }
628    F_uv[0 * r + 0] = -mu_2;
629
630    // Guard: degenerate F_a ⇒ NaN-fill this row's outputs.
631    if (!isfinite(F_a) || F_a <= 0.0) {
632        nan_fill_outputs(r, row, out_neglog, out_grad, out_hess);
633        return;
634    }
635    double inv_Fa = 1.0 / F_a;
636
637    // IFT, first order.
638    //   a_u = -F_u · inv_Fa     (q-override: a_q = mu_1 · inv_Fa).
639    double a_u[32];
640    a_u[0] = mu_1 * inv_Fa;
641    for (int u = 1; u < r; ++u) {
642        a_u[u] = -F_u[u] * inv_Fa;
643    }
644
645    // IFT, second order.
646    //   a_uv = -(F_uv + F_au · a_v + F_av · a_u + F_aa · a_u · a_v) · inv_Fa.
647    // The q-row contributions (u==0 or v==0) collapse to a_uv = mu_2 · inv_Fa
648    // when both are 0 and to (F_au_v) · inv_Fa-style mixed shape otherwise.
649    // We compute it uniformly using the populated F_uv / F_au with the
650    // q-overrides above.
651    double a_uv[32 * 32];
652    for (int u = 0; u < r; ++u) {
653        for (int v = u; v < r; ++v) {
654            double term = F_uv[u * r + v]
655                        + F_au[v] * a_u[u]
656                        + F_au[u] * a_u[v]
657                        + F_aa * a_u[u] * a_u[v];
658            double val = -term * inv_Fa;
659            a_uv[u * r + v] = val;
660            a_uv[v * r + u] = val;
661        }
662    }
663
664    // Observed predictor jets at z_obs.
665    //   bar_e_u  = chi · a_u + rho_u.
666    //   bar_e_uv = chi · a_uv + xi · a_u · a_v + tau_u · a_v + a_u · tau_v + r_uv.
667    double chi = row_chi[row];
668    double xi  = row_xi[row];
669    const double *rho = row_rho + (size_t)row * r;
670    const double *tau = row_tau + (size_t)row * r;
671    const double *ruv = row_ruv + (size_t)row * r * r;
672
673    for (int u = 0; u < r; ++u) {
674        bar_e_u[u] = chi * a_u[u] + rho[u];
675    }
676    for (int u = 0; u < r; ++u) {
677        for (int v = u; v < r; ++v) {
678            double val = chi * a_uv[u * r + v]
679                       + xi  * a_u[u] * a_u[v]
680                       + tau[u] * a_u[v]
681                       + a_u[u] * tau[v]
682                       + ruv[u * r + v];
683            bar_e_uv[u * r + v] = val;
684            if (u != v) {
685                bar_e_uv[v * r + u] = val;
686            }
687        }
688    }
689
690    // Probit Mills.
691    double y    = row_y[row];
692    double w    = row_w[row];
693    double s    = 2.0 * y - 1.0;
694    // The "observed predictor" e_obs is the value (degree-0) term of the
695    // observed jet — same convention as the CPU path. CPU parity:
696    // `e_obs = chi · a_0 + rho_0`... well, no: `bar_e_u` is the *first*
697    // derivative jet, not the value. The observed predictor value comes
698    // from the host pre-evaluation as `rho_u[0]` of the value jet —
699    // pre-baked into the host's `m = s · e_obs` payload. For Stage 2 we
700    // expose it via the `bar_e_u[0]` slot which is `chi·a_0 + rho_0`; the
701    // host wiring lands in the dispatcher wave that bridges this kernel
702    // and the row evaluator in `bernoulli_marginal_slope.rs`.
703    double e_obs = bar_e_u[0];
704    double m_arg = s * e_obs;
705    double log_cdf, lambda;
706    log_ndtr_and_mills(m_arg, &log_cdf, &lambda);
707    double A_i = -w * s * lambda;
708    double B_i =  w * lambda * (m_arg + lambda);
709
710    out_neglog[row] = -w * log_cdf;
711    for (int u = 0; u < r; ++u) {
712        out_grad[row * r + u] = A_i * bar_e_u[u];
713    }
714    for (int u = 0; u < r; ++u) {
715        for (int v = u; v < r; ++v) {
716            double val = B_i * bar_e_u[u] * bar_e_u[v] + A_i * bar_e_uv[u * r + v];
717            out_hess[row * r * r + u * r + v] = val;
718            if (u != v) {
719                out_hess[row * r * r + v * r + u] = val;
720            }
721        }
722    }
723}
724"#;
725
726// Force `s_f` to be considered used at the Rust level even though Stage 2 of
727// the kernel doesn't consume it on-device (the host has already baked the
728// probit frailty scale into the per-cell cubic coefficients). The dispatcher
729// wave that ports the rigid-branch fallback may want to apply `s_f` device-side
730// for log diagnostics; leaving the field on the input struct + reading it here
731// avoids a `let _` silencer the build.rs scanner would reject.
732#[inline]
733pub(crate) fn s_f_diagnostic_finite(inputs: &BmsFlexRowKernelInputs<'_>) -> bool {
734    inputs.s_f.is_finite() && inputs.s_f > 0.0
735}
736
737#[cfg(target_os = "linux")]
738pub(crate) struct RowKernelBackend {
739    pub(crate) stream: Arc<CudaStream>,
740    pub(crate) module: Arc<CudaModule>,
741}
742
743#[cfg(target_os = "linux")]
744impl RowKernelBackend {
745    pub(crate) fn probe() -> Result<&'static Self, GpuError> {
746        static BACKEND: OnceLock<Result<RowKernelBackend, GpuError>> = OnceLock::new();
747        BACKEND
748            .get_or_init(|| {
749                gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row", |parts| {
750                    let row_kernel_source = [
751                        gam_gpu::numerics_device::PROBIT_NUMERICS_CU,
752                        ROW_KERNEL_BODY,
753                    ]
754                    .concat();
755                    let ptx = cudarc::nvrtc::compile_ptx(row_kernel_source).map_err(|err| {
756                        GpuError::DriverCallFailed {
757                            reason: format!("bms_flex_row NVRTC compile failed: {err}"),
758                        }
759                    })?;
760                    let module =
761                        parts
762                            .ctx
763                            .load_module(ptx)
764                            .map_err(|err| GpuError::DriverCallFailed {
765                                reason: format!("bms_flex_row module load failed: {err}"),
766                            })?;
767                    Ok(RowKernelBackend {
768                        stream: parts.stream.clone(),
769                        module,
770                    })
771                })
772            })
773            .as_ref()
774            .map_err(GpuError::clone)
775    }
776}
777
778/// Launch Stage-2 BMS FLEX row kernel. On non-Linux returns
779/// [`GpuError::DriverLibraryUnavailable`]; on Linux NVRTC-compiles the kernel
780/// (cached for the process lifetime), uploads the per-row + per-cell buffers,
781/// and dispatches one block per row.
782pub(crate) fn launch_bms_flex_row_kernel(
783    inputs: BmsFlexRowKernelInputs<'_>,
784) -> Result<BmsFlexRowKernelOutputs, GpuError> {
785    inputs.validate()?;
786    if !s_f_diagnostic_finite(&inputs) {
787        return Err(GpuError::DriverCallFailed {
788            reason: format!(
789                "bms_flex_row inputs: s_f must be positive and finite, got {}",
790                inputs.s_f
791            ),
792        });
793    }
794
795    #[cfg(target_os = "linux")]
796    {
797        launch_linux(inputs)
798    }
799    #[cfg(not(target_os = "linux"))]
800    {
801        Err(GpuError::DriverLibraryUnavailable {
802            reason: "bms_flex_row GPU kernel is Linux-only".to_string(),
803        })
804    }
805}
806
807#[cfg(target_os = "linux")]
808pub(crate) fn launch_linux(
809    inputs: BmsFlexRowKernelInputs<'_>,
810) -> Result<BmsFlexRowKernelOutputs, GpuError> {
811    let backend = RowKernelBackend::probe()?;
812    let stream = &backend.stream;
813
814    let upload_f64 = |slice: &[f64], label: &str| {
815        stream
816            .clone_htod(slice)
817            .map_err(|err| GpuError::DriverCallFailed {
818                reason: format!("bms_flex_row upload {label}: {err}"),
819            })
820    };
821    let upload_u32 = |slice: &[u32], label: &str| {
822        stream
823            .clone_htod(slice)
824            .map_err(|err| GpuError::DriverCallFailed {
825                reason: format!("bms_flex_row upload {label}: {err}"),
826            })
827    };
828
829    let d_q = upload_f64(inputs.q, "q")?;
830    let d_b = upload_f64(inputs.b, "b")?;
831    let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
832    let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
833    let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
834    let d_y = upload_f64(inputs.y, "y")?;
835    let d_w = upload_f64(inputs.w, "w")?;
836    let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
837    let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
838    let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
839    let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
840    let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
841    let d_a = upload_f64(inputs.cell_a, "cell_a")?;
842    let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
843    let d_r = upload_f64(inputs.cell_r, "cell_r")?;
844    let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
845    let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
846    let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
847    let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
848    // Phase-4: optionally consume device-resident moments (no host upload).
849    // Both branches end up holding a `&CudaSlice<f64>` named `d_moments_ref`
850    // we can pass to the launch builder uniformly.
851    let owned_host_moments: CudaSlice<f64>;
852    let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
853        CellMomentsSource::Host(slice) => {
854            owned_host_moments = upload_f64(slice, "cell_moments")?;
855            &owned_host_moments
856        }
857        CellMomentsSource::Device(d) => *d,
858    };
859    let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
860    let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
861    let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
862    let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
863    let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
864
865    let n = inputs.n_rows;
866    let r = inputs.r;
867    let mut d_neglog = stream
868        .alloc_zeros::<f64>(n)
869        .map_err(|err| GpuError::DriverCallFailed {
870            reason: format!("bms_flex_row alloc neglog: {err}"),
871        })?;
872    let mut d_grad =
873        stream
874            .alloc_zeros::<f64>(n * r)
875            .map_err(|err| GpuError::DriverCallFailed {
876                reason: format!("bms_flex_row alloc grad: {err}"),
877            })?;
878    let mut d_hess =
879        stream
880            .alloc_zeros::<f64>(n * r * r)
881            .map_err(|err| GpuError::DriverCallFailed {
882                reason: format!("bms_flex_row alloc hess: {err}"),
883            })?;
884
885    let func = backend
886        .module
887        .load_function("bms_flex_row_kernel")
888        .map_err(|err| GpuError::DriverCallFailed {
889            reason: format!("bms_flex_row load_function: {err}"),
890        })?;
891
892    let cfg = LaunchConfig {
893        grid_dim: (n as u32, 1, 1),
894        block_dim: (ROW_KERNEL_THREADS, 1, 1),
895        shared_mem_bytes: 0,
896    };
897    let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
898        reason: format!("bms_flex_row: n_rows={n} exceeds i32 range"),
899    })?;
900    let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
901        reason: format!("bms_flex_row: r={r} exceeds i32 range"),
902    })?;
903    let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
904        reason: format!("bms_flex_row: p_h={} exceeds i32 range", inputs.p_h),
905    })?;
906    let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
907        reason: format!("bms_flex_row: p_w={} exceeds i32 range", inputs.p_w),
908    })?;
909    let s_f = inputs.s_f;
910
911    let mut builder = stream.launch_builder(&func);
912    builder
913        .arg(&n_i32)
914        .arg(&r_i32)
915        .arg(&p_h_i32)
916        .arg(&p_w_i32)
917        .arg(&s_f)
918        .arg(&d_q)
919        .arg(&d_b)
920        .arg(&d_mu1)
921        .arg(&d_mu2)
922        .arg(&d_zobs)
923        .arg(&d_y)
924        .arg(&d_w)
925        .arg(&d_offsets)
926        .arg(&d_c0)
927        .arg(&d_c1)
928        .arg(&d_c2)
929        .arg(&d_c3)
930        .arg(&d_a)
931        .arg(&d_aa)
932        .arg(&d_r)
933        .arg(&d_ar)
934        .arg(&d_sbb)
935        .arg(&d_sbh)
936        .arg(&d_sbw)
937        .arg(d_moments_ref)
938        .arg(&d_chi)
939        .arg(&d_xi)
940        .arg(&d_rho)
941        .arg(&d_tau)
942        .arg(&d_ruv)
943        .arg(&mut d_neglog)
944        .arg(&mut d_grad)
945        .arg(&mut d_hess);
946
947    // SAFETY: every kernel parameter above is either a primitive `i32` /
948    // `f64` (passed by value), a const device pointer to a buffer whose
949    // length the host validated against the input struct, or an output
950    // buffer pre-allocated to `n_rows`, `n_rows*r`, `n_rows*r*r`
951    // doubles. The kernel's shared-memory arrays are sized to MAX_R = 32
952    // and validate() rejects r > MAX_R.
953    unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
954        reason: format!("bms_flex_row launch: {err}"),
955    })?;
956    stream
957        .synchronize()
958        .map_err(|err| GpuError::DriverCallFailed {
959            reason: format!("bms_flex_row synchronize: {err}"),
960        })?;
961
962    let neglog = stream
963        .clone_dtoh(&d_neglog)
964        .map_err(|err| GpuError::DriverCallFailed {
965            reason: format!("bms_flex_row download neglog: {err}"),
966        })?;
967    let grad = stream
968        .clone_dtoh(&d_grad)
969        .map_err(|err| GpuError::DriverCallFailed {
970            reason: format!("bms_flex_row download grad: {err}"),
971        })?;
972    let hess = stream
973        .clone_dtoh(&d_hess)
974        .map_err(|err| GpuError::DriverCallFailed {
975            reason: format!("bms_flex_row download hess: {err}"),
976        })?;
977
978    Ok(BmsFlexRowKernelOutputs { neglog, grad, hess })
979}
980
981// ─────────────────────────────────────────────────────────────────────────────
982// Phase 3: device-resident row Hessian + HVP / diagonal kernels.
983//
984// Math (mirrors the CPU oracle in
985// `src/families/bernoulli_marginal_slope.rs::exact_newton_joint_hessian_*_from_cache`):
986//
987//   Block layout (joint β):
988//     marginal = [0..p_m), logslope = [p_m..p_m+p_g),
989//     h        = [h_start..h_end), w = [w_start..w_end), total = p_total.
990//
991//   Primary layout (per-row r-vector):
992//     q = 0, logslope = 1,
993//     h = [h_primary_start..h_primary_end),
994//     w = [w_primary_start..w_primary_end), total = r.
995//
996//   row_dir[u] for u in primary layout:
997//     row_dir[0]   = Σ_j marginal_design[row, j] · v[j]
998//     row_dir[1]   = Σ_j logslope_design[row, j] · v[p_m + j]
999//     row_dir[h_k] = v[h_block_start + (h_k - h_primary_start)]
1000//     row_dir[w_k] = v[w_block_start + (w_k - w_primary_start)]
1001//
1002//   action[u]    = Σ_v row_hessians[row, u*r + v] · row_dir[v]
1003//
1004//   block_partial[marginal_j] += action[0] · marginal_design[row, j]
1005//   block_partial[logslope_j] += action[1] · logslope_design[row, j]
1006//   block_partial[h_block_start + (h_k - h_primary_start)] += action[h_k]
1007//   block_partial[w_block_start + (w_k - w_primary_start)] += action[w_k]
1008//
1009// Diagonal:
1010//   diag[marginal_j] += row_hess[row, 0*r + 0] · marginal_design[row, j]²
1011//   diag[logslope_j] += row_hess[row, 1*r + 1] · logslope_design[row, j]²
1012//   diag[h_block_start + k] += row_hess[row, ii*r + ii]   (ii = h_primary_start + k)
1013//   diag[w_block_start + k] += row_hess[row, ii*r + ii]   (ii = w_primary_start + k)
1014//
1015// Determinism: each CTA owns a contiguous slice of `[chunk_start..chunk_end)`
1016// rows and writes its full per-chunk `p_total` partial into a non-overlapping
1017// region of the global partial buffer. The reduce kernel then sums those
1018// partials in fixed chunk-major order. No atomics.
1019
1020/// Joint-β block layout shared with the host (mirrors `BlockSlices` in
1021/// `bernoulli_marginal_slope.rs`).
1022///
1023/// Gating: Linux-only. The lone production constructor lives in
1024/// `bernoulli_marginal_slope.rs:9189` behind `#[cfg(target_os = "linux")]`
1025/// — the device-resident row-Hessian path is the only producer (see
1026/// `launch_bms_flex_row_kernel_device_resident`), and the joint-β
1027/// consumers `launch_bms_flex_row_hvp` / `_diagonal` / `_dense_block`
1028/// are also Linux-only. Any non-Linux test referencing this type must
1029/// guard itself with `#[cfg(target_os = "linux")]` too — the build.rs
1030/// ban scanner explicitly rejects `#[cfg(any(..., test))]` on items as
1031/// a dead-code escape hatch.
1032#[cfg(target_os = "linux")]
1033#[derive(Clone, Debug)]
1034pub(crate) struct BmsFlexBlockLayout {
1035    pub p_m: usize,
1036    pub p_g: usize,
1037    pub h: Option<std::ops::Range<usize>>,
1038    pub w: Option<std::ops::Range<usize>>,
1039    pub p_total: usize,
1040}
1041
1042/// Primary-r layout shared with the host (mirrors `PrimarySlices`).
1043/// Gating rationale identical to [`BmsFlexBlockLayout`].
1044#[cfg(target_os = "linux")]
1045#[derive(Clone, Debug)]
1046pub(crate) struct BmsFlexPrimaryLayout {
1047    pub h: Option<std::ops::Range<usize>>,
1048    pub w: Option<std::ops::Range<usize>>,
1049    pub r: usize,
1050}
1051
1052// ── Linux-only: device-resident row-Hessian state + kernels ─────────────────
1053
1054/// Number of rows each HVP / diagonal CTA processes. Each CTA writes a single
1055/// `[1, p_total]` partial row into the global partial buffer (no atomics);
1056/// the reduce kernel then sums partials in chunk-major fixed order.
1057#[cfg(target_os = "linux")]
1058pub(crate) const HVP_ROWS_PER_CTA: u32 = 256;
1059
1060/// `blockDim.x` for the HVP / diagonal partial kernels.
1061#[cfg(target_os = "linux")]
1062pub(crate) const HVP_THREADS: u32 = 128;
1063
1064/// `blockDim.x` for the partial-sum reduction kernels (one element per thread,
1065/// grid-strided over the `p_total`/`rhs_elems` partial buffer). A full warp
1066/// multiple that keeps the reduce launch occupancy-bound rather than tail-bound
1067/// for the typical large-scale `p_total`.
1068#[cfg(target_os = "linux")]
1069pub(crate) const REDUCTION_THREADS: u32 = 256;
1070
1071/// Maximum RHS columns fused into one row-primary HVP launch. The matching
1072/// CUDA source uses fixed shared arrays sized as
1073/// `BMS_FLEX_ROW_HVP_MAX_RHS * MAX_R`; increasing this requires updating the
1074/// `MAX_MULTI_RHS` define in `HVP_KERNEL_SOURCE`.
1075#[cfg(target_os = "linux")]
1076pub(crate) const BMS_FLEX_ROW_HVP_MAX_RHS: usize = 8;
1077
1078/// Device-resident state produced by
1079/// [`launch_bms_flex_row_kernel_device_resident`] and consumed by
1080/// [`launch_bms_flex_row_hvp`] / [`launch_bms_flex_row_diagonal`].
1081///
1082/// Owns the row-Hessian + design slices on-device so the host can issue
1083/// many HVPs against the same β snapshot without round-tripping
1084/// 626 MB through host RAM. Drop releases the device memory back to
1085/// the CUDA runtime.
1086/// Per-row Hessian storage layout on the device. The build path is free to
1087/// emit either, and the Hv / diag kernels read whichever the storage says.
1088///
1089/// Charter (Block 9 Phase 4): packed-upper halves the DRAM footprint of the
1090/// `n × r²` cache (per-row `r*(r+1)/2` doubles instead of `r²`), at the cost
1091/// of a single per-entry index conversion in the kernel. The benchmark
1092/// decides whether the packed path becomes the default for large-scale
1093/// fits (`r = 20` → 210 vs 400 doubles per row, ~47.5% smaller). The
1094/// numerics are bit-equal because each `H_i` is symmetric by construction
1095/// (the row kernel emits a symmetric block by construction — see the
1096/// symmetric scratch-write loop in `bms_flex_row_kernel`'s shared-memory
1097/// finaliser).
1098#[cfg(target_os = "linux")]
1099pub struct DeviceResidentRowHess {
1100    /// Per-row dense `[n, r, r]` row-major Hessian. Element `(u, v)` of row
1101    /// `i` is `hess[i*r*r + u*r + v]`. This is the only on-device storage
1102    /// layout supported by the current HVP / diag kernels.
1103    pub(crate) hess: CudaSlice<f64>,
1104    pub(crate) marginal_design: CudaSlice<f64>,
1105    pub(crate) logslope_design: CudaSlice<f64>,
1106    pub(crate) n: usize,
1107    pub(crate) r: usize,
1108    pub(crate) block: BmsFlexBlockLayout,
1109    pub(crate) primary: BmsFlexPrimaryLayout,
1110    /// Estimated bytes resident on device (for accounting).
1111    pub(crate) bytes: u64,
1112}
1113
1114#[cfg(target_os = "linux")]
1115impl std::fmt::Debug for DeviceResidentRowHess {
1116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1117        f.debug_struct("DeviceResidentRowHess")
1118            .field("n", &self.n)
1119            .field("r", &self.r)
1120            .field("p_total", &self.block.p_total)
1121            .field("bytes", &self.bytes)
1122            .finish()
1123    }
1124}
1125
1126/// Sized-to-fit-once CTA mapping. Rows `[c * HVP_ROWS_PER_CTA, (c+1) * HVP_ROWS_PER_CTA)`
1127/// belong to chunk `c`.
1128#[cfg(target_os = "linux")]
1129pub(crate) fn num_hvp_chunks(n: usize) -> usize {
1130    n.div_ceil(HVP_ROWS_PER_CTA as usize)
1131}
1132
1133/// NVRTC source: HVP-partial kernel + HVP-reduce kernel + diag-partial +
1134/// diag-reduce. All kernels mirror the CPU oracle in this file.
1135#[cfg(target_os = "linux")]
1136pub(crate) const HVP_KERNEL_SOURCE: &str = r#"
1137// CPU parity reference: cpu_oracle_bms_flex_row_hvp / cpu_oracle_bms_flex_row_diagonal
1138// in this module.
1139
1140#define MAX_MULTI_RHS 8
1141
1142extern "C" __global__ void bms_flex_row_hvp_partial(
1143    int                  n_rows,
1144    int                  r,
1145    int                  p_m,
1146    int                  p_g,
1147    int                  p_total,
1148    int                  h_block_start,
1149    int                  h_block_len,
1150    int                  w_block_start,
1151    int                  w_block_len,
1152    int                  h_primary_start,
1153    int                  w_primary_start,
1154    int                  rows_per_cta,
1155    const double * __restrict__ row_hessians,    // [n, r*r]
1156    const double * __restrict__ marginal_design, // [n, p_m] row-major
1157    const double * __restrict__ logslope_design, // [n, p_g] row-major
1158    const double * __restrict__ v,               // [p_total]
1159    double       * __restrict__ partial)         // [num_chunks, p_total]
1160{
1161    int chunk = blockIdx.x;
1162    int tid   = threadIdx.x;
1163    int row_lo = chunk * rows_per_cta;
1164    int row_hi = row_lo + rows_per_cta;
1165    if (row_hi > n_rows) row_hi = n_rows;
1166
1167    // Zero this chunk's partial slice cooperatively.
1168    double *out = partial + (size_t)chunk * (size_t)p_total;
1169    for (int j = tid; j < p_total; j += blockDim.x) {
1170        out[j] = 0.0;
1171    }
1172    __syncthreads();
1173
1174    // Each thread serially processes a stride-of-blockDim set of rows so
1175    // every write to `out[..]` happens from one thread → no atomics within
1176    // the chunk. To keep writes race-free across threads of the same chunk,
1177    // we serialize the cross-row accumulation through a per-row barrier:
1178    // thread 0 of the block processes all rows in the chunk. The per-row
1179    // work is dominated by the dot/axpy over `p_m + p_g`, which is large.
1180    // For Stage 3 we ship the simple, correct path (thread 0 sequential
1181    // per row, blockDim.x threads parallel within a row's dot/axpy).
1182    __shared__ double row_dir[32];
1183    __shared__ double action[32];
1184    __shared__ double dot_reduce[128];
1185
1186    for (int row = row_lo; row < row_hi; ++row) {
1187        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1188        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1189        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1190
1191        // row_dir[0] = mrow · v[0..p_m]
1192        double local = 0.0;
1193        for (int j = tid; j < p_m; j += blockDim.x) {
1194            local += mrow[j] * v[j];
1195        }
1196        dot_reduce[tid] = local;
1197        __syncthreads();
1198        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1199            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1200            __syncthreads();
1201        }
1202        if (tid == 0) row_dir[0] = dot_reduce[0];
1203
1204        // row_dir[1] = grow · v[p_m..p_m+p_g]
1205        local = 0.0;
1206        for (int j = tid; j < p_g; j += blockDim.x) {
1207            local += grow[j] * v[p_m + j];
1208        }
1209        dot_reduce[tid] = local;
1210        __syncthreads();
1211        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1212            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1213            __syncthreads();
1214        }
1215        if (tid == 0) row_dir[1] = dot_reduce[0];
1216
1217        // h/w blocks: direct copy.
1218        if (tid == 0) {
1219            for (int k = 0; k < h_block_len; ++k) {
1220                row_dir[h_primary_start + k] = v[h_block_start + k];
1221            }
1222            for (int k = 0; k < w_block_len; ++k) {
1223                row_dir[w_primary_start + k] = v[w_block_start + k];
1224            }
1225        }
1226        __syncthreads();
1227
1228        // action[u] = Σ_v Hrow[u*r+v] · row_dir[v], computed by thread u (u < r).
1229        if (tid < r) {
1230            double acc = 0.0;
1231            for (int vv = 0; vv < r; ++vv) {
1232                acc += Hrow[tid * r + vv] * row_dir[vv];
1233            }
1234            action[tid] = acc;
1235        }
1236        __syncthreads();
1237
1238        // Pull back into joint β slot.
1239        //   marginal: out[j] += action[0] · mrow[j]   (parallel j)
1240        double a0 = action[0];
1241        for (int j = tid; j < p_m; j += blockDim.x) {
1242            out[j] += a0 * mrow[j];
1243        }
1244        double a1 = action[1];
1245        for (int j = tid; j < p_g; j += blockDim.x) {
1246            out[p_m + j] += a1 * grow[j];
1247        }
1248        if (tid == 0) {
1249            for (int k = 0; k < h_block_len; ++k) {
1250                out[h_block_start + k] += action[h_primary_start + k];
1251            }
1252            for (int k = 0; k < w_block_len; ++k) {
1253                out[w_block_start + k] += action[w_primary_start + k];
1254            }
1255        }
1256        __syncthreads();
1257    }
1258}
1259
1260extern "C" __global__ void bms_flex_row_hvp_reduce(
1261    int                  num_chunks,
1262    int                  p_total,
1263    const double * __restrict__ partial,   // [num_chunks, p_total]
1264    double       * __restrict__ out)        // [p_total]
1265{
1266    int j = blockIdx.x * blockDim.x + threadIdx.x;
1267    if (j >= p_total) return;
1268    double acc = 0.0;
1269    for (int c = 0; c < num_chunks; ++c) {
1270        acc += partial[(size_t)c * (size_t)p_total + (size_t)j];
1271    }
1272    out[j] = acc;
1273}
1274
1275extern "C" __global__ void bms_flex_row_hvp_multi_partial(
1276    int                  n_rows,
1277    int                  r,
1278    int                  p_m,
1279    int                  p_g,
1280    int                  p_total,
1281    int                  h_block_start,
1282    int                  h_block_len,
1283    int                  w_block_start,
1284    int                  w_block_len,
1285    int                  h_primary_start,
1286    int                  w_primary_start,
1287    int                  rows_per_cta,
1288    int                  rhs_count,
1289    const double * __restrict__ row_hessians,    // [n, r*r]
1290    const double * __restrict__ marginal_design, // [n, p_m]
1291    const double * __restrict__ logslope_design, // [n, p_g]
1292    const double * __restrict__ v_rhs,           // [rhs_count, p_total]
1293    double       * __restrict__ partial)         // [rhs_count, num_chunks, p_total]
1294{
1295    int chunk = blockIdx.x;
1296    int tid   = threadIdx.x;
1297    int row_lo = chunk * rows_per_cta;
1298    int row_hi = row_lo + rows_per_cta;
1299    if (row_hi > n_rows) row_hi = n_rows;
1300
1301    int num_chunks = (n_rows + rows_per_cta - 1) / rows_per_cta;
1302    for (int idx = tid; idx < rhs_count * p_total; idx += blockDim.x) {
1303        int rhs = idx / p_total;
1304        int j = idx - rhs * p_total;
1305        partial[((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total + (size_t)j] = 0.0;
1306    }
1307    __syncthreads();
1308
1309    __shared__ double row_dir[MAX_MULTI_RHS * 32];
1310    __shared__ double action[MAX_MULTI_RHS * 32];
1311    __shared__ double dot_reduce[128];
1312
1313    for (int row = row_lo; row < row_hi; ++row) {
1314        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1315        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1316        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1317
1318        for (int rhs = 0; rhs < rhs_count; ++rhs) {
1319            const double *v = v_rhs + (size_t)rhs * (size_t)p_total;
1320
1321            double local = 0.0;
1322            for (int j = tid; j < p_m; j += blockDim.x) {
1323                local += mrow[j] * v[j];
1324            }
1325            dot_reduce[tid] = local;
1326            __syncthreads();
1327            for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1328                if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1329                __syncthreads();
1330            }
1331            if (tid == 0) row_dir[rhs * 32 + 0] = dot_reduce[0];
1332
1333            local = 0.0;
1334            for (int j = tid; j < p_g; j += blockDim.x) {
1335                local += grow[j] * v[p_m + j];
1336            }
1337            dot_reduce[tid] = local;
1338            __syncthreads();
1339            for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1340                if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1341                __syncthreads();
1342            }
1343            if (tid == 0) {
1344                row_dir[rhs * 32 + 1] = dot_reduce[0];
1345                for (int k = 0; k < h_block_len; ++k) {
1346                    row_dir[rhs * 32 + h_primary_start + k] = v[h_block_start + k];
1347                }
1348                for (int k = 0; k < w_block_len; ++k) {
1349                    row_dir[rhs * 32 + w_primary_start + k] = v[w_block_start + k];
1350                }
1351            }
1352            __syncthreads();
1353        }
1354
1355        for (int idx = tid; idx < rhs_count * r; idx += blockDim.x) {
1356            int rhs = idx / r;
1357            int u = idx - rhs * r;
1358            double acc = 0.0;
1359            const double *dir = row_dir + rhs * 32;
1360            for (int vv = 0; vv < r; ++vv) {
1361                acc += Hrow[u * r + vv] * dir[vv];
1362            }
1363            action[rhs * 32 + u] = acc;
1364        }
1365        __syncthreads();
1366
1367        for (int rhs = 0; rhs < rhs_count; ++rhs) {
1368            double *out = partial + ((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total;
1369            double a0 = action[rhs * 32 + 0];
1370            for (int j = tid; j < p_m; j += blockDim.x) {
1371                out[j] += a0 * mrow[j];
1372            }
1373            double a1 = action[rhs * 32 + 1];
1374            for (int j = tid; j < p_g; j += blockDim.x) {
1375                out[p_m + j] += a1 * grow[j];
1376            }
1377            if (tid == 0) {
1378                for (int k = 0; k < h_block_len; ++k) {
1379                    out[h_block_start + k] += action[rhs * 32 + h_primary_start + k];
1380                }
1381                for (int k = 0; k < w_block_len; ++k) {
1382                    out[w_block_start + k] += action[rhs * 32 + w_primary_start + k];
1383                }
1384            }
1385            __syncthreads();
1386        }
1387    }
1388}
1389
1390extern "C" __global__ void bms_flex_row_hvp_multi_reduce(
1391    int                  num_chunks,
1392    int                  p_total,
1393    int                  rhs_count,
1394    const double * __restrict__ partial,   // [rhs_count, num_chunks, p_total]
1395    double       * __restrict__ out)        // [rhs_count, p_total]
1396{
1397    int idx = blockIdx.x * blockDim.x + threadIdx.x;
1398    int total = rhs_count * p_total;
1399    if (idx >= total) return;
1400    int rhs = idx / p_total;
1401    int j = idx - rhs * p_total;
1402    double acc = 0.0;
1403    for (int c = 0; c < num_chunks; ++c) {
1404        acc += partial[((size_t)rhs * (size_t)num_chunks + (size_t)c) * (size_t)p_total + (size_t)j];
1405    }
1406    out[(size_t)rhs * (size_t)p_total + (size_t)j] = acc;
1407}
1408
1409extern "C" __global__ void bms_flex_row_diag_partial(
1410    int                  n_rows,
1411    int                  r,
1412    int                  p_m,
1413    int                  p_g,
1414    int                  p_total,
1415    int                  h_block_start,
1416    int                  h_block_len,
1417    int                  w_block_start,
1418    int                  w_block_len,
1419    int                  h_primary_start,
1420    int                  w_primary_start,
1421    int                  rows_per_cta,
1422    const double * __restrict__ row_hessians,
1423    const double * __restrict__ marginal_design,
1424    const double * __restrict__ logslope_design,
1425    double       * __restrict__ partial)
1426{
1427    int chunk = blockIdx.x;
1428    int tid   = threadIdx.x;
1429    int row_lo = chunk * rows_per_cta;
1430    int row_hi = row_lo + rows_per_cta;
1431    if (row_hi > n_rows) row_hi = n_rows;
1432
1433    double *out = partial + (size_t)chunk * (size_t)p_total;
1434    for (int j = tid; j < p_total; j += blockDim.x) {
1435        out[j] = 0.0;
1436    }
1437    __syncthreads();
1438
1439    for (int row = row_lo; row < row_hi; ++row) {
1440        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1441        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1442        const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1443        double h00 = Hrow[0];
1444        double h11 = Hrow[1 * r + 1];
1445        for (int j = tid; j < p_m; j += blockDim.x) {
1446            double v = mrow[j];
1447            out[j] += h00 * v * v;
1448        }
1449        for (int j = tid; j < p_g; j += blockDim.x) {
1450            double v = grow[j];
1451            out[p_m + j] += h11 * v * v;
1452        }
1453        if (tid == 0) {
1454            for (int k = 0; k < h_block_len; ++k) {
1455                int ii = h_primary_start + k;
1456                out[h_block_start + k] += Hrow[ii * r + ii];
1457            }
1458            for (int k = 0; k < w_block_len; ++k) {
1459                int ii = w_primary_start + k;
1460                out[w_block_start + k] += Hrow[ii * r + ii];
1461            }
1462        }
1463        __syncthreads();
1464    }
1465}
1466
1467// ────────────────────────────────────────────────────────────────────────
1468// Phase 4 — SymmetricPackedUpper variants. Per-row storage is
1469//   row_hessians_packed + (size_t)row * (size_t)(r*(r+1)/2)
1470// indexed as
1471//   packed[(u*(2*r - u - 1))/2 + (v - u)]   for u <= v
1472// with symmetric mirror for v < u.
1473// ────────────────────────────────────────────────────────────────────────
1474
1475// Helper: packed-upper index for (u, v) within a single row of r*(r+1)/2
1476// doubles. Caller must pre-swap so that u <= v.
1477__device__ __forceinline__ int bms_flex_packed_idx(int u, int v, int r) {
1478    // u*(2r - u - 1)/2 + (v - u)
1479    return (u * (2 * r - u - 1)) / 2 + (v - u);
1480}
1481
1482// Pack one row of the full row-major r×r Hessian into packed-upper layout.
1483// Launched as one CTA per row (gridDim.x = n_rows, blockDim.x configurable).
1484// Bit-equal copy: each upper-triangle entry is read once from the dense
1485// source and written once to the packed destination.
1486extern "C" __global__ void bms_flex_row_pack_upper(
1487    int                  n_rows,
1488    int                  r,
1489    const double * __restrict__ src_full,    // [n, r*r]
1490    double       * __restrict__ dst_packed)  // [n, r*(r+1)/2]
1491{
1492    int row = blockIdx.x;
1493    if (row >= n_rows) return;
1494    int tid = threadIdx.x;
1495    int per_row = r * (r + 1) / 2;
1496    const double *src = src_full + (size_t)row * (size_t)r * (size_t)r;
1497    double       *dst = dst_packed + (size_t)row * (size_t)per_row;
1498    // Linear scan over packed positions; map each back to (u, v).
1499    for (int pos = tid; pos < per_row; pos += blockDim.x) {
1500        // Invert: for u in [0, r), the range [u_start, u_start + (r - u))
1501        // contains positions for that u. u_start = u*(2r - u - 1)/2.
1502        // Solve smallest u with u*(2r - u - 1)/2 > pos to get u (then
1503        // back off by one); equivalent O(r) linear scan with r <= 32.
1504        int u = 0;
1505        int u_start = 0;
1506        while (u < r) {
1507            int next = u_start + (r - u);
1508            if (pos < next) break;
1509            u_start = next;
1510            ++u;
1511        }
1512        int v = u + (pos - u_start);
1513        dst[pos] = src[(size_t)u * (size_t)r + (size_t)v];
1514    }
1515}
1516
1517extern "C" __global__ void bms_flex_row_hvp_partial_packed(
1518    int                  n_rows,
1519    int                  r,
1520    int                  p_m,
1521    int                  p_g,
1522    int                  p_total,
1523    int                  h_block_start,
1524    int                  h_block_len,
1525    int                  w_block_start,
1526    int                  w_block_len,
1527    int                  h_primary_start,
1528    int                  w_primary_start,
1529    int                  rows_per_cta,
1530    const double * __restrict__ row_hessians_packed, // [n, r*(r+1)/2]
1531    const double * __restrict__ marginal_design,
1532    const double * __restrict__ logslope_design,
1533    const double * __restrict__ v,
1534    double       * __restrict__ partial)
1535{
1536    int chunk = blockIdx.x;
1537    int tid   = threadIdx.x;
1538    int row_lo = chunk * rows_per_cta;
1539    int row_hi = row_lo + rows_per_cta;
1540    if (row_hi > n_rows) row_hi = n_rows;
1541
1542    int per_row = r * (r + 1) / 2;
1543    double *out = partial + (size_t)chunk * (size_t)p_total;
1544    for (int j = tid; j < p_total; j += blockDim.x) {
1545        out[j] = 0.0;
1546    }
1547    __syncthreads();
1548
1549    __shared__ double row_dir[32];
1550    __shared__ double action[32];
1551    __shared__ double dot_reduce[128];
1552
1553    for (int row = row_lo; row < row_hi; ++row) {
1554        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1555        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1556        const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1557
1558        // row_dir[0] = mrow · v[0..p_m]
1559        double local = 0.0;
1560        for (int j = tid; j < p_m; j += blockDim.x) {
1561            local += mrow[j] * v[j];
1562        }
1563        dot_reduce[tid] = local;
1564        __syncthreads();
1565        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1566            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1567            __syncthreads();
1568        }
1569        if (tid == 0) row_dir[0] = dot_reduce[0];
1570
1571        // row_dir[1] = grow · v[p_m..p_m+p_g]
1572        local = 0.0;
1573        for (int j = tid; j < p_g; j += blockDim.x) {
1574            local += grow[j] * v[p_m + j];
1575        }
1576        dot_reduce[tid] = local;
1577        __syncthreads();
1578        for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1579            if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1580            __syncthreads();
1581        }
1582        if (tid == 0) row_dir[1] = dot_reduce[0];
1583
1584        if (tid == 0) {
1585            for (int k = 0; k < h_block_len; ++k) {
1586                row_dir[h_primary_start + k] = v[h_block_start + k];
1587            }
1588            for (int k = 0; k < w_block_len; ++k) {
1589                row_dir[w_primary_start + k] = v[w_block_start + k];
1590            }
1591        }
1592        __syncthreads();
1593
1594        // action[u] = Σ_w H[u, w] · row_dir[w], where H[u, w] reads from
1595        // packed-upper with (uu, vv) = (min(u, w), max(u, w)).
1596        if (tid < r) {
1597            double acc = 0.0;
1598            int u = tid;
1599            for (int w = 0; w < r; ++w) {
1600                int uu = u < w ? u : w;
1601                int vv = u < w ? w : u;
1602                acc += Hrow[bms_flex_packed_idx(uu, vv, r)] * row_dir[w];
1603            }
1604            action[tid] = acc;
1605        }
1606        __syncthreads();
1607
1608        double a0 = action[0];
1609        for (int j = tid; j < p_m; j += blockDim.x) {
1610            out[j] += a0 * mrow[j];
1611        }
1612        double a1 = action[1];
1613        for (int j = tid; j < p_g; j += blockDim.x) {
1614            out[p_m + j] += a1 * grow[j];
1615        }
1616        if (tid == 0) {
1617            for (int k = 0; k < h_block_len; ++k) {
1618                out[h_block_start + k] += action[h_primary_start + k];
1619            }
1620            for (int k = 0; k < w_block_len; ++k) {
1621                out[w_block_start + k] += action[w_primary_start + k];
1622            }
1623        }
1624        __syncthreads();
1625    }
1626}
1627
1628// ────────────────────────────────────────────────────────────────────────
1629// Phase 6 — dense joint-Hessian block kernel for the debug / exact-REML
1630// route. Materialises the full `[p_total, p_total]` row-major joint H
1631// from the per-row r×r Hessian via the P_i pullback. NOT the default
1632// Newton path: production Newton uses HVP (Phase 2/3); this kernel exists
1633// for exact-REML logdet / dense-H comparisons / diagnostic dumps where the
1634// caller genuinely needs the dense matrix on the device.
1635//
1636// Per-CTA partial: each CTA owns a contiguous chunk of rows
1637// `[chunk*rows_per_cta, (chunk+1)*rows_per_cta)`. Inside the CTA the
1638// per-row pullback computes `(P_i^T H_i P_i)[m, n]` and adds it to the
1639// CTA's shared-mem `[p_total, p_total]` partial. The reduce kernel sums
1640// chunk-major-fixed-order into a single `[p_total, p_total]` output.
1641//
1642// Math: for primary index u ∈ [0, r):
1643//   * u = 0:        phi_u = (X_i in slot 0..p_m, 0 elsewhere)
1644//   * u = 1:        phi_u = (0, G_i in slot p_m..p_m+p_g, 0 elsewhere)
1645//   * u = 2+j:      phi_u = e_{h_block_start + j}  (j ∈ 0..h_block_len)
1646//   * u = 2+h+l:    phi_u = e_{w_block_start + l}  (l ∈ 0..w_block_len)
1647// Then `H_full[m, n] += sum_{u,v} H_i[u,v] * phi_u[m] * phi_v[n]`.
1648//
1649// Shared-memory budget: at large-scale shape p_total = 44, a [44, 44] f64
1650// partial is 44*44*8 = 15.5 KiB — well below the V100 48 KiB/SM cap.
1651// At p_total ≤ 80 the kernel still fits (80*80*8 = 50 KiB → just over
1652// V100 cap; caller must enforce p_total ≤ DENSE_BLOCK_MAX_P). The
1653// launcher rejects oversize p_total cleanly.
1654
1655extern "C" __global__ void bms_flex_row_dense_block_partial(
1656    int                  n_rows,
1657    int                  r,
1658    int                  p_m,
1659    int                  p_g,
1660    int                  p_total,
1661    int                  h_block_start,
1662    int                  h_block_len,
1663    int                  w_block_start,
1664    int                  w_block_len,
1665    int                  h_primary_start,
1666    int                  w_primary_start,
1667    int                  rows_per_cta,
1668    const double * __restrict__ row_hessians,    // [n, r*r]
1669    const double * __restrict__ marginal_design, // [n, p_m]
1670    const double * __restrict__ logslope_design, // [n, p_g]
1671    double       * __restrict__ partial)         // [num_chunks, p_total, p_total]
1672{
1673    extern __shared__ double shmem[];
1674    int chunk = blockIdx.x;
1675    int tid   = threadIdx.x;
1676    int row_lo = chunk * rows_per_cta;
1677    int row_hi = row_lo + rows_per_cta;
1678    if (row_hi > n_rows) row_hi = n_rows;
1679
1680    int pp = p_total * p_total;
1681    double *acc = shmem; // CTA-private accumulator [p_total, p_total]
1682    for (int j = tid; j < pp; j += blockDim.x) acc[j] = 0.0;
1683    __syncthreads();
1684
1685    // Per-row work performed by thread 0 to avoid cross-thread RW
1686    // contention on `acc[]`. Per-row complexity is O(r * p_m + r * p_g
1687    // + r²): tractable because r ≤ 32 and p_m + p_g typically ≤ 64.
1688    // Tighter parallel implementations are possible (warp-stripe the
1689    // 4-way nested u-v-m-n loop) but Phase 6 is a debug-only path and
1690    // the simple version is easier to audit for correctness against
1691    // the host-side P_i pullback oracle.
1692    if (tid == 0) {
1693        for (int row = row_lo; row < row_hi; ++row) {
1694            const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1695            const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1696            const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1697            for (int u = 0; u < r; ++u) {
1698                for (int v = 0; v < r; ++v) {
1699                    double huv = Hrow[u * r + v];
1700                    if (huv == 0.0) continue;
1701                    // For each (u, v), iterate (m, n) over the non-zero
1702                    // outer-product support of phi_u and phi_v.
1703                    // Build a small (offset, len, src_ptr) descriptor for
1704                    // each operand block as we go.
1705                    int m_off, m_len; const double *m_src; bool m_indicator;
1706                    int n_off, n_len; const double *n_src; bool n_indicator;
1707                    if (u == 0)      { m_off = 0;   m_len = p_m; m_src = mrow; m_indicator = false; }
1708                    else if (u == 1) { m_off = p_m; m_len = p_g; m_src = grow; m_indicator = false; }
1709                    else if (u - 2 < h_block_len) {
1710                                       m_off = h_block_start + (u - 2);
1711                                       m_len = 1;   m_src = NULL; m_indicator = true;
1712                    } else {
1713                                       m_off = w_block_start + (u - 2 - h_block_len);
1714                                       m_len = 1;   m_src = NULL; m_indicator = true;
1715                    }
1716                    if (v == 0)      { n_off = 0;   n_len = p_m; n_src = mrow; n_indicator = false; }
1717                    else if (v == 1) { n_off = p_m; n_len = p_g; n_src = grow; n_indicator = false; }
1718                    else if (v - 2 < h_block_len) {
1719                                       n_off = h_block_start + (v - 2);
1720                                       n_len = 1;   n_src = NULL; n_indicator = true;
1721                    } else {
1722                                       n_off = w_block_start + (v - 2 - h_block_len);
1723                                       n_len = 1;   n_src = NULL; n_indicator = true;
1724                    }
1725                    // accumulate huv * phi_u[m] * phi_v[n] into acc[m, n]
1726                    for (int mi = 0; mi < m_len; ++mi) {
1727                        double pm = m_indicator ? 1.0 : m_src[mi];
1728                        if (pm == 0.0) continue;
1729                        double scaled = huv * pm;
1730                        int m_idx = m_off + mi;
1731                        for (int ni = 0; ni < n_len; ++ni) {
1732                            double pn = n_indicator ? 1.0 : n_src[ni];
1733                            int n_idx = n_off + ni;
1734                            acc[m_idx * p_total + n_idx] += scaled * pn;
1735                        }
1736                    }
1737                }
1738            }
1739        }
1740    }
1741    __syncthreads();
1742
1743    // Write CTA accumulator out to global memory at its chunk slot.
1744    double *out_chunk = partial + (size_t)chunk * (size_t)pp;
1745    for (int j = tid; j < pp; j += blockDim.x) {
1746        out_chunk[j] = acc[j];
1747    }
1748}
1749
1750extern "C" __global__ void bms_flex_row_dense_block_reduce(
1751    int                  num_chunks,
1752    int                  p_total,
1753    const double * __restrict__ partial,
1754    double       * __restrict__ out)
1755{
1756    int j = blockIdx.x * blockDim.x + threadIdx.x;
1757    int pp = p_total * p_total;
1758    if (j >= pp) return;
1759    double acc = 0.0;
1760    for (int c = 0; c < num_chunks; ++c) {
1761        acc += partial[(size_t)c * (size_t)pp + (size_t)j];
1762    }
1763    out[j] = acc;
1764}
1765
1766extern "C" __global__ void bms_flex_row_diag_partial_packed(
1767    int                  n_rows,
1768    int                  r,
1769    int                  p_m,
1770    int                  p_g,
1771    int                  p_total,
1772    int                  h_block_start,
1773    int                  h_block_len,
1774    int                  w_block_start,
1775    int                  w_block_len,
1776    int                  h_primary_start,
1777    int                  w_primary_start,
1778    int                  rows_per_cta,
1779    const double * __restrict__ row_hessians_packed,
1780    const double * __restrict__ marginal_design,
1781    const double * __restrict__ logslope_design,
1782    double       * __restrict__ partial)
1783{
1784    int chunk = blockIdx.x;
1785    int tid   = threadIdx.x;
1786    int row_lo = chunk * rows_per_cta;
1787    int row_hi = row_lo + rows_per_cta;
1788    if (row_hi > n_rows) row_hi = n_rows;
1789
1790    int per_row = r * (r + 1) / 2;
1791    double *out = partial + (size_t)chunk * (size_t)p_total;
1792    for (int j = tid; j < p_total; j += blockDim.x) {
1793        out[j] = 0.0;
1794    }
1795    __syncthreads();
1796
1797    for (int row = row_lo; row < row_hi; ++row) {
1798        const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1799        const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1800        const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1801        // Diagonal entry for (u, u) sits at packed_idx(u, u, r).
1802        double h00 = Hrow[bms_flex_packed_idx(0, 0, r)];
1803        double h11 = Hrow[bms_flex_packed_idx(1, 1, r)];
1804        for (int j = tid; j < p_m; j += blockDim.x) {
1805            double v = mrow[j];
1806            out[j] += h00 * v * v;
1807        }
1808        for (int j = tid; j < p_g; j += blockDim.x) {
1809            double v = grow[j];
1810            out[p_m + j] += h11 * v * v;
1811        }
1812        if (tid == 0) {
1813            for (int k = 0; k < h_block_len; ++k) {
1814                int ii = h_primary_start + k;
1815                out[h_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1816            }
1817            for (int k = 0; k < w_block_len; ++k) {
1818                int ii = w_primary_start + k;
1819                out[w_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1820            }
1821        }
1822        __syncthreads();
1823    }
1824}
1825"#;
1826
1827#[cfg(target_os = "linux")]
1828pub(crate) struct HvpKernelBackend {
1829    pub(crate) stream: Arc<CudaStream>,
1830    pub(crate) module: Arc<CudaModule>,
1831}
1832
1833#[cfg(target_os = "linux")]
1834impl HvpKernelBackend {
1835    pub(crate) fn probe() -> Result<&'static Self, GpuError> {
1836        static BACKEND: OnceLock<Result<HvpKernelBackend, GpuError>> = OnceLock::new();
1837        BACKEND
1838            .get_or_init(|| {
1839                gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row hvp", |parts| {
1840                    let ptx = cudarc::nvrtc::compile_ptx(HVP_KERNEL_SOURCE).map_err(|err| {
1841                        GpuError::DriverCallFailed {
1842                            reason: format!("bms_flex_row hvp NVRTC compile failed: {err}"),
1843                        }
1844                    })?;
1845                    let module =
1846                        parts
1847                            .ctx
1848                            .load_module(ptx)
1849                            .map_err(|err| GpuError::DriverCallFailed {
1850                                reason: format!("bms_flex_row hvp module load failed: {err}"),
1851                            })?;
1852                    Ok(HvpKernelBackend {
1853                        stream: parts.stream.clone(),
1854                        module,
1855                    })
1856                })
1857            })
1858            .as_ref()
1859            .map_err(GpuError::clone)
1860    }
1861}
1862
1863/// Build a device-resident row-Hessian cache by launching the row kernel and
1864/// keeping the resulting `n × r²` slice resident on the device. Also uploads
1865/// the dense marginal + logslope design matrices so subsequent HVPs do not
1866/// re-upload them at every direction.
1867///
1868/// `marginal_design_row_major` and `logslope_design_row_major` must be
1869/// row-major `[n, p_m]` and `[n, p_g]` contiguous slices.
1870///
1871/// #461 absorber (additive Stage-1 influence block): the orthogonalization is
1872/// realized as **A2 — the marginal design widened to `[M | Z̃_infl]`** (see
1873/// `src/families/bms/block_specs.rs::widen_marginal_dense_with_influence`), NOT
1874/// a dedicated 5th primary coordinate. The absorber `+Z̃_infl·γ` is plain
1875/// additive into the marginal index `α(x)`, so γ lives inside `β_m` of the
1876/// widened block and `p_m` already counts the `p₁` influence columns. The row
1877/// kernel reads the marginal index from `block_states[0].eta` (which carries
1878/// `Z̃_infl·γ`) and pulls back through this same widened `marginal_design`, so
1879/// the absorber rides the existing primary-coordinate `u = 0` chain with **no
1880/// kernel-source change**: η, gradient, and Hessian match the CPU kernel
1881/// bit-for-bit precisely because `marginal_design` and `β_m` are the matched
1882/// (design, coefficient) pair the CPU path uses. The validation below pins
1883/// `marginal_design.len() == n·p_m` (with `p_m` widened), so a stale narrow
1884/// design against a widened `block.p_m` is rejected cleanly (CPU fallback)
1885/// rather than silently computing the wrong η. The absorber is dropped at
1886/// predict, where the marginal design is rebuilt without the influence columns,
1887/// so the predict-time `p_m` is narrow and this path is correct there too.
1888#[cfg(target_os = "linux")]
1889pub(crate) fn launch_bms_flex_row_kernel_device_resident(
1890    inputs: BmsFlexRowKernelInputs<'_>,
1891    marginal_design_row_major: &[f64],
1892    logslope_design_row_major: &[f64],
1893    block: BmsFlexBlockLayout,
1894    primary: BmsFlexPrimaryLayout,
1895) -> Result<DeviceResidentRowHess, GpuError> {
1896    inputs.validate()?;
1897    if !s_f_diagnostic_finite(&inputs) {
1898        return Err(GpuError::DriverCallFailed {
1899            reason: format!(
1900                "bms_flex_row device-resident: s_f must be positive and finite, got {}",
1901                inputs.s_f
1902            ),
1903        });
1904    }
1905    let n = inputs.n_rows;
1906    let r = inputs.r;
1907    if marginal_design_row_major.len() != n * block.p_m {
1908        return Err(GpuError::DriverCallFailed {
1909            reason: format!(
1910                "bms_flex_row device-resident: marginal_design len={} != n*p_m={}",
1911                marginal_design_row_major.len(),
1912                n * block.p_m
1913            ),
1914        });
1915    }
1916    if logslope_design_row_major.len() != n * block.p_g {
1917        return Err(GpuError::DriverCallFailed {
1918            reason: format!(
1919                "bms_flex_row device-resident: logslope_design len={} != n*p_g={}",
1920                logslope_design_row_major.len(),
1921                n * block.p_g
1922            ),
1923        });
1924    }
1925    if primary.r != r {
1926        return Err(GpuError::DriverCallFailed {
1927            reason: format!(
1928                "bms_flex_row device-resident: primary.r={} != inputs.r={}",
1929                primary.r, r
1930            ),
1931        });
1932    }
1933
1934    // Ensure the row kernel backend is compiled & loaded (this also compiles
1935    // the HVP backend on first use so the caller surfaces failures here).
1936    let backend = RowKernelBackend::probe()?;
1937    HvpKernelBackend::probe()?;
1938    let stream = backend.stream.clone();
1939
1940    let upload_f64 = |slice: &[f64], label: &str| {
1941        stream
1942            .clone_htod(slice)
1943            .map_err(|err| GpuError::DriverCallFailed {
1944                reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1945            })
1946    };
1947    let upload_u32 = |slice: &[u32], label: &str| {
1948        stream
1949            .clone_htod(slice)
1950            .map_err(|err| GpuError::DriverCallFailed {
1951                reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1952            })
1953    };
1954
1955    let d_q = upload_f64(inputs.q, "q")?;
1956    let d_b = upload_f64(inputs.b, "b")?;
1957    let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
1958    let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
1959    let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
1960    let d_y = upload_f64(inputs.y, "y")?;
1961    let d_w = upload_f64(inputs.w, "w")?;
1962    let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
1963    let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
1964    let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
1965    let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
1966    let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
1967    let d_a = upload_f64(inputs.cell_a, "cell_a")?;
1968    let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
1969    let d_r = upload_f64(inputs.cell_r, "cell_r")?;
1970    let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
1971    let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
1972    let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
1973    let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
1974    // Phase-4: optionally consume device-resident moments (no host upload).
1975    let owned_host_moments: CudaSlice<f64>;
1976    let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
1977        CellMomentsSource::Host(slice) => {
1978            owned_host_moments = upload_f64(slice, "cell_moments")?;
1979            &owned_host_moments
1980        }
1981        CellMomentsSource::Device(d) => *d,
1982    };
1983    let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
1984    let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
1985    let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
1986    let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
1987    let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
1988
1989    let d_marginal = upload_f64(marginal_design_row_major, "marginal_design")?;
1990    let d_logslope = upload_f64(logslope_design_row_major, "logslope_design")?;
1991
1992    let mut d_neglog = stream
1993        .alloc_zeros::<f64>(n)
1994        .map_err(|err| GpuError::DriverCallFailed {
1995            reason: format!("bms_flex_row device-resident alloc neglog: {err}"),
1996        })?;
1997    let mut d_grad =
1998        stream
1999            .alloc_zeros::<f64>(n * r)
2000            .map_err(|err| GpuError::DriverCallFailed {
2001                reason: format!("bms_flex_row device-resident alloc grad: {err}"),
2002            })?;
2003    let mut d_hess =
2004        stream
2005            .alloc_zeros::<f64>(n * r * r)
2006            .map_err(|err| GpuError::DriverCallFailed {
2007                reason: format!("bms_flex_row device-resident alloc hess: {err}"),
2008            })?;
2009
2010    let func = backend
2011        .module
2012        .load_function("bms_flex_row_kernel")
2013        .map_err(|err| GpuError::DriverCallFailed {
2014            reason: format!("bms_flex_row device-resident load_function: {err}"),
2015        })?;
2016
2017    let cfg = LaunchConfig {
2018        grid_dim: (n as u32, 1, 1),
2019        block_dim: (ROW_KERNEL_THREADS, 1, 1),
2020        shared_mem_bytes: 0,
2021    };
2022    let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
2023        reason: format!("bms_flex_row device-resident: n_rows={n} exceeds i32 range"),
2024    })?;
2025    let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
2026        reason: format!("bms_flex_row device-resident: r={r} exceeds i32 range"),
2027    })?;
2028    let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
2029        reason: format!(
2030            "bms_flex_row device-resident: p_h={} exceeds i32 range",
2031            inputs.p_h
2032        ),
2033    })?;
2034    let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
2035        reason: format!(
2036            "bms_flex_row device-resident: p_w={} exceeds i32 range",
2037            inputs.p_w
2038        ),
2039    })?;
2040    let s_f_val = inputs.s_f;
2041
2042    let mut builder = stream.launch_builder(&func);
2043    builder
2044        .arg(&n_i32)
2045        .arg(&r_i32)
2046        .arg(&p_h_i32)
2047        .arg(&p_w_i32)
2048        .arg(&s_f_val)
2049        .arg(&d_q)
2050        .arg(&d_b)
2051        .arg(&d_mu1)
2052        .arg(&d_mu2)
2053        .arg(&d_zobs)
2054        .arg(&d_y)
2055        .arg(&d_w)
2056        .arg(&d_offsets)
2057        .arg(&d_c0)
2058        .arg(&d_c1)
2059        .arg(&d_c2)
2060        .arg(&d_c3)
2061        .arg(&d_a)
2062        .arg(&d_aa)
2063        .arg(&d_r)
2064        .arg(&d_ar)
2065        .arg(&d_sbb)
2066        .arg(&d_sbh)
2067        .arg(&d_sbw)
2068        .arg(d_moments_ref)
2069        .arg(&d_chi)
2070        .arg(&d_xi)
2071        .arg(&d_rho)
2072        .arg(&d_tau)
2073        .arg(&d_ruv)
2074        .arg(&mut d_neglog)
2075        .arg(&mut d_grad)
2076        .arg(&mut d_hess);
2077    // SAFETY: same shape contract as `launch_linux`: every kernel parameter is
2078    // either a primitive scalar by-value, a const device pointer whose
2079    // capacity was validated by `inputs.validate()`, or one of the three
2080    // output buffers we just allocated with the expected element count.
2081    unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
2082        reason: format!("bms_flex_row device-resident launch: {err}"),
2083    })?;
2084    stream
2085        .synchronize()
2086        .map_err(|err| GpuError::DriverCallFailed {
2087            reason: format!("bms_flex_row device-resident synchronize: {err}"),
2088        })?;
2089
2090    // The kernel writes neglog + grad alongside the row Hessian, but the
2091    // device-resident cache path keeps neither on the host: the fused
2092    // CPU gradient pass (the only consumer of host-side neglog/grad) is
2093    // dispatched only as a fallback when the GPU dense-block kernel does
2094    // not apply, and in that fallback the row kernel runs again locally.
2095    // Drop the device buffers so the allocation pool reclaims them
2096    // immediately rather than tying them to the cache's lifetime.
2097    drop(d_neglog);
2098    drop(d_grad);
2099    // Drop the per-cell uploads; keep d_hess + designs.
2100    drop(d_q);
2101    drop(d_b);
2102    drop(d_mu1);
2103    drop(d_mu2);
2104    drop(d_zobs);
2105    drop(d_y);
2106    drop(d_w);
2107    drop(d_offsets);
2108    drop(d_c0);
2109    drop(d_c1);
2110    drop(d_c2);
2111    drop(d_c3);
2112    drop(d_a);
2113    drop(d_aa);
2114    drop(d_r);
2115    drop(d_ar);
2116    drop(d_sbb);
2117    drop(d_sbh);
2118    drop(d_sbw);
2119    // `owned_host_moments` (if any) and the borrowed `d_moments_ref` both
2120    // go out of scope at the end of the function; the device-resident
2121    // moments owned by the caller stay alive.
2122    drop(d_chi);
2123    drop(d_xi);
2124    drop(d_rho);
2125    drop(d_tau);
2126    drop(d_ruv);
2127
2128    let bytes = ((n * r * r + marginal_design_row_major.len() + logslope_design_row_major.len())
2129        * std::mem::size_of::<f64>()) as u64;
2130    Ok(DeviceResidentRowHess {
2131        hess: d_hess,
2132        marginal_design: d_marginal,
2133        logslope_design: d_logslope,
2134        n,
2135        r,
2136        block,
2137        primary,
2138        bytes,
2139    })
2140}
2141
2142/// Which partial kernel the joint-β engine drives, whether it consumes a
2143/// direction vector `d_v`, and where the reduced `[1, p_total]` image lands.
2144/// All three points of variation are encoded here so the public entry points
2145/// stay thin wrappers over one launch helper.
2146#[cfg(target_os = "linux")]
2147#[derive(Clone, Copy)]
2148pub(crate) enum BmsFlexRowLaunchMode {
2149    /// `bms_flex_row_hvp_partial`, `H · v` per row, result left on-stream.
2150    HvpDeviceOut,
2151    /// `bms_flex_row_diag_partial`, `diag(H)` per row, downloaded to host.
2152    DiagonalHostOut,
2153}
2154
2155#[cfg(target_os = "linux")]
2156impl BmsFlexRowLaunchMode {
2157    /// Name of the partial kernel this mode loads from the HVP module.
2158    pub(crate) fn partial_kernel_name(self) -> &'static str {
2159        match self {
2160            BmsFlexRowLaunchMode::HvpDeviceOut => "bms_flex_row_hvp_partial",
2161            BmsFlexRowLaunchMode::DiagonalHostOut => "bms_flex_row_diag_partial",
2162        }
2163    }
2164}
2165
2166/// All scalar launch arguments for the joint-β partial kernel, derived once
2167/// from a [`DeviceResidentRowHess`]. The HVP and diagonal partial kernels take
2168/// the identical leading block-layout argument list (only the trailing
2169/// `d_v` / output pointers differ), so this captures the long, easy-to-
2170/// desynchronize prefix in a single place.
2171#[cfg(target_os = "linux")]
2172pub(crate) struct PreparedBmsFlexRowLaunchArgs {
2173    pub(crate) n_i32: i32,
2174    pub(crate) r_i32: i32,
2175    pub(crate) p_m_i32: i32,
2176    pub(crate) p_g_i32: i32,
2177    pub(crate) p_total_i32: i32,
2178    pub(crate) h_block_start: i32,
2179    pub(crate) h_block_len: i32,
2180    pub(crate) w_block_start: i32,
2181    pub(crate) w_block_len: i32,
2182    pub(crate) h_primary_start: i32,
2183    pub(crate) w_primary_start: i32,
2184    pub(crate) rows_per_cta: i32,
2185    pub(crate) num_chunks: usize,
2186}
2187
2188#[cfg(target_os = "linux")]
2189impl PreparedBmsFlexRowLaunchArgs {
2190    pub(crate) fn from_storage(storage: &DeviceResidentRowHess) -> Self {
2191        let p_total = storage.block.p_total;
2192        let num_chunks = num_hvp_chunks(storage.n);
2193        PreparedBmsFlexRowLaunchArgs {
2194            n_i32: storage.n as i32,
2195            r_i32: storage.r as i32,
2196            p_m_i32: storage.block.p_m as i32,
2197            p_g_i32: storage.block.p_g as i32,
2198            p_total_i32: p_total as i32,
2199            h_block_start: storage
2200                .block
2201                .h
2202                .as_ref()
2203                .map(|r| r.start as i32)
2204                .unwrap_or(0),
2205            h_block_len: storage
2206                .block
2207                .h
2208                .as_ref()
2209                .map(|r| r.len() as i32)
2210                .unwrap_or(0),
2211            w_block_start: storage
2212                .block
2213                .w
2214                .as_ref()
2215                .map(|r| r.start as i32)
2216                .unwrap_or(0),
2217            w_block_len: storage
2218                .block
2219                .w
2220                .as_ref()
2221                .map(|r| r.len() as i32)
2222                .unwrap_or(0),
2223            h_primary_start: storage
2224                .primary
2225                .h
2226                .as_ref()
2227                .map(|r| r.start as i32)
2228                .unwrap_or(0),
2229            w_primary_start: storage
2230                .primary
2231                .w
2232                .as_ref()
2233                .map(|r| r.start as i32)
2234                .unwrap_or(0),
2235            rows_per_cta: HVP_ROWS_PER_CTA as i32,
2236            num_chunks,
2237        }
2238    }
2239}
2240
2241/// Shared partial+reduce engine behind every joint-β launcher.
2242///
2243/// Allocates the `[num_chunks, p_total]` partial buffer, loads the mode's
2244/// partial kernel plus the common `bms_flex_row_hvp_reduce`, builds both
2245/// launch configs from a single [`PreparedBmsFlexRowLaunchArgs`], launches the
2246/// partial kernel (binding `d_v` only for the HVP modes), and launches the
2247/// reduction into caller-supplied `d_out`.
2248///
2249/// **No** `synchronize()` or DtoH is performed here — the surrounding helper
2250/// decides whether to keep the result on-stream (device-resident PCG hot path)
2251/// or sync + download it to the host. `ctx` is a short error-context tag woven
2252/// into every `DriverCallFailed` reason so failures stay attributable to the
2253/// originating entry point.
2254#[cfg(target_os = "linux")]
2255pub(crate) fn run_bms_flex_row_partial_reduce(
2256    storage: &DeviceResidentRowHess,
2257    mode: BmsFlexRowLaunchMode,
2258    d_v: Option<&CudaSlice<f64>>,
2259    d_out: &mut CudaSlice<f64>,
2260    ctx: &str,
2261) -> Result<(), GpuError> {
2262    let backend = HvpKernelBackend::probe()?;
2263    let stream = backend.stream.clone();
2264    let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2265    let p_total = storage.block.p_total;
2266
2267    let mut d_partial = stream
2268        .alloc_zeros::<f64>(args.num_chunks * p_total)
2269        .map_err(|err| GpuError::DriverCallFailed {
2270            reason: format!("bms_flex_row {ctx} alloc partial: {err}"),
2271        })?;
2272
2273    let partial_kernel_name = mode.partial_kernel_name();
2274    let part_func = backend
2275        .module
2276        .load_function(partial_kernel_name)
2277        .map_err(|err| GpuError::DriverCallFailed {
2278            reason: format!("bms_flex_row {ctx} load {partial_kernel_name}: {err}"),
2279        })?;
2280    let red_func = backend
2281        .module
2282        .load_function("bms_flex_row_hvp_reduce")
2283        .map_err(|err| GpuError::DriverCallFailed {
2284            reason: format!("bms_flex_row {ctx} load reduce: {err}"),
2285        })?;
2286
2287    let cfg_part = LaunchConfig {
2288        grid_dim: (args.num_chunks as u32, 1, 1),
2289        block_dim: (HVP_THREADS, 1, 1),
2290        shared_mem_bytes: 0,
2291    };
2292    let mut builder = stream.launch_builder(&part_func);
2293    builder
2294        .arg(&args.n_i32)
2295        .arg(&args.r_i32)
2296        .arg(&args.p_m_i32)
2297        .arg(&args.p_g_i32)
2298        .arg(&args.p_total_i32)
2299        .arg(&args.h_block_start)
2300        .arg(&args.h_block_len)
2301        .arg(&args.w_block_start)
2302        .arg(&args.w_block_len)
2303        .arg(&args.h_primary_start)
2304        .arg(&args.w_primary_start)
2305        .arg(&args.rows_per_cta)
2306        .arg(&storage.hess)
2307        .arg(&storage.marginal_design)
2308        .arg(&storage.logslope_design);
2309    if let Some(d_v) = d_v {
2310        builder.arg(d_v);
2311    }
2312    builder.arg(&mut d_partial);
2313    // SAFETY: every device pointer above either comes from `storage` (whose
2314    // capacities were established by
2315    // `launch_bms_flex_row_kernel_device_resident`) or was just allocated here
2316    // (`d_partial` = num_chunks * p_total). `d_v`, when bound, is length-checked
2317    // by the calling adapter against `p_total`. The diagonal partial kernel
2318    // takes no direction argument, matching `d_v == None`. Scalar args are i32
2319    // by-value.
2320    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2321        reason: format!("bms_flex_row {ctx} partial launch: {err}"),
2322    })?;
2323
2324    let red_threads: u32 = REDUCTION_THREADS;
2325    let red_blocks: u32 = ((p_total as u32) + red_threads - 1) / red_threads;
2326    let cfg_red = LaunchConfig {
2327        grid_dim: (red_blocks, 1, 1),
2328        block_dim: (red_threads, 1, 1),
2329        shared_mem_bytes: 0,
2330    };
2331    let num_chunks_i32 = args.num_chunks as i32;
2332    let mut builder = stream.launch_builder(&red_func);
2333    builder
2334        .arg(&num_chunks_i32)
2335        .arg(&args.p_total_i32)
2336        .arg(&d_partial)
2337        .arg(d_out);
2338    // SAFETY: `d_partial` was just populated by the partial kernel above;
2339    // `d_out` is `p_total` doubles (length-checked / allocated by the calling
2340    // adapter); both scalar args fit i32.
2341    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2342        reason: format!("bms_flex_row {ctx} reduce launch: {err}"),
2343    })?;
2344    // `d_partial` drops at end of fn; cudarc keeps the alloc alive until the
2345    // stream is done with it, so the reduce kernel completes safely.
2346    drop(d_partial);
2347    Ok(())
2348}
2349
2350/// Host-returning joint-β launcher shared by [`launch_bms_flex_row_hvp`]
2351/// ([`BmsFlexRowLaunchMode::HvpHostOut`]) and [`launch_bms_flex_row_diagonal`]
2352/// ([`BmsFlexRowLaunchMode::DiagonalHostOut`]).
2353///
2354/// Probes the backend, allocates the `p_total`-double output on its stream,
2355/// optionally uploads a host direction `v` (HVP modes; `None` for the diagonal
2356/// mode, which takes no direction), runs the shared partial+reduce engine, then
2357/// synchronizes and downloads the reduced image to a host `Vec<f64>`. `ctx`
2358/// tags every `DriverCallFailed` reason with the originating entry point.
2359#[cfg(target_os = "linux")]
2360pub(crate) fn launch_bms_flex_row_host(
2361    storage: &DeviceResidentRowHess,
2362    mode: BmsFlexRowLaunchMode,
2363    v: Option<&[f64]>,
2364    ctx: &str,
2365) -> Result<Vec<f64>, GpuError> {
2366    let p_total = storage.block.p_total;
2367    if let Some(v) = v {
2368        if v.len() != p_total {
2369            return Err(GpuError::DriverCallFailed {
2370                reason: format!(
2371                    "bms_flex_row {ctx}: v.len()={} != p_total={p_total}",
2372                    v.len()
2373                ),
2374            });
2375        }
2376    }
2377
2378    let backend = HvpKernelBackend::probe()?;
2379    let stream = backend.stream.clone();
2380
2381    let d_v = match v {
2382        Some(v) => Some(
2383            stream
2384                .clone_htod(v)
2385                .map_err(|err| GpuError::DriverCallFailed {
2386                    reason: format!("bms_flex_row {ctx} upload v: {err}"),
2387                })?,
2388        ),
2389        None => None,
2390    };
2391    let mut d_out =
2392        stream
2393            .alloc_zeros::<f64>(p_total)
2394            .map_err(|err| GpuError::DriverCallFailed {
2395                reason: format!("bms_flex_row {ctx} alloc out: {err}"),
2396            })?;
2397
2398    run_bms_flex_row_partial_reduce(storage, mode, d_v.as_ref(), &mut d_out, ctx)?;
2399
2400    stream
2401        .synchronize()
2402        .map_err(|err| GpuError::DriverCallFailed {
2403            reason: format!("bms_flex_row {ctx} synchronize: {err}"),
2404        })?;
2405    stream
2406        .clone_dtoh(&d_out)
2407        .map_err(|err| GpuError::DriverCallFailed {
2408            reason: format!("bms_flex_row {ctx} download out: {err}"),
2409        })
2410}
2411
2412#[cfg(target_os = "linux")]
2413pub(crate) fn validate_bms_flex_row_hvp_multi_shape(
2414    storage: &DeviceResidentRowHess,
2415    rhs_count: usize,
2416    v_rhs_len: usize,
2417    out_len: Option<usize>,
2418    ctx: &str,
2419) -> Result<usize, GpuError> {
2420    if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2421        return Err(GpuError::DriverCallFailed {
2422            reason: format!(
2423                "bms_flex_row {ctx}: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2424            ),
2425        });
2426    }
2427    let p_total = storage.block.p_total;
2428    let rhs_elems = rhs_count
2429        .checked_mul(p_total)
2430        .ok_or_else(|| GpuError::DriverCallFailed {
2431            reason: format!(
2432                "bms_flex_row {ctx}: rhs_count({rhs_count})*p_total({p_total}) overflow"
2433            ),
2434        })?;
2435    if v_rhs_len != rhs_elems {
2436        return Err(GpuError::DriverCallFailed {
2437            reason: format!(
2438                "bms_flex_row {ctx}: v_rhs.len()={v_rhs_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2439            ),
2440        });
2441    }
2442    if let Some(out_len) = out_len
2443        && out_len != rhs_elems
2444    {
2445        return Err(GpuError::DriverCallFailed {
2446            reason: format!(
2447                "bms_flex_row {ctx}: out.len()={out_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2448            ),
2449        });
2450    }
2451    Ok(rhs_elems)
2452}
2453
2454/// Transient device bytes for a multi-RHS HVP launch, excluding persistent
2455/// row-Hessian/design storage. Scratch scales with
2456/// `rhs_count * num_chunks * p_total`, not `rhs_count * n * r * r`.
2457#[cfg(target_os = "linux")]
2458pub fn bms_flex_row_hvp_multi_scratch_bytes_for_shape(
2459    n: usize,
2460    p_total: usize,
2461    rhs_count: usize,
2462) -> Result<u64, GpuError> {
2463    if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2464        return Err(GpuError::DriverCallFailed {
2465            reason: format!(
2466                "bms_flex_row hvp_multi_scratch_bytes: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2467            ),
2468        });
2469    }
2470    let num_chunks = num_hvp_chunks(n);
2471    let partial = rhs_count
2472        .checked_mul(num_chunks)
2473        .and_then(|v| v.checked_mul(p_total))
2474        .ok_or_else(|| GpuError::DriverCallFailed {
2475            reason: format!(
2476                "bms_flex_row hvp_multi_scratch_bytes: rhs_count({rhs_count})*num_chunks({num_chunks})*p_total({p_total}) overflow"
2477            ),
2478        })?;
2479    let rhs_vectors = rhs_count
2480        .checked_mul(p_total)
2481        .and_then(|v| v.checked_mul(2))
2482        .ok_or_else(|| GpuError::DriverCallFailed {
2483            reason: format!(
2484                "bms_flex_row hvp_multi_scratch_bytes: 2*rhs_count({rhs_count})*p_total({p_total}) overflow"
2485            ),
2486        })?;
2487    let elems = partial
2488        .checked_add(rhs_vectors)
2489        .ok_or_else(|| GpuError::DriverCallFailed {
2490            reason: "bms_flex_row hvp_multi_scratch_bytes: element count overflow".to_string(),
2491        })?;
2492    Ok((elems * std::mem::size_of::<f64>()) as u64)
2493}
2494
2495#[cfg(target_os = "linux")]
2496pub(crate) fn run_bms_flex_row_multi_partial_reduce(
2497    storage: &DeviceResidentRowHess,
2498    rhs_count: usize,
2499    d_v_rhs: &CudaSlice<f64>,
2500    d_out: &mut CudaSlice<f64>,
2501    ctx: &str,
2502) -> Result<(), GpuError> {
2503    let rhs_elems = validate_bms_flex_row_hvp_multi_shape(
2504        storage,
2505        rhs_count,
2506        d_v_rhs.len(),
2507        Some(d_out.len()),
2508        ctx,
2509    )?;
2510    let backend = HvpKernelBackend::probe()?;
2511    let stream = backend.stream.clone();
2512    let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2513    let p_total = storage.block.p_total;
2514    let partial_len = rhs_count
2515        .checked_mul(args.num_chunks)
2516        .and_then(|v| v.checked_mul(p_total))
2517        .ok_or_else(|| GpuError::DriverCallFailed {
2518            reason: format!(
2519                "bms_flex_row {ctx}: partial length overflow for rhs_count={rhs_count}, num_chunks={}, p_total={p_total}",
2520                args.num_chunks
2521            ),
2522        })?;
2523
2524    let mut d_partial =
2525        stream
2526            .alloc_zeros::<f64>(partial_len)
2527            .map_err(|err| GpuError::DriverCallFailed {
2528                reason: format!("bms_flex_row {ctx} alloc multi partial: {err}"),
2529            })?;
2530    let part_func = backend
2531        .module
2532        .load_function("bms_flex_row_hvp_multi_partial")
2533        .map_err(|err| GpuError::DriverCallFailed {
2534            reason: format!("bms_flex_row {ctx} load multi partial: {err}"),
2535        })?;
2536    let red_func = backend
2537        .module
2538        .load_function("bms_flex_row_hvp_multi_reduce")
2539        .map_err(|err| GpuError::DriverCallFailed {
2540            reason: format!("bms_flex_row {ctx} load multi reduce: {err}"),
2541        })?;
2542
2543    let rhs_count_i32 = i32::try_from(rhs_count).map_err(|_| GpuError::DriverCallFailed {
2544        reason: format!("bms_flex_row {ctx}: rhs_count={rhs_count} exceeds i32 range"),
2545    })?;
2546    let cfg_part = LaunchConfig {
2547        grid_dim: (args.num_chunks as u32, 1, 1),
2548        block_dim: (HVP_THREADS, 1, 1),
2549        shared_mem_bytes: 0,
2550    };
2551    let mut builder = stream.launch_builder(&part_func);
2552    builder
2553        .arg(&args.n_i32)
2554        .arg(&args.r_i32)
2555        .arg(&args.p_m_i32)
2556        .arg(&args.p_g_i32)
2557        .arg(&args.p_total_i32)
2558        .arg(&args.h_block_start)
2559        .arg(&args.h_block_len)
2560        .arg(&args.w_block_start)
2561        .arg(&args.w_block_len)
2562        .arg(&args.h_primary_start)
2563        .arg(&args.w_primary_start)
2564        .arg(&args.rows_per_cta)
2565        .arg(&rhs_count_i32)
2566        .arg(&storage.hess)
2567        .arg(&storage.marginal_design)
2568        .arg(&storage.logslope_design)
2569        .arg(d_v_rhs)
2570        .arg(&mut d_partial);
2571    // SAFETY: storage buffers were validated at construction; `d_v_rhs` and
2572    // `d_out` have rhs_count*p_total elements, `d_partial` has
2573    // rhs_count*num_chunks*p_total, and rhs_count is bounded by fixed shared
2574    // array sizes in the CUDA source.
2575    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2576        reason: format!("bms_flex_row {ctx} multi partial launch: {err}"),
2577    })?;
2578
2579    let red_threads: u32 = REDUCTION_THREADS;
2580    let red_blocks: u32 = ((rhs_elems as u32) + red_threads - 1) / red_threads;
2581    let cfg_red = LaunchConfig {
2582        grid_dim: (red_blocks, 1, 1),
2583        block_dim: (red_threads, 1, 1),
2584        shared_mem_bytes: 0,
2585    };
2586    let num_chunks_i32 = args.num_chunks as i32;
2587    let mut builder = stream.launch_builder(&red_func);
2588    builder
2589        .arg(&num_chunks_i32)
2590        .arg(&args.p_total_i32)
2591        .arg(&rhs_count_i32)
2592        .arg(&d_partial)
2593        .arg(d_out);
2594    // SAFETY: the reduce kernel reads the just-populated partial buffer and
2595    // writes exactly rhs_count*p_total output entries.
2596    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2597        reason: format!("bms_flex_row {ctx} multi reduce launch: {err}"),
2598    })?;
2599    drop(d_partial);
2600    Ok(())
2601}
2602
2603/// Device-resident multi-RHS HVP. `v_rhs` is row-major
2604/// `[rhs_count, p_total]`; the returned vector has the same layout.
2605#[cfg(target_os = "linux")]
2606pub(crate) fn launch_bms_flex_row_hvp_multi(
2607    storage: &DeviceResidentRowHess,
2608    v_rhs: &[f64],
2609    rhs_count: usize,
2610) -> Result<Vec<f64>, GpuError> {
2611    let rhs_elems =
2612        validate_bms_flex_row_hvp_multi_shape(storage, rhs_count, v_rhs.len(), None, "hvp_multi")?;
2613    let backend = HvpKernelBackend::probe()?;
2614    let stream = backend.stream.clone();
2615    let d_v_rhs = stream
2616        .clone_htod(v_rhs)
2617        .map_err(|err| GpuError::DriverCallFailed {
2618            reason: format!("bms_flex_row hvp_multi upload v_rhs: {err}"),
2619        })?;
2620    let mut d_out =
2621        stream
2622            .alloc_zeros::<f64>(rhs_elems)
2623            .map_err(|err| GpuError::DriverCallFailed {
2624                reason: format!("bms_flex_row hvp_multi alloc out: {err}"),
2625            })?;
2626    run_bms_flex_row_multi_partial_reduce(storage, rhs_count, &d_v_rhs, &mut d_out, "hvp_multi")?;
2627    stream
2628        .synchronize()
2629        .map_err(|err| GpuError::DriverCallFailed {
2630            reason: format!("bms_flex_row hvp_multi synchronize: {err}"),
2631        })?;
2632    stream
2633        .clone_dtoh(&d_out)
2634        .map_err(|err| GpuError::DriverCallFailed {
2635            reason: format!("bms_flex_row hvp_multi download out: {err}"),
2636        })
2637}
2638
2639/// Device-output HVP. Runs `bms_flex_row_hvp_partial(_packed)` +
2640/// `bms_flex_row_hvp_reduce` on the storage's stream against caller-supplied
2641/// device-resident `d_v` (length `p_total` doubles), writing the result into
2642/// caller-supplied `d_out` (also `p_total` doubles). **No** `synchronize()`
2643/// or DtoH is performed — the caller is responsible for stream ordering
2644/// against any consumer that reads `d_out`.
2645///
2646/// This is the device-resident PCG hot path (Block 9 Phase 5): keeping the
2647/// HVP output on the stream lets the outer PCG loop chain axpy / dot /
2648/// preconditioner kernels back-to-back without a per-iter device sync.
2649#[cfg(target_os = "linux")]
2650pub(crate) fn launch_bms_flex_row_hvp_into_device(
2651    storage: &DeviceResidentRowHess,
2652    d_v: &CudaSlice<f64>,
2653    d_out: &mut CudaSlice<f64>,
2654) -> Result<(), GpuError> {
2655    let p_total = storage.block.p_total;
2656    if d_v.len() != p_total {
2657        return Err(GpuError::DriverCallFailed {
2658            reason: format!(
2659                "bms_flex_row hvp_into_device: d_v.len()={} != p_total={}",
2660                d_v.len(),
2661                p_total
2662            ),
2663        });
2664    }
2665    if d_out.len() != p_total {
2666        return Err(GpuError::DriverCallFailed {
2667            reason: format!(
2668                "bms_flex_row hvp_into_device: d_out.len()={} != p_total={}",
2669                d_out.len(),
2670                p_total
2671            ),
2672        });
2673    }
2674    // On-stream output: the shared engine launches partial+reduce into the
2675    // caller's `d_out` and returns without sync/DtoH, so the outer PCG loop can
2676    // chain device kernels against the result.
2677    run_bms_flex_row_partial_reduce(
2678        storage,
2679        BmsFlexRowLaunchMode::HvpDeviceOut,
2680        Some(d_v),
2681        d_out,
2682        "hvp_into_device",
2683    )
2684}
2685
2686/// Launch the device-resident HVP kernel. Returns the host-side joint β image
2687/// of length `block.p_total`.
2688#[cfg(target_os = "linux")]
2689pub(crate) fn launch_bms_flex_row_hvp(
2690    storage: &DeviceResidentRowHess,
2691    v: &[f64],
2692) -> Result<Vec<f64>, GpuError> {
2693    launch_bms_flex_row_hvp_multi(storage, v, 1)
2694}
2695
2696/// Launch the device-resident diagonal kernel. Returns the host-side joint
2697/// β diagonal of length `block.p_total`.
2698#[cfg(target_os = "linux")]
2699pub(crate) fn launch_bms_flex_row_diagonal(
2700    storage: &DeviceResidentRowHess,
2701) -> Result<Vec<f64>, GpuError> {
2702    launch_bms_flex_row_host(storage, BmsFlexRowLaunchMode::DiagonalHostOut, None, "diag")
2703}
2704
2705/// Block 9 Phase 6 — hard cap on `p_total` for the dense joint-Hessian
2706/// device kernel. Per-CTA shared-memory accumulator is `p_total² * 8`
2707/// bytes. V100 default per-block shared cap is 48 KiB, so the largest
2708/// safe `p_total` here is `sqrt(48 KiB / 8) = 78`. We round down to a
2709/// power-of-two-ish multiple of 8 for predictable launch geometry.
2710#[cfg(target_os = "linux")]
2711pub(crate) const DENSE_BLOCK_MAX_P: usize = 72;
2712
2713/// Number of rows each dense-block CTA processes. Smaller than the HVP
2714/// `HVP_ROWS_PER_CTA = 256` because the per-row inner loop is `O(r² *
2715/// (p_m + p_g + h_block_len + w_block_len))` rather than `O(r²)` — fewer
2716/// rows per CTA keeps the per-CTA wall time short and lets us scale grid
2717/// occupancy with `num_chunks = ceil(n / DENSE_BLOCK_ROWS_PER_CTA)`.
2718#[cfg(target_os = "linux")]
2719pub(crate) const DENSE_BLOCK_ROWS_PER_CTA: u32 = 32;
2720
2721/// Launch the Phase-6 dense joint-Hessian block kernel. Returns the
2722/// host-side `[p_total, p_total]` row-major joint H as a `Vec<f64>`
2723/// (length `p_total²`).
2724///
2725/// **Not the default Newton path.** Production Newton uses HVP (Phase 2)
2726/// and never materialises the full dense Hessian. This entry exists for:
2727///   * exact-REML logdet (`log|H|`) when the unified evaluator wants to
2728///     factor H directly instead of going through the matrix-free path;
2729///   * diagnostic dumps that compare the GPU dense build against the CPU
2730///     `BernoulliMarginalSlopeFamily::fused_gradient_dense` reference;
2731///   * small-`p` debug routes where it is cheaper to factor + solve dense
2732///     than to run a PCG.
2733///
2734/// The kernel rejects `p_total > DENSE_BLOCK_MAX_P` cleanly because the
2735/// per-CTA shared-memory accumulator (`p_total² * 8` bytes) would exceed
2736/// the V100 48 KiB/block cap above that threshold.
2737#[cfg(target_os = "linux")]
2738pub fn launch_bms_flex_row_dense_block(
2739    storage: &DeviceResidentRowHess,
2740) -> Result<Vec<f64>, GpuError> {
2741    let p_total = storage.block.p_total;
2742    if p_total == 0 {
2743        return Err(GpuError::DriverCallFailed {
2744            reason: "bms_flex_row dense_block: p_total must be > 0".to_string(),
2745        });
2746    }
2747    if p_total > DENSE_BLOCK_MAX_P {
2748        return Err(GpuError::DriverCallFailed {
2749            reason: format!(
2750                "bms_flex_row dense_block: p_total={p_total} exceeds DENSE_BLOCK_MAX_P={DENSE_BLOCK_MAX_P} \
2751                 (per-CTA shmem accumulator p²*8 bytes would exceed V100's 48 KiB/block)"
2752            ),
2753        });
2754    }
2755    let backend = HvpKernelBackend::probe()?;
2756    let stream = backend.stream.clone();
2757    let n = storage.n;
2758    let r = storage.r;
2759    let rows_per_cta = DENSE_BLOCK_ROWS_PER_CTA as usize;
2760    let num_chunks = n.div_ceil(rows_per_cta);
2761    let pp = p_total * p_total;
2762
2763    let mut d_partial =
2764        stream
2765            .alloc_zeros::<f64>(num_chunks * pp)
2766            .map_err(|err| GpuError::DriverCallFailed {
2767                reason: format!("bms_flex_row dense_block alloc partial: {err}"),
2768            })?;
2769    let mut d_out = stream
2770        .alloc_zeros::<f64>(pp)
2771        .map_err(|err| GpuError::DriverCallFailed {
2772            reason: format!("bms_flex_row dense_block alloc out: {err}"),
2773        })?;
2774
2775    let part_func = backend
2776        .module
2777        .load_function("bms_flex_row_dense_block_partial")
2778        .map_err(|err| GpuError::DriverCallFailed {
2779            reason: format!("bms_flex_row dense_block load partial: {err}"),
2780        })?;
2781    let red_func = backend
2782        .module
2783        .load_function("bms_flex_row_dense_block_reduce")
2784        .map_err(|err| GpuError::DriverCallFailed {
2785            reason: format!("bms_flex_row dense_block load reduce: {err}"),
2786        })?;
2787
2788    let n_i32 = n as i32;
2789    let r_i32 = r as i32;
2790    let p_m_i32 = storage.block.p_m as i32;
2791    let p_g_i32 = storage.block.p_g as i32;
2792    let p_total_i32 = p_total as i32;
2793    let h_block_start = storage
2794        .block
2795        .h
2796        .as_ref()
2797        .map(|r| r.start as i32)
2798        .unwrap_or(0);
2799    let h_block_len = storage
2800        .block
2801        .h
2802        .as_ref()
2803        .map(|r| r.len() as i32)
2804        .unwrap_or(0);
2805    let w_block_start = storage
2806        .block
2807        .w
2808        .as_ref()
2809        .map(|r| r.start as i32)
2810        .unwrap_or(0);
2811    let w_block_len = storage
2812        .block
2813        .w
2814        .as_ref()
2815        .map(|r| r.len() as i32)
2816        .unwrap_or(0);
2817    let h_primary_start = storage
2818        .primary
2819        .h
2820        .as_ref()
2821        .map(|r| r.start as i32)
2822        .unwrap_or(0);
2823    let w_primary_start = storage
2824        .primary
2825        .w
2826        .as_ref()
2827        .map(|r| r.start as i32)
2828        .unwrap_or(0);
2829    let rows_per_cta_i32 = DENSE_BLOCK_ROWS_PER_CTA as i32;
2830    let num_chunks_u32 = num_chunks as u32;
2831
2832    // Per-CTA shmem accumulator: p_total² doubles.
2833    let shmem_bytes: u32 =
2834        u32::try_from(pp * std::mem::size_of::<f64>()).map_err(|_| GpuError::DriverCallFailed {
2835            reason: format!("dense_block shmem bytes overflow u32 for p_total={p_total}"),
2836        })?;
2837
2838    let cfg_part = LaunchConfig {
2839        grid_dim: (num_chunks_u32, 1, 1),
2840        block_dim: (HVP_THREADS, 1, 1),
2841        shared_mem_bytes: shmem_bytes,
2842    };
2843    let mut builder = stream.launch_builder(&part_func);
2844    builder
2845        .arg(&n_i32)
2846        .arg(&r_i32)
2847        .arg(&p_m_i32)
2848        .arg(&p_g_i32)
2849        .arg(&p_total_i32)
2850        .arg(&h_block_start)
2851        .arg(&h_block_len)
2852        .arg(&w_block_start)
2853        .arg(&w_block_len)
2854        .arg(&h_primary_start)
2855        .arg(&w_primary_start)
2856        .arg(&rows_per_cta_i32)
2857        .arg(&storage.hess)
2858        .arg(&storage.marginal_design)
2859        .arg(&storage.logslope_design)
2860        .arg(&mut d_partial);
2861    // SAFETY: storage pointers have validated capacities; d_partial sized
2862    // num_chunks * pp doubles; dynamic shmem matches the kernel's `extern
2863    // __shared__` accumulator length.
2864    unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2865        reason: format!("bms_flex_row dense_block partial launch: {err}"),
2866    })?;
2867
2868    let red_threads: u32 = REDUCTION_THREADS;
2869    let red_blocks: u32 = ((pp as u32) + red_threads - 1) / red_threads;
2870    let cfg_red = LaunchConfig {
2871        grid_dim: (red_blocks, 1, 1),
2872        block_dim: (red_threads, 1, 1),
2873        shared_mem_bytes: 0,
2874    };
2875    let num_chunks_i32 = num_chunks as i32;
2876    let mut builder = stream.launch_builder(&red_func);
2877    builder
2878        .arg(&num_chunks_i32)
2879        .arg(&p_total_i32)
2880        .arg(&d_partial)
2881        .arg(&mut d_out);
2882    // SAFETY: d_partial just populated, d_out is pp doubles.
2883    unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2884        reason: format!("bms_flex_row dense_block reduce launch: {err}"),
2885    })?;
2886    stream
2887        .synchronize()
2888        .map_err(|err| GpuError::DriverCallFailed {
2889            reason: format!("bms_flex_row dense_block sync: {err}"),
2890        })?;
2891    stream
2892        .clone_dtoh(&d_out)
2893        .map_err(|err| GpuError::DriverCallFailed {
2894            reason: format!("bms_flex_row dense_block download: {err}"),
2895        })
2896}
2897
2898// Block 9 / V100-build unblock (2026-05-27): every test below either
2899// constructs `BmsFlexBlockLayout` / `BmsFlexPrimaryLayout` (Linux-only
2900// types) or drives a CUDA-dependent fixture, so gate the whole module
2901// `#[cfg(all(test, target_os = "linux"))]`. On macOS the structs are
2902// absent and these tests do not compile — the build.rs ban scanner
2903// explicitly rejects `#[cfg(any(..., test))]` on the struct definitions
2904// themselves as a dead-code escape hatch.
2905#[cfg(all(test, target_os = "linux"))]
2906mod tests {
2907    use super::*;
2908
2909    pub(crate) fn minimal_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
2910        BmsFlexRowKernelInputs {
2911            n_rows: 1,
2912            r: 4,
2913            p_h: 1,
2914            p_w: 1,
2915            q: &buffers.q,
2916            b: &buffers.b,
2917            mu_1: &buffers.mu_1,
2918            mu_2: &buffers.mu_2,
2919            z_obs: &buffers.z_obs,
2920            y: &buffers.y,
2921            w: &buffers.w,
2922            s_f: 1.0,
2923            cell_offsets: &buffers.cell_offsets,
2924            cell_c0: &buffers.cell_c0,
2925            cell_c1: &buffers.cell_c1,
2926            cell_c2: &buffers.cell_c2,
2927            cell_c3: &buffers.cell_c3,
2928            cell_a: &buffers.cell_a,
2929            cell_aa: &buffers.cell_aa,
2930            cell_r: &buffers.cell_r,
2931            cell_ar: &buffers.cell_ar,
2932            cell_sbb: &buffers.cell_sbb,
2933            cell_sbh: &buffers.cell_sbh,
2934            cell_sbw: &buffers.cell_sbw,
2935            cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
2936            chi_obs: &buffers.chi_obs,
2937            xi_obs: &buffers.xi_obs,
2938            rho_u: &buffers.rho_u,
2939            tau_u: &buffers.tau_u,
2940            r_uv: &buffers.r_uv,
2941        }
2942    }
2943
2944    pub(crate) struct TestBuffers {
2945        pub(crate) q: Vec<f64>,
2946        pub(crate) b: Vec<f64>,
2947        pub(crate) mu_1: Vec<f64>,
2948        pub(crate) mu_2: Vec<f64>,
2949        pub(crate) z_obs: Vec<f64>,
2950        pub(crate) y: Vec<f64>,
2951        pub(crate) w: Vec<f64>,
2952        pub(crate) cell_offsets: Vec<u32>,
2953        pub(crate) cell_c0: Vec<f64>,
2954        pub(crate) cell_c1: Vec<f64>,
2955        pub(crate) cell_c2: Vec<f64>,
2956        pub(crate) cell_c3: Vec<f64>,
2957        pub(crate) cell_a: Vec<f64>,
2958        pub(crate) cell_aa: Vec<f64>,
2959        pub(crate) cell_r: Vec<f64>,
2960        pub(crate) cell_ar: Vec<f64>,
2961        pub(crate) cell_sbb: Vec<f64>,
2962        pub(crate) cell_sbh: Vec<f64>,
2963        pub(crate) cell_sbw: Vec<f64>,
2964        pub(crate) cell_moments: Vec<f64>,
2965        pub(crate) chi_obs: Vec<f64>,
2966        pub(crate) xi_obs: Vec<f64>,
2967        pub(crate) rho_u: Vec<f64>,
2968        pub(crate) tau_u: Vec<f64>,
2969        pub(crate) r_uv: Vec<f64>,
2970    }
2971
2972    pub(crate) fn make_buffers(n_cells: u32, r: usize, p_h: usize, p_w: usize) -> TestBuffers {
2973        let cells = n_cells as usize;
2974        TestBuffers {
2975            q: vec![0.1; 1],
2976            b: vec![0.5; 1],
2977            mu_1: vec![0.3; 1],
2978            mu_2: vec![0.07; 1],
2979            z_obs: vec![0.0; 1],
2980            y: vec![1.0; 1],
2981            w: vec![1.0; 1],
2982            cell_offsets: vec![0, n_cells],
2983            cell_c0: vec![0.2; cells],
2984            cell_c1: vec![-0.1; cells],
2985            cell_c2: vec![0.05; cells],
2986            cell_c3: vec![-0.02; cells],
2987            cell_a: vec![0.1; cells * 4],
2988            cell_aa: vec![0.0; cells * 4],
2989            cell_r: vec![0.05; cells * (r - 1) * 4],
2990            cell_ar: vec![0.0; cells * (r - 1) * 4],
2991            cell_sbb: vec![0.0; cells * 4],
2992            cell_sbh: vec![0.0; cells * p_h * 4],
2993            cell_sbw: vec![0.0; cells * p_w * 4],
2994            cell_moments: vec![1.0; cells * MOMENT_STRIDE],
2995            chi_obs: vec![1.0; 1],
2996            xi_obs: vec![0.0; 1],
2997            rho_u: vec![0.0; r],
2998            tau_u: vec![0.0; r],
2999            r_uv: vec![0.0; r * r],
3000        }
3001    }
3002
3003    #[test]
3004    pub(crate) fn validate_accepts_minimal_inputs() {
3005        let buffers = make_buffers(2, 4, 1, 1);
3006        let inputs = minimal_inputs(&buffers);
3007        assert!(inputs.validate().is_ok());
3008    }
3009
3010    #[test]
3011    pub(crate) fn validate_rejects_r_above_max() {
3012        let r = MAX_R + 1;
3013        let p_h = (r - 2) / 2;
3014        let p_w = (r - 2) - p_h;
3015        let buffers = make_buffers(1, r, p_h, p_w);
3016        let bad_inputs = BmsFlexRowKernelInputs {
3017            r,
3018            p_h,
3019            p_w,
3020            rho_u: &buffers.rho_u, // length matches `r` we wrote
3021            tau_u: &buffers.tau_u,
3022            r_uv: &buffers.r_uv,
3023            cell_r: &buffers.cell_r,
3024            cell_ar: &buffers.cell_ar,
3025            cell_sbh: &buffers.cell_sbh,
3026            cell_sbw: &buffers.cell_sbw,
3027            ..minimal_inputs(&buffers)
3028        };
3029        let err = bad_inputs.validate().expect_err("r > MAX_R must fail");
3030        let msg = err.to_string();
3031        assert!(msg.contains("MAX_R"), "expected MAX_R hint, got: {msg}");
3032    }
3033
3034    #[test]
3035    pub(crate) fn validate_rejects_mismatched_r_decomposition() {
3036        let buffers = make_buffers(1, 4, 1, 1);
3037        let bad_inputs = BmsFlexRowKernelInputs {
3038            r: 4,
3039            p_h: 1,
3040            p_w: 2, // inconsistent with r = 4
3041            ..minimal_inputs(&buffers)
3042        };
3043        let err = bad_inputs
3044            .validate()
3045            .expect_err("inconsistent r vs p_h+p_w must fail");
3046        let msg = err.to_string();
3047        assert!(msg.contains("p_h"), "got: {msg}");
3048        assert!(msg.contains("p_w"), "got: {msg}");
3049    }
3050
3051    #[test]
3052    pub(crate) fn validate_rejects_non_monotone_offsets() {
3053        // `minimal_inputs` hard-codes `n_rows = 1`, so the CSR-style row
3054        // pointer length is `n + 1 = 2`. Pin both `offsets[1] = total_cells`
3055        // and `cell_c0.len() = total_cells = 2` from `make_buffers(2, …)`,
3056        // then violate monotonicity by setting `offsets[0] > offsets[1]`;
3057        // every length / per-cell-count check is satisfied so the only
3058        // failure mode left is the monotonicity guard.
3059        let mut buffers = make_buffers(2, 4, 1, 1);
3060        buffers.cell_offsets = vec![5, 2];
3061        let inputs = minimal_inputs(&buffers);
3062        let err = inputs
3063            .validate()
3064            .expect_err("non-monotone offsets must fail");
3065        let msg = err.to_string();
3066        assert!(msg.contains("monotone"), "got: {msg}");
3067    }
3068
3069    #[test]
3070    pub(crate) fn validate_rejects_mismatched_cell_moments_length() {
3071        let mut buffers = make_buffers(2, 4, 1, 1);
3072        buffers.cell_moments.pop(); // length now 2*10 - 1
3073        let inputs = minimal_inputs(&buffers);
3074        let err = inputs.validate().expect_err("short cell_moments must fail");
3075        let msg = err.to_string();
3076        assert!(msg.contains("cell_moments"), "got: {msg}");
3077    }
3078
3079    #[test]
3080    pub(crate) fn launch_on_non_linux_reports_driver_library_unavailable() {
3081        // Mac/Windows builds must surface a typed `DriverLibraryUnavailable`
3082        // rather than panicking or returning Ok. On Linux this test is
3083        // skipped because the kernel actually launches.
3084        #[cfg(target_os = "linux")]
3085        {
3086            // Linux builds may or may not have a device; the dispatcher
3087            // contract is that without a runtime, probe() returns
3088            // DriverLibraryUnavailable. Either outcome (NoDeviceKernel,
3089            // DriverLibraryUnavailable, or DriverCallFailed) is acceptable
3090            // here; success would mean the kernel actually ran which is a
3091            // V100-only outcome we don't gate the unit test on.
3092            let buffers = make_buffers(1, 4, 1, 1);
3093            let inputs = minimal_inputs(&buffers);
3094            match launch_bms_flex_row_kernel(inputs) {
3095                Ok(_) => { /* V100 host: real launch */ }
3096                Err(GpuError::DriverLibraryUnavailable { .. })
3097                | Err(GpuError::DriverCallFailed { .. })
3098                | Err(GpuError::DriverSymbolMissing { .. })
3099                | Err(GpuError::NoDeviceKernel { .. }) => { /* expected on CPU-only */ }
3100                Err(other) => panic!("unexpected GpuError variant: {other:?}"),
3101            }
3102        }
3103        #[cfg(not(target_os = "linux"))]
3104        {
3105            let buffers = make_buffers(1, 4, 1, 1);
3106            let inputs = minimal_inputs(&buffers);
3107            match launch_bms_flex_row_kernel(inputs) {
3108                Err(GpuError::DriverLibraryUnavailable { reason }) => {
3109                    assert!(
3110                        reason.contains("Linux-only"),
3111                        "expected Linux-only hint, got: {reason}"
3112                    );
3113                }
3114                other => panic!("expected DriverLibraryUnavailable on non-Linux, got {other:?}"),
3115            }
3116        }
3117    }
3118
3119    #[test]
3120    pub(crate) fn s_f_must_be_positive_and_finite() {
3121        let buffers = make_buffers(1, 4, 1, 1);
3122        let mut inputs = minimal_inputs(&buffers);
3123        inputs.s_f = 0.0;
3124        match launch_bms_flex_row_kernel(inputs) {
3125            Err(GpuError::DriverCallFailed { reason }) => {
3126                assert!(reason.contains("s_f"), "got: {reason}");
3127            }
3128            other => panic!("expected DriverCallFailed for s_f=0, got {other:?}"),
3129        }
3130    }
3131
3132    // ── CPU oracle that mirrors ROW_KERNEL_BODY bit-for-bit ──────────────────
3133    //
3134    // `cpu_oracle_outputs` implements the same algebra as
3135    // `bms_flex_row_kernel` in ROW_KERNEL_BODY: per-cell `T_n` / `D` / `Q`
3136    // contractions, q-row override, IFT to `a_u` / `a_uv`, observed-point
3137    // assembly to `bar_e_u` / `bar_e_uv`, probit Mills, and the final
3138    // `out_grad` / `out_hess` writes. It takes the same
3139    // `BmsFlexRowKernelInputs` struct so a CUDA-equipped host can run both
3140    // paths off one bundle and check element-wise parity.
3141    //
3142    // Used by the GPU↔CPU parity test below; the test skips on non-Linux
3143    // hosts via cfg, but the oracle itself is platform-independent so the
3144    // macOS lib build can still type-check it.
3145
3146    pub(crate) const ORACLE_INV_TWO_PI: f64 = 1.0 / std::f64::consts::TAU;
3147    pub(crate) const ORACLE_SQRT_2: f64 = std::f64::consts::SQRT_2;
3148    pub(crate) const ORACLE_INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
3149
3150    pub(crate) fn oracle_erfcx_nonnegative(x: f64) -> f64 {
3151        if !x.is_finite() {
3152            return if x > 0.0 { 0.0 } else { f64::INFINITY };
3153        }
3154        if x <= 0.0 {
3155            return 1.0;
3156        }
3157        if x < 26.0 {
3158            let mut xx = x * x;
3159            if xx > 700.0 {
3160                xx = 700.0;
3161            }
3162            return xx.exp() * gam_gpu::numerics_host::erfc(x);
3163        }
3164        let inv = 1.0 / x;
3165        let inv2 = inv * inv;
3166        let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
3167            + 6.5625 * inv2 * inv2 * inv2 * inv2;
3168        let inv_sqrt_pi: f64 = 0.564_189_583_547_756_3;
3169        inv * poly * inv_sqrt_pi
3170    }
3171
3172    pub(crate) fn oracle_log_ndtr_and_mills(x: f64) -> (f64, f64) {
3173        if x == f64::INFINITY {
3174            return (0.0, 0.0);
3175        }
3176        if x == f64::NEG_INFINITY {
3177            return (f64::NEG_INFINITY, f64::INFINITY);
3178        }
3179        if x.is_nan() {
3180            return (x, x);
3181        }
3182        // Single-algorithm region around and above 0. Both `log Φ(x)` and the
3183        // Mills ratio `φ/Φ` are computed from the SAME `erfc(-x/√2)` call, so
3184        // the oracle is C¹ across the x=0 seam (the prior split — erfcx-based
3185        // `-u²+ln(0.5·e^{u²}·erfc u)` for x<0 vs direct `ln(0.5·erfc(-x))` for
3186        // x≥0 — used two distinct float algorithms whose ~1e-7 disagreement at
3187        // x=0 corrupted a finite-difference reference straddling the seam, #838).
3188        //
3189        // The erfcx form is mathematically identical (e^{u²}·erfc(u)=erfcx(u),
3190        // so −u²+ln(0.5·erfcx u)=ln(0.5·erfc(−x/√2))); it is only needed deep in
3191        // the left tail (x ≲ −38), where `erfc(-x/√2)` underflows to 0 and the
3192        // exp/ln cancellation is the *only* way to keep `log Φ` finite. We move
3193        // the branch there, far from any region the kernel or its FD lock visits.
3194        const ORACLE_LEFT_TAIL_X: f64 = -37.0;
3195        if x >= ORACLE_LEFT_TAIL_X {
3196            let mut cdf = 0.5 * gam_gpu::numerics_host::erfc(-x / ORACLE_SQRT_2);
3197            if cdf < 1e-300 {
3198                cdf = 1e-300;
3199            }
3200            if cdf > 1.0 {
3201                cdf = 1.0;
3202            }
3203            let pdf = ORACLE_INV_SQRT_2PI * (-0.5 * x * x).exp();
3204            (cdf.ln(), pdf / cdf)
3205        } else {
3206            let u = -x / ORACLE_SQRT_2;
3207            let mut ex = oracle_erfcx_nonnegative(u);
3208            if ex < 1e-300 {
3209                ex = 1e-300;
3210            }
3211            let log_cdf = -u * u + (0.5 * ex).ln();
3212            let sqrt_2_over_pi: f64 = 0.797_884_560_802_865_4;
3213            (log_cdf, sqrt_2_over_pi / ex)
3214        }
3215    }
3216
3217    /// Same outputs the device kernel writes: `(neglog, grad, hess)` per row.
3218    /// `grad` is row-major `n × r`, `hess` is row-major `n × r × r`.
3219    /// Mirrors `bms_flex_row_kernel` line-for-line so kernel + oracle diverge
3220    /// only if one side breaks parity.
3221    pub(crate) fn cpu_oracle_outputs(
3222        inputs: &BmsFlexRowKernelInputs<'_>,
3223    ) -> BmsFlexRowKernelOutputs {
3224        let n = inputs.n_rows;
3225        let r = inputs.r;
3226        let p_h = inputs.p_h;
3227        let p_w = inputs.p_w;
3228        let mut neglog = vec![0.0_f64; n];
3229        let mut grad = vec![0.0_f64; n * r];
3230        let mut hess = vec![0.0_f64; n * r * r];
3231        let cell_moments_host = match &inputs.cell_moments {
3232            CellMomentsSource::Host(slice) => *slice,
3233            #[cfg(target_os = "linux")]
3234            CellMomentsSource::Device(_) => panic!(
3235                // SAFETY: this CPU oracle is a host-only sanity checker invoked
3236                // exclusively from `#[cfg(test)] mod tests`. The kernel-launch
3237                // path uses `CellMomentsSource::Device(...)`; the oracle must
3238                // never see that variant. Reaching this arm means a test
3239                // mis-wired its fixture — surface it loudly at the call site.
3240                "cpu_oracle_outputs: cell_moments is device-resident; oracle \
3241                 is a host-only sanity checker"
3242            ),
3243        };
3244
3245        for row in 0..n {
3246            // ── per-cell sweep: accumulate F_u, F_au, F_uv, F_a, F_aa.
3247            let mut f_u = vec![0.0_f64; r];
3248            let mut f_au = vec![0.0_f64; r];
3249            let mut f_uv = vec![0.0_f64; r * r];
3250            let mut f_a = 0.0_f64;
3251            let mut f_aa = 0.0_f64;
3252
3253            let cell_lo = inputs.cell_offsets[row] as usize;
3254            let cell_hi = inputs.cell_offsets[row + 1] as usize;
3255            for c in cell_lo..cell_hi {
3256                let c_arr = [
3257                    inputs.cell_c0[c],
3258                    inputs.cell_c1[c],
3259                    inputs.cell_c2[c],
3260                    inputs.cell_c3[c],
3261                ];
3262                let m = &cell_moments_host[c * MOMENT_STRIDE..(c + 1) * MOMENT_STRIDE];
3263
3264                // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
3265                let mut t = [0.0_f64; 7];
3266                for (n_idx, t_slot) in t.iter_mut().enumerate() {
3267                    let mut acc = 0.0_f64;
3268                    for (e, c_e) in c_arr.iter().enumerate() {
3269                        acc = c_e.mul_add(m[e + n_idx], acc);
3270                    }
3271                    *t_slot = acc * ORACLE_INV_TWO_PI;
3272                }
3273
3274                let d_of = |r_arr: &[f64]| -> f64 {
3275                    ORACLE_INV_TWO_PI
3276                        * (r_arr[0] * m[0] + r_arr[1] * m[1] + r_arr[2] * m[2] + r_arr[3] * m[3])
3277                };
3278                let q_of = |r_arr: &[f64], s_arr: &[f64]| -> f64 {
3279                    (r_arr[0] * s_arr[0]) * t[0]
3280                        + (r_arr[0] * s_arr[1] + r_arr[1] * s_arr[0]) * t[1]
3281                        + (r_arr[0] * s_arr[2] + r_arr[1] * s_arr[1] + r_arr[2] * s_arr[0]) * t[2]
3282                        + (r_arr[0] * s_arr[3]
3283                            + r_arr[1] * s_arr[2]
3284                            + r_arr[2] * s_arr[1]
3285                            + r_arr[3] * s_arr[0])
3286                            * t[3]
3287                        + (r_arr[1] * s_arr[3] + r_arr[2] * s_arr[2] + r_arr[3] * s_arr[1]) * t[4]
3288                        + (r_arr[2] * s_arr[3] + r_arr[3] * s_arr[2]) * t[5]
3289                        + (r_arr[3] * s_arr[3]) * t[6]
3290                };
3291
3292                let a_c = &inputs.cell_a[c * 4..(c + 1) * 4];
3293                let aa_c = &inputs.cell_aa[c * 4..(c + 1) * 4];
3294                f_a += d_of(a_c);
3295                f_aa += d_of(aa_c) - q_of(a_c, a_c);
3296
3297                for u in 1..r {
3298                    let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3299                    let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3300                    let ar_u = &inputs.cell_ar[r_u_off..r_u_off + 4];
3301                    f_u[u] += d_of(r_u);
3302                    f_au[u] += d_of(ar_u) - q_of(a_c, r_u);
3303                }
3304
3305                for u in 1..r {
3306                    let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3307                    let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3308                    for v in u..r {
3309                        let r_v_off = (c * (r - 1) + (v - 1)) * 4;
3310                        let r_v = &inputs.cell_r[r_v_off..r_v_off + 4];
3311                        let q_uv = q_of(r_u, r_v);
3312                        let d_s = if u == 1 && v == 1 {
3313                            let s_bb = &inputs.cell_sbb[c * 4..(c + 1) * 4];
3314                            d_of(s_bb)
3315                        } else if u == 1 && v >= 2 && v < 2 + p_h {
3316                            let j = v - 2;
3317                            let off = (c * p_h + j) * 4;
3318                            let s_bh = &inputs.cell_sbh[off..off + 4];
3319                            d_of(s_bh)
3320                        } else if u == 1 && v >= 2 + p_h && v < r {
3321                            let l = v - (2 + p_h);
3322                            let off = (c * p_w + l) * 4;
3323                            let s_bw = &inputs.cell_sbw[off..off + 4];
3324                            d_of(s_bw)
3325                        } else {
3326                            0.0
3327                        };
3328                        f_uv[u * r + v] += d_s - q_uv;
3329                    }
3330                }
3331            }
3332
3333            // q-row overrides (mirror kernel lines 691–700).
3334            let mu_1 = inputs.mu_1[row];
3335            let mu_2 = inputs.mu_2[row];
3336            f_u[0] = -mu_1;
3337            f_au[0] = 0.0;
3338            for v in 0..r {
3339                f_uv[v] = 0.0;
3340                f_uv[v * r] = 0.0;
3341            }
3342            f_uv[0] = -mu_2;
3343
3344            // Degenerate F_a ⇒ NaN-fill (mirror kernel lines 703–706).
3345            if !f_a.is_finite() || f_a <= 0.0 {
3346                neglog[row] = f64::NAN;
3347                for slot in grad[row * r..(row + 1) * r].iter_mut() {
3348                    *slot = f64::NAN;
3349                }
3350                for slot in hess[row * r * r..(row + 1) * r * r].iter_mut() {
3351                    *slot = f64::NAN;
3352                }
3353                continue;
3354            }
3355            let inv_fa = 1.0 / f_a;
3356
3357            // IFT first/second order.
3358            let mut a_u = vec![0.0_f64; r];
3359            a_u[0] = mu_1 * inv_fa;
3360            for u in 1..r {
3361                a_u[u] = -f_u[u] * inv_fa;
3362            }
3363            let mut a_uv = vec![0.0_f64; r * r];
3364            for u in 0..r {
3365                for v in u..r {
3366                    let term = f_uv[u * r + v]
3367                        + f_au[v] * a_u[u]
3368                        + f_au[u] * a_u[v]
3369                        + f_aa * a_u[u] * a_u[v];
3370                    let val = -term * inv_fa;
3371                    a_uv[u * r + v] = val;
3372                    a_uv[v * r + u] = val;
3373                }
3374            }
3375
3376            // Observed predictor jets.
3377            let chi = inputs.chi_obs[row];
3378            let xi = inputs.xi_obs[row];
3379            let rho = &inputs.rho_u[row * r..(row + 1) * r];
3380            let tau = &inputs.tau_u[row * r..(row + 1) * r];
3381            let ruv = &inputs.r_uv[row * r * r..(row + 1) * r * r];
3382            let mut bar_e_u = vec![0.0_f64; r];
3383            for u in 0..r {
3384                bar_e_u[u] = chi * a_u[u] + rho[u];
3385            }
3386            let mut bar_e_uv = vec![0.0_f64; r * r];
3387            for u in 0..r {
3388                for v in u..r {
3389                    let val = chi * a_uv[u * r + v]
3390                        + xi * a_u[u] * a_u[v]
3391                        + tau[u] * a_u[v]
3392                        + a_u[u] * tau[v]
3393                        + ruv[u * r + v];
3394                    bar_e_uv[u * r + v] = val;
3395                    if u != v {
3396                        bar_e_uv[v * r + u] = val;
3397                    }
3398                }
3399            }
3400
3401            // Probit Mills + final writes.
3402            let y = inputs.y[row];
3403            let w = inputs.w[row];
3404            let s = 2.0 * y - 1.0;
3405            let e_obs = bar_e_u[0];
3406            let m_arg = s * e_obs;
3407            let (log_cdf, lambda) = oracle_log_ndtr_and_mills(m_arg);
3408            let a_i = -w * s * lambda;
3409            let b_i = w * lambda * (m_arg + lambda);
3410            neglog[row] = -w * log_cdf;
3411            for u in 0..r {
3412                grad[row * r + u] = a_i * bar_e_u[u];
3413            }
3414            for u in 0..r {
3415                for v in u..r {
3416                    let val = b_i * bar_e_u[u] * bar_e_u[v] + a_i * bar_e_uv[u * r + v];
3417                    hess[row * r * r + u * r + v] = val;
3418                    if u != v {
3419                        hess[row * r * r + v * r + u] = val;
3420                    }
3421                }
3422            }
3423        }
3424
3425        BmsFlexRowKernelOutputs { neglog, grad, hess }
3426    }
3427
3428    /// Build a non-trivial fixture: `n = 4` rows, `r = 5` (p_h = 2, p_w = 1),
3429    /// 2–4 cells per row, distinct values so a structural bug in either path
3430    /// can't be masked by accidental cancellation.
3431    pub(crate) fn make_parity_buffers() -> TestBuffers {
3432        let n = 4_usize;
3433        let r = 5_usize;
3434        let p_h = 2_usize;
3435        let p_w = 1_usize;
3436        // Per-row cell counts: 2, 3, 4, 2 → total 11 cells.
3437        let row_cells: [u32; 4] = [2, 3, 4, 2];
3438        let mut cell_offsets = vec![0_u32; n + 1];
3439        for i in 0..n {
3440            cell_offsets[i + 1] = cell_offsets[i] + row_cells[i];
3441        }
3442        let total_cells = cell_offsets[n] as usize;
3443
3444        // Deterministic but varied generators (LCG-ish so each slot is distinct).
3445        let f = |seed: usize| -> f64 {
3446            let x = ((seed.wrapping_mul(2_654_435_761)) & 0xFFFF) as f64 / 65_536.0;
3447            0.1 + 0.4 * x
3448        };
3449
3450        let q = (0..n).map(|i| 0.05 + 0.1 * (i as f64)).collect::<Vec<_>>();
3451        let b = (0..n).map(|i| 0.6 + 0.05 * (i as f64)).collect::<Vec<_>>();
3452        let mu_1 = (0..n).map(|i| 0.7 + 0.02 * (i as f64)).collect::<Vec<_>>();
3453        let mu_2 = (0..n).map(|i| 0.15 + 0.01 * (i as f64)).collect::<Vec<_>>();
3454        let z_obs = (0..n).map(|i| -0.2 + 0.1 * (i as f64)).collect::<Vec<_>>();
3455        let y = [1.0, 0.0, 1.0, 0.0].to_vec();
3456        let w = vec![1.0; n];
3457
3458        let cell_c0 = (0..total_cells).map(|c| f(c + 1001)).collect::<Vec<_>>();
3459        let cell_c1 = (0..total_cells)
3460            .map(|c| -f(c + 2002) * 0.5)
3461            .collect::<Vec<_>>();
3462        let cell_c2 = (0..total_cells).map(|c| f(c + 3003) * 0.2).collect();
3463        let cell_c3 = (0..total_cells).map(|c| -f(c + 4004) * 0.1).collect();
3464
3465        let cell_a = (0..total_cells * 4)
3466            .map(|i| f(i + 5005) * 0.3)
3467            .collect::<Vec<_>>();
3468        let cell_aa = (0..total_cells * 4)
3469            .map(|i| f(i + 6006) * 0.1)
3470            .collect::<Vec<_>>();
3471        let cell_r = (0..total_cells * (r - 1) * 4)
3472            .map(|i| f(i + 7007) * 0.2)
3473            .collect::<Vec<_>>();
3474        let cell_ar = (0..total_cells * (r - 1) * 4)
3475            .map(|i| f(i + 8008) * 0.05)
3476            .collect::<Vec<_>>();
3477        let cell_sbb = (0..total_cells * 4)
3478            .map(|i| f(i + 9009) * 0.08)
3479            .collect::<Vec<_>>();
3480        let cell_sbh = (0..total_cells * p_h * 4)
3481            .map(|i| f(i + 10_010) * 0.07)
3482            .collect::<Vec<_>>();
3483        let cell_sbw = (0..total_cells * p_w * 4)
3484            .map(|i| f(i + 11_011) * 0.06)
3485            .collect::<Vec<_>>();
3486        let cell_moments = (0..total_cells * MOMENT_STRIDE)
3487            .map(|i| 0.4 + 0.1 * f(i + 12_012))
3488            .collect::<Vec<_>>();
3489
3490        let chi_obs = (0..n).map(|i| 0.9 + 0.01 * (i as f64)).collect::<Vec<_>>();
3491        let xi_obs = (0..n).map(|i| 0.2 + 0.01 * (i as f64)).collect::<Vec<_>>();
3492        let rho_u = (0..n * r).map(|i| 0.03 * f(i + 13_013)).collect::<Vec<_>>();
3493        let tau_u = (0..n * r).map(|i| 0.02 * f(i + 14_014)).collect::<Vec<_>>();
3494        let r_uv = (0..n * r * r)
3495            .map(|i| 0.04 * f(i + 15_015))
3496            .collect::<Vec<_>>();
3497
3498        TestBuffers {
3499            q,
3500            b,
3501            mu_1,
3502            mu_2,
3503            z_obs,
3504            y,
3505            w,
3506            cell_offsets,
3507            cell_c0,
3508            cell_c1,
3509            cell_c2,
3510            cell_c3,
3511            cell_a,
3512            cell_aa,
3513            cell_r,
3514            cell_ar,
3515            cell_sbb,
3516            cell_sbh,
3517            cell_sbw,
3518            cell_moments,
3519            chi_obs,
3520            xi_obs,
3521            rho_u,
3522            tau_u,
3523            r_uv,
3524        }
3525    }
3526
3527    pub(crate) fn parity_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3528        BmsFlexRowKernelInputs {
3529            n_rows: 4,
3530            r: 5,
3531            p_h: 2,
3532            p_w: 1,
3533            q: &buffers.q,
3534            b: &buffers.b,
3535            mu_1: &buffers.mu_1,
3536            mu_2: &buffers.mu_2,
3537            z_obs: &buffers.z_obs,
3538            y: &buffers.y,
3539            w: &buffers.w,
3540            s_f: 1.0,
3541            cell_offsets: &buffers.cell_offsets,
3542            cell_c0: &buffers.cell_c0,
3543            cell_c1: &buffers.cell_c1,
3544            cell_c2: &buffers.cell_c2,
3545            cell_c3: &buffers.cell_c3,
3546            cell_a: &buffers.cell_a,
3547            cell_aa: &buffers.cell_aa,
3548            cell_r: &buffers.cell_r,
3549            cell_ar: &buffers.cell_ar,
3550            cell_sbb: &buffers.cell_sbb,
3551            cell_sbh: &buffers.cell_sbh,
3552            cell_sbw: &buffers.cell_sbw,
3553            cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3554            chi_obs: &buffers.chi_obs,
3555            xi_obs: &buffers.xi_obs,
3556            rho_u: &buffers.rho_u,
3557            tau_u: &buffers.tau_u,
3558            r_uv: &buffers.r_uv,
3559        }
3560    }
3561
3562    /// Symmetry + finiteness of the CPU oracle. Runs on every host (Linux,
3563    /// macOS, CPU CI) since the oracle is platform-independent. Guarantees the
3564    /// reference path used by the GPU parity test is itself well-formed.
3565    #[test]
3566    pub(crate) fn cpu_oracle_produces_finite_symmetric_hessian() {
3567        let buffers = make_parity_buffers();
3568        let inputs = parity_inputs(&buffers);
3569        inputs
3570            .validate()
3571            .expect("parity fixture must satisfy validate()");
3572        let out = cpu_oracle_outputs(&inputs);
3573        let n = inputs.n_rows;
3574        let r = inputs.r;
3575        assert_eq!(out.neglog.len(), n);
3576        assert_eq!(out.grad.len(), n * r);
3577        assert_eq!(out.hess.len(), n * r * r);
3578        for row in 0..n {
3579            assert!(
3580                out.neglog[row].is_finite(),
3581                "row {row}: neglog must be finite, got {}",
3582                out.neglog[row]
3583            );
3584            for u in 0..r {
3585                let g = out.grad[row * r + u];
3586                assert!(g.is_finite(), "row {row}: grad[{u}] = {g}");
3587                for v in 0..r {
3588                    let huv = out.hess[row * r * r + u * r + v];
3589                    let hvu = out.hess[row * r * r + v * r + u];
3590                    assert!(huv.is_finite(), "row {row}: H[{u},{v}] = {huv}");
3591                    assert_eq!(
3592                        huv.to_bits(),
3593                        hvu.to_bits(),
3594                        "row {row}: H[{u},{v}] and H[{v},{u}] must be bit-identical"
3595                    );
3596                }
3597            }
3598        }
3599    }
3600
3601    /// Independent finite-difference correctness lock on the oracle's probit
3602    /// Mills layer — the most optimizer-sensitive, drift-prone term in the
3603    /// whole row kernel (issue #415: "third/fourth-order derivative
3604    /// contractions drift silently … formulas are complex and
3605    /// optimizer-sensitive"). The device kernel's `bms_flex_row_kernel` and
3606    /// the host oracle's `cpu_oracle_outputs` both close out with the same
3607    /// Mills algebra:
3608    ///
3609    /// ```text
3610    ///     m       = s · e_obs ;  s = 2y − 1
3611    ///     A       = −w · s · λ(m)
3612    ///     B       =  w · λ(m) · (m + λ(m))
3613    ///     neglog  = −w · log Φ(s · e_obs)
3614    ///     g_u     = A · bar_e_u
3615    ///     H_uv    = B · bar_e_u · bar_e_v + A · bar_e_uv
3616    /// ```
3617    ///
3618    /// Holding the observed jets `bar_e` fixed, the row neglog is a function
3619    /// of the single scalar `e := bar_e_u[0] = e_obs`, and by the assembled
3620    /// formula `∂neglog/∂e = A` and `∂²neglog/∂e² = B`. This test reconstructs
3621    /// `A`, `B`, and `neglog` exactly as the oracle does (same
3622    /// `oracle_log_ndtr_and_mills`, same sign convention), then verifies the
3623    /// analytic `A`/`B` against high-order central differences of
3624    /// `e ↦ −w · log Φ(s·e)`. A drift in the kernel's Mills derivatives —
3625    /// which the device-parity test cannot catch because it checks the kernel
3626    /// *against the same (possibly-wrong) oracle* — fails here on every host,
3627    /// CUDA or not. Bounds are the genuine fifth-order central-difference
3628    /// truncation floor; they are not weakened to pass.
3629    #[test]
3630    pub(crate) fn cpu_oracle_mills_layer_matches_finite_differences() {
3631        // Probit neglog as the oracle assembles it, as a function of the
3632        // observed scalar predictor `e` with weight `w` and label `y`.
3633        let neglog_of = |e: f64, y: f64, w: f64| -> f64 {
3634            let s = 2.0 * y - 1.0;
3635            let (log_cdf, _) = oracle_log_ndtr_and_mills(s * e);
3636            -w * log_cdf
3637        };
3638        // Analytic first/second derivatives wrt `e` — the exact `A`/`B` the
3639        // kernel writes into `grad`/`hess`.
3640        let ab_of = |e: f64, y: f64, w: f64| -> (f64, f64) {
3641            let s = 2.0 * y - 1.0;
3642            let m_arg = s * e;
3643            let (_, lambda) = oracle_log_ndtr_and_mills(m_arg);
3644            let a_i = -w * s * lambda;
3645            let b_i = w * lambda * (m_arg + lambda);
3646            (a_i, b_i)
3647        };
3648
3649        // Sweep both labels (s = ±1), both tails of the predictor, and a
3650        // non-unit weight so every sign/scale path of the Mills algebra is
3651        // exercised. Points stay clear of the deep-tail asymptote where a
3652        // central-difference reference loses its own accuracy.
3653        let cases: [(f64, f64, f64); 12] = [
3654            (-1.6, 1.0, 1.0),
3655            (-0.7, 1.0, 1.0),
3656            (0.0, 1.0, 1.0),
3657            (0.9, 1.0, 1.0),
3658            (1.8, 1.0, 1.0),
3659            (-1.4, 0.0, 1.0),
3660            (-0.3, 0.0, 1.0),
3661            (0.0, 0.0, 1.0),
3662            (0.6, 0.0, 1.0),
3663            (1.5, 0.0, 1.0),
3664            (0.4, 1.0, 0.75),
3665            (-0.8, 0.0, 1.3),
3666        ];
3667        // Fifth-order central stencils; `h` chosen near the f64 sweet spot for
3668        // first/second derivatives of a smooth O(1) function.
3669        let h = 1e-3_f64;
3670        for (e, y, w) in cases {
3671            let (a_ana, b_ana) = ab_of(e, y, w);
3672
3673            let fp2 = neglog_of(e + 2.0 * h, y, w);
3674            let fp1 = neglog_of(e + h, y, w);
3675            let f0 = neglog_of(e, y, w);
3676            let fm1 = neglog_of(e - h, y, w);
3677            let fm2 = neglog_of(e - 2.0 * h, y, w);
3678
3679            // 5-point central first derivative: O(h⁴).
3680            let d1_fd = (-fp2 + 8.0 * fp1 - 8.0 * fm1 + fm2) / (12.0 * h);
3681            // 5-point central second derivative: O(h⁴).
3682            let d2_fd = (-fp2 + 16.0 * fp1 - 30.0 * f0 + 16.0 * fm1 - fm2) / (12.0 * h * h);
3683
3684            let a_abs = (a_ana - d1_fd).abs();
3685            let a_rel = a_abs / a_ana.abs().max(1.0);
3686            assert!(
3687                a_abs <= 5e-8 || a_rel <= 5e-8,
3688                "Mills A (∂neglog/∂e) drift at e={e} y={y} w={w}: \
3689                 analytic={a_ana:.17e} fd={d1_fd:.17e} abs={a_abs:.3e} rel={a_rel:.3e}"
3690            );
3691
3692            let b_abs = (b_ana - d2_fd).abs();
3693            let b_rel = b_abs / b_ana.abs().max(1.0);
3694            assert!(
3695                b_abs <= 5e-6 || b_rel <= 5e-6,
3696                "Mills B (∂²neglog/∂e²) drift at e={e} y={y} w={w}: \
3697                 analytic={b_ana:.17e} fd={d2_fd:.17e} abs={b_abs:.3e} rel={b_rel:.3e}"
3698            );
3699        }
3700    }
3701
3702    /// CPU↔GPU parity. Only runs end-to-end on a Linux host with a CUDA
3703    /// runtime; skips with a clear `eprintln!` on every other host so the
3704    /// always-on test suite stays green on the macOS dev box and CPU CI.
3705    ///
3706    /// On a CUDA host: drives the kernel through `launch_bms_flex_row_kernel`
3707    /// and the same `BmsFlexRowKernelInputs` through `cpu_oracle_outputs`,
3708    /// then asserts every element of `neglog`, `grad`, and `hess` agrees
3709    /// within `|Δ| <= 1e-8 + 1e-8·|cpu|` (absolute-or-relative).
3710    #[test]
3711    pub(crate) fn bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available() {
3712        #[cfg(not(target_os = "linux"))]
3713        {
3714            eprintln!(
3715                "[bms_flex_row parity] non-Linux host — skipping CUDA parity \
3716                 (CPU oracle exercised by sibling test)"
3717            );
3718            return;
3719        }
3720        #[cfg(target_os = "linux")]
3721        {
3722            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
3723                eprintln!(
3724                    "[bms_flex_row parity] no CUDA runtime — skipping device \
3725                     parity (CPU oracle exercised by sibling test)"
3726                );
3727                return;
3728            };
3729            let buffers = make_parity_buffers();
3730            let inputs_cpu = parity_inputs(&buffers);
3731            inputs_cpu
3732                .validate()
3733                .expect("parity fixture must satisfy validate()");
3734            let cpu_out = cpu_oracle_outputs(&inputs_cpu);
3735
3736            // Launch the device kernel against the same inputs.
3737            let inputs_gpu = parity_inputs(&buffers);
3738            let gpu_out = match launch_bms_flex_row_kernel(inputs_gpu) {
3739                Ok(out) => out,
3740                Err(err) => panic!(
3741                    "[bms_flex_row parity] launch failed on CUDA-selected host; \
3742                     device/oracle parity must fail loudly on GPU CI: {err}"
3743                ),
3744            };
3745
3746            let n = inputs_cpu.n_rows;
3747            let r = inputs_cpu.r;
3748            let tol_abs = 1e-8_f64;
3749            let tol_rel = 1e-8_f64;
3750            let check_close = |label: &str, idx: usize, cpu: f64, gpu: f64| {
3751                if cpu.is_nan() || gpu.is_nan() {
3752                    assert!(
3753                        cpu.is_nan() && gpu.is_nan(),
3754                        "{label}[{idx}]: NaN parity broke — cpu={cpu}, gpu={gpu}"
3755                    );
3756                    return;
3757                }
3758                let diff = (cpu - gpu).abs();
3759                let tol = tol_abs + tol_rel * cpu.abs();
3760                assert!(
3761                    diff <= tol,
3762                    "{label}[{idx}]: |cpu − gpu| = {diff:.3e} > tol = {tol:.3e}; \
3763                     cpu={cpu:.17e}, gpu={gpu:.17e}"
3764                );
3765            };
3766            assert_eq!(cpu_out.neglog.len(), gpu_out.neglog.len());
3767            assert_eq!(cpu_out.grad.len(), gpu_out.grad.len());
3768            assert_eq!(cpu_out.hess.len(), gpu_out.hess.len());
3769            for (i, (&c, &g)) in cpu_out.neglog.iter().zip(gpu_out.neglog.iter()).enumerate() {
3770                check_close("neglog", i, c, g);
3771            }
3772            for (i, (&c, &g)) in cpu_out.grad.iter().zip(gpu_out.grad.iter()).enumerate() {
3773                check_close("grad", i, c, g);
3774            }
3775            for (i, (&c, &g)) in cpu_out.hess.iter().zip(gpu_out.hess.iter()).enumerate() {
3776                check_close("hess", i, c, g);
3777            }
3778            // Spot-check exact symmetry on the GPU Hessian too.
3779            for row in 0..n {
3780                for u in 0..r {
3781                    for v in 0..r {
3782                        let a = gpu_out.hess[row * r * r + u * r + v];
3783                        let bb = gpu_out.hess[row * r * r + v * r + u];
3784                        assert_eq!(
3785                            a.to_bits(),
3786                            bb.to_bits(),
3787                            "GPU row {row}: H[{u},{v}] ≠ H[{v},{u}] bit-for-bit"
3788                        );
3789                    }
3790                }
3791            }
3792        }
3793    }
3794
3795    #[test]
3796    pub(crate) fn kernel_source_mentions_cpu_parity_reference() {
3797        // Guarantee the maintainer-facing parity reference comment survives
3798        // refactors of the NVRTC kernel source — the dispatcher wave that
3799        // wires this to bms_flex.rs cross-checks parity against the CPU
3800        // function named here.
3801        #[cfg(target_os = "linux")]
3802        assert!(ROW_KERNEL_BODY.contains("compute_row_analytic_flex_from_parts_into"));
3803        #[cfg(target_os = "linux")]
3804        assert!(ROW_KERNEL_BODY.contains("cell_first_derivative_from_moments"));
3805    }
3806
3807    // ── Phase-3 HVP / diagonal CPU oracles + GPU parity tests ────────────────
3808
3809    /// CPU oracle for [`launch_bms_flex_row_hvp`]. Mirrors the device kernel
3810    /// element-for-element so the GPU parity test runs against the same algebra.
3811    pub(crate) fn cpu_oracle_bms_flex_row_hvp(
3812        row_hessians: &[f64],
3813        marginal_design: &[f64],
3814        logslope_design: &[f64],
3815        block: &BmsFlexBlockLayout,
3816        primary: &BmsFlexPrimaryLayout,
3817        n: usize,
3818        v: &[f64],
3819    ) -> Vec<f64> {
3820        let r = primary.r;
3821        let p_m = block.p_m;
3822        let p_g = block.p_g;
3823        assert_eq!(v.len(), block.p_total);
3824        assert_eq!(row_hessians.len(), n * r * r);
3825        assert_eq!(marginal_design.len(), n * p_m);
3826        assert_eq!(logslope_design.len(), n * p_g);
3827        let mut out = vec![0.0_f64; block.p_total];
3828        let mut row_dir = vec![0.0_f64; r];
3829        let mut action = vec![0.0_f64; r];
3830        for row in 0..n {
3831            let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
3832            let grow = &logslope_design[row * p_g..(row + 1) * p_g];
3833            let mut acc_q = 0.0_f64;
3834            for j in 0..p_m {
3835                acc_q += mrow[j] * v[j];
3836            }
3837            let mut acc_g = 0.0_f64;
3838            for j in 0..p_g {
3839                acc_g += grow[j] * v[p_m + j];
3840            }
3841            row_dir[0] = acc_q;
3842            row_dir[1] = acc_g;
3843            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3844                for (k, ii) in prange.clone().enumerate() {
3845                    row_dir[ii] = v[brange.start + k];
3846                }
3847            }
3848            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3849                for (k, ii) in prange.clone().enumerate() {
3850                    row_dir[ii] = v[brange.start + k];
3851                }
3852            }
3853            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3854            for u in 0..r {
3855                let mut acc = 0.0_f64;
3856                for v_idx in 0..r {
3857                    acc += h_slice[u * r + v_idx] * row_dir[v_idx];
3858                }
3859                action[u] = acc;
3860            }
3861            let a0 = action[0];
3862            for j in 0..p_m {
3863                out[j] += a0 * mrow[j];
3864            }
3865            let a1 = action[1];
3866            for j in 0..p_g {
3867                out[p_m + j] += a1 * grow[j];
3868            }
3869            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3870                for (k, ii) in prange.clone().enumerate() {
3871                    out[brange.start + k] += action[ii];
3872                }
3873            }
3874            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3875                for (k, ii) in prange.clone().enumerate() {
3876                    out[brange.start + k] += action[ii];
3877                }
3878            }
3879        }
3880        out
3881    }
3882
3883    pub(crate) fn cpu_oracle_bms_flex_row_diagonal(
3884        row_hessians: &[f64],
3885        marginal_design: &[f64],
3886        logslope_design: &[f64],
3887        block: &BmsFlexBlockLayout,
3888        primary: &BmsFlexPrimaryLayout,
3889        n: usize,
3890    ) -> Vec<f64> {
3891        let r = primary.r;
3892        let p_m = block.p_m;
3893        let p_g = block.p_g;
3894        let mut out = vec![0.0_f64; block.p_total];
3895        for row in 0..n {
3896            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3897            let h00 = h_slice[0];
3898            let h11 = h_slice[r + 1];
3899            let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
3900            let grow = &logslope_design[row * p_g..(row + 1) * p_g];
3901            for j in 0..p_m {
3902                out[j] += h00 * mrow[j] * mrow[j];
3903            }
3904            for j in 0..p_g {
3905                out[p_m + j] += h11 * grow[j] * grow[j];
3906            }
3907            if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3908                for (k, ii) in prange.clone().enumerate() {
3909                    out[brange.start + k] += h_slice[ii * r + ii];
3910                }
3911            }
3912            if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3913                for (k, ii) in prange.clone().enumerate() {
3914                    out[brange.start + k] += h_slice[ii * r + ii];
3915                }
3916            }
3917        }
3918        out
3919    }
3920
3921    /// Hand-construct a small symmetric per-row Hessian + small designs and
3922    /// verify the CPU oracle satisfies the expected algebra. Platform-
3923    /// independent (runs on macOS / Linux without CUDA).
3924    #[test]
3925    pub(crate) fn cpu_oracle_hvp_matches_hand_computation_no_hw() {
3926        let n = 4_usize;
3927        let r = 4_usize; // q, logslope, h(1), w(1)
3928        let p_m = 2_usize;
3929        let p_g = 2_usize;
3930        let p_h_dim = 1_usize;
3931        let p_w_dim = 1_usize;
3932        let p_total = p_m + p_g + p_h_dim + p_w_dim;
3933        let block = BmsFlexBlockLayout {
3934            p_m,
3935            p_g,
3936            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
3937            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
3938            p_total,
3939        };
3940        let primary = BmsFlexPrimaryLayout {
3941            h: Some(2..3),
3942            w: Some(3..4),
3943            r,
3944        };
3945        // Symmetric per-row Hessian: H_row[u,v] = (row + 1) * (1 + u + 2v) symmetrised.
3946        let mut row_hessians = vec![0.0_f64; n * r * r];
3947        for row in 0..n {
3948            for u in 0..r {
3949                for v in u..r {
3950                    let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
3951                    row_hessians[row * r * r + u * r + v] = val;
3952                    row_hessians[row * r * r + v * r + u] = val;
3953                }
3954            }
3955        }
3956        let mut marginal = vec![0.0_f64; n * p_m];
3957        for row in 0..n {
3958            for j in 0..p_m {
3959                marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
3960            }
3961        }
3962        let mut logslope = vec![0.0_f64; n * p_g];
3963        for row in 0..n {
3964            for j in 0..p_g {
3965                logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
3966            }
3967        }
3968        let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
3969        let out = cpu_oracle_bms_flex_row_hvp(
3970            &row_hessians,
3971            &marginal,
3972            &logslope,
3973            &block,
3974            &primary,
3975            n,
3976            &v,
3977        );
3978        // Hand check the first marginal slot: out[0] = Σ_row action[0]·mrow[0].
3979        let mut expect_out_0 = 0.0_f64;
3980        for row in 0..n {
3981            let mrow = &marginal[row * p_m..(row + 1) * p_m];
3982            let grow = &logslope[row * p_g..(row + 1) * p_g];
3983            let mut row_dir = vec![0.0_f64; r];
3984            row_dir[0] = mrow[0] * v[0] + mrow[1] * v[1];
3985            row_dir[1] = grow[0] * v[p_m] + grow[1] * v[p_m + 1];
3986            row_dir[2] = v[p_m + p_g];
3987            row_dir[3] = v[p_m + p_g + p_h_dim];
3988            let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3989            let mut action0 = 0.0_f64;
3990            // h_slice is the row-major r×r Hessian for this row; we want
3991            // row 0, i.e. entries (0, vv) for vv in 0..r, which lives at
3992            // `vv` in the flat layout.
3993            for vv in 0..r {
3994                action0 += h_slice[vv] * row_dir[vv];
3995            }
3996            expect_out_0 += action0 * mrow[0];
3997        }
3998        assert!(
3999            (out[0] - expect_out_0).abs() < 1e-12,
4000            "cpu oracle HVP out[0] mismatch: {} vs hand-check {}",
4001            out[0],
4002            expect_out_0
4003        );
4004        assert!(out.iter().all(|x| x.is_finite()));
4005        assert_eq!(out.len(), p_total);
4006    }
4007
4008    /// Diagonal oracle equals the explicit per-row design² accumulator.
4009    #[test]
4010    pub(crate) fn cpu_oracle_diagonal_matches_hand_computation() {
4011        let n = 3_usize;
4012        let r = 4_usize;
4013        let p_m = 2_usize;
4014        let p_g = 2_usize;
4015        let p_h_dim = 1_usize;
4016        let p_w_dim = 1_usize;
4017        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4018        let block = BmsFlexBlockLayout {
4019            p_m,
4020            p_g,
4021            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4022            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4023            p_total,
4024        };
4025        let primary = BmsFlexPrimaryLayout {
4026            h: Some(2..3),
4027            w: Some(3..4),
4028            r,
4029        };
4030        let mut row_hessians = vec![0.0_f64; n * r * r];
4031        for row in 0..n {
4032            for u in 0..r {
4033                row_hessians[row * r * r + u * r + u] = 1.0 + (row as f64) + (u as f64) * 0.5;
4034            }
4035        }
4036        let mut marginal = vec![0.0_f64; n * p_m];
4037        let mut logslope = vec![0.0_f64; n * p_g];
4038        for row in 0..n {
4039            for j in 0..p_m {
4040                marginal[row * p_m + j] = 0.2 + (row as f64) * 0.3 + (j as f64) * 0.1;
4041            }
4042            for j in 0..p_g {
4043                logslope[row * p_g + j] = -0.4 + (row as f64) * 0.1 + (j as f64) * 0.2;
4044            }
4045        }
4046        let out = cpu_oracle_bms_flex_row_diagonal(
4047            &row_hessians,
4048            &marginal,
4049            &logslope,
4050            &block,
4051            &primary,
4052            n,
4053        );
4054        // Hand check: out[0] = Σ_row H[row,0,0] · marginal[row,0]^2.
4055        let mut expect = 0.0_f64;
4056        for row in 0..n {
4057            let h00 = row_hessians[row * r * r];
4058            expect += h00 * marginal[row * p_m].powi(2);
4059        }
4060        assert!(
4061            (out[0] - expect).abs() < 1e-12,
4062            "out[0] {} vs {}",
4063            out[0],
4064            expect
4065        );
4066        // h slot = sum of H[row, 2, 2] across rows.
4067        let mut expect_h = 0.0_f64;
4068        for row in 0..n {
4069            expect_h += row_hessians[row * r * r + 2 * r + 2];
4070        }
4071        let h_slot = p_m + p_g;
4072        assert!(
4073            (out[h_slot] - expect_h).abs() < 1e-12,
4074            "h slot {} vs {}",
4075            out[h_slot],
4076            expect_h
4077        );
4078    }
4079
4080    /// GPU↔CPU parity for the HVP and diagonal kernels. Skips on non-Linux /
4081    /// no-CUDA hosts. Hand-constructs a small `DeviceResidentRowHess` by
4082    /// allocating the device slices directly, uploading the same arrays the
4083    /// CPU oracle consumes, then dispatching the device kernels.
4084    #[test]
4085    pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_when_cuda_available() {
4086        #[cfg(not(target_os = "linux"))]
4087        {
4088            eprintln!(
4089                "[bms_flex_row hvp parity] non-Linux host — skipping CUDA parity \
4090                 (CPU oracle exercised by sibling tests)"
4091            );
4092        }
4093        #[cfg(target_os = "linux")]
4094        {
4095            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4096                eprintln!(
4097                    "[bms_flex_row hvp parity] no CUDA runtime — skipping device \
4098                     parity"
4099                );
4100                return;
4101            };
4102            let n = 4_usize;
4103            let r = 4_usize;
4104            let p_m = 2_usize;
4105            let p_g = 2_usize;
4106            let p_h_dim = 1_usize;
4107            let p_w_dim = 1_usize;
4108            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4109            let block = BmsFlexBlockLayout {
4110                p_m,
4111                p_g,
4112                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4113                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4114                p_total,
4115            };
4116            let primary = BmsFlexPrimaryLayout {
4117                h: Some(2..3),
4118                w: Some(3..4),
4119                r,
4120            };
4121            let mut row_hessians = vec![0.0_f64; n * r * r];
4122            for row in 0..n {
4123                for u in 0..r {
4124                    for v in u..r {
4125                        let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4126                        row_hessians[row * r * r + u * r + v] = val;
4127                        row_hessians[row * r * r + v * r + u] = val;
4128                    }
4129                }
4130            }
4131            let mut marginal = vec![0.0_f64; n * p_m];
4132            for row in 0..n {
4133                for j in 0..p_m {
4134                    marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4135                }
4136            }
4137            let mut logslope = vec![0.0_f64; n * p_g];
4138            for row in 0..n {
4139                for j in 0..p_g {
4140                    logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4141                }
4142            }
4143            let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4144            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4145                &row_hessians,
4146                &marginal,
4147                &logslope,
4148                &block,
4149                &primary,
4150                n,
4151                &v,
4152            );
4153            let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4154                &row_hessians,
4155                &marginal,
4156                &logslope,
4157                &block,
4158                &primary,
4159                n,
4160            );
4161
4162            // Allocate a DeviceResidentRowHess by hand using the HVP backend's
4163            // stream + module so we don't need to drive the full BMS row kernel.
4164            // Past the GpuRuntime::global() Some-gate above: a probe/upload failure
4165            // here is a real device fault on a CUDA host, not a no-CUDA skip. Fail
4166            // loud (the device-PCG skip-pass class, eee12f6b2) — the old arms
4167            // returned and the test passed while exercising nothing.
4168            let backend = HvpKernelBackend::probe()
4169                .expect("[bms_flex_row hvp parity] backend probe must succeed on CUDA host");
4170            let stream = backend.stream.clone();
4171            let d_h = stream
4172                .clone_htod(&row_hessians)
4173                .expect("[bms_flex_row hvp parity] upload h must succeed on CUDA host");
4174            let d_m = stream
4175                .clone_htod(&marginal)
4176                .expect("[bms_flex_row hvp parity] upload marg must succeed on CUDA host");
4177            let d_g = stream
4178                .clone_htod(&logslope)
4179                .expect("[bms_flex_row hvp parity] upload logslope must succeed on CUDA host");
4180            let storage = DeviceResidentRowHess {
4181                hess: d_h,
4182                marginal_design: d_m,
4183                logslope_design: d_g,
4184                n,
4185                r,
4186                block: block.clone(),
4187                primary: primary.clone(),
4188
4189                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4190            };
4191            let gpu_hvp =
4192                launch_bms_flex_row_hvp(&storage, &v).expect("HVP kernel must launch on CUDA host");
4193            let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4194                .expect("diagonal kernel must launch on CUDA host");
4195            assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4196            assert_eq!(gpu_diag.len(), cpu_diag.len());
4197            for i in 0..p_total {
4198                let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4199                assert!(
4200                    diff <= 1e-10,
4201                    "HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4202                    cpu_hvp[i],
4203                    gpu_hvp[i]
4204                );
4205                let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4206                assert!(
4207                    ddiff <= 1e-10,
4208                    "diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4209                    cpu_diag[i],
4210                    gpu_diag[i]
4211                );
4212            }
4213        }
4214    }
4215
4216    #[test]
4217    pub(crate) fn bms_flex_row_hvp_multi_scratch_is_bounded_at_large_scale_shape() {
4218        let n = 195_000_usize;
4219        let r = 20_usize;
4220        let p_total = 44_usize;
4221        let rhs_count = 4_usize;
4222        let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4223            .expect("large-scale multi-RHS scratch budget");
4224        let per_rhs_full_row_cache =
4225            (n * r * r * std::mem::size_of::<f64>()) as u64 * rhs_count as u64;
4226        assert!(
4227            scratch < per_rhs_full_row_cache / 100,
4228            "multi-RHS scratch must tile by row chunks instead of materializing \
4229             a row-Hessian copy per RHS: scratch={scratch} full_per_rhs={per_rhs_full_row_cache}"
4230        );
4231        assert!(
4232            bms_flex_row_hvp_multi_scratch_bytes_for_shape(
4233                n,
4234                p_total,
4235                BMS_FLEX_ROW_HVP_MAX_RHS + 1
4236            )
4237            .is_err(),
4238            "multi-RHS launch must reject unbounded RHS counts"
4239        );
4240    }
4241
4242    #[test]
4243    pub(crate) fn bms_flex_row_hvp_multi_kernel_matches_cpu_oracle_when_cuda_available() {
4244        let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4245            eprintln!("[bms_flex_row hvp_multi parity] no CUDA runtime — skipping device parity");
4246            return;
4247        };
4248        let n = 5_usize;
4249        let r = 4_usize;
4250        let p_m = 2_usize;
4251        let p_g = 2_usize;
4252        let p_h_dim = 1_usize;
4253        let p_w_dim = 1_usize;
4254        let p_total = p_m + p_g + p_h_dim + p_w_dim;
4255        let rhs_count = 3_usize;
4256        let block = BmsFlexBlockLayout {
4257            p_m,
4258            p_g,
4259            h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4260            w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4261            p_total,
4262        };
4263        let primary = BmsFlexPrimaryLayout {
4264            h: Some(2..3),
4265            w: Some(3..4),
4266            r,
4267        };
4268        let mut row_hessians = vec![0.0_f64; n * r * r];
4269        for row in 0..n {
4270            for u in 0..r {
4271                for v in u..r {
4272                    let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4273                    row_hessians[row * r * r + u * r + v] = val;
4274                    row_hessians[row * r * r + v * r + u] = val;
4275                }
4276            }
4277        }
4278        let mut marginal = vec![0.0_f64; n * p_m];
4279        let mut logslope = vec![0.0_f64; n * p_g];
4280        for row in 0..n {
4281            for j in 0..p_m {
4282                marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4283            }
4284            for j in 0..p_g {
4285                logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4286            }
4287        }
4288        let mut v_rhs = vec![0.0_f64; rhs_count * p_total];
4289        for rhs in 0..rhs_count {
4290            for j in 0..p_total {
4291                let seed = (rhs as f64) * 0.37 + (j as f64) * 0.19 + 0.4;
4292                v_rhs[rhs * p_total + j] = seed.sin() * 0.4 + seed.cos() * 0.2;
4293            }
4294        }
4295
4296        // Past the GpuRuntime::global() Some-gate: a probe/upload failure here is a
4297        // real device fault on a CUDA host, not a no-CUDA skip — fail loud
4298        // (device-PCG skip-pass class, eee12f6b2).
4299        let backend = HvpKernelBackend::probe()
4300            .expect("[bms_flex_row hvp_multi parity] backend probe must succeed on CUDA host");
4301        let stream = backend.stream.clone();
4302        let d_h = stream
4303            .clone_htod(&row_hessians)
4304            .expect("[bms_flex_row hvp_multi parity] upload h must succeed on CUDA host");
4305        let d_m = stream
4306            .clone_htod(&marginal)
4307            .expect("[bms_flex_row hvp_multi parity] upload marg must succeed on CUDA host");
4308        let d_g = stream
4309            .clone_htod(&logslope)
4310            .expect("[bms_flex_row hvp_multi parity] upload logslope must succeed on CUDA host");
4311        let storage = DeviceResidentRowHess {
4312            hess: d_h,
4313            marginal_design: d_m,
4314            logslope_design: d_g,
4315            n,
4316            r,
4317            block: block.clone(),
4318            primary: primary.clone(),
4319
4320            bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4321        };
4322        let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4323            .expect("storage scratch budget");
4324        assert!(
4325            scratch < storage.bytes,
4326            "multi-RHS scratch should stay below resident cache bytes"
4327        );
4328        let gpu = launch_bms_flex_row_hvp_multi(&storage, &v_rhs, rhs_count)
4329            .expect("multi-RHS HVP kernel must launch on CUDA host");
4330        assert_eq!(gpu.len(), rhs_count * p_total);
4331        for rhs in 0..rhs_count {
4332            let v = &v_rhs[rhs * p_total..(rhs + 1) * p_total];
4333            let cpu = cpu_oracle_bms_flex_row_hvp(
4334                &row_hessians,
4335                &marginal,
4336                &logslope,
4337                &block,
4338                &primary,
4339                n,
4340                v,
4341            );
4342            let single = launch_bms_flex_row_hvp(&storage, v)
4343                .expect("single-RHS HVP kernel must launch on CUDA host");
4344            for j in 0..p_total {
4345                let got = gpu[rhs * p_total + j];
4346                let diff = (cpu[j] - got).abs();
4347                assert!(
4348                    diff <= 1e-10,
4349                    "multi-RHS HVP rhs={rhs} j={j}: cpu={} gpu={} |diff|={diff:.3e}",
4350                    cpu[j],
4351                    got
4352                );
4353                assert_eq!(
4354                    got, single[j],
4355                    "multi-RHS and single-RHS host launch diverged at rhs={rhs} j={j}"
4356                );
4357            }
4358        }
4359    }
4360
4361    /// Parity for the third launch mode — device-output HVP
4362    /// ([`launch_bms_flex_row_hvp_into_device`]) — which the
4363    /// `run_bms_flex_row_partial_reduce` unification routes through the same
4364    /// partial+reduce engine as the host-returning `_hvp` / `_diagonal`
4365    /// adapters. Confirms that keeping the result on-stream (no internal sync /
4366    /// DtoH) reaches bit-identical output to both the CPU oracle and the
4367    /// host-out adapter, so the engine's mode/output split is faithful.
4368    ///
4369    /// Skips cleanly on non-Linux / no-CUDA hosts using the convention shared
4370    /// with the sibling parity tests.
4371    #[test]
4372    pub(crate) fn bms_flex_row_hvp_into_device_matches_cpu_oracle_and_host_out() {
4373        #[cfg(not(target_os = "linux"))]
4374        {
4375            eprintln!(
4376                "[bms_flex_row hvp_into_device parity] non-Linux host — skipping \
4377                 CUDA parity (CPU oracle exercised by sibling tests)"
4378            );
4379        }
4380        #[cfg(target_os = "linux")]
4381        {
4382            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4383                eprintln!(
4384                    "[bms_flex_row hvp_into_device parity] no CUDA runtime — \
4385                     skipping device parity"
4386                );
4387                return;
4388            };
4389            let n = 4_usize;
4390            let r = 4_usize;
4391            let p_m = 2_usize;
4392            let p_g = 2_usize;
4393            let p_h_dim = 1_usize;
4394            let p_w_dim = 1_usize;
4395            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4396            let block = BmsFlexBlockLayout {
4397                p_m,
4398                p_g,
4399                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4400                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4401                p_total,
4402            };
4403            let primary = BmsFlexPrimaryLayout {
4404                h: Some(2..3),
4405                w: Some(3..4),
4406                r,
4407            };
4408            let mut row_hessians = vec![0.0_f64; n * r * r];
4409            for row in 0..n {
4410                for u in 0..r {
4411                    for v in u..r {
4412                        let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4413                        row_hessians[row * r * r + u * r + v] = val;
4414                        row_hessians[row * r * r + v * r + u] = val;
4415                    }
4416                }
4417            }
4418            let mut marginal = vec![0.0_f64; n * p_m];
4419            for row in 0..n {
4420                for j in 0..p_m {
4421                    marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4422                }
4423            }
4424            let mut logslope = vec![0.0_f64; n * p_g];
4425            for row in 0..n {
4426                for j in 0..p_g {
4427                    logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4428                }
4429            }
4430            let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4431            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4432                &row_hessians,
4433                &marginal,
4434                &logslope,
4435                &block,
4436                &primary,
4437                n,
4438                &v,
4439            );
4440
4441            // Past the GpuRuntime::global() Some-gate: probe/upload failures are
4442            // real device faults on a CUDA host — fail loud (device-PCG class).
4443            let backend = HvpKernelBackend::probe().expect(
4444                "[bms_flex_row hvp_into_device parity] backend probe must succeed on CUDA host",
4445            );
4446            let stream = backend.stream.clone();
4447            let d_h = stream
4448                .clone_htod(&row_hessians)
4449                .expect("[bms_flex_row hvp_into_device parity] upload h must succeed on CUDA host");
4450            let d_m = stream.clone_htod(&marginal).expect(
4451                "[bms_flex_row hvp_into_device parity] upload marg must succeed on CUDA host",
4452            );
4453            let d_g = stream.clone_htod(&logslope).expect(
4454                "[bms_flex_row hvp_into_device parity] upload logslope must succeed on CUDA host",
4455            );
4456            let storage = DeviceResidentRowHess {
4457                hess: d_h,
4458                marginal_design: d_m,
4459                logslope_design: d_g,
4460                n,
4461                r,
4462                block: block.clone(),
4463                primary: primary.clone(),
4464
4465                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4466            };
4467
4468            // Host-out adapter (allocates its own d_out, syncs + downloads).
4469            let host_out_hvp = launch_bms_flex_row_hvp(&storage, &v)
4470                .expect("host-out HVP kernel must launch on CUDA host");
4471
4472            // Device-out adapter: caller owns d_v + d_out; the engine performs
4473            // no sync / DtoH, so we synchronize + download here.
4474            let d_v = stream
4475                .clone_htod(&v)
4476                .expect("upload direction for device-out HVP");
4477            let mut d_out = stream
4478                .alloc_zeros::<f64>(p_total)
4479                .expect("alloc device-out HVP output");
4480            launch_bms_flex_row_hvp_into_device(&storage, &d_v, &mut d_out)
4481                .expect("device-out HVP kernel must launch on CUDA host");
4482            stream
4483                .synchronize()
4484                .expect("synchronize after device-out HVP");
4485            let device_out_hvp = stream
4486                .clone_dtoh(&d_out)
4487                .expect("download device-out HVP output");
4488
4489            assert_eq!(device_out_hvp.len(), cpu_hvp.len());
4490            assert_eq!(device_out_hvp.len(), host_out_hvp.len());
4491            for i in 0..p_total {
4492                let diff = (cpu_hvp[i] - device_out_hvp[i]).abs();
4493                assert!(
4494                    diff <= 1e-10,
4495                    "device-out HVP[{i}] vs CPU: cpu={} gpu={} |Δ|={diff:.3e}",
4496                    cpu_hvp[i],
4497                    device_out_hvp[i]
4498                );
4499                // Both adapters share the engine; the only difference is the
4500                // copy-back path, so they must be bit-identical.
4501                let host_diff = (host_out_hvp[i] - device_out_hvp[i]).abs();
4502                assert!(
4503                    host_diff == 0.0,
4504                    "device-out vs host-out HVP[{i}]: host={} device={} |Δ|={host_diff:.3e}",
4505                    host_out_hvp[i],
4506                    device_out_hvp[i]
4507                );
4508            }
4509        }
4510    }
4511
4512    /// Block 9 Phase 2 parity gate at the shape specified by the
4513    /// charter task: `n = 64`, `r = 20`, `p_total = 44`. Splits
4514    /// `p_total` as `p_m = 14`, `p_g = 12`, `p_h = 10`, `p_w = 8` so
4515    /// `r = 2 + p_h + p_w = 20` and every primary block participates
4516    /// in both the device pullback and the reduce pass. Tolerance is
4517    /// `|Δ| ≤ 1e-8` per the task description (looser than the 1e-10
4518    /// hand-fixture parity, since accumulation order across HVP CTAs
4519    /// differs from the CPU oracle's row-major sum even with the
4520    /// deterministic reduction policy).
4521    ///
4522    /// Skips cleanly on non-Linux and no-CUDA hosts using the same
4523    /// convention as the hand-fixture parity above.
4524    #[test]
4525    pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_at_n64_r20_p44() {
4526        #[cfg(not(target_os = "linux"))]
4527        {
4528            eprintln!(
4529                "[bms_flex_row hvp parity n64_r20_p44] non-Linux host — \
4530                 skipping CUDA parity"
4531            );
4532        }
4533        #[cfg(target_os = "linux")]
4534        {
4535            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4536                eprintln!(
4537                    "[bms_flex_row hvp parity n64_r20_p44] no CUDA runtime — \
4538                     skipping device parity"
4539                );
4540                return;
4541            };
4542            let n = 64_usize;
4543            let p_m = 14_usize;
4544            let p_g = 12_usize;
4545            let p_h_dim = 10_usize;
4546            let p_w_dim = 8_usize;
4547            let r = 2 + p_h_dim + p_w_dim;
4548            assert_eq!(r, 20);
4549            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4550            assert_eq!(p_total, 44);
4551            let block = BmsFlexBlockLayout {
4552                p_m,
4553                p_g,
4554                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4555                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4556                p_total,
4557            };
4558            let primary = BmsFlexPrimaryLayout {
4559                h: Some(2..2 + p_h_dim),
4560                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4561                r,
4562            };
4563
4564            // Deterministic symmetric per-row Hessians + designs +
4565            // direction. Same scrambling family as
4566            // `row_hessian_ops::tests::make_fixture` so any regression
4567            // surfaces consistently across the host-pinned and
4568            // device-resident parity tests.
4569            let mut row_hessians = vec![0.0_f64; n * r * r];
4570            for row in 0..n {
4571                let base = row * r * r;
4572                for u in 0..r {
4573                    for v in 0..r {
4574                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
4575                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4576                        row_hessians[base + u * r + v] = a;
4577                    }
4578                }
4579                for u in 0..r {
4580                    for v in (u + 1)..r {
4581                        let upper = row_hessians[base + u * r + v];
4582                        let lower = row_hessians[base + v * r + u];
4583                        let sym = 0.5 * (upper + lower);
4584                        row_hessians[base + u * r + v] = sym;
4585                        row_hessians[base + v * r + u] = sym;
4586                    }
4587                    row_hessians[base + u * r + u] += r as f64;
4588                }
4589            }
4590            let mut marginal = vec![0.0_f64; n * p_m];
4591            for row in 0..n {
4592                for j in 0..p_m {
4593                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4594                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4595                }
4596            }
4597            let mut logslope = vec![0.0_f64; n * p_g];
4598            for row in 0..n {
4599                for j in 0..p_g {
4600                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4601                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4602                }
4603            }
4604            let v: Vec<f64> = (0..p_total)
4605                .map(|i| {
4606                    let seed = (i as f64) * 0.157 + 0.6;
4607                    seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4608                })
4609                .collect();
4610
4611            let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4612                &row_hessians,
4613                &marginal,
4614                &logslope,
4615                &block,
4616                &primary,
4617                n,
4618                &v,
4619            );
4620            let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4621                &row_hessians,
4622                &marginal,
4623                &logslope,
4624                &block,
4625                &primary,
4626                n,
4627            );
4628
4629            let backend = match HvpKernelBackend::probe() {
4630                Ok(b) => b,
4631                Err(err) => {
4632                    eprintln!(
4633                        "[bms_flex_row hvp parity n64_r20_p44] backend probe \
4634                         failed: {err}"
4635                    );
4636                    return;
4637                }
4638            };
4639            let stream = backend.stream.clone();
4640            let d_h = match stream.clone_htod(&row_hessians) {
4641                Ok(s) => s,
4642                Err(err) => {
4643                    eprintln!(
4644                        "[bms_flex_row hvp parity n64_r20_p44] upload h \
4645                         failed: {err}"
4646                    );
4647                    return;
4648                }
4649            };
4650            let d_m = match stream.clone_htod(&marginal) {
4651                Ok(s) => s,
4652                Err(err) => {
4653                    eprintln!(
4654                        "[bms_flex_row hvp parity n64_r20_p44] upload marg \
4655                         failed: {err}"
4656                    );
4657                    return;
4658                }
4659            };
4660            let d_g = match stream.clone_htod(&logslope) {
4661                Ok(s) => s,
4662                Err(err) => {
4663                    eprintln!(
4664                        "[bms_flex_row hvp parity n64_r20_p44] upload logslope \
4665                         failed: {err}"
4666                    );
4667                    return;
4668                }
4669            };
4670            let storage = DeviceResidentRowHess {
4671                hess: d_h,
4672                marginal_design: d_m,
4673                logslope_design: d_g,
4674                n,
4675                r,
4676                block: block.clone(),
4677                primary: primary.clone(),
4678
4679                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4680            };
4681            let gpu_hvp = launch_bms_flex_row_hvp(&storage, &v)
4682                .expect("HVP kernel must launch on CUDA host at n64/r20/p44");
4683            let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4684                .expect("diagonal kernel must launch on CUDA host at n64/r20/p44");
4685            assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4686            assert_eq!(gpu_diag.len(), cpu_diag.len());
4687            for i in 0..p_total {
4688                let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4689                assert!(
4690                    diff <= 1e-8,
4691                    "n64_r20_p44 HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4692                    cpu_hvp[i],
4693                    gpu_hvp[i]
4694                );
4695                let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4696                assert!(
4697                    ddiff <= 1e-8,
4698                    "n64_r20_p44 diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4699                    cpu_diag[i],
4700                    gpu_diag[i]
4701                );
4702            }
4703        }
4704    }
4705
4706    /// Block 9 Phase 6 — small-fixture parity for the dense-block kernel
4707    /// against the host-side P_i pullback oracle.
4708    /// Verifies bit-equality (modulo reduction-order f.p. noise) between
4709    /// the device-resident dense build and the host accumulator over the
4710    /// same per-row Hessian + designs + P_i pullback.
4711    #[test]
4712    pub(crate) fn bms_flex_row_dense_block_kernel_matches_cpu_pullback() {
4713        #[cfg(not(target_os = "linux"))]
4714        {
4715            eprintln!("[bms_flex_row dense_block parity] non-Linux host — skipping CUDA parity");
4716        }
4717        #[cfg(target_os = "linux")]
4718        {
4719            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4720                eprintln!("[bms_flex_row dense_block parity] no CUDA runtime — skipping");
4721                return;
4722            };
4723            // Small fixture: n=24, r=8 (2 + 3 + 3), p_total=18 (4+4+3+3).
4724            // Keeps the CPU pullback fast while still exercising every
4725            // primary slot (q, g, h, w).
4726            let n = 24_usize;
4727            let p_m = 4_usize;
4728            let p_g = 4_usize;
4729            let p_h_dim = 3_usize;
4730            let p_w_dim = 3_usize;
4731            let r = 2 + p_h_dim + p_w_dim;
4732            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4733            let block = BmsFlexBlockLayout {
4734                p_m,
4735                p_g,
4736                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4737                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4738                p_total,
4739            };
4740            let primary = BmsFlexPrimaryLayout {
4741                h: Some(2..2 + p_h_dim),
4742                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4743                r,
4744            };
4745
4746            let mut row_hessians = vec![0.0_f64; n * r * r];
4747            for row in 0..n {
4748                let base = row * r * r;
4749                for u in 0..r {
4750                    for v in 0..r {
4751                        let seed = (row as f64) * 0.21 + (u as f64) * 1.13 + (v as f64) * 0.47;
4752                        let a = (seed.sin() * 1.4 + (seed * 0.6).cos() * 0.7) * 0.5;
4753                        row_hessians[base + u * r + v] = a;
4754                    }
4755                }
4756                for u in 0..r {
4757                    for v in (u + 1)..r {
4758                        let upper = row_hessians[base + u * r + v];
4759                        let lower = row_hessians[base + v * r + u];
4760                        let sym = 0.5 * (upper + lower);
4761                        row_hessians[base + u * r + v] = sym;
4762                        row_hessians[base + v * r + u] = sym;
4763                    }
4764                    row_hessians[base + u * r + u] += r as f64;
4765                }
4766            }
4767            let mut marginal = vec![0.0_f64; n * p_m];
4768            for row in 0..n {
4769                for j in 0..p_m {
4770                    let seed = (row as f64) * 0.083 + (j as f64) * 0.171 + 0.31;
4771                    marginal[row * p_m + j] = seed.sin() * 0.7 - (seed * 0.5).cos() * 0.25;
4772                }
4773            }
4774            let mut logslope = vec![0.0_f64; n * p_g];
4775            for row in 0..n {
4776                for j in 0..p_g {
4777                    let seed = (row as f64) * 0.097 + (j as f64) * 0.143 - 0.15;
4778                    logslope[row * p_g + j] = seed.cos() * 0.65 + (seed * 0.4).sin() * 0.2;
4779                }
4780            }
4781
4782            // CPU oracle — same pullback math the device kernel mirrors.
4783            let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
4784            let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
4785            let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
4786            let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
4787            let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
4788            let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
4789            let mut h_cpu = vec![0.0_f64; p_total * p_total];
4790            for row in 0..n {
4791                let mrow = &marginal[row * p_m..(row + 1) * p_m];
4792                let grow = &logslope[row * p_g..(row + 1) * p_g];
4793                let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
4794                // Build per-row phi (r length-p_total vectors).
4795                let mut phi = vec![vec![0.0_f64; p_total]; r];
4796                for k in 0..p_m {
4797                    phi[0][k] = mrow[k];
4798                }
4799                for k in 0..p_g {
4800                    phi[1][p_m + k] = grow[k];
4801                }
4802                for k in 0..h_block_len {
4803                    phi[h_primary_start + k][h_block_start + k] = 1.0;
4804                }
4805                for k in 0..w_block_len {
4806                    phi[w_primary_start + k][w_block_start + k] = 1.0;
4807                }
4808                for u in 0..r {
4809                    for v in 0..r {
4810                        let huv = hrow[u * r + v];
4811                        if huv == 0.0 {
4812                            continue;
4813                        }
4814                        for m in 0..p_total {
4815                            let pm = phi[u][m];
4816                            if pm == 0.0 {
4817                                continue;
4818                            }
4819                            let scaled = huv * pm;
4820                            for nn in 0..p_total {
4821                                h_cpu[m * p_total + nn] += scaled * phi[v][nn];
4822                            }
4823                        }
4824                    }
4825                }
4826            }
4827
4828            // Build a transient device-resident storage and launch the
4829            // dense-block kernel.
4830            // Past the GpuRuntime::global() Some-gate: probe/upload failures are
4831            // real device faults on a CUDA host — fail loud (device-PCG class).
4832            let backend = HvpKernelBackend::probe().expect(
4833                "[bms_flex_row dense_block parity] backend probe must succeed on CUDA host",
4834            );
4835            let stream = backend.stream.clone();
4836            let d_h = stream
4837                .clone_htod(&row_hessians)
4838                .expect("[bms_flex_row dense_block parity] upload h must succeed on CUDA host");
4839            let d_m = stream
4840                .clone_htod(&marginal)
4841                .expect("[bms_flex_row dense_block parity] upload marg must succeed on CUDA host");
4842            let d_g = stream.clone_htod(&logslope).expect(
4843                "[bms_flex_row dense_block parity] upload logslope must succeed on CUDA host",
4844            );
4845            let storage = DeviceResidentRowHess {
4846                hess: d_h,
4847                marginal_design: d_m,
4848                logslope_design: d_g,
4849                n,
4850                r,
4851                block: block.clone(),
4852                primary: primary.clone(),
4853
4854                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4855            };
4856            let h_gpu = launch_bms_flex_row_dense_block(&storage)
4857                .expect("dense_block kernel must launch on CUDA host");
4858            assert_eq!(h_gpu.len(), p_total * p_total);
4859
4860            // Compare entry-by-entry with a tolerance that absorbs
4861            // reduction-order f.p. noise from the CTA chunk sum.
4862            let mut max_abs = 0.0_f64;
4863            for i in 0..p_total {
4864                for j in 0..p_total {
4865                    let a = h_cpu[i * p_total + j];
4866                    let b = h_gpu[i * p_total + j];
4867                    let diff = (a - b).abs();
4868                    if diff > max_abs {
4869                        max_abs = diff;
4870                    }
4871                    assert!(
4872                        diff <= 1e-9 * a.abs().max(b.abs()).max(1.0),
4873                        "dense_block[{i},{j}]: cpu={a} gpu={b} |Δ|={diff:.3e}"
4874                    );
4875                }
4876            }
4877            eprintln!(
4878                "[bms_flex_row dense_block parity] n={n} r={r} p={p_total}: max|Δ|={max_abs:.3e}"
4879            );
4880        }
4881    }
4882
4883    /// Block 9 final hill-climb gate — GPU HVP must be at least 5× faster
4884    /// than a Rayon-parallel CPU HVP at large-scale shape (n=195_000, r=20,
4885    /// p_total=44). This is the charter pass/fail metric for whether the
4886    /// device-resident row-Hessian path is a real perf win for the
4887    /// production marginal-slope fit.
4888    ///
4889    /// Methodology:
4890    ///   * Build the same deterministic fixture as the parity tests.
4891    ///   * GPU: median of `iters` `launch_bms_flex_row_hvp` wall-times
4892    ///     after `warmup` warm-up launches (kernel compile + L2 prime).
4893    ///   * CPU: median of `iters` `cpu_oracle_bms_flex_row_hvp` wall-times,
4894    ///     parallelised over rows via Rayon — this mirrors the actual
4895    ///     production CPU path in
4896    ///     `exact_newton_joint_hessian_matvec_from_cache` (which uses
4897    ///     `ROW_CHUNK_SIZE` chunked `into_par_iter()` for the same
4898    ///     contraction).
4899    ///   * Ratio = cpu_median / gpu_median; assert ratio >= 5.
4900    ///
4901    /// Skips on non-Linux / no-CUDA hosts.
4902    #[test]
4903    pub(crate) fn bms_flex_row_hvp_v100_hill_climb_5x_vs_cpu_at_large_scale() {
4904        #[cfg(not(target_os = "linux"))]
4905        {
4906            eprintln!("[bms_flex_row hvp hill-climb] non-Linux host — skipping V100 perf gate");
4907        }
4908        #[cfg(target_os = "linux")]
4909        {
4910            use rayon::prelude::*;
4911
4912            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4913                eprintln!(
4914                    "[bms_flex_row hvp hill-climb] no CUDA runtime — skipping V100 perf gate"
4915                );
4916                return;
4917            };
4918            let n = 195_000_usize;
4919            let p_m = 14_usize;
4920            let p_g = 12_usize;
4921            let p_h_dim = 10_usize;
4922            let p_w_dim = 8_usize;
4923            let r = 2 + p_h_dim + p_w_dim;
4924            let p_total = p_m + p_g + p_h_dim + p_w_dim;
4925            let block = BmsFlexBlockLayout {
4926                p_m,
4927                p_g,
4928                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4929                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4930                p_total,
4931            };
4932            let primary = BmsFlexPrimaryLayout {
4933                h: Some(2..2 + p_h_dim),
4934                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4935                r,
4936            };
4937
4938            // Same deterministic fixture as the Phase 4 large-scale benchmark.
4939            let mut row_hessians = vec![0.0_f64; n * r * r];
4940            for row in 0..n {
4941                let base = row * r * r;
4942                for u in 0..r {
4943                    for vv in 0..r {
4944                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
4945                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4946                        row_hessians[base + u * r + vv] = a;
4947                    }
4948                }
4949                for u in 0..r {
4950                    for vv in (u + 1)..r {
4951                        let upper = row_hessians[base + u * r + vv];
4952                        let lower = row_hessians[base + vv * r + u];
4953                        let sym = 0.5 * (upper + lower);
4954                        row_hessians[base + u * r + vv] = sym;
4955                        row_hessians[base + vv * r + u] = sym;
4956                    }
4957                    row_hessians[base + u * r + u] += r as f64;
4958                }
4959            }
4960            let mut marginal = vec![0.0_f64; n * p_m];
4961            for row in 0..n {
4962                for j in 0..p_m {
4963                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4964                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4965                }
4966            }
4967            let mut logslope = vec![0.0_f64; n * p_g];
4968            for row in 0..n {
4969                for j in 0..p_g {
4970                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4971                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4972                }
4973            }
4974            let v: Vec<f64> = (0..p_total)
4975                .map(|i| {
4976                    let seed = (i as f64) * 0.157 + 0.6;
4977                    seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4978                })
4979                .collect();
4980
4981            // ── GPU side: upload once, time HVP launches ─────────────
4982            let backend = match HvpKernelBackend::probe() {
4983                Ok(b) => b,
4984                Err(err) => {
4985                    eprintln!("[bms_flex_row hvp hill-climb] backend probe failed: {err}");
4986                    return;
4987                }
4988            };
4989            let stream = backend.stream.clone();
4990            let d_h = match stream.clone_htod(&row_hessians) {
4991                Ok(s) => s,
4992                Err(err) => {
4993                    eprintln!("[bms_flex_row hvp hill-climb] upload h failed (likely OOM): {err}");
4994                    return;
4995                }
4996            };
4997            let d_m = match stream.clone_htod(&marginal) {
4998                Ok(s) => s,
4999                Err(err) => {
5000                    eprintln!("[bms_flex_row hvp hill-climb] upload marg failed: {err}");
5001                    return;
5002                }
5003            };
5004            let d_g = match stream.clone_htod(&logslope) {
5005                Ok(s) => s,
5006                Err(err) => {
5007                    eprintln!("[bms_flex_row hvp hill-climb] upload logslope failed: {err}");
5008                    return;
5009                }
5010            };
5011            let storage = DeviceResidentRowHess {
5012                hess: d_h,
5013                marginal_design: d_m,
5014                logslope_design: d_g,
5015                n,
5016                r,
5017                block: block.clone(),
5018                primary: primary.clone(),
5019
5020                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5021            };
5022            let warmup: usize = 3;
5023            let iters: usize = 15;
5024            for _ in 0..warmup {
5025                let out =
5026                    launch_bms_flex_row_hvp(&storage, &v).expect("warmup GPU HVP must launch");
5027                assert_eq!(out.len(), p_total);
5028            }
5029            let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5030            for _ in 0..iters {
5031                let t0 = std::time::Instant::now();
5032                let out = launch_bms_flex_row_hvp(&storage, &v).expect("GPU HVP must launch");
5033                gpu_us.push(t0.elapsed().as_micros());
5034                assert_eq!(out.len(), p_total);
5035            }
5036            gpu_us.sort_unstable();
5037            let gpu_median = gpu_us[iters / 2];
5038
5039            // ── CPU side: chunked Rayon HVP over rows, mirroring the
5040            //    production `exact_newton_joint_hessian_matvec_from_cache`
5041            //    parallelisation pattern (ROW_CHUNK_SIZE-row chunks,
5042            //    try_fold + try_reduce). The per-chunk worker calls the
5043            //    single-threaded oracle on its row slice.
5044            const CHUNK_ROWS: usize = 4096;
5045            let cpu_hvp_parallel = || -> Vec<f64> {
5046                let nchunks = n.div_ceil(CHUNK_ROWS);
5047                (0..nchunks)
5048                    .into_par_iter()
5049                    .fold(
5050                        || vec![0.0_f64; p_total],
5051                        |mut acc, ci| {
5052                            let lo = ci * CHUNK_ROWS;
5053                            let hi = (lo + CHUNK_ROWS).min(n);
5054                            let m = hi - lo;
5055                            let partial = cpu_oracle_bms_flex_row_hvp(
5056                                &row_hessians[lo * r * r..hi * r * r],
5057                                &marginal[lo * p_m..hi * p_m],
5058                                &logslope[lo * p_g..hi * p_g],
5059                                &block,
5060                                &primary,
5061                                m,
5062                                &v,
5063                            );
5064                            for (a, &p) in acc.iter_mut().zip(partial.iter()) {
5065                                *a += p;
5066                            }
5067                            acc
5068                        },
5069                    )
5070                    .reduce(
5071                        || vec![0.0_f64; p_total],
5072                        |mut a, b| {
5073                            for (ax, bx) in a.iter_mut().zip(b.iter()) {
5074                                *ax += *bx;
5075                            }
5076                            a
5077                        },
5078                    )
5079            };
5080            // Warmup once to populate L3 / steady-state Rayon thread pool.
5081            let warm = cpu_hvp_parallel();
5082            assert_eq!(warm.len(), p_total);
5083            let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5084            for _ in 0..iters {
5085                let t0 = std::time::Instant::now();
5086                let out = cpu_hvp_parallel();
5087                cpu_us.push(t0.elapsed().as_micros());
5088                assert_eq!(out.len(), p_total);
5089            }
5090            cpu_us.sort_unstable();
5091            let cpu_median = cpu_us[iters / 2];
5092
5093            let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5094            eprintln!(
5095                "[bms_flex_row hvp hill-climb] large-scale n={n} r={r} p={p_total}: \
5096                 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5097                 speedup={speedup:.2}× (charter target ≥ 5×)"
5098            );
5099            assert!(
5100                speedup >= 5.0,
5101                "large-scale HVP perf gate: GPU only {speedup:.2}× faster than CPU; \
5102                 need ≥ 5× per Block 9 charter (cpu_median={cpu_median}us, \
5103                 gpu_median={gpu_median}us). Hill-climb the kernel until met or \
5104                 prove the kernel is at hardware roofline."
5105            );
5106        }
5107    }
5108
5109    /// Companion to the HVP hill-climb: GPU dense-block build must be at
5110    /// least 10× faster than a Rayon-parallel CPU dense build at large-scale
5111    /// shape. The dense build is `O(n * r² * p_total)` work for both
5112    /// paths so the ratio is well-defined.
5113    #[test]
5114    pub(crate) fn bms_flex_row_dense_block_v100_hill_climb_10x_vs_cpu_at_large_scale() {
5115        #[cfg(not(target_os = "linux"))]
5116        {
5117            eprintln!(
5118                "[bms_flex_row dense_block hill-climb] non-Linux host — skipping V100 perf gate"
5119            );
5120        }
5121        #[cfg(target_os = "linux")]
5122        {
5123            use rayon::prelude::*;
5124
5125            let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5126                eprintln!(
5127                    "[bms_flex_row dense_block hill-climb] no CUDA runtime — skipping V100 perf gate"
5128                );
5129                return;
5130            };
5131            let n = 195_000_usize;
5132            let p_m = 14_usize;
5133            let p_g = 12_usize;
5134            let p_h_dim = 10_usize;
5135            let p_w_dim = 8_usize;
5136            let r = 2 + p_h_dim + p_w_dim;
5137            let p_total = p_m + p_g + p_h_dim + p_w_dim;
5138            let block = BmsFlexBlockLayout {
5139                p_m,
5140                p_g,
5141                h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5142                w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5143                p_total,
5144            };
5145            let primary = BmsFlexPrimaryLayout {
5146                h: Some(2..2 + p_h_dim),
5147                w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5148                r,
5149            };
5150
5151            // Reuse the same large-scale fixture recipe.
5152            let mut row_hessians = vec![0.0_f64; n * r * r];
5153            for row in 0..n {
5154                let base = row * r * r;
5155                for u in 0..r {
5156                    for vv in 0..r {
5157                        let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5158                        let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5159                        row_hessians[base + u * r + vv] = a;
5160                    }
5161                }
5162                for u in 0..r {
5163                    for vv in (u + 1)..r {
5164                        let upper = row_hessians[base + u * r + vv];
5165                        let lower = row_hessians[base + vv * r + u];
5166                        let sym = 0.5 * (upper + lower);
5167                        row_hessians[base + u * r + vv] = sym;
5168                        row_hessians[base + vv * r + u] = sym;
5169                    }
5170                    row_hessians[base + u * r + u] += r as f64;
5171                }
5172            }
5173            let mut marginal = vec![0.0_f64; n * p_m];
5174            for row in 0..n {
5175                for j in 0..p_m {
5176                    let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5177                    marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5178                }
5179            }
5180            let mut logslope = vec![0.0_f64; n * p_g];
5181            for row in 0..n {
5182                for j in 0..p_g {
5183                    let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5184                    logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5185                }
5186            }
5187
5188            // GPU dense_block kernel rejects p_total > DENSE_BLOCK_MAX_P
5189            // (72 at V100 48 KiB/block). LargeScale's p_total = 44 fits.
5190            if p_total > DENSE_BLOCK_MAX_P {
5191                eprintln!(
5192                    "[bms_flex_row dense_block hill-climb] p_total={p_total} > MAX={DENSE_BLOCK_MAX_P}, skipping"
5193                );
5194                return;
5195            }
5196            let backend = match HvpKernelBackend::probe() {
5197                Ok(b) => b,
5198                Err(err) => {
5199                    eprintln!("[bms_flex_row dense_block hill-climb] backend probe failed: {err}");
5200                    return;
5201                }
5202            };
5203            let stream = backend.stream.clone();
5204            let d_h = match stream.clone_htod(&row_hessians) {
5205                Ok(s) => s,
5206                Err(err) => {
5207                    eprintln!("[bms_flex_row dense_block hill-climb] upload h failed: {err}");
5208                    return;
5209                }
5210            };
5211            let d_m = match stream.clone_htod(&marginal) {
5212                Ok(s) => s,
5213                Err(err) => {
5214                    eprintln!("[bms_flex_row dense_block hill-climb] upload marg failed: {err}");
5215                    return;
5216                }
5217            };
5218            let d_g = match stream.clone_htod(&logslope) {
5219                Ok(s) => s,
5220                Err(err) => {
5221                    eprintln!(
5222                        "[bms_flex_row dense_block hill-climb] upload logslope failed: {err}"
5223                    );
5224                    return;
5225                }
5226            };
5227            let storage = DeviceResidentRowHess {
5228                hess: d_h,
5229                marginal_design: d_m,
5230                logslope_design: d_g,
5231                n,
5232                r,
5233                block: block.clone(),
5234                primary: primary.clone(),
5235
5236                bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5237            };
5238            // Warmup + 5-iter median (dense build is heavier than HVP).
5239            let warmup: usize = 2;
5240            let iters: usize = 5;
5241            for _ in 0..warmup {
5242                let out = launch_bms_flex_row_dense_block(&storage)
5243                    .expect("warmup GPU dense_block must launch");
5244                assert_eq!(out.len(), p_total * p_total);
5245            }
5246            let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5247            for _ in 0..iters {
5248                let t0 = std::time::Instant::now();
5249                let out =
5250                    launch_bms_flex_row_dense_block(&storage).expect("GPU dense_block must launch");
5251                gpu_us.push(t0.elapsed().as_micros());
5252                assert_eq!(out.len(), p_total * p_total);
5253            }
5254            gpu_us.sort_unstable();
5255            let gpu_median = gpu_us[iters / 2];
5256
5257            // CPU side: chunked Rayon dense build over rows. Each chunk
5258            // builds a `[p_total, p_total]` partial then we reduce-add.
5259            const CHUNK_ROWS: usize = 2048;
5260            let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5261            let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5262            let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5263            let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5264            let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5265            let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5266            let cpu_build_parallel = || -> Vec<f64> {
5267                let nchunks = n.div_ceil(CHUNK_ROWS);
5268                (0..nchunks)
5269                    .into_par_iter()
5270                    .fold(
5271                        || vec![0.0_f64; p_total * p_total],
5272                        |mut acc, ci| {
5273                            let lo = ci * CHUNK_ROWS;
5274                            let hi = (lo + CHUNK_ROWS).min(n);
5275                            let mut phi: Vec<Vec<f64>> = vec![vec![0.0_f64; p_total]; r];
5276                            for row in lo..hi {
5277                                for col in phi.iter_mut() {
5278                                    col.iter_mut().for_each(|v| *v = 0.0);
5279                                }
5280                                let mrow = &marginal[row * p_m..(row + 1) * p_m];
5281                                let grow = &logslope[row * p_g..(row + 1) * p_g];
5282                                for k in 0..p_m {
5283                                    phi[0][k] = mrow[k];
5284                                }
5285                                for k in 0..p_g {
5286                                    phi[1][p_m + k] = grow[k];
5287                                }
5288                                for k in 0..h_block_len {
5289                                    phi[h_primary_start + k][h_block_start + k] = 1.0;
5290                                }
5291                                for k in 0..w_block_len {
5292                                    phi[w_primary_start + k][w_block_start + k] = 1.0;
5293                                }
5294                                let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5295                                for u in 0..r {
5296                                    for v_idx in 0..r {
5297                                        let huv = hrow[u * r + v_idx];
5298                                        if huv == 0.0 {
5299                                            continue;
5300                                        }
5301                                        for m in 0..p_total {
5302                                            let pm = phi[u][m];
5303                                            if pm == 0.0 {
5304                                                continue;
5305                                            }
5306                                            let scaled = huv * pm;
5307                                            for nn in 0..p_total {
5308                                                acc[m * p_total + nn] += scaled * phi[v_idx][nn];
5309                                            }
5310                                        }
5311                                    }
5312                                }
5313                            }
5314                            acc
5315                        },
5316                    )
5317                    .reduce(
5318                        || vec![0.0_f64; p_total * p_total],
5319                        |mut a, b| {
5320                            for (ax, bx) in a.iter_mut().zip(b.iter()) {
5321                                *ax += *bx;
5322                            }
5323                            a
5324                        },
5325                    )
5326            };
5327            let warm_cpu = cpu_build_parallel();
5328            assert_eq!(warm_cpu.len(), p_total * p_total);
5329            let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5330            for _ in 0..iters {
5331                let t0 = std::time::Instant::now();
5332                let out = cpu_build_parallel();
5333                cpu_us.push(t0.elapsed().as_micros());
5334                assert_eq!(out.len(), p_total * p_total);
5335            }
5336            cpu_us.sort_unstable();
5337            let cpu_median = cpu_us[iters / 2];
5338
5339            let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5340            eprintln!(
5341                "[bms_flex_row dense_block hill-climb] large-scale n={n} r={r} p={p_total}: \
5342                 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5343                 speedup={speedup:.2}× (charter target ≥ 10×)"
5344            );
5345            assert!(
5346                speedup >= 10.0,
5347                "large-scale dense-H perf gate: GPU only {speedup:.2}× faster than CPU; \
5348                 need ≥ 10× per Block 9 charter (cpu_median={cpu_median}us, \
5349                 gpu_median={gpu_median}us). Hill-climb the dense_block kernel \
5350                 (warp-stripe the u-v-m-n loop, vectorise loads, etc.) until met \
5351                 or prove the kernel is at hardware roofline."
5352            );
5353        }
5354    }
5355}