1#[cfg(target_os = "linux")]
85use std::sync::OnceLock;
86
87use gam_gpu::gpu_error::GpuError;
88
89#[cfg(target_os = "linux")]
90use std::sync::Arc;
91
92#[cfg(target_os = "linux")]
93use cudarc::driver::{CudaModule, CudaSlice, CudaStream, LaunchConfig, PushKernelArg};
94
95pub(crate) const MAX_R: usize = 32;
99
100#[cfg(target_os = "linux")]
104pub(crate) const ROW_KERNEL_THREADS: u32 = 32;
105
106pub(crate) const COEFF4: usize = 4;
109
110pub(crate) const MOMENT_STRIDE: usize = 10;
113
114pub(crate) enum CellMomentsSource<'a> {
119 Host(&'a [f64]),
122 #[cfg(target_os = "linux")]
127 Device(&'a CudaSlice<f64>),
128}
129
130impl<'a> CellMomentsSource<'a> {
131 pub(crate) fn len(&self) -> usize {
133 match self {
134 CellMomentsSource::Host(slice) => slice.len(),
135 #[cfg(target_os = "linux")]
136 CellMomentsSource::Device(d) => d.len(),
137 }
138 }
139}
140
141macro_rules! define_bms_flex_row_kernel_input_types {
149 (
150 f64_fields: [$($f64_field:ident),+ $(,)?],
151 u32_fields: [$($u32_field:ident),+ $(,)?],
152 moments_field: $moments_field:ident $(,)?
153 ) => {
154 pub(crate) struct BmsFlexRowKernelInputs<'a> {
155 pub n_rows: usize,
157 pub r: usize,
159 pub p_h: usize,
161 pub p_w: usize,
163 pub s_f: f64,
166 $(pub $f64_field: &'a [f64],)+
167 $(pub $u32_field: &'a [u32],)+
168 pub $moments_field: CellMomentsSource<'a>,
169 }
170
171 pub(crate) struct BmsFlexRowKernelInputsOwned {
176 pub n_rows: usize,
177 pub r: usize,
178 pub p_h: usize,
179 pub p_w: usize,
180 pub s_f: f64,
181 $(pub $f64_field: Vec<f64>,)+
182 $(pub $u32_field: Vec<u32>,)+
183 pub $moments_field: Vec<f64>,
184 #[cfg(target_os = "linux")]
188 pub cell_moments_device: Option<CudaSlice<f64>>,
189 }
190
191 impl BmsFlexRowKernelInputsOwned {
192 pub(crate) fn as_borrowed(&self) -> BmsFlexRowKernelInputs<'_> {
197 #[cfg(target_os = "linux")]
198 let cell_moments = match self.cell_moments_device.as_ref() {
199 Some(d) => CellMomentsSource::Device(d),
200 None => CellMomentsSource::Host(&self.cell_moments),
201 };
202 #[cfg(not(target_os = "linux"))]
203 let cell_moments = CellMomentsSource::Host(&self.cell_moments);
204 BmsFlexRowKernelInputs {
205 n_rows: self.n_rows,
206 r: self.r,
207 p_h: self.p_h,
208 p_w: self.p_w,
209 s_f: self.s_f,
210 $($f64_field: &self.$f64_field,)+
211 $($u32_field: &self.$u32_field,)+
212 $moments_field: cell_moments,
213 }
214 }
215 }
216 };
217}
218
219define_bms_flex_row_kernel_input_types! {
220 f64_fields: [
221 q,
222 b,
223 mu_1,
224 mu_2,
225 z_obs,
226 y,
227 w,
228 e_obs,
229 cell_c0,
230 cell_c1,
231 cell_c2,
232 cell_c3,
233 cell_a,
234 cell_aa,
235 cell_r,
236 cell_ar,
237 cell_sbb,
238 cell_sbh,
239 cell_sbw,
240 chi_obs,
241 xi_obs,
242 rho_u,
243 tau_u,
244 r_uv,
245 ],
246 u32_fields: [cell_offsets],
247 moments_field: cell_moments,
248}
249
250#[derive(Debug)]
252pub(crate) struct BmsFlexRowKernelOutputs {
253 pub neglog: Vec<f64>,
255 pub grad: Vec<f64>,
257 pub hess: Vec<f64>,
260}
261
262impl<'a> BmsFlexRowKernelInputs<'a> {
263 pub(crate) fn validate(&self) -> Result<(), GpuError> {
266 if self.r == 0 {
267 return Err(GpuError::DriverCallFailed {
268 reason: "bms_flex_row inputs: r must be > 0".to_string(),
269 });
270 }
271 if self.r > MAX_R {
272 return Err(GpuError::DriverCallFailed {
273 reason: format!("bms_flex_row inputs: r={} exceeds MAX_R={MAX_R}", self.r),
274 });
275 }
276 if self.r != 2 + self.p_h + self.p_w {
277 return Err(GpuError::DriverCallFailed {
278 reason: format!(
279 "bms_flex_row inputs: r={} must equal 2 + p_h({}) + p_w({}) = {}",
280 self.r,
281 self.p_h,
282 self.p_w,
283 2 + self.p_h + self.p_w
284 ),
285 });
286 }
287 let n = self.n_rows;
288 let check_len = |name: &str, have: usize, want: usize| -> Result<(), GpuError> {
289 if have != want {
290 return Err(GpuError::DriverCallFailed {
291 reason: format!("bms_flex_row inputs: {name}.len()={have} != {want}"),
292 });
293 }
294 Ok(())
295 };
296 check_len("q", self.q.len(), n)?;
297 check_len("b", self.b.len(), n)?;
298 check_len("mu_1", self.mu_1.len(), n)?;
299 check_len("mu_2", self.mu_2.len(), n)?;
300 check_len("z_obs", self.z_obs.len(), n)?;
301 check_len("y", self.y.len(), n)?;
302 check_len("w", self.w.len(), n)?;
303 check_len("e_obs", self.e_obs.len(), n)?;
304 check_len("chi_obs", self.chi_obs.len(), n)?;
305 check_len("xi_obs", self.xi_obs.len(), n)?;
306 check_len("rho_u", self.rho_u.len(), n * self.r)?;
307 check_len("tau_u", self.tau_u.len(), n * self.r)?;
308 check_len("r_uv", self.r_uv.len(), n * self.r * self.r)?;
309 check_len("cell_offsets", self.cell_offsets.len(), n + 1)?;
310 let total_cells_u32 = self.cell_offsets[n];
311 let total_cells = total_cells_u32 as usize;
312 check_len("cell_c0", self.cell_c0.len(), total_cells)?;
313 check_len("cell_c1", self.cell_c1.len(), total_cells)?;
314 check_len("cell_c2", self.cell_c2.len(), total_cells)?;
315 check_len("cell_c3", self.cell_c3.len(), total_cells)?;
316 check_len("cell_a", self.cell_a.len(), total_cells * COEFF4)?;
317 check_len("cell_aa", self.cell_aa.len(), total_cells * COEFF4)?;
318 check_len(
319 "cell_r",
320 self.cell_r.len(),
321 total_cells * self.r.saturating_sub(1) * COEFF4,
322 )?;
323 check_len(
324 "cell_ar",
325 self.cell_ar.len(),
326 total_cells * self.r.saturating_sub(1) * COEFF4,
327 )?;
328 check_len("cell_sbb", self.cell_sbb.len(), total_cells * COEFF4)?;
329 check_len(
330 "cell_sbh",
331 self.cell_sbh.len(),
332 total_cells * self.p_h * COEFF4,
333 )?;
334 check_len(
335 "cell_sbw",
336 self.cell_sbw.len(),
337 total_cells * self.p_w * COEFF4,
338 )?;
339 check_len(
340 "cell_moments",
341 self.cell_moments.len(),
342 total_cells * MOMENT_STRIDE,
343 )?;
344 for i in 0..n {
350 if self.cell_offsets[i] > self.cell_offsets[i + 1] {
351 return Err(GpuError::DriverCallFailed {
352 reason: format!(
353 "bms_flex_row inputs: cell_offsets must be monotone (offset[{}]={} > offset[{}]={})",
354 i,
355 self.cell_offsets[i],
356 i + 1,
357 self.cell_offsets[i + 1]
358 ),
359 });
360 }
361 }
362 Ok(())
363 }
364}
365
366#[cfg(target_os = "linux")]
379pub(crate) const ROW_KERNEL_BODY: &str = r#"
380// One block per row. blockDim.x = 32; threadIdx.x parallises per-cell sums.
381// CPU parity reference: src/families/bernoulli_marginal_slope.rs
382// ::compute_row_analytic_flex_from_parts_into.
383
384#define INV_TWO_PI 0.15915494309189535
385
386extern "C" __device__ __forceinline__ double atomic_add_f64(double *addr, double value) {
387 unsigned long long int *addr_as_ull = (unsigned long long int *)addr;
388 unsigned long long int old = *addr_as_ull;
389 unsigned long long int assumed;
390 do {
391 assumed = old;
392 double next = __longlong_as_double((long long int)assumed) + value;
393 old = atomicCAS(addr_as_ull, assumed, (unsigned long long int)__double_as_longlong(next));
394 } while (assumed != old);
395 return __longlong_as_double((long long int)old);
396}
397
398// `nan_fill_outputs`: thread-0-only path used when row inputs are degenerate
399// (`F_a` non-finite or non-positive). Writes NaNs to neglog/grad/hess so the
400// host falls back to CPU for that row.
401extern "C" __device__ __forceinline__ void
402nan_fill_outputs(int r,
403 int row,
404 double *out_neglog,
405 double *out_grad,
406 double *out_hess) {
407 double nan_value = __longlong_as_double(0x7ff8000000000000ULL);
408 out_neglog[row] = nan_value;
409 for (int u = 0; u < r; ++u) {
410 out_grad[row * r + u] = nan_value;
411 }
412 int rr = r * r;
413 for (int idx = 0; idx < rr; ++idx) {
414 out_hess[row * rr + idx] = nan_value;
415 }
416}
417
418extern "C" __global__ void bms_flex_row_kernel(
419 int n_rows,
420 int r,
421 int p_h,
422 int p_w,
423 double s_f, // currently unused on device:
424 // host has already baked S_f
425 // into the cubic coefficients.
426 // Kept for diagnostic parity.
427 const double * __restrict__ row_q,
428 const double * __restrict__ row_b,
429 const double * __restrict__ row_mu1,
430 const double * __restrict__ row_mu2,
431 const double * __restrict__ row_zobs,
432 const double * __restrict__ row_y,
433 const double * __restrict__ row_w,
434 const unsigned int * __restrict__ cell_offsets,
435 const double * __restrict__ cell_c0,
436 const double * __restrict__ cell_c1,
437 const double * __restrict__ cell_c2,
438 const double * __restrict__ cell_c3,
439 const double * __restrict__ cell_a, // [n_cells, 4]
440 const double * __restrict__ cell_aa, // [n_cells, 4]
441 const double * __restrict__ cell_r, // [n_cells, r-1, 4]
442 const double * __restrict__ cell_ar, // [n_cells, r-1, 4]
443 const double * __restrict__ cell_sbb, // [n_cells, 4]
444 const double * __restrict__ cell_sbh, // [n_cells, p_h, 4]
445 const double * __restrict__ cell_sbw, // [n_cells, p_w, 4]
446 const double * __restrict__ cell_moments, // [n_cells, 10]
447 const double * __restrict__ row_chi,
448 const double * __restrict__ row_xi,
449 const double * __restrict__ row_rho, // [n_rows, r]
450 const double * __restrict__ row_tau, // [n_rows, r]
451 const double * __restrict__ row_ruv, // [n_rows, r*r]
452 const double * __restrict__ row_e_obs, // [n_rows] observed predictor VALUE
453 double * __restrict__ out_neglog,
454 double * __restrict__ out_grad,
455 double * __restrict__ out_hess)
456{
457 int row = blockIdx.x;
458 if (row >= n_rows) return;
459 int tid = threadIdx.x;
460
461 // ── shared scratch (sized to MAX_R = 32) ──────────────────────────────
462 // Layout (doubles):
463 // F_u [r]
464 // F_au [r]
465 // F_uv [r*r]
466 // bar_e_u [r]
467 // bar_e_uv [r*r]
468 // reduce_a [blockDim.x]
469 // reduce_b [blockDim.x]
470 // Sized for the worst case (r = MAX_R = 32).
471 __shared__ double F_u[32];
472 __shared__ double F_au[32];
473 __shared__ double F_uv[32 * 32];
474 __shared__ double bar_e_u[32];
475 __shared__ double bar_e_uv[32 * 32];
476 __shared__ double reduce_a[32];
477 __shared__ double reduce_b[32];
478 __shared__ double F_a_shared;
479 __shared__ double F_aa_shared;
480
481 // Zero scratch.
482 if (tid == 0) { F_a_shared = 0.0; F_aa_shared = 0.0; }
483 for (int u = tid; u < r; u += blockDim.x) {
484 F_u[u] = 0.0;
485 F_au[u] = 0.0;
486 }
487 for (int uv = tid; uv < r * r; uv += blockDim.x) {
488 F_uv[uv] = 0.0;
489 }
490 __syncthreads();
491
492 // ── per-cell sweep ───────────────────────────────────────────────────
493 unsigned int cell_lo = cell_offsets[row];
494 unsigned int cell_hi = cell_offsets[row + 1];
495 int n_cells = (int)(cell_hi - cell_lo);
496
497 double local_Fa = 0.0;
498 double local_Faa = 0.0;
499
500 for (int local_c = tid; local_c < n_cells; local_c += blockDim.x) {
501 unsigned int c = cell_lo + (unsigned int)local_c;
502
503 // Load cubic predictor coeffs C0..C3.
504 double C[4];
505 C[0] = cell_c0[c]; C[1] = cell_c1[c];
506 C[2] = cell_c2[c]; C[3] = cell_c3[c];
507
508 // Load m_0..m_9.
509 const double *m = cell_moments + (size_t)c * 10;
510
511 // T_n = κ · Σ_e C_e · m_{e+n}, n = 0..6.
512 // CPU parity: equivalent to the `eta_rs ⊗ moments` contraction in
513 // `cell_second_derivative_from_moments` after folding the
514 // cubic predictor.
515 double T[7];
516 #pragma unroll
517 for (int n = 0; n < 7; ++n) {
518 double acc = 0.0;
519 #pragma unroll
520 for (int e = 0; e < 4; ++e) {
521 acc = fma(C[e], m[e + n], acc);
522 }
523 T[n] = acc * INV_TWO_PI;
524 }
525
526 // D(R) = κ · Σ_k R_k · m_k.
527 // CPU parity: `cell_first_derivative_from_moments`.
528 #define D_OF(R) (INV_TWO_PI * (R[0]*m[0] + R[1]*m[1] + R[2]*m[2] + R[3]*m[3]))
529
530 // Q(R, S) = Σ_{p,q} R_p · S_q · T_{p+q}.
531 // CPU parity: the `eta_rs` folded dot in
532 // `cell_second_derivative_from_moments`.
533 #define Q_OF(R, S) \
534 ((R[0]*S[0])*T[0] + (R[0]*S[1] + R[1]*S[0])*T[1] \
535 + (R[0]*S[2] + R[1]*S[1] + R[2]*S[0])*T[2] \
536 + (R[0]*S[3] + R[1]*S[2] + R[2]*S[1] + R[3]*S[0])*T[3] \
537 + (R[1]*S[3] + R[2]*S[2] + R[3]*S[1])*T[4] \
538 + (R[2]*S[3] + R[3]*S[2])*T[5] \
539 + (R[3]*S[3])*T[6])
540
541 // F_a += D(A_c) ; F_aa += H(A_c, A_c, AA_c) = D(AA_c) − Q(A_c, A_c).
542 const double *A_c = cell_a + (size_t)c * 4;
543 const double *AA_c = cell_aa + (size_t)c * 4;
544 local_Fa += D_OF(A_c);
545 local_Faa += D_OF(AA_c) - Q_OF(A_c, A_c);
546
547 // For each u > 0: F_u += D(R_{c,u}) ; F_au += H(A_c, R_{c,u}, AR_{c,u})
548 // = D(AR_{c,u}) − Q(A_c, R_{c,u}).
549 for (int u = 1; u < r; ++u) {
550 const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
551 const double *AR_u = cell_ar + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
552 double d_R = D_OF(R_u);
553 double d_AR = D_OF(AR_u);
554 double q_AR = Q_OF(A_c, R_u);
555 atomic_add_f64(&F_u[u], d_R);
556 atomic_add_f64(&F_au[u], d_AR - q_AR);
557 }
558
559 // F_uv: only b·b, b·h_j, b·w_ℓ have a material `S_{c,uv}`; every other
560 // (u, v) pair just contributes −Q(R_u, R_v).
561 // CPU parity: `SparsePrimaryCoeffJetView::pair_from_b_family` with
562 // `COEFF_SUPPORT_BHW` — every cross pair outside the b-row is zero.
563 for (int u = 1; u < r; ++u) {
564 const double *R_u = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(u - 1)) * 4;
565 for (int v = u; v < r; ++v) {
566 const double *R_v = cell_r + ((size_t)c * (size_t)(r - 1) + (size_t)(v - 1)) * 4;
567 double q_uv = Q_OF(R_u, R_v);
568 double d_s = 0.0;
569 // S_{bb}: u == v == 1 (b coordinate).
570 if (u == 1 && v == 1) {
571 const double *S_bb = cell_sbb + (size_t)c * 4;
572 d_s = D_OF(S_bb);
573 }
574 // S_{b·h_j}: u == 1, v in score-warp block, or symmetric.
575 else if (u == 1 && v >= 2 && v < 2 + p_h) {
576 int j = v - 2;
577 const double *S_bh = cell_sbh + ((size_t)c * (size_t)p_h + (size_t)j) * 4;
578 d_s = D_OF(S_bh);
579 }
580 // S_{b·w_ℓ}: u == 1, v in link-wiggle block, or symmetric.
581 else if (u == 1 && v >= 2 + p_h && v < r) {
582 int l = v - (2 + p_h);
583 const double *S_bw = cell_sbw + ((size_t)c * (size_t)p_w + (size_t)l) * 4;
584 d_s = D_OF(S_bw);
585 }
586 // Symmetric mirror: u in (h or w) block, v == 1 cannot happen
587 // because we iterate v >= u; skip.
588 double val = d_s - q_uv;
589 atomic_add_f64(&F_uv[u * r + v], val);
590 }
591 }
592
593 #undef D_OF
594 #undef Q_OF
595 }
596
597 // Block reduction of local_Fa, local_Faa into shared.
598 reduce_a[tid] = local_Fa;
599 reduce_b[tid] = local_Faa;
600 __syncthreads();
601 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
602 if (tid < stride) {
603 reduce_a[tid] += reduce_a[tid + stride];
604 reduce_b[tid] += reduce_b[tid + stride];
605 }
606 __syncthreads();
607 }
608 if (tid == 0) {
609 F_a_shared = reduce_a[0];
610 F_aa_shared = reduce_b[0];
611 }
612 __syncthreads();
613
614 // ── thread-0 finalisation: IFT + observed-point + Mills + writes ──────
615 if (tid != 0) return;
616
617 double F_a = F_a_shared;
618 double F_aa = F_aa_shared;
619 double mu_1 = row_mu1[row];
620 double mu_2 = row_mu2[row];
621
622 // q-row overrides.
623 // F_q = -mu_1 ; F_qq = -mu_2 ; F_qv = 0 (v > 0) ; F_aq = 0.
624 F_u[0] = -mu_1;
625 F_au[0] = 0.0;
626 // Zero the q-cross row/column of F_uv (u == 0 or v == 0), then plant -mu_2 at (0,0).
627 for (int v = 0; v < r; ++v) {
628 F_uv[0 * r + v] = 0.0;
629 F_uv[v * r + 0] = 0.0;
630 }
631 F_uv[0 * r + 0] = -mu_2;
632
633 // Guard: degenerate F_a ⇒ NaN-fill this row's outputs.
634 if (!isfinite(F_a) || F_a <= 0.0) {
635 nan_fill_outputs(r, row, out_neglog, out_grad, out_hess);
636 return;
637 }
638 double inv_Fa = 1.0 / F_a;
639
640 // IFT, first order.
641 // a_u = -F_u · inv_Fa (q-override: a_q = mu_1 · inv_Fa).
642 double a_u[32];
643 a_u[0] = mu_1 * inv_Fa;
644 for (int u = 1; u < r; ++u) {
645 a_u[u] = -F_u[u] * inv_Fa;
646 }
647
648 // IFT, second order.
649 // a_uv = -(F_uv + F_au · a_v + F_av · a_u + F_aa · a_u · a_v) · inv_Fa.
650 // The q-row contributions (u==0 or v==0) collapse to a_uv = mu_2 · inv_Fa
651 // when both are 0 and to (F_au_v) · inv_Fa-style mixed shape otherwise.
652 // We compute it uniformly using the populated F_uv / F_au with the
653 // q-overrides above.
654 double a_uv[32 * 32];
655 for (int u = 0; u < r; ++u) {
656 for (int v = u; v < r; ++v) {
657 double term = F_uv[u * r + v]
658 + F_au[v] * a_u[u]
659 + F_au[u] * a_u[v]
660 + F_aa * a_u[u] * a_u[v];
661 double val = -term * inv_Fa;
662 a_uv[u * r + v] = val;
663 a_uv[v * r + u] = val;
664 }
665 }
666
667 // Observed predictor jets at z_obs.
668 // bar_e_u = chi · a_u + rho_u.
669 // bar_e_uv = chi · a_uv + xi · a_u · a_v + tau_u · a_v + a_u · tau_v + r_uv.
670 double chi = row_chi[row];
671 double xi = row_xi[row];
672 const double *rho = row_rho + (size_t)row * r;
673 const double *tau = row_tau + (size_t)row * r;
674 const double *ruv = row_ruv + (size_t)row * r * r;
675
676 for (int u = 0; u < r; ++u) {
677 bar_e_u[u] = chi * a_u[u] + rho[u];
678 }
679 for (int u = 0; u < r; ++u) {
680 for (int v = u; v < r; ++v) {
681 double val = chi * a_uv[u * r + v]
682 + xi * a_u[u] * a_u[v]
683 + tau[u] * a_u[v]
684 + a_u[u] * tau[v]
685 + ruv[u * r + v];
686 bar_e_uv[u * r + v] = val;
687 if (u != v) {
688 bar_e_uv[v * r + u] = val;
689 }
690 }
691 }
692
693 // Probit Mills.
694 double y = row_y[row];
695 double w = row_w[row];
696 double s = 2.0 * y - 1.0;
697 // The "observed predictor" e_obs is the VALUE (degree-0 term) of the
698 // observed jet η(a(θ), θ; z_obs) — NOT `bar_e_u[0]`, which is the u=0
699 // FIRST-derivative jet (`chi·a_0 + rho_0 = dη_obs/dq`). The host packs
700 // the observed value directly in `row_e_obs[row]` (see
701 // `pack_bms_flex_row_kernel_inputs`, `eta_val = eval_coeff4_at(obs.coeff,
702 // z_obs)`), matching the CPU family `compute_row_analytic_flex_from_parts_into`
703 // which forms `signed_margin = s_y · eta_val`. #415 parity lock.
704 double e_obs = row_e_obs[row];
705 double m_arg = s * e_obs;
706 double log_cdf, lambda;
707 log_ndtr_and_mills(m_arg, &log_cdf, &lambda);
708 double A_i = -w * s * lambda;
709 double B_i = w * lambda * (m_arg + lambda);
710
711 out_neglog[row] = -w * log_cdf;
712 for (int u = 0; u < r; ++u) {
713 out_grad[row * r + u] = A_i * bar_e_u[u];
714 }
715 for (int u = 0; u < r; ++u) {
716 for (int v = u; v < r; ++v) {
717 double val = B_i * bar_e_u[u] * bar_e_u[v] + A_i * bar_e_uv[u * r + v];
718 out_hess[row * r * r + u * r + v] = val;
719 if (u != v) {
720 out_hess[row * r * r + v * r + u] = val;
721 }
722 }
723 }
724}
725"#;
726
727#[inline]
734pub(crate) fn s_f_diagnostic_finite(inputs: &BmsFlexRowKernelInputs<'_>) -> bool {
735 inputs.s_f.is_finite() && inputs.s_f > 0.0
736}
737
738#[cfg(target_os = "linux")]
739pub(crate) struct RowKernelBackend {
740 pub(crate) stream: Arc<CudaStream>,
741 pub(crate) module: Arc<CudaModule>,
742}
743
744#[cfg(target_os = "linux")]
745impl RowKernelBackend {
746 pub(crate) fn probe() -> Result<&'static Self, GpuError> {
747 static BACKEND: OnceLock<Result<RowKernelBackend, GpuError>> = OnceLock::new();
748 BACKEND
749 .get_or_init(|| {
750 gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row", |parts| {
751 let row_kernel_source = [
752 gam_gpu::numerics_device::PROBIT_NUMERICS_CU,
753 ROW_KERNEL_BODY,
754 ]
755 .concat();
756 let ptx = gam_gpu::device_cache::compile_ptx_arch(&row_kernel_source)
767 .map_err(|err| GpuError::DriverCallFailed {
768 reason: format!("bms_flex_row NVRTC compile failed: {err}"),
769 })?;
770 let module =
771 parts
772 .ctx
773 .load_module(ptx)
774 .map_err(|err| GpuError::DriverCallFailed {
775 reason: format!("bms_flex_row module load failed: {err}"),
776 })?;
777 Ok(RowKernelBackend {
778 stream: parts.stream.clone(),
779 module,
780 })
781 })
782 })
783 .as_ref()
784 .map_err(GpuError::clone)
785 }
786}
787
788pub(crate) fn launch_bms_flex_row_kernel(
793 inputs: BmsFlexRowKernelInputs<'_>,
794) -> Result<BmsFlexRowKernelOutputs, GpuError> {
795 inputs.validate()?;
796 if !s_f_diagnostic_finite(&inputs) {
797 return Err(GpuError::DriverCallFailed {
798 reason: format!(
799 "bms_flex_row inputs: s_f must be positive and finite, got {}",
800 inputs.s_f
801 ),
802 });
803 }
804
805 #[cfg(target_os = "linux")]
806 {
807 launch_linux(inputs)
808 }
809 #[cfg(not(target_os = "linux"))]
810 {
811 Err(GpuError::DriverLibraryUnavailable {
812 reason: "bms_flex_row GPU kernel is Linux-only".to_string(),
813 })
814 }
815}
816
817#[cfg(target_os = "linux")]
818pub(crate) fn launch_linux(
819 inputs: BmsFlexRowKernelInputs<'_>,
820) -> Result<BmsFlexRowKernelOutputs, GpuError> {
821 let backend = RowKernelBackend::probe()?;
822 let stream = &backend.stream;
823
824 let upload_f64 = |slice: &[f64], label: &str| {
825 stream
826 .clone_htod(slice)
827 .map_err(|err| GpuError::DriverCallFailed {
828 reason: format!("bms_flex_row upload {label}: {err}"),
829 })
830 };
831 let upload_u32 = |slice: &[u32], label: &str| {
832 stream
833 .clone_htod(slice)
834 .map_err(|err| GpuError::DriverCallFailed {
835 reason: format!("bms_flex_row upload {label}: {err}"),
836 })
837 };
838
839 let d_q = upload_f64(inputs.q, "q")?;
840 let d_b = upload_f64(inputs.b, "b")?;
841 let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
842 let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
843 let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
844 let d_y = upload_f64(inputs.y, "y")?;
845 let d_w = upload_f64(inputs.w, "w")?;
846 let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
847 let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
848 let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
849 let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
850 let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
851 let d_a = upload_f64(inputs.cell_a, "cell_a")?;
852 let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
853 let d_r = upload_f64(inputs.cell_r, "cell_r")?;
854 let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
855 let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
856 let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
857 let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
858 let owned_host_moments: CudaSlice<f64>;
862 let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
863 CellMomentsSource::Host(slice) => {
864 owned_host_moments = upload_f64(slice, "cell_moments")?;
865 &owned_host_moments
866 }
867 CellMomentsSource::Device(d) => *d,
868 };
869 let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
870 let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
871 let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
872 let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
873 let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
874 let d_e_obs = upload_f64(inputs.e_obs, "e_obs")?;
875
876 let n = inputs.n_rows;
877 let r = inputs.r;
878 let mut d_neglog = stream
879 .alloc_zeros::<f64>(n)
880 .map_err(|err| GpuError::DriverCallFailed {
881 reason: format!("bms_flex_row alloc neglog: {err}"),
882 })?;
883 let mut d_grad =
884 stream
885 .alloc_zeros::<f64>(n * r)
886 .map_err(|err| GpuError::DriverCallFailed {
887 reason: format!("bms_flex_row alloc grad: {err}"),
888 })?;
889 let mut d_hess =
890 stream
891 .alloc_zeros::<f64>(n * r * r)
892 .map_err(|err| GpuError::DriverCallFailed {
893 reason: format!("bms_flex_row alloc hess: {err}"),
894 })?;
895
896 let func = backend
897 .module
898 .load_function("bms_flex_row_kernel")
899 .map_err(|err| GpuError::DriverCallFailed {
900 reason: format!("bms_flex_row load_function: {err}"),
901 })?;
902
903 let cfg = LaunchConfig {
904 grid_dim: (n as u32, 1, 1),
905 block_dim: (ROW_KERNEL_THREADS, 1, 1),
906 shared_mem_bytes: 0,
907 };
908 let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
909 reason: format!("bms_flex_row: n_rows={n} exceeds i32 range"),
910 })?;
911 let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
912 reason: format!("bms_flex_row: r={r} exceeds i32 range"),
913 })?;
914 let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
915 reason: format!("bms_flex_row: p_h={} exceeds i32 range", inputs.p_h),
916 })?;
917 let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
918 reason: format!("bms_flex_row: p_w={} exceeds i32 range", inputs.p_w),
919 })?;
920 let s_f = inputs.s_f;
921
922 let mut builder = stream.launch_builder(&func);
923 builder
924 .arg(&n_i32)
925 .arg(&r_i32)
926 .arg(&p_h_i32)
927 .arg(&p_w_i32)
928 .arg(&s_f)
929 .arg(&d_q)
930 .arg(&d_b)
931 .arg(&d_mu1)
932 .arg(&d_mu2)
933 .arg(&d_zobs)
934 .arg(&d_y)
935 .arg(&d_w)
936 .arg(&d_offsets)
937 .arg(&d_c0)
938 .arg(&d_c1)
939 .arg(&d_c2)
940 .arg(&d_c3)
941 .arg(&d_a)
942 .arg(&d_aa)
943 .arg(&d_r)
944 .arg(&d_ar)
945 .arg(&d_sbb)
946 .arg(&d_sbh)
947 .arg(&d_sbw)
948 .arg(d_moments_ref)
949 .arg(&d_chi)
950 .arg(&d_xi)
951 .arg(&d_rho)
952 .arg(&d_tau)
953 .arg(&d_ruv)
954 .arg(&d_e_obs)
955 .arg(&mut d_neglog)
956 .arg(&mut d_grad)
957 .arg(&mut d_hess);
958
959 unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
966 reason: format!("bms_flex_row launch: {err}"),
967 })?;
968 stream
969 .synchronize()
970 .map_err(|err| GpuError::DriverCallFailed {
971 reason: format!("bms_flex_row synchronize: {err}"),
972 })?;
973
974 let neglog = stream
975 .clone_dtoh(&d_neglog)
976 .map_err(|err| GpuError::DriverCallFailed {
977 reason: format!("bms_flex_row download neglog: {err}"),
978 })?;
979 let grad = stream
980 .clone_dtoh(&d_grad)
981 .map_err(|err| GpuError::DriverCallFailed {
982 reason: format!("bms_flex_row download grad: {err}"),
983 })?;
984 let hess = stream
985 .clone_dtoh(&d_hess)
986 .map_err(|err| GpuError::DriverCallFailed {
987 reason: format!("bms_flex_row download hess: {err}"),
988 })?;
989
990 Ok(BmsFlexRowKernelOutputs { neglog, grad, hess })
991}
992
993#[cfg(target_os = "linux")]
1045#[derive(Clone, Debug)]
1046pub(crate) struct BmsFlexBlockLayout {
1047 pub p_m: usize,
1048 pub p_g: usize,
1049 pub h: Option<std::ops::Range<usize>>,
1050 pub w: Option<std::ops::Range<usize>>,
1051 pub p_total: usize,
1052}
1053
1054#[cfg(target_os = "linux")]
1057#[derive(Clone, Debug)]
1058pub(crate) struct BmsFlexPrimaryLayout {
1059 pub h: Option<std::ops::Range<usize>>,
1060 pub w: Option<std::ops::Range<usize>>,
1061 pub r: usize,
1062}
1063
1064#[cfg(target_os = "linux")]
1070pub(crate) const HVP_ROWS_PER_CTA: u32 = 256;
1071
1072#[cfg(target_os = "linux")]
1074pub(crate) const HVP_THREADS: u32 = 128;
1075
1076#[cfg(target_os = "linux")]
1081pub(crate) const REDUCTION_THREADS: u32 = 256;
1082
1083#[cfg(target_os = "linux")]
1088pub(crate) const BMS_FLEX_ROW_HVP_MAX_RHS: usize = 8;
1089
1090#[cfg(target_os = "linux")]
1111pub struct DeviceResidentRowHess {
1112 pub(crate) hess: CudaSlice<f64>,
1116 pub(crate) marginal_design: CudaSlice<f64>,
1117 pub(crate) logslope_design: CudaSlice<f64>,
1118 pub(crate) n: usize,
1119 pub(crate) r: usize,
1120 pub(crate) block: BmsFlexBlockLayout,
1121 pub(crate) primary: BmsFlexPrimaryLayout,
1122 pub(crate) bytes: u64,
1124}
1125
1126#[cfg(target_os = "linux")]
1127impl std::fmt::Debug for DeviceResidentRowHess {
1128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1129 f.debug_struct("DeviceResidentRowHess")
1130 .field("n", &self.n)
1131 .field("r", &self.r)
1132 .field("p_total", &self.block.p_total)
1133 .field("bytes", &self.bytes)
1134 .finish()
1135 }
1136}
1137
1138#[cfg(target_os = "linux")]
1141pub(crate) fn num_hvp_chunks(n: usize) -> usize {
1142 n.div_ceil(HVP_ROWS_PER_CTA as usize)
1143}
1144
1145#[cfg(target_os = "linux")]
1148pub(crate) const HVP_KERNEL_SOURCE: &str = r#"
1149// CPU parity reference: cpu_oracle_bms_flex_row_hvp / cpu_oracle_bms_flex_row_diagonal
1150// in this module.
1151
1152#define MAX_MULTI_RHS 8
1153
1154extern "C" __global__ void bms_flex_row_hvp_partial(
1155 int n_rows,
1156 int r,
1157 int p_m,
1158 int p_g,
1159 int p_total,
1160 int h_block_start,
1161 int h_block_len,
1162 int w_block_start,
1163 int w_block_len,
1164 int h_primary_start,
1165 int w_primary_start,
1166 int rows_per_cta,
1167 const double * __restrict__ row_hessians, // [n, r*r]
1168 const double * __restrict__ marginal_design, // [n, p_m] row-major
1169 const double * __restrict__ logslope_design, // [n, p_g] row-major
1170 const double * __restrict__ v, // [p_total]
1171 double * __restrict__ partial) // [num_chunks, p_total]
1172{
1173 int chunk = blockIdx.x;
1174 int tid = threadIdx.x;
1175 int row_lo = chunk * rows_per_cta;
1176 int row_hi = row_lo + rows_per_cta;
1177 if (row_hi > n_rows) row_hi = n_rows;
1178
1179 // Zero this chunk's partial slice cooperatively.
1180 double *out = partial + (size_t)chunk * (size_t)p_total;
1181 for (int j = tid; j < p_total; j += blockDim.x) {
1182 out[j] = 0.0;
1183 }
1184 __syncthreads();
1185
1186 // Each thread serially processes a stride-of-blockDim set of rows so
1187 // every write to `out[..]` happens from one thread → no atomics within
1188 // the chunk. To keep writes race-free across threads of the same chunk,
1189 // we serialize the cross-row accumulation through a per-row barrier:
1190 // thread 0 of the block processes all rows in the chunk. The per-row
1191 // work is dominated by the dot/axpy over `p_m + p_g`, which is large.
1192 // For Stage 3 we ship the simple, correct path (thread 0 sequential
1193 // per row, blockDim.x threads parallel within a row's dot/axpy).
1194 __shared__ double row_dir[32];
1195 __shared__ double action[32];
1196 __shared__ double dot_reduce[128];
1197
1198 for (int row = row_lo; row < row_hi; ++row) {
1199 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1200 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1201 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1202
1203 // row_dir[0] = mrow · v[0..p_m]
1204 double local = 0.0;
1205 for (int j = tid; j < p_m; j += blockDim.x) {
1206 local += mrow[j] * v[j];
1207 }
1208 dot_reduce[tid] = local;
1209 __syncthreads();
1210 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1211 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1212 __syncthreads();
1213 }
1214 if (tid == 0) row_dir[0] = dot_reduce[0];
1215
1216 // row_dir[1] = grow · v[p_m..p_m+p_g]
1217 local = 0.0;
1218 for (int j = tid; j < p_g; j += blockDim.x) {
1219 local += grow[j] * v[p_m + j];
1220 }
1221 dot_reduce[tid] = local;
1222 __syncthreads();
1223 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1224 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1225 __syncthreads();
1226 }
1227 if (tid == 0) row_dir[1] = dot_reduce[0];
1228
1229 // h/w blocks: direct copy.
1230 if (tid == 0) {
1231 for (int k = 0; k < h_block_len; ++k) {
1232 row_dir[h_primary_start + k] = v[h_block_start + k];
1233 }
1234 for (int k = 0; k < w_block_len; ++k) {
1235 row_dir[w_primary_start + k] = v[w_block_start + k];
1236 }
1237 }
1238 __syncthreads();
1239
1240 // action[u] = Σ_v Hrow[u*r+v] · row_dir[v], computed by thread u (u < r).
1241 if (tid < r) {
1242 double acc = 0.0;
1243 for (int vv = 0; vv < r; ++vv) {
1244 acc += Hrow[tid * r + vv] * row_dir[vv];
1245 }
1246 action[tid] = acc;
1247 }
1248 __syncthreads();
1249
1250 // Pull back into joint β slot.
1251 // marginal: out[j] += action[0] · mrow[j] (parallel j)
1252 double a0 = action[0];
1253 for (int j = tid; j < p_m; j += blockDim.x) {
1254 out[j] += a0 * mrow[j];
1255 }
1256 double a1 = action[1];
1257 for (int j = tid; j < p_g; j += blockDim.x) {
1258 out[p_m + j] += a1 * grow[j];
1259 }
1260 if (tid == 0) {
1261 for (int k = 0; k < h_block_len; ++k) {
1262 out[h_block_start + k] += action[h_primary_start + k];
1263 }
1264 for (int k = 0; k < w_block_len; ++k) {
1265 out[w_block_start + k] += action[w_primary_start + k];
1266 }
1267 }
1268 __syncthreads();
1269 }
1270}
1271
1272extern "C" __global__ void bms_flex_row_hvp_reduce(
1273 int num_chunks,
1274 int p_total,
1275 const double * __restrict__ partial, // [num_chunks, p_total]
1276 double * __restrict__ out) // [p_total]
1277{
1278 int j = blockIdx.x * blockDim.x + threadIdx.x;
1279 if (j >= p_total) return;
1280 double acc = 0.0;
1281 for (int c = 0; c < num_chunks; ++c) {
1282 acc += partial[(size_t)c * (size_t)p_total + (size_t)j];
1283 }
1284 out[j] = acc;
1285}
1286
1287extern "C" __global__ void bms_flex_row_hvp_multi_partial(
1288 int n_rows,
1289 int r,
1290 int p_m,
1291 int p_g,
1292 int p_total,
1293 int h_block_start,
1294 int h_block_len,
1295 int w_block_start,
1296 int w_block_len,
1297 int h_primary_start,
1298 int w_primary_start,
1299 int rows_per_cta,
1300 int rhs_count,
1301 const double * __restrict__ row_hessians, // [n, r*r]
1302 const double * __restrict__ marginal_design, // [n, p_m]
1303 const double * __restrict__ logslope_design, // [n, p_g]
1304 const double * __restrict__ v_rhs, // [rhs_count, p_total]
1305 double * __restrict__ partial) // [rhs_count, num_chunks, p_total]
1306{
1307 int chunk = blockIdx.x;
1308 int tid = threadIdx.x;
1309 int row_lo = chunk * rows_per_cta;
1310 int row_hi = row_lo + rows_per_cta;
1311 if (row_hi > n_rows) row_hi = n_rows;
1312
1313 int num_chunks = (n_rows + rows_per_cta - 1) / rows_per_cta;
1314 for (int idx = tid; idx < rhs_count * p_total; idx += blockDim.x) {
1315 int rhs = idx / p_total;
1316 int j = idx - rhs * p_total;
1317 partial[((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total + (size_t)j] = 0.0;
1318 }
1319 __syncthreads();
1320
1321 __shared__ double row_dir[MAX_MULTI_RHS * 32];
1322 __shared__ double action[MAX_MULTI_RHS * 32];
1323 __shared__ double dot_reduce[128];
1324
1325 for (int row = row_lo; row < row_hi; ++row) {
1326 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1327 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1328 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1329
1330 for (int rhs = 0; rhs < rhs_count; ++rhs) {
1331 const double *v = v_rhs + (size_t)rhs * (size_t)p_total;
1332
1333 double local = 0.0;
1334 for (int j = tid; j < p_m; j += blockDim.x) {
1335 local += mrow[j] * v[j];
1336 }
1337 dot_reduce[tid] = local;
1338 __syncthreads();
1339 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1340 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1341 __syncthreads();
1342 }
1343 if (tid == 0) row_dir[rhs * 32 + 0] = dot_reduce[0];
1344
1345 local = 0.0;
1346 for (int j = tid; j < p_g; j += blockDim.x) {
1347 local += grow[j] * v[p_m + j];
1348 }
1349 dot_reduce[tid] = local;
1350 __syncthreads();
1351 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1352 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1353 __syncthreads();
1354 }
1355 if (tid == 0) {
1356 row_dir[rhs * 32 + 1] = dot_reduce[0];
1357 for (int k = 0; k < h_block_len; ++k) {
1358 row_dir[rhs * 32 + h_primary_start + k] = v[h_block_start + k];
1359 }
1360 for (int k = 0; k < w_block_len; ++k) {
1361 row_dir[rhs * 32 + w_primary_start + k] = v[w_block_start + k];
1362 }
1363 }
1364 __syncthreads();
1365 }
1366
1367 for (int idx = tid; idx < rhs_count * r; idx += blockDim.x) {
1368 int rhs = idx / r;
1369 int u = idx - rhs * r;
1370 double acc = 0.0;
1371 const double *dir = row_dir + rhs * 32;
1372 for (int vv = 0; vv < r; ++vv) {
1373 acc += Hrow[u * r + vv] * dir[vv];
1374 }
1375 action[rhs * 32 + u] = acc;
1376 }
1377 __syncthreads();
1378
1379 for (int rhs = 0; rhs < rhs_count; ++rhs) {
1380 double *out = partial + ((size_t)rhs * (size_t)num_chunks + (size_t)chunk) * (size_t)p_total;
1381 double a0 = action[rhs * 32 + 0];
1382 for (int j = tid; j < p_m; j += blockDim.x) {
1383 out[j] += a0 * mrow[j];
1384 }
1385 double a1 = action[rhs * 32 + 1];
1386 for (int j = tid; j < p_g; j += blockDim.x) {
1387 out[p_m + j] += a1 * grow[j];
1388 }
1389 if (tid == 0) {
1390 for (int k = 0; k < h_block_len; ++k) {
1391 out[h_block_start + k] += action[rhs * 32 + h_primary_start + k];
1392 }
1393 for (int k = 0; k < w_block_len; ++k) {
1394 out[w_block_start + k] += action[rhs * 32 + w_primary_start + k];
1395 }
1396 }
1397 __syncthreads();
1398 }
1399 }
1400}
1401
1402extern "C" __global__ void bms_flex_row_hvp_multi_reduce(
1403 int num_chunks,
1404 int p_total,
1405 int rhs_count,
1406 const double * __restrict__ partial, // [rhs_count, num_chunks, p_total]
1407 double * __restrict__ out) // [rhs_count, p_total]
1408{
1409 int idx = blockIdx.x * blockDim.x + threadIdx.x;
1410 int total = rhs_count * p_total;
1411 if (idx >= total) return;
1412 int rhs = idx / p_total;
1413 int j = idx - rhs * p_total;
1414 double acc = 0.0;
1415 for (int c = 0; c < num_chunks; ++c) {
1416 acc += partial[((size_t)rhs * (size_t)num_chunks + (size_t)c) * (size_t)p_total + (size_t)j];
1417 }
1418 out[(size_t)rhs * (size_t)p_total + (size_t)j] = acc;
1419}
1420
1421extern "C" __global__ void bms_flex_row_diag_partial(
1422 int n_rows,
1423 int r,
1424 int p_m,
1425 int p_g,
1426 int p_total,
1427 int h_block_start,
1428 int h_block_len,
1429 int w_block_start,
1430 int w_block_len,
1431 int h_primary_start,
1432 int w_primary_start,
1433 int rows_per_cta,
1434 const double * __restrict__ row_hessians,
1435 const double * __restrict__ marginal_design,
1436 const double * __restrict__ logslope_design,
1437 double * __restrict__ partial)
1438{
1439 int chunk = blockIdx.x;
1440 int tid = threadIdx.x;
1441 int row_lo = chunk * rows_per_cta;
1442 int row_hi = row_lo + rows_per_cta;
1443 if (row_hi > n_rows) row_hi = n_rows;
1444
1445 double *out = partial + (size_t)chunk * (size_t)p_total;
1446 for (int j = tid; j < p_total; j += blockDim.x) {
1447 out[j] = 0.0;
1448 }
1449 __syncthreads();
1450
1451 for (int row = row_lo; row < row_hi; ++row) {
1452 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1453 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1454 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1455 double h00 = Hrow[0];
1456 double h11 = Hrow[1 * r + 1];
1457 for (int j = tid; j < p_m; j += blockDim.x) {
1458 double v = mrow[j];
1459 out[j] += h00 * v * v;
1460 }
1461 for (int j = tid; j < p_g; j += blockDim.x) {
1462 double v = grow[j];
1463 out[p_m + j] += h11 * v * v;
1464 }
1465 if (tid == 0) {
1466 for (int k = 0; k < h_block_len; ++k) {
1467 int ii = h_primary_start + k;
1468 out[h_block_start + k] += Hrow[ii * r + ii];
1469 }
1470 for (int k = 0; k < w_block_len; ++k) {
1471 int ii = w_primary_start + k;
1472 out[w_block_start + k] += Hrow[ii * r + ii];
1473 }
1474 }
1475 __syncthreads();
1476 }
1477}
1478
1479// ────────────────────────────────────────────────────────────────────────
1480// Phase 4 — SymmetricPackedUpper variants. Per-row storage is
1481// row_hessians_packed + (size_t)row * (size_t)(r*(r+1)/2)
1482// indexed as
1483// packed[(u*(2*r - u - 1))/2 + (v - u)] for u <= v
1484// with symmetric mirror for v < u.
1485// ────────────────────────────────────────────────────────────────────────
1486
1487// Helper: packed-upper index for (u, v) within a single row of r*(r+1)/2
1488// doubles. Caller must pre-swap so that u <= v.
1489__device__ __forceinline__ int bms_flex_packed_idx(int u, int v, int r) {
1490 // u*(2r - u - 1)/2 + (v - u)
1491 return (u * (2 * r - u - 1)) / 2 + (v - u);
1492}
1493
1494// Pack one row of the full row-major r×r Hessian into packed-upper layout.
1495// Launched as one CTA per row (gridDim.x = n_rows, blockDim.x configurable).
1496// Bit-equal copy: each upper-triangle entry is read once from the dense
1497// source and written once to the packed destination.
1498extern "C" __global__ void bms_flex_row_pack_upper(
1499 int n_rows,
1500 int r,
1501 const double * __restrict__ src_full, // [n, r*r]
1502 double * __restrict__ dst_packed) // [n, r*(r+1)/2]
1503{
1504 int row = blockIdx.x;
1505 if (row >= n_rows) return;
1506 int tid = threadIdx.x;
1507 int per_row = r * (r + 1) / 2;
1508 const double *src = src_full + (size_t)row * (size_t)r * (size_t)r;
1509 double *dst = dst_packed + (size_t)row * (size_t)per_row;
1510 // Linear scan over packed positions; map each back to (u, v).
1511 for (int pos = tid; pos < per_row; pos += blockDim.x) {
1512 // Invert: for u in [0, r), the range [u_start, u_start + (r - u))
1513 // contains positions for that u. u_start = u*(2r - u - 1)/2.
1514 // Solve smallest u with u*(2r - u - 1)/2 > pos to get u (then
1515 // back off by one); equivalent O(r) linear scan with r <= 32.
1516 int u = 0;
1517 int u_start = 0;
1518 while (u < r) {
1519 int next = u_start + (r - u);
1520 if (pos < next) break;
1521 u_start = next;
1522 ++u;
1523 }
1524 int v = u + (pos - u_start);
1525 dst[pos] = src[(size_t)u * (size_t)r + (size_t)v];
1526 }
1527}
1528
1529extern "C" __global__ void bms_flex_row_hvp_partial_packed(
1530 int n_rows,
1531 int r,
1532 int p_m,
1533 int p_g,
1534 int p_total,
1535 int h_block_start,
1536 int h_block_len,
1537 int w_block_start,
1538 int w_block_len,
1539 int h_primary_start,
1540 int w_primary_start,
1541 int rows_per_cta,
1542 const double * __restrict__ row_hessians_packed, // [n, r*(r+1)/2]
1543 const double * __restrict__ marginal_design,
1544 const double * __restrict__ logslope_design,
1545 const double * __restrict__ v,
1546 double * __restrict__ partial)
1547{
1548 int chunk = blockIdx.x;
1549 int tid = threadIdx.x;
1550 int row_lo = chunk * rows_per_cta;
1551 int row_hi = row_lo + rows_per_cta;
1552 if (row_hi > n_rows) row_hi = n_rows;
1553
1554 int per_row = r * (r + 1) / 2;
1555 double *out = partial + (size_t)chunk * (size_t)p_total;
1556 for (int j = tid; j < p_total; j += blockDim.x) {
1557 out[j] = 0.0;
1558 }
1559 __syncthreads();
1560
1561 __shared__ double row_dir[32];
1562 __shared__ double action[32];
1563 __shared__ double dot_reduce[128];
1564
1565 for (int row = row_lo; row < row_hi; ++row) {
1566 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1567 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1568 const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1569
1570 // row_dir[0] = mrow · v[0..p_m]
1571 double local = 0.0;
1572 for (int j = tid; j < p_m; j += blockDim.x) {
1573 local += mrow[j] * v[j];
1574 }
1575 dot_reduce[tid] = local;
1576 __syncthreads();
1577 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1578 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1579 __syncthreads();
1580 }
1581 if (tid == 0) row_dir[0] = dot_reduce[0];
1582
1583 // row_dir[1] = grow · v[p_m..p_m+p_g]
1584 local = 0.0;
1585 for (int j = tid; j < p_g; j += blockDim.x) {
1586 local += grow[j] * v[p_m + j];
1587 }
1588 dot_reduce[tid] = local;
1589 __syncthreads();
1590 for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
1591 if (tid < stride) dot_reduce[tid] += dot_reduce[tid + stride];
1592 __syncthreads();
1593 }
1594 if (tid == 0) row_dir[1] = dot_reduce[0];
1595
1596 if (tid == 0) {
1597 for (int k = 0; k < h_block_len; ++k) {
1598 row_dir[h_primary_start + k] = v[h_block_start + k];
1599 }
1600 for (int k = 0; k < w_block_len; ++k) {
1601 row_dir[w_primary_start + k] = v[w_block_start + k];
1602 }
1603 }
1604 __syncthreads();
1605
1606 // action[u] = Σ_w H[u, w] · row_dir[w], where H[u, w] reads from
1607 // packed-upper with (uu, vv) = (min(u, w), max(u, w)).
1608 if (tid < r) {
1609 double acc = 0.0;
1610 int u = tid;
1611 for (int w = 0; w < r; ++w) {
1612 int uu = u < w ? u : w;
1613 int vv = u < w ? w : u;
1614 acc += Hrow[bms_flex_packed_idx(uu, vv, r)] * row_dir[w];
1615 }
1616 action[tid] = acc;
1617 }
1618 __syncthreads();
1619
1620 double a0 = action[0];
1621 for (int j = tid; j < p_m; j += blockDim.x) {
1622 out[j] += a0 * mrow[j];
1623 }
1624 double a1 = action[1];
1625 for (int j = tid; j < p_g; j += blockDim.x) {
1626 out[p_m + j] += a1 * grow[j];
1627 }
1628 if (tid == 0) {
1629 for (int k = 0; k < h_block_len; ++k) {
1630 out[h_block_start + k] += action[h_primary_start + k];
1631 }
1632 for (int k = 0; k < w_block_len; ++k) {
1633 out[w_block_start + k] += action[w_primary_start + k];
1634 }
1635 }
1636 __syncthreads();
1637 }
1638}
1639
1640// ────────────────────────────────────────────────────────────────────────
1641// Phase 6 — dense joint-Hessian block kernel for the debug / exact-REML
1642// route. Materialises the full `[p_total, p_total]` row-major joint H
1643// from the per-row r×r Hessian via the P_i pullback. NOT the default
1644// Newton path: production Newton uses HVP (Phase 2/3); this kernel exists
1645// for exact-REML logdet / dense-H comparisons / diagnostic dumps where the
1646// caller genuinely needs the dense matrix on the device.
1647//
1648// Per-CTA partial: each CTA owns a contiguous chunk of rows
1649// `[chunk*rows_per_cta, (chunk+1)*rows_per_cta)`. Inside the CTA the
1650// per-row pullback computes `(P_i^T H_i P_i)[m, n]` and adds it to the
1651// CTA's shared-mem `[p_total, p_total]` partial. The reduce kernel sums
1652// chunk-major-fixed-order into a single `[p_total, p_total]` output.
1653//
1654// Math: for primary index u ∈ [0, r):
1655// * u = 0: phi_u = (X_i in slot 0..p_m, 0 elsewhere)
1656// * u = 1: phi_u = (0, G_i in slot p_m..p_m+p_g, 0 elsewhere)
1657// * u = 2+j: phi_u = e_{h_block_start + j} (j ∈ 0..h_block_len)
1658// * u = 2+h+l: phi_u = e_{w_block_start + l} (l ∈ 0..w_block_len)
1659// Then `H_full[m, n] += sum_{u,v} H_i[u,v] * phi_u[m] * phi_v[n]`.
1660//
1661// Shared-memory budget: at large-scale shape p_total = 44, a [44, 44] f64
1662// partial is 44*44*8 = 15.5 KiB — well below the V100 48 KiB/SM cap.
1663// At p_total ≤ 80 the kernel still fits (80*80*8 = 50 KiB → just over
1664// V100 cap; caller must enforce p_total ≤ DENSE_BLOCK_MAX_P). The
1665// launcher rejects oversize p_total cleanly.
1666
1667extern "C" __global__ void bms_flex_row_dense_block_partial(
1668 int n_rows,
1669 int r,
1670 int p_m,
1671 int p_g,
1672 int p_total,
1673 int h_block_start,
1674 int h_block_len,
1675 int w_block_start,
1676 int w_block_len,
1677 int h_primary_start,
1678 int w_primary_start,
1679 int rows_per_cta,
1680 const double * __restrict__ row_hessians, // [n, r*r]
1681 const double * __restrict__ marginal_design, // [n, p_m]
1682 const double * __restrict__ logslope_design, // [n, p_g]
1683 double * __restrict__ partial) // [num_chunks, p_total, p_total]
1684{
1685 extern __shared__ double shmem[];
1686 int chunk = blockIdx.x;
1687 int tid = threadIdx.x;
1688 int row_lo = chunk * rows_per_cta;
1689 int row_hi = row_lo + rows_per_cta;
1690 if (row_hi > n_rows) row_hi = n_rows;
1691
1692 int pp = p_total * p_total;
1693 double *acc = shmem; // CTA-private accumulator [p_total, p_total]
1694 for (int j = tid; j < pp; j += blockDim.x) acc[j] = 0.0;
1695 __syncthreads();
1696
1697 // Per-row work performed by thread 0 to avoid cross-thread RW
1698 // contention on `acc[]`. Per-row complexity is O(r * p_m + r * p_g
1699 // + r²): tractable because r ≤ 32 and p_m + p_g typically ≤ 64.
1700 // Tighter parallel implementations are possible (warp-stripe the
1701 // 4-way nested u-v-m-n loop) but Phase 6 is a debug-only path and
1702 // the simple version is easier to audit for correctness against
1703 // the host-side P_i pullback oracle.
1704 if (tid == 0) {
1705 for (int row = row_lo; row < row_hi; ++row) {
1706 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1707 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1708 const double *Hrow = row_hessians + (size_t)row * (size_t)r * (size_t)r;
1709 for (int u = 0; u < r; ++u) {
1710 for (int v = 0; v < r; ++v) {
1711 double huv = Hrow[u * r + v];
1712 if (huv == 0.0) continue;
1713 // For each (u, v), iterate (m, n) over the non-zero
1714 // outer-product support of phi_u and phi_v.
1715 // Build a small (offset, len, src_ptr) descriptor for
1716 // each operand block as we go.
1717 int m_off, m_len; const double *m_src; bool m_indicator;
1718 int n_off, n_len; const double *n_src; bool n_indicator;
1719 if (u == 0) { m_off = 0; m_len = p_m; m_src = mrow; m_indicator = false; }
1720 else if (u == 1) { m_off = p_m; m_len = p_g; m_src = grow; m_indicator = false; }
1721 else if (u - 2 < h_block_len) {
1722 m_off = h_block_start + (u - 2);
1723 m_len = 1; m_src = NULL; m_indicator = true;
1724 } else {
1725 m_off = w_block_start + (u - 2 - h_block_len);
1726 m_len = 1; m_src = NULL; m_indicator = true;
1727 }
1728 if (v == 0) { n_off = 0; n_len = p_m; n_src = mrow; n_indicator = false; }
1729 else if (v == 1) { n_off = p_m; n_len = p_g; n_src = grow; n_indicator = false; }
1730 else if (v - 2 < h_block_len) {
1731 n_off = h_block_start + (v - 2);
1732 n_len = 1; n_src = NULL; n_indicator = true;
1733 } else {
1734 n_off = w_block_start + (v - 2 - h_block_len);
1735 n_len = 1; n_src = NULL; n_indicator = true;
1736 }
1737 // accumulate huv * phi_u[m] * phi_v[n] into acc[m, n]
1738 for (int mi = 0; mi < m_len; ++mi) {
1739 double pm = m_indicator ? 1.0 : m_src[mi];
1740 if (pm == 0.0) continue;
1741 double scaled = huv * pm;
1742 int m_idx = m_off + mi;
1743 for (int ni = 0; ni < n_len; ++ni) {
1744 double pn = n_indicator ? 1.0 : n_src[ni];
1745 int n_idx = n_off + ni;
1746 acc[m_idx * p_total + n_idx] += scaled * pn;
1747 }
1748 }
1749 }
1750 }
1751 }
1752 }
1753 __syncthreads();
1754
1755 // Write CTA accumulator out to global memory at its chunk slot.
1756 double *out_chunk = partial + (size_t)chunk * (size_t)pp;
1757 for (int j = tid; j < pp; j += blockDim.x) {
1758 out_chunk[j] = acc[j];
1759 }
1760}
1761
1762extern "C" __global__ void bms_flex_row_dense_block_reduce(
1763 int num_chunks,
1764 int p_total,
1765 const double * __restrict__ partial,
1766 double * __restrict__ out)
1767{
1768 int j = blockIdx.x * blockDim.x + threadIdx.x;
1769 int pp = p_total * p_total;
1770 if (j >= pp) return;
1771 double acc = 0.0;
1772 for (int c = 0; c < num_chunks; ++c) {
1773 acc += partial[(size_t)c * (size_t)pp + (size_t)j];
1774 }
1775 out[j] = acc;
1776}
1777
1778extern "C" __global__ void bms_flex_row_diag_partial_packed(
1779 int n_rows,
1780 int r,
1781 int p_m,
1782 int p_g,
1783 int p_total,
1784 int h_block_start,
1785 int h_block_len,
1786 int w_block_start,
1787 int w_block_len,
1788 int h_primary_start,
1789 int w_primary_start,
1790 int rows_per_cta,
1791 const double * __restrict__ row_hessians_packed,
1792 const double * __restrict__ marginal_design,
1793 const double * __restrict__ logslope_design,
1794 double * __restrict__ partial)
1795{
1796 int chunk = blockIdx.x;
1797 int tid = threadIdx.x;
1798 int row_lo = chunk * rows_per_cta;
1799 int row_hi = row_lo + rows_per_cta;
1800 if (row_hi > n_rows) row_hi = n_rows;
1801
1802 int per_row = r * (r + 1) / 2;
1803 double *out = partial + (size_t)chunk * (size_t)p_total;
1804 for (int j = tid; j < p_total; j += blockDim.x) {
1805 out[j] = 0.0;
1806 }
1807 __syncthreads();
1808
1809 for (int row = row_lo; row < row_hi; ++row) {
1810 const double *mrow = marginal_design + (size_t)row * (size_t)p_m;
1811 const double *grow = logslope_design + (size_t)row * (size_t)p_g;
1812 const double *Hrow = row_hessians_packed + (size_t)row * (size_t)per_row;
1813 // Diagonal entry for (u, u) sits at packed_idx(u, u, r).
1814 double h00 = Hrow[bms_flex_packed_idx(0, 0, r)];
1815 double h11 = Hrow[bms_flex_packed_idx(1, 1, r)];
1816 for (int j = tid; j < p_m; j += blockDim.x) {
1817 double v = mrow[j];
1818 out[j] += h00 * v * v;
1819 }
1820 for (int j = tid; j < p_g; j += blockDim.x) {
1821 double v = grow[j];
1822 out[p_m + j] += h11 * v * v;
1823 }
1824 if (tid == 0) {
1825 for (int k = 0; k < h_block_len; ++k) {
1826 int ii = h_primary_start + k;
1827 out[h_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1828 }
1829 for (int k = 0; k < w_block_len; ++k) {
1830 int ii = w_primary_start + k;
1831 out[w_block_start + k] += Hrow[bms_flex_packed_idx(ii, ii, r)];
1832 }
1833 }
1834 __syncthreads();
1835 }
1836}
1837"#;
1838
1839#[cfg(target_os = "linux")]
1840pub(crate) struct HvpKernelBackend {
1841 pub(crate) stream: Arc<CudaStream>,
1842 pub(crate) module: Arc<CudaModule>,
1843}
1844
1845#[cfg(target_os = "linux")]
1846impl HvpKernelBackend {
1847 pub(crate) fn probe() -> Result<&'static Self, GpuError> {
1848 static BACKEND: OnceLock<Result<HvpKernelBackend, GpuError>> = OnceLock::new();
1849 BACKEND
1850 .get_or_init(|| {
1851 gam_gpu::backend_probe::probe_backend_with_compile("bms_flex_row hvp", |parts| {
1852 let ptx = gam_gpu::device_cache::compile_ptx_arch(HVP_KERNEL_SOURCE)
1856 .map_err(|err| GpuError::DriverCallFailed {
1857 reason: format!("bms_flex_row hvp NVRTC compile failed: {err}"),
1858 })?;
1859 let module =
1860 parts
1861 .ctx
1862 .load_module(ptx)
1863 .map_err(|err| GpuError::DriverCallFailed {
1864 reason: format!("bms_flex_row hvp module load failed: {err}"),
1865 })?;
1866 Ok(HvpKernelBackend {
1867 stream: parts.stream.clone(),
1868 module,
1869 })
1870 })
1871 })
1872 .as_ref()
1873 .map_err(GpuError::clone)
1874 }
1875}
1876
1877#[cfg(target_os = "linux")]
1903pub(crate) fn launch_bms_flex_row_kernel_device_resident(
1904 inputs: BmsFlexRowKernelInputs<'_>,
1905 marginal_design_row_major: &[f64],
1906 logslope_design_row_major: &[f64],
1907 block: BmsFlexBlockLayout,
1908 primary: BmsFlexPrimaryLayout,
1909) -> Result<DeviceResidentRowHess, GpuError> {
1910 inputs.validate()?;
1911 if !s_f_diagnostic_finite(&inputs) {
1912 return Err(GpuError::DriverCallFailed {
1913 reason: format!(
1914 "bms_flex_row device-resident: s_f must be positive and finite, got {}",
1915 inputs.s_f
1916 ),
1917 });
1918 }
1919 let n = inputs.n_rows;
1920 let r = inputs.r;
1921 if marginal_design_row_major.len() != n * block.p_m {
1922 return Err(GpuError::DriverCallFailed {
1923 reason: format!(
1924 "bms_flex_row device-resident: marginal_design len={} != n*p_m={}",
1925 marginal_design_row_major.len(),
1926 n * block.p_m
1927 ),
1928 });
1929 }
1930 if logslope_design_row_major.len() != n * block.p_g {
1931 return Err(GpuError::DriverCallFailed {
1932 reason: format!(
1933 "bms_flex_row device-resident: logslope_design len={} != n*p_g={}",
1934 logslope_design_row_major.len(),
1935 n * block.p_g
1936 ),
1937 });
1938 }
1939 if primary.r != r {
1940 return Err(GpuError::DriverCallFailed {
1941 reason: format!(
1942 "bms_flex_row device-resident: primary.r={} != inputs.r={}",
1943 primary.r, r
1944 ),
1945 });
1946 }
1947
1948 let backend = RowKernelBackend::probe()?;
1951 HvpKernelBackend::probe()?;
1952 let stream = backend.stream.clone();
1953
1954 let upload_f64 = |slice: &[f64], label: &str| {
1955 stream
1956 .clone_htod(slice)
1957 .map_err(|err| GpuError::DriverCallFailed {
1958 reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1959 })
1960 };
1961 let upload_u32 = |slice: &[u32], label: &str| {
1962 stream
1963 .clone_htod(slice)
1964 .map_err(|err| GpuError::DriverCallFailed {
1965 reason: format!("bms_flex_row device-resident upload {label}: {err}"),
1966 })
1967 };
1968
1969 let d_q = upload_f64(inputs.q, "q")?;
1970 let d_b = upload_f64(inputs.b, "b")?;
1971 let d_mu1 = upload_f64(inputs.mu_1, "mu_1")?;
1972 let d_mu2 = upload_f64(inputs.mu_2, "mu_2")?;
1973 let d_zobs = upload_f64(inputs.z_obs, "z_obs")?;
1974 let d_y = upload_f64(inputs.y, "y")?;
1975 let d_w = upload_f64(inputs.w, "w")?;
1976 let d_offsets = upload_u32(inputs.cell_offsets, "cell_offsets")?;
1977 let d_c0 = upload_f64(inputs.cell_c0, "cell_c0")?;
1978 let d_c1 = upload_f64(inputs.cell_c1, "cell_c1")?;
1979 let d_c2 = upload_f64(inputs.cell_c2, "cell_c2")?;
1980 let d_c3 = upload_f64(inputs.cell_c3, "cell_c3")?;
1981 let d_a = upload_f64(inputs.cell_a, "cell_a")?;
1982 let d_aa = upload_f64(inputs.cell_aa, "cell_aa")?;
1983 let d_r = upload_f64(inputs.cell_r, "cell_r")?;
1984 let d_ar = upload_f64(inputs.cell_ar, "cell_ar")?;
1985 let d_sbb = upload_f64(inputs.cell_sbb, "cell_sbb")?;
1986 let d_sbh = upload_f64(inputs.cell_sbh, "cell_sbh")?;
1987 let d_sbw = upload_f64(inputs.cell_sbw, "cell_sbw")?;
1988 let owned_host_moments: CudaSlice<f64>;
1990 let d_moments_ref: &CudaSlice<f64> = match &inputs.cell_moments {
1991 CellMomentsSource::Host(slice) => {
1992 owned_host_moments = upload_f64(slice, "cell_moments")?;
1993 &owned_host_moments
1994 }
1995 CellMomentsSource::Device(d) => *d,
1996 };
1997 let d_chi = upload_f64(inputs.chi_obs, "chi_obs")?;
1998 let d_xi = upload_f64(inputs.xi_obs, "xi_obs")?;
1999 let d_rho = upload_f64(inputs.rho_u, "rho_u")?;
2000 let d_tau = upload_f64(inputs.tau_u, "tau_u")?;
2001 let d_ruv = upload_f64(inputs.r_uv, "r_uv")?;
2002 let d_e_obs = upload_f64(inputs.e_obs, "e_obs")?;
2003
2004 let d_marginal = upload_f64(marginal_design_row_major, "marginal_design")?;
2005 let d_logslope = upload_f64(logslope_design_row_major, "logslope_design")?;
2006
2007 let mut d_neglog = stream
2008 .alloc_zeros::<f64>(n)
2009 .map_err(|err| GpuError::DriverCallFailed {
2010 reason: format!("bms_flex_row device-resident alloc neglog: {err}"),
2011 })?;
2012 let mut d_grad =
2013 stream
2014 .alloc_zeros::<f64>(n * r)
2015 .map_err(|err| GpuError::DriverCallFailed {
2016 reason: format!("bms_flex_row device-resident alloc grad: {err}"),
2017 })?;
2018 let mut d_hess =
2019 stream
2020 .alloc_zeros::<f64>(n * r * r)
2021 .map_err(|err| GpuError::DriverCallFailed {
2022 reason: format!("bms_flex_row device-resident alloc hess: {err}"),
2023 })?;
2024
2025 let func = backend
2026 .module
2027 .load_function("bms_flex_row_kernel")
2028 .map_err(|err| GpuError::DriverCallFailed {
2029 reason: format!("bms_flex_row device-resident load_function: {err}"),
2030 })?;
2031
2032 let cfg = LaunchConfig {
2033 grid_dim: (n as u32, 1, 1),
2034 block_dim: (ROW_KERNEL_THREADS, 1, 1),
2035 shared_mem_bytes: 0,
2036 };
2037 let n_i32 = i32::try_from(n).map_err(|_| GpuError::DriverCallFailed {
2038 reason: format!("bms_flex_row device-resident: n_rows={n} exceeds i32 range"),
2039 })?;
2040 let r_i32 = i32::try_from(r).map_err(|_| GpuError::DriverCallFailed {
2041 reason: format!("bms_flex_row device-resident: r={r} exceeds i32 range"),
2042 })?;
2043 let p_h_i32 = i32::try_from(inputs.p_h).map_err(|_| GpuError::DriverCallFailed {
2044 reason: format!(
2045 "bms_flex_row device-resident: p_h={} exceeds i32 range",
2046 inputs.p_h
2047 ),
2048 })?;
2049 let p_w_i32 = i32::try_from(inputs.p_w).map_err(|_| GpuError::DriverCallFailed {
2050 reason: format!(
2051 "bms_flex_row device-resident: p_w={} exceeds i32 range",
2052 inputs.p_w
2053 ),
2054 })?;
2055 let s_f_val = inputs.s_f;
2056
2057 let mut builder = stream.launch_builder(&func);
2058 builder
2059 .arg(&n_i32)
2060 .arg(&r_i32)
2061 .arg(&p_h_i32)
2062 .arg(&p_w_i32)
2063 .arg(&s_f_val)
2064 .arg(&d_q)
2065 .arg(&d_b)
2066 .arg(&d_mu1)
2067 .arg(&d_mu2)
2068 .arg(&d_zobs)
2069 .arg(&d_y)
2070 .arg(&d_w)
2071 .arg(&d_offsets)
2072 .arg(&d_c0)
2073 .arg(&d_c1)
2074 .arg(&d_c2)
2075 .arg(&d_c3)
2076 .arg(&d_a)
2077 .arg(&d_aa)
2078 .arg(&d_r)
2079 .arg(&d_ar)
2080 .arg(&d_sbb)
2081 .arg(&d_sbh)
2082 .arg(&d_sbw)
2083 .arg(d_moments_ref)
2084 .arg(&d_chi)
2085 .arg(&d_xi)
2086 .arg(&d_rho)
2087 .arg(&d_tau)
2088 .arg(&d_ruv)
2089 .arg(&d_e_obs)
2090 .arg(&mut d_neglog)
2091 .arg(&mut d_grad)
2092 .arg(&mut d_hess);
2093 unsafe { builder.launch(cfg) }.map_err(|err| GpuError::DriverCallFailed {
2098 reason: format!("bms_flex_row device-resident launch: {err}"),
2099 })?;
2100 stream
2101 .synchronize()
2102 .map_err(|err| GpuError::DriverCallFailed {
2103 reason: format!("bms_flex_row device-resident synchronize: {err}"),
2104 })?;
2105
2106 drop(d_neglog);
2114 drop(d_grad);
2115 drop(d_q);
2117 drop(d_b);
2118 drop(d_mu1);
2119 drop(d_mu2);
2120 drop(d_zobs);
2121 drop(d_y);
2122 drop(d_w);
2123 drop(d_offsets);
2124 drop(d_c0);
2125 drop(d_c1);
2126 drop(d_c2);
2127 drop(d_c3);
2128 drop(d_a);
2129 drop(d_aa);
2130 drop(d_r);
2131 drop(d_ar);
2132 drop(d_sbb);
2133 drop(d_sbh);
2134 drop(d_sbw);
2135 drop(d_chi);
2139 drop(d_xi);
2140 drop(d_rho);
2141 drop(d_tau);
2142 drop(d_ruv);
2143
2144 let bytes = ((n * r * r + marginal_design_row_major.len() + logslope_design_row_major.len())
2145 * std::mem::size_of::<f64>()) as u64;
2146 Ok(DeviceResidentRowHess {
2147 hess: d_hess,
2148 marginal_design: d_marginal,
2149 logslope_design: d_logslope,
2150 n,
2151 r,
2152 block,
2153 primary,
2154 bytes,
2155 })
2156}
2157
2158#[cfg(target_os = "linux")]
2163#[derive(Clone, Copy)]
2164pub(crate) enum BmsFlexRowLaunchMode {
2165 HvpDeviceOut,
2167 DiagonalHostOut,
2169}
2170
2171#[cfg(target_os = "linux")]
2172impl BmsFlexRowLaunchMode {
2173 pub(crate) fn partial_kernel_name(self) -> &'static str {
2175 match self {
2176 BmsFlexRowLaunchMode::HvpDeviceOut => "bms_flex_row_hvp_partial",
2177 BmsFlexRowLaunchMode::DiagonalHostOut => "bms_flex_row_diag_partial",
2178 }
2179 }
2180}
2181
2182#[cfg(target_os = "linux")]
2188pub(crate) struct PreparedBmsFlexRowLaunchArgs {
2189 pub(crate) n_i32: i32,
2190 pub(crate) r_i32: i32,
2191 pub(crate) p_m_i32: i32,
2192 pub(crate) p_g_i32: i32,
2193 pub(crate) p_total_i32: i32,
2194 pub(crate) h_block_start: i32,
2195 pub(crate) h_block_len: i32,
2196 pub(crate) w_block_start: i32,
2197 pub(crate) w_block_len: i32,
2198 pub(crate) h_primary_start: i32,
2199 pub(crate) w_primary_start: i32,
2200 pub(crate) rows_per_cta: i32,
2201 pub(crate) num_chunks: usize,
2202}
2203
2204#[cfg(target_os = "linux")]
2205impl PreparedBmsFlexRowLaunchArgs {
2206 pub(crate) fn from_storage(storage: &DeviceResidentRowHess) -> Self {
2207 let p_total = storage.block.p_total;
2208 let num_chunks = num_hvp_chunks(storage.n);
2209 PreparedBmsFlexRowLaunchArgs {
2210 n_i32: storage.n as i32,
2211 r_i32: storage.r as i32,
2212 p_m_i32: storage.block.p_m as i32,
2213 p_g_i32: storage.block.p_g as i32,
2214 p_total_i32: p_total as i32,
2215 h_block_start: storage
2216 .block
2217 .h
2218 .as_ref()
2219 .map(|r| r.start as i32)
2220 .unwrap_or(0),
2221 h_block_len: storage
2222 .block
2223 .h
2224 .as_ref()
2225 .map(|r| r.len() as i32)
2226 .unwrap_or(0),
2227 w_block_start: storage
2228 .block
2229 .w
2230 .as_ref()
2231 .map(|r| r.start as i32)
2232 .unwrap_or(0),
2233 w_block_len: storage
2234 .block
2235 .w
2236 .as_ref()
2237 .map(|r| r.len() as i32)
2238 .unwrap_or(0),
2239 h_primary_start: storage
2240 .primary
2241 .h
2242 .as_ref()
2243 .map(|r| r.start as i32)
2244 .unwrap_or(0),
2245 w_primary_start: storage
2246 .primary
2247 .w
2248 .as_ref()
2249 .map(|r| r.start as i32)
2250 .unwrap_or(0),
2251 rows_per_cta: HVP_ROWS_PER_CTA as i32,
2252 num_chunks,
2253 }
2254 }
2255}
2256
2257#[cfg(target_os = "linux")]
2271pub(crate) fn run_bms_flex_row_partial_reduce(
2272 storage: &DeviceResidentRowHess,
2273 mode: BmsFlexRowLaunchMode,
2274 d_v: Option<&CudaSlice<f64>>,
2275 d_out: &mut CudaSlice<f64>,
2276 ctx: &str,
2277) -> Result<(), GpuError> {
2278 let backend = HvpKernelBackend::probe()?;
2279 let stream = backend.stream.clone();
2280 let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2281 let p_total = storage.block.p_total;
2282
2283 let mut d_partial = stream
2284 .alloc_zeros::<f64>(args.num_chunks * p_total)
2285 .map_err(|err| GpuError::DriverCallFailed {
2286 reason: format!("bms_flex_row {ctx} alloc partial: {err}"),
2287 })?;
2288
2289 let partial_kernel_name = mode.partial_kernel_name();
2290 let part_func = backend
2291 .module
2292 .load_function(partial_kernel_name)
2293 .map_err(|err| GpuError::DriverCallFailed {
2294 reason: format!("bms_flex_row {ctx} load {partial_kernel_name}: {err}"),
2295 })?;
2296 let red_func = backend
2297 .module
2298 .load_function("bms_flex_row_hvp_reduce")
2299 .map_err(|err| GpuError::DriverCallFailed {
2300 reason: format!("bms_flex_row {ctx} load reduce: {err}"),
2301 })?;
2302
2303 let cfg_part = LaunchConfig {
2304 grid_dim: (args.num_chunks as u32, 1, 1),
2305 block_dim: (HVP_THREADS, 1, 1),
2306 shared_mem_bytes: 0,
2307 };
2308 let mut builder = stream.launch_builder(&part_func);
2309 builder
2310 .arg(&args.n_i32)
2311 .arg(&args.r_i32)
2312 .arg(&args.p_m_i32)
2313 .arg(&args.p_g_i32)
2314 .arg(&args.p_total_i32)
2315 .arg(&args.h_block_start)
2316 .arg(&args.h_block_len)
2317 .arg(&args.w_block_start)
2318 .arg(&args.w_block_len)
2319 .arg(&args.h_primary_start)
2320 .arg(&args.w_primary_start)
2321 .arg(&args.rows_per_cta)
2322 .arg(&storage.hess)
2323 .arg(&storage.marginal_design)
2324 .arg(&storage.logslope_design);
2325 if let Some(d_v) = d_v {
2326 builder.arg(d_v);
2327 }
2328 builder.arg(&mut d_partial);
2329 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2337 reason: format!("bms_flex_row {ctx} partial launch: {err}"),
2338 })?;
2339
2340 let red_threads: u32 = REDUCTION_THREADS;
2341 let red_blocks: u32 = ((p_total as u32) + red_threads - 1) / red_threads;
2342 let cfg_red = LaunchConfig {
2343 grid_dim: (red_blocks, 1, 1),
2344 block_dim: (red_threads, 1, 1),
2345 shared_mem_bytes: 0,
2346 };
2347 let num_chunks_i32 = args.num_chunks as i32;
2348 let mut builder = stream.launch_builder(&red_func);
2349 builder
2350 .arg(&num_chunks_i32)
2351 .arg(&args.p_total_i32)
2352 .arg(&d_partial)
2353 .arg(d_out);
2354 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2358 reason: format!("bms_flex_row {ctx} reduce launch: {err}"),
2359 })?;
2360 drop(d_partial);
2363 Ok(())
2364}
2365
2366#[cfg(target_os = "linux")]
2376pub(crate) fn launch_bms_flex_row_host(
2377 storage: &DeviceResidentRowHess,
2378 mode: BmsFlexRowLaunchMode,
2379 v: Option<&[f64]>,
2380 ctx: &str,
2381) -> Result<Vec<f64>, GpuError> {
2382 let p_total = storage.block.p_total;
2383 if let Some(v) = v {
2384 if v.len() != p_total {
2385 return Err(GpuError::DriverCallFailed {
2386 reason: format!(
2387 "bms_flex_row {ctx}: v.len()={} != p_total={p_total}",
2388 v.len()
2389 ),
2390 });
2391 }
2392 }
2393
2394 let backend = HvpKernelBackend::probe()?;
2395 let stream = backend.stream.clone();
2396
2397 let d_v = match v {
2398 Some(v) => Some(
2399 stream
2400 .clone_htod(v)
2401 .map_err(|err| GpuError::DriverCallFailed {
2402 reason: format!("bms_flex_row {ctx} upload v: {err}"),
2403 })?,
2404 ),
2405 None => None,
2406 };
2407 let mut d_out =
2408 stream
2409 .alloc_zeros::<f64>(p_total)
2410 .map_err(|err| GpuError::DriverCallFailed {
2411 reason: format!("bms_flex_row {ctx} alloc out: {err}"),
2412 })?;
2413
2414 run_bms_flex_row_partial_reduce(storage, mode, d_v.as_ref(), &mut d_out, ctx)?;
2415
2416 stream
2417 .synchronize()
2418 .map_err(|err| GpuError::DriverCallFailed {
2419 reason: format!("bms_flex_row {ctx} synchronize: {err}"),
2420 })?;
2421 stream
2422 .clone_dtoh(&d_out)
2423 .map_err(|err| GpuError::DriverCallFailed {
2424 reason: format!("bms_flex_row {ctx} download out: {err}"),
2425 })
2426}
2427
2428#[cfg(target_os = "linux")]
2429pub(crate) fn validate_bms_flex_row_hvp_multi_shape(
2430 storage: &DeviceResidentRowHess,
2431 rhs_count: usize,
2432 v_rhs_len: usize,
2433 out_len: Option<usize>,
2434 ctx: &str,
2435) -> Result<usize, GpuError> {
2436 if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2437 return Err(GpuError::DriverCallFailed {
2438 reason: format!(
2439 "bms_flex_row {ctx}: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2440 ),
2441 });
2442 }
2443 let p_total = storage.block.p_total;
2444 let rhs_elems = rhs_count
2445 .checked_mul(p_total)
2446 .ok_or_else(|| GpuError::DriverCallFailed {
2447 reason: format!(
2448 "bms_flex_row {ctx}: rhs_count({rhs_count})*p_total({p_total}) overflow"
2449 ),
2450 })?;
2451 if v_rhs_len != rhs_elems {
2452 return Err(GpuError::DriverCallFailed {
2453 reason: format!(
2454 "bms_flex_row {ctx}: v_rhs.len()={v_rhs_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2455 ),
2456 });
2457 }
2458 if let Some(out_len) = out_len
2459 && out_len != rhs_elems
2460 {
2461 return Err(GpuError::DriverCallFailed {
2462 reason: format!(
2463 "bms_flex_row {ctx}: out.len()={out_len} != rhs_count({rhs_count})*p_total({p_total})={rhs_elems}"
2464 ),
2465 });
2466 }
2467 Ok(rhs_elems)
2468}
2469
2470#[cfg(target_os = "linux")]
2474pub fn bms_flex_row_hvp_multi_scratch_bytes_for_shape(
2475 n: usize,
2476 p_total: usize,
2477 rhs_count: usize,
2478) -> Result<u64, GpuError> {
2479 if rhs_count == 0 || rhs_count > BMS_FLEX_ROW_HVP_MAX_RHS {
2480 return Err(GpuError::DriverCallFailed {
2481 reason: format!(
2482 "bms_flex_row hvp_multi_scratch_bytes: rhs_count={rhs_count} outside 1..={BMS_FLEX_ROW_HVP_MAX_RHS}"
2483 ),
2484 });
2485 }
2486 let num_chunks = num_hvp_chunks(n);
2487 let partial = rhs_count
2488 .checked_mul(num_chunks)
2489 .and_then(|v| v.checked_mul(p_total))
2490 .ok_or_else(|| GpuError::DriverCallFailed {
2491 reason: format!(
2492 "bms_flex_row hvp_multi_scratch_bytes: rhs_count({rhs_count})*num_chunks({num_chunks})*p_total({p_total}) overflow"
2493 ),
2494 })?;
2495 let rhs_vectors = rhs_count
2496 .checked_mul(p_total)
2497 .and_then(|v| v.checked_mul(2))
2498 .ok_or_else(|| GpuError::DriverCallFailed {
2499 reason: format!(
2500 "bms_flex_row hvp_multi_scratch_bytes: 2*rhs_count({rhs_count})*p_total({p_total}) overflow"
2501 ),
2502 })?;
2503 let elems = partial
2504 .checked_add(rhs_vectors)
2505 .ok_or_else(|| GpuError::DriverCallFailed {
2506 reason: "bms_flex_row hvp_multi_scratch_bytes: element count overflow".to_string(),
2507 })?;
2508 Ok((elems * std::mem::size_of::<f64>()) as u64)
2509}
2510
2511#[cfg(target_os = "linux")]
2512pub(crate) fn run_bms_flex_row_multi_partial_reduce(
2513 storage: &DeviceResidentRowHess,
2514 rhs_count: usize,
2515 d_v_rhs: &CudaSlice<f64>,
2516 d_out: &mut CudaSlice<f64>,
2517 ctx: &str,
2518) -> Result<(), GpuError> {
2519 let rhs_elems = validate_bms_flex_row_hvp_multi_shape(
2520 storage,
2521 rhs_count,
2522 d_v_rhs.len(),
2523 Some(d_out.len()),
2524 ctx,
2525 )?;
2526 let backend = HvpKernelBackend::probe()?;
2527 let stream = backend.stream.clone();
2528 let args = PreparedBmsFlexRowLaunchArgs::from_storage(storage);
2529 let p_total = storage.block.p_total;
2530 let partial_len = rhs_count
2531 .checked_mul(args.num_chunks)
2532 .and_then(|v| v.checked_mul(p_total))
2533 .ok_or_else(|| GpuError::DriverCallFailed {
2534 reason: format!(
2535 "bms_flex_row {ctx}: partial length overflow for rhs_count={rhs_count}, num_chunks={}, p_total={p_total}",
2536 args.num_chunks
2537 ),
2538 })?;
2539
2540 let mut d_partial =
2541 stream
2542 .alloc_zeros::<f64>(partial_len)
2543 .map_err(|err| GpuError::DriverCallFailed {
2544 reason: format!("bms_flex_row {ctx} alloc multi partial: {err}"),
2545 })?;
2546 let part_func = backend
2547 .module
2548 .load_function("bms_flex_row_hvp_multi_partial")
2549 .map_err(|err| GpuError::DriverCallFailed {
2550 reason: format!("bms_flex_row {ctx} load multi partial: {err}"),
2551 })?;
2552 let red_func = backend
2553 .module
2554 .load_function("bms_flex_row_hvp_multi_reduce")
2555 .map_err(|err| GpuError::DriverCallFailed {
2556 reason: format!("bms_flex_row {ctx} load multi reduce: {err}"),
2557 })?;
2558
2559 let rhs_count_i32 = i32::try_from(rhs_count).map_err(|_| GpuError::DriverCallFailed {
2560 reason: format!("bms_flex_row {ctx}: rhs_count={rhs_count} exceeds i32 range"),
2561 })?;
2562 let cfg_part = LaunchConfig {
2563 grid_dim: (args.num_chunks as u32, 1, 1),
2564 block_dim: (HVP_THREADS, 1, 1),
2565 shared_mem_bytes: 0,
2566 };
2567 let mut builder = stream.launch_builder(&part_func);
2568 builder
2569 .arg(&args.n_i32)
2570 .arg(&args.r_i32)
2571 .arg(&args.p_m_i32)
2572 .arg(&args.p_g_i32)
2573 .arg(&args.p_total_i32)
2574 .arg(&args.h_block_start)
2575 .arg(&args.h_block_len)
2576 .arg(&args.w_block_start)
2577 .arg(&args.w_block_len)
2578 .arg(&args.h_primary_start)
2579 .arg(&args.w_primary_start)
2580 .arg(&args.rows_per_cta)
2581 .arg(&rhs_count_i32)
2582 .arg(&storage.hess)
2583 .arg(&storage.marginal_design)
2584 .arg(&storage.logslope_design)
2585 .arg(d_v_rhs)
2586 .arg(&mut d_partial);
2587 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2592 reason: format!("bms_flex_row {ctx} multi partial launch: {err}"),
2593 })?;
2594
2595 let red_threads: u32 = REDUCTION_THREADS;
2596 let red_blocks: u32 = ((rhs_elems as u32) + red_threads - 1) / red_threads;
2597 let cfg_red = LaunchConfig {
2598 grid_dim: (red_blocks, 1, 1),
2599 block_dim: (red_threads, 1, 1),
2600 shared_mem_bytes: 0,
2601 };
2602 let num_chunks_i32 = args.num_chunks as i32;
2603 let mut builder = stream.launch_builder(&red_func);
2604 builder
2605 .arg(&num_chunks_i32)
2606 .arg(&args.p_total_i32)
2607 .arg(&rhs_count_i32)
2608 .arg(&d_partial)
2609 .arg(d_out);
2610 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2613 reason: format!("bms_flex_row {ctx} multi reduce launch: {err}"),
2614 })?;
2615 drop(d_partial);
2616 Ok(())
2617}
2618
2619#[cfg(target_os = "linux")]
2622pub(crate) fn launch_bms_flex_row_hvp_multi(
2623 storage: &DeviceResidentRowHess,
2624 v_rhs: &[f64],
2625 rhs_count: usize,
2626) -> Result<Vec<f64>, GpuError> {
2627 let rhs_elems =
2628 validate_bms_flex_row_hvp_multi_shape(storage, rhs_count, v_rhs.len(), None, "hvp_multi")?;
2629 let backend = HvpKernelBackend::probe()?;
2630 let stream = backend.stream.clone();
2631 let d_v_rhs = stream
2632 .clone_htod(v_rhs)
2633 .map_err(|err| GpuError::DriverCallFailed {
2634 reason: format!("bms_flex_row hvp_multi upload v_rhs: {err}"),
2635 })?;
2636 let mut d_out =
2637 stream
2638 .alloc_zeros::<f64>(rhs_elems)
2639 .map_err(|err| GpuError::DriverCallFailed {
2640 reason: format!("bms_flex_row hvp_multi alloc out: {err}"),
2641 })?;
2642 run_bms_flex_row_multi_partial_reduce(storage, rhs_count, &d_v_rhs, &mut d_out, "hvp_multi")?;
2643 stream
2644 .synchronize()
2645 .map_err(|err| GpuError::DriverCallFailed {
2646 reason: format!("bms_flex_row hvp_multi synchronize: {err}"),
2647 })?;
2648 stream
2649 .clone_dtoh(&d_out)
2650 .map_err(|err| GpuError::DriverCallFailed {
2651 reason: format!("bms_flex_row hvp_multi download out: {err}"),
2652 })
2653}
2654
2655#[cfg(target_os = "linux")]
2666pub(crate) fn launch_bms_flex_row_hvp_into_device(
2667 storage: &DeviceResidentRowHess,
2668 d_v: &CudaSlice<f64>,
2669 d_out: &mut CudaSlice<f64>,
2670) -> Result<(), GpuError> {
2671 let p_total = storage.block.p_total;
2672 if d_v.len() != p_total {
2673 return Err(GpuError::DriverCallFailed {
2674 reason: format!(
2675 "bms_flex_row hvp_into_device: d_v.len()={} != p_total={}",
2676 d_v.len(),
2677 p_total
2678 ),
2679 });
2680 }
2681 if d_out.len() != p_total {
2682 return Err(GpuError::DriverCallFailed {
2683 reason: format!(
2684 "bms_flex_row hvp_into_device: d_out.len()={} != p_total={}",
2685 d_out.len(),
2686 p_total
2687 ),
2688 });
2689 }
2690 run_bms_flex_row_partial_reduce(
2694 storage,
2695 BmsFlexRowLaunchMode::HvpDeviceOut,
2696 Some(d_v),
2697 d_out,
2698 "hvp_into_device",
2699 )
2700}
2701
2702#[cfg(target_os = "linux")]
2705pub(crate) fn launch_bms_flex_row_hvp(
2706 storage: &DeviceResidentRowHess,
2707 v: &[f64],
2708) -> Result<Vec<f64>, GpuError> {
2709 launch_bms_flex_row_hvp_multi(storage, v, 1)
2710}
2711
2712#[cfg(target_os = "linux")]
2715pub(crate) fn launch_bms_flex_row_diagonal(
2716 storage: &DeviceResidentRowHess,
2717) -> Result<Vec<f64>, GpuError> {
2718 launch_bms_flex_row_host(storage, BmsFlexRowLaunchMode::DiagonalHostOut, None, "diag")
2719}
2720
2721#[cfg(target_os = "linux")]
2727pub(crate) const DENSE_BLOCK_MAX_P: usize = 72;
2728
2729#[cfg(target_os = "linux")]
2735pub(crate) const DENSE_BLOCK_ROWS_PER_CTA: u32 = 32;
2736
2737#[cfg(target_os = "linux")]
2754pub fn launch_bms_flex_row_dense_block(
2755 storage: &DeviceResidentRowHess,
2756) -> Result<Vec<f64>, GpuError> {
2757 let p_total = storage.block.p_total;
2758 if p_total == 0 {
2759 return Err(GpuError::DriverCallFailed {
2760 reason: "bms_flex_row dense_block: p_total must be > 0".to_string(),
2761 });
2762 }
2763 if p_total > DENSE_BLOCK_MAX_P {
2764 return Err(GpuError::DriverCallFailed {
2765 reason: format!(
2766 "bms_flex_row dense_block: p_total={p_total} exceeds DENSE_BLOCK_MAX_P={DENSE_BLOCK_MAX_P} \
2767 (per-CTA shmem accumulator p²*8 bytes would exceed V100's 48 KiB/block)"
2768 ),
2769 });
2770 }
2771 let backend = HvpKernelBackend::probe()?;
2772 let stream = backend.stream.clone();
2773 let n = storage.n;
2774 let r = storage.r;
2775 let rows_per_cta = DENSE_BLOCK_ROWS_PER_CTA as usize;
2776 let num_chunks = n.div_ceil(rows_per_cta);
2777 let pp = p_total * p_total;
2778
2779 let mut d_partial =
2780 stream
2781 .alloc_zeros::<f64>(num_chunks * pp)
2782 .map_err(|err| GpuError::DriverCallFailed {
2783 reason: format!("bms_flex_row dense_block alloc partial: {err}"),
2784 })?;
2785 let mut d_out = stream
2786 .alloc_zeros::<f64>(pp)
2787 .map_err(|err| GpuError::DriverCallFailed {
2788 reason: format!("bms_flex_row dense_block alloc out: {err}"),
2789 })?;
2790
2791 let part_func = backend
2792 .module
2793 .load_function("bms_flex_row_dense_block_partial")
2794 .map_err(|err| GpuError::DriverCallFailed {
2795 reason: format!("bms_flex_row dense_block load partial: {err}"),
2796 })?;
2797 let red_func = backend
2798 .module
2799 .load_function("bms_flex_row_dense_block_reduce")
2800 .map_err(|err| GpuError::DriverCallFailed {
2801 reason: format!("bms_flex_row dense_block load reduce: {err}"),
2802 })?;
2803
2804 let n_i32 = n as i32;
2805 let r_i32 = r as i32;
2806 let p_m_i32 = storage.block.p_m as i32;
2807 let p_g_i32 = storage.block.p_g as i32;
2808 let p_total_i32 = p_total as i32;
2809 let h_block_start = storage
2810 .block
2811 .h
2812 .as_ref()
2813 .map(|r| r.start as i32)
2814 .unwrap_or(0);
2815 let h_block_len = storage
2816 .block
2817 .h
2818 .as_ref()
2819 .map(|r| r.len() as i32)
2820 .unwrap_or(0);
2821 let w_block_start = storage
2822 .block
2823 .w
2824 .as_ref()
2825 .map(|r| r.start as i32)
2826 .unwrap_or(0);
2827 let w_block_len = storage
2828 .block
2829 .w
2830 .as_ref()
2831 .map(|r| r.len() as i32)
2832 .unwrap_or(0);
2833 let h_primary_start = storage
2834 .primary
2835 .h
2836 .as_ref()
2837 .map(|r| r.start as i32)
2838 .unwrap_or(0);
2839 let w_primary_start = storage
2840 .primary
2841 .w
2842 .as_ref()
2843 .map(|r| r.start as i32)
2844 .unwrap_or(0);
2845 let rows_per_cta_i32 = DENSE_BLOCK_ROWS_PER_CTA as i32;
2846 let num_chunks_u32 = num_chunks as u32;
2847
2848 let shmem_bytes: u32 =
2850 u32::try_from(pp * std::mem::size_of::<f64>()).map_err(|_| GpuError::DriverCallFailed {
2851 reason: format!("dense_block shmem bytes overflow u32 for p_total={p_total}"),
2852 })?;
2853
2854 let cfg_part = LaunchConfig {
2855 grid_dim: (num_chunks_u32, 1, 1),
2856 block_dim: (HVP_THREADS, 1, 1),
2857 shared_mem_bytes: shmem_bytes,
2858 };
2859 let mut builder = stream.launch_builder(&part_func);
2860 builder
2861 .arg(&n_i32)
2862 .arg(&r_i32)
2863 .arg(&p_m_i32)
2864 .arg(&p_g_i32)
2865 .arg(&p_total_i32)
2866 .arg(&h_block_start)
2867 .arg(&h_block_len)
2868 .arg(&w_block_start)
2869 .arg(&w_block_len)
2870 .arg(&h_primary_start)
2871 .arg(&w_primary_start)
2872 .arg(&rows_per_cta_i32)
2873 .arg(&storage.hess)
2874 .arg(&storage.marginal_design)
2875 .arg(&storage.logslope_design)
2876 .arg(&mut d_partial);
2877 unsafe { builder.launch(cfg_part) }.map_err(|err| GpuError::DriverCallFailed {
2881 reason: format!("bms_flex_row dense_block partial launch: {err}"),
2882 })?;
2883
2884 let red_threads: u32 = REDUCTION_THREADS;
2885 let red_blocks: u32 = ((pp as u32) + red_threads - 1) / red_threads;
2886 let cfg_red = LaunchConfig {
2887 grid_dim: (red_blocks, 1, 1),
2888 block_dim: (red_threads, 1, 1),
2889 shared_mem_bytes: 0,
2890 };
2891 let num_chunks_i32 = num_chunks as i32;
2892 let mut builder = stream.launch_builder(&red_func);
2893 builder
2894 .arg(&num_chunks_i32)
2895 .arg(&p_total_i32)
2896 .arg(&d_partial)
2897 .arg(&mut d_out);
2898 unsafe { builder.launch(cfg_red) }.map_err(|err| GpuError::DriverCallFailed {
2900 reason: format!("bms_flex_row dense_block reduce launch: {err}"),
2901 })?;
2902 stream
2903 .synchronize()
2904 .map_err(|err| GpuError::DriverCallFailed {
2905 reason: format!("bms_flex_row dense_block sync: {err}"),
2906 })?;
2907 stream
2908 .clone_dtoh(&d_out)
2909 .map_err(|err| GpuError::DriverCallFailed {
2910 reason: format!("bms_flex_row dense_block download: {err}"),
2911 })
2912}
2913
2914#[cfg(test)]
2926mod oracle_parity_tests {
2927 use super::*;
2928
2929 pub(crate) const ORACLE_INV_TWO_PI: f64 = 1.0 / std::f64::consts::TAU;
2944 pub(crate) const ORACLE_SQRT_2: f64 = std::f64::consts::SQRT_2;
2945 pub(crate) const ORACLE_INV_SQRT_2PI: f64 = 0.398_942_280_401_432_7;
2946
2947 pub(crate) fn oracle_erfcx_nonnegative(x: f64) -> f64 {
2948 if !x.is_finite() {
2949 return if x > 0.0 { 0.0 } else { f64::INFINITY };
2950 }
2951 if x <= 0.0 {
2952 return 1.0;
2953 }
2954 if x < 26.0 {
2955 let mut xx = x * x;
2956 if xx > 700.0 {
2957 xx = 700.0;
2958 }
2959 return xx.exp() * gam_gpu::numerics_host::erfc(x);
2960 }
2961 let inv = 1.0 / x;
2962 let inv2 = inv * inv;
2963 let poly = 1.0 - 0.5 * inv2 + 0.75 * inv2 * inv2 - 1.875 * inv2 * inv2 * inv2
2964 + 6.5625 * inv2 * inv2 * inv2 * inv2;
2965 let inv_sqrt_pi: f64 = 0.564_189_583_547_756_3;
2966 inv * poly * inv_sqrt_pi
2967 }
2968
2969 pub(crate) fn oracle_log_ndtr_and_mills(x: f64) -> (f64, f64) {
2970 if x == f64::INFINITY {
2971 return (0.0, 0.0);
2972 }
2973 if x == f64::NEG_INFINITY {
2974 return (f64::NEG_INFINITY, f64::INFINITY);
2975 }
2976 if x.is_nan() {
2977 return (x, x);
2978 }
2979 const ORACLE_LEFT_TAIL_X: f64 = -37.0;
2992 if x >= ORACLE_LEFT_TAIL_X {
2993 let mut cdf = 0.5 * gam_gpu::numerics_host::erfc(-x / ORACLE_SQRT_2);
2994 if cdf < 1e-300 {
2995 cdf = 1e-300;
2996 }
2997 if cdf > 1.0 {
2998 cdf = 1.0;
2999 }
3000 let pdf = ORACLE_INV_SQRT_2PI * (-0.5 * x * x).exp();
3001 (cdf.ln(), pdf / cdf)
3002 } else {
3003 let u = -x / ORACLE_SQRT_2;
3004 let mut ex = oracle_erfcx_nonnegative(u);
3005 if ex < 1e-300 {
3006 ex = 1e-300;
3007 }
3008 let log_cdf = -u * u + (0.5 * ex).ln();
3009 let sqrt_2_over_pi: f64 = 0.797_884_560_802_865_4;
3010 (log_cdf, sqrt_2_over_pi / ex)
3011 }
3012 }
3013
3014 pub(crate) fn cpu_oracle_outputs(
3019 inputs: &BmsFlexRowKernelInputs<'_>,
3020 ) -> BmsFlexRowKernelOutputs {
3021 let n = inputs.n_rows;
3022 let r = inputs.r;
3023 let p_h = inputs.p_h;
3024 let p_w = inputs.p_w;
3025 let mut neglog = vec![0.0_f64; n];
3026 let mut grad = vec![0.0_f64; n * r];
3027 let mut hess = vec![0.0_f64; n * r * r];
3028 let cell_moments_host = match &inputs.cell_moments {
3029 CellMomentsSource::Host(slice) => *slice,
3030 #[cfg(target_os = "linux")]
3031 CellMomentsSource::Device(_) => panic!(
3032 "cpu_oracle_outputs: cell_moments is device-resident; oracle \
3038 is a host-only sanity checker"
3039 ),
3040 };
3041
3042 for row in 0..n {
3043 let mut f_u = vec![0.0_f64; r];
3045 let mut f_au = vec![0.0_f64; r];
3046 let mut f_uv = vec![0.0_f64; r * r];
3047 let mut f_a = 0.0_f64;
3048 let mut f_aa = 0.0_f64;
3049
3050 let cell_lo = inputs.cell_offsets[row] as usize;
3051 let cell_hi = inputs.cell_offsets[row + 1] as usize;
3052 for c in cell_lo..cell_hi {
3053 let c_arr = [
3054 inputs.cell_c0[c],
3055 inputs.cell_c1[c],
3056 inputs.cell_c2[c],
3057 inputs.cell_c3[c],
3058 ];
3059 let m = &cell_moments_host[c * MOMENT_STRIDE..(c + 1) * MOMENT_STRIDE];
3060
3061 let mut t = [0.0_f64; 7];
3063 for (n_idx, t_slot) in t.iter_mut().enumerate() {
3064 let mut acc = 0.0_f64;
3065 for (e, c_e) in c_arr.iter().enumerate() {
3066 acc = c_e.mul_add(m[e + n_idx], acc);
3067 }
3068 *t_slot = acc * ORACLE_INV_TWO_PI;
3069 }
3070
3071 let d_of = |r_arr: &[f64]| -> f64 {
3072 ORACLE_INV_TWO_PI
3073 * (r_arr[0] * m[0] + r_arr[1] * m[1] + r_arr[2] * m[2] + r_arr[3] * m[3])
3074 };
3075 let q_of = |r_arr: &[f64], s_arr: &[f64]| -> f64 {
3076 (r_arr[0] * s_arr[0]) * t[0]
3077 + (r_arr[0] * s_arr[1] + r_arr[1] * s_arr[0]) * t[1]
3078 + (r_arr[0] * s_arr[2] + r_arr[1] * s_arr[1] + r_arr[2] * s_arr[0]) * t[2]
3079 + (r_arr[0] * s_arr[3]
3080 + r_arr[1] * s_arr[2]
3081 + r_arr[2] * s_arr[1]
3082 + r_arr[3] * s_arr[0])
3083 * t[3]
3084 + (r_arr[1] * s_arr[3] + r_arr[2] * s_arr[2] + r_arr[3] * s_arr[1]) * t[4]
3085 + (r_arr[2] * s_arr[3] + r_arr[3] * s_arr[2]) * t[5]
3086 + (r_arr[3] * s_arr[3]) * t[6]
3087 };
3088
3089 let a_c = &inputs.cell_a[c * 4..(c + 1) * 4];
3090 let aa_c = &inputs.cell_aa[c * 4..(c + 1) * 4];
3091 f_a += d_of(a_c);
3092 f_aa += d_of(aa_c) - q_of(a_c, a_c);
3093
3094 for u in 1..r {
3095 let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3096 let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3097 let ar_u = &inputs.cell_ar[r_u_off..r_u_off + 4];
3098 f_u[u] += d_of(r_u);
3099 f_au[u] += d_of(ar_u) - q_of(a_c, r_u);
3100 }
3101
3102 for u in 1..r {
3103 let r_u_off = (c * (r - 1) + (u - 1)) * 4;
3104 let r_u = &inputs.cell_r[r_u_off..r_u_off + 4];
3105 for v in u..r {
3106 let r_v_off = (c * (r - 1) + (v - 1)) * 4;
3107 let r_v = &inputs.cell_r[r_v_off..r_v_off + 4];
3108 let q_uv = q_of(r_u, r_v);
3109 let d_s = if u == 1 && v == 1 {
3110 let s_bb = &inputs.cell_sbb[c * 4..(c + 1) * 4];
3111 d_of(s_bb)
3112 } else if u == 1 && v >= 2 && v < 2 + p_h {
3113 let j = v - 2;
3114 let off = (c * p_h + j) * 4;
3115 let s_bh = &inputs.cell_sbh[off..off + 4];
3116 d_of(s_bh)
3117 } else if u == 1 && v >= 2 + p_h && v < r {
3118 let l = v - (2 + p_h);
3119 let off = (c * p_w + l) * 4;
3120 let s_bw = &inputs.cell_sbw[off..off + 4];
3121 d_of(s_bw)
3122 } else {
3123 0.0
3124 };
3125 f_uv[u * r + v] += d_s - q_uv;
3126 }
3127 }
3128 }
3129
3130 let mu_1 = inputs.mu_1[row];
3132 let mu_2 = inputs.mu_2[row];
3133 f_u[0] = -mu_1;
3134 f_au[0] = 0.0;
3135 for v in 0..r {
3136 f_uv[v] = 0.0;
3137 f_uv[v * r] = 0.0;
3138 }
3139 f_uv[0] = -mu_2;
3140
3141 if !f_a.is_finite() || f_a <= 0.0 {
3143 neglog[row] = f64::NAN;
3144 for slot in grad[row * r..(row + 1) * r].iter_mut() {
3145 *slot = f64::NAN;
3146 }
3147 for slot in hess[row * r * r..(row + 1) * r * r].iter_mut() {
3148 *slot = f64::NAN;
3149 }
3150 continue;
3151 }
3152 let inv_fa = 1.0 / f_a;
3153
3154 let mut a_u = vec![0.0_f64; r];
3156 a_u[0] = mu_1 * inv_fa;
3157 for u in 1..r {
3158 a_u[u] = -f_u[u] * inv_fa;
3159 }
3160 let mut a_uv = vec![0.0_f64; r * r];
3161 for u in 0..r {
3162 for v in u..r {
3163 let term = f_uv[u * r + v]
3164 + f_au[v] * a_u[u]
3165 + f_au[u] * a_u[v]
3166 + f_aa * a_u[u] * a_u[v];
3167 let val = -term * inv_fa;
3168 a_uv[u * r + v] = val;
3169 a_uv[v * r + u] = val;
3170 }
3171 }
3172
3173 let chi = inputs.chi_obs[row];
3175 let xi = inputs.xi_obs[row];
3176 let rho = &inputs.rho_u[row * r..(row + 1) * r];
3177 let tau = &inputs.tau_u[row * r..(row + 1) * r];
3178 let ruv = &inputs.r_uv[row * r * r..(row + 1) * r * r];
3179 let mut bar_e_u = vec![0.0_f64; r];
3180 for u in 0..r {
3181 bar_e_u[u] = chi * a_u[u] + rho[u];
3182 }
3183 let mut bar_e_uv = vec![0.0_f64; r * r];
3184 for u in 0..r {
3185 for v in u..r {
3186 let val = chi * a_uv[u * r + v]
3187 + xi * a_u[u] * a_u[v]
3188 + tau[u] * a_u[v]
3189 + a_u[u] * tau[v]
3190 + ruv[u * r + v];
3191 bar_e_uv[u * r + v] = val;
3192 if u != v {
3193 bar_e_uv[v * r + u] = val;
3194 }
3195 }
3196 }
3197
3198 let y = inputs.y[row];
3200 let w = inputs.w[row];
3201 let s = 2.0 * y - 1.0;
3202 let e_obs = inputs.e_obs[row];
3207 let m_arg = s * e_obs;
3208 let (log_cdf, lambda) = oracle_log_ndtr_and_mills(m_arg);
3209 let a_i = -w * s * lambda;
3210 let b_i = w * lambda * (m_arg + lambda);
3211 neglog[row] = -w * log_cdf;
3212 for u in 0..r {
3213 grad[row * r + u] = a_i * bar_e_u[u];
3214 }
3215 for u in 0..r {
3216 for v in u..r {
3217 let val = b_i * bar_e_u[u] * bar_e_u[v] + a_i * bar_e_uv[u * r + v];
3218 hess[row * r * r + u * r + v] = val;
3219 if u != v {
3220 hess[row * r * r + v * r + u] = val;
3221 }
3222 }
3223 }
3224 }
3225
3226 BmsFlexRowKernelOutputs { neglog, grad, hess }
3227 }
3228
3229 mod parity_415 {
3239 use super::cpu_oracle_outputs;
3256 use crate::bms::family::*;
3257 use crate::bms::hessian_paths::*;
3258 use crate::bms::{exact_kernel, DeviationBlockConfig, LatentMeasureKind};
3259 use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix};
3260 use gam_problem::{InverseLink, ParameterBlockState, StandardLink};
3261 use ndarray::{Array1, Array2};
3262 use std::sync::{Arc, Mutex};
3263
3264 fn make_flex_parity_family(
3270 n: usize,
3271 ) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
3272 let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
3273 let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
3274 let cfg = DeviationBlockConfig {
3275 num_internal_knots: 3,
3276 ..DeviationBlockConfig::default()
3277 };
3278 let score_prepared = build_score_warp_deviation_block_from_seed(&score_seed, &cfg)
3279 .expect("build score warp block");
3280 let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
3281 &link_seed, &link_seed, &cfg,
3282 )
3283 .expect("build link deviation block");
3284
3285 let y: Array1<f64> =
3287 Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
3288 let weights: Array1<f64> =
3289 Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
3290 let z: Array1<f64> =
3291 Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
3292 let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
3293 if j == 0 {
3294 1.0
3295 } else {
3296 -0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
3297 }
3298 });
3299 let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
3300 if j == 0 {
3301 1.0
3302 } else {
3303 0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
3304 }
3305 });
3306
3307 let family = BernoulliMarginalSlopeFamily {
3308 y: Arc::new(y),
3309 weights: Arc::new(weights),
3310 z: Arc::new(z.clone()),
3311 latent_measure: LatentMeasureKind::StandardNormal,
3312 gaussian_frailty_sd: Some(0.15),
3313 base_link: InverseLink::Standard(StandardLink::Probit),
3314 marginal_design: DesignMatrix::Dense(DenseDesignMatrix::from(marginal_x.clone())),
3315 logslope_design: DesignMatrix::Dense(DenseDesignMatrix::from(logslope_x.clone())),
3316 score_warp: Some(score_prepared.runtime.clone()),
3317 link_dev: Some(link_prepared.runtime.clone()),
3318 policy: gam_runtime::resource::ResourcePolicy::default_library(),
3319 cell_moment_lru: Arc::new(exact_kernel::CellMomentLruCache::new(1024)),
3320 cell_moment_cache_stats: Arc::new(exact_kernel::CellMomentCacheStats::default()),
3321 intercept_warm_starts: None,
3322 auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
3323 auto_subsample_last_rho: Arc::new(Mutex::new(None)),
3324 };
3325
3326 let beta_m = Array1::from_vec(vec![0.12, -0.04]);
3327 let beta_g = Array1::from_vec(vec![0.35, 0.03]);
3328 let beta_h = Array1::from_iter(
3329 (0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
3330 );
3331 let beta_w = Array1::from_iter(
3332 (0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
3333 );
3334 let states = vec![
3335 ParameterBlockState {
3336 eta: marginal_x.dot(&beta_m),
3337 beta: beta_m,
3338 },
3339 ParameterBlockState {
3340 eta: logslope_x.dot(&beta_g),
3341 beta: beta_g,
3342 },
3343 ParameterBlockState {
3344 beta: beta_h,
3345 eta: Array1::zeros(z.len()),
3346 },
3347 ParameterBlockState {
3348 beta: beta_w,
3349 eta: Array1::zeros(z.len()),
3350 },
3351 ];
3352 (family, states)
3353 }
3354
3355 #[test]
3359 fn cpu_oracle_matches_cpu_family_row_analytic_flex_415() {
3360 let n = 12usize;
3361 let (family, states) = make_flex_parity_family(n);
3362 let cache = family
3363 .build_exact_eval_cache(&states)
3364 .expect("flex exact eval cache");
3365
3366 assert!(
3370 cache.row_cell_moments.is_some(),
3371 "#415 fixture must materialise the row-cell-moments bundle; the pack \
3372 and both compared paths read it"
3373 );
3374 let primary = &cache.primary;
3375 let r = primary.total;
3376 let p_h = primary.h.as_ref().map(|range| range.len()).unwrap_or(0);
3377 let p_w = primary.w.as_ref().map(|range| range.len()).unwrap_or(0);
3378 assert!(p_h > 0 && p_w > 0, "#415 fixture must be full-flex: p_h={p_h} p_w={p_w}");
3379 assert_eq!(r, 2 + p_h + p_w, "#415 fixture primary layout");
3380
3381 let owned = family
3384 .pack_bms_flex_row_kernel_inputs(&states, &cache)
3385 .expect("pack must not error")
3386 .expect("pack must succeed for the StandardNormal full-flex fixture");
3387 let inputs = owned.as_borrowed();
3388 let oracle = cpu_oracle_outputs(&inputs);
3389 assert_eq!(oracle.neglog.len(), n);
3390 assert_eq!(oracle.grad.len(), n * r);
3391 assert_eq!(oracle.hess.len(), n * r * r);
3392
3393 let tol_abs = 1e-9_f64;
3397 let tol_rel = 1e-10_f64;
3398
3399 let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(r);
3400 let mut max_rel = 0.0_f64;
3401 let mut checked_labels = [false, false];
3402
3403 for row in 0..n {
3404 let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
3405 let row_moments = cache
3406 .row_cell_moments
3407 .as_ref()
3408 .and_then(|bundle| bundle.row(row, 9));
3409 assert!(
3410 row_moments.is_some(),
3411 "row {row} must carry degree-9 cell moments (the oracle reads them)"
3412 );
3413 let label = family.y[row] as usize;
3414 if label < 2 {
3415 checked_labels[label] = true;
3416 }
3417
3418 let value = family
3419 .compute_row_analytic_flex_into_with_moments(
3420 row,
3421 &states,
3422 primary,
3423 row_ctx,
3424 row_moments,
3425 cache.cell_family_forest.as_ref(),
3426 true,
3427 &mut scratch,
3428 )
3429 .expect("cpu family row analytic flex");
3430
3431 let o_val = oracle.neglog[row];
3433 if o_val.is_nan() || value.is_nan() {
3434 assert!(
3435 o_val.is_nan() && value.is_nan(),
3436 "row {row}: NaN parity broke — oracle={o_val} family={value}"
3437 );
3438 continue;
3439 }
3440 let vd = (o_val - value).abs();
3441 let vtol = tol_abs + tol_rel * o_val.abs();
3442 max_rel = max_rel.max(vd / o_val.abs().max(1.0));
3443 assert!(
3444 vd <= vtol,
3445 "row {row} value drift: oracle={o_val:.17e} family={value:.17e} \
3446 |Δ|={vd:.3e} > tol={vtol:.3e}"
3447 );
3448
3449 for u in 0..r {
3451 let o_g = oracle.grad[row * r + u];
3452 let f_g = scratch.grad[u];
3453 let gd = (o_g - f_g).abs();
3454 let gtol = tol_abs + tol_rel * o_g.abs();
3455 max_rel = max_rel.max(gd / o_g.abs().max(1.0));
3456 assert!(
3457 gd <= gtol,
3458 "row {row} grad[{u}] drift: oracle={o_g:.17e} family={f_g:.17e} \
3459 |Δ|={gd:.3e} > tol={gtol:.3e}"
3460 );
3461 }
3462
3463 for u in 0..r {
3465 for v in 0..r {
3466 let o_h = oracle.hess[row * r * r + u * r + v];
3467 let f_h = scratch.hess[[u, v]];
3468 let hd = (o_h - f_h).abs();
3469 let htol = tol_abs + tol_rel * o_h.abs();
3470 max_rel = max_rel.max(hd / o_h.abs().max(1.0));
3471 assert!(
3472 hd <= htol,
3473 "row {row} hess[{u},{v}] drift: oracle={o_h:.17e} \
3474 family={f_h:.17e} |Δ|={hd:.3e} > tol={htol:.3e}"
3475 );
3476 }
3477 }
3478 }
3479
3480 assert!(
3483 checked_labels[0] && checked_labels[1],
3484 "#415 fixture must exercise both y=0 and y=1 rows: {checked_labels:?}"
3485 );
3486 eprintln!(
3487 "#415 parity lock: n={n} r={r} p_h={p_h} p_w={p_w} max_rel(oracle−family)={max_rel:.3e}"
3488 );
3489 }
3490 }
3491}
3492
3493#[cfg(all(test, target_os = "linux"))]
3494mod tests {
3495 use super::*;
3496 use super::oracle_parity_tests::*;
3497
3498 pub(crate) fn minimal_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3499 BmsFlexRowKernelInputs {
3500 n_rows: 1,
3501 r: 4,
3502 p_h: 1,
3503 p_w: 1,
3504 q: &buffers.q,
3505 b: &buffers.b,
3506 mu_1: &buffers.mu_1,
3507 mu_2: &buffers.mu_2,
3508 z_obs: &buffers.z_obs,
3509 y: &buffers.y,
3510 w: &buffers.w,
3511 e_obs: &buffers.e_obs,
3512 s_f: 1.0,
3513 cell_offsets: &buffers.cell_offsets,
3514 cell_c0: &buffers.cell_c0,
3515 cell_c1: &buffers.cell_c1,
3516 cell_c2: &buffers.cell_c2,
3517 cell_c3: &buffers.cell_c3,
3518 cell_a: &buffers.cell_a,
3519 cell_aa: &buffers.cell_aa,
3520 cell_r: &buffers.cell_r,
3521 cell_ar: &buffers.cell_ar,
3522 cell_sbb: &buffers.cell_sbb,
3523 cell_sbh: &buffers.cell_sbh,
3524 cell_sbw: &buffers.cell_sbw,
3525 cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3526 chi_obs: &buffers.chi_obs,
3527 xi_obs: &buffers.xi_obs,
3528 rho_u: &buffers.rho_u,
3529 tau_u: &buffers.tau_u,
3530 r_uv: &buffers.r_uv,
3531 }
3532 }
3533
3534 pub(crate) struct TestBuffers {
3535 pub(crate) q: Vec<f64>,
3536 pub(crate) b: Vec<f64>,
3537 pub(crate) mu_1: Vec<f64>,
3538 pub(crate) mu_2: Vec<f64>,
3539 pub(crate) z_obs: Vec<f64>,
3540 pub(crate) y: Vec<f64>,
3541 pub(crate) w: Vec<f64>,
3542 pub(crate) e_obs: Vec<f64>,
3543 pub(crate) cell_offsets: Vec<u32>,
3544 pub(crate) cell_c0: Vec<f64>,
3545 pub(crate) cell_c1: Vec<f64>,
3546 pub(crate) cell_c2: Vec<f64>,
3547 pub(crate) cell_c3: Vec<f64>,
3548 pub(crate) cell_a: Vec<f64>,
3549 pub(crate) cell_aa: Vec<f64>,
3550 pub(crate) cell_r: Vec<f64>,
3551 pub(crate) cell_ar: Vec<f64>,
3552 pub(crate) cell_sbb: Vec<f64>,
3553 pub(crate) cell_sbh: Vec<f64>,
3554 pub(crate) cell_sbw: Vec<f64>,
3555 pub(crate) cell_moments: Vec<f64>,
3556 pub(crate) chi_obs: Vec<f64>,
3557 pub(crate) xi_obs: Vec<f64>,
3558 pub(crate) rho_u: Vec<f64>,
3559 pub(crate) tau_u: Vec<f64>,
3560 pub(crate) r_uv: Vec<f64>,
3561 }
3562
3563 pub(crate) fn make_buffers(n_cells: u32, r: usize, p_h: usize, p_w: usize) -> TestBuffers {
3564 let cells = n_cells as usize;
3565 TestBuffers {
3566 q: vec![0.1; 1],
3567 b: vec![0.5; 1],
3568 mu_1: vec![0.3; 1],
3569 mu_2: vec![0.07; 1],
3570 z_obs: vec![0.0; 1],
3571 y: vec![1.0; 1],
3572 w: vec![1.0; 1],
3573 e_obs: vec![0.15; 1],
3574 cell_offsets: vec![0, n_cells],
3575 cell_c0: vec![0.2; cells],
3576 cell_c1: vec![-0.1; cells],
3577 cell_c2: vec![0.05; cells],
3578 cell_c3: vec![-0.02; cells],
3579 cell_a: vec![0.1; cells * 4],
3580 cell_aa: vec![0.0; cells * 4],
3581 cell_r: vec![0.05; cells * (r - 1) * 4],
3582 cell_ar: vec![0.0; cells * (r - 1) * 4],
3583 cell_sbb: vec![0.0; cells * 4],
3584 cell_sbh: vec![0.0; cells * p_h * 4],
3585 cell_sbw: vec![0.0; cells * p_w * 4],
3586 cell_moments: vec![1.0; cells * MOMENT_STRIDE],
3587 chi_obs: vec![1.0; 1],
3588 xi_obs: vec![0.0; 1],
3589 rho_u: vec![0.0; r],
3590 tau_u: vec![0.0; r],
3591 r_uv: vec![0.0; r * r],
3592 }
3593 }
3594
3595 #[test]
3596 pub(crate) fn validate_accepts_minimal_inputs() {
3597 let buffers = make_buffers(2, 4, 1, 1);
3598 let inputs = minimal_inputs(&buffers);
3599 assert!(inputs.validate().is_ok());
3600 }
3601
3602 #[test]
3603 pub(crate) fn validate_rejects_r_above_max() {
3604 let r = MAX_R + 1;
3605 let p_h = (r - 2) / 2;
3606 let p_w = (r - 2) - p_h;
3607 let buffers = make_buffers(1, r, p_h, p_w);
3608 let bad_inputs = BmsFlexRowKernelInputs {
3609 r,
3610 p_h,
3611 p_w,
3612 rho_u: &buffers.rho_u, tau_u: &buffers.tau_u,
3614 r_uv: &buffers.r_uv,
3615 cell_r: &buffers.cell_r,
3616 cell_ar: &buffers.cell_ar,
3617 cell_sbh: &buffers.cell_sbh,
3618 cell_sbw: &buffers.cell_sbw,
3619 ..minimal_inputs(&buffers)
3620 };
3621 let err = bad_inputs.validate().expect_err("r > MAX_R must fail");
3622 let msg = err.to_string();
3623 assert!(msg.contains("MAX_R"), "expected MAX_R hint, got: {msg}");
3624 }
3625
3626 #[test]
3627 pub(crate) fn validate_rejects_mismatched_r_decomposition() {
3628 let buffers = make_buffers(1, 4, 1, 1);
3629 let bad_inputs = BmsFlexRowKernelInputs {
3630 r: 4,
3631 p_h: 1,
3632 p_w: 2, ..minimal_inputs(&buffers)
3634 };
3635 let err = bad_inputs
3636 .validate()
3637 .expect_err("inconsistent r vs p_h+p_w must fail");
3638 let msg = err.to_string();
3639 assert!(msg.contains("p_h"), "got: {msg}");
3640 assert!(msg.contains("p_w"), "got: {msg}");
3641 }
3642
3643 #[test]
3644 pub(crate) fn validate_rejects_non_monotone_offsets() {
3645 let mut buffers = make_buffers(2, 4, 1, 1);
3652 buffers.cell_offsets = vec![5, 2];
3653 let inputs = minimal_inputs(&buffers);
3654 let err = inputs
3655 .validate()
3656 .expect_err("non-monotone offsets must fail");
3657 let msg = err.to_string();
3658 assert!(msg.contains("monotone"), "got: {msg}");
3659 }
3660
3661 #[test]
3662 pub(crate) fn validate_rejects_mismatched_cell_moments_length() {
3663 let mut buffers = make_buffers(2, 4, 1, 1);
3664 buffers.cell_moments.pop(); let inputs = minimal_inputs(&buffers);
3666 let err = inputs.validate().expect_err("short cell_moments must fail");
3667 let msg = err.to_string();
3668 assert!(msg.contains("cell_moments"), "got: {msg}");
3669 }
3670
3671 #[test]
3672 pub(crate) fn launch_on_non_linux_reports_driver_library_unavailable() {
3673 #[cfg(target_os = "linux")]
3677 {
3678 let buffers = make_buffers(1, 4, 1, 1);
3685 let inputs = minimal_inputs(&buffers);
3686 match launch_bms_flex_row_kernel(inputs) {
3687 Ok(_) => { }
3688 Err(GpuError::DriverLibraryUnavailable { .. })
3689 | Err(GpuError::DriverCallFailed { .. })
3690 | Err(GpuError::DriverSymbolMissing { .. })
3691 | Err(GpuError::NoDeviceKernel { .. }) => { }
3692 Err(other) => panic!("unexpected GpuError variant: {other:?}"),
3693 }
3694 }
3695 #[cfg(not(target_os = "linux"))]
3696 {
3697 let buffers = make_buffers(1, 4, 1, 1);
3698 let inputs = minimal_inputs(&buffers);
3699 match launch_bms_flex_row_kernel(inputs) {
3700 Err(GpuError::DriverLibraryUnavailable { reason }) => {
3701 assert!(
3702 reason.contains("Linux-only"),
3703 "expected Linux-only hint, got: {reason}"
3704 );
3705 }
3706 other => panic!("expected DriverLibraryUnavailable on non-Linux, got {other:?}"),
3707 }
3708 }
3709 }
3710
3711 #[test]
3712 pub(crate) fn s_f_must_be_positive_and_finite() {
3713 let buffers = make_buffers(1, 4, 1, 1);
3714 let mut inputs = minimal_inputs(&buffers);
3715 inputs.s_f = 0.0;
3716 match launch_bms_flex_row_kernel(inputs) {
3717 Err(GpuError::DriverCallFailed { reason }) => {
3718 assert!(reason.contains("s_f"), "got: {reason}");
3719 }
3720 other => panic!("expected DriverCallFailed for s_f=0, got {other:?}"),
3721 }
3722 }
3723
3724
3725 pub(crate) fn make_parity_buffers() -> TestBuffers {
3729 let n = 4_usize;
3730 let r = 5_usize;
3731 let p_h = 2_usize;
3732 let p_w = 1_usize;
3733 let row_cells: [u32; 4] = [2, 3, 4, 2];
3735 let mut cell_offsets = vec![0_u32; n + 1];
3736 for i in 0..n {
3737 cell_offsets[i + 1] = cell_offsets[i] + row_cells[i];
3738 }
3739 let total_cells = cell_offsets[n] as usize;
3740
3741 let f = |seed: usize| -> f64 {
3743 let x = ((seed.wrapping_mul(2_654_435_761)) & 0xFFFF) as f64 / 65_536.0;
3744 0.1 + 0.4 * x
3745 };
3746
3747 let q = (0..n).map(|i| 0.05 + 0.1 * (i as f64)).collect::<Vec<_>>();
3748 let b = (0..n).map(|i| 0.6 + 0.05 * (i as f64)).collect::<Vec<_>>();
3749 let mu_1 = (0..n).map(|i| 0.7 + 0.02 * (i as f64)).collect::<Vec<_>>();
3750 let mu_2 = (0..n).map(|i| 0.15 + 0.01 * (i as f64)).collect::<Vec<_>>();
3751 let z_obs = (0..n).map(|i| -0.2 + 0.1 * (i as f64)).collect::<Vec<_>>();
3752 let y = [1.0, 0.0, 1.0, 0.0].to_vec();
3753 let w = vec![1.0; n];
3754 let e_obs = (0..n).map(|i| -0.3 + 0.2 * (i as f64)).collect::<Vec<_>>();
3755
3756 let cell_c0 = (0..total_cells).map(|c| f(c + 1001)).collect::<Vec<_>>();
3757 let cell_c1 = (0..total_cells)
3758 .map(|c| -f(c + 2002) * 0.5)
3759 .collect::<Vec<_>>();
3760 let cell_c2 = (0..total_cells).map(|c| f(c + 3003) * 0.2).collect();
3761 let cell_c3 = (0..total_cells).map(|c| -f(c + 4004) * 0.1).collect();
3762
3763 let cell_a = (0..total_cells * 4)
3764 .map(|i| f(i + 5005) * 0.3)
3765 .collect::<Vec<_>>();
3766 let cell_aa = (0..total_cells * 4)
3767 .map(|i| f(i + 6006) * 0.1)
3768 .collect::<Vec<_>>();
3769 let cell_r = (0..total_cells * (r - 1) * 4)
3770 .map(|i| f(i + 7007) * 0.2)
3771 .collect::<Vec<_>>();
3772 let cell_ar = (0..total_cells * (r - 1) * 4)
3773 .map(|i| f(i + 8008) * 0.05)
3774 .collect::<Vec<_>>();
3775 let cell_sbb = (0..total_cells * 4)
3776 .map(|i| f(i + 9009) * 0.08)
3777 .collect::<Vec<_>>();
3778 let cell_sbh = (0..total_cells * p_h * 4)
3779 .map(|i| f(i + 10_010) * 0.07)
3780 .collect::<Vec<_>>();
3781 let cell_sbw = (0..total_cells * p_w * 4)
3782 .map(|i| f(i + 11_011) * 0.06)
3783 .collect::<Vec<_>>();
3784 let cell_moments = (0..total_cells * MOMENT_STRIDE)
3785 .map(|i| 0.4 + 0.1 * f(i + 12_012))
3786 .collect::<Vec<_>>();
3787
3788 let chi_obs = (0..n).map(|i| 0.9 + 0.01 * (i as f64)).collect::<Vec<_>>();
3789 let xi_obs = (0..n).map(|i| 0.2 + 0.01 * (i as f64)).collect::<Vec<_>>();
3790 let rho_u = (0..n * r).map(|i| 0.03 * f(i + 13_013)).collect::<Vec<_>>();
3791 let tau_u = (0..n * r).map(|i| 0.02 * f(i + 14_014)).collect::<Vec<_>>();
3792 let r_uv = (0..n * r * r)
3793 .map(|i| 0.04 * f(i + 15_015))
3794 .collect::<Vec<_>>();
3795
3796 TestBuffers {
3797 q,
3798 b,
3799 mu_1,
3800 mu_2,
3801 z_obs,
3802 y,
3803 w,
3804 e_obs,
3805 cell_offsets,
3806 cell_c0,
3807 cell_c1,
3808 cell_c2,
3809 cell_c3,
3810 cell_a,
3811 cell_aa,
3812 cell_r,
3813 cell_ar,
3814 cell_sbb,
3815 cell_sbh,
3816 cell_sbw,
3817 cell_moments,
3818 chi_obs,
3819 xi_obs,
3820 rho_u,
3821 tau_u,
3822 r_uv,
3823 }
3824 }
3825
3826 pub(crate) fn parity_inputs<'a>(buffers: &'a TestBuffers) -> BmsFlexRowKernelInputs<'a> {
3827 BmsFlexRowKernelInputs {
3828 n_rows: 4,
3829 r: 5,
3830 p_h: 2,
3831 p_w: 1,
3832 q: &buffers.q,
3833 b: &buffers.b,
3834 mu_1: &buffers.mu_1,
3835 mu_2: &buffers.mu_2,
3836 z_obs: &buffers.z_obs,
3837 y: &buffers.y,
3838 w: &buffers.w,
3839 e_obs: &buffers.e_obs,
3840 s_f: 1.0,
3841 cell_offsets: &buffers.cell_offsets,
3842 cell_c0: &buffers.cell_c0,
3843 cell_c1: &buffers.cell_c1,
3844 cell_c2: &buffers.cell_c2,
3845 cell_c3: &buffers.cell_c3,
3846 cell_a: &buffers.cell_a,
3847 cell_aa: &buffers.cell_aa,
3848 cell_r: &buffers.cell_r,
3849 cell_ar: &buffers.cell_ar,
3850 cell_sbb: &buffers.cell_sbb,
3851 cell_sbh: &buffers.cell_sbh,
3852 cell_sbw: &buffers.cell_sbw,
3853 cell_moments: CellMomentsSource::Host(&buffers.cell_moments),
3854 chi_obs: &buffers.chi_obs,
3855 xi_obs: &buffers.xi_obs,
3856 rho_u: &buffers.rho_u,
3857 tau_u: &buffers.tau_u,
3858 r_uv: &buffers.r_uv,
3859 }
3860 }
3861
3862 #[test]
3866 pub(crate) fn cpu_oracle_produces_finite_symmetric_hessian() {
3867 let buffers = make_parity_buffers();
3868 let inputs = parity_inputs(&buffers);
3869 inputs
3870 .validate()
3871 .expect("parity fixture must satisfy validate()");
3872 let out = cpu_oracle_outputs(&inputs);
3873 let n = inputs.n_rows;
3874 let r = inputs.r;
3875 assert_eq!(out.neglog.len(), n);
3876 assert_eq!(out.grad.len(), n * r);
3877 assert_eq!(out.hess.len(), n * r * r);
3878 for row in 0..n {
3879 assert!(
3880 out.neglog[row].is_finite(),
3881 "row {row}: neglog must be finite, got {}",
3882 out.neglog[row]
3883 );
3884 for u in 0..r {
3885 let g = out.grad[row * r + u];
3886 assert!(g.is_finite(), "row {row}: grad[{u}] = {g}");
3887 for v in 0..r {
3888 let huv = out.hess[row * r * r + u * r + v];
3889 let hvu = out.hess[row * r * r + v * r + u];
3890 assert!(huv.is_finite(), "row {row}: H[{u},{v}] = {huv}");
3891 assert_eq!(
3892 huv.to_bits(),
3893 hvu.to_bits(),
3894 "row {row}: H[{u},{v}] and H[{v},{u}] must be bit-identical"
3895 );
3896 }
3897 }
3898 }
3899 }
3900
3901 #[test]
3930 pub(crate) fn cpu_oracle_mills_layer_matches_finite_differences() {
3931 let neglog_of = |e: f64, y: f64, w: f64| -> f64 {
3934 let s = 2.0 * y - 1.0;
3935 let (log_cdf, _) = oracle_log_ndtr_and_mills(s * e);
3936 -w * log_cdf
3937 };
3938 let ab_of = |e: f64, y: f64, w: f64| -> (f64, f64) {
3941 let s = 2.0 * y - 1.0;
3942 let m_arg = s * e;
3943 let (_, lambda) = oracle_log_ndtr_and_mills(m_arg);
3944 let a_i = -w * s * lambda;
3945 let b_i = w * lambda * (m_arg + lambda);
3946 (a_i, b_i)
3947 };
3948
3949 let cases: [(f64, f64, f64); 12] = [
3954 (-1.6, 1.0, 1.0),
3955 (-0.7, 1.0, 1.0),
3956 (0.0, 1.0, 1.0),
3957 (0.9, 1.0, 1.0),
3958 (1.8, 1.0, 1.0),
3959 (-1.4, 0.0, 1.0),
3960 (-0.3, 0.0, 1.0),
3961 (0.0, 0.0, 1.0),
3962 (0.6, 0.0, 1.0),
3963 (1.5, 0.0, 1.0),
3964 (0.4, 1.0, 0.75),
3965 (-0.8, 0.0, 1.3),
3966 ];
3967 let h = 1e-3_f64;
3970 for (e, y, w) in cases {
3971 let (a_ana, b_ana) = ab_of(e, y, w);
3972
3973 let fp2 = neglog_of(e + 2.0 * h, y, w);
3974 let fp1 = neglog_of(e + h, y, w);
3975 let f0 = neglog_of(e, y, w);
3976 let fm1 = neglog_of(e - h, y, w);
3977 let fm2 = neglog_of(e - 2.0 * h, y, w);
3978
3979 let d1_fd = (-fp2 + 8.0 * fp1 - 8.0 * fm1 + fm2) / (12.0 * h);
3981 let d2_fd = (-fp2 + 16.0 * fp1 - 30.0 * f0 + 16.0 * fm1 - fm2) / (12.0 * h * h);
3983
3984 let a_abs = (a_ana - d1_fd).abs();
3985 let a_rel = a_abs / a_ana.abs().max(1.0);
3986 assert!(
3987 a_abs <= 5e-8 || a_rel <= 5e-8,
3988 "Mills A (∂neglog/∂e) drift at e={e} y={y} w={w}: \
3989 analytic={a_ana:.17e} fd={d1_fd:.17e} abs={a_abs:.3e} rel={a_rel:.3e}"
3990 );
3991
3992 let b_abs = (b_ana - d2_fd).abs();
3993 let b_rel = b_abs / b_ana.abs().max(1.0);
3994 assert!(
3995 b_abs <= 5e-6 || b_rel <= 5e-6,
3996 "Mills B (∂²neglog/∂e²) drift at e={e} y={y} w={w}: \
3997 analytic={b_ana:.17e} fd={d2_fd:.17e} abs={b_abs:.3e} rel={b_rel:.3e}"
3998 );
3999 }
4000 }
4001
4002 #[test]
4011 pub(crate) fn bms_flex_row_kernel_matches_cpu_oracle_when_cuda_available() {
4012 #[cfg(not(target_os = "linux"))]
4013 {
4014 eprintln!(
4015 "[bms_flex_row parity] non-Linux host — skipping CUDA parity \
4016 (CPU oracle exercised by sibling test)"
4017 );
4018 return;
4019 }
4020 #[cfg(target_os = "linux")]
4021 {
4022 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4023 eprintln!(
4024 "[bms_flex_row parity] no CUDA runtime — skipping device \
4025 parity (CPU oracle exercised by sibling test)"
4026 );
4027 return;
4028 };
4029 let buffers = make_parity_buffers();
4030 let inputs_cpu = parity_inputs(&buffers);
4031 inputs_cpu
4032 .validate()
4033 .expect("parity fixture must satisfy validate()");
4034 let cpu_out = cpu_oracle_outputs(&inputs_cpu);
4035
4036 let inputs_gpu = parity_inputs(&buffers);
4038 let gpu_out = match launch_bms_flex_row_kernel(inputs_gpu) {
4039 Ok(out) => out,
4040 Err(err) => panic!(
4041 "[bms_flex_row parity] launch failed on CUDA-selected host; \
4042 device/oracle parity must fail loudly on GPU CI: {err}"
4043 ),
4044 };
4045
4046 let n = inputs_cpu.n_rows;
4047 let r = inputs_cpu.r;
4048 let tol_abs = 1e-8_f64;
4049 let tol_rel = 1e-8_f64;
4050 let check_close = |label: &str, idx: usize, cpu: f64, gpu: f64| {
4051 if cpu.is_nan() || gpu.is_nan() {
4052 assert!(
4053 cpu.is_nan() && gpu.is_nan(),
4054 "{label}[{idx}]: NaN parity broke — cpu={cpu}, gpu={gpu}"
4055 );
4056 return;
4057 }
4058 let diff = (cpu - gpu).abs();
4059 let tol = tol_abs + tol_rel * cpu.abs();
4060 assert!(
4061 diff <= tol,
4062 "{label}[{idx}]: |cpu − gpu| = {diff:.3e} > tol = {tol:.3e}; \
4063 cpu={cpu:.17e}, gpu={gpu:.17e}"
4064 );
4065 };
4066 assert_eq!(cpu_out.neglog.len(), gpu_out.neglog.len());
4067 assert_eq!(cpu_out.grad.len(), gpu_out.grad.len());
4068 assert_eq!(cpu_out.hess.len(), gpu_out.hess.len());
4069 for (i, (&c, &g)) in cpu_out.neglog.iter().zip(gpu_out.neglog.iter()).enumerate() {
4070 check_close("neglog", i, c, g);
4071 }
4072 for (i, (&c, &g)) in cpu_out.grad.iter().zip(gpu_out.grad.iter()).enumerate() {
4073 check_close("grad", i, c, g);
4074 }
4075 for (i, (&c, &g)) in cpu_out.hess.iter().zip(gpu_out.hess.iter()).enumerate() {
4076 check_close("hess", i, c, g);
4077 }
4078 for row in 0..n {
4080 for u in 0..r {
4081 for v in 0..r {
4082 let a = gpu_out.hess[row * r * r + u * r + v];
4083 let bb = gpu_out.hess[row * r * r + v * r + u];
4084 assert_eq!(
4085 a.to_bits(),
4086 bb.to_bits(),
4087 "GPU row {row}: H[{u},{v}] ≠ H[{v},{u}] bit-for-bit"
4088 );
4089 }
4090 }
4091 }
4092 }
4093 }
4094
4095 #[test]
4096 pub(crate) fn kernel_source_mentions_cpu_parity_reference() {
4097 #[cfg(target_os = "linux")]
4102 assert!(ROW_KERNEL_BODY.contains("compute_row_analytic_flex_from_parts_into"));
4103 #[cfg(target_os = "linux")]
4104 assert!(ROW_KERNEL_BODY.contains("cell_first_derivative_from_moments"));
4105 }
4106
4107 pub(crate) fn cpu_oracle_bms_flex_row_hvp(
4112 row_hessians: &[f64],
4113 marginal_design: &[f64],
4114 logslope_design: &[f64],
4115 block: &BmsFlexBlockLayout,
4116 primary: &BmsFlexPrimaryLayout,
4117 n: usize,
4118 v: &[f64],
4119 ) -> Vec<f64> {
4120 let r = primary.r;
4121 let p_m = block.p_m;
4122 let p_g = block.p_g;
4123 assert_eq!(v.len(), block.p_total);
4124 assert_eq!(row_hessians.len(), n * r * r);
4125 assert_eq!(marginal_design.len(), n * p_m);
4126 assert_eq!(logslope_design.len(), n * p_g);
4127 let mut out = vec![0.0_f64; block.p_total];
4128 let mut row_dir = vec![0.0_f64; r];
4129 let mut action = vec![0.0_f64; r];
4130 for row in 0..n {
4131 let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
4132 let grow = &logslope_design[row * p_g..(row + 1) * p_g];
4133 let mut acc_q = 0.0_f64;
4134 for j in 0..p_m {
4135 acc_q += mrow[j] * v[j];
4136 }
4137 let mut acc_g = 0.0_f64;
4138 for j in 0..p_g {
4139 acc_g += grow[j] * v[p_m + j];
4140 }
4141 row_dir[0] = acc_q;
4142 row_dir[1] = acc_g;
4143 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4144 for (k, ii) in prange.clone().enumerate() {
4145 row_dir[ii] = v[brange.start + k];
4146 }
4147 }
4148 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4149 for (k, ii) in prange.clone().enumerate() {
4150 row_dir[ii] = v[brange.start + k];
4151 }
4152 }
4153 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4154 for u in 0..r {
4155 let mut acc = 0.0_f64;
4156 for v_idx in 0..r {
4157 acc += h_slice[u * r + v_idx] * row_dir[v_idx];
4158 }
4159 action[u] = acc;
4160 }
4161 let a0 = action[0];
4162 for j in 0..p_m {
4163 out[j] += a0 * mrow[j];
4164 }
4165 let a1 = action[1];
4166 for j in 0..p_g {
4167 out[p_m + j] += a1 * grow[j];
4168 }
4169 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4170 for (k, ii) in prange.clone().enumerate() {
4171 out[brange.start + k] += action[ii];
4172 }
4173 }
4174 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4175 for (k, ii) in prange.clone().enumerate() {
4176 out[brange.start + k] += action[ii];
4177 }
4178 }
4179 }
4180 out
4181 }
4182
4183 pub(crate) fn cpu_oracle_bms_flex_row_diagonal(
4184 row_hessians: &[f64],
4185 marginal_design: &[f64],
4186 logslope_design: &[f64],
4187 block: &BmsFlexBlockLayout,
4188 primary: &BmsFlexPrimaryLayout,
4189 n: usize,
4190 ) -> Vec<f64> {
4191 let r = primary.r;
4192 let p_m = block.p_m;
4193 let p_g = block.p_g;
4194 let mut out = vec![0.0_f64; block.p_total];
4195 for row in 0..n {
4196 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4197 let h00 = h_slice[0];
4198 let h11 = h_slice[r + 1];
4199 let mrow = &marginal_design[row * p_m..(row + 1) * p_m];
4200 let grow = &logslope_design[row * p_g..(row + 1) * p_g];
4201 for j in 0..p_m {
4202 out[j] += h00 * mrow[j] * mrow[j];
4203 }
4204 for j in 0..p_g {
4205 out[p_m + j] += h11 * grow[j] * grow[j];
4206 }
4207 if let (Some(prange), Some(brange)) = (primary.h.as_ref(), block.h.as_ref()) {
4208 for (k, ii) in prange.clone().enumerate() {
4209 out[brange.start + k] += h_slice[ii * r + ii];
4210 }
4211 }
4212 if let (Some(prange), Some(brange)) = (primary.w.as_ref(), block.w.as_ref()) {
4213 for (k, ii) in prange.clone().enumerate() {
4214 out[brange.start + k] += h_slice[ii * r + ii];
4215 }
4216 }
4217 }
4218 out
4219 }
4220
4221 #[test]
4225 pub(crate) fn cpu_oracle_hvp_matches_hand_computation_no_hw() {
4226 let n = 4_usize;
4227 let r = 4_usize; let p_m = 2_usize;
4229 let p_g = 2_usize;
4230 let p_h_dim = 1_usize;
4231 let p_w_dim = 1_usize;
4232 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4233 let block = BmsFlexBlockLayout {
4234 p_m,
4235 p_g,
4236 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4237 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4238 p_total,
4239 };
4240 let primary = BmsFlexPrimaryLayout {
4241 h: Some(2..3),
4242 w: Some(3..4),
4243 r,
4244 };
4245 let mut row_hessians = vec![0.0_f64; n * r * r];
4247 for row in 0..n {
4248 for u in 0..r {
4249 for v in u..r {
4250 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4251 row_hessians[row * r * r + u * r + v] = val;
4252 row_hessians[row * r * r + v * r + u] = val;
4253 }
4254 }
4255 }
4256 let mut marginal = vec![0.0_f64; n * p_m];
4257 for row in 0..n {
4258 for j in 0..p_m {
4259 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4260 }
4261 }
4262 let mut logslope = vec![0.0_f64; n * p_g];
4263 for row in 0..n {
4264 for j in 0..p_g {
4265 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4266 }
4267 }
4268 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4269 let out = cpu_oracle_bms_flex_row_hvp(
4270 &row_hessians,
4271 &marginal,
4272 &logslope,
4273 &block,
4274 &primary,
4275 n,
4276 &v,
4277 );
4278 let mut expect_out_0 = 0.0_f64;
4280 for row in 0..n {
4281 let mrow = &marginal[row * p_m..(row + 1) * p_m];
4282 let grow = &logslope[row * p_g..(row + 1) * p_g];
4283 let mut row_dir = vec![0.0_f64; r];
4284 row_dir[0] = mrow[0] * v[0] + mrow[1] * v[1];
4285 row_dir[1] = grow[0] * v[p_m] + grow[1] * v[p_m + 1];
4286 row_dir[2] = v[p_m + p_g];
4287 row_dir[3] = v[p_m + p_g + p_h_dim];
4288 let h_slice = &row_hessians[row * r * r..(row + 1) * r * r];
4289 let mut action0 = 0.0_f64;
4290 for vv in 0..r {
4294 action0 += h_slice[vv] * row_dir[vv];
4295 }
4296 expect_out_0 += action0 * mrow[0];
4297 }
4298 assert!(
4299 (out[0] - expect_out_0).abs() < 1e-12,
4300 "cpu oracle HVP out[0] mismatch: {} vs hand-check {}",
4301 out[0],
4302 expect_out_0
4303 );
4304 assert!(out.iter().all(|x| x.is_finite()));
4305 assert_eq!(out.len(), p_total);
4306 }
4307
4308 #[test]
4310 pub(crate) fn cpu_oracle_diagonal_matches_hand_computation() {
4311 let n = 3_usize;
4312 let r = 4_usize;
4313 let p_m = 2_usize;
4314 let p_g = 2_usize;
4315 let p_h_dim = 1_usize;
4316 let p_w_dim = 1_usize;
4317 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4318 let block = BmsFlexBlockLayout {
4319 p_m,
4320 p_g,
4321 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4322 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4323 p_total,
4324 };
4325 let primary = BmsFlexPrimaryLayout {
4326 h: Some(2..3),
4327 w: Some(3..4),
4328 r,
4329 };
4330 let mut row_hessians = vec![0.0_f64; n * r * r];
4331 for row in 0..n {
4332 for u in 0..r {
4333 row_hessians[row * r * r + u * r + u] = 1.0 + (row as f64) + (u as f64) * 0.5;
4334 }
4335 }
4336 let mut marginal = vec![0.0_f64; n * p_m];
4337 let mut logslope = vec![0.0_f64; n * p_g];
4338 for row in 0..n {
4339 for j in 0..p_m {
4340 marginal[row * p_m + j] = 0.2 + (row as f64) * 0.3 + (j as f64) * 0.1;
4341 }
4342 for j in 0..p_g {
4343 logslope[row * p_g + j] = -0.4 + (row as f64) * 0.1 + (j as f64) * 0.2;
4344 }
4345 }
4346 let out = cpu_oracle_bms_flex_row_diagonal(
4347 &row_hessians,
4348 &marginal,
4349 &logslope,
4350 &block,
4351 &primary,
4352 n,
4353 );
4354 let mut expect = 0.0_f64;
4356 for row in 0..n {
4357 let h00 = row_hessians[row * r * r];
4358 expect += h00 * marginal[row * p_m].powi(2);
4359 }
4360 assert!(
4361 (out[0] - expect).abs() < 1e-12,
4362 "out[0] {} vs {}",
4363 out[0],
4364 expect
4365 );
4366 let mut expect_h = 0.0_f64;
4368 for row in 0..n {
4369 expect_h += row_hessians[row * r * r + 2 * r + 2];
4370 }
4371 let h_slot = p_m + p_g;
4372 assert!(
4373 (out[h_slot] - expect_h).abs() < 1e-12,
4374 "h slot {} vs {}",
4375 out[h_slot],
4376 expect_h
4377 );
4378 }
4379
4380 #[test]
4385 pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_when_cuda_available() {
4386 #[cfg(not(target_os = "linux"))]
4387 {
4388 eprintln!(
4389 "[bms_flex_row hvp parity] non-Linux host — skipping CUDA parity \
4390 (CPU oracle exercised by sibling tests)"
4391 );
4392 }
4393 #[cfg(target_os = "linux")]
4394 {
4395 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4396 eprintln!(
4397 "[bms_flex_row hvp parity] no CUDA runtime — skipping device \
4398 parity"
4399 );
4400 return;
4401 };
4402 let n = 4_usize;
4403 let r = 4_usize;
4404 let p_m = 2_usize;
4405 let p_g = 2_usize;
4406 let p_h_dim = 1_usize;
4407 let p_w_dim = 1_usize;
4408 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4409 let block = BmsFlexBlockLayout {
4410 p_m,
4411 p_g,
4412 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4413 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4414 p_total,
4415 };
4416 let primary = BmsFlexPrimaryLayout {
4417 h: Some(2..3),
4418 w: Some(3..4),
4419 r,
4420 };
4421 let mut row_hessians = vec![0.0_f64; n * r * r];
4422 for row in 0..n {
4423 for u in 0..r {
4424 for v in u..r {
4425 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4426 row_hessians[row * r * r + u * r + v] = val;
4427 row_hessians[row * r * r + v * r + u] = val;
4428 }
4429 }
4430 }
4431 let mut marginal = vec![0.0_f64; n * p_m];
4432 for row in 0..n {
4433 for j in 0..p_m {
4434 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4435 }
4436 }
4437 let mut logslope = vec![0.0_f64; n * p_g];
4438 for row in 0..n {
4439 for j in 0..p_g {
4440 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4441 }
4442 }
4443 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4444 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4445 &row_hessians,
4446 &marginal,
4447 &logslope,
4448 &block,
4449 &primary,
4450 n,
4451 &v,
4452 );
4453 let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4454 &row_hessians,
4455 &marginal,
4456 &logslope,
4457 &block,
4458 &primary,
4459 n,
4460 );
4461
4462 let backend = HvpKernelBackend::probe()
4469 .expect("[bms_flex_row hvp parity] backend probe must succeed on CUDA host");
4470 let stream = backend.stream.clone();
4471 let d_h = stream
4472 .clone_htod(&row_hessians)
4473 .expect("[bms_flex_row hvp parity] upload h must succeed on CUDA host");
4474 let d_m = stream
4475 .clone_htod(&marginal)
4476 .expect("[bms_flex_row hvp parity] upload marg must succeed on CUDA host");
4477 let d_g = stream
4478 .clone_htod(&logslope)
4479 .expect("[bms_flex_row hvp parity] upload logslope must succeed on CUDA host");
4480 let storage = DeviceResidentRowHess {
4481 hess: d_h,
4482 marginal_design: d_m,
4483 logslope_design: d_g,
4484 n,
4485 r,
4486 block: block.clone(),
4487 primary: primary.clone(),
4488
4489 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4490 };
4491 let gpu_hvp =
4492 launch_bms_flex_row_hvp(&storage, &v).expect("HVP kernel must launch on CUDA host");
4493 let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4494 .expect("diagonal kernel must launch on CUDA host");
4495 assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4496 assert_eq!(gpu_diag.len(), cpu_diag.len());
4497 for i in 0..p_total {
4498 let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4499 assert!(
4500 diff <= 1e-10,
4501 "HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4502 cpu_hvp[i],
4503 gpu_hvp[i]
4504 );
4505 let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4506 assert!(
4507 ddiff <= 1e-10,
4508 "diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4509 cpu_diag[i],
4510 gpu_diag[i]
4511 );
4512 }
4513 }
4514 }
4515
4516 #[test]
4517 pub(crate) fn bms_flex_row_hvp_multi_scratch_is_bounded_at_large_scale_shape() {
4518 let n = 195_000_usize;
4519 let r = 20_usize;
4520 let p_total = 44_usize;
4521 let rhs_count = 4_usize;
4522 let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4523 .expect("large-scale multi-RHS scratch budget");
4524 let per_rhs_full_row_cache =
4525 (n * r * r * std::mem::size_of::<f64>()) as u64 * rhs_count as u64;
4526 assert!(
4527 scratch < per_rhs_full_row_cache / 100,
4528 "multi-RHS scratch must tile by row chunks instead of materializing \
4529 a row-Hessian copy per RHS: scratch={scratch} full_per_rhs={per_rhs_full_row_cache}"
4530 );
4531 assert!(
4532 bms_flex_row_hvp_multi_scratch_bytes_for_shape(
4533 n,
4534 p_total,
4535 BMS_FLEX_ROW_HVP_MAX_RHS + 1
4536 )
4537 .is_err(),
4538 "multi-RHS launch must reject unbounded RHS counts"
4539 );
4540 }
4541
4542 #[test]
4543 pub(crate) fn bms_flex_row_hvp_multi_kernel_matches_cpu_oracle_when_cuda_available() {
4544 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4545 eprintln!("[bms_flex_row hvp_multi parity] no CUDA runtime — skipping device parity");
4546 return;
4547 };
4548 let n = 5_usize;
4549 let r = 4_usize;
4550 let p_m = 2_usize;
4551 let p_g = 2_usize;
4552 let p_h_dim = 1_usize;
4553 let p_w_dim = 1_usize;
4554 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4555 let rhs_count = 3_usize;
4556 let block = BmsFlexBlockLayout {
4557 p_m,
4558 p_g,
4559 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4560 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4561 p_total,
4562 };
4563 let primary = BmsFlexPrimaryLayout {
4564 h: Some(2..3),
4565 w: Some(3..4),
4566 r,
4567 };
4568 let mut row_hessians = vec![0.0_f64; n * r * r];
4569 for row in 0..n {
4570 for u in 0..r {
4571 for v in u..r {
4572 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4573 row_hessians[row * r * r + u * r + v] = val;
4574 row_hessians[row * r * r + v * r + u] = val;
4575 }
4576 }
4577 }
4578 let mut marginal = vec![0.0_f64; n * p_m];
4579 let mut logslope = vec![0.0_f64; n * p_g];
4580 for row in 0..n {
4581 for j in 0..p_m {
4582 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4583 }
4584 for j in 0..p_g {
4585 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4586 }
4587 }
4588 let mut v_rhs = vec![0.0_f64; rhs_count * p_total];
4589 for rhs in 0..rhs_count {
4590 for j in 0..p_total {
4591 let seed = (rhs as f64) * 0.37 + (j as f64) * 0.19 + 0.4;
4592 v_rhs[rhs * p_total + j] = seed.sin() * 0.4 + seed.cos() * 0.2;
4593 }
4594 }
4595
4596 let backend = HvpKernelBackend::probe()
4600 .expect("[bms_flex_row hvp_multi parity] backend probe must succeed on CUDA host");
4601 let stream = backend.stream.clone();
4602 let d_h = stream
4603 .clone_htod(&row_hessians)
4604 .expect("[bms_flex_row hvp_multi parity] upload h must succeed on CUDA host");
4605 let d_m = stream
4606 .clone_htod(&marginal)
4607 .expect("[bms_flex_row hvp_multi parity] upload marg must succeed on CUDA host");
4608 let d_g = stream
4609 .clone_htod(&logslope)
4610 .expect("[bms_flex_row hvp_multi parity] upload logslope must succeed on CUDA host");
4611 let storage = DeviceResidentRowHess {
4612 hess: d_h,
4613 marginal_design: d_m,
4614 logslope_design: d_g,
4615 n,
4616 r,
4617 block: block.clone(),
4618 primary: primary.clone(),
4619
4620 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4621 };
4622 let scratch = bms_flex_row_hvp_multi_scratch_bytes_for_shape(n, p_total, rhs_count)
4623 .expect("storage scratch budget");
4624 assert!(
4625 scratch < storage.bytes,
4626 "multi-RHS scratch should stay below resident cache bytes"
4627 );
4628 let gpu = launch_bms_flex_row_hvp_multi(&storage, &v_rhs, rhs_count)
4629 .expect("multi-RHS HVP kernel must launch on CUDA host");
4630 assert_eq!(gpu.len(), rhs_count * p_total);
4631 for rhs in 0..rhs_count {
4632 let v = &v_rhs[rhs * p_total..(rhs + 1) * p_total];
4633 let cpu = cpu_oracle_bms_flex_row_hvp(
4634 &row_hessians,
4635 &marginal,
4636 &logslope,
4637 &block,
4638 &primary,
4639 n,
4640 v,
4641 );
4642 let single = launch_bms_flex_row_hvp(&storage, v)
4643 .expect("single-RHS HVP kernel must launch on CUDA host");
4644 for j in 0..p_total {
4645 let got = gpu[rhs * p_total + j];
4646 let diff = (cpu[j] - got).abs();
4647 assert!(
4648 diff <= 1e-10,
4649 "multi-RHS HVP rhs={rhs} j={j}: cpu={} gpu={} |diff|={diff:.3e}",
4650 cpu[j],
4651 got
4652 );
4653 assert_eq!(
4654 got, single[j],
4655 "multi-RHS and single-RHS host launch diverged at rhs={rhs} j={j}"
4656 );
4657 }
4658 }
4659 }
4660
4661 #[test]
4672 pub(crate) fn bms_flex_row_hvp_into_device_matches_cpu_oracle_and_host_out() {
4673 #[cfg(not(target_os = "linux"))]
4674 {
4675 eprintln!(
4676 "[bms_flex_row hvp_into_device parity] non-Linux host — skipping \
4677 CUDA parity (CPU oracle exercised by sibling tests)"
4678 );
4679 }
4680 #[cfg(target_os = "linux")]
4681 {
4682 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4683 eprintln!(
4684 "[bms_flex_row hvp_into_device parity] no CUDA runtime — \
4685 skipping device parity"
4686 );
4687 return;
4688 };
4689 let n = 4_usize;
4690 let r = 4_usize;
4691 let p_m = 2_usize;
4692 let p_g = 2_usize;
4693 let p_h_dim = 1_usize;
4694 let p_w_dim = 1_usize;
4695 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4696 let block = BmsFlexBlockLayout {
4697 p_m,
4698 p_g,
4699 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4700 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4701 p_total,
4702 };
4703 let primary = BmsFlexPrimaryLayout {
4704 h: Some(2..3),
4705 w: Some(3..4),
4706 r,
4707 };
4708 let mut row_hessians = vec![0.0_f64; n * r * r];
4709 for row in 0..n {
4710 for u in 0..r {
4711 for v in u..r {
4712 let val = ((row + 1) as f64) * (1.0 + (u as f64) + 2.0 * (v as f64));
4713 row_hessians[row * r * r + u * r + v] = val;
4714 row_hessians[row * r * r + v * r + u] = val;
4715 }
4716 }
4717 }
4718 let mut marginal = vec![0.0_f64; n * p_m];
4719 for row in 0..n {
4720 for j in 0..p_m {
4721 marginal[row * p_m + j] = 0.5 + (row as f64) * 0.1 - (j as f64) * 0.2;
4722 }
4723 }
4724 let mut logslope = vec![0.0_f64; n * p_g];
4725 for row in 0..n {
4726 for j in 0..p_g {
4727 logslope[row * p_g + j] = -0.3 + (row as f64) * 0.05 + (j as f64) * 0.15;
4728 }
4729 }
4730 let v: Vec<f64> = (0..p_total).map(|i| 0.1 + (i as f64) * 0.25).collect();
4731 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4732 &row_hessians,
4733 &marginal,
4734 &logslope,
4735 &block,
4736 &primary,
4737 n,
4738 &v,
4739 );
4740
4741 let backend = HvpKernelBackend::probe().expect(
4744 "[bms_flex_row hvp_into_device parity] backend probe must succeed on CUDA host",
4745 );
4746 let stream = backend.stream.clone();
4747 let d_h = stream
4748 .clone_htod(&row_hessians)
4749 .expect("[bms_flex_row hvp_into_device parity] upload h must succeed on CUDA host");
4750 let d_m = stream.clone_htod(&marginal).expect(
4751 "[bms_flex_row hvp_into_device parity] upload marg must succeed on CUDA host",
4752 );
4753 let d_g = stream.clone_htod(&logslope).expect(
4754 "[bms_flex_row hvp_into_device parity] upload logslope must succeed on CUDA host",
4755 );
4756 let storage = DeviceResidentRowHess {
4757 hess: d_h,
4758 marginal_design: d_m,
4759 logslope_design: d_g,
4760 n,
4761 r,
4762 block: block.clone(),
4763 primary: primary.clone(),
4764
4765 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4766 };
4767
4768 let host_out_hvp = launch_bms_flex_row_hvp(&storage, &v)
4770 .expect("host-out HVP kernel must launch on CUDA host");
4771
4772 let d_v = stream
4775 .clone_htod(&v)
4776 .expect("upload direction for device-out HVP");
4777 let mut d_out = stream
4778 .alloc_zeros::<f64>(p_total)
4779 .expect("alloc device-out HVP output");
4780 launch_bms_flex_row_hvp_into_device(&storage, &d_v, &mut d_out)
4781 .expect("device-out HVP kernel must launch on CUDA host");
4782 stream
4783 .synchronize()
4784 .expect("synchronize after device-out HVP");
4785 let device_out_hvp = stream
4786 .clone_dtoh(&d_out)
4787 .expect("download device-out HVP output");
4788
4789 assert_eq!(device_out_hvp.len(), cpu_hvp.len());
4790 assert_eq!(device_out_hvp.len(), host_out_hvp.len());
4791 for i in 0..p_total {
4792 let diff = (cpu_hvp[i] - device_out_hvp[i]).abs();
4793 assert!(
4794 diff <= 1e-10,
4795 "device-out HVP[{i}] vs CPU: cpu={} gpu={} |Δ|={diff:.3e}",
4796 cpu_hvp[i],
4797 device_out_hvp[i]
4798 );
4799 let host_diff = (host_out_hvp[i] - device_out_hvp[i]).abs();
4802 assert!(
4803 host_diff == 0.0,
4804 "device-out vs host-out HVP[{i}]: host={} device={} |Δ|={host_diff:.3e}",
4805 host_out_hvp[i],
4806 device_out_hvp[i]
4807 );
4808 }
4809 }
4810 }
4811
4812 #[test]
4825 pub(crate) fn bms_flex_row_hvp_kernel_matches_cpu_oracle_at_n64_r20_p44() {
4826 #[cfg(not(target_os = "linux"))]
4827 {
4828 eprintln!(
4829 "[bms_flex_row hvp parity n64_r20_p44] non-Linux host — \
4830 skipping CUDA parity"
4831 );
4832 }
4833 #[cfg(target_os = "linux")]
4834 {
4835 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
4836 eprintln!(
4837 "[bms_flex_row hvp parity n64_r20_p44] no CUDA runtime — \
4838 skipping device parity"
4839 );
4840 return;
4841 };
4842 let n = 64_usize;
4843 let p_m = 14_usize;
4844 let p_g = 12_usize;
4845 let p_h_dim = 10_usize;
4846 let p_w_dim = 8_usize;
4847 let r = 2 + p_h_dim + p_w_dim;
4848 assert_eq!(r, 20);
4849 let p_total = p_m + p_g + p_h_dim + p_w_dim;
4850 assert_eq!(p_total, 44);
4851 let block = BmsFlexBlockLayout {
4852 p_m,
4853 p_g,
4854 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
4855 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
4856 p_total,
4857 };
4858 let primary = BmsFlexPrimaryLayout {
4859 h: Some(2..2 + p_h_dim),
4860 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
4861 r,
4862 };
4863
4864 let mut row_hessians = vec![0.0_f64; n * r * r];
4870 for row in 0..n {
4871 let base = row * r * r;
4872 for u in 0..r {
4873 for v in 0..r {
4874 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (v as f64) * 0.317;
4875 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
4876 row_hessians[base + u * r + v] = a;
4877 }
4878 }
4879 for u in 0..r {
4880 for v in (u + 1)..r {
4881 let upper = row_hessians[base + u * r + v];
4882 let lower = row_hessians[base + v * r + u];
4883 let sym = 0.5 * (upper + lower);
4884 row_hessians[base + u * r + v] = sym;
4885 row_hessians[base + v * r + u] = sym;
4886 }
4887 row_hessians[base + u * r + u] += r as f64;
4888 }
4889 }
4890 let mut marginal = vec![0.0_f64; n * p_m];
4891 for row in 0..n {
4892 for j in 0..p_m {
4893 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
4894 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
4895 }
4896 }
4897 let mut logslope = vec![0.0_f64; n * p_g];
4898 for row in 0..n {
4899 for j in 0..p_g {
4900 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
4901 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
4902 }
4903 }
4904 let v: Vec<f64> = (0..p_total)
4905 .map(|i| {
4906 let seed = (i as f64) * 0.157 + 0.6;
4907 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
4908 })
4909 .collect();
4910
4911 let cpu_hvp = cpu_oracle_bms_flex_row_hvp(
4912 &row_hessians,
4913 &marginal,
4914 &logslope,
4915 &block,
4916 &primary,
4917 n,
4918 &v,
4919 );
4920 let cpu_diag = cpu_oracle_bms_flex_row_diagonal(
4921 &row_hessians,
4922 &marginal,
4923 &logslope,
4924 &block,
4925 &primary,
4926 n,
4927 );
4928
4929 let backend = match HvpKernelBackend::probe() {
4930 Ok(b) => b,
4931 Err(err) => {
4932 eprintln!(
4933 "[bms_flex_row hvp parity n64_r20_p44] backend probe \
4934 failed: {err}"
4935 );
4936 return;
4937 }
4938 };
4939 let stream = backend.stream.clone();
4940 let d_h = match stream.clone_htod(&row_hessians) {
4941 Ok(s) => s,
4942 Err(err) => {
4943 eprintln!(
4944 "[bms_flex_row hvp parity n64_r20_p44] upload h \
4945 failed: {err}"
4946 );
4947 return;
4948 }
4949 };
4950 let d_m = match stream.clone_htod(&marginal) {
4951 Ok(s) => s,
4952 Err(err) => {
4953 eprintln!(
4954 "[bms_flex_row hvp parity n64_r20_p44] upload marg \
4955 failed: {err}"
4956 );
4957 return;
4958 }
4959 };
4960 let d_g = match stream.clone_htod(&logslope) {
4961 Ok(s) => s,
4962 Err(err) => {
4963 eprintln!(
4964 "[bms_flex_row hvp parity n64_r20_p44] upload logslope \
4965 failed: {err}"
4966 );
4967 return;
4968 }
4969 };
4970 let storage = DeviceResidentRowHess {
4971 hess: d_h,
4972 marginal_design: d_m,
4973 logslope_design: d_g,
4974 n,
4975 r,
4976 block: block.clone(),
4977 primary: primary.clone(),
4978
4979 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
4980 };
4981 let gpu_hvp = launch_bms_flex_row_hvp(&storage, &v)
4982 .expect("HVP kernel must launch on CUDA host at n64/r20/p44");
4983 let gpu_diag = launch_bms_flex_row_diagonal(&storage)
4984 .expect("diagonal kernel must launch on CUDA host at n64/r20/p44");
4985 assert_eq!(gpu_hvp.len(), cpu_hvp.len());
4986 assert_eq!(gpu_diag.len(), cpu_diag.len());
4987 for i in 0..p_total {
4988 let diff = (cpu_hvp[i] - gpu_hvp[i]).abs();
4989 assert!(
4990 diff <= 1e-8,
4991 "n64_r20_p44 HVP[{i}]: cpu={} gpu={} |Δ|={diff:.3e}",
4992 cpu_hvp[i],
4993 gpu_hvp[i]
4994 );
4995 let ddiff = (cpu_diag[i] - gpu_diag[i]).abs();
4996 assert!(
4997 ddiff <= 1e-8,
4998 "n64_r20_p44 diag[{i}]: cpu={} gpu={} |Δ|={ddiff:.3e}",
4999 cpu_diag[i],
5000 gpu_diag[i]
5001 );
5002 }
5003 }
5004 }
5005
5006 #[test]
5012 pub(crate) fn bms_flex_row_dense_block_kernel_matches_cpu_pullback() {
5013 #[cfg(not(target_os = "linux"))]
5014 {
5015 eprintln!("[bms_flex_row dense_block parity] non-Linux host — skipping CUDA parity");
5016 }
5017 #[cfg(target_os = "linux")]
5018 {
5019 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5020 eprintln!("[bms_flex_row dense_block parity] no CUDA runtime — skipping");
5021 return;
5022 };
5023 let n = 24_usize;
5027 let p_m = 4_usize;
5028 let p_g = 4_usize;
5029 let p_h_dim = 3_usize;
5030 let p_w_dim = 3_usize;
5031 let r = 2 + p_h_dim + p_w_dim;
5032 let p_total = p_m + p_g + p_h_dim + p_w_dim;
5033 let block = BmsFlexBlockLayout {
5034 p_m,
5035 p_g,
5036 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5037 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5038 p_total,
5039 };
5040 let primary = BmsFlexPrimaryLayout {
5041 h: Some(2..2 + p_h_dim),
5042 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5043 r,
5044 };
5045
5046 let mut row_hessians = vec![0.0_f64; n * r * r];
5047 for row in 0..n {
5048 let base = row * r * r;
5049 for u in 0..r {
5050 for v in 0..r {
5051 let seed = (row as f64) * 0.21 + (u as f64) * 1.13 + (v as f64) * 0.47;
5052 let a = (seed.sin() * 1.4 + (seed * 0.6).cos() * 0.7) * 0.5;
5053 row_hessians[base + u * r + v] = a;
5054 }
5055 }
5056 for u in 0..r {
5057 for v in (u + 1)..r {
5058 let upper = row_hessians[base + u * r + v];
5059 let lower = row_hessians[base + v * r + u];
5060 let sym = 0.5 * (upper + lower);
5061 row_hessians[base + u * r + v] = sym;
5062 row_hessians[base + v * r + u] = sym;
5063 }
5064 row_hessians[base + u * r + u] += r as f64;
5065 }
5066 }
5067 let mut marginal = vec![0.0_f64; n * p_m];
5068 for row in 0..n {
5069 for j in 0..p_m {
5070 let seed = (row as f64) * 0.083 + (j as f64) * 0.171 + 0.31;
5071 marginal[row * p_m + j] = seed.sin() * 0.7 - (seed * 0.5).cos() * 0.25;
5072 }
5073 }
5074 let mut logslope = vec![0.0_f64; n * p_g];
5075 for row in 0..n {
5076 for j in 0..p_g {
5077 let seed = (row as f64) * 0.097 + (j as f64) * 0.143 - 0.15;
5078 logslope[row * p_g + j] = seed.cos() * 0.65 + (seed * 0.4).sin() * 0.2;
5079 }
5080 }
5081
5082 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5084 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5085 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5086 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5087 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5088 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5089 let mut h_cpu = vec![0.0_f64; p_total * p_total];
5090 for row in 0..n {
5091 let mrow = &marginal[row * p_m..(row + 1) * p_m];
5092 let grow = &logslope[row * p_g..(row + 1) * p_g];
5093 let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5094 let mut phi = vec![vec![0.0_f64; p_total]; r];
5096 for k in 0..p_m {
5097 phi[0][k] = mrow[k];
5098 }
5099 for k in 0..p_g {
5100 phi[1][p_m + k] = grow[k];
5101 }
5102 for k in 0..h_block_len {
5103 phi[h_primary_start + k][h_block_start + k] = 1.0;
5104 }
5105 for k in 0..w_block_len {
5106 phi[w_primary_start + k][w_block_start + k] = 1.0;
5107 }
5108 for u in 0..r {
5109 for v in 0..r {
5110 let huv = hrow[u * r + v];
5111 if huv == 0.0 {
5112 continue;
5113 }
5114 for m in 0..p_total {
5115 let pm = phi[u][m];
5116 if pm == 0.0 {
5117 continue;
5118 }
5119 let scaled = huv * pm;
5120 for nn in 0..p_total {
5121 h_cpu[m * p_total + nn] += scaled * phi[v][nn];
5122 }
5123 }
5124 }
5125 }
5126 }
5127
5128 let backend = HvpKernelBackend::probe().expect(
5133 "[bms_flex_row dense_block parity] backend probe must succeed on CUDA host",
5134 );
5135 let stream = backend.stream.clone();
5136 let d_h = stream
5137 .clone_htod(&row_hessians)
5138 .expect("[bms_flex_row dense_block parity] upload h must succeed on CUDA host");
5139 let d_m = stream
5140 .clone_htod(&marginal)
5141 .expect("[bms_flex_row dense_block parity] upload marg must succeed on CUDA host");
5142 let d_g = stream.clone_htod(&logslope).expect(
5143 "[bms_flex_row dense_block parity] upload logslope must succeed on CUDA host",
5144 );
5145 let storage = DeviceResidentRowHess {
5146 hess: d_h,
5147 marginal_design: d_m,
5148 logslope_design: d_g,
5149 n,
5150 r,
5151 block: block.clone(),
5152 primary: primary.clone(),
5153
5154 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5155 };
5156 let h_gpu = launch_bms_flex_row_dense_block(&storage)
5157 .expect("dense_block kernel must launch on CUDA host");
5158 assert_eq!(h_gpu.len(), p_total * p_total);
5159
5160 let mut max_abs = 0.0_f64;
5163 for i in 0..p_total {
5164 for j in 0..p_total {
5165 let a = h_cpu[i * p_total + j];
5166 let b = h_gpu[i * p_total + j];
5167 let diff = (a - b).abs();
5168 if diff > max_abs {
5169 max_abs = diff;
5170 }
5171 assert!(
5172 diff <= 1e-9 * a.abs().max(b.abs()).max(1.0),
5173 "dense_block[{i},{j}]: cpu={a} gpu={b} |Δ|={diff:.3e}"
5174 );
5175 }
5176 }
5177 eprintln!(
5178 "[bms_flex_row dense_block parity] n={n} r={r} p={p_total}: max|Δ|={max_abs:.3e}"
5179 );
5180 }
5181 }
5182
5183 #[test]
5203 pub(crate) fn bms_flex_row_hvp_v100_hill_climb_5x_vs_cpu_at_large_scale() {
5204 #[cfg(not(target_os = "linux"))]
5205 {
5206 eprintln!("[bms_flex_row hvp hill-climb] non-Linux host — skipping V100 perf gate");
5207 }
5208 #[cfg(target_os = "linux")]
5209 {
5210 use rayon::prelude::*;
5211
5212 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5213 eprintln!(
5214 "[bms_flex_row hvp hill-climb] no CUDA runtime — skipping V100 perf gate"
5215 );
5216 return;
5217 };
5218 let n = 195_000_usize;
5219 let p_m = 14_usize;
5220 let p_g = 12_usize;
5221 let p_h_dim = 10_usize;
5222 let p_w_dim = 8_usize;
5223 let r = 2 + p_h_dim + p_w_dim;
5224 let p_total = p_m + p_g + p_h_dim + p_w_dim;
5225 let block = BmsFlexBlockLayout {
5226 p_m,
5227 p_g,
5228 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5229 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5230 p_total,
5231 };
5232 let primary = BmsFlexPrimaryLayout {
5233 h: Some(2..2 + p_h_dim),
5234 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5235 r,
5236 };
5237
5238 let mut row_hessians = vec![0.0_f64; n * r * r];
5240 for row in 0..n {
5241 let base = row * r * r;
5242 for u in 0..r {
5243 for vv in 0..r {
5244 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5245 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5246 row_hessians[base + u * r + vv] = a;
5247 }
5248 }
5249 for u in 0..r {
5250 for vv in (u + 1)..r {
5251 let upper = row_hessians[base + u * r + vv];
5252 let lower = row_hessians[base + vv * r + u];
5253 let sym = 0.5 * (upper + lower);
5254 row_hessians[base + u * r + vv] = sym;
5255 row_hessians[base + vv * r + u] = sym;
5256 }
5257 row_hessians[base + u * r + u] += r as f64;
5258 }
5259 }
5260 let mut marginal = vec![0.0_f64; n * p_m];
5261 for row in 0..n {
5262 for j in 0..p_m {
5263 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5264 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5265 }
5266 }
5267 let mut logslope = vec![0.0_f64; n * p_g];
5268 for row in 0..n {
5269 for j in 0..p_g {
5270 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5271 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5272 }
5273 }
5274 let v: Vec<f64> = (0..p_total)
5275 .map(|i| {
5276 let seed = (i as f64) * 0.157 + 0.6;
5277 seed.sin() * 0.55 + (seed * 0.4).cos() * 0.35
5278 })
5279 .collect();
5280
5281 let backend = match HvpKernelBackend::probe() {
5283 Ok(b) => b,
5284 Err(err) => {
5285 eprintln!("[bms_flex_row hvp hill-climb] backend probe failed: {err}");
5286 return;
5287 }
5288 };
5289 let stream = backend.stream.clone();
5290 let d_h = match stream.clone_htod(&row_hessians) {
5291 Ok(s) => s,
5292 Err(err) => {
5293 eprintln!("[bms_flex_row hvp hill-climb] upload h failed (likely OOM): {err}");
5294 return;
5295 }
5296 };
5297 let d_m = match stream.clone_htod(&marginal) {
5298 Ok(s) => s,
5299 Err(err) => {
5300 eprintln!("[bms_flex_row hvp hill-climb] upload marg failed: {err}");
5301 return;
5302 }
5303 };
5304 let d_g = match stream.clone_htod(&logslope) {
5305 Ok(s) => s,
5306 Err(err) => {
5307 eprintln!("[bms_flex_row hvp hill-climb] upload logslope failed: {err}");
5308 return;
5309 }
5310 };
5311 let storage = DeviceResidentRowHess {
5312 hess: d_h,
5313 marginal_design: d_m,
5314 logslope_design: d_g,
5315 n,
5316 r,
5317 block: block.clone(),
5318 primary: primary.clone(),
5319
5320 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5321 };
5322 let warmup: usize = 3;
5323 let iters: usize = 15;
5324 for _ in 0..warmup {
5325 let out =
5326 launch_bms_flex_row_hvp(&storage, &v).expect("warmup GPU HVP must launch");
5327 assert_eq!(out.len(), p_total);
5328 }
5329 let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5330 for _ in 0..iters {
5331 let t0 = std::time::Instant::now();
5332 let out = launch_bms_flex_row_hvp(&storage, &v).expect("GPU HVP must launch");
5333 gpu_us.push(t0.elapsed().as_micros());
5334 assert_eq!(out.len(), p_total);
5335 }
5336 gpu_us.sort_unstable();
5337 let gpu_median = gpu_us[iters / 2];
5338
5339 const CHUNK_ROWS: usize = 4096;
5345 let cpu_hvp_parallel = || -> Vec<f64> {
5346 let nchunks = n.div_ceil(CHUNK_ROWS);
5347 (0..nchunks)
5348 .into_par_iter()
5349 .fold(
5350 || vec![0.0_f64; p_total],
5351 |mut acc, ci| {
5352 let lo = ci * CHUNK_ROWS;
5353 let hi = (lo + CHUNK_ROWS).min(n);
5354 let m = hi - lo;
5355 let partial = cpu_oracle_bms_flex_row_hvp(
5356 &row_hessians[lo * r * r..hi * r * r],
5357 &marginal[lo * p_m..hi * p_m],
5358 &logslope[lo * p_g..hi * p_g],
5359 &block,
5360 &primary,
5361 m,
5362 &v,
5363 );
5364 for (a, &p) in acc.iter_mut().zip(partial.iter()) {
5365 *a += p;
5366 }
5367 acc
5368 },
5369 )
5370 .reduce(
5371 || vec![0.0_f64; p_total],
5372 |mut a, b| {
5373 for (ax, bx) in a.iter_mut().zip(b.iter()) {
5374 *ax += *bx;
5375 }
5376 a
5377 },
5378 )
5379 };
5380 let warm = cpu_hvp_parallel();
5382 assert_eq!(warm.len(), p_total);
5383 let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5384 for _ in 0..iters {
5385 let t0 = std::time::Instant::now();
5386 let out = cpu_hvp_parallel();
5387 cpu_us.push(t0.elapsed().as_micros());
5388 assert_eq!(out.len(), p_total);
5389 }
5390 cpu_us.sort_unstable();
5391 let cpu_median = cpu_us[iters / 2];
5392
5393 let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5394 eprintln!(
5395 "[bms_flex_row hvp hill-climb] large-scale n={n} r={r} p={p_total}: \
5396 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5397 speedup={speedup:.2}× (charter target ≥ 5×)"
5398 );
5399 assert!(
5400 speedup >= 5.0,
5401 "large-scale HVP perf gate: GPU only {speedup:.2}× faster than CPU; \
5402 need ≥ 5× per Block 9 charter (cpu_median={cpu_median}us, \
5403 gpu_median={gpu_median}us). Hill-climb the kernel until met or \
5404 prove the kernel is at hardware roofline."
5405 );
5406 }
5407 }
5408
5409 #[test]
5414 pub(crate) fn bms_flex_row_dense_block_v100_hill_climb_10x_vs_cpu_at_large_scale() {
5415 #[cfg(not(target_os = "linux"))]
5416 {
5417 eprintln!(
5418 "[bms_flex_row dense_block hill-climb] non-Linux host — skipping V100 perf gate"
5419 );
5420 }
5421 #[cfg(target_os = "linux")]
5422 {
5423 use rayon::prelude::*;
5424
5425 let Some(_runtime) = gam_gpu::device_runtime::GpuRuntime::global() else {
5426 eprintln!(
5427 "[bms_flex_row dense_block hill-climb] no CUDA runtime — skipping V100 perf gate"
5428 );
5429 return;
5430 };
5431 let n = 195_000_usize;
5432 let p_m = 14_usize;
5433 let p_g = 12_usize;
5434 let p_h_dim = 10_usize;
5435 let p_w_dim = 8_usize;
5436 let r = 2 + p_h_dim + p_w_dim;
5437 let p_total = p_m + p_g + p_h_dim + p_w_dim;
5438 let block = BmsFlexBlockLayout {
5439 p_m,
5440 p_g,
5441 h: Some(p_m + p_g..p_m + p_g + p_h_dim),
5442 w: Some(p_m + p_g + p_h_dim..p_m + p_g + p_h_dim + p_w_dim),
5443 p_total,
5444 };
5445 let primary = BmsFlexPrimaryLayout {
5446 h: Some(2..2 + p_h_dim),
5447 w: Some(2 + p_h_dim..2 + p_h_dim + p_w_dim),
5448 r,
5449 };
5450
5451 let mut row_hessians = vec![0.0_f64; n * r * r];
5453 for row in 0..n {
5454 let base = row * r * r;
5455 for u in 0..r {
5456 for vv in 0..r {
5457 let seed = (row as f64) * 0.137 + (u as f64) * 1.901 + (vv as f64) * 0.317;
5458 let a = (seed.sin() * 1.7 + (seed * 0.5).cos() * 0.9) * 0.5;
5459 row_hessians[base + u * r + vv] = a;
5460 }
5461 }
5462 for u in 0..r {
5463 for vv in (u + 1)..r {
5464 let upper = row_hessians[base + u * r + vv];
5465 let lower = row_hessians[base + vv * r + u];
5466 let sym = 0.5 * (upper + lower);
5467 row_hessians[base + u * r + vv] = sym;
5468 row_hessians[base + vv * r + u] = sym;
5469 }
5470 row_hessians[base + u * r + u] += r as f64;
5471 }
5472 }
5473 let mut marginal = vec![0.0_f64; n * p_m];
5474 for row in 0..n {
5475 for j in 0..p_m {
5476 let seed = (row as f64) * 0.073 + (j as f64) * 0.211 + 0.4;
5477 marginal[row * p_m + j] = seed.sin() * 0.8 - (seed * 0.7).cos() * 0.3;
5478 }
5479 }
5480 let mut logslope = vec![0.0_f64; n * p_g];
5481 for row in 0..n {
5482 for j in 0..p_g {
5483 let seed = (row as f64) * 0.091 + (j as f64) * 0.179 - 0.2;
5484 logslope[row * p_g + j] = seed.cos() * 0.7 + (seed * 0.3).sin() * 0.25;
5485 }
5486 }
5487
5488 if p_total > DENSE_BLOCK_MAX_P {
5491 eprintln!(
5492 "[bms_flex_row dense_block hill-climb] p_total={p_total} > MAX={DENSE_BLOCK_MAX_P}, skipping"
5493 );
5494 return;
5495 }
5496 let backend = match HvpKernelBackend::probe() {
5497 Ok(b) => b,
5498 Err(err) => {
5499 eprintln!("[bms_flex_row dense_block hill-climb] backend probe failed: {err}");
5500 return;
5501 }
5502 };
5503 let stream = backend.stream.clone();
5504 let d_h = match stream.clone_htod(&row_hessians) {
5505 Ok(s) => s,
5506 Err(err) => {
5507 eprintln!("[bms_flex_row dense_block hill-climb] upload h failed: {err}");
5508 return;
5509 }
5510 };
5511 let d_m = match stream.clone_htod(&marginal) {
5512 Ok(s) => s,
5513 Err(err) => {
5514 eprintln!("[bms_flex_row dense_block hill-climb] upload marg failed: {err}");
5515 return;
5516 }
5517 };
5518 let d_g = match stream.clone_htod(&logslope) {
5519 Ok(s) => s,
5520 Err(err) => {
5521 eprintln!(
5522 "[bms_flex_row dense_block hill-climb] upload logslope failed: {err}"
5523 );
5524 return;
5525 }
5526 };
5527 let storage = DeviceResidentRowHess {
5528 hess: d_h,
5529 marginal_design: d_m,
5530 logslope_design: d_g,
5531 n,
5532 r,
5533 block: block.clone(),
5534 primary: primary.clone(),
5535
5536 bytes: ((n * r * r + n * p_m + n * p_g) * std::mem::size_of::<f64>()) as u64,
5537 };
5538 let warmup: usize = 2;
5540 let iters: usize = 5;
5541 for _ in 0..warmup {
5542 let out = launch_bms_flex_row_dense_block(&storage)
5543 .expect("warmup GPU dense_block must launch");
5544 assert_eq!(out.len(), p_total * p_total);
5545 }
5546 let mut gpu_us: Vec<u128> = Vec::with_capacity(iters);
5547 for _ in 0..iters {
5548 let t0 = std::time::Instant::now();
5549 let out =
5550 launch_bms_flex_row_dense_block(&storage).expect("GPU dense_block must launch");
5551 gpu_us.push(t0.elapsed().as_micros());
5552 assert_eq!(out.len(), p_total * p_total);
5553 }
5554 gpu_us.sort_unstable();
5555 let gpu_median = gpu_us[iters / 2];
5556
5557 const CHUNK_ROWS: usize = 2048;
5560 let h_block_start = block.h.as_ref().map(|r| r.start).unwrap_or(0);
5561 let h_block_len = block.h.as_ref().map(|r| r.len()).unwrap_or(0);
5562 let w_block_start = block.w.as_ref().map(|r| r.start).unwrap_or(0);
5563 let w_block_len = block.w.as_ref().map(|r| r.len()).unwrap_or(0);
5564 let h_primary_start = primary.h.as_ref().map(|r| r.start).unwrap_or(0);
5565 let w_primary_start = primary.w.as_ref().map(|r| r.start).unwrap_or(0);
5566 let cpu_build_parallel = || -> Vec<f64> {
5567 let nchunks = n.div_ceil(CHUNK_ROWS);
5568 (0..nchunks)
5569 .into_par_iter()
5570 .fold(
5571 || vec![0.0_f64; p_total * p_total],
5572 |mut acc, ci| {
5573 let lo = ci * CHUNK_ROWS;
5574 let hi = (lo + CHUNK_ROWS).min(n);
5575 let mut phi: Vec<Vec<f64>> = vec![vec![0.0_f64; p_total]; r];
5576 for row in lo..hi {
5577 for col in phi.iter_mut() {
5578 col.iter_mut().for_each(|v| *v = 0.0);
5579 }
5580 let mrow = &marginal[row * p_m..(row + 1) * p_m];
5581 let grow = &logslope[row * p_g..(row + 1) * p_g];
5582 for k in 0..p_m {
5583 phi[0][k] = mrow[k];
5584 }
5585 for k in 0..p_g {
5586 phi[1][p_m + k] = grow[k];
5587 }
5588 for k in 0..h_block_len {
5589 phi[h_primary_start + k][h_block_start + k] = 1.0;
5590 }
5591 for k in 0..w_block_len {
5592 phi[w_primary_start + k][w_block_start + k] = 1.0;
5593 }
5594 let hrow = &row_hessians[row * r * r..(row + 1) * r * r];
5595 for u in 0..r {
5596 for v_idx in 0..r {
5597 let huv = hrow[u * r + v_idx];
5598 if huv == 0.0 {
5599 continue;
5600 }
5601 for m in 0..p_total {
5602 let pm = phi[u][m];
5603 if pm == 0.0 {
5604 continue;
5605 }
5606 let scaled = huv * pm;
5607 for nn in 0..p_total {
5608 acc[m * p_total + nn] += scaled * phi[v_idx][nn];
5609 }
5610 }
5611 }
5612 }
5613 }
5614 acc
5615 },
5616 )
5617 .reduce(
5618 || vec![0.0_f64; p_total * p_total],
5619 |mut a, b| {
5620 for (ax, bx) in a.iter_mut().zip(b.iter()) {
5621 *ax += *bx;
5622 }
5623 a
5624 },
5625 )
5626 };
5627 let warm_cpu = cpu_build_parallel();
5628 assert_eq!(warm_cpu.len(), p_total * p_total);
5629 let mut cpu_us: Vec<u128> = Vec::with_capacity(iters);
5630 for _ in 0..iters {
5631 let t0 = std::time::Instant::now();
5632 let out = cpu_build_parallel();
5633 cpu_us.push(t0.elapsed().as_micros());
5634 assert_eq!(out.len(), p_total * p_total);
5635 }
5636 cpu_us.sort_unstable();
5637 let cpu_median = cpu_us[iters / 2];
5638
5639 let speedup = (cpu_median as f64) / (gpu_median.max(1) as f64);
5640 eprintln!(
5641 "[bms_flex_row dense_block hill-climb] large-scale n={n} r={r} p={p_total}: \
5642 cpu_median={cpu_median}us gpu_median={gpu_median}us \
5643 speedup={speedup:.2}× (charter target ≥ 10×)"
5644 );
5645 assert!(
5646 speedup >= 10.0,
5647 "large-scale dense-H perf gate: GPU only {speedup:.2}× faster than CPU; \
5648 need ≥ 10× per Block 9 charter (cpu_median={cpu_median}us, \
5649 gpu_median={gpu_median}us). Hill-climb the dense_block kernel \
5650 (warp-stripe the u-v-m-n loop, vectorise loads, etc.) until met \
5651 or prove the kernel is at hardware roofline."
5652 );
5653 }
5654 }
5655}
5656
5657#[cfg(test)]
5667mod parity_415_tests {
5668 use crate::bms::family::*;
5684 use crate::bms::hessian_paths::*;
5685 use crate::bms::{DeviationBlockConfig, LatentMeasureKind, exact_kernel};
5686 use gam_linalg::matrix::{DenseDesignMatrix, DesignMatrix};
5687 use gam_problem::{InverseLink, ParameterBlockState, StandardLink};
5688 use ndarray::{Array1, Array2};
5689 use std::sync::{Arc, Mutex};
5690
5691 fn make_flex_parity_family(n: usize) -> (BernoulliMarginalSlopeFamily, Vec<ParameterBlockState>) {
5697 let score_seed = Array1::linspace(-2.0, 2.0, n.max(6));
5698 let link_seed = Array1::linspace(-1.8, 1.8, n.max(6));
5699 let cfg = DeviationBlockConfig {
5700 num_internal_knots: 3,
5701 ..DeviationBlockConfig::default()
5702 };
5703 let score_prepared = build_score_warp_deviation_block_from_seed(&score_seed, &cfg)
5704 .expect("build score warp block");
5705 let link_prepared = build_link_deviation_block_from_knots_design_seed_and_weights(
5706 &link_seed, &link_seed, &cfg,
5707 )
5708 .expect("build link deviation block");
5709
5710 let y: Array1<f64> =
5712 Array1::from_iter((0..n).map(|i| if (i * 17 + 3) % 7 >= 4 { 1.0 } else { 0.0 }));
5713 let weights: Array1<f64> =
5714 Array1::from_iter((0..n).map(|i| 0.75 + ((i * 11 + 5) % 5) as f64 * 0.05));
5715 let z: Array1<f64> =
5716 Array1::from_iter((0..n).map(|i| -1.7 + 3.4 * (i as f64 + 0.5) / n as f64));
5717 let marginal_x = Array2::from_shape_fn((n, 2), |(i, j)| {
5718 if j == 0 {
5719 1.0
5720 } else {
5721 -0.4 + 0.8 * ((i * 19 + 7) % n) as f64 / n as f64
5722 }
5723 });
5724 let logslope_x = Array2::from_shape_fn((n, 2), |(i, j)| {
5725 if j == 0 {
5726 1.0
5727 } else {
5728 0.3 - 0.6 * ((i * 23 + 11) % n) as f64 / n as f64
5729 }
5730 });
5731
5732 let family = BernoulliMarginalSlopeFamily {
5733 y: Arc::new(y),
5734 weights: Arc::new(weights),
5735 z: Arc::new(z.clone()),
5736 latent_measure: LatentMeasureKind::StandardNormal,
5737 gaussian_frailty_sd: Some(0.15),
5738 base_link: InverseLink::Standard(StandardLink::Probit),
5739 marginal_design: DesignMatrix::Dense(DenseDesignMatrix::from(marginal_x.clone())),
5740 logslope_design: DesignMatrix::Dense(DenseDesignMatrix::from(logslope_x.clone())),
5741 score_warp: Some(score_prepared.runtime.clone()),
5742 link_dev: Some(link_prepared.runtime.clone()),
5743 policy: gam_runtime::resource::ResourcePolicy::default_library(),
5744 cell_moment_lru: Arc::new(exact_kernel::CellMomentLruCache::new(1024)),
5745 cell_moment_cache_stats: Arc::new(exact_kernel::CellMomentCacheStats::default()),
5746 intercept_warm_starts: None,
5747 auto_subsample_phase_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
5748 auto_subsample_last_rho: Arc::new(Mutex::new(None)),
5749 };
5750
5751 let beta_m = Array1::from_vec(vec![0.12, -0.04]);
5752 let beta_g = Array1::from_vec(vec![0.35, 0.03]);
5753 let beta_h = Array1::from_iter(
5754 (0..score_prepared.runtime.basis_dim()).map(|idx| 0.0015 * (idx as f64 + 1.0)),
5755 );
5756 let beta_w = Array1::from_iter(
5757 (0..link_prepared.runtime.basis_dim()).map(|idx| -0.001 * (idx as f64 + 1.0)),
5758 );
5759 let states = vec![
5760 ParameterBlockState {
5761 eta: marginal_x.dot(&beta_m),
5762 beta: beta_m,
5763 },
5764 ParameterBlockState {
5765 eta: logslope_x.dot(&beta_g),
5766 beta: beta_g,
5767 },
5768 ParameterBlockState {
5769 beta: beta_h,
5770 eta: Array1::zeros(z.len()),
5771 },
5772 ParameterBlockState {
5773 beta: beta_w,
5774 eta: Array1::zeros(z.len()),
5775 },
5776 ];
5777 (family, states)
5778 }
5779
5780 #[test]
5783 fn cpu_oracle_matches_cpu_family_row_analytic_flex_415() {
5784 let n = 12usize;
5785 let (family, states) = make_flex_parity_family(n);
5786 let cache = family
5787 .build_exact_eval_cache(&states)
5788 .expect("flex exact eval cache");
5789
5790 assert!(
5794 cache.row_cell_moments.is_some(),
5795 "#415 fixture must materialise the row-cell-moments bundle; the pack \
5796 and both compared paths read it"
5797 );
5798 let primary = &cache.primary;
5799 let r = primary.total;
5800 let p_h = primary.h.as_ref().map(|range| range.len()).unwrap_or(0);
5801 let p_w = primary.w.as_ref().map(|range| range.len()).unwrap_or(0);
5802 assert!(p_h > 0 && p_w > 0, "#415 fixture must be full-flex: p_h={p_h} p_w={p_w}");
5803 assert_eq!(r, 2 + p_h + p_w, "#415 fixture primary layout");
5804
5805 let owned = family
5808 .pack_bms_flex_row_kernel_inputs(&states, &cache)
5809 .expect("pack must not error")
5810 .expect("pack must succeed for the StandardNormal full-flex fixture");
5811 let inputs = owned.as_borrowed();
5812 let oracle = super::oracle_parity_tests::cpu_oracle_outputs(&inputs);
5813 assert_eq!(oracle.neglog.len(), n);
5814 assert_eq!(oracle.grad.len(), n * r);
5815 assert_eq!(oracle.hess.len(), n * r * r);
5816
5817 let tol_abs = 1e-9_f64;
5821 let tol_rel = 1e-10_f64;
5822
5823 let mut scratch = BernoulliMarginalSlopeFlexRowScratch::new(r);
5824 let mut max_rel = 0.0_f64;
5825 let mut checked_labels = [false, false];
5826
5827 for row in 0..n {
5828 let row_ctx = BernoulliMarginalSlopeFamily::row_ctx(&cache, row);
5829 let row_moments = cache
5830 .row_cell_moments
5831 .as_ref()
5832 .and_then(|bundle| bundle.row(row, 9));
5833 assert!(
5834 row_moments.is_some(),
5835 "row {row} must carry degree-9 cell moments (the oracle reads them)"
5836 );
5837 let label = family.y[row] as usize;
5838 if label < 2 {
5839 checked_labels[label] = true;
5840 }
5841
5842 let value = family
5843 .compute_row_analytic_flex_into_with_moments(
5844 row,
5845 &states,
5846 primary,
5847 row_ctx,
5848 row_moments,
5849 cache.cell_family_forest.as_ref(),
5850 true,
5851 &mut scratch,
5852 )
5853 .expect("cpu family row analytic flex");
5854
5855 let o_val = oracle.neglog[row];
5857 if o_val.is_nan() || value.is_nan() {
5858 assert!(
5859 o_val.is_nan() && value.is_nan(),
5860 "row {row}: NaN parity broke — oracle={o_val} family={value}"
5861 );
5862 continue;
5863 }
5864 let vd = (o_val - value).abs();
5865 let vtol = tol_abs + tol_rel * o_val.abs();
5866 max_rel = max_rel.max(vd / o_val.abs().max(1.0));
5867 assert!(
5868 vd <= vtol,
5869 "row {row} value drift: oracle={o_val:.17e} family={value:.17e} \
5870 |Δ|={vd:.3e} > tol={vtol:.3e}"
5871 );
5872
5873 for u in 0..r {
5875 let o_g = oracle.grad[row * r + u];
5876 let f_g = scratch.grad[u];
5877 let gd = (o_g - f_g).abs();
5878 let gtol = tol_abs + tol_rel * o_g.abs();
5879 max_rel = max_rel.max(gd / o_g.abs().max(1.0));
5880 assert!(
5881 gd <= gtol,
5882 "row {row} grad[{u}] drift: oracle={o_g:.17e} family={f_g:.17e} \
5883 |Δ|={gd:.3e} > tol={gtol:.3e}"
5884 );
5885 }
5886
5887 for u in 0..r {
5889 for v in 0..r {
5890 let o_h = oracle.hess[row * r * r + u * r + v];
5891 let f_h = scratch.hess[[u, v]];
5892 let hd = (o_h - f_h).abs();
5893 let htol = tol_abs + tol_rel * o_h.abs();
5894 max_rel = max_rel.max(hd / o_h.abs().max(1.0));
5895 assert!(
5896 hd <= htol,
5897 "row {row} hess[{u},{v}] drift: oracle={o_h:.17e} \
5898 family={f_h:.17e} |Δ|={hd:.3e} > tol={htol:.3e}"
5899 );
5900 }
5901 }
5902 }
5903
5904 assert!(
5907 checked_labels[0] && checked_labels[1],
5908 "#415 fixture must exercise both y=0 and y=1 rows: {checked_labels:?}"
5909 );
5910 eprintln!(
5911 "#415 parity lock: n={n} r={r} p_h={p_h} p_w={p_w} max_rel(oracle−family)={max_rel:.3e}"
5912 );
5913 }
5914}