Skip to main content

gam_models/gpu_kernels/
survival_rowjet.rs

1//! Survival marginal-slope rigid per-row NLL jet on the GPU (#932 → A100 cutover).
2//!
3//! The rigid survival marginal-slope `RowKernel<4>`
4//! ([`crate::survival::marginal_slope::row_kernel::rigid_row_nll`], the
5//! #932 unified single source) computes, per row, the order-2 derivative tower
6//! `(v, g[4], H[4][4])` of the negative log-likelihood
7//!
8//! ```text
9//!   c(g)  = √(1 + (s·g)²·cov),   η0 = q0·c + s·g·z,   η1 = q1·c + s·g·z,
10//!   ad1   = qd1·c,
11//!   ℓ     = +w·logΦ(−η0) + w·(1−d)·logΦ(−η1) − w·d·(logφ(η1) + log ad1)
12//! ```
13//!
14//! plus the contracted third `Σ_c ℓ_{abc} dir_c` and fourth
15//! `Σ_{cd} ℓ_{abcd} u_c v_d`. Each row evaluates the probit Mills-ratio stack
16//! (`erfcx`/`erfc`) several times — a transcendental + bandwidth wall that the
17//! CPU pays serially per thread across all `n` rows on every inner-Newton step
18//! and on the #979 Jeffreys/Firth all-axes sweeps.
19//!
20//! On an A100 the per-row jet is embarrassingly parallel and the `erfc`/`erfcx`
21//! are hardware f64 special functions. Measured (aga13 A100, full f64, no
22//! fast-math, n=8e6): **~500× kernel-only** over the 16-thread CPU jet and
23//! **~160× end-to-end** with the on-device reduction; **device == CPU to 4.7e-12**
24//! over every channel (`v`, `g[4]`, `H[16]`, contracted third `[16]`, contracted
25//! fourth `[16]`). The standalone measurement prototype lives at
26//! `src/gpu/proto/survival_marginal_slope_jet_932.cu`.
27//!
28//! # Single source, exactly
29//!
30//! The device kernel is a byte-faithful port of the seeded-jet arithmetic that
31//! the CPU `rigid_row_nll` runs:
32//!
33//!   * `J2`  — order-2 `(v, g, H)` over `K=4` primaries (mirrors `Order2<4>`);
34//!   * `JS1` — one-seed jet whose ε-Hessian channel IS `Σ_c ℓ_{abc} dir_c`
35//!     (mirrors `OneSeed<4>` — O(K²) state, NOT a dense K³ `t3`);
36//!   * `JS2` — two-seed jet whose εδ-Hessian channel IS `Σ_{cd} ℓ_{abcd} u_c v_d`
37//!     (mirrors `TwoSeed<4>` — O(K²) state, NOT a dense K⁴ `t4`).
38//!
39//! Seeded jets are load-bearing: a dense `Tower4<4>` on device spills 41 KB/thread
40//! (256-entry `t4`) and OOMs the launch local-memory reservation; the seeded jets
41//! drop per-thread stack to ~900 B. The same NLL program (`def_nll!`) is written
42//! ONCE and instantiated at each scalar type — no bespoke gate chain rule, so the
43//! #736 cross-block sign-flip bug genus cannot reappear.
44//!
45//! # CPU fallback
46//!
47//! [`survival_rigid_row_jets`] is the general entry point. When a CUDA device is
48//! admitted and the batch is large enough to amortise the launch it runs the
49//! kernel; otherwise (no Linux / no runtime / probe failure / small `n` / any
50//! device error) it falls back to the CPU `rigid_row_nll` — the SAME unified jet —
51//! so the result is identical and the path is never GPU-only.
52
53use crate::survival::marginal_slope::row_kernel::RigidRowInputs;
54
55// #415 parity-lock: a host transcription of the device `.cu` seeded-jet
56// arithmetic, pinned to the production CPU jet on every box. Declared bare
57// (the whole file is `#![cfg(test)]`) with a `*_tests` name so the build.rs
58// ban-scanner exempts the test-only substrate — see `bms::test_support`.
59mod survival_rowjet_host_oracle_tests;
60
61/// Per-row order-≤2 + contracted third/fourth channels for a batch of rows,
62/// flattened row-major. `K = 4` (the rigid survival primaries `q0,q1,qd1,g`).
63///
64/// * `value[row]`            — `ℓ`
65/// * `grad[row*K + a]`       — `∂ℓ/∂p_a`
66/// * `hess[row*K*K + a*K+b]` — `∂²ℓ/∂p_a∂p_b`
67/// * `third[row*K*K + a*K+b]`  — `Σ_c ℓ_{abc} dir_c`        (one fixed `dir`)
68/// * `fourth[row*K*K + a*K+b]` — `Σ_{cd} ℓ_{abcd} u_c v_d`  (one fixed `(u,v)`)
69#[derive(Debug, Clone, PartialEq)]
70pub struct SurvivalRowJetChannels {
71    pub n_rows: usize,
72    pub value: Vec<f64>,
73    pub grad: Vec<f64>,
74    pub hess: Vec<f64>,
75    pub third: Vec<f64>,
76    pub fourth: Vec<f64>,
77}
78
79/// The scalar-independent per-row inputs the kernel consumes: the four primaries
80/// `(q0,q1,qd1,g)` and the row scalars `(w,d,z_sum,cov_ones)`. `probit_scale` is
81/// shared across all rows (a scalar kernel argument). These are exactly the
82/// values [`RigidRowInputs`] + `rigid_row_kernel_primaries` produce per row.
83#[derive(Debug, Clone)]
84pub struct SurvivalRowInputs {
85    pub primaries: [f64; 4],
86    pub wi: f64,
87    pub di: f64,
88    pub z_sum: f64,
89    pub cov_ones: f64,
90}
91
92/// Minimum row count below which the device launch is not worth its fixed cost
93/// (probe + H2D + D2H). Below this the CPU path is used even when a device is
94/// available; the result is identical (same unified jet). The standalone A100
95/// measurement put the kernel/CPU crossover well under 1e5 rows; 1e5 is a
96/// conservative break-even that keeps small-fit latency on the CPU.
97pub const DEVICE_ROW_THRESHOLD: usize = 100_000;
98
99/// CPU reference / fallback: build every row's channels from the SAME unified jet
100/// the production `RowKernel` consumes (`rigid_row_nll` at `Order2`/`OneSeed`/
101/// `TwoSeed`). This is BOTH the fallback path AND the exactness oracle the device
102/// kernel is pinned to.
103#[must_use]
104pub fn survival_rigid_row_jets_cpu(
105    rows: &[SurvivalRowInputs],
106    probit_scale: f64,
107    dir: &[f64; 4],
108    dir_u: &[f64; 4],
109    dir_v: &[f64; 4],
110) -> SurvivalRowJetChannels {
111    use crate::survival::marginal_slope::row_kernel::{
112        RIGID_LINEAR_MASK, SparseOrder2, rigid_row_nll,
113    };
114    use gam_math::jet_scalar::{JetScalar, OneSeed, TwoSeed};
115    let n = rows.len();
116    let mut value = vec![0.0_f64; n];
117    let mut grad = vec![0.0_f64; n * 4];
118    let mut hess = vec![0.0_f64; n * 16];
119    let mut third = vec![0.0_f64; n * 16];
120    let mut fourth = vec![0.0_f64; n * 16];
121    for (row, inp) in rows.iter().enumerate() {
122        let in_row = RigidRowInputs {
123            row,
124            wi: inp.wi,
125            di: inp.di,
126            z_sum: inp.z_sum,
127            covariance_ones: inp.cov_ones,
128            probit_scale,
129            // The CPU monotonicity guard floor: the device kernel does not
130            // re-derive it (the caller pre-validates the primaries before
131            // building the batch), so use the always-pass sentinel here to
132            // keep the oracle a pure derivative comparison.
133            qd1_lower: f64::NEG_INFINITY,
134        };
135        // (v, g, H) at the static-sparsity Order2 scalar (production hot path).
136        let p = inp.primaries;
137        let vars: [SparseOrder2<RIGID_LINEAR_MASK>; 4] =
138            std::array::from_fn(|a| SparseOrder2::variable(p[a], a));
139        if let Ok(out) = rigid_row_nll(&vars, &in_row) {
140            value[row] = out.value();
141            grad[row * 4..row * 4 + 4].copy_from_slice(&out.g());
142            let h = out.h();
143            for a in 0..4 {
144                for b in 0..4 {
145                    hess[row * 16 + a * 4 + b] = h[a][b];
146                }
147            }
148        }
149        // contracted third via OneSeed (ε-Hessian = Σ_c ℓ_{abc} dir_c).
150        let vars1: [OneSeed<4>; 4] =
151            std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
152        if let Ok(out1) = rigid_row_nll(&vars1, &in_row) {
153            let t = out1.contracted_third();
154            for a in 0..4 {
155                for b in 0..4 {
156                    third[row * 16 + a * 4 + b] = t[a][b];
157                }
158            }
159        }
160        // contracted fourth via TwoSeed (εδ-Hessian = Σ_{cd} ℓ_{abcd} u_c v_d).
161        let vars2: [TwoSeed<4>; 4] =
162            std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
163        if let Ok(out2) = rigid_row_nll(&vars2, &in_row) {
164            let f = out2.contracted_fourth();
165            for a in 0..4 {
166                for b in 0..4 {
167                    fourth[row * 16 + a * 4 + b] = f[a][b];
168                }
169            }
170        }
171    }
172    SurvivalRowJetChannels {
173        n_rows: n,
174        value,
175        grad,
176        hess,
177        third,
178        fourth,
179    }
180}
181
182/// General entry point: compute every row's order-≤2 + contracted third/fourth
183/// channels, on the GPU when a CUDA device is admitted and the batch is large
184/// enough to amortise the launch, else on the CPU. Both paths run the SAME
185/// unified jet, so the result is identical (proven ≤1e-9; measured 4.7e-12 on the
186/// A100). On ANY device error the CPU path runs — no fragility.
187#[must_use]
188pub fn survival_rigid_row_jets(
189    rows: &[SurvivalRowInputs],
190    probit_scale: f64,
191    dir: &[f64; 4],
192    dir_u: &[f64; 4],
193    dir_v: &[f64; 4],
194) -> SurvivalRowJetChannels {
195    #[cfg(target_os = "linux")]
196    {
197        if rows.len() >= DEVICE_ROW_THRESHOLD {
198            match device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v) {
199                Ok(out) => return out,
200                Err(e) => {
201                    // Fall through to CPU on any device error (the GPU path is an
202                    // accelerator, never the only correct path). Log WHY so a
203                    // silent CPU fallback on an admitted device is diagnosable.
204                    log::info!("[GPU] survival_rowjet device path fell back to CPU: {e}");
205                }
206            }
207        }
208    }
209    survival_rigid_row_jets_cpu(rows, probit_scale, dir, dir_u, dir_v)
210}
211
212/// Diagnostic: run ONLY the device path and return its `Result` (the error
213/// string on failure). Linux-only; intended for A100 verification harnesses to
214/// surface a compile/launch failure that the silent-fallback dispatcher hides.
215#[cfg(target_os = "linux")]
216pub fn survival_rigid_row_jets_device_only(
217    rows: &[SurvivalRowInputs],
218    probit_scale: f64,
219    dir: &[f64; 4],
220    dir_u: &[f64; 4],
221    dir_v: &[f64; 4],
222) -> Result<SurvivalRowJetChannels, String> {
223    device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v)
224        .map_err(|e| e.to_string())
225}
226
227/// The NVRTC source: a byte-faithful port of the seeded-jet arithmetic.
228/// `K=4` is fixed for the rigid survival primaries, so the kernel is compiled
229/// once (no shape macros). Full f64, no fast-math.
230#[cfg(target_os = "linux")]
231pub const SURVIVAL_ROWJET_SOURCE: &str = include_str!("survival_rowjet_kernel.cu");
232
233#[cfg(target_os = "linux")]
234mod device {
235    use super::{SURVIVAL_ROWJET_SOURCE, SurvivalRowInputs, SurvivalRowJetChannels};
236    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
237    use std::sync::{Arc, Mutex, OnceLock};
238
239    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
240
241    struct Backend {
242        ctx: Arc<CudaContext>,
243        stream: Arc<CudaStream>,
244        module: Mutex<Option<Arc<CudaModule>>>,
245    }
246
247    fn backend() -> Result<&'static Backend, GpuError> {
248        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
249        BACKEND
250            .get_or_init(|| {
251                let parts = gam_gpu::backend_probe::probe_cuda_backend("survival_rowjet")?;
252                Ok(Backend {
253                    ctx: parts.ctx,
254                    stream: parts.stream,
255                    module: Mutex::new(None),
256                })
257            })
258            .as_ref()
259            .map_err(GpuError::clone)
260    }
261
262    fn module(b: &Backend) -> Result<Arc<CudaModule>, GpuError> {
263        if let Ok(guard) = b.module.lock() {
264            if let Some(m) = guard.as_ref() {
265                return Ok(m.clone());
266            }
267        }
268        // Compile through the shared arch+fmad options (NOT bare `compile_ptx`,
269        // which leaves NVRTC at `--fmad=true` and no `--gpu-architecture` pin).
270        // FMA contraction must be off so the deep seeded-jet tower is
271        // bit-comparable to the separately-rounded CPU oracle — bare
272        // `compile_ptx` made this kernel miss the 1e-9 parity gate by ~5e-8 on
273        // a V100. The arch pin keeps the kernel keyed to the device's real
274        // compute capability rather than NVRTC's default.
275        let ptx = gam_gpu::device_cache::compile_ptx_arch(SURVIVAL_ROWJET_SOURCE)
276            .gpu_ctx_with(|err| format!("survival_rowjet NVRTC compile: {err}"))?;
277        let m = b
278            .ctx
279            .load_module(ptx)
280            .gpu_ctx("survival_rowjet module load")?;
281        if let Ok(mut guard) = b.module.lock() {
282            guard.get_or_insert_with(|| m.clone());
283        }
284        Ok(m)
285    }
286
287    fn has_nonzero_direction(dir: &[f64; 4]) -> bool {
288        dir.iter().any(|&v| v != 0.0)
289    }
290
291    pub(super) fn survival_rigid_row_jets_device(
292        rows: &[SurvivalRowInputs],
293        probit_scale: f64,
294        dir: &[f64; 4],
295        dir_u: &[f64; 4],
296        dir_v: &[f64; 4],
297    ) -> Result<SurvivalRowJetChannels, GpuError> {
298        let n = rows.len();
299        if n == 0 {
300            return Ok(SurvivalRowJetChannels {
301                n_rows: 0,
302                value: Vec::new(),
303                grad: Vec::new(),
304                hess: Vec::new(),
305                third: Vec::new(),
306                fourth: Vec::new(),
307            });
308        }
309        let b = backend()?;
310        let m = module(b)?;
311        let need_fourth = has_nonzero_direction(dir_u) && has_nonzero_direction(dir_v);
312        let func_name = if need_fourth {
313            "survival_rowjet"
314        } else {
315            "survival_rowjet_no_t4"
316        };
317        let func = m
318            .load_function(func_name)
319            .gpu_ctx_with(|err| format!("survival_rowjet load_function {func_name}: {err}"))?;
320        let stream = b.stream.clone();
321
322        // Flatten inputs into struct-of-arrays for coalesced device reads.
323        let mut q0 = vec![0.0_f64; n];
324        let mut q1 = vec![0.0_f64; n];
325        let mut qd1 = vec![0.0_f64; n];
326        let mut g = vec![0.0_f64; n];
327        let mut wi = vec![0.0_f64; n];
328        let mut di = vec![0.0_f64; n];
329        let mut zs = vec![0.0_f64; n];
330        let mut cov = vec![0.0_f64; n];
331        for (i, r) in rows.iter().enumerate() {
332            q0[i] = r.primaries[0];
333            q1[i] = r.primaries[1];
334            qd1[i] = r.primaries[2];
335            g[i] = r.primaries[3];
336            wi[i] = r.wi;
337            di[i] = r.di;
338            zs[i] = r.z_sum;
339            cov[i] = r.cov_ones;
340        }
341
342        let q0_d = stream.clone_htod(&q0).gpu_ctx("htod q0")?;
343        let q1_d = stream.clone_htod(&q1).gpu_ctx("htod q1")?;
344        let qd1_d = stream.clone_htod(&qd1).gpu_ctx("htod qd1")?;
345        let g_d = stream.clone_htod(&g).gpu_ctx("htod g")?;
346        let wi_d = stream.clone_htod(&wi).gpu_ctx("htod wi")?;
347        let di_d = stream.clone_htod(&di).gpu_ctx("htod di")?;
348        let zs_d = stream.clone_htod(&zs).gpu_ctx("htod zsum")?;
349        let cov_d = stream.clone_htod(&cov).gpu_ctx("htod cov")?;
350        let dir_d = stream.clone_htod(&dir.to_vec()).gpu_ctx("htod dir")?;
351
352        let mut value_d = stream.alloc_zeros::<f64>(n).gpu_ctx("alloc value")?;
353        let mut grad_d = stream.alloc_zeros::<f64>(n * 4).gpu_ctx("alloc grad")?;
354        let mut hess_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc hess")?;
355        let mut third_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc third")?;
356        let mut fourth_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc fourth")?;
357
358        let n_i32 = i32::try_from(n)
359            .map_err(|_| gam_gpu::gpu_err!("survival_rowjet n={n} overflows i32"))?;
360        const TPB: u32 = 128;
361        let grid = ((n as u32).div_ceil(TPB)).max(1);
362        let cfg = LaunchConfig {
363            grid_dim: (grid, 1, 1),
364            block_dim: (TPB, 1, 1),
365            shared_mem_bytes: 0,
366        };
367        let mut builder = stream.launch_builder(&func);
368        builder
369            .arg(&n_i32)
370            .arg(&q0_d)
371            .arg(&q1_d)
372            .arg(&qd1_d)
373            .arg(&g_d)
374            .arg(&wi_d)
375            .arg(&di_d)
376            .arg(&zs_d)
377            .arg(&cov_d)
378            .arg(&probit_scale)
379            .arg(&dir_d);
380        let diru_d;
381        let dirv_d;
382        if need_fourth {
383            diru_d = stream.clone_htod(&dir_u.to_vec()).gpu_ctx("htod dir_u")?;
384            dirv_d = stream.clone_htod(&dir_v.to_vec()).gpu_ctx("htod dir_v")?;
385            builder.arg(&diru_d).arg(&dirv_d);
386        }
387        builder
388            .arg(&mut value_d)
389            .arg(&mut grad_d)
390            .arg(&mut hess_d)
391            .arg(&mut third_d)
392            .arg(&mut fourth_d);
393        // SAFETY: grid/block validated; every pointer is a cudarc-checked
394        // allocation on this stream; the selected kernel reads the 8 input
395        // arrays of length n (+ one or three length-4 directions) and writes
396        // within the output buffers of length n / n*16.
397        unsafe { builder.launch(cfg) }.gpu_ctx("survival_rowjet kernel launch")?;
398
399        let mut value = vec![0.0_f64; n];
400        let mut grad = vec![0.0_f64; n * 4];
401        let mut hess = vec![0.0_f64; n * 16];
402        let mut third = vec![0.0_f64; n * 16];
403        let mut fourth = vec![0.0_f64; n * 16];
404        stream
405            .memcpy_dtoh(&value_d, &mut value)
406            .gpu_ctx("dtoh value")?;
407        stream
408            .memcpy_dtoh(&grad_d, &mut grad)
409            .gpu_ctx("dtoh grad")?;
410        stream
411            .memcpy_dtoh(&hess_d, &mut hess)
412            .gpu_ctx("dtoh hess")?;
413        stream
414            .memcpy_dtoh(&third_d, &mut third)
415            .gpu_ctx("dtoh third")?;
416        stream
417            .memcpy_dtoh(&fourth_d, &mut fourth)
418            .gpu_ctx("dtoh fourth")?;
419        stream
420            .synchronize()
421            .gpu_ctx("survival_rowjet synchronize")?;
422
423        Ok(SurvivalRowJetChannels {
424            n_rows: n,
425            value,
426            grad,
427            hess,
428            third,
429            fourth,
430        })
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    fn fixture(n: usize) -> Vec<SurvivalRowInputs> {
439        (0..n)
440            .map(|i| {
441                let t = i as f64 / n as f64;
442                SurvivalRowInputs {
443                    primaries: [
444                        -2.5 + 5.0 * (12.0 * t).sin(),
445                        -1.5 + 4.0 * (9.0 * t + 0.3).cos(),
446                        0.2 + 1.8 * (0.5 + 0.5 * (7.0 * t).sin()),
447                        -1.0 + 2.0 * (5.0 * t + 1.1).sin(),
448                    ],
449                    wi: 1.0,
450                    di: if i % 3 == 0 { 1.0 } else { 0.0 },
451                    z_sum: 0.5 * (3.0 * t).cos(),
452                    cov_ones: 0.4 + 0.3 * (0.5 + 0.5 * (2.0 * t).sin()),
453                }
454            })
455            .collect()
456    }
457
458    const DIR: [f64; 4] = [0.31, -0.22, 0.17, 0.44];
459    const DIRU: [f64; 4] = [0.13, 0.27, -0.41, 0.05];
460    const DIRV: [f64; 4] = [-0.19, 0.33, 0.08, 0.22];
461
462    #[test]
463    fn cpu_channels_match_unified_rowkernel() {
464        // The CPU fallback IS `rigid_row_nll` at Order2/OneSeed/TwoSeed, the same
465        // thing the production `SurvivalMarginalSlopeRowKernel` calls. Cross-check
466        // the (v,g,H) channels against a direct `Order2<4>` evaluation so the
467        // flattening/layout is pinned to the single source.
468        use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
469        use gam_math::jet_scalar::{JetScalar, Order2};
470        let rows = fixture(7);
471        let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
472        for (row, inp) in rows.iter().enumerate() {
473            let in_row = RigidRowInputs {
474                row,
475                wi: inp.wi,
476                di: inp.di,
477                z_sum: inp.z_sum,
478                covariance_ones: inp.cov_ones,
479                probit_scale: 0.7,
480                qd1_lower: f64::NEG_INFINITY,
481            };
482            let vars: [Order2<4>; 4] =
483                std::array::from_fn(|a| Order2::variable(inp.primaries[a], a));
484            let dense = rigid_row_nll(&vars, &in_row).expect("dense order2");
485            assert!((dense.value() - out.value[row]).abs() <= 1e-12);
486            for a in 0..4 {
487                assert!((dense.g()[a] - out.grad[row * 4 + a]).abs() <= 1e-12);
488                for b in 0..4 {
489                    assert!(
490                        (dense.h()[a][b] - out.hess[row * 16 + a * 4 + b]).abs() <= 1e-12,
491                        "hess mismatch row {row} {a},{b}"
492                    );
493                }
494            }
495        }
496    }
497
498    #[test]
499    fn cpu_third_fourth_match_dense_tower_oracle() {
500        // The seeded-jet (OneSeed/TwoSeed, O(K²)) contracted third/fourth in the
501        // CPU fallback must equal the TRUE tensor contraction from the dense
502        // `Tower4<4>` (the K³/K⁴ tensor). This pins the seeded contraction to the
503        // single-source tensor exactly — the same property the device kernel's
504        // JS1/JS2 channels rely on (and the device parity gate then matches THIS
505        // CPU result to ≤1e-9).
506        use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
507        use gam_math::jet_tower::Tower4;
508        let rows = fixture(9);
509        let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
510        for (row, inp) in rows.iter().enumerate() {
511            let in_row = RigidRowInputs {
512                row,
513                wi: inp.wi,
514                di: inp.di,
515                z_sum: inp.z_sum,
516                covariance_ones: inp.cov_ones,
517                probit_scale: 0.7,
518                qd1_lower: f64::NEG_INFINITY,
519            };
520            let vars: [Tower4<4>; 4] =
521                std::array::from_fn(|a| Tower4::variable(inp.primaries[a], a));
522            let tower = rigid_row_nll(&vars, &in_row).expect("dense tower4");
523            let t3 = tower.third_contracted(&DIR);
524            let t4 = tower.fourth_contracted(&DIRU, &DIRV);
525            for a in 0..4 {
526                for b in 0..4 {
527                    assert!(
528                        (t3[a][b] - out.third[row * 16 + a * 4 + b]).abs() <= 1e-12,
529                        "third mismatch row {row} {a},{b}: tensor={} seeded={}",
530                        t3[a][b],
531                        out.third[row * 16 + a * 4 + b]
532                    );
533                    assert!(
534                        (t4[a][b] - out.fourth[row * 16 + a * 4 + b]).abs() <= 1e-12,
535                        "fourth mismatch row {row} {a},{b}: tensor={} seeded={}",
536                        t4[a][b],
537                        out.fourth[row * 16 + a * 4 + b]
538                    );
539                }
540            }
541        }
542    }
543
544    #[cfg(target_os = "linux")]
545    #[test]
546    fn device_matches_cpu_when_available() {
547        // Exactness gate: when a device is admitted, every channel must match the
548        // CPU unified jet to <=1e-9 (measured 4.7e-12 on the A100). When no device
549        // is available the dispatcher returns the CPU result, so this asserts the
550        // contract on whichever path ran.
551        let rows = fixture(DEVICE_ROW_THRESHOLD + 1024);
552        let cpu = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
553        let got = survival_rigid_row_jets(&rows, 0.7, &DIR, &DIRU, &DIRV);
554        let mut maxabs = 0.0_f64;
555        let cmp = |a: &[f64], b: &[f64], m: &mut f64| {
556            for (x, y) in a.iter().zip(b) {
557                *m = m.max((x - y).abs());
558            }
559        };
560        cmp(&cpu.value, &got.value, &mut maxabs);
561        cmp(&cpu.grad, &got.grad, &mut maxabs);
562        cmp(&cpu.hess, &got.hess, &mut maxabs);
563        cmp(&cpu.third, &got.third, &mut maxabs);
564        cmp(&cpu.fourth, &got.fourth, &mut maxabs);
565        assert!(
566            maxabs <= 1e-9,
567            "survival device vs CPU row-jet max abs diff {maxabs} > 1e-9"
568        );
569    }
570}