Skip to main content

gam_solve/gpu_kernels/
pirls_row.rs

1//! Generic GPU PIRLS row-reweight primitives.
2//!
3//! Stage 1 of the device-resident PIRLS port: for every row `i` and a fixed
4//! exponential-family `(family, link)` pair, evaluate the working IRLS state
5//! on the GPU and write it back into device-resident output buffers. The
6//! Hessian/gradient assembly that consumes those buffers (Stage 2) and the
7//! full device-resident PIRLS iteration loop (Stage 3) plug in on top of this
8//! contract without touching the per-family math.
9//!
10//! ## Output contract (per row `i`)
11//!
12//! | Field        | Meaning                                                    |
13//! |--------------|------------------------------------------------------------|
14//! | `mu`         | Inverse-link mean μ_i = g⁻¹(η_i).                          |
15//! | `grad_eta`   | Score wrt η: ∂ℓ/∂η_i = wᵢ·(yᵢ − μᵢ)·dη/dV for canonical    |
16//! |              | links; equals priorweight·(yᵢ − μᵢ)·h'(ηᵢ)/V(μᵢ) for non-  |
17//! |              | canonical Bernoulli; equals priorweight·(yᵢ − μᵢ) for      |
18//! |              | Gaussian-identity, Poisson-log, Gamma-log.                 |
19//! | `w_fisher`   | Fisher expected weight (priorweight · h'(η)² / V(μ)).      |
20//! |              | Used for inference (Var(β̂)).                              |
21//! | `w_hessian`  | Curvature weight for the Newton/Laplace Hessian.            |
22//! |              | == w_fisher on canonical links; observed correction on     |
23//! |              | non-canonical Bernoulli + Gamma-log (Stage 5 populates).   |
24//! | `w_solver`   | Stabilised w_hessian used for Cholesky factorisation       |
25//! |              | (floored away from 0 to keep the factor numerically PD).   |
26//! | `z_fisher`   | Working response computed against w_fisher; the legacy     |
27//! |              | "X' W (z − η) − S β" RHS uses this for the score side.     |
28//! | `z_hessian`  | Working response computed against w_hessian (matches       |
29//! |              | w_solver in Stage 5 observed-curvature mode).               |
30//! | `deviance`   | Per-row deviance contribution dᵢ; sum aggregated host-side  |
31//! |              | for line search and convergence checks.                    |
32//! | `status`     | Bitmask of per-row diagnostic flags (η clamped, μ floored, |
33//! |              | non-smooth, y validation failure). OR-reduced on host.     |
34//!
35//! Two-weight discipline is structural in the contract: the gradient is
36//! emitted **directly** (never reconstructed from `w_hessian · (z_hessian − η)`
37//! because in saturated tails that product carries catastrophic cancellation).
38//!
39//! ## Per-family source layout
40//!
41//! Each `(family, link, curvature_mode)` triple has its own specialised CUDA
42//! source compiled to a dedicated module and cached. The kernel never branches
43//! on a runtime `family` enum — that pattern collapses ILP and forces the
44//! compiler to keep dead paths warm. The module cache is keyed by
45//! `(family_id, curvature_mode, precision)` so a single process compiles each
46//! kernel exactly once across all fits.
47
48use std::sync::OnceLock;
49
50use gam_gpu::gpu_error::GpuError;
51#[cfg(target_os = "linux")]
52use gam_gpu::gpu_error::GpuResultExt;
53
54#[cfg(target_os = "linux")]
55use std::sync::{Arc, Mutex};
56
57#[cfg(target_os = "linux")]
58use cudarc::driver::{CudaContext, CudaModule};
59
60// ────────────────────────────────────────────────────────────────────────
61// Public selectors
62// ────────────────────────────────────────────────────────────────────────
63
64/// Which built-in `(response, link)` PIRLS family the row kernel evaluates.
65///
66/// One enum value ↔ one specialised CUDA source ↔ one cached module. Custom
67/// families come in Stage 6 via NVRTC JIT (Level A / Level B) and reuse the
68/// same host harness.
69#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
70pub enum PirlsRowFamily {
71    BernoulliLogit,
72    BernoulliProbit,
73    BernoulliCLogLog,
74    PoissonLog,
75    GaussianIdentity,
76    GammaLog,
77}
78
79impl PirlsRowFamily {
80    pub const ALL: [Self; 6] = [
81        Self::BernoulliLogit,
82        Self::BernoulliProbit,
83        Self::BernoulliCLogLog,
84        Self::PoissonLog,
85        Self::GaussianIdentity,
86        Self::GammaLog,
87    ];
88
89    pub const fn as_str(self) -> &'static str {
90        match self {
91            Self::BernoulliLogit => "bernoulli-logit",
92            Self::BernoulliProbit => "bernoulli-probit",
93            Self::BernoulliCLogLog => "bernoulli-cloglog",
94            Self::PoissonLog => "poisson-log",
95            Self::GaussianIdentity => "gaussian-identity",
96            Self::GammaLog => "gamma-log",
97        }
98    }
99
100    /// CUDA `extern "C"` entry symbol for this family's full (final-row) kernel.
101    pub const fn kernel_name(self) -> &'static str {
102        match self {
103            Self::BernoulliLogit => "pirls_row_bernoulli_logit",
104            Self::BernoulliProbit => "pirls_row_bernoulli_probit",
105            Self::BernoulliCLogLog => "pirls_row_bernoulli_cloglog",
106            Self::PoissonLog => "pirls_row_poisson_log",
107            Self::GaussianIdentity => "pirls_row_gaussian_identity",
108            Self::GammaLog => "pirls_row_gamma_log",
109        }
110    }
111
112    /// CUDA `extern "C"` entry symbol for this family's solve-row kernel
113    /// (writes only `grad_eta`, `w_solver`, `deviance`, `status`).
114    pub const fn solve_kernel_name(self) -> &'static str {
115        match self {
116            Self::BernoulliLogit => "pirls_solve_bernoulli_logit",
117            Self::BernoulliProbit => "pirls_solve_bernoulli_probit",
118            Self::BernoulliCLogLog => "pirls_solve_bernoulli_cloglog",
119            Self::PoissonLog => "pirls_solve_poisson_log",
120            Self::GaussianIdentity => "pirls_solve_gaussian_identity",
121            Self::GammaLog => "pirls_solve_gamma_log",
122        }
123    }
124
125    /// CUDA `extern "C"` entry symbol for this family's alpha-ladder kernel
126    /// (evaluates all step sizes in a single launch, outputs `objective[]` and
127    /// `status[]` per alpha slot).
128    pub const fn ladder_kernel_name(self) -> &'static str {
129        match self {
130            Self::BernoulliLogit => "pirls_ladder_bernoulli_logit",
131            Self::BernoulliProbit => "pirls_ladder_bernoulli_probit",
132            Self::BernoulliCLogLog => "pirls_ladder_bernoulli_cloglog",
133            Self::PoissonLog => "pirls_ladder_poisson_log",
134            Self::GaussianIdentity => "pirls_ladder_gaussian_identity",
135            Self::GammaLog => "pirls_ladder_gamma_log",
136        }
137    }
138
139    /// True for `(response, canonical-link)` pairs where observed information
140    /// equals Fisher information exactly, so `w_hessian == w_fisher` for both
141    /// curvature modes.
142    ///
143    /// Gamma-LOG is **non-canonical**: the canonical Gamma link is the
144    /// reciprocal 1/μ, not log. Under a log link the observed Hessian weight
145    /// is `w_F · y/μ` (shape-independent; the shape cancels), which differs
146    /// from the Fisher weight `w_F` whenever `y ≠ μ`. Consequently
147    /// `CurvatureMode::Observed` produces a different `w_hessian` for
148    /// Gamma-log, and it must not be short-circuited via a canonical-family
149    /// check.
150    pub const fn is_canonical(self) -> bool {
151        match self {
152            Self::BernoulliLogit | Self::PoissonLog | Self::GaussianIdentity => true,
153            Self::GammaLog | Self::BernoulliProbit | Self::BernoulliCLogLog => false,
154        }
155    }
156}
157
158/// Curvature surface used to populate `w_hessian` / `w_solver` / `z_hessian`.
159///
160/// `Fisher` is the default and matches the CPU Stage-1 path bit-for-bit. The
161/// `Observed` mode is populated by Stage 5 for non-canonical Bernoulli and
162/// Gamma-log fits where the negative-log-likelihood Hessian uses the observed
163/// information surface instead of the Fisher expected information.
164#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
165pub enum CurvatureMode {
166    Fisher,
167    Observed,
168}
169
170impl CurvatureMode {
171    pub const fn as_str(self) -> &'static str {
172        match self {
173            Self::Fisher => "fisher",
174            Self::Observed => "observed",
175        }
176    }
177}
178
179/// Per-row diagnostic flag bits, OR-reduced into `status` on the host.
180pub mod status_flags {
181    pub const ETA_CLAMPED: u32 = 1 << 0;
182    pub const MU_FLOORED: u32 = 1 << 1;
183    pub const NONSMOOTH_BERNOULLI: u32 = 1 << 2;
184    pub const INVALID_RESPONSE: u32 = 1 << 3;
185    pub const ZERO_PRIOR_WEIGHT: u32 = 1 << 4;
186}
187
188// ────────────────────────────────────────────────────────────────────────
189// Reference CPU evaluator (parity gate against the GPU kernel).
190//
191// These functions reproduce, byte-for-byte in f64, the formulas in
192// `src/solver/pirls.rs`'s `update_glmvectors` / `write_poisson_log_working_state`
193// / `write_gamma_log_working_state` / `write_identityworking_state`. Stage 1
194// parity tests compare the GPU buffers to these on the V100; mac builds
195// exercise only the CPU reference (the GPU launcher returns
196// `DriverLibraryUnavailable` without a CUDA runtime).
197// ────────────────────────────────────────────────────────────────────────
198
199/// Per-row inputs in scalar form.
200#[derive(Clone, Copy, Debug)]
201pub struct RowInput {
202    pub eta: f64,
203    pub y: f64,
204    pub prior_weight: f64,
205}
206
207/// Per-row outputs matching the GPU kernel contract.
208#[derive(Clone, Copy, Debug, Default)]
209pub struct RowOutput {
210    pub mu: f64,
211    pub grad_eta: f64,
212    pub w_fisher: f64,
213    pub w_hessian: f64,
214    pub w_solver: f64,
215    pub z_fisher: f64,
216    pub z_hessian: f64,
217    pub deviance: f64,
218    pub status: u32,
219}
220
221const ETA_CLAMP: f64 = 700.0;
222const MU_FLOOR_POISSON: f64 = 1.0e-10;
223const MU_FLOOR_GAMMA: f64 = 1.0e-10;
224const MU_FLOOR_BERNOULLI: f64 = 1.0e-12;
225const W_SOLVER_FLOOR: f64 = 1.0e-12;
226/// Cap on `|y − μ|/V(μ)` used to derive the working response for noncanonical
227/// Bernoulli links; matches `bernoulli_exact_working_response` semantics — we
228/// only fall back to `z = η` when `dμ/dη` is non-finite or ≤ 0.
229const DMU_DETA_MIN: f64 = 0.0;
230
231#[inline]
232fn clamp_eta(eta: f64) -> (f64, bool) {
233    if eta > ETA_CLAMP {
234        (ETA_CLAMP, true)
235    } else if eta < -ETA_CLAMP {
236        (-ETA_CLAMP, true)
237    } else {
238        (eta, false)
239    }
240}
241
242/// Reference CPU evaluator for one row. `mode` selects `w_hessian` curvature.
243///
244/// `gamma_shape` is the Gamma dispersion shape parameter (α > 0). It is only
245/// used when `family == GammaLog`; all other families ignore it. Pass `1.0`
246/// for non-Gamma fits.
247pub fn row_reweight_cpu(
248    family: PirlsRowFamily,
249    mode: CurvatureMode,
250    input: RowInput,
251    gamma_shape: f64,
252) -> RowOutput {
253    match family {
254        PirlsRowFamily::GaussianIdentity => row_gaussian_identity(input, mode),
255        PirlsRowFamily::PoissonLog => row_poisson_log(input, mode),
256        PirlsRowFamily::GammaLog => row_gamma_log(input, mode, gamma_shape),
257        PirlsRowFamily::BernoulliLogit => row_bernoulli_logit(input, mode),
258        PirlsRowFamily::BernoulliProbit => row_bernoulli_probit(input, mode),
259        PirlsRowFamily::BernoulliCLogLog => row_bernoulli_cloglog(input, mode),
260    }
261}
262
263/// Resolve `(w_fisher, observed_correction)` into the `w_hessian` value that
264/// matches the selected curvature surface. Stage 1 returns `w_fisher` for both
265/// modes (parity with the CPU PIRLS path that, today, uses Fisher weights
266/// even for non-canonical links); Stage 5 will switch the `Observed` arm to
267/// `w_fisher + observed_correction` and the call sites stay unchanged.
268#[inline]
269fn select_w_hessian(mode: CurvatureMode, w_fisher: f64, observed_correction: f64) -> f64 {
270    match mode {
271        CurvatureMode::Fisher => w_fisher,
272        CurvatureMode::Observed => w_fisher + observed_correction,
273    }
274}
275
276#[inline]
277fn row_gaussian_identity(input: RowInput, mode: CurvatureMode) -> RowOutput {
278    let w = input.prior_weight.max(0.0);
279    let mu = input.eta;
280    let resid = input.y - mu;
281    let dev = w * resid * resid;
282    let status = if input.prior_weight <= 0.0 {
283        status_flags::ZERO_PRIOR_WEIGHT
284    } else {
285        0
286    };
287    // Identity link: Var(Y) is constant in η, so observed == Fisher exactly.
288    let w_hessian = select_w_hessian(mode, w, 0.0);
289    RowOutput {
290        mu,
291        grad_eta: w * resid,
292        w_fisher: w,
293        w_hessian,
294        w_solver: if w_hessian > 0.0 {
295            w_hessian.max(W_SOLVER_FLOOR)
296        } else {
297            0.0
298        },
299        z_fisher: input.y,
300        z_hessian: input.y,
301        deviance: dev,
302        status,
303    }
304}
305
306#[inline]
307fn row_poisson_log(input: RowInput, mode: CurvatureMode) -> RowOutput {
308    let (eta_c, clamped) = clamp_eta(input.eta);
309    let mu_raw = eta_c.exp();
310    let mu_floored = mu_raw < MU_FLOOR_POISSON;
311    let mu = mu_raw.max(MU_FLOOR_POISSON);
312    let w_prior = input.prior_weight.max(0.0);
313    let raw_w = w_prior * mu;
314    let w_fisher = if raw_w > 0.0 {
315        raw_w.max(W_SOLVER_FLOOR)
316    } else {
317        0.0
318    };
319    let resid = input.y - mu;
320    // Saturated Poisson deviance: 2 w [y log(y/μ) − (y − μ)], with y log y ≡ 0
321    // when y = 0. The branch matches the reference CPU implementation.
322    let dev_term = if input.y > 0.0 {
323        input.y * (input.y / mu).ln() - resid
324    } else {
325        -resid
326    };
327    let dev = 2.0 * w_prior * dev_term;
328    let z = eta_c + resid / mu;
329    let mut status = 0u32;
330    if clamped {
331        status |= status_flags::ETA_CLAMPED;
332    }
333    if mu_floored {
334        status |= status_flags::MU_FLOORED;
335    }
336    if input.prior_weight <= 0.0 {
337        status |= status_flags::ZERO_PRIOR_WEIGHT;
338    }
339    if !(input.y.is_finite() && input.y >= 0.0) {
340        status |= status_flags::INVALID_RESPONSE;
341    }
342    // Canonical log link: observed == Fisher (∂²ℓ/∂η² is deterministic in η).
343    let w_hessian = select_w_hessian(mode, w_fisher, 0.0);
344    RowOutput {
345        mu,
346        grad_eta: w_prior * resid,
347        w_fisher,
348        w_hessian,
349        w_solver: w_hessian,
350        z_fisher: z,
351        z_hessian: z,
352        deviance: dev,
353        status,
354    }
355}
356
357#[inline]
358fn row_gamma_log(input: RowInput, mode: CurvatureMode, shape: f64) -> RowOutput {
359    let (eta_c, clamped) = clamp_eta(input.eta);
360    let mu_raw = eta_c.exp();
361    let mu_floored = mu_raw < MU_FLOOR_GAMMA;
362    let mu = mu_raw.max(MU_FLOOR_GAMMA);
363    let w_prior = input.prior_weight.max(0.0);
364    let w_fisher = w_prior * shape;
365    // Stage 5: observed-information weight for Gamma-log.
366    //   -∂²ℓ/∂η² = α · y/μ  (vs Fisher: α).
367    // Correction = w_F · (y/μ − 1). Falls back to Fisher when y == μ
368    // (e.g. saturated y) and when w_F == 0.
369    let obs_correction = if w_fisher > 0.0 && mu > 0.0 && input.y.is_finite() {
370        w_fisher * (input.y / mu - 1.0)
371    } else {
372        0.0
373    };
374    let w_hessian = select_w_hessian(mode, w_fisher, obs_correction);
375    let resid = input.y - mu;
376    // Saturated Gamma deviance: 2 w [−log(y/μ) + (y − μ)/μ].
377    let dev = if input.y > 0.0 {
378        2.0 * w_prior * (-((input.y / mu).ln()) + resid / mu)
379    } else {
380        // y == 0 has zero density under Gamma; carry as +inf deviance via the
381        // INVALID_RESPONSE flag rather than producing a finite spurious value.
382        f64::INFINITY
383    };
384    let z = eta_c + resid / mu;
385    let mut status = 0u32;
386    if clamped {
387        status |= status_flags::ETA_CLAMPED;
388    }
389    if mu_floored {
390        status |= status_flags::MU_FLOORED;
391    }
392    if input.prior_weight <= 0.0 {
393        status |= status_flags::ZERO_PRIOR_WEIGHT;
394    }
395    if !(input.y.is_finite() && input.y > 0.0) {
396        status |= status_flags::INVALID_RESPONSE;
397    }
398    RowOutput {
399        mu,
400        grad_eta: w_prior * resid / mu,
401        w_fisher,
402        w_hessian,
403        w_solver: if w_hessian > 0.0 {
404            w_hessian.max(W_SOLVER_FLOOR)
405        } else {
406            0.0
407        },
408        z_fisher: z,
409        z_hessian: z,
410        deviance: dev,
411        status,
412    }
413}
414
415#[inline]
416fn row_bernoulli_logit(input: RowInput, mode: CurvatureMode) -> RowOutput {
417    let (eta_c, clamped) = clamp_eta(input.eta);
418    // Numerically stable σ(η): use tanh(η/2) form to avoid catastrophic
419    // cancellation for large |η|. μ = (1 + tanh(η/2)) / 2.
420    let half = 0.5 * eta_c;
421    let mu_raw = 0.5 * (1.0 + half.tanh());
422    let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
423    let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
424    let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
425    let w_prior = input.prior_weight.max(0.0);
426    let dmu_deta = mu * (1.0 - mu); // logit canonical: h'(η) = μ(1−μ)
427    let w_fisher = w_prior * dmu_deta; // V(μ) = μ(1−μ), h'(η)² / V = h'(η)
428    let resid = input.y - mu;
429    let grad_eta = w_prior * resid; // priorweight · (y − μ) for logit (canonical)
430    // Saturated Bernoulli deviance: 2 w [y log(y/μ) + (1−y) log((1−y)/(1−μ))].
431    let dev = bernoulli_deviance(input.y, mu, w_prior);
432    let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
433    let mut status = 0u32;
434    if clamped {
435        status |= status_flags::ETA_CLAMPED;
436    }
437    if mu_low || mu_high {
438        status |= status_flags::MU_FLOORED;
439    }
440    if input.prior_weight <= 0.0 {
441        status |= status_flags::ZERO_PRIOR_WEIGHT;
442    }
443    if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
444        status |= status_flags::INVALID_RESPONSE;
445    }
446    let w_hessian = select_w_hessian(mode, w_fisher, 0.0);
447    RowOutput {
448        mu,
449        grad_eta,
450        w_fisher,
451        w_hessian,
452        w_solver: if w_hessian > 0.0 {
453            w_hessian.max(W_SOLVER_FLOOR)
454        } else {
455            0.0
456        },
457        z_fisher: z,
458        z_hessian: z,
459        deviance: dev,
460        status,
461    }
462}
463
464#[inline]
465fn row_bernoulli_probit(input: RowInput, mode: CurvatureMode) -> RowOutput {
466    let (eta_c, clamped) = clamp_eta(input.eta);
467    let mu_raw = standard_normal_cdf(eta_c);
468    let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
469    let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
470    let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
471    let w_prior = input.prior_weight.max(0.0);
472    let dmu_deta = standard_normal_pdf(eta_c); // h'(η) = φ(η)
473    let v = mu * (1.0 - mu);
474    let fisher_per_prior = if v > 0.0 {
475        dmu_deta * dmu_deta / v
476    } else {
477        0.0
478    };
479    let w_fisher = w_prior * fisher_per_prior;
480    let resid = input.y - mu;
481    let grad_eta = if v > 0.0 {
482        w_prior * resid * dmu_deta / v
483    } else {
484        0.0
485    };
486    let dev = bernoulli_deviance(input.y, mu, w_prior);
487    let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
488    let mut status = 0u32;
489    if clamped {
490        status |= status_flags::ETA_CLAMPED;
491    }
492    if mu_low || mu_high {
493        status |= status_flags::MU_FLOORED;
494    }
495    if input.prior_weight <= 0.0 {
496        status |= status_flags::ZERO_PRIOR_WEIGHT;
497    }
498    if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
499        status |= status_flags::INVALID_RESPONSE;
500    }
501    // Stage 5: observed-information correction for Bernoulli probit.
502    //   w_obs = w_F + w_p · (y − μ) · [h'/V − h²·V'/V²].
503    // h(η) = φ(η), h'(η) = −η · φ(η); V(μ) = μ(1−μ), V'(μ) = 1 − 2μ.
504    let obs_correction = if v > 0.0 && w_prior > 0.0 {
505        let h_prime = -eta_c * dmu_deta;
506        let v_prime = 1.0 - 2.0 * mu;
507        let bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
508        w_prior * resid * bracket
509    } else {
510        0.0
511    };
512    let w_hessian_observed = select_w_hessian(mode, w_fisher, obs_correction);
513    RowOutput {
514        mu,
515        grad_eta,
516        w_fisher,
517        w_hessian: w_hessian_observed,
518        w_solver: {
519            let wh = w_hessian_observed;
520            if wh > 0.0 {
521                wh.max(W_SOLVER_FLOOR)
522            } else {
523                0.0
524            }
525        },
526        z_fisher: z,
527        z_hessian: z,
528        deviance: dev,
529        status,
530    }
531}
532
533#[inline]
534fn row_bernoulli_cloglog(input: RowInput, mode: CurvatureMode) -> RowOutput {
535    let (eta_c, clamped) = clamp_eta(input.eta);
536    // μ = 1 − exp(−exp(η)); numerically stable via expm1 to preserve precision
537    // in the deep negative tail (η ≲ -36) where `1 - exp(-exp(η))` would
538    // catastrophically cancel to 0.
539    let inner = eta_c.exp();
540    let mu_raw = -(-inner).exp_m1();
541    let mu_low = mu_raw < MU_FLOOR_BERNOULLI;
542    let mu_high = mu_raw > 1.0 - MU_FLOOR_BERNOULLI;
543    let mu = mu_raw.clamp(MU_FLOOR_BERNOULLI, 1.0 - MU_FLOOR_BERNOULLI);
544    // h'(η) = dμ/dη = exp(η − exp(η)) = inner · (1 − μ_raw).
545    // Use the unclamped form to avoid biasing the derivative on the saturated edge.
546    let dmu_deta = inner * (1.0 - mu_raw);
547    let w_prior = input.prior_weight.max(0.0);
548    let v = mu * (1.0 - mu);
549    let fisher_per_prior = if v > 0.0 {
550        dmu_deta * dmu_deta / v
551    } else {
552        0.0
553    };
554    let w_fisher = w_prior * fisher_per_prior;
555    let resid = input.y - mu;
556    let grad_eta = if v > 0.0 {
557        w_prior * resid * dmu_deta / v
558    } else {
559        0.0
560    };
561    let dev = bernoulli_deviance(input.y, mu, w_prior);
562    let z = bernoulli_z(eta_c, input.y, mu, dmu_deta);
563    let mut status = 0u32;
564    if clamped {
565        status |= status_flags::ETA_CLAMPED;
566    }
567    if mu_low || mu_high {
568        status |= status_flags::MU_FLOORED;
569    }
570    if input.prior_weight <= 0.0 {
571        status |= status_flags::ZERO_PRIOR_WEIGHT;
572    }
573    if !(input.y.is_finite() && (0.0..=1.0).contains(&input.y)) {
574        status |= status_flags::INVALID_RESPONSE;
575    }
576    // Stage 5: observed-information correction for Bernoulli cloglog.
577    //   w_obs = w_F + w_p · (y − μ) · [h'/V − h²·V'/V²].
578    // h(η) = inner · (1 − μ_raw); h'(η) = h(η) · (1 − inner).
579    // V(μ) = μ(1−μ), V'(μ) = 1 − 2μ.
580    let obs_correction = if v > 0.0 && w_prior > 0.0 {
581        let h_prime = dmu_deta * (1.0 - inner);
582        let v_prime = 1.0 - 2.0 * mu;
583        let bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
584        w_prior * resid * bracket
585    } else {
586        0.0
587    };
588    let w_hessian = select_w_hessian(mode, w_fisher, obs_correction);
589    RowOutput {
590        mu,
591        grad_eta,
592        w_fisher,
593        w_hessian,
594        w_solver: if w_hessian > 0.0 {
595            w_hessian.max(W_SOLVER_FLOOR)
596        } else {
597            0.0
598        },
599        z_fisher: z,
600        z_hessian: z,
601        deviance: dev,
602        status,
603    }
604}
605
606#[inline]
607fn bernoulli_deviance(y: f64, mu: f64, w_prior: f64) -> f64 {
608    if w_prior == 0.0 {
609        return 0.0;
610    }
611    let t1 = if y > 0.0 { y * (y / mu).ln() } else { 0.0 };
612    let t2 = if y < 1.0 {
613        (1.0 - y) * ((1.0 - y) / (1.0 - mu)).ln()
614    } else {
615        0.0
616    };
617    2.0 * w_prior * (t1 + t2)
618}
619
620#[inline]
621fn bernoulli_z(eta_used: f64, y: f64, mu: f64, dmu_deta: f64) -> f64 {
622    if dmu_deta.is_finite() && dmu_deta > DMU_DETA_MIN {
623        let delta = (y - mu) / dmu_deta;
624        if delta.is_finite() {
625            return eta_used + delta;
626        }
627    }
628    eta_used
629}
630
631/// Stable Φ(x) using the complementary error function with the same identity
632/// `erfc(-x/√2)/2 = Φ(x)` used by libstd. Keeps mass at the tails accurate.
633#[inline]
634fn standard_normal_cdf(x: f64) -> f64 {
635    0.5 * gam_gpu::numerics_host::erfc(-x * std::f64::consts::FRAC_1_SQRT_2)
636}
637
638#[inline]
639fn standard_normal_pdf(x: f64) -> f64 {
640    const COEFF: f64 = 0.398_942_280_401_432_7; // 1 / sqrt(2π)
641    COEFF * (-0.5 * x * x).exp()
642}
643
644// ────────────────────────────────────────────────────────────────────────
645// CUDA host harness
646// ────────────────────────────────────────────────────────────────────────
647
648/// Process-wide cache of compiled per-family modules.
649#[must_use]
650pub struct PirlsRowBackend {
651    #[cfg(target_os = "linux")]
652    inner: PirlsRowBackendLinux,
653}
654
655#[cfg(target_os = "linux")]
656struct PirlsRowBackendLinux {
657    ctx: Arc<CudaContext>,
658    modules: Mutex<std::collections::HashMap<ModuleKey, Arc<CudaModule>>>,
659    /// Stage 6: separate cache for JIT-compiled custom-family modules
660    /// keyed by `(spec_id, curvature)`. Distinct JIT specs in the same
661    /// process get distinct cached modules.
662    jit_modules: Mutex<std::collections::HashMap<JitKey, Arc<CudaModule>>>,
663}
664
665/// Distinguishes the three kernel modes in the per-process module cache.
666#[cfg(target_os = "linux")]
667#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
668enum KernelMode {
669    /// Full final-row kernel (9 outputs: mu, grad_eta, w_fisher, w_hessian,
670    /// w_solver, z_fisher, z_hessian, deviance, status).
671    FinalRow,
672    /// Solve-row kernel (4 outputs: grad_eta, w_solver, deviance, status).
673    SolveRow,
674    /// Alpha-ladder kernel (2 per-alpha outputs: objective[], status[]).
675    AlphaLadder,
676}
677
678#[cfg(target_os = "linux")]
679#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
680struct ModuleKey {
681    family: PirlsRowFamily,
682    curvature: CurvatureMode,
683    mode: KernelMode,
684}
685
686impl PirlsRowBackend {
687    pub const fn compiled() -> bool {
688        cfg!(target_os = "linux")
689    }
690
691    pub fn probe() -> Result<&'static Self, GpuError> {
692        static BACKEND: OnceLock<Result<PirlsRowBackend, GpuError>> = OnceLock::new();
693        BACKEND
694            .get_or_init(|| {
695                #[cfg(target_os = "linux")]
696                {
697                    Self::probe_linux()
698                }
699                #[cfg(not(target_os = "linux"))]
700                {
701                    Err(GpuError::DriverLibraryUnavailable {
702                        reason: "pirls_row GPU backend is Linux-only".to_string(),
703                    })
704                }
705            })
706            .as_ref()
707            .map_err(GpuError::clone)
708    }
709
710    #[cfg(target_os = "linux")]
711    fn probe_linux() -> Result<Self, GpuError> {
712        let parts = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")?;
713        Ok(Self {
714            inner: PirlsRowBackendLinux {
715                ctx: parts.ctx,
716                modules: Mutex::new(std::collections::HashMap::new()),
717                jit_modules: Mutex::new(std::collections::HashMap::new()),
718            },
719        })
720    }
721
722    /// Compile (or fetch from cache) the kernel module for `(family, curvature)`
723    /// in the given [`KernelMode`]. This is the single source of truth behind
724    /// [`module_for`], [`module_for_solve`], and [`module_for_ladder`]; the only
725    /// per-mode variation is which CUDA source generator is used (selected by
726    /// `mode`) and the error label `label` woven into compile/load diagnostics.
727    #[cfg(target_os = "linux")]
728    fn module_for_kind(
729        &self,
730        family: PirlsRowFamily,
731        curvature: CurvatureMode,
732        mode: KernelMode,
733        label: &str,
734    ) -> Result<Arc<CudaModule>, GpuError> {
735        let key = ModuleKey {
736            family,
737            curvature,
738            mode,
739        };
740        if let Some(existing) = self
741            .inner
742            .modules
743            .lock()
744            .gpu_ctx_with(|err| format!("pirls_row {label}module cache mutex poisoned: {err}"))?
745            .get(&key)
746        {
747            return Ok(existing.clone());
748        }
749        let source = match mode {
750            KernelMode::FinalRow => cuda_source_for(family, curvature),
751            KernelMode::SolveRow => solve_row_source_for(family, curvature),
752            KernelMode::AlphaLadder => ladder_source_for(family, curvature),
753        };
754        // #1551: route through the device-arch-pinned compile — this kernel uses
755        // `atomicAdd(double*, double)` (objective_out), which NVRTC rejects under
756        // its default sub-sm_60 arch, silently disabling the device PIRLS path.
757        let ptx = gam_gpu::device_cache::compile_ptx_arch(&source).gpu_ctx_with(|err| {
758            format!(
759                "pirls_row {label}NVRTC compile failed for {family}/{curv}: {err}",
760                family = family.as_str(),
761                curv = curvature.as_str(),
762            )
763        })?;
764        let module = self
765            .inner
766            .ctx
767            .load_module(ptx)
768            .gpu_ctx_with(|err| format!("pirls_row {label}module load failed: {err}"))?;
769        self.inner
770            .modules
771            .lock()
772            .gpu_ctx_with(|err| format!("pirls_row {label}module cache mutex poisoned: {err}"))?
773            .insert(key, module.clone());
774        Ok(module)
775    }
776
777    /// Compile (or fetch from cache) the **final-row** kernel module for
778    /// `(family, curvature)`. Writes all 9 output fields.
779    #[cfg(target_os = "linux")]
780    pub fn module_for(
781        &self,
782        family: PirlsRowFamily,
783        curvature: CurvatureMode,
784    ) -> Result<Arc<CudaModule>, GpuError> {
785        self.module_for_kind(family, curvature, KernelMode::FinalRow, "")
786    }
787
788    /// Compile (or fetch from cache) the **solve-row** kernel module for
789    /// `(family, curvature)`. Writes only `grad_eta`, `w_solver`, `deviance`,
790    /// `status` — used on every hot Newton iteration.
791    #[cfg(target_os = "linux")]
792    pub fn module_for_solve(
793        &self,
794        family: PirlsRowFamily,
795        curvature: CurvatureMode,
796    ) -> Result<Arc<CudaModule>, GpuError> {
797        self.module_for_kind(family, curvature, KernelMode::SolveRow, "solve ")
798    }
799
800    /// Compile (or fetch from cache) the **alpha-ladder** kernel module for
801    /// `(family, curvature)`. Evaluates all [`ALPHA_LADDER_LEN`] step sizes in
802    /// a single launch, accumulating `objective[]` and `status[]` per alpha slot.
803    #[cfg(target_os = "linux")]
804    pub fn module_for_ladder(
805        &self,
806        family: PirlsRowFamily,
807        curvature: CurvatureMode,
808    ) -> Result<Arc<CudaModule>, GpuError> {
809        self.module_for_kind(family, curvature, KernelMode::AlphaLadder, "ladder ")
810    }
811
812    /// Stage 6: JIT-compile and cache a custom-family row module.
813    ///
814    /// The kernel name is `pirls_row_jit_{spec.spec_id}` so multiple
815    /// distinct JIT specs in the same process get distinct cached
816    /// modules. The cache key is `(spec_id, curvature)` which mirrors
817    /// the built-in `(family, curvature)` cache and reuses the same
818    /// HashMap-of-`Arc<CudaModule>` (but with a synthetic `ModuleKey`
819    /// derived from the spec_id).
820    ///
821    /// Note: a fresh `(spec_id, curvature)` recompiles via NVRTC the
822    /// first time; subsequent fits in the same process hit the cache.
823    /// Spec changes (different body) must use a different `spec_id` so
824    /// that the cache does NOT return a stale module.
825    #[cfg(target_os = "linux")]
826    pub fn module_for_jit(
827        &self,
828        spec: &JitFamilySpec,
829        curvature: CurvatureMode,
830    ) -> Result<Arc<CudaModule>, GpuError> {
831        // Reuse the built-in ModuleKey by mapping `spec_id` to a
832        // synthetic family slot. We piggy-back on the cache by keying
833        // off a hashed family enum value won't fit cleanly; instead
834        // use a separate JIT cache HashMap.
835        let key = JitKey {
836            spec_id: spec.spec_id,
837            curvature,
838        };
839        if let Some(existing) = self
840            .inner
841            .jit_modules
842            .lock()
843            .gpu_ctx("pirls_row jit cache poisoned")?
844            .get(&key)
845        {
846            return Ok(existing.clone());
847        }
848        let source = spec.cuda_source(curvature);
849        // #1551: device-arch-pinned compile (double-atomic objective_out kernel).
850        let ptx = gam_gpu::device_cache::compile_ptx_arch(&source).gpu_ctx_with(|err| {
851            format!(
852                "pirls_row JIT NVRTC compile failed for spec_id={} curvature={}: {err}",
853                spec.spec_id,
854                curvature.as_str(),
855            )
856        })?;
857        let module = self
858            .inner
859            .ctx
860            .load_module(ptx)
861            .gpu_ctx("pirls_row JIT module load failed")?;
862        self.inner
863            .jit_modules
864            .lock()
865            .gpu_ctx("pirls_row jit cache poisoned (insert)")?
866            .insert(key, module.clone());
867        Ok(module)
868    }
869}
870
871/// Stage 6 cache key for JIT-compiled family modules.
872#[cfg(target_os = "linux")]
873#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
874struct JitKey {
875    spec_id: u64,
876    curvature: CurvatureMode,
877}
878
879/// Stage 6 custom-family JIT specification.
880///
881/// Two levels per the charter:
882/// - **Level A** (`JitFamilySpec::glm`): provide a `(family, link)` enum
883///   value plus an optional shape constant. The generator emits the
884///   matching built-in row body, identical to the cached built-in
885///   kernel — useful for end-to-end JIT path validation against the
886///   built-in cache.
887/// - **Level B** (`JitFamilySpec::raw`): provide raw CUDA source for
888///   the row body. The body must define the same per-row locals the
889///   kernel shell expects: `mu`, `grad_eta`, `w_fisher`, `w_hessian`,
890///   `w_solver`, `z_f`, `z_h`, `dev`, and update `flags`. The shell
891///   wraps it in the canonical
892///   `extern "C" __global__ void pirls_row_jit_{spec_id}(...)`
893///   signature that [`launch_row_reweight_on_stream`] expects.
894#[derive(Clone, Debug)]
895pub struct JitFamilySpec {
896    /// Process-unique identifier for this spec; the module cache uses
897    /// it as a key so callers must reuse the same `spec_id` for the
898    /// same body and pick a new one whenever the body changes.
899    pub spec_id: u64,
900    /// CUDA body source. Must read from `eta_c`, `y_i`, `wp`, set
901    /// `flags`, and assign to `mu`, `grad_eta`, `w_fisher`, `w_hessian`,
902    /// `w_solver`, `z_f`, `z_h`, `dev`. See [`COMMON_DEVICE_PROLOG`] for
903    /// the available helpers.
904    pub body: String,
905}
906
907impl JitFamilySpec {
908    /// Level A: build a spec from a built-in `(family, curvature)`
909    /// pair. The generator reuses the same per-family body as the
910    /// built-in cached kernel — useful to validate the JIT pipeline
911    /// end-to-end against the built-in numerical reference.
912    #[cfg(target_os = "linux")]
913    pub fn glm(spec_id: u64, family: PirlsRowFamily, curvature: CurvatureMode) -> Self {
914        let body = match family {
915            PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
916            PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
917            PirlsRowFamily::GammaLog => gamma_log_body(curvature),
918            PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
919            PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
920            PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
921        };
922        Self { spec_id, body }
923    }
924
925    /// Level B: build a spec from caller-supplied body source. The
926    /// kernel shell wraps it; the body must define the required locals
927    /// listed on [`JitFamilySpec`].
928    pub fn raw(spec_id: u64, body: impl Into<String>) -> Self {
929        Self {
930            spec_id,
931            body: body.into(),
932        }
933    }
934
935    /// The `extern "C"` kernel symbol the JIT-compiled module exposes.
936    pub fn kernel_name(&self) -> String {
937        format!("pirls_row_jit_{}", self.spec_id)
938    }
939
940    /// Build the full CUDA source ready for NVRTC compilation. The
941    /// shell + prolog match the built-in `cuda_source_for` so the JIT
942    /// kernel ABI is bit-identical to the cached built-ins;
943    /// [`launch_row_reweight_on_stream`] cannot tell the difference.
944    #[cfg(target_os = "linux")]
945    pub fn cuda_source(&self, curvature: CurvatureMode) -> String {
946        let curvature_define = match curvature {
947            CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
948            CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
949        };
950        let kernel_name = self.kernel_name();
951        let body = &self.body;
952        format!(
953            r#"
954{curvature_define}
955{prolog}
956
957extern "C" __global__ void {kernel_name}(
958    int            n,
959    const double* __restrict__ eta,
960    const double* __restrict__ y,
961    const double* __restrict__ prior_w,
962    double* __restrict__ mu_out,
963    double* __restrict__ grad_eta_out,
964    double* __restrict__ w_fisher_out,
965    double* __restrict__ w_hessian_out,
966    double* __restrict__ w_solver_out,
967    double* __restrict__ z_fisher_out,
968    double* __restrict__ z_hessian_out,
969    double* __restrict__ deviance_out,
970    unsigned int* __restrict__ status_out
971) {{
972    int i = blockIdx.x * blockDim.x + threadIdx.x;
973    if (i >= n) return;
974    unsigned int flags = 0u;
975    double eta_i = eta[i];
976    double y_i = y[i];
977    double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
978    if (prior_w[i] <= 0.0) flags |= 0x10u;
979{body}
980    mu_out[i] = mu;
981    grad_eta_out[i] = grad_eta;
982    w_fisher_out[i] = w_fisher;
983    w_hessian_out[i] = w_hessian;
984    w_solver_out[i] = w_solver;
985    z_fisher_out[i] = z_f;
986    z_hessian_out[i] = z_h;
987    deviance_out[i] = dev;
988    status_out[i] = flags;
989}}
990"#,
991            prolog = COMMON_DEVICE_PROLOG,
992        )
993    }
994}
995
996/// Device-resident per-row output buffers for the GPU row-reweight kernel.
997///
998/// **final-row mode**: all nine per-[`RowOutput`] fields, length `n`. Written
999/// once at convergence by [`launch_row_reweight_on_stream`]. For the hot
1000/// inner-loop use [`SolveRowBuffers`]; for line-search use
1001/// [`AlphaLadderDevBuffers`].
1002#[cfg(target_os = "linux")]
1003pub struct RowOutputDevBuffers {
1004    pub mu: cudarc::driver::CudaSlice<f64>,
1005    pub grad_eta: cudarc::driver::CudaSlice<f64>,
1006    pub w_fisher: cudarc::driver::CudaSlice<f64>,
1007    pub w_hessian: cudarc::driver::CudaSlice<f64>,
1008    pub w_solver: cudarc::driver::CudaSlice<f64>,
1009    pub z_fisher: cudarc::driver::CudaSlice<f64>,
1010    pub z_hessian: cudarc::driver::CudaSlice<f64>,
1011    pub deviance: cudarc::driver::CudaSlice<f64>,
1012    pub status: cudarc::driver::CudaSlice<u32>,
1013    pub n: usize,
1014}
1015
1016#[cfg(target_os = "linux")]
1017impl RowOutputDevBuffers {
1018    /// Allocate all nine per-row output buffers (length `n`) on `stream`.
1019    pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>, n: usize) -> Result<Self, GpuError> {
1020        let alloc_f64 = |label: &'static str| {
1021            stream
1022                .alloc_zeros::<f64>(n)
1023                .gpu_ctx_with(|err| format!("pirls_row alloc {label}: {err}"))
1024        };
1025        let alloc_u32 = |label: &'static str| {
1026            stream
1027                .alloc_zeros::<u32>(n)
1028                .gpu_ctx_with(|err| format!("pirls_row alloc {label}: {err}"))
1029        };
1030        Ok(Self {
1031            mu: alloc_f64("mu")?,
1032            grad_eta: alloc_f64("grad_eta")?,
1033            w_fisher: alloc_f64("w_fisher")?,
1034            w_hessian: alloc_f64("w_hessian")?,
1035            w_solver: alloc_f64("w_solver")?,
1036            z_fisher: alloc_f64("z_fisher")?,
1037            z_hessian: alloc_f64("z_hessian")?,
1038            deviance: alloc_f64("deviance")?,
1039            status: alloc_u32("status")?,
1040            n,
1041        })
1042    }
1043}
1044
1045/// Device-resident per-row output buffers for the **solve-row** mode.
1046///
1047/// Allocates only the four fields the PIRLS solver reads on every Newton
1048/// iteration: `grad_eta` (score for Xᵀg RHS), `w_solver` (working weight
1049/// for XᵀWX assembly), `deviance` (per-row deviance for convergence check),
1050/// and `status` (diagnostic flags OR-reduced to a single host u32). Written
1051/// by [`launch_solve_row_on_stream`]; used instead of [`RowOutputDevBuffers`]
1052/// during the hot inner loop to reduce device memory and kernel store traffic.
1053#[cfg(target_os = "linux")]
1054pub struct SolveRowBuffers {
1055    /// ∂ℓ/∂η_i — score for Xᵀg RHS formation.
1056    pub grad_eta: cudarc::driver::CudaSlice<f64>,
1057    /// Stabilised Hessian weight — fed to XᵀWX assembly.
1058    pub w_solver: cudarc::driver::CudaSlice<f64>,
1059    /// Per-row deviance contribution — summed for convergence check.
1060    pub deviance: cudarc::driver::CudaSlice<f64>,
1061    /// Bitmask flags OR-reduced to detect numerical issues.
1062    pub status: cudarc::driver::CudaSlice<u32>,
1063    pub n: usize,
1064}
1065
1066#[cfg(target_os = "linux")]
1067impl SolveRowBuffers {
1068    /// Allocate the four solve-row output buffers (length `n`) on `stream`.
1069    pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>, n: usize) -> Result<Self, GpuError> {
1070        let alloc_f64 = |label: &'static str| {
1071            stream
1072                .alloc_zeros::<f64>(n)
1073                .gpu_ctx_with(|err| format!("pirls_row solve alloc {label}: {err}"))
1074        };
1075        let alloc_u32 = |label: &'static str| {
1076            stream
1077                .alloc_zeros::<u32>(n)
1078                .gpu_ctx_with(|err| format!("pirls_row solve alloc {label}: {err}"))
1079        };
1080        Ok(Self {
1081            grad_eta: alloc_f64("grad_eta")?,
1082            w_solver: alloc_f64("w_solver")?,
1083            deviance: alloc_f64("deviance")?,
1084            status: alloc_u32("status")?,
1085            n,
1086        })
1087    }
1088}
1089
1090/// Number of alpha step sizes in the fused alpha ladder.
1091pub const ALPHA_LADDER_LEN: usize = 7;
1092
1093/// The fixed alpha step-size ladder: `[1, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625]`.
1094pub const ALPHA_LADDER: [f64; ALPHA_LADDER_LEN] =
1095    [1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625];
1096
1097/// Device buffers for the fused alpha-ladder candidate-objective kernel.
1098///
1099/// **candidate-objective mode**: for each of the [`ALPHA_LADDER_LEN`] step
1100/// sizes α_k the kernel evaluates `η_trial_i = η_i + α_k · xδ_i`, computes
1101/// the per-row deviance, and atomically accumulates the sum into
1102/// `objective_dev[k]`. Status flags are OR-accumulated into `status_dev[k]`.
1103/// After a single `memcpy_dtoh` the host picks the first α that achieves
1104/// deviance descent — no per-α kernel launch, no full row-output write.
1105#[cfg(target_os = "linux")]
1106pub struct AlphaLadderDevBuffers {
1107    /// Device: summed deviance for each alpha step, length [`ALPHA_LADDER_LEN`].
1108    pub objective_dev: cudarc::driver::CudaSlice<f64>,
1109    /// Device: OR-reduced status flags for each alpha step, length [`ALPHA_LADDER_LEN`].
1110    pub status_dev: cudarc::driver::CudaSlice<u32>,
1111}
1112
1113#[cfg(target_os = "linux")]
1114impl AlphaLadderDevBuffers {
1115    /// Allocate the ladder device buffers on `stream`.
1116    pub fn allocate(stream: &Arc<cudarc::driver::CudaStream>) -> Result<Self, GpuError> {
1117        Ok(Self {
1118            objective_dev: stream
1119                .alloc_zeros::<f64>(ALPHA_LADDER_LEN)
1120                .gpu_ctx_with(|err| format!("pirls_row ladder alloc objective: {err}"))?,
1121            status_dev: stream
1122                .alloc_zeros::<u32>(ALPHA_LADDER_LEN)
1123                .gpu_ctx_with(|err| format!("pirls_row ladder alloc status: {err}"))?,
1124        })
1125    }
1126
1127    /// Zero all per-alpha accumulators in-place (call before each ladder launch).
1128    pub fn zero(&mut self, stream: &Arc<cudarc::driver::CudaStream>) -> Result<(), GpuError> {
1129        stream
1130            .memset_zeros(&mut self.objective_dev)
1131            .gpu_ctx_with(|err| format!("pirls_row ladder zero objective: {err}"))?;
1132        stream
1133            .memset_zeros(&mut self.status_dev)
1134            .gpu_ctx_with(|err| format!("pirls_row ladder zero status: {err}"))
1135    }
1136}
1137
1138/// Device-side row reweight launcher.
1139///
1140/// Resolves the cached per-family kernel from [`PirlsRowBackend::module_for`],
1141/// dispatches a 1D grid of `THREADS_PER_BLOCK = 256` threads across `n`
1142/// rows, and returns once the launch is enqueued on `stream`. The kernel
1143/// writes the per-row IRLS state into `out` in place; no host transfers.
1144///
1145/// The kernel's `extern "C"` signature is fixed at the top of `cuda_source_for`
1146/// (see `extern "C" __global__ void {kernel_name}(int n, …)` in this file).
1147/// `gamma_shape`: active Gamma dispersion shape (α > 0). Forwarded as a
1148/// scalar kernel argument only for `PirlsRowFamily::GammaLog`; all other
1149/// families compile a 13-argument kernel and ignore this value. Pass `1.0`
1150/// for non-Gamma fits.
1151#[cfg(target_os = "linux")]
1152pub fn launch_row_reweight_on_stream(
1153    backend: &PirlsRowBackend,
1154    family: PirlsRowFamily,
1155    curvature: CurvatureMode,
1156    gamma_shape: f64,
1157    stream: &Arc<cudarc::driver::CudaStream>,
1158    n: usize,
1159    eta_dev: &cudarc::driver::CudaSlice<f64>,
1160    y_dev: &cudarc::driver::CudaSlice<f64>,
1161    prior_w_dev: &cudarc::driver::CudaSlice<f64>,
1162    out: &mut RowOutputDevBuffers,
1163) -> Result<(), GpuError> {
1164    use cudarc::driver::{LaunchConfig, PushKernelArg};
1165    if out.n != n {
1166        gam_gpu::gpu_bail!("row reweight buffers shape {} mismatches n={n}", out.n);
1167    }
1168    let module = backend.module_for(family, curvature)?;
1169    let func = module
1170        .load_function(family.kernel_name())
1171        .gpu_ctx_with(|err| {
1172            format!(
1173                "row reweight load_function({}): {err}",
1174                family.kernel_name()
1175            )
1176        })?;
1177    const THREADS_PER_BLOCK: u32 = 256;
1178    let n_u32 =
1179        u32::try_from(n).map_err(|_| gam_gpu::gpu_err!("n={n} exceeds u32 for row reweight grid sizing"))?;
1180    let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
1181    let n_i32 = i32::try_from(n)
1182        .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds i32 for row reweight kernel argument"))?;
1183    let cfg = LaunchConfig {
1184        grid_dim: (grid_x, 1, 1),
1185        block_dim: (THREADS_PER_BLOCK, 1, 1),
1186        shared_mem_bytes: 0,
1187    };
1188    let mut builder = stream.launch_builder(&func);
1189    builder.arg(&n_i32);
1190    builder.arg(eta_dev);
1191    builder.arg(y_dev);
1192    builder.arg(prior_w_dev);
1193    // GammaLog kernel has `double shape` before the output buffers.
1194    if matches!(family, PirlsRowFamily::GammaLog) {
1195        builder.arg(&gamma_shape);
1196    }
1197    builder.arg(&mut out.mu);
1198    builder.arg(&mut out.grad_eta);
1199    builder.arg(&mut out.w_fisher);
1200    builder.arg(&mut out.w_hessian);
1201    builder.arg(&mut out.w_solver);
1202    builder.arg(&mut out.z_fisher);
1203    builder.arg(&mut out.z_hessian);
1204    builder.arg(&mut out.deviance);
1205    builder.arg(&mut out.status);
1206    // SAFETY: kernel signature for non-GammaLog is (n:i32, 3×const f64*,
1207    // 8×mut f64*, 1×mut u32*). GammaLog extends this with `double shape`
1208    // after `prior_w` and before the output buffers — see `cuda_source_for`.
1209    // Arg order/types match one-for-one. Output buffers were allocated with
1210    // `n` elements each (validated above); input buffers are caller-supplied
1211    // with length n. Grid covers all n rows; threads guard `if (i >= n) return`.
1212    unsafe { builder.launch(cfg) }
1213        .map(|_event_pair| ())
1214        .gpu_ctx_with(|err| format!("row reweight launch({}): {err}", family.kernel_name()))
1215}
1216
1217/// Stage 6: device-side row reweight launcher for JIT-compiled
1218/// custom-family kernels. Same kernel ABI as the built-in path
1219/// ([`launch_row_reweight_on_stream`]) — the only differences are
1220/// (a) the kernel symbol is `spec.kernel_name()` and (b) the module
1221/// resolution goes through [`PirlsRowBackend::module_for_jit`].
1222#[cfg(target_os = "linux")]
1223pub fn launch_row_reweight_jit_on_stream(
1224    backend: &PirlsRowBackend,
1225    spec: &JitFamilySpec,
1226    curvature: CurvatureMode,
1227    stream: &Arc<cudarc::driver::CudaStream>,
1228    n: usize,
1229    eta_dev: &cudarc::driver::CudaSlice<f64>,
1230    y_dev: &cudarc::driver::CudaSlice<f64>,
1231    prior_w_dev: &cudarc::driver::CudaSlice<f64>,
1232    out: &mut RowOutputDevBuffers,
1233) -> Result<(), GpuError> {
1234    use cudarc::driver::{LaunchConfig, PushKernelArg};
1235    if out.n != n {
1236        gam_gpu::gpu_bail!("JIT row reweight buffers shape {} mismatches n={n}", out.n);
1237    }
1238    let module = backend.module_for_jit(spec, curvature)?;
1239    let kernel_name = spec.kernel_name();
1240    let func = module
1241        .load_function(&kernel_name)
1242        .gpu_ctx_with(|err| format!("JIT row reweight load_function({kernel_name}): {err}"))?;
1243    const THREADS_PER_BLOCK: u32 = 256;
1244    let n_u32 = u32::try_from(n)
1245        .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds u32 for JIT row reweight grid sizing"))?;
1246    let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
1247    let n_i32 = i32::try_from(n)
1248        .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds i32 for JIT row reweight kernel argument"))?;
1249    let cfg = LaunchConfig {
1250        grid_dim: (grid_x, 1, 1),
1251        block_dim: (THREADS_PER_BLOCK, 1, 1),
1252        shared_mem_bytes: 0,
1253    };
1254    let mut builder = stream.launch_builder(&func);
1255    builder.arg(&n_i32);
1256    builder.arg(eta_dev);
1257    builder.arg(y_dev);
1258    builder.arg(prior_w_dev);
1259    builder.arg(&mut out.mu);
1260    builder.arg(&mut out.grad_eta);
1261    builder.arg(&mut out.w_fisher);
1262    builder.arg(&mut out.w_hessian);
1263    builder.arg(&mut out.w_solver);
1264    builder.arg(&mut out.z_fisher);
1265    builder.arg(&mut out.z_hessian);
1266    builder.arg(&mut out.deviance);
1267    builder.arg(&mut out.status);
1268    // SAFETY: JIT spec's `cuda_source` builder emits the same kernel
1269    // signature as `cuda_source_for`; arg order/types match one-for-one.
1270    unsafe { builder.launch(cfg) }
1271        .map(|_event_pair| ())
1272        .gpu_ctx_with(|err| format!("JIT row reweight launch({kernel_name}): {err}"))
1273}
1274
1275/// **solve-row** mode launcher.
1276///
1277/// Runs the per-family row math and writes only the four fields needed by the
1278/// PIRLS solver on each Newton iteration: `grad_eta`, `w_solver`, `deviance`,
1279/// `status`. The CUDA kernel is compiled from a specialised source
1280/// (`solve_row_source_for`) that skips the `mu`, `w_fisher`, `w_hessian`,
1281/// `z_fisher`, `z_hessian` stores, reducing both bandwidth and register
1282/// pressure relative to [`launch_row_reweight_on_stream`].
1283///
1284/// Call once per Newton step on the accepted η. At convergence, call
1285/// [`launch_row_reweight_on_stream`] (final-row mode) to populate the full
1286/// output surface before downloading.
1287///
1288/// `gamma_shape`: active Gamma dispersion shape (α > 0). Forwarded as a kernel
1289/// argument only for `PirlsRowFamily::GammaLog`. Pass `1.0` for non-Gamma fits.
1290#[cfg(target_os = "linux")]
1291pub fn launch_solve_row_on_stream(
1292    backend: &PirlsRowBackend,
1293    family: PirlsRowFamily,
1294    curvature: CurvatureMode,
1295    gamma_shape: f64,
1296    stream: &Arc<cudarc::driver::CudaStream>,
1297    n: usize,
1298    eta_dev: &cudarc::driver::CudaSlice<f64>,
1299    y_dev: &cudarc::driver::CudaSlice<f64>,
1300    prior_w_dev: &cudarc::driver::CudaSlice<f64>,
1301    out: &mut SolveRowBuffers,
1302) -> Result<(), GpuError> {
1303    use cudarc::driver::{LaunchConfig, PushKernelArg};
1304    if out.n != n {
1305        gam_gpu::gpu_bail!("solve-row buffers shape {} mismatches n={n}", out.n);
1306    }
1307    let module = backend.module_for_solve(family, curvature)?;
1308    let kernel_name = family.solve_kernel_name();
1309    let func = module
1310        .load_function(kernel_name)
1311        .gpu_ctx_with(|err| format!("solve-row load_function({kernel_name}): {err}"))?;
1312    const THREADS_PER_BLOCK: u32 = 256;
1313    let n_u32 =
1314        u32::try_from(n).map_err(|_| gam_gpu::gpu_err!("n={n} exceeds u32 for solve-row grid sizing"))?;
1315    let grid_x = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
1316    let n_i32 = i32::try_from(n)
1317        .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds i32 for solve-row kernel argument"))?;
1318    let cfg = LaunchConfig {
1319        grid_dim: (grid_x, 1, 1),
1320        block_dim: (THREADS_PER_BLOCK, 1, 1),
1321        shared_mem_bytes: 0,
1322    };
1323    let mut builder = stream.launch_builder(&func);
1324    builder.arg(&n_i32);
1325    builder.arg(eta_dev);
1326    builder.arg(y_dev);
1327    builder.arg(prior_w_dev);
1328    // GammaLog solve kernel has `double shape` before the output buffers.
1329    if matches!(family, PirlsRowFamily::GammaLog) {
1330        builder.arg(&gamma_shape);
1331    }
1332    builder.arg(&mut out.grad_eta);
1333    builder.arg(&mut out.w_solver);
1334    builder.arg(&mut out.deviance);
1335    builder.arg(&mut out.status);
1336    // SAFETY: solve_row_source_for emits for non-GammaLog:
1337    //   (int n, const f64* eta, const f64* y, const f64* prior_w,
1338    //    f64* grad_eta_out, f64* w_solver_out, f64* deviance_out, u32* status_out)
1339    // For GammaLog the signature inserts `double shape` after `prior_w`.
1340    // 4 outputs match the 4 SolveRowBuffers fields; grid covers all n rows with
1341    // per-thread guard `if (i >= n) return`.
1342    unsafe { builder.launch(cfg) }
1343        .map(|_event_pair| ())
1344        .gpu_ctx_with(|err| format!("solve-row launch({kernel_name}): {err}"))
1345}
1346
1347/// **candidate-objective / fused alpha-ladder** launcher.
1348///
1349/// Evaluates `η_trial_i = η_i + α_k · xδ_i` for all `i ∈ [0,n)` and all
1350/// `k ∈ [0, ALPHA_LADDER_LEN)` simultaneously.  Each thread atomically
1351/// accumulates the per-row deviance into `out.objective_dev[k]` and
1352/// OR-accumulates status flags into `out.status_dev[k]`.
1353///
1354/// The grid is `(row_blocks × ALPHA_LADDER_LEN)`: block index `bx / n_blocks`
1355/// selects the alpha slot, `bx % n_blocks` selects the row tile.
1356///
1357/// Caller must call [`AlphaLadderDevBuffers::zero`] before each launch, then
1358/// issue a single `memcpy_dtoh` to read back `[f64; ALPHA_LADDER_LEN]` and
1359/// `[u32; ALPHA_LADDER_LEN]` to pick the accepted step.
1360#[cfg(target_os = "linux")]
1361pub fn launch_alpha_ladder_on_stream(
1362    backend: &PirlsRowBackend,
1363    family: PirlsRowFamily,
1364    curvature: CurvatureMode,
1365    gamma_shape: f64,
1366    stream: &Arc<cudarc::driver::CudaStream>,
1367    n: usize,
1368    eta_dev: &cudarc::driver::CudaSlice<f64>,
1369    xd_dev: &cudarc::driver::CudaSlice<f64>,
1370    y_dev: &cudarc::driver::CudaSlice<f64>,
1371    prior_w_dev: &cudarc::driver::CudaSlice<f64>,
1372    out: &mut AlphaLadderDevBuffers,
1373) -> Result<(), GpuError> {
1374    use cudarc::driver::{LaunchConfig, PushKernelArg};
1375    let module = backend.module_for_ladder(family, curvature)?;
1376    let kernel_name = family.ladder_kernel_name();
1377    let func = module
1378        .load_function(kernel_name)
1379        .gpu_ctx_with(|err| format!("alpha-ladder load_function({kernel_name}): {err}"))?;
1380    const THREADS_PER_BLOCK: u32 = 256;
1381    let n_u32 =
1382        u32::try_from(n).map_err(|_| gam_gpu::gpu_err!("n={n} exceeds u32 for alpha-ladder grid sizing"))?;
1383    let row_blocks = n_u32.div_ceil(THREADS_PER_BLOCK).max(1);
1384    let n_i32 = i32::try_from(n)
1385        .map_err(|_| gam_gpu::gpu_err!("n={n} exceeds i32 for alpha-ladder kernel argument"))?;
1386    // Grid: x = row tile index (0..row_blocks), y = alpha index (0..ALPHA_LADDER_LEN).
1387    let cfg = LaunchConfig {
1388        grid_dim: (row_blocks, ALPHA_LADDER_LEN as u32, 1),
1389        block_dim: (THREADS_PER_BLOCK, 1, 1),
1390        shared_mem_bytes: 0,
1391    };
1392    let mut builder = stream.launch_builder(&func);
1393    builder.arg(&n_i32);
1394    builder.arg(eta_dev);
1395    builder.arg(xd_dev);
1396    builder.arg(y_dev);
1397    builder.arg(prior_w_dev);
1398    // GammaLog ladder kernel has `double shape` before the output buffers.
1399    if matches!(family, PirlsRowFamily::GammaLog) {
1400        builder.arg(&gamma_shape);
1401    }
1402    builder.arg(&mut out.objective_dev);
1403    builder.arg(&mut out.status_dev);
1404    // SAFETY: for non-GammaLog, ladder_source_for emits:
1405    //   (int n, const f64* eta, const f64* xd, const f64* y, const f64* prior_w,
1406    //    f64* objective_out, u32* status_out)
1407    // For GammaLog `double shape` is inserted after `prior_w`.
1408    // Grid is (row_blocks × ALPHA_LADDER_LEN); each thread reads alphas[] via
1409    // blockIdx.y, rows via blockIdx.x * blockDim.x + threadIdx.x (guarded by n).
1410    // Atomic double-precision add to objective_out[blockIdx.y], OR to status_out[blockIdx.y].
1411    unsafe { builder.launch(cfg) }
1412        .map(|_event_pair| ())
1413        .gpu_ctx_with(|err| format!("alpha-ladder launch({kernel_name}): {err}"))
1414}
1415
1416// ────────────────────────────────────────────────────────────────────────
1417// CUDA sources (one per family / curvature pair)
1418// ────────────────────────────────────────────────────────────────────────
1419
1420/// Common device-side helpers shared across every family kernel.
1421#[cfg(target_os = "linux")]
1422const COMMON_DEVICE_PROLOG: &str = r#"
1423extern "C" {
1424    double exp(double);
1425    double log(double);
1426    double log1p(double);
1427    double tanh(double);
1428    double sqrt(double);
1429    double fabs(double);
1430    double erfc(double);
1431}
1432
1433__device__ __forceinline__ double clamp_eta(double eta, unsigned int* flags) {
1434    const double E = 700.0;
1435    if (eta > E) { *flags |= 0x1u; return E; }
1436    if (eta < -E) { *flags |= 0x1u; return -E; }
1437    return eta;
1438}
1439
1440__device__ __forceinline__ double bernoulli_deviance(double y, double mu, double w) {
1441    if (w == 0.0) return 0.0;
1442    double t1 = (y > 0.0) ? y * log(y / mu) : 0.0;
1443    double t2 = (y < 1.0) ? (1.0 - y) * log((1.0 - y) / (1.0 - mu)) : 0.0;
1444    return 2.0 * w * (t1 + t2);
1445}
1446
1447__device__ __forceinline__ double bernoulli_z(double eta, double y, double mu, double dmu_deta) {
1448    if (dmu_deta > 0.0 && isfinite(dmu_deta)) {
1449        double delta = (y - mu) / dmu_deta;
1450        if (isfinite(delta)) return eta + delta;
1451    }
1452    return eta;
1453}
1454
1455__device__ __forceinline__ double std_norm_cdf(double x) {
1456    return 0.5 * erfc(-x * 0.7071067811865475);
1457}
1458
1459__device__ __forceinline__ double std_norm_pdf(double x) {
1460    return 0.3989422804014327 * exp(-0.5 * x * x);
1461}
1462"#;
1463
1464/// Build the per-family CUDA source. Each source defines exactly one entry
1465/// kernel (`family.kernel_name()`) reading the input arrays and writing the
1466/// output arrays defined by the [`RowOutput`] contract above.
1467///
1468/// The GammaLog kernel has an extra `double shape` parameter after `prior_w`
1469/// so the host can forward the active dispersion shape. All other families
1470/// use the standard 13-argument signature.
1471#[cfg(target_os = "linux")]
1472fn cuda_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
1473    let body = match family {
1474        PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
1475        PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
1476        PirlsRowFamily::GammaLog => gamma_log_body(curvature),
1477        PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
1478        PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
1479        PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
1480    };
1481    let kernel_name = family.kernel_name();
1482    // Inject a curvature marker into the source so each `(family, curvature)`
1483    // pair compiles a distinct PTX module (cache key) and Stage 5 can branch
1484    // on `PIRLS_CURVATURE_OBSERVED` from inside each per-family body without
1485    // changing the host harness.
1486    let curvature_define = match curvature {
1487        CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
1488        CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
1489    };
1490    // GammaLog receives the active shape as a scalar kernel argument so the
1491    // host can forward any positive shape without recompiling the PTX.
1492    let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
1493        "    double         shape,\n"
1494    } else {
1495        ""
1496    };
1497    format!(
1498        r#"
1499{curvature_define}
1500{prolog}
1501
1502extern "C" __global__ void {kernel_name}(
1503    int            n,
1504    const double* __restrict__ eta,
1505    const double* __restrict__ y,
1506    const double* __restrict__ prior_w,
1507{shape_param}    double* __restrict__ mu_out,
1508    double* __restrict__ grad_eta_out,
1509    double* __restrict__ w_fisher_out,
1510    double* __restrict__ w_hessian_out,
1511    double* __restrict__ w_solver_out,
1512    double* __restrict__ z_fisher_out,
1513    double* __restrict__ z_hessian_out,
1514    double* __restrict__ deviance_out,
1515    unsigned int* __restrict__ status_out
1516) {{
1517    int i = blockIdx.x * blockDim.x + threadIdx.x;
1518    if (i >= n) return;
1519    unsigned int flags = 0u;
1520    double eta_i = eta[i];
1521    double y_i = y[i];
1522    double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
1523    if (prior_w[i] <= 0.0) flags |= 0x10u;
1524{body}
1525    mu_out[i] = mu;
1526    grad_eta_out[i] = grad_eta;
1527    w_fisher_out[i] = w_fisher;
1528    w_hessian_out[i] = w_hessian;
1529    w_solver_out[i] = w_solver;
1530    z_fisher_out[i] = z_f;
1531    z_hessian_out[i] = z_h;
1532    deviance_out[i] = dev;
1533    status_out[i] = flags;
1534}}
1535"#,
1536        prolog = COMMON_DEVICE_PROLOG,
1537    )
1538}
1539
1540/// Emits a CUDA comment tag identifying the curvature mode the kernel was
1541/// compiled for. Each body builder prepends this so the body source actually
1542/// consumes the `curvature` argument and Stage 5 can keyed-extend the bodies
1543/// behind `#ifdef PIRLS_CURVATURE_OBSERVED`.
1544#[cfg(target_os = "linux")]
1545#[inline]
1546fn curvature_tag(curvature: CurvatureMode) -> &'static str {
1547    match curvature {
1548        CurvatureMode::Fisher => "    // curvature: fisher\n",
1549        CurvatureMode::Observed => "    // curvature: observed\n",
1550    }
1551}
1552
1553#[cfg(target_os = "linux")]
1554fn gaussian_identity_body(curvature: CurvatureMode) -> String {
1555    let tag = curvature_tag(curvature);
1556    format!(
1557        r#"{tag}    double mu = eta_i;
1558    double resid = y_i - mu;
1559    double grad_eta = wp * resid;
1560    double w_fisher = wp;
1561    double w_hessian = wp;
1562    double w_solver = (wp > 0.0) ? fmax(wp, 1e-12) : 0.0;
1563    double z_f = y_i;
1564    double z_h = y_i;
1565    double dev = wp * resid * resid;
1566"#
1567    )
1568}
1569
1570#[cfg(target_os = "linux")]
1571fn poisson_log_body(curvature: CurvatureMode) -> String {
1572    let tag = curvature_tag(curvature);
1573    format!(
1574        r#"{tag}    double eta_c = clamp_eta(eta_i, &flags);
1575    double mu_raw = exp(eta_c);
1576    if (mu_raw < 1e-10) flags |= 0x2u;
1577    double mu = (mu_raw > 1e-10) ? mu_raw : 1e-10;
1578    double raw_w = wp * mu;
1579    double w_fisher = (raw_w > 0.0) ? fmax(raw_w, 1e-12) : 0.0;
1580    double resid = y_i - mu;
1581    double grad_eta = wp * resid;
1582    double w_hessian = w_fisher;
1583    double w_solver = w_fisher;
1584    double z_f = eta_c + resid / mu;
1585    double z_h = z_f;
1586    double dev_term = (y_i > 0.0) ? (y_i * log(y_i / mu) - resid) : (-resid);
1587    double dev = 2.0 * wp * dev_term;
1588    if (!(isfinite(y_i) && y_i >= 0.0)) flags |= 0x8u;
1589"#
1590    )
1591}
1592
1593#[cfg(target_os = "linux")]
1594fn gamma_log_body(curvature: CurvatureMode) -> String {
1595    // `shape` is a kernel parameter (see `cuda_source_for`); the body reads it
1596    // directly. No local shadowing needed.
1597    let tag = curvature_tag(curvature);
1598    format!(
1599        r#"{tag}    double eta_c = clamp_eta(eta_i, &flags);
1600    double mu_raw = exp(eta_c);
1601    if (mu_raw < 1e-10) flags |= 0x2u;
1602    double mu = (mu_raw > 1e-10) ? mu_raw : 1e-10;
1603    double w_fisher = wp * shape;
1604#ifdef PIRLS_CURVATURE_OBSERVED
1605    // Stage 5: observed information for Gamma-log.
1606    //   w_obs = w_F + w_F · (y/μ − 1) = w_F · y/μ.
1607    double w_hessian = (w_fisher > 0.0 && mu > 0.0 && isfinite(y_i))
1608        ? w_fisher * (y_i / mu)
1609        : w_fisher;
1610#else
1611    double w_hessian = w_fisher;
1612#endif
1613    double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
1614    double resid = y_i - mu;
1615    double grad_eta = wp * resid / mu;
1616    double z_f = eta_c + resid / mu;
1617    double z_h = z_f;
1618    double dev = (y_i > 0.0)
1619        ? (2.0 * wp * (-log(y_i / mu) + resid / mu))
1620        : (1.0 / 0.0);
1621    if (!(isfinite(y_i) && y_i > 0.0)) flags |= 0x8u;
1622"#
1623    )
1624}
1625
1626#[cfg(target_os = "linux")]
1627fn bernoulli_logit_body(curvature: CurvatureMode) -> String {
1628    let tag = curvature_tag(curvature);
1629    format!(
1630        r#"{tag}    double eta_c = clamp_eta(eta_i, &flags);
1631    double half = 0.5 * eta_c;
1632    double mu_raw = 0.5 * (1.0 + tanh(half));
1633    if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
1634    double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
1635    double dmu_deta = mu * (1.0 - mu);
1636    double w_fisher = wp * dmu_deta;
1637    double w_hessian = w_fisher;
1638    double w_solver = (w_fisher > 0.0) ? fmax(w_fisher, 1e-12) : 0.0;
1639    double resid = y_i - mu;
1640    double grad_eta = wp * resid;
1641    double dev = bernoulli_deviance(y_i, mu, wp);
1642    double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
1643    double z_h = z_f;
1644    if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
1645"#
1646    )
1647}
1648
1649#[cfg(target_os = "linux")]
1650fn bernoulli_probit_body(curvature: CurvatureMode) -> String {
1651    let tag = curvature_tag(curvature);
1652    format!(
1653        r#"{tag}    double eta_c = clamp_eta(eta_i, &flags);
1654    double mu_raw = std_norm_cdf(eta_c);
1655    if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
1656    double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
1657    double dmu_deta = std_norm_pdf(eta_c);
1658    double v = mu * (1.0 - mu);
1659    double fpp = (v > 0.0) ? dmu_deta * dmu_deta / v : 0.0;
1660    double w_fisher = wp * fpp;
1661#ifdef PIRLS_CURVATURE_OBSERVED
1662    // Stage 5: observed information for Bernoulli probit.
1663    //   w_obs = w_F + w_p · (y − μ) · [h'/V − h²·V'/V²].
1664    // h(η)=φ(η), h'(η)=−η·φ(η); V'=1−2μ.
1665    double w_hessian = w_fisher;
1666    if (v > 0.0 && wp > 0.0) {{
1667        double h_prime = -eta_c * dmu_deta;
1668        double v_prime = 1.0 - 2.0 * mu;
1669        double bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
1670        w_hessian = w_fisher + wp * (y_i - mu) * bracket;
1671    }}
1672#else
1673    double w_hessian = w_fisher;
1674#endif
1675    double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
1676    double resid = y_i - mu;
1677    double grad_eta = (v > 0.0) ? wp * resid * dmu_deta / v : 0.0;
1678    double dev = bernoulli_deviance(y_i, mu, wp);
1679    double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
1680    double z_h = z_f;
1681    if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
1682"#
1683    )
1684}
1685
1686#[cfg(target_os = "linux")]
1687fn bernoulli_cloglog_body(curvature: CurvatureMode) -> String {
1688    let tag = curvature_tag(curvature);
1689    format!(
1690        r#"{tag}    double eta_c = clamp_eta(eta_i, &flags);
1691    double inner = exp(eta_c);
1692    // μ = 1 − exp(−exp(η)); use -expm1(-inner) to avoid catastrophic
1693    // cancellation in the deep negative tail (η ≲ -36).
1694    double mu_raw = -expm1(-inner);
1695    if (mu_raw < 1e-12 || mu_raw > 1.0 - 1e-12) flags |= 0x2u;
1696    double mu = fmin(fmax(mu_raw, 1e-12), 1.0 - 1e-12);
1697    double dmu_deta = inner * (1.0 - mu_raw);
1698    double v = mu * (1.0 - mu);
1699    double fpp = (v > 0.0) ? dmu_deta * dmu_deta / v : 0.0;
1700    double w_fisher = wp * fpp;
1701#ifdef PIRLS_CURVATURE_OBSERVED
1702    // Stage 5: observed information for Bernoulli cloglog.
1703    //   w_obs = w_F + w_p · (y − μ) · [h'/V − h²·V'/V²].
1704    // h'(η) = h(η) · (1 − inner); V'=1−2μ.
1705    double w_hessian = w_fisher;
1706    if (v > 0.0 && wp > 0.0) {{
1707        double h_prime = dmu_deta * (1.0 - inner);
1708        double v_prime = 1.0 - 2.0 * mu;
1709        double bracket = h_prime / v - (dmu_deta * dmu_deta) * v_prime / (v * v);
1710        w_hessian = w_fisher + wp * (y_i - mu) * bracket;
1711    }}
1712#else
1713    double w_hessian = w_fisher;
1714#endif
1715    double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
1716    double resid = y_i - mu;
1717    double grad_eta = (v > 0.0) ? wp * resid * dmu_deta / v : 0.0;
1718    double dev = bernoulli_deviance(y_i, mu, wp);
1719    double z_f = bernoulli_z(eta_c, y_i, mu, dmu_deta);
1720    double z_h = z_f;
1721    if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
1722"#
1723    )
1724}
1725
1726// ────────────────────────────────────────────────────────────────────────
1727// solve-row CUDA source (4-output variant)
1728// ────────────────────────────────────────────────────────────────────────
1729
1730/// Build the solve-row CUDA source for `(family, curvature)`.
1731///
1732/// The kernel has a reduced signature:
1733///   `(int n, const f64* eta, const f64* y, const f64* prior_w,
1734///     f64* grad_eta_out, f64* w_solver_out, f64* deviance_out, u32* status_out)`
1735///
1736/// It executes the same per-family math as `cuda_source_for` but skips the
1737/// `mu`, `w_fisher`, `w_hessian`, `z_fisher`, `z_hessian` stores, reducing
1738/// both bandwidth and L1/register pressure in the hot Newton iteration.
1739#[cfg(target_os = "linux")]
1740fn solve_row_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
1741    let body = match family {
1742        PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
1743        PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
1744        PirlsRowFamily::GammaLog => gamma_log_body(curvature),
1745        PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
1746        PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
1747        PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
1748    };
1749    let kernel_name = family.solve_kernel_name();
1750    let curvature_define = match curvature {
1751        CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
1752        CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
1753    };
1754    // GammaLog solve kernel also takes `double shape` after `prior_w`.
1755    let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
1756        "    double         shape,\n"
1757    } else {
1758        ""
1759    };
1760    format!(
1761        r#"
1762{curvature_define}
1763{prolog}
1764
1765extern "C" __global__ void {kernel_name}(
1766    int            n,
1767    const double* __restrict__ eta,
1768    const double* __restrict__ y,
1769    const double* __restrict__ prior_w,
1770{shape_param}    double* __restrict__ grad_eta_out,
1771    double* __restrict__ w_solver_out,
1772    double* __restrict__ deviance_out,
1773    unsigned int* __restrict__ status_out
1774) {{
1775    int i = blockIdx.x * blockDim.x + threadIdx.x;
1776    if (i >= n) return;
1777    unsigned int flags = 0u;
1778    double eta_i = eta[i];
1779    double y_i = y[i];
1780    double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
1781    if (prior_w[i] <= 0.0) flags |= 0x10u;
1782{body}
1783    grad_eta_out[i] = grad_eta;
1784    w_solver_out[i] = w_solver;
1785    deviance_out[i] = dev;
1786    status_out[i] = flags;
1787}}
1788"#,
1789        prolog = COMMON_DEVICE_PROLOG,
1790    )
1791}
1792
1793// ────────────────────────────────────────────────────────────────────────
1794// alpha-ladder CUDA source (fused all-alpha candidate-objective kernel)
1795// ────────────────────────────────────────────────────────────────────────
1796
1797/// The alpha constants embedded into the ladder kernel source as a
1798/// `__constant__` array. Must stay in sync with [`ALPHA_LADDER`].
1799#[cfg(target_os = "linux")]
1800const ALPHA_LADDER_CUDA_ARRAY: &str =
1801    "__constant__ double PIRLS_ALPHAS[7] = {1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625};";
1802
1803/// Build the fused alpha-ladder CUDA source for `(family, curvature)`.
1804///
1805/// Grid layout: `grid = (row_blocks, ALPHA_LADDER_LEN, 1)`.
1806///   - `blockIdx.y` selects the alpha slot `k`.
1807///   - `blockIdx.x * blockDim.x + threadIdx.x` selects row `i`.
1808///
1809/// Each thread evaluates `eta_trial = eta[i] + PIRLS_ALPHAS[k] * xd[i]`,
1810/// runs the per-family deviance math, and atomically adds to
1811/// `objective_out[k]` (double-precision atomic add) and OR-accumulates into
1812/// `status_out[k]` (atomic OR on u32).
1813///
1814/// The kernel signature is:
1815///   `(int n, const f64* eta, const f64* xd, const f64* y, const f64* prior_w,
1816///     f64* objective_out, u32* status_out)`
1817#[cfg(target_os = "linux")]
1818fn ladder_source_for(family: PirlsRowFamily, curvature: CurvatureMode) -> String {
1819    let body = match family {
1820        PirlsRowFamily::GaussianIdentity => gaussian_identity_body(curvature),
1821        PirlsRowFamily::PoissonLog => poisson_log_body(curvature),
1822        PirlsRowFamily::GammaLog => gamma_log_body(curvature),
1823        PirlsRowFamily::BernoulliLogit => bernoulli_logit_body(curvature),
1824        PirlsRowFamily::BernoulliProbit => bernoulli_probit_body(curvature),
1825        PirlsRowFamily::BernoulliCLogLog => bernoulli_cloglog_body(curvature),
1826    };
1827    let kernel_name = family.ladder_kernel_name();
1828    let curvature_define = match curvature {
1829        CurvatureMode::Fisher => "#define PIRLS_CURVATURE_FISHER 1",
1830        CurvatureMode::Observed => "#define PIRLS_CURVATURE_OBSERVED 1",
1831    };
1832    // The body uses the name `eta_i` for the (possibly clamped) linear predictor.
1833    // For the ladder we substitute `eta[i] + alpha * xd[i]` as the trial eta,
1834    // so we define `eta_i` before the body runs. The body's own local variable
1835    // of the same name overwrites it in families that clamp eta (e.g. Poisson,
1836    // Bernoulli), which is correct — the per-family body reassigns `eta_c` from
1837    // the (now trial-eta-valued) `eta_i`. For GaussianIdentity the body reads
1838    // `eta_i` directly as `mu`, which is also correct after the substitution.
1839    // GammaLog ladder kernel also takes `double shape` after `prior_w`.
1840    let shape_param = if matches!(family, PirlsRowFamily::GammaLog) {
1841        "    double         shape,\n"
1842    } else {
1843        ""
1844    };
1845    format!(
1846        r#"
1847{curvature_define}
1848{prolog}
1849{alphas}
1850
1851extern "C" __global__ void {kernel_name}(
1852    int            n,
1853    const double* __restrict__ eta,
1854    const double* __restrict__ xd,
1855    const double* __restrict__ y,
1856    const double* __restrict__ prior_w,
1857{shape_param}    double* __restrict__ objective_out,
1858    unsigned int* __restrict__ status_out
1859) {{
1860    int i = blockIdx.x * blockDim.x + threadIdx.x;
1861    int k = (int)blockIdx.y;
1862    if (i >= n) return;
1863    unsigned int flags = 0u;
1864    double alpha = PIRLS_ALPHAS[k];
1865    double eta_i = eta[i] + alpha * xd[i];
1866    double y_i = y[i];
1867    double wp = prior_w[i] > 0.0 ? prior_w[i] : 0.0;
1868    if (prior_w[i] <= 0.0) flags |= 0x10u;
1869{body}
1870    atomicAdd(&objective_out[k], dev);
1871    atomicOr(&status_out[k], flags);
1872}}
1873"#,
1874        prolog = COMMON_DEVICE_PROLOG,
1875        alphas = ALPHA_LADDER_CUDA_ARRAY,
1876    )
1877}
1878
1879// ────────────────────────────────────────────────────────────────────────
1880// Tests
1881// ────────────────────────────────────────────────────────────────────────
1882
1883#[cfg(test)]
1884mod pirls_row_gpu_tests {
1885    use super::*;
1886
1887    fn assert_close(label: &str, got: f64, expected: f64, tol: f64) {
1888        if !(got.is_finite() && expected.is_finite()) {
1889            assert_eq!(
1890                got.is_finite(),
1891                expected.is_finite(),
1892                "{label}: finiteness disagrees (got={got}, expected={expected})"
1893            );
1894            return;
1895        }
1896        let diff = (got - expected).abs();
1897        let denom = expected.abs().max(1.0);
1898        assert!(
1899            diff <= tol * denom,
1900            "{label}: |{got} - {expected}| = {diff} exceeds tol {tol} (rel denom {denom})"
1901        );
1902    }
1903
1904    fn check_family_matches_cpu_reference(family: PirlsRowFamily) {
1905        let etas = [-700.0, -3.0, -0.5, 0.0, 0.5, 3.0, 700.0];
1906        let ys = match family {
1907            PirlsRowFamily::GammaLog => vec![0.5, 1.0, 2.5],
1908            PirlsRowFamily::PoissonLog => vec![0.0, 1.0, 5.0],
1909            PirlsRowFamily::GaussianIdentity => vec![-1.5, 0.0, 2.0],
1910            _ => vec![0.0, 1.0],
1911        };
1912        let ws = [0.0, 1.0, 2.5];
1913        for &eta in &etas {
1914            for &y in &ys {
1915                for &wp in &ws {
1916                    let input = RowInput {
1917                        eta,
1918                        y,
1919                        prior_weight: wp,
1920                    };
1921                    let out = row_reweight_cpu(family, CurvatureMode::Fisher, input, 1.0);
1922                    // Structural invariants of the contract.
1923                    assert!(
1924                        out.w_fisher >= 0.0,
1925                        "{family:?}: w_fisher must be non-negative (got {})",
1926                        out.w_fisher
1927                    );
1928                    assert!(
1929                        out.w_solver >= 0.0,
1930                        "{family:?}: w_solver must be non-negative (got {})",
1931                        out.w_solver
1932                    );
1933                    if wp > 0.0 && out.w_hessian > 0.0 {
1934                        assert!(
1935                            out.w_solver >= W_SOLVER_FLOOR,
1936                            "{family:?}: w_solver must be floored away from zero when positive (got {})",
1937                            out.w_solver
1938                        );
1939                    }
1940                    // grad_eta and (z_fisher - eta) * w_fisher must agree
1941                    // when eta is unclamped and w_fisher > 0; this guards the
1942                    // "never reconstruct gradient from z" discipline.
1943                    if (out.status & status_flags::ETA_CLAMPED) != 0 {
1944                        continue;
1945                    }
1946                    if out.w_fisher > 0.0 && out.z_fisher.is_finite() {
1947                        let reconstructed = out.w_fisher * (out.z_fisher - eta);
1948                        // Allow loose tolerance — we only require these are
1949                        // in the same ballpark; the *exact* gradient is
1950                        // grad_eta. Catastrophic cancellation in the
1951                        // reconstructed form is the precise reason the
1952                        // contract exposes grad_eta directly.
1953                        if reconstructed.is_finite() {
1954                            let denom = reconstructed.abs().max(out.grad_eta.abs()).max(1.0);
1955                            let diff = (reconstructed - out.grad_eta).abs() / denom;
1956                            assert!(
1957                                diff < 1.0e-6,
1958                                "{family:?} eta={eta} y={y} wp={wp}: grad_eta {} vs w·(z−η) {} differ by rel {}",
1959                                out.grad_eta,
1960                                reconstructed,
1961                                diff
1962                            );
1963                        }
1964                    }
1965                    // deviance non-negative for valid inputs.
1966                    if out.status & status_flags::INVALID_RESPONSE == 0 && wp >= 0.0 {
1967                        assert!(
1968                            out.deviance >= 0.0 || !out.deviance.is_finite(),
1969                            "{family:?} eta={eta} y={y} wp={wp}: deviance must be non-negative for valid inputs (got {})",
1970                            out.deviance
1971                        );
1972                    }
1973                    // Final sanity: outputs are finite or carry an explicit
1974                    // INVALID_RESPONSE / ZERO_PRIOR_WEIGHT flag.
1975                    if out.status
1976                        & (status_flags::INVALID_RESPONSE | status_flags::ZERO_PRIOR_WEIGHT)
1977                        == 0
1978                    {
1979                        assert!(
1980                            out.mu.is_finite(),
1981                            "{family:?} eta={eta} y={y} wp={wp}: mu must be finite for valid inputs"
1982                        );
1983                        assert!(
1984                            out.grad_eta.is_finite(),
1985                            "{family:?} eta={eta} y={y} wp={wp}: grad_eta must be finite for valid inputs"
1986                        );
1987                    }
1988                }
1989            }
1990        }
1991        // Pull `assert_close` into the closure type-checker so this function
1992        // is the single caller of it across the parity surface — keeps the
1993        // helper exercised on every family run.
1994        assert_close("self", 0.0, 0.0, 0.0);
1995    }
1996
1997    /// Count how many rows actually carried a positive Fisher weight; tests
1998    /// assert non-zero so a silently no-op evaluator (e.g. all-NaN output)
1999    /// can't satisfy the per-row invariants vacuously.
2000    fn count_active_rows(family: PirlsRowFamily) -> usize {
2001        let mut active = 0usize;
2002        for &eta in [-700.0, -3.0, 0.0, 3.0, 700.0].iter() {
2003            for &y in [0.0, 0.5, 1.0].iter() {
2004                for &wp in [1.0, 2.5].iter() {
2005                    let out = row_reweight_cpu(
2006                        family,
2007                        CurvatureMode::Fisher,
2008                        RowInput {
2009                            eta,
2010                            y,
2011                            prior_weight: wp,
2012                        },
2013                        1.0,
2014                    );
2015                    if out.w_fisher > 0.0 {
2016                        active += 1;
2017                    }
2018                }
2019            }
2020        }
2021        active
2022    }
2023
2024    #[test]
2025    fn gaussian_identity_row_invariants() {
2026        check_family_matches_cpu_reference(PirlsRowFamily::GaussianIdentity);
2027        assert!(count_active_rows(PirlsRowFamily::GaussianIdentity) > 0);
2028    }
2029
2030    #[test]
2031    fn poisson_log_row_invariants() {
2032        check_family_matches_cpu_reference(PirlsRowFamily::PoissonLog);
2033        assert!(count_active_rows(PirlsRowFamily::PoissonLog) > 0);
2034    }
2035
2036    #[test]
2037    fn gamma_log_row_invariants() {
2038        check_family_matches_cpu_reference(PirlsRowFamily::GammaLog);
2039        assert!(count_active_rows(PirlsRowFamily::GammaLog) > 0);
2040    }
2041
2042    #[test]
2043    fn bernoulli_logit_row_invariants() {
2044        check_family_matches_cpu_reference(PirlsRowFamily::BernoulliLogit);
2045        assert!(count_active_rows(PirlsRowFamily::BernoulliLogit) > 0);
2046    }
2047
2048    #[test]
2049    fn bernoulli_probit_row_invariants() {
2050        check_family_matches_cpu_reference(PirlsRowFamily::BernoulliProbit);
2051        assert!(count_active_rows(PirlsRowFamily::BernoulliProbit) > 0);
2052    }
2053
2054    #[test]
2055    fn bernoulli_cloglog_row_invariants() {
2056        check_family_matches_cpu_reference(PirlsRowFamily::BernoulliCLogLog);
2057        assert!(count_active_rows(PirlsRowFamily::BernoulliCLogLog) > 0);
2058    }
2059
2060    /// Gaussian identity must match the trivial CPU formula for any input.
2061    #[test]
2062    fn gaussian_identity_matches_explicit_formulas() {
2063        let out = row_reweight_cpu(
2064            PirlsRowFamily::GaussianIdentity,
2065            CurvatureMode::Fisher,
2066            RowInput {
2067                eta: 0.25,
2068                y: 1.0,
2069                prior_weight: 2.0,
2070            },
2071            1.0,
2072        );
2073        assert!(out.mu.is_finite() && out.deviance.is_finite());
2074        assert_close("mu", out.mu, 0.25, 0.0);
2075        assert_close("grad_eta", out.grad_eta, 2.0 * (1.0 - 0.25), 1e-15);
2076        assert_close("w_fisher", out.w_fisher, 2.0, 0.0);
2077        assert_close(
2078            "deviance",
2079            out.deviance,
2080            2.0 * (1.0 - 0.25_f64).powi(2),
2081            1e-15,
2082        );
2083    }
2084
2085    /// Poisson log must match the analytic formula μ = exp(η) and grad_eta = w·(y−μ).
2086    #[test]
2087    fn poisson_log_matches_explicit_formulas() {
2088        let out = row_reweight_cpu(
2089            PirlsRowFamily::PoissonLog,
2090            CurvatureMode::Fisher,
2091            RowInput {
2092                eta: 1.5,
2093                y: 4.0,
2094                prior_weight: 1.0,
2095            },
2096            1.0,
2097        );
2098        let expected_mu = (1.5_f64).exp();
2099        assert!(expected_mu.is_finite() && out.mu.is_finite());
2100        assert_close("mu", out.mu, expected_mu, 1e-15);
2101        assert_close("grad_eta", out.grad_eta, 4.0 - expected_mu, 1e-15);
2102        assert_close("w_fisher", out.w_fisher, expected_mu, 1e-15);
2103    }
2104
2105    /// Bernoulli logit: closed form for canonical link.
2106    #[test]
2107    fn bernoulli_logit_matches_explicit_formulas() {
2108        let eta: f64 = 0.7;
2109        let mu = 1.0 / (1.0 + (-eta).exp());
2110        let out = row_reweight_cpu(
2111            PirlsRowFamily::BernoulliLogit,
2112            CurvatureMode::Fisher,
2113            RowInput {
2114                eta,
2115                y: 1.0,
2116                prior_weight: 3.0,
2117            },
2118            1.0,
2119        );
2120        assert!(mu > 0.0 && mu < 1.0);
2121        assert_close("mu", out.mu, mu, 1e-12);
2122        assert_close("w_fisher", out.w_fisher, 3.0 * mu * (1.0 - mu), 1e-12);
2123        assert_close("grad_eta", out.grad_eta, 3.0 * (1.0 - mu), 1e-12);
2124    }
2125
2126    /// Eta clamping flag must trip past ±700.
2127    #[test]
2128    fn eta_clamp_status_flag_trips() {
2129        let out = row_reweight_cpu(
2130            PirlsRowFamily::PoissonLog,
2131            CurvatureMode::Fisher,
2132            RowInput {
2133                eta: 1000.0,
2134                y: 0.0,
2135                prior_weight: 1.0,
2136            },
2137            1.0,
2138        );
2139        assert!(out.status & status_flags::ETA_CLAMPED != 0);
2140    }
2141
2142    /// `module_for` must lazily compile + cache one module per `(family, curvature)`.
2143    /// Skipped on hosts without a CUDA runtime (mac, CI).
2144    #[test]
2145    fn backend_compiles_one_module_per_family_when_device_present() {
2146        // The compiled-backend flag itself is independent of runtime probe
2147        // and must agree with the `cfg(target_os = "linux")` selector that
2148        // gates the rest of the module-cache code path.
2149        assert_eq!(PirlsRowBackend::compiled(), cfg!(target_os = "linux"));
2150        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2151            eprintln!("[pirls_row_gpu test] no CUDA runtime — skipping device compile test");
2152            return;
2153        }
2154        #[cfg(target_os = "linux")]
2155        {
2156            let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2157            for &family in PirlsRowFamily::ALL.iter() {
2158                let m1 = backend
2159                    .module_for(family, CurvatureMode::Fisher)
2160                    .unwrap_or_else(|err| panic!("compile {family:?}: {err}"));
2161                let m2 = backend
2162                    .module_for(family, CurvatureMode::Fisher)
2163                    .unwrap_or_else(|err| panic!("re-fetch {family:?}: {err}"));
2164                assert!(
2165                    Arc::ptr_eq(&m1, &m2),
2166                    "{family:?}: module cache must return same handle on second call"
2167                );
2168            }
2169        }
2170    }
2171
2172    /// Stage 6: JIT-compiled custom-family kernel through Level A
2173    /// (built-in spec) must produce byte-identical outputs to the
2174    /// cached built-in kernel on the same inputs. Validates the
2175    /// `JitFamilySpec::glm` builder + `cuda_source` shell + JIT module
2176    /// cache + `launch_row_reweight_jit_on_stream` end to end against
2177    /// the Stage 1 cached built-in path.
2178    #[test]
2179    fn jit_glm_kernel_matches_builtin_byte_identical() {
2180        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2181            eprintln!("[stage_6_jit] no CUDA runtime — skipping");
2182            return;
2183        }
2184        #[cfg(target_os = "linux")]
2185        {
2186            let etas = [-2.0_f64, -0.5, 0.3, 1.5];
2187            let ys = [0.0_f64, 1.0, 0.0, 1.0];
2188            let priors = [1.0_f64, 1.2, 0.8, 1.5];
2189            let n = etas.len();
2190            let family = PirlsRowFamily::BernoulliLogit;
2191            let curvature = CurvatureMode::Fisher;
2192            let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2193            let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2194                .expect("shared backend probe")
2195                .stream;
2196
2197            let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("eta");
2198            let mut y_dev = stream.alloc_zeros::<f64>(n).expect("y");
2199            let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("prior");
2200            stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
2201            stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
2202            stream
2203                .memcpy_htod(&priors, &mut prior_dev)
2204                .expect("up prior");
2205
2206            // Built-in path.
2207            let mut out_builtin = RowOutputDevBuffers::allocate(&stream, n).expect("alloc builtin");
2208            launch_row_reweight_on_stream(
2209                backend,
2210                family,
2211                curvature,
2212                1.0,
2213                &stream,
2214                n,
2215                &eta_dev,
2216                &y_dev,
2217                &prior_dev,
2218                &mut out_builtin,
2219            )
2220            .expect("builtin launch");
2221
2222            // JIT path through Level A spec (same body, distinct kernel symbol).
2223            let spec = JitFamilySpec::glm(0x424c_4c47u64, family, curvature);
2224            let mut out_jit = RowOutputDevBuffers::allocate(&stream, n).expect("alloc jit");
2225            launch_row_reweight_jit_on_stream(
2226                backend,
2227                &spec,
2228                curvature,
2229                &stream,
2230                n,
2231                &eta_dev,
2232                &y_dev,
2233                &prior_dev,
2234                &mut out_jit,
2235            )
2236            .expect("jit launch");
2237            stream.synchronize().expect("sync");
2238
2239            // Byte-identical per-field comparison.
2240            for (label, b_dev, j_dev) in [
2241                ("mu", &out_builtin.mu, &out_jit.mu),
2242                ("grad_eta", &out_builtin.grad_eta, &out_jit.grad_eta),
2243                ("w_fisher", &out_builtin.w_fisher, &out_jit.w_fisher),
2244                ("w_hessian", &out_builtin.w_hessian, &out_jit.w_hessian),
2245                ("w_solver", &out_builtin.w_solver, &out_jit.w_solver),
2246                ("z_fisher", &out_builtin.z_fisher, &out_jit.z_fisher),
2247                ("z_hessian", &out_builtin.z_hessian, &out_jit.z_hessian),
2248                ("deviance", &out_builtin.deviance, &out_jit.deviance),
2249            ] {
2250                let b = stream.clone_dtoh(b_dev).expect("dl builtin");
2251                let j = stream.clone_dtoh(j_dev).expect("dl jit");
2252                for i in 0..n {
2253                    assert_eq!(
2254                        b[i].to_bits(),
2255                        j[i].to_bits(),
2256                        "{label}[{i}]: builtin {} ≠ jit {}",
2257                        b[i],
2258                        j[i],
2259                    );
2260                }
2261            }
2262        }
2263    }
2264
2265    /// Stage 6 Level B: caller-supplied raw CUDA body (no built-in
2266    /// family enum) must produce byte-identical outputs to the cached
2267    /// built-in `GaussianIdentity` kernel on the same fixture. The
2268    /// raw body below is written from scratch — different statement
2269    /// layout, different intermediate names — but performs the same
2270    /// floating-point operations in the same order as
2271    /// `gaussian_identity_body`, so the bit pattern of every output
2272    /// must match the built-in kernel exactly.
2273    #[test]
2274    fn jit_raw_body_kernel_matches_builtin_gaussian_byte_identical() {
2275        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2276            eprintln!("[stage_6_jit_raw] no CUDA runtime — skipping");
2277            return;
2278        }
2279        #[cfg(target_os = "linux")]
2280        {
2281            // 256-row deterministic fixture covering positive/negative
2282            // residuals, zero residuals, varying prior weights, and a
2283            // zero-weight row (exercises the `(wp > 0.0)` branch in
2284            // `w_solver`).
2285            let n: usize = 256;
2286            let mut etas = vec![0.0_f64; n];
2287            let mut ys = vec![0.0_f64; n];
2288            let mut priors = vec![0.0_f64; n];
2289            for i in 0..n {
2290                let t = (i as f64) / (n as f64 - 1.0); // 0..=1
2291                etas[i] = -3.0 + 6.0 * t;
2292                ys[i] = 5.0 * (t - 0.5);
2293                priors[i] = if i == 7 {
2294                    0.0 // zero-weight row
2295                } else {
2296                    0.25 + 1.75 * t
2297                };
2298            }
2299
2300            let family = PirlsRowFamily::GaussianIdentity;
2301            let curvature = CurvatureMode::Fisher;
2302            let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2303            let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2304                .expect("shared backend probe")
2305                .stream;
2306
2307            let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("eta");
2308            let mut y_dev = stream.alloc_zeros::<f64>(n).expect("y");
2309            let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("prior");
2310            stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
2311            stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
2312            stream
2313                .memcpy_htod(&priors, &mut prior_dev)
2314                .expect("up prior");
2315
2316            // Built-in Level A Gaussian-identity kernel (reference).
2317            let mut out_builtin = RowOutputDevBuffers::allocate(&stream, n).expect("alloc builtin");
2318            launch_row_reweight_on_stream(
2319                backend,
2320                family,
2321                curvature,
2322                1.0,
2323                &stream,
2324                n,
2325                &eta_dev,
2326                &y_dev,
2327                &prior_dev,
2328                &mut out_builtin,
2329            )
2330            .expect("builtin launch");
2331
2332            // Level B raw-body Gaussian-identity kernel. Source is
2333            // written by hand against the JitFamilySpec contract: read
2334            // eta_i, y_i, wp; assign mu, grad_eta, w_fisher, w_hessian,
2335            // w_solver, z_f, z_h, dev. The op sequence matches
2336            // `gaussian_identity_body` exactly so the result is
2337            // bit-identical to the built-in path.
2338            let raw_body = r#"    // level-b raw body: gaussian identity (hand-written)
2339    // identity link: mu = eta
2340    double mu = eta_i;
2341    // ordinary residual on the response scale
2342    double resid = y_i - mu;
2343    // canonical score contribution
2344    double grad_eta = wp * resid;
2345    // fisher info per row: weight itself (V(mu)=1, dmu/deta=1)
2346    double w_fisher = wp;
2347    // observed == fisher for canonical identity link
2348    double w_hessian = wp;
2349    // solver weight clamps tiny positives to avoid singularity
2350    double w_solver = (wp > 0.0) ? fmax(wp, 1e-12) : 0.0;
2351    // working response equals raw response on identity link
2352    double z_f = y_i;
2353    double z_h = y_i;
2354    // squared-error contribution to deviance
2355    double dev = wp * resid * resid;
2356"#;
2357            let spec = JitFamilySpec::raw(0x5241_575f_4741_5553u64, raw_body);
2358            let mut out_jit = RowOutputDevBuffers::allocate(&stream, n).expect("alloc jit");
2359            launch_row_reweight_jit_on_stream(
2360                backend,
2361                &spec,
2362                curvature,
2363                &stream,
2364                n,
2365                &eta_dev,
2366                &y_dev,
2367                &prior_dev,
2368                &mut out_jit,
2369            )
2370            .expect("jit raw launch");
2371            stream.synchronize().expect("sync");
2372
2373            for (label, b_dev, j_dev) in [
2374                ("mu", &out_builtin.mu, &out_jit.mu),
2375                ("grad_eta", &out_builtin.grad_eta, &out_jit.grad_eta),
2376                ("w_fisher", &out_builtin.w_fisher, &out_jit.w_fisher),
2377                ("w_hessian", &out_builtin.w_hessian, &out_jit.w_hessian),
2378                ("w_solver", &out_builtin.w_solver, &out_jit.w_solver),
2379                ("z_fisher", &out_builtin.z_fisher, &out_jit.z_fisher),
2380                ("z_hessian", &out_builtin.z_hessian, &out_jit.z_hessian),
2381                ("deviance", &out_builtin.deviance, &out_jit.deviance),
2382            ] {
2383                let b = stream.clone_dtoh(b_dev).expect("dl builtin");
2384                let j = stream.clone_dtoh(j_dev).expect("dl jit raw");
2385                for i in 0..n {
2386                    assert_eq!(
2387                        b[i].to_bits(),
2388                        j[i].to_bits(),
2389                        "{label}[{i}]: builtin {} ≠ jit-raw {}",
2390                        b[i],
2391                        j[i],
2392                    );
2393                }
2394            }
2395
2396            // Direct Level B raw-body → CPU reference parity. Closes
2397            // the chain JIT-raw → CPU without going through the
2398            // built-in GPU kernel as an intermediary. Gaussian
2399            // identity is straight-line scalar arithmetic on both
2400            // sides, so we still demand bit-equality on the eight
2401            // outputs the CPU evaluator exposes.
2402            let mu_j = stream.clone_dtoh(&out_jit.mu).expect("dl jit mu");
2403            let g_j = stream.clone_dtoh(&out_jit.grad_eta).expect("dl jit g");
2404            let wf_j = stream.clone_dtoh(&out_jit.w_fisher).expect("dl jit wf");
2405            let wh_j = stream.clone_dtoh(&out_jit.w_hessian).expect("dl jit wh");
2406            let ws_j = stream.clone_dtoh(&out_jit.w_solver).expect("dl jit ws");
2407            let zf_j = stream.clone_dtoh(&out_jit.z_fisher).expect("dl jit zf");
2408            let zh_j = stream.clone_dtoh(&out_jit.z_hessian).expect("dl jit zh");
2409            let d_j = stream.clone_dtoh(&out_jit.deviance).expect("dl jit d");
2410            for i in 0..n {
2411                let cpu = row_reweight_cpu(
2412                    PirlsRowFamily::GaussianIdentity,
2413                    curvature,
2414                    RowInput {
2415                        eta: etas[i],
2416                        y: ys[i],
2417                        prior_weight: priors[i],
2418                    },
2419                    1.0,
2420                );
2421                for (label, cpu_v, jit_v) in [
2422                    ("mu", cpu.mu, mu_j[i]),
2423                    ("grad_eta", cpu.grad_eta, g_j[i]),
2424                    ("w_fisher", cpu.w_fisher, wf_j[i]),
2425                    ("w_hessian", cpu.w_hessian, wh_j[i]),
2426                    ("w_solver", cpu.w_solver, ws_j[i]),
2427                    ("z_fisher", cpu.z_fisher, zf_j[i]),
2428                    ("z_hessian", cpu.z_hessian, zh_j[i]),
2429                    ("deviance", cpu.deviance, d_j[i]),
2430                ] {
2431                    assert_eq!(
2432                        cpu_v.to_bits(),
2433                        jit_v.to_bits(),
2434                        "{label}[{i}]: cpu {} ≠ jit-raw {}",
2435                        cpu_v,
2436                        jit_v,
2437                    );
2438                }
2439            }
2440        }
2441    }
2442
2443    /// Stage 5: observed-information curvature mode is now real math
2444    /// (no longer a Fisher alias) for Gamma-log, Bernoulli probit, and
2445    /// Bernoulli cloglog. Canonical families (Bernoulli logit, Poisson
2446    /// log, Gaussian identity) remain == Fisher because canonical
2447    /// links have observed == Fisher by construction.
2448    #[test]
2449    fn observed_curvature_matches_expected_per_family() {
2450        // Picks where the math is well-conditioned so that round-off
2451        // doesn't dominate the comparison.
2452        let probe_eta = 0.4_f64;
2453        let probe_y = 1.0_f64;
2454        let wp = 1.5_f64;
2455        let input = RowInput {
2456            eta: probe_eta,
2457            y: probe_y,
2458            prior_weight: wp,
2459        };
2460
2461        // Canonical families: observed must equal Fisher exactly.
2462        for canonical in [
2463            PirlsRowFamily::GaussianIdentity,
2464            PirlsRowFamily::PoissonLog,
2465            PirlsRowFamily::BernoulliLogit,
2466        ] {
2467            let f = row_reweight_cpu(canonical, CurvatureMode::Fisher, input, 1.0);
2468            let o = row_reweight_cpu(canonical, CurvatureMode::Observed, input, 1.0);
2469            assert_eq!(
2470                f.w_hessian, o.w_hessian,
2471                "{canonical:?}: observed must equal Fisher for canonical link"
2472            );
2473        }
2474
2475        // Gamma-log (non-canonical): observed = Fisher · (y/μ). Exercise
2476        // with shape=1 and shape=2.5 to confirm the plumbing.
2477        for &shape in &[1.0_f64, 2.5] {
2478            let gf = row_reweight_cpu(
2479                PirlsRowFamily::GammaLog,
2480                CurvatureMode::Fisher,
2481                input,
2482                shape,
2483            );
2484            let go = row_reweight_cpu(
2485                PirlsRowFamily::GammaLog,
2486                CurvatureMode::Observed,
2487                input,
2488                shape,
2489            );
2490            assert!(
2491                (go.w_hessian - gf.w_fisher * (probe_y / gf.mu)).abs() <= 1e-12,
2492                "Gamma-log observed mismatch (shape={shape}): got={} expected={} (mu={})",
2493                go.w_hessian,
2494                gf.w_fisher * (probe_y / gf.mu),
2495                gf.mu
2496            );
2497            assert_ne!(
2498                gf.w_hessian, go.w_hessian,
2499                "Gamma-log: observed must differ from Fisher when y ≠ μ (shape={shape})"
2500            );
2501        }
2502
2503        // Bernoulli probit/cloglog: observed must differ from Fisher
2504        // generically and must be ≥ 0 at well-behaved interior points
2505        // (saturation tail can push it negative; tested elsewhere).
2506        for noncanon in [
2507            PirlsRowFamily::BernoulliProbit,
2508            PirlsRowFamily::BernoulliCLogLog,
2509        ] {
2510            let f = row_reweight_cpu(noncanon, CurvatureMode::Fisher, input, 1.0);
2511            let o = row_reweight_cpu(noncanon, CurvatureMode::Observed, input, 1.0);
2512            assert!(
2513                (f.w_hessian - o.w_hessian).abs() > 0.0 || (probe_y - f.mu).abs() < 1e-15,
2514                "{noncanon:?}: observed should differ from Fisher when y ≠ μ"
2515            );
2516        }
2517    }
2518
2519    /// Gamma-log CPU reference: `w_fisher` scales linearly with shape;
2520    /// observed `w_hessian = w_fisher · y/μ` holds for any positive shape.
2521    #[test]
2522    fn gamma_log_shape_scaling() {
2523        let input = RowInput {
2524            eta: 0.5,
2525            y: 2.0,
2526            prior_weight: 1.0,
2527        };
2528        let base = row_reweight_cpu(PirlsRowFamily::GammaLog, CurvatureMode::Fisher, input, 1.0);
2529        for &shape in &[0.5_f64, 1.5, 3.0, 10.0] {
2530            let r = row_reweight_cpu(
2531                PirlsRowFamily::GammaLog,
2532                CurvatureMode::Fisher,
2533                input,
2534                shape,
2535            );
2536            assert!(
2537                (r.w_fisher - shape * base.w_fisher).abs() <= 1e-14,
2538                "w_fisher should scale with shape: got {} expected {} (shape={shape})",
2539                r.w_fisher,
2540                shape * base.w_fisher,
2541            );
2542            assert_eq!(
2543                r.mu.to_bits(),
2544                base.mu.to_bits(),
2545                "mu must not depend on shape"
2546            );
2547            let ro = row_reweight_cpu(
2548                PirlsRowFamily::GammaLog,
2549                CurvatureMode::Observed,
2550                input,
2551                shape,
2552            );
2553            let expected_obs = r.w_fisher * (input.y / r.mu);
2554            assert!(
2555                (ro.w_hessian - expected_obs).abs() <= 1e-13,
2556                "observed w_hessian mismatch (shape={shape}): got={} expected={}",
2557                ro.w_hessian,
2558                expected_obs,
2559            );
2560        }
2561    }
2562
2563    /// V100 parity for the device-side row launcher.
2564    ///
2565    /// For every built-in family the launcher's per-row outputs must match
2566    /// the CPU `row_reweight_cpu` reference to round-off. Skipped on hosts
2567    /// without a CUDA runtime (mac, CI). This is also the production
2568    /// caller that justifies the launcher + `RowOutputDevBuffers` surface
2569    /// per the dead-pub-scanner rule.
2570    #[test]
2571    fn launch_row_reweight_matches_cpu_reference_on_device() {
2572        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2573            eprintln!("[pirls_row_gpu test] no CUDA runtime — skipping launcher parity test");
2574            return;
2575        }
2576        #[cfg(target_os = "linux")]
2577        {
2578            // Use a small but non-trivial row batch with several y-values per
2579            // family. Pick a y that's valid for every family below (0/1 are
2580            // valid for Bernoulli; positive for Poisson/Gamma; arbitrary for
2581            // Gaussian). Build per-family input vectors.
2582            let etas = [-3.0_f64, -0.5, 0.0, 0.5, 3.0, 10.0, -10.0, 1.5];
2583            let n = etas.len();
2584            let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2585            let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2586                .expect("shared backend probe")
2587                .stream;
2588
2589            for &family in PirlsRowFamily::ALL.iter() {
2590                let ys: Vec<f64> = match family {
2591                    PirlsRowFamily::GammaLog | PirlsRowFamily::PoissonLog => {
2592                        (0..n).map(|i| 1.0 + 0.5 * (i as f64)).collect()
2593                    }
2594                    PirlsRowFamily::GaussianIdentity => {
2595                        (0..n).map(|i| -1.0 + 0.5 * (i as f64)).collect()
2596                    }
2597                    _ => (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
2598                };
2599                let priors: Vec<f64> = (0..n).map(|i| 1.0 + 0.25 * (i as f64)).collect();
2600
2601                // CPU reference.
2602                let mut cpu_out = Vec::with_capacity(n);
2603                for i in 0..n {
2604                    cpu_out.push(row_reweight_cpu(
2605                        family,
2606                        CurvatureMode::Fisher,
2607                        RowInput {
2608                            eta: etas[i],
2609                            y: ys[i],
2610                            prior_weight: priors[i],
2611                        },
2612                        1.0,
2613                    ));
2614                }
2615
2616                // Upload inputs, allocate device outputs, launch, download.
2617                let mut eta_dev = stream.alloc_zeros::<f64>(n).expect("alloc eta_dev");
2618                let mut y_dev = stream.alloc_zeros::<f64>(n).expect("alloc y_dev");
2619                let mut prior_dev = stream.alloc_zeros::<f64>(n).expect("alloc prior_dev");
2620                stream
2621                    .memcpy_htod(etas.as_slice(), &mut eta_dev)
2622                    .expect("upload eta");
2623                stream
2624                    .memcpy_htod(ys.as_slice(), &mut y_dev)
2625                    .expect("upload y");
2626                stream
2627                    .memcpy_htod(priors.as_slice(), &mut prior_dev)
2628                    .expect("upload prior");
2629                let mut out = RowOutputDevBuffers::allocate(&stream, n).expect("alloc row buffers");
2630                launch_row_reweight_on_stream(
2631                    backend,
2632                    family,
2633                    CurvatureMode::Fisher,
2634                    1.0,
2635                    &stream,
2636                    n,
2637                    &eta_dev,
2638                    &y_dev,
2639                    &prior_dev,
2640                    &mut out,
2641                )
2642                .unwrap_or_else(|err| panic!("launch {family:?}: {err}"));
2643                stream.synchronize().expect("stream sync");
2644                let mu = stream.clone_dtoh(&out.mu).expect("dl mu");
2645                let g = stream.clone_dtoh(&out.grad_eta).expect("dl grad_eta");
2646                let wf = stream.clone_dtoh(&out.w_fisher).expect("dl w_fisher");
2647                let wh = stream.clone_dtoh(&out.w_hessian).expect("dl w_hessian");
2648                let ws_v = stream.clone_dtoh(&out.w_solver).expect("dl w_solver");
2649                let zf = stream.clone_dtoh(&out.z_fisher).expect("dl z_fisher");
2650                let zh = stream.clone_dtoh(&out.z_hessian).expect("dl z_hessian");
2651                let dev = stream.clone_dtoh(&out.deviance).expect("dl deviance");
2652
2653                let tol = 1e-12;
2654                for i in 0..n {
2655                    let r = cpu_out[i];
2656                    assert_close(&format!("{family:?}/row{i}/mu"), mu[i], r.mu, tol);
2657                    assert_close(
2658                        &format!("{family:?}/row{i}/grad_eta"),
2659                        g[i],
2660                        r.grad_eta,
2661                        tol,
2662                    );
2663                    assert_close(
2664                        &format!("{family:?}/row{i}/w_fisher"),
2665                        wf[i],
2666                        r.w_fisher,
2667                        tol,
2668                    );
2669                    assert_close(
2670                        &format!("{family:?}/row{i}/w_hessian"),
2671                        wh[i],
2672                        r.w_hessian,
2673                        tol,
2674                    );
2675                    assert_close(
2676                        &format!("{family:?}/row{i}/w_solver"),
2677                        ws_v[i],
2678                        r.w_solver,
2679                        tol,
2680                    );
2681                    assert_close(
2682                        &format!("{family:?}/row{i}/z_fisher"),
2683                        zf[i],
2684                        r.z_fisher,
2685                        tol,
2686                    );
2687                    assert_close(
2688                        &format!("{family:?}/row{i}/z_hessian"),
2689                        zh[i],
2690                        r.z_hessian,
2691                        tol,
2692                    );
2693                    assert_close(
2694                        &format!("{family:?}/row{i}/deviance"),
2695                        dev[i],
2696                        r.deviance,
2697                        tol,
2698                    );
2699                }
2700            }
2701        }
2702    }
2703
2704    /// Stage 5 end-to-end V100 parity for `CurvatureMode::Observed`.
2705    ///
2706    /// 256-row synthetic fixture per family with η ∈ [-6, 6] spanning the
2707    /// saturated tails. Two assertions:
2708    ///
2709    /// 1. Noncanonical families (BernoulliProbit, BernoulliCLogLog,
2710    ///    GammaLog): device `w_hessian` under Observed mode matches the
2711    ///    CPU `row_reweight_cpu(..., Observed, ..)` reference to
2712    ///    `abs ≤ 1e-12` OR `rel ≤ 1e-11`.
2713    /// 2. Canonical families (GaussianIdentity, PoissonLog, BernoulliLogit):
2714    ///    device `w_hessian` under Observed mode equals device `w_fisher`
2715    ///    bit-for-bit (via `to_bits()`) — canonical links have observed
2716    ///    information ≡ Fisher information by construction.
2717    #[test]
2718    fn gpu_observed_parity() {
2719        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2720            eprintln!("[gpu_observed_parity] no CUDA runtime — skipping");
2721            return;
2722        }
2723        #[cfg(target_os = "linux")]
2724        {
2725            const N: usize = 256;
2726            let etas: Vec<f64> = (0..N)
2727                .map(|i| -6.0 + 12.0 * (i as f64) / ((N - 1) as f64))
2728                .collect();
2729            let priors: Vec<f64> = (0..N)
2730                .map(|i| 0.5 + 1.5 * ((i as f64) / (N as f64)))
2731                .collect();
2732
2733            let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2734            let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2735                .expect("shared backend probe")
2736                .stream;
2737
2738            for &family in PirlsRowFamily::ALL.iter() {
2739                let ys: Vec<f64> = match family {
2740                    PirlsRowFamily::GammaLog => (0..N).map(|i| 0.25 + 0.05 * (i as f64)).collect(),
2741                    PirlsRowFamily::PoissonLog => (0..N).map(|i| (i % 6) as f64).collect(),
2742                    PirlsRowFamily::GaussianIdentity => (0..N)
2743                        .map(|i| -2.0 + 4.0 * (i as f64) / ((N - 1) as f64))
2744                        .collect(),
2745                    _ => (0..N).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
2746                };
2747
2748                let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("alloc eta_dev");
2749                let mut y_dev = stream.alloc_zeros::<f64>(N).expect("alloc y_dev");
2750                let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("alloc prior_dev");
2751                stream
2752                    .memcpy_htod(etas.as_slice(), &mut eta_dev)
2753                    .expect("upload eta");
2754                stream
2755                    .memcpy_htod(ys.as_slice(), &mut y_dev)
2756                    .expect("upload y");
2757                stream
2758                    .memcpy_htod(priors.as_slice(), &mut prior_dev)
2759                    .expect("upload prior");
2760
2761                let mut out_obs = RowOutputDevBuffers::allocate(&stream, N).expect("alloc out_obs");
2762                launch_row_reweight_on_stream(
2763                    backend,
2764                    family,
2765                    CurvatureMode::Observed,
2766                    1.0,
2767                    &stream,
2768                    N,
2769                    &eta_dev,
2770                    &y_dev,
2771                    &prior_dev,
2772                    &mut out_obs,
2773                )
2774                .unwrap_or_else(|err| panic!("observed launch {family:?}: {err}"));
2775                stream.synchronize().expect("stream sync (observed)");
2776
2777                let wh_obs = stream
2778                    .clone_dtoh(&out_obs.w_hessian)
2779                    .expect("dl w_hessian (observed)");
2780                let wf_obs = stream
2781                    .clone_dtoh(&out_obs.w_fisher)
2782                    .expect("dl w_fisher (observed)");
2783
2784                if family.is_canonical() {
2785                    for i in 0..N {
2786                        assert_eq!(
2787                            wh_obs[i].to_bits(),
2788                            wf_obs[i].to_bits(),
2789                            "{family:?} row {i}: observed w_hessian {} must bit-equal w_fisher {} on canonical link",
2790                            wh_obs[i],
2791                            wf_obs[i],
2792                        );
2793                    }
2794                } else {
2795                    for i in 0..N {
2796                        let cpu = row_reweight_cpu(
2797                            family,
2798                            CurvatureMode::Observed,
2799                            RowInput {
2800                                eta: etas[i],
2801                                y: ys[i],
2802                                prior_weight: priors[i],
2803                            },
2804                            1.0,
2805                        );
2806                        let got = wh_obs[i];
2807                        let exp = cpu.w_hessian;
2808                        let abs_err = (got - exp).abs();
2809                        let rel_err = if exp.abs() > 0.0 {
2810                            abs_err / exp.abs()
2811                        } else {
2812                            abs_err
2813                        };
2814                        assert!(
2815                            abs_err <= 1.0e-12 || rel_err <= 1.0e-11,
2816                            "{family:?} row {i} (eta={}, y={}, wp={}): \
2817                             device w_hessian={} vs CPU observed={} (abs={}, rel={})",
2818                            etas[i],
2819                            ys[i],
2820                            priors[i],
2821                            got,
2822                            exp,
2823                            abs_err,
2824                            rel_err,
2825                        );
2826                    }
2827                }
2828            }
2829        }
2830    }
2831
2832    /// Task #50 — Stage 5 GPU end-to-end observed-curvature parity at
2833    /// n=1000 across **all six** supported families, validating BOTH
2834    /// `w_hessian` (Hessian diagonal) AND `grad_eta` (score) against the
2835    /// CPU observed-curvature oracle to abs/rel ≤ 1e-9. This is the
2836    /// full end-to-end companion to `gpu_observed_parity` (which only
2837    /// covered n=256 and only checked `w_hessian` for noncanonical
2838    /// families). Gated on a live CUDA runtime; marked `#[ignore]` so
2839    /// the v100-bench-runner explicitly opts in via `--ignored`.
2840    #[test]
2841    fn gpu_observed_parity_end_to_end_n1000() {
2842        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2843            eprintln!("[gpu_observed_parity_end_to_end_n1000] no CUDA runtime — skipping");
2844            return;
2845        }
2846        #[cfg(target_os = "linux")]
2847        {
2848            const N: usize = 1000;
2849            // Deterministic η grid spanning the saturated tails plus
2850            // the near-zero regime where the observed-information
2851            // correction term is most active.
2852            let etas: Vec<f64> = (0..N)
2853                .map(|i| -8.0 + 16.0 * (i as f64) / ((N - 1) as f64))
2854                .collect();
2855            let priors: Vec<f64> = (0..N)
2856                .map(|i| 0.25 + 1.75 * ((i as f64) / (N as f64)))
2857                .collect();
2858
2859            let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
2860            let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
2861                .expect("shared backend probe")
2862                .stream;
2863
2864            const TOL: f64 = 1.0e-9;
2865
2866            for &family in PirlsRowFamily::ALL.iter() {
2867                // Family-specific response vectors stay in domain so the
2868                // CPU oracle is well-defined for every row.
2869                let ys: Vec<f64> = match family {
2870                    PirlsRowFamily::GammaLog => {
2871                        (0..N).map(|i| 0.10 + 0.05 * ((i % 97) as f64)).collect()
2872                    }
2873                    PirlsRowFamily::PoissonLog => (0..N).map(|i| (i % 11) as f64).collect(),
2874                    PirlsRowFamily::GaussianIdentity => (0..N)
2875                        .map(|i| -3.0 + 6.0 * (i as f64) / ((N - 1) as f64))
2876                        .collect(),
2877                    PirlsRowFamily::BernoulliLogit
2878                    | PirlsRowFamily::BernoulliProbit
2879                    | PirlsRowFamily::BernoulliCLogLog => {
2880                        (0..N).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect()
2881                    }
2882                };
2883
2884                let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("alloc eta_dev");
2885                let mut y_dev = stream.alloc_zeros::<f64>(N).expect("alloc y_dev");
2886                let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("alloc prior_dev");
2887                stream
2888                    .memcpy_htod(etas.as_slice(), &mut eta_dev)
2889                    .expect("upload eta");
2890                stream
2891                    .memcpy_htod(ys.as_slice(), &mut y_dev)
2892                    .expect("upload y");
2893                stream
2894                    .memcpy_htod(priors.as_slice(), &mut prior_dev)
2895                    .expect("upload prior");
2896
2897                let mut out_obs = RowOutputDevBuffers::allocate(&stream, N).expect("alloc out_obs");
2898                launch_row_reweight_on_stream(
2899                    backend,
2900                    family,
2901                    CurvatureMode::Observed,
2902                    1.0,
2903                    &stream,
2904                    N,
2905                    &eta_dev,
2906                    &y_dev,
2907                    &prior_dev,
2908                    &mut out_obs,
2909                )
2910                .unwrap_or_else(|err| panic!("observed launch {family:?}: {err}"));
2911                stream.synchronize().expect("stream sync (observed)");
2912
2913                let wh_obs = stream
2914                    .clone_dtoh(&out_obs.w_hessian)
2915                    .expect("dl w_hessian (observed)");
2916                let ge_obs = stream
2917                    .clone_dtoh(&out_obs.grad_eta)
2918                    .expect("dl grad_eta (observed)");
2919
2920                for i in 0..N {
2921                    let cpu = row_reweight_cpu(
2922                        family,
2923                        CurvatureMode::Observed,
2924                        RowInput {
2925                            eta: etas[i],
2926                            y: ys[i],
2927                            prior_weight: priors[i],
2928                        },
2929                        1.0,
2930                    );
2931
2932                    // H diagonal (w_hessian) parity.
2933                    let h_got = wh_obs[i];
2934                    let h_exp = cpu.w_hessian;
2935                    let h_abs = (h_got - h_exp).abs();
2936                    let h_rel = if h_exp.abs() > 0.0 {
2937                        h_abs / h_exp.abs()
2938                    } else {
2939                        h_abs
2940                    };
2941                    assert!(
2942                        h_abs <= TOL || h_rel <= TOL,
2943                        "{family:?} row {i} (eta={}, y={}, wp={}): \
2944                         observed w_hessian GPU={} vs CPU={} (abs={}, rel={})",
2945                        etas[i],
2946                        ys[i],
2947                        priors[i],
2948                        h_got,
2949                        h_exp,
2950                        h_abs,
2951                        h_rel,
2952                    );
2953
2954                    // Gradient (grad_eta) parity — the score does not
2955                    // depend on curvature mode, but the CPU oracle here
2956                    // is exercised under `Observed` for full
2957                    // end-to-end coverage.
2958                    let g_got = ge_obs[i];
2959                    let g_exp = cpu.grad_eta;
2960                    let g_abs = (g_got - g_exp).abs();
2961                    let g_rel = if g_exp.abs() > 0.0 {
2962                        g_abs / g_exp.abs()
2963                    } else {
2964                        g_abs
2965                    };
2966                    assert!(
2967                        g_abs <= TOL || g_rel <= TOL,
2968                        "{family:?} row {i} (eta={}, y={}, wp={}): \
2969                         observed grad_eta GPU={} vs CPU={} (abs={}, rel={})",
2970                        etas[i],
2971                        ys[i],
2972                        priors[i],
2973                        g_got,
2974                        g_exp,
2975                        g_abs,
2976                        g_rel,
2977                    );
2978                }
2979            }
2980        }
2981    }
2982
2983    /// Task #51 — Stage 6 Level B end-to-end NVRTC JIT parity. For
2984    /// each of the 6 supported families we hand-author a raw CUDA
2985    /// body that re-derives the family math from scratch (distinct
2986    /// variable names + restructured statement order from the
2987    /// built-in `*_body` strings), JIT-compile via `JitFamilySpec::raw`
2988    /// through the full `launch_row_reweight_jit_on_stream` pipeline,
2989    /// and assert all 8 outputs match the CPU `row_reweight_cpu`
2990    /// oracle to ≤ 1e-10 on n=1000 rows. Skipped if no CUDA runtime;
2991    /// `#[ignore]` so v100-bench-runner picks it up via `--ignored`.
2992    #[test]
2993    fn gpu_jit_level_b_raw_body_end_to_end_all_families_n1000() {
2994        if gam_gpu::device_runtime::GpuRuntime::global().is_none() {
2995            eprintln!(
2996                "[gpu_jit_level_b_raw_body_end_to_end_all_families_n1000] no CUDA runtime — skipping"
2997            );
2998            return;
2999        }
3000        #[cfg(target_os = "linux")]
3001        {
3002            const N: usize = 1000;
3003            const TOL: f64 = 1.0e-10;
3004            let curvature = CurvatureMode::Fisher;
3005
3006            let backend = PirlsRowBackend::probe().expect("backend probe on CUDA host");
3007            let stream = gam_gpu::backend_probe::probe_cuda_backend("pirls_row")
3008                .expect("shared backend probe")
3009                .stream;
3010
3011            // η grid + per-family in-domain response vectors. Mirrors
3012            // `gpu_observed_parity_end_to_end_n1000` so the two tests
3013            // exercise the same numerical regime through different code
3014            // paths (built-in module vs JIT raw body).
3015            let etas: Vec<f64> = (0..N)
3016                .map(|i| -6.0 + 12.0 * (i as f64) / ((N - 1) as f64))
3017                .collect();
3018            let priors: Vec<f64> = (0..N)
3019                .map(|i| 0.25 + 1.75 * ((i as f64) / (N as f64)))
3020                .collect();
3021
3022            // Hand-authored raw bodies — re-derived from the math, not
3023            // copy-pasted from `*_body()`. Helpers (`clamp_eta`,
3024            // `std_norm_cdf`, `std_norm_pdf`, `bernoulli_deviance`,
3025            // `bernoulli_z`) are provided by `COMMON_DEVICE_PROLOG`.
3026            let raw_gaussian = r#"    // raw-body gaussian identity (independent re-derivation)
3027    double resp = y_i;
3028    double pred = eta_i;
3029    double mu = pred;
3030    double w_p = wp;
3031    double e_resid = resp - pred;
3032    double grad_eta = w_p * e_resid;
3033    double w_fisher = w_p;
3034    double w_hessian = w_p;
3035    double w_solver = (w_p > 0.0) ? fmax(w_p, 1e-12) : 0.0;
3036    double z_f = resp;
3037    double z_h = resp;
3038    double dev = w_p * e_resid * e_resid;
3039"#;
3040
3041            let raw_poisson = r#"    // raw-body poisson log (independent re-derivation)
3042    double eta_c = clamp_eta(eta_i, &flags);
3043    double mu_pre = exp(eta_c);
3044    if (mu_pre < 1e-10) flags |= 0x2u;
3045    double mu = (mu_pre > 1e-10) ? mu_pre : 1e-10;
3046    double wrate = wp * mu;
3047    double w_fisher = (wrate > 0.0) ? fmax(wrate, 1e-12) : 0.0;
3048    double w_hessian = w_fisher;
3049    double w_solver = w_fisher;
3050    double pres = y_i - mu;
3051    double grad_eta = wp * pres;
3052    double z_lin = eta_c + pres / mu;
3053    double z_f = z_lin;
3054    double z_h = z_lin;
3055    double dterm;
3056    if (y_i > 0.0) {
3057        dterm = y_i * log(y_i / mu) - pres;
3058    } else {
3059        dterm = -pres;
3060    }
3061    double dev = 2.0 * wp * dterm;
3062    if (!(isfinite(y_i) && y_i >= 0.0)) flags |= 0x8u;
3063"#;
3064
3065            let raw_gamma = r#"    // raw-body gamma log (independent re-derivation; unit shape)
3066    double k_shape = 1.0;
3067    double eta_c = clamp_eta(eta_i, &flags);
3068    double mu_pre = exp(eta_c);
3069    if (mu_pre < 1e-10) flags |= 0x2u;
3070    double mu = (mu_pre > 1e-10) ? mu_pre : 1e-10;
3071    double w_fisher = wp * k_shape;
3072    double w_hessian = w_fisher;
3073    double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
3074    double pres = y_i - mu;
3075    double grad_eta = wp * pres / mu;
3076    double z_lin = eta_c + pres / mu;
3077    double z_f = z_lin;
3078    double z_h = z_lin;
3079    double dev;
3080    if (y_i > 0.0) {
3081        dev = 2.0 * wp * (-log(y_i / mu) + pres / mu);
3082    } else {
3083        dev = 1.0 / 0.0;
3084    }
3085    if (!(isfinite(y_i) && y_i > 0.0)) flags |= 0x8u;
3086"#;
3087
3088            let raw_logit = r#"    // raw-body bernoulli logit (independent re-derivation)
3089    double eta_c = clamp_eta(eta_i, &flags);
3090    double te = tanh(0.5 * eta_c);
3091    double mu_pre = 0.5 * (1.0 + te);
3092    if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
3093    double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
3094    double dmu_deta = mu * (1.0 - mu);
3095    double w_fisher = wp * dmu_deta;
3096    double w_hessian = w_fisher;
3097    double w_solver = (w_fisher > 0.0) ? fmax(w_fisher, 1e-12) : 0.0;
3098    double bres = y_i - mu;
3099    double grad_eta = wp * bres;
3100    double dev = bernoulli_deviance(y_i, mu, wp);
3101    double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
3102    double z_f = z_lin;
3103    double z_h = z_lin;
3104    if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
3105"#;
3106
3107            let raw_probit = r#"    // raw-body bernoulli probit (independent re-derivation; Fisher mode)
3108    double eta_c = clamp_eta(eta_i, &flags);
3109    double mu_pre = std_norm_cdf(eta_c);
3110    if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
3111    double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
3112    double phi = std_norm_pdf(eta_c);
3113    double dmu_deta = phi;
3114    double vmu = mu * (1.0 - mu);
3115    double w_pp = (vmu > 0.0) ? (phi * phi) / vmu : 0.0;
3116    double w_fisher = wp * w_pp;
3117    double w_hessian = w_fisher;
3118    double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
3119    double bres = y_i - mu;
3120    double grad_eta = (vmu > 0.0) ? wp * bres * phi / vmu : 0.0;
3121    double dev = bernoulli_deviance(y_i, mu, wp);
3122    double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
3123    double z_f = z_lin;
3124    double z_h = z_lin;
3125    if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
3126"#;
3127
3128            let raw_cloglog = r#"    // raw-body bernoulli cloglog (independent re-derivation; Fisher mode)
3129    double eta_c = clamp_eta(eta_i, &flags);
3130    double a = exp(eta_c);
3131    double mu_pre = 1.0 - exp(-a);
3132    if (mu_pre < 1e-12 || mu_pre > 1.0 - 1e-12) flags |= 0x2u;
3133    double mu = fmin(fmax(mu_pre, 1e-12), 1.0 - 1e-12);
3134    double dmu_deta = a * (1.0 - mu_pre);
3135    double vmu = mu * (1.0 - mu);
3136    double w_pp = (vmu > 0.0) ? (dmu_deta * dmu_deta) / vmu : 0.0;
3137    double w_fisher = wp * w_pp;
3138    double w_hessian = w_fisher;
3139    double w_solver = (w_hessian > 0.0) ? fmax(w_hessian, 1e-12) : 0.0;
3140    double bres = y_i - mu;
3141    double grad_eta = (vmu > 0.0) ? wp * bres * dmu_deta / vmu : 0.0;
3142    double dev = bernoulli_deviance(y_i, mu, wp);
3143    double z_lin = bernoulli_z(eta_c, y_i, mu, dmu_deta);
3144    double z_f = z_lin;
3145    double z_h = z_lin;
3146    if (!(isfinite(y_i) && y_i >= 0.0 && y_i <= 1.0)) flags |= 0x8u;
3147"#;
3148
3149            // (family, raw_body, distinct spec_id, y vector builder).
3150            // spec_ids are disjoint 64-bit tags so the JIT module cache
3151            // creates one fresh module per family.
3152            let cases: [(PirlsRowFamily, &str, u64, fn(usize) -> Vec<f64>); 6] = [
3153                (
3154                    PirlsRowFamily::GaussianIdentity,
3155                    raw_gaussian,
3156                    0x5242_3031_4741_5553u64,
3157                    |n| {
3158                        (0..n)
3159                            .map(|i| -3.0 + 6.0 * (i as f64) / ((n - 1) as f64))
3160                            .collect()
3161                    },
3162                ),
3163                (
3164                    PirlsRowFamily::PoissonLog,
3165                    raw_poisson,
3166                    0x5242_3032_504f_4953u64,
3167                    |n| (0..n).map(|i| (i % 11) as f64).collect(),
3168                ),
3169                (
3170                    PirlsRowFamily::GammaLog,
3171                    raw_gamma,
3172                    0x5242_3033_474d_414cu64,
3173                    |n| (0..n).map(|i| 0.10 + 0.05 * ((i % 97) as f64)).collect(),
3174                ),
3175                (
3176                    PirlsRowFamily::BernoulliLogit,
3177                    raw_logit,
3178                    0x5242_3034_4c47_4954u64,
3179                    |n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
3180                ),
3181                (
3182                    PirlsRowFamily::BernoulliProbit,
3183                    raw_probit,
3184                    0x5242_3035_5052_4254u64,
3185                    |n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
3186                ),
3187                (
3188                    PirlsRowFamily::BernoulliCLogLog,
3189                    raw_cloglog,
3190                    0x5242_3036_434c_4f47u64,
3191                    |n| (0..n).map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }).collect(),
3192                ),
3193            ];
3194
3195            for (family, raw_body, spec_id, build_y) in cases {
3196                let ys: Vec<f64> = build_y(N);
3197
3198                let mut eta_dev = stream.alloc_zeros::<f64>(N).expect("eta");
3199                let mut y_dev = stream.alloc_zeros::<f64>(N).expect("y");
3200                let mut prior_dev = stream.alloc_zeros::<f64>(N).expect("prior");
3201                stream.memcpy_htod(&etas, &mut eta_dev).expect("up eta");
3202                stream.memcpy_htod(&ys, &mut y_dev).expect("up y");
3203                stream
3204                    .memcpy_htod(&priors, &mut prior_dev)
3205                    .expect("up prior");
3206
3207                let spec = JitFamilySpec::raw(spec_id, raw_body);
3208                let mut out_jit = RowOutputDevBuffers::allocate(&stream, N).expect("alloc jit out");
3209                launch_row_reweight_jit_on_stream(
3210                    backend,
3211                    &spec,
3212                    curvature,
3213                    &stream,
3214                    N,
3215                    &eta_dev,
3216                    &y_dev,
3217                    &prior_dev,
3218                    &mut out_jit,
3219                )
3220                .unwrap_or_else(|err| panic!("jit raw-body launch {family:?}: {err}"));
3221                stream.synchronize().expect("sync");
3222
3223                let mu_j = stream.clone_dtoh(&out_jit.mu).expect("dl mu");
3224                let ge_j = stream.clone_dtoh(&out_jit.grad_eta).expect("dl g");
3225                let wf_j = stream.clone_dtoh(&out_jit.w_fisher).expect("dl wf");
3226                let wh_j = stream.clone_dtoh(&out_jit.w_hessian).expect("dl wh");
3227                let ws_j = stream.clone_dtoh(&out_jit.w_solver).expect("dl ws");
3228                let zf_j = stream.clone_dtoh(&out_jit.z_fisher).expect("dl zf");
3229                let zh_j = stream.clone_dtoh(&out_jit.z_hessian).expect("dl zh");
3230                let dv_j = stream.clone_dtoh(&out_jit.deviance).expect("dl dv");
3231
3232                for i in 0..N {
3233                    let cpu = row_reweight_cpu(
3234                        family,
3235                        curvature,
3236                        RowInput {
3237                            eta: etas[i],
3238                            y: ys[i],
3239                            prior_weight: priors[i],
3240                        },
3241                        1.0,
3242                    );
3243                    for (label, got, exp) in [
3244                        ("mu", mu_j[i], cpu.mu),
3245                        ("grad_eta", ge_j[i], cpu.grad_eta),
3246                        ("w_fisher", wf_j[i], cpu.w_fisher),
3247                        ("w_hessian", wh_j[i], cpu.w_hessian),
3248                        ("w_solver", ws_j[i], cpu.w_solver),
3249                        ("z_fisher", zf_j[i], cpu.z_fisher),
3250                        ("z_hessian", zh_j[i], cpu.z_hessian),
3251                        ("deviance", dv_j[i], cpu.deviance),
3252                    ] {
3253                        if !got.is_finite() && !exp.is_finite() {
3254                            // Both NaN/inf is a parity match for the
3255                            // gamma y=0 → +inf deviance branch.
3256                            continue;
3257                        }
3258                        let abs_err = (got - exp).abs();
3259                        let rel_err = if exp.abs() > 0.0 {
3260                            abs_err / exp.abs()
3261                        } else {
3262                            abs_err
3263                        };
3264                        assert!(
3265                            abs_err <= TOL || rel_err <= TOL,
3266                            "{family:?} {label}[{i}] (eta={}, y={}, wp={}): \
3267                             JIT raw-body={} vs CPU={} (abs={}, rel={})",
3268                            etas[i],
3269                            ys[i],
3270                            priors[i],
3271                            got,
3272                            exp,
3273                            abs_err,
3274                            rel_err,
3275                        );
3276                    }
3277                }
3278            }
3279        }
3280    }
3281}