use crate::families::cubic_cell_kernel::{
DenestedCubicCell, DenestedPartitionCell, LocalSpanCubic,
};
use crate::gpu::error::GpuError;
pub mod kernel_src {
pub const DENESTED_PARTITION_CELLS_KERNEL_SRC: &str = r#"
// f64 throughout (no --use_fast_math).
extern "C" {
__device__ __forceinline__ double pos_inf_f64() {
// IEEE-754 +inf bit pattern: 0x7ff0000000000000.
return __longlong_as_double((long long)0x7ff0000000000000LL);
}
__device__ __forceinline__ double neg_inf_f64() {
// IEEE-754 -inf bit pattern: 0xfff0000000000000.
return __longlong_as_double((long long)0xfff0000000000000LL);
}
__global__ void denested_partition_cells_kernel(
int n_rows,
double scale,
const double *a_per_row,
const double *b_per_row,
double *out_cells_flat, // 18 doubles per row (single cell)
unsigned int *out_row_offsets, // length n_rows + 1
unsigned char *out_status // length n_rows
) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n_rows) return;
double a = a_per_row[i];
double b = b_per_row[i];
double *cell = out_cells_flat + (long long)i * 18;
// ── cell: (-inf, +inf, c0=a*scale, c1=b*scale, c2=0, c3=0) ──
cell[0] = neg_inf_f64();
cell[1] = pos_inf_f64();
cell[2] = a * scale;
cell[3] = b * scale;
cell[4] = 0.0;
cell[5] = 0.0;
// ── score_span (zero cubic, left=0,right=1) ──
cell[6] = 0.0; cell[7] = 1.0;
cell[8] = 0.0; cell[9] = 0.0; cell[10] = 0.0; cell[11] = 0.0;
// ── link_span (zero cubic, left=0,right=1) ──
cell[12] = 0.0; cell[13] = 1.0;
cell[14] = 0.0; cell[15] = 0.0; cell[16] = 0.0; cell[17] = 0.0;
// ── row offset: one cell per row ──
out_row_offsets[i] = (unsigned int)i;
if (i == n_rows - 1) {
out_row_offsets[n_rows] = (unsigned int)n_rows;
}
out_status[i] = 0;
}
} // extern "C"
"#;
pub const DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC: &str = r#"
// f64 throughout (no --use_fast_math).
extern "C" {
__global__ void denested_cell_primary_fixed_partials_kernel(
int n_cells_total,
unsigned int r,
unsigned int g_slot,
double scale,
double *out_partials_flat, // (12 + 40·r) doubles per cell
unsigned char *out_status
) {
int cell = blockIdx.x * blockDim.x + threadIdx.x;
if (cell >= n_cells_total) return;
unsigned int per_cell = 12u + 40u * r;
double *base = out_partials_flat + (long long)cell * (long long)per_cell;
// Zero the whole block (cheap; r is small).
for (unsigned int s = 0; s < per_cell; ++s) {
base[s] = 0.0;
}
// dc_da = [1, 0, 0, 0] · scale
base[0] = scale;
// dc_daa, dc_daaa already zero.
// g-slot fills (offset = 12 + 4·g_slot within each per-cell run).
// coeff_u [g] = dc_db = [0, 1, 0, 0] · scale
// coeff_au [g] = dc_dab = [0, 0, 0, 0]
// coeff_bu [g] = dc_dbb = [0, 0, 0, 0]
// coeff_aau [g] = dc_daab = [0, 0, 0, 0]
// coeff_abu [g] = dc_dabb = [0, 0, 0, 0]
// coeff_bbu [g] = dc_dbbb = [0, 0, 0, 0]
// (third partials all zero in the no-runtime case)
unsigned int g_off = 12u + 4u * g_slot;
base[g_off + 1] = scale; // coeff_u[g][1] = scale
out_status[cell] = 0;
}
} // extern "C"
"#;
}
#[derive(Clone, Copy, Debug)]
pub struct PartitionCellsRowInputs<'a> {
pub a: f64,
pub b: f64,
pub beta_h: Option<&'a [f64]>,
pub beta_w: Option<&'a [f64]>,
}
pub type PartitionCellsOutput = Vec<Vec<DenestedPartitionCell>>;
pub fn try_device_partition_cells(
rows: &[PartitionCellsRowInputs<'_>],
) -> Result<Option<PartitionCellsOutput>, GpuError> {
if rows.is_empty() {
return Ok(Some(Vec::new()));
}
let trivial = rows
.iter()
.all(|r| r.beta_h.is_none() && r.beta_w.is_none());
if !trivial {
return Ok(None);
}
device_dispatch::partition_cells_baseline(rows, 1.0)
}
#[derive(Clone, Copy, Debug)]
pub struct CellPrimaryFixedPartialsCellInputs {
pub score_span: LocalSpanCubic,
pub link_span: LocalSpanCubic,
pub z_basis: f64,
pub u_basis: f64,
}
#[derive(Clone, Copy, Debug)]
pub struct CellPrimaryFixedPartialsRowInputs<'a> {
pub a: f64,
pub b: f64,
pub cells: &'a [CellPrimaryFixedPartialsCellInputs],
pub layout: FlexPrimaryLayout,
}
#[derive(Clone, Debug, Default)]
pub struct CellPrimaryFixedPartialsOutput {
pub partials: Vec<Vec<Vec<f64>>>,
}
#[derive(Clone, Copy, Debug)]
pub struct FlexPrimaryLayout {
pub r: u32,
pub g_slot: u32,
}
pub fn try_device_cell_primary_fixed_partials(
rows: &[CellPrimaryFixedPartialsRowInputs<'_>],
) -> Result<Option<CellPrimaryFixedPartialsOutput>, GpuError> {
if rows.is_empty() {
return Ok(Some(CellPrimaryFixedPartialsOutput::default()));
}
let trivial_spans = rows.iter().all(|row| {
row.cells
.iter()
.all(|cell| span_is_zero(cell.score_span) && span_is_zero(cell.link_span))
});
if !trivial_spans {
return Ok(None);
}
let layout0 = rows[0].layout;
if !rows
.iter()
.all(|r| r.layout.r == layout0.r && r.layout.g_slot == layout0.g_slot)
{
return Ok(None);
}
let mut row_cell_counts: Vec<usize> = rows.iter().map(|r| r.cells.len()).collect();
let total_cells: usize = row_cell_counts.iter().copied().sum();
if total_cells == 0 {
let mut partials: Vec<Vec<Vec<f64>>> = Vec::with_capacity(rows.len());
for _ in 0..rows.len() {
partials.push(Vec::new());
}
return Ok(Some(CellPrimaryFixedPartialsOutput { partials }));
}
let flat = match device_dispatch::cell_primary_fixed_partials_baseline(layout0, total_cells) {
Ok(flat) => flat,
Err(_) => return Ok(None),
};
let per_cell = 12usize + 40usize * (layout0.r as usize);
let mut partials: Vec<Vec<Vec<f64>>> = Vec::with_capacity(rows.len());
let mut cursor = 0usize;
for n_cells in row_cell_counts.drain(..) {
let mut row_cells: Vec<Vec<f64>> = Vec::with_capacity(n_cells);
for _ in 0..n_cells {
row_cells.push(flat[cursor..cursor + per_cell].to_vec());
cursor += per_cell;
}
partials.push(row_cells);
}
assert_eq!(cursor, flat.len());
Ok(Some(CellPrimaryFixedPartialsOutput { partials }))
}
#[inline]
fn span_is_zero(span: LocalSpanCubic) -> bool {
span.c0 == 0.0 && span.c1 == 0.0 && span.c2 == 0.0 && span.c3 == 0.0
}
pub fn trivial_partition_cell(a: f64, b: f64, scale: f64) -> DenestedPartitionCell {
DenestedPartitionCell {
cell: DenestedCubicCell {
left: f64::NEG_INFINITY,
right: f64::INFINITY,
c0: a * scale,
c1: b * scale,
c2: 0.0,
c3: 0.0,
},
score_span: LocalSpanCubic {
left: 0.0,
right: 1.0,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
},
link_span: LocalSpanCubic {
left: 0.0,
right: 1.0,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
},
}
}
#[cfg(target_os = "linux")]
mod device_dispatch {
use super::kernel_src::DENESTED_PARTITION_CELLS_KERNEL_SRC;
use super::{PartitionCellsOutput, PartitionCellsRowInputs, trivial_partition_cell};
use crate::gpu::common::PtxModuleCache;
use crate::gpu::error::{GpuError, GpuResultExt};
use crate::gpu::solver::context_and_stream;
use cudarc::driver::{LaunchConfig, PushKernelArg};
static PARTITION_PTX_CACHE: PtxModuleCache = PtxModuleCache::new();
const THREADS_PER_BLOCK: u32 = 128;
pub(super) fn partition_cells_baseline(
rows: &[PartitionCellsRowInputs<'_>],
scale: f64,
) -> Result<Option<PartitionCellsOutput>, GpuError> {
let n = rows.len();
let n_u32 = u32::try_from(n)
.map_err(|_| crate::gpu_err!("partition_cells_baseline: n_rows={n} exceeds u32"))?;
let n_i32 = i32::try_from(n)
.map_err(|_| crate::gpu_err!("partition_cells_baseline: n_rows={n} exceeds i32"))?;
let (ctx, stream) = match context_and_stream() {
Ok(pair) => pair,
Err(_) => return Ok(None),
};
let module = PARTITION_PTX_CACHE.get_or_compile(
&ctx,
"survival_flex_prep::partition_cells",
DENESTED_PARTITION_CELLS_KERNEL_SRC,
)?;
let func = module
.load_function("denested_partition_cells_kernel")
.gpu_ctx("survival_flex_prep: load_function partition_cells")?;
let a_host: Vec<f64> = rows.iter().map(|r| r.a).collect();
let b_host: Vec<f64> = rows.iter().map(|r| r.b).collect();
let a_dev = stream
.clone_htod(&a_host)
.gpu_ctx("survival_flex_prep: upload a_per_row")?;
let b_dev = stream
.clone_htod(&b_host)
.gpu_ctx("survival_flex_prep: upload b_per_row")?;
let mut cells_dev = stream
.alloc_zeros::<f64>(n * 18)
.gpu_ctx("survival_flex_prep: alloc cells_flat")?;
let mut offsets_dev = stream
.alloc_zeros::<u32>(n + 1)
.gpu_ctx("survival_flex_prep: alloc row_offsets")?;
let mut status_dev = stream
.alloc_zeros::<u8>(n)
.gpu_ctx("survival_flex_prep: alloc status")?;
let cfg = LaunchConfig {
grid_dim: (n_u32.div_ceil(THREADS_PER_BLOCK).max(1), 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&n_i32);
builder.arg(&scale);
builder.arg(&a_dev);
builder.arg(&b_dev);
builder.arg(&mut cells_dev);
builder.arg(&mut offsets_dev);
builder.arg(&mut status_dev);
builder.launch(cfg)
}
.map(|_event_pair| ())
.gpu_ctx("survival_flex_prep: launch partition_cells")?;
let cells_host = stream
.clone_dtoh(&cells_dev)
.gpu_ctx("survival_flex_prep: download cells_flat")?;
let status_host = stream
.clone_dtoh(&status_dev)
.gpu_ctx("survival_flex_prep: download status")?;
for (i, st) in status_host.iter().enumerate() {
if *st != 0 {
return Err(crate::gpu_err!(
"survival_flex_prep: row {i} status={st} from device kernel"
));
}
}
assert_eq!(cells_host.len(), n * 18);
let mut out: PartitionCellsOutput = Vec::with_capacity(n);
for i in 0..n {
let base = i * 18;
let c0 = cells_host[base + 2];
let c1 = cells_host[base + 3];
let mut cell = trivial_partition_cell(rows[i].a, rows[i].b, scale);
cell.cell.c0 = c0;
cell.cell.c1 = c1;
out.push(vec![cell]);
}
Ok(Some(out))
}
pub(super) fn cell_primary_fixed_partials_baseline(
layout: super::FlexPrimaryLayout,
n_cells_total: usize,
) -> Result<Vec<f64>, GpuError> {
use super::kernel_src::DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC;
static FP_PTX_CACHE: PtxModuleCache = PtxModuleCache::new();
let n_i32 = i32::try_from(n_cells_total).map_err(|_| {
crate::gpu_err!(
"cell_primary_fixed_partials_baseline: n_cells={n_cells_total} exceeds i32"
)
})?;
let n_u32 = u32::try_from(n_cells_total).map_err(|_| {
crate::gpu_err!(
"cell_primary_fixed_partials_baseline: n_cells={n_cells_total} exceeds u32"
)
})?;
let (ctx, stream) = context_and_stream()
.map_err(|reason| crate::gpu::error::GpuError::DriverCallFailed { reason })?;
let module = FP_PTX_CACHE.get_or_compile(
&ctx,
"survival_flex_prep::cell_primary_fixed_partials",
DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC,
)?;
let func = module
.load_function("denested_cell_primary_fixed_partials_kernel")
.gpu_ctx("survival_flex_prep: load_function fixed_partials")?;
let per_cell = 12usize + 40usize * (layout.r as usize);
let scale = 1.0f64;
let mut out_dev = stream
.alloc_zeros::<f64>(n_cells_total * per_cell)
.gpu_ctx("survival_flex_prep: alloc fixed_partials")?;
let mut status_dev = stream
.alloc_zeros::<u8>(n_cells_total)
.gpu_ctx("survival_flex_prep: alloc fixed_partials status")?;
let cfg = LaunchConfig {
grid_dim: (n_u32.div_ceil(THREADS_PER_BLOCK).max(1), 1, 1),
block_dim: (THREADS_PER_BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&n_i32);
builder.arg(&layout.r);
builder.arg(&layout.g_slot);
builder.arg(&scale);
builder.arg(&mut out_dev);
builder.arg(&mut status_dev);
builder.launch(cfg)
}
.map(|_event_pair| ())
.gpu_ctx("survival_flex_prep: launch fixed_partials")?;
let out_host = stream
.clone_dtoh(&out_dev)
.gpu_ctx("survival_flex_prep: download fixed_partials")?;
let status_host = stream
.clone_dtoh(&status_dev)
.gpu_ctx("survival_flex_prep: download fixed_partials status")?;
for (i, st) in status_host.iter().enumerate() {
if *st != 0 {
return Err(crate::gpu_err!(
"survival_flex_prep: fixed_partials cell {i} status={st}"
));
}
}
Ok(out_host)
}
}
#[cfg(not(target_os = "linux"))]
mod device_dispatch {
use super::{PartitionCellsOutput, PartitionCellsRowInputs};
use crate::gpu::error::GpuError;
pub(super) fn partition_cells_baseline(
rows: &[PartitionCellsRowInputs<'_>],
scale: f64,
) -> Result<Option<PartitionCellsOutput>, GpuError> {
log::trace!(
"survival_flex_prep::partition_cells_baseline declined on non-linux \
(n_rows={}, scale={scale})",
rows.len()
);
Ok(None)
}
pub(super) fn cell_primary_fixed_partials_baseline(
layout: super::FlexPrimaryLayout,
n_cells_total: usize,
) -> Result<Vec<f64>, GpuError> {
Err(crate::gpu_err!(
"survival_flex_prep::cell_primary_fixed_partials_baseline: CUDA only supported on linux \
(would have launched n_cells={n_cells_total}, r={}, g_slot={})",
layout.r,
layout.g_slot
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_partition_inputs_short_circuit() {
let out = try_device_partition_cells(&[]).expect("ok");
assert!(out.is_some());
assert!(out.unwrap().is_empty());
}
#[test]
fn nonempty_partition_with_betas_declines() {
let beta = [0.0_f64];
let inputs = [PartitionCellsRowInputs {
a: 0.0,
b: 1.0,
beta_h: Some(&beta),
beta_w: None,
}];
let out = try_device_partition_cells(&inputs).expect("ok");
assert!(out.is_none());
}
#[test]
fn empty_fixed_partials_inputs_short_circuit() {
let out = try_device_cell_primary_fixed_partials(&[]).expect("ok");
assert!(out.is_some());
assert!(out.unwrap().partials.is_empty());
}
#[test]
fn empty_cells_per_row_returns_empty_partials() {
let inputs = [CellPrimaryFixedPartialsRowInputs {
a: 0.0,
b: 1.0,
cells: &[],
layout: FlexPrimaryLayout { r: 4, g_slot: 3 },
}];
let out = try_device_cell_primary_fixed_partials(&inputs).expect("ok");
let some = out.expect("Some when all rows have zero cells");
assert_eq!(some.partials.len(), 1);
assert!(some.partials[0].is_empty());
}
#[test]
fn kernel_src_strings_are_nonempty() {
assert!(!kernel_src::DENESTED_PARTITION_CELLS_KERNEL_SRC.is_empty());
assert!(!kernel_src::DENESTED_CELL_PRIMARY_FIXED_PARTIALS_KERNEL_SRC.is_empty());
}
#[test]
fn trivial_partition_cell_matches_cpu_empty_split_branch() {
let cell = trivial_partition_cell(2.5, -1.25, 1.0);
assert_eq!(cell.cell.c0, 2.5);
assert_eq!(cell.cell.c1, -1.25);
assert_eq!(cell.cell.c2, 0.0);
assert_eq!(cell.cell.c3, 0.0);
assert!(cell.cell.left.is_infinite() && cell.cell.left.is_sign_negative());
assert!(cell.cell.right.is_infinite() && cell.cell.right.is_sign_positive());
}
}