ringkernel_cuda/
lib.rs

1//! CUDA Backend for RingKernel
2//!
3//! This crate provides NVIDIA CUDA GPU support for RingKernel using cudarc.
4//!
5//! # Features
6//!
7//! - Persistent kernel execution (cooperative groups)
8//! - Lock-free message queues in GPU global memory
9//! - PTX compilation via NVRTC
10//! - Multi-GPU support
11//!
12//! # Requirements
13//!
14//! - NVIDIA GPU with Compute Capability 7.0+
15//! - CUDA Toolkit 11.0+
16//! - Native Linux (persistent kernels) or WSL2 (event-driven fallback)
17//!
18//! # Example
19//!
20//! ```ignore
21//! use ringkernel_cuda::CudaRuntime;
22//! use ringkernel_core::runtime::RingKernelRuntime;
23//!
24//! #[tokio::main]
25//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
26//!     let runtime = CudaRuntime::new().await?;
27//!     let kernel = runtime.launch("vector_add", Default::default()).await?;
28//!     kernel.activate().await?;
29//!     Ok(())
30//! }
31//! ```
32
33#![warn(missing_docs)]
34
35#[cfg(feature = "cooperative")]
36pub mod cooperative;
37#[cfg(feature = "cuda")]
38mod device;
39#[cfg(feature = "cuda")]
40pub mod driver_api;
41#[cfg(feature = "cuda")]
42pub mod k2k_gpu;
43#[cfg(feature = "cuda")]
44mod kernel;
45#[cfg(feature = "cuda")]
46mod memory;
47#[cfg(feature = "cuda")]
48pub mod persistent;
49#[cfg(feature = "cuda")]
50pub mod phases;
51#[cfg(feature = "profiling")]
52pub mod profiling;
53#[cfg(feature = "cuda")]
54pub mod reduction;
55#[cfg(feature = "cuda")]
56mod runtime;
57#[cfg(feature = "cuda")]
58mod stencil;
59
60#[cfg(feature = "cuda")]
61pub use device::CudaDevice;
62#[cfg(feature = "cuda")]
63pub use kernel::CudaKernel;
64#[cfg(feature = "cuda")]
65pub use memory::{CudaBuffer, CudaControlBlock, CudaMemoryPool, CudaMessageQueue};
66#[cfg(feature = "cuda")]
67pub use persistent::CudaMappedBuffer;
68#[cfg(feature = "cuda")]
69pub use phases::{
70    InterPhaseReduction, KernelPhase, MultiPhaseConfig, MultiPhaseExecutor, PhaseExecutionStats,
71    SyncMode,
72};
73#[cfg(feature = "cuda")]
74pub use reduction::{
75    generate_block_reduce_code, generate_grid_reduce_code, generate_reduce_and_broadcast_code,
76    CacheKey, CacheStats, CachedReductionBuffer, ReductionBuffer, ReductionBufferBuilder,
77    ReductionBufferCache,
78};
79#[cfg(feature = "cuda")]
80pub use runtime::CudaRuntime;
81#[cfg(feature = "cuda")]
82pub use stencil::{CompiledStencilKernel, LaunchConfig, StencilKernelLoader};
83
84// Profiling re-exports
85#[cfg(feature = "profiling")]
86pub use profiling::{
87    CudaEvent, CudaEventFlags, CudaMemoryKind, CudaMemoryTracker, CudaNvtxProfiler,
88    GpuChromeTraceBuilder, GpuEventArgs, GpuTimer, GpuTimerPool, GpuTraceEvent, KernelMetrics,
89    ProfilingSession, TrackedAllocation, TransferDirection, TransferMetrics,
90};
91
92/// Re-export memory module for advanced usage.
93#[cfg(feature = "cuda")]
94pub mod memory_exports {
95    pub use super::memory::{CudaBuffer, CudaControlBlock, CudaMemoryPool, CudaMessageQueue};
96}
97
98// Placeholder implementations when CUDA is not available
99#[cfg(not(feature = "cuda"))]
100mod stub {
101    use async_trait::async_trait;
102    use ringkernel_core::error::{Result, RingKernelError};
103    use ringkernel_core::runtime::{
104        Backend, KernelHandle, KernelId, LaunchOptions, RingKernelRuntime, RuntimeMetrics,
105    };
106
107    /// Stub CUDA runtime when CUDA feature is disabled.
108    pub struct CudaRuntime;
109
110    impl CudaRuntime {
111        /// Create fails when CUDA is not available.
112        pub async fn new() -> Result<Self> {
113            Err(RingKernelError::BackendUnavailable(
114                "CUDA feature not enabled".to_string(),
115            ))
116        }
117    }
118
119    #[async_trait]
120    impl RingKernelRuntime for CudaRuntime {
121        fn backend(&self) -> Backend {
122            Backend::Cuda
123        }
124
125        fn is_backend_available(&self, _backend: Backend) -> bool {
126            false
127        }
128
129        async fn launch(&self, _kernel_id: &str, _options: LaunchOptions) -> Result<KernelHandle> {
130            Err(RingKernelError::BackendUnavailable("CUDA".to_string()))
131        }
132
133        fn get_kernel(&self, _kernel_id: &KernelId) -> Option<KernelHandle> {
134            None
135        }
136
137        fn list_kernels(&self) -> Vec<KernelId> {
138            vec![]
139        }
140
141        fn metrics(&self) -> RuntimeMetrics {
142            RuntimeMetrics::default()
143        }
144
145        async fn shutdown(&self) -> Result<()> {
146            Ok(())
147        }
148    }
149}
150
151#[cfg(not(feature = "cuda"))]
152pub use stub::CudaRuntime;
153
154/// Check if CUDA is available at runtime.
155///
156/// This function returns false if:
157/// - CUDA feature is not enabled
158/// - CUDA libraries are not installed on the system
159/// - No CUDA devices are present
160///
161/// It safely catches panics from cudarc when CUDA is not installed.
162pub fn is_cuda_available() -> bool {
163    #[cfg(feature = "cuda")]
164    {
165        // cudarc panics if CUDA libraries are not found, so we catch that
166        std::panic::catch_unwind(|| {
167            cudarc::driver::CudaContext::device_count()
168                .map(|c| c > 0)
169                .unwrap_or(false)
170        })
171        .unwrap_or(false)
172    }
173    #[cfg(not(feature = "cuda"))]
174    {
175        false
176    }
177}
178
179/// Get CUDA device count.
180///
181/// Returns 0 if CUDA is not available or libraries are not installed.
182pub fn cuda_device_count() -> usize {
183    #[cfg(feature = "cuda")]
184    {
185        // cudarc panics if CUDA libraries are not found, so we catch that
186        std::panic::catch_unwind(|| {
187            cudarc::driver::CudaContext::device_count().unwrap_or(0) as usize
188        })
189        .unwrap_or(0)
190    }
191    #[cfg(not(feature = "cuda"))]
192    {
193        0
194    }
195}
196
197/// Compile CUDA C source code to PTX using NVRTC.
198///
199/// This wraps `cudarc::nvrtc::compile_ptx` to provide PTX compilation
200/// without requiring downstream crates to depend on cudarc directly.
201///
202/// # Arguments
203///
204/// * `cuda_source` - CUDA C source code string
205///
206/// # Returns
207///
208/// PTX assembly as a string, or an error if compilation fails.
209///
210/// # Example
211///
212/// ```ignore
213/// use ringkernel_cuda::compile_ptx;
214///
215/// let cuda_source = r#"
216///     extern "C" __global__ void add(float* a, float* b, float* c, int n) {
217///         int i = blockIdx.x * blockDim.x + threadIdx.x;
218///         if (i < n) c[i] = a[i] + b[i];
219///     }
220/// "#;
221///
222/// let ptx = compile_ptx(cuda_source)?;
223/// ```
224#[cfg(feature = "cuda")]
225pub fn compile_ptx(cuda_source: &str) -> ringkernel_core::error::Result<String> {
226    use ringkernel_core::error::RingKernelError;
227
228    let ptx = cudarc::nvrtc::compile_ptx(cuda_source).map_err(|e| {
229        RingKernelError::CompilationError(format!("NVRTC compilation failed: {}", e))
230    })?;
231
232    Ok(ptx.to_src().to_string())
233}
234
235/// Stub compile_ptx when CUDA is not available.
236#[cfg(not(feature = "cuda"))]
237pub fn compile_ptx(_cuda_source: &str) -> ringkernel_core::error::Result<String> {
238    Err(ringkernel_core::error::RingKernelError::BackendUnavailable(
239        "CUDA feature not enabled".to_string(),
240    ))
241}
242
243/// PTX kernel source template for persistent ring kernel.
244///
245/// This is a minimal kernel that immediately marks itself as terminated.
246/// Uses PTX 8.0 / sm_89 for Ada Lovelace GPU compatibility (RTX 40xx series).
247pub const RING_KERNEL_PTX_TEMPLATE: &str = r#"
248.version 8.0
249.target sm_89
250.address_size 64
251
252.visible .entry ring_kernel_main(
253    .param .u64 control_block_ptr,
254    .param .u64 input_queue_ptr,
255    .param .u64 output_queue_ptr,
256    .param .u64 shared_state_ptr
257) {
258    .reg .u64 %cb_ptr;
259    .reg .u32 %one;
260
261    // Load control block pointer
262    ld.param.u64 %cb_ptr, [control_block_ptr];
263
264    // Mark as terminated immediately (offset 8)
265    mov.u32 %one, 1;
266    st.global.u32 [%cb_ptr + 8], %one;
267
268    ret;
269}
270"#;