#[cfg(target_os = "linux")]
use std::collections::HashMap;
#[cfg(target_os = "linux")]
use std::collections::hash_map::DefaultHasher;
#[cfg(target_os = "linux")]
use std::hash::{Hash, Hasher};
#[cfg(target_os = "linux")]
use std::sync::Mutex;
use std::sync::OnceLock;
use super::error::GpuError;
#[cfg(target_os = "linux")]
use super::error::GpuResultExt;
#[cfg(target_os = "linux")]
use crate::gpu_err;
#[cfg(target_os = "linux")]
use cudarc::driver::{CudaContext, CudaModule, CudaStream};
pub const DEGREE: usize = 3;
pub const ACTIVE_PER_SPAN: usize = DEGREE + 1; pub const PROD_LEN: usize = 2 * DEGREE + 1; pub const PAIRS_PER_SPAN: usize = ACTIVE_PER_SPAN * (ACTIVE_PER_SPAN + 1) / 2;
fn binomial(n: usize, k: usize) -> f64 {
if k > n {
return 0.0;
}
let k = k.min(n - k);
let mut acc: f64 = 1.0;
for i in 0..k {
acc = acc * (n - i) as f64 / (i + 1) as f64;
}
acc
}
#[inline]
pub fn active_pair_index(a: usize, b: usize) -> usize {
let (lo, hi) = if a <= b { (a, b) } else { (b, a) };
hi * (hi + 1) / 2 + lo
}
#[inline]
pub fn span_is_active(width: f64) -> bool {
width > 0.0 && width.is_finite()
}
pub fn cubic_basis_local_coeffs(t: &[f64], k: usize) -> [[f64; ACTIVE_PER_SPAN]; ACTIVE_PER_SPAN] {
assert!(k + 4 <= t.len(), "knot index out of range");
assert!(k >= 3, "need 3 left-side knots for cubic B-splines");
let tk = t[k];
let zero: [f64; ACTIVE_PER_SPAN] = [0.0; ACTIVE_PER_SPAN];
let mut level: Vec<[f64; ACTIVE_PER_SPAN]> = vec![{
let mut v = zero;
v[0] = 1.0;
v
}];
for p in 1..=DEGREE {
let mut next: Vec<[f64; ACTIVE_PER_SPAN]> = vec![zero; p + 1];
for j in 0..=p {
let i = k - p + j; if j >= 1 {
let denom = t[i + p] - t[i];
if denom > 0.0 {
let shift = tk - t[i];
let inv = 1.0 / denom;
let prev = &level[j - 1];
for q in 0..p {
next[j][q] += inv * shift * prev[q];
next[j][q + 1] += inv * prev[q];
}
}
}
if j < p {
let denom = t[i + p + 1] - t[i + 1];
if denom > 0.0 {
let shift = t[i + p + 1] - tk;
let inv = 1.0 / denom;
let prev = &level[j];
for q in 0..p {
next[j][q] += inv * shift * prev[q];
next[j][q + 1] += -inv * prev[q];
}
}
}
}
level = next;
}
let mut out: [[f64; ACTIVE_PER_SPAN]; ACTIVE_PER_SPAN] =
[[0.0; ACTIVE_PER_SPAN]; ACTIVE_PER_SPAN];
for (idx, v) in level.into_iter().enumerate() {
out[idx] = v;
}
out
}
pub fn differentiate_basis_coeffs(a: [f64; ACTIVE_PER_SPAN]) -> [f64; ACTIVE_PER_SPAN] {
[a[1], 2.0 * a[2], 3.0 * a[3], 0.0]
}
#[inline]
pub fn convolve_basis_pair(
a: [f64; ACTIVE_PER_SPAN],
b: [f64; ACTIVE_PER_SPAN],
) -> [f64; PROD_LEN] {
let mut c = [0.0; PROD_LEN];
for i in 0..ACTIVE_PER_SPAN {
if a[i] == 0.0 {
continue;
}
for j in 0..ACTIVE_PER_SPAN {
c[i + j] += a[i] * b[j];
}
}
c
}
pub fn moment_1d_about(c: [f64; PROD_LEN], width: f64, nu: usize, m_minus_left: f64) -> f64 {
if !span_is_active(width) {
return 0.0;
}
let mut acc = 0.0;
let lm = -m_minus_left; for s in 0..=nu {
let bin = binomial(nu, s);
let lm_pow = lm.powi((nu - s) as i32);
let mut ss = 0.0;
let mut h_pow = width.powi((s + 1) as i32);
for q in 0..PROD_LEN {
ss += c[q] * h_pow / ((q + s + 1) as f64);
h_pow *= width;
}
acc += bin * lm_pow * ss;
}
acc
}
#[inline]
pub fn moment_1d_local(c: [f64; PROD_LEN], width: f64, nu: usize) -> f64 {
if !span_is_active(width) {
return 0.0;
}
let mut acc = 0.0;
let mut h_pow = width.powi((nu + 1) as i32);
for q in 0..PROD_LEN {
acc += c[q] * h_pow / ((q + nu + 1) as f64);
h_pow *= width;
}
acc
}
const GL20_X: [f64; 20] = [
-0.993_128_599_185_094_9,
-0.963_971_927_277_913_8,
-0.912_234_428_251_325_9,
-0.839_116_971_822_218_8,
-0.746_331_906_460_150_8,
-0.636_053_680_726_515_0,
-0.510_867_001_950_827_1,
-0.373_706_088_715_419_6,
-0.227_785_851_141_645_1,
-0.076_526_521_133_497_3,
0.076_526_521_133_497_3,
0.227_785_851_141_645_1,
0.373_706_088_715_419_6,
0.510_867_001_950_827_1,
0.636_053_680_726_515_0,
0.746_331_906_460_150_8,
0.839_116_971_822_218_8,
0.912_234_428_251_325_9,
0.963_971_927_277_913_8,
0.993_128_599_185_094_9,
];
const GL20_W: [f64; 20] = [
0.017_614_007_139_152_1,
0.040_601_429_800_386_9,
0.062_672_048_334_109_1,
0.083_276_741_576_704_7,
0.101_930_119_817_240_5,
0.118_194_531_961_518_4,
0.131_688_638_449_176_6,
0.142_096_109_318_382_1,
0.149_172_986_472_603_7,
0.152_753_387_130_725_8,
0.152_753_387_130_725_8,
0.149_172_986_472_603_7,
0.142_096_109_318_382_1,
0.131_688_638_449_176_6,
0.118_194_531_961_518_4,
0.101_930_119_817_240_5,
0.083_276_741_576_704_7,
0.062_672_048_334_109_1,
0.040_601_429_800_386_9,
0.017_614_007_139_152_1,
];
pub fn moment_1d_gauss_legendre(
c: [f64; PROD_LEN],
left: f64,
width: f64,
nu: usize,
m: f64,
) -> f64 {
if !span_is_active(width) {
return 0.0;
}
let half = 0.5 * width;
let center = left + half;
let mut acc = 0.0;
for k in 0..20 {
let x = center + half * GL20_X[k];
let u = x - left;
let mut p = c[PROD_LEN - 1];
for q in (0..PROD_LEN - 1).rev() {
p = p * u + c[q];
}
let mom = (x - m).powi(nu as i32);
acc += GL20_W[k] * mom * p;
}
acc * half
}
#[derive(Clone, Debug)]
pub struct AxisCubicMomentTables {
pub span_indices: Vec<usize>,
pub left: Vec<f64>,
pub width: Vec<f64>,
pub prod_coeff: Vec<f64>,
pub derivative_left: u8,
pub derivative_right: u8,
}
fn derive_basis_coeffs(mut a: [f64; ACTIVE_PER_SPAN], d: u8) -> [f64; ACTIVE_PER_SPAN] {
for _ in 0..d {
a = differentiate_basis_coeffs(a);
}
a
}
impl AxisCubicMomentTables {
pub fn build(t: &[f64], derivative_left: u8, derivative_right: u8) -> Self {
assert!(
t.len() >= 2 * DEGREE + 2,
"knot vector too short for cubic B-splines: got {} knots, need ≥ {}",
t.len(),
2 * DEGREE + 2
);
let mut span_indices = Vec::new();
let mut left = Vec::new();
let mut width = Vec::new();
let mut prod_coeff = Vec::new();
for k in DEGREE..(t.len() - DEGREE - 1) {
let w = t[k + 1] - t[k];
if !span_is_active(w) {
continue;
}
let basis = cubic_basis_local_coeffs(t, k);
let mut left_basis = basis;
let mut right_basis = basis;
for a in left_basis.iter_mut() {
*a = derive_basis_coeffs(*a, derivative_left);
}
for a in right_basis.iter_mut() {
*a = derive_basis_coeffs(*a, derivative_right);
}
let mut span_prod = [[0.0f64; PROD_LEN]; PAIRS_PER_SPAN];
for a in 0..ACTIVE_PER_SPAN {
for b in a..ACTIVE_PER_SPAN {
let pair_idx = active_pair_index(a, b);
span_prod[pair_idx] = convolve_basis_pair(left_basis[a], right_basis[b]);
}
}
span_indices.push(k);
left.push(t[k]);
width.push(w);
prod_coeff.extend(span_prod.iter().flatten().copied());
}
Self {
span_indices,
left,
width,
prod_coeff,
derivative_left,
derivative_right,
}
}
pub fn n_spans(&self) -> usize {
self.left.len()
}
pub fn prod(&self, span: usize, pair: usize) -> [f64; PROD_LEN] {
let off = (span * PAIRS_PER_SPAN + pair) * PROD_LEN;
let mut out = [0.0; PROD_LEN];
out.copy_from_slice(&self.prod_coeff[off..off + PROD_LEN]);
out
}
pub fn moment_local(&self, span: usize, pair: usize, nu: usize) -> f64 {
moment_1d_local(self.prod(span, pair), self.width[span], nu)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MomentLayout {
AlphaMajor,
}
#[derive(Clone, Debug)]
pub struct CubicMomentSpec {
pub alphas: Vec<Vec<u8>>,
pub derivative_left: Vec<Vec<u8>>,
pub derivative_right: Vec<Vec<u8>>,
pub layout: MomentLayout,
}
impl CubicMomentSpec {
pub fn d(&self) -> usize {
self.alphas.first().map(|v| v.len()).unwrap_or(0)
}
pub fn n_alpha(&self) -> usize {
self.alphas.len()
}
}
pub fn build_axis_tables_cpu(
spec: &CubicMomentSpec,
knots_per_axis: &[Vec<f64>],
) -> Vec<Vec<AxisCubicMomentTables>> {
assert_eq!(spec.d(), knots_per_axis.len(), "axis count mismatch");
let d = spec.d();
let mut out = Vec::with_capacity(d);
for axis in 0..d {
let mut sigs: Vec<(u8, u8)> = (0..spec.n_alpha())
.map(|i| {
(
spec.derivative_left[i][axis],
spec.derivative_right[i][axis],
)
})
.collect();
sigs.sort_unstable();
sigs.dedup();
let mut axis_tables = Vec::with_capacity(sigs.len());
for (dl, dr) in sigs {
axis_tables.push(AxisCubicMomentTables::build(&knots_per_axis[axis], dl, dr));
}
out.push(axis_tables);
}
out
}
pub fn tensor_hex_moment_cpu(
axis_tables: &[&AxisCubicMomentTables],
cell_span: &[usize],
alpha: &[u8],
pair_per_axis: &[usize],
) -> f64 {
assert_eq!(axis_tables.len(), cell_span.len());
assert_eq!(axis_tables.len(), alpha.len());
assert_eq!(axis_tables.len(), pair_per_axis.len());
let mut prod = 1.0;
for r in 0..axis_tables.len() {
let i_r = axis_tables[r].moment_local(cell_span[r], pair_per_axis[r], alpha[r] as usize);
prod *= i_r;
if prod == 0.0 {
return 0.0;
}
}
prod
}
#[derive(Debug)]
pub struct DeviceCubicMomentTable {
pub n_cells: usize,
pub pair_tuple_count: usize,
pub n_alpha: usize,
pub layout: MomentLayout,
#[cfg(target_os = "linux")]
pub values: cudarc::driver::CudaSlice<f64>,
#[cfg(not(target_os = "linux"))]
pub values: Vec<f64>,
}
#[derive(Debug)]
pub struct DeviceMarginalTable {
pub n_axes: usize,
pub n_spans_per_axis: Vec<usize>,
pub n_alpha_per_axis: Vec<usize>,
#[cfg(target_os = "linux")]
pub values: cudarc::driver::CudaSlice<f64>,
#[cfg(not(target_os = "linux"))]
pub values: Vec<f64>,
}
#[cfg(target_os = "linux")]
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
struct HexMomentModuleKey {
cc_major: i32,
cc_minor: i32,
d: u32,
amax: u32,
nalpha: u32,
alpha_hash: u64,
deriv_hash: u64,
layout_tag: u8,
}
#[cfg(target_os = "linux")]
fn hash_alpha_table(alphas: &[Vec<u8>]) -> u64 {
let mut h = DefaultHasher::new();
(alphas.len() as u64).hash(&mut h);
for row in alphas {
(row.len() as u64).hash(&mut h);
for &v in row {
v.hash(&mut h);
}
}
h.finish()
}
#[cfg(target_os = "linux")]
fn hash_deriv_table(deriv_left: &[Vec<u8>], deriv_right: &[Vec<u8>]) -> u64 {
let mut h = DefaultHasher::new();
(deriv_left.len() as u64).hash(&mut h);
for row in deriv_left {
(row.len() as u64).hash(&mut h);
for &v in row {
v.hash(&mut h);
}
}
(deriv_right.len() as u64).hash(&mut h);
for row in deriv_right {
(row.len() as u64).hash(&mut h);
for &v in row {
v.hash(&mut h);
}
}
h.finish()
}
#[inline]
#[cfg(target_os = "linux")]
fn layout_tag(layout: MomentLayout) -> u8 {
match layout {
MomentLayout::AlphaMajor => 0,
}
}
#[cfg(target_os = "linux")]
fn build_hex_tensor_kernel_source(d: usize, amax: usize, alphas: &[Vec<u8>]) -> String {
let nalpha = alphas.len();
let mut alpha_decl = String::new();
alpha_decl.push_str("__constant__ unsigned char ALPHA_TABLE[NALPHA][D] = {\n");
for row in alphas {
alpha_decl.push_str(" { ");
for (k, v) in row.iter().enumerate() {
if k > 0 {
alpha_decl.push_str(", ");
}
alpha_decl.push_str(&v.to_string());
}
alpha_decl.push_str(" },\n");
}
alpha_decl.push_str("};\n");
format!(
r#"
#define D {d}
#define AMAX {amax}
#define NALPHA {nalpha}
#define PROD_LEN 7
{alpha_decl}
// Closed-form 1D moment about the cell's left endpoint (m = L):
// I_nu^{{ij}}(L) = sum_{{q=0..6}} c_q * h^{{q+nu+1}} / (q + nu + 1)
// Implemented with a Horner-style accumulating `h_pow` and fma. We loop on
// nu_plus_one = nu + 1 up to the compile-time AMAX so the divides are
// constant-foldable.
__device__ __forceinline__ double moment_1d_local(
const double *cprod, // [PROD_LEN], product-poly coefs on cell
double h, // cell width
unsigned int nu // moment exponent
) {{
double h_pow = 1.0;
for (unsigned int s = 0; s <= nu; ++s) {{
h_pow *= h; // after loop: h_pow == h^{{nu+1}} ... continues below
}}
double acc = 0.0;
#pragma unroll
for (int q = 0; q < PROD_LEN; ++q) {{
double denom = (double)(q + nu + 1);
acc = fma(cprod[q] / denom, h_pow, acc);
h_pow *= h;
}}
return acc;
}}
// One thread = one (cell, alpha-slot). Block (32, 8, 1): x → cell, y → alpha.
// Inputs (all device-resident):
// axis_prod_coeff_flat: f64 buffer, concatenated per-axis tables; offset
// per axis given by `axis_offset[axis]` (in elements). Each axis table
// is [n_spans_axis][PAIRS_PER_SPAN(=10)][PROD_LEN(=7)] row-major.
// axis_offset: i64[D]
// cell_span_per_axis: i32[n_cells * D] — active-span index per axis.
// cell_pair_per_axis: i32[n_cells * D] — unordered-pair slot (0..=9) per axis.
// cell_width_per_axis: f64[n_cells * D] — cell width per axis.
// Output:
// out: f64[NALPHA * out_stride], alpha-major; thread writes out[a*out_stride + c].
extern "C" __global__ void cubic_hex_tensor_moments(
const double *axis_prod_coeff_flat,
const long long *axis_offset,
const int *cell_span_per_axis,
const int *cell_pair_per_axis,
const double *cell_width_per_axis,
int n_cells,
long long out_stride,
double *out
) {{
const int cell = blockIdx.x * blockDim.x + threadIdx.x;
const int alpha = blockIdx.y * blockDim.y + threadIdx.y;
if (cell >= n_cells || alpha >= NALPHA) return;
double prod = 1.0;
#pragma unroll
for (int r = 0; r < D; ++r) {{
const int span_r = cell_span_per_axis[cell * D + r];
const int pair_r = cell_pair_per_axis[cell * D + r];
const double width_r = cell_width_per_axis[cell * D + r];
const long long base = axis_offset[r]
+ (long long)span_r * 10LL * (long long)PROD_LEN
+ (long long)pair_r * (long long)PROD_LEN;
const double *cprod = axis_prod_coeff_flat + base;
const unsigned int nu = (unsigned int)ALPHA_TABLE[alpha][r];
const double mu = moment_1d_local(cprod, width_r, nu);
prod *= mu;
}}
out[(long long)alpha * out_stride + (long long)cell] = prod;
}}
"#,
d = d,
amax = amax,
nalpha = nalpha,
alpha_decl = alpha_decl,
)
}
#[cfg(target_os = "linux")]
struct CubicMomentBackendInner {
ctx: std::sync::Arc<CudaContext>,
stream: std::sync::Arc<CudaStream>,
modules: Mutex<HashMap<HexMomentModuleKey, std::sync::Arc<CudaModule>>>,
tet_modules: Mutex<HashMap<TetMomentModuleKey, std::sync::Arc<CudaModule>>>,
cc_major: i32,
cc_minor: i32,
}
#[must_use]
pub struct CubicMomentBackend {
#[cfg(target_os = "linux")]
inner: CubicMomentBackendInner,
}
impl CubicMomentBackend {
pub const fn compiled() -> bool {
cfg!(target_os = "linux")
}
pub fn download_alpha_major(&self, dev: &DeviceCubicMomentTable) -> Result<Vec<f64>, GpuError> {
#[cfg(target_os = "linux")]
{
let stream = &self.inner.stream;
let host = stream
.clone_dtoh(&dev.values)
.gpu_ctx("cubic_bspline_moments download_alpha_major dtov")?;
stream
.synchronize()
.gpu_ctx("cubic_bspline_moments download_alpha_major sync")?;
Ok(host)
}
#[cfg(not(target_os = "linux"))]
{
Ok(dev.values.clone())
}
}
pub fn probe() -> Result<&'static Self, GpuError> {
static BACKEND: OnceLock<Result<CubicMomentBackend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(|| {
#[cfg(target_os = "linux")]
{
Self::probe_linux()
}
#[cfg(not(target_os = "linux"))]
{
Err(GpuError::DriverLibraryUnavailable {
reason: "cubic_bspline_moments GPU backend is Linux-only".to_string(),
})
}
})
.as_ref()
.map_err(GpuError::clone)
}
#[cfg(target_os = "linux")]
fn probe_linux() -> Result<Self, GpuError> {
let runtime = super::runtime::GpuRuntime::global().ok_or_else(|| {
GpuError::DriverLibraryUnavailable {
reason: "cubic_bspline_moments backend: no CUDA runtime available".to_string(),
}
})?;
let ordinal = runtime.selected_device().ordinal;
let ctx = super::runtime::cuda_context_for(ordinal).ok_or_else(|| {
gpu_err!(
"cubic_bspline_moments backend: failed to create CUDA context for device {ordinal}"
)
})?;
let stream = ctx.default_stream();
let cap = &runtime.selected_device().capability;
let cc_major = cap.compute_major;
let cc_minor = cap.compute_minor;
Ok(CubicMomentBackend {
inner: CubicMomentBackendInner {
ctx,
stream,
modules: Mutex::new(HashMap::new()),
tet_modules: Mutex::new(HashMap::new()),
cc_major,
cc_minor,
},
})
}
#[cfg(target_os = "linux")]
fn tet_module_for(
&self,
key: TetMomentModuleKey,
src_factory: impl FnOnce() -> String,
) -> Result<std::sync::Arc<CudaModule>, GpuError> {
if let Ok(guard) = self.inner.tet_modules.lock() {
if let Some(existing) = guard.get(&key) {
return Ok(existing.clone());
}
}
let src = src_factory();
let ptx = cudarc::nvrtc::compile_ptx(&src).gpu_ctx_with(|err| {
format!(
"tetrahedral_moments NVRTC compile (D={}, NBETA={}, NALPHA={}): {err}",
key.d, key.nbeta, key.nalpha
)
})?;
let module = self
.inner
.ctx
.load_module(ptx)
.gpu_ctx("tetrahedral_moments module load")?;
if let Ok(mut guard) = self.inner.tet_modules.lock() {
guard.entry(key).or_insert_with(|| module.clone());
}
Ok(module)
}
#[cfg(target_os = "linux")]
fn module_for(
&self,
key: HexMomentModuleKey,
src_factory: impl FnOnce() -> String,
) -> Result<std::sync::Arc<CudaModule>, GpuError> {
if let Ok(guard) = self.inner.modules.lock() {
if let Some(existing) = guard.get(&key) {
return Ok(existing.clone());
}
}
let src = src_factory();
let ptx = cudarc::nvrtc::compile_ptx(&src).gpu_ctx_with(|err| {
format!(
"cubic_bspline_moments NVRTC compile (D={}, AMAX={}, NALPHA={}): {err}",
key.d, key.amax, key.nalpha
)
})?;
let module = self
.inner
.ctx
.load_module(ptx)
.gpu_ctx("cubic_bspline_moments module load")?;
if let Ok(mut guard) = self.inner.modules.lock() {
guard.entry(key).or_insert_with(|| module.clone());
}
Ok(module)
}
}
#[derive(Clone, Debug)]
pub struct HexCellTable {
pub span_per_axis: Vec<i32>,
pub pair_per_axis: Vec<i32>,
pub width_per_axis: Vec<f64>,
pub n_cells: usize,
pub d: usize,
}
impl HexCellTable {
pub fn validate(&self) -> Result<(), GpuError> {
let want = self.n_cells * self.d;
if self.span_per_axis.len() != want
|| self.pair_per_axis.len() != want
|| self.width_per_axis.len() != want
{
crate::gpu_bail!(
"HexCellTable: expected length {want} (n_cells*d), got span={}, pair={}, width={}",
self.span_per_axis.len(),
self.pair_per_axis.len(),
self.width_per_axis.len(),
);
}
Ok(())
}
}
#[cfg(target_os = "linux")]
pub fn build_hex_tensor_moments_device(
spec: &CubicMomentSpec,
axis_tables: &[Vec<AxisCubicMomentTables>],
cells: &HexCellTable,
) -> Result<DeviceCubicMomentTable, GpuError> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
cells.validate()?;
if spec.d() != cells.d {
crate::gpu_bail!(
"build_hex_tensor_moments_device: spec.d()={} != cells.d={}",
spec.d(),
cells.d
);
}
if axis_tables.len() != cells.d {
crate::gpu_bail!(
"build_hex_tensor_moments_device: axis_tables.len()={} != d={}",
axis_tables.len(),
cells.d
);
}
for (axis, banks) in axis_tables.iter().enumerate() {
if banks.len() != 1 {
return Err(GpuError::NotYetImplemented {
reason: format!(
"build_hex_tensor_moments_device: axis {axis} has {} derivative banks; \
single-bank only in Phase 2 — use the CPU path or wait on Phase 3",
banks.len()
),
});
}
}
let nalpha = spec.n_alpha();
if nalpha == 0 || cells.n_cells == 0 {
return Err(GpuError::DriverCallFailed {
reason: "build_hex_tensor_moments_device: empty spec or cell list".to_string(),
});
}
let amax = spec
.alphas
.iter()
.flat_map(|row| row.iter().copied())
.max()
.unwrap_or(0) as usize;
let backend = CubicMomentBackend::probe()?;
let key = HexMomentModuleKey {
cc_major: backend.inner.cc_major,
cc_minor: backend.inner.cc_minor,
d: cells.d as u32,
amax: amax as u32,
nalpha: nalpha as u32,
alpha_hash: hash_alpha_table(&spec.alphas),
deriv_hash: hash_deriv_table(&spec.derivative_left, &spec.derivative_right),
layout_tag: layout_tag(spec.layout),
};
let d_for_src = cells.d;
let alphas_for_src = spec.alphas.clone();
let module = backend.module_for(key, move || {
build_hex_tensor_kernel_source(d_for_src, amax, &alphas_for_src)
})?;
let func = module
.load_function("cubic_hex_tensor_moments")
.gpu_ctx("cubic_bspline_moments load_function")?;
let stream = backend.inner.stream.clone();
let mut axis_offsets: Vec<i64> = Vec::with_capacity(cells.d);
let mut flat: Vec<f64> = Vec::new();
for banks in axis_tables.iter() {
axis_offsets.push(flat.len() as i64);
flat.extend_from_slice(&banks[0].prod_coeff);
}
let axis_flat_dev = stream
.clone_htod(flat.as_slice())
.gpu_ctx("cubic_bspline_moments htod axis_flat")?;
let axis_off_dev = stream
.clone_htod(axis_offsets.as_slice())
.gpu_ctx("cubic_bspline_moments htod axis_offsets")?;
let span_dev = stream
.clone_htod(cells.span_per_axis.as_slice())
.gpu_ctx("cubic_bspline_moments htod span")?;
let pair_dev = stream
.clone_htod(cells.pair_per_axis.as_slice())
.gpu_ctx("cubic_bspline_moments htod pair")?;
let width_dev = stream
.clone_htod(cells.width_per_axis.as_slice())
.gpu_ctx("cubic_bspline_moments htod width")?;
let out_stride = ((cells.n_cells + 31) / 32) * 32;
let mut out_dev = stream
.alloc_zeros::<f64>(out_stride * nalpha)
.gpu_ctx_with(|err| {
format!("cubic_bspline_moments alloc out (stride={out_stride}, nalpha={nalpha}): {err}")
})?;
let block_x: u32 = 32;
let block_y: u32 = 8;
let grid_x: u32 = ((cells.n_cells as u32) + block_x - 1) / block_x;
let grid_y: u32 = ((nalpha as u32) + block_y - 1) / block_y;
let cfg = LaunchConfig {
grid_dim: (grid_x, grid_y, 1),
block_dim: (block_x, block_y, 1),
shared_mem_bytes: 0,
};
let n_cells_i32: i32 = i32::try_from(cells.n_cells).map_err(|_| {
gpu_err!(
"cubic_bspline_moments n_cells={} overflows i32",
cells.n_cells
)
})?;
let out_stride_i64: i64 = out_stride as i64;
let mut builder = stream.launch_builder(&func);
builder
.arg(&axis_flat_dev)
.arg(&axis_off_dev)
.arg(&span_dev)
.arg(&pair_dev)
.arg(&width_dev)
.arg(&n_cells_i32)
.arg(&out_stride_i64)
.arg(&mut out_dev);
unsafe { builder.launch(cfg) }.gpu_ctx("cubic_bspline_moments kernel launch")?;
stream
.synchronize()
.gpu_ctx("cubic_bspline_moments synchronize")?;
Ok(DeviceCubicMomentTable {
n_cells: cells.n_cells,
pair_tuple_count: 1,
n_alpha: nalpha,
layout: spec.layout,
values: out_dev,
})
}
#[cfg(not(target_os = "linux"))]
pub fn build_hex_tensor_moments_device(
spec: &CubicMomentSpec,
axis_tables: &[Vec<AxisCubicMomentTables>],
cells: &HexCellTable,
) -> Result<DeviceCubicMomentTable, GpuError> {
Err(GpuError::DriverLibraryUnavailable {
reason: format!(
"cubic_bspline_moments GPU backend is Linux-only \
(layout={:?}, axis_tables_axes={}, n_cells={})",
spec.layout,
axis_tables.len(),
cells.n_cells,
),
})
}
#[derive(Clone, Debug)]
pub struct TetrahedralCellTable {
pub vertices: Vec<f64>,
pub cell_index: Vec<i32>,
pub cell_centers: Vec<f64>,
pub n_tets: usize,
pub n_cells: usize,
pub d: usize,
}
impl TetrahedralCellTable {
pub fn validate(&self) -> Result<(), GpuError> {
let want_v = self.n_tets * 4 * self.d;
if self.vertices.len() != want_v {
crate::gpu_bail!(
"TetrahedralCellTable: expected vertices len {want_v} (n_tets*4*d), got {}",
self.vertices.len()
);
}
if self.cell_index.len() != self.n_tets {
crate::gpu_bail!(
"TetrahedralCellTable: cell_index len {} != n_tets {}",
self.cell_index.len(),
self.n_tets
);
}
if self.cell_centers.len() != self.n_cells * self.d {
crate::gpu_bail!(
"TetrahedralCellTable: cell_centers len {} != n_cells*d {}",
self.cell_centers.len(),
self.n_cells * self.d
);
}
for (i, &c) in self.cell_index.iter().enumerate() {
if c < 0 || (c as usize) >= self.n_cells {
crate::gpu_bail!(
"TetrahedralCellTable: cell_index[{i}] = {c} out of range [0, {})",
self.n_cells
);
}
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct TetrahedralMomentSpec {
pub geom_betas: Vec<Vec<u8>>,
pub alphas: Vec<Vec<u8>>,
pub pairs_per_cell: usize,
pub layout: MomentLayout,
}
impl TetrahedralMomentSpec {
pub fn d(&self) -> usize {
self.alphas
.first()
.map(|v| v.len())
.or_else(|| self.geom_betas.first().map(|v| v.len()))
.unwrap_or(0)
}
pub fn n_alpha(&self) -> usize {
self.alphas.len()
}
pub fn n_beta(&self) -> usize {
self.geom_betas.len()
}
}
#[cfg(target_os = "linux")]
fn hash_beta_table(betas: &[Vec<u8>]) -> u64 {
let mut h = DefaultHasher::new();
(betas.len() as u64).hash(&mut h);
for row in betas {
(row.len() as u64).hash(&mut h);
for &v in row {
v.hash(&mut h);
}
}
h.finish()
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[cfg(target_os = "linux")]
struct TetMomentModuleKey {
cc_major: i32,
cc_minor: i32,
kind: u8, d: u32,
nbeta: u32,
nalpha: u32,
pairs: u32,
beta_hash: u64,
alpha_hash: u64,
layout_tag: u8,
}
#[inline]
fn fact_f64(n: u32) -> f64 {
let mut acc = 1.0f64;
for k in 2..=n {
acc *= k as f64;
}
acc
}
#[inline]
fn dirichlet_ref_simplex(n1: u32, n2: u32, n3: u32) -> f64 {
fact_f64(n1) * fact_f64(n2) * fact_f64(n3) / fact_f64(n1 + n2 + n3 + 3)
}
pub fn tetrahedral_geom_moment_cpu(
vertices: &[f64], cell_center: &[f64],
beta: &[u8],
d: usize,
) -> f64 {
assert_eq!(vertices.len(), 4 * d);
assert_eq!(cell_center.len(), d);
assert_eq!(beta.len(), d);
let mut q = vec![0.0f64; d];
let mut e = [vec![0.0f64; d], vec![0.0f64; d], vec![0.0f64; d]];
for r in 0..d {
q[r] = vertices[r] - cell_center[r];
for i in 0..3 {
e[i][r] = vertices[(i + 1) * d + r] - vertices[r];
}
}
assert!(d >= 3, "tetrahedral path requires D ≥ 3");
let det_b = if d == 3 {
let m = [
[e[0][0], e[1][0], e[2][0]],
[e[0][1], e[1][1], e[2][1]],
[e[0][2], e[1][2], e[2][2]],
];
(m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
- m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
+ m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]))
.abs()
} else {
let mut g = [[0.0f64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
let mut acc = 0.0;
for r in 0..d {
acc += e[i][r] * e[j][r];
}
g[i][j] = acc;
}
}
let det_g = g[0][0] * (g[1][1] * g[2][2] - g[1][2] * g[2][1])
- g[0][1] * (g[1][0] * g[2][2] - g[1][2] * g[2][0])
+ g[0][2] * (g[1][0] * g[2][1] - g[1][1] * g[2][0]);
det_g.max(0.0).sqrt()
};
let beta_total: u32 = beta.iter().map(|&v| v as u32).sum();
let n_max = beta_total as usize;
let stride = n_max + 1;
let size = stride * stride * stride;
let mut poly = vec![0.0f64; size];
poly[0] = 1.0;
for r in 0..d {
let br = beta[r] as u32;
if br == 0 {
continue;
}
let qr = q[r];
let e1 = e[0][r];
let e2 = e[1][r];
let e3 = e[2][r];
for _ in 0..br {
let mut next = vec![0.0f64; size];
for k in 0..stride {
for j in 0..stride {
for i in 0..stride {
let v = poly[i + stride * (j + stride * k)];
if v == 0.0 {
continue;
}
next[i + stride * (j + stride * k)] += qr * v;
if i + 1 < stride {
next[(i + 1) + stride * (j + stride * k)] += e1 * v;
}
if j + 1 < stride {
next[i + stride * ((j + 1) + stride * k)] += e2 * v;
}
if k + 1 < stride {
next[i + stride * (j + stride * (k + 1))] += e3 * v;
}
}
}
}
poly = next;
}
}
let mut acc = 0.0f64;
for k in 0..stride {
for j in 0..stride {
for i in 0..stride {
let coeff = poly[i + stride * (j + stride * k)];
if coeff == 0.0 {
continue;
}
acc += coeff * dirichlet_ref_simplex(i as u32, j as u32, k as u32);
}
}
}
acc * det_b
}
#[cfg(target_os = "linux")]
fn build_tet_geom_kernel_source(d: usize, betas: &[Vec<u8>]) -> String {
let nbeta = betas.len();
let beta_total_max: u32 = betas
.iter()
.map(|row| row.iter().map(|&v| v as u32).sum::<u32>())
.max()
.unwrap_or(0);
let stride = (beta_total_max + 1) as usize;
let poly_len = stride * stride * stride;
let mut beta_decl = String::new();
beta_decl.push_str("__constant__ unsigned char BETA_TABLE[NBETA][D] = {\n");
for row in betas {
beta_decl.push_str(" { ");
for (k, v) in row.iter().enumerate() {
if k > 0 {
beta_decl.push_str(", ");
}
beta_decl.push_str(&v.to_string());
}
beta_decl.push_str(" },\n");
}
beta_decl.push_str("};\n");
format!(
r#"
#define D {d}
#define NBETA {nbeta}
#define STRIDE {stride}
#define POLYLEN {poly_len}
{beta_decl}
// Dense polynomial in (u_1, u_2, u_3) with per-axis degree cap STRIDE-1.
// Coefficient of u_1^i u_2^j u_3^k stored at poly[i + STRIDE*(j + STRIDE*k)].
// Factorials up to 12 (covers any β-sum we ever ship).
__constant__ double FACT_LUT[13] = {{
1.0, 1.0, 2.0, 6.0, 24.0, 120.0, 720.0, 5040.0, 40320.0, 362880.0,
3628800.0, 39916800.0, 479001600.0
}};
__device__ __forceinline__ double dirichlet_3(int n1, int n2, int n3) {{
int s = n1 + n2 + n3 + 3;
return FACT_LUT[n1] * FACT_LUT[n2] * FACT_LUT[n3] / FACT_LUT[s];
}}
// Inputs (all device-resident):
// vertices_flat: f64[n_tets * 4 * D] — tet vertex coordinates.
// cell_index: i32[n_tets] — logical cell per tet (only used by
// stage 2; the geom kernel forwards
// it through unmodified output).
// cell_centers: f64[n_cells * D] — per-cell expansion point c0.
// Output:
// out: f64[NBETA * out_stride], β-major; thread (tet, β) writes
// out[β * out_stride + tet].
extern "C" __global__ void tetrahedral_geom_moments_kernel(
const double *vertices_flat,
const int *cell_index,
const double *cell_centers,
int n_tets,
long long out_stride,
double *out
) {{
const int tet = blockIdx.x * blockDim.x + threadIdx.x;
const int bidx = blockIdx.y * blockDim.y + threadIdx.y;
if (tet >= n_tets || bidx >= NBETA) return;
const int cell = cell_index[tet];
const double *vptr = vertices_flat + (long long)tet * 4LL * (long long)D;
const double *c0ptr = cell_centers + (long long)cell * (long long)D;
// Per-axis q[r] and e[i][r] (i = 0..2 → edges v1−v0, v2−v0, v3−v0).
double q[D];
double e[3][D];
#pragma unroll
for (int r = 0; r < D; ++r) {{
const double v0r = vptr[r];
q[r] = v0r - c0ptr[r];
e[0][r] = vptr[1*D + r] - v0r;
e[1][r] = vptr[2*D + r] - v0r;
e[2][r] = vptr[3*D + r] - v0r;
}}
// |det B| (D >= 3). For D > 3 fall back to sqrt(det(BᵀB)). Compile-time
// branched so the D = 3 fast path stays a single 3×3 cofactor expansion.
double det_b;
#if D == 3
det_b = fabs(
e[0][0] * (e[1][1] * e[2][2] - e[1][2] * e[2][1])
- e[1][0] * (e[0][1] * e[2][2] - e[0][2] * e[2][1])
+ e[2][0] * (e[0][1] * e[1][2] - e[0][2] * e[1][1])
);
#else
double g[3][3];
#pragma unroll
for (int i = 0; i < 3; ++i) {{
#pragma unroll
for (int j = 0; j < 3; ++j) {{
double acc = 0.0;
#pragma unroll
for (int r = 0; r < D; ++r) acc += e[i][r] * e[j][r];
g[i][j] = acc;
}}
}}
double det_g =
g[0][0] * (g[1][1] * g[2][2] - g[1][2] * g[2][1])
- g[0][1] * (g[1][0] * g[2][2] - g[1][2] * g[2][0])
+ g[0][2] * (g[1][0] * g[2][1] - g[1][1] * g[2][0]);
det_b = sqrt(det_g > 0.0 ? det_g : 0.0);
#endif
// Dense polynomial buffer in (u_1, u_2, u_3). Starts as the constant
// polynomial 1, then iteratively multiplied by (q_r + Σ_i e_{{i,r}} u_i)
// for each axis r, β_r times. We use two ping-pong buffers in local
// memory; the compiler keeps them in registers when POLYLEN is small.
double poly_a[POLYLEN];
double poly_b[POLYLEN];
#pragma unroll
for (int t = 0; t < POLYLEN; ++t) {{ poly_a[t] = 0.0; poly_b[t] = 0.0; }}
poly_a[0] = 1.0;
bool a_is_src = true;
#pragma unroll
for (int r = 0; r < D; ++r) {{
const unsigned int br = (unsigned int)BETA_TABLE[bidx][r];
const double qr = q[r];
const double e1 = e[0][r];
const double e2 = e[1][r];
const double e3 = e[2][r];
for (unsigned int rep = 0; rep < br; ++rep) {{
double *src = a_is_src ? poly_a : poly_b;
double *dst = a_is_src ? poly_b : poly_a;
#pragma unroll
for (int t = 0; t < POLYLEN; ++t) dst[t] = 0.0;
// Shift-add: dst = qr*src + e1*shift_i(src) + e2*shift_j(src) + e3*shift_k(src)
for (int k = 0; k < STRIDE; ++k) {{
for (int j = 0; j < STRIDE; ++j) {{
for (int i = 0; i < STRIDE; ++i) {{
const int idx = i + STRIDE * (j + STRIDE * k);
const double v = src[idx];
if (v == 0.0) continue;
dst[idx] += qr * v;
if (i + 1 < STRIDE) dst[(i+1) + STRIDE*(j + STRIDE*k)] += e1 * v;
if (j + 1 < STRIDE) dst[i + STRIDE*((j+1) + STRIDE*k)] += e2 * v;
if (k + 1 < STRIDE) dst[i + STRIDE*(j + STRIDE*(k+1))] += e3 * v;
}}
}}
}}
a_is_src = !a_is_src;
}}
}}
const double *final_poly = a_is_src ? poly_a : poly_b;
double acc = 0.0;
for (int k = 0; k < STRIDE; ++k) {{
for (int j = 0; j < STRIDE; ++j) {{
for (int i = 0; i < STRIDE; ++i) {{
const double coeff = final_poly[i + STRIDE * (j + STRIDE * k)];
if (coeff == 0.0) continue;
acc = fma(coeff, dirichlet_3(i, j, k), acc);
}}
}}
}}
out[(long long)bidx * out_stride + (long long)tet] = acc * det_b;
}}
"#,
d = d,
nbeta = nbeta,
stride = stride,
poly_len = poly_len,
beta_decl = beta_decl,
)
}
#[cfg(target_os = "linux")]
fn build_tet_contract_kernel_source(nalpha: usize, nbeta: usize, pairs: usize) -> String {
format!(
r#"
#define NALPHA {nalpha}
#define NBETA {nbeta}
#define PAIRS {pairs}
// Inputs:
// geom: f64[NBETA * geom_stride] β-major; geom[β*geom_stride + tet] = G_β(tet).
// weights: f64[n_cells * NALPHA * NBETA * PAIRS]
// row-major (cell, α, β, pair). Caller supplies the basis-Gram tensor
// per cell — for tensor-product cubics this is the same coefficient
// array used by the hex kernel, expressed as a coefficient on the
// geometric monomial basis (see the math contract block above).
// tet_to_cell: i32[n_tets] — same array as the stage-1 cell_index input.
// tet_offsets: i32[n_cells + 1] — CSR-style segment offsets so the kernel
// walks all tets of cell `c` in tet_to_cell[tet_offsets[c]..tet_offsets[c+1]].
//
// Output:
// out: f64[NALPHA * PAIRS * out_stride], α-major along outermost, then
// pair, then cell; thread (cell, α, pair) writes
// out[(α * PAIRS + pair) * out_stride + cell].
extern "C" __global__ void tetrahedral_contract_kernel(
const double *geom,
long long geom_stride,
const double *weights,
const int *tet_offsets,
const int *tet_index_in_segment,
int n_cells,
long long out_stride,
double *out
) {{
const int cell = blockIdx.x * blockDim.x + threadIdx.x;
const int alpha = blockIdx.y * blockDim.y + threadIdx.y;
const int pair = blockIdx.z * blockDim.z + threadIdx.z;
if (cell >= n_cells || alpha >= NALPHA || pair >= PAIRS) return;
const int beg = tet_offsets[cell];
const int end = tet_offsets[cell + 1];
// Per-cell weight base: weights[cell, α, β, pair] flattened.
const long long w_base = ((long long)cell * (long long)NALPHA + (long long)alpha)
* (long long)NBETA * (long long)PAIRS
+ (long long)pair;
double acc = 0.0;
for (int t = beg; t < end; ++t) {{
const int tet = tet_index_in_segment[t];
#pragma unroll
for (int b = 0; b < NBETA; ++b) {{
const double g_b = geom[(long long)b * geom_stride + (long long)tet];
const double w = weights[w_base + (long long)b * (long long)PAIRS];
acc = fma(w, g_b, acc);
}}
}}
out[((long long)alpha * (long long)PAIRS + (long long)pair) * out_stride
+ (long long)cell] = acc;
}}
"#,
nalpha = nalpha,
nbeta = nbeta,
pairs = pairs,
)
}
#[derive(Clone, Debug)]
pub struct TetrahedralMomentInputs<'a> {
pub spec: &'a TetrahedralMomentSpec,
pub cells: &'a TetrahedralCellTable,
pub tet_offsets: &'a [i32],
pub tet_index_in_segment: &'a [i32],
pub weights: &'a [f64],
}
#[cfg(target_os = "linux")]
pub fn try_device_tetrahedral_moments(
inputs: &TetrahedralMomentInputs<'_>,
) -> Result<Option<DeviceCubicMomentTable>, GpuError> {
let backend = match CubicMomentBackend::probe() {
Ok(b) => b,
Err(GpuError::DriverLibraryUnavailable { .. }) => return Ok(None),
Err(other) => return Err(other),
};
build_tetrahedral_moments_device(backend, inputs).map(Some)
}
#[cfg(not(target_os = "linux"))]
pub fn try_device_tetrahedral_moments(
inputs: &TetrahedralMomentInputs<'_>,
) -> Result<Option<DeviceCubicMomentTable>, GpuError> {
inputs.cells.validate()?;
if inputs.cells.d < 3 {
crate::gpu_bail!(
"try_device_tetrahedral_moments: tetrahedral path requires D >= 3 (got {})",
inputs.cells.d
);
}
Ok(None)
}
#[cfg(target_os = "linux")]
fn build_tetrahedral_moments_device(
backend: &CubicMomentBackend,
inputs: &TetrahedralMomentInputs<'_>,
) -> Result<DeviceCubicMomentTable, GpuError> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
inputs.cells.validate()?;
let spec = inputs.spec;
let cells = inputs.cells;
if spec.d() != cells.d {
crate::gpu_bail!(
"build_tetrahedral_moments_device: spec.d()={} != cells.d={}",
spec.d(),
cells.d
);
}
if cells.d < 3 {
crate::gpu_bail!(
"build_tetrahedral_moments_device: tetrahedral path requires D >= 3 (got {})",
cells.d
);
}
let nalpha = spec.n_alpha();
let nbeta = spec.n_beta();
let pairs = spec.pairs_per_cell;
if nalpha == 0 || nbeta == 0 || pairs == 0 || cells.n_tets == 0 || cells.n_cells == 0 {
crate::gpu_bail!(
"build_tetrahedral_moments_device: empty spec or cell list \
(nalpha={nalpha}, nbeta={nbeta}, pairs={pairs}, n_tets={}, n_cells={})",
cells.n_tets,
cells.n_cells
);
}
let want_off = cells.n_cells + 1;
if inputs.tet_offsets.len() != want_off {
crate::gpu_bail!(
"build_tetrahedral_moments_device: tet_offsets len {} != n_cells+1 {}",
inputs.tet_offsets.len(),
want_off
);
}
if inputs.tet_index_in_segment.len() != cells.n_tets {
crate::gpu_bail!(
"build_tetrahedral_moments_device: tet_index_in_segment len {} != n_tets {}",
inputs.tet_index_in_segment.len(),
cells.n_tets
);
}
let want_w = cells.n_cells * nalpha * nbeta * pairs;
if inputs.weights.len() != want_w {
crate::gpu_bail!(
"build_tetrahedral_moments_device: weights len {} != n_cells*nalpha*nbeta*pairs {}",
inputs.weights.len(),
want_w
);
}
let stream = backend.inner.stream.clone();
let geom_key = TetMomentModuleKey {
cc_major: backend.inner.cc_major,
cc_minor: backend.inner.cc_minor,
kind: 0,
d: cells.d as u32,
nbeta: nbeta as u32,
nalpha: 0,
pairs: 0,
beta_hash: hash_beta_table(&spec.geom_betas),
alpha_hash: 0,
layout_tag: layout_tag(spec.layout),
};
let d_for_geom = cells.d;
let betas_for_geom = spec.geom_betas.clone();
let geom_module = backend.tet_module_for(geom_key, move || {
build_tet_geom_kernel_source(d_for_geom, &betas_for_geom)
})?;
let geom_func = geom_module
.load_function("tetrahedral_geom_moments_kernel")
.gpu_ctx("tetrahedral_moments load_function geom")?;
let vertices_dev = stream
.clone_htod(cells.vertices.as_slice())
.gpu_ctx("tetrahedral_moments htod vertices")?;
let cell_index_dev = stream
.clone_htod(cells.cell_index.as_slice())
.gpu_ctx("tetrahedral_moments htod cell_index")?;
let cell_centers_dev = stream
.clone_htod(cells.cell_centers.as_slice())
.gpu_ctx("tetrahedral_moments htod cell_centers")?;
let geom_stride = ((cells.n_tets + 31) / 32) * 32;
let mut geom_dev = stream
.alloc_zeros::<f64>(geom_stride * nbeta)
.gpu_ctx_with(|err| {
format!("tetrahedral_moments alloc geom (stride={geom_stride}, nbeta={nbeta}): {err}")
})?;
let block_x: u32 = 32;
let block_y: u32 = 8;
let grid_x: u32 = ((cells.n_tets as u32) + block_x - 1) / block_x;
let grid_y: u32 = ((nbeta as u32) + block_y - 1) / block_y;
let cfg_geom = LaunchConfig {
grid_dim: (grid_x, grid_y, 1),
block_dim: (block_x, block_y, 1),
shared_mem_bytes: 0,
};
let n_tets_i32: i32 = i32::try_from(cells.n_tets)
.map_err(|_| gpu_err!("tetrahedral_moments n_tets={} overflows i32", cells.n_tets))?;
let geom_stride_i64: i64 = geom_stride as i64;
let mut builder = stream.launch_builder(&geom_func);
builder
.arg(&vertices_dev)
.arg(&cell_index_dev)
.arg(&cell_centers_dev)
.arg(&n_tets_i32)
.arg(&geom_stride_i64)
.arg(&mut geom_dev);
unsafe { builder.launch(cfg_geom) }.gpu_ctx("tetrahedral_moments geom kernel launch")?;
let con_key = TetMomentModuleKey {
cc_major: backend.inner.cc_major,
cc_minor: backend.inner.cc_minor,
kind: 1,
d: cells.d as u32,
nbeta: nbeta as u32,
nalpha: nalpha as u32,
pairs: pairs as u32,
beta_hash: hash_beta_table(&spec.geom_betas),
alpha_hash: hash_alpha_table(&spec.alphas),
layout_tag: layout_tag(spec.layout),
};
let con_module = backend.tet_module_for(con_key, move || {
build_tet_contract_kernel_source(nalpha, nbeta, pairs)
})?;
let con_func = con_module
.load_function("tetrahedral_contract_kernel")
.gpu_ctx("tetrahedral_moments load_function contract")?;
let weights_dev = stream
.clone_htod(inputs.weights)
.gpu_ctx("tetrahedral_moments htod weights")?;
let offsets_dev = stream
.clone_htod(inputs.tet_offsets)
.gpu_ctx("tetrahedral_moments htod offsets")?;
let segidx_dev = stream
.clone_htod(inputs.tet_index_in_segment)
.gpu_ctx("tetrahedral_moments htod segidx")?;
let out_stride = ((cells.n_cells + 31) / 32) * 32;
let mut out_dev = stream
.alloc_zeros::<f64>(out_stride * nalpha * pairs)
.gpu_ctx_with(|err| format!(
"tetrahedral_moments alloc out (stride={out_stride}, nalpha={nalpha}, pairs={pairs}): {err}"
))?;
let block_cx: u32 = 16;
let block_cy: u32 = 4;
let block_cz: u32 = 4;
let grid_cx: u32 = ((cells.n_cells as u32) + block_cx - 1) / block_cx;
let grid_cy: u32 = ((nalpha as u32) + block_cy - 1) / block_cy;
let grid_cz: u32 = ((pairs as u32) + block_cz - 1) / block_cz;
let cfg_con = LaunchConfig {
grid_dim: (grid_cx, grid_cy, grid_cz),
block_dim: (block_cx, block_cy, block_cz),
shared_mem_bytes: 0,
};
let n_cells_i32: i32 = i32::try_from(cells.n_cells).map_err(|_| {
gpu_err!(
"tetrahedral_moments n_cells={} overflows i32",
cells.n_cells
)
})?;
let out_stride_i64: i64 = out_stride as i64;
let mut builder = stream.launch_builder(&con_func);
builder
.arg(&geom_dev)
.arg(&geom_stride_i64)
.arg(&weights_dev)
.arg(&offsets_dev)
.arg(&segidx_dev)
.arg(&n_cells_i32)
.arg(&out_stride_i64)
.arg(&mut out_dev);
unsafe { builder.launch(cfg_con) }.gpu_ctx("tetrahedral_moments contract kernel launch")?;
stream
.synchronize()
.gpu_ctx("tetrahedral_moments synchronize")?;
Ok(DeviceCubicMomentTable {
n_cells: cells.n_cells,
pair_tuple_count: pairs,
n_alpha: nalpha,
layout: spec.layout,
values: out_dev,
})
}
#[cfg(test)]
mod cubic_bspline_moments_tests {
use super::*;
fn open_uniform_knots(n_basis: usize) -> Vec<f64> {
let n_int = n_basis - DEGREE; let mut t = Vec::with_capacity(n_basis + DEGREE + 1);
for _ in 0..=DEGREE {
t.push(0.0);
}
for i in 1..n_int {
t.push(i as f64 / n_int as f64);
}
for _ in 0..=DEGREE {
t.push(1.0);
}
t
}
fn nonuniform_knots() -> Vec<f64> {
let interior = [-1.7, -0.4, 0.1, 0.9, 1.55];
let mut t = Vec::new();
for _ in 0..=DEGREE {
t.push(-2.0);
}
t.extend_from_slice(&interior);
for _ in 0..=DEGREE {
t.push(3.0);
}
t
}
macro_rules! assert_close {
($label:expr, $got:expr, $expected:expr, $rel:expr, $abs:expr $(,)?) => {{
let got_v: f64 = $got;
let expected_v: f64 = $expected;
let rel_v: f64 = $rel;
let abs_v: f64 = $abs;
assert!(
got_v.is_finite() && expected_v.is_finite(),
"{}: non-finite (got={}, expected={})",
$label,
got_v,
expected_v
);
let diff = (got_v - expected_v).abs();
let bound = abs_v + rel_v * expected_v.abs().max(1.0);
assert!(
diff <= bound,
"{}: |{} - {}| = {} exceeds tol abs={}, rel={} (bound {})",
$label,
got_v,
expected_v,
diff,
abs_v,
rel_v,
bound
);
}};
}
#[test]
fn cox_de_boor_partition_of_unity_uniform() {
let t = open_uniform_knots(8);
for k in DEGREE..(t.len() - DEGREE - 1) {
let width = t[k + 1] - t[k];
if !span_is_active(width) {
continue;
}
let coeffs = cubic_basis_local_coeffs(&t, k);
for step in 0..=4 {
let u = step as f64 * width / 4.0;
let mut sum = 0.0;
for a in 0..ACTIVE_PER_SPAN {
let c = &coeffs[a];
let mut p = c[3];
p = p * u + c[2];
p = p * u + c[1];
p = p * u + c[0];
sum += p;
}
assert_close!(
&format!("partition span={k} step={step}"),
sum,
1.0,
1e-13,
1e-13,
);
}
}
}
#[test]
fn one_d_closed_form_matches_gauss_legendre_nonuniform() {
let t = nonuniform_knots();
let tables = AxisCubicMomentTables::build(&t, 0, 0);
for span in 0..tables.n_spans() {
let width = tables.width[span];
let left = tables.left[span];
for pair in 0..PAIRS_PER_SPAN {
let c = tables.prod(span, pair);
for nu in 0..=4usize {
let closed = moment_1d_local(c, width, nu);
let gl = moment_1d_gauss_legendre(c, left, width, nu, left);
assert_close!(
&format!("span={span} pair={pair} nu={nu}"),
closed,
gl,
1e-13,
1e-14,
);
}
}
}
}
#[test]
fn one_d_closed_form_shifted_moments_match_gauss_legendre() {
let t = nonuniform_knots();
let tables = AxisCubicMomentTables::build(&t, 0, 0);
for span in 0..tables.n_spans() {
let width = tables.width[span];
let left = tables.left[span];
for pair in 0..PAIRS_PER_SPAN {
let c = tables.prod(span, pair);
for nu in 0..=3usize {
for &m in &[
left - 0.3,
left + 0.1,
left + 0.5 * width,
left + width + 0.2,
] {
let closed = moment_1d_about(c, width, nu, m - left);
let gl = moment_1d_gauss_legendre(c, left, width, nu, m);
assert_close!(
&format!("span={span} pair={pair} nu={nu} m={m}"),
closed,
gl,
1e-12,
1e-13,
);
}
}
}
}
}
#[test]
fn partition_of_unity_zeroth_moment_equals_span_width() {
let t = nonuniform_knots();
let tables = AxisCubicMomentTables::build(&t, 0, 0);
for span in 0..tables.n_spans() {
let width = tables.width[span];
let mut sum = 0.0;
for a in 0..ACTIVE_PER_SPAN {
for b in 0..ACTIVE_PER_SPAN {
let m = tables.moment_local(span, active_pair_index(a, b), 0);
sum += m;
}
}
assert_close!(&format!("partition span={span}"), sum, width, 1e-13, 1e-14);
}
}
#[test]
fn tensor_separability_2d() {
let t = nonuniform_knots();
let table_x = AxisCubicMomentTables::build(&t, 0, 0);
let table_y = AxisCubicMomentTables::build(&t, 0, 0);
let axes: Vec<&AxisCubicMomentTables> = vec![&table_x, &table_y];
for sx in 0..table_x.n_spans() {
for sy in 0..table_y.n_spans() {
for pa in [0usize, 4, 9] {
for pb in [0usize, 3, 7] {
for alpha in &[[0u8, 0u8], [1, 0], [0, 1], [2, 1], [3, 3]] {
let m_tensor =
tensor_hex_moment_cpu(&axes, &[sx, sy], alpha, &[pa, pb]);
let m_marginal = table_x.moment_local(sx, pa, alpha[0] as usize)
* table_y.moment_local(sy, pb, alpha[1] as usize);
assert_close!(
&format!("tensor sx={sx} sy={sy} pa={pa} pb={pb}"),
m_tensor,
m_marginal,
1e-14,
1e-15,
);
}
}
}
}
}
}
#[test]
fn symmetry_pair_swap_gives_same_moment() {
let t = nonuniform_knots();
let tables = AxisCubicMomentTables::build(&t, 0, 0);
for span in 0..tables.n_spans() {
for a in 0..ACTIVE_PER_SPAN {
for b in 0..ACTIVE_PER_SPAN {
for nu in 0..=3usize {
let m_ab = tables.moment_local(span, active_pair_index(a, b), nu);
let m_ba = tables.moment_local(span, active_pair_index(b, a), nu);
assert_eq!(
m_ab.to_bits(),
m_ba.to_bits(),
"span={span} pair=({a},{b}) nu={nu}: pair index must be unordered"
);
}
}
}
}
}
#[test]
fn derivative_moment_matches_gauss_legendre() {
let t = nonuniform_knots();
let tables = AxisCubicMomentTables::build(&t, 1, 1);
for span in 0..tables.n_spans() {
let left = tables.left[span];
let width = tables.width[span];
let k = tables.span_indices[span];
let basis = cubic_basis_local_coeffs(&t, k);
for a in 0..ACTIVE_PER_SPAN {
for b in a..ACTIVE_PER_SPAN {
let pair = active_pair_index(a, b);
let da = differentiate_basis_coeffs(basis[a]);
let db = differentiate_basis_coeffs(basis[b]);
let prod = convolve_basis_pair(da, db);
for nu in 0..=2usize {
let closed = tables.moment_local(span, pair, nu);
let reference = moment_1d_gauss_legendre(prod, left, width, nu, left);
assert_close!(
&format!("d/dx span={span} pair=({a},{b}) nu={nu}"),
closed,
reference,
1e-13,
1e-14,
);
}
}
}
}
}
#[cfg(target_os = "linux")]
#[test]
fn gpu_hex_tensor_moments_match_cpu_reference() {
let t = nonuniform_knots();
let table = AxisCubicMomentTables::build(&t, 0, 0);
let axes_cpu: Vec<&AxisCubicMomentTables> = vec![&table, &table];
let axes_for_build: Vec<Vec<AxisCubicMomentTables>> =
vec![vec![table.clone()], vec![table.clone()]];
let alphas: Vec<Vec<u8>> = vec![vec![0, 0], vec![1, 0], vec![0, 1], vec![2, 1], vec![3, 3]];
let deriv = vec![vec![0u8, 0u8]; alphas.len()];
let spec = CubicMomentSpec {
alphas: alphas.clone(),
derivative_left: deriv.clone(),
derivative_right: deriv.clone(),
layout: MomentLayout::AlphaMajor,
};
let pair_choices: [usize; 3] = [0, 4, 9];
let mut span_per_axis: Vec<i32> = Vec::new();
let mut pair_per_axis: Vec<i32> = Vec::new();
let mut width_per_axis: Vec<f64> = Vec::new();
let mut cell_meta: Vec<(usize, usize, usize, usize)> = Vec::new();
for sx in 0..table.n_spans() {
for sy in 0..table.n_spans() {
for &pa in &pair_choices {
for &pb in &pair_choices {
span_per_axis.push(sx as i32);
span_per_axis.push(sy as i32);
pair_per_axis.push(pa as i32);
pair_per_axis.push(pb as i32);
width_per_axis.push(table.width[sx]);
width_per_axis.push(table.width[sy]);
cell_meta.push((sx, sy, pa, pb));
}
}
}
}
let n_cells = cell_meta.len();
let cells = HexCellTable {
span_per_axis,
pair_per_axis,
width_per_axis,
n_cells,
d: 2,
};
let dev = match super::build_hex_tensor_moments_device(&spec, &axes_for_build, &cells) {
Ok(d) => d,
Err(err) => {
eprintln!("skipping GPU parity test (no CUDA runtime): {err}");
assert!(
matches!(err, GpuError::DriverLibraryUnavailable { .. })
|| matches!(err, GpuError::DriverCallFailed { .. })
|| matches!(err, GpuError::NotYetImplemented { .. }),
"unexpected GPU error variant: {err:?}"
);
return;
}
};
let stream = CubicMomentBackend::probe()
.expect("backend probe ok after a successful build")
.inner
.stream
.clone();
let host_vals = stream
.clone_dtoh(&dev.values)
.expect("dtov of device moments");
let out_stride = host_vals.len() / spec.n_alpha();
assert!(
out_stride >= n_cells,
"out_stride={out_stride} < n_cells={n_cells}"
);
for (a_idx, alpha) in alphas.iter().enumerate() {
for (cell, &(sx, sy, pa, pb)) in cell_meta.iter().enumerate() {
let expected = tensor_hex_moment_cpu(&axes_cpu, &[sx, sy], alpha, &[pa, pb]);
let got = host_vals[a_idx * out_stride + cell];
assert_close!(
&format!("gpu cell={cell} alpha={alpha:?}"),
got,
expected,
1e-12,
1e-13,
);
}
}
assert_eq!(
out_stride,
((n_cells + 31) / 32) * 32,
"alpha-major stride must be 32-aligned n_cells"
);
assert_eq!(
host_vals.len(),
out_stride * spec.n_alpha(),
"alpha-major total = stride * n_alpha"
);
}
#[cfg(target_os = "linux")]
#[test]
fn hex_tensor_module_cache_hits_on_repeat_spec() {
let t = nonuniform_knots();
let table = AxisCubicMomentTables::build(&t, 0, 0);
let axes_for_build: Vec<Vec<AxisCubicMomentTables>> =
vec![vec![table.clone()], vec![table.clone()]];
let alphas: Vec<Vec<u8>> = vec![vec![0, 0], vec![1, 0], vec![2, 1]];
let deriv = vec![vec![0u8, 0u8]; alphas.len()];
let spec = CubicMomentSpec {
alphas,
derivative_left: deriv.clone(),
derivative_right: deriv,
layout: MomentLayout::AlphaMajor,
};
let cells = HexCellTable {
span_per_axis: vec![0, 0],
pair_per_axis: vec![0, 0],
width_per_axis: vec![table.width[0], table.width[0]],
n_cells: 1,
d: 2,
};
let first = match super::build_hex_tensor_moments_device(&spec, &axes_for_build, &cells) {
Ok(d) => d,
Err(err) => {
eprintln!("skipping module-cache test (no CUDA runtime): {err}");
assert!(
matches!(err, GpuError::DriverLibraryUnavailable { .. })
|| matches!(err, GpuError::DriverCallFailed { .. })
|| matches!(err, GpuError::NotYetImplemented { .. }),
"unexpected GPU error variant: {err:?}"
);
return;
}
};
let backend = CubicMomentBackend::probe().expect("backend probe");
let cache_len_after_first = {
let g = backend.inner.modules.lock().expect("cache lock");
g.len()
};
assert!(
cache_len_after_first >= 1,
"module cache must hold ≥1 entry after first build"
);
let second = super::build_hex_tensor_moments_device(&spec, &axes_for_build, &cells)
.expect("second build with identical spec must succeed");
assert_eq!(
second.n_alpha, first.n_alpha,
"cache hit must yield the same n_alpha as the first build"
);
assert_eq!(
second.n_cells, first.n_cells,
"cache hit must yield the same n_cells as the first build"
);
let cache_len_after_second = {
let g = backend.inner.modules.lock().expect("cache lock");
g.len()
};
assert_eq!(
cache_len_after_first, cache_len_after_second,
"identical spec must hit the cache (no new module compiled)"
);
}
#[test]
#[cfg(target_os = "linux")]
fn hex_tensor_kernel_source_contains_required_symbols() {
let alphas = vec![vec![0u8, 0u8], vec![1, 0], vec![0, 1], vec![2, 1]];
let src = super::build_hex_tensor_kernel_source(2, 2, &alphas);
assert!(
src.contains("#define D 2"),
"D macro missing in:\n{src}"
);
assert!(
src.contains("#define AMAX 2"),
"AMAX macro missing in:\n{src}"
);
assert!(
src.contains("#define NALPHA 4"),
"NALPHA macro missing in:\n{src}"
);
assert!(
src.contains("cubic_hex_tensor_moments"),
"kernel entry-point name missing"
);
assert!(
src.contains("ALPHA_TABLE[NALPHA][D]"),
"constant alpha table missing"
);
assert!(src.contains("{ 2, 1 }"), "alpha row (2,1) missing");
}
#[test]
#[cfg(target_os = "linux")]
fn alpha_table_hash_is_stable_and_sensitive() {
let a = vec![vec![0u8, 0u8], vec![1, 0], vec![0, 1]];
let b = vec![vec![0u8, 0u8], vec![1, 0], vec![0, 1]];
let c = vec![vec![0u8, 0u8], vec![1, 0], vec![0, 2]];
assert_eq!(super::hash_alpha_table(&a), super::hash_alpha_table(&b));
assert_ne!(super::hash_alpha_table(&a), super::hash_alpha_table(&c));
}
#[test]
fn tetrahedral_geom_moment_cpu_matches_dirichlet_unit_simplex() {
let v0 = [0.0, 0.0, 0.0];
let v1 = [1.0, 0.0, 0.0];
let v2 = [0.0, 1.0, 0.0];
let v3 = [0.0, 0.0, 1.0];
let mut verts = Vec::new();
verts.extend_from_slice(&v0);
verts.extend_from_slice(&v1);
verts.extend_from_slice(&v2);
verts.extend_from_slice(&v3);
let c0 = [0.0f64, 0.0, 0.0];
for beta in &[
[0u8, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[2, 0, 0],
[1, 1, 0],
[1, 0, 1],
[0, 1, 1],
[2, 1, 0],
[1, 1, 1],
[2, 2, 1],
] {
let got = super::tetrahedral_geom_moment_cpu(&verts, &c0, beta, 3);
let want = super::dirichlet_ref_simplex(beta[0] as u32, beta[1] as u32, beta[2] as u32);
assert_close!(&format!("dirichlet β={:?}", beta), got, want, 1e-14, 1e-15,);
}
}
#[test]
fn tetrahedral_geom_moment_cpu_matches_quadrature_general_tet() {
let v0 = [0.3f64, -0.2, 0.7];
let v1 = [1.1, 0.4, 0.6];
let v2 = [0.5, 0.9, 1.1];
let v3 = [0.7, -0.1, 1.8];
let mut verts = Vec::new();
verts.extend_from_slice(&v0);
verts.extend_from_slice(&v1);
verts.extend_from_slice(&v2);
verts.extend_from_slice(&v3);
let c0 = [0.1f64, 0.05, 0.2];
const GL8_X01: [f64; 8] = [
0.019_855_071_751_231_88,
0.101_666_761_293_186_63,
0.237_233_795_041_835_50,
0.408_282_678_752_175_10,
0.591_717_321_247_824_90,
0.762_766_204_958_164_50,
0.898_333_238_706_813_30,
0.980_144_928_248_768_10,
];
const GL8_W01: [f64; 8] = [
0.050_614_268_145_188_18,
0.111_190_517_226_687_24,
0.156_853_322_938_943_55,
0.181_341_891_689_180_92,
0.181_341_891_689_180_92,
0.156_853_322_938_943_55,
0.111_190_517_226_687_24,
0.050_614_268_145_188_18,
];
for beta in &[
[0u8, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[2, 1, 0],
[1, 1, 1],
[3, 0, 0],
] {
let got = super::tetrahedral_geom_moment_cpu(&verts, &c0, beta, 3);
let mut ref_acc = 0.0f64;
for ix in 0..8 {
for iy in 0..8 {
for iz in 0..8 {
let xi = GL8_X01[ix];
let et = GL8_X01[iy];
let ze = GL8_X01[iz];
let w = GL8_W01[ix] * GL8_W01[iy] * GL8_W01[iz];
let u1 = xi;
let u2 = (1.0 - xi) * et;
let u3 = (1.0 - xi) * (1.0 - et) * ze;
let jac = (1.0 - xi) * (1.0 - xi) * (1.0 - et);
let mut x = [0.0f64; 3];
for r in 0..3 {
x[r] = v0[r]
+ u1 * (v1[r] - v0[r])
+ u2 * (v2[r] - v0[r])
+ u3 * (v3[r] - v0[r]);
}
let mut integrand = 1.0;
for r in 0..3 {
integrand *= (x[r] - c0[r]).powi(beta[r] as i32);
}
ref_acc += w * jac * integrand;
}
}
}
let b = [
[v1[0] - v0[0], v2[0] - v0[0], v3[0] - v0[0]],
[v1[1] - v0[1], v2[1] - v0[1], v3[1] - v0[1]],
[v1[2] - v0[2], v2[2] - v0[2], v3[2] - v0[2]],
];
let det = (b[0][0] * (b[1][1] * b[2][2] - b[1][2] * b[2][1])
- b[0][1] * (b[1][0] * b[2][2] - b[1][2] * b[2][0])
+ b[0][2] * (b[1][0] * b[2][1] - b[1][1] * b[2][0]))
.abs();
let want = ref_acc * det;
assert_close!(&format!("tet β={:?}", beta), got, want, 1e-12, 1e-13,);
}
}
#[test]
fn tetrahedral_cell_table_validate_catches_misshapen_input() {
let good = super::TetrahedralCellTable {
vertices: vec![0.0; 4 * 3],
cell_index: vec![0],
cell_centers: vec![0.0, 0.0, 0.0],
n_tets: 1,
n_cells: 1,
d: 3,
};
assert!(good.validate().is_ok(), "well-formed table validates");
let bad_verts = super::TetrahedralCellTable {
vertices: vec![0.0; 4 * 3 - 1],
..good.clone()
};
assert!(bad_verts.validate().is_err(), "short vertex array rejected");
let bad_idx = super::TetrahedralCellTable {
cell_index: vec![3],
..good.clone()
};
assert!(
bad_idx.validate().is_err(),
"out-of-range cell index rejected"
);
}
#[test]
#[cfg(target_os = "linux")]
fn tetrahedral_kernel_sources_contain_required_symbols() {
let betas = vec![vec![0u8, 0, 0], vec![1, 0, 0]];
let geom_src = super::build_tet_geom_kernel_source(3, &betas);
assert!(
geom_src.contains("tetrahedral_geom_moments_kernel"),
"geom kernel missing entry-point symbol"
);
assert!(
geom_src.contains("BETA_TABLE"),
"geom kernel missing baked-in β table"
);
let con_src = super::build_tet_contract_kernel_source(4, 3, 10);
assert!(
con_src.contains("tetrahedral_contract_kernel"),
"contract kernel missing entry-point symbol"
);
assert!(
con_src.contains("NALPHA 4") && con_src.contains("NBETA 3"),
"contract kernel missing NALPHA / NBETA defines"
);
}
#[test]
fn backend_compiled_flag_matches_platform() {
assert_eq!(CubicMomentBackend::compiled(), cfg!(target_os = "linux"));
let probe = CubicMomentBackend::probe();
if cfg!(target_os = "linux") {
assert!(
probe.is_ok() || probe.is_err(),
"probe must return a Result"
);
} else {
assert!(probe.is_err(), "non-Linux probe must return Err");
}
}
}