Skip to main content

gam_models/gpu_kernels/
survival_rowjet.rs

1//! Survival marginal-slope rigid per-row NLL jet on the GPU (#932 → A100 cutover).
2//!
3//! The rigid survival marginal-slope `RowKernel<4>`
4//! ([`crate::survival::marginal_slope::row_kernel::rigid_row_nll`], the
5//! #932 unified single source) computes, per row, the order-2 derivative tower
6//! `(v, g[4], H[4][4])` of the negative log-likelihood
7//!
8//! ```text
9//!   c(g)  = √(1 + (s·g)²·cov),   η0 = q0·c + s·g·z,   η1 = q1·c + s·g·z,
10//!   ad1   = qd1·c,
11//!   ℓ     = +w·logΦ(−η0) + w·(1−d)·logΦ(−η1) − w·d·(logφ(η1) + log ad1)
12//! ```
13//!
14//! plus the contracted third `Σ_c ℓ_{abc} dir_c` and fourth
15//! `Σ_{cd} ℓ_{abcd} u_c v_d`. Each row evaluates the probit Mills-ratio stack
16//! (`erfcx`/`erfc`) several times — a transcendental + bandwidth wall that the
17//! CPU pays serially per thread across all `n` rows on every inner-Newton step
18//! and on the #979 Jeffreys/Firth all-axes sweeps.
19//!
20//! On an A100 the per-row jet is embarrassingly parallel and the `erfc`/`erfcx`
21//! are hardware f64 special functions. Measured (aga13 A100, full f64, no
22//! fast-math, n=8e6): **~500× kernel-only** over the 16-thread CPU jet and
23//! **~160× end-to-end** with the on-device reduction; **device == CPU to 4.7e-12**
24//! over every channel (`v`, `g[4]`, `H[16]`, contracted third `[16]`, contracted
25//! fourth `[16]`). The standalone measurement prototype lives at
26//! `src/gpu/proto/survival_marginal_slope_jet_932.cu`.
27//!
28//! # Single source, exactly
29//!
30//! The device kernel is a byte-faithful port of the seeded-jet arithmetic that
31//! the CPU `rigid_row_nll` runs:
32//!
33//!   * `J2`  — order-2 `(v, g, H)` over `K=4` primaries (mirrors `Order2<4>`);
34//!   * `JS1` — one-seed jet whose ε-Hessian channel IS `Σ_c ℓ_{abc} dir_c`
35//!     (mirrors `OneSeed<4>` — O(K²) state, NOT a dense K³ `t3`);
36//!   * `JS2` — two-seed jet whose εδ-Hessian channel IS `Σ_{cd} ℓ_{abcd} u_c v_d`
37//!     (mirrors `TwoSeed<4>` — O(K²) state, NOT a dense K⁴ `t4`).
38//!
39//! Seeded jets are load-bearing: a dense `Tower4<4>` on device spills 41 KB/thread
40//! (256-entry `t4`) and OOMs the launch local-memory reservation; the seeded jets
41//! drop per-thread stack to ~900 B. The same NLL program (`def_nll!`) is written
42//! ONCE and instantiated at each scalar type — no bespoke gate chain rule, so the
43//! #736 cross-block sign-flip bug genus cannot reappear.
44//!
45//! # CPU fallback
46//!
47//! [`survival_rigid_row_jets`] is the general entry point. When a CUDA device is
48//! admitted and the batch is large enough to amortise the launch it runs the
49//! kernel; otherwise (no Linux / no runtime / probe failure / small `n` / any
50//! device error) it falls back to the CPU `rigid_row_nll` — the SAME unified jet —
51//! so the result is identical and the path is never GPU-only.
52
53use crate::survival::marginal_slope::row_kernel::RigidRowInputs;
54
55/// Per-row order-≤2 + contracted third/fourth channels for a batch of rows,
56/// flattened row-major. `K = 4` (the rigid survival primaries `q0,q1,qd1,g`).
57///
58/// * `value[row]`            — `ℓ`
59/// * `grad[row*K + a]`       — `∂ℓ/∂p_a`
60/// * `hess[row*K*K + a*K+b]` — `∂²ℓ/∂p_a∂p_b`
61/// * `third[row*K*K + a*K+b]`  — `Σ_c ℓ_{abc} dir_c`        (one fixed `dir`)
62/// * `fourth[row*K*K + a*K+b]` — `Σ_{cd} ℓ_{abcd} u_c v_d`  (one fixed `(u,v)`)
63#[derive(Debug, Clone, PartialEq)]
64pub struct SurvivalRowJetChannels {
65    pub n_rows: usize,
66    pub value: Vec<f64>,
67    pub grad: Vec<f64>,
68    pub hess: Vec<f64>,
69    pub third: Vec<f64>,
70    pub fourth: Vec<f64>,
71}
72
73/// The scalar-independent per-row inputs the kernel consumes: the four primaries
74/// `(q0,q1,qd1,g)` and the row scalars `(w,d,z_sum,cov_ones)`. `probit_scale` is
75/// shared across all rows (a scalar kernel argument). These are exactly the
76/// values [`RigidRowInputs`] + `rigid_row_kernel_primaries` produce per row.
77#[derive(Debug, Clone)]
78pub struct SurvivalRowInputs {
79    pub primaries: [f64; 4],
80    pub wi: f64,
81    pub di: f64,
82    pub z_sum: f64,
83    pub cov_ones: f64,
84}
85
86/// Minimum row count below which the device launch is not worth its fixed cost
87/// (probe + H2D + D2H). Below this the CPU path is used even when a device is
88/// available; the result is identical (same unified jet). The standalone A100
89/// measurement put the kernel/CPU crossover well under 1e5 rows; 1e5 is a
90/// conservative break-even that keeps small-fit latency on the CPU.
91pub const DEVICE_ROW_THRESHOLD: usize = 100_000;
92
93/// CPU reference / fallback: build every row's channels from the SAME unified jet
94/// the production `RowKernel` consumes (`rigid_row_nll` at `Order2`/`OneSeed`/
95/// `TwoSeed`). This is BOTH the fallback path AND the exactness oracle the device
96/// kernel is pinned to.
97#[must_use]
98pub fn survival_rigid_row_jets_cpu(
99    rows: &[SurvivalRowInputs],
100    probit_scale: f64,
101    dir: &[f64; 4],
102    dir_u: &[f64; 4],
103    dir_v: &[f64; 4],
104) -> SurvivalRowJetChannels {
105    use crate::survival::marginal_slope::row_kernel::{
106        RIGID_LINEAR_MASK, SparseOrder2, rigid_row_nll,
107    };
108    use gam_math::jet_scalar::{JetScalar, OneSeed, TwoSeed};
109    let n = rows.len();
110    let mut value = vec![0.0_f64; n];
111    let mut grad = vec![0.0_f64; n * 4];
112    let mut hess = vec![0.0_f64; n * 16];
113    let mut third = vec![0.0_f64; n * 16];
114    let mut fourth = vec![0.0_f64; n * 16];
115    for (row, inp) in rows.iter().enumerate() {
116        let in_row = RigidRowInputs {
117            row,
118            wi: inp.wi,
119            di: inp.di,
120            z_sum: inp.z_sum,
121            covariance_ones: inp.cov_ones,
122            probit_scale,
123            // The CPU monotonicity guard floor: the device kernel does not
124            // re-derive it (the caller pre-validates the primaries before
125            // building the batch), so use the always-pass sentinel here to
126            // keep the oracle a pure derivative comparison.
127            qd1_lower: f64::NEG_INFINITY,
128        };
129        // (v, g, H) at the static-sparsity Order2 scalar (production hot path).
130        let p = inp.primaries;
131        let vars: [SparseOrder2<RIGID_LINEAR_MASK>; 4] =
132            std::array::from_fn(|a| SparseOrder2::variable(p[a], a));
133        if let Ok(out) = rigid_row_nll(&vars, &in_row) {
134            value[row] = out.value();
135            grad[row * 4..row * 4 + 4].copy_from_slice(&out.g());
136            let h = out.h();
137            for a in 0..4 {
138                for b in 0..4 {
139                    hess[row * 16 + a * 4 + b] = h[a][b];
140                }
141            }
142        }
143        // contracted third via OneSeed (ε-Hessian = Σ_c ℓ_{abc} dir_c).
144        let vars1: [OneSeed<4>; 4] =
145            std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
146        if let Ok(out1) = rigid_row_nll(&vars1, &in_row) {
147            let t = out1.contracted_third();
148            for a in 0..4 {
149                for b in 0..4 {
150                    third[row * 16 + a * 4 + b] = t[a][b];
151                }
152            }
153        }
154        // contracted fourth via TwoSeed (εδ-Hessian = Σ_{cd} ℓ_{abcd} u_c v_d).
155        let vars2: [TwoSeed<4>; 4] =
156            std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
157        if let Ok(out2) = rigid_row_nll(&vars2, &in_row) {
158            let f = out2.contracted_fourth();
159            for a in 0..4 {
160                for b in 0..4 {
161                    fourth[row * 16 + a * 4 + b] = f[a][b];
162                }
163            }
164        }
165    }
166    SurvivalRowJetChannels {
167        n_rows: n,
168        value,
169        grad,
170        hess,
171        third,
172        fourth,
173    }
174}
175
176/// General entry point: compute every row's order-≤2 + contracted third/fourth
177/// channels, on the GPU when a CUDA device is admitted and the batch is large
178/// enough to amortise the launch, else on the CPU. Both paths run the SAME
179/// unified jet, so the result is identical (proven ≤1e-9; measured 4.7e-12 on the
180/// A100). On ANY device error the CPU path runs — no fragility.
181#[must_use]
182pub fn survival_rigid_row_jets(
183    rows: &[SurvivalRowInputs],
184    probit_scale: f64,
185    dir: &[f64; 4],
186    dir_u: &[f64; 4],
187    dir_v: &[f64; 4],
188) -> SurvivalRowJetChannels {
189    #[cfg(target_os = "linux")]
190    {
191        if rows.len() >= DEVICE_ROW_THRESHOLD {
192            match device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v) {
193                Ok(out) => return out,
194                Err(e) => {
195                    // Fall through to CPU on any device error (the GPU path is an
196                    // accelerator, never the only correct path). Log WHY so a
197                    // silent CPU fallback on an admitted device is diagnosable.
198                    log::info!("[GPU] survival_rowjet device path fell back to CPU: {e}");
199                }
200            }
201        }
202    }
203    survival_rigid_row_jets_cpu(rows, probit_scale, dir, dir_u, dir_v)
204}
205
206/// Diagnostic: run ONLY the device path and return its `Result` (the error
207/// string on failure). Linux-only; intended for A100 verification harnesses to
208/// surface a compile/launch failure that the silent-fallback dispatcher hides.
209#[cfg(target_os = "linux")]
210pub fn survival_rigid_row_jets_device_only(
211    rows: &[SurvivalRowInputs],
212    probit_scale: f64,
213    dir: &[f64; 4],
214    dir_u: &[f64; 4],
215    dir_v: &[f64; 4],
216) -> Result<SurvivalRowJetChannels, String> {
217    device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v)
218        .map_err(|e| e.to_string())
219}
220
221/// The NVRTC source: a byte-faithful port of the seeded-jet arithmetic.
222/// `K=4` is fixed for the rigid survival primaries, so the kernel is compiled
223/// once (no shape macros). Full f64, no fast-math.
224#[cfg(target_os = "linux")]
225pub const SURVIVAL_ROWJET_SOURCE: &str = include_str!("survival_rowjet_kernel.cu");
226
227#[cfg(target_os = "linux")]
228mod device {
229    use super::{SURVIVAL_ROWJET_SOURCE, SurvivalRowInputs, SurvivalRowJetChannels};
230    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
231    use std::sync::{Arc, Mutex, OnceLock};
232
233    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
234
235    struct Backend {
236        ctx: Arc<CudaContext>,
237        stream: Arc<CudaStream>,
238        module: Mutex<Option<Arc<CudaModule>>>,
239    }
240
241    fn backend() -> Result<&'static Backend, GpuError> {
242        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
243        BACKEND
244            .get_or_init(|| {
245                let parts = gam_gpu::backend_probe::probe_cuda_backend("survival_rowjet")?;
246                Ok(Backend {
247                    ctx: parts.ctx,
248                    stream: parts.stream,
249                    module: Mutex::new(None),
250                })
251            })
252            .as_ref()
253            .map_err(GpuError::clone)
254    }
255
256    fn module(b: &Backend) -> Result<Arc<CudaModule>, GpuError> {
257        if let Ok(guard) = b.module.lock() {
258            if let Some(m) = guard.as_ref() {
259                return Ok(m.clone());
260            }
261        }
262        let ptx = cudarc::nvrtc::compile_ptx(SURVIVAL_ROWJET_SOURCE)
263            .gpu_ctx_with(|err| format!("survival_rowjet NVRTC compile: {err}"))?;
264        let m = b
265            .ctx
266            .load_module(ptx)
267            .gpu_ctx("survival_rowjet module load")?;
268        if let Ok(mut guard) = b.module.lock() {
269            guard.get_or_insert_with(|| m.clone());
270        }
271        Ok(m)
272    }
273
274    fn has_nonzero_direction(dir: &[f64; 4]) -> bool {
275        dir.iter().any(|&v| v != 0.0)
276    }
277
278    pub(super) fn survival_rigid_row_jets_device(
279        rows: &[SurvivalRowInputs],
280        probit_scale: f64,
281        dir: &[f64; 4],
282        dir_u: &[f64; 4],
283        dir_v: &[f64; 4],
284    ) -> Result<SurvivalRowJetChannels, GpuError> {
285        let n = rows.len();
286        if n == 0 {
287            return Ok(SurvivalRowJetChannels {
288                n_rows: 0,
289                value: Vec::new(),
290                grad: Vec::new(),
291                hess: Vec::new(),
292                third: Vec::new(),
293                fourth: Vec::new(),
294            });
295        }
296        let b = backend()?;
297        let m = module(b)?;
298        let need_fourth = has_nonzero_direction(dir_u) && has_nonzero_direction(dir_v);
299        let func_name = if need_fourth {
300            "survival_rowjet"
301        } else {
302            "survival_rowjet_no_t4"
303        };
304        let func = m
305            .load_function(func_name)
306            .gpu_ctx_with(|err| format!("survival_rowjet load_function {func_name}: {err}"))?;
307        let stream = b.stream.clone();
308
309        // Flatten inputs into struct-of-arrays for coalesced device reads.
310        let mut q0 = vec![0.0_f64; n];
311        let mut q1 = vec![0.0_f64; n];
312        let mut qd1 = vec![0.0_f64; n];
313        let mut g = vec![0.0_f64; n];
314        let mut wi = vec![0.0_f64; n];
315        let mut di = vec![0.0_f64; n];
316        let mut zs = vec![0.0_f64; n];
317        let mut cov = vec![0.0_f64; n];
318        for (i, r) in rows.iter().enumerate() {
319            q0[i] = r.primaries[0];
320            q1[i] = r.primaries[1];
321            qd1[i] = r.primaries[2];
322            g[i] = r.primaries[3];
323            wi[i] = r.wi;
324            di[i] = r.di;
325            zs[i] = r.z_sum;
326            cov[i] = r.cov_ones;
327        }
328
329        let q0_d = stream.clone_htod(&q0).gpu_ctx("htod q0")?;
330        let q1_d = stream.clone_htod(&q1).gpu_ctx("htod q1")?;
331        let qd1_d = stream.clone_htod(&qd1).gpu_ctx("htod qd1")?;
332        let g_d = stream.clone_htod(&g).gpu_ctx("htod g")?;
333        let wi_d = stream.clone_htod(&wi).gpu_ctx("htod wi")?;
334        let di_d = stream.clone_htod(&di).gpu_ctx("htod di")?;
335        let zs_d = stream.clone_htod(&zs).gpu_ctx("htod zsum")?;
336        let cov_d = stream.clone_htod(&cov).gpu_ctx("htod cov")?;
337        let dir_d = stream.clone_htod(&dir.to_vec()).gpu_ctx("htod dir")?;
338
339        let mut value_d = stream.alloc_zeros::<f64>(n).gpu_ctx("alloc value")?;
340        let mut grad_d = stream.alloc_zeros::<f64>(n * 4).gpu_ctx("alloc grad")?;
341        let mut hess_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc hess")?;
342        let mut third_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc third")?;
343        let mut fourth_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc fourth")?;
344
345        let n_i32 = i32::try_from(n)
346            .map_err(|_| gam_gpu::gpu_err!("survival_rowjet n={n} overflows i32"))?;
347        const TPB: u32 = 128;
348        let grid = ((n as u32).div_ceil(TPB)).max(1);
349        let cfg = LaunchConfig {
350            grid_dim: (grid, 1, 1),
351            block_dim: (TPB, 1, 1),
352            shared_mem_bytes: 0,
353        };
354        let mut builder = stream.launch_builder(&func);
355        builder
356            .arg(&n_i32)
357            .arg(&q0_d)
358            .arg(&q1_d)
359            .arg(&qd1_d)
360            .arg(&g_d)
361            .arg(&wi_d)
362            .arg(&di_d)
363            .arg(&zs_d)
364            .arg(&cov_d)
365            .arg(&probit_scale)
366            .arg(&dir_d);
367        let diru_d;
368        let dirv_d;
369        if need_fourth {
370            diru_d = stream.clone_htod(&dir_u.to_vec()).gpu_ctx("htod dir_u")?;
371            dirv_d = stream.clone_htod(&dir_v.to_vec()).gpu_ctx("htod dir_v")?;
372            builder.arg(&diru_d).arg(&dirv_d);
373        }
374        builder
375            .arg(&mut value_d)
376            .arg(&mut grad_d)
377            .arg(&mut hess_d)
378            .arg(&mut third_d)
379            .arg(&mut fourth_d);
380        // SAFETY: grid/block validated; every pointer is a cudarc-checked
381        // allocation on this stream; the selected kernel reads the 8 input
382        // arrays of length n (+ one or three length-4 directions) and writes
383        // within the output buffers of length n / n*16.
384        unsafe { builder.launch(cfg) }.gpu_ctx("survival_rowjet kernel launch")?;
385
386        let mut value = vec![0.0_f64; n];
387        let mut grad = vec![0.0_f64; n * 4];
388        let mut hess = vec![0.0_f64; n * 16];
389        let mut third = vec![0.0_f64; n * 16];
390        let mut fourth = vec![0.0_f64; n * 16];
391        stream
392            .memcpy_dtoh(&value_d, &mut value)
393            .gpu_ctx("dtoh value")?;
394        stream
395            .memcpy_dtoh(&grad_d, &mut grad)
396            .gpu_ctx("dtoh grad")?;
397        stream
398            .memcpy_dtoh(&hess_d, &mut hess)
399            .gpu_ctx("dtoh hess")?;
400        stream
401            .memcpy_dtoh(&third_d, &mut third)
402            .gpu_ctx("dtoh third")?;
403        stream
404            .memcpy_dtoh(&fourth_d, &mut fourth)
405            .gpu_ctx("dtoh fourth")?;
406        stream
407            .synchronize()
408            .gpu_ctx("survival_rowjet synchronize")?;
409
410        Ok(SurvivalRowJetChannels {
411            n_rows: n,
412            value,
413            grad,
414            hess,
415            third,
416            fourth,
417        })
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    fn fixture(n: usize) -> Vec<SurvivalRowInputs> {
426        (0..n)
427            .map(|i| {
428                let t = i as f64 / n as f64;
429                SurvivalRowInputs {
430                    primaries: [
431                        -2.5 + 5.0 * (12.0 * t).sin(),
432                        -1.5 + 4.0 * (9.0 * t + 0.3).cos(),
433                        0.2 + 1.8 * (0.5 + 0.5 * (7.0 * t).sin()),
434                        -1.0 + 2.0 * (5.0 * t + 1.1).sin(),
435                    ],
436                    wi: 1.0,
437                    di: if i % 3 == 0 { 1.0 } else { 0.0 },
438                    z_sum: 0.5 * (3.0 * t).cos(),
439                    cov_ones: 0.4 + 0.3 * (0.5 + 0.5 * (2.0 * t).sin()),
440                }
441            })
442            .collect()
443    }
444
445    const DIR: [f64; 4] = [0.31, -0.22, 0.17, 0.44];
446    const DIRU: [f64; 4] = [0.13, 0.27, -0.41, 0.05];
447    const DIRV: [f64; 4] = [-0.19, 0.33, 0.08, 0.22];
448
449    #[test]
450    fn cpu_channels_match_unified_rowkernel() {
451        // The CPU fallback IS `rigid_row_nll` at Order2/OneSeed/TwoSeed, the same
452        // thing the production `SurvivalMarginalSlopeRowKernel` calls. Cross-check
453        // the (v,g,H) channels against a direct `Order2<4>` evaluation so the
454        // flattening/layout is pinned to the single source.
455        use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
456        use gam_math::jet_scalar::{JetScalar, Order2};
457        let rows = fixture(7);
458        let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
459        for (row, inp) in rows.iter().enumerate() {
460            let in_row = RigidRowInputs {
461                row,
462                wi: inp.wi,
463                di: inp.di,
464                z_sum: inp.z_sum,
465                covariance_ones: inp.cov_ones,
466                probit_scale: 0.7,
467                qd1_lower: f64::NEG_INFINITY,
468            };
469            let vars: [Order2<4>; 4] =
470                std::array::from_fn(|a| Order2::variable(inp.primaries[a], a));
471            let dense = rigid_row_nll(&vars, &in_row).expect("dense order2");
472            assert!((dense.value() - out.value[row]).abs() <= 1e-12);
473            for a in 0..4 {
474                assert!((dense.g()[a] - out.grad[row * 4 + a]).abs() <= 1e-12);
475                for b in 0..4 {
476                    assert!(
477                        (dense.h()[a][b] - out.hess[row * 16 + a * 4 + b]).abs() <= 1e-12,
478                        "hess mismatch row {row} {a},{b}"
479                    );
480                }
481            }
482        }
483    }
484
485    #[test]
486    fn cpu_third_fourth_match_dense_tower_oracle() {
487        // The seeded-jet (OneSeed/TwoSeed, O(K²)) contracted third/fourth in the
488        // CPU fallback must equal the TRUE tensor contraction from the dense
489        // `Tower4<4>` (the K³/K⁴ tensor). This pins the seeded contraction to the
490        // single-source tensor exactly — the same property the device kernel's
491        // JS1/JS2 channels rely on (and the device parity gate then matches THIS
492        // CPU result to ≤1e-9).
493        use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
494        use gam_math::jet_tower::Tower4;
495        let rows = fixture(9);
496        let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
497        for (row, inp) in rows.iter().enumerate() {
498            let in_row = RigidRowInputs {
499                row,
500                wi: inp.wi,
501                di: inp.di,
502                z_sum: inp.z_sum,
503                covariance_ones: inp.cov_ones,
504                probit_scale: 0.7,
505                qd1_lower: f64::NEG_INFINITY,
506            };
507            let vars: [Tower4<4>; 4] =
508                std::array::from_fn(|a| Tower4::variable(inp.primaries[a], a));
509            let tower = rigid_row_nll(&vars, &in_row).expect("dense tower4");
510            let t3 = tower.third_contracted(&DIR);
511            let t4 = tower.fourth_contracted(&DIRU, &DIRV);
512            for a in 0..4 {
513                for b in 0..4 {
514                    assert!(
515                        (t3[a][b] - out.third[row * 16 + a * 4 + b]).abs() <= 1e-12,
516                        "third mismatch row {row} {a},{b}: tensor={} seeded={}",
517                        t3[a][b],
518                        out.third[row * 16 + a * 4 + b]
519                    );
520                    assert!(
521                        (t4[a][b] - out.fourth[row * 16 + a * 4 + b]).abs() <= 1e-12,
522                        "fourth mismatch row {row} {a},{b}: tensor={} seeded={}",
523                        t4[a][b],
524                        out.fourth[row * 16 + a * 4 + b]
525                    );
526                }
527            }
528        }
529    }
530
531    #[cfg(target_os = "linux")]
532    #[test]
533    fn device_matches_cpu_when_available() {
534        // Exactness gate: when a device is admitted, every channel must match the
535        // CPU unified jet to <=1e-9 (measured 4.7e-12 on the A100). When no device
536        // is available the dispatcher returns the CPU result, so this asserts the
537        // contract on whichever path ran.
538        let rows = fixture(DEVICE_ROW_THRESHOLD + 1024);
539        let cpu = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
540        let got = survival_rigid_row_jets(&rows, 0.7, &DIR, &DIRU, &DIRV);
541        let mut maxabs = 0.0_f64;
542        let cmp = |a: &[f64], b: &[f64], m: &mut f64| {
543            for (x, y) in a.iter().zip(b) {
544                *m = m.max((x - y).abs());
545            }
546        };
547        cmp(&cpu.value, &got.value, &mut maxabs);
548        cmp(&cpu.grad, &got.grad, &mut maxabs);
549        cmp(&cpu.hess, &got.hess, &mut maxabs);
550        cmp(&cpu.third, &got.third, &mut maxabs);
551        cmp(&cpu.fourth, &got.fourth, &mut maxabs);
552        assert!(
553            maxabs <= 1e-9,
554            "survival device vs CPU row-jet max abs diff {maxabs} > 1e-9"
555        );
556    }
557}