Skip to main content

gam_models/gpu_kernels/
survival_rowjet.rs

1//! Survival marginal-slope rigid per-row NLL jet on the GPU (#932 → A100 cutover).
2//!
3//! The rigid survival marginal-slope `RowKernel<4>`
4//! ([`crate::survival::marginal_slope::row_kernel::rigid_row_nll`], the
5//! #932 unified single source) computes, per row, the order-2 derivative tower
6//! `(v, g[4], H[4][4])` of the negative log-likelihood
7//!
8//! ```text
9//!   c(g)  = √(1 + (s·g)²·cov),   η0 = q0·c + s·g·z,   η1 = q1·c + s·g·z,
10//!   ad1   = qd1·c,
11//!   ℓ     = +w·logΦ(−η0) + w·(1−d)·logΦ(−η1) − w·d·(logφ(η1) + log ad1)
12//! ```
13//!
14//! plus the contracted third `Σ_c ℓ_{abc} dir_c` and fourth
15//! `Σ_{cd} ℓ_{abcd} u_c v_d`. Each row evaluates the probit Mills-ratio stack
16//! (`erfcx`/`erfc`) several times — a transcendental + bandwidth wall that the
17//! CPU pays serially per thread across all `n` rows on every inner-Newton step
18//! and on the #979 Jeffreys/Firth all-axes sweeps.
19//!
20//! On an A100 the per-row jet is embarrassingly parallel and the `erfc`/`erfcx`
21//! are hardware f64 special functions. Measured (aga13 A100, full f64, no
22//! fast-math, n=8e6): **~500× kernel-only** over the 16-thread CPU jet and
23//! **~160× end-to-end** with the on-device reduction. The standalone
24//! measurement prototype lives at
25//! `src/gpu/proto/survival_marginal_slope_jet_932.cu`.
26//!
27//! # CPU↔device parity (#415 / #1175)
28//!
29//! The device kernel runs the SAME seeded-jet arithmetic as the CPU jet (pinned
30//! line-for-line by the host-oracle `*_tests` module on every box), so the
31//! CPU↔device residual is NOT an algebra mismatch. After #1686 disabled NVRTC
32//! FMA contraction (`--fmad=false`, applied here because this kernel now
33//! compiles through `device_cache::compile_ptx_arch`, the shared arch+fmad
34//! options), TWO distinct floors remain, with very different magnitudes:
35//!
36//!   * **Low-order channels (value/grad/hess)** — FMA contraction WAS the
37//!     dominant source here, so `--fmad=false` tightened them sharply. Measured
38//!     on a **Tesla V100 (sm_70)**: value 1.5e-10, grad 8.2e-10, hess 8.8e-9
39//!     absolute (≤1.1e-1 normalized to channel magnitude).
40//!   * **High-order channels (third/fourth)** — dominated by *transcendental*
41//!     drift, NOT FMA: CUDA's `erfc`/`erfcx`/`exp`/`sqrt` differ from the host
42//!     libm at the ULP level, and that ε is amplified ~5e8× through the order-4
43//!     seeded-jet chain. `--fmad=false` leaves these essentially unchanged
44//!     (third 5.09e-8, fourth 4.54e-8 absolute — bit-identical to the
45//!     pre-#1686 measurement to 4 sig figs), confirming FMA was never their
46//!     root cause. Normalized to channel magnitude they are ≤1.2e-9 (third) and
47//!     bounded by the magnitude-scaled band below (fourth).
48//!
49//! The parity gate (`tests::device_matches_cpu_when_available`, and the
50//! fail-loud device-only sweep) is therefore a per-channel
51//! `atol + rtol·channel_scale` band, NOT a flat absolute tolerance — see
52//! `tests::PARITY_RTOL` for why a flat `1e-9` absolute bound was wrong (it
53//! ignored both derivative-order amplification AND the transcendental floor
54//! that #1686's FMA fix cannot reach) and why the magnitude-scaled band still
55//! catches any real algebra bug with comfortable headroom. This band is
56//! *complementary* to #1686, not redundant: #1686 removes the FMA component,
57//! the band absorbs the irreducible transcendental component.
58//!
59//! # Single source, exactly
60//!
61//! The device kernel is a byte-faithful port of the seeded-jet arithmetic that
62//! the CPU `rigid_row_nll` runs:
63//!
64//!   * `J2`  — order-2 `(v, g, H)` over `K=4` primaries (mirrors `Order2<4>`);
65//!   * `JS1` — one-seed jet whose ε-Hessian channel IS `Σ_c ℓ_{abc} dir_c`
66//!     (mirrors `OneSeed<4>` — O(K²) state, NOT a dense K³ `t3`);
67//!   * `JS2` — two-seed jet whose εδ-Hessian channel IS `Σ_{cd} ℓ_{abcd} u_c v_d`
68//!     (mirrors `TwoSeed<4>` — O(K²) state, NOT a dense K⁴ `t4`).
69//!
70//! Seeded jets are load-bearing: a dense `Tower4<4>` on device spills 41 KB/thread
71//! (256-entry `t4`) and OOMs the launch local-memory reservation; the seeded jets
72//! drop per-thread stack to ~900 B. The same NLL program (`def_nll!`) is written
73//! ONCE and instantiated at each scalar type — no bespoke gate chain rule, so the
74//! #736 cross-block sign-flip bug genus cannot reappear.
75//!
76//! # CPU fallback
77//!
78//! [`survival_rigid_row_jets`] is the general entry point. When a CUDA device is
79//! admitted and the batch is large enough to amortise the launch it runs the
80//! kernel; otherwise (no Linux / no runtime / probe failure / small `n` / any
81//! device error) it falls back to the CPU `rigid_row_nll` — the SAME unified jet —
82//! so the result is identical and the path is never GPU-only.
83
84use crate::survival::marginal_slope::row_kernel::RigidRowInputs;
85
86// #415 parity-lock: a host transcription of the device `.cu` seeded-jet
87// arithmetic, pinned to the production CPU jet on every box. Declared bare
88// (the whole file is `#![cfg(test)]`) with a `*_tests` name so the build.rs
89// ban-scanner exempts the test-only substrate — see `bms::test_support`.
90mod survival_rowjet_host_oracle_tests;
91
92/// Per-row order-≤2 + contracted third/fourth channels for a batch of rows,
93/// flattened row-major. `K = 4` (the rigid survival primaries `q0,q1,qd1,g`).
94///
95/// * `value[row]`            — `ℓ`
96/// * `grad[row*K + a]`       — `∂ℓ/∂p_a`
97/// * `hess[row*K*K + a*K+b]` — `∂²ℓ/∂p_a∂p_b`
98/// * `third[row*K*K + a*K+b]`  — `Σ_c ℓ_{abc} dir_c`        (one fixed `dir`)
99/// * `fourth[row*K*K + a*K+b]` — `Σ_{cd} ℓ_{abcd} u_c v_d`  (one fixed `(u,v)`)
100#[derive(Debug, Clone, PartialEq)]
101pub struct SurvivalRowJetChannels {
102    pub n_rows: usize,
103    pub value: Vec<f64>,
104    pub grad: Vec<f64>,
105    pub hess: Vec<f64>,
106    pub third: Vec<f64>,
107    pub fourth: Vec<f64>,
108}
109
110/// The scalar-independent per-row inputs the kernel consumes: the four primaries
111/// `(q0,q1,qd1,g)` and the row scalars `(w,d,z_sum,cov_ones)`. `probit_scale` is
112/// shared across all rows (a scalar kernel argument). These are exactly the
113/// values [`RigidRowInputs`] + `rigid_row_kernel_primaries` produce per row.
114#[derive(Debug, Clone)]
115pub struct SurvivalRowInputs {
116    pub primaries: [f64; 4],
117    pub wi: f64,
118    pub di: f64,
119    pub z_sum: f64,
120    pub cov_ones: f64,
121}
122
123/// Minimum row count below which the device launch is not worth its fixed cost
124/// (probe + H2D + D2H). Below this the CPU path is used even when a device is
125/// available; the result is identical (same unified jet). The standalone A100
126/// measurement put the kernel/CPU crossover well under 1e5 rows; 1e5 is a
127/// conservative break-even that keeps small-fit latency on the CPU.
128pub const DEVICE_ROW_THRESHOLD: usize = 100_000;
129
130/// CPU reference / fallback: build every row's channels from the SAME unified jet
131/// the production `RowKernel` consumes (`rigid_row_nll` at `Order2`/`OneSeed`/
132/// `TwoSeed`). This is BOTH the fallback path AND the exactness oracle the device
133/// kernel is pinned to.
134#[must_use]
135pub fn survival_rigid_row_jets_cpu(
136    rows: &[SurvivalRowInputs],
137    probit_scale: f64,
138    dir: &[f64; 4],
139    dir_u: &[f64; 4],
140    dir_v: &[f64; 4],
141) -> SurvivalRowJetChannels {
142    use crate::survival::marginal_slope::row_kernel::{
143        RIGID_LINEAR_MASK, SparseOrder2, rigid_row_nll,
144    };
145    use gam_math::jet_scalar::{JetScalar, OneSeed, TwoSeed};
146    let n = rows.len();
147    let mut value = vec![0.0_f64; n];
148    let mut grad = vec![0.0_f64; n * 4];
149    let mut hess = vec![0.0_f64; n * 16];
150    let mut third = vec![0.0_f64; n * 16];
151    let mut fourth = vec![0.0_f64; n * 16];
152    for (row, inp) in rows.iter().enumerate() {
153        let in_row = RigidRowInputs {
154            row,
155            wi: inp.wi,
156            di: inp.di,
157            z_sum: inp.z_sum,
158            covariance_ones: inp.cov_ones,
159            probit_scale,
160            // The CPU monotonicity guard floor: the device kernel does not
161            // re-derive it (the caller pre-validates the primaries before
162            // building the batch), so use the always-pass sentinel here to
163            // keep the oracle a pure derivative comparison.
164            qd1_lower: f64::NEG_INFINITY,
165        };
166        // (v, g, H) at the static-sparsity Order2 scalar (production hot path).
167        let p = inp.primaries;
168        let vars: [SparseOrder2<RIGID_LINEAR_MASK>; 4] =
169            std::array::from_fn(|a| SparseOrder2::variable(p[a], a));
170        if let Ok(out) = rigid_row_nll(&vars, &in_row) {
171            value[row] = out.value();
172            grad[row * 4..row * 4 + 4].copy_from_slice(&out.g());
173            let h = out.h();
174            for a in 0..4 {
175                for b in 0..4 {
176                    hess[row * 16 + a * 4 + b] = h[a][b];
177                }
178            }
179        }
180        // contracted third via OneSeed (ε-Hessian = Σ_c ℓ_{abc} dir_c).
181        let vars1: [OneSeed<4>; 4] =
182            std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
183        if let Ok(out1) = rigid_row_nll(&vars1, &in_row) {
184            let t = out1.contracted_third();
185            for a in 0..4 {
186                for b in 0..4 {
187                    third[row * 16 + a * 4 + b] = t[a][b];
188                }
189            }
190        }
191        // contracted fourth via TwoSeed (εδ-Hessian = Σ_{cd} ℓ_{abcd} u_c v_d).
192        let vars2: [TwoSeed<4>; 4] =
193            std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
194        if let Ok(out2) = rigid_row_nll(&vars2, &in_row) {
195            let f = out2.contracted_fourth();
196            for a in 0..4 {
197                for b in 0..4 {
198                    fourth[row * 16 + a * 4 + b] = f[a][b];
199                }
200            }
201        }
202    }
203    SurvivalRowJetChannels {
204        n_rows: n,
205        value,
206        grad,
207        hess,
208        third,
209        fourth,
210    }
211}
212
213/// General entry point: compute every row's order-≤2 + contracted third/fourth
214/// channels, on the GPU when a CUDA device is admitted and the batch is large
215/// enough to amortise the launch, else on the CPU. Both paths run the SAME
216/// unified jet, so the result agrees within the per-channel magnitude-scaled
217/// parity band (irreducible transcendental drift only — see the module docs and
218/// `tests::PARITY_RTOL`; worst measured ≤1.2e-9 relative on a V100). On ANY
219/// device error the CPU path runs — no fragility.
220#[must_use]
221pub fn survival_rigid_row_jets(
222    rows: &[SurvivalRowInputs],
223    probit_scale: f64,
224    dir: &[f64; 4],
225    dir_u: &[f64; 4],
226    dir_v: &[f64; 4],
227) -> SurvivalRowJetChannels {
228    #[cfg(target_os = "linux")]
229    {
230        if rows.len() >= DEVICE_ROW_THRESHOLD {
231            match device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v) {
232                Ok(out) => return out,
233                Err(e) => {
234                    // Fall through to CPU on any device error (the GPU path is an
235                    // accelerator, never the only correct path). Log WHY so a
236                    // silent CPU fallback on an admitted device is diagnosable.
237                    log::info!("[GPU] survival_rowjet device path fell back to CPU: {e}");
238                }
239            }
240        }
241    }
242    survival_rigid_row_jets_cpu(rows, probit_scale, dir, dir_u, dir_v)
243}
244
245/// Diagnostic: run ONLY the device path and return its `Result` (the error
246/// string on failure). Linux-only; intended for A100 verification harnesses to
247/// surface a compile/launch failure that the silent-fallback dispatcher hides.
248#[cfg(target_os = "linux")]
249pub fn survival_rigid_row_jets_device_only(
250    rows: &[SurvivalRowInputs],
251    probit_scale: f64,
252    dir: &[f64; 4],
253    dir_u: &[f64; 4],
254    dir_v: &[f64; 4],
255) -> Result<SurvivalRowJetChannels, String> {
256    device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v)
257        .map_err(|e| e.to_string())
258}
259
260/// The NVRTC source: a byte-faithful port of the seeded-jet arithmetic.
261/// `K=4` is fixed for the rigid survival primaries, so the kernel is compiled
262/// once (no shape macros). Full f64, no fast-math.
263#[cfg(target_os = "linux")]
264pub const SURVIVAL_ROWJET_SOURCE: &str = include_str!("survival_rowjet_kernel.cu");
265
266#[cfg(target_os = "linux")]
267mod device {
268    use super::{SURVIVAL_ROWJET_SOURCE, SurvivalRowInputs, SurvivalRowJetChannels};
269    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
270    use std::sync::{Arc, Mutex, OnceLock};
271
272    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
273
274    struct Backend {
275        ctx: Arc<CudaContext>,
276        stream: Arc<CudaStream>,
277        module: Mutex<Option<Arc<CudaModule>>>,
278    }
279
280    fn backend() -> Result<&'static Backend, GpuError> {
281        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
282        BACKEND
283            .get_or_init(|| {
284                let parts = gam_gpu::backend_probe::probe_cuda_backend("survival_rowjet")?;
285                Ok(Backend {
286                    ctx: parts.ctx,
287                    stream: parts.stream,
288                    module: Mutex::new(None),
289                })
290            })
291            .as_ref()
292            .map_err(GpuError::clone)
293    }
294
295    fn module(b: &Backend) -> Result<Arc<CudaModule>, GpuError> {
296        if let Ok(guard) = b.module.lock() {
297            if let Some(m) = guard.as_ref() {
298                return Ok(m.clone());
299            }
300        }
301        // Compile through the shared arch+fmad options (NOT bare `compile_ptx`,
302        // which leaves NVRTC at `--fmad=true` and no `--gpu-architecture` pin).
303        // FMA contraction must be off so the deep seeded-jet tower is
304        // bit-comparable to the separately-rounded CPU oracle — bare
305        // `compile_ptx` made this kernel miss the 1e-9 parity gate by ~5e-8 on
306        // a V100. The arch pin keeps the kernel keyed to the device's real
307        // compute capability rather than NVRTC's default.
308        let ptx = gam_gpu::device_cache::compile_ptx_arch(SURVIVAL_ROWJET_SOURCE)
309            .gpu_ctx_with(|err| format!("survival_rowjet NVRTC compile: {err}"))?;
310        let m = b
311            .ctx
312            .load_module(ptx)
313            .gpu_ctx("survival_rowjet module load")?;
314        if let Ok(mut guard) = b.module.lock() {
315            guard.get_or_insert_with(|| m.clone());
316        }
317        Ok(m)
318    }
319
320    fn has_nonzero_direction(dir: &[f64; 4]) -> bool {
321        dir.iter().any(|&v| v != 0.0)
322    }
323
324    pub(super) fn survival_rigid_row_jets_device(
325        rows: &[SurvivalRowInputs],
326        probit_scale: f64,
327        dir: &[f64; 4],
328        dir_u: &[f64; 4],
329        dir_v: &[f64; 4],
330    ) -> Result<SurvivalRowJetChannels, GpuError> {
331        let n = rows.len();
332        if n == 0 {
333            return Ok(SurvivalRowJetChannels {
334                n_rows: 0,
335                value: Vec::new(),
336                grad: Vec::new(),
337                hess: Vec::new(),
338                third: Vec::new(),
339                fourth: Vec::new(),
340            });
341        }
342        let b = backend()?;
343        let m = module(b)?;
344        let need_fourth = has_nonzero_direction(dir_u) && has_nonzero_direction(dir_v);
345        let func_name = if need_fourth {
346            "survival_rowjet"
347        } else {
348            "survival_rowjet_no_t4"
349        };
350        let func = m
351            .load_function(func_name)
352            .gpu_ctx_with(|err| format!("survival_rowjet load_function {func_name}: {err}"))?;
353        let stream = b.stream.clone();
354
355        // Flatten inputs into struct-of-arrays for coalesced device reads.
356        let mut q0 = vec![0.0_f64; n];
357        let mut q1 = vec![0.0_f64; n];
358        let mut qd1 = vec![0.0_f64; n];
359        let mut g = vec![0.0_f64; n];
360        let mut wi = vec![0.0_f64; n];
361        let mut di = vec![0.0_f64; n];
362        let mut zs = vec![0.0_f64; n];
363        let mut cov = vec![0.0_f64; n];
364        for (i, r) in rows.iter().enumerate() {
365            q0[i] = r.primaries[0];
366            q1[i] = r.primaries[1];
367            qd1[i] = r.primaries[2];
368            g[i] = r.primaries[3];
369            wi[i] = r.wi;
370            di[i] = r.di;
371            zs[i] = r.z_sum;
372            cov[i] = r.cov_ones;
373        }
374
375        let q0_d = stream.clone_htod(&q0).gpu_ctx("htod q0")?;
376        let q1_d = stream.clone_htod(&q1).gpu_ctx("htod q1")?;
377        let qd1_d = stream.clone_htod(&qd1).gpu_ctx("htod qd1")?;
378        let g_d = stream.clone_htod(&g).gpu_ctx("htod g")?;
379        let wi_d = stream.clone_htod(&wi).gpu_ctx("htod wi")?;
380        let di_d = stream.clone_htod(&di).gpu_ctx("htod di")?;
381        let zs_d = stream.clone_htod(&zs).gpu_ctx("htod zsum")?;
382        let cov_d = stream.clone_htod(&cov).gpu_ctx("htod cov")?;
383        let dir_d = stream.clone_htod(&dir.to_vec()).gpu_ctx("htod dir")?;
384
385        let mut value_d = stream.alloc_zeros::<f64>(n).gpu_ctx("alloc value")?;
386        let mut grad_d = stream.alloc_zeros::<f64>(n * 4).gpu_ctx("alloc grad")?;
387        let mut hess_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc hess")?;
388        let mut third_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc third")?;
389        let mut fourth_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc fourth")?;
390
391        let n_i32 = i32::try_from(n)
392            .map_err(|_| gam_gpu::gpu_err!("survival_rowjet n={n} overflows i32"))?;
393        const TPB: u32 = 128;
394        let grid = ((n as u32).div_ceil(TPB)).max(1);
395        let cfg = LaunchConfig {
396            grid_dim: (grid, 1, 1),
397            block_dim: (TPB, 1, 1),
398            shared_mem_bytes: 0,
399        };
400        let mut builder = stream.launch_builder(&func);
401        builder
402            .arg(&n_i32)
403            .arg(&q0_d)
404            .arg(&q1_d)
405            .arg(&qd1_d)
406            .arg(&g_d)
407            .arg(&wi_d)
408            .arg(&di_d)
409            .arg(&zs_d)
410            .arg(&cov_d)
411            .arg(&probit_scale)
412            .arg(&dir_d);
413        let diru_d;
414        let dirv_d;
415        if need_fourth {
416            diru_d = stream.clone_htod(&dir_u.to_vec()).gpu_ctx("htod dir_u")?;
417            dirv_d = stream.clone_htod(&dir_v.to_vec()).gpu_ctx("htod dir_v")?;
418            builder.arg(&diru_d).arg(&dirv_d);
419        }
420        builder
421            .arg(&mut value_d)
422            .arg(&mut grad_d)
423            .arg(&mut hess_d)
424            .arg(&mut third_d)
425            .arg(&mut fourth_d);
426        // SAFETY: grid/block validated; every pointer is a cudarc-checked
427        // allocation on this stream; the selected kernel reads the 8 input
428        // arrays of length n (+ one or three length-4 directions) and writes
429        // within the output buffers of length n / n*16.
430        unsafe { builder.launch(cfg) }.gpu_ctx("survival_rowjet kernel launch")?;
431
432        let mut value = vec![0.0_f64; n];
433        let mut grad = vec![0.0_f64; n * 4];
434        let mut hess = vec![0.0_f64; n * 16];
435        let mut third = vec![0.0_f64; n * 16];
436        let mut fourth = vec![0.0_f64; n * 16];
437        stream
438            .memcpy_dtoh(&value_d, &mut value)
439            .gpu_ctx("dtoh value")?;
440        stream
441            .memcpy_dtoh(&grad_d, &mut grad)
442            .gpu_ctx("dtoh grad")?;
443        stream
444            .memcpy_dtoh(&hess_d, &mut hess)
445            .gpu_ctx("dtoh hess")?;
446        stream
447            .memcpy_dtoh(&third_d, &mut third)
448            .gpu_ctx("dtoh third")?;
449        stream
450            .memcpy_dtoh(&fourth_d, &mut fourth)
451            .gpu_ctx("dtoh fourth")?;
452        stream
453            .synchronize()
454            .gpu_ctx("survival_rowjet synchronize")?;
455
456        Ok(SurvivalRowJetChannels {
457            n_rows: n,
458            value,
459            grad,
460            hess,
461            third,
462            fourth,
463        })
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470
471    fn fixture(n: usize) -> Vec<SurvivalRowInputs> {
472        (0..n)
473            .map(|i| {
474                let t = i as f64 / n as f64;
475                SurvivalRowInputs {
476                    primaries: [
477                        -2.5 + 5.0 * (12.0 * t).sin(),
478                        -1.5 + 4.0 * (9.0 * t + 0.3).cos(),
479                        0.2 + 1.8 * (0.5 + 0.5 * (7.0 * t).sin()),
480                        -1.0 + 2.0 * (5.0 * t + 1.1).sin(),
481                    ],
482                    wi: 1.0,
483                    di: if i % 3 == 0 { 1.0 } else { 0.0 },
484                    z_sum: 0.5 * (3.0 * t).cos(),
485                    cov_ones: 0.4 + 0.3 * (0.5 + 0.5 * (2.0 * t).sin()),
486                }
487            })
488            .collect()
489    }
490
491    const DIR: [f64; 4] = [0.31, -0.22, 0.17, 0.44];
492    const DIRU: [f64; 4] = [0.13, 0.27, -0.41, 0.05];
493    const DIRV: [f64; 4] = [-0.19, 0.33, 0.08, 0.22];
494
495    #[test]
496    fn cpu_channels_match_unified_rowkernel() {
497        // The CPU fallback IS `rigid_row_nll` at Order2/OneSeed/TwoSeed, the same
498        // thing the production `SurvivalMarginalSlopeRowKernel` calls. Cross-check
499        // the (v,g,H) channels against a direct `Order2<4>` evaluation so the
500        // flattening/layout is pinned to the single source.
501        use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
502        use gam_math::jet_scalar::{JetScalar, Order2};
503        let rows = fixture(7);
504        let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
505        for (row, inp) in rows.iter().enumerate() {
506            let in_row = RigidRowInputs {
507                row,
508                wi: inp.wi,
509                di: inp.di,
510                z_sum: inp.z_sum,
511                covariance_ones: inp.cov_ones,
512                probit_scale: 0.7,
513                qd1_lower: f64::NEG_INFINITY,
514            };
515            let vars: [Order2<4>; 4] =
516                std::array::from_fn(|a| Order2::variable(inp.primaries[a], a));
517            let dense = rigid_row_nll(&vars, &in_row).expect("dense order2");
518            assert!((dense.value() - out.value[row]).abs() <= 1e-12);
519            for a in 0..4 {
520                assert!((dense.g()[a] - out.grad[row * 4 + a]).abs() <= 1e-12);
521                for b in 0..4 {
522                    assert!(
523                        (dense.h()[a][b] - out.hess[row * 16 + a * 4 + b]).abs() <= 1e-12,
524                        "hess mismatch row {row} {a},{b}"
525                    );
526                }
527            }
528        }
529    }
530
531    #[test]
532    fn cpu_third_fourth_match_dense_tower_oracle() {
533        // The seeded-jet (OneSeed/TwoSeed, O(K²)) contracted third/fourth in the
534        // CPU fallback must equal the TRUE tensor contraction from the dense
535        // `Tower4<4>` (the K³/K⁴ tensor). This pins the seeded contraction to the
536        // single-source tensor exactly — the same property the device kernel's
537        // JS1/JS2 channels rely on (and the device parity gate then matches THIS
538        // CPU result to ≤1e-9).
539        use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
540        use gam_math::jet_tower::Tower4;
541        let rows = fixture(9);
542        let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
543        for (row, inp) in rows.iter().enumerate() {
544            let in_row = RigidRowInputs {
545                row,
546                wi: inp.wi,
547                di: inp.di,
548                z_sum: inp.z_sum,
549                covariance_ones: inp.cov_ones,
550                probit_scale: 0.7,
551                qd1_lower: f64::NEG_INFINITY,
552            };
553            let vars: [Tower4<4>; 4] =
554                std::array::from_fn(|a| Tower4::variable(inp.primaries[a], a));
555            let tower = rigid_row_nll(&vars, &in_row).expect("dense tower4");
556            let t3 = tower.third_contracted(&DIR);
557            let t4 = tower.fourth_contracted(&DIRU, &DIRV);
558            for a in 0..4 {
559                for b in 0..4 {
560                    assert!(
561                        (t3[a][b] - out.third[row * 16 + a * 4 + b]).abs() <= 1e-12,
562                        "third mismatch row {row} {a},{b}: tensor={} seeded={}",
563                        t3[a][b],
564                        out.third[row * 16 + a * 4 + b]
565                    );
566                    assert!(
567                        (t4[a][b] - out.fourth[row * 16 + a * 4 + b]).abs() <= 1e-12,
568                        "fourth mismatch row {row} {a},{b}: tensor={} seeded={}",
569                        t4[a][b],
570                        out.fourth[row * 16 + a * 4 + b]
571                    );
572                }
573            }
574        }
575    }
576
577    /// Per-channel CPU↔device parity tolerance (#415 / #1175).
578    ///
579    /// The device kernel runs the SAME seeded-jet arithmetic as the CPU jet
580    /// (pinned line-for-line by the host-oracle `*_tests` module on every box),
581    /// so the residual is NOT an algebra mismatch. With NVRTC FMA contraction
582    /// now disabled (#1686, `--fmad=false`), the residual splits into a tight
583    /// low-order floor (FMA was its dominant source, so the fix shrank it) and
584    /// an irreducible transcendental floor in the high-order channels: CUDA's
585    /// `erfc`/`erfcx`/`exp`/`sqrt` differ from the host libm at the ULP level,
586    /// and that ε is amplified through the order-4 jet chain (`logΦ`, the Mills
587    /// `k1..k4` polynomial, the `c=√(1+(s·g)²cov)` composition) into the
588    /// third/fourth channels — which `--fmad=false` leaves unchanged (5.09e-8 /
589    /// 4.54e-8, bit-identical to the pre-#1686 measurement). Measured on a
590    /// Tesla V100 (sm_70), the drift, **normalized to each channel's
591    /// magnitude**, is:
592    ///
593    /// ```text
594    ///   channel  worst |Δ|     channel max|cpu|   |Δ|/scale
595    ///   value    1.48e-10      2.22e1             6.7e-12
596    ///   grad     8.18e-10      1.14e1             7.2e-11
597    ///   hess     8.79e-9       2.50e1             3.5e-10
598    ///   third    5.09e-8       4.25e1             1.2e-9
599    ///   fourth   4.54e-8       1.23e2             3.7e-10
600    /// ```
601    ///
602    /// (The old gate compared a flat `|Δ| <= 1e-9` ACROSS ALL channels — it
603    /// ignored both derivative-order amplification and the transcendental
604    /// floor, so the third channel's 5.09e-8 failed it even though that is a
605    /// 1.2e-9 relative drift. Per-element *relative* error is also wrong here:
606    /// the high-order channels cross zero, so at a cancellation point |cpu| is
607    /// ~1e-7 while the channel scale is ~1e2 and the relative error spuriously
608    /// reads 2.0.) The principled scale is the channel magnitude. A real
609    /// algebra bug (a sign flip / dropped Leibniz term, the #736 genus) makes
610    /// an error of order the channel magnitude itself — normalized residual
611    /// ~O(1), seven orders above this floor — so the gate below catches every
612    /// real defect with ~80× headroom over the transcendental noise.
613    const PARITY_ATOL: f64 = 1e-9;
614    const PARITY_RTOL: f64 = 1e-7;
615
616    /// Assert every element of `dev` matches `cpu` within
617    /// `PARITY_ATOL + PARITY_RTOL * channel_scale`, where `channel_scale` is the
618    /// channel's max |cpu| (the magnitude a real bug would perturb). Returns the
619    /// worst normalized residual for reporting.
620    fn assert_channel_parity(name: &str, cpu: &[f64], dev: &[f64]) -> f64 {
621        let scale = cpu.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
622        let tol = PARITY_ATOL + PARITY_RTOL * scale;
623        let mut worst = 0.0_f64;
624        let mut worst_i = 0usize;
625        for (i, (x, y)) in cpu.iter().zip(dev).enumerate() {
626            let d = (x - y).abs();
627            if d > worst {
628                worst = d;
629                worst_i = i;
630            }
631        }
632        assert!(
633            worst <= tol,
634            "survival device vs CPU `{name}` channel: worst |Δ|={worst:.3e} at idx {worst_i} \
635             (cpu={:.6e} dev={:.6e}) exceeds tol={tol:.3e} (atol={PARITY_ATOL:.0e} + \
636             rtol={PARITY_RTOL:.0e}·scale {scale:.3e}). A residual this large is an algebra \
637             mismatch, not transcendental drift — check the .cu JS1/JS2 recurrences.",
638            cpu[worst_i],
639            dev[worst_i]
640        );
641        worst / tol
642    }
643
644    #[cfg(target_os = "linux")]
645    #[test]
646    fn device_matches_cpu_when_available() {
647        // Exactness gate: when a device is admitted, every channel must match the
648        // CPU unified jet within the principled per-channel magnitude-scaled band
649        // (see PARITY_ATOL/PARITY_RTOL). When no device is available the dispatcher
650        // returns the CPU result, so this asserts CPU==CPU (trivially within band).
651        let rows = fixture(DEVICE_ROW_THRESHOLD + 1024);
652        let cpu = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
653        let got = survival_rigid_row_jets(&rows, 0.7, &DIR, &DIRU, &DIRV);
654        assert_channel_parity("value", &cpu.value, &got.value);
655        assert_channel_parity("grad", &cpu.grad, &got.grad);
656        assert_channel_parity("hess", &cpu.hess, &got.hess);
657        assert_channel_parity("third", &cpu.third, &got.third);
658        assert_channel_parity("fourth", &cpu.fourth, &got.fourth);
659
660        // Anti-false-green: if a CUDA runtime is present the dispatcher MUST have
661        // exercised the device kernel above (n > DEVICE_ROW_THRESHOLD), not the
662        // silent CPU fallback. Prove the device path itself runs and matches —
663        // otherwise this gate would pass on CPU==CPU even with a dead kernel.
664        if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
665            let dev = survival_rigid_row_jets_device_only(&rows, 0.7, &DIR, &DIRU, &DIRV)
666                .expect("CUDA runtime present but survival_rowjet device path could not run");
667            assert_channel_parity("device value", &cpu.value, &dev.value);
668            assert_channel_parity("device grad", &cpu.grad, &dev.grad);
669            assert_channel_parity("device hess", &cpu.hess, &dev.hess);
670            assert_channel_parity("device third", &cpu.third, &dev.third);
671            assert_channel_parity("device fourth", &cpu.fourth, &dev.fourth);
672        }
673    }
674
675    /// Edge-regime fixture: rows deliberately placed in the hard corners of the
676    /// probit Mills-ratio stack, where erfc/erfcx differ most between host libm
677    /// and CUDA and the seeded-jet amplification is largest. Covers
678    /// censored/event × entry-present, deep negative tails (logΦ underflow
679    /// regime), tiny and large covariance, near-zero slope, large scale, zero
680    /// weight (the early-out branch), and the erfcx asymptotic cutover (|η|>26).
681    fn edge_fixture() -> Vec<SurvivalRowInputs> {
682        let mut rows = Vec::new();
683        let push = |rows: &mut Vec<SurvivalRowInputs>, p: [f64; 4], w, d, z, c| {
684            rows.push(SurvivalRowInputs {
685                primaries: p,
686                wi: w,
687                di: d,
688                z_sum: z,
689                cov_ones: c,
690            });
691        };
692        // interior, event & censored
693        push(&mut rows, [-0.4, 0.6, 0.9, 0.3], 1.0, 1.0, 0.2, 0.5);
694        push(&mut rows, [-0.4, 0.6, 0.9, 0.3], 1.0, 0.0, 0.2, 0.5);
695        // deep negative probit tail (logΦ(−η)→ asymptotic / Mills tail)
696        push(&mut rows, [8.0, 9.0, 1.2, 2.5], 1.0, 0.0, -3.0, 1.0);
697        push(&mut rows, [-8.0, -9.0, 1.2, -2.5], 1.0, 1.0, 3.0, 1.0);
698        // erfcx asymptotic cutover region (argument near/above 26)
699        push(&mut rows, [40.0, 41.0, 0.7, 3.0], 1.0, 0.0, 0.0, 2.0);
700        // tiny covariance (c ≈ 1, derivative of √ near flat)
701        push(&mut rows, [-0.3, 0.5, 0.8, 1.5], 1.0, 1.0, 0.4, 1e-10);
702        // large covariance + large scale (c large, strong coupling)
703        push(&mut rows, [-0.2, 0.4, 1.1, 4.0], 1.0, 1.0, 0.1, 50.0);
704        // near-zero slope (og→0, opb2→1)
705        push(&mut rows, [-0.5, 0.3, 0.6, 1e-9], 1.0, 0.0, 0.7, 0.9);
706        // zero weight (the w==0 early-out: every channel 0)
707        push(&mut rows, [-0.5, 0.3, 0.6, 0.4], 0.0, 1.0, 0.7, 0.9);
708        // small positive qd1 (log(ad1) near its valid edge)
709        push(&mut rows, [-0.5, 0.3, 1e-3, 0.4], 1.0, 1.0, 0.2, 0.6);
710        rows
711    }
712
713    /// #415 core deliverable — **fail loud, never silently degrade.** On a GPU
714    /// box the device path MUST run; this calls `survival_rigid_row_jets_device_only`
715    /// (which never falls back) and asserts it both (a) succeeds — no silent
716    /// NVRTC-declined / wrong-arch / launch-failure swallowed by the dispatcher —
717    /// and (b) matches the CPU oracle within the principled per-channel band, for
718    /// BOTH the t4 and the no-t4 kernel variants and across the edge-regime sweep.
719    ///
720    /// When no CUDA device is present the device-only path returns `Err`, which
721    /// is the legitimate state on a CPU-only box — so the test SKIPS with a clear
722    /// log there. Set `GAM_REQUIRE_GPU=1` (CI on the GPU runner) to turn that skip
723    /// into a HARD failure: a box that is supposed to have a GPU but can't run the
724    /// kernel must break the build, not pass on the CPU.
725    #[cfg(target_os = "linux")]
726    #[test]
727    fn device_only_path_runs_and_matches_cpu_fail_loud() {
728        // Fail loud only when a CUDA device is actually present (a real runtime
729        // check, not an env-var read — `env::var` is banned crate-wide): on a GPU
730        // box the device path MUST run, while a CI runner with no device skips
731        // gracefully.
732        let require_gpu = gam_gpu::device_runtime::GpuRuntime::global().is_some();
733
734        // Two batches: enough rows to amortise the launch, in both the interior
735        // (smooth) and edge (transcendental-stress) regimes. The edge batch is
736        // padded by tiling so it crosses DEVICE_ROW_THRESHOLD.
737        let interior = fixture(DEVICE_ROW_THRESHOLD + 777);
738        let edge_unit = edge_fixture();
739        let reps = (DEVICE_ROW_THRESHOLD + 999).div_ceil(edge_unit.len());
740        let edge: Vec<_> = edge_unit
741            .iter()
742            .cloned()
743            .cycle()
744            .take(reps * edge_unit.len())
745            .collect();
746
747        // Variant matrix: (label, dir_u, dir_v). All-zero (u,v) selects the
748        // `survival_rowjet_no_t4` kernel (fourth channel ≡ 0); nonzero selects
749        // the full `survival_rowjet`. Cover both so neither entry point rots.
750        let zero = [0.0_f64; 4];
751        let variants: [(&str, &[f64; 4], &[f64; 4]); 2] =
752            [("t4", &DIRU, &DIRV), ("no_t4", &zero, &zero)];
753
754        let mut ran_on_device = false;
755        for (regime, rows) in [("interior", &interior), ("edge", &edge)] {
756            for (vlabel, du, dv) in variants {
757                let dev = match survival_rigid_row_jets_device_only(rows, 0.7, &DIR, du, dv) {
758                    Ok(d) => d,
759                    Err(e) => {
760                        if require_gpu {
761                            panic!(
762                                "GAM_REQUIRE_GPU set but survival_rowjet device path \
763                                 ({regime}/{vlabel}) could not run: {e}"
764                            );
765                        }
766                        eprintln!(
767                            "[#415] no CUDA device ({regime}/{vlabel}) — skipping device-only \
768                             parity (set GAM_REQUIRE_GPU=1 to make this a hard failure): {e}"
769                        );
770                        continue;
771                    }
772                };
773                ran_on_device = true;
774                let cpu = survival_rigid_row_jets_cpu(rows, 0.7, &DIR, du, dv);
775                assert_channel_parity(&format!("{regime}/{vlabel}/value"), &cpu.value, &dev.value);
776                assert_channel_parity(&format!("{regime}/{vlabel}/grad"), &cpu.grad, &dev.grad);
777                assert_channel_parity(&format!("{regime}/{vlabel}/hess"), &cpu.hess, &dev.hess);
778                assert_channel_parity(&format!("{regime}/{vlabel}/third"), &cpu.third, &dev.third);
779                assert_channel_parity(
780                    &format!("{regime}/{vlabel}/fourth"),
781                    &cpu.fourth,
782                    &dev.fourth,
783                );
784                // The no_t4 variant must yield an exactly-zero fourth channel
785                // (the kernel writes 0.0), and the CPU oracle agrees because
786                // (u,v)=0 contracts the fourth tensor to zero.
787                if vlabel == "no_t4" {
788                    assert!(
789                        dev.fourth.iter().all(|&x| x == 0.0),
790                        "no_t4 kernel must write an all-zero fourth channel"
791                    );
792                }
793            }
794        }
795        if ran_on_device {
796            eprintln!("[#415] device-only parity PASSED on GPU for all regimes × variants");
797        }
798    }
799
800}