1#[cfg(target_os = "linux")]
85use std::sync::OnceLock;
86
87use gam_gpu::gpu_error::GpuError;
88
89#[cfg(target_os = "linux")]
90use std::sync::Arc;
91
92#[cfg(target_os = "linux")]
93use cudarc::driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
94
95pub(crate) const MAX_R: usize = 32;
99
100#[cfg(target_os = "linux")]
104pub(crate) const ROW_KERNEL_THREADS: u32 = 32;
105
106pub(crate) const COEFF4: usize = 4;
109
110pub(crate) const MOMENT_STRIDE: usize = 10;
113
114pub(crate) enum CellMomentsSource<'a> {
119 Host(&'a [f64]),
122 #[cfg(target_os = "linux")]
127 Device(&'a CudaSlice<f64>),
128}
129
130impl<'a> CellMomentsSource<'a> {
131 pub(crate) fn len(&self) -> usize {
133 match self {
134 CellMomentsSource::Host(slice) => slice.len(),
135 #[cfg(target_os = "linux")]
136 CellMomentsSource::Device(d) => d.len(),
137 }
138 }
139}
140
141macro_rules! define_bms_flex_row_kernel_input_types {
149 (
150 f64_fields: [$($f64_field:ident),+ $(,)?],
151 u32_fields: [$($u32_field:ident),+ $(,)?],
152 moments_field: $moments_field:ident $(,)?
153 ) => {
154 pub(crate) struct BmsFlexRowKernelInputs<'a> {
155 pub n_rows: usize,
157 pub r: usize,
159 pub p_h: usize,
161 pub p_w: usize,
163 pub s_f: f64,
166 $(pub $f64_field: &'a [f64],)+
167 $(pub $u32_field: &'a [u32],)+
168 pub $moments_field: CellMomentsSource<'a>,
169 }
170
171 pub(crate) struct BmsFlexRowKernelInputsOwned {
176 pub n_rows: usize,
177 pub r: usize,
178 pub p_h: usize,
179 pub p_w: usize,
180 pub s_f: f64,
181 $(pub $f64_field: Vec<f64>,)+
182 $(pub $u32_field: Vec<u32>,)+
183 pub $moments_field: Vec<f64>,
184 #[cfg(target_os = "linux")]
188 pub cell_moments_device: Option<CudaSlice<f64>>,
189 }
190
191 impl BmsFlexRowKernelInputsOwned {
192 pub(crate) fn as_borrowed(&self) -> BmsFlexRowKernelInputs<'_> {
197 #[cfg(target_os = "linux")]
198 let cell_moments = match self.cell_moments_device.as_ref() {
199 Some(d) => CellMomentsSource::Device(d),
200 None => CellMomentsSource::Host(&self.cell_moments),
201 };
202 #[cfg(not(target_os = "linux"))]
203 let cell_moments = CellMomentsSource::Host(&self.cell_moments);
204 BmsFlexRowKernelInputs {
205 n_rows: self.n_rows,
206 r: self.r,
207 p_h: self.p_h,
208 p_w: self.p_w,
209 s_f: self.s_f,
210 $($f64_field: &self.$f64_field,)+
211 $($u32_field: &self.$u32_field,)+
212 $moments_field: cell_moments,
213 }
214 }
215 }
216 };
217}
218
219define_bms_flex_row_kernel_input_types! {
220 f64_fields: [
221 q,
222 b,
223 mu_1,
224 mu_2,
225 z_obs,
226 y,
227 w,
228 e_obs,
229 cell_c0,
230 cell_c1,
231 cell_c2,
232 cell_c3,
233 cell_a,
234 cell_aa,
235 cell_r,
236 cell_ar,
237 cell_sbb,
238 cell_sbh,
239 cell_sbw,
240 chi_obs,
241 xi_obs,
242 rho_u,
243 tau_u,
244 r_uv,
245 ],
246 u32_fields: [cell_offsets],
247 moments_field: cell_moments,
248}
249
250#[derive(Debug)]
252pub(crate) struct BmsFlexRowKernelOutputs {
253 pub neglog: Vec<f64>,
255 pub grad: Vec<f64>,
257 pub hess: Vec<f64>,
260}
261
262impl<'a> BmsFlexRowKernelInputs<'a> {
263 pub(crate) fn validate(&self) -> Result<(), GpuError> {
266 if self.r == 0 {
267 return Err(GpuError::DriverCallFailed {
268 reason: "bms_flex_row inputs: r must be > 0".to_string(),
269 });
270 }
271 if self.r > MAX_R {
272 return Err(GpuError::DriverCallFailed {
273 reason: format!("bms_flex_row inputs: r={} exceeds MAX_R={MAX_R}", self.r),
274 });
275 }
276 if self.r != 2 + self.p_h + self.p_w {
277 return Err(GpuError::DriverCallFailed {
278 reason: format!(
279 "bms_flex_row inputs: r={} must equal 2 + p_h({}) + p_w({}) = {}",
280 self.r,
281 self.p_h,
282 self.p_w,
283 2 + self.p_h + self.p_w
284 ),
285 });
286 }
287 let n = self.n_rows;
288 let check_len = |name: &str, have: usize, want: usize| -> Result<(), GpuError> {
289 if have != want {
290 return Err(GpuError::DriverCallFailed {
291 reason: format!("bms_flex_row inputs: {name}.len()={have} != {want}"),
292 });
293 }
294 Ok(())
295 };
296 check_len("q", self.q.len(), n)?;
297 check_len("b", self.b.len(), n)?;
298 check_len("mu_1", self.mu_1.len(), n)?;
299 check_len("mu_2", self.mu_2.len(), n)?;
300 check_len("z_obs", self.z_obs.len(), n)?;
301 check_len("y", self.y.len(), n)?;
302 check_len("w", self.w.len(), n)?;
303 check_len("e_obs", self.e_obs.len(), n)?;
304 check_len("chi_obs", self.chi_obs.len(), n)?;
305 check_len("xi_obs", self.xi_obs.len(), n)?;
306 check_len("rho_u", self.rho_u.len(), n * self.r)?;
307 check_len("tau_u", self.tau_u.len(), n * self.r)?;
308 check_len("r_uv", self.r_uv.len(), n * self.r * self.r)?;
309 check_len("cell_offsets", self.cell_offsets.len(), n + 1)?;
310 let total_cells_u32 = self.cell_offsets[n];
311 let total_cells = total_cells_u32 as usize;
312 check_len("cell_c0", self.cell_c0.len(), total_cells)?;
313 check_len("cell_c1", self.cell_c1.len(), total_cells)?;
314 check_len("cell_c2", self.cell_c2.len(), total_cells)?;
315 check_len("cell_c3", self.cell_c3.len(), total_cells)?;
316 check_len("cell_a", self.cell_a.len(), total_cells * COEFF4)?;
317 check_len("cell_aa", self.cell_aa.len(), total_cells * COEFF4)?;
318 check_len(
319 "cell_r",
320 self.cell_r.len(),
321 total_cells * self.r.saturating_sub(1) * COEFF4,
322 )?;
323 check_len(
324 "cell_ar",
325 self.cell_ar.len(),
326 total_cells * self.r.saturating_sub(1) * COEFF4,
327 )?;
328 check_len("cell_sbb", self.cell_sbb.len(), total_cells * COEFF4)?;
329 check_len(
330 "cell_sbh",
331 self.cell_sbh.len(),
332 total_cells * self.p_h * COEFF4,
333 )?;
334 check_len(
335 "cell_sbw",
336 self.cell_sbw.len(),
337 total_cells * self.p_w * COEFF4,
338 )?;
339 check_len(
340 "cell_moments",
341 self.cell_moments.len(),
342 total_cells * MOMENT_STRIDE,
343 )?;
344 for i in 0..n {
350 if self.cell_offsets[i] > self.cell_offsets[i + 1] {
351 return Err(GpuError::DriverCallFailed {
352 reason: format!(
353 "bms_flex_row inputs: cell_offsets must be monotone (offset[{}]={} > offset[{}]={})",
354 i,
355 self.cell_offsets[i],
356 i + 1,
357 self.cell_offsets[i + 1]
358 ),
359 });
360 }
361 }
362 Ok(())
363 }
364}
365
366#[cfg(target_os = "linux")]
379pub(crate) const ROW_KERNEL_BODY: &str = r#"
380// One block per row. blockDim.x = 32; threadIdx.x parallises per-cell sums.
381// CPU parity reference: src/families/bernoulli_marginal_slope.rs
382// ::compute_row_analytic_flex_from_parts_into.
383
384#define INV_TWO_PI 0.15915494309189535
385
386extern "C" __device__ __forceinline__ double atomic_add_f64(double *addr, double value) {
387 unsigned long long int *addr_as_ull = (unsigned long long int *)addr;
388 unsigned long long int old = *addr_as_ull;
389 unsigned long long int assumed;
390 do {
391 assumed = old;
392 double next = __longlong_as_double((long long int)assumed) + value;
393 old = atomicCAS(addr_as_ull, assumed, (unsigned long long int)__double_as_longlong(next));
394 } while (assumed != old);
395 return __longlong_as_double((long long int)old);
396}
397
398// `nan_fill_outputs`: thread-0-only path used when row inputs are degenerate
399// (`F_a` non-finite or non-positive). Writes NaNs to neglog/grad/hess so the
400// host falls back to CPU for that row.
401extern "C" __device__ __forceinline__ void
402nan_fill_outputs(int r,
403 int row,
404 double *out_neglog,
405 double *out_grad,
406 double *out_hess) {
407 double nan_value = __longlong_as_double(0x7ff8000000000000ULL);
408 out_neglog[row] = nan_value;
409 for (int u = 0; u < r; ++u) {
410 out_grad[row * r + u] = nan_value;
411 }
412 int rr = r * r;
413 for (int idx = 0; idx < rr; ++idx) {
414 out_hess[row * rr + idx] = nan_value;
415 }
416}
417
418extern "C" __global__ void bms_flex_row_kernel(
419 int n_rows,
420 int r,
421 int p_h,
422 int p_w,
423 double s_f, // currently unused on device:
424 // host has already baked S_f
425 // into the cubic coefficients.
426 // Kept for diagnostic parity.
427 const double * __restrict__ row_q,
428 const double * __restrict__ row_b,
429 const double * __restrict__ row_mu1,
430 const double * __restrict__ row_mu2,
431 const double * __restrict__ row_zobs,
432 const double * __restrict__ row_y,
433 const double * __restrict__ row_w,
434 const unsigned int * __restrict__ cell_offsets,
435 const double * __restrict__ cell_c0,
436 const double * __restrict__ cell_c1,
437 const double * __restrict__ cell_c2,
438 const double * __restrict__ cell_c3,
439 const double * __restrict__ cell_a, // [n_cells, 4]
440 const double * __restrict__ cell_aa, // [n_cells, 4]
441 const double * __restrict__ cell_r, // [n_cells, r-1, 4]
442 const double * __restrict__ cell_ar, // [n_cells, r-1, 4]
443 const double * __restrict__ cell_sbb, // [n_cells, 4]
444 const double * __restrict__ cell_sbh, // [n_cells, p_h, 4]
445 const double * __restrict__ cell_sbw, // [n_cells, p_w, 4]
446 const double * __restrict__ cell_moments, // [n_cells, 10]
447 const double * __restrict__ row_chi,
448 const double * __restrict__ row_xi,
449 const double * __restrict__ row_rho, // [n_rows, r]
450 const double * __restrict__ row_tau, // [n_rows, r]
451 const double * __restrict__ row_ruv, // [n_rows, r*r]
452 const double * __restrict__ row_e_obs, // [n_rows] observed predictor VALUE
453 double * __restrict__ out_neglog,
454 double * __restrict__ out_grad,
455 double * __restrict__ out_hess)
456{
457 int row = blockIdx.x;
458 if (row >= n_rows) return;
459 int tid = threadIdx.x;
460
461 // ── shared scratch (sized to MAX_R = 32) ──────────────────────────────
462 // Layout (doubles):
463 // F_u [r]
464 // F_au [r]
465 // F_uv [r*r]
466 // bar_e_u [r]
467 // bar_e_uv [r*r]
468 // reduce_a [blockDim.x]
469 // reduce_b [blockDim.x]
470 // Sized for the worst case (r = MAX_R = 32).
471 __shared__ double F_u[32];
472 __shared__ double F_au[32];
473 __shared__ double F_uv[32 * 32];
474 __shared__ double bar_e_u[32];
475 __shared__ double bar_e_uv[32 * 32];
476 __shared__ double reduce_a[32];
477 __shared__ double reduce_b[32];
478 __shared__ double F_a_shared;
479 __shared__ double F_aa_shared;
480
481 // Zero scratch.
482 if (tid == 0) { F_a_shared = 0.0; F_aa_shared = 0.0; }
483 for (int u = tid; u < r; u += blockDim.x) {
484 F_u[u] = 0.0;
485 F_au[u] = 0.0;
486 }
487 for (int uv = tid; uv < r * r; uv += blockDim.x) {
488 F_uv[uv] = 0.0;
489 }
490 __syncthreads();
491
492 // ── per-cell sweep ───────────────────────────────────────────────────
493 unsigned int cell_lo = cell_offsets[row];
494 unsigned int cell_hi = cell_offsets[row + 1];
495 int n_cells = (int)(cell_hi - cell_lo);
496
497 double local_Fa = 0.0;
498 double local_Faa = 0.0;
499
500 for (int local_c = tid; local_c < n_cells; local_c += blockDim.x) {
501 unsigned int c = cell_lo + (unsigned int)local_c;
502
503 // Load cubic predictor coeffs C0..C3.
504 double C[4];
505 C[0] = cell_c0[c]; C[1] = cell_c1[c];
506 C[2] = cell_c2[c]; C[3] = cell_c3[c];
507
508 // Load m_0..m_9.
509 const double *m = cell_moments + (size_t)c * 10;
510
511 // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
512 // CPU parity: equivalent to the `eta_rs ⊗ moments` contraction in
513 // `cell_second_derivative_from_moments` after folding the
514 // cubic predictor.
515 double T[7];
516 #pragma unroll
517 for (int n = 0; n < 7; ++n) {
518 double acc = 0.0;
519 #pragma unroll
520 for (int e = 0; e < 4; ++e) {
521 acc = fma(C[e], m[e + n], acc);
522 }
523 T[n] = acc * INV_TWO_PI;
524 }
525
526 // D(R) = κ · Σ_k R_k · m_k.
527 // CPU parity: `cell_first_derivative_from_moments`.
528 #define D_OF(R) (INV_TWO_PI * (R[0]*m[0] + R[1]*m[1] + R[2]*m[2] + R[3]*m[3]))
529
530 // Q(R, S) = Σ_{p,q} R_p · S_q · T_{p+q}.
531 // CPU parity: the `eta_rs` folded dot in
532 // `cell_second_derivative_from_moments`.
533 #define Q_OF(R, S) \
534 ((R[0]*S[0])*T[0] + (R[0]*S[1] + R[1]*S[0])*T[1] \
535 + (R[0]*S[2] + R[1]*S[1] + R[2]*S[0])*T[2] \
536 + (R[0]*S[3] + R[1]*S[2] + R[2]*S[1] + R[3]*S[0])*T[3] \
537 + (R[1]*S[3] + R[2]*S[2] + R[3]*S[1])*T[4] \
538 + (R[2]*S[3] + R[3]*S[2])*T[5] \
539 + (R[3]*S[3])*T[6])
540
541 // F_a += D(A_c) ; F_aa += H(A_c, A_c, AA_c) = D(AA_c) − Q(A_c, A_c).
542 const double *A_c = cell_a + (size_t)c * 4;
543 const double *AA_c = cell_aa + (size_t)c * 4;
544 local_Fa += D_OF(A_c);
545 local_Faa += D_OF(AA_c) - Q_OF(A_c, A_c);
546
547 // For each u > 0: F_u += D(R_{c,u}) ; F_au += H(A_c, R_{c,u}, AR_{c,u})
548 // = D(AR_{c,u}) − Q(A_c, R_{c,u}).
549 for (int u = 1; u < r; ++u) {
550 const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
551 const double *AR_u = cell_ar + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
552 double d_R = D_OF(R_u);
553 double d_AR = D_OF(AR_u);
554 double q_AR = Q_OF(A_c, R_u);
555 atomic_add_f64(&F_u[u], d_R);
556 atomic_add_f64(&F_au[u], d_AR - q_AR);
557 }
558
559 // F_uv: only b·b, b·h_j, b·w_ℓ have a material `S_{c,uv}`; every other
560 // (u, v) pair just contributes −Q(R_u, R_v).
561 // CPU parity: `SparsePrimaryCoeffJetView::pair_from_b_family` with
562 // `COEFF_SUPPORT_BHW` — every cross pair outside the b-row is zero.
563 for (int u = 1; u < r; ++u) {
564 const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
565 for (int v = u; v < r; ++v) {
566 const double *R_v = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(v - 1)) * 4;
567 double q_uv = Q_OF(R_u, R_v);
568 double d_s = 0.0;
569 // S_{bb}: u == v == 1 (b coordinate).
570 if (u == 1 && v == 1) {
571 const double *S_bb = cell_sbb + (size_t)c * 4;
572 d_s = D_OF(S_bb);
573 }
574 // S_{b·h_j}: u == 1, v in score-warp block, or symmetric.
575 else if (u == 1 && v >= 2 && v < 2 + p_h) {
576 int j = v - 2;
577 const double *S_bh = cell_sbh + ((size_t)c * (size_t)p_h + (size_t)j) * 4;
578 d_s = D_OF(S_bh);
579 }
580 // S_{b·w_ℓ}: u == 1, v in link-wiggle block, or symmetric.
581 else if (u == 1 && v >= 2 + p_h && v < r) {
582 int l = v - (2 + p_h);
583 const double *S_bw = cell_sbw + ((size_t)c * (size_t)p_w + (size_t)l) * 4;
584 d_s = D_OF(S_bw);
585 }
586 // Symmetric mirror: u in (h or w) block, v == 1 cannot happen
587 // because we iterate v >= u; skip.
588 double val = d_s - q_uv;
589 atomic_add_f64(&F_uv[u * r + v], val);
590 }
591 }
592
593 #undef D_OF
594 #undef Q_OF
595 }
596
597 // Block reduction of local_Fa, local_Faa into shared.
598 reduce_a[tid] = local_Fa;
599 reduce_b[tid] = local_Faa;
600 __syncthreads();
601 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
602 if (tid < stride) {
603 reduce_a[tid] += reduce_a[tid + stride];
604 reduce_b[tid] += reduce_b[tid + stride];
605 }
606 __syncthreads();
607 }
608 if (tid == 0) {
609 F_a_shared = reduce_a[0];
610 F_aa_shared = reduce_b[0];
611 }
612 __syncthreads();
613
614 // ── thread-0 finalisation: IFT + observed-point + Mills + writes ──────
615 if (tid != 0) return;
616
617 double F_a = F_a_shared;
618 double F_aa = F_aa_shared;
619 double mu_1 = row_mu1[row];
620 double mu_2 = row_mu2[row];
621
622 // q-row overrides.
623 // F_q = -mu_1 ; F_qq = -mu_2 ; F_qv = 0 (v > 0) ; F_aq = 0.
624 F_u[0] = -mu_1;
625 F_au[0] = 0.0;
626 // Zero the q-cross row/column of F_uv (u == 0 or v == 0), then plant -mu_2 at (0,0).
627 for (int v = 0; v < r; ++v) {
628 F_uv[0 * r + v] = 0.0;
629 F_uv[v * r + 0] = 0.0;
630 }
631 F_uv[0 * r + 0] = -mu_2;
632
633 // Guard: degenerate F_a ⇒ NaN-fill this row's outputs.
634 if (!isfinite(F_a) || F_a <= 0.0) {
635 nan_fill_outputs(r, row, out_neglog, out_grad, out_hess);
636 return;
637 }
638 double inv_Fa = 1.0 / F_a;
639
640 // IFT, first order.
641 // a_u = -F_u · inv_Fa (q-override: a_q = mu_1 · inv_Fa).
642 double a_u[32];
643 a_u[0] = mu_1 * inv_Fa;
644 for (int u = 1; u < r; ++u) {
645 a_u[u] = -F_u[u] * inv_Fa;
646 }
647
648 // IFT, second order.
649 // a_uv = -(F_uv + F_au · a_v + F_av · a_u + F_aa · a_u · a_v) · inv_Fa.
650 // The q-row contributions (u==0 or v==0) collapse to a_uv = mu_2 · inv_Fa
651 // when both are 0 and to (F_au_v) · inv_Fa-style mixed shape otherwise.
652 // We compute it uniformly using the populated F_uv / F_au with the
653 // q-overrides above.
654 double a_uv[32 * 32];
655 for (int u = 0; u < r; ++u) {
656 for (int v = u; v < r; ++v) {
657 double term = F_uv[u * r + v]
658 + F_au[v] * a_u[u]
659 + F_au[u] * a_u[v]
660 + F_aa * a_u[u] * a_u[v];
661 double val = -term * inv_Fa;
662 a_uv[u * r + v] = val;
663 a_uv[v * r + u] = val;
664 }
665 }
666
667 // Observed predictor jets at z_obs.
668 // bar_e_u = chi · a_u + rho_u.
669 // bar_e_uv = chi · a_uv + xi · a_u · a_v + tau_u · a_v + a_u · tau_v + r_uv.
670 double chi = row_chi[row];
671 double xi = row_xi[row];
672 const double *rho = row_rho + (size_t)row * r;
673 const double *tau = row_tau + (size_t)row * r;
674 const double *ruv = row_ruv + (size_t)row * r * r;
675
676 for (int u = 0; u < r; ++u) {
677 bar_e_u[u] = chi * a_u[u] + rho[u];
678 }
679 for (int u = 0; u < r; ++u) {
680 for (int v = u; v < r; ++v) {
681 double val = chi * a_uv[u * r + v]
682 + xi * a_u[u] * a_u[v]
683 + tau[u] * a_u[v]
684 + a_u[u] * tau[v]
685 + ruv[u * r + v];
686 bar_e_uv[u * r + v] = val;
687 if (u != v) {
688 bar_e_uv[v * r + u] = val;
689 }
690 }
691 }
692
693 // Probit Mills.
694 double y = row_y[row];
695 double w = row_w[row];
696 double s = 2.0 * y - 1.0;
697 // The "observed predictor" e_obs is the VALUE (degree-0 term) of the
698 // observed jet η(a(θ), θ; z_obs) — NOT `bar_e_u[0]`, which is the u=0
699 // FIRST-derivative jet (`chi·a_0 + rho_0 = dη_obs/dq`). The host packs
700 // the observed value directly in `row_e_obs[row]` (see
701 // `pack_bms_flex_row_kernel_inputs`, `eta_val = eval_coeff4_at(obs.coeff,
702 // z_obs)`), matching the CPU family `compute_row_analytic_flex_from_parts_into`
703 // which forms `signed_margin = s_y · eta_val`. #415 parity lock.
704 double e_obs = row_e_obs[row];
705 double m_arg = s * e_obs;
706 double log_cdf, lambda;
707 log_ndtr_and_mills(m_arg, &log_cdf, &lambda);
708 double A_i = -w * s * lambda;
709 double B_i = w * lambda * (m_arg + lambda);
710
711 out_neglog[row] = -w * log_cdf;
712 for (int u = 0; u < r; ++u) {
713 out_grad[row * r + u] = A_i * bar_e_u[u];
714 }
715 for (int u = 0; u < r; ++u) {
716 for (int v = u; v < r; ++v) {
717 double val = B_i * bar_e_u[u] * bar_e_u[v] + A_i * bar_e_uv[u * r + v];
718 out_hess[row * r * r + u * r + v] = val;
719 if (u != v) {
720 out_hess[row * r * r + v * r + u] = val;
721 }
722 }
723 }
724}
725"#;
726
727#[inline]
734pub(crate) fn s_f_diagnostic_finite(inputs: &BmsFlexRowKernelInputs<'_>) -> bool {
735 inputs.s_f.is_finite() && inputs.s_f > 0.0
736}
737
738#[cfg(target_os = "linux")]
739pub(crate) struct RowKernelBackend {
740 pub(crate) stream: Arc<CudaStream>,
741 pub(crate) module: Arc<CudaModule>,
742}
743
744#[cfg(target_os = "linux")]
745impl RowKernelBackend {
746 pub(crate) fn probe() -> Result<&'static Self, GpuError> {
747 static BACKEND: OnceLock<Result<RowKernelBackend, GpuError>> = OnceLock::new();
748 BACKEND
749 .get_or_init(|| {
750 gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row", |parts| {
751 let row_kernel_source = [
752 gam_gpu::numerics_device::PROBIT_NUMERICS_CU,
753 ROW_KERNEL_BODY,
754 ]
755 .concat();
756 let ptx = gam_gpu::device_cache::compile_ptx_arch(&row_kernel_source).map_err(
767 |err| GpuError::DriverCallFailed {
768 reason: format!("bms_flex_row NVRTC compile failed: {err}"),
769 },
770 )?;
771 let module =
772 parts
773 .ctx
774 .load_module(ptx)
775 .map_err(|err| GpuError::DriverCallFailed {
776 reason: format!("bms_flex_row module load failed: {err}"),
777 })?;
778 Ok(RowKernelBackend {
779 stream: parts.stream.clone(),
780 module,
781 })
782 })
783 })
784 .as_ref()
785 .map_err(GpuError::clone)
786 }
787}
788
789pub(crate) fn launch_bms_flex_row_kernel(
794 inputs: BmsFlexRowKernelInputs<'_>,
795) -> Result<BmsFlexRowKernelOutputs, GpuError> {
796 inputs.validate()?;
797 if !s_f_diagnostic_finite(&inputs) {
798 return Err(GpuError::DriverCallFailed {
799 reason: format!(
800 "bms_flex_row inputs: s_f must be positive and finite, got {}",
801 inputs.s_f
802 ),
803 });
804 }
805
806 #[cfg(target_os = "linux")]
807 {
808 launch_linux(inputs)
809 }
810 #[cfg(not(target_os = "linux"))]
811 {
812 Err(GpuError::DriverLibraryUnavailable {
813 reason: "bms_flex_row GPU kernel is Linux-only".to_string(),
814 })
815 }
816}
817
818#[cfg(target_os = "linux")]
819pub(crate) fn launch_linux(
820 inputs: BmsFlexRowKernelInputs<'_>,
821) -> Result<BmsFlexRowKernelOutputs, GpuError> {
822 let backend = RowKernelBackend::probe()?;
823 let stream = &backend.stream;
824
825 let upload_f64 = |slice: &[f64], label: &str| {
826 stream
827 .clone_htod(slice)
828 .map_err(|err| GpuError::DriverCallFailed {
829 reason: format!("bms_flex_row upload {label}: {err}"),
830 })
831 };
832 let upload_u32 = |slice: &[u32], label: &str| {
833 stream
834 .clone_htod(slice)
835 .map_err(|err| GpuError::DriverCallFailed {
836 reason: format!("bms_flex_row upload {label}: {err}"),
837 })
838 };
839
840 let d_q = upload_f64(inputs.q, "q")?;
841 let d_b = upload_f64(inputs.b, "b")?;
842 let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
843 let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
844 let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
845 let d_y = upload_f64(inputs.y, "y")?;
846 let d_w = upload_f64(inputs.w, "w")?;
847 let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
848 let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
849 let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
850 let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
851 let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
852 let d_a = upload_f64(inputs.cell_a, "cell_a")?;
853 let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
854 let d_r = upload_f64(inputs.cell_r, "cell_r")?;
855 let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
856 let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
857 let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
858 let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
859 let owned_host_moments: CudaSlice<f64>;
863 let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
864 CellMomentsSource::Host(slice) => {
865 owned_host_moments = upload_f64(slice, "cell_moments")?;
866 &owned_host_moments
867 }
868 CellMomentsSource::Device(d) => *d,
869 };
870 let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
871 let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
872 let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
873 let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
874 let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
875 let d_e_obs = upload_f64(inputs.e_obs, "e_obs")?;
876
877 let n = inputs.n_rows;
878 let r = inputs.r;
879 let mut d_neglog = stream
880 .alloc_zeros::<f64>(n)
881 .map_err(|err| GpuError::DriverCallFailed {
882 reason: format!("bms_flex_row alloc neglog: {err}"),
883 })?;
884 let mut d_grad =
885 stream
886 .alloc_zeros::<f64>(n * r)
887 .map_err(|err| GpuError::DriverCallFailed {
888 reason: format!("bms_flex_row alloc grad: {err}"),
889 })?;
890 let mut d_hess =
891 stream
892 .alloc_zeros::<f64>(n * r * r)
893 .map_err(|err| GpuError::DriverCallFailed {
894 reason: format!("bms_flex_row alloc hess: {err}"),
895 })?;
896
897 let func = backend
898 .module
899 .load_function("bms_flex_row_kernel")
900 .map_err(|err| GpuError::DriverCallFailed {
901 reason: format!("bms_flex_row load_function: {err}"),
902 })?;
903
904 let cfg = LaunchConfig {
905 grid_dim: (n as u32, 1, 1),
906 block_dim: (ROW_KERNEL_THREADS, 1, 1),
907 shared_mem_bytes: 0,
908 };
909 let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
910 reason: format!("bms_flex_row: n_rows={n} exceeds i32 range"),
911 })?;
912 let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
913 reason: format!("bms_flex_row: r={r} exceeds i32 range"),
914 })?;
915 let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
916 reason: format!("bms_flex_row: p_h={} exceeds i32 range", inputs.p_h),
917 })?;
918 let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
919 reason: format!("bms_flex_row: p_w={} exceeds i32 range", inputs.p_w),
920 })?;
921 let s_f = inputs.s_f;
922
923 let mut builder = stream.launch_builder(&func);
924 builder
925 .arg(&n_i32)
926 .arg(&r_i32)
927 .arg(&p_h_i32)
928 .arg(&p_w_i32)
929 .arg(&s_f)
930 .arg(&d_q)
931 .arg(&d_b)
932 .arg(&d_mu1)
933 .arg(&d_mu2)
934 .arg(&d_zobs)
935 .arg(&d_y)
936 .arg(&d_w)
937 .arg(&d_offsets)
938 .arg(&d_c0)
939 .arg(&d_c1)
940 .arg(&d_c2)
941 .arg(&d_c3)
942 .arg(&d_a)
943 .arg(&d_aa)
944 .arg(&d_r)
945 .arg(&d_ar)
946 .arg(&d_sbb)
947 .arg(&d_sbh)
948 .arg(&d_sbw)
949 .arg(d_moments_ref)
950 .arg(&d_chi)
951 .arg(&d_xi)
952 .arg(&d_rho)
953 .arg(&d_tau)
954 .arg(&d_ruv)
955 .arg(&d_e_obs)
956 .arg(&mut d_neglog)
957 .arg(&mut d_grad)
958 .arg(&mut d_hess);
959
960 unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
967 reason: format!("bms_flex_row launch: {err}"),
968 })?;
969 stream
970 .synchronize()
971 .map_err(|err| GpuError::DriverCallFailed {
972 reason: format!("bms_flex_row synchronize: {err}"),
973 })?;
974
975 let neglog = stream
976 .clone_dtoh(&d_neglog)
977 .map_err(|err| GpuError::DriverCallFailed {
978 reason: format!("bms_flex_row download neglog: {err}"),
979 })?;
980 let grad = stream
981 .clone_dtoh(&d_grad)
982 .map_err(|err| GpuError::DriverCallFailed {
983 reason: format!("bms_flex_row download grad: {err}"),
984 })?;
985 let hess = stream
986 .clone_dtoh(&d_hess)
987 .map_err(|err| GpuError::DriverCallFailed {
988 reason: format!("bms_flex_row download hess: {err}"),
989 })?;
990
991 Ok(BmsFlexRowKernelOutputs { neglog, grad, hess })
992}
993
994#[cfg(target_os = "linux")]
1046#[derive(Clone, Debug)]
1047pub(crate) struct BmsFlexBlockLayout {
1048 pub p_m: usize,
1049 pub p_g: usize,
1050 pub h: Option<std::ops::Range<usize>>,
1051 pub w: Option<std::ops::Range<usize>>,
1052 pub p_total: usize,
1053}
1054
1055#[cfg(target_os = "linux")]
1058#[derive(Clone, Debug)]
1059pub(crate) struct BmsFlexPrimaryLayout {
1060 pub h: Option<std::ops::Range<usize>>,
1061 pub w: Option<std::ops::Range<usize>>,
1062 pub r: usize,
1063}
1064
1065#[cfg(target_os = "linux")]
1071pub(crate) const HVP_ROWS_PER_CTA: u32 = 256;
1072
1073#[cfg(target_os = "linux")]
1075pub(crate) const HVP_THREADS: u32 = 128;
1076
1077#[cfg(target_os = "linux")]
1082pub(crate) const REDUCTION_THREADS: u32 = 256;
1083
1084#[cfg(target_os = "linux")]
1089pub(crate) const BMS_FLEX_ROW_HVP_MAX_RHS: usize = 8;
1090
1091#[cfg(target_os = "linux")]
1112pub struct DeviceResidentRowHess {
1113 pub(crate) hess: CudaSlice<f64>,
1117 pub(crate) marginal_design: CudaSlice<f64>,
1118 pub(crate) logslope_design: CudaSlice<f64>,
1119 pub(crate) n: usize,
1120 pub(crate) r: usize,
1121 pub(crate) block: BmsFlexBlockLayout,
1122 pub(crate) primary: BmsFlexPrimaryLayout,
1123 pub(crate) bytes: u64,
1125}
1126
1127#[cfg(target_os = "linux")]
1128impl std::fmt::Debug for DeviceResidentRowHess {
1129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1130 f.debug_struct("DeviceResidentRowHess")
1131 .field("n", &self.n)
1132 .field("r", &self.r)
1133 .field("p_total", &self.block.p_total)
1134 .field("bytes", &self.bytes)
1135 .finish()
1136 }
1137}
1138
1139#[cfg(target_os = "linux")]
1142pub(crate) fn num_hvp_chunks(n: usize) -> usize {
1143 n.div_ceil(HVP_ROWS_PER_CTA as usize)
1144}
1145
1146#[cfg(target_os = "linux")]
1149pub(crate) const HVP_KERNEL_SOURCE: &str = r#"
1150// CPU parity reference: cpu_oracle_bms_flex_row_hvp / cpu_oracle_bms_flex_row_diagonal
1151// in this module.
1152
1153#define MAX_MULTI_RHS 8
1154
1155extern "C" __global__ void bms_flex_row_hvp_partial(
1156 int n_rows,
1157 int r,
1158 int p_m,
1159 int p_g,
1160 int p_total,
1161 int h_block_start,
1162 int h_block_len,
1163 int w_block_start,
1164 int w_block_len,
1165 int h_primary_start,
1166 int w_primary_start,
1167 int rows_per_cta,
1168 const double * __restrict__ row_hessians, // [n, r*r]
1169 const double * __restrict__ marginal_design, // [n, p_m] row-major
1170 const double * __restrict__ logslope_design, // [n, p_g] row-major
1171 const double * __restrict__ v, // [p_total]
1172 double * __restrict__ partial) // [num_chunks, p_total]
1173{
1174 int chunk = blockIdx.x;
1175 int tid = threadIdx.x;
1176 int row_lo = chunk * rows_per_cta;
1177 int row_hi = row_lo + rows_per_cta;
1178 if (row_hi > n_rows) row_hi = n_rows;
1179
1180 // Zero this chunk's partial slice cooperatively.
1181 double *out = partial + (size_t)chunk * (size_t)p_total;
1182 for (int j = tid; j < p_total; j += blockDim.x) {
1183 out[j] = 0.0;
1184 }
1185 __syncthreads();
1186
1187 // Each thread serially processes a stride-of-blockDim set of rows so
1188 // every write to `out[..]` happens from one thread → no atomics within
1189 // the chunk. To keep writes race-free across threads of the same chunk,
1190 // we serialize the cross-row accumulation through a per-row barrier:
1191 // thread 0 of the block processes all rows in the chunk. The per-row
1192 // work is dominated by the dot/axpy over `p_m + p_g`, which is large.
1193 // For Stage 3 we ship the simple, correct path (thread 0 sequential
1194 // per row, blockDim.x threads parallel within a row's dot/axpy).
1195 __shared__ double row_dir[32];
1196 __shared__ double action[32];
1197 __shared__ double dot_reduce[128];
1198
1199 for (int row = row_lo; row < row_hi; ++row) {
1200 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1201 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1202 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1203
1204 // row_dir[0] = mrow · v[0..p_m]
1205 double local = 0.0;
1206 for (int j = tid; j < p_m; j += blockDim.x) {
1207 local += mrow[j] * v[j];
1208 }
1209 dot_reduce[tid] = local;
1210 __syncthreads();
1211 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1212 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1213 __syncthreads();
1214 }
1215 if (tid == 0) row_dir[0] = dot_reduce[0];
1216
1217 // row_dir[1] = grow · v[p_m..p_m+p_g]
1218 local = 0.0;
1219 for (int j = tid; j < p_g; j += blockDim.x) {
1220 local += grow[j] * v[p_m + j];
1221 }
1222 dot_reduce[tid] = local;
1223 __syncthreads();
1224 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1225 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1226 __syncthreads();
1227 }
1228 if (tid == 0) row_dir[1] = dot_reduce[0];
1229
1230 // h/w blocks: direct copy.
1231 if (tid == 0) {
1232 for (int k = 0; k < h_block_len; ++k) {
1233 row_dir[h_primary_start + k] = v[h_block_start + k];
1234 }
1235 for (int k = 0; k < w_block_len; ++k) {
1236 row_dir[w_primary_start + k] = v[w_block_start + k];
1237 }
1238 }
1239 __syncthreads();
1240
1241 // action[u] = Σ_v Hrow[u*r+v] · row_dir[v], computed by thread u (u < r).
1242 if (tid < r) {
1243 double acc = 0.0;
1244 for (int vv = 0; vv < r; ++vv) {
1245 acc += Hrow[tid * r + vv] * row_dir[vv];
1246 }
1247 action[tid] = acc;
1248 }
1249 __syncthreads();
1250
1251 // Pull back into joint β slot.
1252 // marginal: out[j] += action[0] · mrow[j] (parallel j)
1253 double a0 = action[0];
1254 for (int j = tid; j < p_m; j += blockDim.x) {
1255 out[j] += a0 * mrow[j];
1256 }
1257 double a1 = action[1];
1258 for (int j = tid; j < p_g; j += blockDim.x) {
1259 out[p_m + j] += a1 * grow[j];
1260 }
1261 if (tid == 0) {
1262 for (int k = 0; k < h_block_len; ++k) {
1263 out[h_block_start + k] += action[h_primary_start + k];
1264 }
1265 for (int k = 0; k < w_block_len; ++k) {
1266 out[w_block_start + k] += action[w_primary_start + k];
1267 }
1268 }
1269 __syncthreads();
1270 }
1271}
1272
1273extern "C" __global__ void bms_flex_row_hvp_reduce(
1274 int num_chunks,
1275 int p_total,
1276 const double * __restrict__ partial, // [num_chunks, p_total]
1277 double * __restrict__ out) // [p_total]
1278{
1279 int j = blockIdx.x * blockDim.x + threadIdx.x;
1280 if (j >= p_total) return;
1281 double acc = 0.0;
1282 for (int c = 0; c < num_chunks; ++c) {
1283 acc += partial[(size_t)c * (size_t)p_total + (size_t)j];
1284 }
1285 out[j] = acc;
1286}
1287
1288extern "C" __global__ void bms_flex_row_hvp_multi_partial(
1289 int n_rows,
1290 int r,
1291 int p_m,
1292 int p_g,
1293 int p_total,
1294 int h_block_start,
1295 int h_block_len,
1296 int w_block_start,
1297 int w_block_len,
1298 int h_primary_start,
1299 int w_primary_start,
1300 int rows_per_cta,
1301 int rhs_count,
1302 const double * __restrict__ row_hessians, // [n, r*r]
1303 const double * __restrict__ marginal_design, // [n, p_m]
1304 const double * __restrict__ logslope_design, // [n, p_g]
1305 const double * __restrict__ v_rhs, // [rhs_count, p_total]
1306 double * __restrict__ partial) // [rhs_count, num_chunks, p_total]
1307{
1308 int chunk = blockIdx.x;
1309 int tid = threadIdx.x;
1310 int row_lo = chunk * rows_per_cta;
1311 int row_hi = row_lo + rows_per_cta;
1312 if (row_hi > n_rows) row_hi = n_rows;
1313
1314 int num_chunks = (n_rows + rows_per_cta - 1) / rows_per_cta;
1315 for (int idx = tid; idx < rhs_count * p_total; idx += blockDim.x) {
1316 int rhs = idx / p_total;
1317 int j = idx - rhs * p_total;
1318 partial[((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total + (size_t)j] = 0.0;
1319 }
1320 __syncthreads();
1321
1322 __shared__ double row_dir[MAX_MULTI_RHS * 32];
1323 __shared__ double action[MAX_MULTI_RHS * 32];
1324 __shared__ double dot_reduce[128];
1325
1326 for (int row = row_lo; row < row_hi; ++row) {
1327 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1328 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1329 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1330
1331 for (int rhs = 0; rhs < rhs_count; ++rhs) {
1332 const double *v = v_rhs + (size_t)rhs * (size_t)p_total;
1333
1334 double local = 0.0;
1335 for (int j = tid; j < p_m; j += blockDim.x) {
1336 local += mrow[j] * v[j];
1337 }
1338 dot_reduce[tid] = local;
1339 __syncthreads();
1340 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1341 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1342 __syncthreads();
1343 }
1344 if (tid == 0) row_dir[rhs * 32 + 0] = dot_reduce[0];
1345
1346 local = 0.0;
1347 for (int j = tid; j < p_g; j += blockDim.x) {
1348 local += grow[j] * v[p_m + j];
1349 }
1350 dot_reduce[tid] = local;
1351 __syncthreads();
1352 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1353 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1354 __syncthreads();
1355 }
1356 if (tid == 0) {
1357 row_dir[rhs * 32 + 1] = dot_reduce[0];
1358 for (int k = 0; k < h_block_len; ++k) {
1359 row_dir[rhs * 32 + h_primary_start + k] = v[h_block_start + k];
1360 }
1361 for (int k = 0; k < w_block_len; ++k) {
1362 row_dir[rhs * 32 + w_primary_start + k] = v[w_block_start + k];
1363 }
1364 }
1365 __syncthreads();
1366 }
1367
1368 for (int idx = tid; idx < rhs_count * r; idx += blockDim.x) {
1369 int rhs = idx / r;
1370 int u = idx - rhs * r;
1371 double acc = 0.0;
1372 const double *dir = row_dir + rhs * 32;
1373 for (int vv = 0; vv < r; ++vv) {
1374 acc += Hrow[u * r + vv] * dir[vv];
1375 }
1376 action[rhs * 32 + u] = acc;
1377 }
1378 __syncthreads();
1379
1380 for (int rhs = 0; rhs < rhs_count; ++rhs) {
1381 double *out = partial + ((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total;
1382 double a0 = action[rhs * 32 + 0];
1383 for (int j = tid; j < p_m; j += blockDim.x) {
1384 out[j] += a0 * mrow[j];
1385 }
1386 double a1 = action[rhs * 32 + 1];
1387 for (int j = tid; j < p_g; j += blockDim.x) {
1388 out[p_m + j] += a1 * grow[j];
1389 }
1390 if (tid == 0) {
1391 for (int k = 0; k < h_block_len; ++k) {
1392 out[h_block_start + k] += action[rhs * 32 + h_primary_start + k];
1393 }
1394 for (int k = 0; k < w_block_len; ++k) {
1395 out[w_block_start + k] += action[rhs * 32 + w_primary_start + k];
1396 }
1397 }
1398 __syncthreads();
1399 }
1400 }
1401}
1402
1403extern "C" __global__ void bms_flex_row_hvp_multi_reduce(
1404 int num_chunks,
1405 int p_total,
1406 int rhs_count,
1407 const double * __restrict__ partial, // [rhs_count, num_chunks, p_total]
1408 double * __restrict__ out) // [rhs_count, p_total]
1409{
1410 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1411 int total = rhs_count * p_total;
1412 if (idx >= total) return;
1413 int rhs = idx / p_total;
1414 int j = idx - rhs * p_total;
1415 double acc = 0.0;
1416 for (int c = 0; c < num_chunks; ++c) {
1417 acc += partial[((size_t)rhs * (size_t)num_chunks + (size_t)c) * (size_t)p_total + (size_t)j];
1418 }
1419 out[(size_t)rhs * (size_t)p_total + (size_t)j] = acc;
1420}
1421
1422extern "C" __global__ void bms_flex_row_diag_partial(
1423 int n_rows,
1424 int r,
1425 int p_m,
1426 int p_g,
1427 int p_total,
1428 int h_block_start,
1429 int h_block_len,
1430 int w_block_start,
1431 int w_block_len,
1432 int h_primary_start,
1433 int w_primary_start,
1434 int rows_per_cta,
1435 const double * __restrict__ row_hessians,
1436 const double * __restrict__ marginal_design,
1437 const double * __restrict__ logslope_design,
1438 double * __restrict__ partial)
1439{
1440 int chunk = blockIdx.x;
1441 int tid = threadIdx.x;
1442 int row_lo = chunk * rows_per_cta;
1443 int row_hi = row_lo + rows_per_cta;
1444 if (row_hi > n_rows) row_hi = n_rows;
1445
1446 double *out = partial + (size_t)chunk * (size_t)p_total;
1447 for (int j = tid; j < p_total; j += blockDim.x) {
1448 out[j] = 0.0;
1449 }
1450 __syncthreads();
1451
1452 for (int row = row_lo; row < row_hi; ++row) {
1453 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1454 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1455 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1456 double h00 = Hrow[0];
1457 double h11 = Hrow[1 * r + 1];
1458 for (int j = tid; j < p_m; j += blockDim.x) {
1459 double v = mrow[j];
1460 out[j] += h00 * v * v;
1461 }
1462 for (int j = tid; j < p_g; j += blockDim.x) {
1463 double v = grow[j];
1464 out[p_m + j] += h11 * v * v;
1465 }
1466 if (tid == 0) {
1467 for (int k = 0; k < h_block_len; ++k) {
1468 int ii = h_primary_start + k;
1469 out[h_block_start + k] += Hrow[ii * r + ii];
1470 }
1471 for (int k = 0; k < w_block_len; ++k) {
1472 int ii = w_primary_start + k;
1473 out[w_block_start + k] += Hrow[ii * r + ii];
1474 }
1475 }
1476 __syncthreads();
1477 }
1478}
1479
1480// ────────────────────────────────────────────────────────────────────────
1481// Phase 4 — SymmetricPackedUpper variants. Per-row storage is
1482// row_hessians_packed + (size_t)row * (size_t)(r*(r+1)/2)
1483// indexed as
1484// packed[(u*(2*r - u - 1))/2 + (v - u)] for u <= v
1485// with symmetric mirror for v < u.
1486// ────────────────────────────────────────────────────────────────────────
1487
1488// Helper: packed-upper index for (u, v) within a single row of r*(r+1)/2
1489// doubles. Caller must pre-swap so that u <= v.
1490__device__ __forceinline__ int bms_flex_packed_idx(int u, int v, int r) {
1491 // u*(2r - u - 1)/2 + (v - u)
1492 return (u * (2 * r - u - 1)) / 2 + (v - u);
1493}
1494
1495// Pack one row of the full row-major r×r Hessian into packed-upper layout.
1496// Launched as one CTA per row (gridDim.x = n_rows, blockDim.x configurable).
1497// Bit-equal copy: each upper-triangle entry is read once from the dense
1498// source and written once to the packed destination.
1499extern "C" __global__ void bms_flex_row_pack_upper(
1500 int n_rows,
1501 int r,
1502 const double * __restrict__ src_full, // [n, r*r]
1503 double * __restrict__ dst_packed) // [n, r*(r+1)/2]
1504{
1505 int row = blockIdx.x;
1506 if (row >= n_rows) return;
1507 int tid = threadIdx.x;
1508 int per_row = r * (r + 1) / 2;
1509 const double *src = src_full + (size_t)row * (size_t)r * (size_t)r;
1510 double *dst = dst_packed + (size_t)row * (size_t)per_row;
1511 // Linear scan over packed positions; map each back to (u, v).
1512 for (int pos = tid; pos < per_row; pos += blockDim.x) {
1513 // Invert: for u in [0, r), the range [u_start, u_start + (r - u))
1514 // contains positions for that u. u_start = u*(2r - u - 1)/2.
1515 // Solve smallest u with u*(2r - u - 1)/2 > pos to get u (then
1516 // back off by one); equivalent O(r) linear scan with r <= 32.
1517 int u = 0;
1518 int u_start = 0;
1519 while (u < r) {
1520 int next = u_start + (r - u);
1521 if (pos < next) break;
1522 u_start = next;
1523 ++u;
1524 }
1525 int v = u + (pos - u_start);
1526 dst[pos] = src[(size_t)u * (size_t)r + (size_t)v];
1527 }
1528}
1529
1530extern "C" __global__ void bms_flex_row_hvp_partial_packed(
1531 int n_rows,
1532 int r,
1533 int p_m,
1534 int p_g,
1535 int p_total,
1536 int h_block_start,
1537 int h_block_len,
1538 int w_block_start,
1539 int w_block_len,
1540 int h_primary_start,
1541 int w_primary_start,
1542 int rows_per_cta,
1543 const double * __restrict__ row_hessians_packed, // [n, r*(r+1)/2]
1544 const double * __restrict__ marginal_design,
1545 const double * __restrict__ logslope_design,
1546 const double * __restrict__ v,
1547 double * __restrict__ partial)
1548{
1549 int chunk = blockIdx.x;
1550 int tid = threadIdx.x;
1551 int row_lo = chunk * rows_per_cta;
1552 int row_hi = row_lo + rows_per_cta;
1553 if (row_hi > n_rows) row_hi = n_rows;
1554
1555 int per_row = r * (r + 1) / 2;
1556 double *out = partial + (size_t)chunk * (size_t)p_total;
1557 for (int j = tid; j < p_total; j += blockDim.x) {
1558 out[j] = 0.0;
1559 }
1560 __syncthreads();
1561
1562 __shared__ double row_dir[32];
1563 __shared__ double action[32];
1564 __shared__ double dot_reduce[128];
1565
1566 for (int row = row_lo; row < row_hi; ++row) {
1567 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1568 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1569 const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1570
1571 // row_dir[0] = mrow · v[0..p_m]
1572 double local = 0.0;
1573 for (int j = tid; j < p_m; j += blockDim.x) {
1574 local += mrow[j] * v[j];
1575 }
1576 dot_reduce[tid] = local;
1577 __syncthreads();
1578 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1579 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1580 __syncthreads();
1581 }
1582 if (tid == 0) row_dir[0] = dot_reduce[0];
1583
1584 // row_dir[1] = grow · v[p_m..p_m+p_g]
1585 local = 0.0;
1586 for (int j = tid; j < p_g; j += blockDim.x) {
1587 local += grow[j] * v[p_m + j];
1588 }
1589 dot_reduce[tid] = local;
1590 __syncthreads();
1591 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1592 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1593 __syncthreads();
1594 }
1595 if (tid == 0) row_dir[1] = dot_reduce[0];
1596
1597 if (tid == 0) {
1598 for (int k = 0; k < h_block_len; ++k) {
1599 row_dir[h_primary_start + k] = v[h_block_start + k];
1600 }
1601 for (int k = 0; k < w_block_len; ++k) {
1602 row_dir[w_primary_start + k] = v[w_block_start + k];
1603 }
1604 }
1605 __syncthreads();
1606
1607 // action[u] = Σ_w H[u, w] · row_dir[w], where H[u, w] reads from
1608 // packed-upper with (uu, vv) = (min(u, w), max(u, w)).
1609 if (tid < r) {
1610 double acc = 0.0;
1611 int u = tid;
1612 for (int w = 0; w < r; ++w) {
1613 int uu = u < w ? u : w;
1614 int vv = u < w ? w : u;
1615 acc += Hrow[bms_flex_packed_idx(uu, vv, r)] * row_dir[w];
1616 }
1617 action[tid] = acc;
1618 }
1619 __syncthreads();
1620
1621 double a0 = action[0];
1622 for (int j = tid; j < p_m; j += blockDim.x) {
1623 out[j] += a0 * mrow[j];
1624 }
1625 double a1 = action[1];
1626 for (int j = tid; j < p_g; j += blockDim.x) {
1627 out[p_m + j] += a1 * grow[j];
1628 }
1629 if (tid == 0) {
1630 for (int k = 0; k < h_block_len; ++k) {
1631 out[h_block_start + k] += action[h_primary_start + k];
1632 }
1633 for (int k = 0; k < w_block_len; ++k) {
1634 out[w_block_start + k] += action[w_primary_start + k];
1635 }
1636 }
1637 __syncthreads();
1638 }
1639}
1640
1641// ────────────────────────────────────────────────────────────────────────
1642// Phase 6 — dense joint-Hessian block kernel for the debug / exact-REML
1643// route. Materialises the full `[p_total, p_total]` row-major joint H
1644// from the per-row r×r Hessian via the P_i pullback. NOT the default
1645// Newton path: production Newton uses HVP (Phase 2/3); this kernel exists
1646// for exact-REML logdet / dense-H comparisons / diagnostic dumps where the
1647// caller genuinely needs the dense matrix on the device.
1648//
1649// Per-CTA partial: each CTA owns a contiguous chunk of rows
1650// `[chunk*rows_per_cta, (chunk+1)*rows_per_cta)`. Inside the CTA the
1651// per-row pullback computes `(P_i^T H_i P_i)[m, n]` and adds it to the
1652// CTA's shared-mem `[p_total, p_total]` partial. The reduce kernel sums
1653// chunk-major-fixed-order into a single `[p_total, p_total]` output.
1654//
1655// Math: for primary index u ∈ [0, r):
1656// * u = 0: phi_u = (X_i in slot 0..p_m, 0 elsewhere)
1657// * u = 1: phi_u = (0, G_i in slot p_m..p_m+p_g, 0 elsewhere)
1658// * u = 2+j: phi_u = e_{h_block_start + j} (j ∈ 0..h_block_len)
1659// * u = 2+h+l: phi_u = e_{w_block_start + l} (l ∈ 0..w_block_len)
1660// Then `H_full[m, n] += sum_{u,v} H_i[u,v] * phi_u[m] * phi_v[n]`.
1661//
1662// Shared-memory budget: at large-scale shape p_total = 44, a [44, 44] f64
1663// partial is 44*44*8 = 15.5 KiB — well below the V100 48 KiB/SM cap.
1664// At p_total ≤ 80 the kernel still fits (80*80*8 = 50 KiB → just over
1665// V100 cap; caller must enforce p_total ≤ DENSE_BLOCK_MAX_P). The
1666// launcher rejects oversize p_total cleanly.
1667
1668extern "C" __global__ void bms_flex_row_dense_block_partial(
1669 int n_rows,
1670 int r,
1671 int p_m,
1672 int p_g,
1673 int p_total,
1674 int h_block_start,
1675 int h_block_len,
1676 int w_block_start,
1677 int w_block_len,
1678 int h_primary_start,
1679 int w_primary_start,
1680 int rows_per_cta,
1681 const double * __restrict__ row_hessians, // [n, r*r]
1682 const double * __restrict__ marginal_design, // [n, p_m]
1683 const double * __restrict__ logslope_design, // [n, p_g]
1684 double * __restrict__ partial) // [num_chunks, p_total, p_total]
1685{
1686 extern __shared__ double shmem[];
1687 int chunk = blockIdx.x;
1688 int tid = threadIdx.x;
1689 int row_lo = chunk * rows_per_cta;
1690 int row_hi = row_lo + rows_per_cta;
1691 if (row_hi > n_rows) row_hi = n_rows;
1692
1693 int pp = p_total * p_total;
1694 double *acc = shmem; // CTA-private accumulator [p_total, p_total]
1695 for (int j = tid; j < pp; j += blockDim.x) acc[j] = 0.0;
1696 __syncthreads();
1697
1698 // Per-row work performed by thread 0 to avoid cross-thread RW
1699 // contention on `acc[]`. Per-row complexity is O(r * p_m + r * p_g
1700 // + r²): tractable because r ≤ 32 and p_m + p_g typically ≤ 64.
1701 // Tighter parallel implementations are possible (warp-stripe the
1702 // 4-way nested u-v-m-n loop) but Phase 6 is a debug-only path and
1703 // the simple version is easier to audit for correctness against
1704 // the host-side P_i pullback oracle.
1705 if (tid == 0) {
1706 for (int row = row_lo; row < row_hi; ++row) {
1707 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1708 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1709 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1710 for (int u = 0; u < r; ++u) {
1711 for (int v = 0; v < r; ++v) {
1712 double huv = Hrow[u * r + v];
1713 if (huv == 0.0) continue;
1714 // For each (u, v), iterate (m, n) over the non-zero
1715 // outer-product support of phi_u and phi_v.
1716 // Build a small (offset, len, src_ptr) descriptor for
1717 // each operand block as we go.
1718 int m_off, m_len; const double *m_src; bool m_indicator;
1719 int n_off, n_len; const double *n_src; bool n_indicator;
1720 if (u == 0) { m_off = 0; m_len = p_m; m_src = mrow; m_indicator = false; }
1721 else if (u == 1) { m_off = p_m; m_len = p_g; m_src = grow; m_indicator = false; }
1722 else if (u - 2 < h_block_len) {
1723 m_off = h_block_start + (u - 2);
1724 m_len = 1; m_src = NULL; m_indicator = true;
1725 } else {
1726 m_off = w_block_start + (u - 2 - h_block_len);
1727 m_len = 1; m_src = NULL; m_indicator = true;
1728 }
1729 if (v == 0) { n_off = 0; n_len = p_m; n_src = mrow; n_indicator = false; }
1730 else if (v == 1) { n_off = p_m; n_len = p_g; n_src = grow; n_indicator = false; }
1731 else if (v - 2 < h_block_len) {
1732 n_off = h_block_start + (v - 2);
1733 n_len = 1; n_src = NULL; n_indicator = true;
1734 } else {
1735 n_off = w_block_start + (v - 2 - h_block_len);
1736 n_len = 1; n_src = NULL; n_indicator = true;
1737 }
1738 // accumulate huv * phi_u[m] * phi_v[n] into acc[m, n]
1739 for (int mi = 0; mi < m_len; ++mi) {
1740 double pm = m_indicator ? 1.0 : m_src[mi];
1741 if (pm == 0.0) continue;
1742 double scaled = huv * pm;
1743 int m_idx = m_off + mi;
1744 for (int ni = 0; ni < n_len; ++ni) {
1745 double pn = n_indicator ? 1.0 : n_src[ni];
1746 int n_idx = n_off + ni;
1747 acc[m_idx * p_total + n_idx] += scaled * pn;
1748 }
1749 }
1750 }
1751 }
1752 }
1753 }
1754 __syncthreads();
1755
1756 // Write CTA accumulator out to global memory at its chunk slot.
1757 double *out_chunk = partial + (size_t)chunk * (size_t)pp;
1758 for (int j = tid; j < pp; j += blockDim.x) {
1759 out_chunk[j] = acc[j];
1760 }
1761}
1762
1763extern "C" __global__ void bms_flex_row_dense_block_reduce(
1764 int num_chunks,
1765 int p_total,
1766 const double * __restrict__ partial,
1767 double * __restrict__ out)
1768{
1769 int j = blockIdx.x * blockDim.x + threadIdx.x;
1770 int pp = p_total * p_total;
1771 if (j >= pp) return;
1772 double acc = 0.0;
1773 for (int c = 0; c < num_chunks; ++c) {
1774 acc += partial[(size_t)c * (size_t)pp + (size_t)j];
1775 }
1776 out[j] = acc;
1777}
1778
1779extern "C" __global__ void bms_flex_row_diag_partial_packed(
1780 int n_rows,
1781 int r,
1782 int p_m,
1783 int p_g,
1784 int p_total,
1785 int h_block_start,
1786 int h_block_len,
1787 int w_block_start,
1788 int w_block_len,
1789 int h_primary_start,
1790 int w_primary_start,
1791 int rows_per_cta,
1792 const double * __restrict__ row_hessians_packed,
1793 const double * __restrict__ marginal_design,
1794 const double * __restrict__ logslope_design,
1795 double * __restrict__ partial)
1796{
1797 int chunk = blockIdx.x;
1798 int tid = threadIdx.x;
1799 int row_lo = chunk * rows_per_cta;
1800 int row_hi = row_lo + rows_per_cta;
1801 if (row_hi > n_rows) row_hi = n_rows;
1802
1803 int per_row = r * (r + 1) / 2;
1804 double *out = partial + (size_t)chunk * (size_t)p_total;
1805 for (int j = tid; j < p_total; j += blockDim.x) {
1806 out[j] = 0.0;
1807 }
1808 __syncthreads();
1809
1810 for (int row = row_lo; row < row_hi; ++row) {
1811 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1812 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1813 const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1814 // Diagonal entry for (u, u) sits at packed_idx(u, u, r).
1815 double h00 = Hrow[bms_flex_packed_idx(0, 0, r)];
1816 double h11 = Hrow[bms_flex_packed_idx(1, 1, r)];
1817 for (int j = tid; j < p_m; j += blockDim.x) {
1818 double v = mrow[j];
1819 out[j] += h00 * v * v;
1820 }
1821 for (int j = tid; j < p_g; j += blockDim.x) {
1822 double v = grow[j];
1823 out[p_m + j] += h11 * v * v;
1824 }
1825 if (tid == 0) {
1826 for (int k = 0; k < h_block_len; ++k) {
1827 int ii = h_primary_start + k;
1828 out[h_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1829 }
1830 for (int k = 0; k < w_block_len; ++k) {
1831 int ii = w_primary_start + k;
1832 out[w_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1833 }
1834 }
1835 __syncthreads();
1836 }
1837}
1838"#;
1839
1840#[cfg(target_os = "linux")]
1841pub(crate) struct HvpKernelBackend {
1842 pub(crate) stream: Arc<CudaStream>,
1843 pub(crate) module: Arc<CudaModule>,
1844}
1845
1846#[cfg(target_os = "linux")]
1847impl HvpKernelBackend {
1848 pub(crate) fn probe() -> Result<&'static Self, GpuError> {
1849 static BACKEND: OnceLock<Result<HvpKernelBackend, GpuError>> = OnceLock::new();
1850 BACKEND
1851 .get_or_init(|| {
1852 gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row hvp", |parts| {
1853 let ptx = gam_gpu::device_cache::compile_ptx_arch(HVP_KERNEL_SOURCE).map_err(
1857 |err| GpuError::DriverCallFailed {
1858 reason: format!("bms_flex_row hvp NVRTC compile failed: {err}"),
1859 },
1860 )?;
1861 let module =
1862 parts
1863 .ctx
1864 .load_module(ptx)
1865 .map_err(|err| GpuError::DriverCallFailed {
1866 reason: format!("bms_flex_row hvp module load failed: {err}"),
1867 })?;
1868 Ok(HvpKernelBackend {
1869 stream: parts.stream.clone(),
1870 module,
1871 })
1872 })
1873 })
1874 .as_ref()
1875 .map_err(GpuError::clone)
1876 }
1877}
1878
1879#[cfg(target_os = "linux")]
1905pub(crate) fn launch_bms_flex_row_kernel_device_resident(
1906 inputs: BmsFlexRowKernelInputs<'_>,
1907 marginal_design_row_major: &[f64],
1908 logslope_design_row_major: &[f64],
1909 block: BmsFlexBlockLayout,
1910 primary: BmsFlexPrimaryLayout,
1911) -> Result<DeviceResidentRowHess, GpuError> {
1912 inputs.validate()?;
1913 if !s_f_diagnostic_finite(&inputs) {
1914 return Err(GpuError::DriverCallFailed {
1915 reason: format!(
1916 "bms_flex_row device-resident: s_f must be positive and finite, got {}",
1917 inputs.s_f
1918 ),
1919 });
1920 }
1921 let n = inputs.n_rows;
1922 let r = inputs.r;
1923 if marginal_design_row_major.len() != n * block.p_m {
1924 return Err(GpuError::DriverCallFailed {
1925 reason: format!(
1926 "bms_flex_row device-resident: marginal_design len={} != n*p_m={}",
1927 marginal_design_row_major.len(),
1928 n * block.p_m
1929 ),
1930 });
1931 }
1932 if logslope_design_row_major.len() != n * block.p_g {
1933 return Err(GpuError::DriverCallFailed {
1934 reason: format!(
1935 "bms_flex_row device-resident: logslope_design len={} != n*p_g={}",
1936 logslope_design_row_major.len(),
1937 n * block.p_g
1938 ),
1939 });
1940 }
1941 if primary.r != r {
1942 return Err(GpuError::DriverCallFailed {
1943 reason: format!(
1944 "bms_flex_row device-resident: primary.r={} != inputs.r={}",
1945 primary.r, r
1946 ),
1947 });
1948 }
1949
1950 let backend = RowKernelBackend::probe()?;
1953 HvpKernelBackend::probe()?;
1954 let stream = backend.stream.clone();
1955
1956 let upload_f64 = |slice: &[f64], label: &str| {
1957 stream
1958 .clone_htod(slice)
1959 .map_err(|err| GpuError::DriverCallFailed {
1960 reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1961 })
1962 };
1963 let upload_u32 = |slice: &[u32], label: &str| {
1964 stream
1965 .clone_htod(slice)
1966 .map_err(|err| GpuError::DriverCallFailed {
1967 reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1968 })
1969 };
1970
1971 let d_q = upload_f64(inputs.q, "q")?;
1972 let d_b = upload_f64(inputs.b, "b")?;
1973 let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
1974 let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
1975 let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
1976 let d_y = upload_f64(inputs.y, "y")?;
1977 let d_w = upload_f64(inputs.w, "w")?;
1978 let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
1979 let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
1980 let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
1981 let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
1982 let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
1983 let d_a = upload_f64(inputs.cell_a, "cell_a")?;
1984 let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
1985 let d_r = upload_f64(inputs.cell_r, "cell_r")?;
1986 let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
1987 let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
1988 let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
1989 let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
1990 let owned_host_moments: CudaSlice<f64>;
1992 let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
1993 CellMomentsSource::Host(slice) => {
1994 owned_host_moments = upload_f64(slice, "cell_moments")?;
1995 &owned_host_moments
1996 }
1997 CellMomentsSource::Device(d) => *d,
1998 };
1999 let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
2000 let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
2001 let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
2002 let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
2003 let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
2004 let d_e_obs = upload_f64(inputs.e_obs, "e_obs")?;
2005
2006 let d_marginal = upload_f64(marginal_design_row_major, "marginal_design")?;
2007 let d_logslope = upload_f64(logslope_design_row_major, "logslope_design")?;
2008
2009 let mut d_neglog = stream
2010 .alloc_zeros::<f64>(n)
2011 .map_err(|err| GpuError::DriverCallFailed {
2012 reason: format!("bms_flex_row device-resident alloc neglog: {err}"),
2013 })?;
2014 let mut d_grad =
2015 stream
2016 .alloc_zeros::<f64>(n * r)
2017 .map_err(|err| GpuError::DriverCallFailed {
2018 reason: format!("bms_flex_row device-resident alloc grad: {err}"),
2019 })?;
2020 let mut d_hess =
2021 stream
2022 .alloc_zeros::<f64>(n * r * r)
2023 .map_err(|err| GpuError::DriverCallFailed {
2024 reason: format!("bms_flex_row device-resident alloc hess: {err}"),
2025 })?;
2026
2027 let func = backend
2028 .module
2029 .load_function("bms_flex_row_kernel")
2030 .map_err(|err| GpuError::DriverCallFailed {
2031 reason: format!("bms_flex_row device-resident load_function: {err}"),
2032 })?;
2033
2034 let cfg = LaunchConfig {
2035 grid_dim: (n as u32, 1, 1),
2036 block_dim: (ROW_KERNEL_THREADS, 1, 1),
2037 shared_mem_bytes: 0,
2038 };
2039 let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
2040 reason: format!("bms_flex_row device-resident: n_rows={n} exceeds i32 range"),
2041 })?;
2042 let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
2043 reason: format!("bms_flex_row device-resident: r={r} exceeds i32 range"),
2044 })?;
2045 let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
2046 reason: format!(
2047 "bms_flex_row device-resident: p_h={} exceeds i32 range",
2048 inputs.p_h
2049 ),
2050 })?;
2051 let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
2052 reason: format!(
2053 "bms_flex_row device-resident: p_w={} exceeds i32 range",
2054 inputs.p_w
2055 ),
2056 })?;
2057 let s_f_val = inputs.s_f;
2058
2059 let mut builder = stream.launch_builder(&func);
2060 builder
2061 .arg(&n_i32)
2062 .arg(&r_i32)
2063 .arg(&p_h_i32)
2064 .arg(&p_w_i32)
2065 .arg(&s_f_val)
2066 .arg(&d_q)
2067 .arg(&d_b)
2068 .arg(&d_mu1)
2069 .arg(&d_mu2)
2070 .arg(&d_zobs)
2071 .arg(&d_y)
2072 .arg(&d_w)
2073 .arg(&d_offsets)
2074 .arg(&d_c0)
2075 .arg(&d_c1)
2076 .arg(&d_c2)
2077 .arg(&d_c3)
2078 .arg(&d_a)
2079 .arg(&d_aa)
2080 .arg(&d_r)
2081 .arg(&d_ar)
2082 .arg(&d_sbb)
2083 .arg(&d_sbh)
2084 .arg(&d_sbw)
2085 .arg(d_moments_ref)
2086 .arg(&d_chi)
2087 .arg(&d_xi)
2088 .arg(&d_rho)
2089 .arg(&d_tau)
2090 .arg(&d_ruv)
2091 .arg(&d_e_obs)
2092 .arg(&mut d_neglog)
2093 .arg(&mut d_grad)
2094 .arg(&mut d_hess);
2095 unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
2100 reason: format!("bms_flex_row device-resident launch: {err}"),
2101 })?;
2102 stream
2103 .synchronize()
2104 .map_err(|err| GpuError::DriverCallFailed {
2105 reason: format!("bms_flex_row device-resident synchronize: {err}"),
2106 })?;
2107
2108 drop(d_neglog);
2116 drop(d_grad);
2117 drop(d_q);
2119 drop(d_b);
2120 drop(d_mu1);
2121 drop(d_mu2);
2122 drop(d_zobs);
2123 drop(d_y);
2124 drop(d_w);
2125 drop(d_offsets);
2126 drop(d_c0);
2127 drop(d_c1);
2128 drop(d_c2);
2129 drop(d_c3);
2130 drop(d_a);
2131 drop(d_aa);
2132 drop(d_r);
2133 drop(d_ar);
2134 drop(d_sbb);
2135 drop(d_sbh);
2136 drop(d_sbw);
2137 drop(d_chi);
2141 drop(d_xi);
2142 drop(d_rho);
2143 drop(d_tau);
2144 drop(d_ruv);
2145
2146 let bytes = ((n * r * r + marginal_design_row_major.len() + logslope_design_row_major.len())
2147 * std::mem::size_of::<f64>()) as u64;
2148 Ok(DeviceResidentRowHess {
2149 hess: d_hess,
2150 marginal_design: d_marginal,
2151 logslope_design: d_logslope,
2152 n,
2153 r,
2154 block,
2155 primary,
2156 bytes,
2157 })
2158}
2159
2160#[cfg(target_os = "linux")]
2165#[derive(Clone, Copy)]
2166pub(crate) enum BmsFlexRowLaunchMode {
2167 HvpDeviceOut,
2169 DiagonalHostOut,
2171}
2172
2173#[cfg(target_os = "linux")]
2174impl BmsFlexRowLaunchMode {
2175 pub(crate) fn partial_kernel_name(self) -> &'static str {
2177 match self {
2178 BmsFlexRowLaunchMode::HvpDeviceOut => "bms_flex_row_hvp_partial",
2179 BmsFlexRowLaunchMode::DiagonalHostOut => "bms_flex_row_diag_partial",
2180 }
2181 }
2182}
2183
2184#[cfg(target_os = "linux")]
2190pub(crate) struct PreparedBmsFlexRowLaunchArgs {
2191 pub(crate) n_i32: i32,
2192 pub(crate) r_i32: i32,
2193 pub(crate) p_m_i32: i32,
2194 pub(crate) p_g_i32: i32,
2195 pub(crate) p_total_i32: i32,
2196 pub(crate) h_block_start: i32,
2197 pub(crate) h_block_len: i32,
2198 pub(crate) w_block_start: i32,
2199 pub(crate) w_block_len: i32,
2200 pub(crate) h_primary_start: i32,
2201 pub(crate) w_primary_start: i32,
2202 pub(crate) rows_per_cta: i32,
2203 pub(crate) num_chunks: usize,
2204}
2205
2206#[cfg(target_os = "linux")]
2207impl PreparedBmsFlexRowLaunchArgs {
2208 pub(crate) fn from_storage(storage: &DeviceResidentRowHess) -> Self {
2209 let p_total = storage.block.p_total;
2210 let num_chunks = num_hvp_chunks(storage.n);
2211 PreparedBmsFlexRowLaunchArgs {
2212 n_i32: storage.n as i32,
2213 r_i32: storage.r as i32,
2214 p_m_i32: storage.block.p_m as i32,
2215 p_g_i32: storage.block.p_g as i32,
2216 p_total_i32: p_total as i32,
2217 h_block_start: storage
2218 .block
2219 .h
2220 .as_ref()
2221 .map(|r| r.start as i32)
2222 .unwrap_or(0),
2223 h_block_len: storage
2224 .block
2225 .h
2226 .as_ref()
2227 .map(|r| r.len() as i32)
2228 .unwrap_or(0),
2229 w_block_start: storage
2230 .block
2231 .w
2232 .as_ref()
2233 .map(|r| r.start as i32)
2234 .unwrap_or(0),
2235 w_block_len: storage
2236 .block
2237 .w
2238 .as_ref()
2239 .map(|r| r.len() as i32)
2240 .unwrap_or(0),
2241 h_primary_start: storage
2242 .primary
2243 .h
2244 .as_ref()
2245 .map(|r| r.start as i32)
2246 .unwrap_or(0),
2247 w_primary_start: storage
2248 .primary
2249 .w
2250 .as_ref()
2251 .map(|r| r.start as i32)
2252 .unwrap_or(0),
2253 rows_per_cta: HVP_ROWS_PER_CTA as i32,
2254 num_chunks,
2255 }
2256 }
2257}
2258
2259#[cfg(target_os = "linux")]
2273pub(crate) fn run_bms_flex_row_partial_reduce(
2274 storage: &DeviceResidentRowHess,
2275 mode: BmsFlexRowLaunchMode,
2276 d_v: Option<&CudaSlice<f64>>,
2277 d_out: &mut CudaSlice<f64>,
2278 ctx: &str,
2279) -> Result<(), GpuError> {
2280 let backend = HvpKernelBackend::probe()?;
2281 let stream = backend.stream.clone();
2282 let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2283 let p_total = storage.block.p_total;
2284
2285 let mut d_partial = stream
2286 .alloc_zeros::<f64>(args.num_chunks * p_total)
2287 .map_err(|err| GpuError::DriverCallFailed {
2288 reason: format!("bms_flex_row {ctx} alloc partial: {err}"),
2289 })?;
2290
2291 let partial_kernel_name = mode.partial_kernel_name();
2292 let part_func = backend
2293 .module
2294 .load_function(partial_kernel_name)
2295 .map_err(|err| GpuError::DriverCallFailed {
2296 reason: format!("bms_flex_row {ctx} load {partial_kernel_name}: {err}"),
2297 })?;
2298 let red_func = backend
2299 .module
2300 .load_function("bms_flex_row_hvp_reduce")
2301 .map_err(|err| GpuError::DriverCallFailed {
2302 reason: format!("bms_flex_row {ctx} load reduce: {err}"),
2303 })?;
2304
2305 let cfg_part = LaunchConfig {
2306 grid_dim: (args.num_chunks as u32, 1, 1),
2307 block_dim: (HVP_THREADS, 1, 1),
2308 shared_mem_bytes: 0,
2309 };
2310 let mut builder = stream.launch_builder(&part_func);
2311 builder
2312 .arg(&args.n_i32)
2313 .arg(&args.r_i32)
2314 .arg(&args.p_m_i32)
2315 .arg(&args.p_g_i32)
2316 .arg(&args.p_total_i32)
2317 .arg(&args.h_block_start)
2318 .arg(&args.h_block_len)
2319 .arg(&args.w_block_start)
2320 .arg(&args.w_block_len)
2321 .arg(&args.h_primary_start)
2322 .arg(&args.w_primary_start)
2323 .arg(&args.rows_per_cta)
2324 .arg(&storage.hess)
2325 .arg(&storage.marginal_design)
2326 .arg(&storage.logslope_design);
2327 if let Some(d_v) = d_v {
2328 builder.arg(d_v);
2329 }
2330 builder.arg(&mut d_partial);
2331 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2339 reason: format!("bms_flex_row {ctx} partial launch: {err}"),
2340 })?;
2341
2342 let red_threads: u32 = REDUCTION_THREADS;
2343 let red_blocks: u32 = ((p_total as u32) + red_threads - 1) / red_threads;
2344 let cfg_red = LaunchConfig {
2345 grid_dim: (red_blocks, 1, 1),
2346 block_dim: (red_threads, 1, 1),
2347 shared_mem_bytes: 0,
2348 };
2349 let num_chunks_i32 = args.num_chunks as i32;
2350 let mut builder = stream.launch_builder(&red_func);
2351 builder
2352 .arg(&num_chunks_i32)
2353 .arg(&args.p_total_i32)
2354 .arg(&d_partial)
2355 .arg(d_out);
2356 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2360 reason: format!("bms_flex_row {ctx} reduce launch: {err}"),
2361 })?;
2362 drop(d_partial);
2365 Ok(())
2366}
2367
2368#[cfg(target_os = "linux")]
2378pub(crate) fn launch_bms_flex_row_host(
2379 storage: &DeviceResidentRowHess,
2380 mode: BmsFlexRowLaunchMode,
2381 v: Option<&[f64]>,
2382 ctx: &str,
2383) -> Result<Vec<f64>, GpuError> {
2384 let p_total = storage.block.p_total;
2385 if let Some(v) = v {
2386 if v.len() != p_total {
2387 return Err(GpuError::DriverCallFailed {
2388 reason: format!(
2389 "bms_flex_row {ctx}: v.len()={} != p_total={p_total}",
2390 v.len()
2391 ),
2392 });
2393 }
2394 }
2395
2396 let backend = HvpKernelBackend::probe()?;
2397 let stream = backend.stream.clone();
2398
2399 let d_v = match v {
2400 Some(v) => Some(
2401 stream
2402 .clone_htod(v)
2403 .map_err(|err| GpuError::DriverCallFailed {
2404 reason: format!("bms_flex_row {ctx} upload v: {err}"),
2405 })?,
2406 ),
2407 None => None,
2408 };
2409 let mut d_out =
2410 stream
2411 .alloc_zeros::<f64>(p_total)
2412 .map_err(|err| GpuError::DriverCallFailed {
2413 reason: format!("bms_flex_row {ctx} alloc out: {err}"),
2414 })?;
2415
2416 run_bms_flex_row_partial_reduce(storage, mode, d_v.as_ref(), &mut d_out, ctx)?;
2417
2418 stream
2419 .synchronize()
2420 .map_err(|err| GpuError::DriverCallFailed {
2421 reason: format!("bms_flex_row {ctx} synchronize: {err}"),
2422 })?;
2423 stream
2424 .clone_dtoh(&d_out)
2425 .map_err(|err| GpuError::DriverCallFailed {
2426 reason: format!("bms_flex_row {ctx} download out: {err}"),
2427 })
2428}
2429
2430#[cfg(target_os = "linux")]
2431pub(crate) fn validate_bms_flex_row_hvp_multi_shape(
2432 storage: &DeviceResidentRowHess,
2433 rhs_count: usize,
2434 v_rhs_len: usize,
2435 out_len: Option<usize>,
2436 ctx: &str,
2437) -> Result<usize, GpuError> {
2438 if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2439 return Err(GpuError::DriverCallFailed {
2440 reason: format!(
2441 "bms_flex_row {ctx}: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2442 ),
2443 });
2444 }
2445 let p_total = storage.block.p_total;
2446 let rhs_elems = rhs_count
2447 .checked_mul(p_total)
2448 .ok_or_else(|| GpuError::DriverCallFailed {
2449 reason: format!(
2450 "bms_flex_row {ctx}: rhs_count({rhs_count})*p_total({p_total}) overflow"
2451 ),
2452 })?;
2453 if v_rhs_len != rhs_elems {
2454 return Err(GpuError::DriverCallFailed {
2455 reason: format!(
2456 "bms_flex_row {ctx}: v_rhs.len()={v_rhs_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2457 ),
2458 });
2459 }
2460 if let Some(out_len) = out_len
2461 && out_len != rhs_elems
2462 {
2463 return Err(GpuError::DriverCallFailed {
2464 reason: format!(
2465 "bms_flex_row {ctx}: out.len()={out_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2466 ),
2467 });
2468 }
2469 Ok(rhs_elems)
2470}
2471
2472#[cfg(target_os = "linux")]
2476pub fn bms_flex_row_hvp_multi_scratch_bytes_for_shape(
2477 n: usize,
2478 p_total: usize,
2479 rhs_count: usize,
2480) -> Result<u64, GpuError> {
2481 if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2482 return Err(GpuError::DriverCallFailed {
2483 reason: format!(
2484 "bms_flex_row hvp_multi_scratch_bytes: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2485 ),
2486 });
2487 }
2488 let num_chunks = num_hvp_chunks(n);
2489 let partial = rhs_count
2490 .checked_mul(num_chunks)
2491 .and_then(|v| v.checked_mul(p_total))
2492 .ok_or_else(|| GpuError::DriverCallFailed {
2493 reason: format!(
2494 "bms_flex_row hvp_multi_scratch_bytes: rhs_count({rhs_count})*num_chunks({num_chunks})*p_total({p_total}) overflow"
2495 ),
2496 })?;
2497 let rhs_vectors = rhs_count
2498 .checked_mul(p_total)
2499 .and_then(|v| v.checked_mul(2))
2500 .ok_or_else(|| GpuError::DriverCallFailed {
2501 reason: format!(
2502 "bms_flex_row hvp_multi_scratch_bytes: 2*rhs_count({rhs_count})*p_total({p_total}) overflow"
2503 ),
2504 })?;
2505 let elems = partial
2506 .checked_add(rhs_vectors)
2507 .ok_or_else(|| GpuError::DriverCallFailed {
2508 reason: "bms_flex_row hvp_multi_scratch_bytes: element count overflow".to_string(),
2509 })?;
2510 Ok((elems * std::mem::size_of::<f64>()) as u64)
2511}
2512
2513#[cfg(target_os = "linux")]
2514pub(crate) fn run_bms_flex_row_multi_partial_reduce(
2515 storage: &DeviceResidentRowHess,
2516 rhs_count: usize,
2517 d_v_rhs: &CudaSlice<f64>,
2518 d_out: &mut CudaSlice<f64>,
2519 ctx: &str,
2520) -> Result<(), GpuError> {
2521 let rhs_elems = validate_bms_flex_row_hvp_multi_shape(
2522 storage,
2523 rhs_count,
2524 d_v_rhs.len(),
2525 Some(d_out.len()),
2526 ctx,
2527 )?;
2528 let backend = HvpKernelBackend::probe()?;
2529 let stream = backend.stream.clone();
2530 let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2531 let p_total = storage.block.p_total;
2532 let partial_len = rhs_count
2533 .checked_mul(args.num_chunks)
2534 .and_then(|v| v.checked_mul(p_total))
2535 .ok_or_else(|| GpuError::DriverCallFailed {
2536 reason: format!(
2537 "bms_flex_row {ctx}: partial length overflow for rhs_count={rhs_count}, num_chunks={}, p_total={p_total}",
2538 args.num_chunks
2539 ),
2540 })?;
2541
2542 let mut d_partial =
2543 stream
2544 .alloc_zeros::<f64>(partial_len)
2545 .map_err(|err| GpuError::DriverCallFailed {
2546 reason: format!("bms_flex_row {ctx} alloc multi partial: {err}"),
2547 })?;
2548 let part_func = backend
2549 .module
2550 .load_function("bms_flex_row_hvp_multi_partial")
2551 .map_err(|err| GpuError::DriverCallFailed {
2552 reason: format!("bms_flex_row {ctx} load multi partial: {err}"),
2553 })?;
2554 let red_func = backend
2555 .module
2556 .load_function("bms_flex_row_hvp_multi_reduce")
2557 .map_err(|err| GpuError::DriverCallFailed {
2558 reason: format!("bms_flex_row {ctx} load multi reduce: {err}"),
2559 })?;
2560
2561 let rhs_count_i32 = i32::try_from(rhs_count).map_err(|_| GpuError::DriverCallFailed {
2562 reason: format!("bms_flex_row {ctx}: rhs_count={rhs_count} exceeds i32 range"),
2563 })?;
2564 let cfg_part = LaunchConfig {
2565 grid_dim: (args.num_chunks as u32, 1, 1),
2566 block_dim: (HVP_THREADS, 1, 1),
2567 shared_mem_bytes: 0,
2568 };
2569 let mut builder = stream.launch_builder(&part_func);
2570 builder
2571 .arg(&args.n_i32)
2572 .arg(&args.r_i32)
2573 .arg(&args.p_m_i32)
2574 .arg(&args.p_g_i32)
2575 .arg(&args.p_total_i32)
2576 .arg(&args.h_block_start)
2577 .arg(&args.h_block_len)
2578 .arg(&args.w_block_start)
2579 .arg(&args.w_block_len)
2580 .arg(&args.h_primary_start)
2581 .arg(&args.w_primary_start)
2582 .arg(&args.rows_per_cta)
2583 .arg(&rhs_count_i32)
2584 .arg(&storage.hess)
2585 .arg(&storage.marginal_design)
2586 .arg(&storage.logslope_design)
2587 .arg(d_v_rhs)
2588 .arg(&mut d_partial);
2589 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2594 reason: format!("bms_flex_row {ctx} multi partial launch: {err}"),
2595 })?;
2596
2597 let red_threads: u32 = REDUCTION_THREADS;
2598 let red_blocks: u32 = ((rhs_elems as u32) + red_threads - 1) / red_threads;
2599 let cfg_red = LaunchConfig {
2600 grid_dim: (red_blocks, 1, 1),
2601 block_dim: (red_threads, 1, 1),
2602 shared_mem_bytes: 0,
2603 };
2604 let num_chunks_i32 = args.num_chunks as i32;
2605 let mut builder = stream.launch_builder(&red_func);
2606 builder
2607 .arg(&num_chunks_i32)
2608 .arg(&args.p_total_i32)
2609 .arg(&rhs_count_i32)
2610 .arg(&d_partial)
2611 .arg(d_out);
2612 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2615 reason: format!("bms_flex_row {ctx} multi reduce launch: {err}"),
2616 })?;
2617 drop(d_partial);
2618 Ok(())
2619}
2620
2621#[cfg(target_os = "linux")]
2624pub(crate) fn launch_bms_flex_row_hvp_multi(
2625 storage: &DeviceResidentRowHess,
2626 v_rhs: &[f64],
2627 rhs_count: usize,
2628) -> Result<Vec<f64>, GpuError> {
2629 let rhs_elems =
2630 validate_bms_flex_row_hvp_multi_shape(storage, rhs_count, v_rhs.len(), None, "hvp_multi")?;
2631 let backend = HvpKernelBackend::probe()?;
2632 let stream = backend.stream.clone();
2633 let d_v_rhs = stream
2634 .clone_htod(v_rhs)
2635 .map_err(|err| GpuError::DriverCallFailed {
2636 reason: format!("bms_flex_row hvp_multi upload v_rhs: {err}"),
2637 })?;
2638 let mut d_out =
2639 stream
2640 .alloc_zeros::<f64>(rhs_elems)
2641 .map_err(|err| GpuError::DriverCallFailed {
2642 reason: format!("bms_flex_row hvp_multi alloc out: {err}"),
2643 })?;
2644 run_bms_flex_row_multi_partial_reduce(storage, rhs_count, &d_v_rhs, &mut d_out, "hvp_multi")?;
2645 stream
2646 .synchronize()
2647 .map_err(|err| GpuError::DriverCallFailed {
2648 reason: format!("bms_flex_row hvp_multi synchronize: {err}"),
2649 })?;
2650 stream
2651 .clone_dtoh(&d_out)
2652 .map_err(|err| GpuError::DriverCallFailed {
2653 reason: format!("bms_flex_row hvp_multi download out: {err}"),
2654 })
2655}
2656
2657#[cfg(target_os = "linux")]
2668pub(crate) fn launch_bms_flex_row_hvp_into_device(
2669 storage: &DeviceResidentRowHess,
2670 d_v: &CudaSlice<f64>,
2671 d_out: &mut CudaSlice<f64>,
2672) -> Result<(), GpuError> {
2673 let p_total = storage.block.p_total;
2674 if d_v.len() != p_total {
2675 return Err(GpuError::DriverCallFailed {
2676 reason: format!(
2677 "bms_flex_row hvp_into_device: d_v.len()={} != p_total={}",
2678 d_v.len(),
2679 p_total
2680 ),
2681 });
2682 }
2683 if d_out.len() != p_total {
2684 return Err(GpuError::DriverCallFailed {
2685 reason: format!(
2686 "bms_flex_row hvp_into_device: d_out.len()={} != p_total={}",
2687 d_out.len(),
2688 p_total
2689 ),
2690 });
2691 }
2692 run_bms_flex_row_partial_reduce(
2696 storage,
2697 BmsFlexRowLaunchMode::HvpDeviceOut,
2698 Some(d_v),
2699 d_out,
2700 "hvp_into_device",
2701 )
2702}
2703
2704#[cfg(target_os = "linux")]
2707pub(crate) fn launch_bms_flex_row_hvp(
2708 storage: &DeviceResidentRowHess,
2709 v: &[f64],
2710) -> Result<Vec<f64>, GpuError> {
2711 launch_bms_flex_row_hvp_multi(storage, v, 1)
2712}
2713
2714#[cfg(target_os = "linux")]
2717pub(crate) fn launch_bms_flex_row_diagonal(
2718 storage: &DeviceResidentRowHess,
2719) -> Result<Vec<f64>, GpuError> {
2720 launch_bms_flex_row_host(storage, BmsFlexRowLaunchMode::DiagonalHostOut, None, "diag")
2721}
2722
2723#[cfg(target_os = "linux")]
2729pub(crate) const DENSE_BLOCK_MAX_P: usize = 72;
2730
2731#[cfg(target_os = "linux")]
2737pub(crate) const DENSE_BLOCK_ROWS_PER_CTA: u32 = 32;
2738
2739#[cfg(target_os = "linux")]
2756pub fn launch_bms_flex_row_dense_block(
2757 storage: &DeviceResidentRowHess,
2758) -> Result<Vec<f64>, GpuError> {
2759 let p_total = storage.block.p_total;
2760 if p_total == 0 {
2761 return Err(GpuError::DriverCallFailed {
2762 reason: "bms_flex_row dense_block: p_total must be > 0".to_string(),
2763 });
2764 }
2765 if p_total > DENSE_BLOCK_MAX_P {
2766 return Err(GpuError::DriverCallFailed {
2767 reason: format!(
2768 "bms_flex_row dense_block: p_total={p_total} exceeds DENSE_BLOCK_MAX_P={DENSE_BLOCK_MAX_P} \
2769 (per-CTA shmem accumulator p²*8 bytes would exceed V100's 48 KiB/block)"
2770 ),
2771 });
2772 }
2773 let backend = HvpKernelBackend::probe()?;
2774 let stream = backend.stream.clone();
2775 let n = storage.n;
2776 let r = storage.r;
2777 let rows_per_cta = DENSE_BLOCK_ROWS_PER_CTA as usize;
2778 let num_chunks = n.div_ceil(rows_per_cta);
2779 let pp = p_total * p_total;
2780
2781 let mut d_partial =
2782 stream
2783 .alloc_zeros::<f64>(num_chunks * pp)
2784 .map_err(|err| GpuError::DriverCallFailed {
2785 reason: format!("bms_flex_row dense_block alloc partial: {err}"),
2786 })?;
2787 let mut d_out = stream
2788 .alloc_zeros::<f64>(pp)
2789 .map_err(|err| GpuError::DriverCallFailed {
2790 reason: format!("bms_flex_row dense_block alloc out: {err}"),
2791 })?;
2792
2793 let part_func = backend
2794 .module
2795 .load_function("bms_flex_row_dense_block_partial")
2796 .map_err(|err| GpuError::DriverCallFailed {
2797 reason: format!("bms_flex_row dense_block load partial: {err}"),
2798 })?;
2799 let red_func = backend
2800 .module
2801 .load_function("bms_flex_row_dense_block_reduce")
2802 .map_err(|err| GpuError::DriverCallFailed {
2803 reason: format!("bms_flex_row dense_block load reduce: {err}"),
2804 })?;
2805
2806 let n_i32 = n as i32;
2807 let r_i32 = r as i32;
2808 let p_m_i32 = storage.block.p_m as i32;
2809 let p_g_i32 = storage.block.p_g as i32;
2810 let p_total_i32 = p_total as i32;
2811 let h_block_start = storage
2812 .block
2813 .h
2814 .as_ref()
2815 .map(|r| r.start as i32)
2816 .unwrap_or(0);
2817 let h_block_len = storage
2818 .block
2819 .h
2820 .as_ref()
2821 .map(|r| r.len() as i32)
2822 .unwrap_or(0);
2823 let w_block_start = storage
2824 .block
2825 .w
2826 .as_ref()
2827 .map(|r| r.start as i32)
2828 .unwrap_or(0);
2829 let w_block_len = storage
2830 .block
2831 .w
2832 .as_ref()
2833 .map(|r| r.len() as i32)
2834 .unwrap_or(0);
2835 let h_primary_start = storage
2836 .primary
2837 .h
2838 .as_ref()
2839 .map(|r| r.start as i32)
2840 .unwrap_or(0);
2841 let w_primary_start = storage
2842 .primary
2843 .w
2844 .as_ref()
2845 .map(|r| r.start as i32)
2846 .unwrap_or(0);
2847 let rows_per_cta_i32 = DENSE_BLOCK_ROWS_PER_CTA as i32;
2848 let num_chunks_u32 = num_chunks as u32;
2849
2850 let shmem_bytes: u32 =
2852 u32::try_from(pp * std::mem::size_of::<f64>()).map_err(|_| GpuError::DriverCallFailed {
2853 reason: format!("dense_block shmem bytes overflow u32 for p_total={p_total}"),
2854 })?;
2855
2856 let cfg_part = LaunchConfig {
2857 grid_dim: (num_chunks_u32, 1, 1),
2858 block_dim: (HVP_THREADS, 1, 1),
2859 shared_mem_bytes: shmem_bytes,
2860 };
2861 let mut builder = stream.launch_builder(&part_func);
2862 builder
2863 .arg(&n_i32)
2864 .arg(&r_i32)
2865 .arg(&p_m_i32)
2866 .arg(&p_g_i32)
2867 .arg(&p_total_i32)
2868 .arg(&h_block_start)
2869 .arg(&h_block_len)
2870 .arg(&w_block_start)
2871 .arg(&w_block_len)
2872 .arg(&h_primary_start)
2873 .arg(&w_primary_start)
2874 .arg(&rows_per_cta_i32)
2875 .arg(&storage.hess)
2876 .arg(&storage.marginal_design)
2877 .arg(&storage.logslope_design)
2878 .arg(&mut d_partial);
2879 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2883 reason: format!("bms_flex_row dense_block partial launch: {err}"),
2884 })?;
2885
2886 let red_threads: u32 = REDUCTION_THREADS;
2887 let red_blocks: u32 = ((pp as u32) + red_threads - 1) / red_threads;
2888 let cfg_red = LaunchConfig {
2889 grid_dim: (red_blocks, 1, 1),
2890 block_dim: (red_threads, 1, 1),
2891 shared_mem_bytes: 0,
2892 };
2893 let num_chunks_i32 = num_chunks as i32;
2894 let mut builder = stream.launch_builder(&red_func);
2895 builder
2896 .arg(&num_chunks_i32)
2897 .arg(&p_total_i32)
2898 .arg(&d_partial)
2899 .arg(&mut d_out);
2900 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2902 reason: format!("bms_flex_row dense_block reduce launch: {err}"),
2903 })?;
2904 stream
2905 .synchronize()
2906 .map_err(|err| GpuError::DriverCallFailed {
2907 reason: format!("bms_flex_row dense_block sync: {err}"),
2908 })?;
2909 stream
2910 .clone_dtoh(&d_out)
2911 .map_err(|err| GpuError::DriverCallFailed {
2912 reason: format!("bms_flex_row dense_block download: {err}"),
2913 })
2914}
2915
2916#[cfg(test)]
2928mod oracle_parity_tests {
2929 use super::*;
2930
2931 pub(crate) const ORACLE_INV_TWO_PI: f64 = 1.0 / std::f64::consts::TAU;
2946 pub(crate) const ORACLE_SQRT_2: f64 = std::f64::consts::SQRT_2;
2947 pub(crate) const ORACLE_INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
2948
2949 pub(crate) fn oracle_erfcx_nonnegative(x: f64) -> f64 {
2950 if !x.is_finite() {
2951 return if x > 0.0 { 0.0 } else { f64::INFINITY };
2952 }
2953 if x <= 0.0 {
2954 return 1.0;
2955 }
2956 if x < 26.0 {
2957 let mut xx = x * x;
2958 if xx > 700.0 {
2959 xx = 700.0;
2960 }
2961 return xx.exp() * gam_gpu::numerics_host::erfc(x);
2962 }
2963 let inv = 1.0 / x;
2964 let inv2 = inv * inv;
2965 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
2966 + 6.5625 * inv2 * inv2 * inv2 * inv2;
2967 let inv_sqrt_pi: f64 = 0.564_189_583_547_756_3;
2968 inv * poly * inv_sqrt_pi
2969 }
2970
2971 pub(crate) fn oracle_log_ndtr_and_mills(x: f64) -> (f64, f64) {
2972 if x == f64::INFINITY {
2973 return (0.0, 0.0);
2974 }
2975 if x == f64::NEG_INFINITY {
2976 return (f64::NEG_INFINITY, f64::INFINITY);
2977 }
2978 if x.is_nan() {
2979 return (x, x);
2980 }
2981 const ORACLE_LEFT_TAIL_X: f64 = -37.0;
2994 if x >= ORACLE_LEFT_TAIL_X {
2995 let mut cdf = 0.5 * gam_gpu::numerics_host::erfc(-x / ORACLE_SQRT_2);
2996 if cdf < 1e-300 {
2997 cdf = 1e-300;
2998 }
2999 if cdf > 1.0 {
3000 cdf = 1.0;
3001 }
3002 let pdf = ORACLE_INV_SQRT_2PI * (-0.5 * x * x).exp();
3003 (cdf.ln(), pdf / cdf)
3004 } else {
3005 let u = -x / ORACLE_SQRT_2;
3006 let mut ex = oracle_erfcx_nonnegative(u);
3007 if ex < 1e-300 {
3008 ex = 1e-300;
3009 }
3010 let log_cdf = -u * u + (0.5 * ex).ln();
3011 let sqrt_2_over_pi: f64 = 0.797_884_560_802_865_4;
3012 (log_cdf, sqrt_2_over_pi / ex)
3013 }
3014 }
3015
3016 pub(crate) fn cpu_oracle_outputs(
3021 inputs: &BmsFlexRowKernelInputs<'_>,
3022 ) -> BmsFlexRowKernelOutputs {
3023 let n = inputs.n_rows;
3024 let r = inputs.r;
3025 let p_h = inputs.p_h;
3026 let p_w = inputs.p_w;
3027 let mut neglog = vec![0.0_f64; n];
3028 let mut grad = vec![0.0_f64; n * r];
3029 let mut hess = vec![0.0_f64; n * r * r];
3030 let cell_moments_host = match &inputs.cell_moments {
3031 CellMomentsSource::Host(slice) => *slice,
3032 #[cfg(target_os = "linux")]
3033 CellMomentsSource::Device(_) => panic!(
3034 "cpu_oracle_outputs: cell_moments is device-resident; oracle \
3040 is a host-only sanity checker"
3041 ),
3042 };
3043
3044 for row in 0..n {
3045 let mut f_u = vec![0.0_f64; r];
3047 let mut f_au = vec![0.0_f64; r];
3048 let mut f_uv = vec![0.0_f64; r * r];
3049 let mut f_a = 0.0_f64;
3050 let mut f_aa = 0.0_f64;
3051
3052 let cell_lo = inputs.cell_offsets[row] as usize;
3053 let cell_hi = inputs.cell_offsets[row + 1] as usize;
3054 for c in cell_lo..cell_hi {
3055 let c_arr = [
3056 inputs.cell_c0[c],
3057 inputs.cell_c1[c],
3058 inputs.cell_c2[c],
3059 inputs.cell_c3[c],
3060 ];
3061 let m = &cell_moments_host[c * MOMENT_STRIDE..(c + 1) * MOMENT_STRIDE];
3062
3063 let mut t = [0.0_f64; 7];
3065 for (n_idx, t_slot) in t.iter_mut().enumerate() {
3066 let mut acc = 0.0_f64;
3067 for (e, c_e) in c_arr.iter().enumerate() {
3068 acc = c_e.mul_add(m[e + n_idx], acc);
3069 }
3070 *t_slot = acc * ORACLE_INV_TWO_PI;
3071 }
3072
3073 let d_of = |r_arr: &[f64]| -> f64 {
3074 ORACLE_INV_TWO_PI
3075 * (r_arr[0] * m[0] + r_arr[1] * m[1] + r_arr[2] * m[2] + r_arr[3] * m[3])
3076 };
3077 let q_of = |r_arr: &[f64], s_arr: &[f64]| -> f64 {
3078 (r_arr[0] * s_arr[0]) * t[0]
3079 + (r_arr[0] * s_arr[1] + r_arr[1] * s_arr[0]) * t[1]
3080 + (r_arr[0] * s_arr[2] + r_arr[1] * s_arr[1] + r_arr[2] * s_arr[0]) * t[2]
3081 + (r_arr[0] * s_arr[3]
3082 + r_arr[1] * s_arr[2]
3083 + r_arr[2] * s_arr[1]
3084 + r_arr[3] * s_arr[0])
3085 * t[3]
3086 + (r_arr[1] * s_arr[3] + r_arr[2] * s_arr[2] + r_arr[3] * s_arr[1]) * t[4]
3087 + (r_arr[2] * s_arr[3] + r_arr[3] * s_arr[2]) * t[5]
3088 + (r_arr[3] * s_arr[3]) * t[6]
3089 };
3090
3091 let a_c = &inputs.cell_a[c * 4..(c + 1) * 4];
3092 let aa_c = &inputs.cell_aa[c * 4..(c + 1) * 4];
3093 f_a += d_of(a_c);
3094 f_aa += d_of(aa_c) - q_of(a_c, a_c);
3095
3096 for u in 1..r {
3097 let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3098 let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3099 let ar_u = &inputs.cell_ar[r_u_off..r_u_off + 4];
3100 f_u[u] += d_of(r_u);
3101 f_au[u] += d_of(ar_u) - q_of(a_c, r_u);
3102 }
3103
3104 for u in 1..r {
3105 let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3106 let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3107 for v in u..r {
3108 let r_v_off = (c * (r - 1) + (v - 1)) * 4;
3109 let r_v = &inputs.cell_r[r_v_off..r_v_off + 4];
3110 let q_uv = q_of(r_u, r_v);
3111 let d_s = if u == 1 && v == 1 {
3112 let s_bb = &inputs.cell_sbb[c * 4..(c + 1) * 4];
3113 d_of(s_bb)
3114 } else if u == 1 && v >= 2 && v < 2 + p_h {
3115 let j = v - 2;
3116 let off = (c * p_h + j) * 4;
3117 let s_bh = &inputs.cell_sbh[off..off + 4];
3118 d_of(s_bh)
3119 } else if u == 1 && v >= 2 + p_h && v < r {
3120 let l = v - (2 + p_h);
3121 let off = (c * p_w + l) * 4;
3122 let s_bw = &inputs.cell_sbw[off..off + 4];
3123 d_of(s_bw)
3124 } else {
3125 0.0
3126 };
3127 f_uv[u * r + v] += d_s - q_uv;
3128 }
3129 }
3130 }
3131
3132 let mu_1 = inputs.mu_1[row];
3134 let mu_2 = inputs.mu_2[row];
3135 f_u[0] = -mu_1;
3136 f_au[0] = 0.0;
3137 for v in 0..r {
3138 f_uv[v] = 0.0;
3139 f_uv[v * r] = 0.0;
3140 }
3141 f_uv[0] = -mu_2;
3142
3143 if !f_a.is_finite() || f_a <= 0.0 {
3145 neglog[row] = f64::NAN;
3146 for slot in grad[row * r..(row + 1) * r].iter_mut() {
3147 *slot = f64::NAN;
3148 }
3149 for slot in hess[row * r * r..(row + 1) * r * r].iter_mut() {
3150 *slot = f64::NAN;
3151 }
3152 continue;
3153 }
3154 let inv_fa = 1.0 / f_a;
3155
3156 let mut a_u = vec![0.0_f64; r];
3158 a_u[0] = mu_1 * inv_fa;
3159 for u in 1..r {
3160 a_u[u] = -f_u[u] * inv_fa;
3161 }
3162 let mut a_uv = vec![0.0_f64; r * r];
3163 for u in 0..r {
3164 for v in u..r {
3165 let term = f_uv[u * r + v]
3166 + f_au[v] * a_u[u]
3167 + f_au[u] * a_u[v]
3168 + f_aa * a_u[u] * a_u[v];
3169 let val = -term * inv_fa;
3170 a_uv[u * r + v] = val;
3171 a_uv[v * r + u] = val;
3172 }
3173 }
3174
3175 let chi = inputs.chi_obs[row];
3177 let xi = inputs.xi_obs[row];
3178 let rho = &inputs.rho_u[row * r..(row + 1) * r];
3179 let tau = &inputs.tau_u[row * r..(row + 1) * r];
3180 let ruv = &inputs.r_uv[row * r * r..(row + 1) * r * r];
3181 let mut bar_e_u = vec![0.0_f64; r];
3182 for u in 0..r {
3183 bar_e_u[u] = chi * a_u[u] + rho[u];
3184 }
3185 let mut bar_e_uv = vec![0.0_f64; r * r];
3186 for u in 0..r {
3187 for v in u..r {
3188 let val = chi * a_uv[u * r + v]
3189 + xi * a_u[u] * a_u[v]
3190 + tau[u] * a_u[v]
3191 + a_u[u] * tau[v]
3192 + ruv[u * r + v];
3193 bar_e_uv[u * r + v] = val;
3194 if u != v {
3195 bar_e_uv[v * r + u] = val;
3196 }
3197 }
3198 }
3199
3200 let y = inputs.y[row];
3202 let w = inputs.w[row];
3203 let s = 2.0 * y - 1.0;
3204 let e_obs = inputs.e_obs[row];
3209 let m_arg = s * e_obs;
3210 let (log_cdf, lambda) = oracle_log_ndtr_and_mills(m_arg);
3211 let a_i = -w * s * lambda;
3212 let b_i = w * lambda * (m_arg + lambda);
3213 neglog[row] = -w * log_cdf;
3214 for u in 0..r {
3215 grad[row * r + u] = a_i * bar_e_u[u];
3216 }
3217 for u in 0..r {
3218 for v in u..r {
3219 let val = b_i * bar_e_u[u] * bar_e_u[v] + a_i * bar_e_uv[u * r + v];
3220 hess[row * r * r + u * r + v] = val;
3221 if u != v {
3222 hess[row * r * r + v * r + u] = val;
3223 }
3224 }
3225 }
3226 }
3227
3228 BmsFlexRowKernelOutputs { neglog, grad, hess }
3229 }
3230
3231 mod parity_415 {
3241 use super::cpu_oracle_outputs;
3258 use crate::bms::family::*;
3259 use crate::bms::hessian_paths::*;
3260 use crate::bms::{DeviationBlockConfig, LatentMeasureKind, exact_kernel};
3261 use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix};
3262 use gam_problem::{InverseLink, ParameterBlockState, StandardLink};
3263 use ndarray::{Array1, Array2};
3264 use std::sync::{Arc, Mutex};
3265
3266 fn make_flex_parity_family(
3272 n: usize,
3273 ) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
3274 let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
3275 let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
3276 let cfg = DeviationBlockConfig {
3277 num_internal_knots: 3,
3278 ..DeviationBlockConfig::default()
3279 };
3280 let score_prepared = build_score_warp_deviation_block_from_seed(&score_seed, &cfg)
3281 .expect("build score warp block");
3282 let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
3283 &link_seed, &link_seed, &cfg,
3284 )
3285 .expect("build link deviation block");
3286
3287 let y: Array1<f64> =
3289 Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
3290 let weights: Array1<f64> =
3291 Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
3292 let z: Array1<f64> =
3293 Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
3294 let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
3295 if j == 0 {
3296 1.0
3297 } else {
3298 -0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
3299 }
3300 });
3301 let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
3302 if j == 0 {
3303 1.0
3304 } else {
3305 0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
3306 }
3307 });
3308
3309 let family = BernoulliMarginalSlopeFamily {
3310 y: Arc::new(y),
3311 weights: Arc::new(weights),
3312 z: Arc::new(z.clone()),
3313 latent_measure: LatentMeasureKind::StandardNormal,
3314 gaussian_frailty_sd: Some(0.15),
3315 base_link: InverseLink::Standard(StandardLink::Probit),
3316 marginal_design: DesignMatrix::Dense(DenseDesignMatrix::from(marginal_x.clone())),
3317 logslope_design: DesignMatrix::Dense(DenseDesignMatrix::from(logslope_x.clone())),
3318 score_warp: Some(score_prepared.runtime.clone()),
3319 link_dev: Some(link_prepared.runtime.clone()),
3320 policy: gam_runtime::resource::ResourcePolicy::default_library(),
3321 cell_moment_lru: Arc::new(exact_kernel::CellMomentLruCache::new(1024)),
3322 cell_moment_cache_stats: Arc::new(exact_kernel::CellMomentCacheStats::default()),
3323 intercept_warm_starts: None,
3324 auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
3325 auto_subsample_last_rho: Arc::new(Mutex::new(None)),
3326 };
3327
3328 let beta_m = Array1::from_vec(vec![0.12, -0.04]);
3329 let beta_g = Array1::from_vec(vec![0.35, 0.03]);
3330 let beta_h = Array1::from_iter(
3331 (0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
3332 );
3333 let beta_w = Array1::from_iter(
3334 (0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
3335 );
3336 let states = vec![
3337 ParameterBlockState {
3338 eta: marginal_x.dot(&beta_m),
3339 beta: beta_m,
3340 },
3341 ParameterBlockState {
3342 eta: logslope_x.dot(&beta_g),
3343 beta: beta_g,
3344 },
3345 ParameterBlockState {
3346 beta: beta_h,
3347 eta: Array1::zeros(z.len()),
3348 },
3349 ParameterBlockState {
3350 beta: beta_w,
3351 eta: Array1::zeros(z.len()),
3352 },
3353 ];
3354 (family, states)
3355 }
3356
3357 #[test]
3361 fn cpu_oracle_matches_cpu_family_row_analytic_flex_415() {
3362 let n = 12usize;
3363 let (family, states) = make_flex_parity_family(n);
3364 let cache = family
3365 .build_exact_eval_cache(&states)
3366 .expect("flex exact eval cache");
3367
3368 assert!(
3372 cache.row_cell_moments.is_some(),
3373 "#415 fixture must materialise the row-cell-moments bundle; the pack \
3374 and both compared paths read it"
3375 );
3376 let primary = &cache.primary;
3377 let r = primary.total;
3378 let p_h = primary.h.as_ref().map(|range| range.len()).unwrap_or(0);
3379 let p_w = primary.w.as_ref().map(|range| range.len()).unwrap_or(0);
3380 assert!(
3381 p_h > 0 && p_w > 0,
3382 "#415 fixture must be full-flex: p_h={p_h} p_w={p_w}"
3383 );
3384 assert_eq!(r, 2 + p_h + p_w, "#415 fixture primary layout");
3385
3386 let owned = family
3389 .pack_bms_flex_row_kernel_inputs(&states, &cache)
3390 .expect("pack must not error")
3391 .expect("pack must succeed for the StandardNormal full-flex fixture");
3392 let inputs = owned.as_borrowed();
3393 let oracle = cpu_oracle_outputs(&inputs);
3394 assert_eq!(oracle.neglog.len(), n);
3395 assert_eq!(oracle.grad.len(), n * r);
3396 assert_eq!(oracle.hess.len(), n * r * r);
3397
3398 let mut separates_observed_value_from_q_derivative = false;
3408 for row in 0..n {
3409 let y = inputs.y[row];
3410 let w = inputs.w[row];
3411 let s = 2.0 * y - 1.0;
3412 let e_obs = inputs.e_obs[row];
3413 let (_, lambda) = super::oracle_log_ndtr_and_mills(s * e_obs);
3414 let a_i = -w * s * lambda;
3415 if a_i.abs() > 1e-12 {
3416 let recovered_bar_e_q = oracle.grad[row * r] / a_i;
3417 if (recovered_bar_e_q - e_obs).abs() > 1e-8 {
3418 separates_observed_value_from_q_derivative = true;
3419 break;
3420 }
3421 }
3422 }
3423 assert!(
3424 separates_observed_value_from_q_derivative,
3425 "#415 fixture must distinguish e_obs from bar_e_u[0]; otherwise \
3426 the observed-value Mills-margin regression is not exercised"
3427 );
3428
3429 let tol_abs = 1e-9_f64;
3433 let tol_rel = 1e-10_f64;
3434
3435 let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(r);
3436 let mut max_rel = 0.0_f64;
3437 let mut checked_labels = [false, false];
3438
3439 for row in 0..n {
3440 let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
3441 let row_moments = cache
3442 .row_cell_moments
3443 .as_ref()
3444 .and_then(|bundle| bundle.row(row, 9));
3445 assert!(
3446 row_moments.is_some(),
3447 "row {row} must carry degree-9 cell moments (the oracle reads them)"
3448 );
3449 let label = family.y[row] as usize;
3450 if label < 2 {
3451 checked_labels[label] = true;
3452 }
3453
3454 let value = family
3455 .compute_row_analytic_flex_into_with_moments(
3456 row,
3457 &states,
3458 primary,
3459 row_ctx,
3460 row_moments,
3461 cache.cell_family_forest.as_ref(),
3462 true,
3463 &mut scratch,
3464 )
3465 .expect("cpu family row analytic flex");
3466
3467 let o_val = oracle.neglog[row];
3469 if o_val.is_nan() || value.is_nan() {
3470 assert!(
3471 o_val.is_nan() && value.is_nan(),
3472 "row {row}: NaN parity broke — oracle={o_val} family={value}"
3473 );
3474 continue;
3475 }
3476 let vd = (o_val - value).abs();
3477 let vtol = tol_abs + tol_rel * o_val.abs();
3478 max_rel = max_rel.max(vd / o_val.abs().max(1.0));
3479 assert!(
3480 vd <= vtol,
3481 "row {row} value drift: oracle={o_val:.17e} family={value:.17e} \
3482 |Δ|={vd:.3e} > tol={vtol:.3e}"
3483 );
3484
3485 for u in 0..r {
3487 let o_g = oracle.grad[row * r + u];
3488 let f_g = scratch.grad[u];
3489 let gd = (o_g - f_g).abs();
3490 let gtol = tol_abs + tol_rel * o_g.abs();
3491 max_rel = max_rel.max(gd / o_g.abs().max(1.0));
3492 assert!(
3493 gd <= gtol,
3494 "row {row} grad[{u}] drift: oracle={o_g:.17e} family={f_g:.17e} \
3495 |Δ|={gd:.3e} > tol={gtol:.3e}"
3496 );
3497 }
3498
3499 for u in 0..r {
3501 for v in 0..r {
3502 let o_h = oracle.hess[row * r * r + u * r + v];
3503 let f_h = scratch.hess[[u, v]];
3504 let hd = (o_h - f_h).abs();
3505 let htol = tol_abs + tol_rel * o_h.abs();
3506 max_rel = max_rel.max(hd / o_h.abs().max(1.0));
3507 assert!(
3508 hd <= htol,
3509 "row {row} hess[{u},{v}] drift: oracle={o_h:.17e} \
3510 family={f_h:.17e} |Δ|={hd:.3e} > tol={htol:.3e}"
3511 );
3512 }
3513 }
3514 }
3515
3516 assert!(
3519 checked_labels[0] && checked_labels[1],
3520 "#415 fixture must exercise both y=0 and y=1 rows: {checked_labels:?}"
3521 );
3522 eprintln!(
3523 "#415 parity lock: n={n} r={r} p_h={p_h} p_w={p_w} max_rel(oracle−family)={max_rel:.3e}"
3524 );
3525 }
3526 }
3527}
3528
3529#[cfg(all(test, target_os = "linux"))]
3530mod tests {
3531 use super::oracle_parity_tests::*;
3532 use super::*;
3533
3534 pub(crate) fn minimal_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3535 BmsFlexRowKernelInputs {
3536 n_rows: 1,
3537 r: 4,
3538 p_h: 1,
3539 p_w: 1,
3540 q: &buffers.q,
3541 b: &buffers.b,
3542 mu_1: &buffers.mu_1,
3543 mu_2: &buffers.mu_2,
3544 z_obs: &buffers.z_obs,
3545 y: &buffers.y,
3546 w: &buffers.w,
3547 e_obs: &buffers.e_obs,
3548 s_f: 1.0,
3549 cell_offsets: &buffers.cell_offsets,
3550 cell_c0: &buffers.cell_c0,
3551 cell_c1: &buffers.cell_c1,
3552 cell_c2: &buffers.cell_c2,
3553 cell_c3: &buffers.cell_c3,
3554 cell_a: &buffers.cell_a,
3555 cell_aa: &buffers.cell_aa,
3556 cell_r: &buffers.cell_r,
3557 cell_ar: &buffers.cell_ar,
3558 cell_sbb: &buffers.cell_sbb,
3559 cell_sbh: &buffers.cell_sbh,
3560 cell_sbw: &buffers.cell_sbw,
3561 cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3562 chi_obs: &buffers.chi_obs,
3563 xi_obs: &buffers.xi_obs,
3564 rho_u: &buffers.rho_u,
3565 tau_u: &buffers.tau_u,
3566 r_uv: &buffers.r_uv,
3567 }
3568 }
3569
3570 pub(crate) struct TestBuffers {
3571 pub(crate) q: Vec<f64>,
3572 pub(crate) b: Vec<f64>,
3573 pub(crate) mu_1: Vec<f64>,
3574 pub(crate) mu_2: Vec<f64>,
3575 pub(crate) z_obs: Vec<f64>,
3576 pub(crate) y: Vec<f64>,
3577 pub(crate) w: Vec<f64>,
3578 pub(crate) e_obs: Vec<f64>,
3579 pub(crate) cell_offsets: Vec<u32>,
3580 pub(crate) cell_c0: Vec<f64>,
3581 pub(crate) cell_c1: Vec<f64>,
3582 pub(crate) cell_c2: Vec<f64>,
3583 pub(crate) cell_c3: Vec<f64>,
3584 pub(crate) cell_a: Vec<f64>,
3585 pub(crate) cell_aa: Vec<f64>,
3586 pub(crate) cell_r: Vec<f64>,
3587 pub(crate) cell_ar: Vec<f64>,
3588 pub(crate) cell_sbb: Vec<f64>,
3589 pub(crate) cell_sbh: Vec<f64>,
3590 pub(crate) cell_sbw: Vec<f64>,
3591 pub(crate) cell_moments: Vec<f64>,
3592 pub(crate) chi_obs: Vec<f64>,
3593 pub(crate) xi_obs: Vec<f64>,
3594 pub(crate) rho_u: Vec<f64>,
3595 pub(crate) tau_u: Vec<f64>,
3596 pub(crate) r_uv: Vec<f64>,
3597 }
3598
3599 pub(crate) fn make_buffers(n_cells: u32, r: usize, p_h: usize, p_w: usize) -> TestBuffers {
3600 let cells = n_cells as usize;
3601 TestBuffers {
3602 q: vec![0.1; 1],
3603 b: vec![0.5; 1],
3604 mu_1: vec![0.3; 1],
3605 mu_2: vec![0.07; 1],
3606 z_obs: vec![0.0; 1],
3607 y: vec![1.0; 1],
3608 w: vec![1.0; 1],
3609 e_obs: vec![0.15; 1],
3610 cell_offsets: vec![0, n_cells],
3611 cell_c0: vec![0.2; cells],
3612 cell_c1: vec![-0.1; cells],
3613 cell_c2: vec![0.05; cells],
3614 cell_c3: vec![-0.02; cells],
3615 cell_a: vec![0.1; cells * 4],
3616 cell_aa: vec![0.0; cells * 4],
3617 cell_r: vec![0.05; cells * (r - 1) * 4],
3618 cell_ar: vec![0.0; cells * (r - 1) * 4],
3619 cell_sbb: vec![0.0; cells * 4],
3620 cell_sbh: vec![0.0; cells * p_h * 4],
3621 cell_sbw: vec![0.0; cells * p_w * 4],
3622 cell_moments: vec![1.0; cells * MOMENT_STRIDE],
3623 chi_obs: vec![1.0; 1],
3624 xi_obs: vec![0.0; 1],
3625 rho_u: vec![0.0; r],
3626 tau_u: vec![0.0; r],
3627 r_uv: vec![0.0; r * r],
3628 }
3629 }
3630
3631 #[test]
3632 pub(crate) fn validate_accepts_minimal_inputs() {
3633 let buffers = make_buffers(2, 4, 1, 1);
3634 let inputs = minimal_inputs(&buffers);
3635 assert!(inputs.validate().is_ok());
3636 }
3637
3638 #[test]
3639 pub(crate) fn validate_rejects_r_above_max() {
3640 let r = MAX_R + 1;
3641 let p_h = (r - 2) / 2;
3642 let p_w = (r - 2) - p_h;
3643 let buffers = make_buffers(1, r, p_h, p_w);
3644 let bad_inputs = BmsFlexRowKernelInputs {
3645 r,
3646 p_h,
3647 p_w,
3648 rho_u: &buffers.rho_u, tau_u: &buffers.tau_u,
3650 r_uv: &buffers.r_uv,
3651 cell_r: &buffers.cell_r,
3652 cell_ar: &buffers.cell_ar,
3653 cell_sbh: &buffers.cell_sbh,
3654 cell_sbw: &buffers.cell_sbw,
3655 ..minimal_inputs(&buffers)
3656 };
3657 let err = bad_inputs.validate().expect_err("r > MAX_R must fail");
3658 let msg = err.to_string();
3659 assert!(msg.contains("MAX_R"), "expected MAX_R hint, got: {msg}");
3660 }
3661
3662 #[test]
3663 pub(crate) fn validate_rejects_mismatched_r_decomposition() {
3664 let buffers = make_buffers(1, 4, 1, 1);
3665 let bad_inputs = BmsFlexRowKernelInputs {
3666 r: 4,
3667 p_h: 1,
3668 p_w: 2, ..minimal_inputs(&buffers)
3670 };
3671 let err = bad_inputs
3672 .validate()
3673 .expect_err("inconsistent r vs p_h+p_w must fail");
3674 let msg = err.to_string();
3675 assert!(msg.contains("p_h"), "got: {msg}");
3676 assert!(msg.contains("p_w"), "got: {msg}");
3677 }
3678
3679 #[test]
3680 pub(crate) fn validate_rejects_non_monotone_offsets() {
3681 let mut buffers = make_buffers(2, 4, 1, 1);
3688 buffers.cell_offsets = vec![5, 2];
3689 let inputs = minimal_inputs(&buffers);
3690 let err = inputs
3691 .validate()
3692 .expect_err("non-monotone offsets must fail");
3693 let msg = err.to_string();
3694 assert!(msg.contains("monotone"), "got: {msg}");
3695 }
3696
3697 #[test]
3698 pub(crate) fn validate_rejects_mismatched_cell_moments_length() {
3699 let mut buffers = make_buffers(2, 4, 1, 1);
3700 buffers.cell_moments.pop(); let inputs = minimal_inputs(&buffers);
3702 let err = inputs.validate().expect_err("short cell_moments must fail");
3703 let msg = err.to_string();
3704 assert!(msg.contains("cell_moments"), "got: {msg}");
3705 }
3706
3707 #[test]
3708 pub(crate) fn launch_on_non_linux_reports_driver_library_unavailable() {
3709 #[cfg(target_os = "linux")]
3713 {
3714 let buffers = make_buffers(1, 4, 1, 1);
3721 let inputs = minimal_inputs(&buffers);
3722 match launch_bms_flex_row_kernel(inputs) {
3723 Ok(_) => { }
3724 Err(GpuError::DriverLibraryUnavailable { .. })
3725 | Err(GpuError::DriverCallFailed { .. })
3726 | Err(GpuError::DriverSymbolMissing { .. })
3727 | Err(GpuError::NoDeviceKernel { .. }) => { }
3728 Err(other) => panic!("unexpected GpuError variant: {other:?}"),
3729 }
3730 }
3731 #[cfg(not(target_os = "linux"))]
3732 {
3733 let buffers = make_buffers(1, 4, 1, 1);
3734 let inputs = minimal_inputs(&buffers);
3735 match launch_bms_flex_row_kernel(inputs) {
3736 Err(GpuError::DriverLibraryUnavailable { reason }) => {
3737 assert!(
3738 reason.contains("Linux-only"),
3739 "expected Linux-only hint, got: {reason}"
3740 );
3741 }
3742 other => panic!("expected DriverLibraryUnavailable on non-Linux, got {other:?}"),
3743 }
3744 }
3745 }
3746
3747 #[test]
3748 pub(crate) fn s_f_must_be_positive_and_finite() {
3749 let buffers = make_buffers(1, 4, 1, 1);
3750 let mut inputs = minimal_inputs(&buffers);
3751 inputs.s_f = 0.0;
3752 match launch_bms_flex_row_kernel(inputs) {
3753 Err(GpuError::DriverCallFailed { reason }) => {
3754 assert!(reason.contains("s_f"), "got: {reason}");
3755 }
3756 other => panic!("expected DriverCallFailed for s_f=0, got {other:?}"),
3757 }
3758 }
3759
3760 pub(crate) fn make_parity_buffers() -> TestBuffers {
3764 let n = 4_usize;
3765 let r = 5_usize;
3766 let p_h = 2_usize;
3767 let p_w = 1_usize;
3768 let row_cells: [u32; 4] = [2, 3, 4, 2];
3770 let mut cell_offsets = vec![0_u32; n + 1];
3771 for i in 0..n {
3772 cell_offsets[i + 1] = cell_offsets[i] + row_cells[i];
3773 }
3774 let total_cells = cell_offsets[n] as usize;
3775
3776 let f = |seed: usize| -> f64 {
3778 let x = ((seed.wrapping_mul(2_654_435_761)) & 0xFFFF) as f64 / 65_536.0;
3779 0.1 + 0.4 * x
3780 };
3781
3782 let q = (0..n).map(|i| 0.05 + 0.1 * (i as f64)).collect::<Vec<_>>();
3783 let b = (0..n).map(|i| 0.6 + 0.05 * (i as f64)).collect::<Vec<_>>();
3784 let mu_1 = (0..n).map(|i| 0.7 + 0.02 * (i as f64)).collect::<Vec<_>>();
3785 let mu_2 = (0..n).map(|i| 0.15 + 0.01 * (i as f64)).collect::<Vec<_>>();
3786 let z_obs = (0..n).map(|i| -0.2 + 0.1 * (i as f64)).collect::<Vec<_>>();
3787 let y = [1.0, 0.0, 1.0, 0.0].to_vec();
3788 let w = vec![1.0; n];
3789 let e_obs = (0..n).map(|i| -0.3 + 0.2 * (i as f64)).collect::<Vec<_>>();
3790
3791 let cell_c0 = (0..total_cells).map(|c| f(c + 1001)).collect::<Vec<_>>();
3792 let cell_c1 = (0..total_cells)
3793 .map(|c| -f(c + 2002) * 0.5)
3794 .collect::<Vec<_>>();
3795 let cell_c2 = (0..total_cells).map(|c| f(c + 3003) * 0.2).collect();
3796 let cell_c3 = (0..total_cells).map(|c| -f(c + 4004) * 0.1).collect();
3797
3798 let cell_a = (0..total_cells * 4)
3799 .map(|i| f(i + 5005) * 0.3)
3800 .collect::<Vec<_>>();
3801 let cell_aa = (0..total_cells * 4)
3802 .map(|i| f(i + 6006) * 0.1)
3803 .collect::<Vec<_>>();
3804 let cell_r = (0..total_cells * (r - 1) * 4)
3805 .map(|i| f(i + 7007) * 0.2)
3806 .collect::<Vec<_>>();
3807 let cell_ar = (0..total_cells * (r - 1) * 4)
3808 .map(|i| f(i + 8008) * 0.05)
3809 .collect::<Vec<_>>();
3810 let cell_sbb = (0..total_cells * 4)
3811 .map(|i| f(i + 9009) * 0.08)
3812 .collect::<Vec<_>>();
3813 let cell_sbh = (0..total_cells * p_h * 4)
3814 .map(|i| f(i + 10_010) * 0.07)
3815 .collect::<Vec<_>>();
3816 let cell_sbw = (0..total_cells * p_w * 4)
3817 .map(|i| f(i + 11_011) * 0.06)
3818 .collect::<Vec<_>>();
3819 let cell_moments = (0..total_cells * MOMENT_STRIDE)
3820 .map(|i| 0.4 + 0.1 * f(i + 12_012))
3821 .collect::<Vec<_>>();
3822
3823 let chi_obs = (0..n).map(|i| 0.9 + 0.01 * (i as f64)).collect::<Vec<_>>();
3824 let xi_obs = (0..n).map(|i| 0.2 + 0.01 * (i as f64)).collect::<Vec<_>>();
3825 let rho_u = (0..n * r).map(|i| 0.03 * f(i + 13_013)).collect::<Vec<_>>();
3826 let tau_u = (0..n * r).map(|i| 0.02 * f(i + 14_014)).collect::<Vec<_>>();
3827 let r_uv = (0..n * r * r)
3828 .map(|i| 0.04 * f(i + 15_015))
3829 .collect::<Vec<_>>();
3830
3831 TestBuffers {
3832 q,
3833 b,
3834 mu_1,
3835 mu_2,
3836 z_obs,
3837 y,
3838 w,
3839 e_obs,
3840 cell_offsets,
3841 cell_c0,
3842 cell_c1,
3843 cell_c2,
3844 cell_c3,
3845 cell_a,
3846 cell_aa,
3847 cell_r,
3848 cell_ar,
3849 cell_sbb,
3850 cell_sbh,
3851 cell_sbw,
3852 cell_moments,
3853 chi_obs,
3854 xi_obs,
3855 rho_u,
3856 tau_u,
3857 r_uv,
3858 }
3859 }
3860
3861 pub(crate) fn parity_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3862 BmsFlexRowKernelInputs {
3863 n_rows: 4,
3864 r: 5,
3865 p_h: 2,
3866 p_w: 1,
3867 q: &buffers.q,
3868 b: &buffers.b,
3869 mu_1: &buffers.mu_1,
3870 mu_2: &buffers.mu_2,
3871 z_obs: &buffers.z_obs,
3872 y: &buffers.y,
3873 w: &buffers.w,
3874 e_obs: &buffers.e_obs,
3875 s_f: 1.0,
3876 cell_offsets: &buffers.cell_offsets,
3877 cell_c0: &buffers.cell_c0,
3878 cell_c1: &buffers.cell_c1,
3879 cell_c2: &buffers.cell_c2,
3880 cell_c3: &buffers.cell_c3,
3881 cell_a: &buffers.cell_a,
3882 cell_aa: &buffers.cell_aa,
3883 cell_r: &buffers.cell_r,
3884 cell_ar: &buffers.cell_ar,
3885 cell_sbb: &buffers.cell_sbb,
3886 cell_sbh: &buffers.cell_sbh,
3887 cell_sbw: &buffers.cell_sbw,
3888 cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3889 chi_obs: &buffers.chi_obs,
3890 xi_obs: &buffers.xi_obs,
3891 rho_u: &buffers.rho_u,
3892 tau_u: &buffers.tau_u,
3893 r_uv: &buffers.r_uv,
3894 }
3895 }
3896
3897 #[test]
3901 pub(crate) fn cpu_oracle_produces_finite_symmetric_hessian() {
3902 let buffers = make_parity_buffers();
3903 let inputs = parity_inputs(&buffers);
3904 inputs
3905 .validate()
3906 .expect("parity fixture must satisfy validate()");
3907 let out = cpu_oracle_outputs(&inputs);
3908 let n = inputs.n_rows;
3909 let r = inputs.r;
3910 assert_eq!(out.neglog.len(), n);
3911 assert_eq!(out.grad.len(), n * r);
3912 assert_eq!(out.hess.len(), n * r * r);
3913 for row in 0..n {
3914 assert!(
3915 out.neglog[row].is_finite(),
3916 "row {row}: neglog must be finite, got {}",
3917 out.neglog[row]
3918 );
3919 for u in 0..r {
3920 let g = out.grad[row * r + u];
3921 assert!(g.is_finite(), "row {row}: grad[{u}] = {g}");
3922 for v in 0..r {
3923 let huv = out.hess[row * r * r + u * r + v];
3924 let hvu = out.hess[row * r * r + v * r + u];
3925 assert!(huv.is_finite(), "row {row}: H[{u},{v}] = {huv}");
3926 assert_eq!(
3927 huv.to_bits(),
3928 hvu.to_bits(),
3929 "row {row}: H[{u},{v}] and H[{v},{u}] must be bit-identical"
3930 );
3931 }
3932 }
3933 }
3934 }
3935
3936 #[test]
3966 pub(crate) fn cpu_oracle_mills_layer_matches_finite_differences() {
3967 let neglog_of = |e: f64, y: f64, w: f64| -> f64 {
3970 let s = 2.0 * y - 1.0;
3971 let (log_cdf, _) = oracle_log_ndtr_and_mills(s * e);
3972 -w * log_cdf
3973 };
3974 let ab_of = |e: f64, y: f64, w: f64| -> (f64, f64) {
3977 let s = 2.0 * y - 1.0;
3978 let m_arg = s * e;
3979 let (_, lambda) = oracle_log_ndtr_and_mills(m_arg);
3980 let a_i = -w * s * lambda;
3981 let b_i = w * lambda * (m_arg + lambda);
3982 (a_i, b_i)
3983 };
3984
3985 let cases: [(f64, f64, f64); 12] = [
3990 (-1.6, 1.0, 1.0),
3991 (-0.7, 1.0, 1.0),
3992 (0.0, 1.0, 1.0),
3993 (0.9, 1.0, 1.0),
3994 (1.8, 1.0, 1.0),
3995 (-1.4, 0.0, 1.0),
3996 (-0.3, 0.0, 1.0),
3997 (0.0, 0.0, 1.0),
3998 (0.6, 0.0, 1.0),
3999 (1.5, 0.0, 1.0),
4000 (0.4, 1.0, 0.75),
4001 (-0.8, 0.0, 1.3),
4002 ];
4003 let h = 1e-3_f64;
4006 for (e, y, w) in cases {
4007 let (a_ana, b_ana) = ab_of(e, y, w);
4008
4009 let fp2 = neglog_of(e + 2.0 * h, y, w);
4010 let fp1 = neglog_of(e + h, y, w);
4011 let f0 = neglog_of(e, y, w);
4012 let fm1 = neglog_of(e - h, y, w);
4013 let fm2 = neglog_of(e - 2.0 * h, y, w);
4014
4015 let d1_fd = (-fp2 + 8.0 * fp1 - 8.0 * fm1 + fm2) / (12.0 * h);
4017 let d2_fd = (-fp2 + 16.0 * fp1 - 30.0 * f0 + 16.0 * fm1 - fm2) / (12.0 * h * h);
4019
4020 let a_abs = (a_ana - d1_fd).abs();
4021 let a_rel = a_abs / a_ana.abs().max(1.0);
4022 assert!(
4023 a_abs <= 5e-8 || a_rel <= 5e-8,
4024 "Mills A (∂neglog/∂e) drift at e={e} y={y} w={w}: \
4025 analytic={a_ana:.17e} fd={d1_fd:.17e} abs={a_abs:.3e} rel={a_rel:.3e}"
4026 );
4027
4028 let b_abs = (b_ana - d2_fd).abs();
4029 let b_rel = b_abs / b_ana.abs().max(1.0);
4030 assert!(
4031 b_abs <= 5e-6 || b_rel <= 5e-6,
4032 "Mills B (∂²neglog/∂e²) drift at e={e} y={y} w={w}: \
4033 analytic={b_ana:.17e} fd={d2_fd:.17e} abs={b_abs:.3e} rel={b_rel:.3e}"
4034 );
4035 }
4036 }
4037
4038 #[test]
4047 pub(crate) fn bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available() {
4048 #[cfg(not(target_os = "linux"))]
4049 {
4050 eprintln!(
4051 "[bms_flex_row parity] non-Linux host — skipping CUDA parity \
4052 (CPU oracle exercised by sibling test)"
4053 );
4054 return;
4055 }
4056 #[cfg(target_os = "linux")]
4057 {
4058 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4059 eprintln!(
4060 "[bms_flex_row parity] no CUDA runtime — skipping device \
4061 parity (CPU oracle exercised by sibling test)"
4062 );
4063 return;
4064 };
4065 let buffers = make_parity_buffers();
4066 let inputs_cpu = parity_inputs(&buffers);
4067 inputs_cpu
4068 .validate()
4069 .expect("parity fixture must satisfy validate()");
4070 let cpu_out = cpu_oracle_outputs(&inputs_cpu);
4071
4072 let inputs_gpu = parity_inputs(&buffers);
4074 let gpu_out = match launch_bms_flex_row_kernel(inputs_gpu) {
4075 Ok(out) => out,
4076 Err(err) => panic!(
4077 "[bms_flex_row parity] launch failed on CUDA-selected host; \
4078 device/oracle parity must fail loudly on GPU CI: {err}"
4079 ),
4080 };
4081
4082 let n = inputs_cpu.n_rows;
4083 let r = inputs_cpu.r;
4084 let tol_abs = 1e-8_f64;
4085 let tol_rel = 1e-8_f64;
4086 let check_close = |label: &str, idx: usize, cpu: f64, gpu: f64| {
4087 if cpu.is_nan() || gpu.is_nan() {
4088 assert!(
4089 cpu.is_nan() && gpu.is_nan(),
4090 "{label}[{idx}]: NaN parity broke — cpu={cpu}, gpu={gpu}"
4091 );
4092 return;
4093 }
4094 let diff = (cpu - gpu).abs();
4095 let tol = tol_abs + tol_rel * cpu.abs();
4096 assert!(
4097 diff <= tol,
4098 "{label}[{idx}]: |cpu − gpu| = {diff:.3e} > tol = {tol:.3e}; \
4099 cpu={cpu:.17e}, gpu={gpu:.17e}"
4100 );
4101 };
4102 assert_eq!(cpu_out.neglog.len(), gpu_out.neglog.len());
4103 assert_eq!(cpu_out.grad.len(), gpu_out.grad.len());
4104 assert_eq!(cpu_out.hess.len(), gpu_out.hess.len());
4105 for (i, (&c, &g)) in cpu_out.neglog.iter().zip(gpu_out.neglog.iter()).enumerate() {
4106 check_close("neglog", i, c, g);
4107 }
4108 for (i, (&c, &g)) in cpu_out.grad.iter().zip(gpu_out.grad.iter()).enumerate() {
4109 check_close("grad", i, c, g);
4110 }
4111 for (i, (&c, &g)) in cpu_out.hess.iter().zip(gpu_out.hess.iter()).enumerate() {
4112 check_close("hess", i, c, g);
4113 }
4114 for row in 0..n {
4116 for u in 0..r {
4117 for v in 0..r {
4118 let a = gpu_out.hess[row * r * r + u * r + v];
4119 let bb = gpu_out.hess[row * r * r + v * r + u];
4120 assert_eq!(
4121 a.to_bits(),
4122 bb.to_bits(),
4123 "GPU row {row}: H[{u},{v}] ≠ H[{v},{u}] bit-for-bit"
4124 );
4125 }
4126 }
4127 }
4128 }
4129 }
4130
4131 #[test]
4132 pub(crate) fn kernel_source_mentions_cpu_parity_reference() {
4133 #[cfg(target_os = "linux")]
4138 assert!(ROW_KERNEL_BODY.contains("compute_row_analytic_flex_from_parts_into"));
4139 #[cfg(target_os = "linux")]
4140 assert!(ROW_KERNEL_BODY.contains("cell_first_derivative_from_moments"));
4141 }
4142
4143 pub(crate) fn cpu_oracle_bms_flex_row_hvp(
4148 row_hessians: &[f64],
4149 marginal_design: &[f64],
4150 logslope_design: &[f64],
4151 block: &BmsFlexBlockLayout,
4152 primary: &BmsFlexPrimaryLayout,
4153 n: usize,
4154 v: &[f64],
4155 ) -> Vec<f64> {
4156 let r = primary.r;
4157 let p_m = block.p_m;
4158 let p_g = block.p_g;
4159 assert_eq!(v.len(), block.p_total);
4160 assert_eq!(row_hessians.len(), n * r * r);
4161 assert_eq!(marginal_design.len(), n * p_m);
4162 assert_eq!(logslope_design.len(), n * p_g);
4163 let mut out = vec![0.0_f64; block.p_total];
4164 let mut row_dir = vec![0.0_f64; r];
4165 let mut action = vec![0.0_f64; r];
4166 for row in 0..n {
4167 let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
4168 let grow = &logslope_design[row * p_g..(row + 1) * p_g];
4169 let mut acc_q = 0.0_f64;
4170 for j in 0..p_m {
4171 acc_q += mrow[j] * v[j];
4172 }
4173 let mut acc_g = 0.0_f64;
4174 for j in 0..p_g {
4175 acc_g += grow[j] * v[p_m + j];
4176 }
4177 row_dir[0] = acc_q;
4178 row_dir[1] = acc_g;
4179 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4180 for (k, ii) in prange.clone().enumerate() {
4181 row_dir[ii] = v[brange.start + k];
4182 }
4183 }
4184 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4185 for (k, ii) in prange.clone().enumerate() {
4186 row_dir[ii] = v[brange.start + k];
4187 }
4188 }
4189 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4190 for u in 0..r {
4191 let mut acc = 0.0_f64;
4192 for v_idx in 0..r {
4193 acc += h_slice[u * r + v_idx] * row_dir[v_idx];
4194 }
4195 action[u] = acc;
4196 }
4197 let a0 = action[0];
4198 for j in 0..p_m {
4199 out[j] += a0 * mrow[j];
4200 }
4201 let a1 = action[1];
4202 for j in 0..p_g {
4203 out[p_m + j] += a1 * grow[j];
4204 }
4205 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4206 for (k, ii) in prange.clone().enumerate() {
4207 out[brange.start + k] += action[ii];
4208 }
4209 }
4210 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4211 for (k, ii) in prange.clone().enumerate() {
4212 out[brange.start + k] += action[ii];
4213 }
4214 }
4215 }
4216 out
4217 }
4218
4219 pub(crate) fn cpu_oracle_bms_flex_row_diagonal(
4220 row_hessians: &[f64],
4221 marginal_design: &[f64],
4222 logslope_design: &[f64],
4223 block: &BmsFlexBlockLayout,
4224 primary: &BmsFlexPrimaryLayout,
4225 n: usize,
4226 ) -> Vec<f64> {
4227 let r = primary.r;
4228 let p_m = block.p_m;
4229 let p_g = block.p_g;
4230 let mut out = vec![0.0_f64; block.p_total];
4231 for row in 0..n {
4232 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4233 let h00 = h_slice[0];
4234 let h11 = h_slice[r + 1];
4235 let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
4236 let grow = &logslope_design[row * p_g..(row + 1) * p_g];
4237 for j in 0..p_m {
4238 out[j] += h00 * mrow[j] * mrow[j];
4239 }
4240 for j in 0..p_g {
4241 out[p_m + j] += h11 * grow[j] * grow[j];
4242 }
4243 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4244 for (k, ii) in prange.clone().enumerate() {
4245 out[brange.start + k] += h_slice[ii * r + ii];
4246 }
4247 }
4248 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4249 for (k, ii) in prange.clone().enumerate() {
4250 out[brange.start + k] += h_slice[ii * r + ii];
4251 }
4252 }
4253 }
4254 out
4255 }
4256
4257 #[test]
4261 pub(crate) fn cpu_oracle_hvp_matches_hand_computation_no_hw() {
4262 let n = 4_usize;
4263 let r = 4_usize; let p_m = 2_usize;
4265 let p_g = 2_usize;
4266 let p_h_dim = 1_usize;
4267 let p_w_dim = 1_usize;
4268 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4269 let block = BmsFlexBlockLayout {
4270 p_m,
4271 p_g,
4272 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4273 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4274 p_total,
4275 };
4276 let primary = BmsFlexPrimaryLayout {
4277 h: Some(2..3),
4278 w: Some(3..4),
4279 r,
4280 };
4281 let mut row_hessians = vec![0.0_f64; n * r * r];
4283 for row in 0..n {
4284 for u in 0..r {
4285 for v in u..r {
4286 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4287 row_hessians[row * r * r + u * r + v] = val;
4288 row_hessians[row * r * r + v * r + u] = val;
4289 }
4290 }
4291 }
4292 let mut marginal = vec![0.0_f64; n * p_m];
4293 for row in 0..n {
4294 for j in 0..p_m {
4295 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4296 }
4297 }
4298 let mut logslope = vec![0.0_f64; n * p_g];
4299 for row in 0..n {
4300 for j in 0..p_g {
4301 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4302 }
4303 }
4304 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4305 let out = cpu_oracle_bms_flex_row_hvp(
4306 &row_hessians,
4307 &marginal,
4308 &logslope,
4309 &block,
4310 &primary,
4311 n,
4312 &v,
4313 );
4314 let mut expect_out_0 = 0.0_f64;
4316 for row in 0..n {
4317 let mrow = &marginal[row * p_m..(row + 1) * p_m];
4318 let grow = &logslope[row * p_g..(row + 1) * p_g];
4319 let mut row_dir = vec![0.0_f64; r];
4320 row_dir[0] = mrow[0] * v[0] + mrow[1] * v[1];
4321 row_dir[1] = grow[0] * v[p_m] + grow[1] * v[p_m + 1];
4322 row_dir[2] = v[p_m + p_g];
4323 row_dir[3] = v[p_m + p_g + p_h_dim];
4324 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4325 let mut action0 = 0.0_f64;
4326 for vv in 0..r {
4330 action0 += h_slice[vv] * row_dir[vv];
4331 }
4332 expect_out_0 += action0 * mrow[0];
4333 }
4334 assert!(
4335 (out[0] - expect_out_0).abs() < 1e-12,
4336 "cpu oracle HVP out[0] mismatch: {} vs hand-check {}",
4337 out[0],
4338 expect_out_0
4339 );
4340 assert!(out.iter().all(|x| x.is_finite()));
4341 assert_eq!(out.len(), p_total);
4342 }
4343
4344 #[test]
4346 pub(crate) fn cpu_oracle_diagonal_matches_hand_computation() {
4347 let n = 3_usize;
4348 let r = 4_usize;
4349 let p_m = 2_usize;
4350 let p_g = 2_usize;
4351 let p_h_dim = 1_usize;
4352 let p_w_dim = 1_usize;
4353 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4354 let block = BmsFlexBlockLayout {
4355 p_m,
4356 p_g,
4357 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4358 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4359 p_total,
4360 };
4361 let primary = BmsFlexPrimaryLayout {
4362 h: Some(2..3),
4363 w: Some(3..4),
4364 r,
4365 };
4366 let mut row_hessians = vec![0.0_f64; n * r * r];
4367 for row in 0..n {
4368 for u in 0..r {
4369 row_hessians[row * r * r + u * r + u] = 1.0 + (row as f64) + (u as f64) * 0.5;
4370 }
4371 }
4372 let mut marginal = vec![0.0_f64; n * p_m];
4373 let mut logslope = vec![0.0_f64; n * p_g];
4374 for row in 0..n {
4375 for j in 0..p_m {
4376 marginal[row * p_m + j] = 0.2 + (row as f64) * 0.3 + (j as f64) * 0.1;
4377 }
4378 for j in 0..p_g {
4379 logslope[row * p_g + j] = -0.4 + (row as f64) * 0.1 + (j as f64) * 0.2;
4380 }
4381 }
4382 let out = cpu_oracle_bms_flex_row_diagonal(
4383 &row_hessians,
4384 &marginal,
4385 &logslope,
4386 &block,
4387 &primary,
4388 n,
4389 );
4390 let mut expect = 0.0_f64;
4392 for row in 0..n {
4393 let h00 = row_hessians[row * r * r];
4394 expect += h00 * marginal[row * p_m].powi(2);
4395 }
4396 assert!(
4397 (out[0] - expect).abs() < 1e-12,
4398 "out[0] {} vs {}",
4399 out[0],
4400 expect
4401 );
4402 let mut expect_h = 0.0_f64;
4404 for row in 0..n {
4405 expect_h += row_hessians[row * r * r + 2 * r + 2];
4406 }
4407 let h_slot = p_m + p_g;
4408 assert!(
4409 (out[h_slot] - expect_h).abs() < 1e-12,
4410 "h slot {} vs {}",
4411 out[h_slot],
4412 expect_h
4413 );
4414 }
4415
4416 #[test]
4421 pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_when_cuda_available() {
4422 #[cfg(not(target_os = "linux"))]
4423 {
4424 eprintln!(
4425 "[bms_flex_row hvp parity] non-Linux host — skipping CUDA parity \
4426 (CPU oracle exercised by sibling tests)"
4427 );
4428 }
4429 #[cfg(target_os = "linux")]
4430 {
4431 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4432 eprintln!(
4433 "[bms_flex_row hvp parity] no CUDA runtime — skipping device \
4434 parity"
4435 );
4436 return;
4437 };
4438 let n = 4_usize;
4439 let r = 4_usize;
4440 let p_m = 2_usize;
4441 let p_g = 2_usize;
4442 let p_h_dim = 1_usize;
4443 let p_w_dim = 1_usize;
4444 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4445 let block = BmsFlexBlockLayout {
4446 p_m,
4447 p_g,
4448 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4449 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4450 p_total,
4451 };
4452 let primary = BmsFlexPrimaryLayout {
4453 h: Some(2..3),
4454 w: Some(3..4),
4455 r,
4456 };
4457 let mut row_hessians = vec![0.0_f64; n * r * r];
4458 for row in 0..n {
4459 for u in 0..r {
4460 for v in u..r {
4461 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4462 row_hessians[row * r * r + u * r + v] = val;
4463 row_hessians[row * r * r + v * r + u] = val;
4464 }
4465 }
4466 }
4467 let mut marginal = vec![0.0_f64; n * p_m];
4468 for row in 0..n {
4469 for j in 0..p_m {
4470 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4471 }
4472 }
4473 let mut logslope = vec![0.0_f64; n * p_g];
4474 for row in 0..n {
4475 for j in 0..p_g {
4476 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4477 }
4478 }
4479 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4480 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4481 &row_hessians,
4482 &marginal,
4483 &logslope,
4484 &block,
4485 &primary,
4486 n,
4487 &v,
4488 );
4489 let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4490 &row_hessians,
4491 &marginal,
4492 &logslope,
4493 &block,
4494 &primary,
4495 n,
4496 );
4497
4498 let backend = HvpKernelBackend::probe()
4505 .expect("[bms_flex_row hvp parity] backend probe must succeed on CUDA host");
4506 let stream = backend.stream.clone();
4507 let d_h = stream
4508 .clone_htod(&row_hessians)
4509 .expect("[bms_flex_row hvp parity] upload h must succeed on CUDA host");
4510 let d_m = stream
4511 .clone_htod(&marginal)
4512 .expect("[bms_flex_row hvp parity] upload marg must succeed on CUDA host");
4513 let d_g = stream
4514 .clone_htod(&logslope)
4515 .expect("[bms_flex_row hvp parity] upload logslope must succeed on CUDA host");
4516 let storage = DeviceResidentRowHess {
4517 hess: d_h,
4518 marginal_design: d_m,
4519 logslope_design: d_g,
4520 n,
4521 r,
4522 block: block.clone(),
4523 primary: primary.clone(),
4524
4525 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4526 };
4527 let gpu_hvp =
4528 launch_bms_flex_row_hvp(&storage, &v).expect("HVP kernel must launch on CUDA host");
4529 let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4530 .expect("diagonal kernel must launch on CUDA host");
4531 assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4532 assert_eq!(gpu_diag.len(), cpu_diag.len());
4533 for i in 0..p_total {
4534 let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4535 assert!(
4536 diff <= 1e-10,
4537 "HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4538 cpu_hvp[i],
4539 gpu_hvp[i]
4540 );
4541 let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4542 assert!(
4543 ddiff <= 1e-10,
4544 "diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4545 cpu_diag[i],
4546 gpu_diag[i]
4547 );
4548 }
4549 }
4550 }
4551
4552 #[test]
4553 pub(crate) fn bms_flex_row_hvp_multi_scratch_is_bounded_at_large_scale_shape() {
4554 let n = 195_000_usize;
4555 let r = 20_usize;
4556 let p_total = 44_usize;
4557 let rhs_count = 4_usize;
4558 let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4559 .expect("large-scale multi-RHS scratch budget");
4560 let per_rhs_full_row_cache =
4561 (n * r * r * std::mem::size_of::<f64>()) as u64 * rhs_count as u64;
4562 assert!(
4563 scratch < per_rhs_full_row_cache / 100,
4564 "multi-RHS scratch must tile by row chunks instead of materializing \
4565 a row-Hessian copy per RHS: scratch={scratch} full_per_rhs={per_rhs_full_row_cache}"
4566 );
4567 assert!(
4568 bms_flex_row_hvp_multi_scratch_bytes_for_shape(
4569 n,
4570 p_total,
4571 BMS_FLEX_ROW_HVP_MAX_RHS + 1
4572 )
4573 .is_err(),
4574 "multi-RHS launch must reject unbounded RHS counts"
4575 );
4576 }
4577
4578 #[test]
4579 pub(crate) fn bms_flex_row_hvp_multi_kernel_matches_cpu_oracle_when_cuda_available() {
4580 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4581 eprintln!("[bms_flex_row hvp_multi parity] no CUDA runtime — skipping device parity");
4582 return;
4583 };
4584 let n = 5_usize;
4585 let r = 4_usize;
4586 let p_m = 2_usize;
4587 let p_g = 2_usize;
4588 let p_h_dim = 1_usize;
4589 let p_w_dim = 1_usize;
4590 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4591 let rhs_count = 3_usize;
4592 let block = BmsFlexBlockLayout {
4593 p_m,
4594 p_g,
4595 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4596 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4597 p_total,
4598 };
4599 let primary = BmsFlexPrimaryLayout {
4600 h: Some(2..3),
4601 w: Some(3..4),
4602 r,
4603 };
4604 let mut row_hessians = vec![0.0_f64; n * r * r];
4605 for row in 0..n {
4606 for u in 0..r {
4607 for v in u..r {
4608 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4609 row_hessians[row * r * r + u * r + v] = val;
4610 row_hessians[row * r * r + v * r + u] = val;
4611 }
4612 }
4613 }
4614 let mut marginal = vec![0.0_f64; n * p_m];
4615 let mut logslope = vec![0.0_f64; n * p_g];
4616 for row in 0..n {
4617 for j in 0..p_m {
4618 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4619 }
4620 for j in 0..p_g {
4621 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4622 }
4623 }
4624 let mut v_rhs = vec![0.0_f64; rhs_count * p_total];
4625 for rhs in 0..rhs_count {
4626 for j in 0..p_total {
4627 let seed = (rhs as f64) * 0.37 + (j as f64) * 0.19 + 0.4;
4628 v_rhs[rhs * p_total + j] = seed.sin() * 0.4 + seed.cos() * 0.2;
4629 }
4630 }
4631
4632 let backend = HvpKernelBackend::probe()
4636 .expect("[bms_flex_row hvp_multi parity] backend probe must succeed on CUDA host");
4637 let stream = backend.stream.clone();
4638 let d_h = stream
4639 .clone_htod(&row_hessians)
4640 .expect("[bms_flex_row hvp_multi parity] upload h must succeed on CUDA host");
4641 let d_m = stream
4642 .clone_htod(&marginal)
4643 .expect("[bms_flex_row hvp_multi parity] upload marg must succeed on CUDA host");
4644 let d_g = stream
4645 .clone_htod(&logslope)
4646 .expect("[bms_flex_row hvp_multi parity] upload logslope must succeed on CUDA host");
4647 let storage = DeviceResidentRowHess {
4648 hess: d_h,
4649 marginal_design: d_m,
4650 logslope_design: d_g,
4651 n,
4652 r,
4653 block: block.clone(),
4654 primary: primary.clone(),
4655
4656 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4657 };
4658 let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4659 .expect("storage scratch budget");
4660 assert!(
4661 scratch < storage.bytes,
4662 "multi-RHS scratch should stay below resident cache bytes"
4663 );
4664 let gpu = launch_bms_flex_row_hvp_multi(&storage, &v_rhs, rhs_count)
4665 .expect("multi-RHS HVP kernel must launch on CUDA host");
4666 assert_eq!(gpu.len(), rhs_count * p_total);
4667 for rhs in 0..rhs_count {
4668 let v = &v_rhs[rhs * p_total..(rhs + 1) * p_total];
4669 let cpu = cpu_oracle_bms_flex_row_hvp(
4670 &row_hessians,
4671 &marginal,
4672 &logslope,
4673 &block,
4674 &primary,
4675 n,
4676 v,
4677 );
4678 let single = launch_bms_flex_row_hvp(&storage, v)
4679 .expect("single-RHS HVP kernel must launch on CUDA host");
4680 for j in 0..p_total {
4681 let got = gpu[rhs * p_total + j];
4682 let diff = (cpu[j] - got).abs();
4683 assert!(
4684 diff <= 1e-10,
4685 "multi-RHS HVP rhs={rhs} j={j}: cpu={} gpu={} |diff|={diff:.3e}",
4686 cpu[j],
4687 got
4688 );
4689 assert_eq!(
4690 got, single[j],
4691 "multi-RHS and single-RHS host launch diverged at rhs={rhs} j={j}"
4692 );
4693 }
4694 }
4695 }
4696
4697 #[test]
4708 pub(crate) fn bms_flex_row_hvp_into_device_matches_cpu_oracle_and_host_out() {
4709 #[cfg(not(target_os = "linux"))]
4710 {
4711 eprintln!(
4712 "[bms_flex_row hvp_into_device parity] non-Linux host — skipping \
4713 CUDA parity (CPU oracle exercised by sibling tests)"
4714 );
4715 }
4716 #[cfg(target_os = "linux")]
4717 {
4718 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4719 eprintln!(
4720 "[bms_flex_row hvp_into_device parity] no CUDA runtime — \
4721 skipping device parity"
4722 );
4723 return;
4724 };
4725 let n = 4_usize;
4726 let r = 4_usize;
4727 let p_m = 2_usize;
4728 let p_g = 2_usize;
4729 let p_h_dim = 1_usize;
4730 let p_w_dim = 1_usize;
4731 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4732 let block = BmsFlexBlockLayout {
4733 p_m,
4734 p_g,
4735 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4736 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4737 p_total,
4738 };
4739 let primary = BmsFlexPrimaryLayout {
4740 h: Some(2..3),
4741 w: Some(3..4),
4742 r,
4743 };
4744 let mut row_hessians = vec![0.0_f64; n * r * r];
4745 for row in 0..n {
4746 for u in 0..r {
4747 for v in u..r {
4748 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4749 row_hessians[row * r * r + u * r + v] = val;
4750 row_hessians[row * r * r + v * r + u] = val;
4751 }
4752 }
4753 }
4754 let mut marginal = vec![0.0_f64; n * p_m];
4755 for row in 0..n {
4756 for j in 0..p_m {
4757 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4758 }
4759 }
4760 let mut logslope = vec![0.0_f64; n * p_g];
4761 for row in 0..n {
4762 for j in 0..p_g {
4763 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4764 }
4765 }
4766 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4767 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4768 &row_hessians,
4769 &marginal,
4770 &logslope,
4771 &block,
4772 &primary,
4773 n,
4774 &v,
4775 );
4776
4777 let backend = HvpKernelBackend::probe().expect(
4780 "[bms_flex_row hvp_into_device parity] backend probe must succeed on CUDA host",
4781 );
4782 let stream = backend.stream.clone();
4783 let d_h = stream
4784 .clone_htod(&row_hessians)
4785 .expect("[bms_flex_row hvp_into_device parity] upload h must succeed on CUDA host");
4786 let d_m = stream.clone_htod(&marginal).expect(
4787 "[bms_flex_row hvp_into_device parity] upload marg must succeed on CUDA host",
4788 );
4789 let d_g = stream.clone_htod(&logslope).expect(
4790 "[bms_flex_row hvp_into_device parity] upload logslope must succeed on CUDA host",
4791 );
4792 let storage = DeviceResidentRowHess {
4793 hess: d_h,
4794 marginal_design: d_m,
4795 logslope_design: d_g,
4796 n,
4797 r,
4798 block: block.clone(),
4799 primary: primary.clone(),
4800
4801 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4802 };
4803
4804 let host_out_hvp = launch_bms_flex_row_hvp(&storage, &v)
4806 .expect("host-out HVP kernel must launch on CUDA host");
4807
4808 let d_v = stream
4811 .clone_htod(&v)
4812 .expect("upload direction for device-out HVP");
4813 let mut d_out = stream
4814 .alloc_zeros::<f64>(p_total)
4815 .expect("alloc device-out HVP output");
4816 launch_bms_flex_row_hvp_into_device(&storage, &d_v, &mut d_out)
4817 .expect("device-out HVP kernel must launch on CUDA host");
4818 stream
4819 .synchronize()
4820 .expect("synchronize after device-out HVP");
4821 let device_out_hvp = stream
4822 .clone_dtoh(&d_out)
4823 .expect("download device-out HVP output");
4824
4825 assert_eq!(device_out_hvp.len(), cpu_hvp.len());
4826 assert_eq!(device_out_hvp.len(), host_out_hvp.len());
4827 for i in 0..p_total {
4828 let diff = (cpu_hvp[i] - device_out_hvp[i]).abs();
4829 assert!(
4830 diff <= 1e-10,
4831 "device-out HVP[{i}] vs CPU: cpu={} gpu={} |Δ|={diff:.3e}",
4832 cpu_hvp[i],
4833 device_out_hvp[i]
4834 );
4835 let host_diff = (host_out_hvp[i] - device_out_hvp[i]).abs();
4838 assert!(
4839 host_diff == 0.0,
4840 "device-out vs host-out HVP[{i}]: host={} device={} |Δ|={host_diff:.3e}",
4841 host_out_hvp[i],
4842 device_out_hvp[i]
4843 );
4844 }
4845 }
4846 }
4847
4848 #[test]
4861 pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_at_n64_r20_p44() {
4862 #[cfg(not(target_os = "linux"))]
4863 {
4864 eprintln!(
4865 "[bms_flex_row hvp parity n64_r20_p44] non-Linux host — \
4866 skipping CUDA parity"
4867 );
4868 }
4869 #[cfg(target_os = "linux")]
4870 {
4871 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4872 eprintln!(
4873 "[bms_flex_row hvp parity n64_r20_p44] no CUDA runtime — \
4874 skipping device parity"
4875 );
4876 return;
4877 };
4878 let n = 64_usize;
4879 let p_m = 14_usize;
4880 let p_g = 12_usize;
4881 let p_h_dim = 10_usize;
4882 let p_w_dim = 8_usize;
4883 let r = 2 + p_h_dim + p_w_dim;
4884 assert_eq!(r, 20);
4885 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4886 assert_eq!(p_total, 44);
4887 let block = BmsFlexBlockLayout {
4888 p_m,
4889 p_g,
4890 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4891 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4892 p_total,
4893 };
4894 let primary = BmsFlexPrimaryLayout {
4895 h: Some(2..2 + p_h_dim),
4896 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4897 r,
4898 };
4899
4900 let mut row_hessians = vec![0.0_f64; n * r * r];
4906 for row in 0..n {
4907 let base = row * r * r;
4908 for u in 0..r {
4909 for v in 0..r {
4910 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
4911 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4912 row_hessians[base + u * r + v] = a;
4913 }
4914 }
4915 for u in 0..r {
4916 for v in (u + 1)..r {
4917 let upper = row_hessians[base + u * r + v];
4918 let lower = row_hessians[base + v * r + u];
4919 let sym = 0.5 * (upper + lower);
4920 row_hessians[base + u * r + v] = sym;
4921 row_hessians[base + v * r + u] = sym;
4922 }
4923 row_hessians[base + u * r + u] += r as f64;
4924 }
4925 }
4926 let mut marginal = vec![0.0_f64; n * p_m];
4927 for row in 0..n {
4928 for j in 0..p_m {
4929 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4930 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4931 }
4932 }
4933 let mut logslope = vec![0.0_f64; n * p_g];
4934 for row in 0..n {
4935 for j in 0..p_g {
4936 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4937 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4938 }
4939 }
4940 let v: Vec<f64> = (0..p_total)
4941 .map(|i| {
4942 let seed = (i as f64) * 0.157 + 0.6;
4943 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4944 })
4945 .collect();
4946
4947 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4948 &row_hessians,
4949 &marginal,
4950 &logslope,
4951 &block,
4952 &primary,
4953 n,
4954 &v,
4955 );
4956 let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4957 &row_hessians,
4958 &marginal,
4959 &logslope,
4960 &block,
4961 &primary,
4962 n,
4963 );
4964
4965 let backend = match HvpKernelBackend::probe() {
4966 Ok(b) => b,
4967 Err(err) => {
4968 eprintln!(
4969 "[bms_flex_row hvp parity n64_r20_p44] backend probe \
4970 failed: {err}"
4971 );
4972 return;
4973 }
4974 };
4975 let stream = backend.stream.clone();
4976 let d_h = match stream.clone_htod(&row_hessians) {
4977 Ok(s) => s,
4978 Err(err) => {
4979 eprintln!(
4980 "[bms_flex_row hvp parity n64_r20_p44] upload h \
4981 failed: {err}"
4982 );
4983 return;
4984 }
4985 };
4986 let d_m = match stream.clone_htod(&marginal) {
4987 Ok(s) => s,
4988 Err(err) => {
4989 eprintln!(
4990 "[bms_flex_row hvp parity n64_r20_p44] upload marg \
4991 failed: {err}"
4992 );
4993 return;
4994 }
4995 };
4996 let d_g = match stream.clone_htod(&logslope) {
4997 Ok(s) => s,
4998 Err(err) => {
4999 eprintln!(
5000 "[bms_flex_row hvp parity n64_r20_p44] upload logslope \
5001 failed: {err}"
5002 );
5003 return;
5004 }
5005 };
5006 let storage = DeviceResidentRowHess {
5007 hess: d_h,
5008 marginal_design: d_m,
5009 logslope_design: d_g,
5010 n,
5011 r,
5012 block: block.clone(),
5013 primary: primary.clone(),
5014
5015 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5016 };
5017 let gpu_hvp = launch_bms_flex_row_hvp(&storage, &v)
5018 .expect("HVP kernel must launch on CUDA host at n64/r20/p44");
5019 let gpu_diag = launch_bms_flex_row_diagonal(&storage)
5020 .expect("diagonal kernel must launch on CUDA host at n64/r20/p44");
5021 assert_eq!(gpu_hvp.len(), cpu_hvp.len());
5022 assert_eq!(gpu_diag.len(), cpu_diag.len());
5023 for i in 0..p_total {
5024 let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
5025 assert!(
5026 diff <= 1e-8,
5027 "n64_r20_p44 HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
5028 cpu_hvp[i],
5029 gpu_hvp[i]
5030 );
5031 let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
5032 assert!(
5033 ddiff <= 1e-8,
5034 "n64_r20_p44 diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
5035 cpu_diag[i],
5036 gpu_diag[i]
5037 );
5038 }
5039 }
5040 }
5041
5042 #[test]
5048 pub(crate) fn bms_flex_row_dense_block_kernel_matches_cpu_pullback() {
5049 #[cfg(not(target_os = "linux"))]
5050 {
5051 eprintln!("[bms_flex_row dense_block parity] non-Linux host — skipping CUDA parity");
5052 }
5053 #[cfg(target_os = "linux")]
5054 {
5055 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5056 eprintln!("[bms_flex_row dense_block parity] no CUDA runtime — skipping");
5057 return;
5058 };
5059 let n = 24_usize;
5063 let p_m = 4_usize;
5064 let p_g = 4_usize;
5065 let p_h_dim = 3_usize;
5066 let p_w_dim = 3_usize;
5067 let r = 2 + p_h_dim + p_w_dim;
5068 let p_total = p_m + p_g + p_h_dim + p_w_dim;
5069 let block = BmsFlexBlockLayout {
5070 p_m,
5071 p_g,
5072 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5073 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5074 p_total,
5075 };
5076 let primary = BmsFlexPrimaryLayout {
5077 h: Some(2..2 + p_h_dim),
5078 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5079 r,
5080 };
5081
5082 let mut row_hessians = vec![0.0_f64; n * r * r];
5083 for row in 0..n {
5084 let base = row * r * r;
5085 for u in 0..r {
5086 for v in 0..r {
5087 let seed = (row as f64) * 0.21 + (u as f64) * 1.13 + (v as f64) * 0.47;
5088 let a = (seed.sin() * 1.4 + (seed * 0.6).cos() * 0.7) * 0.5;
5089 row_hessians[base + u * r + v] = a;
5090 }
5091 }
5092 for u in 0..r {
5093 for v in (u + 1)..r {
5094 let upper = row_hessians[base + u * r + v];
5095 let lower = row_hessians[base + v * r + u];
5096 let sym = 0.5 * (upper + lower);
5097 row_hessians[base + u * r + v] = sym;
5098 row_hessians[base + v * r + u] = sym;
5099 }
5100 row_hessians[base + u * r + u] += r as f64;
5101 }
5102 }
5103 let mut marginal = vec![0.0_f64; n * p_m];
5104 for row in 0..n {
5105 for j in 0..p_m {
5106 let seed = (row as f64) * 0.083 + (j as f64) * 0.171 + 0.31;
5107 marginal[row * p_m + j] = seed.sin() * 0.7 - (seed * 0.5).cos() * 0.25;
5108 }
5109 }
5110 let mut logslope = vec![0.0_f64; n * p_g];
5111 for row in 0..n {
5112 for j in 0..p_g {
5113 let seed = (row as f64) * 0.097 + (j as f64) * 0.143 - 0.15;
5114 logslope[row * p_g + j] = seed.cos() * 0.65 + (seed * 0.4).sin() * 0.2;
5115 }
5116 }
5117
5118 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5120 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5121 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5122 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5123 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5124 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5125 let mut h_cpu = vec![0.0_f64; p_total * p_total];
5126 for row in 0..n {
5127 let mrow = &marginal[row * p_m..(row + 1) * p_m];
5128 let grow = &logslope[row * p_g..(row + 1) * p_g];
5129 let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5130 let mut phi = vec![vec![0.0_f64; p_total]; r];
5132 for k in 0..p_m {
5133 phi[0][k] = mrow[k];
5134 }
5135 for k in 0..p_g {
5136 phi[1][p_m + k] = grow[k];
5137 }
5138 for k in 0..h_block_len {
5139 phi[h_primary_start + k][h_block_start + k] = 1.0;
5140 }
5141 for k in 0..w_block_len {
5142 phi[w_primary_start + k][w_block_start + k] = 1.0;
5143 }
5144 for u in 0..r {
5145 for v in 0..r {
5146 let huv = hrow[u * r + v];
5147 if huv == 0.0 {
5148 continue;
5149 }
5150 for m in 0..p_total {
5151 let pm = phi[u][m];
5152 if pm == 0.0 {
5153 continue;
5154 }
5155 let scaled = huv * pm;
5156 for nn in 0..p_total {
5157 h_cpu[m * p_total + nn] += scaled * phi[v][nn];
5158 }
5159 }
5160 }
5161 }
5162 }
5163
5164 let backend = HvpKernelBackend::probe().expect(
5169 "[bms_flex_row dense_block parity] backend probe must succeed on CUDA host",
5170 );
5171 let stream = backend.stream.clone();
5172 let d_h = stream
5173 .clone_htod(&row_hessians)
5174 .expect("[bms_flex_row dense_block parity] upload h must succeed on CUDA host");
5175 let d_m = stream
5176 .clone_htod(&marginal)
5177 .expect("[bms_flex_row dense_block parity] upload marg must succeed on CUDA host");
5178 let d_g = stream.clone_htod(&logslope).expect(
5179 "[bms_flex_row dense_block parity] upload logslope must succeed on CUDA host",
5180 );
5181 let storage = DeviceResidentRowHess {
5182 hess: d_h,
5183 marginal_design: d_m,
5184 logslope_design: d_g,
5185 n,
5186 r,
5187 block: block.clone(),
5188 primary: primary.clone(),
5189
5190 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5191 };
5192 let h_gpu = launch_bms_flex_row_dense_block(&storage)
5193 .expect("dense_block kernel must launch on CUDA host");
5194 assert_eq!(h_gpu.len(), p_total * p_total);
5195
5196 let mut max_abs = 0.0_f64;
5199 for i in 0..p_total {
5200 for j in 0..p_total {
5201 let a = h_cpu[i * p_total + j];
5202 let b = h_gpu[i * p_total + j];
5203 let diff = (a - b).abs();
5204 if diff > max_abs {
5205 max_abs = diff;
5206 }
5207 assert!(
5208 diff <= 1e-9 * a.abs().max(b.abs()).max(1.0),
5209 "dense_block[{i},{j}]: cpu={a} gpu={b} |Δ|={diff:.3e}"
5210 );
5211 }
5212 }
5213 eprintln!(
5214 "[bms_flex_row dense_block parity] n={n} r={r} p={p_total}: max|Δ|={max_abs:.3e}"
5215 );
5216 }
5217 }
5218
5219 #[test]
5239 pub(crate) fn bms_flex_row_hvp_v100_hill_climb_5x_vs_cpu_at_large_scale() {
5240 #[cfg(not(target_os = "linux"))]
5241 {
5242 eprintln!("[bms_flex_row hvp hill-climb] non-Linux host — skipping V100 perf gate");
5243 }
5244 #[cfg(target_os = "linux")]
5245 {
5246 use rayon::prelude::*;
5247
5248 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5249 eprintln!(
5250 "[bms_flex_row hvp hill-climb] no CUDA runtime — skipping V100 perf gate"
5251 );
5252 return;
5253 };
5254 let n = 195_000_usize;
5255 let p_m = 14_usize;
5256 let p_g = 12_usize;
5257 let p_h_dim = 10_usize;
5258 let p_w_dim = 8_usize;
5259 let r = 2 + p_h_dim + p_w_dim;
5260 let p_total = p_m + p_g + p_h_dim + p_w_dim;
5261 let block = BmsFlexBlockLayout {
5262 p_m,
5263 p_g,
5264 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5265 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5266 p_total,
5267 };
5268 let primary = BmsFlexPrimaryLayout {
5269 h: Some(2..2 + p_h_dim),
5270 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5271 r,
5272 };
5273
5274 let mut row_hessians = vec![0.0_f64; n * r * r];
5276 for row in 0..n {
5277 let base = row * r * r;
5278 for u in 0..r {
5279 for vv in 0..r {
5280 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5281 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5282 row_hessians[base + u * r + vv] = a;
5283 }
5284 }
5285 for u in 0..r {
5286 for vv in (u + 1)..r {
5287 let upper = row_hessians[base + u * r + vv];
5288 let lower = row_hessians[base + vv * r + u];
5289 let sym = 0.5 * (upper + lower);
5290 row_hessians[base + u * r + vv] = sym;
5291 row_hessians[base + vv * r + u] = sym;
5292 }
5293 row_hessians[base + u * r + u] += r as f64;
5294 }
5295 }
5296 let mut marginal = vec![0.0_f64; n * p_m];
5297 for row in 0..n {
5298 for j in 0..p_m {
5299 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5300 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5301 }
5302 }
5303 let mut logslope = vec![0.0_f64; n * p_g];
5304 for row in 0..n {
5305 for j in 0..p_g {
5306 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5307 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5308 }
5309 }
5310 let v: Vec<f64> = (0..p_total)
5311 .map(|i| {
5312 let seed = (i as f64) * 0.157 + 0.6;
5313 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
5314 })
5315 .collect();
5316
5317 let backend = match HvpKernelBackend::probe() {
5319 Ok(b) => b,
5320 Err(err) => {
5321 eprintln!("[bms_flex_row hvp hill-climb] backend probe failed: {err}");
5322 return;
5323 }
5324 };
5325 let stream = backend.stream.clone();
5326 let d_h = match stream.clone_htod(&row_hessians) {
5327 Ok(s) => s,
5328 Err(err) => {
5329 eprintln!("[bms_flex_row hvp hill-climb] upload h failed (likely OOM): {err}");
5330 return;
5331 }
5332 };
5333 let d_m = match stream.clone_htod(&marginal) {
5334 Ok(s) => s,
5335 Err(err) => {
5336 eprintln!("[bms_flex_row hvp hill-climb] upload marg failed: {err}");
5337 return;
5338 }
5339 };
5340 let d_g = match stream.clone_htod(&logslope) {
5341 Ok(s) => s,
5342 Err(err) => {
5343 eprintln!("[bms_flex_row hvp hill-climb] upload logslope failed: {err}");
5344 return;
5345 }
5346 };
5347 let storage = DeviceResidentRowHess {
5348 hess: d_h,
5349 marginal_design: d_m,
5350 logslope_design: d_g,
5351 n,
5352 r,
5353 block: block.clone(),
5354 primary: primary.clone(),
5355
5356 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5357 };
5358 let warmup: usize = 3;
5359 let iters: usize = 15;
5360 for _ in 0..warmup {
5361 let out =
5362 launch_bms_flex_row_hvp(&storage, &v).expect("warmup GPU HVP must launch");
5363 assert_eq!(out.len(), p_total);
5364 }
5365 let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5366 for _ in 0..iters {
5367 let t0 = std::time::Instant::now();
5368 let out = launch_bms_flex_row_hvp(&storage, &v).expect("GPU HVP must launch");
5369 gpu_us.push(t0.elapsed().as_micros());
5370 assert_eq!(out.len(), p_total);
5371 }
5372 gpu_us.sort_unstable();
5373 let gpu_median = gpu_us[iters / 2];
5374
5375 const CHUNK_ROWS: usize = 4096;
5381 let cpu_hvp_parallel = || -> Vec<f64> {
5382 let nchunks = n.div_ceil(CHUNK_ROWS);
5383 (0..nchunks)
5384 .into_par_iter()
5385 .fold(
5386 || vec![0.0_f64; p_total],
5387 |mut acc, ci| {
5388 let lo = ci * CHUNK_ROWS;
5389 let hi = (lo + CHUNK_ROWS).min(n);
5390 let m = hi - lo;
5391 let partial = cpu_oracle_bms_flex_row_hvp(
5392 &row_hessians[lo * r * r..hi * r * r],
5393 &marginal[lo * p_m..hi * p_m],
5394 &logslope[lo * p_g..hi * p_g],
5395 &block,
5396 &primary,
5397 m,
5398 &v,
5399 );
5400 for (a, &p) in acc.iter_mut().zip(partial.iter()) {
5401 *a += p;
5402 }
5403 acc
5404 },
5405 )
5406 .reduce(
5407 || vec![0.0_f64; p_total],
5408 |mut a, b| {
5409 for (ax, bx) in a.iter_mut().zip(b.iter()) {
5410 *ax += *bx;
5411 }
5412 a
5413 },
5414 )
5415 };
5416 let warm = cpu_hvp_parallel();
5418 assert_eq!(warm.len(), p_total);
5419 let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5420 for _ in 0..iters {
5421 let t0 = std::time::Instant::now();
5422 let out = cpu_hvp_parallel();
5423 cpu_us.push(t0.elapsed().as_micros());
5424 assert_eq!(out.len(), p_total);
5425 }
5426 cpu_us.sort_unstable();
5427 let cpu_median = cpu_us[iters / 2];
5428
5429 let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5430 eprintln!(
5431 "[bms_flex_row hvp hill-climb] large-scale n={n} r={r} p={p_total}: \
5432 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5433 speedup={speedup:.2}× (charter target ≥ 5×)"
5434 );
5435 assert!(
5436 speedup >= 5.0,
5437 "large-scale HVP perf gate: GPU only {speedup:.2}× faster than CPU; \
5438 need ≥ 5× per Block 9 charter (cpu_median={cpu_median}us, \
5439 gpu_median={gpu_median}us). Hill-climb the kernel until met or \
5440 prove the kernel is at hardware roofline."
5441 );
5442 }
5443 }
5444
5445 #[test]
5450 pub(crate) fn bms_flex_row_dense_block_v100_hill_climb_10x_vs_cpu_at_large_scale() {
5451 #[cfg(not(target_os = "linux"))]
5452 {
5453 eprintln!(
5454 "[bms_flex_row dense_block hill-climb] non-Linux host — skipping V100 perf gate"
5455 );
5456 }
5457 #[cfg(target_os = "linux")]
5458 {
5459 use rayon::prelude::*;
5460
5461 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5462 eprintln!(
5463 "[bms_flex_row dense_block hill-climb] no CUDA runtime — skipping V100 perf gate"
5464 );
5465 return;
5466 };
5467 let n = 195_000_usize;
5468 let p_m = 14_usize;
5469 let p_g = 12_usize;
5470 let p_h_dim = 10_usize;
5471 let p_w_dim = 8_usize;
5472 let r = 2 + p_h_dim + p_w_dim;
5473 let p_total = p_m + p_g + p_h_dim + p_w_dim;
5474 let block = BmsFlexBlockLayout {
5475 p_m,
5476 p_g,
5477 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5478 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5479 p_total,
5480 };
5481 let primary = BmsFlexPrimaryLayout {
5482 h: Some(2..2 + p_h_dim),
5483 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5484 r,
5485 };
5486
5487 let mut row_hessians = vec![0.0_f64; n * r * r];
5489 for row in 0..n {
5490 let base = row * r * r;
5491 for u in 0..r {
5492 for vv in 0..r {
5493 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5494 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5495 row_hessians[base + u * r + vv] = a;
5496 }
5497 }
5498 for u in 0..r {
5499 for vv in (u + 1)..r {
5500 let upper = row_hessians[base + u * r + vv];
5501 let lower = row_hessians[base + vv * r + u];
5502 let sym = 0.5 * (upper + lower);
5503 row_hessians[base + u * r + vv] = sym;
5504 row_hessians[base + vv * r + u] = sym;
5505 }
5506 row_hessians[base + u * r + u] += r as f64;
5507 }
5508 }
5509 let mut marginal = vec![0.0_f64; n * p_m];
5510 for row in 0..n {
5511 for j in 0..p_m {
5512 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5513 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5514 }
5515 }
5516 let mut logslope = vec![0.0_f64; n * p_g];
5517 for row in 0..n {
5518 for j in 0..p_g {
5519 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5520 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5521 }
5522 }
5523
5524 if p_total > DENSE_BLOCK_MAX_P {
5527 eprintln!(
5528 "[bms_flex_row dense_block hill-climb] p_total={p_total} > MAX={DENSE_BLOCK_MAX_P}, skipping"
5529 );
5530 return;
5531 }
5532 let backend = match HvpKernelBackend::probe() {
5533 Ok(b) => b,
5534 Err(err) => {
5535 eprintln!("[bms_flex_row dense_block hill-climb] backend probe failed: {err}");
5536 return;
5537 }
5538 };
5539 let stream = backend.stream.clone();
5540 let d_h = match stream.clone_htod(&row_hessians) {
5541 Ok(s) => s,
5542 Err(err) => {
5543 eprintln!("[bms_flex_row dense_block hill-climb] upload h failed: {err}");
5544 return;
5545 }
5546 };
5547 let d_m = match stream.clone_htod(&marginal) {
5548 Ok(s) => s,
5549 Err(err) => {
5550 eprintln!("[bms_flex_row dense_block hill-climb] upload marg failed: {err}");
5551 return;
5552 }
5553 };
5554 let d_g = match stream.clone_htod(&logslope) {
5555 Ok(s) => s,
5556 Err(err) => {
5557 eprintln!(
5558 "[bms_flex_row dense_block hill-climb] upload logslope failed: {err}"
5559 );
5560 return;
5561 }
5562 };
5563 let storage = DeviceResidentRowHess {
5564 hess: d_h,
5565 marginal_design: d_m,
5566 logslope_design: d_g,
5567 n,
5568 r,
5569 block: block.clone(),
5570 primary: primary.clone(),
5571
5572 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5573 };
5574 let warmup: usize = 2;
5576 let iters: usize = 5;
5577 for _ in 0..warmup {
5578 let out = launch_bms_flex_row_dense_block(&storage)
5579 .expect("warmup GPU dense_block must launch");
5580 assert_eq!(out.len(), p_total * p_total);
5581 }
5582 let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5583 for _ in 0..iters {
5584 let t0 = std::time::Instant::now();
5585 let out =
5586 launch_bms_flex_row_dense_block(&storage).expect("GPU dense_block must launch");
5587 gpu_us.push(t0.elapsed().as_micros());
5588 assert_eq!(out.len(), p_total * p_total);
5589 }
5590 gpu_us.sort_unstable();
5591 let gpu_median = gpu_us[iters / 2];
5592
5593 const CHUNK_ROWS: usize = 2048;
5596 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5597 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5598 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5599 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5600 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5601 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5602 let cpu_build_parallel = || -> Vec<f64> {
5603 let nchunks = n.div_ceil(CHUNK_ROWS);
5604 (0..nchunks)
5605 .into_par_iter()
5606 .fold(
5607 || vec![0.0_f64; p_total * p_total],
5608 |mut acc, ci| {
5609 let lo = ci * CHUNK_ROWS;
5610 let hi = (lo + CHUNK_ROWS).min(n);
5611 let mut phi: Vec<Vec<f64>> = vec![vec![0.0_f64; p_total]; r];
5612 for row in lo..hi {
5613 for col in phi.iter_mut() {
5614 col.iter_mut().for_each(|v| *v = 0.0);
5615 }
5616 let mrow = &marginal[row * p_m..(row + 1) * p_m];
5617 let grow = &logslope[row * p_g..(row + 1) * p_g];
5618 for k in 0..p_m {
5619 phi[0][k] = mrow[k];
5620 }
5621 for k in 0..p_g {
5622 phi[1][p_m + k] = grow[k];
5623 }
5624 for k in 0..h_block_len {
5625 phi[h_primary_start + k][h_block_start + k] = 1.0;
5626 }
5627 for k in 0..w_block_len {
5628 phi[w_primary_start + k][w_block_start + k] = 1.0;
5629 }
5630 let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5631 for u in 0..r {
5632 for v_idx in 0..r {
5633 let huv = hrow[u * r + v_idx];
5634 if huv == 0.0 {
5635 continue;
5636 }
5637 for m in 0..p_total {
5638 let pm = phi[u][m];
5639 if pm == 0.0 {
5640 continue;
5641 }
5642 let scaled = huv * pm;
5643 for nn in 0..p_total {
5644 acc[m * p_total + nn] += scaled * phi[v_idx][nn];
5645 }
5646 }
5647 }
5648 }
5649 }
5650 acc
5651 },
5652 )
5653 .reduce(
5654 || vec![0.0_f64; p_total * p_total],
5655 |mut a, b| {
5656 for (ax, bx) in a.iter_mut().zip(b.iter()) {
5657 *ax += *bx;
5658 }
5659 a
5660 },
5661 )
5662 };
5663 let warm_cpu = cpu_build_parallel();
5664 assert_eq!(warm_cpu.len(), p_total * p_total);
5665 let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5666 for _ in 0..iters {
5667 let t0 = std::time::Instant::now();
5668 let out = cpu_build_parallel();
5669 cpu_us.push(t0.elapsed().as_micros());
5670 assert_eq!(out.len(), p_total * p_total);
5671 }
5672 cpu_us.sort_unstable();
5673 let cpu_median = cpu_us[iters / 2];
5674
5675 let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5676 eprintln!(
5677 "[bms_flex_row dense_block hill-climb] large-scale n={n} r={r} p={p_total}: \
5678 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5679 speedup={speedup:.2}× (charter target ≥ 10×)"
5680 );
5681 assert!(
5682 speedup >= 10.0,
5683 "large-scale dense-H perf gate: GPU only {speedup:.2}× faster than CPU; \
5684 need ≥ 10× per Block 9 charter (cpu_median={cpu_median}us, \
5685 gpu_median={gpu_median}us). Hill-climb the dense_block kernel \
5686 (warp-stripe the u-v-m-n loop, vectorise loads, etc.) until met \
5687 or prove the kernel is at hardware roofline."
5688 );
5689 }
5690 }
5691}