pub(crate) mod branch;
pub(crate) mod device;
pub(crate) mod host_substrate;
pub(crate) mod kernel_src;
use crate::gpu::error::GpuError;
pub(crate) use host_substrate::{HostMomentBatch, build_host_moments};
pub(crate) const MAX_SUPPORTED_DEGREE: usize = 24;
#[derive(Clone, Copy, Debug, PartialEq)]
pub(crate) struct GpuDenestedCubicCell {
pub left: f64,
pub right: f64,
pub c0: f64,
pub c1: f64,
pub c2: f64,
pub c3: f64,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum GpuCellBranchTag {
Affine,
NonAffineFinite,
AffineTail,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum CubicCellMomentResidency {
Host,
#[cfg(target_os = "linux")]
Device,
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum CubicCellMomentStatus {
Ok = 0,
InvalidInterval = 1,
NonAffineInfiniteInterval = 2,
NonFiniteCoefficient = 3,
NonFiniteEvaluation = 4,
}
pub(crate) struct CubicCellDerivativeMomentHostView<'a> {
pub cells: &'a [GpuDenestedCubicCell],
pub branches: &'a [GpuCellBranchTag],
pub max_degree: usize,
pub residency: CubicCellMomentResidency,
}
#[derive(Debug)]
pub(crate) enum CubicCellDerivativeMomentOutput {
Host {
moments: Vec<f64>,
status: Vec<u8>,
stride: usize,
},
#[cfg(target_os = "linux")]
Device {
d_moments: cudarc::driver::CudaSlice<f64>,
status: Vec<u8>,
stride: usize,
n_cells: usize,
},
}
pub(crate) fn try_build_cubic_cell_derivative_moments(
input: CubicCellDerivativeMomentHostView<'_>,
) -> Result<Option<CubicCellDerivativeMomentOutput>, GpuError> {
if input.cells.len() != input.branches.len() {
crate::gpu_bail!(
"gpu cubic-cell substrate: cells.len()={} != branches.len()={}",
input.cells.len(),
input.branches.len()
);
}
if input.max_degree > MAX_SUPPORTED_DEGREE {
crate::gpu_bail!(
"gpu cubic-cell substrate: max_degree={} exceeds MAX_SUPPORTED_DEGREE={}",
input.max_degree,
MAX_SUPPORTED_DEGREE
);
}
if input.cells.is_empty() {
return Ok(None);
}
match input.residency {
CubicCellMomentResidency::Host => {
#[cfg(target_os = "linux")]
{
if let Some(batch) = device::try_device_moments(&input)? {
return Ok(Some(into_host_output(batch)));
}
}
let batch = build_host_moments(&input)
.map_err(|reason| GpuError::DriverCallFailed { reason })?;
Ok(Some(into_host_output(batch)))
}
#[cfg(target_os = "linux")]
CubicCellMomentResidency::Device => {
if let Some(device_batch) = device::try_device_moments_resident(&input)? {
return Ok(Some(device_batch));
}
let batch = build_host_moments(&input)
.map_err(|reason| GpuError::DriverCallFailed { reason })?;
Ok(Some(into_host_output(batch)))
}
}
}
#[inline]
fn into_host_output(batch: HostMomentBatch) -> CubicCellDerivativeMomentOutput {
CubicCellDerivativeMomentOutput::Host {
moments: batch.moments,
status: batch.status,
stride: batch.stride,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn affine_cell() -> GpuDenestedCubicCell {
GpuDenestedCubicCell {
left: -1.0,
right: 1.0,
c0: 0.0,
c1: 0.0,
c2: 0.0,
c3: 0.0,
}
}
fn host_view<'a>(
cells: &'a [GpuDenestedCubicCell],
branches: &'a [GpuCellBranchTag],
max_degree: usize,
) -> CubicCellDerivativeMomentHostView<'a> {
CubicCellDerivativeMomentHostView {
cells,
branches,
max_degree,
residency: CubicCellMomentResidency::Host,
}
}
#[test]
fn host_residency_returns_real_moments() {
let cells = [affine_cell()];
let branches = [GpuCellBranchTag::Affine];
let out = try_build_cubic_cell_derivative_moments(host_view(&cells, &branches, 9))
.expect("host substrate succeeds on a valid cell")
.expect("non-empty input produces output");
let (moments, status, stride) = match out {
CubicCellDerivativeMomentOutput::Host {
moments,
status,
stride,
} => (moments, status, stride),
#[cfg(target_os = "linux")]
CubicCellDerivativeMomentOutput::Device { .. } => {
panic!("host residency request must not yield Device output")
}
};
assert_eq!(stride, 10);
assert_eq!(status, vec![CubicCellMomentStatus::Ok as u8]);
assert!((moments[0] - 1.7112488348667447).abs() < 1e-12);
}
#[test]
fn empty_input_returns_ok_none() {
let out = try_build_cubic_cell_derivative_moments(host_view(&[], &[], 9)).expect("ok");
assert!(out.is_none());
}
#[test]
fn rejects_mismatched_lengths() {
let cells = [affine_cell()];
let branches: [GpuCellBranchTag; 0] = [];
let err =
try_build_cubic_cell_derivative_moments(host_view(&cells, &branches, 9)).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("cells.len()"), "got: {msg}");
assert!(msg.contains("branches.len()"), "got: {msg}");
}
#[test]
fn rejects_degree_above_supported_max() {
let cells = [affine_cell()];
let branches = [GpuCellBranchTag::Affine];
let err = try_build_cubic_cell_derivative_moments(host_view(
&cells,
&branches,
MAX_SUPPORTED_DEGREE + 1,
))
.unwrap_err();
assert!(err.to_string().contains("MAX_SUPPORTED_DEGREE"));
}
#[test]
fn status_codes_match_kernel_abi() {
assert_eq!(CubicCellMomentStatus::Ok as u8, 0);
assert_eq!(CubicCellMomentStatus::InvalidInterval as u8, 1);
assert_eq!(CubicCellMomentStatus::NonAffineInfiniteInterval as u8, 2);
assert_eq!(CubicCellMomentStatus::NonFiniteCoefficient as u8, 3);
assert_eq!(CubicCellMomentStatus::NonFiniteEvaluation as u8, 4);
}
}