Skip to main content

Module row

Module row 

Source
Expand description

Stage 2 of the BMS FLEX row kernel — per-row math that turns per-cell derivative moments (built by Stage 1 in src/gpu/cubic_cell/mod.rs) into a row gradient and row-primary r × r Hessian.

Math (mirrors the CPU reference BernoulliMarginalSlope::compute_row_analytic_flex_from_parts_into in src/families/bernoulli_marginal_slope.rs):

For each row i, with per-cell cubic predictor coefficients C_c = (C0, C1, C2, C3) and derivative moments m_0..m_9, build

    κ        = 1 / (2π)
    T_n      = κ · Σ_{e=0..3} C_e · m_{e+n}     (n = 0..6)
    D(R)     = κ · Σ_{k=0..3} R_k · m_k
    Q(R, S)  = Σ_{p,q=0..3} R_p · S_q · T_{p+q}
    H(R, S, U) = D(U) − Q(R, S)

Per cell c, accumulate into row scratch:

    F_a   += D(A_c)
    F_aa  += H(A_c, A_c, AA_c)
    F_u   += D(R_{c,u})                         u > 0
    F_au  += H(A_c, R_{c,u}, AR_{c,u})          u > 0
    F_uv  += H(R_{c,u}, R_{c,v}, S_{c,uv})      0 < u ≤ v

After the cell sum, the q-row is overridden:

    F_q  = −mu_1
    F_qq = −mu_2
    F_qv = 0   (v > 0)
    F_aq = 0

Implicit function theorem (single 1/F_a):

    inv_Fa = 1 / F_a
    a_u    = −F_u · inv_Fa                       (q-row override: mu_1 · inv_Fa)
    a_uv   = −(F_uv + F_au·a_v + F_av·a_u + F_aa·a_u·a_v) · inv_Fa

Observed predictor at z_obs (host supplies pre-evaluated chi, xi, rho, tau, r_uv per row and coordinate):

    bar_e_u  = chi_obs · a_u + rho_u
    bar_e_uv = chi_obs · a_uv + xi_obs · a_u · a_v + tau_u · a_v
               + a_u · tau_v + r_uv

Probit Mills (stable; uses log_ndtr_and_mills from numerics_device::PROBIT_NUMERICS_CU):

    s = 2y − 1 ;  m = s · e_obs
    [log_cdf, λ] = log_ndtr_and_mills(m)
    A = −w · s · λ
    B =  w · λ · (m + λ)

Final outputs:

    neglog   = −w · log_cdf
    g_u      = A · bar_e_u
    H_{uv}   = B · bar_e_u · bar_e_v + A · bar_e_uv     (symmetric)

Implementation choice (Stage 2): one CUDA block per row, with blockDim.x = 32 threads. The block’s F_u, F_au, F_uv, bar_e_u, bar_e_uv live in shared memory; threads in the block parallelise the per-cell sums, then a single thread of the block (threadIdx.x == 0) does the IFT solve, the observed-point assembly, the Mills evaluation, and the final gradient + Hessian write-out. With the r ≤ MAX_R cap (32) the shared-memory footprint per block is r + r + r*r + r + r*r doubles = 2r² + 3r ≤ 2 144 doubles ≈ 17 KB, well below the V100 48 KB per-block limit. This keeps the implementation simple and avoids per-thread global scratch (a per-thread r*r scratch arena would be ~2 GB at n=195k, r=20).

Structs§

DeviceResidentRowHess
Device-resident state produced by [launch_bms_flex_row_kernel_device_resident] and consumed by [launch_bms_flex_row_hvp] / [launch_bms_flex_row_diagonal].

Functions§

bms_flex_row_hvp_multi_scratch_bytes_for_shape
Transient device bytes for a multi-RHS HVP launch, excluding persistent row-Hessian/design storage. Scratch scales with rhs_count * num_chunks * p_total, not rhs_count * n * r * r.
launch_bms_flex_row_dense_block
Launch the Phase-6 dense joint-Hessian block kernel. Returns the host-side [p_total, p_total] row-major joint H as a Vec<f64> (length p_total²).