Skip to main content

gam_models/survival/marginal_slope/
gpu_prep.rs

1//! Survival-flex per-row **prep-step** dispatchers.
2//!
3//! These two `try_device_*` entries are the GPU-shaped seam for the per-row
4//! prep work that currently dominates large-scale survival-flex wall time:
5//!
6//! * [`try_device_partition_cells`] — batched version of
7//!   `SurvivalMarginalSlopeFamily::denested_partition_cells`
8//!   (`src/families/survival_marginal_slope.rs:5701`).
9//! * [`try_device_cell_primary_fixed_partials`] — batched version of
10//!   `SurvivalMarginalSlopeFamily::denested_cell_primary_fixed_partials`
11//!   (`src/families/survival_marginal_slope.rs:6218`).
12//!
13//! Layout of the device output:
14//!
15//! ```text
16//!   cells     : Vec<f64>  // flat 18·n_cells doubles (cell ⨁ score_span ⨁ link_span)
17//!   offsets   : Vec<u32>  // CSR-style row offsets, length n_rows + 1
18//!   status    : Vec<u8>   // 0 = ok, non-zero = host-fallback signal for that row
19//! ```
20//!
21//! and the primary-fixed-partials kernel writes a parallel
22//! `12 + 40·primary.total` doubles per cell into the same row/cell indexing.
23//!
24//! ## Supported shapes
25//!
26//! The NVRTC bodies here cover the **no-runtime baseline** path: rows where
27//! neither `beta_h` nor `beta_w` is provided.  In that regime
28//! `build_denested_partition_cells_with_tails` returns a single trivial
29//! affine cell `(c0=a·scale, c1=b·scale, c2=0, c3=0)` per row (no split
30//! points), and the fixed-partials per cell reduces to just the `g`-slot
31//! pieces (`coeff_u[g]=dc_db`, `coeff_au[g]=dc_dab`, ..., `dc_da=[1,0,0,0]·scale`).
32//! Both kernels execute that closed-form path on-device with zero host
33//! arithmetic and the dispatchers DtoH back to the caller's shape.
34//!
35//! Rows that need a non-trivial knot-table / B-spline runtime traversal
36//! (i.e. any row carrying a `beta_h` or `beta_w` slice) cause the
37//! dispatcher to return `Ok(None)` so the family-side path falls back to
38//! the existing CPU per-row code.  The kernel surface, runtime upload
39//! plumbing, and DtoH re-pack stay device-shaped so the eventual general
40//! body lands behind the same call boundary.
41
42use crate::cubic_cell_kernel::{
43    DenestedCubicCell, DenestedPartitionCell, LocalSpanCubic,
44};
45use gam_gpu::gpu_error::GpuError;
46
47/// CUDA C++ kernel source strings for the two NVRTC kernels.  Both bodies are
48/// the literal translation of the CPU implementations cited above.
49pub mod kernel_src {
50    /// NVRTC source for `denested_partition_cells_kernel`.
51    ///
52    /// One thread per row.  Trivial no-runtime case: emits a single affine
53    /// cell `(c0=a·scale, c1=b·scale, c2=0, c3=0)` with zero score/link
54    /// spans, mirroring the CPU `build_denested_partition_cells_with_tails`
55    /// empty-split-points branch followed by the
56    /// `SurvivalMarginalSlopeFamily::denested_partition_cells` post-scale.
57    pub const DENESTED_PARTITION_CELLS_KERNEL_SRC: &str = r#"
58// f64 throughout (no --use_fast_math).
59
60extern "C" {
61
62__device__ __forceinline__ double pos_inf_f64() {
63    // IEEE-754 +inf bit pattern: 0x7ff0000000000000.
64    return __longlong_as_double((long long)0x7ff0000000000000LL);
65}
66__device__ __forceinline__ double neg_inf_f64() {
67    // IEEE-754 -inf bit pattern: 0xfff0000000000000.
68    return __longlong_as_double((long long)0xfff0000000000000LL);
69}
70
71__global__ void denested_partition_cells_kernel(
72    int n_rows,
73    double scale,
74    const double *a_per_row,
75    const double *b_per_row,
76    double *out_cells_flat,        // 18 doubles per row (single cell)
77    unsigned int *out_row_offsets, // length n_rows + 1
78    unsigned char *out_status      // length n_rows
79) {
80    int i = blockIdx.x * blockDim.x + threadIdx.x;
81    if (i >= n_rows) return;
82    double a = a_per_row[i];
83    double b = b_per_row[i];
84    double *cell = out_cells_flat + (long long)i * 18;
85    // ── cell: (-inf, +inf, c0=a*scale, c1=b*scale, c2=0, c3=0) ──
86    cell[0]  = neg_inf_f64();
87    cell[1]  = pos_inf_f64();
88    cell[2]  = a * scale;
89    cell[3]  = b * scale;
90    cell[4]  = 0.0;
91    cell[5]  = 0.0;
92    // ── score_span (zero cubic, left=0,right=1) ──
93    cell[6]  = 0.0; cell[7]  = 1.0;
94    cell[8]  = 0.0; cell[9]  = 0.0; cell[10] = 0.0; cell[11] = 0.0;
95    // ── link_span (zero cubic, left=0,right=1) ──
96    cell[12] = 0.0; cell[13] = 1.0;
97    cell[14] = 0.0; cell[15] = 0.0; cell[16] = 0.0; cell[17] = 0.0;
98    // ── row offset: one cell per row ──
99    out_row_offsets[i] = (unsigned int)i;
100    if (i == n_rows - 1) {
101        out_row_offsets[n_rows] = (unsigned int)n_rows;
102    }
103    out_status[i] = 0;
104}
105
106}  // extern "C"
107"#;
108
109    /// NVRTC source for `denested_cell_primary_fixed_partials_kernel`.
110    ///
111    /// One thread per cell.  Trivial no-runtime case: only the `g` slot is
112    /// populated because `primary.h` and `primary.w` are empty when both
113    /// runtimes are absent.  Mirrors the closed-form arithmetic the CPU
114    /// `denested_cell_primary_fixed_partials` runs when `h_len == 0 &&
115    /// w_len == 0`.
116    ///
117    /// For the trivial cell `(c0=a·scale, c1=b·scale, c2=0, c3=0)` the
118    /// partials evaluate to:
119    /// * `dc_da   = [1, 0, 0, 0] · scale`
120    /// * `dc_daa  = [0, 0, 0, 0]`
121    /// * `dc_daaa = [0, 0, 0, 0]`
122    /// * `dc_db = dc_dab = dc_dbb = dc_dabb = dc_dbbb = ...` reduce to
123    ///   `[0, 1, 0, 0] · scale`, `[0, 0, 0, 0]`, ... per the
124    ///   `denested_cell_*_partials` formulas with `score_span=zero`,
125    ///   `link_span=zero`.
126    pub const DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC: &str = r#"
127// f64 throughout (no --use_fast_math).
128
129extern "C" {
130
131__global__ void denested_cell_primary_fixed_partials_kernel(
132    int n_cells_total,
133    unsigned int r,
134    unsigned int g_slot,
135    double scale,
136    double *out_partials_flat,  // (12 + 40·r) doubles per cell
137    unsigned char *out_status
138) {
139    int cell = blockIdx.x * blockDim.x + threadIdx.x;
140    if (cell >= n_cells_total) return;
141    unsigned int per_cell = 12u + 40u * r;
142    double *base = out_partials_flat + (long long)cell * (long long)per_cell;
143    // Zero the whole block (cheap; r is small).
144    for (unsigned int s = 0; s < per_cell; ++s) {
145        base[s] = 0.0;
146    }
147    // dc_da = [1, 0, 0, 0] · scale
148    base[0] = scale;
149    // dc_daa, dc_daaa already zero.
150    // g-slot fills (offset = 12 + 4·g_slot within each per-cell run).
151    //   coeff_u   [g] = dc_db   = [0, 1, 0, 0] · scale
152    //   coeff_au  [g] = dc_dab  = [0, 0, 0, 0]
153    //   coeff_bu  [g] = dc_dbb  = [0, 0, 0, 0]
154    //   coeff_aau [g] = dc_daab = [0, 0, 0, 0]
155    //   coeff_abu [g] = dc_dabb = [0, 0, 0, 0]
156    //   coeff_bbu [g] = dc_dbbb = [0, 0, 0, 0]
157    //   (third partials all zero in the no-runtime case)
158    unsigned int g_off = 12u + 4u * g_slot;
159    base[g_off + 1] = scale;  // coeff_u[g][1] = scale
160    out_status[cell] = 0;
161}
162
163}  // extern "C"
164"#;
165}
166
167/// Per-row inputs for [`try_device_partition_cells`].
168#[derive(Clone, Copy, Debug)]
169pub struct PartitionCellsRowInputs<'a> {
170    pub a: f64,
171    pub b: f64,
172    pub beta_h: Option<&'a [f64]>,
173    pub beta_w: Option<&'a [f64]>,
174}
175
176/// Output of [`try_device_partition_cells`]: per-row partition cells in the
177/// existing `DenestedPartitionCell` shape, one inner `Vec` per row.
178pub type PartitionCellsOutput = Vec<Vec<DenestedPartitionCell>>;
179
180/// GPU-shaped seam for `SurvivalMarginalSlopeFamily::denested_partition_cells`.
181///
182/// Returns:
183///
184/// * `Ok(None)` when the GPU path is unsupported (CUDA absent, or any row
185///   carries a `beta_h`/`beta_w` slice that would require a B-spline
186///   runtime traversal — those fall through to the existing CPU per-row
187///   path).
188/// * `Ok(Some(out))` when the device-shaped output is materialized.
189/// * `Err(_)` only when the request *is* supported but the driver failed.
190pub fn try_device_partition_cells(
191    rows: &[PartitionCellsRowInputs<'_>],
192) -> Result<Option<PartitionCellsOutput>, GpuError> {
193    if rows.is_empty() {
194        return Ok(Some(Vec::new()));
195    }
196    // Only the no-runtime baseline (no β slices on any row) is implemented
197    // device-side today.  Any row carrying a beta vector needs the
198    // knot-table / B-spline traversal which falls back to CPU.
199    let trivial = rows
200        .iter()
201        .all(|r| r.beta_h.is_none() && r.beta_w.is_none());
202    if !trivial {
203        return Ok(None);
204    }
205    device_dispatch::partition_cells_baseline(rows, 1.0)
206}
207
208/// Per-cell inputs for [`try_device_cell_primary_fixed_partials`].
209#[derive(Clone, Copy, Debug)]
210pub struct CellPrimaryFixedPartialsCellInputs {
211    pub score_span: LocalSpanCubic,
212    pub link_span: LocalSpanCubic,
213}
214
215/// Per-row inputs for [`try_device_cell_primary_fixed_partials`]: shared
216/// `(a, b)` scalars, the per-cell slice from this row, and the layout of
217/// the destination `FlexPrimarySlices` (`r = primary.total`, `g_slot =
218/// primary.g`).
219#[derive(Clone, Copy, Debug)]
220pub struct CellPrimaryFixedPartialsRowInputs<'a> {
221    pub cells: &'a [CellPrimaryFixedPartialsCellInputs],
222    pub layout: FlexPrimaryLayout,
223}
224
225/// Flat-packed output of [`try_device_cell_primary_fixed_partials`].
226///
227/// `partials[row_idx][cell_idx]` is a `Vec<f64>` of length `12 + 40·r` laid
228/// out per the
229/// [`kernel_src::DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC`] schema.
230#[derive(Clone, Debug, Default)]
231pub struct CellPrimaryFixedPartialsOutput {
232    pub partials: Vec<Vec<Vec<f64>>>,
233}
234
235/// FlexPrimaryLayout constant for the fixed-partials kernel.
236///
237/// Mirrors the host `FlexPrimarySlices` shape that the family passes into
238/// the CPU per-cell partials helper.  Held in the device-side closure
239/// because the trivial kernel only needs `r` and the `g` slot index.
240#[derive(Clone, Copy, Debug)]
241pub struct FlexPrimaryLayout {
242    pub r: u32,
243    pub g_slot: u32,
244}
245
246/// GPU-shaped seam for
247/// `SurvivalMarginalSlopeFamily::denested_cell_primary_fixed_partials`.
248///
249/// Returns `Ok(None)` when the input shape is outside the supported
250/// regime (any non-zero score/link span — i.e. a runtime that needs the
251/// full B-spline basis traversal — or no cells at all).
252///
253/// When the caller passes only cells whose `score_span` and `link_span`
254/// are the no-runtime zero spans, the kernel evaluates the closed-form
255/// trivial-cell partials on-device and returns the flat-packed layout.
256pub fn try_device_cell_primary_fixed_partials(
257    rows: &[CellPrimaryFixedPartialsRowInputs<'_>],
258) -> Result<Option<CellPrimaryFixedPartialsOutput>, GpuError> {
259    if rows.is_empty() {
260        return Ok(Some(CellPrimaryFixedPartialsOutput::default()));
261    }
262    // We can only run the device kernel when every cell's spans are the
263    // zero-span (no-runtime) baseline, because the trivial kernel doesn't
264    // carry the knot tables needed for a non-trivial basis traversal.
265    let trivial_spans = rows.iter().all(|row| {
266        row.cells
267            .iter()
268            .all(|cell| span_is_zero(cell.score_span) && span_is_zero(cell.link_span))
269    });
270    if !trivial_spans {
271        return Ok(None);
272    }
273    // The trivial kernel requires every row's layout to share the same
274    // `(r, g_slot)` so a single launch can emit a uniform per-cell stride.
275    // Differing layouts → decline (CPU fallback per row).
276    let layout0 = rows[0].layout;
277    if !rows
278        .iter()
279        .all(|r| r.layout.r == layout0.r && r.layout.g_slot == layout0.g_slot)
280    {
281        return Ok(None);
282    }
283    // If no cells at all, return an empty partials shape that matches
284    // `rows.len()` so the caller can index into the result.
285    let mut row_cell_counts: Vec<usize> = rows.iter().map(|r| r.cells.len()).collect();
286    let total_cells: usize = row_cell_counts.iter().copied().sum();
287    if total_cells == 0 {
288        let mut partials: Vec<Vec<Vec<f64>>> = Vec::with_capacity(rows.len());
289        for _ in 0..rows.len() {
290            partials.push(Vec::new());
291        }
292        return Ok(Some(CellPrimaryFixedPartialsOutput { partials }));
293    }
294    let flat = match device_dispatch::cell_primary_fixed_partials_baseline(layout0, total_cells) {
295        Ok(flat) => flat,
296        Err(_) => return Ok(None),
297    };
298    let per_cell = 12usize + 40usize * (layout0.r as usize);
299    let mut partials: Vec<Vec<Vec<f64>>> = Vec::with_capacity(rows.len());
300    let mut cursor = 0usize;
301    for n_cells in row_cell_counts.drain(..) {
302        let mut row_cells: Vec<Vec<f64>> = Vec::with_capacity(n_cells);
303        for _ in 0..n_cells {
304            row_cells.push(flat[cursor..cursor + per_cell].to_vec());
305            cursor += per_cell;
306        }
307        partials.push(row_cells);
308    }
309    assert_eq!(cursor, flat.len());
310    Ok(Some(CellPrimaryFixedPartialsOutput { partials }))
311}
312
313#[inline]
314fn span_is_zero(span: LocalSpanCubic) -> bool {
315    span.c0 == 0.0 && span.c1 == 0.0 && span.c2 == 0.0 && span.c3 == 0.0
316}
317
318/// Construct the trivial no-runtime partition cell for `(a, b, scale)`.
319/// Used as the byte-equivalent host shape for the kernel's per-row output
320/// (and as the reference the kernel reproduces).
321pub fn trivial_partition_cell(a: f64, b: f64, scale: f64) -> DenestedPartitionCell {
322    DenestedPartitionCell {
323        cell: DenestedCubicCell {
324            left: f64::NEG_INFINITY,
325            right: f64::INFINITY,
326            c0: a * scale,
327            c1: b * scale,
328            c2: 0.0,
329            c3: 0.0,
330        },
331        score_span: LocalSpanCubic {
332            left: 0.0,
333            right: 1.0,
334            c0: 0.0,
335            c1: 0.0,
336            c2: 0.0,
337            c3: 0.0,
338        },
339        link_span: LocalSpanCubic {
340            left: 0.0,
341            right: 1.0,
342            c0: 0.0,
343            c1: 0.0,
344            c2: 0.0,
345            c3: 0.0,
346        },
347        left_edge: crate::cubic_cell_kernel::PartitionEdge::Fixed(f64::NEG_INFINITY),
348        right_edge: crate::cubic_cell_kernel::PartitionEdge::Fixed(f64::INFINITY),
349    }
350}
351
352#[cfg(target_os = "linux")]
353mod device_dispatch {
354    use super::kernel_src::DENESTED_PARTITION_CELLS_KERNEL_SRC;
355    use super::{PartitionCellsOutput, PartitionCellsRowInputs, trivial_partition_cell};
356    use gam_gpu::device_cache::PtxModuleCache;
357    use gam_gpu::gpu_err as gam_gpu_err;
358    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
359    use gam_gpu::solver::context_and_stream;
360    use cudarc::driver::{LaunchConfig, PushKernelArg};
361
362    static PARTITION_PTX_CACHE: PtxModuleCache = PtxModuleCache::new();
363
364    const THREADS_PER_BLOCK: u32 = 128;
365
366    /// Launch the partition-cells kernel for the no-runtime baseline.
367    pub(super) fn partition_cells_baseline(
368        rows: &[PartitionCellsRowInputs<'_>],
369        scale: f64,
370    ) -> Result<Option<PartitionCellsOutput>, GpuError> {
371        let n = rows.len();
372        let n_u32 = u32::try_from(n)
373            .map_err(|_| gam_gpu_err!("partition_cells_baseline: n_rows={n} exceeds u32"))?;
374        let n_i32 = i32::try_from(n)
375            .map_err(|_| gam_gpu_err!("partition_cells_baseline: n_rows={n} exceeds i32"))?;
376        let (ctx, stream) = match context_and_stream() {
377            Ok(pair) => pair,
378            Err(_) => return Ok(None),
379        };
380        let module = PARTITION_PTX_CACHE.get_or_compile(
381            &ctx,
382            "survival_flex_prep::partition_cells",
383            DENESTED_PARTITION_CELLS_KERNEL_SRC,
384        )?;
385        let func = module
386            .load_function("denested_partition_cells_kernel")
387            .gpu_ctx("survival_flex_prep: load_function partition_cells")?;
388
389        let a_host: Vec<f64> = rows.iter().map(|r| r.a).collect();
390        let b_host: Vec<f64> = rows.iter().map(|r| r.b).collect();
391        let a_dev = stream
392            .clone_htod(&a_host)
393            .gpu_ctx("survival_flex_prep: upload a_per_row")?;
394        let b_dev = stream
395            .clone_htod(&b_host)
396            .gpu_ctx("survival_flex_prep: upload b_per_row")?;
397        let mut cells_dev = stream
398            .alloc_zeros::<f64>(n * 18)
399            .gpu_ctx("survival_flex_prep: alloc cells_flat")?;
400        let mut offsets_dev = stream
401            .alloc_zeros::<u32>(n + 1)
402            .gpu_ctx("survival_flex_prep: alloc row_offsets")?;
403        let mut status_dev = stream
404            .alloc_zeros::<u8>(n)
405            .gpu_ctx("survival_flex_prep: alloc status")?;
406
407        let cfg = LaunchConfig {
408            grid_dim: (n_u32.div_ceil(THREADS_PER_BLOCK).max(1), 1, 1),
409            block_dim: (THREADS_PER_BLOCK, 1, 1),
410            shared_mem_bytes: 0,
411        };
412        // SAFETY: kernel signature is fixed in the source string above
413        // (n:i32, scale:f64, 2 const f64*, 1 mut f64*, 1 mut u32*, 1 mut u8*).
414        // All buffers are sized to the kernel's per-row stride, and each
415        // thread guards i >= n_rows.
416        unsafe {
417            let mut builder = stream.launch_builder(&func);
418            builder.arg(&n_i32);
419            builder.arg(&scale);
420            builder.arg(&a_dev);
421            builder.arg(&b_dev);
422            builder.arg(&mut cells_dev);
423            builder.arg(&mut offsets_dev);
424            builder.arg(&mut status_dev);
425            builder.launch(cfg)
426        }
427        .map(|_event_pair| ())
428        .gpu_ctx("survival_flex_prep: launch partition_cells")?;
429
430        let cells_host = stream
431            .clone_dtoh(&cells_dev)
432            .gpu_ctx("survival_flex_prep: download cells_flat")?;
433        let status_host = stream
434            .clone_dtoh(&status_dev)
435            .gpu_ctx("survival_flex_prep: download status")?;
436        for (i, st) in status_host.iter().enumerate() {
437            if *st != 0 {
438                return Err(gam_gpu_err!(
439                    "survival_flex_prep: row {i} status={st} from device kernel"
440                ));
441            }
442        }
443        assert_eq!(cells_host.len(), n * 18);
444        // Reconstruct per-row Vec<DenestedPartitionCell>.  The kernel writes
445        // exactly one cell per row in the trivial baseline; we reproduce
446        // the host trivial cell shape (using the kernel-written numerics
447        // for c0/c1) and ignore the device-written infinity sentinels in
448        // favour of the host-typed `f64::INFINITY` constants — both encode
449        // bit-identical infinities, so this is a presentation-only step.
450        let mut out: PartitionCellsOutput = Vec::with_capacity(n);
451        for i in 0..n {
452            let base = i * 18;
453            let c0 = cells_host[base + 2];
454            let c1 = cells_host[base + 3];
455            let mut cell = trivial_partition_cell(rows[i].a, rows[i].b, scale);
456            // Use the device-computed (a*scale, b*scale) so any future
457            // scale plumbing is faithfully reflected.
458            cell.cell.c0 = c0;
459            cell.cell.c1 = c1;
460            out.push(vec![cell]);
461        }
462        Ok(Some(out))
463    }
464
465    /// Launch the fixed-partials kernel for the no-runtime baseline.
466    ///
467    /// Returns the flat-packed `(12 + 40·r) · n_cells_total` doubles per
468    /// the layout described in
469    /// `kernel_src::DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC`.
470    /// The caller re-packs into `CellPrimaryFixedPartialsOutput` and the
471    /// family-side consumer rebuilds `DenestedCellPrimaryFixedPartials`
472    /// via `from_flat_slice`.
473    pub(super) fn cell_primary_fixed_partials_baseline(
474        layout: super::FlexPrimaryLayout,
475        n_cells_total: usize,
476    ) -> Result<Vec<f64>, GpuError> {
477        use super::kernel_src::DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC;
478        static FP_PTX_CACHE: PtxModuleCache = PtxModuleCache::new();
479
480        let n_i32 = i32::try_from(n_cells_total).map_err(|_| {
481            gam_gpu_err!(
482                "cell_primary_fixed_partials_baseline: n_cells={n_cells_total} exceeds i32"
483            )
484        })?;
485        let n_u32 = u32::try_from(n_cells_total).map_err(|_| {
486            gam_gpu_err!(
487                "cell_primary_fixed_partials_baseline: n_cells={n_cells_total} exceeds u32"
488            )
489        })?;
490        let (ctx, stream) = context_and_stream()
491            .map_err(|reason| gam_gpu::gpu_error::GpuError::DriverCallFailed { reason })?;
492        let module = FP_PTX_CACHE.get_or_compile(
493            &ctx,
494            "survival_flex_prep::cell_primary_fixed_partials",
495            DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC,
496        )?;
497        let func = module
498            .load_function("denested_cell_primary_fixed_partials_kernel")
499            .gpu_ctx("survival_flex_prep: load_function fixed_partials")?;
500
501        let per_cell = 12usize + 40usize * (layout.r as usize);
502        let scale = 1.0f64;
503        let mut out_dev = stream
504            .alloc_zeros::<f64>(n_cells_total * per_cell)
505            .gpu_ctx("survival_flex_prep: alloc fixed_partials")?;
506        let mut status_dev = stream
507            .alloc_zeros::<u8>(n_cells_total)
508            .gpu_ctx("survival_flex_prep: alloc fixed_partials status")?;
509        let cfg = LaunchConfig {
510            grid_dim: (n_u32.div_ceil(THREADS_PER_BLOCK).max(1), 1, 1),
511            block_dim: (THREADS_PER_BLOCK, 1, 1),
512            shared_mem_bytes: 0,
513        };
514        // SAFETY: kernel signature matches (n:i32, r:u32, g_slot:u32,
515        // scale:f64, mut f64*, mut u8*).  Buffer sized to per-cell stride.
516        unsafe {
517            let mut builder = stream.launch_builder(&func);
518            builder.arg(&n_i32);
519            builder.arg(&layout.r);
520            builder.arg(&layout.g_slot);
521            builder.arg(&scale);
522            builder.arg(&mut out_dev);
523            builder.arg(&mut status_dev);
524            builder.launch(cfg)
525        }
526        .map(|_event_pair| ())
527        .gpu_ctx("survival_flex_prep: launch fixed_partials")?;
528        let out_host = stream
529            .clone_dtoh(&out_dev)
530            .gpu_ctx("survival_flex_prep: download fixed_partials")?;
531        let status_host = stream
532            .clone_dtoh(&status_dev)
533            .gpu_ctx("survival_flex_prep: download fixed_partials status")?;
534        for (i, st) in status_host.iter().enumerate() {
535            if *st != 0 {
536                return Err(gam_gpu_err!(
537                    "survival_flex_prep: fixed_partials cell {i} status={st}"
538                ));
539            }
540        }
541        Ok(out_host)
542    }
543}
544
545#[cfg(not(target_os = "linux"))]
546mod device_dispatch {
547    use super::{PartitionCellsOutput, PartitionCellsRowInputs};
548    use gam_gpu::gpu_err as gam_gpu_err;
549    use gam_gpu::gpu_error::GpuError;
550
551    pub(super) fn partition_cells_baseline(
552        rows: &[PartitionCellsRowInputs<'_>],
553        scale: f64,
554    ) -> Result<Option<PartitionCellsOutput>, GpuError> {
555        // CUDA only supported on linux; the caller falls back to CPU.
556        // The scalar inputs are surfaced in the diagnostic-but-not-error
557        // log so callers can still see what shape would have launched.
558        let first = rows.first().map(|row| (row.a, row.b));
559        log::trace!(
560            "survival_flex_prep::partition_cells_baseline declined on non-linux \
561             (n_rows={}, scale={scale}, first_ab={first:?})",
562            rows.len(),
563        );
564        Ok(None)
565    }
566
567    pub(super) fn cell_primary_fixed_partials_baseline(
568        layout: super::FlexPrimaryLayout,
569        n_cells_total: usize,
570    ) -> Result<Vec<f64>, GpuError> {
571        Err(gam_gpu_err!(
572            "survival_flex_prep::cell_primary_fixed_partials_baseline: CUDA only supported on linux \
573             (would have launched n_cells={n_cells_total}, r={}, g_slot={})",
574            layout.r,
575            layout.g_slot
576        ))
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583
584    #[test]
585    fn empty_partition_inputs_short_circuit() {
586        let out = try_device_partition_cells(&[]).expect("ok");
587        assert!(out.is_some());
588        assert!(out.unwrap().is_empty());
589    }
590
591    #[test]
592    fn nonempty_partition_with_betas_declines() {
593        let beta = [0.0_f64];
594        let inputs = [PartitionCellsRowInputs {
595            a: 0.0,
596            b: 1.0,
597            beta_h: Some(&beta),
598            beta_w: None,
599        }];
600        let out = try_device_partition_cells(&inputs).expect("ok");
601        // Must decline because beta_h is present (B-spline runtime traversal
602        // is not implemented in the trivial kernel).
603        assert!(out.is_none());
604    }
605
606    #[test]
607    fn empty_fixed_partials_inputs_short_circuit() {
608        let out = try_device_cell_primary_fixed_partials(&[]).expect("ok");
609        assert!(out.is_some());
610        assert!(out.unwrap().partials.is_empty());
611    }
612
613    #[test]
614    fn empty_cells_per_row_returns_empty_partials() {
615        let inputs = [CellPrimaryFixedPartialsRowInputs {
616            cells: &[],
617            layout: FlexPrimaryLayout { r: 4, g_slot: 3 },
618        }];
619        let out = try_device_cell_primary_fixed_partials(&inputs).expect("ok");
620        let some = out.expect("Some when all rows have zero cells");
621        assert_eq!(some.partials.len(), 1);
622        assert!(some.partials[0].is_empty());
623    }
624
625    #[test]
626    fn kernel_src_strings_are_nonempty() {
627        assert!(!kernel_src::DENESTED_PARTITION_CELLS_KERNEL_SRC.is_empty());
628        assert!(!kernel_src::DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC.is_empty());
629    }
630
631    #[test]
632    fn trivial_partition_cell_matches_cpu_empty_split_branch() {
633        // For a=2.5, b=-1.25, scale=1.0 the empty-split-points branch of
634        // build_denested_partition_cells_with_tails produces a single
635        // affine cell with c0=a, c1=b (post-scale).
636        let cell = trivial_partition_cell(2.5, -1.25, 1.0);
637        assert_eq!(cell.cell.c0, 2.5);
638        assert_eq!(cell.cell.c1, -1.25);
639        assert_eq!(cell.cell.c2, 0.0);
640        assert_eq!(cell.cell.c3, 0.0);
641        assert!(cell.cell.left.is_infinite() && cell.cell.left.is_sign_negative());
642        assert!(cell.cell.right.is_infinite() && cell.cell.right.is_sign_positive());
643    }
644}