1use crate::cubic_cell_kernel::{
43 DenestedCubicCell, DenestedPartitionCell, LocalSpanCubic,
44};
45use gam_gpu::gpu_error::GpuError;
46
47pub mod kernel_src {
50 pub const DENESTED_PARTITION_CELLS_KERNEL_SRC: &str = r#"
58// f64 throughout (no --use_fast_math).
59
60extern "C" {
61
62__device__ __forceinline__ double pos_inf_f64() {
63 // IEEE-754 +inf bit pattern: 0x7ff0000000000000.
64 return __longlong_as_double((long long)0x7ff0000000000000LL);
65}
66__device__ __forceinline__ double neg_inf_f64() {
67 // IEEE-754 -inf bit pattern: 0xfff0000000000000.
68 return __longlong_as_double((long long)0xfff0000000000000LL);
69}
70
71__global__ void denested_partition_cells_kernel(
72 int n_rows,
73 double scale,
74 const double *a_per_row,
75 const double *b_per_row,
76 double *out_cells_flat, // 18 doubles per row (single cell)
77 unsigned int *out_row_offsets, // length n_rows + 1
78 unsigned char *out_status // length n_rows
79) {
80 int i = blockIdx.x * blockDim.x + threadIdx.x;
81 if (i >= n_rows) return;
82 double a = a_per_row[i];
83 double b = b_per_row[i];
84 double *cell = out_cells_flat + (long long)i * 18;
85 // ── cell: (-inf, +inf, c0=a*scale, c1=b*scale, c2=0, c3=0) ──
86 cell[0] = neg_inf_f64();
87 cell[1] = pos_inf_f64();
88 cell[2] = a * scale;
89 cell[3] = b * scale;
90 cell[4] = 0.0;
91 cell[5] = 0.0;
92 // ── score_span (zero cubic, left=0,right=1) ──
93 cell[6] = 0.0; cell[7] = 1.0;
94 cell[8] = 0.0; cell[9] = 0.0; cell[10] = 0.0; cell[11] = 0.0;
95 // ── link_span (zero cubic, left=0,right=1) ──
96 cell[12] = 0.0; cell[13] = 1.0;
97 cell[14] = 0.0; cell[15] = 0.0; cell[16] = 0.0; cell[17] = 0.0;
98 // ── row offset: one cell per row ──
99 out_row_offsets[i] = (unsigned int)i;
100 if (i == n_rows - 1) {
101 out_row_offsets[n_rows] = (unsigned int)n_rows;
102 }
103 out_status[i] = 0;
104}
105
106} // extern "C"
107"#;
108
109 pub const DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC: &str = r#"
127// f64 throughout (no --use_fast_math).
128
129extern "C" {
130
131__global__ void denested_cell_primary_fixed_partials_kernel(
132 int n_cells_total,
133 unsigned int r,
134 unsigned int g_slot,
135 double scale,
136 double *out_partials_flat, // (12 + 40·r) doubles per cell
137 unsigned char *out_status
138) {
139 int cell = blockIdx.x * blockDim.x + threadIdx.x;
140 if (cell >= n_cells_total) return;
141 unsigned int per_cell = 12u + 40u * r;
142 double *base = out_partials_flat + (long long)cell * (long long)per_cell;
143 // Zero the whole block (cheap; r is small).
144 for (unsigned int s = 0; s < per_cell; ++s) {
145 base[s] = 0.0;
146 }
147 // dc_da = [1, 0, 0, 0] · scale
148 base[0] = scale;
149 // dc_daa, dc_daaa already zero.
150 // g-slot fills (offset = 12 + 4·g_slot within each per-cell run).
151 // coeff_u [g] = dc_db = [0, 1, 0, 0] · scale
152 // coeff_au [g] = dc_dab = [0, 0, 0, 0]
153 // coeff_bu [g] = dc_dbb = [0, 0, 0, 0]
154 // coeff_aau [g] = dc_daab = [0, 0, 0, 0]
155 // coeff_abu [g] = dc_dabb = [0, 0, 0, 0]
156 // coeff_bbu [g] = dc_dbbb = [0, 0, 0, 0]
157 // (third partials all zero in the no-runtime case)
158 unsigned int g_off = 12u + 4u * g_slot;
159 base[g_off + 1] = scale; // coeff_u[g][1] = scale
160 out_status[cell] = 0;
161}
162
163} // extern "C"
164"#;
165}
166
167#[derive(Clone, Copy, Debug)]
169pub struct PartitionCellsRowInputs<'a> {
170 pub a: f64,
171 pub b: f64,
172 pub beta_h: Option<&'a [f64]>,
173 pub beta_w: Option<&'a [f64]>,
174}
175
176pub type PartitionCellsOutput = Vec<Vec<DenestedPartitionCell>>;
179
180pub fn try_device_partition_cells(
191 rows: &[PartitionCellsRowInputs<'_>],
192) -> Result<Option<PartitionCellsOutput>, GpuError> {
193 if rows.is_empty() {
194 return Ok(Some(Vec::new()));
195 }
196 let trivial = rows
200 .iter()
201 .all(|r| r.beta_h.is_none() && r.beta_w.is_none());
202 if !trivial {
203 return Ok(None);
204 }
205 device_dispatch::partition_cells_baseline(rows, 1.0)
206}
207
208#[derive(Clone, Copy, Debug)]
210pub struct CellPrimaryFixedPartialsCellInputs {
211 pub score_span: LocalSpanCubic,
212 pub link_span: LocalSpanCubic,
213}
214
215#[derive(Clone, Copy, Debug)]
220pub struct CellPrimaryFixedPartialsRowInputs<'a> {
221 pub cells: &'a [CellPrimaryFixedPartialsCellInputs],
222 pub layout: FlexPrimaryLayout,
223}
224
225#[derive(Clone, Debug, Default)]
231pub struct CellPrimaryFixedPartialsOutput {
232 pub partials: Vec<Vec<Vec<f64>>>,
233}
234
235#[derive(Clone, Copy, Debug)]
241pub struct FlexPrimaryLayout {
242 pub r: u32,
243 pub g_slot: u32,
244}
245
246pub fn try_device_cell_primary_fixed_partials(
257 rows: &[CellPrimaryFixedPartialsRowInputs<'_>],
258) -> Result<Option<CellPrimaryFixedPartialsOutput>, GpuError> {
259 if rows.is_empty() {
260 return Ok(Some(CellPrimaryFixedPartialsOutput::default()));
261 }
262 let trivial_spans = rows.iter().all(|row| {
266 row.cells
267 .iter()
268 .all(|cell| span_is_zero(cell.score_span) && span_is_zero(cell.link_span))
269 });
270 if !trivial_spans {
271 return Ok(None);
272 }
273 let layout0 = rows[0].layout;
277 if !rows
278 .iter()
279 .all(|r| r.layout.r == layout0.r && r.layout.g_slot == layout0.g_slot)
280 {
281 return Ok(None);
282 }
283 let mut row_cell_counts: Vec<usize> = rows.iter().map(|r| r.cells.len()).collect();
286 let total_cells: usize = row_cell_counts.iter().copied().sum();
287 if total_cells == 0 {
288 let mut partials: Vec<Vec<Vec<f64>>> = Vec::with_capacity(rows.len());
289 for _ in 0..rows.len() {
290 partials.push(Vec::new());
291 }
292 return Ok(Some(CellPrimaryFixedPartialsOutput { partials }));
293 }
294 let flat = match device_dispatch::cell_primary_fixed_partials_baseline(layout0, total_cells) {
295 Ok(flat) => flat,
296 Err(_) => return Ok(None),
297 };
298 let per_cell = 12usize + 40usize * (layout0.r as usize);
299 let mut partials: Vec<Vec<Vec<f64>>> = Vec::with_capacity(rows.len());
300 let mut cursor = 0usize;
301 for n_cells in row_cell_counts.drain(..) {
302 let mut row_cells: Vec<Vec<f64>> = Vec::with_capacity(n_cells);
303 for _ in 0..n_cells {
304 row_cells.push(flat[cursor..cursor + per_cell].to_vec());
305 cursor += per_cell;
306 }
307 partials.push(row_cells);
308 }
309 assert_eq!(cursor, flat.len());
310 Ok(Some(CellPrimaryFixedPartialsOutput { partials }))
311}
312
313#[inline]
314fn span_is_zero(span: LocalSpanCubic) -> bool {
315 span.c0 == 0.0 && span.c1 == 0.0 && span.c2 == 0.0 && span.c3 == 0.0
316}
317
318pub fn trivial_partition_cell(a: f64, b: f64, scale: f64) -> DenestedPartitionCell {
322 DenestedPartitionCell {
323 cell: DenestedCubicCell {
324 left: f64::NEG_INFINITY,
325 right: f64::INFINITY,
326 c0: a * scale,
327 c1: b * scale,
328 c2: 0.0,
329 c3: 0.0,
330 },
331 score_span: LocalSpanCubic {
332 left: 0.0,
333 right: 1.0,
334 c0: 0.0,
335 c1: 0.0,
336 c2: 0.0,
337 c3: 0.0,
338 },
339 link_span: LocalSpanCubic {
340 left: 0.0,
341 right: 1.0,
342 c0: 0.0,
343 c1: 0.0,
344 c2: 0.0,
345 c3: 0.0,
346 },
347 left_edge: crate::cubic_cell_kernel::PartitionEdge::Fixed(f64::NEG_INFINITY),
348 right_edge: crate::cubic_cell_kernel::PartitionEdge::Fixed(f64::INFINITY),
349 }
350}
351
352#[cfg(target_os = "linux")]
353mod device_dispatch {
354 use super::kernel_src::DENESTED_PARTITION_CELLS_KERNEL_SRC;
355 use super::{PartitionCellsOutput, PartitionCellsRowInputs, trivial_partition_cell};
356 use gam_gpu::device_cache::PtxModuleCache;
357 use gam_gpu::gpu_err as gam_gpu_err;
358 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
359 use gam_gpu::solver::context_and_stream;
360 use cudarc::driver::{LaunchConfig, PushKernelArg};
361
362 static PARTITION_PTX_CACHE: PtxModuleCache = PtxModuleCache::new();
363
364 const THREADS_PER_BLOCK: u32 = 128;
365
366 pub(super) fn partition_cells_baseline(
368 rows: &[PartitionCellsRowInputs<'_>],
369 scale: f64,
370 ) -> Result<Option<PartitionCellsOutput>, GpuError> {
371 let n = rows.len();
372 let n_u32 = u32::try_from(n)
373 .map_err(|_| gam_gpu_err!("partition_cells_baseline: n_rows={n} exceeds u32"))?;
374 let n_i32 = i32::try_from(n)
375 .map_err(|_| gam_gpu_err!("partition_cells_baseline: n_rows={n} exceeds i32"))?;
376 let (ctx, stream) = match context_and_stream() {
377 Ok(pair) => pair,
378 Err(_) => return Ok(None),
379 };
380 let module = PARTITION_PTX_CACHE.get_or_compile(
381 &ctx,
382 "survival_flex_prep::partition_cells",
383 DENESTED_PARTITION_CELLS_KERNEL_SRC,
384 )?;
385 let func = module
386 .load_function("denested_partition_cells_kernel")
387 .gpu_ctx("survival_flex_prep: load_function partition_cells")?;
388
389 let a_host: Vec<f64> = rows.iter().map(|r| r.a).collect();
390 let b_host: Vec<f64> = rows.iter().map(|r| r.b).collect();
391 let a_dev = stream
392 .clone_htod(&a_host)
393 .gpu_ctx("survival_flex_prep: upload a_per_row")?;
394 let b_dev = stream
395 .clone_htod(&b_host)
396 .gpu_ctx("survival_flex_prep: upload b_per_row")?;
397 let mut cells_dev = stream
398 .alloc_zeros::<f64>(n * 18)
399 .gpu_ctx("survival_flex_prep: alloc cells_flat")?;
400 let mut offsets_dev = stream
401 .alloc_zeros::<u32>(n + 1)
402 .gpu_ctx("survival_flex_prep: alloc row_offsets")?;
403 let mut status_dev = stream
404 .alloc_zeros::<u8>(n)
405 .gpu_ctx("survival_flex_prep: alloc status")?;
406
407 let cfg = LaunchConfig {
408 grid_dim: (n_u32.div_ceil(THREADS_PER_BLOCK).max(1), 1, 1),
409 block_dim: (THREADS_PER_BLOCK, 1, 1),
410 shared_mem_bytes: 0,
411 };
412 unsafe {
417 let mut builder = stream.launch_builder(&func);
418 builder.arg(&n_i32);
419 builder.arg(&scale);
420 builder.arg(&a_dev);
421 builder.arg(&b_dev);
422 builder.arg(&mut cells_dev);
423 builder.arg(&mut offsets_dev);
424 builder.arg(&mut status_dev);
425 builder.launch(cfg)
426 }
427 .map(|_event_pair| ())
428 .gpu_ctx("survival_flex_prep: launch partition_cells")?;
429
430 let cells_host = stream
431 .clone_dtoh(&cells_dev)
432 .gpu_ctx("survival_flex_prep: download cells_flat")?;
433 let status_host = stream
434 .clone_dtoh(&status_dev)
435 .gpu_ctx("survival_flex_prep: download status")?;
436 for (i, st) in status_host.iter().enumerate() {
437 if *st != 0 {
438 return Err(gam_gpu_err!(
439 "survival_flex_prep: row {i} status={st} from device kernel"
440 ));
441 }
442 }
443 assert_eq!(cells_host.len(), n * 18);
444 let mut out: PartitionCellsOutput = Vec::with_capacity(n);
451 for i in 0..n {
452 let base = i * 18;
453 let c0 = cells_host[base + 2];
454 let c1 = cells_host[base + 3];
455 let mut cell = trivial_partition_cell(rows[i].a, rows[i].b, scale);
456 cell.cell.c0 = c0;
459 cell.cell.c1 = c1;
460 out.push(vec![cell]);
461 }
462 Ok(Some(out))
463 }
464
465 pub(super) fn cell_primary_fixed_partials_baseline(
474 layout: super::FlexPrimaryLayout,
475 n_cells_total: usize,
476 ) -> Result<Vec<f64>, GpuError> {
477 use super::kernel_src::DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC;
478 static FP_PTX_CACHE: PtxModuleCache = PtxModuleCache::new();
479
480 let n_i32 = i32::try_from(n_cells_total).map_err(|_| {
481 gam_gpu_err!(
482 "cell_primary_fixed_partials_baseline: n_cells={n_cells_total} exceeds i32"
483 )
484 })?;
485 let n_u32 = u32::try_from(n_cells_total).map_err(|_| {
486 gam_gpu_err!(
487 "cell_primary_fixed_partials_baseline: n_cells={n_cells_total} exceeds u32"
488 )
489 })?;
490 let (ctx, stream) = context_and_stream()
491 .map_err(|reason| gam_gpu::gpu_error::GpuError::DriverCallFailed { reason })?;
492 let module = FP_PTX_CACHE.get_or_compile(
493 &ctx,
494 "survival_flex_prep::cell_primary_fixed_partials",
495 DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC,
496 )?;
497 let func = module
498 .load_function("denested_cell_primary_fixed_partials_kernel")
499 .gpu_ctx("survival_flex_prep: load_function fixed_partials")?;
500
501 let per_cell = 12usize + 40usize * (layout.r as usize);
502 let scale = 1.0f64;
503 let mut out_dev = stream
504 .alloc_zeros::<f64>(n_cells_total * per_cell)
505 .gpu_ctx("survival_flex_prep: alloc fixed_partials")?;
506 let mut status_dev = stream
507 .alloc_zeros::<u8>(n_cells_total)
508 .gpu_ctx("survival_flex_prep: alloc fixed_partials status")?;
509 let cfg = LaunchConfig {
510 grid_dim: (n_u32.div_ceil(THREADS_PER_BLOCK).max(1), 1, 1),
511 block_dim: (THREADS_PER_BLOCK, 1, 1),
512 shared_mem_bytes: 0,
513 };
514 unsafe {
517 let mut builder = stream.launch_builder(&func);
518 builder.arg(&n_i32);
519 builder.arg(&layout.r);
520 builder.arg(&layout.g_slot);
521 builder.arg(&scale);
522 builder.arg(&mut out_dev);
523 builder.arg(&mut status_dev);
524 builder.launch(cfg)
525 }
526 .map(|_event_pair| ())
527 .gpu_ctx("survival_flex_prep: launch fixed_partials")?;
528 let out_host = stream
529 .clone_dtoh(&out_dev)
530 .gpu_ctx("survival_flex_prep: download fixed_partials")?;
531 let status_host = stream
532 .clone_dtoh(&status_dev)
533 .gpu_ctx("survival_flex_prep: download fixed_partials status")?;
534 for (i, st) in status_host.iter().enumerate() {
535 if *st != 0 {
536 return Err(gam_gpu_err!(
537 "survival_flex_prep: fixed_partials cell {i} status={st}"
538 ));
539 }
540 }
541 Ok(out_host)
542 }
543}
544
545#[cfg(not(target_os = "linux"))]
546mod device_dispatch {
547 use super::{PartitionCellsOutput, PartitionCellsRowInputs};
548 use gam_gpu::gpu_err as gam_gpu_err;
549 use gam_gpu::gpu_error::GpuError;
550
551 pub(super) fn partition_cells_baseline(
552 rows: &[PartitionCellsRowInputs<'_>],
553 scale: f64,
554 ) -> Result<Option<PartitionCellsOutput>, GpuError> {
555 let first = rows.first().map(|row| (row.a, row.b));
559 log::trace!(
560 "survival_flex_prep::partition_cells_baseline declined on non-linux \
561 (n_rows={}, scale={scale}, first_ab={first:?})",
562 rows.len(),
563 );
564 Ok(None)
565 }
566
567 pub(super) fn cell_primary_fixed_partials_baseline(
568 layout: super::FlexPrimaryLayout,
569 n_cells_total: usize,
570 ) -> Result<Vec<f64>, GpuError> {
571 Err(gam_gpu_err!(
572 "survival_flex_prep::cell_primary_fixed_partials_baseline: CUDA only supported on linux \
573 (would have launched n_cells={n_cells_total}, r={}, g_slot={})",
574 layout.r,
575 layout.g_slot
576 ))
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 #[test]
585 fn empty_partition_inputs_short_circuit() {
586 let out = try_device_partition_cells(&[]).expect("ok");
587 assert!(out.is_some());
588 assert!(out.unwrap().is_empty());
589 }
590
591 #[test]
592 fn nonempty_partition_with_betas_declines() {
593 let beta = [0.0_f64];
594 let inputs = [PartitionCellsRowInputs {
595 a: 0.0,
596 b: 1.0,
597 beta_h: Some(&beta),
598 beta_w: None,
599 }];
600 let out = try_device_partition_cells(&inputs).expect("ok");
601 assert!(out.is_none());
604 }
605
606 #[test]
607 fn empty_fixed_partials_inputs_short_circuit() {
608 let out = try_device_cell_primary_fixed_partials(&[]).expect("ok");
609 assert!(out.is_some());
610 assert!(out.unwrap().partials.is_empty());
611 }
612
613 #[test]
614 fn empty_cells_per_row_returns_empty_partials() {
615 let inputs = [CellPrimaryFixedPartialsRowInputs {
616 cells: &[],
617 layout: FlexPrimaryLayout { r: 4, g_slot: 3 },
618 }];
619 let out = try_device_cell_primary_fixed_partials(&inputs).expect("ok");
620 let some = out.expect("Some when all rows have zero cells");
621 assert_eq!(some.partials.len(), 1);
622 assert!(some.partials[0].is_empty());
623 }
624
625 #[test]
626 fn kernel_src_strings_are_nonempty() {
627 assert!(!kernel_src::DENESTED_PARTITION_CELLS_KERNEL_SRC.is_empty());
628 assert!(!kernel_src::DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC.is_empty());
629 }
630
631 #[test]
632 fn trivial_partition_cell_matches_cpu_empty_split_branch() {
633 let cell = trivial_partition_cell(2.5, -1.25, 1.0);
637 assert_eq!(cell.cell.c0, 2.5);
638 assert_eq!(cell.cell.c1, -1.25);
639 assert_eq!(cell.cell.c2, 0.0);
640 assert_eq!(cell.cell.c3, 0.0);
641 assert!(cell.cell.left.is_infinite() && cell.cell.left.is_sign_negative());
642 assert!(cell.cell.right.is_infinite() && cell.cell.right.is_sign_positive());
643 }
644}