1#[cfg(target_os = "linux")]
85use std::sync::OnceLock;
86
87use gam_gpu::gpu_error::GpuError;
88
89#[cfg(target_os = "linux")]
90use std::sync::Arc;
91
92#[cfg(target_os = "linux")]
93use cudarc::driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
94
95pub(crate) const MAX_R: usize = 32;
99
100#[cfg(target_os = "linux")]
104pub(crate) const ROW_KERNEL_THREADS: u32 = 32;
105
106pub(crate) const COEFF4: usize = 4;
109
110pub(crate) const MOMENT_STRIDE: usize = 10;
113
114pub(crate) enum CellMomentsSource<'a> {
119 Host(&'a [f64]),
122 #[cfg(target_os = "linux")]
127 Device(&'a CudaSlice<f64>),
128}
129
130impl<'a> CellMomentsSource<'a> {
131 pub(crate) fn len(&self) -> usize {
133 match self {
134 CellMomentsSource::Host(slice) => slice.len(),
135 #[cfg(target_os = "linux")]
136 CellMomentsSource::Device(d) => d.len(),
137 }
138 }
139}
140
141macro_rules! define_bms_flex_row_kernel_input_types {
149 (
150 f64_fields: [$($f64_field:ident),+ $(,)?],
151 u32_fields: [$($u32_field:ident),+ $(,)?],
152 moments_field: $moments_field:ident $(,)?
153 ) => {
154 pub(crate) struct BmsFlexRowKernelInputs<'a> {
155 pub n_rows: usize,
157 pub r: usize,
159 pub p_h: usize,
161 pub p_w: usize,
163 pub s_f: f64,
166 $(pub $f64_field: &'a [f64],)+
167 $(pub $u32_field: &'a [u32],)+
168 pub $moments_field: CellMomentsSource<'a>,
169 }
170
171 pub(crate) struct BmsFlexRowKernelInputsOwned {
176 pub n_rows: usize,
177 pub r: usize,
178 pub p_h: usize,
179 pub p_w: usize,
180 pub s_f: f64,
181 $(pub $f64_field: Vec<f64>,)+
182 $(pub $u32_field: Vec<u32>,)+
183 pub $moments_field: Vec<f64>,
184 #[cfg(target_os = "linux")]
188 pub cell_moments_device: Option<CudaSlice<f64>>,
189 }
190
191 impl BmsFlexRowKernelInputsOwned {
192 pub(crate) fn as_borrowed(&self) -> BmsFlexRowKernelInputs<'_> {
197 #[cfg(target_os = "linux")]
198 let cell_moments = match self.cell_moments_device.as_ref() {
199 Some(d) => CellMomentsSource::Device(d),
200 None => CellMomentsSource::Host(&self.cell_moments),
201 };
202 #[cfg(not(target_os = "linux"))]
203 let cell_moments = CellMomentsSource::Host(&self.cell_moments);
204 BmsFlexRowKernelInputs {
205 n_rows: self.n_rows,
206 r: self.r,
207 p_h: self.p_h,
208 p_w: self.p_w,
209 s_f: self.s_f,
210 $($f64_field: &self.$f64_field,)+
211 $($u32_field: &self.$u32_field,)+
212 $moments_field: cell_moments,
213 }
214 }
215 }
216 };
217}
218
219define_bms_flex_row_kernel_input_types! {
220 f64_fields: [
221 q,
222 b,
223 mu_1,
224 mu_2,
225 z_obs,
226 y,
227 w,
228 cell_c0,
229 cell_c1,
230 cell_c2,
231 cell_c3,
232 cell_a,
233 cell_aa,
234 cell_r,
235 cell_ar,
236 cell_sbb,
237 cell_sbh,
238 cell_sbw,
239 chi_obs,
240 xi_obs,
241 rho_u,
242 tau_u,
243 r_uv,
244 ],
245 u32_fields: [cell_offsets],
246 moments_field: cell_moments,
247}
248
249#[derive(Debug)]
251pub(crate) struct BmsFlexRowKernelOutputs {
252 pub neglog: Vec<f64>,
254 pub grad: Vec<f64>,
256 pub hess: Vec<f64>,
259}
260
261impl<'a> BmsFlexRowKernelInputs<'a> {
262 pub(crate) fn validate(&self) -> Result<(), GpuError> {
265 if self.r == 0 {
266 return Err(GpuError::DriverCallFailed {
267 reason: "bms_flex_row inputs: r must be > 0".to_string(),
268 });
269 }
270 if self.r > MAX_R {
271 return Err(GpuError::DriverCallFailed {
272 reason: format!("bms_flex_row inputs: r={} exceeds MAX_R={MAX_R}", self.r),
273 });
274 }
275 if self.r != 2 + self.p_h + self.p_w {
276 return Err(GpuError::DriverCallFailed {
277 reason: format!(
278 "bms_flex_row inputs: r={} must equal 2 + p_h({}) + p_w({}) = {}",
279 self.r,
280 self.p_h,
281 self.p_w,
282 2 + self.p_h + self.p_w
283 ),
284 });
285 }
286 let n = self.n_rows;
287 let check_len = |name: &str, have: usize, want: usize| -> Result<(), GpuError> {
288 if have != want {
289 return Err(GpuError::DriverCallFailed {
290 reason: format!("bms_flex_row inputs: {name}.len()={have} != {want}"),
291 });
292 }
293 Ok(())
294 };
295 check_len("q", self.q.len(), n)?;
296 check_len("b", self.b.len(), n)?;
297 check_len("mu_1", self.mu_1.len(), n)?;
298 check_len("mu_2", self.mu_2.len(), n)?;
299 check_len("z_obs", self.z_obs.len(), n)?;
300 check_len("y", self.y.len(), n)?;
301 check_len("w", self.w.len(), n)?;
302 check_len("chi_obs", self.chi_obs.len(), n)?;
303 check_len("xi_obs", self.xi_obs.len(), n)?;
304 check_len("rho_u", self.rho_u.len(), n * self.r)?;
305 check_len("tau_u", self.tau_u.len(), n * self.r)?;
306 check_len("r_uv", self.r_uv.len(), n * self.r * self.r)?;
307 check_len("cell_offsets", self.cell_offsets.len(), n + 1)?;
308 let total_cells_u32 = self.cell_offsets[n];
309 let total_cells = total_cells_u32 as usize;
310 check_len("cell_c0", self.cell_c0.len(), total_cells)?;
311 check_len("cell_c1", self.cell_c1.len(), total_cells)?;
312 check_len("cell_c2", self.cell_c2.len(), total_cells)?;
313 check_len("cell_c3", self.cell_c3.len(), total_cells)?;
314 check_len("cell_a", self.cell_a.len(), total_cells * COEFF4)?;
315 check_len("cell_aa", self.cell_aa.len(), total_cells * COEFF4)?;
316 check_len(
317 "cell_r",
318 self.cell_r.len(),
319 total_cells * self.r.saturating_sub(1) * COEFF4,
320 )?;
321 check_len(
322 "cell_ar",
323 self.cell_ar.len(),
324 total_cells * self.r.saturating_sub(1) * COEFF4,
325 )?;
326 check_len("cell_sbb", self.cell_sbb.len(), total_cells * COEFF4)?;
327 check_len(
328 "cell_sbh",
329 self.cell_sbh.len(),
330 total_cells * self.p_h * COEFF4,
331 )?;
332 check_len(
333 "cell_sbw",
334 self.cell_sbw.len(),
335 total_cells * self.p_w * COEFF4,
336 )?;
337 check_len(
338 "cell_moments",
339 self.cell_moments.len(),
340 total_cells * MOMENT_STRIDE,
341 )?;
342 for i in 0..n {
348 if self.cell_offsets[i] > self.cell_offsets[i + 1] {
349 return Err(GpuError::DriverCallFailed {
350 reason: format!(
351 "bms_flex_row inputs: cell_offsets must be monotone (offset[{}]={} > offset[{}]={})",
352 i,
353 self.cell_offsets[i],
354 i + 1,
355 self.cell_offsets[i + 1]
356 ),
357 });
358 }
359 }
360 Ok(())
361 }
362}
363
364#[cfg(target_os = "linux")]
377pub(crate) const ROW_KERNEL_BODY: &str = r#"
378// One block per row. blockDim.x = 32; threadIdx.x parallises per-cell sums.
379// CPU parity reference: src/families/bernoulli_marginal_slope.rs
380// ::compute_row_analytic_flex_from_parts_into.
381
382#define INV_TWO_PI 0.15915494309189535
383
384extern "C" __device__ __forceinline__ double atomic_add_f64(double *addr, double value) {
385 unsigned long long int *addr_as_ull = (unsigned long long int *)addr;
386 unsigned long long int old = *addr_as_ull;
387 unsigned long long int assumed;
388 do {
389 assumed = old;
390 double next = __longlong_as_double((long long int)assumed) + value;
391 old = atomicCAS(addr_as_ull, assumed, (unsigned long long int)__double_as_longlong(next));
392 } while (assumed != old);
393 return __longlong_as_double((long long int)old);
394}
395
396// `nan_fill_outputs`: thread-0-only path used when row inputs are degenerate
397// (`F_a` non-finite or non-positive). Writes NaNs to neglog/grad/hess so the
398// host falls back to CPU for that row.
399extern "C" __device__ __forceinline__ void
400nan_fill_outputs(int r,
401 int row,
402 double *out_neglog,
403 double *out_grad,
404 double *out_hess) {
405 double nan_value = __longlong_as_double(0x7ff8000000000000ULL);
406 out_neglog[row] = nan_value;
407 for (int u = 0; u < r; ++u) {
408 out_grad[row * r + u] = nan_value;
409 }
410 int rr = r * r;
411 for (int idx = 0; idx < rr; ++idx) {
412 out_hess[row * rr + idx] = nan_value;
413 }
414}
415
416extern "C" __global__ void bms_flex_row_kernel(
417 int n_rows,
418 int r,
419 int p_h,
420 int p_w,
421 double s_f, // currently unused on device:
422 // host has already baked S_f
423 // into the cubic coefficients.
424 // Kept for diagnostic parity.
425 const double * __restrict__ row_q,
426 const double * __restrict__ row_b,
427 const double * __restrict__ row_mu1,
428 const double * __restrict__ row_mu2,
429 const double * __restrict__ row_zobs,
430 const double * __restrict__ row_y,
431 const double * __restrict__ row_w,
432 const unsigned int * __restrict__ cell_offsets,
433 const double * __restrict__ cell_c0,
434 const double * __restrict__ cell_c1,
435 const double * __restrict__ cell_c2,
436 const double * __restrict__ cell_c3,
437 const double * __restrict__ cell_a, // [n_cells, 4]
438 const double * __restrict__ cell_aa, // [n_cells, 4]
439 const double * __restrict__ cell_r, // [n_cells, r-1, 4]
440 const double * __restrict__ cell_ar, // [n_cells, r-1, 4]
441 const double * __restrict__ cell_sbb, // [n_cells, 4]
442 const double * __restrict__ cell_sbh, // [n_cells, p_h, 4]
443 const double * __restrict__ cell_sbw, // [n_cells, p_w, 4]
444 const double * __restrict__ cell_moments, // [n_cells, 10]
445 const double * __restrict__ row_chi,
446 const double * __restrict__ row_xi,
447 const double * __restrict__ row_rho, // [n_rows, r]
448 const double * __restrict__ row_tau, // [n_rows, r]
449 const double * __restrict__ row_ruv, // [n_rows, r*r]
450 double * __restrict__ out_neglog,
451 double * __restrict__ out_grad,
452 double * __restrict__ out_hess)
453{
454 int row = blockIdx.x;
455 if (row >= n_rows) return;
456 int tid = threadIdx.x;
457
458 // ── shared scratch (sized to MAX_R = 32) ──────────────────────────────
459 // Layout (doubles):
460 // F_u [r]
461 // F_au [r]
462 // F_uv [r*r]
463 // bar_e_u [r]
464 // bar_e_uv [r*r]
465 // reduce_a [blockDim.x]
466 // reduce_b [blockDim.x]
467 // Sized for the worst case (r = MAX_R = 32).
468 __shared__ double F_u[32];
469 __shared__ double F_au[32];
470 __shared__ double F_uv[32 * 32];
471 __shared__ double bar_e_u[32];
472 __shared__ double bar_e_uv[32 * 32];
473 __shared__ double reduce_a[32];
474 __shared__ double reduce_b[32];
475 __shared__ double F_a_shared;
476 __shared__ double F_aa_shared;
477
478 // Zero scratch.
479 if (tid == 0) { F_a_shared = 0.0; F_aa_shared = 0.0; }
480 for (int u = tid; u < r; u += blockDim.x) {
481 F_u[u] = 0.0;
482 F_au[u] = 0.0;
483 }
484 for (int uv = tid; uv < r * r; uv += blockDim.x) {
485 F_uv[uv] = 0.0;
486 }
487 __syncthreads();
488
489 // ── per-cell sweep ───────────────────────────────────────────────────
490 unsigned int cell_lo = cell_offsets[row];
491 unsigned int cell_hi = cell_offsets[row + 1];
492 int n_cells = (int)(cell_hi - cell_lo);
493
494 double local_Fa = 0.0;
495 double local_Faa = 0.0;
496
497 for (int local_c = tid; local_c < n_cells; local_c += blockDim.x) {
498 unsigned int c = cell_lo + (unsigned int)local_c;
499
500 // Load cubic predictor coeffs C0..C3.
501 double C[4];
502 C[0] = cell_c0[c]; C[1] = cell_c1[c];
503 C[2] = cell_c2[c]; C[3] = cell_c3[c];
504
505 // Load m_0..m_9.
506 const double *m = cell_moments + (size_t)c * 10;
507
508 // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
509 // CPU parity: equivalent to the `eta_rs ⊗ moments` contraction in
510 // `cell_second_derivative_from_moments` after folding the
511 // cubic predictor.
512 double T[7];
513 #pragma unroll
514 for (int n = 0; n < 7; ++n) {
515 double acc = 0.0;
516 #pragma unroll
517 for (int e = 0; e < 4; ++e) {
518 acc = fma(C[e], m[e + n], acc);
519 }
520 T[n] = acc * INV_TWO_PI;
521 }
522
523 // D(R) = κ · Σ_k R_k · m_k.
524 // CPU parity: `cell_first_derivative_from_moments`.
525 #define D_OF(R) (INV_TWO_PI * (R[0]*m[0] + R[1]*m[1] + R[2]*m[2] + R[3]*m[3]))
526
527 // Q(R, S) = Σ_{p,q} R_p · S_q · T_{p+q}.
528 // CPU parity: the `eta_rs` folded dot in
529 // `cell_second_derivative_from_moments`.
530 #define Q_OF(R, S) \
531 ((R[0]*S[0])*T[0] + (R[0]*S[1] + R[1]*S[0])*T[1] \
532 + (R[0]*S[2] + R[1]*S[1] + R[2]*S[0])*T[2] \
533 + (R[0]*S[3] + R[1]*S[2] + R[2]*S[1] + R[3]*S[0])*T[3] \
534 + (R[1]*S[3] + R[2]*S[2] + R[3]*S[1])*T[4] \
535 + (R[2]*S[3] + R[3]*S[2])*T[5] \
536 + (R[3]*S[3])*T[6])
537
538 // F_a += D(A_c) ; F_aa += H(A_c, A_c, AA_c) = D(AA_c) − Q(A_c, A_c).
539 const double *A_c = cell_a + (size_t)c * 4;
540 const double *AA_c = cell_aa + (size_t)c * 4;
541 local_Fa += D_OF(A_c);
542 local_Faa += D_OF(AA_c) - Q_OF(A_c, A_c);
543
544 // For each u > 0: F_u += D(R_{c,u}) ; F_au += H(A_c, R_{c,u}, AR_{c,u})
545 // = D(AR_{c,u}) − Q(A_c, R_{c,u}).
546 for (int u = 1; u < r; ++u) {
547 const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
548 const double *AR_u = cell_ar + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
549 double d_R = D_OF(R_u);
550 double d_AR = D_OF(AR_u);
551 double q_AR = Q_OF(A_c, R_u);
552 atomic_add_f64(&F_u[u], d_R);
553 atomic_add_f64(&F_au[u], d_AR - q_AR);
554 }
555
556 // F_uv: only b·b, b·h_j, b·w_ℓ have a material `S_{c,uv}`; every other
557 // (u, v) pair just contributes −Q(R_u, R_v).
558 // CPU parity: `SparsePrimaryCoeffJetView::pair_from_b_family` with
559 // `COEFF_SUPPORT_BHW` — every cross pair outside the b-row is zero.
560 for (int u = 1; u < r; ++u) {
561 const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
562 for (int v = u; v < r; ++v) {
563 const double *R_v = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(v - 1)) * 4;
564 double q_uv = Q_OF(R_u, R_v);
565 double d_s = 0.0;
566 // S_{bb}: u == v == 1 (b coordinate).
567 if (u == 1 && v == 1) {
568 const double *S_bb = cell_sbb + (size_t)c * 4;
569 d_s = D_OF(S_bb);
570 }
571 // S_{b·h_j}: u == 1, v in score-warp block, or symmetric.
572 else if (u == 1 && v >= 2 && v < 2 + p_h) {
573 int j = v - 2;
574 const double *S_bh = cell_sbh + ((size_t)c * (size_t)p_h + (size_t)j) * 4;
575 d_s = D_OF(S_bh);
576 }
577 // S_{b·w_ℓ}: u == 1, v in link-wiggle block, or symmetric.
578 else if (u == 1 && v >= 2 + p_h && v < r) {
579 int l = v - (2 + p_h);
580 const double *S_bw = cell_sbw + ((size_t)c * (size_t)p_w + (size_t)l) * 4;
581 d_s = D_OF(S_bw);
582 }
583 // Symmetric mirror: u in (h or w) block, v == 1 cannot happen
584 // because we iterate v >= u; skip.
585 double val = d_s - q_uv;
586 atomic_add_f64(&F_uv[u * r + v], val);
587 }
588 }
589
590 #undef D_OF
591 #undef Q_OF
592 }
593
594 // Block reduction of local_Fa, local_Faa into shared.
595 reduce_a[tid] = local_Fa;
596 reduce_b[tid] = local_Faa;
597 __syncthreads();
598 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
599 if (tid < stride) {
600 reduce_a[tid] += reduce_a[tid + stride];
601 reduce_b[tid] += reduce_b[tid + stride];
602 }
603 __syncthreads();
604 }
605 if (tid == 0) {
606 F_a_shared = reduce_a[0];
607 F_aa_shared = reduce_b[0];
608 }
609 __syncthreads();
610
611 // ── thread-0 finalisation: IFT + observed-point + Mills + writes ──────
612 if (tid != 0) return;
613
614 double F_a = F_a_shared;
615 double F_aa = F_aa_shared;
616 double mu_1 = row_mu1[row];
617 double mu_2 = row_mu2[row];
618
619 // q-row overrides.
620 // F_q = -mu_1 ; F_qq = -mu_2 ; F_qv = 0 (v > 0) ; F_aq = 0.
621 F_u[0] = -mu_1;
622 F_au[0] = 0.0;
623 // Zero the q-cross row/column of F_uv (u == 0 or v == 0), then plant -mu_2 at (0,0).
624 for (int v = 0; v < r; ++v) {
625 F_uv[0 * r + v] = 0.0;
626 F_uv[v * r + 0] = 0.0;
627 }
628 F_uv[0 * r + 0] = -mu_2;
629
630 // Guard: degenerate F_a ⇒ NaN-fill this row's outputs.
631 if (!isfinite(F_a) || F_a <= 0.0) {
632 nan_fill_outputs(r, row, out_neglog, out_grad, out_hess);
633 return;
634 }
635 double inv_Fa = 1.0 / F_a;
636
637 // IFT, first order.
638 // a_u = -F_u · inv_Fa (q-override: a_q = mu_1 · inv_Fa).
639 double a_u[32];
640 a_u[0] = mu_1 * inv_Fa;
641 for (int u = 1; u < r; ++u) {
642 a_u[u] = -F_u[u] * inv_Fa;
643 }
644
645 // IFT, second order.
646 // a_uv = -(F_uv + F_au · a_v + F_av · a_u + F_aa · a_u · a_v) · inv_Fa.
647 // The q-row contributions (u==0 or v==0) collapse to a_uv = mu_2 · inv_Fa
648 // when both are 0 and to (F_au_v) · inv_Fa-style mixed shape otherwise.
649 // We compute it uniformly using the populated F_uv / F_au with the
650 // q-overrides above.
651 double a_uv[32 * 32];
652 for (int u = 0; u < r; ++u) {
653 for (int v = u; v < r; ++v) {
654 double term = F_uv[u * r + v]
655 + F_au[v] * a_u[u]
656 + F_au[u] * a_u[v]
657 + F_aa * a_u[u] * a_u[v];
658 double val = -term * inv_Fa;
659 a_uv[u * r + v] = val;
660 a_uv[v * r + u] = val;
661 }
662 }
663
664 // Observed predictor jets at z_obs.
665 // bar_e_u = chi · a_u + rho_u.
666 // bar_e_uv = chi · a_uv + xi · a_u · a_v + tau_u · a_v + a_u · tau_v + r_uv.
667 double chi = row_chi[row];
668 double xi = row_xi[row];
669 const double *rho = row_rho + (size_t)row * r;
670 const double *tau = row_tau + (size_t)row * r;
671 const double *ruv = row_ruv + (size_t)row * r * r;
672
673 for (int u = 0; u < r; ++u) {
674 bar_e_u[u] = chi * a_u[u] + rho[u];
675 }
676 for (int u = 0; u < r; ++u) {
677 for (int v = u; v < r; ++v) {
678 double val = chi * a_uv[u * r + v]
679 + xi * a_u[u] * a_u[v]
680 + tau[u] * a_u[v]
681 + a_u[u] * tau[v]
682 + ruv[u * r + v];
683 bar_e_uv[u * r + v] = val;
684 if (u != v) {
685 bar_e_uv[v * r + u] = val;
686 }
687 }
688 }
689
690 // Probit Mills.
691 double y = row_y[row];
692 double w = row_w[row];
693 double s = 2.0 * y - 1.0;
694 // The "observed predictor" e_obs is the value (degree-0) term of the
695 // observed jet — same convention as the CPU path. CPU parity:
696 // `e_obs = chi · a_0 + rho_0`... well, no: `bar_e_u` is the *first*
697 // derivative jet, not the value. The observed predictor value comes
698 // from the host pre-evaluation as `rho_u[0]` of the value jet —
699 // pre-baked into the host's `m = s · e_obs` payload. For Stage 2 we
700 // expose it via the `bar_e_u[0]` slot which is `chi·a_0 + rho_0`; the
701 // host wiring lands in the dispatcher wave that bridges this kernel
702 // and the row evaluator in `bernoulli_marginal_slope.rs`.
703 double e_obs = bar_e_u[0];
704 double m_arg = s * e_obs;
705 double log_cdf, lambda;
706 log_ndtr_and_mills(m_arg, &log_cdf, &lambda);
707 double A_i = -w * s * lambda;
708 double B_i = w * lambda * (m_arg + lambda);
709
710 out_neglog[row] = -w * log_cdf;
711 for (int u = 0; u < r; ++u) {
712 out_grad[row * r + u] = A_i * bar_e_u[u];
713 }
714 for (int u = 0; u < r; ++u) {
715 for (int v = u; v < r; ++v) {
716 double val = B_i * bar_e_u[u] * bar_e_u[v] + A_i * bar_e_uv[u * r + v];
717 out_hess[row * r * r + u * r + v] = val;
718 if (u != v) {
719 out_hess[row * r * r + v * r + u] = val;
720 }
721 }
722 }
723}
724"#;
725
726#[inline]
733pub(crate) fn s_f_diagnostic_finite(inputs: &BmsFlexRowKernelInputs<'_>) -> bool {
734 inputs.s_f.is_finite() && inputs.s_f > 0.0
735}
736
737#[cfg(target_os = "linux")]
738pub(crate) struct RowKernelBackend {
739 pub(crate) stream: Arc<CudaStream>,
740 pub(crate) module: Arc<CudaModule>,
741}
742
743#[cfg(target_os = "linux")]
744impl RowKernelBackend {
745 pub(crate) fn probe() -> Result<&'static Self, GpuError> {
746 static BACKEND: OnceLock<Result<RowKernelBackend, GpuError>> = OnceLock::new();
747 BACKEND
748 .get_or_init(|| {
749 gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row", |parts| {
750 let row_kernel_source = [
751 gam_gpu::numerics_device::PROBIT_NUMERICS_CU,
752 ROW_KERNEL_BODY,
753 ]
754 .concat();
755 let ptx = cudarc::nvrtc::compile_ptx(row_kernel_source).map_err(|err| {
756 GpuError::DriverCallFailed {
757 reason: format!("bms_flex_row NVRTC compile failed: {err}"),
758 }
759 })?;
760 let module =
761 parts
762 .ctx
763 .load_module(ptx)
764 .map_err(|err| GpuError::DriverCallFailed {
765 reason: format!("bms_flex_row module load failed: {err}"),
766 })?;
767 Ok(RowKernelBackend {
768 stream: parts.stream.clone(),
769 module,
770 })
771 })
772 })
773 .as_ref()
774 .map_err(GpuError::clone)
775 }
776}
777
778pub(crate) fn launch_bms_flex_row_kernel(
783 inputs: BmsFlexRowKernelInputs<'_>,
784) -> Result<BmsFlexRowKernelOutputs, GpuError> {
785 inputs.validate()?;
786 if !s_f_diagnostic_finite(&inputs) {
787 return Err(GpuError::DriverCallFailed {
788 reason: format!(
789 "bms_flex_row inputs: s_f must be positive and finite, got {}",
790 inputs.s_f
791 ),
792 });
793 }
794
795 #[cfg(target_os = "linux")]
796 {
797 launch_linux(inputs)
798 }
799 #[cfg(not(target_os = "linux"))]
800 {
801 Err(GpuError::DriverLibraryUnavailable {
802 reason: "bms_flex_row GPU kernel is Linux-only".to_string(),
803 })
804 }
805}
806
807#[cfg(target_os = "linux")]
808pub(crate) fn launch_linux(
809 inputs: BmsFlexRowKernelInputs<'_>,
810) -> Result<BmsFlexRowKernelOutputs, GpuError> {
811 let backend = RowKernelBackend::probe()?;
812 let stream = &backend.stream;
813
814 let upload_f64 = |slice: &[f64], label: &str| {
815 stream
816 .clone_htod(slice)
817 .map_err(|err| GpuError::DriverCallFailed {
818 reason: format!("bms_flex_row upload {label}: {err}"),
819 })
820 };
821 let upload_u32 = |slice: &[u32], label: &str| {
822 stream
823 .clone_htod(slice)
824 .map_err(|err| GpuError::DriverCallFailed {
825 reason: format!("bms_flex_row upload {label}: {err}"),
826 })
827 };
828
829 let d_q = upload_f64(inputs.q, "q")?;
830 let d_b = upload_f64(inputs.b, "b")?;
831 let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
832 let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
833 let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
834 let d_y = upload_f64(inputs.y, "y")?;
835 let d_w = upload_f64(inputs.w, "w")?;
836 let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
837 let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
838 let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
839 let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
840 let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
841 let d_a = upload_f64(inputs.cell_a, "cell_a")?;
842 let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
843 let d_r = upload_f64(inputs.cell_r, "cell_r")?;
844 let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
845 let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
846 let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
847 let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
848 let owned_host_moments: CudaSlice<f64>;
852 let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
853 CellMomentsSource::Host(slice) => {
854 owned_host_moments = upload_f64(slice, "cell_moments")?;
855 &owned_host_moments
856 }
857 CellMomentsSource::Device(d) => *d,
858 };
859 let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
860 let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
861 let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
862 let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
863 let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
864
865 let n = inputs.n_rows;
866 let r = inputs.r;
867 let mut d_neglog = stream
868 .alloc_zeros::<f64>(n)
869 .map_err(|err| GpuError::DriverCallFailed {
870 reason: format!("bms_flex_row alloc neglog: {err}"),
871 })?;
872 let mut d_grad =
873 stream
874 .alloc_zeros::<f64>(n * r)
875 .map_err(|err| GpuError::DriverCallFailed {
876 reason: format!("bms_flex_row alloc grad: {err}"),
877 })?;
878 let mut d_hess =
879 stream
880 .alloc_zeros::<f64>(n * r * r)
881 .map_err(|err| GpuError::DriverCallFailed {
882 reason: format!("bms_flex_row alloc hess: {err}"),
883 })?;
884
885 let func = backend
886 .module
887 .load_function("bms_flex_row_kernel")
888 .map_err(|err| GpuError::DriverCallFailed {
889 reason: format!("bms_flex_row load_function: {err}"),
890 })?;
891
892 let cfg = LaunchConfig {
893 grid_dim: (n as u32, 1, 1),
894 block_dim: (ROW_KERNEL_THREADS, 1, 1),
895 shared_mem_bytes: 0,
896 };
897 let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
898 reason: format!("bms_flex_row: n_rows={n} exceeds i32 range"),
899 })?;
900 let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
901 reason: format!("bms_flex_row: r={r} exceeds i32 range"),
902 })?;
903 let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
904 reason: format!("bms_flex_row: p_h={} exceeds i32 range", inputs.p_h),
905 })?;
906 let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
907 reason: format!("bms_flex_row: p_w={} exceeds i32 range", inputs.p_w),
908 })?;
909 let s_f = inputs.s_f;
910
911 let mut builder = stream.launch_builder(&func);
912 builder
913 .arg(&n_i32)
914 .arg(&r_i32)
915 .arg(&p_h_i32)
916 .arg(&p_w_i32)
917 .arg(&s_f)
918 .arg(&d_q)
919 .arg(&d_b)
920 .arg(&d_mu1)
921 .arg(&d_mu2)
922 .arg(&d_zobs)
923 .arg(&d_y)
924 .arg(&d_w)
925 .arg(&d_offsets)
926 .arg(&d_c0)
927 .arg(&d_c1)
928 .arg(&d_c2)
929 .arg(&d_c3)
930 .arg(&d_a)
931 .arg(&d_aa)
932 .arg(&d_r)
933 .arg(&d_ar)
934 .arg(&d_sbb)
935 .arg(&d_sbh)
936 .arg(&d_sbw)
937 .arg(d_moments_ref)
938 .arg(&d_chi)
939 .arg(&d_xi)
940 .arg(&d_rho)
941 .arg(&d_tau)
942 .arg(&d_ruv)
943 .arg(&mut d_neglog)
944 .arg(&mut d_grad)
945 .arg(&mut d_hess);
946
947 unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
954 reason: format!("bms_flex_row launch: {err}"),
955 })?;
956 stream
957 .synchronize()
958 .map_err(|err| GpuError::DriverCallFailed {
959 reason: format!("bms_flex_row synchronize: {err}"),
960 })?;
961
962 let neglog = stream
963 .clone_dtoh(&d_neglog)
964 .map_err(|err| GpuError::DriverCallFailed {
965 reason: format!("bms_flex_row download neglog: {err}"),
966 })?;
967 let grad = stream
968 .clone_dtoh(&d_grad)
969 .map_err(|err| GpuError::DriverCallFailed {
970 reason: format!("bms_flex_row download grad: {err}"),
971 })?;
972 let hess = stream
973 .clone_dtoh(&d_hess)
974 .map_err(|err| GpuError::DriverCallFailed {
975 reason: format!("bms_flex_row download hess: {err}"),
976 })?;
977
978 Ok(BmsFlexRowKernelOutputs { neglog, grad, hess })
979}
980
981#[cfg(target_os = "linux")]
1033#[derive(Clone, Debug)]
1034pub(crate) struct BmsFlexBlockLayout {
1035 pub p_m: usize,
1036 pub p_g: usize,
1037 pub h: Option<std::ops::Range<usize>>,
1038 pub w: Option<std::ops::Range<usize>>,
1039 pub p_total: usize,
1040}
1041
1042#[cfg(target_os = "linux")]
1045#[derive(Clone, Debug)]
1046pub(crate) struct BmsFlexPrimaryLayout {
1047 pub h: Option<std::ops::Range<usize>>,
1048 pub w: Option<std::ops::Range<usize>>,
1049 pub r: usize,
1050}
1051
1052#[cfg(target_os = "linux")]
1058pub(crate) const HVP_ROWS_PER_CTA: u32 = 256;
1059
1060#[cfg(target_os = "linux")]
1062pub(crate) const HVP_THREADS: u32 = 128;
1063
1064#[cfg(target_os = "linux")]
1069pub(crate) const REDUCTION_THREADS: u32 = 256;
1070
1071#[cfg(target_os = "linux")]
1076pub(crate) const BMS_FLEX_ROW_HVP_MAX_RHS: usize = 8;
1077
1078#[cfg(target_os = "linux")]
1099pub struct DeviceResidentRowHess {
1100 pub(crate) hess: CudaSlice<f64>,
1104 pub(crate) marginal_design: CudaSlice<f64>,
1105 pub(crate) logslope_design: CudaSlice<f64>,
1106 pub(crate) n: usize,
1107 pub(crate) r: usize,
1108 pub(crate) block: BmsFlexBlockLayout,
1109 pub(crate) primary: BmsFlexPrimaryLayout,
1110 pub(crate) bytes: u64,
1112}
1113
1114#[cfg(target_os = "linux")]
1115impl std::fmt::Debug for DeviceResidentRowHess {
1116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1117 f.debug_struct("DeviceResidentRowHess")
1118 .field("n", &self.n)
1119 .field("r", &self.r)
1120 .field("p_total", &self.block.p_total)
1121 .field("bytes", &self.bytes)
1122 .finish()
1123 }
1124}
1125
1126#[cfg(target_os = "linux")]
1129pub(crate) fn num_hvp_chunks(n: usize) -> usize {
1130 n.div_ceil(HVP_ROWS_PER_CTA as usize)
1131}
1132
1133#[cfg(target_os = "linux")]
1136pub(crate) const HVP_KERNEL_SOURCE: &str = r#"
1137// CPU parity reference: cpu_oracle_bms_flex_row_hvp / cpu_oracle_bms_flex_row_diagonal
1138// in this module.
1139
1140#define MAX_MULTI_RHS 8
1141
1142extern "C" __global__ void bms_flex_row_hvp_partial(
1143 int n_rows,
1144 int r,
1145 int p_m,
1146 int p_g,
1147 int p_total,
1148 int h_block_start,
1149 int h_block_len,
1150 int w_block_start,
1151 int w_block_len,
1152 int h_primary_start,
1153 int w_primary_start,
1154 int rows_per_cta,
1155 const double * __restrict__ row_hessians, // [n, r*r]
1156 const double * __restrict__ marginal_design, // [n, p_m] row-major
1157 const double * __restrict__ logslope_design, // [n, p_g] row-major
1158 const double * __restrict__ v, // [p_total]
1159 double * __restrict__ partial) // [num_chunks, p_total]
1160{
1161 int chunk = blockIdx.x;
1162 int tid = threadIdx.x;
1163 int row_lo = chunk * rows_per_cta;
1164 int row_hi = row_lo + rows_per_cta;
1165 if (row_hi > n_rows) row_hi = n_rows;
1166
1167 // Zero this chunk's partial slice cooperatively.
1168 double *out = partial + (size_t)chunk * (size_t)p_total;
1169 for (int j = tid; j < p_total; j += blockDim.x) {
1170 out[j] = 0.0;
1171 }
1172 __syncthreads();
1173
1174 // Each thread serially processes a stride-of-blockDim set of rows so
1175 // every write to `out[..]` happens from one thread → no atomics within
1176 // the chunk. To keep writes race-free across threads of the same chunk,
1177 // we serialize the cross-row accumulation through a per-row barrier:
1178 // thread 0 of the block processes all rows in the chunk. The per-row
1179 // work is dominated by the dot/axpy over `p_m + p_g`, which is large.
1180 // For Stage 3 we ship the simple, correct path (thread 0 sequential
1181 // per row, blockDim.x threads parallel within a row's dot/axpy).
1182 __shared__ double row_dir[32];
1183 __shared__ double action[32];
1184 __shared__ double dot_reduce[128];
1185
1186 for (int row = row_lo; row < row_hi; ++row) {
1187 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1188 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1189 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1190
1191 // row_dir[0] = mrow · v[0..p_m]
1192 double local = 0.0;
1193 for (int j = tid; j < p_m; j += blockDim.x) {
1194 local += mrow[j] * v[j];
1195 }
1196 dot_reduce[tid] = local;
1197 __syncthreads();
1198 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1199 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1200 __syncthreads();
1201 }
1202 if (tid == 0) row_dir[0] = dot_reduce[0];
1203
1204 // row_dir[1] = grow · v[p_m..p_m+p_g]
1205 local = 0.0;
1206 for (int j = tid; j < p_g; j += blockDim.x) {
1207 local += grow[j] * v[p_m + j];
1208 }
1209 dot_reduce[tid] = local;
1210 __syncthreads();
1211 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1212 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1213 __syncthreads();
1214 }
1215 if (tid == 0) row_dir[1] = dot_reduce[0];
1216
1217 // h/w blocks: direct copy.
1218 if (tid == 0) {
1219 for (int k = 0; k < h_block_len; ++k) {
1220 row_dir[h_primary_start + k] = v[h_block_start + k];
1221 }
1222 for (int k = 0; k < w_block_len; ++k) {
1223 row_dir[w_primary_start + k] = v[w_block_start + k];
1224 }
1225 }
1226 __syncthreads();
1227
1228 // action[u] = Σ_v Hrow[u*r+v] · row_dir[v], computed by thread u (u < r).
1229 if (tid < r) {
1230 double acc = 0.0;
1231 for (int vv = 0; vv < r; ++vv) {
1232 acc += Hrow[tid * r + vv] * row_dir[vv];
1233 }
1234 action[tid] = acc;
1235 }
1236 __syncthreads();
1237
1238 // Pull back into joint β slot.
1239 // marginal: out[j] += action[0] · mrow[j] (parallel j)
1240 double a0 = action[0];
1241 for (int j = tid; j < p_m; j += blockDim.x) {
1242 out[j] += a0 * mrow[j];
1243 }
1244 double a1 = action[1];
1245 for (int j = tid; j < p_g; j += blockDim.x) {
1246 out[p_m + j] += a1 * grow[j];
1247 }
1248 if (tid == 0) {
1249 for (int k = 0; k < h_block_len; ++k) {
1250 out[h_block_start + k] += action[h_primary_start + k];
1251 }
1252 for (int k = 0; k < w_block_len; ++k) {
1253 out[w_block_start + k] += action[w_primary_start + k];
1254 }
1255 }
1256 __syncthreads();
1257 }
1258}
1259
1260extern "C" __global__ void bms_flex_row_hvp_reduce(
1261 int num_chunks,
1262 int p_total,
1263 const double * __restrict__ partial, // [num_chunks, p_total]
1264 double * __restrict__ out) // [p_total]
1265{
1266 int j = blockIdx.x * blockDim.x + threadIdx.x;
1267 if (j >= p_total) return;
1268 double acc = 0.0;
1269 for (int c = 0; c < num_chunks; ++c) {
1270 acc += partial[(size_t)c * (size_t)p_total + (size_t)j];
1271 }
1272 out[j] = acc;
1273}
1274
1275extern "C" __global__ void bms_flex_row_hvp_multi_partial(
1276 int n_rows,
1277 int r,
1278 int p_m,
1279 int p_g,
1280 int p_total,
1281 int h_block_start,
1282 int h_block_len,
1283 int w_block_start,
1284 int w_block_len,
1285 int h_primary_start,
1286 int w_primary_start,
1287 int rows_per_cta,
1288 int rhs_count,
1289 const double * __restrict__ row_hessians, // [n, r*r]
1290 const double * __restrict__ marginal_design, // [n, p_m]
1291 const double * __restrict__ logslope_design, // [n, p_g]
1292 const double * __restrict__ v_rhs, // [rhs_count, p_total]
1293 double * __restrict__ partial) // [rhs_count, num_chunks, p_total]
1294{
1295 int chunk = blockIdx.x;
1296 int tid = threadIdx.x;
1297 int row_lo = chunk * rows_per_cta;
1298 int row_hi = row_lo + rows_per_cta;
1299 if (row_hi > n_rows) row_hi = n_rows;
1300
1301 int num_chunks = (n_rows + rows_per_cta - 1) / rows_per_cta;
1302 for (int idx = tid; idx < rhs_count * p_total; idx += blockDim.x) {
1303 int rhs = idx / p_total;
1304 int j = idx - rhs * p_total;
1305 partial[((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total + (size_t)j] = 0.0;
1306 }
1307 __syncthreads();
1308
1309 __shared__ double row_dir[MAX_MULTI_RHS * 32];
1310 __shared__ double action[MAX_MULTI_RHS * 32];
1311 __shared__ double dot_reduce[128];
1312
1313 for (int row = row_lo; row < row_hi; ++row) {
1314 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1315 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1316 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1317
1318 for (int rhs = 0; rhs < rhs_count; ++rhs) {
1319 const double *v = v_rhs + (size_t)rhs * (size_t)p_total;
1320
1321 double local = 0.0;
1322 for (int j = tid; j < p_m; j += blockDim.x) {
1323 local += mrow[j] * v[j];
1324 }
1325 dot_reduce[tid] = local;
1326 __syncthreads();
1327 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1328 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1329 __syncthreads();
1330 }
1331 if (tid == 0) row_dir[rhs * 32 + 0] = dot_reduce[0];
1332
1333 local = 0.0;
1334 for (int j = tid; j < p_g; j += blockDim.x) {
1335 local += grow[j] * v[p_m + j];
1336 }
1337 dot_reduce[tid] = local;
1338 __syncthreads();
1339 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1340 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1341 __syncthreads();
1342 }
1343 if (tid == 0) {
1344 row_dir[rhs * 32 + 1] = dot_reduce[0];
1345 for (int k = 0; k < h_block_len; ++k) {
1346 row_dir[rhs * 32 + h_primary_start + k] = v[h_block_start + k];
1347 }
1348 for (int k = 0; k < w_block_len; ++k) {
1349 row_dir[rhs * 32 + w_primary_start + k] = v[w_block_start + k];
1350 }
1351 }
1352 __syncthreads();
1353 }
1354
1355 for (int idx = tid; idx < rhs_count * r; idx += blockDim.x) {
1356 int rhs = idx / r;
1357 int u = idx - rhs * r;
1358 double acc = 0.0;
1359 const double *dir = row_dir + rhs * 32;
1360 for (int vv = 0; vv < r; ++vv) {
1361 acc += Hrow[u * r + vv] * dir[vv];
1362 }
1363 action[rhs * 32 + u] = acc;
1364 }
1365 __syncthreads();
1366
1367 for (int rhs = 0; rhs < rhs_count; ++rhs) {
1368 double *out = partial + ((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total;
1369 double a0 = action[rhs * 32 + 0];
1370 for (int j = tid; j < p_m; j += blockDim.x) {
1371 out[j] += a0 * mrow[j];
1372 }
1373 double a1 = action[rhs * 32 + 1];
1374 for (int j = tid; j < p_g; j += blockDim.x) {
1375 out[p_m + j] += a1 * grow[j];
1376 }
1377 if (tid == 0) {
1378 for (int k = 0; k < h_block_len; ++k) {
1379 out[h_block_start + k] += action[rhs * 32 + h_primary_start + k];
1380 }
1381 for (int k = 0; k < w_block_len; ++k) {
1382 out[w_block_start + k] += action[rhs * 32 + w_primary_start + k];
1383 }
1384 }
1385 __syncthreads();
1386 }
1387 }
1388}
1389
1390extern "C" __global__ void bms_flex_row_hvp_multi_reduce(
1391 int num_chunks,
1392 int p_total,
1393 int rhs_count,
1394 const double * __restrict__ partial, // [rhs_count, num_chunks, p_total]
1395 double * __restrict__ out) // [rhs_count, p_total]
1396{
1397 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1398 int total = rhs_count * p_total;
1399 if (idx >= total) return;
1400 int rhs = idx / p_total;
1401 int j = idx - rhs * p_total;
1402 double acc = 0.0;
1403 for (int c = 0; c < num_chunks; ++c) {
1404 acc += partial[((size_t)rhs * (size_t)num_chunks + (size_t)c) * (size_t)p_total + (size_t)j];
1405 }
1406 out[(size_t)rhs * (size_t)p_total + (size_t)j] = acc;
1407}
1408
1409extern "C" __global__ void bms_flex_row_diag_partial(
1410 int n_rows,
1411 int r,
1412 int p_m,
1413 int p_g,
1414 int p_total,
1415 int h_block_start,
1416 int h_block_len,
1417 int w_block_start,
1418 int w_block_len,
1419 int h_primary_start,
1420 int w_primary_start,
1421 int rows_per_cta,
1422 const double * __restrict__ row_hessians,
1423 const double * __restrict__ marginal_design,
1424 const double * __restrict__ logslope_design,
1425 double * __restrict__ partial)
1426{
1427 int chunk = blockIdx.x;
1428 int tid = threadIdx.x;
1429 int row_lo = chunk * rows_per_cta;
1430 int row_hi = row_lo + rows_per_cta;
1431 if (row_hi > n_rows) row_hi = n_rows;
1432
1433 double *out = partial + (size_t)chunk * (size_t)p_total;
1434 for (int j = tid; j < p_total; j += blockDim.x) {
1435 out[j] = 0.0;
1436 }
1437 __syncthreads();
1438
1439 for (int row = row_lo; row < row_hi; ++row) {
1440 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1441 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1442 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1443 double h00 = Hrow[0];
1444 double h11 = Hrow[1 * r + 1];
1445 for (int j = tid; j < p_m; j += blockDim.x) {
1446 double v = mrow[j];
1447 out[j] += h00 * v * v;
1448 }
1449 for (int j = tid; j < p_g; j += blockDim.x) {
1450 double v = grow[j];
1451 out[p_m + j] += h11 * v * v;
1452 }
1453 if (tid == 0) {
1454 for (int k = 0; k < h_block_len; ++k) {
1455 int ii = h_primary_start + k;
1456 out[h_block_start + k] += Hrow[ii * r + ii];
1457 }
1458 for (int k = 0; k < w_block_len; ++k) {
1459 int ii = w_primary_start + k;
1460 out[w_block_start + k] += Hrow[ii * r + ii];
1461 }
1462 }
1463 __syncthreads();
1464 }
1465}
1466
1467// ────────────────────────────────────────────────────────────────────────
1468// Phase 4 — SymmetricPackedUpper variants. Per-row storage is
1469// row_hessians_packed + (size_t)row * (size_t)(r*(r+1)/2)
1470// indexed as
1471// packed[(u*(2*r - u - 1))/2 + (v - u)] for u <= v
1472// with symmetric mirror for v < u.
1473// ────────────────────────────────────────────────────────────────────────
1474
1475// Helper: packed-upper index for (u, v) within a single row of r*(r+1)/2
1476// doubles. Caller must pre-swap so that u <= v.
1477__device__ __forceinline__ int bms_flex_packed_idx(int u, int v, int r) {
1478 // u*(2r - u - 1)/2 + (v - u)
1479 return (u * (2 * r - u - 1)) / 2 + (v - u);
1480}
1481
1482// Pack one row of the full row-major r×r Hessian into packed-upper layout.
1483// Launched as one CTA per row (gridDim.x = n_rows, blockDim.x configurable).
1484// Bit-equal copy: each upper-triangle entry is read once from the dense
1485// source and written once to the packed destination.
1486extern "C" __global__ void bms_flex_row_pack_upper(
1487 int n_rows,
1488 int r,
1489 const double * __restrict__ src_full, // [n, r*r]
1490 double * __restrict__ dst_packed) // [n, r*(r+1)/2]
1491{
1492 int row = blockIdx.x;
1493 if (row >= n_rows) return;
1494 int tid = threadIdx.x;
1495 int per_row = r * (r + 1) / 2;
1496 const double *src = src_full + (size_t)row * (size_t)r * (size_t)r;
1497 double *dst = dst_packed + (size_t)row * (size_t)per_row;
1498 // Linear scan over packed positions; map each back to (u, v).
1499 for (int pos = tid; pos < per_row; pos += blockDim.x) {
1500 // Invert: for u in [0, r), the range [u_start, u_start + (r - u))
1501 // contains positions for that u. u_start = u*(2r - u - 1)/2.
1502 // Solve smallest u with u*(2r - u - 1)/2 > pos to get u (then
1503 // back off by one); equivalent O(r) linear scan with r <= 32.
1504 int u = 0;
1505 int u_start = 0;
1506 while (u < r) {
1507 int next = u_start + (r - u);
1508 if (pos < next) break;
1509 u_start = next;
1510 ++u;
1511 }
1512 int v = u + (pos - u_start);
1513 dst[pos] = src[(size_t)u * (size_t)r + (size_t)v];
1514 }
1515}
1516
1517extern "C" __global__ void bms_flex_row_hvp_partial_packed(
1518 int n_rows,
1519 int r,
1520 int p_m,
1521 int p_g,
1522 int p_total,
1523 int h_block_start,
1524 int h_block_len,
1525 int w_block_start,
1526 int w_block_len,
1527 int h_primary_start,
1528 int w_primary_start,
1529 int rows_per_cta,
1530 const double * __restrict__ row_hessians_packed, // [n, r*(r+1)/2]
1531 const double * __restrict__ marginal_design,
1532 const double * __restrict__ logslope_design,
1533 const double * __restrict__ v,
1534 double * __restrict__ partial)
1535{
1536 int chunk = blockIdx.x;
1537 int tid = threadIdx.x;
1538 int row_lo = chunk * rows_per_cta;
1539 int row_hi = row_lo + rows_per_cta;
1540 if (row_hi > n_rows) row_hi = n_rows;
1541
1542 int per_row = r * (r + 1) / 2;
1543 double *out = partial + (size_t)chunk * (size_t)p_total;
1544 for (int j = tid; j < p_total; j += blockDim.x) {
1545 out[j] = 0.0;
1546 }
1547 __syncthreads();
1548
1549 __shared__ double row_dir[32];
1550 __shared__ double action[32];
1551 __shared__ double dot_reduce[128];
1552
1553 for (int row = row_lo; row < row_hi; ++row) {
1554 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1555 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1556 const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1557
1558 // row_dir[0] = mrow · v[0..p_m]
1559 double local = 0.0;
1560 for (int j = tid; j < p_m; j += blockDim.x) {
1561 local += mrow[j] * v[j];
1562 }
1563 dot_reduce[tid] = local;
1564 __syncthreads();
1565 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1566 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1567 __syncthreads();
1568 }
1569 if (tid == 0) row_dir[0] = dot_reduce[0];
1570
1571 // row_dir[1] = grow · v[p_m..p_m+p_g]
1572 local = 0.0;
1573 for (int j = tid; j < p_g; j += blockDim.x) {
1574 local += grow[j] * v[p_m + j];
1575 }
1576 dot_reduce[tid] = local;
1577 __syncthreads();
1578 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1579 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1580 __syncthreads();
1581 }
1582 if (tid == 0) row_dir[1] = dot_reduce[0];
1583
1584 if (tid == 0) {
1585 for (int k = 0; k < h_block_len; ++k) {
1586 row_dir[h_primary_start + k] = v[h_block_start + k];
1587 }
1588 for (int k = 0; k < w_block_len; ++k) {
1589 row_dir[w_primary_start + k] = v[w_block_start + k];
1590 }
1591 }
1592 __syncthreads();
1593
1594 // action[u] = Σ_w H[u, w] · row_dir[w], where H[u, w] reads from
1595 // packed-upper with (uu, vv) = (min(u, w), max(u, w)).
1596 if (tid < r) {
1597 double acc = 0.0;
1598 int u = tid;
1599 for (int w = 0; w < r; ++w) {
1600 int uu = u < w ? u : w;
1601 int vv = u < w ? w : u;
1602 acc += Hrow[bms_flex_packed_idx(uu, vv, r)] * row_dir[w];
1603 }
1604 action[tid] = acc;
1605 }
1606 __syncthreads();
1607
1608 double a0 = action[0];
1609 for (int j = tid; j < p_m; j += blockDim.x) {
1610 out[j] += a0 * mrow[j];
1611 }
1612 double a1 = action[1];
1613 for (int j = tid; j < p_g; j += blockDim.x) {
1614 out[p_m + j] += a1 * grow[j];
1615 }
1616 if (tid == 0) {
1617 for (int k = 0; k < h_block_len; ++k) {
1618 out[h_block_start + k] += action[h_primary_start + k];
1619 }
1620 for (int k = 0; k < w_block_len; ++k) {
1621 out[w_block_start + k] += action[w_primary_start + k];
1622 }
1623 }
1624 __syncthreads();
1625 }
1626}
1627
1628// ────────────────────────────────────────────────────────────────────────
1629// Phase 6 — dense joint-Hessian block kernel for the debug / exact-REML
1630// route. Materialises the full `[p_total, p_total]` row-major joint H
1631// from the per-row r×r Hessian via the P_i pullback. NOT the default
1632// Newton path: production Newton uses HVP (Phase 2/3); this kernel exists
1633// for exact-REML logdet / dense-H comparisons / diagnostic dumps where the
1634// caller genuinely needs the dense matrix on the device.
1635//
1636// Per-CTA partial: each CTA owns a contiguous chunk of rows
1637// `[chunk*rows_per_cta, (chunk+1)*rows_per_cta)`. Inside the CTA the
1638// per-row pullback computes `(P_i^T H_i P_i)[m, n]` and adds it to the
1639// CTA's shared-mem `[p_total, p_total]` partial. The reduce kernel sums
1640// chunk-major-fixed-order into a single `[p_total, p_total]` output.
1641//
1642// Math: for primary index u ∈ [0, r):
1643// * u = 0: phi_u = (X_i in slot 0..p_m, 0 elsewhere)
1644// * u = 1: phi_u = (0, G_i in slot p_m..p_m+p_g, 0 elsewhere)
1645// * u = 2+j: phi_u = e_{h_block_start + j} (j ∈ 0..h_block_len)
1646// * u = 2+h+l: phi_u = e_{w_block_start + l} (l ∈ 0..w_block_len)
1647// Then `H_full[m, n] += sum_{u,v} H_i[u,v] * phi_u[m] * phi_v[n]`.
1648//
1649// Shared-memory budget: at large-scale shape p_total = 44, a [44, 44] f64
1650// partial is 44*44*8 = 15.5 KiB — well below the V100 48 KiB/SM cap.
1651// At p_total ≤ 80 the kernel still fits (80*80*8 = 50 KiB → just over
1652// V100 cap; caller must enforce p_total ≤ DENSE_BLOCK_MAX_P). The
1653// launcher rejects oversize p_total cleanly.
1654
1655extern "C" __global__ void bms_flex_row_dense_block_partial(
1656 int n_rows,
1657 int r,
1658 int p_m,
1659 int p_g,
1660 int p_total,
1661 int h_block_start,
1662 int h_block_len,
1663 int w_block_start,
1664 int w_block_len,
1665 int h_primary_start,
1666 int w_primary_start,
1667 int rows_per_cta,
1668 const double * __restrict__ row_hessians, // [n, r*r]
1669 const double * __restrict__ marginal_design, // [n, p_m]
1670 const double * __restrict__ logslope_design, // [n, p_g]
1671 double * __restrict__ partial) // [num_chunks, p_total, p_total]
1672{
1673 extern __shared__ double shmem[];
1674 int chunk = blockIdx.x;
1675 int tid = threadIdx.x;
1676 int row_lo = chunk * rows_per_cta;
1677 int row_hi = row_lo + rows_per_cta;
1678 if (row_hi > n_rows) row_hi = n_rows;
1679
1680 int pp = p_total * p_total;
1681 double *acc = shmem; // CTA-private accumulator [p_total, p_total]
1682 for (int j = tid; j < pp; j += blockDim.x) acc[j] = 0.0;
1683 __syncthreads();
1684
1685 // Per-row work performed by thread 0 to avoid cross-thread RW
1686 // contention on `acc[]`. Per-row complexity is O(r * p_m + r * p_g
1687 // + r²): tractable because r ≤ 32 and p_m + p_g typically ≤ 64.
1688 // Tighter parallel implementations are possible (warp-stripe the
1689 // 4-way nested u-v-m-n loop) but Phase 6 is a debug-only path and
1690 // the simple version is easier to audit for correctness against
1691 // the host-side P_i pullback oracle.
1692 if (tid == 0) {
1693 for (int row = row_lo; row < row_hi; ++row) {
1694 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1695 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1696 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1697 for (int u = 0; u < r; ++u) {
1698 for (int v = 0; v < r; ++v) {
1699 double huv = Hrow[u * r + v];
1700 if (huv == 0.0) continue;
1701 // For each (u, v), iterate (m, n) over the non-zero
1702 // outer-product support of phi_u and phi_v.
1703 // Build a small (offset, len, src_ptr) descriptor for
1704 // each operand block as we go.
1705 int m_off, m_len; const double *m_src; bool m_indicator;
1706 int n_off, n_len; const double *n_src; bool n_indicator;
1707 if (u == 0) { m_off = 0; m_len = p_m; m_src = mrow; m_indicator = false; }
1708 else if (u == 1) { m_off = p_m; m_len = p_g; m_src = grow; m_indicator = false; }
1709 else if (u - 2 < h_block_len) {
1710 m_off = h_block_start + (u - 2);
1711 m_len = 1; m_src = NULL; m_indicator = true;
1712 } else {
1713 m_off = w_block_start + (u - 2 - h_block_len);
1714 m_len = 1; m_src = NULL; m_indicator = true;
1715 }
1716 if (v == 0) { n_off = 0; n_len = p_m; n_src = mrow; n_indicator = false; }
1717 else if (v == 1) { n_off = p_m; n_len = p_g; n_src = grow; n_indicator = false; }
1718 else if (v - 2 < h_block_len) {
1719 n_off = h_block_start + (v - 2);
1720 n_len = 1; n_src = NULL; n_indicator = true;
1721 } else {
1722 n_off = w_block_start + (v - 2 - h_block_len);
1723 n_len = 1; n_src = NULL; n_indicator = true;
1724 }
1725 // accumulate huv * phi_u[m] * phi_v[n] into acc[m, n]
1726 for (int mi = 0; mi < m_len; ++mi) {
1727 double pm = m_indicator ? 1.0 : m_src[mi];
1728 if (pm == 0.0) continue;
1729 double scaled = huv * pm;
1730 int m_idx = m_off + mi;
1731 for (int ni = 0; ni < n_len; ++ni) {
1732 double pn = n_indicator ? 1.0 : n_src[ni];
1733 int n_idx = n_off + ni;
1734 acc[m_idx * p_total + n_idx] += scaled * pn;
1735 }
1736 }
1737 }
1738 }
1739 }
1740 }
1741 __syncthreads();
1742
1743 // Write CTA accumulator out to global memory at its chunk slot.
1744 double *out_chunk = partial + (size_t)chunk * (size_t)pp;
1745 for (int j = tid; j < pp; j += blockDim.x) {
1746 out_chunk[j] = acc[j];
1747 }
1748}
1749
1750extern "C" __global__ void bms_flex_row_dense_block_reduce(
1751 int num_chunks,
1752 int p_total,
1753 const double * __restrict__ partial,
1754 double * __restrict__ out)
1755{
1756 int j = blockIdx.x * blockDim.x + threadIdx.x;
1757 int pp = p_total * p_total;
1758 if (j >= pp) return;
1759 double acc = 0.0;
1760 for (int c = 0; c < num_chunks; ++c) {
1761 acc += partial[(size_t)c * (size_t)pp + (size_t)j];
1762 }
1763 out[j] = acc;
1764}
1765
1766extern "C" __global__ void bms_flex_row_diag_partial_packed(
1767 int n_rows,
1768 int r,
1769 int p_m,
1770 int p_g,
1771 int p_total,
1772 int h_block_start,
1773 int h_block_len,
1774 int w_block_start,
1775 int w_block_len,
1776 int h_primary_start,
1777 int w_primary_start,
1778 int rows_per_cta,
1779 const double * __restrict__ row_hessians_packed,
1780 const double * __restrict__ marginal_design,
1781 const double * __restrict__ logslope_design,
1782 double * __restrict__ partial)
1783{
1784 int chunk = blockIdx.x;
1785 int tid = threadIdx.x;
1786 int row_lo = chunk * rows_per_cta;
1787 int row_hi = row_lo + rows_per_cta;
1788 if (row_hi > n_rows) row_hi = n_rows;
1789
1790 int per_row = r * (r + 1) / 2;
1791 double *out = partial + (size_t)chunk * (size_t)p_total;
1792 for (int j = tid; j < p_total; j += blockDim.x) {
1793 out[j] = 0.0;
1794 }
1795 __syncthreads();
1796
1797 for (int row = row_lo; row < row_hi; ++row) {
1798 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1799 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1800 const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1801 // Diagonal entry for (u, u) sits at packed_idx(u, u, r).
1802 double h00 = Hrow[bms_flex_packed_idx(0, 0, r)];
1803 double h11 = Hrow[bms_flex_packed_idx(1, 1, r)];
1804 for (int j = tid; j < p_m; j += blockDim.x) {
1805 double v = mrow[j];
1806 out[j] += h00 * v * v;
1807 }
1808 for (int j = tid; j < p_g; j += blockDim.x) {
1809 double v = grow[j];
1810 out[p_m + j] += h11 * v * v;
1811 }
1812 if (tid == 0) {
1813 for (int k = 0; k < h_block_len; ++k) {
1814 int ii = h_primary_start + k;
1815 out[h_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1816 }
1817 for (int k = 0; k < w_block_len; ++k) {
1818 int ii = w_primary_start + k;
1819 out[w_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1820 }
1821 }
1822 __syncthreads();
1823 }
1824}
1825"#;
1826
1827#[cfg(target_os = "linux")]
1828pub(crate) struct HvpKernelBackend {
1829 pub(crate) stream: Arc<CudaStream>,
1830 pub(crate) module: Arc<CudaModule>,
1831}
1832
1833#[cfg(target_os = "linux")]
1834impl HvpKernelBackend {
1835 pub(crate) fn probe() -> Result<&'static Self, GpuError> {
1836 static BACKEND: OnceLock<Result<HvpKernelBackend, GpuError>> = OnceLock::new();
1837 BACKEND
1838 .get_or_init(|| {
1839 gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row hvp", |parts| {
1840 let ptx = cudarc::nvrtc::compile_ptx(HVP_KERNEL_SOURCE).map_err(|err| {
1841 GpuError::DriverCallFailed {
1842 reason: format!("bms_flex_row hvp NVRTC compile failed: {err}"),
1843 }
1844 })?;
1845 let module =
1846 parts
1847 .ctx
1848 .load_module(ptx)
1849 .map_err(|err| GpuError::DriverCallFailed {
1850 reason: format!("bms_flex_row hvp module load failed: {err}"),
1851 })?;
1852 Ok(HvpKernelBackend {
1853 stream: parts.stream.clone(),
1854 module,
1855 })
1856 })
1857 })
1858 .as_ref()
1859 .map_err(GpuError::clone)
1860 }
1861}
1862
1863#[cfg(target_os = "linux")]
1889pub(crate) fn launch_bms_flex_row_kernel_device_resident(
1890 inputs: BmsFlexRowKernelInputs<'_>,
1891 marginal_design_row_major: &[f64],
1892 logslope_design_row_major: &[f64],
1893 block: BmsFlexBlockLayout,
1894 primary: BmsFlexPrimaryLayout,
1895) -> Result<DeviceResidentRowHess, GpuError> {
1896 inputs.validate()?;
1897 if !s_f_diagnostic_finite(&inputs) {
1898 return Err(GpuError::DriverCallFailed {
1899 reason: format!(
1900 "bms_flex_row device-resident: s_f must be positive and finite, got {}",
1901 inputs.s_f
1902 ),
1903 });
1904 }
1905 let n = inputs.n_rows;
1906 let r = inputs.r;
1907 if marginal_design_row_major.len() != n * block.p_m {
1908 return Err(GpuError::DriverCallFailed {
1909 reason: format!(
1910 "bms_flex_row device-resident: marginal_design len={} != n*p_m={}",
1911 marginal_design_row_major.len(),
1912 n * block.p_m
1913 ),
1914 });
1915 }
1916 if logslope_design_row_major.len() != n * block.p_g {
1917 return Err(GpuError::DriverCallFailed {
1918 reason: format!(
1919 "bms_flex_row device-resident: logslope_design len={} != n*p_g={}",
1920 logslope_design_row_major.len(),
1921 n * block.p_g
1922 ),
1923 });
1924 }
1925 if primary.r != r {
1926 return Err(GpuError::DriverCallFailed {
1927 reason: format!(
1928 "bms_flex_row device-resident: primary.r={} != inputs.r={}",
1929 primary.r, r
1930 ),
1931 });
1932 }
1933
1934 let backend = RowKernelBackend::probe()?;
1937 HvpKernelBackend::probe()?;
1938 let stream = backend.stream.clone();
1939
1940 let upload_f64 = |slice: &[f64], label: &str| {
1941 stream
1942 .clone_htod(slice)
1943 .map_err(|err| GpuError::DriverCallFailed {
1944 reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1945 })
1946 };
1947 let upload_u32 = |slice: &[u32], label: &str| {
1948 stream
1949 .clone_htod(slice)
1950 .map_err(|err| GpuError::DriverCallFailed {
1951 reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1952 })
1953 };
1954
1955 let d_q = upload_f64(inputs.q, "q")?;
1956 let d_b = upload_f64(inputs.b, "b")?;
1957 let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
1958 let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
1959 let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
1960 let d_y = upload_f64(inputs.y, "y")?;
1961 let d_w = upload_f64(inputs.w, "w")?;
1962 let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
1963 let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
1964 let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
1965 let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
1966 let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
1967 let d_a = upload_f64(inputs.cell_a, "cell_a")?;
1968 let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
1969 let d_r = upload_f64(inputs.cell_r, "cell_r")?;
1970 let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
1971 let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
1972 let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
1973 let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
1974 let owned_host_moments: CudaSlice<f64>;
1976 let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
1977 CellMomentsSource::Host(slice) => {
1978 owned_host_moments = upload_f64(slice, "cell_moments")?;
1979 &owned_host_moments
1980 }
1981 CellMomentsSource::Device(d) => *d,
1982 };
1983 let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
1984 let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
1985 let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
1986 let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
1987 let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
1988
1989 let d_marginal = upload_f64(marginal_design_row_major, "marginal_design")?;
1990 let d_logslope = upload_f64(logslope_design_row_major, "logslope_design")?;
1991
1992 let mut d_neglog = stream
1993 .alloc_zeros::<f64>(n)
1994 .map_err(|err| GpuError::DriverCallFailed {
1995 reason: format!("bms_flex_row device-resident alloc neglog: {err}"),
1996 })?;
1997 let mut d_grad =
1998 stream
1999 .alloc_zeros::<f64>(n * r)
2000 .map_err(|err| GpuError::DriverCallFailed {
2001 reason: format!("bms_flex_row device-resident alloc grad: {err}"),
2002 })?;
2003 let mut d_hess =
2004 stream
2005 .alloc_zeros::<f64>(n * r * r)
2006 .map_err(|err| GpuError::DriverCallFailed {
2007 reason: format!("bms_flex_row device-resident alloc hess: {err}"),
2008 })?;
2009
2010 let func = backend
2011 .module
2012 .load_function("bms_flex_row_kernel")
2013 .map_err(|err| GpuError::DriverCallFailed {
2014 reason: format!("bms_flex_row device-resident load_function: {err}"),
2015 })?;
2016
2017 let cfg = LaunchConfig {
2018 grid_dim: (n as u32, 1, 1),
2019 block_dim: (ROW_KERNEL_THREADS, 1, 1),
2020 shared_mem_bytes: 0,
2021 };
2022 let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
2023 reason: format!("bms_flex_row device-resident: n_rows={n} exceeds i32 range"),
2024 })?;
2025 let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
2026 reason: format!("bms_flex_row device-resident: r={r} exceeds i32 range"),
2027 })?;
2028 let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
2029 reason: format!(
2030 "bms_flex_row device-resident: p_h={} exceeds i32 range",
2031 inputs.p_h
2032 ),
2033 })?;
2034 let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
2035 reason: format!(
2036 "bms_flex_row device-resident: p_w={} exceeds i32 range",
2037 inputs.p_w
2038 ),
2039 })?;
2040 let s_f_val = inputs.s_f;
2041
2042 let mut builder = stream.launch_builder(&func);
2043 builder
2044 .arg(&n_i32)
2045 .arg(&r_i32)
2046 .arg(&p_h_i32)
2047 .arg(&p_w_i32)
2048 .arg(&s_f_val)
2049 .arg(&d_q)
2050 .arg(&d_b)
2051 .arg(&d_mu1)
2052 .arg(&d_mu2)
2053 .arg(&d_zobs)
2054 .arg(&d_y)
2055 .arg(&d_w)
2056 .arg(&d_offsets)
2057 .arg(&d_c0)
2058 .arg(&d_c1)
2059 .arg(&d_c2)
2060 .arg(&d_c3)
2061 .arg(&d_a)
2062 .arg(&d_aa)
2063 .arg(&d_r)
2064 .arg(&d_ar)
2065 .arg(&d_sbb)
2066 .arg(&d_sbh)
2067 .arg(&d_sbw)
2068 .arg(d_moments_ref)
2069 .arg(&d_chi)
2070 .arg(&d_xi)
2071 .arg(&d_rho)
2072 .arg(&d_tau)
2073 .arg(&d_ruv)
2074 .arg(&mut d_neglog)
2075 .arg(&mut d_grad)
2076 .arg(&mut d_hess);
2077 unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
2082 reason: format!("bms_flex_row device-resident launch: {err}"),
2083 })?;
2084 stream
2085 .synchronize()
2086 .map_err(|err| GpuError::DriverCallFailed {
2087 reason: format!("bms_flex_row device-resident synchronize: {err}"),
2088 })?;
2089
2090 drop(d_neglog);
2098 drop(d_grad);
2099 drop(d_q);
2101 drop(d_b);
2102 drop(d_mu1);
2103 drop(d_mu2);
2104 drop(d_zobs);
2105 drop(d_y);
2106 drop(d_w);
2107 drop(d_offsets);
2108 drop(d_c0);
2109 drop(d_c1);
2110 drop(d_c2);
2111 drop(d_c3);
2112 drop(d_a);
2113 drop(d_aa);
2114 drop(d_r);
2115 drop(d_ar);
2116 drop(d_sbb);
2117 drop(d_sbh);
2118 drop(d_sbw);
2119 drop(d_chi);
2123 drop(d_xi);
2124 drop(d_rho);
2125 drop(d_tau);
2126 drop(d_ruv);
2127
2128 let bytes = ((n * r * r + marginal_design_row_major.len() + logslope_design_row_major.len())
2129 * std::mem::size_of::<f64>()) as u64;
2130 Ok(DeviceResidentRowHess {
2131 hess: d_hess,
2132 marginal_design: d_marginal,
2133 logslope_design: d_logslope,
2134 n,
2135 r,
2136 block,
2137 primary,
2138 bytes,
2139 })
2140}
2141
2142#[cfg(target_os = "linux")]
2147#[derive(Clone, Copy)]
2148pub(crate) enum BmsFlexRowLaunchMode {
2149 HvpDeviceOut,
2151 DiagonalHostOut,
2153}
2154
2155#[cfg(target_os = "linux")]
2156impl BmsFlexRowLaunchMode {
2157 pub(crate) fn partial_kernel_name(self) -> &'static str {
2159 match self {
2160 BmsFlexRowLaunchMode::HvpDeviceOut => "bms_flex_row_hvp_partial",
2161 BmsFlexRowLaunchMode::DiagonalHostOut => "bms_flex_row_diag_partial",
2162 }
2163 }
2164}
2165
2166#[cfg(target_os = "linux")]
2172pub(crate) struct PreparedBmsFlexRowLaunchArgs {
2173 pub(crate) n_i32: i32,
2174 pub(crate) r_i32: i32,
2175 pub(crate) p_m_i32: i32,
2176 pub(crate) p_g_i32: i32,
2177 pub(crate) p_total_i32: i32,
2178 pub(crate) h_block_start: i32,
2179 pub(crate) h_block_len: i32,
2180 pub(crate) w_block_start: i32,
2181 pub(crate) w_block_len: i32,
2182 pub(crate) h_primary_start: i32,
2183 pub(crate) w_primary_start: i32,
2184 pub(crate) rows_per_cta: i32,
2185 pub(crate) num_chunks: usize,
2186}
2187
2188#[cfg(target_os = "linux")]
2189impl PreparedBmsFlexRowLaunchArgs {
2190 pub(crate) fn from_storage(storage: &DeviceResidentRowHess) -> Self {
2191 let p_total = storage.block.p_total;
2192 let num_chunks = num_hvp_chunks(storage.n);
2193 PreparedBmsFlexRowLaunchArgs {
2194 n_i32: storage.n as i32,
2195 r_i32: storage.r as i32,
2196 p_m_i32: storage.block.p_m as i32,
2197 p_g_i32: storage.block.p_g as i32,
2198 p_total_i32: p_total as i32,
2199 h_block_start: storage
2200 .block
2201 .h
2202 .as_ref()
2203 .map(|r| r.start as i32)
2204 .unwrap_or(0),
2205 h_block_len: storage
2206 .block
2207 .h
2208 .as_ref()
2209 .map(|r| r.len() as i32)
2210 .unwrap_or(0),
2211 w_block_start: storage
2212 .block
2213 .w
2214 .as_ref()
2215 .map(|r| r.start as i32)
2216 .unwrap_or(0),
2217 w_block_len: storage
2218 .block
2219 .w
2220 .as_ref()
2221 .map(|r| r.len() as i32)
2222 .unwrap_or(0),
2223 h_primary_start: storage
2224 .primary
2225 .h
2226 .as_ref()
2227 .map(|r| r.start as i32)
2228 .unwrap_or(0),
2229 w_primary_start: storage
2230 .primary
2231 .w
2232 .as_ref()
2233 .map(|r| r.start as i32)
2234 .unwrap_or(0),
2235 rows_per_cta: HVP_ROWS_PER_CTA as i32,
2236 num_chunks,
2237 }
2238 }
2239}
2240
2241#[cfg(target_os = "linux")]
2255pub(crate) fn run_bms_flex_row_partial_reduce(
2256 storage: &DeviceResidentRowHess,
2257 mode: BmsFlexRowLaunchMode,
2258 d_v: Option<&CudaSlice<f64>>,
2259 d_out: &mut CudaSlice<f64>,
2260 ctx: &str,
2261) -> Result<(), GpuError> {
2262 let backend = HvpKernelBackend::probe()?;
2263 let stream = backend.stream.clone();
2264 let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2265 let p_total = storage.block.p_total;
2266
2267 let mut d_partial = stream
2268 .alloc_zeros::<f64>(args.num_chunks * p_total)
2269 .map_err(|err| GpuError::DriverCallFailed {
2270 reason: format!("bms_flex_row {ctx} alloc partial: {err}"),
2271 })?;
2272
2273 let partial_kernel_name = mode.partial_kernel_name();
2274 let part_func = backend
2275 .module
2276 .load_function(partial_kernel_name)
2277 .map_err(|err| GpuError::DriverCallFailed {
2278 reason: format!("bms_flex_row {ctx} load {partial_kernel_name}: {err}"),
2279 })?;
2280 let red_func = backend
2281 .module
2282 .load_function("bms_flex_row_hvp_reduce")
2283 .map_err(|err| GpuError::DriverCallFailed {
2284 reason: format!("bms_flex_row {ctx} load reduce: {err}"),
2285 })?;
2286
2287 let cfg_part = LaunchConfig {
2288 grid_dim: (args.num_chunks as u32, 1, 1),
2289 block_dim: (HVP_THREADS, 1, 1),
2290 shared_mem_bytes: 0,
2291 };
2292 let mut builder = stream.launch_builder(&part_func);
2293 builder
2294 .arg(&args.n_i32)
2295 .arg(&args.r_i32)
2296 .arg(&args.p_m_i32)
2297 .arg(&args.p_g_i32)
2298 .arg(&args.p_total_i32)
2299 .arg(&args.h_block_start)
2300 .arg(&args.h_block_len)
2301 .arg(&args.w_block_start)
2302 .arg(&args.w_block_len)
2303 .arg(&args.h_primary_start)
2304 .arg(&args.w_primary_start)
2305 .arg(&args.rows_per_cta)
2306 .arg(&storage.hess)
2307 .arg(&storage.marginal_design)
2308 .arg(&storage.logslope_design);
2309 if let Some(d_v) = d_v {
2310 builder.arg(d_v);
2311 }
2312 builder.arg(&mut d_partial);
2313 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2321 reason: format!("bms_flex_row {ctx} partial launch: {err}"),
2322 })?;
2323
2324 let red_threads: u32 = REDUCTION_THREADS;
2325 let red_blocks: u32 = ((p_total as u32) + red_threads - 1) / red_threads;
2326 let cfg_red = LaunchConfig {
2327 grid_dim: (red_blocks, 1, 1),
2328 block_dim: (red_threads, 1, 1),
2329 shared_mem_bytes: 0,
2330 };
2331 let num_chunks_i32 = args.num_chunks as i32;
2332 let mut builder = stream.launch_builder(&red_func);
2333 builder
2334 .arg(&num_chunks_i32)
2335 .arg(&args.p_total_i32)
2336 .arg(&d_partial)
2337 .arg(d_out);
2338 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2342 reason: format!("bms_flex_row {ctx} reduce launch: {err}"),
2343 })?;
2344 drop(d_partial);
2347 Ok(())
2348}
2349
2350#[cfg(target_os = "linux")]
2360pub(crate) fn launch_bms_flex_row_host(
2361 storage: &DeviceResidentRowHess,
2362 mode: BmsFlexRowLaunchMode,
2363 v: Option<&[f64]>,
2364 ctx: &str,
2365) -> Result<Vec<f64>, GpuError> {
2366 let p_total = storage.block.p_total;
2367 if let Some(v) = v {
2368 if v.len() != p_total {
2369 return Err(GpuError::DriverCallFailed {
2370 reason: format!(
2371 "bms_flex_row {ctx}: v.len()={} != p_total={p_total}",
2372 v.len()
2373 ),
2374 });
2375 }
2376 }
2377
2378 let backend = HvpKernelBackend::probe()?;
2379 let stream = backend.stream.clone();
2380
2381 let d_v = match v {
2382 Some(v) => Some(
2383 stream
2384 .clone_htod(v)
2385 .map_err(|err| GpuError::DriverCallFailed {
2386 reason: format!("bms_flex_row {ctx} upload v: {err}"),
2387 })?,
2388 ),
2389 None => None,
2390 };
2391 let mut d_out =
2392 stream
2393 .alloc_zeros::<f64>(p_total)
2394 .map_err(|err| GpuError::DriverCallFailed {
2395 reason: format!("bms_flex_row {ctx} alloc out: {err}"),
2396 })?;
2397
2398 run_bms_flex_row_partial_reduce(storage, mode, d_v.as_ref(), &mut d_out, ctx)?;
2399
2400 stream
2401 .synchronize()
2402 .map_err(|err| GpuError::DriverCallFailed {
2403 reason: format!("bms_flex_row {ctx} synchronize: {err}"),
2404 })?;
2405 stream
2406 .clone_dtoh(&d_out)
2407 .map_err(|err| GpuError::DriverCallFailed {
2408 reason: format!("bms_flex_row {ctx} download out: {err}"),
2409 })
2410}
2411
2412#[cfg(target_os = "linux")]
2413pub(crate) fn validate_bms_flex_row_hvp_multi_shape(
2414 storage: &DeviceResidentRowHess,
2415 rhs_count: usize,
2416 v_rhs_len: usize,
2417 out_len: Option<usize>,
2418 ctx: &str,
2419) -> Result<usize, GpuError> {
2420 if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2421 return Err(GpuError::DriverCallFailed {
2422 reason: format!(
2423 "bms_flex_row {ctx}: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2424 ),
2425 });
2426 }
2427 let p_total = storage.block.p_total;
2428 let rhs_elems = rhs_count
2429 .checked_mul(p_total)
2430 .ok_or_else(|| GpuError::DriverCallFailed {
2431 reason: format!(
2432 "bms_flex_row {ctx}: rhs_count({rhs_count})*p_total({p_total}) overflow"
2433 ),
2434 })?;
2435 if v_rhs_len != rhs_elems {
2436 return Err(GpuError::DriverCallFailed {
2437 reason: format!(
2438 "bms_flex_row {ctx}: v_rhs.len()={v_rhs_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2439 ),
2440 });
2441 }
2442 if let Some(out_len) = out_len
2443 && out_len != rhs_elems
2444 {
2445 return Err(GpuError::DriverCallFailed {
2446 reason: format!(
2447 "bms_flex_row {ctx}: out.len()={out_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2448 ),
2449 });
2450 }
2451 Ok(rhs_elems)
2452}
2453
2454#[cfg(target_os = "linux")]
2458pub fn bms_flex_row_hvp_multi_scratch_bytes_for_shape(
2459 n: usize,
2460 p_total: usize,
2461 rhs_count: usize,
2462) -> Result<u64, GpuError> {
2463 if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2464 return Err(GpuError::DriverCallFailed {
2465 reason: format!(
2466 "bms_flex_row hvp_multi_scratch_bytes: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2467 ),
2468 });
2469 }
2470 let num_chunks = num_hvp_chunks(n);
2471 let partial = rhs_count
2472 .checked_mul(num_chunks)
2473 .and_then(|v| v.checked_mul(p_total))
2474 .ok_or_else(|| GpuError::DriverCallFailed {
2475 reason: format!(
2476 "bms_flex_row hvp_multi_scratch_bytes: rhs_count({rhs_count})*num_chunks({num_chunks})*p_total({p_total}) overflow"
2477 ),
2478 })?;
2479 let rhs_vectors = rhs_count
2480 .checked_mul(p_total)
2481 .and_then(|v| v.checked_mul(2))
2482 .ok_or_else(|| GpuError::DriverCallFailed {
2483 reason: format!(
2484 "bms_flex_row hvp_multi_scratch_bytes: 2*rhs_count({rhs_count})*p_total({p_total}) overflow"
2485 ),
2486 })?;
2487 let elems = partial
2488 .checked_add(rhs_vectors)
2489 .ok_or_else(|| GpuError::DriverCallFailed {
2490 reason: "bms_flex_row hvp_multi_scratch_bytes: element count overflow".to_string(),
2491 })?;
2492 Ok((elems * std::mem::size_of::<f64>()) as u64)
2493}
2494
2495#[cfg(target_os = "linux")]
2496pub(crate) fn run_bms_flex_row_multi_partial_reduce(
2497 storage: &DeviceResidentRowHess,
2498 rhs_count: usize,
2499 d_v_rhs: &CudaSlice<f64>,
2500 d_out: &mut CudaSlice<f64>,
2501 ctx: &str,
2502) -> Result<(), GpuError> {
2503 let rhs_elems = validate_bms_flex_row_hvp_multi_shape(
2504 storage,
2505 rhs_count,
2506 d_v_rhs.len(),
2507 Some(d_out.len()),
2508 ctx,
2509 )?;
2510 let backend = HvpKernelBackend::probe()?;
2511 let stream = backend.stream.clone();
2512 let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2513 let p_total = storage.block.p_total;
2514 let partial_len = rhs_count
2515 .checked_mul(args.num_chunks)
2516 .and_then(|v| v.checked_mul(p_total))
2517 .ok_or_else(|| GpuError::DriverCallFailed {
2518 reason: format!(
2519 "bms_flex_row {ctx}: partial length overflow for rhs_count={rhs_count}, num_chunks={}, p_total={p_total}",
2520 args.num_chunks
2521 ),
2522 })?;
2523
2524 let mut d_partial =
2525 stream
2526 .alloc_zeros::<f64>(partial_len)
2527 .map_err(|err| GpuError::DriverCallFailed {
2528 reason: format!("bms_flex_row {ctx} alloc multi partial: {err}"),
2529 })?;
2530 let part_func = backend
2531 .module
2532 .load_function("bms_flex_row_hvp_multi_partial")
2533 .map_err(|err| GpuError::DriverCallFailed {
2534 reason: format!("bms_flex_row {ctx} load multi partial: {err}"),
2535 })?;
2536 let red_func = backend
2537 .module
2538 .load_function("bms_flex_row_hvp_multi_reduce")
2539 .map_err(|err| GpuError::DriverCallFailed {
2540 reason: format!("bms_flex_row {ctx} load multi reduce: {err}"),
2541 })?;
2542
2543 let rhs_count_i32 = i32::try_from(rhs_count).map_err(|_| GpuError::DriverCallFailed {
2544 reason: format!("bms_flex_row {ctx}: rhs_count={rhs_count} exceeds i32 range"),
2545 })?;
2546 let cfg_part = LaunchConfig {
2547 grid_dim: (args.num_chunks as u32, 1, 1),
2548 block_dim: (HVP_THREADS, 1, 1),
2549 shared_mem_bytes: 0,
2550 };
2551 let mut builder = stream.launch_builder(&part_func);
2552 builder
2553 .arg(&args.n_i32)
2554 .arg(&args.r_i32)
2555 .arg(&args.p_m_i32)
2556 .arg(&args.p_g_i32)
2557 .arg(&args.p_total_i32)
2558 .arg(&args.h_block_start)
2559 .arg(&args.h_block_len)
2560 .arg(&args.w_block_start)
2561 .arg(&args.w_block_len)
2562 .arg(&args.h_primary_start)
2563 .arg(&args.w_primary_start)
2564 .arg(&args.rows_per_cta)
2565 .arg(&rhs_count_i32)
2566 .arg(&storage.hess)
2567 .arg(&storage.marginal_design)
2568 .arg(&storage.logslope_design)
2569 .arg(d_v_rhs)
2570 .arg(&mut d_partial);
2571 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2576 reason: format!("bms_flex_row {ctx} multi partial launch: {err}"),
2577 })?;
2578
2579 let red_threads: u32 = REDUCTION_THREADS;
2580 let red_blocks: u32 = ((rhs_elems as u32) + red_threads - 1) / red_threads;
2581 let cfg_red = LaunchConfig {
2582 grid_dim: (red_blocks, 1, 1),
2583 block_dim: (red_threads, 1, 1),
2584 shared_mem_bytes: 0,
2585 };
2586 let num_chunks_i32 = args.num_chunks as i32;
2587 let mut builder = stream.launch_builder(&red_func);
2588 builder
2589 .arg(&num_chunks_i32)
2590 .arg(&args.p_total_i32)
2591 .arg(&rhs_count_i32)
2592 .arg(&d_partial)
2593 .arg(d_out);
2594 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2597 reason: format!("bms_flex_row {ctx} multi reduce launch: {err}"),
2598 })?;
2599 drop(d_partial);
2600 Ok(())
2601}
2602
2603#[cfg(target_os = "linux")]
2606pub(crate) fn launch_bms_flex_row_hvp_multi(
2607 storage: &DeviceResidentRowHess,
2608 v_rhs: &[f64],
2609 rhs_count: usize,
2610) -> Result<Vec<f64>, GpuError> {
2611 let rhs_elems =
2612 validate_bms_flex_row_hvp_multi_shape(storage, rhs_count, v_rhs.len(), None, "hvp_multi")?;
2613 let backend = HvpKernelBackend::probe()?;
2614 let stream = backend.stream.clone();
2615 let d_v_rhs = stream
2616 .clone_htod(v_rhs)
2617 .map_err(|err| GpuError::DriverCallFailed {
2618 reason: format!("bms_flex_row hvp_multi upload v_rhs: {err}"),
2619 })?;
2620 let mut d_out =
2621 stream
2622 .alloc_zeros::<f64>(rhs_elems)
2623 .map_err(|err| GpuError::DriverCallFailed {
2624 reason: format!("bms_flex_row hvp_multi alloc out: {err}"),
2625 })?;
2626 run_bms_flex_row_multi_partial_reduce(storage, rhs_count, &d_v_rhs, &mut d_out, "hvp_multi")?;
2627 stream
2628 .synchronize()
2629 .map_err(|err| GpuError::DriverCallFailed {
2630 reason: format!("bms_flex_row hvp_multi synchronize: {err}"),
2631 })?;
2632 stream
2633 .clone_dtoh(&d_out)
2634 .map_err(|err| GpuError::DriverCallFailed {
2635 reason: format!("bms_flex_row hvp_multi download out: {err}"),
2636 })
2637}
2638
2639#[cfg(target_os = "linux")]
2650pub(crate) fn launch_bms_flex_row_hvp_into_device(
2651 storage: &DeviceResidentRowHess,
2652 d_v: &CudaSlice<f64>,
2653 d_out: &mut CudaSlice<f64>,
2654) -> Result<(), GpuError> {
2655 let p_total = storage.block.p_total;
2656 if d_v.len() != p_total {
2657 return Err(GpuError::DriverCallFailed {
2658 reason: format!(
2659 "bms_flex_row hvp_into_device: d_v.len()={} != p_total={}",
2660 d_v.len(),
2661 p_total
2662 ),
2663 });
2664 }
2665 if d_out.len() != p_total {
2666 return Err(GpuError::DriverCallFailed {
2667 reason: format!(
2668 "bms_flex_row hvp_into_device: d_out.len()={} != p_total={}",
2669 d_out.len(),
2670 p_total
2671 ),
2672 });
2673 }
2674 run_bms_flex_row_partial_reduce(
2678 storage,
2679 BmsFlexRowLaunchMode::HvpDeviceOut,
2680 Some(d_v),
2681 d_out,
2682 "hvp_into_device",
2683 )
2684}
2685
2686#[cfg(target_os = "linux")]
2689pub(crate) fn launch_bms_flex_row_hvp(
2690 storage: &DeviceResidentRowHess,
2691 v: &[f64],
2692) -> Result<Vec<f64>, GpuError> {
2693 launch_bms_flex_row_hvp_multi(storage, v, 1)
2694}
2695
2696#[cfg(target_os = "linux")]
2699pub(crate) fn launch_bms_flex_row_diagonal(
2700 storage: &DeviceResidentRowHess,
2701) -> Result<Vec<f64>, GpuError> {
2702 launch_bms_flex_row_host(storage, BmsFlexRowLaunchMode::DiagonalHostOut, None, "diag")
2703}
2704
2705#[cfg(target_os = "linux")]
2711pub(crate) const DENSE_BLOCK_MAX_P: usize = 72;
2712
2713#[cfg(target_os = "linux")]
2719pub(crate) const DENSE_BLOCK_ROWS_PER_CTA: u32 = 32;
2720
2721#[cfg(target_os = "linux")]
2738pub fn launch_bms_flex_row_dense_block(
2739 storage: &DeviceResidentRowHess,
2740) -> Result<Vec<f64>, GpuError> {
2741 let p_total = storage.block.p_total;
2742 if p_total == 0 {
2743 return Err(GpuError::DriverCallFailed {
2744 reason: "bms_flex_row dense_block: p_total must be > 0".to_string(),
2745 });
2746 }
2747 if p_total > DENSE_BLOCK_MAX_P {
2748 return Err(GpuError::DriverCallFailed {
2749 reason: format!(
2750 "bms_flex_row dense_block: p_total={p_total} exceeds DENSE_BLOCK_MAX_P={DENSE_BLOCK_MAX_P} \
2751 (per-CTA shmem accumulator p²*8 bytes would exceed V100's 48 KiB/block)"
2752 ),
2753 });
2754 }
2755 let backend = HvpKernelBackend::probe()?;
2756 let stream = backend.stream.clone();
2757 let n = storage.n;
2758 let r = storage.r;
2759 let rows_per_cta = DENSE_BLOCK_ROWS_PER_CTA as usize;
2760 let num_chunks = n.div_ceil(rows_per_cta);
2761 let pp = p_total * p_total;
2762
2763 let mut d_partial =
2764 stream
2765 .alloc_zeros::<f64>(num_chunks * pp)
2766 .map_err(|err| GpuError::DriverCallFailed {
2767 reason: format!("bms_flex_row dense_block alloc partial: {err}"),
2768 })?;
2769 let mut d_out = stream
2770 .alloc_zeros::<f64>(pp)
2771 .map_err(|err| GpuError::DriverCallFailed {
2772 reason: format!("bms_flex_row dense_block alloc out: {err}"),
2773 })?;
2774
2775 let part_func = backend
2776 .module
2777 .load_function("bms_flex_row_dense_block_partial")
2778 .map_err(|err| GpuError::DriverCallFailed {
2779 reason: format!("bms_flex_row dense_block load partial: {err}"),
2780 })?;
2781 let red_func = backend
2782 .module
2783 .load_function("bms_flex_row_dense_block_reduce")
2784 .map_err(|err| GpuError::DriverCallFailed {
2785 reason: format!("bms_flex_row dense_block load reduce: {err}"),
2786 })?;
2787
2788 let n_i32 = n as i32;
2789 let r_i32 = r as i32;
2790 let p_m_i32 = storage.block.p_m as i32;
2791 let p_g_i32 = storage.block.p_g as i32;
2792 let p_total_i32 = p_total as i32;
2793 let h_block_start = storage
2794 .block
2795 .h
2796 .as_ref()
2797 .map(|r| r.start as i32)
2798 .unwrap_or(0);
2799 let h_block_len = storage
2800 .block
2801 .h
2802 .as_ref()
2803 .map(|r| r.len() as i32)
2804 .unwrap_or(0);
2805 let w_block_start = storage
2806 .block
2807 .w
2808 .as_ref()
2809 .map(|r| r.start as i32)
2810 .unwrap_or(0);
2811 let w_block_len = storage
2812 .block
2813 .w
2814 .as_ref()
2815 .map(|r| r.len() as i32)
2816 .unwrap_or(0);
2817 let h_primary_start = storage
2818 .primary
2819 .h
2820 .as_ref()
2821 .map(|r| r.start as i32)
2822 .unwrap_or(0);
2823 let w_primary_start = storage
2824 .primary
2825 .w
2826 .as_ref()
2827 .map(|r| r.start as i32)
2828 .unwrap_or(0);
2829 let rows_per_cta_i32 = DENSE_BLOCK_ROWS_PER_CTA as i32;
2830 let num_chunks_u32 = num_chunks as u32;
2831
2832 let shmem_bytes: u32 =
2834 u32::try_from(pp * std::mem::size_of::<f64>()).map_err(|_| GpuError::DriverCallFailed {
2835 reason: format!("dense_block shmem bytes overflow u32 for p_total={p_total}"),
2836 })?;
2837
2838 let cfg_part = LaunchConfig {
2839 grid_dim: (num_chunks_u32, 1, 1),
2840 block_dim: (HVP_THREADS, 1, 1),
2841 shared_mem_bytes: shmem_bytes,
2842 };
2843 let mut builder = stream.launch_builder(&part_func);
2844 builder
2845 .arg(&n_i32)
2846 .arg(&r_i32)
2847 .arg(&p_m_i32)
2848 .arg(&p_g_i32)
2849 .arg(&p_total_i32)
2850 .arg(&h_block_start)
2851 .arg(&h_block_len)
2852 .arg(&w_block_start)
2853 .arg(&w_block_len)
2854 .arg(&h_primary_start)
2855 .arg(&w_primary_start)
2856 .arg(&rows_per_cta_i32)
2857 .arg(&storage.hess)
2858 .arg(&storage.marginal_design)
2859 .arg(&storage.logslope_design)
2860 .arg(&mut d_partial);
2861 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2865 reason: format!("bms_flex_row dense_block partial launch: {err}"),
2866 })?;
2867
2868 let red_threads: u32 = REDUCTION_THREADS;
2869 let red_blocks: u32 = ((pp as u32) + red_threads - 1) / red_threads;
2870 let cfg_red = LaunchConfig {
2871 grid_dim: (red_blocks, 1, 1),
2872 block_dim: (red_threads, 1, 1),
2873 shared_mem_bytes: 0,
2874 };
2875 let num_chunks_i32 = num_chunks as i32;
2876 let mut builder = stream.launch_builder(&red_func);
2877 builder
2878 .arg(&num_chunks_i32)
2879 .arg(&p_total_i32)
2880 .arg(&d_partial)
2881 .arg(&mut d_out);
2882 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2884 reason: format!("bms_flex_row dense_block reduce launch: {err}"),
2885 })?;
2886 stream
2887 .synchronize()
2888 .map_err(|err| GpuError::DriverCallFailed {
2889 reason: format!("bms_flex_row dense_block sync: {err}"),
2890 })?;
2891 stream
2892 .clone_dtoh(&d_out)
2893 .map_err(|err| GpuError::DriverCallFailed {
2894 reason: format!("bms_flex_row dense_block download: {err}"),
2895 })
2896}
2897
2898#[cfg(all(test, target_os = "linux"))]
2906mod tests {
2907 use super::*;
2908
2909 pub(crate) fn minimal_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
2910 BmsFlexRowKernelInputs {
2911 n_rows: 1,
2912 r: 4,
2913 p_h: 1,
2914 p_w: 1,
2915 q: &buffers.q,
2916 b: &buffers.b,
2917 mu_1: &buffers.mu_1,
2918 mu_2: &buffers.mu_2,
2919 z_obs: &buffers.z_obs,
2920 y: &buffers.y,
2921 w: &buffers.w,
2922 s_f: 1.0,
2923 cell_offsets: &buffers.cell_offsets,
2924 cell_c0: &buffers.cell_c0,
2925 cell_c1: &buffers.cell_c1,
2926 cell_c2: &buffers.cell_c2,
2927 cell_c3: &buffers.cell_c3,
2928 cell_a: &buffers.cell_a,
2929 cell_aa: &buffers.cell_aa,
2930 cell_r: &buffers.cell_r,
2931 cell_ar: &buffers.cell_ar,
2932 cell_sbb: &buffers.cell_sbb,
2933 cell_sbh: &buffers.cell_sbh,
2934 cell_sbw: &buffers.cell_sbw,
2935 cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
2936 chi_obs: &buffers.chi_obs,
2937 xi_obs: &buffers.xi_obs,
2938 rho_u: &buffers.rho_u,
2939 tau_u: &buffers.tau_u,
2940 r_uv: &buffers.r_uv,
2941 }
2942 }
2943
2944 pub(crate) struct TestBuffers {
2945 pub(crate) q: Vec<f64>,
2946 pub(crate) b: Vec<f64>,
2947 pub(crate) mu_1: Vec<f64>,
2948 pub(crate) mu_2: Vec<f64>,
2949 pub(crate) z_obs: Vec<f64>,
2950 pub(crate) y: Vec<f64>,
2951 pub(crate) w: Vec<f64>,
2952 pub(crate) cell_offsets: Vec<u32>,
2953 pub(crate) cell_c0: Vec<f64>,
2954 pub(crate) cell_c1: Vec<f64>,
2955 pub(crate) cell_c2: Vec<f64>,
2956 pub(crate) cell_c3: Vec<f64>,
2957 pub(crate) cell_a: Vec<f64>,
2958 pub(crate) cell_aa: Vec<f64>,
2959 pub(crate) cell_r: Vec<f64>,
2960 pub(crate) cell_ar: Vec<f64>,
2961 pub(crate) cell_sbb: Vec<f64>,
2962 pub(crate) cell_sbh: Vec<f64>,
2963 pub(crate) cell_sbw: Vec<f64>,
2964 pub(crate) cell_moments: Vec<f64>,
2965 pub(crate) chi_obs: Vec<f64>,
2966 pub(crate) xi_obs: Vec<f64>,
2967 pub(crate) rho_u: Vec<f64>,
2968 pub(crate) tau_u: Vec<f64>,
2969 pub(crate) r_uv: Vec<f64>,
2970 }
2971
2972 pub(crate) fn make_buffers(n_cells: u32, r: usize, p_h: usize, p_w: usize) -> TestBuffers {
2973 let cells = n_cells as usize;
2974 TestBuffers {
2975 q: vec![0.1; 1],
2976 b: vec![0.5; 1],
2977 mu_1: vec![0.3; 1],
2978 mu_2: vec![0.07; 1],
2979 z_obs: vec![0.0; 1],
2980 y: vec![1.0; 1],
2981 w: vec![1.0; 1],
2982 cell_offsets: vec![0, n_cells],
2983 cell_c0: vec![0.2; cells],
2984 cell_c1: vec![-0.1; cells],
2985 cell_c2: vec![0.05; cells],
2986 cell_c3: vec![-0.02; cells],
2987 cell_a: vec![0.1; cells * 4],
2988 cell_aa: vec![0.0; cells * 4],
2989 cell_r: vec![0.05; cells * (r - 1) * 4],
2990 cell_ar: vec![0.0; cells * (r - 1) * 4],
2991 cell_sbb: vec![0.0; cells * 4],
2992 cell_sbh: vec![0.0; cells * p_h * 4],
2993 cell_sbw: vec![0.0; cells * p_w * 4],
2994 cell_moments: vec![1.0; cells * MOMENT_STRIDE],
2995 chi_obs: vec![1.0; 1],
2996 xi_obs: vec![0.0; 1],
2997 rho_u: vec![0.0; r],
2998 tau_u: vec![0.0; r],
2999 r_uv: vec![0.0; r * r],
3000 }
3001 }
3002
3003 #[test]
3004 pub(crate) fn validate_accepts_minimal_inputs() {
3005 let buffers = make_buffers(2, 4, 1, 1);
3006 let inputs = minimal_inputs(&buffers);
3007 assert!(inputs.validate().is_ok());
3008 }
3009
3010 #[test]
3011 pub(crate) fn validate_rejects_r_above_max() {
3012 let r = MAX_R + 1;
3013 let p_h = (r - 2) / 2;
3014 let p_w = (r - 2) - p_h;
3015 let buffers = make_buffers(1, r, p_h, p_w);
3016 let bad_inputs = BmsFlexRowKernelInputs {
3017 r,
3018 p_h,
3019 p_w,
3020 rho_u: &buffers.rho_u, tau_u: &buffers.tau_u,
3022 r_uv: &buffers.r_uv,
3023 cell_r: &buffers.cell_r,
3024 cell_ar: &buffers.cell_ar,
3025 cell_sbh: &buffers.cell_sbh,
3026 cell_sbw: &buffers.cell_sbw,
3027 ..minimal_inputs(&buffers)
3028 };
3029 let err = bad_inputs.validate().expect_err("r > MAX_R must fail");
3030 let msg = err.to_string();
3031 assert!(msg.contains("MAX_R"), "expected MAX_R hint, got: {msg}");
3032 }
3033
3034 #[test]
3035 pub(crate) fn validate_rejects_mismatched_r_decomposition() {
3036 let buffers = make_buffers(1, 4, 1, 1);
3037 let bad_inputs = BmsFlexRowKernelInputs {
3038 r: 4,
3039 p_h: 1,
3040 p_w: 2, ..minimal_inputs(&buffers)
3042 };
3043 let err = bad_inputs
3044 .validate()
3045 .expect_err("inconsistent r vs p_h+p_w must fail");
3046 let msg = err.to_string();
3047 assert!(msg.contains("p_h"), "got: {msg}");
3048 assert!(msg.contains("p_w"), "got: {msg}");
3049 }
3050
3051 #[test]
3052 pub(crate) fn validate_rejects_non_monotone_offsets() {
3053 let mut buffers = make_buffers(2, 4, 1, 1);
3060 buffers.cell_offsets = vec![5, 2];
3061 let inputs = minimal_inputs(&buffers);
3062 let err = inputs
3063 .validate()
3064 .expect_err("non-monotone offsets must fail");
3065 let msg = err.to_string();
3066 assert!(msg.contains("monotone"), "got: {msg}");
3067 }
3068
3069 #[test]
3070 pub(crate) fn validate_rejects_mismatched_cell_moments_length() {
3071 let mut buffers = make_buffers(2, 4, 1, 1);
3072 buffers.cell_moments.pop(); let inputs = minimal_inputs(&buffers);
3074 let err = inputs.validate().expect_err("short cell_moments must fail");
3075 let msg = err.to_string();
3076 assert!(msg.contains("cell_moments"), "got: {msg}");
3077 }
3078
3079 #[test]
3080 pub(crate) fn launch_on_non_linux_reports_driver_library_unavailable() {
3081 #[cfg(target_os = "linux")]
3085 {
3086 let buffers = make_buffers(1, 4, 1, 1);
3093 let inputs = minimal_inputs(&buffers);
3094 match launch_bms_flex_row_kernel(inputs) {
3095 Ok(_) => { }
3096 Err(GpuError::DriverLibraryUnavailable { .. })
3097 | Err(GpuError::DriverCallFailed { .. })
3098 | Err(GpuError::DriverSymbolMissing { .. })
3099 | Err(GpuError::NoDeviceKernel { .. }) => { }
3100 Err(other) => panic!("unexpected GpuError variant: {other:?}"),
3101 }
3102 }
3103 #[cfg(not(target_os = "linux"))]
3104 {
3105 let buffers = make_buffers(1, 4, 1, 1);
3106 let inputs = minimal_inputs(&buffers);
3107 match launch_bms_flex_row_kernel(inputs) {
3108 Err(GpuError::DriverLibraryUnavailable { reason }) => {
3109 assert!(
3110 reason.contains("Linux-only"),
3111 "expected Linux-only hint, got: {reason}"
3112 );
3113 }
3114 other => panic!("expected DriverLibraryUnavailable on non-Linux, got {other:?}"),
3115 }
3116 }
3117 }
3118
3119 #[test]
3120 pub(crate) fn s_f_must_be_positive_and_finite() {
3121 let buffers = make_buffers(1, 4, 1, 1);
3122 let mut inputs = minimal_inputs(&buffers);
3123 inputs.s_f = 0.0;
3124 match launch_bms_flex_row_kernel(inputs) {
3125 Err(GpuError::DriverCallFailed { reason }) => {
3126 assert!(reason.contains("s_f"), "got: {reason}");
3127 }
3128 other => panic!("expected DriverCallFailed for s_f=0, got {other:?}"),
3129 }
3130 }
3131
3132 pub(crate) const ORACLE_INV_TWO_PI: f64 = 1.0 / std::f64::consts::TAU;
3147 pub(crate) const ORACLE_SQRT_2: f64 = std::f64::consts::SQRT_2;
3148 pub(crate) const ORACLE_INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
3149
3150 pub(crate) fn oracle_erfcx_nonnegative(x: f64) -> f64 {
3151 if !x.is_finite() {
3152 return if x > 0.0 { 0.0 } else { f64::INFINITY };
3153 }
3154 if x <= 0.0 {
3155 return 1.0;
3156 }
3157 if x < 26.0 {
3158 let mut xx = x * x;
3159 if xx > 700.0 {
3160 xx = 700.0;
3161 }
3162 return xx.exp() * gam_gpu::numerics_host::erfc(x);
3163 }
3164 let inv = 1.0 / x;
3165 let inv2 = inv * inv;
3166 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
3167 + 6.5625 * inv2 * inv2 * inv2 * inv2;
3168 let inv_sqrt_pi: f64 = 0.564_189_583_547_756_3;
3169 inv * poly * inv_sqrt_pi
3170 }
3171
3172 pub(crate) fn oracle_log_ndtr_and_mills(x: f64) -> (f64, f64) {
3173 if x == f64::INFINITY {
3174 return (0.0, 0.0);
3175 }
3176 if x == f64::NEG_INFINITY {
3177 return (f64::NEG_INFINITY, f64::INFINITY);
3178 }
3179 if x.is_nan() {
3180 return (x, x);
3181 }
3182 const ORACLE_LEFT_TAIL_X: f64 = -37.0;
3195 if x >= ORACLE_LEFT_TAIL_X {
3196 let mut cdf = 0.5 * gam_gpu::numerics_host::erfc(-x / ORACLE_SQRT_2);
3197 if cdf < 1e-300 {
3198 cdf = 1e-300;
3199 }
3200 if cdf > 1.0 {
3201 cdf = 1.0;
3202 }
3203 let pdf = ORACLE_INV_SQRT_2PI * (-0.5 * x * x).exp();
3204 (cdf.ln(), pdf / cdf)
3205 } else {
3206 let u = -x / ORACLE_SQRT_2;
3207 let mut ex = oracle_erfcx_nonnegative(u);
3208 if ex < 1e-300 {
3209 ex = 1e-300;
3210 }
3211 let log_cdf = -u * u + (0.5 * ex).ln();
3212 let sqrt_2_over_pi: f64 = 0.797_884_560_802_865_4;
3213 (log_cdf, sqrt_2_over_pi / ex)
3214 }
3215 }
3216
3217 pub(crate) fn cpu_oracle_outputs(
3222 inputs: &BmsFlexRowKernelInputs<'_>,
3223 ) -> BmsFlexRowKernelOutputs {
3224 let n = inputs.n_rows;
3225 let r = inputs.r;
3226 let p_h = inputs.p_h;
3227 let p_w = inputs.p_w;
3228 let mut neglog = vec![0.0_f64; n];
3229 let mut grad = vec![0.0_f64; n * r];
3230 let mut hess = vec![0.0_f64; n * r * r];
3231 let cell_moments_host = match &inputs.cell_moments {
3232 CellMomentsSource::Host(slice) => *slice,
3233 #[cfg(target_os = "linux")]
3234 CellMomentsSource::Device(_) => panic!(
3235 "cpu_oracle_outputs: cell_moments is device-resident; oracle \
3241 is a host-only sanity checker"
3242 ),
3243 };
3244
3245 for row in 0..n {
3246 let mut f_u = vec![0.0_f64; r];
3248 let mut f_au = vec![0.0_f64; r];
3249 let mut f_uv = vec![0.0_f64; r * r];
3250 let mut f_a = 0.0_f64;
3251 let mut f_aa = 0.0_f64;
3252
3253 let cell_lo = inputs.cell_offsets[row] as usize;
3254 let cell_hi = inputs.cell_offsets[row + 1] as usize;
3255 for c in cell_lo..cell_hi {
3256 let c_arr = [
3257 inputs.cell_c0[c],
3258 inputs.cell_c1[c],
3259 inputs.cell_c2[c],
3260 inputs.cell_c3[c],
3261 ];
3262 let m = &cell_moments_host[c * MOMENT_STRIDE..(c + 1) * MOMENT_STRIDE];
3263
3264 let mut t = [0.0_f64; 7];
3266 for (n_idx, t_slot) in t.iter_mut().enumerate() {
3267 let mut acc = 0.0_f64;
3268 for (e, c_e) in c_arr.iter().enumerate() {
3269 acc = c_e.mul_add(m[e + n_idx], acc);
3270 }
3271 *t_slot = acc * ORACLE_INV_TWO_PI;
3272 }
3273
3274 let d_of = |r_arr: &[f64]| -> f64 {
3275 ORACLE_INV_TWO_PI
3276 * (r_arr[0] * m[0] + r_arr[1] * m[1] + r_arr[2] * m[2] + r_arr[3] * m[3])
3277 };
3278 let q_of = |r_arr: &[f64], s_arr: &[f64]| -> f64 {
3279 (r_arr[0] * s_arr[0]) * t[0]
3280 + (r_arr[0] * s_arr[1] + r_arr[1] * s_arr[0]) * t[1]
3281 + (r_arr[0] * s_arr[2] + r_arr[1] * s_arr[1] + r_arr[2] * s_arr[0]) * t[2]
3282 + (r_arr[0] * s_arr[3]
3283 + r_arr[1] * s_arr[2]
3284 + r_arr[2] * s_arr[1]
3285 + r_arr[3] * s_arr[0])
3286 * t[3]
3287 + (r_arr[1] * s_arr[3] + r_arr[2] * s_arr[2] + r_arr[3] * s_arr[1]) * t[4]
3288 + (r_arr[2] * s_arr[3] + r_arr[3] * s_arr[2]) * t[5]
3289 + (r_arr[3] * s_arr[3]) * t[6]
3290 };
3291
3292 let a_c = &inputs.cell_a[c * 4..(c + 1) * 4];
3293 let aa_c = &inputs.cell_aa[c * 4..(c + 1) * 4];
3294 f_a += d_of(a_c);
3295 f_aa += d_of(aa_c) - q_of(a_c, a_c);
3296
3297 for u in 1..r {
3298 let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3299 let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3300 let ar_u = &inputs.cell_ar[r_u_off..r_u_off + 4];
3301 f_u[u] += d_of(r_u);
3302 f_au[u] += d_of(ar_u) - q_of(a_c, r_u);
3303 }
3304
3305 for u in 1..r {
3306 let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3307 let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3308 for v in u..r {
3309 let r_v_off = (c * (r - 1) + (v - 1)) * 4;
3310 let r_v = &inputs.cell_r[r_v_off..r_v_off + 4];
3311 let q_uv = q_of(r_u, r_v);
3312 let d_s = if u == 1 && v == 1 {
3313 let s_bb = &inputs.cell_sbb[c * 4..(c + 1) * 4];
3314 d_of(s_bb)
3315 } else if u == 1 && v >= 2 && v < 2 + p_h {
3316 let j = v - 2;
3317 let off = (c * p_h + j) * 4;
3318 let s_bh = &inputs.cell_sbh[off..off + 4];
3319 d_of(s_bh)
3320 } else if u == 1 && v >= 2 + p_h && v < r {
3321 let l = v - (2 + p_h);
3322 let off = (c * p_w + l) * 4;
3323 let s_bw = &inputs.cell_sbw[off..off + 4];
3324 d_of(s_bw)
3325 } else {
3326 0.0
3327 };
3328 f_uv[u * r + v] += d_s - q_uv;
3329 }
3330 }
3331 }
3332
3333 let mu_1 = inputs.mu_1[row];
3335 let mu_2 = inputs.mu_2[row];
3336 f_u[0] = -mu_1;
3337 f_au[0] = 0.0;
3338 for v in 0..r {
3339 f_uv[v] = 0.0;
3340 f_uv[v * r] = 0.0;
3341 }
3342 f_uv[0] = -mu_2;
3343
3344 if !f_a.is_finite() || f_a <= 0.0 {
3346 neglog[row] = f64::NAN;
3347 for slot in grad[row * r..(row + 1) * r].iter_mut() {
3348 *slot = f64::NAN;
3349 }
3350 for slot in hess[row * r * r..(row + 1) * r * r].iter_mut() {
3351 *slot = f64::NAN;
3352 }
3353 continue;
3354 }
3355 let inv_fa = 1.0 / f_a;
3356
3357 let mut a_u = vec![0.0_f64; r];
3359 a_u[0] = mu_1 * inv_fa;
3360 for u in 1..r {
3361 a_u[u] = -f_u[u] * inv_fa;
3362 }
3363 let mut a_uv = vec![0.0_f64; r * r];
3364 for u in 0..r {
3365 for v in u..r {
3366 let term = f_uv[u * r + v]
3367 + f_au[v] * a_u[u]
3368 + f_au[u] * a_u[v]
3369 + f_aa * a_u[u] * a_u[v];
3370 let val = -term * inv_fa;
3371 a_uv[u * r + v] = val;
3372 a_uv[v * r + u] = val;
3373 }
3374 }
3375
3376 let chi = inputs.chi_obs[row];
3378 let xi = inputs.xi_obs[row];
3379 let rho = &inputs.rho_u[row * r..(row + 1) * r];
3380 let tau = &inputs.tau_u[row * r..(row + 1) * r];
3381 let ruv = &inputs.r_uv[row * r * r..(row + 1) * r * r];
3382 let mut bar_e_u = vec![0.0_f64; r];
3383 for u in 0..r {
3384 bar_e_u[u] = chi * a_u[u] + rho[u];
3385 }
3386 let mut bar_e_uv = vec![0.0_f64; r * r];
3387 for u in 0..r {
3388 for v in u..r {
3389 let val = chi * a_uv[u * r + v]
3390 + xi * a_u[u] * a_u[v]
3391 + tau[u] * a_u[v]
3392 + a_u[u] * tau[v]
3393 + ruv[u * r + v];
3394 bar_e_uv[u * r + v] = val;
3395 if u != v {
3396 bar_e_uv[v * r + u] = val;
3397 }
3398 }
3399 }
3400
3401 let y = inputs.y[row];
3403 let w = inputs.w[row];
3404 let s = 2.0 * y - 1.0;
3405 let e_obs = bar_e_u[0];
3406 let m_arg = s * e_obs;
3407 let (log_cdf, lambda) = oracle_log_ndtr_and_mills(m_arg);
3408 let a_i = -w * s * lambda;
3409 let b_i = w * lambda * (m_arg + lambda);
3410 neglog[row] = -w * log_cdf;
3411 for u in 0..r {
3412 grad[row * r + u] = a_i * bar_e_u[u];
3413 }
3414 for u in 0..r {
3415 for v in u..r {
3416 let val = b_i * bar_e_u[u] * bar_e_u[v] + a_i * bar_e_uv[u * r + v];
3417 hess[row * r * r + u * r + v] = val;
3418 if u != v {
3419 hess[row * r * r + v * r + u] = val;
3420 }
3421 }
3422 }
3423 }
3424
3425 BmsFlexRowKernelOutputs { neglog, grad, hess }
3426 }
3427
3428 pub(crate) fn make_parity_buffers() -> TestBuffers {
3432 let n = 4_usize;
3433 let r = 5_usize;
3434 let p_h = 2_usize;
3435 let p_w = 1_usize;
3436 let row_cells: [u32; 4] = [2, 3, 4, 2];
3438 let mut cell_offsets = vec![0_u32; n + 1];
3439 for i in 0..n {
3440 cell_offsets[i + 1] = cell_offsets[i] + row_cells[i];
3441 }
3442 let total_cells = cell_offsets[n] as usize;
3443
3444 let f = |seed: usize| -> f64 {
3446 let x = ((seed.wrapping_mul(2_654_435_761)) & 0xFFFF) as f64 / 65_536.0;
3447 0.1 + 0.4 * x
3448 };
3449
3450 let q = (0..n).map(|i| 0.05 + 0.1 * (i as f64)).collect::<Vec<_>>();
3451 let b = (0..n).map(|i| 0.6 + 0.05 * (i as f64)).collect::<Vec<_>>();
3452 let mu_1 = (0..n).map(|i| 0.7 + 0.02 * (i as f64)).collect::<Vec<_>>();
3453 let mu_2 = (0..n).map(|i| 0.15 + 0.01 * (i as f64)).collect::<Vec<_>>();
3454 let z_obs = (0..n).map(|i| -0.2 + 0.1 * (i as f64)).collect::<Vec<_>>();
3455 let y = [1.0, 0.0, 1.0, 0.0].to_vec();
3456 let w = vec![1.0; n];
3457
3458 let cell_c0 = (0..total_cells).map(|c| f(c + 1001)).collect::<Vec<_>>();
3459 let cell_c1 = (0..total_cells)
3460 .map(|c| -f(c + 2002) * 0.5)
3461 .collect::<Vec<_>>();
3462 let cell_c2 = (0..total_cells).map(|c| f(c + 3003) * 0.2).collect();
3463 let cell_c3 = (0..total_cells).map(|c| -f(c + 4004) * 0.1).collect();
3464
3465 let cell_a = (0..total_cells * 4)
3466 .map(|i| f(i + 5005) * 0.3)
3467 .collect::<Vec<_>>();
3468 let cell_aa = (0..total_cells * 4)
3469 .map(|i| f(i + 6006) * 0.1)
3470 .collect::<Vec<_>>();
3471 let cell_r = (0..total_cells * (r - 1) * 4)
3472 .map(|i| f(i + 7007) * 0.2)
3473 .collect::<Vec<_>>();
3474 let cell_ar = (0..total_cells * (r - 1) * 4)
3475 .map(|i| f(i + 8008) * 0.05)
3476 .collect::<Vec<_>>();
3477 let cell_sbb = (0..total_cells * 4)
3478 .map(|i| f(i + 9009) * 0.08)
3479 .collect::<Vec<_>>();
3480 let cell_sbh = (0..total_cells * p_h * 4)
3481 .map(|i| f(i + 10_010) * 0.07)
3482 .collect::<Vec<_>>();
3483 let cell_sbw = (0..total_cells * p_w * 4)
3484 .map(|i| f(i + 11_011) * 0.06)
3485 .collect::<Vec<_>>();
3486 let cell_moments = (0..total_cells * MOMENT_STRIDE)
3487 .map(|i| 0.4 + 0.1 * f(i + 12_012))
3488 .collect::<Vec<_>>();
3489
3490 let chi_obs = (0..n).map(|i| 0.9 + 0.01 * (i as f64)).collect::<Vec<_>>();
3491 let xi_obs = (0..n).map(|i| 0.2 + 0.01 * (i as f64)).collect::<Vec<_>>();
3492 let rho_u = (0..n * r).map(|i| 0.03 * f(i + 13_013)).collect::<Vec<_>>();
3493 let tau_u = (0..n * r).map(|i| 0.02 * f(i + 14_014)).collect::<Vec<_>>();
3494 let r_uv = (0..n * r * r)
3495 .map(|i| 0.04 * f(i + 15_015))
3496 .collect::<Vec<_>>();
3497
3498 TestBuffers {
3499 q,
3500 b,
3501 mu_1,
3502 mu_2,
3503 z_obs,
3504 y,
3505 w,
3506 cell_offsets,
3507 cell_c0,
3508 cell_c1,
3509 cell_c2,
3510 cell_c3,
3511 cell_a,
3512 cell_aa,
3513 cell_r,
3514 cell_ar,
3515 cell_sbb,
3516 cell_sbh,
3517 cell_sbw,
3518 cell_moments,
3519 chi_obs,
3520 xi_obs,
3521 rho_u,
3522 tau_u,
3523 r_uv,
3524 }
3525 }
3526
3527 pub(crate) fn parity_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3528 BmsFlexRowKernelInputs {
3529 n_rows: 4,
3530 r: 5,
3531 p_h: 2,
3532 p_w: 1,
3533 q: &buffers.q,
3534 b: &buffers.b,
3535 mu_1: &buffers.mu_1,
3536 mu_2: &buffers.mu_2,
3537 z_obs: &buffers.z_obs,
3538 y: &buffers.y,
3539 w: &buffers.w,
3540 s_f: 1.0,
3541 cell_offsets: &buffers.cell_offsets,
3542 cell_c0: &buffers.cell_c0,
3543 cell_c1: &buffers.cell_c1,
3544 cell_c2: &buffers.cell_c2,
3545 cell_c3: &buffers.cell_c3,
3546 cell_a: &buffers.cell_a,
3547 cell_aa: &buffers.cell_aa,
3548 cell_r: &buffers.cell_r,
3549 cell_ar: &buffers.cell_ar,
3550 cell_sbb: &buffers.cell_sbb,
3551 cell_sbh: &buffers.cell_sbh,
3552 cell_sbw: &buffers.cell_sbw,
3553 cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3554 chi_obs: &buffers.chi_obs,
3555 xi_obs: &buffers.xi_obs,
3556 rho_u: &buffers.rho_u,
3557 tau_u: &buffers.tau_u,
3558 r_uv: &buffers.r_uv,
3559 }
3560 }
3561
3562 #[test]
3566 pub(crate) fn cpu_oracle_produces_finite_symmetric_hessian() {
3567 let buffers = make_parity_buffers();
3568 let inputs = parity_inputs(&buffers);
3569 inputs
3570 .validate()
3571 .expect("parity fixture must satisfy validate()");
3572 let out = cpu_oracle_outputs(&inputs);
3573 let n = inputs.n_rows;
3574 let r = inputs.r;
3575 assert_eq!(out.neglog.len(), n);
3576 assert_eq!(out.grad.len(), n * r);
3577 assert_eq!(out.hess.len(), n * r * r);
3578 for row in 0..n {
3579 assert!(
3580 out.neglog[row].is_finite(),
3581 "row {row}: neglog must be finite, got {}",
3582 out.neglog[row]
3583 );
3584 for u in 0..r {
3585 let g = out.grad[row * r + u];
3586 assert!(g.is_finite(), "row {row}: grad[{u}] = {g}");
3587 for v in 0..r {
3588 let huv = out.hess[row * r * r + u * r + v];
3589 let hvu = out.hess[row * r * r + v * r + u];
3590 assert!(huv.is_finite(), "row {row}: H[{u},{v}] = {huv}");
3591 assert_eq!(
3592 huv.to_bits(),
3593 hvu.to_bits(),
3594 "row {row}: H[{u},{v}] and H[{v},{u}] must be bit-identical"
3595 );
3596 }
3597 }
3598 }
3599 }
3600
3601 #[test]
3630 pub(crate) fn cpu_oracle_mills_layer_matches_finite_differences() {
3631 let neglog_of = |e: f64, y: f64, w: f64| -> f64 {
3634 let s = 2.0 * y - 1.0;
3635 let (log_cdf, _) = oracle_log_ndtr_and_mills(s * e);
3636 -w * log_cdf
3637 };
3638 let ab_of = |e: f64, y: f64, w: f64| -> (f64, f64) {
3641 let s = 2.0 * y - 1.0;
3642 let m_arg = s * e;
3643 let (_, lambda) = oracle_log_ndtr_and_mills(m_arg);
3644 let a_i = -w * s * lambda;
3645 let b_i = w * lambda * (m_arg + lambda);
3646 (a_i, b_i)
3647 };
3648
3649 let cases: [(f64, f64, f64); 12] = [
3654 (-1.6, 1.0, 1.0),
3655 (-0.7, 1.0, 1.0),
3656 (0.0, 1.0, 1.0),
3657 (0.9, 1.0, 1.0),
3658 (1.8, 1.0, 1.0),
3659 (-1.4, 0.0, 1.0),
3660 (-0.3, 0.0, 1.0),
3661 (0.0, 0.0, 1.0),
3662 (0.6, 0.0, 1.0),
3663 (1.5, 0.0, 1.0),
3664 (0.4, 1.0, 0.75),
3665 (-0.8, 0.0, 1.3),
3666 ];
3667 let h = 1e-3_f64;
3670 for (e, y, w) in cases {
3671 let (a_ana, b_ana) = ab_of(e, y, w);
3672
3673 let fp2 = neglog_of(e + 2.0 * h, y, w);
3674 let fp1 = neglog_of(e + h, y, w);
3675 let f0 = neglog_of(e, y, w);
3676 let fm1 = neglog_of(e - h, y, w);
3677 let fm2 = neglog_of(e - 2.0 * h, y, w);
3678
3679 let d1_fd = (-fp2 + 8.0 * fp1 - 8.0 * fm1 + fm2) / (12.0 * h);
3681 let d2_fd = (-fp2 + 16.0 * fp1 - 30.0 * f0 + 16.0 * fm1 - fm2) / (12.0 * h * h);
3683
3684 let a_abs = (a_ana - d1_fd).abs();
3685 let a_rel = a_abs / a_ana.abs().max(1.0);
3686 assert!(
3687 a_abs <= 5e-8 || a_rel <= 5e-8,
3688 "Mills A (∂neglog/∂e) drift at e={e} y={y} w={w}: \
3689 analytic={a_ana:.17e} fd={d1_fd:.17e} abs={a_abs:.3e} rel={a_rel:.3e}"
3690 );
3691
3692 let b_abs = (b_ana - d2_fd).abs();
3693 let b_rel = b_abs / b_ana.abs().max(1.0);
3694 assert!(
3695 b_abs <= 5e-6 || b_rel <= 5e-6,
3696 "Mills B (∂²neglog/∂e²) drift at e={e} y={y} w={w}: \
3697 analytic={b_ana:.17e} fd={d2_fd:.17e} abs={b_abs:.3e} rel={b_rel:.3e}"
3698 );
3699 }
3700 }
3701
3702 #[test]
3711 pub(crate) fn bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available() {
3712 #[cfg(not(target_os = "linux"))]
3713 {
3714 eprintln!(
3715 "[bms_flex_row parity] non-Linux host — skipping CUDA parity \
3716 (CPU oracle exercised by sibling test)"
3717 );
3718 return;
3719 }
3720 #[cfg(target_os = "linux")]
3721 {
3722 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
3723 eprintln!(
3724 "[bms_flex_row parity] no CUDA runtime — skipping device \
3725 parity (CPU oracle exercised by sibling test)"
3726 );
3727 return;
3728 };
3729 let buffers = make_parity_buffers();
3730 let inputs_cpu = parity_inputs(&buffers);
3731 inputs_cpu
3732 .validate()
3733 .expect("parity fixture must satisfy validate()");
3734 let cpu_out = cpu_oracle_outputs(&inputs_cpu);
3735
3736 let inputs_gpu = parity_inputs(&buffers);
3738 let gpu_out = match launch_bms_flex_row_kernel(inputs_gpu) {
3739 Ok(out) => out,
3740 Err(err) => panic!(
3741 "[bms_flex_row parity] launch failed on CUDA-selected host; \
3742 device/oracle parity must fail loudly on GPU CI: {err}"
3743 ),
3744 };
3745
3746 let n = inputs_cpu.n_rows;
3747 let r = inputs_cpu.r;
3748 let tol_abs = 1e-8_f64;
3749 let tol_rel = 1e-8_f64;
3750 let check_close = |label: &str, idx: usize, cpu: f64, gpu: f64| {
3751 if cpu.is_nan() || gpu.is_nan() {
3752 assert!(
3753 cpu.is_nan() && gpu.is_nan(),
3754 "{label}[{idx}]: NaN parity broke — cpu={cpu}, gpu={gpu}"
3755 );
3756 return;
3757 }
3758 let diff = (cpu - gpu).abs();
3759 let tol = tol_abs + tol_rel * cpu.abs();
3760 assert!(
3761 diff <= tol,
3762 "{label}[{idx}]: |cpu − gpu| = {diff:.3e} > tol = {tol:.3e}; \
3763 cpu={cpu:.17e}, gpu={gpu:.17e}"
3764 );
3765 };
3766 assert_eq!(cpu_out.neglog.len(), gpu_out.neglog.len());
3767 assert_eq!(cpu_out.grad.len(), gpu_out.grad.len());
3768 assert_eq!(cpu_out.hess.len(), gpu_out.hess.len());
3769 for (i, (&c, &g)) in cpu_out.neglog.iter().zip(gpu_out.neglog.iter()).enumerate() {
3770 check_close("neglog", i, c, g);
3771 }
3772 for (i, (&c, &g)) in cpu_out.grad.iter().zip(gpu_out.grad.iter()).enumerate() {
3773 check_close("grad", i, c, g);
3774 }
3775 for (i, (&c, &g)) in cpu_out.hess.iter().zip(gpu_out.hess.iter()).enumerate() {
3776 check_close("hess", i, c, g);
3777 }
3778 for row in 0..n {
3780 for u in 0..r {
3781 for v in 0..r {
3782 let a = gpu_out.hess[row * r * r + u * r + v];
3783 let bb = gpu_out.hess[row * r * r + v * r + u];
3784 assert_eq!(
3785 a.to_bits(),
3786 bb.to_bits(),
3787 "GPU row {row}: H[{u},{v}] ≠ H[{v},{u}] bit-for-bit"
3788 );
3789 }
3790 }
3791 }
3792 }
3793 }
3794
3795 #[test]
3796 pub(crate) fn kernel_source_mentions_cpu_parity_reference() {
3797 #[cfg(target_os = "linux")]
3802 assert!(ROW_KERNEL_BODY.contains("compute_row_analytic_flex_from_parts_into"));
3803 #[cfg(target_os = "linux")]
3804 assert!(ROW_KERNEL_BODY.contains("cell_first_derivative_from_moments"));
3805 }
3806
3807 pub(crate) fn cpu_oracle_bms_flex_row_hvp(
3812 row_hessians: &[f64],
3813 marginal_design: &[f64],
3814 logslope_design: &[f64],
3815 block: &BmsFlexBlockLayout,
3816 primary: &BmsFlexPrimaryLayout,
3817 n: usize,
3818 v: &[f64],
3819 ) -> Vec<f64> {
3820 let r = primary.r;
3821 let p_m = block.p_m;
3822 let p_g = block.p_g;
3823 assert_eq!(v.len(), block.p_total);
3824 assert_eq!(row_hessians.len(), n * r * r);
3825 assert_eq!(marginal_design.len(), n * p_m);
3826 assert_eq!(logslope_design.len(), n * p_g);
3827 let mut out = vec![0.0_f64; block.p_total];
3828 let mut row_dir = vec![0.0_f64; r];
3829 let mut action = vec![0.0_f64; r];
3830 for row in 0..n {
3831 let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
3832 let grow = &logslope_design[row * p_g..(row + 1) * p_g];
3833 let mut acc_q = 0.0_f64;
3834 for j in 0..p_m {
3835 acc_q += mrow[j] * v[j];
3836 }
3837 let mut acc_g = 0.0_f64;
3838 for j in 0..p_g {
3839 acc_g += grow[j] * v[p_m + j];
3840 }
3841 row_dir[0] = acc_q;
3842 row_dir[1] = acc_g;
3843 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3844 for (k, ii) in prange.clone().enumerate() {
3845 row_dir[ii] = v[brange.start + k];
3846 }
3847 }
3848 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3849 for (k, ii) in prange.clone().enumerate() {
3850 row_dir[ii] = v[brange.start + k];
3851 }
3852 }
3853 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3854 for u in 0..r {
3855 let mut acc = 0.0_f64;
3856 for v_idx in 0..r {
3857 acc += h_slice[u * r + v_idx] * row_dir[v_idx];
3858 }
3859 action[u] = acc;
3860 }
3861 let a0 = action[0];
3862 for j in 0..p_m {
3863 out[j] += a0 * mrow[j];
3864 }
3865 let a1 = action[1];
3866 for j in 0..p_g {
3867 out[p_m + j] += a1 * grow[j];
3868 }
3869 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3870 for (k, ii) in prange.clone().enumerate() {
3871 out[brange.start + k] += action[ii];
3872 }
3873 }
3874 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3875 for (k, ii) in prange.clone().enumerate() {
3876 out[brange.start + k] += action[ii];
3877 }
3878 }
3879 }
3880 out
3881 }
3882
3883 pub(crate) fn cpu_oracle_bms_flex_row_diagonal(
3884 row_hessians: &[f64],
3885 marginal_design: &[f64],
3886 logslope_design: &[f64],
3887 block: &BmsFlexBlockLayout,
3888 primary: &BmsFlexPrimaryLayout,
3889 n: usize,
3890 ) -> Vec<f64> {
3891 let r = primary.r;
3892 let p_m = block.p_m;
3893 let p_g = block.p_g;
3894 let mut out = vec![0.0_f64; block.p_total];
3895 for row in 0..n {
3896 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3897 let h00 = h_slice[0];
3898 let h11 = h_slice[r + 1];
3899 let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
3900 let grow = &logslope_design[row * p_g..(row + 1) * p_g];
3901 for j in 0..p_m {
3902 out[j] += h00 * mrow[j] * mrow[j];
3903 }
3904 for j in 0..p_g {
3905 out[p_m + j] += h11 * grow[j] * grow[j];
3906 }
3907 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
3908 for (k, ii) in prange.clone().enumerate() {
3909 out[brange.start + k] += h_slice[ii * r + ii];
3910 }
3911 }
3912 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
3913 for (k, ii) in prange.clone().enumerate() {
3914 out[brange.start + k] += h_slice[ii * r + ii];
3915 }
3916 }
3917 }
3918 out
3919 }
3920
3921 #[test]
3925 pub(crate) fn cpu_oracle_hvp_matches_hand_computation_no_hw() {
3926 let n = 4_usize;
3927 let r = 4_usize; let p_m = 2_usize;
3929 let p_g = 2_usize;
3930 let p_h_dim = 1_usize;
3931 let p_w_dim = 1_usize;
3932 let p_total = p_m + p_g + p_h_dim + p_w_dim;
3933 let block = BmsFlexBlockLayout {
3934 p_m,
3935 p_g,
3936 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
3937 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
3938 p_total,
3939 };
3940 let primary = BmsFlexPrimaryLayout {
3941 h: Some(2..3),
3942 w: Some(3..4),
3943 r,
3944 };
3945 let mut row_hessians = vec![0.0_f64; n * r * r];
3947 for row in 0..n {
3948 for u in 0..r {
3949 for v in u..r {
3950 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
3951 row_hessians[row * r * r + u * r + v] = val;
3952 row_hessians[row * r * r + v * r + u] = val;
3953 }
3954 }
3955 }
3956 let mut marginal = vec![0.0_f64; n * p_m];
3957 for row in 0..n {
3958 for j in 0..p_m {
3959 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
3960 }
3961 }
3962 let mut logslope = vec![0.0_f64; n * p_g];
3963 for row in 0..n {
3964 for j in 0..p_g {
3965 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
3966 }
3967 }
3968 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
3969 let out = cpu_oracle_bms_flex_row_hvp(
3970 &row_hessians,
3971 &marginal,
3972 &logslope,
3973 &block,
3974 &primary,
3975 n,
3976 &v,
3977 );
3978 let mut expect_out_0 = 0.0_f64;
3980 for row in 0..n {
3981 let mrow = &marginal[row * p_m..(row + 1) * p_m];
3982 let grow = &logslope[row * p_g..(row + 1) * p_g];
3983 let mut row_dir = vec![0.0_f64; r];
3984 row_dir[0] = mrow[0] * v[0] + mrow[1] * v[1];
3985 row_dir[1] = grow[0] * v[p_m] + grow[1] * v[p_m + 1];
3986 row_dir[2] = v[p_m + p_g];
3987 row_dir[3] = v[p_m + p_g + p_h_dim];
3988 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
3989 let mut action0 = 0.0_f64;
3990 for vv in 0..r {
3994 action0 += h_slice[vv] * row_dir[vv];
3995 }
3996 expect_out_0 += action0 * mrow[0];
3997 }
3998 assert!(
3999 (out[0] - expect_out_0).abs() < 1e-12,
4000 "cpu oracle HVP out[0] mismatch: {} vs hand-check {}",
4001 out[0],
4002 expect_out_0
4003 );
4004 assert!(out.iter().all(|x| x.is_finite()));
4005 assert_eq!(out.len(), p_total);
4006 }
4007
4008 #[test]
4010 pub(crate) fn cpu_oracle_diagonal_matches_hand_computation() {
4011 let n = 3_usize;
4012 let r = 4_usize;
4013 let p_m = 2_usize;
4014 let p_g = 2_usize;
4015 let p_h_dim = 1_usize;
4016 let p_w_dim = 1_usize;
4017 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4018 let block = BmsFlexBlockLayout {
4019 p_m,
4020 p_g,
4021 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4022 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4023 p_total,
4024 };
4025 let primary = BmsFlexPrimaryLayout {
4026 h: Some(2..3),
4027 w: Some(3..4),
4028 r,
4029 };
4030 let mut row_hessians = vec![0.0_f64; n * r * r];
4031 for row in 0..n {
4032 for u in 0..r {
4033 row_hessians[row * r * r + u * r + u] = 1.0 + (row as f64) + (u as f64) * 0.5;
4034 }
4035 }
4036 let mut marginal = vec![0.0_f64; n * p_m];
4037 let mut logslope = vec![0.0_f64; n * p_g];
4038 for row in 0..n {
4039 for j in 0..p_m {
4040 marginal[row * p_m + j] = 0.2 + (row as f64) * 0.3 + (j as f64) * 0.1;
4041 }
4042 for j in 0..p_g {
4043 logslope[row * p_g + j] = -0.4 + (row as f64) * 0.1 + (j as f64) * 0.2;
4044 }
4045 }
4046 let out = cpu_oracle_bms_flex_row_diagonal(
4047 &row_hessians,
4048 &marginal,
4049 &logslope,
4050 &block,
4051 &primary,
4052 n,
4053 );
4054 let mut expect = 0.0_f64;
4056 for row in 0..n {
4057 let h00 = row_hessians[row * r * r];
4058 expect += h00 * marginal[row * p_m].powi(2);
4059 }
4060 assert!(
4061 (out[0] - expect).abs() < 1e-12,
4062 "out[0] {} vs {}",
4063 out[0],
4064 expect
4065 );
4066 let mut expect_h = 0.0_f64;
4068 for row in 0..n {
4069 expect_h += row_hessians[row * r * r + 2 * r + 2];
4070 }
4071 let h_slot = p_m + p_g;
4072 assert!(
4073 (out[h_slot] - expect_h).abs() < 1e-12,
4074 "h slot {} vs {}",
4075 out[h_slot],
4076 expect_h
4077 );
4078 }
4079
4080 #[test]
4085 pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_when_cuda_available() {
4086 #[cfg(not(target_os = "linux"))]
4087 {
4088 eprintln!(
4089 "[bms_flex_row hvp parity] non-Linux host — skipping CUDA parity \
4090 (CPU oracle exercised by sibling tests)"
4091 );
4092 }
4093 #[cfg(target_os = "linux")]
4094 {
4095 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4096 eprintln!(
4097 "[bms_flex_row hvp parity] no CUDA runtime — skipping device \
4098 parity"
4099 );
4100 return;
4101 };
4102 let n = 4_usize;
4103 let r = 4_usize;
4104 let p_m = 2_usize;
4105 let p_g = 2_usize;
4106 let p_h_dim = 1_usize;
4107 let p_w_dim = 1_usize;
4108 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4109 let block = BmsFlexBlockLayout {
4110 p_m,
4111 p_g,
4112 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4113 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4114 p_total,
4115 };
4116 let primary = BmsFlexPrimaryLayout {
4117 h: Some(2..3),
4118 w: Some(3..4),
4119 r,
4120 };
4121 let mut row_hessians = vec![0.0_f64; n * r * r];
4122 for row in 0..n {
4123 for u in 0..r {
4124 for v in u..r {
4125 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4126 row_hessians[row * r * r + u * r + v] = val;
4127 row_hessians[row * r * r + v * r + u] = val;
4128 }
4129 }
4130 }
4131 let mut marginal = vec![0.0_f64; n * p_m];
4132 for row in 0..n {
4133 for j in 0..p_m {
4134 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4135 }
4136 }
4137 let mut logslope = vec![0.0_f64; n * p_g];
4138 for row in 0..n {
4139 for j in 0..p_g {
4140 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4141 }
4142 }
4143 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4144 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4145 &row_hessians,
4146 &marginal,
4147 &logslope,
4148 &block,
4149 &primary,
4150 n,
4151 &v,
4152 );
4153 let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4154 &row_hessians,
4155 &marginal,
4156 &logslope,
4157 &block,
4158 &primary,
4159 n,
4160 );
4161
4162 let backend = HvpKernelBackend::probe()
4169 .expect("[bms_flex_row hvp parity] backend probe must succeed on CUDA host");
4170 let stream = backend.stream.clone();
4171 let d_h = stream
4172 .clone_htod(&row_hessians)
4173 .expect("[bms_flex_row hvp parity] upload h must succeed on CUDA host");
4174 let d_m = stream
4175 .clone_htod(&marginal)
4176 .expect("[bms_flex_row hvp parity] upload marg must succeed on CUDA host");
4177 let d_g = stream
4178 .clone_htod(&logslope)
4179 .expect("[bms_flex_row hvp parity] upload logslope must succeed on CUDA host");
4180 let storage = DeviceResidentRowHess {
4181 hess: d_h,
4182 marginal_design: d_m,
4183 logslope_design: d_g,
4184 n,
4185 r,
4186 block: block.clone(),
4187 primary: primary.clone(),
4188
4189 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4190 };
4191 let gpu_hvp =
4192 launch_bms_flex_row_hvp(&storage, &v).expect("HVP kernel must launch on CUDA host");
4193 let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4194 .expect("diagonal kernel must launch on CUDA host");
4195 assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4196 assert_eq!(gpu_diag.len(), cpu_diag.len());
4197 for i in 0..p_total {
4198 let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4199 assert!(
4200 diff <= 1e-10,
4201 "HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4202 cpu_hvp[i],
4203 gpu_hvp[i]
4204 );
4205 let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4206 assert!(
4207 ddiff <= 1e-10,
4208 "diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4209 cpu_diag[i],
4210 gpu_diag[i]
4211 );
4212 }
4213 }
4214 }
4215
4216 #[test]
4217 pub(crate) fn bms_flex_row_hvp_multi_scratch_is_bounded_at_large_scale_shape() {
4218 let n = 195_000_usize;
4219 let r = 20_usize;
4220 let p_total = 44_usize;
4221 let rhs_count = 4_usize;
4222 let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4223 .expect("large-scale multi-RHS scratch budget");
4224 let per_rhs_full_row_cache =
4225 (n * r * r * std::mem::size_of::<f64>()) as u64 * rhs_count as u64;
4226 assert!(
4227 scratch < per_rhs_full_row_cache / 100,
4228 "multi-RHS scratch must tile by row chunks instead of materializing \
4229 a row-Hessian copy per RHS: scratch={scratch} full_per_rhs={per_rhs_full_row_cache}"
4230 );
4231 assert!(
4232 bms_flex_row_hvp_multi_scratch_bytes_for_shape(
4233 n,
4234 p_total,
4235 BMS_FLEX_ROW_HVP_MAX_RHS + 1
4236 )
4237 .is_err(),
4238 "multi-RHS launch must reject unbounded RHS counts"
4239 );
4240 }
4241
4242 #[test]
4243 pub(crate) fn bms_flex_row_hvp_multi_kernel_matches_cpu_oracle_when_cuda_available() {
4244 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4245 eprintln!("[bms_flex_row hvp_multi parity] no CUDA runtime — skipping device parity");
4246 return;
4247 };
4248 let n = 5_usize;
4249 let r = 4_usize;
4250 let p_m = 2_usize;
4251 let p_g = 2_usize;
4252 let p_h_dim = 1_usize;
4253 let p_w_dim = 1_usize;
4254 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4255 let rhs_count = 3_usize;
4256 let block = BmsFlexBlockLayout {
4257 p_m,
4258 p_g,
4259 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4260 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4261 p_total,
4262 };
4263 let primary = BmsFlexPrimaryLayout {
4264 h: Some(2..3),
4265 w: Some(3..4),
4266 r,
4267 };
4268 let mut row_hessians = vec![0.0_f64; n * r * r];
4269 for row in 0..n {
4270 for u in 0..r {
4271 for v in u..r {
4272 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4273 row_hessians[row * r * r + u * r + v] = val;
4274 row_hessians[row * r * r + v * r + u] = val;
4275 }
4276 }
4277 }
4278 let mut marginal = vec![0.0_f64; n * p_m];
4279 let mut logslope = vec![0.0_f64; n * p_g];
4280 for row in 0..n {
4281 for j in 0..p_m {
4282 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4283 }
4284 for j in 0..p_g {
4285 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4286 }
4287 }
4288 let mut v_rhs = vec![0.0_f64; rhs_count * p_total];
4289 for rhs in 0..rhs_count {
4290 for j in 0..p_total {
4291 let seed = (rhs as f64) * 0.37 + (j as f64) * 0.19 + 0.4;
4292 v_rhs[rhs * p_total + j] = seed.sin() * 0.4 + seed.cos() * 0.2;
4293 }
4294 }
4295
4296 let backend = HvpKernelBackend::probe()
4300 .expect("[bms_flex_row hvp_multi parity] backend probe must succeed on CUDA host");
4301 let stream = backend.stream.clone();
4302 let d_h = stream
4303 .clone_htod(&row_hessians)
4304 .expect("[bms_flex_row hvp_multi parity] upload h must succeed on CUDA host");
4305 let d_m = stream
4306 .clone_htod(&marginal)
4307 .expect("[bms_flex_row hvp_multi parity] upload marg must succeed on CUDA host");
4308 let d_g = stream
4309 .clone_htod(&logslope)
4310 .expect("[bms_flex_row hvp_multi parity] upload logslope must succeed on CUDA host");
4311 let storage = DeviceResidentRowHess {
4312 hess: d_h,
4313 marginal_design: d_m,
4314 logslope_design: d_g,
4315 n,
4316 r,
4317 block: block.clone(),
4318 primary: primary.clone(),
4319
4320 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4321 };
4322 let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4323 .expect("storage scratch budget");
4324 assert!(
4325 scratch < storage.bytes,
4326 "multi-RHS scratch should stay below resident cache bytes"
4327 );
4328 let gpu = launch_bms_flex_row_hvp_multi(&storage, &v_rhs, rhs_count)
4329 .expect("multi-RHS HVP kernel must launch on CUDA host");
4330 assert_eq!(gpu.len(), rhs_count * p_total);
4331 for rhs in 0..rhs_count {
4332 let v = &v_rhs[rhs * p_total..(rhs + 1) * p_total];
4333 let cpu = cpu_oracle_bms_flex_row_hvp(
4334 &row_hessians,
4335 &marginal,
4336 &logslope,
4337 &block,
4338 &primary,
4339 n,
4340 v,
4341 );
4342 let single = launch_bms_flex_row_hvp(&storage, v)
4343 .expect("single-RHS HVP kernel must launch on CUDA host");
4344 for j in 0..p_total {
4345 let got = gpu[rhs * p_total + j];
4346 let diff = (cpu[j] - got).abs();
4347 assert!(
4348 diff <= 1e-10,
4349 "multi-RHS HVP rhs={rhs} j={j}: cpu={} gpu={} |diff|={diff:.3e}",
4350 cpu[j],
4351 got
4352 );
4353 assert_eq!(
4354 got, single[j],
4355 "multi-RHS and single-RHS host launch diverged at rhs={rhs} j={j}"
4356 );
4357 }
4358 }
4359 }
4360
4361 #[test]
4372 pub(crate) fn bms_flex_row_hvp_into_device_matches_cpu_oracle_and_host_out() {
4373 #[cfg(not(target_os = "linux"))]
4374 {
4375 eprintln!(
4376 "[bms_flex_row hvp_into_device parity] non-Linux host — skipping \
4377 CUDA parity (CPU oracle exercised by sibling tests)"
4378 );
4379 }
4380 #[cfg(target_os = "linux")]
4381 {
4382 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4383 eprintln!(
4384 "[bms_flex_row hvp_into_device parity] no CUDA runtime — \
4385 skipping device parity"
4386 );
4387 return;
4388 };
4389 let n = 4_usize;
4390 let r = 4_usize;
4391 let p_m = 2_usize;
4392 let p_g = 2_usize;
4393 let p_h_dim = 1_usize;
4394 let p_w_dim = 1_usize;
4395 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4396 let block = BmsFlexBlockLayout {
4397 p_m,
4398 p_g,
4399 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4400 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4401 p_total,
4402 };
4403 let primary = BmsFlexPrimaryLayout {
4404 h: Some(2..3),
4405 w: Some(3..4),
4406 r,
4407 };
4408 let mut row_hessians = vec![0.0_f64; n * r * r];
4409 for row in 0..n {
4410 for u in 0..r {
4411 for v in u..r {
4412 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4413 row_hessians[row * r * r + u * r + v] = val;
4414 row_hessians[row * r * r + v * r + u] = val;
4415 }
4416 }
4417 }
4418 let mut marginal = vec![0.0_f64; n * p_m];
4419 for row in 0..n {
4420 for j in 0..p_m {
4421 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4422 }
4423 }
4424 let mut logslope = vec![0.0_f64; n * p_g];
4425 for row in 0..n {
4426 for j in 0..p_g {
4427 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4428 }
4429 }
4430 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4431 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4432 &row_hessians,
4433 &marginal,
4434 &logslope,
4435 &block,
4436 &primary,
4437 n,
4438 &v,
4439 );
4440
4441 let backend = HvpKernelBackend::probe().expect(
4444 "[bms_flex_row hvp_into_device parity] backend probe must succeed on CUDA host",
4445 );
4446 let stream = backend.stream.clone();
4447 let d_h = stream
4448 .clone_htod(&row_hessians)
4449 .expect("[bms_flex_row hvp_into_device parity] upload h must succeed on CUDA host");
4450 let d_m = stream.clone_htod(&marginal).expect(
4451 "[bms_flex_row hvp_into_device parity] upload marg must succeed on CUDA host",
4452 );
4453 let d_g = stream.clone_htod(&logslope).expect(
4454 "[bms_flex_row hvp_into_device parity] upload logslope must succeed on CUDA host",
4455 );
4456 let storage = DeviceResidentRowHess {
4457 hess: d_h,
4458 marginal_design: d_m,
4459 logslope_design: d_g,
4460 n,
4461 r,
4462 block: block.clone(),
4463 primary: primary.clone(),
4464
4465 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4466 };
4467
4468 let host_out_hvp = launch_bms_flex_row_hvp(&storage, &v)
4470 .expect("host-out HVP kernel must launch on CUDA host");
4471
4472 let d_v = stream
4475 .clone_htod(&v)
4476 .expect("upload direction for device-out HVP");
4477 let mut d_out = stream
4478 .alloc_zeros::<f64>(p_total)
4479 .expect("alloc device-out HVP output");
4480 launch_bms_flex_row_hvp_into_device(&storage, &d_v, &mut d_out)
4481 .expect("device-out HVP kernel must launch on CUDA host");
4482 stream
4483 .synchronize()
4484 .expect("synchronize after device-out HVP");
4485 let device_out_hvp = stream
4486 .clone_dtoh(&d_out)
4487 .expect("download device-out HVP output");
4488
4489 assert_eq!(device_out_hvp.len(), cpu_hvp.len());
4490 assert_eq!(device_out_hvp.len(), host_out_hvp.len());
4491 for i in 0..p_total {
4492 let diff = (cpu_hvp[i] - device_out_hvp[i]).abs();
4493 assert!(
4494 diff <= 1e-10,
4495 "device-out HVP[{i}] vs CPU: cpu={} gpu={} |Δ|={diff:.3e}",
4496 cpu_hvp[i],
4497 device_out_hvp[i]
4498 );
4499 let host_diff = (host_out_hvp[i] - device_out_hvp[i]).abs();
4502 assert!(
4503 host_diff == 0.0,
4504 "device-out vs host-out HVP[{i}]: host={} device={} |Δ|={host_diff:.3e}",
4505 host_out_hvp[i],
4506 device_out_hvp[i]
4507 );
4508 }
4509 }
4510 }
4511
4512 #[test]
4525 pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_at_n64_r20_p44() {
4526 #[cfg(not(target_os = "linux"))]
4527 {
4528 eprintln!(
4529 "[bms_flex_row hvp parity n64_r20_p44] non-Linux host — \
4530 skipping CUDA parity"
4531 );
4532 }
4533 #[cfg(target_os = "linux")]
4534 {
4535 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4536 eprintln!(
4537 "[bms_flex_row hvp parity n64_r20_p44] no CUDA runtime — \
4538 skipping device parity"
4539 );
4540 return;
4541 };
4542 let n = 64_usize;
4543 let p_m = 14_usize;
4544 let p_g = 12_usize;
4545 let p_h_dim = 10_usize;
4546 let p_w_dim = 8_usize;
4547 let r = 2 + p_h_dim + p_w_dim;
4548 assert_eq!(r, 20);
4549 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4550 assert_eq!(p_total, 44);
4551 let block = BmsFlexBlockLayout {
4552 p_m,
4553 p_g,
4554 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4555 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4556 p_total,
4557 };
4558 let primary = BmsFlexPrimaryLayout {
4559 h: Some(2..2 + p_h_dim),
4560 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4561 r,
4562 };
4563
4564 let mut row_hessians = vec![0.0_f64; n * r * r];
4570 for row in 0..n {
4571 let base = row * r * r;
4572 for u in 0..r {
4573 for v in 0..r {
4574 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
4575 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4576 row_hessians[base + u * r + v] = a;
4577 }
4578 }
4579 for u in 0..r {
4580 for v in (u + 1)..r {
4581 let upper = row_hessians[base + u * r + v];
4582 let lower = row_hessians[base + v * r + u];
4583 let sym = 0.5 * (upper + lower);
4584 row_hessians[base + u * r + v] = sym;
4585 row_hessians[base + v * r + u] = sym;
4586 }
4587 row_hessians[base + u * r + u] += r as f64;
4588 }
4589 }
4590 let mut marginal = vec![0.0_f64; n * p_m];
4591 for row in 0..n {
4592 for j in 0..p_m {
4593 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4594 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4595 }
4596 }
4597 let mut logslope = vec![0.0_f64; n * p_g];
4598 for row in 0..n {
4599 for j in 0..p_g {
4600 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4601 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4602 }
4603 }
4604 let v: Vec<f64> = (0..p_total)
4605 .map(|i| {
4606 let seed = (i as f64) * 0.157 + 0.6;
4607 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4608 })
4609 .collect();
4610
4611 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4612 &row_hessians,
4613 &marginal,
4614 &logslope,
4615 &block,
4616 &primary,
4617 n,
4618 &v,
4619 );
4620 let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4621 &row_hessians,
4622 &marginal,
4623 &logslope,
4624 &block,
4625 &primary,
4626 n,
4627 );
4628
4629 let backend = match HvpKernelBackend::probe() {
4630 Ok(b) => b,
4631 Err(err) => {
4632 eprintln!(
4633 "[bms_flex_row hvp parity n64_r20_p44] backend probe \
4634 failed: {err}"
4635 );
4636 return;
4637 }
4638 };
4639 let stream = backend.stream.clone();
4640 let d_h = match stream.clone_htod(&row_hessians) {
4641 Ok(s) => s,
4642 Err(err) => {
4643 eprintln!(
4644 "[bms_flex_row hvp parity n64_r20_p44] upload h \
4645 failed: {err}"
4646 );
4647 return;
4648 }
4649 };
4650 let d_m = match stream.clone_htod(&marginal) {
4651 Ok(s) => s,
4652 Err(err) => {
4653 eprintln!(
4654 "[bms_flex_row hvp parity n64_r20_p44] upload marg \
4655 failed: {err}"
4656 );
4657 return;
4658 }
4659 };
4660 let d_g = match stream.clone_htod(&logslope) {
4661 Ok(s) => s,
4662 Err(err) => {
4663 eprintln!(
4664 "[bms_flex_row hvp parity n64_r20_p44] upload logslope \
4665 failed: {err}"
4666 );
4667 return;
4668 }
4669 };
4670 let storage = DeviceResidentRowHess {
4671 hess: d_h,
4672 marginal_design: d_m,
4673 logslope_design: d_g,
4674 n,
4675 r,
4676 block: block.clone(),
4677 primary: primary.clone(),
4678
4679 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4680 };
4681 let gpu_hvp = launch_bms_flex_row_hvp(&storage, &v)
4682 .expect("HVP kernel must launch on CUDA host at n64/r20/p44");
4683 let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4684 .expect("diagonal kernel must launch on CUDA host at n64/r20/p44");
4685 assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4686 assert_eq!(gpu_diag.len(), cpu_diag.len());
4687 for i in 0..p_total {
4688 let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4689 assert!(
4690 diff <= 1e-8,
4691 "n64_r20_p44 HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4692 cpu_hvp[i],
4693 gpu_hvp[i]
4694 );
4695 let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4696 assert!(
4697 ddiff <= 1e-8,
4698 "n64_r20_p44 diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4699 cpu_diag[i],
4700 gpu_diag[i]
4701 );
4702 }
4703 }
4704 }
4705
4706 #[test]
4712 pub(crate) fn bms_flex_row_dense_block_kernel_matches_cpu_pullback() {
4713 #[cfg(not(target_os = "linux"))]
4714 {
4715 eprintln!("[bms_flex_row dense_block parity] non-Linux host — skipping CUDA parity");
4716 }
4717 #[cfg(target_os = "linux")]
4718 {
4719 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4720 eprintln!("[bms_flex_row dense_block parity] no CUDA runtime — skipping");
4721 return;
4722 };
4723 let n = 24_usize;
4727 let p_m = 4_usize;
4728 let p_g = 4_usize;
4729 let p_h_dim = 3_usize;
4730 let p_w_dim = 3_usize;
4731 let r = 2 + p_h_dim + p_w_dim;
4732 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4733 let block = BmsFlexBlockLayout {
4734 p_m,
4735 p_g,
4736 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4737 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4738 p_total,
4739 };
4740 let primary = BmsFlexPrimaryLayout {
4741 h: Some(2..2 + p_h_dim),
4742 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4743 r,
4744 };
4745
4746 let mut row_hessians = vec![0.0_f64; n * r * r];
4747 for row in 0..n {
4748 let base = row * r * r;
4749 for u in 0..r {
4750 for v in 0..r {
4751 let seed = (row as f64) * 0.21 + (u as f64) * 1.13 + (v as f64) * 0.47;
4752 let a = (seed.sin() * 1.4 + (seed * 0.6).cos() * 0.7) * 0.5;
4753 row_hessians[base + u * r + v] = a;
4754 }
4755 }
4756 for u in 0..r {
4757 for v in (u + 1)..r {
4758 let upper = row_hessians[base + u * r + v];
4759 let lower = row_hessians[base + v * r + u];
4760 let sym = 0.5 * (upper + lower);
4761 row_hessians[base + u * r + v] = sym;
4762 row_hessians[base + v * r + u] = sym;
4763 }
4764 row_hessians[base + u * r + u] += r as f64;
4765 }
4766 }
4767 let mut marginal = vec![0.0_f64; n * p_m];
4768 for row in 0..n {
4769 for j in 0..p_m {
4770 let seed = (row as f64) * 0.083 + (j as f64) * 0.171 + 0.31;
4771 marginal[row * p_m + j] = seed.sin() * 0.7 - (seed * 0.5).cos() * 0.25;
4772 }
4773 }
4774 let mut logslope = vec![0.0_f64; n * p_g];
4775 for row in 0..n {
4776 for j in 0..p_g {
4777 let seed = (row as f64) * 0.097 + (j as f64) * 0.143 - 0.15;
4778 logslope[row * p_g + j] = seed.cos() * 0.65 + (seed * 0.4).sin() * 0.2;
4779 }
4780 }
4781
4782 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
4784 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
4785 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
4786 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
4787 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
4788 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
4789 let mut h_cpu = vec![0.0_f64; p_total * p_total];
4790 for row in 0..n {
4791 let mrow = &marginal[row * p_m..(row + 1) * p_m];
4792 let grow = &logslope[row * p_g..(row + 1) * p_g];
4793 let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
4794 let mut phi = vec![vec![0.0_f64; p_total]; r];
4796 for k in 0..p_m {
4797 phi[0][k] = mrow[k];
4798 }
4799 for k in 0..p_g {
4800 phi[1][p_m + k] = grow[k];
4801 }
4802 for k in 0..h_block_len {
4803 phi[h_primary_start + k][h_block_start + k] = 1.0;
4804 }
4805 for k in 0..w_block_len {
4806 phi[w_primary_start + k][w_block_start + k] = 1.0;
4807 }
4808 for u in 0..r {
4809 for v in 0..r {
4810 let huv = hrow[u * r + v];
4811 if huv == 0.0 {
4812 continue;
4813 }
4814 for m in 0..p_total {
4815 let pm = phi[u][m];
4816 if pm == 0.0 {
4817 continue;
4818 }
4819 let scaled = huv * pm;
4820 for nn in 0..p_total {
4821 h_cpu[m * p_total + nn] += scaled * phi[v][nn];
4822 }
4823 }
4824 }
4825 }
4826 }
4827
4828 let backend = HvpKernelBackend::probe().expect(
4833 "[bms_flex_row dense_block parity] backend probe must succeed on CUDA host",
4834 );
4835 let stream = backend.stream.clone();
4836 let d_h = stream
4837 .clone_htod(&row_hessians)
4838 .expect("[bms_flex_row dense_block parity] upload h must succeed on CUDA host");
4839 let d_m = stream
4840 .clone_htod(&marginal)
4841 .expect("[bms_flex_row dense_block parity] upload marg must succeed on CUDA host");
4842 let d_g = stream.clone_htod(&logslope).expect(
4843 "[bms_flex_row dense_block parity] upload logslope must succeed on CUDA host",
4844 );
4845 let storage = DeviceResidentRowHess {
4846 hess: d_h,
4847 marginal_design: d_m,
4848 logslope_design: d_g,
4849 n,
4850 r,
4851 block: block.clone(),
4852 primary: primary.clone(),
4853
4854 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4855 };
4856 let h_gpu = launch_bms_flex_row_dense_block(&storage)
4857 .expect("dense_block kernel must launch on CUDA host");
4858 assert_eq!(h_gpu.len(), p_total * p_total);
4859
4860 let mut max_abs = 0.0_f64;
4863 for i in 0..p_total {
4864 for j in 0..p_total {
4865 let a = h_cpu[i * p_total + j];
4866 let b = h_gpu[i * p_total + j];
4867 let diff = (a - b).abs();
4868 if diff > max_abs {
4869 max_abs = diff;
4870 }
4871 assert!(
4872 diff <= 1e-9 * a.abs().max(b.abs()).max(1.0),
4873 "dense_block[{i},{j}]: cpu={a} gpu={b} |Δ|={diff:.3e}"
4874 );
4875 }
4876 }
4877 eprintln!(
4878 "[bms_flex_row dense_block parity] n={n} r={r} p={p_total}: max|Δ|={max_abs:.3e}"
4879 );
4880 }
4881 }
4882
4883 #[test]
4903 pub(crate) fn bms_flex_row_hvp_v100_hill_climb_5x_vs_cpu_at_large_scale() {
4904 #[cfg(not(target_os = "linux"))]
4905 {
4906 eprintln!("[bms_flex_row hvp hill-climb] non-Linux host — skipping V100 perf gate");
4907 }
4908 #[cfg(target_os = "linux")]
4909 {
4910 use rayon::prelude::*;
4911
4912 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4913 eprintln!(
4914 "[bms_flex_row hvp hill-climb] no CUDA runtime — skipping V100 perf gate"
4915 );
4916 return;
4917 };
4918 let n = 195_000_usize;
4919 let p_m = 14_usize;
4920 let p_g = 12_usize;
4921 let p_h_dim = 10_usize;
4922 let p_w_dim = 8_usize;
4923 let r = 2 + p_h_dim + p_w_dim;
4924 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4925 let block = BmsFlexBlockLayout {
4926 p_m,
4927 p_g,
4928 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4929 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4930 p_total,
4931 };
4932 let primary = BmsFlexPrimaryLayout {
4933 h: Some(2..2 + p_h_dim),
4934 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4935 r,
4936 };
4937
4938 let mut row_hessians = vec![0.0_f64; n * r * r];
4940 for row in 0..n {
4941 let base = row * r * r;
4942 for u in 0..r {
4943 for vv in 0..r {
4944 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
4945 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4946 row_hessians[base + u * r + vv] = a;
4947 }
4948 }
4949 for u in 0..r {
4950 for vv in (u + 1)..r {
4951 let upper = row_hessians[base + u * r + vv];
4952 let lower = row_hessians[base + vv * r + u];
4953 let sym = 0.5 * (upper + lower);
4954 row_hessians[base + u * r + vv] = sym;
4955 row_hessians[base + vv * r + u] = sym;
4956 }
4957 row_hessians[base + u * r + u] += r as f64;
4958 }
4959 }
4960 let mut marginal = vec![0.0_f64; n * p_m];
4961 for row in 0..n {
4962 for j in 0..p_m {
4963 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4964 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4965 }
4966 }
4967 let mut logslope = vec![0.0_f64; n * p_g];
4968 for row in 0..n {
4969 for j in 0..p_g {
4970 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4971 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4972 }
4973 }
4974 let v: Vec<f64> = (0..p_total)
4975 .map(|i| {
4976 let seed = (i as f64) * 0.157 + 0.6;
4977 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4978 })
4979 .collect();
4980
4981 let backend = match HvpKernelBackend::probe() {
4983 Ok(b) => b,
4984 Err(err) => {
4985 eprintln!("[bms_flex_row hvp hill-climb] backend probe failed: {err}");
4986 return;
4987 }
4988 };
4989 let stream = backend.stream.clone();
4990 let d_h = match stream.clone_htod(&row_hessians) {
4991 Ok(s) => s,
4992 Err(err) => {
4993 eprintln!("[bms_flex_row hvp hill-climb] upload h failed (likely OOM): {err}");
4994 return;
4995 }
4996 };
4997 let d_m = match stream.clone_htod(&marginal) {
4998 Ok(s) => s,
4999 Err(err) => {
5000 eprintln!("[bms_flex_row hvp hill-climb] upload marg failed: {err}");
5001 return;
5002 }
5003 };
5004 let d_g = match stream.clone_htod(&logslope) {
5005 Ok(s) => s,
5006 Err(err) => {
5007 eprintln!("[bms_flex_row hvp hill-climb] upload logslope failed: {err}");
5008 return;
5009 }
5010 };
5011 let storage = DeviceResidentRowHess {
5012 hess: d_h,
5013 marginal_design: d_m,
5014 logslope_design: d_g,
5015 n,
5016 r,
5017 block: block.clone(),
5018 primary: primary.clone(),
5019
5020 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5021 };
5022 let warmup: usize = 3;
5023 let iters: usize = 15;
5024 for _ in 0..warmup {
5025 let out =
5026 launch_bms_flex_row_hvp(&storage, &v).expect("warmup GPU HVP must launch");
5027 assert_eq!(out.len(), p_total);
5028 }
5029 let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5030 for _ in 0..iters {
5031 let t0 = std::time::Instant::now();
5032 let out = launch_bms_flex_row_hvp(&storage, &v).expect("GPU HVP must launch");
5033 gpu_us.push(t0.elapsed().as_micros());
5034 assert_eq!(out.len(), p_total);
5035 }
5036 gpu_us.sort_unstable();
5037 let gpu_median = gpu_us[iters / 2];
5038
5039 const CHUNK_ROWS: usize = 4096;
5045 let cpu_hvp_parallel = || -> Vec<f64> {
5046 let nchunks = n.div_ceil(CHUNK_ROWS);
5047 (0..nchunks)
5048 .into_par_iter()
5049 .fold(
5050 || vec![0.0_f64; p_total],
5051 |mut acc, ci| {
5052 let lo = ci * CHUNK_ROWS;
5053 let hi = (lo + CHUNK_ROWS).min(n);
5054 let m = hi - lo;
5055 let partial = cpu_oracle_bms_flex_row_hvp(
5056 &row_hessians[lo * r * r..hi * r * r],
5057 &marginal[lo * p_m..hi * p_m],
5058 &logslope[lo * p_g..hi * p_g],
5059 &block,
5060 &primary,
5061 m,
5062 &v,
5063 );
5064 for (a, &p) in acc.iter_mut().zip(partial.iter()) {
5065 *a += p;
5066 }
5067 acc
5068 },
5069 )
5070 .reduce(
5071 || vec![0.0_f64; p_total],
5072 |mut a, b| {
5073 for (ax, bx) in a.iter_mut().zip(b.iter()) {
5074 *ax += *bx;
5075 }
5076 a
5077 },
5078 )
5079 };
5080 let warm = cpu_hvp_parallel();
5082 assert_eq!(warm.len(), p_total);
5083 let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5084 for _ in 0..iters {
5085 let t0 = std::time::Instant::now();
5086 let out = cpu_hvp_parallel();
5087 cpu_us.push(t0.elapsed().as_micros());
5088 assert_eq!(out.len(), p_total);
5089 }
5090 cpu_us.sort_unstable();
5091 let cpu_median = cpu_us[iters / 2];
5092
5093 let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5094 eprintln!(
5095 "[bms_flex_row hvp hill-climb] large-scale n={n} r={r} p={p_total}: \
5096 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5097 speedup={speedup:.2}× (charter target ≥ 5×)"
5098 );
5099 assert!(
5100 speedup >= 5.0,
5101 "large-scale HVP perf gate: GPU only {speedup:.2}× faster than CPU; \
5102 need ≥ 5× per Block 9 charter (cpu_median={cpu_median}us, \
5103 gpu_median={gpu_median}us). Hill-climb the kernel until met or \
5104 prove the kernel is at hardware roofline."
5105 );
5106 }
5107 }
5108
5109 #[test]
5114 pub(crate) fn bms_flex_row_dense_block_v100_hill_climb_10x_vs_cpu_at_large_scale() {
5115 #[cfg(not(target_os = "linux"))]
5116 {
5117 eprintln!(
5118 "[bms_flex_row dense_block hill-climb] non-Linux host — skipping V100 perf gate"
5119 );
5120 }
5121 #[cfg(target_os = "linux")]
5122 {
5123 use rayon::prelude::*;
5124
5125 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5126 eprintln!(
5127 "[bms_flex_row dense_block hill-climb] no CUDA runtime — skipping V100 perf gate"
5128 );
5129 return;
5130 };
5131 let n = 195_000_usize;
5132 let p_m = 14_usize;
5133 let p_g = 12_usize;
5134 let p_h_dim = 10_usize;
5135 let p_w_dim = 8_usize;
5136 let r = 2 + p_h_dim + p_w_dim;
5137 let p_total = p_m + p_g + p_h_dim + p_w_dim;
5138 let block = BmsFlexBlockLayout {
5139 p_m,
5140 p_g,
5141 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5142 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5143 p_total,
5144 };
5145 let primary = BmsFlexPrimaryLayout {
5146 h: Some(2..2 + p_h_dim),
5147 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5148 r,
5149 };
5150
5151 let mut row_hessians = vec![0.0_f64; n * r * r];
5153 for row in 0..n {
5154 let base = row * r * r;
5155 for u in 0..r {
5156 for vv in 0..r {
5157 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5158 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5159 row_hessians[base + u * r + vv] = a;
5160 }
5161 }
5162 for u in 0..r {
5163 for vv in (u + 1)..r {
5164 let upper = row_hessians[base + u * r + vv];
5165 let lower = row_hessians[base + vv * r + u];
5166 let sym = 0.5 * (upper + lower);
5167 row_hessians[base + u * r + vv] = sym;
5168 row_hessians[base + vv * r + u] = sym;
5169 }
5170 row_hessians[base + u * r + u] += r as f64;
5171 }
5172 }
5173 let mut marginal = vec![0.0_f64; n * p_m];
5174 for row in 0..n {
5175 for j in 0..p_m {
5176 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5177 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5178 }
5179 }
5180 let mut logslope = vec![0.0_f64; n * p_g];
5181 for row in 0..n {
5182 for j in 0..p_g {
5183 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5184 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5185 }
5186 }
5187
5188 if p_total > DENSE_BLOCK_MAX_P {
5191 eprintln!(
5192 "[bms_flex_row dense_block hill-climb] p_total={p_total} > MAX={DENSE_BLOCK_MAX_P}, skipping"
5193 );
5194 return;
5195 }
5196 let backend = match HvpKernelBackend::probe() {
5197 Ok(b) => b,
5198 Err(err) => {
5199 eprintln!("[bms_flex_row dense_block hill-climb] backend probe failed: {err}");
5200 return;
5201 }
5202 };
5203 let stream = backend.stream.clone();
5204 let d_h = match stream.clone_htod(&row_hessians) {
5205 Ok(s) => s,
5206 Err(err) => {
5207 eprintln!("[bms_flex_row dense_block hill-climb] upload h failed: {err}");
5208 return;
5209 }
5210 };
5211 let d_m = match stream.clone_htod(&marginal) {
5212 Ok(s) => s,
5213 Err(err) => {
5214 eprintln!("[bms_flex_row dense_block hill-climb] upload marg failed: {err}");
5215 return;
5216 }
5217 };
5218 let d_g = match stream.clone_htod(&logslope) {
5219 Ok(s) => s,
5220 Err(err) => {
5221 eprintln!(
5222 "[bms_flex_row dense_block hill-climb] upload logslope failed: {err}"
5223 );
5224 return;
5225 }
5226 };
5227 let storage = DeviceResidentRowHess {
5228 hess: d_h,
5229 marginal_design: d_m,
5230 logslope_design: d_g,
5231 n,
5232 r,
5233 block: block.clone(),
5234 primary: primary.clone(),
5235
5236 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5237 };
5238 let warmup: usize = 2;
5240 let iters: usize = 5;
5241 for _ in 0..warmup {
5242 let out = launch_bms_flex_row_dense_block(&storage)
5243 .expect("warmup GPU dense_block must launch");
5244 assert_eq!(out.len(), p_total * p_total);
5245 }
5246 let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5247 for _ in 0..iters {
5248 let t0 = std::time::Instant::now();
5249 let out =
5250 launch_bms_flex_row_dense_block(&storage).expect("GPU dense_block must launch");
5251 gpu_us.push(t0.elapsed().as_micros());
5252 assert_eq!(out.len(), p_total * p_total);
5253 }
5254 gpu_us.sort_unstable();
5255 let gpu_median = gpu_us[iters / 2];
5256
5257 const CHUNK_ROWS: usize = 2048;
5260 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5261 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5262 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5263 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5264 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5265 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5266 let cpu_build_parallel = || -> Vec<f64> {
5267 let nchunks = n.div_ceil(CHUNK_ROWS);
5268 (0..nchunks)
5269 .into_par_iter()
5270 .fold(
5271 || vec![0.0_f64; p_total * p_total],
5272 |mut acc, ci| {
5273 let lo = ci * CHUNK_ROWS;
5274 let hi = (lo + CHUNK_ROWS).min(n);
5275 let mut phi: Vec<Vec<f64>> = vec![vec![0.0_f64; p_total]; r];
5276 for row in lo..hi {
5277 for col in phi.iter_mut() {
5278 col.iter_mut().for_each(|v| *v = 0.0);
5279 }
5280 let mrow = &marginal[row * p_m..(row + 1) * p_m];
5281 let grow = &logslope[row * p_g..(row + 1) * p_g];
5282 for k in 0..p_m {
5283 phi[0][k] = mrow[k];
5284 }
5285 for k in 0..p_g {
5286 phi[1][p_m + k] = grow[k];
5287 }
5288 for k in 0..h_block_len {
5289 phi[h_primary_start + k][h_block_start + k] = 1.0;
5290 }
5291 for k in 0..w_block_len {
5292 phi[w_primary_start + k][w_block_start + k] = 1.0;
5293 }
5294 let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5295 for u in 0..r {
5296 for v_idx in 0..r {
5297 let huv = hrow[u * r + v_idx];
5298 if huv == 0.0 {
5299 continue;
5300 }
5301 for m in 0..p_total {
5302 let pm = phi[u][m];
5303 if pm == 0.0 {
5304 continue;
5305 }
5306 let scaled = huv * pm;
5307 for nn in 0..p_total {
5308 acc[m * p_total + nn] += scaled * phi[v_idx][nn];
5309 }
5310 }
5311 }
5312 }
5313 }
5314 acc
5315 },
5316 )
5317 .reduce(
5318 || vec![0.0_f64; p_total * p_total],
5319 |mut a, b| {
5320 for (ax, bx) in a.iter_mut().zip(b.iter()) {
5321 *ax += *bx;
5322 }
5323 a
5324 },
5325 )
5326 };
5327 let warm_cpu = cpu_build_parallel();
5328 assert_eq!(warm_cpu.len(), p_total * p_total);
5329 let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5330 for _ in 0..iters {
5331 let t0 = std::time::Instant::now();
5332 let out = cpu_build_parallel();
5333 cpu_us.push(t0.elapsed().as_micros());
5334 assert_eq!(out.len(), p_total * p_total);
5335 }
5336 cpu_us.sort_unstable();
5337 let cpu_median = cpu_us[iters / 2];
5338
5339 let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5340 eprintln!(
5341 "[bms_flex_row dense_block hill-climb] large-scale n={n} r={r} p={p_total}: \
5342 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5343 speedup={speedup:.2}× (charter target ≥ 10×)"
5344 );
5345 assert!(
5346 speedup >= 10.0,
5347 "large-scale dense-H perf gate: GPU only {speedup:.2}× faster than CPU; \
5348 need ≥ 10× per Block 9 charter (cpu_median={cpu_median}us, \
5349 gpu_median={gpu_median}us). Hill-climb the dense_block kernel \
5350 (warp-stripe the u-v-m-n loop, vectorise loads, etc.) until met \
5351 or prove the kernel is at hardware roofline."
5352 );
5353 }
5354 }
5355}