#[cfg(target_os = "linux")]
use crate::gpu::cubic_cell::{
CubicCellDerivativeMomentHostView, CubicCellDerivativeMomentOutput, CubicCellMomentStatus,
GpuCellBranchTag, branch::classify_cell_for_gpu,
};
#[cfg(target_os = "linux")]
use crate::gpu::error::GpuError;
#[cfg(target_os = "linux")]
use crate::gpu::error::GpuResultExt;
#[cfg(target_os = "linux")]
use crate::gpu_err;
#[cfg(target_os = "linux")]
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(target_os = "linux")]
use cudarc::driver::{CudaContext, CudaModule, CudaStream};
#[cfg(target_os = "linux")]
pub(crate) fn try_device_moments_resident(
view: &CubicCellDerivativeMomentHostView<'_>,
) -> Result<Option<CubicCellDerivativeMomentOutput>, GpuError> {
let backend = match CubicCellGpuBackend::probe() {
Ok(b) => b,
Err(GpuError::DriverLibraryUnavailable { .. }) => return Ok(None),
Err(other) => return Err(other),
};
backend.dispatch_device_resident(view).map(Some)
}
#[cfg(target_os = "linux")]
#[must_use]
pub(crate) struct CubicCellGpuBackend {
inner: CubicCellGpuContextLinux,
}
#[cfg(target_os = "linux")]
struct CubicCellGpuContextLinux {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
modules: Mutex<std::collections::HashMap<usize, Arc<CudaModule>>>,
}
#[cfg(target_os = "linux")]
impl CubicCellGpuBackend {
pub(crate) fn probe() -> Result<&'static Self, GpuError> {
static BACKEND: OnceLock<Result<CubicCellGpuBackend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(Self::probe_linux)
.as_ref()
.map_err(GpuError::clone)
}
#[cfg(target_os = "linux")]
fn probe_linux() -> Result<Self, GpuError> {
let parts = crate::gpu::backend_probe::probe_cuda_backend("cubic_cell")?;
Ok(CubicCellGpuBackend {
inner: CubicCellGpuContextLinux {
ctx: parts.ctx,
stream: parts.stream,
modules: Mutex::new(std::collections::HashMap::new()),
},
})
}
#[cfg(target_os = "linux")]
fn module_for_degree(&self, max_degree: usize) -> Result<Arc<CudaModule>, GpuError> {
let key = max_degree;
{
let guard = self
.inner
.modules
.lock()
.gpu_ctx("cubic_cell module cache mutex poisoned")?;
if let Some(module) = guard.get(&key) {
return Ok(Arc::clone(module));
}
}
let source =
crate::gpu::cubic_cell::kernel_src::build_cubic_deriv_moments_kernel_source(max_degree);
let ptx = cudarc::nvrtc::compile_ptx(&source).gpu_ctx_with(|err| {
format!("cubic_cell NVRTC compile (degree={max_degree}) failed: {err}")
})?;
let module = self.inner.ctx.load_module(ptx).gpu_ctx_with(|err| {
format!("cubic_cell module load (degree={max_degree}) failed: {err}")
})?;
let mut guard = self
.inner
.modules
.lock()
.gpu_ctx("cubic_cell module cache mutex poisoned")?;
let entry = guard.entry(key).or_insert(module);
Ok(Arc::clone(entry))
}
#[cfg(target_os = "linux")]
fn dispatch_device_resident(
&self,
view: &CubicCellDerivativeMomentHostView<'_>,
) -> Result<CubicCellDerivativeMomentOutput, GpuError> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
let n_cells = view.cells.len();
let stride = view.max_degree + 1;
assert!(n_cells > 0, "caller must guard empty views");
let mut status_host = vec![CubicCellMomentStatus::Ok as u8; n_cells];
let mut branch_code = vec![255_u8; n_cells];
let mut left = vec![0.0_f64; n_cells];
let mut right = vec![0.0_f64; n_cells];
let mut c0 = vec![0.0_f64; n_cells];
let mut c1 = vec![0.0_f64; n_cells];
let mut c2 = vec![0.0_f64; n_cells];
let mut c3 = vec![0.0_f64; n_cells];
for (i, &gpu_cell) in view.cells.iter().enumerate() {
left[i] = gpu_cell.left;
right[i] = gpu_cell.right;
c0[i] = gpu_cell.c0;
c1[i] = gpu_cell.c1;
c2[i] = gpu_cell.c2;
c3[i] = gpu_cell.c3;
match classify_cell_for_gpu(gpu_cell) {
Ok(host_tag) => {
if host_tag != view.branches[i] {
status_host[i] = CubicCellMomentStatus::InvalidInterval as u8;
continue;
}
branch_code[i] = match host_tag {
GpuCellBranchTag::Affine => 0,
GpuCellBranchTag::NonAffineFinite => 1,
GpuCellBranchTag::AffineTail => 2,
};
}
Err(code) => {
status_host[i] = code as u8;
}
}
}
let max_degree = view.max_degree;
let module = self.module_for_degree(max_degree)?;
let kernel_name = format!("cubic_deriv_moments_d{max_degree}");
let func = module
.load_function(&kernel_name)
.gpu_ctx_with(|err| format!("cubic_cell load_function {kernel_name}: {err}"))?;
let stream = &self.inner.stream;
let d_left = stream
.clone_htod(&left)
.gpu_ctx("cubic_cell device-resident memcpy left")?;
let d_right = stream
.clone_htod(&right)
.gpu_ctx("cubic_cell device-resident memcpy right")?;
let d_c0 = stream
.clone_htod(&c0)
.gpu_ctx("cubic_cell device-resident memcpy c0")?;
let d_c1 = stream
.clone_htod(&c1)
.gpu_ctx("cubic_cell device-resident memcpy c1")?;
let d_c2 = stream
.clone_htod(&c2)
.gpu_ctx("cubic_cell device-resident memcpy c2")?;
let d_c3 = stream
.clone_htod(&c3)
.gpu_ctx("cubic_cell device-resident memcpy c3")?;
let d_branch = stream
.clone_htod(&branch_code)
.gpu_ctx("cubic_cell device-resident memcpy branch")?;
let mut d_moments = stream
.alloc_zeros::<f64>(n_cells * stride)
.map_err(|err| gpu_err!("cubic_cell device-resident alloc moments: {err}"))?;
let mut d_status = stream
.alloc_zeros::<u8>(n_cells)
.gpu_ctx("cubic_cell device-resident alloc status")?;
let warps_per_block: u32 = 4;
let block: u32 = 32 * warps_per_block;
let n_u32: u32 = u32::try_from(n_cells)
.map_err(|_| gpu_err!("cubic_cell n_cells={n_cells} overflows u32"))?;
let grid: u32 = n_u32.div_ceil(warps_per_block).max(1);
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let n_cells_u32 = n_u32;
let mut builder = stream.launch_builder(&func);
builder
.arg(&d_left)
.arg(&d_right)
.arg(&d_c0)
.arg(&d_c1)
.arg(&d_c2)
.arg(&d_c3)
.arg(&d_branch)
.arg(&mut d_moments)
.arg(&mut d_status)
.arg(&n_cells_u32);
unsafe { builder.launch(cfg) }.gpu_ctx("cubic_cell device-resident kernel launch")?;
let kernel_status = stream
.clone_dtoh(&d_status)
.gpu_ctx("cubic_cell device-resident DtoH status")?;
stream
.synchronize()
.gpu_ctx("cubic_cell device-resident sync after kernel")?;
for i in 0..n_cells {
if status_host[i] == CubicCellMomentStatus::Ok as u8 {
status_host[i] = kernel_status[i];
}
}
drop(d_status);
Ok(CubicCellDerivativeMomentOutput::Device {
d_moments,
status: status_host,
stride,
n_cells,
})
}
}
#[cfg(all(test, target_os = "linux"))]
mod tests {
use super::*;
use crate::gpu::cubic_cell::{
CubicCellDerivativeMomentHostView, CubicCellDerivativeMomentOutput,
CubicCellMomentResidency, CubicCellMomentStatus, GpuCellBranchTag, GpuDenestedCubicCell,
try_build_cubic_cell_derivative_moments,
};
use crate::gpu::error::GpuError;
use crate::gpu::error::GpuResultExt;
use crate::gpu::runtime::GpuRuntime;
fn download_moments(
backend: &CubicCellGpuBackend,
d_moments: &cudarc::driver::CudaSlice<f64>,
) -> Result<Vec<f64>, GpuError> {
let stream = &backend.inner.stream;
let host = stream
.clone_dtoh(d_moments)
.gpu_ctx("cubic_cell tests::download_moments DtoH")?;
stream
.synchronize()
.gpu_ctx("cubic_cell tests::download_moments sync")?;
Ok(host)
}
#[cfg(target_os = "linux")]
#[test]
fn cubic_cell_device_residency_matches_cpu_all_branches() {
use crate::families::cubic_cell_kernel::{
DenestedCubicCell, evaluate_cell_derivative_moments_uncached,
};
if GpuRuntime::global().is_none() {
eprintln!("[cubic_cell device-residency parity] no CUDA runtime — skipping");
return;
}
let cpu_cells = vec![
DenestedCubicCell {
left: -1.0,
right: 1.0,
c0: 0.2,
c1: 0.7,
c2: 0.0,
c3: 0.0,
},
DenestedCubicCell {
left: -1.25,
right: -0.2,
c0: -0.35,
c1: 0.85,
c2: 0.4,
c3: 0.0,
},
DenestedCubicCell {
left: -0.5,
right: 1.7,
c0: 0.2,
c1: -0.6,
c2: 0.25,
c3: 0.18,
},
DenestedCubicCell {
left: f64::NEG_INFINITY,
right: -0.7,
c0: 0.1,
c1: 0.5,
c2: 0.0,
c3: 0.0,
},
DenestedCubicCell {
left: 1.2,
right: f64::INFINITY,
c0: -0.05,
c1: 0.3,
c2: 0.0,
c3: 0.0,
},
DenestedCubicCell {
left: f64::NEG_INFINITY,
right: f64::INFINITY,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
},
];
let cells_gpu: Vec<GpuDenestedCubicCell> = cpu_cells
.iter()
.map(|c| GpuDenestedCubicCell {
left: c.left,
right: c.right,
c0: c.c0,
c1: c.c1,
c2: c.c2,
c3: c.c3,
})
.collect();
let branches: Vec<GpuCellBranchTag> = cpu_cells
.iter()
.map(|c| {
if !c.left.is_finite() || !c.right.is_finite() {
GpuCellBranchTag::AffineTail
} else if c.c2 == 0.0 && c.c3 == 0.0 {
GpuCellBranchTag::Affine
} else {
GpuCellBranchTag::NonAffineFinite
}
})
.collect();
for &max_degree in &[9_usize, 15, 21] {
let view = CubicCellDerivativeMomentHostView {
cells: &cells_gpu,
branches: &branches,
max_degree,
residency: CubicCellMomentResidency::Device,
};
let out = try_build_cubic_cell_derivative_moments(view)
.expect("device-residency dispatch must succeed with CUDA")
.expect("non-empty input must yield output");
let (d_moments, status, stride, n_cells) = match out {
CubicCellDerivativeMomentOutput::Device {
d_moments,
status,
stride,
n_cells,
} => (d_moments, status, stride, n_cells),
CubicCellDerivativeMomentOutput::Host { .. } => panic!(
"device residency must produce CubicCellDerivativeMomentOutput::Device on a CUDA host"
),
};
assert_eq!(stride, max_degree + 1);
assert_eq!(n_cells, cpu_cells.len());
assert_eq!(status.len(), cpu_cells.len());
let backend = CubicCellGpuBackend::probe().expect("backend probe");
let host_moments =
download_moments(backend, &d_moments).expect("DtoH download for parity check");
for (i, &cpu_cell) in cpu_cells.iter().enumerate() {
assert_eq!(
status[i],
CubicCellMomentStatus::Ok as u8,
"cell {i} must classify Ok (status={})",
status[i]
);
let row = &host_moments[i * stride..(i + 1) * stride];
let cpu_state = evaluate_cell_derivative_moments_uncached(cpu_cell, max_degree)
.expect("cpu reference");
for (k, (&got, &want)) in row.iter().zip(cpu_state.moments.iter()).enumerate() {
let abs = (got - want).abs();
let denom = want.abs().max(1.0);
let rel = abs / denom;
assert!(
abs <= 1e-12 || rel <= 1e-11,
"device parity drift at degree={max_degree} cell={i} k={k} \
gpu={got:.17e} cpu={want:.17e} abs={abs:.3e} rel={rel:.3e}"
);
}
}
}
}
}