Skip to main content

oxiphysics_gpu/compute/
cuda_backend.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! CUDA compute backend for the OxiPhysics GPU acceleration layer.
5//!
6//! This module provides [`CudaBackend`] which implements the compute-backend
7//! interface using NVIDIA CUDA via the [`cudarc`](https://crates.io/crates/cudarc)
8//! crate for type-safe PTX / CUDA kernel management.
9//!
10//! ## Feature flag
11//!
12//! This backend is gated behind the `cuda-backend` Cargo feature:
13//!
14//! ```toml
15//! [dependencies]
16//! oxiphysics-gpu = { features = ["cuda-backend"] }
17//! ```
18//!
19//! When the feature is disabled the module compiles to a no-op stub returning
20//! [`CudaInitError::NotAvailable`] from [`CudaBackend::try_new`].
21//!
22//! When the feature is enabled, cudarc uses dynamic-loading (`libloading`) so
23//! the crate compiles on any platform; the CUDA driver is opened at runtime and
24//! an error is returned if it is absent (e.g. macOS, headless Linux without an
25//! NVIDIA driver).
26//!
27//! ## Architecture
28//!
29//! ```text
30//!  CudaBackend
31//!   ├── cudarc::CudaContext                  ← CUDA device context (Arc)
32//!   ├── cudarc::CudaStream                   ← Default stream for kernel dispatch
33//!   ├── cudarc::CudaSlice<u8>                ← Device-resident buffer slices
34//!   ├── Vec<CudaBufferEntry>                 ← Registered buffer metadata
35//!   └── KernelRegistry                       ← Compiled PTX / NVRTC modules
36//!
37//!  Compute pipeline:
38//!    write_buffer [host→device memcpy via stream]
39//!    → launch_kernel(grid, block, args)
40//!    → read_buffer  [device→host memcpy via stream]
41//! ```
42//!
43//! ## Kernels shipped with this backend
44//!
45//! | Source constant | Description |
46//! |---|---|
47//! | [`PTX_SPH_DENSITY`] | SPH density summation (cubic-spline W3), 256 threads/block |
48//! | [`PTX_PARALLEL_SCAN`] | Blelloch exclusive prefix scan, warp-shuffle optimised |
49//! | [`PTX_CONSTRAINT_PGS`] | Block-PGS constraint solver, 64 threads/block |
50//! | [`CUDA_SPH_DENSITY_SRC`] | CUDA C SPH density kernel (compiled at runtime via NVRTC) |
51//!
52//! ## Example (when `cuda-backend` feature enabled)
53//!
54//! ```ignore
55//! use oxiphysics_gpu::compute::cuda_backend::CudaBackend;
56//!
57//! let mut backend = CudaBackend::try_new(0)?;          // device 0
58//! let buf = backend.create_buffer(1024);               // 1024 f64 slots
59//! backend.write_buffer(buf, &vec![1.0_f64; 1024]);
60//! backend.launch("sph_density", &[buf], 16, 256);      // 16 blocks × 256 threads
61//! let result = backend.read_buffer(buf);
62//! ```
63
64#![allow(dead_code)]
65
66// ── CudaBufferHandle ──────────────────────────────────────────────────────────
67
68/// Opaque handle to a CUDA device buffer allocated by [`CudaBackend`].
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub struct CudaBufferHandle(pub usize);
71
72// ── CudaDeviceInfo ────────────────────────────────────────────────────────────
73
74/// Information about the selected CUDA device.
75#[derive(Debug, Clone, Default)]
76pub struct CudaDeviceInfo {
77    /// CUDA device ordinal (0-indexed).
78    pub ordinal: u32,
79    /// Device name from `cuDeviceGetName`.
80    pub name: String,
81    /// Total global memory in bytes (`cuDeviceTotalMem`).
82    pub total_mem_bytes: u64,
83    /// Compute capability as `(major, minor)`.
84    pub compute_capability: (u32, u32),
85    /// Number of CUDA streaming multiprocessors.
86    pub multiprocessor_count: u32,
87    /// Maximum threads per block.
88    pub max_threads_per_block: u32,
89    /// Warp size (always 32 on current NVIDIA hardware).
90    pub warp_size: u32,
91    /// Whether the device supports unified memory (Compute Capability ≥ 3.0).
92    pub supports_unified_memory: bool,
93    /// Whether the device supports FP64 (`cuDeviceGetAttribute CUDA_DEVICE_ATTRIBUTE_DOUBLE`).
94    pub supports_f64: bool,
95    /// CUDA driver version string.
96    pub driver_version: String,
97}
98
99// ── CudaInitError ─────────────────────────────────────────────────────────────
100
101/// Errors returned by [`CudaBackend::try_new`].
102#[derive(Debug, Clone)]
103pub enum CudaInitError {
104    /// CUDA runtime or driver is not installed on this system.
105    NotAvailable,
106    /// The `cuda-backend` Cargo feature is not enabled in this build.
107    FeatureNotEnabled,
108    /// No CUDA-capable device found (all GPUs are AMD / Intel).
109    NoDevice,
110    /// The requested device ordinal is out of range.
111    DeviceOrdinalOutOfRange(u32),
112    /// cudarc device initialisation returned an error.
113    DeviceError(String),
114    /// NVRTC compilation of a kernel source failed.
115    CompilationError(String),
116}
117
118impl std::fmt::Display for CudaInitError {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        match self {
121            Self::NotAvailable => write!(f, "CUDA is not available on this system"),
122            Self::FeatureNotEnabled => write!(f, "`cuda-backend` feature is not enabled"),
123            Self::NoDevice => write!(f, "no CUDA-capable device found"),
124            Self::DeviceOrdinalOutOfRange(n) => write!(f, "device ordinal {n} is out of range"),
125            Self::DeviceError(msg) => write!(f, "CUDA device error: {msg}"),
126            Self::CompilationError(msg) => write!(f, "NVRTC compile error: {msg}"),
127        }
128    }
129}
130
131impl std::error::Error for CudaInitError {}
132
133// ── PTX kernel sources (stub PTX for introspection / documentation) ────────────
134
135/// PTX source for SPH density summation with cubic-spline W3 kernel.
136///
137/// Grid: N/256 blocks. Block: 256 threads.  Each thread computes the density
138/// for one particle by summing contributions from all particles within 2h.
139///
140/// Shared memory is used for tile-based neighbour loading (32 kB per SM).
141///
142/// When the `cuda-backend` feature is active, the real CUDA C source in
143/// [`CUDA_SPH_DENSITY_SRC`] is compiled via NVRTC at runtime.  This constant
144/// is kept as reference documentation and for `register_kernel` calls in the
145/// stub path.
146pub const PTX_SPH_DENSITY: &str = r#"
147// CUDA C source (compiled to PTX via nvcc -arch=sm_70 -ptx)
148// extern "C" __global__ void sph_density(
149//     const float4* __restrict__ pos,   // positions + mass in .w
150//     float*        __restrict__ rho,   // output density
151//     int                        n,     // particle count
152//     float                      h,     // smoothing length
153//     float                      h_inv  // 1/h
154// ) {
155//     int i = blockIdx.x * blockDim.x + threadIdx.x;
156//     if (i >= n) return;
157//
158//     float xi = pos[i].x, yi = pos[i].y, zi = pos[i].z;
159//     float density = 0.0f;
160//     const float coeff = (315.0f / 64.0f) * __fdividef(1.0f, 3.14159265f * h*h*h);
161//
162//     // tile-based neighbour loop (shared memory)
163//     __shared__ float4 tile[256];
164//     for (int t = 0; t < (n + 255) / 256; t++) {
165//         int j = t * 256 + threadIdx.x;
166//         tile[threadIdx.x] = (j < n) ? pos[j] : make_float4(1e30f, 1e30f, 1e30f, 0.0f);
167//         __syncthreads();
168//         for (int k = 0; k < 256; k++) {
169//             float dx = xi - tile[k].x, dy = yi - tile[k].y, dz = zi - tile[k].z;
170//             float r2 = dx*dx + dy*dy + dz*dz;
171//             float h2 = h * h;
172//             if (r2 < h2) {
173//                 float q = 1.0f - r2 * __fdividef(1.0f, h2);
174//                 density += tile[k].w * coeff * q * q * q;
175//             }
176//         }
177//         __syncthreads();
178//     }
179//     rho[i] = density;
180// }
181// --- actual PTX would be here ---
182.version 7.0
183.target sm_70
184.address_size 64
185// (stub — replace with actual nvcc-compiled PTX)
186"#;
187
188/// PTX source for Blelloch exclusive parallel prefix scan.
189///
190/// Grid: 1 block per chunk of 512 elements.  Block: 256 threads.
191/// Uses warp-shuffle primitives (`__shfl_up_sync`) for the intra-warp scan,
192/// then shared memory for the inter-warp reduction.
193pub const PTX_PARALLEL_SCAN: &str = r#"
194// CUDA C source:
195// extern "C" __global__ void exclusive_scan(
196//     const double* __restrict__ in,
197//     double*       __restrict__ out,
198//     int                        n
199// ) {
200//     extern __shared__ double shmem[];
201//     int tid = threadIdx.x;
202//     int gid = blockIdx.x * blockDim.x + tid;
203//
204//     // Load into shared memory
205//     shmem[tid] = (gid < n) ? in[gid] : 0.0;
206//     __syncthreads();
207//
208//     // Blelloch up-sweep
209//     for (int stride = 1; stride < blockDim.x; stride <<= 1) {
210//         int idx = (tid + 1) * stride * 2 - 1;
211//         if (idx < blockDim.x)
212//             shmem[idx] += shmem[idx - stride];
213//         __syncthreads();
214//     }
215//
216//     // Set root to zero
217//     if (tid == blockDim.x - 1) shmem[tid] = 0.0;
218//     __syncthreads();
219//
220//     // Blelloch down-sweep
221//     for (int stride = blockDim.x / 2; stride >= 1; stride >>= 1) {
222//         int idx = (tid + 1) * stride * 2 - 1;
223//         if (idx < blockDim.x) {
224//             double t    = shmem[idx - stride];
225//             shmem[idx - stride] = shmem[idx];
226//             shmem[idx]  = shmem[idx] + t;
227//         }
228//         __syncthreads();
229//     }
230//
231//     if (gid < n) out[gid] = shmem[tid];
232// }
233.version 7.0
234.target sm_70
235.address_size 64
236// (stub — replace with actual nvcc-compiled PTX)
237"#;
238
239/// PTX source for block-PGS constraint solving.
240///
241/// Grid: ⌈N/64⌉ blocks.  Block: 64 threads (1 thread does the sequential inner
242/// loop for guaranteed Gauss-Seidel convergence within the block).
243pub const PTX_CONSTRAINT_PGS: &str = r#"
244// CUDA C source:
245// extern "C" __global__ void constraint_pgs_iter(
246//     const GpuConstraint* __restrict__ constraints,
247//     float* __restrict__              lambda,
248//     float4* __restrict__             vel_lin,   // xyz=vel, w=inv_mass
249//     float4* __restrict__             vel_ang,
250//     int                              n,
251//     float                            omega
252// ) {
253//     int base = blockIdx.x * blockDim.x;
254//     if (threadIdx.x != 0) return;
255//
256//     for (int ci = base; ci < min(base + (int)blockDim.x, n); ci++) {
257//         GpuConstraint c = constraints[ci];
258//         float3 vla = make_float3(0), wla = make_float3(0); float inv_ma = 0;
259//         float3 vlb = make_float3(0), wlb = make_float3(0); float inv_mb = 0;
260//
261//         if (c.body_a != 0xFFFFFFFF) {
262//             float4 vl = vel_lin[c.body_a], vw = vel_ang[c.body_a];
263//             vla = make_float3(vl); wla = make_float3(vw); inv_ma = vl.w;
264//         }
265//         if (c.body_b != 0xFFFFFFFF) {
266//             float4 vl = vel_lin[c.body_b], vw = vel_ang[c.body_b];
267//             vlb = make_float3(vl); wlb = make_float3(vw); inv_mb = vl.w;
268//         }
269//
270//         float3 n3 = make_float3(c.nx, c.ny, c.nz);
271//         float3 va  = vla + cross(wla, make_float3(c.rax, c.ray, c.raz));
272//         float3 vb  = vlb + cross(wlb, make_float3(c.rbx, c.rby, c.rbz));
273//         float  rv  = dot(n3, va - vb);
274//         float  d   = -(rv + c.bias) * c.em * omega;
275//         float  old = lambda[ci];
276//         float  neo = __saturatef((old + d - c.lambda_lo) / (c.lambda_hi - c.lambda_lo))
277//                      * (c.lambda_hi - c.lambda_lo) + c.lambda_lo;
278//         lambda[ci] = neo;
279//         float  dl  = neo - old;
280//
281//         float3 imp = n3 * dl;
282//         if (c.body_a != 0xFFFFFFFF) { /* update vel_lin/ang[body_a] */ }
283//         if (c.body_b != 0xFFFFFFFF) { /* update vel_lin/ang[body_b] */ }
284//     }
285// }
286.version 7.0
287.target sm_70
288.address_size 64
289// (stub — replace with actual nvcc-compiled PTX)
290"#;
291
292/// CUDA C source for SPH density summation kernel, compiled at runtime via NVRTC.
293///
294/// This kernel computes the SPH density for each particle using the cubic-spline
295/// kernel W(r, h) = (315 / 64π h³) (1 − r²/h²)³ for r < h.
296///
297/// Each thread handles one particle (index `i`) and iterates over all `n_particles`
298/// to accumulate density.  The grid-stride is 1 thread per particle; caller must
299/// dispatch `ceil(n_particles / 256)` blocks × 256 threads.
300///
301/// Positions are stored as a flat interleaved array: `positions[3*i]` = x,
302/// `positions[3*i+1]` = y, `positions[3*i+2]` = z.
303pub const CUDA_SPH_DENSITY_SRC: &str = r#"
304extern "C" __global__ void sph_density_kernel(
305    const double* __restrict__ positions,
306    double* __restrict__ densities,
307    int n_particles,
308    double smoothing_length,
309    double particle_mass
310) {
311    int i = blockIdx.x * blockDim.x + threadIdx.x;
312    if (i >= n_particles) return;
313    double px = positions[3*i], py = positions[3*i+1], pz = positions[3*i+2];
314    double rho = 0.0;
315    double h2 = smoothing_length * smoothing_length;
316    double coeff = 315.0 / (64.0 * 3.14159265358979 * smoothing_length
317                            * smoothing_length * smoothing_length);
318    for (int j = 0; j < n_particles; j++) {
319        double dx = px - positions[3*j];
320        double dy = py - positions[3*j+1];
321        double dz = pz - positions[3*j+2];
322        double r2 = dx*dx + dy*dy + dz*dz;
323        if (r2 < h2) {
324            double q = 1.0 - r2 / h2;
325            rho += q * q * q;
326        }
327    }
328    densities[i] = particle_mass * coeff * rho;
329}
330"#;
331
332// ── Internal buffer entry ──────────────────────────────────────────────────────
333
334/// Internal buffer entry: CPU shadow + metadata.
335#[derive(Debug, Clone)]
336struct CudaBufferEntry {
337    /// Number of `f64` elements allocated.
338    len: usize,
339    /// CPU shadow data (mirrors device memory in stub implementation).
340    shadow: Vec<f64>,
341    /// Whether this buffer uses unified memory (UM).
342    unified: bool,
343}
344
345// ── Real CUDA context (feature-gated) ─────────────────────────────────────────
346
347#[cfg(feature = "cuda-backend")]
348mod real_ctx {
349    use super::CudaInitError;
350    use std::collections::HashMap;
351    use std::sync::Arc;
352
353    use cudarc::driver::{CudaContext, CudaFunction, CudaModule, CudaSlice, CudaStream};
354
355    /// Holds live cudarc objects for the active CUDA device context.
356    pub(super) struct CudaRealContext {
357        /// The CUDA device context.
358        pub ctx: Arc<CudaContext>,
359        /// Default stream used for all memory operations and kernel launches.
360        pub stream: Arc<CudaStream>,
361        /// Device-resident byte buffers, indexed parallel to `CudaBackend::buffers`.
362        pub real_buffers: Vec<CudaSlice<u8>>,
363        /// Loaded modules keyed by name.
364        pub modules: HashMap<String, Arc<CudaModule>>,
365        /// Functions keyed by name.
366        pub functions: HashMap<String, CudaFunction>,
367    }
368
369    impl CudaRealContext {
370        /// Initialise a CUDA device context for the given ordinal.
371        ///
372        /// `default_stream` is infallible in cudarc 0.19 — it simply wraps the
373        /// null-pointer stream which always exists.
374        pub fn new(ordinal: u32) -> Result<Self, CudaInitError> {
375            let ctx = CudaContext::new(ordinal as usize)
376                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))?;
377            // default_stream() returns Arc<CudaStream> directly (not Result).
378            let stream = ctx.default_stream();
379            Ok(Self {
380                ctx,
381                stream,
382                real_buffers: Vec::new(),
383                modules: HashMap::new(),
384                functions: HashMap::new(),
385            })
386        }
387
388        /// Allocate `len` bytes zeroed on the device, returning the buffer index.
389        pub fn alloc_bytes(&mut self, len: usize) -> Result<usize, CudaInitError> {
390            let slice: CudaSlice<u8> = self
391                .stream
392                .alloc_zeros::<u8>(len)
393                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))?;
394            let idx = self.real_buffers.len();
395            self.real_buffers.push(slice);
396            Ok(idx)
397        }
398
399        /// Upload `data` (as raw bytes of f64) to buffer at `idx`.
400        pub fn write_f64_slice(&mut self, idx: usize, data: &[f64]) -> Result<(), CudaInitError> {
401            // Reinterpret f64 slice as u8 slice for the memcpy.
402            let byte_len = std::mem::size_of_val(data);
403            let byte_slice: &[u8] =
404                // SAFETY: f64 is a POD type; we never write through this reference.
405                unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<u8>(), byte_len) };
406
407            let dst = self
408                .real_buffers
409                .get_mut(idx)
410                .ok_or_else(|| CudaInitError::DeviceError("invalid buffer index".to_owned()))?;
411
412            // Only copy as many bytes as fit in the allocated slice.
413            let copy_len = byte_len.min(dst.len());
414            if copy_len == 0 {
415                return Ok(());
416            }
417            let src_trimmed = &byte_slice[..copy_len];
418
419            // memcpy_htod requires dst.len() >= src.len(), so use a sub-view.
420            let mut dst_view = dst
421                .try_slice_mut(..copy_len)
422                .ok_or_else(|| CudaInitError::DeviceError("slice view failed".to_owned()))?;
423
424            self.stream
425                .memcpy_htod(src_trimmed, &mut dst_view)
426                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))
427        }
428
429        /// Download buffer at `idx` into a Vec<f64>.
430        pub fn read_f64_vec(&self, idx: usize) -> Result<Vec<f64>, CudaInitError> {
431            let src = self
432                .real_buffers
433                .get(idx)
434                .ok_or_else(|| CudaInitError::DeviceError("invalid buffer index".to_owned()))?;
435            let bytes: Vec<u8> = self
436                .stream
437                .clone_dtoh(src)
438                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))?;
439            Ok(bytes_to_f64_vec(bytes))
440        }
441
442        /// Register a PTX-source kernel from a raw `.ptx` string via `Ptx::from_src`.
443        pub fn register_ptx(&mut self, name: &str, ptx_src: &str) -> Result<(), CudaInitError> {
444            use cudarc::nvrtc::Ptx;
445            let ptx = Ptx::from_src(ptx_src);
446            let module: Arc<CudaModule> = self
447                .ctx
448                .load_module(ptx)
449                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))?;
450            let func: CudaFunction = module
451                .load_function(name)
452                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))?;
453            self.modules.insert(name.to_owned(), module);
454            self.functions.insert(name.to_owned(), func);
455            Ok(())
456        }
457
458        /// Compile CUDA C source via NVRTC and register the named kernel.
459        pub fn compile_and_register(
460            &mut self,
461            name: &str,
462            cuda_c_src: &str,
463        ) -> Result<(), CudaInitError> {
464            use cudarc::nvrtc::compile_ptx;
465            let ptx = compile_ptx(cuda_c_src)
466                .map_err(|e| CudaInitError::CompilationError(format!("{e:?}")))?;
467            let module: Arc<CudaModule> = self
468                .ctx
469                .load_module(ptx)
470                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))?;
471            let func: CudaFunction = module
472                .load_function(name)
473                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))?;
474            self.modules.insert(name.to_owned(), module);
475            self.functions.insert(name.to_owned(), func);
476            Ok(())
477        }
478
479        /// Synchronise the default stream (block until all work completes).
480        pub fn synchronize(&self) -> Result<(), CudaInitError> {
481            self.stream
482                .synchronize()
483                .map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))
484        }
485    }
486
487    /// Convert raw `Vec<u8>` (little-endian IEEE-754) back to `Vec<f64>`.
488    pub(super) fn bytes_to_f64_vec(bytes: Vec<u8>) -> Vec<f64> {
489        if !bytes.len().is_multiple_of(8) {
490            return Vec::new();
491        }
492        bytes
493            .chunks_exact(8)
494            .filter_map(|c| <[u8; 8]>::try_from(c).ok().map(f64::from_le_bytes))
495            .collect()
496    }
497}
498
499// ── CudaBackend ───────────────────────────────────────────────────────────────
500
501/// CUDA compute backend.
502///
503/// **Without** `cuda-backend` feature: no-op stub; [`Self::try_new`] always returns
504/// [`CudaInitError::FeatureNotEnabled`].  All buffer and kernel methods operate
505/// on CPU shadows so unit tests compile and run on any platform.
506///
507/// **With** `cuda-backend` feature: real cudarc device context; [`Self::try_new`]
508/// calls `CudaContext::new(ordinal)` and returns an error if the CUDA driver is
509/// absent (e.g. on macOS or a Linux machine without an NVIDIA driver).  Buffer
510/// methods perform actual host↔device memcpy via the default stream.
511pub struct CudaBackend {
512    /// Device information (filled from driver attributes when a real context is active).
513    pub device_info: CudaDeviceInfo,
514    /// Whether a real CUDA device context is active.
515    available: bool,
516    /// CPU-side buffer shadows (used by the stub path; metadata only in real path).
517    buffers: Vec<CudaBufferEntry>,
518    /// Registered kernel names (stub path) or names of compiled functions (real path).
519    kernels: Vec<String>,
520    /// Live cudarc context — present only when `cuda-backend` feature is enabled
521    /// **and** device initialisation succeeded.
522    #[cfg(feature = "cuda-backend")]
523    real: Option<real_ctx::CudaRealContext>,
524}
525
526// ── Common constructor helpers ────────────────────────────────────────────────
527
528impl CudaBackend {
529    /// Attempt to create a CUDA backend on device `ordinal`.
530    ///
531    /// - **Without** `cuda-backend` feature: always returns
532    ///   `Err(CudaInitError::FeatureNotEnabled)`.
533    /// - **With** `cuda-backend` feature: calls `CudaContext::new(ordinal)`.
534    ///   Returns `Err(CudaInitError::DeviceError(...))` if the CUDA driver is
535    ///   absent or the ordinal is invalid.
536    pub fn try_new(ordinal: u32) -> Result<Self, CudaInitError> {
537        #[cfg(feature = "cuda-backend")]
538        {
539            Self::try_new_real(ordinal)
540        }
541        #[cfg(not(feature = "cuda-backend"))]
542        {
543            let _ = ordinal;
544            Err(CudaInitError::FeatureNotEnabled)
545        }
546    }
547
548    /// Create a CPU-fallback stub (useful for unit testing without a GPU).
549    pub fn new_stub() -> Self {
550        Self {
551            device_info: CudaDeviceInfo {
552                name: "CPU stub".into(),
553                ..Default::default()
554            },
555            available: false,
556            buffers: Vec::new(),
557            kernels: Vec::new(),
558            #[cfg(feature = "cuda-backend")]
559            real: None,
560        }
561    }
562
563    /// True if a real CUDA device context is active.
564    pub fn is_available(&self) -> bool {
565        self.available
566    }
567
568    /// Device information.
569    pub fn device_info(&self) -> &CudaDeviceInfo {
570        &self.device_info
571    }
572
573    // ── Buffer management ────────────────────────────────────────────────────
574
575    /// Allocate a device buffer that can hold `len` `f64` values.
576    ///
577    /// Real path: calls `CudaStream::alloc_zeros::<u8>(len * 8)` and stores
578    /// the returned `CudaSlice<u8>`.  Falls back to a CPU-shadow buffer when
579    /// no real context is active.
580    pub fn create_buffer(&mut self, len: usize) -> CudaBufferHandle {
581        let handle = CudaBufferHandle(self.buffers.len());
582
583        #[cfg(feature = "cuda-backend")]
584        if let Some(ctx) = self.real.as_mut() {
585            let byte_len = len * std::mem::size_of::<f64>();
586            // If real allocation fails, degrade gracefully to CPU shadow.
587            if ctx.alloc_bytes(byte_len).is_ok() {
588                self.buffers.push(CudaBufferEntry {
589                    len,
590                    shadow: Vec::new(), // no CPU shadow in real path
591                    unified: false,
592                });
593                return handle;
594            }
595        }
596
597        self.buffers.push(CudaBufferEntry {
598            len,
599            shadow: vec![0.0; len],
600            unified: false,
601        });
602        handle
603    }
604
605    /// Allocate a **unified memory** buffer (accessible from both CPU and GPU).
606    ///
607    /// In the current implementation unified memory is backed by the same
608    /// `CudaSlice<u8>` path as a regular buffer; true UM page migration would
609    /// require `UnifiedSlice` from cudarc which is gated on additional CUDA
610    /// driver capabilities.  Falls back to a CPU-shadow buffer in the stub.
611    pub fn alloc_unified(&mut self, len: usize) -> CudaBufferHandle {
612        let handle = CudaBufferHandle(self.buffers.len());
613
614        #[cfg(feature = "cuda-backend")]
615        if let Some(ctx) = self.real.as_mut() {
616            let byte_len = len * std::mem::size_of::<f64>();
617            if ctx.alloc_bytes(byte_len).is_ok() {
618                self.buffers.push(CudaBufferEntry {
619                    len,
620                    shadow: Vec::new(),
621                    unified: true,
622                });
623                return handle;
624            }
625        }
626
627        self.buffers.push(CudaBufferEntry {
628            len,
629            shadow: vec![0.0; len],
630            unified: true,
631        });
632        handle
633    }
634
635    /// Upload `data` to the device buffer at `handle`.
636    ///
637    /// Real path: `CudaStream::memcpy_htod` — synchronous on the default stream.
638    /// Stub path: copies into the CPU shadow.
639    pub fn write_buffer(&mut self, handle: CudaBufferHandle, data: &[f64]) {
640        #[cfg(feature = "cuda-backend")]
641        if let Some(ctx) = self.real.as_mut() {
642            // Attempt real memcpy; silently degrade on error.
643            let _ = ctx.write_f64_slice(handle.0, data);
644            return;
645        }
646
647        if let Some(entry) = self.buffers.get_mut(handle.0) {
648            let len = data.len().min(entry.len);
649            if entry.shadow.len() < len {
650                entry.shadow.resize(entry.len, 0.0);
651            }
652            entry.shadow[..len].copy_from_slice(&data[..len]);
653        }
654    }
655
656    /// Download data from the device buffer at `handle`.
657    ///
658    /// Real path: `CudaStream::clone_dtoh` — synchronous copy to a new `Vec<f64>`.
659    /// Stub path: returns a clone of the CPU shadow.
660    pub fn read_buffer(&self, handle: CudaBufferHandle) -> Vec<f64> {
661        #[cfg(feature = "cuda-backend")]
662        if let Some(ctx) = self.real.as_ref() {
663            return ctx.read_f64_vec(handle.0).unwrap_or_default();
664        }
665
666        self.buffers
667            .get(handle.0)
668            .map(|e| e.shadow.clone())
669            .unwrap_or_default()
670    }
671
672    // ── Kernel management ────────────────────────────────────────────────────
673
674    /// Register a PTX kernel source and associate it with `name`.
675    ///
676    /// Real path: loads the module via `CudaContext::load_module` and retrieves
677    /// the named function.  Stub path: records the name only.
678    pub fn register_kernel(&mut self, name: &str, ptx_source: &str) {
679        #[cfg(feature = "cuda-backend")]
680        if let Some(ctx) = self.real.as_mut() {
681            let _ = ctx.register_ptx(name, ptx_source);
682        }
683        // In stub path the ptx_source is intentionally not used (no NVRTC).
684        #[cfg(not(feature = "cuda-backend"))]
685        let _ = ptx_source;
686
687        if !self.kernels.contains(&name.to_owned()) {
688            self.kernels.push(name.to_string());
689        }
690    }
691
692    /// Compile a CUDA C kernel at runtime via NVRTC and register it.
693    ///
694    /// Real path: calls `cudarc::nvrtc::compile_ptx` then loads the module.
695    /// Stub path: records the name and returns `Ok(())`.
696    pub fn compile_and_register(
697        &mut self,
698        name: &str,
699        cuda_c_source: &str,
700    ) -> Result<(), CudaInitError> {
701        #[cfg(feature = "cuda-backend")]
702        if let Some(ctx) = self.real.as_mut() {
703            ctx.compile_and_register(name, cuda_c_source)?;
704            if !self.kernels.contains(&name.to_owned()) {
705                self.kernels.push(name.to_string());
706            }
707            return Ok(());
708        }
709
710        // Stub path: record name, suppress unused-var warnings
711        let _ = cuda_c_source;
712        if !self.kernels.contains(&name.to_owned()) {
713            self.kernels.push(name.to_string());
714        }
715        Ok(())
716    }
717
718    // ── Kernel launch ────────────────────────────────────────────────────────
719
720    /// Launch a registered kernel with buffer arguments only.
721    ///
722    /// # Parameters
723    ///
724    /// - `name` — kernel name as passed to [`Self::register_kernel`] or
725    ///   [`Self::compile_and_register`]
726    /// - `buffers` — buffer handles bound as kernel arguments (in order)
727    /// - `grid_x` — number of thread blocks in X dimension
728    /// - `block_x` — number of threads per block in X dimension
729    ///
730    /// For kernels that take scalar arguments (e.g. an integer particle count
731    /// or floating-point smoothing length), use [`Self::launch_with_scalars`]
732    /// instead — calling `launch` against a kernel whose signature includes
733    /// scalar parameters will pass uninitialised registers to those slots and
734    /// is undefined behaviour.
735    ///
736    /// Real path: retrieves the stored `CudaFunction` and dispatches via
737    /// `CudaStream::launch_builder`.  Currently up to two buffer arguments
738    /// are forwarded; extend as needed for higher-arity kernels.
739    ///
740    /// Stub path: no-op.
741    pub fn launch(&mut self, name: &str, buffers: &[CudaBufferHandle], grid_x: u32, block_x: u32) {
742        self.launch_with_scalars(name, buffers, &[], &[], grid_x, block_x);
743    }
744
745    /// Launch a registered kernel passing buffer **and** scalar arguments.
746    ///
747    /// Scalars are appended to the kernel argument list after the buffer
748    /// arguments in the order `i32` scalars then `f64` scalars; the kernel
749    /// signature must match that ordering exactly.
750    ///
751    /// # Parameters
752    ///
753    /// - `name` — kernel name as passed to [`Self::register_kernel`] or
754    ///   [`Self::compile_and_register`]
755    /// - `buffers` — buffer handles bound as the leading kernel arguments
756    /// - `scalars_i32` — `i32` scalars appended after the buffers
757    /// - `scalars_f64` — `f64` scalars appended after the `i32` scalars
758    /// - `grid_x` — number of thread blocks in X dimension
759    /// - `block_x` — number of threads per block in X dimension
760    ///
761    /// Stub path: no-op.
762    pub fn launch_with_scalars(
763        &mut self,
764        name: &str,
765        buffers: &[CudaBufferHandle],
766        scalars_i32: &[i32],
767        scalars_f64: &[f64],
768        grid_x: u32,
769        block_x: u32,
770    ) {
771        #[cfg(feature = "cuda-backend")]
772        if let Some(ctx) = self.real.as_mut() {
773            use cudarc::driver::{LaunchConfig, PushKernelArg};
774            let cfg = LaunchConfig {
775                grid_dim: (grid_x, 1, 1),
776                block_dim: (block_x, 1, 1),
777                shared_mem_bytes: 0,
778            };
779            let Some(func) = ctx.functions.get(name).cloned() else {
780                return;
781            };
782
783            // Current support: up to two buffer arguments.  Validate indices
784            // are in range and pairwise distinct (aliasing breaks the unsafe
785            // split below).
786            if buffers.len() > 2 {
787                return;
788            }
789            for (i, h) in buffers.iter().enumerate() {
790                if h.0 >= ctx.real_buffers.len() {
791                    return;
792                }
793                for h2 in &buffers[i + 1..] {
794                    if h.0 == h2.0 {
795                        return;
796                    }
797                }
798            }
799
800            // Materialise the buffer references first (raw-pointer split),
801            // then borrow ctx.stream immutably to build the launch.  The
802            // raw-pointer derived references and ctx.stream live in disjoint
803            // fields of ctx; the borrow checker cannot see this through the
804            // pointer cast, so we rely on the manual validation above.
805            let real_ptr = ctx.real_buffers.as_mut_ptr();
806            // SAFETY: indices validated above; lifetimes do not outlive this
807            // function and we do not call any &mut ctx.real_buffers method
808            // between here and `.launch(cfg)`.
809            let buf0 = buffers.first().map(|h| unsafe { &mut *real_ptr.add(h.0) });
810            let buf1 = buffers.get(1).map(|h| unsafe { &mut *real_ptr.add(h.0) });
811
812            let mut builder = ctx.stream.launch_builder(&func);
813            if let Some(b) = buf0 {
814                builder.arg(b);
815            }
816            if let Some(b) = buf1 {
817                builder.arg(b);
818            }
819            for v in scalars_i32 {
820                builder.arg(v);
821            }
822            for v in scalars_f64 {
823                builder.arg(v);
824            }
825            let _ = unsafe { builder.launch(cfg) };
826            return;
827        }
828        // Stub: no-op
829        let _ = (name, buffers, scalars_i32, scalars_f64, grid_x, block_x);
830    }
831
832    /// Synchronise the device (blocks until all submitted work completes).
833    ///
834    /// Real path: `CudaStream::synchronize()`.
835    /// Stub path: immediate return.
836    pub fn synchronize(&mut self) {
837        #[cfg(feature = "cuda-backend")]
838        if let Some(ctx) = self.real.as_ref() {
839            let _ = ctx.synchronize();
840        }
841    }
842
843    // ── Device query ─────────────────────────────────────────────────────────
844
845    /// Return the number of CUDA devices available on this system.
846    ///
847    /// Real path: calls `cudarc::driver::result::device::get_count()`.
848    /// Stub path: always returns `0`.
849    pub fn device_count() -> u32 {
850        #[cfg(feature = "cuda-backend")]
851        {
852            // cudarc panics on dlopen failure with dynamic-loading; catch it.
853            let count = std::panic::catch_unwind(|| {
854                cudarc::driver::result::init()
855                    .ok()
856                    .and_then(|()| cudarc::driver::result::device::get_count().ok())
857                    .map(|n| n as u32)
858                    .unwrap_or(0)
859            });
860            count.unwrap_or(0)
861        }
862        #[cfg(not(feature = "cuda-backend"))]
863        {
864            0
865        }
866    }
867
868    /// Query device attributes for device `ordinal` without creating a backend.
869    ///
870    /// Stub path: always returns `Err(CudaInitError::NotAvailable)`.
871    /// Real path: returns basic info derived from the driver (name, total mem, CC).
872    pub fn query_device_info(ordinal: u32) -> Result<CudaDeviceInfo, CudaInitError> {
873        #[cfg(feature = "cuda-backend")]
874        {
875            use cudarc::driver::result;
876            result::init().map_err(|e| CudaInitError::DeviceError(format!("{e:?}")))?;
877            let dev = result::device::get(ordinal as i32)
878                .map_err(|_| CudaInitError::DeviceOrdinalOutOfRange(ordinal))?;
879            let name = result::device::get_name(dev).unwrap_or_else(|_| "unknown".to_owned());
880            let total_mem = unsafe { result::device::total_mem(dev) }.unwrap_or(0);
881            Ok(CudaDeviceInfo {
882                ordinal,
883                name,
884                total_mem_bytes: total_mem as u64,
885                ..Default::default()
886            })
887        }
888        #[cfg(not(feature = "cuda-backend"))]
889        {
890            let _ = ordinal;
891            Err(CudaInitError::NotAvailable)
892        }
893    }
894}
895
896// ── Real-path constructor (feature-gated) ─────────────────────────────────────
897
898#[cfg(feature = "cuda-backend")]
899impl CudaBackend {
900    /// Initialise a real CUDA backend on device `ordinal` using cudarc 0.19.
901    ///
902    /// Called by [`try_new`] when the `cuda-backend` feature is active.
903    fn try_new_real(ordinal: u32) -> Result<Self, CudaInitError> {
904        use cudarc::driver::result;
905
906        // cudarc with `dynamic-loading` **panics** at the dlopen stage when no
907        // CUDA shared library is found on the system (e.g. on macOS or a
908        // machine without an NVIDIA driver).  Catch that panic and convert it
909        // into a clean `Err(DeviceError(...))` so callers can handle it without
910        // unwinding the test process.
911        let init_result = std::panic::catch_unwind(result::init);
912        match init_result {
913            Ok(Ok(())) => {}
914            Ok(Err(e)) => {
915                return Err(CudaInitError::DeviceError(format!("{e:?}")));
916            }
917            Err(_payload) => {
918                // cudarc panicked during dlopen — CUDA driver not present.
919                return Err(CudaInitError::NotAvailable);
920            }
921        }
922
923        let dev = result::device::get(ordinal as i32)
924            .map_err(|_| CudaInitError::DeviceOrdinalOutOfRange(ordinal))?;
925
926        // Query basic device info before acquiring the context.
927        let name = result::device::get_name(dev).unwrap_or_else(|_| "unknown".to_owned());
928        // SAFETY: `dev` was returned by `result::device::get`, fulfilling the contract.
929        let total_mem = unsafe { result::device::total_mem(dev) }.unwrap_or(0);
930
931        let real = real_ctx::CudaRealContext::new(ordinal)?;
932
933        Ok(Self {
934            device_info: CudaDeviceInfo {
935                ordinal,
936                name,
937                total_mem_bytes: total_mem as u64,
938                ..Default::default()
939            },
940            available: true,
941            buffers: Vec::new(),
942            kernels: Vec::new(),
943            real: Some(real),
944        })
945    }
946}
947
948// ── Debug impl ────────────────────────────────────────────────────────────────
949
950impl std::fmt::Debug for CudaBackend {
951    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
952        f.debug_struct("CudaBackend")
953            .field("device", &self.device_info.name)
954            .field("available", &self.available)
955            .field("buffers", &self.buffers.len())
956            .field("kernels", &self.kernels.len())
957            .finish()
958    }
959}
960
961// ── tests ─────────────────────────────────────────────────────────────────────
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966
967    #[test]
968    fn test_try_new_behaviour() {
969        // Without the `cuda-backend` feature, `try_new` must fail with
970        // `FeatureNotEnabled`.
971        //
972        // With the `cuda-backend` feature on a machine without a CUDA driver,
973        // it must fail with `NotAvailable` / `DeviceError` (and not panic).
974        //
975        // With the `cuda-backend` feature on a machine *with* a working CUDA
976        // driver and at least one device, it returns `Ok` and the backend
977        // must report itself as available.  All three outcomes are valid;
978        // the contract is "no panic and outcome consistent with environment".
979        let result = CudaBackend::try_new(0);
980        #[cfg(not(feature = "cuda-backend"))]
981        {
982            assert!(matches!(result, Err(CudaInitError::FeatureNotEnabled)));
983        }
984        #[cfg(feature = "cuda-backend")]
985        {
986            match result {
987                Ok(b) => assert!(b.is_available()),
988                Err(_) => { /* no CUDA driver / device on this machine — OK */ }
989            }
990        }
991    }
992
993    #[test]
994    fn test_stub_backend_buffer_roundtrip() {
995        let mut b = CudaBackend::new_stub();
996        let h = b.create_buffer(8);
997        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0_f64];
998        b.write_buffer(h, &data);
999        let out = b.read_buffer(h);
1000        assert_eq!(out, data);
1001    }
1002
1003    #[test]
1004    fn test_stub_kernel_registration() {
1005        let mut b = CudaBackend::new_stub();
1006        b.register_kernel("sph_density", PTX_SPH_DENSITY);
1007        assert_eq!(b.kernels.len(), 1);
1008        assert_eq!(b.kernels[0], "sph_density");
1009    }
1010
1011    #[test]
1012    fn test_stub_unified_alloc() {
1013        let mut b = CudaBackend::new_stub();
1014        let h = b.alloc_unified(16);
1015        b.write_buffer(h, &[std::f64::consts::PI; 16]);
1016        let out = b.read_buffer(h);
1017        assert!((out[0] - std::f64::consts::PI).abs() < 1e-10);
1018        // Verify the entry is marked as unified
1019        assert!(b.buffers[h.0].unified);
1020    }
1021
1022    #[test]
1023    fn test_device_count_environment_consistent() {
1024        // Without the `cuda-backend` feature the count is always 0.
1025        // With the feature the count reflects the host: 0 on machines without
1026        // a CUDA driver, >=1 on machines with one or more CUDA devices.  In
1027        // either case the call must not panic.
1028        let count = CudaBackend::device_count();
1029        #[cfg(not(feature = "cuda-backend"))]
1030        {
1031            assert_eq!(count, 0);
1032        }
1033        #[cfg(feature = "cuda-backend")]
1034        {
1035            // Just exercise the path — any non-panicking result is acceptable.
1036            let _ = count;
1037        }
1038    }
1039
1040    #[test]
1041    fn test_compile_and_register() {
1042        let mut b = CudaBackend::new_stub();
1043        let result = b.compile_and_register("scan", PTX_PARALLEL_SCAN);
1044        assert!(result.is_ok());
1045        assert_eq!(b.kernels[0], "scan");
1046    }
1047
1048    #[test]
1049    fn test_error_display() {
1050        let e = CudaInitError::CompilationError("undefined symbol 'foo'".into());
1051        let s = format!("{e}");
1052        assert!(s.contains("foo"));
1053    }
1054
1055    #[test]
1056    fn test_cuda_sph_density_src_not_empty() {
1057        assert!(!CUDA_SPH_DENSITY_SRC.is_empty());
1058        assert!(CUDA_SPH_DENSITY_SRC.contains("sph_density_kernel"));
1059    }
1060
1061    #[test]
1062    fn test_try_new_no_panic() {
1063        // Regardless of feature flags or hardware, try_new(0) must not panic.
1064        let _ = CudaBackend::try_new(0);
1065    }
1066}