Expand description
Stage 2 of the BMS FLEX row kernel — per-row math that turns per-cell
derivative moments (built by Stage 1 in src/gpu/cubic_cell/mod.rs) into a
row gradient and row-primary r × r Hessian.
Math (mirrors the CPU reference
BernoulliMarginalSlope::compute_row_analytic_flex_from_parts_into in
src/families/bernoulli_marginal_slope.rs):
For each row i, with per-cell cubic predictor coefficients
C_c = (C0, C1, C2, C3) and derivative moments m_0..m_9, build
κ = 1 / (2π)
T_n = κ · Σ_{e=0..3} C_e · m_{e+n} (n = 0..6)
D(R) = κ · Σ_{k=0..3} R_k · m_k
Q(R, S) = Σ_{p,q=0..3} R_p · S_q · T_{p+q}
H(R, S, U) = D(U) − Q(R, S)Per cell c, accumulate into row scratch:
F_a += D(A_c)
F_aa += H(A_c, A_c, AA_c)
F_u += D(R_{c,u}) u > 0
F_au += H(A_c, R_{c,u}, AR_{c,u}) u > 0
F_uv += H(R_{c,u}, R_{c,v}, S_{c,uv}) 0 < u ≤ vAfter the cell sum, the q-row is overridden:
F_q = −mu_1
F_qq = −mu_2
F_qv = 0 (v > 0)
F_aq = 0Implicit function theorem (single 1/F_a):
inv_Fa = 1 / F_a
a_u = −F_u · inv_Fa (q-row override: mu_1 · inv_Fa)
a_uv = −(F_uv + F_au·a_v + F_av·a_u + F_aa·a_u·a_v) · inv_FaObserved predictor at z_obs (host supplies pre-evaluated chi, xi, rho, tau,
r_uv per row and coordinate):
bar_e_u = chi_obs · a_u + rho_u
bar_e_uv = chi_obs · a_uv + xi_obs · a_u · a_v + tau_u · a_v
+ a_u · tau_v + r_uvProbit Mills (stable; uses log_ndtr_and_mills from numerics_device::PROBIT_NUMERICS_CU):
s = 2y − 1 ; m = s · e_obs
[log_cdf, λ] = log_ndtr_and_mills(m)
A = −w · s · λ
B = w · λ · (m + λ)Final outputs:
neglog = −w · log_cdf
g_u = A · bar_e_u
H_{uv} = B · bar_e_u · bar_e_v + A · bar_e_uv (symmetric)Implementation choice (Stage 2): one CUDA block per row, with
blockDim.x = 32 threads. The block’s F_u, F_au, F_uv, bar_e_u,
bar_e_uv live in shared memory; threads in the block parallelise the
per-cell sums, then a single thread of the block (threadIdx.x == 0) does
the IFT solve, the observed-point assembly, the Mills evaluation, and the
final gradient + Hessian write-out. With the r ≤ MAX_R cap (32) the
shared-memory footprint per block is r + r + r*r + r + r*r doubles
= 2r² + 3r ≤ 2 144 doubles ≈ 17 KB, well below the V100 48 KB per-block
limit. This keeps the implementation simple and avoids per-thread global
scratch (a per-thread r*r scratch arena would be ~2 GB at n=195k, r=20).
Structs§
- Device
Resident RowHess - Device-resident state produced by
[
launch_bms_flex_row_kernel_device_resident] and consumed by [launch_bms_flex_row_hvp] / [launch_bms_flex_row_diagonal].
Functions§
- bms_
flex_ row_ hvp_ multi_ scratch_ bytes_ for_ shape - Transient device bytes for a multi-RHS HVP launch, excluding persistent
row-Hessian/design storage. Scratch scales with
rhs_count * num_chunks * p_total, notrhs_count * n * r * r. - launch_
bms_ flex_ row_ dense_ block - Launch the Phase-6 dense joint-Hessian block kernel. Returns the
host-side
[p_total, p_total]row-major joint H as aVec<f64>(lengthp_total²).