1#[cfg(target_os = "linux")]
85use std::sync::OnceLock;
86
87use gam_gpu::gpu_error::GpuError;
88
89#[cfg(target_os = "linux")]
90use std::sync::Arc;
91
92#[cfg(target_os = "linux")]
93use cudarc::driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
94
95pub(crate) const MAX_R: usize = 32;
99
100#[cfg(target_os = "linux")]
104pub(crate) const ROW_KERNEL_THREADS: u32 = 32;
105
106pub(crate) const COEFF4: usize = 4;
109
110pub(crate) const MOMENT_STRIDE: usize = 10;
113
114pub(crate) enum CellMomentsSource<'a> {
119 Host(&'a [f64]),
122 #[cfg(target_os = "linux")]
127 Device(&'a CudaSlice<f64>),
128}
129
130impl<'a> CellMomentsSource<'a> {
131 pub(crate) fn len(&self) -> usize {
133 match self {
134 CellMomentsSource::Host(slice) => slice.len(),
135 #[cfg(target_os = "linux")]
136 CellMomentsSource::Device(d) => d.len(),
137 }
138 }
139}
140
141macro_rules! define_bms_flex_row_kernel_input_types {
149 (
150 f64_fields: [$($f64_field:ident),+ $(,)?],
151 u32_fields: [$($u32_field:ident),+ $(,)?],
152 moments_field: $moments_field:ident $(,)?
153 ) => {
154 pub(crate) struct BmsFlexRowKernelInputs<'a> {
155 pub n_rows: usize,
157 pub r: usize,
159 pub p_h: usize,
161 pub p_w: usize,
163 pub s_f: f64,
166 $(pub $f64_field: &'a [f64],)+
167 $(pub $u32_field: &'a [u32],)+
168 pub $moments_field: CellMomentsSource<'a>,
169 }
170
171 pub(crate) struct BmsFlexRowKernelInputsOwned {
176 pub n_rows: usize,
177 pub r: usize,
178 pub p_h: usize,
179 pub p_w: usize,
180 pub s_f: f64,
181 $(pub $f64_field: Vec<f64>,)+
182 $(pub $u32_field: Vec<u32>,)+
183 pub $moments_field: Vec<f64>,
184 #[cfg(target_os = "linux")]
188 pub cell_moments_device: Option<CudaSlice<f64>>,
189 }
190
191 impl BmsFlexRowKernelInputsOwned {
192 pub(crate) fn as_borrowed(&self) -> BmsFlexRowKernelInputs<'_> {
197 #[cfg(target_os = "linux")]
198 let cell_moments = match self.cell_moments_device.as_ref() {
199 Some(d) => CellMomentsSource::Device(d),
200 None => CellMomentsSource::Host(&self.cell_moments),
201 };
202 #[cfg(not(target_os = "linux"))]
203 let cell_moments = CellMomentsSource::Host(&self.cell_moments);
204 BmsFlexRowKernelInputs {
205 n_rows: self.n_rows,
206 r: self.r,
207 p_h: self.p_h,
208 p_w: self.p_w,
209 s_f: self.s_f,
210 $($f64_field: &self.$f64_field,)+
211 $($u32_field: &self.$u32_field,)+
212 $moments_field: cell_moments,
213 }
214 }
215 }
216 };
217}
218
219define_bms_flex_row_kernel_input_types! {
220 f64_fields: [
221 q,
222 b,
223 mu_1,
224 mu_2,
225 z_obs,
226 y,
227 w,
228 cell_c0,
229 cell_c1,
230 cell_c2,
231 cell_c3,
232 cell_a,
233 cell_aa,
234 cell_r,
235 cell_ar,
236 cell_sbb,
237 cell_sbh,
238 cell_sbw,
239 chi_obs,
240 xi_obs,
241 rho_u,
242 tau_u,
243 r_uv,
244 ],
245 u32_fields: [cell_offsets],
246 moments_field: cell_moments,
247}
248
249#[derive(Debug)]
251pub(crate) struct BmsFlexRowKernelOutputs {
252 pub neglog: Vec<f64>,
254 pub grad: Vec<f64>,
256 pub hess: Vec<f64>,
259}
260
261impl<'a> BmsFlexRowKernelInputs<'a> {
262 pub(crate) fn validate(&self) -> Result<(), GpuError> {
265 if self.r == 0 {
266 return Err(GpuError::DriverCallFailed {
267 reason: "bms_flex_row inputs: r must be > 0".to_string(),
268 });
269 }
270 if self.r > MAX_R {
271 return Err(GpuError::DriverCallFailed {
272 reason: format!("bms_flex_row inputs: r={} exceeds MAX_R={MAX_R}", self.r),
273 });
274 }
275 if self.r != 2 + self.p_h + self.p_w {
276 return Err(GpuError::DriverCallFailed {
277 reason: format!(
278 "bms_flex_row inputs: r={} must equal 2 + p_h({}) + p_w({}) = {}",
279 self.r,
280 self.p_h,
281 self.p_w,
282 2 + self.p_h + self.p_w
283 ),
284 });
285 }
286 let n = self.n_rows;
287 let check_len = |name: &str, have: usize, want: usize| -> Result<(), GpuError> {
288 if have != want {
289 return Err(GpuError::DriverCallFailed {
290 reason: format!("bms_flex_row inputs: {name}.len()={have} != {want}"),
291 });
292 }
293 Ok(())
294 };
295 check_len("q", self.q.len(), n)?;
296 check_len("b", self.b.len(), n)?;
297 check_len("mu_1", self.mu_1.len(), n)?;
298 check_len("mu_2", self.mu_2.len(), n)?;
299 check_len("z_obs", self.z_obs.len(), n)?;
300 check_len("y", self.y.len(), n)?;
301 check_len("w", self.w.len(), n)?;
302 check_len("chi_obs", self.chi_obs.len(), n)?;
303 check_len("xi_obs", self.xi_obs.len(), n)?;
304 check_len("rho_u", self.rho_u.len(), n * self.r)?;
305 check_len("tau_u", self.tau_u.len(), n * self.r)?;
306 check_len("r_uv", self.r_uv.len(), n * self.r * self.r)?;
307 check_len("cell_offsets", self.cell_offsets.len(), n + 1)?;
308 let total_cells_u32 = self.cell_offsets[n];
309 let total_cells = total_cells_u32 as usize;
310 check_len("cell_c0", self.cell_c0.len(), total_cells)?;
311 check_len("cell_c1", self.cell_c1.len(), total_cells)?;
312 check_len("cell_c2", self.cell_c2.len(), total_cells)?;
313 check_len("cell_c3", self.cell_c3.len(), total_cells)?;
314 check_len("cell_a", self.cell_a.len(), total_cells * COEFF4)?;
315 check_len("cell_aa", self.cell_aa.len(), total_cells * COEFF4)?;
316 check_len(
317 "cell_r",
318 self.cell_r.len(),
319 total_cells * self.r.saturating_sub(1) * COEFF4,
320 )?;
321 check_len(
322 "cell_ar",
323 self.cell_ar.len(),
324 total_cells * self.r.saturating_sub(1) * COEFF4,
325 )?;
326 check_len("cell_sbb", self.cell_sbb.len(), total_cells * COEFF4)?;
327 check_len(
328 "cell_sbh",
329 self.cell_sbh.len(),
330 total_cells * self.p_h * COEFF4,
331 )?;
332 check_len(
333 "cell_sbw",
334 self.cell_sbw.len(),
335 total_cells * self.p_w * COEFF4,
336 )?;
337 check_len(
338 "cell_moments",
339 self.cell_moments.len(),
340 total_cells * MOMENT_STRIDE,
341 )?;
342 for i in 0..n {
348 if self.cell_offsets[i] > self.cell_offsets[i + 1] {
349 return Err(GpuError::DriverCallFailed {
350 reason: format!(
351 "bms_flex_row inputs: cell_offsets must be monotone (offset[{}]={} > offset[{}]={})",
352 i,
353 self.cell_offsets[i],
354 i + 1,
355 self.cell_offsets[i + 1]
356 ),
357 });
358 }
359 }
360 Ok(())
361 }
362}
363
364#[cfg(target_os = "linux")]
377pub(crate) const ROW_KERNEL_BODY: &str = r#"
378// One block per row. blockDim.x = 32; threadIdx.x parallises per-cell sums.
379// CPU parity reference: src/families/bernoulli_marginal_slope.rs
380// ::compute_row_analytic_flex_from_parts_into.
381
382#define INV_TWO_PI 0.15915494309189535
383
384extern "C" __device__ __forceinline__ double atomic_add_f64(double *addr, double value) {
385 unsigned long long int *addr_as_ull = (unsigned long long int *)addr;
386 unsigned long long int old = *addr_as_ull;
387 unsigned long long int assumed;
388 do {
389 assumed = old;
390 double next = __longlong_as_double((long long int)assumed) + value;
391 old = atomicCAS(addr_as_ull, assumed, (unsigned long long int)__double_as_longlong(next));
392 } while (assumed != old);
393 return __longlong_as_double((long long int)old);
394}
395
396// `nan_fill_outputs`: thread-0-only path used when row inputs are degenerate
397// (`F_a` non-finite or non-positive). Writes NaNs to neglog/grad/hess so the
398// host falls back to CPU for that row.
399extern "C" __device__ __forceinline__ void
400nan_fill_outputs(int r,
401 int row,
402 double *out_neglog,
403 double *out_grad,
404 double *out_hess) {
405 double nan_value = __longlong_as_double(0x7ff8000000000000ULL);
406 out_neglog[row] = nan_value;
407 for (int u = 0; u < r; ++u) {
408 out_grad[row * r + u] = nan_value;
409 }
410 int rr = r * r;
411 for (int idx = 0; idx < rr; ++idx) {
412 out_hess[row * rr + idx] = nan_value;
413 }
414}
415
416extern "C" __global__ void bms_flex_row_kernel(
417 int n_rows,
418 int r,
419 int p_h,
420 int p_w,
421 double s_f, // currently unused on device:
422 // host has already baked S_f
423 // into the cubic coefficients.
424 // Kept for diagnostic parity.
425 const double * __restrict__ row_q,
426 const double * __restrict__ row_b,
427 const double * __restrict__ row_mu1,
428 const double * __restrict__ row_mu2,
429 const double * __restrict__ row_zobs,
430 const double * __restrict__ row_y,
431 const double * __restrict__ row_w,
432 const unsigned int * __restrict__ cell_offsets,
433 const double * __restrict__ cell_c0,
434 const double * __restrict__ cell_c1,
435 const double * __restrict__ cell_c2,
436 const double * __restrict__ cell_c3,
437 const double * __restrict__ cell_a, // [n_cells, 4]
438 const double * __restrict__ cell_aa, // [n_cells, 4]
439 const double * __restrict__ cell_r, // [n_cells, r-1, 4]
440 const double * __restrict__ cell_ar, // [n_cells, r-1, 4]
441 const double * __restrict__ cell_sbb, // [n_cells, 4]
442 const double * __restrict__ cell_sbh, // [n_cells, p_h, 4]
443 const double * __restrict__ cell_sbw, // [n_cells, p_w, 4]
444 const double * __restrict__ cell_moments, // [n_cells, 10]
445 const double * __restrict__ row_chi,
446 const double * __restrict__ row_xi,
447 const double * __restrict__ row_rho, // [n_rows, r]
448 const double * __restrict__ row_tau, // [n_rows, r]
449 const double * __restrict__ row_ruv, // [n_rows, r*r]
450 double * __restrict__ out_neglog,
451 double * __restrict__ out_grad,
452 double * __restrict__ out_hess)
453{
454 int row = blockIdx.x;
455 if (row >= n_rows) return;
456 int tid = threadIdx.x;
457
458 // ── shared scratch (sized to MAX_R = 32) ──────────────────────────────
459 // Layout (doubles):
460 // F_u [r]
461 // F_au [r]
462 // F_uv [r*r]
463 // bar_e_u [r]
464 // bar_e_uv [r*r]
465 // reduce_a [blockDim.x]
466 // reduce_b [blockDim.x]
467 // Sized for the worst case (r = MAX_R = 32).
468 __shared__ double F_u[32];
469 __shared__ double F_au[32];
470 __shared__ double F_uv[32 * 32];
471 __shared__ double bar_e_u[32];
472 __shared__ double bar_e_uv[32 * 32];
473 __shared__ double reduce_a[32];
474 __shared__ double reduce_b[32];
475 __shared__ double F_a_shared;
476 __shared__ double F_aa_shared;
477
478 // Zero scratch.
479 if (tid == 0) { F_a_shared = 0.0; F_aa_shared = 0.0; }
480 for (int u = tid; u < r; u += blockDim.x) {
481 F_u[u] = 0.0;
482 F_au[u] = 0.0;
483 }
484 for (int uv = tid; uv < r * r; uv += blockDim.x) {
485 F_uv[uv] = 0.0;
486 }
487 __syncthreads();
488
489 // ── per-cell sweep ───────────────────────────────────────────────────
490 unsigned int cell_lo = cell_offsets[row];
491 unsigned int cell_hi = cell_offsets[row + 1];
492 int n_cells = (int)(cell_hi - cell_lo);
493
494 double local_Fa = 0.0;
495 double local_Faa = 0.0;
496
497 for (int local_c = tid; local_c < n_cells; local_c += blockDim.x) {
498 unsigned int c = cell_lo + (unsigned int)local_c;
499
500 // Load cubic predictor coeffs C0..C3.
501 double C[4];
502 C[0] = cell_c0[c]; C[1] = cell_c1[c];
503 C[2] = cell_c2[c]; C[3] = cell_c3[c];
504
505 // Load m_0..m_9.
506 const double *m = cell_moments + (size_t)c * 10;
507
508 // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
509 // CPU parity: equivalent to the `eta_rs ⊗ moments` contraction in
510 // `cell_second_derivative_from_moments` after folding the
511 // cubic predictor.
512 double T[7];
513 #pragma unroll
514 for (int n = 0; n < 7; ++n) {
515 double acc = 0.0;
516 #pragma unroll
517 for (int e = 0; e < 4; ++e) {
518 acc = fma(C[e], m[e + n], acc);
519 }
520 T[n] = acc * INV_TWO_PI;
521 }
522
523 // D(R) = κ · Σ_k R_k · m_k.
524 // CPU parity: `cell_first_derivative_from_moments`.
525 #define D_OF(R) (INV_TWO_PI * (R[0]*m[0] + R[1]*m[1] + R[2]*m[2] + R[3]*m[3]))
526
527 // Q(R, S) = Σ_{p,q} R_p · S_q · T_{p+q}.
528 // CPU parity: the `eta_rs` folded dot in
529 // `cell_second_derivative_from_moments`.
530 #define Q_OF(R, S) \
531 ((R[0]*S[0])*T[0] + (R[0]*S[1] + R[1]*S[0])*T[1] \
532 + (R[0]*S[2] + R[1]*S[1] + R[2]*S[0])*T[2] \
533 + (R[0]*S[3] + R[1]*S[2] + R[2]*S[1] + R[3]*S[0])*T[3] \
534 + (R[1]*S[3] + R[2]*S[2] + R[3]*S[1])*T[4] \
535 + (R[2]*S[3] + R[3]*S[2])*T[5] \
536 + (R[3]*S[3])*T[6])
537
538 // F_a += D(A_c) ; F_aa += H(A_c, A_c, AA_c) = D(AA_c) − Q(A_c, A_c).
539 const double *A_c = cell_a + (size_t)c * 4;
540 const double *AA_c = cell_aa + (size_t)c * 4;
541 local_Fa += D_OF(A_c);
542 local_Faa += D_OF(AA_c) - Q_OF(A_c, A_c);
543
544 // For each u > 0: F_u += D(R_{c,u}) ; F_au += H(A_c, R_{c,u}, AR_{c,u})
545 // = D(AR_{c,u}) − Q(A_c, R_{c,u}).
546 for (int u = 1; u < r; ++u) {
547 const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
548 const double *AR_u = cell_ar + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
549 double d_R = D_OF(R_u);
550 double d_AR = D_OF(AR_u);
551 double q_AR = Q_OF(A_c, R_u);
552 atomic_add_f64(&F_u[u], d_R);
553 atomic_add_f64(&F_au[u], d_AR - q_AR);
554 }
555
556 // F_uv: only b·b, b·h_j, b·w_ℓ have a material `S_{c,uv}`; every other
557 // (u, v) pair just contributes −Q(R_u, R_v).
558 // CPU parity: `SparsePrimaryCoeffJetView::pair_from_b_family` with
559 // `COEFF_SUPPORT_BHW` — every cross pair outside the b-row is zero.
560 for (int u = 1; u < r; ++u) {
561 const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
562 for (int v = u; v < r; ++v) {
563 const double *R_v = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(v - 1)) * 4;
564 double q_uv = Q_OF(R_u, R_v);
565 double d_s = 0.0;
566 // S_{bb}: u == v == 1 (b coordinate).
567 if (u == 1 && v == 1) {
568 const double *S_bb = cell_sbb + (size_t)c * 4;
569 d_s = D_OF(S_bb);
570 }
571 // S_{b·h_j}: u == 1, v in score-warp block, or symmetric.
572 else if (u == 1 && v >= 2 && v < 2 + p_h) {
573 int j = v - 2;
574 const double *S_bh = cell_sbh + ((size_t)c * (size_t)p_h + (size_t)j) * 4;
575 d_s = D_OF(S_bh);
576 }
577 // S_{b·w_ℓ}: u == 1, v in link-wiggle block, or symmetric.
578 else if (u == 1 && v >= 2 + p_h && v < r) {
579 int l = v - (2 + p_h);
580 const double *S_bw = cell_sbw + ((size_t)c * (size_t)p_w + (size_t)l) * 4;
581 d_s = D_OF(S_bw);
582 }
583 // Symmetric mirror: u in (h or w) block, v == 1 cannot happen
584 // because we iterate v >= u; skip.
585 double val = d_s - q_uv;
586 atomic_add_f64(&F_uv[u * r + v], val);
587 }
588 }
589
590 #undef D_OF
591 #undef Q_OF
592 }
593
594 // Block reduction of local_Fa, local_Faa into shared.
595 reduce_a[tid] = local_Fa;
596 reduce_b[tid] = local_Faa;
597 __syncthreads();
598 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
599 if (tid < stride) {
600 reduce_a[tid] += reduce_a[tid + stride];
601 reduce_b[tid] += reduce_b[tid + stride];
602 }
603 __syncthreads();
604 }
605 if (tid == 0) {
606 F_a_shared = reduce_a[0];
607 F_aa_shared = reduce_b[0];
608 }
609 __syncthreads();
610
611 // ── thread-0 finalisation: IFT + observed-point + Mills + writes ──────
612 if (tid != 0) return;
613
614 double F_a = F_a_shared;
615 double F_aa = F_aa_shared;
616 double mu_1 = row_mu1[row];
617 double mu_2 = row_mu2[row];
618
619 // q-row overrides.
620 // F_q = -mu_1 ; F_qq = -mu_2 ; F_qv = 0 (v > 0) ; F_aq = 0.
621 F_u[0] = -mu_1;
622 F_au[0] = 0.0;
623 // Zero the q-cross row/column of F_uv (u == 0 or v == 0), then plant -mu_2 at (0,0).
624 for (int v = 0; v < r; ++v) {
625 F_uv[0 * r + v] = 0.0;
626 F_uv[v * r + 0] = 0.0;
627 }
628 F_uv[0 * r + 0] = -mu_2;
629
630 // Guard: degenerate F_a ⇒ NaN-fill this row's outputs.
631 if (!isfinite(F_a) || F_a <= 0.0) {
632 nan_fill_outputs(r, row, out_neglog, out_grad, out_hess);
633 return;
634 }
635 double inv_Fa = 1.0 / F_a;
636
637 // IFT, first order.
638 // a_u = -F_u · inv_Fa (q-override: a_q = mu_1 · inv_Fa).
639 double a_u[32];
640 a_u[0] = mu_1 * inv_Fa;
641 for (int u = 1; u < r; ++u) {
642 a_u[u] = -F_u[u] * inv_Fa;
643 }
644
645 // IFT, second order.
646 // a_uv = -(F_uv + F_au · a_v + F_av · a_u + F_aa · a_u · a_v) · inv_Fa.
647 // The q-row contributions (u==0 or v==0) collapse to a_uv = mu_2 · inv_Fa
648 // when both are 0 and to (F_au_v) · inv_Fa-style mixed shape otherwise.
649 // We compute it uniformly using the populated F_uv / F_au with the
650 // q-overrides above.
651 double a_uv[32 * 32];
652 for (int u = 0; u < r; ++u) {
653 for (int v = u; v < r; ++v) {
654 double term = F_uv[u * r + v]
655 + F_au[v] * a_u[u]
656 + F_au[u] * a_u[v]
657 + F_aa * a_u[u] * a_u[v];
658 double val = -term * inv_Fa;
659 a_uv[u * r + v] = val;
660 a_uv[v * r + u] = val;
661 }
662 }
663
664 // Observed predictor jets at z_obs.
665 // bar_e_u = chi · a_u + rho_u.
666 // bar_e_uv = chi · a_uv + xi · a_u · a_v + tau_u · a_v + a_u · tau_v + r_uv.
667 double chi = row_chi[row];
668 double xi = row_xi[row];
669 const double *rho = row_rho + (size_t)row * r;
670 const double *tau = row_tau + (size_t)row * r;
671 const double *ruv = row_ruv + (size_t)row * r * r;
672
673 for (int u = 0; u < r; ++u) {
674 bar_e_u[u] = chi * a_u[u] + rho[u];
675 }
676 for (int u = 0; u < r; ++u) {
677 for (int v = u; v < r; ++v) {
678 double val = chi * a_uv[u * r + v]
679 + xi * a_u[u] * a_u[v]
680 + tau[u] * a_u[v]
681 + a_u[u] * tau[v]
682 + ruv[u * r + v];
683 bar_e_uv[u * r + v] = val;
684 if (u != v) {
685 bar_e_uv[v * r + u] = val;
686 }
687 }
688 }
689
690 // Probit Mills.
691 double y = row_y[row];
692 double w = row_w[row];
693 double s = 2.0 * y - 1.0;
694 // The "observed predictor" e_obs is the value (degree-0) term of the
695 // observed jet — same convention as the CPU path. CPU parity:
696 // `e_obs = chi · a_0 + rho_0`... well, no: `bar_e_u` is the *first*
697 // derivative jet, not the value. The observed predictor value comes
698 // from the host pre-evaluation as `rho_u[0]` of the value jet —
699 // pre-baked into the host's `m = s · e_obs` payload. For Stage 2 we
700 // expose it via the `bar_e_u[0]` slot which is `chi·a_0 + rho_0`; the
701 // host wiring lands in the dispatcher wave that bridges this kernel
702 // and the row evaluator in `bernoulli_marginal_slope.rs`.
703 double e_obs = bar_e_u[0];
704 double m_arg = s * e_obs;
705 double log_cdf, lambda;
706 log_ndtr_and_mills(m_arg, &log_cdf, &lambda);
707 double A_i = -w * s * lambda;
708 double B_i = w * lambda * (m_arg + lambda);
709
710 out_neglog[row] = -w * log_cdf;
711 for (int u = 0; u < r; ++u) {
712 out_grad[row * r + u] = A_i * bar_e_u[u];
713 }
714 for (int u = 0; u < r; ++u) {
715 for (int v = u; v < r; ++v) {
716 double val = B_i * bar_e_u[u] * bar_e_u[v] + A_i * bar_e_uv[u * r + v];
717 out_hess[row * r * r + u * r + v] = val;
718 if (u != v) {
719 out_hess[row * r * r + v * r + u] = val;
720 }
721 }
722 }
723}
724"#;
725
726#[inline]
733pub(crate) fn s_f_diagnostic_finite(inputs: &BmsFlexRowKernelInputs<'_>) -> bool {
734 inputs.s_f.is_finite() && inputs.s_f > 0.0
735}
736
737#[cfg(target_os = "linux")]
738pub(crate) struct RowKernelBackend {
739 pub(crate) stream: Arc<CudaStream>,
740 pub(crate) module: Arc<CudaModule>,
741}
742
743#[cfg(target_os = "linux")]
744impl RowKernelBackend {
745 pub(crate) fn probe() -> Result<&'static Self, GpuError> {
746 static BACKEND: OnceLock<Result<RowKernelBackend, GpuError>> = OnceLock::new();
747 BACKEND
748 .get_or_init(|| {
749 gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row", |parts| {
750 let row_kernel_source = [
751 gam_gpu::numerics_device::PROBIT_NUMERICS_CU,
752 ROW_KERNEL_BODY,
753 ]
754 .concat();
755 let ptx = gam_gpu::device_cache::compile_ptx_arch(&row_kernel_source)
766 .map_err(|err| GpuError::DriverCallFailed {
767 reason: format!("bms_flex_row NVRTC compile failed: {err}"),
768 })?;
769 let module =
770 parts
771 .ctx
772 .load_module(ptx)
773 .map_err(|err| GpuError::DriverCallFailed {
774 reason: format!("bms_flex_row module load failed: {err}"),
775 })?;
776 Ok(RowKernelBackend {
777 stream: parts.stream.clone(),
778 module,
779 })
780 })
781 })
782 .as_ref()
783 .map_err(GpuError::clone)
784 }
785}
786
787pub(crate) fn launch_bms_flex_row_kernel(
792 inputs: BmsFlexRowKernelInputs<'_>,
793) -> Result<BmsFlexRowKernelOutputs, GpuError> {
794 inputs.validate()?;
795 if !s_f_diagnostic_finite(&inputs) {
796 return Err(GpuError::DriverCallFailed {
797 reason: format!(
798 "bms_flex_row inputs: s_f must be positive and finite, got {}",
799 inputs.s_f
800 ),
801 });
802 }
803
804 #[cfg(target_os = "linux")]
805 {
806 launch_linux(inputs)
807 }
808 #[cfg(not(target_os = "linux"))]
809 {
810 Err(GpuError::DriverLibraryUnavailable {
811 reason: "bms_flex_row GPU kernel is Linux-only".to_string(),
812 })
813 }
814}
815
816#[cfg(target_os = "linux")]
817pub(crate) fn launch_linux(
818 inputs: BmsFlexRowKernelInputs<'_>,
819) -> Result<BmsFlexRowKernelOutputs, GpuError> {
820 let backend = RowKernelBackend::probe()?;
821 let stream = &backend.stream;
822
823 let upload_f64 = |slice: &[f64], label: &str| {
824 stream
825 .clone_htod(slice)
826 .map_err(|err| GpuError::DriverCallFailed {
827 reason: format!("bms_flex_row upload {label}: {err}"),
828 })
829 };
830 let upload_u32 = |slice: &[u32], label: &str| {
831 stream
832 .clone_htod(slice)
833 .map_err(|err| GpuError::DriverCallFailed {
834 reason: format!("bms_flex_row upload {label}: {err}"),
835 })
836 };
837
838 let d_q = upload_f64(inputs.q, "q")?;
839 let d_b = upload_f64(inputs.b, "b")?;
840 let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
841 let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
842 let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
843 let d_y = upload_f64(inputs.y, "y")?;
844 let d_w = upload_f64(inputs.w, "w")?;
845 let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
846 let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
847 let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
848 let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
849 let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
850 let d_a = upload_f64(inputs.cell_a, "cell_a")?;
851 let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
852 let d_r = upload_f64(inputs.cell_r, "cell_r")?;
853 let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
854 let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
855 let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
856 let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
857 let owned_host_moments: CudaSlice<f64>;
861 let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
862 CellMomentsSource::Host(slice) => {
863 owned_host_moments = upload_f64(slice, "cell_moments")?;
864 &owned_host_moments
865 }
866 CellMomentsSource::Device(d) => *d,
867 };
868 let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
869 let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
870 let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
871 let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
872 let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
873
874 let n = inputs.n_rows;
875 let r = inputs.r;
876 let mut d_neglog = stream
877 .alloc_zeros::<f64>(n)
878 .map_err(|err| GpuError::DriverCallFailed {
879 reason: format!("bms_flex_row alloc neglog: {err}"),
880 })?;
881 let mut d_grad =
882 stream
883 .alloc_zeros::<f64>(n * r)
884 .map_err(|err| GpuError::DriverCallFailed {
885 reason: format!("bms_flex_row alloc grad: {err}"),
886 })?;
887 let mut d_hess =
888 stream
889 .alloc_zeros::<f64>(n * r * r)
890 .map_err(|err| GpuError::DriverCallFailed {
891 reason: format!("bms_flex_row alloc hess: {err}"),
892 })?;
893
894 let func = backend
895 .module
896 .load_function("bms_flex_row_kernel")
897 .map_err(|err| GpuError::DriverCallFailed {
898 reason: format!("bms_flex_row load_function: {err}"),
899 })?;
900
901 let cfg = LaunchConfig {
902 grid_dim: (n as u32, 1, 1),
903 block_dim: (ROW_KERNEL_THREADS, 1, 1),
904 shared_mem_bytes: 0,
905 };
906 let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
907 reason: format!("bms_flex_row: n_rows={n} exceeds i32 range"),
908 })?;
909 let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
910 reason: format!("bms_flex_row: r={r} exceeds i32 range"),
911 })?;
912 let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
913 reason: format!("bms_flex_row: p_h={} exceeds i32 range", inputs.p_h),
914 })?;
915 let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
916 reason: format!("bms_flex_row: p_w={} exceeds i32 range", inputs.p_w),
917 })?;
918 let s_f = inputs.s_f;
919
920 let mut builder = stream.launch_builder(&func);
921 builder
922 .arg(&n_i32)
923 .arg(&r_i32)
924 .arg(&p_h_i32)
925 .arg(&p_w_i32)
926 .arg(&s_f)
927 .arg(&d_q)
928 .arg(&d_b)
929 .arg(&d_mu1)
930 .arg(&d_mu2)
931 .arg(&d_zobs)
932 .arg(&d_y)
933 .arg(&d_w)
934 .arg(&d_offsets)
935 .arg(&d_c0)
936 .arg(&d_c1)
937 .arg(&d_c2)
938 .arg(&d_c3)
939 .arg(&d_a)
940 .arg(&d_aa)
941 .arg(&d_r)
942 .arg(&d_ar)
943 .arg(&d_sbb)
944 .arg(&d_sbh)
945 .arg(&d_sbw)
946 .arg(d_moments_ref)
947 .arg(&d_chi)
948 .arg(&d_xi)
949 .arg(&d_rho)
950 .arg(&d_tau)
951 .arg(&d_ruv)
952 .arg(&mut d_neglog)
953 .arg(&mut d_grad)
954 .arg(&mut d_hess);
955
956 unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
963 reason: format!("bms_flex_row launch: {err}"),
964 })?;
965 stream
966 .synchronize()
967 .map_err(|err| GpuError::DriverCallFailed {
968 reason: format!("bms_flex_row synchronize: {err}"),
969 })?;
970
971 let neglog = stream
972 .clone_dtoh(&d_neglog)
973 .map_err(|err| GpuError::DriverCallFailed {
974 reason: format!("bms_flex_row download neglog: {err}"),
975 })?;
976 let grad = stream
977 .clone_dtoh(&d_grad)
978 .map_err(|err| GpuError::DriverCallFailed {
979 reason: format!("bms_flex_row download grad: {err}"),
980 })?;
981 let hess = stream
982 .clone_dtoh(&d_hess)
983 .map_err(|err| GpuError::DriverCallFailed {
984 reason: format!("bms_flex_row download hess: {err}"),
985 })?;
986
987 Ok(BmsFlexRowKernelOutputs { neglog, grad, hess })
988}
989
990#[cfg(target_os = "linux")]
1042#[derive(Clone, Debug)]
1043pub(crate) struct BmsFlexBlockLayout {
1044 pub p_m: usize,
1045 pub p_g: usize,
1046 pub h: Option<std::ops::Range<usize>>,
1047 pub w: Option<std::ops::Range<usize>>,
1048 pub p_total: usize,
1049}
1050
1051#[cfg(target_os = "linux")]
1054#[derive(Clone, Debug)]
1055pub(crate) struct BmsFlexPrimaryLayout {
1056 pub h: Option<std::ops::Range<usize>>,
1057 pub w: Option<std::ops::Range<usize>>,
1058 pub r: usize,
1059}
1060
1061#[cfg(target_os = "linux")]
1067pub(crate) const HVP_ROWS_PER_CTA: u32 = 256;
1068
1069#[cfg(target_os = "linux")]
1071pub(crate) const HVP_THREADS: u32 = 128;
1072
1073#[cfg(target_os = "linux")]
1078pub(crate) const REDUCTION_THREADS: u32 = 256;
1079
1080#[cfg(target_os = "linux")]
1085pub(crate) const BMS_FLEX_ROW_HVP_MAX_RHS: usize = 8;
1086
1087#[cfg(target_os = "linux")]
1108pub struct DeviceResidentRowHess {
1109 pub(crate) hess: CudaSlice<f64>,
1113 pub(crate) marginal_design: CudaSlice<f64>,
1114 pub(crate) logslope_design: CudaSlice<f64>,
1115 pub(crate) n: usize,
1116 pub(crate) r: usize,
1117 pub(crate) block: BmsFlexBlockLayout,
1118 pub(crate) primary: BmsFlexPrimaryLayout,
1119 pub(crate) bytes: u64,
1121}
1122
1123#[cfg(target_os = "linux")]
1124impl std::fmt::Debug for DeviceResidentRowHess {
1125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1126 f.debug_struct("DeviceResidentRowHess")
1127 .field("n", &self.n)
1128 .field("r", &self.r)
1129 .field("p_total", &self.block.p_total)
1130 .field("bytes", &self.bytes)
1131 .finish()
1132 }
1133}
1134
1135#[cfg(target_os = "linux")]
1138pub(crate) fn num_hvp_chunks(n: usize) -> usize {
1139 n.div_ceil(HVP_ROWS_PER_CTA as usize)
1140}
1141
1142#[cfg(target_os = "linux")]
1145pub(crate) const HVP_KERNEL_SOURCE: &str = r#"
1146// CPU parity reference: cpu_oracle_bms_flex_row_hvp / cpu_oracle_bms_flex_row_diagonal
1147// in this module.
1148
1149#define MAX_MULTI_RHS 8
1150
1151extern "C" __global__ void bms_flex_row_hvp_partial(
1152 int n_rows,
1153 int r,
1154 int p_m,
1155 int p_g,
1156 int p_total,
1157 int h_block_start,
1158 int h_block_len,
1159 int w_block_start,
1160 int w_block_len,
1161 int h_primary_start,
1162 int w_primary_start,
1163 int rows_per_cta,
1164 const double * __restrict__ row_hessians, // [n, r*r]
1165 const double * __restrict__ marginal_design, // [n, p_m] row-major
1166 const double * __restrict__ logslope_design, // [n, p_g] row-major
1167 const double * __restrict__ v, // [p_total]
1168 double * __restrict__ partial) // [num_chunks, p_total]
1169{
1170 int chunk = blockIdx.x;
1171 int tid = threadIdx.x;
1172 int row_lo = chunk * rows_per_cta;
1173 int row_hi = row_lo + rows_per_cta;
1174 if (row_hi > n_rows) row_hi = n_rows;
1175
1176 // Zero this chunk's partial slice cooperatively.
1177 double *out = partial + (size_t)chunk * (size_t)p_total;
1178 for (int j = tid; j < p_total; j += blockDim.x) {
1179 out[j] = 0.0;
1180 }
1181 __syncthreads();
1182
1183 // Each thread serially processes a stride-of-blockDim set of rows so
1184 // every write to `out[..]` happens from one thread → no atomics within
1185 // the chunk. To keep writes race-free across threads of the same chunk,
1186 // we serialize the cross-row accumulation through a per-row barrier:
1187 // thread 0 of the block processes all rows in the chunk. The per-row
1188 // work is dominated by the dot/axpy over `p_m + p_g`, which is large.
1189 // For Stage 3 we ship the simple, correct path (thread 0 sequential
1190 // per row, blockDim.x threads parallel within a row's dot/axpy).
1191 __shared__ double row_dir[32];
1192 __shared__ double action[32];
1193 __shared__ double dot_reduce[128];
1194
1195 for (int row = row_lo; row < row_hi; ++row) {
1196 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1197 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1198 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1199
1200 // row_dir[0] = mrow · v[0..p_m]
1201 double local = 0.0;
1202 for (int j = tid; j < p_m; j += blockDim.x) {
1203 local += mrow[j] * v[j];
1204 }
1205 dot_reduce[tid] = local;
1206 __syncthreads();
1207 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1208 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1209 __syncthreads();
1210 }
1211 if (tid == 0) row_dir[0] = dot_reduce[0];
1212
1213 // row_dir[1] = grow · v[p_m..p_m+p_g]
1214 local = 0.0;
1215 for (int j = tid; j < p_g; j += blockDim.x) {
1216 local += grow[j] * v[p_m + j];
1217 }
1218 dot_reduce[tid] = local;
1219 __syncthreads();
1220 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1221 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1222 __syncthreads();
1223 }
1224 if (tid == 0) row_dir[1] = dot_reduce[0];
1225
1226 // h/w blocks: direct copy.
1227 if (tid == 0) {
1228 for (int k = 0; k < h_block_len; ++k) {
1229 row_dir[h_primary_start + k] = v[h_block_start + k];
1230 }
1231 for (int k = 0; k < w_block_len; ++k) {
1232 row_dir[w_primary_start + k] = v[w_block_start + k];
1233 }
1234 }
1235 __syncthreads();
1236
1237 // action[u] = Σ_v Hrow[u*r+v] · row_dir[v], computed by thread u (u < r).
1238 if (tid < r) {
1239 double acc = 0.0;
1240 for (int vv = 0; vv < r; ++vv) {
1241 acc += Hrow[tid * r + vv] * row_dir[vv];
1242 }
1243 action[tid] = acc;
1244 }
1245 __syncthreads();
1246
1247 // Pull back into joint β slot.
1248 // marginal: out[j] += action[0] · mrow[j] (parallel j)
1249 double a0 = action[0];
1250 for (int j = tid; j < p_m; j += blockDim.x) {
1251 out[j] += a0 * mrow[j];
1252 }
1253 double a1 = action[1];
1254 for (int j = tid; j < p_g; j += blockDim.x) {
1255 out[p_m + j] += a1 * grow[j];
1256 }
1257 if (tid == 0) {
1258 for (int k = 0; k < h_block_len; ++k) {
1259 out[h_block_start + k] += action[h_primary_start + k];
1260 }
1261 for (int k = 0; k < w_block_len; ++k) {
1262 out[w_block_start + k] += action[w_primary_start + k];
1263 }
1264 }
1265 __syncthreads();
1266 }
1267}
1268
1269extern "C" __global__ void bms_flex_row_hvp_reduce(
1270 int num_chunks,
1271 int p_total,
1272 const double * __restrict__ partial, // [num_chunks, p_total]
1273 double * __restrict__ out) // [p_total]
1274{
1275 int j = blockIdx.x * blockDim.x + threadIdx.x;
1276 if (j >= p_total) return;
1277 double acc = 0.0;
1278 for (int c = 0; c < num_chunks; ++c) {
1279 acc += partial[(size_t)c * (size_t)p_total + (size_t)j];
1280 }
1281 out[j] = acc;
1282}
1283
1284extern "C" __global__ void bms_flex_row_hvp_multi_partial(
1285 int n_rows,
1286 int r,
1287 int p_m,
1288 int p_g,
1289 int p_total,
1290 int h_block_start,
1291 int h_block_len,
1292 int w_block_start,
1293 int w_block_len,
1294 int h_primary_start,
1295 int w_primary_start,
1296 int rows_per_cta,
1297 int rhs_count,
1298 const double * __restrict__ row_hessians, // [n, r*r]
1299 const double * __restrict__ marginal_design, // [n, p_m]
1300 const double * __restrict__ logslope_design, // [n, p_g]
1301 const double * __restrict__ v_rhs, // [rhs_count, p_total]
1302 double * __restrict__ partial) // [rhs_count, num_chunks, p_total]
1303{
1304 int chunk = blockIdx.x;
1305 int tid = threadIdx.x;
1306 int row_lo = chunk * rows_per_cta;
1307 int row_hi = row_lo + rows_per_cta;
1308 if (row_hi > n_rows) row_hi = n_rows;
1309
1310 int num_chunks = (n_rows + rows_per_cta - 1) / rows_per_cta;
1311 for (int idx = tid; idx < rhs_count * p_total; idx += blockDim.x) {
1312 int rhs = idx / p_total;
1313 int j = idx - rhs * p_total;
1314 partial[((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total + (size_t)j] = 0.0;
1315 }
1316 __syncthreads();
1317
1318 __shared__ double row_dir[MAX_MULTI_RHS * 32];
1319 __shared__ double action[MAX_MULTI_RHS * 32];
1320 __shared__ double dot_reduce[128];
1321
1322 for (int row = row_lo; row < row_hi; ++row) {
1323 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1324 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1325 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1326
1327 for (int rhs = 0; rhs < rhs_count; ++rhs) {
1328 const double *v = v_rhs + (size_t)rhs * (size_t)p_total;
1329
1330 double local = 0.0;
1331 for (int j = tid; j < p_m; j += blockDim.x) {
1332 local += mrow[j] * v[j];
1333 }
1334 dot_reduce[tid] = local;
1335 __syncthreads();
1336 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1337 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1338 __syncthreads();
1339 }
1340 if (tid == 0) row_dir[rhs * 32 + 0] = dot_reduce[0];
1341
1342 local = 0.0;
1343 for (int j = tid; j < p_g; j += blockDim.x) {
1344 local += grow[j] * v[p_m + j];
1345 }
1346 dot_reduce[tid] = local;
1347 __syncthreads();
1348 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1349 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1350 __syncthreads();
1351 }
1352 if (tid == 0) {
1353 row_dir[rhs * 32 + 1] = dot_reduce[0];
1354 for (int k = 0; k < h_block_len; ++k) {
1355 row_dir[rhs * 32 + h_primary_start + k] = v[h_block_start + k];
1356 }
1357 for (int k = 0; k < w_block_len; ++k) {
1358 row_dir[rhs * 32 + w_primary_start + k] = v[w_block_start + k];
1359 }
1360 }
1361 __syncthreads();
1362 }
1363
1364 for (int idx = tid; idx < rhs_count * r; idx += blockDim.x) {
1365 int rhs = idx / r;
1366 int u = idx - rhs * r;
1367 double acc = 0.0;
1368 const double *dir = row_dir + rhs * 32;
1369 for (int vv = 0; vv < r; ++vv) {
1370 acc += Hrow[u * r + vv] * dir[vv];
1371 }
1372 action[rhs * 32 + u] = acc;
1373 }
1374 __syncthreads();
1375
1376 for (int rhs = 0; rhs < rhs_count; ++rhs) {
1377 double *out = partial + ((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total;
1378 double a0 = action[rhs * 32 + 0];
1379 for (int j = tid; j < p_m; j += blockDim.x) {
1380 out[j] += a0 * mrow[j];
1381 }
1382 double a1 = action[rhs * 32 + 1];
1383 for (int j = tid; j < p_g; j += blockDim.x) {
1384 out[p_m + j] += a1 * grow[j];
1385 }
1386 if (tid == 0) {
1387 for (int k = 0; k < h_block_len; ++k) {
1388 out[h_block_start + k] += action[rhs * 32 + h_primary_start + k];
1389 }
1390 for (int k = 0; k < w_block_len; ++k) {
1391 out[w_block_start + k] += action[rhs * 32 + w_primary_start + k];
1392 }
1393 }
1394 __syncthreads();
1395 }
1396 }
1397}
1398
1399extern "C" __global__ void bms_flex_row_hvp_multi_reduce(
1400 int num_chunks,
1401 int p_total,
1402 int rhs_count,
1403 const double * __restrict__ partial, // [rhs_count, num_chunks, p_total]
1404 double * __restrict__ out) // [rhs_count, p_total]
1405{
1406 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1407 int total = rhs_count * p_total;
1408 if (idx >= total) return;
1409 int rhs = idx / p_total;
1410 int j = idx - rhs * p_total;
1411 double acc = 0.0;
1412 for (int c = 0; c < num_chunks; ++c) {
1413 acc += partial[((size_t)rhs * (size_t)num_chunks + (size_t)c) * (size_t)p_total + (size_t)j];
1414 }
1415 out[(size_t)rhs * (size_t)p_total + (size_t)j] = acc;
1416}
1417
1418extern "C" __global__ void bms_flex_row_diag_partial(
1419 int n_rows,
1420 int r,
1421 int p_m,
1422 int p_g,
1423 int p_total,
1424 int h_block_start,
1425 int h_block_len,
1426 int w_block_start,
1427 int w_block_len,
1428 int h_primary_start,
1429 int w_primary_start,
1430 int rows_per_cta,
1431 const double * __restrict__ row_hessians,
1432 const double * __restrict__ marginal_design,
1433 const double * __restrict__ logslope_design,
1434 double * __restrict__ partial)
1435{
1436 int chunk = blockIdx.x;
1437 int tid = threadIdx.x;
1438 int row_lo = chunk * rows_per_cta;
1439 int row_hi = row_lo + rows_per_cta;
1440 if (row_hi > n_rows) row_hi = n_rows;
1441
1442 double *out = partial + (size_t)chunk * (size_t)p_total;
1443 for (int j = tid; j < p_total; j += blockDim.x) {
1444 out[j] = 0.0;
1445 }
1446 __syncthreads();
1447
1448 for (int row = row_lo; row < row_hi; ++row) {
1449 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1450 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1451 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1452 double h00 = Hrow[0];
1453 double h11 = Hrow[1 * r + 1];
1454 for (int j = tid; j < p_m; j += blockDim.x) {
1455 double v = mrow[j];
1456 out[j] += h00 * v * v;
1457 }
1458 for (int j = tid; j < p_g; j += blockDim.x) {
1459 double v = grow[j];
1460 out[p_m + j] += h11 * v * v;
1461 }
1462 if (tid == 0) {
1463 for (int k = 0; k < h_block_len; ++k) {
1464 int ii = h_primary_start + k;
1465 out[h_block_start + k] += Hrow[ii * r + ii];
1466 }
1467 for (int k = 0; k < w_block_len; ++k) {
1468 int ii = w_primary_start + k;
1469 out[w_block_start + k] += Hrow[ii * r + ii];
1470 }
1471 }
1472 __syncthreads();
1473 }
1474}
1475
1476// ────────────────────────────────────────────────────────────────────────
1477// Phase 4 — SymmetricPackedUpper variants. Per-row storage is
1478// row_hessians_packed + (size_t)row * (size_t)(r*(r+1)/2)
1479// indexed as
1480// packed[(u*(2*r - u - 1))/2 + (v - u)] for u <= v
1481// with symmetric mirror for v < u.
1482// ────────────────────────────────────────────────────────────────────────
1483
1484// Helper: packed-upper index for (u, v) within a single row of r*(r+1)/2
1485// doubles. Caller must pre-swap so that u <= v.
1486__device__ __forceinline__ int bms_flex_packed_idx(int u, int v, int r) {
1487 // u*(2r - u - 1)/2 + (v - u)
1488 return (u * (2 * r - u - 1)) / 2 + (v - u);
1489}
1490
1491// Pack one row of the full row-major r×r Hessian into packed-upper layout.
1492// Launched as one CTA per row (gridDim.x = n_rows, blockDim.x configurable).
1493// Bit-equal copy: each upper-triangle entry is read once from the dense
1494// source and written once to the packed destination.
1495extern "C" __global__ void bms_flex_row_pack_upper(
1496 int n_rows,
1497 int r,
1498 const double * __restrict__ src_full, // [n, r*r]
1499 double * __restrict__ dst_packed) // [n, r*(r+1)/2]
1500{
1501 int row = blockIdx.x;
1502 if (row >= n_rows) return;
1503 int tid = threadIdx.x;
1504 int per_row = r * (r + 1) / 2;
1505 const double *src = src_full + (size_t)row * (size_t)r * (size_t)r;
1506 double *dst = dst_packed + (size_t)row * (size_t)per_row;
1507 // Linear scan over packed positions; map each back to (u, v).
1508 for (int pos = tid; pos < per_row; pos += blockDim.x) {
1509 // Invert: for u in [0, r), the range [u_start, u_start + (r - u))
1510 // contains positions for that u. u_start = u*(2r - u - 1)/2.
1511 // Solve smallest u with u*(2r - u - 1)/2 > pos to get u (then
1512 // back off by one); equivalent O(r) linear scan with r <= 32.
1513 int u = 0;
1514 int u_start = 0;
1515 while (u < r) {
1516 int next = u_start + (r - u);
1517 if (pos < next) break;
1518 u_start = next;
1519 ++u;
1520 }
1521 int v = u + (pos - u_start);
1522 dst[pos] = src[(size_t)u * (size_t)r + (size_t)v];
1523 }
1524}
1525
1526extern "C" __global__ void bms_flex_row_hvp_partial_packed(
1527 int n_rows,
1528 int r,
1529 int p_m,
1530 int p_g,
1531 int p_total,
1532 int h_block_start,
1533 int h_block_len,
1534 int w_block_start,
1535 int w_block_len,
1536 int h_primary_start,
1537 int w_primary_start,
1538 int rows_per_cta,
1539 const double * __restrict__ row_hessians_packed, // [n, r*(r+1)/2]
1540 const double * __restrict__ marginal_design,
1541 const double * __restrict__ logslope_design,
1542 const double * __restrict__ v,
1543 double * __restrict__ partial)
1544{
1545 int chunk = blockIdx.x;
1546 int tid = threadIdx.x;
1547 int row_lo = chunk * rows_per_cta;
1548 int row_hi = row_lo + rows_per_cta;
1549 if (row_hi > n_rows) row_hi = n_rows;
1550
1551 int per_row = r * (r + 1) / 2;
1552 double *out = partial + (size_t)chunk * (size_t)p_total;
1553 for (int j = tid; j < p_total; j += blockDim.x) {
1554 out[j] = 0.0;
1555 }
1556 __syncthreads();
1557
1558 __shared__ double row_dir[32];
1559 __shared__ double action[32];
1560 __shared__ double dot_reduce[128];
1561
1562 for (int row = row_lo; row < row_hi; ++row) {
1563 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1564 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1565 const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1566
1567 // row_dir[0] = mrow · v[0..p_m]
1568 double local = 0.0;
1569 for (int j = tid; j < p_m; j += blockDim.x) {
1570 local += mrow[j] * v[j];
1571 }
1572 dot_reduce[tid] = local;
1573 __syncthreads();
1574 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1575 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1576 __syncthreads();
1577 }
1578 if (tid == 0) row_dir[0] = dot_reduce[0];
1579
1580 // row_dir[1] = grow · v[p_m..p_m+p_g]
1581 local = 0.0;
1582 for (int j = tid; j < p_g; j += blockDim.x) {
1583 local += grow[j] * v[p_m + j];
1584 }
1585 dot_reduce[tid] = local;
1586 __syncthreads();
1587 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1588 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1589 __syncthreads();
1590 }
1591 if (tid == 0) row_dir[1] = dot_reduce[0];
1592
1593 if (tid == 0) {
1594 for (int k = 0; k < h_block_len; ++k) {
1595 row_dir[h_primary_start + k] = v[h_block_start + k];
1596 }
1597 for (int k = 0; k < w_block_len; ++k) {
1598 row_dir[w_primary_start + k] = v[w_block_start + k];
1599 }
1600 }
1601 __syncthreads();
1602
1603 // action[u] = Σ_w H[u, w] · row_dir[w], where H[u, w] reads from
1604 // packed-upper with (uu, vv) = (min(u, w), max(u, w)).
1605 if (tid < r) {
1606 double acc = 0.0;
1607 int u = tid;
1608 for (int w = 0; w < r; ++w) {
1609 int uu = u < w ? u : w;
1610 int vv = u < w ? w : u;
1611 acc += Hrow[bms_flex_packed_idx(uu, vv, r)] * row_dir[w];
1612 }
1613 action[tid] = acc;
1614 }
1615 __syncthreads();
1616
1617 double a0 = action[0];
1618 for (int j = tid; j < p_m; j += blockDim.x) {
1619 out[j] += a0 * mrow[j];
1620 }
1621 double a1 = action[1];
1622 for (int j = tid; j < p_g; j += blockDim.x) {
1623 out[p_m + j] += a1 * grow[j];
1624 }
1625 if (tid == 0) {
1626 for (int k = 0; k < h_block_len; ++k) {
1627 out[h_block_start + k] += action[h_primary_start + k];
1628 }
1629 for (int k = 0; k < w_block_len; ++k) {
1630 out[w_block_start + k] += action[w_primary_start + k];
1631 }
1632 }
1633 __syncthreads();
1634 }
1635}
1636
1637// ────────────────────────────────────────────────────────────────────────
1638// Phase 6 — dense joint-Hessian block kernel for the debug / exact-REML
1639// route. Materialises the full `[p_total, p_total]` row-major joint H
1640// from the per-row r×r Hessian via the P_i pullback. NOT the default
1641// Newton path: production Newton uses HVP (Phase 2/3); this kernel exists
1642// for exact-REML logdet / dense-H comparisons / diagnostic dumps where the
1643// caller genuinely needs the dense matrix on the device.
1644//
1645// Per-CTA partial: each CTA owns a contiguous chunk of rows
1646// `[chunk*rows_per_cta, (chunk+1)*rows_per_cta)`. Inside the CTA the
1647// per-row pullback computes `(P_i^T H_i P_i)[m, n]` and adds it to the
1648// CTA's shared-mem `[p_total, p_total]` partial. The reduce kernel sums
1649// chunk-major-fixed-order into a single `[p_total, p_total]` output.
1650//
1651// Math: for primary index u ∈ [0, r):
1652// * u = 0: phi_u = (X_i in slot 0..p_m, 0 elsewhere)
1653// * u = 1: phi_u = (0, G_i in slot p_m..p_m+p_g, 0 elsewhere)
1654// * u = 2+j: phi_u = e_{h_block_start + j} (j ∈ 0..h_block_len)
1655// * u = 2+h+l: phi_u = e_{w_block_start + l} (l ∈ 0..w_block_len)
1656// Then `H_full[m, n] += sum_{u,v} H_i[u,v] * phi_u[m] * phi_v[n]`.
1657//
1658// Shared-memory budget: at large-scale shape p_total = 44, a [44, 44] f64
1659// partial is 44*44*8 = 15.5 KiB — well below the V100 48 KiB/SM cap.
1660// At p_total ≤ 80 the kernel still fits (80*80*8 = 50 KiB → just over
1661// V100 cap; caller must enforce p_total ≤ DENSE_BLOCK_MAX_P). The
1662// launcher rejects oversize p_total cleanly.
1663
1664extern "C" __global__ void bms_flex_row_dense_block_partial(
1665 int n_rows,
1666 int r,
1667 int p_m,
1668 int p_g,
1669 int p_total,
1670 int h_block_start,
1671 int h_block_len,
1672 int w_block_start,
1673 int w_block_len,
1674 int h_primary_start,
1675 int w_primary_start,
1676 int rows_per_cta,
1677 const double * __restrict__ row_hessians, // [n, r*r]
1678 const double * __restrict__ marginal_design, // [n, p_m]
1679 const double * __restrict__ logslope_design, // [n, p_g]
1680 double * __restrict__ partial) // [num_chunks, p_total, p_total]
1681{
1682 extern __shared__ double shmem[];
1683 int chunk = blockIdx.x;
1684 int tid = threadIdx.x;
1685 int row_lo = chunk * rows_per_cta;
1686 int row_hi = row_lo + rows_per_cta;
1687 if (row_hi > n_rows) row_hi = n_rows;
1688
1689 int pp = p_total * p_total;
1690 double *acc = shmem; // CTA-private accumulator [p_total, p_total]
1691 for (int j = tid; j < pp; j += blockDim.x) acc[j] = 0.0;
1692 __syncthreads();
1693
1694 // Per-row work performed by thread 0 to avoid cross-thread RW
1695 // contention on `acc[]`. Per-row complexity is O(r * p_m + r * p_g
1696 // + r²): tractable because r ≤ 32 and p_m + p_g typically ≤ 64.
1697 // Tighter parallel implementations are possible (warp-stripe the
1698 // 4-way nested u-v-m-n loop) but Phase 6 is a debug-only path and
1699 // the simple version is easier to audit for correctness against
1700 // the host-side P_i pullback oracle.
1701 if (tid == 0) {
1702 for (int row = row_lo; row < row_hi; ++row) {
1703 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1704 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1705 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1706 for (int u = 0; u < r; ++u) {
1707 for (int v = 0; v < r; ++v) {
1708 double huv = Hrow[u * r + v];
1709 if (huv == 0.0) continue;
1710 // For each (u, v), iterate (m, n) over the non-zero
1711 // outer-product support of phi_u and phi_v.
1712 // Build a small (offset, len, src_ptr) descriptor for
1713 // each operand block as we go.
1714 int m_off, m_len; const double *m_src; bool m_indicator;
1715 int n_off, n_len; const double *n_src; bool n_indicator;
1716 if (u == 0) { m_off = 0; m_len = p_m; m_src = mrow; m_indicator = false; }
1717 else if (u == 1) { m_off = p_m; m_len = p_g; m_src = grow; m_indicator = false; }
1718 else if (u - 2 < h_block_len) {
1719 m_off = h_block_start + (u - 2);
1720 m_len = 1; m_src = NULL; m_indicator = true;
1721 } else {
1722 m_off = w_block_start + (u - 2 - h_block_len);
1723 m_len = 1; m_src = NULL; m_indicator = true;
1724 }
1725 if (v == 0) { n_off = 0; n_len = p_m; n_src = mrow; n_indicator = false; }
1726 else if (v == 1) { n_off = p_m; n_len = p_g; n_src = grow; n_indicator = false; }
1727 else if (v - 2 < h_block_len) {
1728 n_off = h_block_start + (v - 2);
1729 n_len = 1; n_src = NULL; n_indicator = true;
1730 } else {
1731 n_off = w_block_start + (v - 2 - h_block_len);
1732 n_len = 1; n_src = NULL; n_indicator = true;
1733 }
1734 // accumulate huv * phi_u[m] * phi_v[n] into acc[m, n]
1735 for (int mi = 0; mi < m_len; ++mi) {
1736 double pm = m_indicator ? 1.0 : m_src[mi];
1737 if (pm == 0.0) continue;
1738 double scaled = huv * pm;
1739 int m_idx = m_off + mi;
1740 for (int ni = 0; ni < n_len; ++ni) {
1741 double pn = n_indicator ? 1.0 : n_src[ni];
1742 int n_idx = n_off + ni;
1743 acc[m_idx * p_total + n_idx] += scaled * pn;
1744 }
1745 }
1746 }
1747 }
1748 }
1749 }
1750 __syncthreads();
1751
1752 // Write CTA accumulator out to global memory at its chunk slot.
1753 double *out_chunk = partial + (size_t)chunk * (size_t)pp;
1754 for (int j = tid; j < pp; j += blockDim.x) {
1755 out_chunk[j] = acc[j];
1756 }
1757}
1758
1759extern "C" __global__ void bms_flex_row_dense_block_reduce(
1760 int num_chunks,
1761 int p_total,
1762 const double * __restrict__ partial,
1763 double * __restrict__ out)
1764{
1765 int j = blockIdx.x * blockDim.x + threadIdx.x;
1766 int pp = p_total * p_total;
1767 if (j >= pp) return;
1768 double acc = 0.0;
1769 for (int c = 0; c < num_chunks; ++c) {
1770 acc += partial[(size_t)c * (size_t)pp + (size_t)j];
1771 }
1772 out[j] = acc;
1773}
1774
1775extern "C" __global__ void bms_flex_row_diag_partial_packed(
1776 int n_rows,
1777 int r,
1778 int p_m,
1779 int p_g,
1780 int p_total,
1781 int h_block_start,
1782 int h_block_len,
1783 int w_block_start,
1784 int w_block_len,
1785 int h_primary_start,
1786 int w_primary_start,
1787 int rows_per_cta,
1788 const double * __restrict__ row_hessians_packed,
1789 const double * __restrict__ marginal_design,
1790 const double * __restrict__ logslope_design,
1791 double * __restrict__ partial)
1792{
1793 int chunk = blockIdx.x;
1794 int tid = threadIdx.x;
1795 int row_lo = chunk * rows_per_cta;
1796 int row_hi = row_lo + rows_per_cta;
1797 if (row_hi > n_rows) row_hi = n_rows;
1798
1799 int per_row = r * (r + 1) / 2;
1800 double *out = partial + (size_t)chunk * (size_t)p_total;
1801 for (int j = tid; j < p_total; j += blockDim.x) {
1802 out[j] = 0.0;
1803 }
1804 __syncthreads();
1805
1806 for (int row = row_lo; row < row_hi; ++row) {
1807 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1808 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1809 const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1810 // Diagonal entry for (u, u) sits at packed_idx(u, u, r).
1811 double h00 = Hrow[bms_flex_packed_idx(0, 0, r)];
1812 double h11 = Hrow[bms_flex_packed_idx(1, 1, r)];
1813 for (int j = tid; j < p_m; j += blockDim.x) {
1814 double v = mrow[j];
1815 out[j] += h00 * v * v;
1816 }
1817 for (int j = tid; j < p_g; j += blockDim.x) {
1818 double v = grow[j];
1819 out[p_m + j] += h11 * v * v;
1820 }
1821 if (tid == 0) {
1822 for (int k = 0; k < h_block_len; ++k) {
1823 int ii = h_primary_start + k;
1824 out[h_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1825 }
1826 for (int k = 0; k < w_block_len; ++k) {
1827 int ii = w_primary_start + k;
1828 out[w_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1829 }
1830 }
1831 __syncthreads();
1832 }
1833}
1834"#;
1835
1836#[cfg(target_os = "linux")]
1837pub(crate) struct HvpKernelBackend {
1838 pub(crate) stream: Arc<CudaStream>,
1839 pub(crate) module: Arc<CudaModule>,
1840}
1841
1842#[cfg(target_os = "linux")]
1843impl HvpKernelBackend {
1844 pub(crate) fn probe() -> Result<&'static Self, GpuError> {
1845 static BACKEND: OnceLock<Result<HvpKernelBackend, GpuError>> = OnceLock::new();
1846 BACKEND
1847 .get_or_init(|| {
1848 gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row hvp", |parts| {
1849 let ptx = gam_gpu::device_cache::compile_ptx_arch(HVP_KERNEL_SOURCE)
1853 .map_err(|err| GpuError::DriverCallFailed {
1854 reason: format!("bms_flex_row hvp NVRTC compile failed: {err}"),
1855 })?;
1856 let module =
1857 parts
1858 .ctx
1859 .load_module(ptx)
1860 .map_err(|err| GpuError::DriverCallFailed {
1861 reason: format!("bms_flex_row hvp module load failed: {err}"),
1862 })?;
1863 Ok(HvpKernelBackend {
1864 stream: parts.stream.clone(),
1865 module,
1866 })
1867 })
1868 })
1869 .as_ref()
1870 .map_err(GpuError::clone)
1871 }
1872}
1873
1874#[cfg(target_os = "linux")]
1900pub(crate) fn launch_bms_flex_row_kernel_device_resident(
1901 inputs: BmsFlexRowKernelInputs<'_>,
1902 marginal_design_row_major: &[f64],
1903 logslope_design_row_major: &[f64],
1904 block: BmsFlexBlockLayout,
1905 primary: BmsFlexPrimaryLayout,
1906) -> Result<DeviceResidentRowHess, GpuError> {
1907 inputs.validate()?;
1908 if !s_f_diagnostic_finite(&inputs) {
1909 return Err(GpuError::DriverCallFailed {
1910 reason: format!(
1911 "bms_flex_row device-resident: s_f must be positive and finite, got {}",
1912 inputs.s_f
1913 ),
1914 });
1915 }
1916 let n = inputs.n_rows;
1917 let r = inputs.r;
1918 if marginal_design_row_major.len() != n * block.p_m {
1919 return Err(GpuError::DriverCallFailed {
1920 reason: format!(
1921 "bms_flex_row device-resident: marginal_design len={} != n*p_m={}",
1922 marginal_design_row_major.len(),
1923 n * block.p_m
1924 ),
1925 });
1926 }
1927 if logslope_design_row_major.len() != n * block.p_g {
1928 return Err(GpuError::DriverCallFailed {
1929 reason: format!(
1930 "bms_flex_row device-resident: logslope_design len={} != n*p_g={}",
1931 logslope_design_row_major.len(),
1932 n * block.p_g
1933 ),
1934 });
1935 }
1936 if primary.r != r {
1937 return Err(GpuError::DriverCallFailed {
1938 reason: format!(
1939 "bms_flex_row device-resident: primary.r={} != inputs.r={}",
1940 primary.r, r
1941 ),
1942 });
1943 }
1944
1945 let backend = RowKernelBackend::probe()?;
1948 HvpKernelBackend::probe()?;
1949 let stream = backend.stream.clone();
1950
1951 let upload_f64 = |slice: &[f64], label: &str| {
1952 stream
1953 .clone_htod(slice)
1954 .map_err(|err| GpuError::DriverCallFailed {
1955 reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1956 })
1957 };
1958 let upload_u32 = |slice: &[u32], label: &str| {
1959 stream
1960 .clone_htod(slice)
1961 .map_err(|err| GpuError::DriverCallFailed {
1962 reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1963 })
1964 };
1965
1966 let d_q = upload_f64(inputs.q, "q")?;
1967 let d_b = upload_f64(inputs.b, "b")?;
1968 let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
1969 let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
1970 let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
1971 let d_y = upload_f64(inputs.y, "y")?;
1972 let d_w = upload_f64(inputs.w, "w")?;
1973 let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
1974 let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
1975 let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
1976 let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
1977 let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
1978 let d_a = upload_f64(inputs.cell_a, "cell_a")?;
1979 let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
1980 let d_r = upload_f64(inputs.cell_r, "cell_r")?;
1981 let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
1982 let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
1983 let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
1984 let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
1985 let owned_host_moments: CudaSlice<f64>;
1987 let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
1988 CellMomentsSource::Host(slice) => {
1989 owned_host_moments = upload_f64(slice, "cell_moments")?;
1990 &owned_host_moments
1991 }
1992 CellMomentsSource::Device(d) => *d,
1993 };
1994 let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
1995 let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
1996 let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
1997 let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
1998 let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
1999
2000 let d_marginal = upload_f64(marginal_design_row_major, "marginal_design")?;
2001 let d_logslope = upload_f64(logslope_design_row_major, "logslope_design")?;
2002
2003 let mut d_neglog = stream
2004 .alloc_zeros::<f64>(n)
2005 .map_err(|err| GpuError::DriverCallFailed {
2006 reason: format!("bms_flex_row device-resident alloc neglog: {err}"),
2007 })?;
2008 let mut d_grad =
2009 stream
2010 .alloc_zeros::<f64>(n * r)
2011 .map_err(|err| GpuError::DriverCallFailed {
2012 reason: format!("bms_flex_row device-resident alloc grad: {err}"),
2013 })?;
2014 let mut d_hess =
2015 stream
2016 .alloc_zeros::<f64>(n * r * r)
2017 .map_err(|err| GpuError::DriverCallFailed {
2018 reason: format!("bms_flex_row device-resident alloc hess: {err}"),
2019 })?;
2020
2021 let func = backend
2022 .module
2023 .load_function("bms_flex_row_kernel")
2024 .map_err(|err| GpuError::DriverCallFailed {
2025 reason: format!("bms_flex_row device-resident load_function: {err}"),
2026 })?;
2027
2028 let cfg = LaunchConfig {
2029 grid_dim: (n as u32, 1, 1),
2030 block_dim: (ROW_KERNEL_THREADS, 1, 1),
2031 shared_mem_bytes: 0,
2032 };
2033 let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
2034 reason: format!("bms_flex_row device-resident: n_rows={n} exceeds i32 range"),
2035 })?;
2036 let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
2037 reason: format!("bms_flex_row device-resident: r={r} exceeds i32 range"),
2038 })?;
2039 let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
2040 reason: format!(
2041 "bms_flex_row device-resident: p_h={} exceeds i32 range",
2042 inputs.p_h
2043 ),
2044 })?;
2045 let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
2046 reason: format!(
2047 "bms_flex_row device-resident: p_w={} exceeds i32 range",
2048 inputs.p_w
2049 ),
2050 })?;
2051 let s_f_val = inputs.s_f;
2052
2053 let mut builder = stream.launch_builder(&func);
2054 builder
2055 .arg(&n_i32)
2056 .arg(&r_i32)
2057 .arg(&p_h_i32)
2058 .arg(&p_w_i32)
2059 .arg(&s_f_val)
2060 .arg(&d_q)
2061 .arg(&d_b)
2062 .arg(&d_mu1)
2063 .arg(&d_mu2)
2064 .arg(&d_zobs)
2065 .arg(&d_y)
2066 .arg(&d_w)
2067 .arg(&d_offsets)
2068 .arg(&d_c0)
2069 .arg(&d_c1)
2070 .arg(&d_c2)
2071 .arg(&d_c3)
2072 .arg(&d_a)
2073 .arg(&d_aa)
2074 .arg(&d_r)
2075 .arg(&d_ar)
2076 .arg(&d_sbb)
2077 .arg(&d_sbh)
2078 .arg(&d_sbw)
2079 .arg(d_moments_ref)
2080 .arg(&d_chi)
2081 .arg(&d_xi)
2082 .arg(&d_rho)
2083 .arg(&d_tau)
2084 .arg(&d_ruv)
2085 .arg(&mut d_neglog)
2086 .arg(&mut d_grad)
2087 .arg(&mut d_hess);
2088 unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
2093 reason: format!("bms_flex_row device-resident launch: {err}"),
2094 })?;
2095 stream
2096 .synchronize()
2097 .map_err(|err| GpuError::DriverCallFailed {
2098 reason: format!("bms_flex_row device-resident synchronize: {err}"),
2099 })?;
2100
2101 drop(d_neglog);
2109 drop(d_grad);
2110 drop(d_q);
2112 drop(d_b);
2113 drop(d_mu1);
2114 drop(d_mu2);
2115 drop(d_zobs);
2116 drop(d_y);
2117 drop(d_w);
2118 drop(d_offsets);
2119 drop(d_c0);
2120 drop(d_c1);
2121 drop(d_c2);
2122 drop(d_c3);
2123 drop(d_a);
2124 drop(d_aa);
2125 drop(d_r);
2126 drop(d_ar);
2127 drop(d_sbb);
2128 drop(d_sbh);
2129 drop(d_sbw);
2130 drop(d_chi);
2134 drop(d_xi);
2135 drop(d_rho);
2136 drop(d_tau);
2137 drop(d_ruv);
2138
2139 let bytes = ((n * r * r + marginal_design_row_major.len() + logslope_design_row_major.len())
2140 * std::mem::size_of::<f64>()) as u64;
2141 Ok(DeviceResidentRowHess {
2142 hess: d_hess,
2143 marginal_design: d_marginal,
2144 logslope_design: d_logslope,
2145 n,
2146 r,
2147 block,
2148 primary,
2149 bytes,
2150 })
2151}
2152
2153#[cfg(target_os = "linux")]
2158#[derive(Clone, Copy)]
2159pub(crate) enum BmsFlexRowLaunchMode {
2160 HvpDeviceOut,
2162 DiagonalHostOut,
2164}
2165
2166#[cfg(target_os = "linux")]
2167impl BmsFlexRowLaunchMode {
2168 pub(crate) fn partial_kernel_name(self) -> &'static str {
2170 match self {
2171 BmsFlexRowLaunchMode::HvpDeviceOut => "bms_flex_row_hvp_partial",
2172 BmsFlexRowLaunchMode::DiagonalHostOut => "bms_flex_row_diag_partial",
2173 }
2174 }
2175}
2176
2177#[cfg(target_os = "linux")]
2183pub(crate) struct PreparedBmsFlexRowLaunchArgs {
2184 pub(crate) n_i32: i32,
2185 pub(crate) r_i32: i32,
2186 pub(crate) p_m_i32: i32,
2187 pub(crate) p_g_i32: i32,
2188 pub(crate) p_total_i32: i32,
2189 pub(crate) h_block_start: i32,
2190 pub(crate) h_block_len: i32,
2191 pub(crate) w_block_start: i32,
2192 pub(crate) w_block_len: i32,
2193 pub(crate) h_primary_start: i32,
2194 pub(crate) w_primary_start: i32,
2195 pub(crate) rows_per_cta: i32,
2196 pub(crate) num_chunks: usize,
2197}
2198
2199#[cfg(target_os = "linux")]
2200impl PreparedBmsFlexRowLaunchArgs {
2201 pub(crate) fn from_storage(storage: &DeviceResidentRowHess) -> Self {
2202 let p_total = storage.block.p_total;
2203 let num_chunks = num_hvp_chunks(storage.n);
2204 PreparedBmsFlexRowLaunchArgs {
2205 n_i32: storage.n as i32,
2206 r_i32: storage.r as i32,
2207 p_m_i32: storage.block.p_m as i32,
2208 p_g_i32: storage.block.p_g as i32,
2209 p_total_i32: p_total as i32,
2210 h_block_start: storage
2211 .block
2212 .h
2213 .as_ref()
2214 .map(|r| r.start as i32)
2215 .unwrap_or(0),
2216 h_block_len: storage
2217 .block
2218 .h
2219 .as_ref()
2220 .map(|r| r.len() as i32)
2221 .unwrap_or(0),
2222 w_block_start: storage
2223 .block
2224 .w
2225 .as_ref()
2226 .map(|r| r.start as i32)
2227 .unwrap_or(0),
2228 w_block_len: storage
2229 .block
2230 .w
2231 .as_ref()
2232 .map(|r| r.len() as i32)
2233 .unwrap_or(0),
2234 h_primary_start: storage
2235 .primary
2236 .h
2237 .as_ref()
2238 .map(|r| r.start as i32)
2239 .unwrap_or(0),
2240 w_primary_start: storage
2241 .primary
2242 .w
2243 .as_ref()
2244 .map(|r| r.start as i32)
2245 .unwrap_or(0),
2246 rows_per_cta: HVP_ROWS_PER_CTA as i32,
2247 num_chunks,
2248 }
2249 }
2250}
2251
2252#[cfg(target_os = "linux")]
2266pub(crate) fn run_bms_flex_row_partial_reduce(
2267 storage: &DeviceResidentRowHess,
2268 mode: BmsFlexRowLaunchMode,
2269 d_v: Option<&CudaSlice<f64>>,
2270 d_out: &mut CudaSlice<f64>,
2271 ctx: &str,
2272) -> Result<(), GpuError> {
2273 let backend = HvpKernelBackend::probe()?;
2274 let stream = backend.stream.clone();
2275 let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2276 let p_total = storage.block.p_total;
2277
2278 let mut d_partial = stream
2279 .alloc_zeros::<f64>(args.num_chunks * p_total)
2280 .map_err(|err| GpuError::DriverCallFailed {
2281 reason: format!("bms_flex_row {ctx} alloc partial: {err}"),
2282 })?;
2283
2284 let partial_kernel_name = mode.partial_kernel_name();
2285 let part_func = backend
2286 .module
2287 .load_function(partial_kernel_name)
2288 .map_err(|err| GpuError::DriverCallFailed {
2289 reason: format!("bms_flex_row {ctx} load {partial_kernel_name}: {err}"),
2290 })?;
2291 let red_func = backend
2292 .module
2293 .load_function("bms_flex_row_hvp_reduce")
2294 .map_err(|err| GpuError::DriverCallFailed {
2295 reason: format!("bms_flex_row {ctx} load reduce: {err}"),
2296 })?;
2297
2298 let cfg_part = LaunchConfig {
2299 grid_dim: (args.num_chunks as u32, 1, 1),
2300 block_dim: (HVP_THREADS, 1, 1),
2301 shared_mem_bytes: 0,
2302 };
2303 let mut builder = stream.launch_builder(&part_func);
2304 builder
2305 .arg(&args.n_i32)
2306 .arg(&args.r_i32)
2307 .arg(&args.p_m_i32)
2308 .arg(&args.p_g_i32)
2309 .arg(&args.p_total_i32)
2310 .arg(&args.h_block_start)
2311 .arg(&args.h_block_len)
2312 .arg(&args.w_block_start)
2313 .arg(&args.w_block_len)
2314 .arg(&args.h_primary_start)
2315 .arg(&args.w_primary_start)
2316 .arg(&args.rows_per_cta)
2317 .arg(&storage.hess)
2318 .arg(&storage.marginal_design)
2319 .arg(&storage.logslope_design);
2320 if let Some(d_v) = d_v {
2321 builder.arg(d_v);
2322 }
2323 builder.arg(&mut d_partial);
2324 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2332 reason: format!("bms_flex_row {ctx} partial launch: {err}"),
2333 })?;
2334
2335 let red_threads: u32 = REDUCTION_THREADS;
2336 let red_blocks: u32 = ((p_total as u32) + red_threads - 1) / red_threads;
2337 let cfg_red = LaunchConfig {
2338 grid_dim: (red_blocks, 1, 1),
2339 block_dim: (red_threads, 1, 1),
2340 shared_mem_bytes: 0,
2341 };
2342 let num_chunks_i32 = args.num_chunks as i32;
2343 let mut builder = stream.launch_builder(&red_func);
2344 builder
2345 .arg(&num_chunks_i32)
2346 .arg(&args.p_total_i32)
2347 .arg(&d_partial)
2348 .arg(d_out);
2349 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2353 reason: format!("bms_flex_row {ctx} reduce launch: {err}"),
2354 })?;
2355 drop(d_partial);
2358 Ok(())
2359}
2360
2361#[cfg(target_os = "linux")]
2371pub(crate) fn launch_bms_flex_row_host(
2372 storage: &DeviceResidentRowHess,
2373 mode: BmsFlexRowLaunchMode,
2374 v: Option<&[f64]>,
2375 ctx: &str,
2376) -> Result<Vec<f64>, GpuError> {
2377 let p_total = storage.block.p_total;
2378 if let Some(v) = v {
2379 if v.len() != p_total {
2380 return Err(GpuError::DriverCallFailed {
2381 reason: format!(
2382 "bms_flex_row {ctx}: v.len()={} != p_total={p_total}",
2383 v.len()
2384 ),
2385 });
2386 }
2387 }
2388
2389 let backend = HvpKernelBackend::probe()?;
2390 let stream = backend.stream.clone();
2391
2392 let d_v = match v {
2393 Some(v) => Some(
2394 stream
2395 .clone_htod(v)
2396 .map_err(|err| GpuError::DriverCallFailed {
2397 reason: format!("bms_flex_row {ctx} upload v: {err}"),
2398 })?,
2399 ),
2400 None => None,
2401 };
2402 let mut d_out =
2403 stream
2404 .alloc_zeros::<f64>(p_total)
2405 .map_err(|err| GpuError::DriverCallFailed {
2406 reason: format!("bms_flex_row {ctx} alloc out: {err}"),
2407 })?;
2408
2409 run_bms_flex_row_partial_reduce(storage, mode, d_v.as_ref(), &mut d_out, ctx)?;
2410
2411 stream
2412 .synchronize()
2413 .map_err(|err| GpuError::DriverCallFailed {
2414 reason: format!("bms_flex_row {ctx} synchronize: {err}"),
2415 })?;
2416 stream
2417 .clone_dtoh(&d_out)
2418 .map_err(|err| GpuError::DriverCallFailed {
2419 reason: format!("bms_flex_row {ctx} download out: {err}"),
2420 })
2421}
2422
2423#[cfg(target_os = "linux")]
2424pub(crate) fn validate_bms_flex_row_hvp_multi_shape(
2425 storage: &DeviceResidentRowHess,
2426 rhs_count: usize,
2427 v_rhs_len: usize,
2428 out_len: Option<usize>,
2429 ctx: &str,
2430) -> Result<usize, GpuError> {
2431 if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2432 return Err(GpuError::DriverCallFailed {
2433 reason: format!(
2434 "bms_flex_row {ctx}: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2435 ),
2436 });
2437 }
2438 let p_total = storage.block.p_total;
2439 let rhs_elems = rhs_count
2440 .checked_mul(p_total)
2441 .ok_or_else(|| GpuError::DriverCallFailed {
2442 reason: format!(
2443 "bms_flex_row {ctx}: rhs_count({rhs_count})*p_total({p_total}) overflow"
2444 ),
2445 })?;
2446 if v_rhs_len != rhs_elems {
2447 return Err(GpuError::DriverCallFailed {
2448 reason: format!(
2449 "bms_flex_row {ctx}: v_rhs.len()={v_rhs_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2450 ),
2451 });
2452 }
2453 if let Some(out_len) = out_len
2454 && out_len != rhs_elems
2455 {
2456 return Err(GpuError::DriverCallFailed {
2457 reason: format!(
2458 "bms_flex_row {ctx}: out.len()={out_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2459 ),
2460 });
2461 }
2462 Ok(rhs_elems)
2463}
2464
2465#[cfg(target_os = "linux")]
2469pub fn bms_flex_row_hvp_multi_scratch_bytes_for_shape(
2470 n: usize,
2471 p_total: usize,
2472 rhs_count: usize,
2473) -> Result<u64, GpuError> {
2474 if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2475 return Err(GpuError::DriverCallFailed {
2476 reason: format!(
2477 "bms_flex_row hvp_multi_scratch_bytes: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2478 ),
2479 });
2480 }
2481 let num_chunks = num_hvp_chunks(n);
2482 let partial = rhs_count
2483 .checked_mul(num_chunks)
2484 .and_then(|v| v.checked_mul(p_total))
2485 .ok_or_else(|| GpuError::DriverCallFailed {
2486 reason: format!(
2487 "bms_flex_row hvp_multi_scratch_bytes: rhs_count({rhs_count})*num_chunks({num_chunks})*p_total({p_total}) overflow"
2488 ),
2489 })?;
2490 let rhs_vectors = rhs_count
2491 .checked_mul(p_total)
2492 .and_then(|v| v.checked_mul(2))
2493 .ok_or_else(|| GpuError::DriverCallFailed {
2494 reason: format!(
2495 "bms_flex_row hvp_multi_scratch_bytes: 2*rhs_count({rhs_count})*p_total({p_total}) overflow"
2496 ),
2497 })?;
2498 let elems = partial
2499 .checked_add(rhs_vectors)
2500 .ok_or_else(|| GpuError::DriverCallFailed {
2501 reason: "bms_flex_row hvp_multi_scratch_bytes: element count overflow".to_string(),
2502 })?;
2503 Ok((elems * std::mem::size_of::<f64>()) as u64)
2504}
2505
2506#[cfg(target_os = "linux")]
2507pub(crate) fn run_bms_flex_row_multi_partial_reduce(
2508 storage: &DeviceResidentRowHess,
2509 rhs_count: usize,
2510 d_v_rhs: &CudaSlice<f64>,
2511 d_out: &mut CudaSlice<f64>,
2512 ctx: &str,
2513) -> Result<(), GpuError> {
2514 let rhs_elems = validate_bms_flex_row_hvp_multi_shape(
2515 storage,
2516 rhs_count,
2517 d_v_rhs.len(),
2518 Some(d_out.len()),
2519 ctx,
2520 )?;
2521 let backend = HvpKernelBackend::probe()?;
2522 let stream = backend.stream.clone();
2523 let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2524 let p_total = storage.block.p_total;
2525 let partial_len = rhs_count
2526 .checked_mul(args.num_chunks)
2527 .and_then(|v| v.checked_mul(p_total))
2528 .ok_or_else(|| GpuError::DriverCallFailed {
2529 reason: format!(
2530 "bms_flex_row {ctx}: partial length overflow for rhs_count={rhs_count}, num_chunks={}, p_total={p_total}",
2531 args.num_chunks
2532 ),
2533 })?;
2534
2535 let mut d_partial =
2536 stream
2537 .alloc_zeros::<f64>(partial_len)
2538 .map_err(|err| GpuError::DriverCallFailed {
2539 reason: format!("bms_flex_row {ctx} alloc multi partial: {err}"),
2540 })?;
2541 let part_func = backend
2542 .module
2543 .load_function("bms_flex_row_hvp_multi_partial")
2544 .map_err(|err| GpuError::DriverCallFailed {
2545 reason: format!("bms_flex_row {ctx} load multi partial: {err}"),
2546 })?;
2547 let red_func = backend
2548 .module
2549 .load_function("bms_flex_row_hvp_multi_reduce")
2550 .map_err(|err| GpuError::DriverCallFailed {
2551 reason: format!("bms_flex_row {ctx} load multi reduce: {err}"),
2552 })?;
2553
2554 let rhs_count_i32 = i32::try_from(rhs_count).map_err(|_| GpuError::DriverCallFailed {
2555 reason: format!("bms_flex_row {ctx}: rhs_count={rhs_count} exceeds i32 range"),
2556 })?;
2557 let cfg_part = LaunchConfig {
2558 grid_dim: (args.num_chunks as u32, 1, 1),
2559 block_dim: (HVP_THREADS, 1, 1),
2560 shared_mem_bytes: 0,
2561 };
2562 let mut builder = stream.launch_builder(&part_func);
2563 builder
2564 .arg(&args.n_i32)
2565 .arg(&args.r_i32)
2566 .arg(&args.p_m_i32)
2567 .arg(&args.p_g_i32)
2568 .arg(&args.p_total_i32)
2569 .arg(&args.h_block_start)
2570 .arg(&args.h_block_len)
2571 .arg(&args.w_block_start)
2572 .arg(&args.w_block_len)
2573 .arg(&args.h_primary_start)
2574 .arg(&args.w_primary_start)
2575 .arg(&args.rows_per_cta)
2576 .arg(&rhs_count_i32)
2577 .arg(&storage.hess)
2578 .arg(&storage.marginal_design)
2579 .arg(&storage.logslope_design)
2580 .arg(d_v_rhs)
2581 .arg(&mut d_partial);
2582 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2587 reason: format!("bms_flex_row {ctx} multi partial launch: {err}"),
2588 })?;
2589
2590 let red_threads: u32 = REDUCTION_THREADS;
2591 let red_blocks: u32 = ((rhs_elems as u32) + red_threads - 1) / red_threads;
2592 let cfg_red = LaunchConfig {
2593 grid_dim: (red_blocks, 1, 1),
2594 block_dim: (red_threads, 1, 1),
2595 shared_mem_bytes: 0,
2596 };
2597 let num_chunks_i32 = args.num_chunks as i32;
2598 let mut builder = stream.launch_builder(&red_func);
2599 builder
2600 .arg(&num_chunks_i32)
2601 .arg(&args.p_total_i32)
2602 .arg(&rhs_count_i32)
2603 .arg(&d_partial)
2604 .arg(d_out);
2605 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2608 reason: format!("bms_flex_row {ctx} multi reduce launch: {err}"),
2609 })?;
2610 drop(d_partial);
2611 Ok(())
2612}
2613
2614#[cfg(target_os = "linux")]
2617pub(crate) fn launch_bms_flex_row_hvp_multi(
2618 storage: &DeviceResidentRowHess,
2619 v_rhs: &[f64],
2620 rhs_count: usize,
2621) -> Result<Vec<f64>, GpuError> {
2622 let rhs_elems =
2623 validate_bms_flex_row_hvp_multi_shape(storage, rhs_count, v_rhs.len(), None, "hvp_multi")?;
2624 let backend = HvpKernelBackend::probe()?;
2625 let stream = backend.stream.clone();
2626 let d_v_rhs = stream
2627 .clone_htod(v_rhs)
2628 .map_err(|err| GpuError::DriverCallFailed {
2629 reason: format!("bms_flex_row hvp_multi upload v_rhs: {err}"),
2630 })?;
2631 let mut d_out =
2632 stream
2633 .alloc_zeros::<f64>(rhs_elems)
2634 .map_err(|err| GpuError::DriverCallFailed {
2635 reason: format!("bms_flex_row hvp_multi alloc out: {err}"),
2636 })?;
2637 run_bms_flex_row_multi_partial_reduce(storage, rhs_count, &d_v_rhs, &mut d_out, "hvp_multi")?;
2638 stream
2639 .synchronize()
2640 .map_err(|err| GpuError::DriverCallFailed {
2641 reason: format!("bms_flex_row hvp_multi synchronize: {err}"),
2642 })?;
2643 stream
2644 .clone_dtoh(&d_out)
2645 .map_err(|err| GpuError::DriverCallFailed {
2646 reason: format!("bms_flex_row hvp_multi download out: {err}"),
2647 })
2648}
2649
2650#[cfg(target_os = "linux")]
2661pub(crate) fn launch_bms_flex_row_hvp_into_device(
2662 storage: &DeviceResidentRowHess,
2663 d_v: &CudaSlice<f64>,
2664 d_out: &mut CudaSlice<f64>,
2665) -> Result<(), GpuError> {
2666 let p_total = storage.block.p_total;
2667 if d_v.len() != p_total {
2668 return Err(GpuError::DriverCallFailed {
2669 reason: format!(
2670 "bms_flex_row hvp_into_device: d_v.len()={} != p_total={}",
2671 d_v.len(),
2672 p_total
2673 ),
2674 });
2675 }
2676 if d_out.len() != p_total {
2677 return Err(GpuError::DriverCallFailed {
2678 reason: format!(
2679 "bms_flex_row hvp_into_device: d_out.len()={} != p_total={}",
2680 d_out.len(),
2681 p_total
2682 ),
2683 });
2684 }
2685 run_bms_flex_row_partial_reduce(
2689 storage,
2690 BmsFlexRowLaunchMode::HvpDeviceOut,
2691 Some(d_v),
2692 d_out,
2693 "hvp_into_device",
2694 )
2695}
2696
2697#[cfg(target_os = "linux")]
2700pub(crate) fn launch_bms_flex_row_hvp(
2701 storage: &DeviceResidentRowHess,
2702 v: &[f64],
2703) -> Result<Vec<f64>, GpuError> {
2704 launch_bms_flex_row_hvp_multi(storage, v, 1)
2705}
2706
2707#[cfg(target_os = "linux")]
2710pub(crate) fn launch_bms_flex_row_diagonal(
2711 storage: &DeviceResidentRowHess,
2712) -> Result<Vec<f64>, GpuError> {
2713 launch_bms_flex_row_host(storage, BmsFlexRowLaunchMode::DiagonalHostOut, None, "diag")
2714}
2715
2716#[cfg(target_os = "linux")]
2722pub(crate) const DENSE_BLOCK_MAX_P: usize = 72;
2723
2724#[cfg(target_os = "linux")]
2730pub(crate) const DENSE_BLOCK_ROWS_PER_CTA: u32 = 32;
2731
2732#[cfg(target_os = "linux")]
2749pub fn launch_bms_flex_row_dense_block(
2750 storage: &DeviceResidentRowHess,
2751) -> Result<Vec<f64>, GpuError> {
2752 let p_total = storage.block.p_total;
2753 if p_total == 0 {
2754 return Err(GpuError::DriverCallFailed {
2755 reason: "bms_flex_row dense_block: p_total must be > 0".to_string(),
2756 });
2757 }
2758 if p_total > DENSE_BLOCK_MAX_P {
2759 return Err(GpuError::DriverCallFailed {
2760 reason: format!(
2761 "bms_flex_row dense_block: p_total={p_total} exceeds DENSE_BLOCK_MAX_P={DENSE_BLOCK_MAX_P} \
2762 (per-CTA shmem accumulator p²*8 bytes would exceed V100's 48 KiB/block)"
2763 ),
2764 });
2765 }
2766 let backend = HvpKernelBackend::probe()?;
2767 let stream = backend.stream.clone();
2768 let n = storage.n;
2769 let r = storage.r;
2770 let rows_per_cta = DENSE_BLOCK_ROWS_PER_CTA as usize;
2771 let num_chunks = n.div_ceil(rows_per_cta);
2772 let pp = p_total * p_total;
2773
2774 let mut d_partial =
2775 stream
2776 .alloc_zeros::<f64>(num_chunks * pp)
2777 .map_err(|err| GpuError::DriverCallFailed {
2778 reason: format!("bms_flex_row dense_block alloc partial: {err}"),
2779 })?;
2780 let mut d_out = stream
2781 .alloc_zeros::<f64>(pp)
2782 .map_err(|err| GpuError::DriverCallFailed {
2783 reason: format!("bms_flex_row dense_block alloc out: {err}"),
2784 })?;
2785
2786 let part_func = backend
2787 .module
2788 .load_function("bms_flex_row_dense_block_partial")
2789 .map_err(|err| GpuError::DriverCallFailed {
2790 reason: format!("bms_flex_row dense_block load partial: {err}"),
2791 })?;
2792 let red_func = backend
2793 .module
2794 .load_function("bms_flex_row_dense_block_reduce")
2795 .map_err(|err| GpuError::DriverCallFailed {
2796 reason: format!("bms_flex_row dense_block load reduce: {err}"),
2797 })?;
2798
2799 let n_i32 = n as i32;
2800 let r_i32 = r as i32;
2801 let p_m_i32 = storage.block.p_m as i32;
2802 let p_g_i32 = storage.block.p_g as i32;
2803 let p_total_i32 = p_total as i32;
2804 let h_block_start = storage
2805 .block
2806 .h
2807 .as_ref()
2808 .map(|r| r.start as i32)
2809 .unwrap_or(0);
2810 let h_block_len = storage
2811 .block
2812 .h
2813 .as_ref()
2814 .map(|r| r.len() as i32)
2815 .unwrap_or(0);
2816 let w_block_start = storage
2817 .block
2818 .w
2819 .as_ref()
2820 .map(|r| r.start as i32)
2821 .unwrap_or(0);
2822 let w_block_len = storage
2823 .block
2824 .w
2825 .as_ref()
2826 .map(|r| r.len() as i32)
2827 .unwrap_or(0);
2828 let h_primary_start = storage
2829 .primary
2830 .h
2831 .as_ref()
2832 .map(|r| r.start as i32)
2833 .unwrap_or(0);
2834 let w_primary_start = storage
2835 .primary
2836 .w
2837 .as_ref()
2838 .map(|r| r.start as i32)
2839 .unwrap_or(0);
2840 let rows_per_cta_i32 = DENSE_BLOCK_ROWS_PER_CTA as i32;
2841 let num_chunks_u32 = num_chunks as u32;
2842
2843 let shmem_bytes: u32 =
2845 u32::try_from(pp * std::mem::size_of::<f64>()).map_err(|_| GpuError::DriverCallFailed {
2846 reason: format!("dense_block shmem bytes overflow u32 for p_total={p_total}"),
2847 })?;
2848
2849 let cfg_part = LaunchConfig {
2850 grid_dim: (num_chunks_u32, 1, 1),
2851 block_dim: (HVP_THREADS, 1, 1),
2852 shared_mem_bytes: shmem_bytes,
2853 };
2854 let mut builder = stream.launch_builder(&part_func);
2855 builder
2856 .arg(&n_i32)
2857 .arg(&r_i32)
2858 .arg(&p_m_i32)
2859 .arg(&p_g_i32)
2860 .arg(&p_total_i32)
2861 .arg(&h_block_start)
2862 .arg(&h_block_len)
2863 .arg(&w_block_start)
2864 .arg(&w_block_len)
2865 .arg(&h_primary_start)
2866 .arg(&w_primary_start)
2867 .arg(&rows_per_cta_i32)
2868 .arg(&storage.hess)
2869 .arg(&storage.marginal_design)
2870 .arg(&storage.logslope_design)
2871 .arg(&mut d_partial);
2872 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2876 reason: format!("bms_flex_row dense_block partial launch: {err}"),
2877 })?;
2878
2879 let red_threads: u32 = REDUCTION_THREADS;
2880 let red_blocks: u32 = ((pp as u32) + red_threads - 1) / red_threads;
2881 let cfg_red = LaunchConfig {
2882 grid_dim: (red_blocks, 1, 1),
2883 block_dim: (red_threads, 1, 1),
2884 shared_mem_bytes: 0,
2885 };
2886 let num_chunks_i32 = num_chunks as i32;
2887 let mut builder = stream.launch_builder(&red_func);
2888 builder
2889 .arg(&num_chunks_i32)
2890 .arg(&p_total_i32)
2891 .arg(&d_partial)
2892 .arg(&mut d_out);
2893 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2895 reason: format!("bms_flex_row dense_block reduce launch: {err}"),
2896 })?;
2897 stream
2898 .synchronize()
2899 .map_err(|err| GpuError::DriverCallFailed {
2900 reason: format!("bms_flex_row dense_block sync: {err}"),
2901 })?;
2902 stream
2903 .clone_dtoh(&d_out)
2904 .map_err(|err| GpuError::DriverCallFailed {
2905 reason: format!("bms_flex_row dense_block download: {err}"),
2906 })
2907}
2908
2909#[cfg(all(test, target_os = "linux"))]
2917mod tests {
2918 use super::*;
2919
2920 pub(crate) fn minimal_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
2921 BmsFlexRowKernelInputs {
2922 n_rows: 1,
2923 r: 4,
2924 p_h: 1,
2925 p_w: 1,
2926 q: &buffers.q,
2927 b: &buffers.b,
2928 mu_1: &buffers.mu_1,
2929 mu_2: &buffers.mu_2,
2930 z_obs: &buffers.z_obs,
2931 y: &buffers.y,
2932 w: &buffers.w,
2933 s_f: 1.0,
2934 cell_offsets: &buffers.cell_offsets,
2935 cell_c0: &buffers.cell_c0,
2936 cell_c1: &buffers.cell_c1,
2937 cell_c2: &buffers.cell_c2,
2938 cell_c3: &buffers.cell_c3,
2939 cell_a: &buffers.cell_a,
2940 cell_aa: &buffers.cell_aa,
2941 cell_r: &buffers.cell_r,
2942 cell_ar: &buffers.cell_ar,
2943 cell_sbb: &buffers.cell_sbb,
2944 cell_sbh: &buffers.cell_sbh,
2945 cell_sbw: &buffers.cell_sbw,
2946 cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
2947 chi_obs: &buffers.chi_obs,
2948 xi_obs: &buffers.xi_obs,
2949 rho_u: &buffers.rho_u,
2950 tau_u: &buffers.tau_u,
2951 r_uv: &buffers.r_uv,
2952 }
2953 }
2954
2955 pub(crate) struct TestBuffers {
2956 pub(crate) q: Vec<f64>,
2957 pub(crate) b: Vec<f64>,
2958 pub(crate) mu_1: Vec<f64>,
2959 pub(crate) mu_2: Vec<f64>,
2960 pub(crate) z_obs: Vec<f64>,
2961 pub(crate) y: Vec<f64>,
2962 pub(crate) w: Vec<f64>,
2963 pub(crate) cell_offsets: Vec<u32>,
2964 pub(crate) cell_c0: Vec<f64>,
2965 pub(crate) cell_c1: Vec<f64>,
2966 pub(crate) cell_c2: Vec<f64>,
2967 pub(crate) cell_c3: Vec<f64>,
2968 pub(crate) cell_a: Vec<f64>,
2969 pub(crate) cell_aa: Vec<f64>,
2970 pub(crate) cell_r: Vec<f64>,
2971 pub(crate) cell_ar: Vec<f64>,
2972 pub(crate) cell_sbb: Vec<f64>,
2973 pub(crate) cell_sbh: Vec<f64>,
2974 pub(crate) cell_sbw: Vec<f64>,
2975 pub(crate) cell_moments: Vec<f64>,
2976 pub(crate) chi_obs: Vec<f64>,
2977 pub(crate) xi_obs: Vec<f64>,
2978 pub(crate) rho_u: Vec<f64>,
2979 pub(crate) tau_u: Vec<f64>,
2980 pub(crate) r_uv: Vec<f64>,
2981 }
2982
2983 pub(crate) fn make_buffers(n_cells: u32, r: usize, p_h: usize, p_w: usize) -> TestBuffers {
2984 let cells = n_cells as usize;
2985 TestBuffers {
2986 q: vec![0.1; 1],
2987 b: vec![0.5; 1],
2988 mu_1: vec![0.3; 1],
2989 mu_2: vec![0.07; 1],
2990 z_obs: vec![0.0; 1],
2991 y: vec![1.0; 1],
2992 w: vec![1.0; 1],
2993 cell_offsets: vec![0, n_cells],
2994 cell_c0: vec![0.2; cells],
2995 cell_c1: vec![-0.1; cells],
2996 cell_c2: vec![0.05; cells],
2997 cell_c3: vec![-0.02; cells],
2998 cell_a: vec![0.1; cells * 4],
2999 cell_aa: vec![0.0; cells * 4],
3000 cell_r: vec![0.05; cells * (r - 1) * 4],
3001 cell_ar: vec![0.0; cells * (r - 1) * 4],
3002 cell_sbb: vec![0.0; cells * 4],
3003 cell_sbh: vec![0.0; cells * p_h * 4],
3004 cell_sbw: vec![0.0; cells * p_w * 4],
3005 cell_moments: vec![1.0; cells * MOMENT_STRIDE],
3006 chi_obs: vec![1.0; 1],
3007 xi_obs: vec![0.0; 1],
3008 rho_u: vec![0.0; r],
3009 tau_u: vec![0.0; r],
3010 r_uv: vec![0.0; r * r],
3011 }
3012 }
3013
3014 #[test]
3015 pub(crate) fn validate_accepts_minimal_inputs() {
3016 let buffers = make_buffers(2, 4, 1, 1);
3017 let inputs = minimal_inputs(&buffers);
3018 assert!(inputs.validate().is_ok());
3019 }
3020
3021 #[test]
3022 pub(crate) fn validate_rejects_r_above_max() {
3023 let r = MAX_R + 1;
3024 let p_h = (r - 2) / 2;
3025 let p_w = (r - 2) - p_h;
3026 let buffers = make_buffers(1, r, p_h, p_w);
3027 let bad_inputs = BmsFlexRowKernelInputs {
3028 r,
3029 p_h,
3030 p_w,
3031 rho_u: &buffers.rho_u, tau_u: &buffers.tau_u,
3033 r_uv: &buffers.r_uv,
3034 cell_r: &buffers.cell_r,
3035 cell_ar: &buffers.cell_ar,
3036 cell_sbh: &buffers.cell_sbh,
3037 cell_sbw: &buffers.cell_sbw,
3038 ..minimal_inputs(&buffers)
3039 };
3040 let err = bad_inputs.validate().expect_err("r > MAX_R must fail");
3041 let msg = err.to_string();
3042 assert!(msg.contains("MAX_R"), "expected MAX_R hint, got: {msg}");
3043 }
3044
3045 #[test]
3046 pub(crate) fn validate_rejects_mismatched_r_decomposition() {
3047 let buffers = make_buffers(1, 4, 1, 1);
3048 let bad_inputs = BmsFlexRowKernelInputs {
3049 r: 4,
3050 p_h: 1,
3051 p_w: 2, ..minimal_inputs(&buffers)
3053 };
3054 let err = bad_inputs
3055 .validate()
3056 .expect_err("inconsistent r vs p_h+p_w must fail");
3057 let msg = err.to_string();
3058 assert!(msg.contains("p_h"), "got: {msg}");
3059 assert!(msg.contains("p_w"), "got: {msg}");
3060 }
3061
3062 #[test]
3063 pub(crate) fn validate_rejects_non_monotone_offsets() {
3064 let mut buffers = make_buffers(2, 4, 1, 1);
3071 buffers.cell_offsets = vec![5, 2];
3072 let inputs = minimal_inputs(&buffers);
3073 let err = inputs
3074 .validate()
3075 .expect_err("non-monotone offsets must fail");
3076 let msg = err.to_string();
3077 assert!(msg.contains("monotone"), "got: {msg}");
3078 }
3079
3080 #[test]
3081 pub(crate) fn validate_rejects_mismatched_cell_moments_length() {
3082 let mut buffers = make_buffers(2, 4, 1, 1);
3083 buffers.cell_moments.pop(); let inputs = minimal_inputs(&buffers);
3085 let err = inputs.validate().expect_err("short cell_moments must fail");
3086 let msg = err.to_string();
3087 assert!(msg.contains("cell_moments"), "got: {msg}");
3088 }
3089
3090 #[test]
3091 pub(crate) fn launch_on_non_linux_reports_driver_library_unavailable() {
3092 #[cfg(target_os = "linux")]
3096 {
3097 let buffers = make_buffers(1, 4, 1, 1);
3104 let inputs = minimal_inputs(&buffers);
3105 match launch_bms_flex_row_kernel(inputs) {
3106 Ok(_) => { }
3107 Err(GpuError::DriverLibraryUnavailable { .. })
3108 | Err(GpuError::DriverCallFailed { .. })
3109 | Err(GpuError::DriverSymbolMissing { .. })
3110 | Err(GpuError::NoDeviceKernel { .. }) => { }
3111 Err(other) => panic!("unexpected GpuError variant: {other:?}"),
3112 }
3113 }
3114 #[cfg(not(target_os = "linux"))]
3115 {
3116 let buffers = make_buffers(1, 4, 1, 1);
3117 let inputs = minimal_inputs(&buffers);
3118 match launch_bms_flex_row_kernel(inputs) {
3119 Err(GpuError::DriverLibraryUnavailable { reason }) => {
3120 assert!(
3121 reason.contains("Linux-only"),
3122 "expected Linux-only hint, got: {reason}"
3123 );
3124 }
3125 other => panic!("expected DriverLibraryUnavailable on non-Linux, got {other:?}"),
3126 }
3127 }
3128 }
3129
3130 #[test]
3131 pub(crate) fn s_f_must_be_positive_and_finite() {
3132 let buffers = make_buffers(1, 4, 1, 1);
3133 let mut inputs = minimal_inputs(&buffers);
3134 inputs.s_f = 0.0;
3135 match launch_bms_flex_row_kernel(inputs) {
3136 Err(GpuError::DriverCallFailed { reason }) => {
3137 assert!(reason.contains("s_f"), "got: {reason}");
3138 }
3139 other => panic!("expected DriverCallFailed for s_f=0, got {other:?}"),
3140 }
3141 }
3142
3143 pub(crate) const ORACLE_INV_TWO_PI: f64 = 1.0 / std::f64::consts::TAU;
3158 pub(crate) const ORACLE_SQRT_2: f64 = std::f64::consts::SQRT_2;
3159 pub(crate) const ORACLE_INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
3160
3161 pub(crate) fn oracle_erfcx_nonnegative(x: f64) -> f64 {
3162 if !x.is_finite() {
3163 return if x > 0.0 { 0.0 } else { f64::INFINITY };
3164 }
3165 if x <= 0.0 {
3166 return 1.0;
3167 }
3168 if x < 26.0 {
3169 let mut xx = x * x;
3170 if xx > 700.0 {
3171 xx = 700.0;
3172 }
3173 return xx.exp() * gam_gpu::numerics_host::erfc(x);
3174 }
3175 let inv = 1.0 / x;
3176 let inv2 = inv * inv;
3177 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
3178 + 6.5625 * inv2 * inv2 * inv2 * inv2;
3179 let inv_sqrt_pi: f64 = 0.564_189_583_547_756_3;
3180 inv * poly * inv_sqrt_pi
3181 }
3182
3183 pub(crate) fn oracle_log_ndtr_and_mills(x: f64) -> (f64, f64) {
3184 if x == f64::INFINITY {
3185 return (0.0, 0.0);
3186 }
3187 if x == f64::NEG_INFINITY {
3188 return (f64::NEG_INFINITY, f64::INFINITY);
3189 }
3190 if x.is_nan() {
3191 return (x, x);
3192 }
3193 const ORACLE_LEFT_TAIL_X: f64 = -37.0;
3206 if x >= ORACLE_LEFT_TAIL_X {
3207 let mut cdf = 0.5 * gam_gpu::numerics_host::erfc(-x / ORACLE_SQRT_2);
3208 if cdf < 1e-300 {
3209 cdf = 1e-300;
3210 }
3211 if cdf > 1.0 {
3212 cdf = 1.0;
3213 }
3214 let pdf = ORACLE_INV_SQRT_2PI * (-0.5 * x * x).exp();
3215 (cdf.ln(), pdf / cdf)
3216 } else {
3217 let u = -x / ORACLE_SQRT_2;
3218 let mut ex = oracle_erfcx_nonnegative(u);
3219 if ex < 1e-300 {
3220 ex = 1e-300;
3221 }
3222 let log_cdf = -u * u + (0.5 * ex).ln();
3223 let sqrt_2_over_pi: f64 = 0.797_884_560_802_865_4;
3224 (log_cdf, sqrt_2_over_pi / ex)
3225 }
3226 }
3227
3228 pub(crate) fn cpu_oracle_outputs(
3233 inputs: &BmsFlexRowKernelInputs<'_>,
3234 ) -> BmsFlexRowKernelOutputs {
3235 let n = inputs.n_rows;
3236 let r = inputs.r;
3237 let p_h = inputs.p_h;
3238 let p_w = inputs.p_w;
3239 let mut neglog = vec![0.0_f64; n];
3240 let mut grad = vec![0.0_f64; n * r];
3241 let mut hess = vec![0.0_f64; n * r * r];
3242 let cell_moments_host = match &inputs.cell_moments {
3243 CellMomentsSource::Host(slice) => *slice,
3244 #[cfg(target_os = "linux")]
3245 CellMomentsSource::Device(_) => panic!(
3246 "cpu_oracle_outputs: cell_moments is device-resident; oracle \
3252 is a host-only sanity checker"
3253 ),
3254 };
3255
3256 for row in 0..n {
3257 let mut f_u = vec![0.0_f64; r];
3259 let mut f_au = vec![0.0_f64; r];
3260 let mut f_uv = vec![0.0_f64; r * r];
3261 let mut f_a = 0.0_f64;
3262 let mut f_aa = 0.0_f64;
3263
3264 let cell_lo = inputs.cell_offsets[row] as usize;
3265 let cell_hi = inputs.cell_offsets[row + 1] as usize;
3266 for c in cell_lo..cell_hi {
3267 let c_arr = [
3268 inputs.cell_c0[c],
3269 inputs.cell_c1[c],
3270 inputs.cell_c2[c],
3271 inputs.cell_c3[c],
3272 ];
3273 let m = &cell_moments_host[c * MOMENT_STRIDE..(c + 1) * MOMENT_STRIDE];
3274
3275 let mut t = [0.0_f64; 7];
3277 for (n_idx, t_slot) in t.iter_mut().enumerate() {
3278 let mut acc = 0.0_f64;
3279 for (e, c_e) in c_arr.iter().enumerate() {
3280 acc = c_e.mul_add(m[e + n_idx], acc);
3281 }
3282 *t_slot = acc * ORACLE_INV_TWO_PI;
3283 }
3284
3285 let d_of = |r_arr: &[f64]| -> f64 {
3286 ORACLE_INV_TWO_PI
3287 * (r_arr[0] * m[0] + r_arr[1] * m[1] + r_arr[2] * m[2] + r_arr[3] * m[3])
3288 };
3289 let q_of = |r_arr: &[f64], s_arr: &[f64]| -> f64 {
3290 (r_arr[0] * s_arr[0]) * t[0]
3291 + (r_arr[0] * s_arr[1] + r_arr[1] * s_arr[0]) * t[1]
3292 + (r_arr[0] * s_arr[2] + r_arr[1] * s_arr[1] + r_arr[2] * s_arr[0]) * t[2]
3293 + (r_arr[0] * s_arr[3]
3294 + r_arr[1] * s_arr[2]
3295 + r_arr[2] * s_arr[1]
3296 + r_arr[3] * s_arr[0])
3297 * t[3]
3298 + (r_arr[1] * s_arr[3] + r_arr[2] * s_arr[2] + r_arr[3] * s_arr[1]) * t[4]
3299 + (r_arr[2] * s_arr[3] + r_arr[3] * s_arr[2]) * t[5]
3300 + (r_arr[3] * s_arr[3]) * t[6]
3301 };
3302
3303 let a_c = &inputs.cell_a[c * 4..(c + 1) * 4];
3304 let aa_c = &inputs.cell_aa[c * 4..(c + 1) * 4];
3305 f_a += d_of(a_c);
3306 f_aa += d_of(aa_c) - q_of(a_c, a_c);
3307
3308 for u in 1..r {
3309 let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3310 let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3311 let ar_u = &inputs.cell_ar[r_u_off..r_u_off + 4];
3312 f_u[u] += d_of(r_u);
3313 f_au[u] += d_of(ar_u) - q_of(a_c, r_u);
3314 }
3315
3316 for u in 1..r {
3317 let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3318 let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3319 for v in u..r {
3320 let r_v_off = (c * (r - 1) + (v - 1)) * 4;
3321 let r_v = &inputs.cell_r[r_v_off..r_v_off + 4];
3322 let q_uv = q_of(r_u, r_v);
3323 let d_s = if u == 1 && v == 1 {
3324 let s_bb = &inputs.cell_sbb[c * 4..(c + 1) * 4];
3325 d_of(s_bb)
3326 } else if u == 1 && v >= 2 && v < 2 + p_h {
3327 let j = v - 2;
3328 let off = (c * p_h + j) * 4;
3329 let s_bh = &inputs.cell_sbh[off..off + 4];
3330 d_of(s_bh)
3331 } else if u == 1 && v >= 2 + p_h && v < r {
3332 let l = v - (2 + p_h);
3333 let off = (c * p_w + l) * 4;
3334 let s_bw = &inputs.cell_sbw[off..off + 4];
3335 d_of(s_bw)
3336 } else {
3337 0.0
3338 };
3339 f_uv[u * r + v] += d_s - q_uv;
3340 }
3341 }
3342 }
3343
3344 let mu_1 = inputs.mu_1[row];
3346 let mu_2 = inputs.mu_2[row];
3347 f_u[0] = -mu_1;
3348 f_au[0] = 0.0;
3349 for v in 0..r {
3350 f_uv[v] = 0.0;
3351 f_uv[v * r] = 0.0;
3352 }
3353 f_uv[0] = -mu_2;
3354
3355 if !f_a.is_finite() || f_a <= 0.0 {
3357 neglog[row] = f64::NAN;
3358 for slot in grad[row * r..(row + 1) * r].iter_mut() {
3359 *slot = f64::NAN;
3360 }
3361 for slot in hess[row * r * r..(row + 1) * r * r].iter_mut() {
3362 *slot = f64::NAN;
3363 }
3364 continue;
3365 }
3366 let inv_fa = 1.0 / f_a;
3367
3368 let mut a_u = vec![0.0_f64; r];
3370 a_u[0] = mu_1 * inv_fa;
3371 for u in 1..r {
3372 a_u[u] = -f_u[u] * inv_fa;
3373 }
3374 let mut a_uv = vec![0.0_f64; r * r];
3375 for u in 0..r {
3376 for v in u..r {
3377 let term = f_uv[u * r + v]
3378 + f_au[v] * a_u[u]
3379 + f_au[u] * a_u[v]
3380 + f_aa * a_u[u] * a_u[v];
3381 let val = -term * inv_fa;
3382 a_uv[u * r + v] = val;
3383 a_uv[v * r + u] = val;
3384 }
3385 }
3386
3387 let chi = inputs.chi_obs[row];
3389 let xi = inputs.xi_obs[row];
3390 let rho = &inputs.rho_u[row * r..(row + 1) * r];
3391 let tau = &inputs.tau_u[row * r..(row + 1) * r];
3392 let ruv = &inputs.r_uv[row * r * r..(row + 1) * r * r];
3393 let mut bar_e_u = vec![0.0_f64; r];
3394 for u in 0..r {
3395 bar_e_u[u] = chi * a_u[u] + rho[u];
3396 }
3397 let mut bar_e_uv = vec![0.0_f64; r * r];
3398 for u in 0..r {
3399 for v in u..r {
3400 let val = chi * a_uv[u * r + v]
3401 + xi * a_u[u] * a_u[v]
3402 + tau[u] * a_u[v]
3403 + a_u[u] * tau[v]
3404 + ruv[u * r + v];
3405 bar_e_uv[u * r + v] = val;
3406 if u != v {
3407 bar_e_uv[v * r + u] = val;
3408 }
3409 }
3410 }
3411
3412 let y = inputs.y[row];
3414 let w = inputs.w[row];
3415 let s = 2.0 * y - 1.0;
3416 let e_obs = bar_e_u[0];
3417 let m_arg = s * e_obs;
3418 let (log_cdf, lambda) = oracle_log_ndtr_and_mills(m_arg);
3419 let a_i = -w * s * lambda;
3420 let b_i = w * lambda * (m_arg + lambda);
3421 neglog[row] = -w * log_cdf;
3422 for u in 0..r {
3423 grad[row * r + u] = a_i * bar_e_u[u];
3424 }
3425 for u in 0..r {
3426 for v in u..r {
3427 let val = b_i * bar_e_u[u] * bar_e_u[v] + a_i * bar_e_uv[u * r + v];
3428 hess[row * r * r + u * r + v] = val;
3429 if u != v {
3430 hess[row * r * r + v * r + u] = val;
3431 }
3432 }
3433 }
3434 }
3435
3436 BmsFlexRowKernelOutputs { neglog, grad, hess }
3437 }
3438
3439 pub(crate) fn make_parity_buffers() -> TestBuffers {
3443 let n = 4_usize;
3444 let r = 5_usize;
3445 let p_h = 2_usize;
3446 let p_w = 1_usize;
3447 let row_cells: [u32; 4] = [2, 3, 4, 2];
3449 let mut cell_offsets = vec![0_u32; n + 1];
3450 for i in 0..n {
3451 cell_offsets[i + 1] = cell_offsets[i] + row_cells[i];
3452 }
3453 let total_cells = cell_offsets[n] as usize;
3454
3455 let f = |seed: usize| -> f64 {
3457 let x = ((seed.wrapping_mul(2_654_435_761)) & 0xFFFF) as f64 / 65_536.0;
3458 0.1 + 0.4 * x
3459 };
3460
3461 let q = (0..n).map(|i| 0.05 + 0.1 * (i as f64)).collect::<Vec<_>>();
3462 let b = (0..n).map(|i| 0.6 + 0.05 * (i as f64)).collect::<Vec<_>>();
3463 let mu_1 = (0..n).map(|i| 0.7 + 0.02 * (i as f64)).collect::<Vec<_>>();
3464 let mu_2 = (0..n).map(|i| 0.15 + 0.01 * (i as f64)).collect::<Vec<_>>();
3465 let z_obs = (0..n).map(|i| -0.2 + 0.1 * (i as f64)).collect::<Vec<_>>();
3466 let y = [1.0, 0.0, 1.0, 0.0].to_vec();
3467 let w = vec![1.0; n];
3468
3469 let cell_c0 = (0..total_cells).map(|c| f(c + 1001)).collect::<Vec<_>>();
3470 let cell_c1 = (0..total_cells)
3471 .map(|c| -f(c + 2002) * 0.5)
3472 .collect::<Vec<_>>();
3473 let cell_c2 = (0..total_cells).map(|c| f(c + 3003) * 0.2).collect();
3474 let cell_c3 = (0..total_cells).map(|c| -f(c + 4004) * 0.1).collect();
3475
3476 let cell_a = (0..total_cells * 4)
3477 .map(|i| f(i + 5005) * 0.3)
3478 .collect::<Vec<_>>();
3479 let cell_aa = (0..total_cells * 4)
3480 .map(|i| f(i + 6006) * 0.1)
3481 .collect::<Vec<_>>();
3482 let cell_r = (0..total_cells * (r - 1) * 4)
3483 .map(|i| f(i + 7007) * 0.2)
3484 .collect::<Vec<_>>();
3485 let cell_ar = (0..total_cells * (r - 1) * 4)
3486 .map(|i| f(i + 8008) * 0.05)
3487 .collect::<Vec<_>>();
3488 let cell_sbb = (0..total_cells * 4)
3489 .map(|i| f(i + 9009) * 0.08)
3490 .collect::<Vec<_>>();
3491 let cell_sbh = (0..total_cells * p_h * 4)
3492 .map(|i| f(i + 10_010) * 0.07)
3493 .collect::<Vec<_>>();
3494 let cell_sbw = (0..total_cells * p_w * 4)
3495 .map(|i| f(i + 11_011) * 0.06)
3496 .collect::<Vec<_>>();
3497 let cell_moments = (0..total_cells * MOMENT_STRIDE)
3498 .map(|i| 0.4 + 0.1 * f(i + 12_012))
3499 .collect::<Vec<_>>();
3500
3501 let chi_obs = (0..n).map(|i| 0.9 + 0.01 * (i as f64)).collect::<Vec<_>>();
3502 let xi_obs = (0..n).map(|i| 0.2 + 0.01 * (i as f64)).collect::<Vec<_>>();
3503 let rho_u = (0..n * r).map(|i| 0.03 * f(i + 13_013)).collect::<Vec<_>>();
3504 let tau_u = (0..n * r).map(|i| 0.02 * f(i + 14_014)).collect::<Vec<_>>();
3505 let r_uv = (0..n * r * r)
3506 .map(|i| 0.04 * f(i + 15_015))
3507 .collect::<Vec<_>>();
3508
3509 TestBuffers {
3510 q,
3511 b,
3512 mu_1,
3513 mu_2,
3514 z_obs,
3515 y,
3516 w,
3517 cell_offsets,
3518 cell_c0,
3519 cell_c1,
3520 cell_c2,
3521 cell_c3,
3522 cell_a,
3523 cell_aa,
3524 cell_r,
3525 cell_ar,
3526 cell_sbb,
3527 cell_sbh,
3528 cell_sbw,
3529 cell_moments,
3530 chi_obs,
3531 xi_obs,
3532 rho_u,
3533 tau_u,
3534 r_uv,
3535 }
3536 }
3537
3538 pub(crate) fn parity_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3539 BmsFlexRowKernelInputs {
3540 n_rows: 4,
3541 r: 5,
3542 p_h: 2,
3543 p_w: 1,
3544 q: &buffers.q,
3545 b: &buffers.b,
3546 mu_1: &buffers.mu_1,
3547 mu_2: &buffers.mu_2,
3548 z_obs: &buffers.z_obs,
3549 y: &buffers.y,
3550 w: &buffers.w,
3551 s_f: 1.0,
3552 cell_offsets: &buffers.cell_offsets,
3553 cell_c0: &buffers.cell_c0,
3554 cell_c1: &buffers.cell_c1,
3555 cell_c2: &buffers.cell_c2,
3556 cell_c3: &buffers.cell_c3,
3557 cell_a: &buffers.cell_a,
3558 cell_aa: &buffers.cell_aa,
3559 cell_r: &buffers.cell_r,
3560 cell_ar: &buffers.cell_ar,
3561 cell_sbb: &buffers.cell_sbb,
3562 cell_sbh: &buffers.cell_sbh,
3563 cell_sbw: &buffers.cell_sbw,
3564 cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3565 chi_obs: &buffers.chi_obs,
3566 xi_obs: &buffers.xi_obs,
3567 rho_u: &buffers.rho_u,
3568 tau_u: &buffers.tau_u,
3569 r_uv: &buffers.r_uv,
3570 }
3571 }
3572
3573 #[test]
3577 pub(crate) fn cpu_oracle_produces_finite_symmetric_hessian() {
3578 let buffers = make_parity_buffers();
3579 let inputs = parity_inputs(&buffers);
3580 inputs
3581 .validate()
3582 .expect("parity fixture must satisfy validate()");
3583 let out = cpu_oracle_outputs(&inputs);
3584 let n = inputs.n_rows;
3585 let r = inputs.r;
3586 assert_eq!(out.neglog.len(), n);
3587 assert_eq!(out.grad.len(), n * r);
3588 assert_eq!(out.hess.len(), n * r * r);
3589 for row in 0..n {
3590 assert!(
3591 out.neglog[row].is_finite(),
3592 "row {row}: neglog must be finite, got {}",
3593 out.neglog[row]
3594 );
3595 for u in 0..r {
3596 let g = out.grad[row * r + u];
3597 assert!(g.is_finite(), "row {row}: grad[{u}] = {g}");
3598 for v in 0..r {
3599 let huv = out.hess[row * r * r + u * r + v];
3600 let hvu = out.hess[row * r * r + v * r + u];
3601 assert!(huv.is_finite(), "row {row}: H[{u},{v}] = {huv}");
3602 assert_eq!(
3603 huv.to_bits(),
3604 hvu.to_bits(),
3605 "row {row}: H[{u},{v}] and H[{v},{u}] must be bit-identical"
3606 );
3607 }
3608 }
3609 }
3610 }
3611
3612 #[test]
3641 pub(crate) fn cpu_oracle_mills_layer_matches_finite_differences() {
3642 let neglog_of = |e: f64, y: f64, w: f64| -> f64 {
3645 let s = 2.0 * y - 1.0;
3646 let (log_cdf, _) = oracle_log_ndtr_and_mills(s * e);
3647 -w * log_cdf
3648 };
3649 let ab_of = |e: f64, y: f64, w: f64| -> (f64, f64) {
3652 let s = 2.0 * y - 1.0;
3653 let m_arg = s * e;
3654 let (_, lambda) = oracle_log_ndtr_and_mills(m_arg);
3655 let a_i = -w * s * lambda;
3656 let b_i = w * lambda * (m_arg + lambda);
3657 (a_i, b_i)
3658 };
3659
3660 let cases: [(f64, f64, f64); 12] = [
3665 (-1.6, 1.0, 1.0),
3666 (-0.7, 1.0, 1.0),
3667 (0.0, 1.0, 1.0),
3668 (0.9, 1.0, 1.0),
3669 (1.8, 1.0, 1.0),
3670 (-1.4, 0.0, 1.0),
3671 (-0.3, 0.0, 1.0),
3672 (0.0, 0.0, 1.0),
3673 (0.6, 0.0, 1.0),
3674 (1.5, 0.0, 1.0),
3675 (0.4, 1.0, 0.75),
3676 (-0.8, 0.0, 1.3),
3677 ];
3678 let h = 1e-3_f64;
3681 for (e, y, w) in cases {
3682 let (a_ana, b_ana) = ab_of(e, y, w);
3683
3684 let fp2 = neglog_of(e + 2.0 * h, y, w);
3685 let fp1 = neglog_of(e + h, y, w);
3686 let f0 = neglog_of(e, y, w);
3687 let fm1 = neglog_of(e - h, y, w);
3688 let fm2 = neglog_of(e - 2.0 * h, y, w);
3689
3690 let d1_fd = (-fp2 + 8.0 * fp1 - 8.0 * fm1 + fm2) / (12.0 * h);
3692 let d2_fd = (-fp2 + 16.0 * fp1 - 30.0 * f0 + 16.0 * fm1 - fm2) / (12.0 * h * h);
3694
3695 let a_abs = (a_ana - d1_fd).abs();
3696 let a_rel = a_abs / a_ana.abs().max(1.0);
3697 assert!(
3698 a_abs <= 5e-8 || a_rel <= 5e-8,
3699 "Mills A (∂neglog/∂e) drift at e={e} y={y} w={w}: \
3700 analytic={a_ana:.17e} fd={d1_fd:.17e} abs={a_abs:.3e} rel={a_rel:.3e}"
3701 );
3702
3703 let b_abs = (b_ana - d2_fd).abs();
3704 let b_rel = b_abs / b_ana.abs().max(1.0);
3705 assert!(
3706 b_abs <= 5e-6 || b_rel <= 5e-6,
3707 "Mills B (∂²neglog/∂e²) drift at e={e} y={y} w={w}: \
3708 analytic={b_ana:.17e} fd={d2_fd:.17e} abs={b_abs:.3e} rel={b_rel:.3e}"
3709 );
3710 }
3711 }
3712
3713 #[test]
3722 pub(crate) fn bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available() {
3723 #[cfg(not(target_os = "linux"))]
3724 {
3725 eprintln!(
3726 "[bms_flex_row parity] non-Linux host — skipping CUDA parity \
3727 (CPU oracle exercised by sibling test)"
3728 );
3729 return;
3730 }
3731 #[cfg(target_os = "linux")]
3732 {
3733 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
3734 eprintln!(
3735 "[bms_flex_row parity] no CUDA runtime — skipping device \
3736 parity (CPU oracle exercised by sibling test)"
3737 );
3738 return;
3739 };
3740 let buffers = make_parity_buffers();
3741 let inputs_cpu = parity_inputs(&buffers);
3742 inputs_cpu
3743 .validate()
3744 .expect("parity fixture must satisfy validate()");
3745 let cpu_out = cpu_oracle_outputs(&inputs_cpu);
3746
3747 let inputs_gpu = parity_inputs(&buffers);
3749 let gpu_out = match launch_bms_flex_row_kernel(inputs_gpu) {
3750 Ok(out) => out,
3751 Err(err) => panic!(
3752 "[bms_flex_row parity] launch failed on CUDA-selected host; \
3753 device/oracle parity must fail loudly on GPU CI: {err}"
3754 ),
3755 };
3756
3757 let n = inputs_cpu.n_rows;
3758 let r = inputs_cpu.r;
3759 let tol_abs = 1e-8_f64;
3760 let tol_rel = 1e-8_f64;
3761 let check_close = |label: &str, idx: usize, cpu: f64, gpu: f64| {
3762 if cpu.is_nan() || gpu.is_nan() {
3763 assert!(
3764 cpu.is_nan() && gpu.is_nan(),
3765 "{label}[{idx}]: NaN parity broke — cpu={cpu}, gpu={gpu}"
3766 );
3767 return;
3768 }
3769 let diff = (cpu - gpu).abs();
3770 let tol = tol_abs + tol_rel * cpu.abs();
3771 assert!(
3772 diff <= tol,
3773 "{label}[{idx}]: |cpu − gpu| = {diff:.3e} > tol = {tol:.3e}; \
3774 cpu={cpu:.17e}, gpu={gpu:.17e}"
3775 );
3776 };
3777 assert_eq!(cpu_out.neglog.len(), gpu_out.neglog.len());
3778 assert_eq!(cpu_out.grad.len(), gpu_out.grad.len());
3779 assert_eq!(cpu_out.hess.len(), gpu_out.hess.len());
3780 for (i, (&c, &g)) in cpu_out.neglog.iter().zip(gpu_out.neglog.iter()).enumerate() {
3781 check_close("neglog", i, c, g);
3782 }
3783 for (i, (&c, &g)) in cpu_out.grad.iter().zip(gpu_out.grad.iter()).enumerate() {
3784 check_close("grad", i, c, g);
3785 }
3786 for (i, (&c, &g)) in cpu_out.hess.iter().zip(gpu_out.hess.iter()).enumerate() {
3787 check_close("hess", i, c, g);
3788 }
3789 for row in 0..n {
3791 for u in 0..r {
3792 for v in 0..r {
3793 let a = gpu_out.hess[row * r * r + u * r + v];
3794 let bb = gpu_out.hess[row * r * r + v * r + u];
3795 assert_eq!(
3796 a.to_bits(),
3797 bb.to_bits(),
3798 "GPU row {row}: H[{u},{v}] ≠ H[{v},{u}] bit-for-bit"
3799 );
3800 }
3801 }
3802 }
3803 }
3804 }
3805
3806 #[test]
3807 pub(crate) fn kernel_source_mentions_cpu_parity_reference() {
3808 #[cfg(target_os = "linux")]
3813 assert!(ROW_KERNEL_BODY.contains("compute_row_analytic_flex_from_parts_into"));
3814 #[cfg(target_os = "linux")]
3815 assert!(ROW_KERNEL_BODY.contains("cell_first_derivative_from_moments"));
3816 }
3817
3818 pub(crate) fn cpu_oracle_bms_flex_row_hvp(
3823 row_hessians: &[f64],
3824 marginal_design: &[f64],
3825 logslope_design: &[f64],
3826 block: &BmsFlexBlockLayout,
3827 primary: &BmsFlexPrimaryLayout,
3828 n: usize,
3829 v: &[f64],
3830 ) -> Vec<f64> {
3831 let r = primary.r;
3832 let p_m = block.p_m;
3833 let p_g = block.p_g;
3834 assert_eq!(v.len(), block.p_total);
3835 assert_eq!(row_hessians.len(), n * r * r);
3836 assert_eq!(marginal_design.len(), n * p_m);
3837 assert_eq!(logslope_design.len(), n * p_g);
3838 let mut out = vec![0.0_f64; block.p_total];
3839 let mut row_dir = vec![0.0_f64; r];
3840 let mut action = vec![0.0_f64; r];
3841 for row in 0..n {
3842 let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
3843 let grow = &logslope_design[row * p_g..(row + 1) * p_g];
3844 let mut acc_q = 0.0_f64;
3845 for j in 0..p_m {
3846 acc_q += mrow[j] * v[j];
3847 }
3848 let mut acc_g = 0.0_f64;
3849 for j in 0..p_g {
3850 acc_g += grow[j] * v[p_m + j];
3851 }
3852 row_dir[0] = acc_q;
3853 row_dir[1] = acc_g;
3854 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3855 for (k, ii) in prange.clone().enumerate() {
3856 row_dir[ii] = v[brange.start + k];
3857 }
3858 }
3859 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3860 for (k, ii) in prange.clone().enumerate() {
3861 row_dir[ii] = v[brange.start + k];
3862 }
3863 }
3864 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3865 for u in 0..r {
3866 let mut acc = 0.0_f64;
3867 for v_idx in 0..r {
3868 acc += h_slice[u * r + v_idx] * row_dir[v_idx];
3869 }
3870 action[u] = acc;
3871 }
3872 let a0 = action[0];
3873 for j in 0..p_m {
3874 out[j] += a0 * mrow[j];
3875 }
3876 let a1 = action[1];
3877 for j in 0..p_g {
3878 out[p_m + j] += a1 * grow[j];
3879 }
3880 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3881 for (k, ii) in prange.clone().enumerate() {
3882 out[brange.start + k] += action[ii];
3883 }
3884 }
3885 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3886 for (k, ii) in prange.clone().enumerate() {
3887 out[brange.start + k] += action[ii];
3888 }
3889 }
3890 }
3891 out
3892 }
3893
3894 pub(crate) fn cpu_oracle_bms_flex_row_diagonal(
3895 row_hessians: &[f64],
3896 marginal_design: &[f64],
3897 logslope_design: &[f64],
3898 block: &BmsFlexBlockLayout,
3899 primary: &BmsFlexPrimaryLayout,
3900 n: usize,
3901 ) -> Vec<f64> {
3902 let r = primary.r;
3903 let p_m = block.p_m;
3904 let p_g = block.p_g;
3905 let mut out = vec![0.0_f64; block.p_total];
3906 for row in 0..n {
3907 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3908 let h00 = h_slice[0];
3909 let h11 = h_slice[r + 1];
3910 let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
3911 let grow = &logslope_design[row * p_g..(row + 1) * p_g];
3912 for j in 0..p_m {
3913 out[j] += h00 * mrow[j] * mrow[j];
3914 }
3915 for j in 0..p_g {
3916 out[p_m + j] += h11 * grow[j] * grow[j];
3917 }
3918 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3919 for (k, ii) in prange.clone().enumerate() {
3920 out[brange.start + k] += h_slice[ii * r + ii];
3921 }
3922 }
3923 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3924 for (k, ii) in prange.clone().enumerate() {
3925 out[brange.start + k] += h_slice[ii * r + ii];
3926 }
3927 }
3928 }
3929 out
3930 }
3931
3932 #[test]
3936 pub(crate) fn cpu_oracle_hvp_matches_hand_computation_no_hw() {
3937 let n = 4_usize;
3938 let r = 4_usize; let p_m = 2_usize;
3940 let p_g = 2_usize;
3941 let p_h_dim = 1_usize;
3942 let p_w_dim = 1_usize;
3943 let p_total = p_m + p_g + p_h_dim + p_w_dim;
3944 let block = BmsFlexBlockLayout {
3945 p_m,
3946 p_g,
3947 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
3948 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
3949 p_total,
3950 };
3951 let primary = BmsFlexPrimaryLayout {
3952 h: Some(2..3),
3953 w: Some(3..4),
3954 r,
3955 };
3956 let mut row_hessians = vec![0.0_f64; n * r * r];
3958 for row in 0..n {
3959 for u in 0..r {
3960 for v in u..r {
3961 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
3962 row_hessians[row * r * r + u * r + v] = val;
3963 row_hessians[row * r * r + v * r + u] = val;
3964 }
3965 }
3966 }
3967 let mut marginal = vec![0.0_f64; n * p_m];
3968 for row in 0..n {
3969 for j in 0..p_m {
3970 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
3971 }
3972 }
3973 let mut logslope = vec![0.0_f64; n * p_g];
3974 for row in 0..n {
3975 for j in 0..p_g {
3976 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
3977 }
3978 }
3979 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
3980 let out = cpu_oracle_bms_flex_row_hvp(
3981 &row_hessians,
3982 &marginal,
3983 &logslope,
3984 &block,
3985 &primary,
3986 n,
3987 &v,
3988 );
3989 let mut expect_out_0 = 0.0_f64;
3991 for row in 0..n {
3992 let mrow = &marginal[row * p_m..(row + 1) * p_m];
3993 let grow = &logslope[row * p_g..(row + 1) * p_g];
3994 let mut row_dir = vec![0.0_f64; r];
3995 row_dir[0] = mrow[0] * v[0] + mrow[1] * v[1];
3996 row_dir[1] = grow[0] * v[p_m] + grow[1] * v[p_m + 1];
3997 row_dir[2] = v[p_m + p_g];
3998 row_dir[3] = v[p_m + p_g + p_h_dim];
3999 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4000 let mut action0 = 0.0_f64;
4001 for vv in 0..r {
4005 action0 += h_slice[vv] * row_dir[vv];
4006 }
4007 expect_out_0 += action0 * mrow[0];
4008 }
4009 assert!(
4010 (out[0] - expect_out_0).abs() < 1e-12,
4011 "cpu oracle HVP out[0] mismatch: {} vs hand-check {}",
4012 out[0],
4013 expect_out_0
4014 );
4015 assert!(out.iter().all(|x| x.is_finite()));
4016 assert_eq!(out.len(), p_total);
4017 }
4018
4019 #[test]
4021 pub(crate) fn cpu_oracle_diagonal_matches_hand_computation() {
4022 let n = 3_usize;
4023 let r = 4_usize;
4024 let p_m = 2_usize;
4025 let p_g = 2_usize;
4026 let p_h_dim = 1_usize;
4027 let p_w_dim = 1_usize;
4028 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4029 let block = BmsFlexBlockLayout {
4030 p_m,
4031 p_g,
4032 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4033 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4034 p_total,
4035 };
4036 let primary = BmsFlexPrimaryLayout {
4037 h: Some(2..3),
4038 w: Some(3..4),
4039 r,
4040 };
4041 let mut row_hessians = vec![0.0_f64; n * r * r];
4042 for row in 0..n {
4043 for u in 0..r {
4044 row_hessians[row * r * r + u * r + u] = 1.0 + (row as f64) + (u as f64) * 0.5;
4045 }
4046 }
4047 let mut marginal = vec![0.0_f64; n * p_m];
4048 let mut logslope = vec![0.0_f64; n * p_g];
4049 for row in 0..n {
4050 for j in 0..p_m {
4051 marginal[row * p_m + j] = 0.2 + (row as f64) * 0.3 + (j as f64) * 0.1;
4052 }
4053 for j in 0..p_g {
4054 logslope[row * p_g + j] = -0.4 + (row as f64) * 0.1 + (j as f64) * 0.2;
4055 }
4056 }
4057 let out = cpu_oracle_bms_flex_row_diagonal(
4058 &row_hessians,
4059 &marginal,
4060 &logslope,
4061 &block,
4062 &primary,
4063 n,
4064 );
4065 let mut expect = 0.0_f64;
4067 for row in 0..n {
4068 let h00 = row_hessians[row * r * r];
4069 expect += h00 * marginal[row * p_m].powi(2);
4070 }
4071 assert!(
4072 (out[0] - expect).abs() < 1e-12,
4073 "out[0] {} vs {}",
4074 out[0],
4075 expect
4076 );
4077 let mut expect_h = 0.0_f64;
4079 for row in 0..n {
4080 expect_h += row_hessians[row * r * r + 2 * r + 2];
4081 }
4082 let h_slot = p_m + p_g;
4083 assert!(
4084 (out[h_slot] - expect_h).abs() < 1e-12,
4085 "h slot {} vs {}",
4086 out[h_slot],
4087 expect_h
4088 );
4089 }
4090
4091 #[test]
4096 pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_when_cuda_available() {
4097 #[cfg(not(target_os = "linux"))]
4098 {
4099 eprintln!(
4100 "[bms_flex_row hvp parity] non-Linux host — skipping CUDA parity \
4101 (CPU oracle exercised by sibling tests)"
4102 );
4103 }
4104 #[cfg(target_os = "linux")]
4105 {
4106 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4107 eprintln!(
4108 "[bms_flex_row hvp parity] no CUDA runtime — skipping device \
4109 parity"
4110 );
4111 return;
4112 };
4113 let n = 4_usize;
4114 let r = 4_usize;
4115 let p_m = 2_usize;
4116 let p_g = 2_usize;
4117 let p_h_dim = 1_usize;
4118 let p_w_dim = 1_usize;
4119 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4120 let block = BmsFlexBlockLayout {
4121 p_m,
4122 p_g,
4123 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4124 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4125 p_total,
4126 };
4127 let primary = BmsFlexPrimaryLayout {
4128 h: Some(2..3),
4129 w: Some(3..4),
4130 r,
4131 };
4132 let mut row_hessians = vec![0.0_f64; n * r * r];
4133 for row in 0..n {
4134 for u in 0..r {
4135 for v in u..r {
4136 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4137 row_hessians[row * r * r + u * r + v] = val;
4138 row_hessians[row * r * r + v * r + u] = val;
4139 }
4140 }
4141 }
4142 let mut marginal = vec![0.0_f64; n * p_m];
4143 for row in 0..n {
4144 for j in 0..p_m {
4145 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4146 }
4147 }
4148 let mut logslope = vec![0.0_f64; n * p_g];
4149 for row in 0..n {
4150 for j in 0..p_g {
4151 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4152 }
4153 }
4154 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4155 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4156 &row_hessians,
4157 &marginal,
4158 &logslope,
4159 &block,
4160 &primary,
4161 n,
4162 &v,
4163 );
4164 let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4165 &row_hessians,
4166 &marginal,
4167 &logslope,
4168 &block,
4169 &primary,
4170 n,
4171 );
4172
4173 let backend = HvpKernelBackend::probe()
4180 .expect("[bms_flex_row hvp parity] backend probe must succeed on CUDA host");
4181 let stream = backend.stream.clone();
4182 let d_h = stream
4183 .clone_htod(&row_hessians)
4184 .expect("[bms_flex_row hvp parity] upload h must succeed on CUDA host");
4185 let d_m = stream
4186 .clone_htod(&marginal)
4187 .expect("[bms_flex_row hvp parity] upload marg must succeed on CUDA host");
4188 let d_g = stream
4189 .clone_htod(&logslope)
4190 .expect("[bms_flex_row hvp parity] upload logslope must succeed on CUDA host");
4191 let storage = DeviceResidentRowHess {
4192 hess: d_h,
4193 marginal_design: d_m,
4194 logslope_design: d_g,
4195 n,
4196 r,
4197 block: block.clone(),
4198 primary: primary.clone(),
4199
4200 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4201 };
4202 let gpu_hvp =
4203 launch_bms_flex_row_hvp(&storage, &v).expect("HVP kernel must launch on CUDA host");
4204 let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4205 .expect("diagonal kernel must launch on CUDA host");
4206 assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4207 assert_eq!(gpu_diag.len(), cpu_diag.len());
4208 for i in 0..p_total {
4209 let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4210 assert!(
4211 diff <= 1e-10,
4212 "HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4213 cpu_hvp[i],
4214 gpu_hvp[i]
4215 );
4216 let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4217 assert!(
4218 ddiff <= 1e-10,
4219 "diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4220 cpu_diag[i],
4221 gpu_diag[i]
4222 );
4223 }
4224 }
4225 }
4226
4227 #[test]
4228 pub(crate) fn bms_flex_row_hvp_multi_scratch_is_bounded_at_large_scale_shape() {
4229 let n = 195_000_usize;
4230 let r = 20_usize;
4231 let p_total = 44_usize;
4232 let rhs_count = 4_usize;
4233 let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4234 .expect("large-scale multi-RHS scratch budget");
4235 let per_rhs_full_row_cache =
4236 (n * r * r * std::mem::size_of::<f64>()) as u64 * rhs_count as u64;
4237 assert!(
4238 scratch < per_rhs_full_row_cache / 100,
4239 "multi-RHS scratch must tile by row chunks instead of materializing \
4240 a row-Hessian copy per RHS: scratch={scratch} full_per_rhs={per_rhs_full_row_cache}"
4241 );
4242 assert!(
4243 bms_flex_row_hvp_multi_scratch_bytes_for_shape(
4244 n,
4245 p_total,
4246 BMS_FLEX_ROW_HVP_MAX_RHS + 1
4247 )
4248 .is_err(),
4249 "multi-RHS launch must reject unbounded RHS counts"
4250 );
4251 }
4252
4253 #[test]
4254 pub(crate) fn bms_flex_row_hvp_multi_kernel_matches_cpu_oracle_when_cuda_available() {
4255 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4256 eprintln!("[bms_flex_row hvp_multi parity] no CUDA runtime — skipping device parity");
4257 return;
4258 };
4259 let n = 5_usize;
4260 let r = 4_usize;
4261 let p_m = 2_usize;
4262 let p_g = 2_usize;
4263 let p_h_dim = 1_usize;
4264 let p_w_dim = 1_usize;
4265 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4266 let rhs_count = 3_usize;
4267 let block = BmsFlexBlockLayout {
4268 p_m,
4269 p_g,
4270 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4271 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4272 p_total,
4273 };
4274 let primary = BmsFlexPrimaryLayout {
4275 h: Some(2..3),
4276 w: Some(3..4),
4277 r,
4278 };
4279 let mut row_hessians = vec![0.0_f64; n * r * r];
4280 for row in 0..n {
4281 for u in 0..r {
4282 for v in u..r {
4283 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4284 row_hessians[row * r * r + u * r + v] = val;
4285 row_hessians[row * r * r + v * r + u] = val;
4286 }
4287 }
4288 }
4289 let mut marginal = vec![0.0_f64; n * p_m];
4290 let mut logslope = vec![0.0_f64; n * p_g];
4291 for row in 0..n {
4292 for j in 0..p_m {
4293 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4294 }
4295 for j in 0..p_g {
4296 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4297 }
4298 }
4299 let mut v_rhs = vec![0.0_f64; rhs_count * p_total];
4300 for rhs in 0..rhs_count {
4301 for j in 0..p_total {
4302 let seed = (rhs as f64) * 0.37 + (j as f64) * 0.19 + 0.4;
4303 v_rhs[rhs * p_total + j] = seed.sin() * 0.4 + seed.cos() * 0.2;
4304 }
4305 }
4306
4307 let backend = HvpKernelBackend::probe()
4311 .expect("[bms_flex_row hvp_multi parity] backend probe must succeed on CUDA host");
4312 let stream = backend.stream.clone();
4313 let d_h = stream
4314 .clone_htod(&row_hessians)
4315 .expect("[bms_flex_row hvp_multi parity] upload h must succeed on CUDA host");
4316 let d_m = stream
4317 .clone_htod(&marginal)
4318 .expect("[bms_flex_row hvp_multi parity] upload marg must succeed on CUDA host");
4319 let d_g = stream
4320 .clone_htod(&logslope)
4321 .expect("[bms_flex_row hvp_multi parity] upload logslope must succeed on CUDA host");
4322 let storage = DeviceResidentRowHess {
4323 hess: d_h,
4324 marginal_design: d_m,
4325 logslope_design: d_g,
4326 n,
4327 r,
4328 block: block.clone(),
4329 primary: primary.clone(),
4330
4331 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4332 };
4333 let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4334 .expect("storage scratch budget");
4335 assert!(
4336 scratch < storage.bytes,
4337 "multi-RHS scratch should stay below resident cache bytes"
4338 );
4339 let gpu = launch_bms_flex_row_hvp_multi(&storage, &v_rhs, rhs_count)
4340 .expect("multi-RHS HVP kernel must launch on CUDA host");
4341 assert_eq!(gpu.len(), rhs_count * p_total);
4342 for rhs in 0..rhs_count {
4343 let v = &v_rhs[rhs * p_total..(rhs + 1) * p_total];
4344 let cpu = cpu_oracle_bms_flex_row_hvp(
4345 &row_hessians,
4346 &marginal,
4347 &logslope,
4348 &block,
4349 &primary,
4350 n,
4351 v,
4352 );
4353 let single = launch_bms_flex_row_hvp(&storage, v)
4354 .expect("single-RHS HVP kernel must launch on CUDA host");
4355 for j in 0..p_total {
4356 let got = gpu[rhs * p_total + j];
4357 let diff = (cpu[j] - got).abs();
4358 assert!(
4359 diff <= 1e-10,
4360 "multi-RHS HVP rhs={rhs} j={j}: cpu={} gpu={} |diff|={diff:.3e}",
4361 cpu[j],
4362 got
4363 );
4364 assert_eq!(
4365 got, single[j],
4366 "multi-RHS and single-RHS host launch diverged at rhs={rhs} j={j}"
4367 );
4368 }
4369 }
4370 }
4371
4372 #[test]
4383 pub(crate) fn bms_flex_row_hvp_into_device_matches_cpu_oracle_and_host_out() {
4384 #[cfg(not(target_os = "linux"))]
4385 {
4386 eprintln!(
4387 "[bms_flex_row hvp_into_device parity] non-Linux host — skipping \
4388 CUDA parity (CPU oracle exercised by sibling tests)"
4389 );
4390 }
4391 #[cfg(target_os = "linux")]
4392 {
4393 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4394 eprintln!(
4395 "[bms_flex_row hvp_into_device parity] no CUDA runtime — \
4396 skipping device parity"
4397 );
4398 return;
4399 };
4400 let n = 4_usize;
4401 let r = 4_usize;
4402 let p_m = 2_usize;
4403 let p_g = 2_usize;
4404 let p_h_dim = 1_usize;
4405 let p_w_dim = 1_usize;
4406 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4407 let block = BmsFlexBlockLayout {
4408 p_m,
4409 p_g,
4410 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4411 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4412 p_total,
4413 };
4414 let primary = BmsFlexPrimaryLayout {
4415 h: Some(2..3),
4416 w: Some(3..4),
4417 r,
4418 };
4419 let mut row_hessians = vec![0.0_f64; n * r * r];
4420 for row in 0..n {
4421 for u in 0..r {
4422 for v in u..r {
4423 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4424 row_hessians[row * r * r + u * r + v] = val;
4425 row_hessians[row * r * r + v * r + u] = val;
4426 }
4427 }
4428 }
4429 let mut marginal = vec![0.0_f64; n * p_m];
4430 for row in 0..n {
4431 for j in 0..p_m {
4432 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4433 }
4434 }
4435 let mut logslope = vec![0.0_f64; n * p_g];
4436 for row in 0..n {
4437 for j in 0..p_g {
4438 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4439 }
4440 }
4441 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4442 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4443 &row_hessians,
4444 &marginal,
4445 &logslope,
4446 &block,
4447 &primary,
4448 n,
4449 &v,
4450 );
4451
4452 let backend = HvpKernelBackend::probe().expect(
4455 "[bms_flex_row hvp_into_device parity] backend probe must succeed on CUDA host",
4456 );
4457 let stream = backend.stream.clone();
4458 let d_h = stream
4459 .clone_htod(&row_hessians)
4460 .expect("[bms_flex_row hvp_into_device parity] upload h must succeed on CUDA host");
4461 let d_m = stream.clone_htod(&marginal).expect(
4462 "[bms_flex_row hvp_into_device parity] upload marg must succeed on CUDA host",
4463 );
4464 let d_g = stream.clone_htod(&logslope).expect(
4465 "[bms_flex_row hvp_into_device parity] upload logslope must succeed on CUDA host",
4466 );
4467 let storage = DeviceResidentRowHess {
4468 hess: d_h,
4469 marginal_design: d_m,
4470 logslope_design: d_g,
4471 n,
4472 r,
4473 block: block.clone(),
4474 primary: primary.clone(),
4475
4476 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4477 };
4478
4479 let host_out_hvp = launch_bms_flex_row_hvp(&storage, &v)
4481 .expect("host-out HVP kernel must launch on CUDA host");
4482
4483 let d_v = stream
4486 .clone_htod(&v)
4487 .expect("upload direction for device-out HVP");
4488 let mut d_out = stream
4489 .alloc_zeros::<f64>(p_total)
4490 .expect("alloc device-out HVP output");
4491 launch_bms_flex_row_hvp_into_device(&storage, &d_v, &mut d_out)
4492 .expect("device-out HVP kernel must launch on CUDA host");
4493 stream
4494 .synchronize()
4495 .expect("synchronize after device-out HVP");
4496 let device_out_hvp = stream
4497 .clone_dtoh(&d_out)
4498 .expect("download device-out HVP output");
4499
4500 assert_eq!(device_out_hvp.len(), cpu_hvp.len());
4501 assert_eq!(device_out_hvp.len(), host_out_hvp.len());
4502 for i in 0..p_total {
4503 let diff = (cpu_hvp[i] - device_out_hvp[i]).abs();
4504 assert!(
4505 diff <= 1e-10,
4506 "device-out HVP[{i}] vs CPU: cpu={} gpu={} |Δ|={diff:.3e}",
4507 cpu_hvp[i],
4508 device_out_hvp[i]
4509 );
4510 let host_diff = (host_out_hvp[i] - device_out_hvp[i]).abs();
4513 assert!(
4514 host_diff == 0.0,
4515 "device-out vs host-out HVP[{i}]: host={} device={} |Δ|={host_diff:.3e}",
4516 host_out_hvp[i],
4517 device_out_hvp[i]
4518 );
4519 }
4520 }
4521 }
4522
4523 #[test]
4536 pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_at_n64_r20_p44() {
4537 #[cfg(not(target_os = "linux"))]
4538 {
4539 eprintln!(
4540 "[bms_flex_row hvp parity n64_r20_p44] non-Linux host — \
4541 skipping CUDA parity"
4542 );
4543 }
4544 #[cfg(target_os = "linux")]
4545 {
4546 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4547 eprintln!(
4548 "[bms_flex_row hvp parity n64_r20_p44] no CUDA runtime — \
4549 skipping device parity"
4550 );
4551 return;
4552 };
4553 let n = 64_usize;
4554 let p_m = 14_usize;
4555 let p_g = 12_usize;
4556 let p_h_dim = 10_usize;
4557 let p_w_dim = 8_usize;
4558 let r = 2 + p_h_dim + p_w_dim;
4559 assert_eq!(r, 20);
4560 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4561 assert_eq!(p_total, 44);
4562 let block = BmsFlexBlockLayout {
4563 p_m,
4564 p_g,
4565 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4566 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4567 p_total,
4568 };
4569 let primary = BmsFlexPrimaryLayout {
4570 h: Some(2..2 + p_h_dim),
4571 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4572 r,
4573 };
4574
4575 let mut row_hessians = vec![0.0_f64; n * r * r];
4581 for row in 0..n {
4582 let base = row * r * r;
4583 for u in 0..r {
4584 for v in 0..r {
4585 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
4586 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4587 row_hessians[base + u * r + v] = a;
4588 }
4589 }
4590 for u in 0..r {
4591 for v in (u + 1)..r {
4592 let upper = row_hessians[base + u * r + v];
4593 let lower = row_hessians[base + v * r + u];
4594 let sym = 0.5 * (upper + lower);
4595 row_hessians[base + u * r + v] = sym;
4596 row_hessians[base + v * r + u] = sym;
4597 }
4598 row_hessians[base + u * r + u] += r as f64;
4599 }
4600 }
4601 let mut marginal = vec![0.0_f64; n * p_m];
4602 for row in 0..n {
4603 for j in 0..p_m {
4604 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4605 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4606 }
4607 }
4608 let mut logslope = vec![0.0_f64; n * p_g];
4609 for row in 0..n {
4610 for j in 0..p_g {
4611 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4612 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4613 }
4614 }
4615 let v: Vec<f64> = (0..p_total)
4616 .map(|i| {
4617 let seed = (i as f64) * 0.157 + 0.6;
4618 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4619 })
4620 .collect();
4621
4622 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4623 &row_hessians,
4624 &marginal,
4625 &logslope,
4626 &block,
4627 &primary,
4628 n,
4629 &v,
4630 );
4631 let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4632 &row_hessians,
4633 &marginal,
4634 &logslope,
4635 &block,
4636 &primary,
4637 n,
4638 );
4639
4640 let backend = match HvpKernelBackend::probe() {
4641 Ok(b) => b,
4642 Err(err) => {
4643 eprintln!(
4644 "[bms_flex_row hvp parity n64_r20_p44] backend probe \
4645 failed: {err}"
4646 );
4647 return;
4648 }
4649 };
4650 let stream = backend.stream.clone();
4651 let d_h = match stream.clone_htod(&row_hessians) {
4652 Ok(s) => s,
4653 Err(err) => {
4654 eprintln!(
4655 "[bms_flex_row hvp parity n64_r20_p44] upload h \
4656 failed: {err}"
4657 );
4658 return;
4659 }
4660 };
4661 let d_m = match stream.clone_htod(&marginal) {
4662 Ok(s) => s,
4663 Err(err) => {
4664 eprintln!(
4665 "[bms_flex_row hvp parity n64_r20_p44] upload marg \
4666 failed: {err}"
4667 );
4668 return;
4669 }
4670 };
4671 let d_g = match stream.clone_htod(&logslope) {
4672 Ok(s) => s,
4673 Err(err) => {
4674 eprintln!(
4675 "[bms_flex_row hvp parity n64_r20_p44] upload logslope \
4676 failed: {err}"
4677 );
4678 return;
4679 }
4680 };
4681 let storage = DeviceResidentRowHess {
4682 hess: d_h,
4683 marginal_design: d_m,
4684 logslope_design: d_g,
4685 n,
4686 r,
4687 block: block.clone(),
4688 primary: primary.clone(),
4689
4690 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4691 };
4692 let gpu_hvp = launch_bms_flex_row_hvp(&storage, &v)
4693 .expect("HVP kernel must launch on CUDA host at n64/r20/p44");
4694 let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4695 .expect("diagonal kernel must launch on CUDA host at n64/r20/p44");
4696 assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4697 assert_eq!(gpu_diag.len(), cpu_diag.len());
4698 for i in 0..p_total {
4699 let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4700 assert!(
4701 diff <= 1e-8,
4702 "n64_r20_p44 HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4703 cpu_hvp[i],
4704 gpu_hvp[i]
4705 );
4706 let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4707 assert!(
4708 ddiff <= 1e-8,
4709 "n64_r20_p44 diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4710 cpu_diag[i],
4711 gpu_diag[i]
4712 );
4713 }
4714 }
4715 }
4716
4717 #[test]
4723 pub(crate) fn bms_flex_row_dense_block_kernel_matches_cpu_pullback() {
4724 #[cfg(not(target_os = "linux"))]
4725 {
4726 eprintln!("[bms_flex_row dense_block parity] non-Linux host — skipping CUDA parity");
4727 }
4728 #[cfg(target_os = "linux")]
4729 {
4730 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4731 eprintln!("[bms_flex_row dense_block parity] no CUDA runtime — skipping");
4732 return;
4733 };
4734 let n = 24_usize;
4738 let p_m = 4_usize;
4739 let p_g = 4_usize;
4740 let p_h_dim = 3_usize;
4741 let p_w_dim = 3_usize;
4742 let r = 2 + p_h_dim + p_w_dim;
4743 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4744 let block = BmsFlexBlockLayout {
4745 p_m,
4746 p_g,
4747 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4748 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4749 p_total,
4750 };
4751 let primary = BmsFlexPrimaryLayout {
4752 h: Some(2..2 + p_h_dim),
4753 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4754 r,
4755 };
4756
4757 let mut row_hessians = vec![0.0_f64; n * r * r];
4758 for row in 0..n {
4759 let base = row * r * r;
4760 for u in 0..r {
4761 for v in 0..r {
4762 let seed = (row as f64) * 0.21 + (u as f64) * 1.13 + (v as f64) * 0.47;
4763 let a = (seed.sin() * 1.4 + (seed * 0.6).cos() * 0.7) * 0.5;
4764 row_hessians[base + u * r + v] = a;
4765 }
4766 }
4767 for u in 0..r {
4768 for v in (u + 1)..r {
4769 let upper = row_hessians[base + u * r + v];
4770 let lower = row_hessians[base + v * r + u];
4771 let sym = 0.5 * (upper + lower);
4772 row_hessians[base + u * r + v] = sym;
4773 row_hessians[base + v * r + u] = sym;
4774 }
4775 row_hessians[base + u * r + u] += r as f64;
4776 }
4777 }
4778 let mut marginal = vec![0.0_f64; n * p_m];
4779 for row in 0..n {
4780 for j in 0..p_m {
4781 let seed = (row as f64) * 0.083 + (j as f64) * 0.171 + 0.31;
4782 marginal[row * p_m + j] = seed.sin() * 0.7 - (seed * 0.5).cos() * 0.25;
4783 }
4784 }
4785 let mut logslope = vec![0.0_f64; n * p_g];
4786 for row in 0..n {
4787 for j in 0..p_g {
4788 let seed = (row as f64) * 0.097 + (j as f64) * 0.143 - 0.15;
4789 logslope[row * p_g + j] = seed.cos() * 0.65 + (seed * 0.4).sin() * 0.2;
4790 }
4791 }
4792
4793 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
4795 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
4796 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
4797 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
4798 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
4799 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
4800 let mut h_cpu = vec![0.0_f64; p_total * p_total];
4801 for row in 0..n {
4802 let mrow = &marginal[row * p_m..(row + 1) * p_m];
4803 let grow = &logslope[row * p_g..(row + 1) * p_g];
4804 let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
4805 let mut phi = vec![vec![0.0_f64; p_total]; r];
4807 for k in 0..p_m {
4808 phi[0][k] = mrow[k];
4809 }
4810 for k in 0..p_g {
4811 phi[1][p_m + k] = grow[k];
4812 }
4813 for k in 0..h_block_len {
4814 phi[h_primary_start + k][h_block_start + k] = 1.0;
4815 }
4816 for k in 0..w_block_len {
4817 phi[w_primary_start + k][w_block_start + k] = 1.0;
4818 }
4819 for u in 0..r {
4820 for v in 0..r {
4821 let huv = hrow[u * r + v];
4822 if huv == 0.0 {
4823 continue;
4824 }
4825 for m in 0..p_total {
4826 let pm = phi[u][m];
4827 if pm == 0.0 {
4828 continue;
4829 }
4830 let scaled = huv * pm;
4831 for nn in 0..p_total {
4832 h_cpu[m * p_total + nn] += scaled * phi[v][nn];
4833 }
4834 }
4835 }
4836 }
4837 }
4838
4839 let backend = HvpKernelBackend::probe().expect(
4844 "[bms_flex_row dense_block parity] backend probe must succeed on CUDA host",
4845 );
4846 let stream = backend.stream.clone();
4847 let d_h = stream
4848 .clone_htod(&row_hessians)
4849 .expect("[bms_flex_row dense_block parity] upload h must succeed on CUDA host");
4850 let d_m = stream
4851 .clone_htod(&marginal)
4852 .expect("[bms_flex_row dense_block parity] upload marg must succeed on CUDA host");
4853 let d_g = stream.clone_htod(&logslope).expect(
4854 "[bms_flex_row dense_block parity] upload logslope must succeed on CUDA host",
4855 );
4856 let storage = DeviceResidentRowHess {
4857 hess: d_h,
4858 marginal_design: d_m,
4859 logslope_design: d_g,
4860 n,
4861 r,
4862 block: block.clone(),
4863 primary: primary.clone(),
4864
4865 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4866 };
4867 let h_gpu = launch_bms_flex_row_dense_block(&storage)
4868 .expect("dense_block kernel must launch on CUDA host");
4869 assert_eq!(h_gpu.len(), p_total * p_total);
4870
4871 let mut max_abs = 0.0_f64;
4874 for i in 0..p_total {
4875 for j in 0..p_total {
4876 let a = h_cpu[i * p_total + j];
4877 let b = h_gpu[i * p_total + j];
4878 let diff = (a - b).abs();
4879 if diff > max_abs {
4880 max_abs = diff;
4881 }
4882 assert!(
4883 diff <= 1e-9 * a.abs().max(b.abs()).max(1.0),
4884 "dense_block[{i},{j}]: cpu={a} gpu={b} |Δ|={diff:.3e}"
4885 );
4886 }
4887 }
4888 eprintln!(
4889 "[bms_flex_row dense_block parity] n={n} r={r} p={p_total}: max|Δ|={max_abs:.3e}"
4890 );
4891 }
4892 }
4893
4894 #[test]
4914 pub(crate) fn bms_flex_row_hvp_v100_hill_climb_5x_vs_cpu_at_large_scale() {
4915 #[cfg(not(target_os = "linux"))]
4916 {
4917 eprintln!("[bms_flex_row hvp hill-climb] non-Linux host — skipping V100 perf gate");
4918 }
4919 #[cfg(target_os = "linux")]
4920 {
4921 use rayon::prelude::*;
4922
4923 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4924 eprintln!(
4925 "[bms_flex_row hvp hill-climb] no CUDA runtime — skipping V100 perf gate"
4926 );
4927 return;
4928 };
4929 let n = 195_000_usize;
4930 let p_m = 14_usize;
4931 let p_g = 12_usize;
4932 let p_h_dim = 10_usize;
4933 let p_w_dim = 8_usize;
4934 let r = 2 + p_h_dim + p_w_dim;
4935 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4936 let block = BmsFlexBlockLayout {
4937 p_m,
4938 p_g,
4939 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4940 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4941 p_total,
4942 };
4943 let primary = BmsFlexPrimaryLayout {
4944 h: Some(2..2 + p_h_dim),
4945 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4946 r,
4947 };
4948
4949 let mut row_hessians = vec![0.0_f64; n * r * r];
4951 for row in 0..n {
4952 let base = row * r * r;
4953 for u in 0..r {
4954 for vv in 0..r {
4955 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
4956 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4957 row_hessians[base + u * r + vv] = a;
4958 }
4959 }
4960 for u in 0..r {
4961 for vv in (u + 1)..r {
4962 let upper = row_hessians[base + u * r + vv];
4963 let lower = row_hessians[base + vv * r + u];
4964 let sym = 0.5 * (upper + lower);
4965 row_hessians[base + u * r + vv] = sym;
4966 row_hessians[base + vv * r + u] = sym;
4967 }
4968 row_hessians[base + u * r + u] += r as f64;
4969 }
4970 }
4971 let mut marginal = vec![0.0_f64; n * p_m];
4972 for row in 0..n {
4973 for j in 0..p_m {
4974 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4975 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4976 }
4977 }
4978 let mut logslope = vec![0.0_f64; n * p_g];
4979 for row in 0..n {
4980 for j in 0..p_g {
4981 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4982 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4983 }
4984 }
4985 let v: Vec<f64> = (0..p_total)
4986 .map(|i| {
4987 let seed = (i as f64) * 0.157 + 0.6;
4988 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4989 })
4990 .collect();
4991
4992 let backend = match HvpKernelBackend::probe() {
4994 Ok(b) => b,
4995 Err(err) => {
4996 eprintln!("[bms_flex_row hvp hill-climb] backend probe failed: {err}");
4997 return;
4998 }
4999 };
5000 let stream = backend.stream.clone();
5001 let d_h = match stream.clone_htod(&row_hessians) {
5002 Ok(s) => s,
5003 Err(err) => {
5004 eprintln!("[bms_flex_row hvp hill-climb] upload h failed (likely OOM): {err}");
5005 return;
5006 }
5007 };
5008 let d_m = match stream.clone_htod(&marginal) {
5009 Ok(s) => s,
5010 Err(err) => {
5011 eprintln!("[bms_flex_row hvp hill-climb] upload marg failed: {err}");
5012 return;
5013 }
5014 };
5015 let d_g = match stream.clone_htod(&logslope) {
5016 Ok(s) => s,
5017 Err(err) => {
5018 eprintln!("[bms_flex_row hvp hill-climb] upload logslope failed: {err}");
5019 return;
5020 }
5021 };
5022 let storage = DeviceResidentRowHess {
5023 hess: d_h,
5024 marginal_design: d_m,
5025 logslope_design: d_g,
5026 n,
5027 r,
5028 block: block.clone(),
5029 primary: primary.clone(),
5030
5031 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5032 };
5033 let warmup: usize = 3;
5034 let iters: usize = 15;
5035 for _ in 0..warmup {
5036 let out =
5037 launch_bms_flex_row_hvp(&storage, &v).expect("warmup GPU HVP must launch");
5038 assert_eq!(out.len(), p_total);
5039 }
5040 let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5041 for _ in 0..iters {
5042 let t0 = std::time::Instant::now();
5043 let out = launch_bms_flex_row_hvp(&storage, &v).expect("GPU HVP must launch");
5044 gpu_us.push(t0.elapsed().as_micros());
5045 assert_eq!(out.len(), p_total);
5046 }
5047 gpu_us.sort_unstable();
5048 let gpu_median = gpu_us[iters / 2];
5049
5050 const CHUNK_ROWS: usize = 4096;
5056 let cpu_hvp_parallel = || -> Vec<f64> {
5057 let nchunks = n.div_ceil(CHUNK_ROWS);
5058 (0..nchunks)
5059 .into_par_iter()
5060 .fold(
5061 || vec![0.0_f64; p_total],
5062 |mut acc, ci| {
5063 let lo = ci * CHUNK_ROWS;
5064 let hi = (lo + CHUNK_ROWS).min(n);
5065 let m = hi - lo;
5066 let partial = cpu_oracle_bms_flex_row_hvp(
5067 &row_hessians[lo * r * r..hi * r * r],
5068 &marginal[lo * p_m..hi * p_m],
5069 &logslope[lo * p_g..hi * p_g],
5070 &block,
5071 &primary,
5072 m,
5073 &v,
5074 );
5075 for (a, &p) in acc.iter_mut().zip(partial.iter()) {
5076 *a += p;
5077 }
5078 acc
5079 },
5080 )
5081 .reduce(
5082 || vec![0.0_f64; p_total],
5083 |mut a, b| {
5084 for (ax, bx) in a.iter_mut().zip(b.iter()) {
5085 *ax += *bx;
5086 }
5087 a
5088 },
5089 )
5090 };
5091 let warm = cpu_hvp_parallel();
5093 assert_eq!(warm.len(), p_total);
5094 let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5095 for _ in 0..iters {
5096 let t0 = std::time::Instant::now();
5097 let out = cpu_hvp_parallel();
5098 cpu_us.push(t0.elapsed().as_micros());
5099 assert_eq!(out.len(), p_total);
5100 }
5101 cpu_us.sort_unstable();
5102 let cpu_median = cpu_us[iters / 2];
5103
5104 let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5105 eprintln!(
5106 "[bms_flex_row hvp hill-climb] large-scale n={n} r={r} p={p_total}: \
5107 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5108 speedup={speedup:.2}× (charter target ≥ 5×)"
5109 );
5110 assert!(
5111 speedup >= 5.0,
5112 "large-scale HVP perf gate: GPU only {speedup:.2}× faster than CPU; \
5113 need ≥ 5× per Block 9 charter (cpu_median={cpu_median}us, \
5114 gpu_median={gpu_median}us). Hill-climb the kernel until met or \
5115 prove the kernel is at hardware roofline."
5116 );
5117 }
5118 }
5119
5120 #[test]
5125 pub(crate) fn bms_flex_row_dense_block_v100_hill_climb_10x_vs_cpu_at_large_scale() {
5126 #[cfg(not(target_os = "linux"))]
5127 {
5128 eprintln!(
5129 "[bms_flex_row dense_block hill-climb] non-Linux host — skipping V100 perf gate"
5130 );
5131 }
5132 #[cfg(target_os = "linux")]
5133 {
5134 use rayon::prelude::*;
5135
5136 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5137 eprintln!(
5138 "[bms_flex_row dense_block hill-climb] no CUDA runtime — skipping V100 perf gate"
5139 );
5140 return;
5141 };
5142 let n = 195_000_usize;
5143 let p_m = 14_usize;
5144 let p_g = 12_usize;
5145 let p_h_dim = 10_usize;
5146 let p_w_dim = 8_usize;
5147 let r = 2 + p_h_dim + p_w_dim;
5148 let p_total = p_m + p_g + p_h_dim + p_w_dim;
5149 let block = BmsFlexBlockLayout {
5150 p_m,
5151 p_g,
5152 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5153 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5154 p_total,
5155 };
5156 let primary = BmsFlexPrimaryLayout {
5157 h: Some(2..2 + p_h_dim),
5158 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5159 r,
5160 };
5161
5162 let mut row_hessians = vec![0.0_f64; n * r * r];
5164 for row in 0..n {
5165 let base = row * r * r;
5166 for u in 0..r {
5167 for vv in 0..r {
5168 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5169 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5170 row_hessians[base + u * r + vv] = a;
5171 }
5172 }
5173 for u in 0..r {
5174 for vv in (u + 1)..r {
5175 let upper = row_hessians[base + u * r + vv];
5176 let lower = row_hessians[base + vv * r + u];
5177 let sym = 0.5 * (upper + lower);
5178 row_hessians[base + u * r + vv] = sym;
5179 row_hessians[base + vv * r + u] = sym;
5180 }
5181 row_hessians[base + u * r + u] += r as f64;
5182 }
5183 }
5184 let mut marginal = vec![0.0_f64; n * p_m];
5185 for row in 0..n {
5186 for j in 0..p_m {
5187 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5188 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5189 }
5190 }
5191 let mut logslope = vec![0.0_f64; n * p_g];
5192 for row in 0..n {
5193 for j in 0..p_g {
5194 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5195 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5196 }
5197 }
5198
5199 if p_total > DENSE_BLOCK_MAX_P {
5202 eprintln!(
5203 "[bms_flex_row dense_block hill-climb] p_total={p_total} > MAX={DENSE_BLOCK_MAX_P}, skipping"
5204 );
5205 return;
5206 }
5207 let backend = match HvpKernelBackend::probe() {
5208 Ok(b) => b,
5209 Err(err) => {
5210 eprintln!("[bms_flex_row dense_block hill-climb] backend probe failed: {err}");
5211 return;
5212 }
5213 };
5214 let stream = backend.stream.clone();
5215 let d_h = match stream.clone_htod(&row_hessians) {
5216 Ok(s) => s,
5217 Err(err) => {
5218 eprintln!("[bms_flex_row dense_block hill-climb] upload h failed: {err}");
5219 return;
5220 }
5221 };
5222 let d_m = match stream.clone_htod(&marginal) {
5223 Ok(s) => s,
5224 Err(err) => {
5225 eprintln!("[bms_flex_row dense_block hill-climb] upload marg failed: {err}");
5226 return;
5227 }
5228 };
5229 let d_g = match stream.clone_htod(&logslope) {
5230 Ok(s) => s,
5231 Err(err) => {
5232 eprintln!(
5233 "[bms_flex_row dense_block hill-climb] upload logslope failed: {err}"
5234 );
5235 return;
5236 }
5237 };
5238 let storage = DeviceResidentRowHess {
5239 hess: d_h,
5240 marginal_design: d_m,
5241 logslope_design: d_g,
5242 n,
5243 r,
5244 block: block.clone(),
5245 primary: primary.clone(),
5246
5247 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5248 };
5249 let warmup: usize = 2;
5251 let iters: usize = 5;
5252 for _ in 0..warmup {
5253 let out = launch_bms_flex_row_dense_block(&storage)
5254 .expect("warmup GPU dense_block must launch");
5255 assert_eq!(out.len(), p_total * p_total);
5256 }
5257 let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5258 for _ in 0..iters {
5259 let t0 = std::time::Instant::now();
5260 let out =
5261 launch_bms_flex_row_dense_block(&storage).expect("GPU dense_block must launch");
5262 gpu_us.push(t0.elapsed().as_micros());
5263 assert_eq!(out.len(), p_total * p_total);
5264 }
5265 gpu_us.sort_unstable();
5266 let gpu_median = gpu_us[iters / 2];
5267
5268 const CHUNK_ROWS: usize = 2048;
5271 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5272 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5273 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5274 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5275 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5276 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5277 let cpu_build_parallel = || -> Vec<f64> {
5278 let nchunks = n.div_ceil(CHUNK_ROWS);
5279 (0..nchunks)
5280 .into_par_iter()
5281 .fold(
5282 || vec![0.0_f64; p_total * p_total],
5283 |mut acc, ci| {
5284 let lo = ci * CHUNK_ROWS;
5285 let hi = (lo + CHUNK_ROWS).min(n);
5286 let mut phi: Vec<Vec<f64>> = vec![vec![0.0_f64; p_total]; r];
5287 for row in lo..hi {
5288 for col in phi.iter_mut() {
5289 col.iter_mut().for_each(|v| *v = 0.0);
5290 }
5291 let mrow = &marginal[row * p_m..(row + 1) * p_m];
5292 let grow = &logslope[row * p_g..(row + 1) * p_g];
5293 for k in 0..p_m {
5294 phi[0][k] = mrow[k];
5295 }
5296 for k in 0..p_g {
5297 phi[1][p_m + k] = grow[k];
5298 }
5299 for k in 0..h_block_len {
5300 phi[h_primary_start + k][h_block_start + k] = 1.0;
5301 }
5302 for k in 0..w_block_len {
5303 phi[w_primary_start + k][w_block_start + k] = 1.0;
5304 }
5305 let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5306 for u in 0..r {
5307 for v_idx in 0..r {
5308 let huv = hrow[u * r + v_idx];
5309 if huv == 0.0 {
5310 continue;
5311 }
5312 for m in 0..p_total {
5313 let pm = phi[u][m];
5314 if pm == 0.0 {
5315 continue;
5316 }
5317 let scaled = huv * pm;
5318 for nn in 0..p_total {
5319 acc[m * p_total + nn] += scaled * phi[v_idx][nn];
5320 }
5321 }
5322 }
5323 }
5324 }
5325 acc
5326 },
5327 )
5328 .reduce(
5329 || vec![0.0_f64; p_total * p_total],
5330 |mut a, b| {
5331 for (ax, bx) in a.iter_mut().zip(b.iter()) {
5332 *ax += *bx;
5333 }
5334 a
5335 },
5336 )
5337 };
5338 let warm_cpu = cpu_build_parallel();
5339 assert_eq!(warm_cpu.len(), p_total * p_total);
5340 let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5341 for _ in 0..iters {
5342 let t0 = std::time::Instant::now();
5343 let out = cpu_build_parallel();
5344 cpu_us.push(t0.elapsed().as_micros());
5345 assert_eq!(out.len(), p_total * p_total);
5346 }
5347 cpu_us.sort_unstable();
5348 let cpu_median = cpu_us[iters / 2];
5349
5350 let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5351 eprintln!(
5352 "[bms_flex_row dense_block hill-climb] large-scale n={n} r={r} p={p_total}: \
5353 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5354 speedup={speedup:.2}× (charter target ≥ 10×)"
5355 );
5356 assert!(
5357 speedup >= 10.0,
5358 "large-scale dense-H perf gate: GPU only {speedup:.2}× faster than CPU; \
5359 need ≥ 10× per Block 9 charter (cpu_median={cpu_median}us, \
5360 gpu_median={gpu_median}us). Hill-climb the dense_block kernel \
5361 (warp-stripe the u-v-m-n loop, vectorise loads, etc.) until met \
5362 or prove the kernel is at hardware roofline."
5363 );
5364 }
5365 }
5366}