Skip to main content

oxicuda_runtime/
launch.rs

1//! Kernel launch API.
2//!
3//! Implements `cudaLaunchKernel`, `cudaFuncGetAttributes`, and
4//! `cudaFuncSetAttribute` on top of the CUDA Driver API.
5//!
6//! # Design
7//!
8//! In the CUDA Runtime, kernels are typically invoked via `<<<...>>>` syntax
9//! which the NVCC compiler rewrites into `cudaLaunchKernel` calls.  Since
10//! OxiCUDA never uses NVCC, callers must use the driver-level module/function
11//! handle pair directly.  This module therefore exposes a slightly lower-level
12//! surface that accepts a [`CudaFunction`] instead of a raw symbol pointer.
13
14use std::ffi::c_void;
15
16use oxicuda_driver::ffi::{CUfunction, CUmodule};
17use oxicuda_driver::loader::try_driver;
18
19use crate::error::{CudaRtError, CudaRtResult};
20use crate::stream::CudaStream;
21
22// ─── Re-exports ───────────────────────────────────────────────────────────────
23
24/// A compiled GPU kernel function (alias for the driver's `CUfunction`).
25pub type CudaFunction = CUfunction;
26
27/// A compiled GPU module (alias for the driver's `CUmodule`).
28pub type CudaModule = CUmodule;
29
30// ─── Dim3 ────────────────────────────────────────────────────────────────────
31
32/// 3-D grid / block dimensions for kernel launches.
33///
34/// Mirrors CUDA's `dim3` struct.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub struct Dim3 {
37    /// X dimension.
38    pub x: u32,
39    /// Y dimension.
40    pub y: u32,
41    /// Z dimension.
42    pub z: u32,
43}
44
45impl Dim3 {
46    /// Construct a 1-D dimension (y = z = 1).
47    #[must_use]
48    pub const fn one_d(x: u32) -> Self {
49        Self { x, y: 1, z: 1 }
50    }
51
52    /// Construct a 2-D dimension (z = 1).
53    #[must_use]
54    pub const fn two_d(x: u32, y: u32) -> Self {
55        Self { x, y, z: 1 }
56    }
57
58    /// Construct a full 3-D dimension.
59    #[must_use]
60    pub const fn three_d(x: u32, y: u32, z: u32) -> Self {
61        Self { x, y, z }
62    }
63
64    /// Total number of threads / blocks.
65    #[must_use]
66    pub fn volume(self) -> u64 {
67        self.x as u64 * self.y as u64 * self.z as u64
68    }
69}
70
71impl From<u32> for Dim3 {
72    fn from(x: u32) -> Self {
73        Self::one_d(x)
74    }
75}
76
77impl From<(u32, u32)> for Dim3 {
78    fn from((x, y): (u32, u32)) -> Self {
79        Self::two_d(x, y)
80    }
81}
82
83impl From<(u32, u32, u32)> for Dim3 {
84    fn from((x, y, z): (u32, u32, u32)) -> Self {
85        Self::three_d(x, y, z)
86    }
87}
88
89// ─── FuncAttributes ──────────────────────────────────────────────────────────
90
91/// Attributes of a compiled kernel function.
92///
93/// Mirrors `cudaFuncAttributes`.
94#[derive(Debug, Clone, Copy, Default)]
95pub struct FuncAttributes {
96    /// Size in bytes of statically-allocated shared memory per block.
97    pub shared_size_bytes: usize,
98    /// Size in bytes of the constant memory used by the function.
99    pub const_size_bytes: usize,
100    /// Size in bytes of local memory used by each thread.
101    pub local_size_bytes: usize,
102    /// Maximum number of threads per block the function can use.
103    pub max_threads_per_block: u32,
104    /// Number of registers used by each thread.
105    pub num_regs: u32,
106    /// PTX virtual architecture of the function.
107    pub ptx_version: u32,
108    /// Binary architecture of the function (same as compute capability × 10).
109    pub binary_version: u32,
110    /// Cache mode configuration.
111    pub cache_mode_ca: bool,
112    /// Maximum dynamic shared memory per block.
113    pub max_dynamic_shared_size_bytes: usize,
114    /// Preferred shared memory carveout.
115    pub preferred_shared_memory_carveout: i32,
116}
117
118/// Attribute selector for `cudaFuncSetAttribute`.
119///
120/// Mirrors `cudaFuncAttribute`.
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
122pub enum FuncAttribute {
123    /// Maximum dynamic shared memory size.
124    MaxDynamicSharedMemorySize = 8,
125    /// Preferred shared memory / L1 carveout (0–100).
126    PreferredSharedMemoryCarveout = 9,
127}
128
129// ─── Module / function loading ────────────────────────────────────────────────
130
131/// Load a PTX module from a null-terminated byte string.
132///
133/// Mirrors the driver's `cuModuleLoadDataEx`.
134///
135/// # Errors
136///
137/// Propagates driver errors.
138pub fn module_load_ptx(ptx: &[u8]) -> CudaRtResult<CudaModule> {
139    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
140    let mut module = CUmodule::default();
141    // Ensure null termination.
142    let mut ptx_owned;
143    let ptx_ptr = if ptx.last().copied() == Some(0) {
144        ptx.as_ptr()
145    } else {
146        ptx_owned = ptx.to_vec();
147        ptx_owned.push(0);
148        ptx_owned.as_ptr()
149    };
150    // SAFETY: FFI; ptx_ptr points to null-terminated PTX text.
151    let rc = unsafe {
152        (api.cu_module_load_data_ex)(
153            &raw mut module,
154            ptx_ptr as *const c_void,
155            0,
156            std::ptr::null_mut(),
157            std::ptr::null_mut(),
158        )
159    };
160    if rc != 0 {
161        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidPtx));
162    }
163    Ok(module)
164}
165
166/// Get a function handle by name from a loaded module.
167///
168/// Mirrors the driver's `cuModuleGetFunction`.
169///
170/// # Errors
171///
172/// Returns [`CudaRtError::SymbolNotFound`] if the function does not exist.
173pub fn module_get_function(module: CudaModule, name: &str) -> CudaRtResult<CudaFunction> {
174    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
175    let mut func = CUfunction::default();
176    let name_cstr = std::ffi::CString::new(name).map_err(|_| CudaRtError::InvalidSymbol)?;
177    // SAFETY: FFI; name_cstr is null-terminated.
178    let rc = unsafe { (api.cu_module_get_function)(&raw mut func, module, name_cstr.as_ptr()) };
179    if rc != 0 {
180        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::SymbolNotFound));
181    }
182    Ok(func)
183}
184
185/// Unload a previously loaded module.
186///
187/// Mirrors `cuModuleUnload`.
188///
189/// # Errors
190///
191/// Propagates driver errors.
192pub fn module_unload(module: CudaModule) -> CudaRtResult<()> {
193    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
194    // SAFETY: FFI; module is valid.
195    let rc = unsafe { (api.cu_module_unload)(module) };
196    if rc != 0 {
197        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidResourceHandle));
198    }
199    Ok(())
200}
201
202// ─── Kernel launch ────────────────────────────────────────────────────────────
203
204/// Launch a CUDA kernel.
205///
206/// Mirrors `cudaLaunchKernel` (with explicit function handle).
207///
208/// # Parameters
209///
210/// - `func` — compiled kernel function (from [`module_get_function`]).
211/// - `grid` — grid dimensions.
212/// - `block` — block dimensions.
213/// - `args` — mutable slice of pointers to kernel arguments; each element
214///   must point to the actual argument value, as required by `cuLaunchKernel`.
215/// - `shared_mem` — dynamic shared memory in bytes.
216/// - `stream` — CUDA stream on which to enqueue the launch.
217///
218/// # Safety
219///
220/// - `func` must be a valid kernel handle.
221/// - Each `args[i]` pointer must point to a value whose type matches the
222///   kernel's `i`-th parameter.
223/// - `shared_mem` must not exceed the device's maximum shared memory per block.
224///
225/// # Errors
226///
227/// Propagates driver errors.
228pub unsafe fn launch_kernel(
229    func: CudaFunction,
230    grid: Dim3,
231    block: Dim3,
232    args: &mut [*mut c_void],
233    shared_mem: u32,
234    stream: CudaStream,
235) -> CudaRtResult<()> {
236    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
237    // SAFETY: FFI; caller guarantees func, args, and stream are valid.
238    let rc = unsafe {
239        (api.cu_launch_kernel)(
240            func,
241            grid.x,
242            grid.y,
243            grid.z,
244            block.x,
245            block.y,
246            block.z,
247            shared_mem,
248            stream.raw(),
249            args.as_mut_ptr(),
250            std::ptr::null_mut(), // extra (unused)
251        )
252    };
253    if rc != 0 {
254        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::LaunchFailure));
255    }
256    Ok(())
257}
258
259/// Query attributes of a compiled kernel.
260///
261/// Mirrors `cudaFuncGetAttributes`.
262///
263/// # Errors
264///
265/// Propagates driver errors.
266pub fn func_get_attributes(func: CudaFunction) -> CudaRtResult<FuncAttributes> {
267    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
268
269    // cu_func_get_attribute is optional (not available on very old drivers).
270    let get_attr_fn = api.cu_func_get_attribute.ok_or(CudaRtError::NotSupported)?;
271    let attr = |a: oxicuda_driver::ffi::CUfunction_attribute| -> CudaRtResult<i32> {
272        let mut v: std::ffi::c_int = 0;
273        // SAFETY: FFI.
274        let rc = unsafe { get_attr_fn(&raw mut v, a as std::ffi::c_int, func) };
275        if rc != 0 {
276            return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDeviceFunction));
277        }
278        Ok(v)
279    };
280
281    use oxicuda_driver::ffi::CUfunction_attribute as FA;
282    Ok(FuncAttributes {
283        shared_size_bytes: attr(FA::SharedSizeBytes)? as usize,
284        const_size_bytes: attr(FA::ConstSizeBytes)? as usize,
285        local_size_bytes: attr(FA::LocalSizeBytes)? as usize,
286        max_threads_per_block: attr(FA::MaxThreadsPerBlock)? as u32,
287        num_regs: attr(FA::NumRegs)? as u32,
288        ptx_version: attr(FA::PtxVersion)? as u32,
289        binary_version: attr(FA::BinaryVersion)? as u32,
290        cache_mode_ca: attr(FA::CacheModeCa)? != 0,
291        max_dynamic_shared_size_bytes: attr(FA::MaxDynamicSharedSizeBytes)? as usize,
292        preferred_shared_memory_carveout: attr(FA::PreferredSharedMemoryCarveout)?,
293    })
294}
295
296/// Set a kernel attribute.
297///
298/// Mirrors `cudaFuncSetAttribute`.
299///
300/// # Errors
301///
302/// Propagates driver errors.
303pub fn func_set_attribute(func: CudaFunction, attr: FuncAttribute, value: i32) -> CudaRtResult<()> {
304    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
305    // cu_func_set_attribute is optional (not available on very old drivers).
306    let set_attr_fn = api.cu_func_set_attribute.ok_or(CudaRtError::NotSupported)?;
307    // SAFETY: FFI.
308    let rc = unsafe { set_attr_fn(func, attr as std::ffi::c_int, value) };
309    if rc != 0 {
310        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDeviceFunction));
311    }
312    Ok(())
313}
314
315// ─── Tests ───────────────────────────────────────────────────────────────────
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn dim3_one_d() {
323        let d = Dim3::one_d(128);
324        assert_eq!(d.x, 128);
325        assert_eq!(d.y, 1);
326        assert_eq!(d.z, 1);
327        assert_eq!(d.volume(), 128);
328    }
329
330    #[test]
331    fn dim3_from_u32() {
332        let d: Dim3 = 256u32.into();
333        assert_eq!(d.x, 256);
334    }
335
336    #[test]
337    fn dim3_from_tuple() {
338        let d: Dim3 = (32u32, 8u32).into();
339        assert_eq!(d.volume(), 256);
340        let d3: Dim3 = (4u32, 4u32, 4u32).into();
341        assert_eq!(d3.volume(), 64);
342    }
343
344    #[test]
345    fn dim3_volume() {
346        assert_eq!(Dim3::three_d(2, 3, 4).volume(), 24);
347    }
348
349    #[test]
350    fn module_load_ptx_without_gpu_errors() {
351        let ptx = b"// empty\n\0";
352        let _ = module_load_ptx(ptx); // must not panic
353    }
354}