gam-models 0.3.139

Model families (GAMLSS, survival location-scale, BMS) for the gam penalized-likelihood engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
//! Survival marginal-slope rigid per-row NLL jet on the GPU (#932 → A100 cutover).
//!
//! The rigid survival marginal-slope `RowKernel<4>`
//! ([`crate::survival::marginal_slope::row_kernel::rigid_row_nll`], the
//! #932 unified single source) computes, per row, the order-2 derivative tower
//! `(v, g[4], H[4][4])` of the negative log-likelihood
//!
//! ```text
//!   c(g)  = √(1 + (s·g)²·cov),   η0 = q0·c + s·g·z,   η1 = q1·c + s·g·z,
//!   ad1   = qd1·c,
//!   ℓ     = +w·logΦ(−η0) + w·(1−d)·logΦ(−η1) − w·d·(logφ(η1) + log ad1)
//! ```
//!
//! plus the contracted third `Σ_c ℓ_{abc} dir_c` and fourth
//! `Σ_{cd} ℓ_{abcd} u_c v_d`. Each row evaluates the probit Mills-ratio stack
//! (`erfcx`/`erfc`) several times — a transcendental + bandwidth wall that the
//! CPU pays serially per thread across all `n` rows on every inner-Newton step
//! and on the #979 Jeffreys/Firth all-axes sweeps.
//!
//! On an A100 the per-row jet is embarrassingly parallel and the `erfc`/`erfcx`
//! are hardware f64 special functions. Measured (aga13 A100, full f64, no
//! fast-math, n=8e6): **~500× kernel-only** over the 16-thread CPU jet and
//! **~160× end-to-end** with the on-device reduction. The standalone
//! measurement prototype lives at
//! `src/gpu/proto/survival_marginal_slope_jet_932.cu`.
//!
//! # CPU↔device parity (#415 / #1175)
//!
//! The device kernel runs the SAME seeded-jet arithmetic as the CPU jet (pinned
//! line-for-line by the host-oracle `*_tests` module on every box), so the
//! CPU↔device residual is NOT an algebra mismatch. After #1686 disabled NVRTC
//! FMA contraction (`--fmad=false`, applied here because this kernel now
//! compiles through `device_cache::compile_ptx_arch`, the shared arch+fmad
//! options), TWO distinct floors remain, with very different magnitudes:
//!
//!   * **Low-order channels (value/grad/hess)** — FMA contraction WAS the
//!     dominant source here, so `--fmad=false` tightened them sharply. Measured
//!     on a **Tesla V100 (sm_70)**: value 1.5e-10, grad 8.2e-10, hess 8.8e-9
//!     absolute (≤1.1e-1 normalized to channel magnitude).
//!   * **High-order channels (third/fourth)** — dominated by *transcendental*
//!     drift, NOT FMA: CUDA's `erfc`/`erfcx`/`exp`/`sqrt` differ from the host
//!     libm at the ULP level, and that ε is amplified ~5e8× through the order-4
//!     seeded-jet chain. `--fmad=false` leaves these essentially unchanged
//!     (third 5.09e-8, fourth 4.54e-8 absolute — bit-identical to the
//!     pre-#1686 measurement to 4 sig figs), confirming FMA was never their
//!     root cause. Normalized to channel magnitude they are ≤1.2e-9 (third) and
//!     bounded by the magnitude-scaled band below (fourth).
//!
//! The parity gate (`tests::device_matches_cpu_when_available`, and the
//! fail-loud device-only sweep) is therefore a per-channel
//! `atol + rtol·channel_scale` band, NOT a flat absolute tolerance — see
//! `tests::PARITY_RTOL` for why a flat `1e-9` absolute bound was wrong (it
//! ignored both derivative-order amplification AND the transcendental floor
//! that #1686's FMA fix cannot reach) and why the magnitude-scaled band still
//! catches any real algebra bug with comfortable headroom. This band is
//! *complementary* to #1686, not redundant: #1686 removes the FMA component,
//! the band absorbs the irreducible transcendental component.
//!
//! # Single source, exactly
//!
//! The device kernel is a byte-faithful port of the seeded-jet arithmetic that
//! the CPU `rigid_row_nll` runs:
//!
//!   * `J2`  — order-2 `(v, g, H)` over `K=4` primaries (mirrors `Order2<4>`);
//!   * `JS1` — one-seed jet whose ε-Hessian channel IS `Σ_c ℓ_{abc} dir_c`
//!     (mirrors `OneSeed<4>` — O(K²) state, NOT a dense K³ `t3`);
//!   * `JS2` — two-seed jet whose εδ-Hessian channel IS `Σ_{cd} ℓ_{abcd} u_c v_d`
//!     (mirrors `TwoSeed<4>` — O(K²) state, NOT a dense K⁴ `t4`).
//!
//! Seeded jets are load-bearing: a dense `Tower4<4>` on device spills 41 KB/thread
//! (256-entry `t4`) and OOMs the launch local-memory reservation; the seeded jets
//! drop per-thread stack to ~900 B. The same NLL program (`def_nll!`) is written
//! ONCE and instantiated at each scalar type — no bespoke gate chain rule, so the
//! #736 cross-block sign-flip bug genus cannot reappear.
//!
//! # CPU fallback
//!
//! [`survival_rigid_row_jets`] is the general entry point. When a CUDA device is
//! admitted and the batch is large enough to amortise the launch it runs the
//! kernel; otherwise (no Linux / no runtime / probe failure / small `n` / any
//! device error) it falls back to the CPU `rigid_row_nll` — the SAME unified jet —
//! so the result is identical and the path is never GPU-only.

use crate::survival::marginal_slope::row_kernel::RigidRowInputs;

// #415 parity-lock: a host transcription of the device `.cu` seeded-jet
// arithmetic, pinned to the production CPU jet on every box. Declared bare
// (the whole file is `#![cfg(test)]`) with a `*_tests` name so the build.rs
// ban-scanner exempts the test-only substrate — see `bms::test_support`.
mod survival_rowjet_host_oracle_tests;

/// Per-row order-≤2 + contracted third/fourth channels for a batch of rows,
/// flattened row-major. `K = 4` (the rigid survival primaries `q0,q1,qd1,g`).
///
/// * `value[row]`            — `ℓ`
/// * `grad[row*K + a]`       — `∂ℓ/∂p_a`
/// * `hess[row*K*K + a*K+b]` — `∂²ℓ/∂p_a∂p_b`
/// * `third[row*K*K + a*K+b]`  — `Σ_c ℓ_{abc} dir_c`        (one fixed `dir`)
/// * `fourth[row*K*K + a*K+b]` — `Σ_{cd} ℓ_{abcd} u_c v_d`  (one fixed `(u,v)`)
#[derive(Debug, Clone, PartialEq)]
pub struct SurvivalRowJetChannels {
    pub n_rows: usize,
    pub value: Vec<f64>,
    pub grad: Vec<f64>,
    pub hess: Vec<f64>,
    pub third: Vec<f64>,
    pub fourth: Vec<f64>,
}

/// The scalar-independent per-row inputs the kernel consumes: the four primaries
/// `(q0,q1,qd1,g)` and the row scalars `(w,d,z_sum,cov_ones)`. `probit_scale` is
/// shared across all rows (a scalar kernel argument). These are exactly the
/// values [`RigidRowInputs`] + `rigid_row_kernel_primaries` produce per row.
#[derive(Debug, Clone)]
pub struct SurvivalRowInputs {
    pub primaries: [f64; 4],
    pub wi: f64,
    pub di: f64,
    pub z_sum: f64,
    pub cov_ones: f64,
}

/// Minimum row count below which the device launch is not worth its fixed cost
/// (probe + H2D + D2H). Below this the CPU path is used even when a device is
/// available; the result is identical (same unified jet). The standalone A100
/// measurement put the kernel/CPU crossover well under 1e5 rows; 1e5 is a
/// conservative break-even that keeps small-fit latency on the CPU.
pub const DEVICE_ROW_THRESHOLD: usize = 100_000;

/// CPU reference / fallback: build every row's channels from the SAME unified jet
/// the production `RowKernel` consumes (`rigid_row_nll` at `Order2`/`OneSeed`/
/// `TwoSeed`). This is BOTH the fallback path AND the exactness oracle the device
/// kernel is pinned to.
#[must_use]
pub fn survival_rigid_row_jets_cpu(
    rows: &[SurvivalRowInputs],
    probit_scale: f64,
    dir: &[f64; 4],
    dir_u: &[f64; 4],
    dir_v: &[f64; 4],
) -> SurvivalRowJetChannels {
    use crate::survival::marginal_slope::row_kernel::{
        RIGID_LINEAR_MASK, SparseOrder2, rigid_row_nll,
    };
    use gam_math::jet_scalar::{JetScalar, OneSeed, TwoSeed};
    let n = rows.len();
    let mut value = vec![0.0_f64; n];
    let mut grad = vec![0.0_f64; n * 4];
    let mut hess = vec![0.0_f64; n * 16];
    let mut third = vec![0.0_f64; n * 16];
    let mut fourth = vec![0.0_f64; n * 16];
    for (row, inp) in rows.iter().enumerate() {
        let in_row = RigidRowInputs {
            row,
            wi: inp.wi,
            di: inp.di,
            z_sum: inp.z_sum,
            covariance_ones: inp.cov_ones,
            probit_scale,
            // The CPU monotonicity guard floor: the device kernel does not
            // re-derive it (the caller pre-validates the primaries before
            // building the batch), so use the always-pass sentinel here to
            // keep the oracle a pure derivative comparison.
            qd1_lower: f64::NEG_INFINITY,
        };
        // (v, g, H) at the static-sparsity Order2 scalar (production hot path).
        let p = inp.primaries;
        let vars: [SparseOrder2<RIGID_LINEAR_MASK>; 4] =
            std::array::from_fn(|a| SparseOrder2::variable(p[a], a));
        if let Ok(out) = rigid_row_nll(&vars, &in_row) {
            value[row] = out.value();
            grad[row * 4..row * 4 + 4].copy_from_slice(&out.g());
            let h = out.h();
            for a in 0..4 {
                for b in 0..4 {
                    hess[row * 16 + a * 4 + b] = h[a][b];
                }
            }
        }
        // contracted third via OneSeed (ε-Hessian = Σ_c ℓ_{abc} dir_c).
        let vars1: [OneSeed<4>; 4] =
            std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
        if let Ok(out1) = rigid_row_nll(&vars1, &in_row) {
            let t = out1.contracted_third();
            for a in 0..4 {
                for b in 0..4 {
                    third[row * 16 + a * 4 + b] = t[a][b];
                }
            }
        }
        // contracted fourth via TwoSeed (εδ-Hessian = Σ_{cd} ℓ_{abcd} u_c v_d).
        let vars2: [TwoSeed<4>; 4] =
            std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
        if let Ok(out2) = rigid_row_nll(&vars2, &in_row) {
            let f = out2.contracted_fourth();
            for a in 0..4 {
                for b in 0..4 {
                    fourth[row * 16 + a * 4 + b] = f[a][b];
                }
            }
        }
    }
    SurvivalRowJetChannels {
        n_rows: n,
        value,
        grad,
        hess,
        third,
        fourth,
    }
}

/// General entry point: compute every row's order-≤2 + contracted third/fourth
/// channels, on the GPU when a CUDA device is admitted and the batch is large
/// enough to amortise the launch, else on the CPU. Both paths run the SAME
/// unified jet, so the result agrees within the per-channel magnitude-scaled
/// parity band (irreducible transcendental drift only — see the module docs and
/// `tests::PARITY_RTOL`; worst measured ≤1.2e-9 relative on a V100). On ANY
/// device error the CPU path runs — no fragility.
#[must_use]
pub fn survival_rigid_row_jets(
    rows: &[SurvivalRowInputs],
    probit_scale: f64,
    dir: &[f64; 4],
    dir_u: &[f64; 4],
    dir_v: &[f64; 4],
) -> SurvivalRowJetChannels {
    #[cfg(target_os = "linux")]
    {
        if rows.len() >= DEVICE_ROW_THRESHOLD {
            match device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v) {
                Ok(out) => return out,
                Err(e) => {
                    // Fall through to CPU on any device error (the GPU path is an
                    // accelerator, never the only correct path). Log WHY so a
                    // silent CPU fallback on an admitted device is diagnosable.
                    log::info!("[GPU] survival_rowjet device path fell back to CPU: {e}");
                }
            }
        }
    }
    survival_rigid_row_jets_cpu(rows, probit_scale, dir, dir_u, dir_v)
}

/// Diagnostic: run ONLY the device path and return its `Result` (the error
/// string on failure). Linux-only; intended for A100 verification harnesses to
/// surface a compile/launch failure that the silent-fallback dispatcher hides.
#[cfg(target_os = "linux")]
pub fn survival_rigid_row_jets_device_only(
    rows: &[SurvivalRowInputs],
    probit_scale: f64,
    dir: &[f64; 4],
    dir_u: &[f64; 4],
    dir_v: &[f64; 4],
) -> Result<SurvivalRowJetChannels, String> {
    device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v)
        .map_err(|e| e.to_string())
}

/// The NVRTC source: a byte-faithful port of the seeded-jet arithmetic.
/// `K=4` is fixed for the rigid survival primaries, so the kernel is compiled
/// once (no shape macros). Full f64, no fast-math.
#[cfg(target_os = "linux")]
pub const SURVIVAL_ROWJET_SOURCE: &str = include_str!("survival_rowjet_kernel.cu");

#[cfg(target_os = "linux")]
mod device {
    use super::{SURVIVAL_ROWJET_SOURCE, SurvivalRowInputs, SurvivalRowJetChannels};
    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
    use std::sync::{Arc, Mutex, OnceLock};

    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};

    struct Backend {
        ctx: Arc<CudaContext>,
        stream: Arc<CudaStream>,
        module: Mutex<Option<Arc<CudaModule>>>,
    }

    fn backend() -> Result<&'static Backend, GpuError> {
        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
        BACKEND
            .get_or_init(|| {
                let parts = gam_gpu::backend_probe::probe_cuda_backend("survival_rowjet")?;
                Ok(Backend {
                    ctx: parts.ctx,
                    stream: parts.stream,
                    module: Mutex::new(None),
                })
            })
            .as_ref()
            .map_err(GpuError::clone)
    }

    fn module(b: &Backend) -> Result<Arc<CudaModule>, GpuError> {
        if let Ok(guard) = b.module.lock() {
            if let Some(m) = guard.as_ref() {
                return Ok(m.clone());
            }
        }
        // Compile through the shared arch+fmad options (NOT bare `compile_ptx`,
        // which leaves NVRTC at `--fmad=true` and no `--gpu-architecture` pin).
        // FMA contraction must be off so the deep seeded-jet tower is
        // bit-comparable to the separately-rounded CPU oracle — bare
        // `compile_ptx` made this kernel miss the 1e-9 parity gate by ~5e-8 on
        // a V100. The arch pin keeps the kernel keyed to the device's real
        // compute capability rather than NVRTC's default.
        let ptx = gam_gpu::device_cache::compile_ptx_arch(SURVIVAL_ROWJET_SOURCE)
            .gpu_ctx_with(|err| format!("survival_rowjet NVRTC compile: {err}"))?;
        let m = b
            .ctx
            .load_module(ptx)
            .gpu_ctx("survival_rowjet module load")?;
        if let Ok(mut guard) = b.module.lock() {
            guard.get_or_insert_with(|| m.clone());
        }
        Ok(m)
    }

    fn has_nonzero_direction(dir: &[f64; 4]) -> bool {
        dir.iter().any(|&v| v != 0.0)
    }

    pub(super) fn survival_rigid_row_jets_device(
        rows: &[SurvivalRowInputs],
        probit_scale: f64,
        dir: &[f64; 4],
        dir_u: &[f64; 4],
        dir_v: &[f64; 4],
    ) -> Result<SurvivalRowJetChannels, GpuError> {
        let n = rows.len();
        if n == 0 {
            return Ok(SurvivalRowJetChannels {
                n_rows: 0,
                value: Vec::new(),
                grad: Vec::new(),
                hess: Vec::new(),
                third: Vec::new(),
                fourth: Vec::new(),
            });
        }
        let b = backend()?;
        let m = module(b)?;
        let need_fourth = has_nonzero_direction(dir_u) && has_nonzero_direction(dir_v);
        let func_name = if need_fourth {
            "survival_rowjet"
        } else {
            "survival_rowjet_no_t4"
        };
        let func = m
            .load_function(func_name)
            .gpu_ctx_with(|err| format!("survival_rowjet load_function {func_name}: {err}"))?;
        let stream = b.stream.clone();

        // Flatten inputs into struct-of-arrays for coalesced device reads.
        let mut q0 = vec![0.0_f64; n];
        let mut q1 = vec![0.0_f64; n];
        let mut qd1 = vec![0.0_f64; n];
        let mut g = vec![0.0_f64; n];
        let mut wi = vec![0.0_f64; n];
        let mut di = vec![0.0_f64; n];
        let mut zs = vec![0.0_f64; n];
        let mut cov = vec![0.0_f64; n];
        for (i, r) in rows.iter().enumerate() {
            q0[i] = r.primaries[0];
            q1[i] = r.primaries[1];
            qd1[i] = r.primaries[2];
            g[i] = r.primaries[3];
            wi[i] = r.wi;
            di[i] = r.di;
            zs[i] = r.z_sum;
            cov[i] = r.cov_ones;
        }

        let q0_d = stream.clone_htod(&q0).gpu_ctx("htod q0")?;
        let q1_d = stream.clone_htod(&q1).gpu_ctx("htod q1")?;
        let qd1_d = stream.clone_htod(&qd1).gpu_ctx("htod qd1")?;
        let g_d = stream.clone_htod(&g).gpu_ctx("htod g")?;
        let wi_d = stream.clone_htod(&wi).gpu_ctx("htod wi")?;
        let di_d = stream.clone_htod(&di).gpu_ctx("htod di")?;
        let zs_d = stream.clone_htod(&zs).gpu_ctx("htod zsum")?;
        let cov_d = stream.clone_htod(&cov).gpu_ctx("htod cov")?;
        let dir_d = stream.clone_htod(&dir.to_vec()).gpu_ctx("htod dir")?;

        let mut value_d = stream.alloc_zeros::<f64>(n).gpu_ctx("alloc value")?;
        let mut grad_d = stream.alloc_zeros::<f64>(n * 4).gpu_ctx("alloc grad")?;
        let mut hess_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc hess")?;
        let mut third_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc third")?;
        let mut fourth_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc fourth")?;

        let n_i32 = i32::try_from(n)
            .map_err(|_| gam_gpu::gpu_err!("survival_rowjet n={n} overflows i32"))?;
        const TPB: u32 = 128;
        let grid = ((n as u32).div_ceil(TPB)).max(1);
        let cfg = LaunchConfig {
            grid_dim: (grid, 1, 1),
            block_dim: (TPB, 1, 1),
            shared_mem_bytes: 0,
        };
        let mut builder = stream.launch_builder(&func);
        builder
            .arg(&n_i32)
            .arg(&q0_d)
            .arg(&q1_d)
            .arg(&qd1_d)
            .arg(&g_d)
            .arg(&wi_d)
            .arg(&di_d)
            .arg(&zs_d)
            .arg(&cov_d)
            .arg(&probit_scale)
            .arg(&dir_d);
        let diru_d;
        let dirv_d;
        if need_fourth {
            diru_d = stream.clone_htod(&dir_u.to_vec()).gpu_ctx("htod dir_u")?;
            dirv_d = stream.clone_htod(&dir_v.to_vec()).gpu_ctx("htod dir_v")?;
            builder.arg(&diru_d).arg(&dirv_d);
        }
        builder
            .arg(&mut value_d)
            .arg(&mut grad_d)
            .arg(&mut hess_d)
            .arg(&mut third_d)
            .arg(&mut fourth_d);
        // SAFETY: grid/block validated; every pointer is a cudarc-checked
        // allocation on this stream; the selected kernel reads the 8 input
        // arrays of length n (+ one or three length-4 directions) and writes
        // within the output buffers of length n / n*16.
        unsafe { builder.launch(cfg) }.gpu_ctx("survival_rowjet kernel launch")?;

        let mut value = vec![0.0_f64; n];
        let mut grad = vec![0.0_f64; n * 4];
        let mut hess = vec![0.0_f64; n * 16];
        let mut third = vec![0.0_f64; n * 16];
        let mut fourth = vec![0.0_f64; n * 16];
        stream
            .memcpy_dtoh(&value_d, &mut value)
            .gpu_ctx("dtoh value")?;
        stream
            .memcpy_dtoh(&grad_d, &mut grad)
            .gpu_ctx("dtoh grad")?;
        stream
            .memcpy_dtoh(&hess_d, &mut hess)
            .gpu_ctx("dtoh hess")?;
        stream
            .memcpy_dtoh(&third_d, &mut third)
            .gpu_ctx("dtoh third")?;
        stream
            .memcpy_dtoh(&fourth_d, &mut fourth)
            .gpu_ctx("dtoh fourth")?;
        stream
            .synchronize()
            .gpu_ctx("survival_rowjet synchronize")?;

        Ok(SurvivalRowJetChannels {
            n_rows: n,
            value,
            grad,
            hess,
            third,
            fourth,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn fixture(n: usize) -> Vec<SurvivalRowInputs> {
        (0..n)
            .map(|i| {
                let t = i as f64 / n as f64;
                SurvivalRowInputs {
                    primaries: [
                        -2.5 + 5.0 * (12.0 * t).sin(),
                        -1.5 + 4.0 * (9.0 * t + 0.3).cos(),
                        0.2 + 1.8 * (0.5 + 0.5 * (7.0 * t).sin()),
                        -1.0 + 2.0 * (5.0 * t + 1.1).sin(),
                    ],
                    wi: 1.0,
                    di: if i % 3 == 0 { 1.0 } else { 0.0 },
                    z_sum: 0.5 * (3.0 * t).cos(),
                    cov_ones: 0.4 + 0.3 * (0.5 + 0.5 * (2.0 * t).sin()),
                }
            })
            .collect()
    }

    const DIR: [f64; 4] = [0.31, -0.22, 0.17, 0.44];
    const DIRU: [f64; 4] = [0.13, 0.27, -0.41, 0.05];
    const DIRV: [f64; 4] = [-0.19, 0.33, 0.08, 0.22];

    #[test]
    fn cpu_channels_match_unified_rowkernel() {
        // The CPU fallback IS `rigid_row_nll` at Order2/OneSeed/TwoSeed, the same
        // thing the production `SurvivalMarginalSlopeRowKernel` calls. Cross-check
        // the (v,g,H) channels against a direct `Order2<4>` evaluation so the
        // flattening/layout is pinned to the single source.
        use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
        use gam_math::jet_scalar::{JetScalar, Order2};
        let rows = fixture(7);
        let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
        for (row, inp) in rows.iter().enumerate() {
            let in_row = RigidRowInputs {
                row,
                wi: inp.wi,
                di: inp.di,
                z_sum: inp.z_sum,
                covariance_ones: inp.cov_ones,
                probit_scale: 0.7,
                qd1_lower: f64::NEG_INFINITY,
            };
            let vars: [Order2<4>; 4] =
                std::array::from_fn(|a| Order2::variable(inp.primaries[a], a));
            let dense = rigid_row_nll(&vars, &in_row).expect("dense order2");
            assert!((dense.value() - out.value[row]).abs() <= 1e-12);
            for a in 0..4 {
                assert!((dense.g()[a] - out.grad[row * 4 + a]).abs() <= 1e-12);
                for b in 0..4 {
                    assert!(
                        (dense.h()[a][b] - out.hess[row * 16 + a * 4 + b]).abs() <= 1e-12,
                        "hess mismatch row {row} {a},{b}"
                    );
                }
            }
        }
    }

    #[test]
    fn cpu_third_fourth_match_dense_tower_oracle() {
        // The seeded-jet (OneSeed/TwoSeed, O(K²)) contracted third/fourth in the
        // CPU fallback must equal the TRUE tensor contraction from the dense
        // `Tower4<4>` (the K³/K⁴ tensor). This pins the seeded contraction to the
        // single-source tensor exactly — the same property the device kernel's
        // JS1/JS2 channels rely on (and the device parity gate then matches THIS
        // CPU result to ≤1e-9).
        use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
        use gam_math::jet_tower::Tower4;
        let rows = fixture(9);
        let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
        for (row, inp) in rows.iter().enumerate() {
            let in_row = RigidRowInputs {
                row,
                wi: inp.wi,
                di: inp.di,
                z_sum: inp.z_sum,
                covariance_ones: inp.cov_ones,
                probit_scale: 0.7,
                qd1_lower: f64::NEG_INFINITY,
            };
            let vars: [Tower4<4>; 4] =
                std::array::from_fn(|a| Tower4::variable(inp.primaries[a], a));
            let tower = rigid_row_nll(&vars, &in_row).expect("dense tower4");
            let t3 = tower.third_contracted(&DIR);
            let t4 = tower.fourth_contracted(&DIRU, &DIRV);
            for a in 0..4 {
                for b in 0..4 {
                    assert!(
                        (t3[a][b] - out.third[row * 16 + a * 4 + b]).abs() <= 1e-12,
                        "third mismatch row {row} {a},{b}: tensor={} seeded={}",
                        t3[a][b],
                        out.third[row * 16 + a * 4 + b]
                    );
                    assert!(
                        (t4[a][b] - out.fourth[row * 16 + a * 4 + b]).abs() <= 1e-12,
                        "fourth mismatch row {row} {a},{b}: tensor={} seeded={}",
                        t4[a][b],
                        out.fourth[row * 16 + a * 4 + b]
                    );
                }
            }
        }
    }

    /// Per-channel CPU↔device parity tolerance (#415 / #1175).
    ///
    /// The device kernel runs the SAME seeded-jet arithmetic as the CPU jet
    /// (pinned line-for-line by the host-oracle `*_tests` module on every box),
    /// so the residual is NOT an algebra mismatch. With NVRTC FMA contraction
    /// now disabled (#1686, `--fmad=false`), the residual splits into a tight
    /// low-order floor (FMA was its dominant source, so the fix shrank it) and
    /// an irreducible transcendental floor in the high-order channels: CUDA's
    /// `erfc`/`erfcx`/`exp`/`sqrt` differ from the host libm at the ULP level,
    /// and that ε is amplified through the order-4 jet chain (`logΦ`, the Mills
    /// `k1..k4` polynomial, the `c=√(1+(s·g)²cov)` composition) into the
    /// third/fourth channels — which `--fmad=false` leaves unchanged (5.09e-8 /
    /// 4.54e-8, bit-identical to the pre-#1686 measurement). Measured on a
    /// Tesla V100 (sm_70), the drift, **normalized to each channel's
    /// magnitude**, is:
    ///
    /// ```text
    ///   channel  worst |Δ|     channel max|cpu|   |Δ|/scale
    ///   value    1.48e-10      2.22e1             6.7e-12
    ///   grad     8.18e-10      1.14e1             7.2e-11
    ///   hess     8.79e-9       2.50e1             3.5e-10
    ///   third    5.09e-8       4.25e1             1.2e-9
    ///   fourth   4.54e-8       1.23e2             3.7e-10
    /// ```
    ///
    /// (The old gate compared a flat `|Δ| <= 1e-9` ACROSS ALL channels — it
    /// ignored both derivative-order amplification and the transcendental
    /// floor, so the third channel's 5.09e-8 failed it even though that is a
    /// 1.2e-9 relative drift. Per-element *relative* error is also wrong here:
    /// the high-order channels cross zero, so at a cancellation point |cpu| is
    /// ~1e-7 while the channel scale is ~1e2 and the relative error spuriously
    /// reads 2.0.) The principled scale is the channel magnitude. A real
    /// algebra bug (a sign flip / dropped Leibniz term, the #736 genus) makes
    /// an error of order the channel magnitude itself — normalized residual
    /// ~O(1), seven orders above this floor — so the gate below catches every
    /// real defect with ~80× headroom over the transcendental noise.
    const PARITY_ATOL: f64 = 1e-9;
    const PARITY_RTOL: f64 = 1e-7;

    /// Assert every element of `dev` matches `cpu` within
    /// `PARITY_ATOL + PARITY_RTOL * channel_scale`, where `channel_scale` is the
    /// channel's max |cpu| (the magnitude a real bug would perturb). Returns the
    /// worst normalized residual for reporting.
    fn assert_channel_parity(name: &str, cpu: &[f64], dev: &[f64]) -> f64 {
        let scale = cpu.iter().fold(0.0_f64, |m, x| m.max(x.abs()));
        let tol = PARITY_ATOL + PARITY_RTOL * scale;
        let mut worst = 0.0_f64;
        let mut worst_i = 0usize;
        for (i, (x, y)) in cpu.iter().zip(dev).enumerate() {
            let d = (x - y).abs();
            if d > worst {
                worst = d;
                worst_i = i;
            }
        }
        assert!(
            worst <= tol,
            "survival device vs CPU `{name}` channel: worst |Δ|={worst:.3e} at idx {worst_i} \
             (cpu={:.6e} dev={:.6e}) exceeds tol={tol:.3e} (atol={PARITY_ATOL:.0e} + \
             rtol={PARITY_RTOL:.0e}·scale {scale:.3e}). A residual this large is an algebra \
             mismatch, not transcendental drift — check the .cu JS1/JS2 recurrences.",
            cpu[worst_i],
            dev[worst_i]
        );
        worst / tol
    }

    #[cfg(target_os = "linux")]
    #[test]
    fn device_matches_cpu_when_available() {
        // Exactness gate: when a device is admitted, every channel must match the
        // CPU unified jet within the principled per-channel magnitude-scaled band
        // (see PARITY_ATOL/PARITY_RTOL). When no device is available the dispatcher
        // returns the CPU result, so this asserts CPU==CPU (trivially within band).
        let rows = fixture(DEVICE_ROW_THRESHOLD + 1024);
        let cpu = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
        let got = survival_rigid_row_jets(&rows, 0.7, &DIR, &DIRU, &DIRV);
        assert_channel_parity("value", &cpu.value, &got.value);
        assert_channel_parity("grad", &cpu.grad, &got.grad);
        assert_channel_parity("hess", &cpu.hess, &got.hess);
        assert_channel_parity("third", &cpu.third, &got.third);
        assert_channel_parity("fourth", &cpu.fourth, &got.fourth);

        // Anti-false-green: if a CUDA runtime is present the dispatcher MUST have
        // exercised the device kernel above (n > DEVICE_ROW_THRESHOLD), not the
        // silent CPU fallback. Prove the device path itself runs and matches —
        // otherwise this gate would pass on CPU==CPU even with a dead kernel.
        if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
            let dev = survival_rigid_row_jets_device_only(&rows, 0.7, &DIR, &DIRU, &DIRV)
                .expect("CUDA runtime present but survival_rowjet device path could not run");
            assert_channel_parity("device value", &cpu.value, &dev.value);
            assert_channel_parity("device grad", &cpu.grad, &dev.grad);
            assert_channel_parity("device hess", &cpu.hess, &dev.hess);
            assert_channel_parity("device third", &cpu.third, &dev.third);
            assert_channel_parity("device fourth", &cpu.fourth, &dev.fourth);
        }
    }

    /// Edge-regime fixture: rows deliberately placed in the hard corners of the
    /// probit Mills-ratio stack, where erfc/erfcx differ most between host libm
    /// and CUDA and the seeded-jet amplification is largest. Covers
    /// censored/event × entry-present, deep negative tails (logΦ underflow
    /// regime), tiny and large covariance, near-zero slope, large scale, zero
    /// weight (the early-out branch), and the erfcx asymptotic cutover (|η|>26).
    fn edge_fixture() -> Vec<SurvivalRowInputs> {
        let mut rows = Vec::new();
        let push = |rows: &mut Vec<SurvivalRowInputs>, p: [f64; 4], w, d, z, c| {
            rows.push(SurvivalRowInputs {
                primaries: p,
                wi: w,
                di: d,
                z_sum: z,
                cov_ones: c,
            });
        };
        // interior, event & censored
        push(&mut rows, [-0.4, 0.6, 0.9, 0.3], 1.0, 1.0, 0.2, 0.5);
        push(&mut rows, [-0.4, 0.6, 0.9, 0.3], 1.0, 0.0, 0.2, 0.5);
        // deep negative probit tail (logΦ(−η)→ asymptotic / Mills tail)
        push(&mut rows, [8.0, 9.0, 1.2, 2.5], 1.0, 0.0, -3.0, 1.0);
        push(&mut rows, [-8.0, -9.0, 1.2, -2.5], 1.0, 1.0, 3.0, 1.0);
        // erfcx asymptotic cutover region (argument near/above 26)
        push(&mut rows, [40.0, 41.0, 0.7, 3.0], 1.0, 0.0, 0.0, 2.0);
        // tiny covariance (c ≈ 1, derivative of √ near flat)
        push(&mut rows, [-0.3, 0.5, 0.8, 1.5], 1.0, 1.0, 0.4, 1e-10);
        // large covariance + large scale (c large, strong coupling)
        push(&mut rows, [-0.2, 0.4, 1.1, 4.0], 1.0, 1.0, 0.1, 50.0);
        // near-zero slope (og→0, opb2→1)
        push(&mut rows, [-0.5, 0.3, 0.6, 1e-9], 1.0, 0.0, 0.7, 0.9);
        // zero weight (the w==0 early-out: every channel 0)
        push(&mut rows, [-0.5, 0.3, 0.6, 0.4], 0.0, 1.0, 0.7, 0.9);
        // small positive qd1 (log(ad1) near its valid edge)
        push(&mut rows, [-0.5, 0.3, 1e-3, 0.4], 1.0, 1.0, 0.2, 0.6);
        rows
    }

    /// #415 core deliverable — **fail loud, never silently degrade.** On a GPU
    /// box the device path MUST run; this calls `survival_rigid_row_jets_device_only`
    /// (which never falls back) and asserts it both (a) succeeds — no silent
    /// NVRTC-declined / wrong-arch / launch-failure swallowed by the dispatcher —
    /// and (b) matches the CPU oracle within the principled per-channel band, for
    /// BOTH the t4 and the no-t4 kernel variants and across the edge-regime sweep.
    ///
    /// When no CUDA device is present the device-only path returns `Err`, which
    /// is the legitimate state on a CPU-only box — so the test SKIPS with a clear
    /// log there. Set `GAM_REQUIRE_GPU=1` (CI on the GPU runner) to turn that skip
    /// into a HARD failure: a box that is supposed to have a GPU but can't run the
    /// kernel must break the build, not pass on the CPU.
    #[cfg(target_os = "linux")]
    #[test]
    fn device_only_path_runs_and_matches_cpu_fail_loud() {
        // Fail loud only when a CUDA device is actually present (a real runtime
        // check, not an env-var read — `env::var` is banned crate-wide): on a GPU
        // box the device path MUST run, while a CI runner with no device skips
        // gracefully.
        let require_gpu = gam_gpu::device_runtime::GpuRuntime::global().is_some();

        // Two batches: enough rows to amortise the launch, in both the interior
        // (smooth) and edge (transcendental-stress) regimes. The edge batch is
        // padded by tiling so it crosses DEVICE_ROW_THRESHOLD.
        let interior = fixture(DEVICE_ROW_THRESHOLD + 777);
        let edge_unit = edge_fixture();
        let reps = (DEVICE_ROW_THRESHOLD + 999).div_ceil(edge_unit.len());
        let edge: Vec<_> = edge_unit
            .iter()
            .cloned()
            .cycle()
            .take(reps * edge_unit.len())
            .collect();

        // Variant matrix: (label, dir_u, dir_v). All-zero (u,v) selects the
        // `survival_rowjet_no_t4` kernel (fourth channel ≡ 0); nonzero selects
        // the full `survival_rowjet`. Cover both so neither entry point rots.
        let zero = [0.0_f64; 4];
        let variants: [(&str, &[f64; 4], &[f64; 4]); 2] =
            [("t4", &DIRU, &DIRV), ("no_t4", &zero, &zero)];

        let mut ran_on_device = false;
        for (regime, rows) in [("interior", &interior), ("edge", &edge)] {
            for (vlabel, du, dv) in variants {
                let dev = match survival_rigid_row_jets_device_only(rows, 0.7, &DIR, du, dv) {
                    Ok(d) => d,
                    Err(e) => {
                        if require_gpu {
                            panic!(
                                "GAM_REQUIRE_GPU set but survival_rowjet device path \
                                 ({regime}/{vlabel}) could not run: {e}"
                            );
                        }
                        eprintln!(
                            "[#415] no CUDA device ({regime}/{vlabel}) — skipping device-only \
                             parity (set GAM_REQUIRE_GPU=1 to make this a hard failure): {e}"
                        );
                        continue;
                    }
                };
                ran_on_device = true;
                let cpu = survival_rigid_row_jets_cpu(rows, 0.7, &DIR, du, dv);
                assert_channel_parity(&format!("{regime}/{vlabel}/value"), &cpu.value, &dev.value);
                assert_channel_parity(&format!("{regime}/{vlabel}/grad"), &cpu.grad, &dev.grad);
                assert_channel_parity(&format!("{regime}/{vlabel}/hess"), &cpu.hess, &dev.hess);
                assert_channel_parity(&format!("{regime}/{vlabel}/third"), &cpu.third, &dev.third);
                assert_channel_parity(
                    &format!("{regime}/{vlabel}/fourth"),
                    &cpu.fourth,
                    &dev.fourth,
                );
                // The no_t4 variant must yield an exactly-zero fourth channel
                // (the kernel writes 0.0), and the CPU oracle agrees because
                // (u,v)=0 contracts the fourth tensor to zero.
                if vlabel == "no_t4" {
                    assert!(
                        dev.fourth.iter().all(|&x| x == 0.0),
                        "no_t4 kernel must write an all-zero fourth channel"
                    );
                }
            }
        }
        if ran_on_device {
            eprintln!("[#415] device-only parity PASSED on GPU for all regimes × variants");
        }
    }

}