gam_models/gpu_kernels/survival_rowjet.rs
1//! Survival marginal-slope rigid per-row NLL jet on the GPU (#932 → A100 cutover).
2//!
3//! The rigid survival marginal-slope `RowKernel<4>`
4//! ([`crate::survival::marginal_slope::row_kernel::rigid_row_nll`], the
5//! #932 unified single source) computes, per row, the order-2 derivative tower
6//! `(v, g[4], H[4][4])` of the negative log-likelihood
7//!
8//! ```text
9//! c(g) = √(1 + (s·g)²·cov), η0 = q0·c + s·g·z, η1 = q1·c + s·g·z,
10//! ad1 = qd1·c,
11//! ℓ = +w·logΦ(−η0) + w·(1−d)·logΦ(−η1) − w·d·(logφ(η1) + log ad1)
12//! ```
13//!
14//! plus the contracted third `Σ_c ℓ_{abc} dir_c` and fourth
15//! `Σ_{cd} ℓ_{abcd} u_c v_d`. Each row evaluates the probit Mills-ratio stack
16//! (`erfcx`/`erfc`) several times — a transcendental + bandwidth wall that the
17//! CPU pays serially per thread across all `n` rows on every inner-Newton step
18//! and on the #979 Jeffreys/Firth all-axes sweeps.
19//!
20//! On an A100 the per-row jet is embarrassingly parallel and the `erfc`/`erfcx`
21//! are hardware f64 special functions. Measured (aga13 A100, full f64, no
22//! fast-math, n=8e6): **~500× kernel-only** over the 16-thread CPU jet and
23//! **~160× end-to-end** with the on-device reduction. The standalone
24//! measurement prototype lives at
25//! `src/gpu/proto/survival_marginal_slope_jet_932.cu`.
26//!
27//! # CPU↔device parity (#415 / #1175)
28//!
29//! The device kernel runs the SAME seeded-jet arithmetic as the CPU jet (pinned
30//! line-for-line by the host-oracle `*_tests` module on every box), so the
31//! CPU↔device residual is NOT an algebra mismatch. After #1686 disabled NVRTC
32//! FMA contraction (`--fmad=false`, applied here because this kernel now
33//! compiles through `device_cache::compile_ptx_arch`, the shared arch+fmad
34//! options), TWO distinct floors remain, with very different magnitudes:
35//!
36//! * **Low-order channels (value/grad/hess)** — FMA contraction WAS the
37//! dominant source here, so `--fmad=false` tightened them sharply. Measured
38//! on a **Tesla V100 (sm_70)**: value 1.5e-10, grad 8.2e-10, hess 8.8e-9
39//! absolute (≤1.1e-1 normalized to channel magnitude).
40//! * **High-order channels (third/fourth)** — dominated by *transcendental*
41//! drift, NOT FMA: CUDA's `erfc`/`erfcx`/`exp`/`sqrt` differ from the host
42//! libm at the ULP level, and that ε is amplified ~5e8× through the order-4
43//! seeded-jet chain. `--fmad=false` leaves these essentially unchanged
44//! (third 5.09e-8, fourth 4.54e-8 absolute — bit-identical to the
45//! pre-#1686 measurement to 4 sig figs), confirming FMA was never their
46//! root cause. Normalized to channel magnitude they are ≤1.2e-9 (third) and
47//! bounded by the magnitude-scaled band below (fourth).
48//!
49//! The parity gate (`tests::device_matches_cpu_when_available`, and the
50//! fail-loud device-only sweep) is therefore a per-channel
51//! `atol + rtol·channel_scale` band, NOT a flat absolute tolerance — see
52//! `tests::PARITY_RTOL` for why a flat `1e-9` absolute bound was wrong (it
53//! ignored both derivative-order amplification AND the transcendental floor
54//! that #1686's FMA fix cannot reach) and why the magnitude-scaled band still
55//! catches any real algebra bug with comfortable headroom. This band is
56//! *complementary* to #1686, not redundant: #1686 removes the FMA component,
57//! the band absorbs the irreducible transcendental component.
58//!
59//! # Single source, exactly
60//!
61//! The device kernel is a byte-faithful port of the seeded-jet arithmetic that
62//! the CPU `rigid_row_nll` runs:
63//!
64//! * `J2` — order-2 `(v, g, H)` over `K=4` primaries (mirrors `Order2<4>`);
65//! * `JS1` — one-seed jet whose ε-Hessian channel IS `Σ_c ℓ_{abc} dir_c`
66//! (mirrors `OneSeed<4>` — O(K²) state, NOT a dense K³ `t3`);
67//! * `JS2` — two-seed jet whose εδ-Hessian channel IS `Σ_{cd} ℓ_{abcd} u_c v_d`
68//! (mirrors `TwoSeed<4>` — O(K²) state, NOT a dense K⁴ `t4`).
69//!
70//! Seeded jets are load-bearing: a dense `Tower4<4>` on device spills 41 KB/thread
71//! (256-entry `t4`) and OOMs the launch local-memory reservation; the seeded jets
72//! drop per-thread stack to ~900 B. The same NLL program (`def_nll!`) is written
73//! ONCE and instantiated at each scalar type — no bespoke gate chain rule, so the
74//! #736 cross-block sign-flip bug genus cannot reappear.
75//!
76//! # CPU fallback
77//!
78//! [`survival_rigid_row_jets`] is the general entry point. When a CUDA device is
79//! admitted and the batch is large enough to amortise the launch it runs the
80//! kernel; otherwise (no Linux / no runtime / probe failure / small `n` / any
81//! device error) it falls back to the CPU `rigid_row_nll` — the SAME unified jet —
82//! so the result is identical and the path is never GPU-only.
83
84use crate::survival::marginal_slope::row_kernel::RigidRowInputs;
85
86// #415 parity-lock: a host transcription of the device `.cu` seeded-jet
87// arithmetic, pinned to the production CPU jet on every box. Declared bare
88// (the whole file is `#![cfg(test)]`) with a `*_tests` name so the build.rs
89// ban-scanner exempts the test-only substrate — see `bms::test_support`.
90mod survival_rowjet_host_oracle_tests;
91
92/// Per-row order-≤2 + contracted third/fourth channels for a batch of rows,
93/// flattened row-major. `K = 4` (the rigid survival primaries `q0,q1,qd1,g`).
94///
95/// * `value[row]` — `ℓ`
96/// * `grad[row*K + a]` — `∂ℓ/∂p_a`
97/// * `hess[row*K*K + a*K+b]` — `∂²ℓ/∂p_a∂p_b`
98/// * `third[row*K*K + a*K+b]` — `Σ_c ℓ_{abc} dir_c` (one fixed `dir`)
99/// * `fourth[row*K*K + a*K+b]` — `Σ_{cd} ℓ_{abcd} u_c v_d` (one fixed `(u,v)`)
100#[derive(Debug, Clone, PartialEq)]
101pub struct SurvivalRowJetChannels {
102 pub n_rows: usize,
103 pub value: Vec<f64>,
104 pub grad: Vec<f64>,
105 pub hess: Vec<f64>,
106 pub third: Vec<f64>,
107 pub fourth: Vec<f64>,
108}
109
110/// The scalar-independent per-row inputs the kernel consumes: the four primaries
111/// `(q0,q1,qd1,g)` and the row scalars `(w,d,z_sum,cov_ones)`. `probit_scale` is
112/// shared across all rows (a scalar kernel argument). These are exactly the
113/// values [`RigidRowInputs`] + `rigid_row_kernel_primaries` produce per row.
114#[derive(Debug, Clone)]
115pub struct SurvivalRowInputs {
116 pub primaries: [f64; 4],
117 pub wi: f64,
118 pub di: f64,
119 pub z_sum: f64,
120 pub cov_ones: f64,
121}
122
123/// Minimum row count below which the device launch is not worth its fixed cost
124/// (probe + H2D + D2H). Below this the CPU path is used even when a device is
125/// available; the result is identical (same unified jet). The standalone A100
126/// measurement put the kernel/CPU crossover well under 1e5 rows; 1e5 is a
127/// conservative break-even that keeps small-fit latency on the CPU.
128pub const DEVICE_ROW_THRESHOLD: usize = 100_000;
129
130/// CPU reference / fallback: build every row's channels from the SAME unified jet
131/// the production `RowKernel` consumes (`rigid_row_nll` at `Order2`/`OneSeed`/
132/// `TwoSeed`). This is BOTH the fallback path AND the exactness oracle the device
133/// kernel is pinned to.
134#[must_use]
135pub fn survival_rigid_row_jets_cpu(
136 rows: &[SurvivalRowInputs],
137 probit_scale: f64,
138 dir: &[f64; 4],
139 dir_u: &[f64; 4],
140 dir_v: &[f64; 4],
141) -> SurvivalRowJetChannels {
142 use crate::survival::marginal_slope::row_kernel::{
143 RIGID_LINEAR_MASK, SparseOrder2, rigid_row_nll,
144 };
145 use gam_math::jet_scalar::{JetScalar, OneSeed, TwoSeed};
146 let n = rows.len();
147 let mut value = vec![0.0_f64; n];
148 let mut grad = vec![0.0_f64; n * 4];
149 let mut hess = vec![0.0_f64; n * 16];
150 let mut third = vec![0.0_f64; n * 16];
151 let mut fourth = vec![0.0_f64; n * 16];
152 for (row, inp) in rows.iter().enumerate() {
153 let in_row = RigidRowInputs {
154 row,
155 wi: inp.wi,
156 di: inp.di,
157 z_sum: inp.z_sum,
158 covariance_ones: inp.cov_ones,
159 probit_scale,
160 // The CPU monotonicity guard floor: the device kernel does not
161 // re-derive it (the caller pre-validates the primaries before
162 // building the batch), so use the always-pass sentinel here to
163 // keep the oracle a pure derivative comparison.
164 qd1_lower: f64::NEG_INFINITY,
165 };
166 // (v, g, H) at the static-sparsity Order2 scalar (production hot path).
167 let p = inp.primaries;
168 let vars: [SparseOrder2<RIGID_LINEAR_MASK>; 4] =
169 std::array::from_fn(|a| SparseOrder2::variable(p[a], a));
170 if let Ok(out) = rigid_row_nll(&vars, &in_row) {
171 value[row] = out.value();
172 grad[row * 4..row * 4 + 4].copy_from_slice(&out.g());
173 let h = out.h();
174 for a in 0..4 {
175 for b in 0..4 {
176 hess[row * 16 + a * 4 + b] = h[a][b];
177 }
178 }
179 }
180 // contracted third via OneSeed (ε-Hessian = Σ_c ℓ_{abc} dir_c).
181 let vars1: [OneSeed<4>; 4] =
182 std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
183 if let Ok(out1) = rigid_row_nll(&vars1, &in_row) {
184 let t = out1.contracted_third();
185 for a in 0..4 {
186 for b in 0..4 {
187 third[row * 16 + a * 4 + b] = t[a][b];
188 }
189 }
190 }
191 // contracted fourth via TwoSeed (εδ-Hessian = Σ_{cd} ℓ_{abcd} u_c v_d).
192 let vars2: [TwoSeed<4>; 4] =
193 std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
194 if let Ok(out2) = rigid_row_nll(&vars2, &in_row) {
195 let f = out2.contracted_fourth();
196 for a in 0..4 {
197 for b in 0..4 {
198 fourth[row * 16 + a * 4 + b] = f[a][b];
199 }
200 }
201 }
202 }
203 SurvivalRowJetChannels {
204 n_rows: n,
205 value,
206 grad,
207 hess,
208 third,
209 fourth,
210 }
211}
212
213/// General entry point: compute every row's order-≤2 + contracted third/fourth
214/// channels, on the GPU when a CUDA device is admitted and the batch is large
215/// enough to amortise the launch, else on the CPU. Both paths run the SAME
216/// unified jet, so the result agrees within the per-channel magnitude-scaled
217/// parity band (irreducible transcendental drift only — see the module docs and
218/// `tests::PARITY_RTOL`; worst measured ≤1.2e-9 relative on a V100). On ANY
219/// device error the CPU path runs — no fragility.
220#[must_use]
221pub fn survival_rigid_row_jets(
222 rows: &[SurvivalRowInputs],
223 probit_scale: f64,
224 dir: &[f64; 4],
225 dir_u: &[f64; 4],
226 dir_v: &[f64; 4],
227) -> SurvivalRowJetChannels {
228 #[cfg(target_os = "linux")]
229 {
230 if rows.len() >= DEVICE_ROW_THRESHOLD {
231 match device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v) {
232 Ok(out) => return out,
233 Err(e) => {
234 // Fall through to CPU on any device error (the GPU path is an
235 // accelerator, never the only correct path). Log WHY so a
236 // silent CPU fallback on an admitted device is diagnosable.
237 log::info!("[GPU] survival_rowjet device path fell back to CPU: {e}");
238 }
239 }
240 }
241 }
242 survival_rigid_row_jets_cpu(rows, probit_scale, dir, dir_u, dir_v)
243}
244
245/// Diagnostic: run ONLY the device path and return its `Result` (the error
246/// string on failure). Linux-only; intended for A100 verification harnesses to
247/// surface a compile/launch failure that the silent-fallback dispatcher hides.
248#[cfg(target_os = "linux")]
249pub fn survival_rigid_row_jets_device_only(
250 rows: &[SurvivalRowInputs],
251 probit_scale: f64,
252 dir: &[f64; 4],
253 dir_u: &[f64; 4],
254 dir_v: &[f64; 4],
255) -> Result<SurvivalRowJetChannels, String> {
256 device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v)
257 .map_err(|e| e.to_string())
258}
259
260/// The NVRTC source: a byte-faithful port of the seeded-jet arithmetic.
261/// `K=4` is fixed for the rigid survival primaries, so the kernel is compiled
262/// once (no shape macros). Full f64, no fast-math.
263#[cfg(target_os = "linux")]
264pub const SURVIVAL_ROWJET_SOURCE: &str = include_str!("survival_rowjet_kernel.cu");
265
266#[cfg(target_os = "linux")]
267mod device {
268 use super::{SURVIVAL_ROWJET_SOURCE, SurvivalRowInputs, SurvivalRowJetChannels};
269 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
270 use std::sync::{Arc, Mutex, OnceLock};
271
272 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
273
274 struct Backend {
275 ctx: Arc<CudaContext>,
276 stream: Arc<CudaStream>,
277 module: Mutex<Option<Arc<CudaModule>>>,
278 }
279
280 fn backend() -> Result<&'static Backend, GpuError> {
281 static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
282 BACKEND
283 .get_or_init(|| {
284 let parts = gam_gpu::backend_probe::probe_cuda_backend("survival_rowjet")?;
285 Ok(Backend {
286 ctx: parts.ctx,
287 stream: parts.stream,
288 module: Mutex::new(None),
289 })
290 })
291 .as_ref()
292 .map_err(GpuError::clone)
293 }
294
295 fn module(b: &Backend) -> Result<Arc<CudaModule>, GpuError> {
296 if let Ok(guard) = b.module.lock() {
297 if let Some(m) = guard.as_ref() {
298 return Ok(m.clone());
299 }
300 }
301 // Compile through the shared arch+fmad options (NOT bare `compile_ptx`,
302 // which leaves NVRTC at `--fmad=true` and no `--gpu-architecture` pin).
303 // FMA contraction must be off so the deep seeded-jet tower is
304 // bit-comparable to the separately-rounded CPU oracle — bare
305 // `compile_ptx` made this kernel miss the 1e-9 parity gate by ~5e-8 on
306 // a V100. The arch pin keeps the kernel keyed to the device's real
307 // compute capability rather than NVRTC's default.
308 let ptx = gam_gpu::device_cache::compile_ptx_arch(SURVIVAL_ROWJET_SOURCE)
309 .gpu_ctx_with(|err| format!("survival_rowjet NVRTC compile: {err}"))?;
310 let m = b
311 .ctx
312 .load_module(ptx)
313 .gpu_ctx("survival_rowjet module load")?;
314 if let Ok(mut guard) = b.module.lock() {
315 guard.get_or_insert_with(|| m.clone());
316 }
317 Ok(m)
318 }
319
320 fn has_nonzero_direction(dir: &[f64; 4]) -> bool {
321 dir.iter().any(|&v| v != 0.0)
322 }
323
324 pub(super) fn survival_rigid_row_jets_device(
325 rows: &[SurvivalRowInputs],
326 probit_scale: f64,
327 dir: &[f64; 4],
328 dir_u: &[f64; 4],
329 dir_v: &[f64; 4],
330 ) -> Result<SurvivalRowJetChannels, GpuError> {
331 let n = rows.len();
332 if n == 0 {
333 return Ok(SurvivalRowJetChannels {
334 n_rows: 0,
335 value: Vec::new(),
336 grad: Vec::new(),
337 hess: Vec::new(),
338 third: Vec::new(),
339 fourth: Vec::new(),
340 });
341 }
342 let b = backend()?;
343 let m = module(b)?;
344 let need_fourth = has_nonzero_direction(dir_u) && has_nonzero_direction(dir_v);
345 let func_name = if need_fourth {
346 "survival_rowjet"
347 } else {
348 "survival_rowjet_no_t4"
349 };
350 let func = m
351 .load_function(func_name)
352 .gpu_ctx_with(|err| format!("survival_rowjet load_function {func_name}: {err}"))?;
353 let stream = b.stream.clone();
354
355 // Flatten inputs into struct-of-arrays for coalesced device reads.
356 let mut q0 = vec![0.0_f64; n];
357 let mut q1 = vec![0.0_f64; n];
358 let mut qd1 = vec![0.0_f64; n];
359 let mut g = vec![0.0_f64; n];
360 let mut wi = vec![0.0_f64; n];
361 let mut di = vec![0.0_f64; n];
362 let mut zs = vec![0.0_f64; n];
363 let mut cov = vec![0.0_f64; n];
364 for (i, r) in rows.iter().enumerate() {
365 q0[i] = r.primaries[0];
366 q1[i] = r.primaries[1];
367 qd1[i] = r.primaries[2];
368 g[i] = r.primaries[3];
369 wi[i] = r.wi;
370 di[i] = r.di;
371 zs[i] = r.z_sum;
372 cov[i] = r.cov_ones;
373 }
374
375 let q0_d = stream.clone_htod(&q0).gpu_ctx("htod q0")?;
376 let q1_d = stream.clone_htod(&q1).gpu_ctx("htod q1")?;
377 let qd1_d = stream.clone_htod(&qd1).gpu_ctx("htod qd1")?;
378 let g_d = stream.clone_htod(&g).gpu_ctx("htod g")?;
379 let wi_d = stream.clone_htod(&wi).gpu_ctx("htod wi")?;
380 let di_d = stream.clone_htod(&di).gpu_ctx("htod di")?;
381 let zs_d = stream.clone_htod(&zs).gpu_ctx("htod zsum")?;
382 let cov_d = stream.clone_htod(&cov).gpu_ctx("htod cov")?;
383 let dir_d = stream.clone_htod(&dir.to_vec()).gpu_ctx("htod dir")?;
384
385 let mut value_d = stream.alloc_zeros::<f64>(n).gpu_ctx("alloc value")?;
386 let mut grad_d = stream.alloc_zeros::<f64>(n * 4).gpu_ctx("alloc grad")?;
387 let mut hess_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc hess")?;
388 let mut third_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc third")?;
389 let mut fourth_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc fourth")?;
390
391 let n_i32 = i32::try_from(n)
392 .map_err(|_| gam_gpu::gpu_err!("survival_rowjet n={n} overflows i32"))?;
393 const TPB: u32 = 128;
394 let grid = ((n as u32).div_ceil(TPB)).max(1);
395 let cfg = LaunchConfig {
396 grid_dim: (grid, 1, 1),
397 block_dim: (TPB, 1, 1),
398 shared_mem_bytes: 0,
399 };
400 let mut builder = stream.launch_builder(&func);
401 builder
402 .arg(&n_i32)
403 .arg(&q0_d)
404 .arg(&q1_d)
405 .arg(&qd1_d)
406 .arg(&g_d)
407 .arg(&wi_d)
408 .arg(&di_d)
409 .arg(&zs_d)
410 .arg(&cov_d)
411 .arg(&probit_scale)
412 .arg(&dir_d);
413 let diru_d;
414 let dirv_d;
415 if need_fourth {
416 diru_d = stream.clone_htod(&dir_u.to_vec()).gpu_ctx("htod dir_u")?;
417 dirv_d = stream.clone_htod(&dir_v.to_vec()).gpu_ctx("htod dir_v")?;
418 builder.arg(&diru_d).arg(&dirv_d);
419 }
420 builder
421 .arg(&mut value_d)
422 .arg(&mut grad_d)
423 .arg(&mut hess_d)
424 .arg(&mut third_d)
425 .arg(&mut fourth_d);
426 // SAFETY: grid/block validated; every pointer is a cudarc-checked
427 // allocation on this stream; the selected kernel reads the 8 input
428 // arrays of length n (+ one or three length-4 directions) and writes
429 // within the output buffers of length n / n*16.
430 unsafe { builder.launch(cfg) }.gpu_ctx("survival_rowjet kernel launch")?;
431
432 let mut value = vec![0.0_f64; n];
433 let mut grad = vec![0.0_f64; n * 4];
434 let mut hess = vec![0.0_f64; n * 16];
435 let mut third = vec![0.0_f64; n * 16];
436 let mut fourth = vec![0.0_f64; n * 16];
437 stream
438 .memcpy_dtoh(&value_d, &mut value)
439 .gpu_ctx("dtoh value")?;
440 stream
441 .memcpy_dtoh(&grad_d, &mut grad)
442 .gpu_ctx("dtoh grad")?;
443 stream
444 .memcpy_dtoh(&hess_d, &mut hess)
445 .gpu_ctx("dtoh hess")?;
446 stream
447 .memcpy_dtoh(&third_d, &mut third)
448 .gpu_ctx("dtoh third")?;
449 stream
450 .memcpy_dtoh(&fourth_d, &mut fourth)
451 .gpu_ctx("dtoh fourth")?;
452 stream
453 .synchronize()
454 .gpu_ctx("survival_rowjet synchronize")?;
455
456 Ok(SurvivalRowJetChannels {
457 n_rows: n,
458 value,
459 grad,
460 hess,
461 third,
462 fourth,
463 })
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470
471 fn fixture(n: usize) -> Vec<SurvivalRowInputs> {
472 (0..n)
473 .map(|i| {
474 let t = i as f64 / n as f64;
475 SurvivalRowInputs {
476 primaries: [
477 -2.5 + 5.0 * (12.0 * t).sin(),
478 -1.5 + 4.0 * (9.0 * t + 0.3).cos(),
479 0.2 + 1.8 * (0.5 + 0.5 * (7.0 * t).sin()),
480 -1.0 + 2.0 * (5.0 * t + 1.1).sin(),
481 ],
482 wi: 1.0,
483 di: if i % 3 == 0 { 1.0 } else { 0.0 },
484 z_sum: 0.5 * (3.0 * t).cos(),
485 cov_ones: 0.4 + 0.3 * (0.5 + 0.5 * (2.0 * t).sin()),
486 }
487 })
488 .collect()
489 }
490
491 const DIR: [f64; 4] = [0.31, -0.22, 0.17, 0.44];
492 const DIRU: [f64; 4] = [0.13, 0.27, -0.41, 0.05];
493 const DIRV: [f64; 4] = [-0.19, 0.33, 0.08, 0.22];
494
495 #[test]
496 fn cpu_channels_match_unified_rowkernel() {
497 // The CPU fallback IS `rigid_row_nll` at Order2/OneSeed/TwoSeed, the same
498 // thing the production `SurvivalMarginalSlopeRowKernel` calls. Cross-check
499 // the (v,g,H) channels against a direct `Order2<4>` evaluation so the
500 // flattening/layout is pinned to the single source.
501 use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
502 use gam_math::jet_scalar::{JetScalar, Order2};
503 let rows = fixture(7);
504 let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
505 for (row, inp) in rows.iter().enumerate() {
506 let in_row = RigidRowInputs {
507 row,
508 wi: inp.wi,
509 di: inp.di,
510 z_sum: inp.z_sum,
511 covariance_ones: inp.cov_ones,
512 probit_scale: 0.7,
513 qd1_lower: f64::NEG_INFINITY,
514 };
515 let vars: [Order2<4>; 4] =
516 std::array::from_fn(|a| Order2::variable(inp.primaries[a], a));
517 let dense = rigid_row_nll(&vars, &in_row).expect("dense order2");
518 assert!((dense.value() - out.value[row]).abs() <= 1e-12);
519 for a in 0..4 {
520 assert!((dense.g()[a] - out.grad[row * 4 + a]).abs() <= 1e-12);
521 for b in 0..4 {
522 assert!(
523 (dense.h()[a][b] - out.hess[row * 16 + a * 4 + b]).abs() <= 1e-12,
524 "hess mismatch row {row} {a},{b}"
525 );
526 }
527 }
528 }
529 }
530
531 #[test]
532 fn cpu_third_fourth_match_dense_tower_oracle() {
533 // The seeded-jet (OneSeed/TwoSeed, O(K²)) contracted third/fourth in the
534 // CPU fallback must equal the TRUE tensor contraction from the dense
535 // `Tower4<4>` (the K³/K⁴ tensor). This pins the seeded contraction to the
536 // single-source tensor exactly — the same property the device kernel's
537 // JS1/JS2 channels rely on (and the device parity gate then matches THIS
538 // CPU result to ≤1e-9).
539 use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
540 use gam_math::jet_tower::Tower4;
541 let rows = fixture(9);
542 let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
543 for (row, inp) in rows.iter().enumerate() {
544 let in_row = RigidRowInputs {
545 row,
546 wi: inp.wi,
547 di: inp.di,
548 z_sum: inp.z_sum,
549 covariance_ones: inp.cov_ones,
550 probit_scale: 0.7,
551 qd1_lower: f64::NEG_INFINITY,
552 };
553 let vars: [Tower4<4>; 4] =
554 std::array::from_fn(|a| Tower4::variable(inp.primaries[a], a));
555 let tower = rigid_row_nll(&vars, &in_row).expect("dense tower4");
556 let t3 = tower.third_contracted(&DIR);
557 let t4 = tower.fourth_contracted(&DIRU, &DIRV);
558 for a in 0..4 {
559 for b in 0..4 {
560 assert!(
561 (t3[a][b] - out.third[row * 16 + a * 4 + b]).abs() <= 1e-12,
562 "third mismatch row {row} {a},{b}: tensor={} seeded={}",
563 t3[a][b],
564 out.third[row * 16 + a * 4 + b]
565 );
566 assert!(
567 (t4[a][b] - out.fourth[row * 16 + a * 4 + b]).abs() <= 1e-12,
568 "fourth mismatch row {row} {a},{b}: tensor={} seeded={}",
569 t4[a][b],
570 out.fourth[row * 16 + a * 4 + b]
571 );
572 }
573 }
574 }
575 }
576
577 /// Per-channel CPU↔device parity tolerance (#415 / #1175).
578 ///
579 /// The device kernel runs the SAME seeded-jet arithmetic as the CPU jet
580 /// (pinned line-for-line by the host-oracle `*_tests` module on every box),
581 /// so the residual is NOT an algebra mismatch. With NVRTC FMA contraction
582 /// now disabled (#1686, `--fmad=false`), the residual splits into a tight
583 /// low-order floor (FMA was its dominant source, so the fix shrank it) and
584 /// an irreducible transcendental floor in the high-order channels: CUDA's
585 /// `erfc`/`erfcx`/`exp`/`sqrt` differ from the host libm at the ULP level,
586 /// and that ε is amplified through the order-4 jet chain (`logΦ`, the Mills
587 /// `k1..k4` polynomial, the `c=√(1+(s·g)²cov)` composition) into the
588 /// third/fourth channels — which `--fmad=false` leaves unchanged (5.09e-8 /
589 /// 4.54e-8, bit-identical to the pre-#1686 measurement). Measured on a
590 /// Tesla V100 (sm_70), the drift, **normalized to each channel's
591 /// magnitude**, is:
592 ///
593 /// ```text
594 /// channel worst |Δ| channel max|cpu| |Δ|/scale
595 /// value 1.48e-10 2.22e1 6.7e-12
596 /// grad 8.18e-10 1.14e1 7.2e-11
597 /// hess 8.79e-9 2.50e1 3.5e-10
598 /// third 5.09e-8 4.25e1 1.2e-9
599 /// fourth 4.54e-8 1.23e2 3.7e-10
600 /// ```
601 ///
602 /// (The old gate compared a flat `|Δ| <= 1e-9` ACROSS ALL channels — it
603 /// ignored both derivative-order amplification and the transcendental
604 /// floor, so the third channel's 5.09e-8 failed it even though that is a
605 /// 1.2e-9 relative drift. Per-element *relative* error is also wrong here:
606 /// the high-order channels cross zero, so at a cancellation point |cpu| is
607 /// ~1e-7 while the channel scale is ~1e2 and the relative error spuriously
608 /// reads 2.0.) The principled scale is the channel magnitude. A real
609 /// algebra bug (a sign flip / dropped Leibniz term, the #736 genus) makes
610 /// an error of order the channel magnitude itself — normalized residual
611 /// ~O(1), seven orders above this floor — so the gate below catches every
612 /// real defect with ~80× headroom over the transcendental noise.
613 const PARITY_ATOL: f64 = 1e-9;
614 const PARITY_RTOL: f64 = 1e-7;
615
616 /// Assert every element of `dev` matches `cpu` within
617 /// `PARITY_ATOL + PARITY_RTOL * channel_scale`, where `channel_scale` is the
618 /// channel's max |cpu| (the magnitude a real bug would perturb). Returns the
619 /// worst normalized residual for reporting.
620 fn assert_channel_parity(name: &str, cpu: &[f64], dev: &[f64]) -> f64 {
621 let scale = cpu.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
622 let tol = PARITY_ATOL + PARITY_RTOL * scale;
623 let mut worst = 0.0_f64;
624 let mut worst_i = 0usize;
625 for (i, (x, y)) in cpu.iter().zip(dev).enumerate() {
626 let d = (x - y).abs();
627 if d > worst {
628 worst = d;
629 worst_i = i;
630 }
631 }
632 assert!(
633 worst <= tol,
634 "survival device vs CPU `{name}` channel: worst |Δ|={worst:.3e} at idx {worst_i} \
635 (cpu={:.6e} dev={:.6e}) exceeds tol={tol:.3e} (atol={PARITY_ATOL:.0e} + \
636 rtol={PARITY_RTOL:.0e}·scale {scale:.3e}). A residual this large is an algebra \
637 mismatch, not transcendental drift — check the .cu JS1/JS2 recurrences.",
638 cpu[worst_i],
639 dev[worst_i]
640 );
641 worst / tol
642 }
643
644 #[cfg(target_os = "linux")]
645 #[test]
646 fn device_matches_cpu_when_available() {
647 // Exactness gate: when a device is admitted, every channel must match the
648 // CPU unified jet within the principled per-channel magnitude-scaled band
649 // (see PARITY_ATOL/PARITY_RTOL). When no device is available the dispatcher
650 // returns the CPU result, so this asserts CPU==CPU (trivially within band).
651 let rows = fixture(DEVICE_ROW_THRESHOLD + 1024);
652 let cpu = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
653 let got = survival_rigid_row_jets(&rows, 0.7, &DIR, &DIRU, &DIRV);
654 assert_channel_parity("value", &cpu.value, &got.value);
655 assert_channel_parity("grad", &cpu.grad, &got.grad);
656 assert_channel_parity("hess", &cpu.hess, &got.hess);
657 assert_channel_parity("third", &cpu.third, &got.third);
658 assert_channel_parity("fourth", &cpu.fourth, &got.fourth);
659
660 // Anti-false-green: if a CUDA runtime is present the dispatcher MUST have
661 // exercised the device kernel above (n > DEVICE_ROW_THRESHOLD), not the
662 // silent CPU fallback. Prove the device path itself runs and matches —
663 // otherwise this gate would pass on CPU==CPU even with a dead kernel.
664 if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
665 let dev = survival_rigid_row_jets_device_only(&rows, 0.7, &DIR, &DIRU, &DIRV)
666 .expect("CUDA runtime present but survival_rowjet device path could not run");
667 assert_channel_parity("device value", &cpu.value, &dev.value);
668 assert_channel_parity("device grad", &cpu.grad, &dev.grad);
669 assert_channel_parity("device hess", &cpu.hess, &dev.hess);
670 assert_channel_parity("device third", &cpu.third, &dev.third);
671 assert_channel_parity("device fourth", &cpu.fourth, &dev.fourth);
672 }
673 }
674
675 /// Edge-regime fixture: rows deliberately placed in the hard corners of the
676 /// probit Mills-ratio stack, where erfc/erfcx differ most between host libm
677 /// and CUDA and the seeded-jet amplification is largest. Covers
678 /// censored/event × entry-present, deep negative tails (logΦ underflow
679 /// regime), tiny and large covariance, near-zero slope, large scale, zero
680 /// weight (the early-out branch), and the erfcx asymptotic cutover (|η|>26).
681 fn edge_fixture() -> Vec<SurvivalRowInputs> {
682 let mut rows = Vec::new();
683 let push = |rows: &mut Vec<SurvivalRowInputs>, p: [f64; 4], w, d, z, c| {
684 rows.push(SurvivalRowInputs {
685 primaries: p,
686 wi: w,
687 di: d,
688 z_sum: z,
689 cov_ones: c,
690 });
691 };
692 // interior, event & censored
693 push(&mut rows, [-0.4, 0.6, 0.9, 0.3], 1.0, 1.0, 0.2, 0.5);
694 push(&mut rows, [-0.4, 0.6, 0.9, 0.3], 1.0, 0.0, 0.2, 0.5);
695 // deep negative probit tail (logΦ(−η)→ asymptotic / Mills tail)
696 push(&mut rows, [8.0, 9.0, 1.2, 2.5], 1.0, 0.0, -3.0, 1.0);
697 push(&mut rows, [-8.0, -9.0, 1.2, -2.5], 1.0, 1.0, 3.0, 1.0);
698 // erfcx asymptotic cutover region (argument near/above 26)
699 push(&mut rows, [40.0, 41.0, 0.7, 3.0], 1.0, 0.0, 0.0, 2.0);
700 // tiny covariance (c ≈ 1, derivative of √ near flat)
701 push(&mut rows, [-0.3, 0.5, 0.8, 1.5], 1.0, 1.0, 0.4, 1e-10);
702 // large covariance + large scale (c large, strong coupling)
703 push(&mut rows, [-0.2, 0.4, 1.1, 4.0], 1.0, 1.0, 0.1, 50.0);
704 // near-zero slope (og→0, opb2→1)
705 push(&mut rows, [-0.5, 0.3, 0.6, 1e-9], 1.0, 0.0, 0.7, 0.9);
706 // zero weight (the w==0 early-out: every channel 0)
707 push(&mut rows, [-0.5, 0.3, 0.6, 0.4], 0.0, 1.0, 0.7, 0.9);
708 // small positive qd1 (log(ad1) near its valid edge)
709 push(&mut rows, [-0.5, 0.3, 1e-3, 0.4], 1.0, 1.0, 0.2, 0.6);
710 rows
711 }
712
713 /// #415 core deliverable — **fail loud, never silently degrade.** On a GPU
714 /// box the device path MUST run; this calls `survival_rigid_row_jets_device_only`
715 /// (which never falls back) and asserts it both (a) succeeds — no silent
716 /// NVRTC-declined / wrong-arch / launch-failure swallowed by the dispatcher —
717 /// and (b) matches the CPU oracle within the principled per-channel band, for
718 /// BOTH the t4 and the no-t4 kernel variants and across the edge-regime sweep.
719 ///
720 /// When no CUDA device is present the device-only path returns `Err`, which
721 /// is the legitimate state on a CPU-only box — so the test SKIPS with a clear
722 /// log there. Set `GAM_REQUIRE_GPU=1` (CI on the GPU runner) to turn that skip
723 /// into a HARD failure: a box that is supposed to have a GPU but can't run the
724 /// kernel must break the build, not pass on the CPU.
725 #[cfg(target_os = "linux")]
726 #[test]
727 fn device_only_path_runs_and_matches_cpu_fail_loud() {
728 // Fail loud only when a CUDA device is actually present (a real runtime
729 // check, not an env-var read — `env::var` is banned crate-wide): on a GPU
730 // box the device path MUST run, while a CI runner with no device skips
731 // gracefully.
732 let require_gpu = gam_gpu::device_runtime::GpuRuntime::global().is_some();
733
734 // Two batches: enough rows to amortise the launch, in both the interior
735 // (smooth) and edge (transcendental-stress) regimes. The edge batch is
736 // padded by tiling so it crosses DEVICE_ROW_THRESHOLD.
737 let interior = fixture(DEVICE_ROW_THRESHOLD + 777);
738 let edge_unit = edge_fixture();
739 let reps = (DEVICE_ROW_THRESHOLD + 999).div_ceil(edge_unit.len());
740 let edge: Vec<_> = edge_unit
741 .iter()
742 .cloned()
743 .cycle()
744 .take(reps * edge_unit.len())
745 .collect();
746
747 // Variant matrix: (label, dir_u, dir_v). All-zero (u,v) selects the
748 // `survival_rowjet_no_t4` kernel (fourth channel ≡ 0); nonzero selects
749 // the full `survival_rowjet`. Cover both so neither entry point rots.
750 let zero = [0.0_f64; 4];
751 let variants: [(&str, &[f64; 4], &[f64; 4]); 2] =
752 [("t4", &DIRU, &DIRV), ("no_t4", &zero, &zero)];
753
754 let mut ran_on_device = false;
755 for (regime, rows) in [("interior", &interior), ("edge", &edge)] {
756 for (vlabel, du, dv) in variants {
757 let dev = match survival_rigid_row_jets_device_only(rows, 0.7, &DIR, du, dv) {
758 Ok(d) => d,
759 Err(e) => {
760 if require_gpu {
761 panic!(
762 "GAM_REQUIRE_GPU set but survival_rowjet device path \
763 ({regime}/{vlabel}) could not run: {e}"
764 );
765 }
766 eprintln!(
767 "[#415] no CUDA device ({regime}/{vlabel}) — skipping device-only \
768 parity (set GAM_REQUIRE_GPU=1 to make this a hard failure): {e}"
769 );
770 continue;
771 }
772 };
773 ran_on_device = true;
774 let cpu = survival_rigid_row_jets_cpu(rows, 0.7, &DIR, du, dv);
775 assert_channel_parity(&format!("{regime}/{vlabel}/value"), &cpu.value, &dev.value);
776 assert_channel_parity(&format!("{regime}/{vlabel}/grad"), &cpu.grad, &dev.grad);
777 assert_channel_parity(&format!("{regime}/{vlabel}/hess"), &cpu.hess, &dev.hess);
778 assert_channel_parity(&format!("{regime}/{vlabel}/third"), &cpu.third, &dev.third);
779 assert_channel_parity(
780 &format!("{regime}/{vlabel}/fourth"),
781 &cpu.fourth,
782 &dev.fourth,
783 );
784 // The no_t4 variant must yield an exactly-zero fourth channel
785 // (the kernel writes 0.0), and the CPU oracle agrees because
786 // (u,v)=0 contracts the fourth tensor to zero.
787 if vlabel == "no_t4" {
788 assert!(
789 dev.fourth.iter().all(|&x| x == 0.0),
790 "no_t4 kernel must write an all-zero fourth channel"
791 );
792 }
793 }
794 }
795 if ran_on_device {
796 eprintln!("[#415] device-only parity PASSED on GPU for all regimes × variants");
797 }
798 }
799
800}