Skip to main content

oxicuda_launch/
kernel.rs

1//! Type-safe GPU kernel management and argument passing.
2//!
3//! This module provides the [`Kernel`] struct for launching GPU kernels
4//! and the [`KernelArgs`] trait for type-safe argument passing to CUDA
5//! kernel functions.
6//!
7//! # Architecture
8//!
9//! A [`Kernel`] wraps a [`Function`] handle and holds an `Arc<Module>`
10//! to ensure the PTX module remains loaded for the kernel's lifetime.
11//! Arguments are passed via the [`KernelArgs`] trait, which converts
12//! typed Rust values into the `*mut c_void` array that `cuLaunchKernel`
13//! expects.
14//!
15//! # Tuple arguments
16//!
17//! The [`KernelArgs`] trait is implemented for tuples of `Copy` types
18//! up to 24 elements. Each element must be `Copy` because kernel
19//! arguments are passed by value to the GPU.
20//!
21//! # Example
22//!
23//! ```rust,no_run
24//! # use std::sync::Arc;
25//! # use oxicuda_driver::{Module, Stream, Context, Device};
26//! # use oxicuda_launch::{Kernel, LaunchParams, Dim3};
27//! # fn main() -> oxicuda_driver::CudaResult<()> {
28//! # oxicuda_driver::init()?;
29//! # let dev = Device::get(0)?;
30//! # let ctx = Arc::new(Context::new(&dev)?);
31//! # let ptx = "";
32//! let module = Arc::new(Module::from_ptx(ptx)?);
33//! let kernel = Kernel::from_module(module, "vector_add")?;
34//!
35//! let stream = Stream::new(&ctx)?;
36//! let params = LaunchParams::new(4u32, 256u32);
37//!
38//! // Launch with typed arguments: (a_ptr, b_ptr, c_ptr, n)
39//! let args = (0u64, 0u64, 0u64, 1024u32);
40//! kernel.launch(&params, &stream, &args)?;
41//! # Ok(())
42//! # }
43//! ```
44
45use std::ffi::c_void;
46use std::sync::Arc;
47
48use oxicuda_driver::error::CudaResult;
49use oxicuda_driver::loader::try_driver;
50use oxicuda_driver::module::{Function, Module};
51use oxicuda_driver::stream::Stream;
52
53use crate::params::LaunchParams;
54use crate::trace::KernelSpanGuard;
55
56// ---------------------------------------------------------------------------
57// KernelArgs trait
58// ---------------------------------------------------------------------------
59
60/// Trait for types that can be passed as kernel arguments.
61///
62/// Kernel arguments must be convertible to an array of void pointers
63/// that `cuLaunchKernel` accepts. Each pointer points to the argument
64/// value on the host; the CUDA driver copies the values to the GPU
65/// before the kernel executes.
66///
67/// # Safety
68///
69/// Implementors must ensure that:
70/// - `as_param_ptrs` returns valid pointers to the argument values.
71/// - The pointed-to values remain valid for the duration of the kernel launch
72///   (i.e., until `cuLaunchKernel` returns).
73/// - The argument types and sizes match what the kernel expects.
74pub unsafe trait KernelArgs {
75    /// Convert arguments to an array of void pointers for `cuLaunchKernel`.
76    ///
77    /// Each element in the returned `Vec` is a pointer to one kernel argument.
78    /// The CUDA driver reads the value through each pointer and copies it
79    /// to the GPU.
80    fn as_param_ptrs(&self) -> Vec<*mut c_void>;
81}
82
83// ---------------------------------------------------------------------------
84// KernelArgs — unit type (no arguments)
85// ---------------------------------------------------------------------------
86
87/// Implementation for kernels that take no arguments.
88///
89/// # Safety
90///
91/// Returns an empty pointer array, which is valid for zero-argument kernels.
92unsafe impl KernelArgs for () {
93    #[inline]
94    fn as_param_ptrs(&self) -> Vec<*mut c_void> {
95        Vec::new()
96    }
97}
98
99// ---------------------------------------------------------------------------
100// KernelArgs — tuple implementations via macro
101// ---------------------------------------------------------------------------
102
103/// Generates [`KernelArgs`] implementations for tuples of `Copy` types.
104///
105/// Each tuple element is converted to a `*mut c_void` by taking
106/// a reference to the element and casting through `*const T`.
107macro_rules! impl_kernel_args_tuple {
108    ($($idx:tt: $T:ident),+) => {
109        /// # Safety
110        ///
111        /// The pointers returned point into `self`, which must remain
112        /// valid (i.e., not moved or dropped) until `cuLaunchKernel` returns.
113        unsafe impl<$($T: Copy),+> KernelArgs for ($($T,)+) {
114            #[inline]
115            fn as_param_ptrs(&self) -> Vec<*mut c_void> {
116                vec![
117                    $(&self.$idx as *const $T as *mut c_void,)+
118                ]
119            }
120        }
121    };
122}
123
124impl_kernel_args_tuple!(0: A);
125impl_kernel_args_tuple!(0: A, 1: B);
126impl_kernel_args_tuple!(0: A, 1: B, 2: C);
127impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D);
128impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E);
129impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
130impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G);
131impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H);
132impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I);
133impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J);
134impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K);
135impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L);
136impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M);
137impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N);
138impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O);
139impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P);
140impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q);
141impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R);
142impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S);
143impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T);
144impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U);
145impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V);
146impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W);
147impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X);
148impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y);
149impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z);
150impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z, 26: AA);
151impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z, 26: AA, 27: BB);
152
153// ---------------------------------------------------------------------------
154// Kernel struct
155// ---------------------------------------------------------------------------
156
157/// A launchable GPU kernel with module lifetime management.
158///
159/// Holds an `Arc<Module>` to ensure the PTX module remains loaded
160/// as long as any `Kernel` references it. This is important because
161/// [`Function`] handles become invalid once their parent module is
162/// unloaded.
163///
164/// # Creating a kernel
165///
166/// ```rust,no_run
167/// # use std::sync::Arc;
168/// # use oxicuda_driver::Module;
169/// # use oxicuda_launch::Kernel;
170/// # fn main() -> oxicuda_driver::CudaResult<()> {
171/// # let ptx = "";
172/// let module = Arc::new(Module::from_ptx(ptx)?);
173/// let kernel = Kernel::from_module(module, "my_kernel")?;
174/// println!("loaded kernel: {}", kernel.name());
175/// # Ok(())
176/// # }
177/// ```
178///
179/// # Launching
180///
181/// ```rust,no_run
182/// # use std::sync::Arc;
183/// # use oxicuda_driver::{Module, Stream, Context, Device};
184/// # use oxicuda_launch::{Kernel, LaunchParams};
185/// # fn main() -> oxicuda_driver::CudaResult<()> {
186/// # oxicuda_driver::init()?;
187/// # let dev = Device::get(0)?;
188/// # let ctx = Arc::new(Context::new(&dev)?);
189/// # let ptx = "";
190/// # let module = Arc::new(Module::from_ptx(ptx)?);
191/// # let kernel = Kernel::from_module(module, "my_kernel")?;
192/// let stream = Stream::new(&ctx)?;
193/// let params = LaunchParams::new(4u32, 256u32);
194/// kernel.launch(&params, &stream, &(42u32, 1024u32))?;
195/// # Ok(())
196/// # }
197/// ```
198pub struct Kernel {
199    /// The underlying CUDA function handle.
200    function: Function,
201    /// Keeps the parent module alive as long as this kernel exists.
202    _module: Arc<Module>,
203    /// The kernel function name (for debugging and diagnostics).
204    name: String,
205}
206
207impl Kernel {
208    /// Creates a new `Kernel` from a module and function name.
209    ///
210    /// Looks up the named function in the module. The `Arc<Module>` ensures
211    /// the module is not unloaded while this kernel exists.
212    ///
213    /// # Errors
214    ///
215    /// Returns [`CudaError::NotFound`](oxicuda_driver::CudaError::NotFound) if no
216    /// function with the given name exists in the module, or another
217    /// [`CudaError`](oxicuda_driver::CudaError) on driver failure.
218    pub fn from_module(module: Arc<Module>, name: &str) -> CudaResult<Self> {
219        let function = module.get_function(name)?;
220        Ok(Self {
221            function,
222            _module: module,
223            name: name.to_owned(),
224        })
225    }
226
227    /// Launches the kernel with the given parameters and arguments on a stream.
228    ///
229    /// This is the primary entry point for kernel execution. It calls
230    /// `cuLaunchKernel` with the specified grid/block dimensions, shared
231    /// memory, stream, and kernel arguments.
232    ///
233    /// The launch is asynchronous — it returns immediately and the kernel
234    /// executes on the GPU. Use [`Stream::synchronize`] to wait for completion.
235    ///
236    /// # Type safety
237    ///
238    /// The `args` parameter accepts any type implementing [`KernelArgs`],
239    /// including tuples of `Copy` types up to 24 elements. The caller is
240    /// responsible for ensuring the argument types match the kernel signature.
241    ///
242    /// # Errors
243    ///
244    /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the launch fails
245    /// (e.g., invalid dimensions, insufficient resources, driver error).
246    pub fn launch<A: KernelArgs>(
247        &self,
248        params: &LaunchParams,
249        stream: &Stream,
250        args: &A,
251    ) -> CudaResult<()> {
252        // Emit a tracing span for this kernel launch (no-op when the
253        // `tracing` feature is disabled).
254        let _span = KernelSpanGuard::enter(
255            &self.name,
256            (params.grid.x, params.grid.y, params.grid.z),
257            (params.block.x, params.block.y, params.block.z),
258        );
259
260        let driver = try_driver()?;
261        let mut param_ptrs = args.as_param_ptrs();
262        oxicuda_driver::error::check(unsafe {
263            (driver.cu_launch_kernel)(
264                self.function.raw(),
265                params.grid.x,
266                params.grid.y,
267                params.grid.z,
268                params.block.x,
269                params.block.y,
270                params.block.z,
271                params.shared_mem_bytes,
272                stream.raw(),
273                param_ptrs.as_mut_ptr(),
274                std::ptr::null_mut(),
275            )
276        })
277    }
278
279    /// Returns the kernel function name.
280    #[inline]
281    pub fn name(&self) -> &str {
282        &self.name
283    }
284
285    /// Returns a reference to the underlying [`Function`] handle.
286    ///
287    /// This can be used for occupancy queries and other function-level
288    /// operations provided by `oxicuda-driver`.
289    #[inline]
290    pub fn function(&self) -> &Function {
291        &self.function
292    }
293
294    /// Returns the maximum number of active blocks per streaming multiprocessor
295    /// for a given block size and dynamic shared memory.
296    ///
297    /// Delegates to [`Function::max_active_blocks_per_sm`].
298    ///
299    /// # Parameters
300    ///
301    /// * `block_size` — number of threads per block.
302    /// * `dynamic_smem` — dynamic shared memory per block in bytes.
303    ///
304    /// # Errors
305    ///
306    /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the query fails.
307    pub fn max_active_blocks_per_sm(
308        &self,
309        block_size: i32,
310        dynamic_smem: usize,
311    ) -> CudaResult<i32> {
312        self.function
313            .max_active_blocks_per_sm(block_size, dynamic_smem)
314    }
315
316    /// Returns the optimal block size for this kernel and the minimum
317    /// grid size to achieve maximum occupancy.
318    ///
319    /// Delegates to [`Function::optimal_block_size`].
320    ///
321    /// Returns `(min_grid_size, optimal_block_size)`.
322    ///
323    /// # Parameters
324    ///
325    /// * `dynamic_smem` — dynamic shared memory per block in bytes.
326    ///
327    /// # Errors
328    ///
329    /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the query fails.
330    pub fn optimal_block_size(&self, dynamic_smem: usize) -> CudaResult<(i32, i32)> {
331        self.function.optimal_block_size(dynamic_smem)
332    }
333}
334
335impl std::fmt::Debug for Kernel {
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        f.debug_struct("Kernel")
338            .field("name", &self.name)
339            .field("function", &self.function)
340            .finish_non_exhaustive()
341    }
342}
343
344impl std::fmt::Display for Kernel {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        write!(f, "Kernel({})", self.name)
347    }
348}
349
350// ---------------------------------------------------------------------------
351// Tests
352// ---------------------------------------------------------------------------
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn unit_args_empty() {
360        let args = ();
361        let ptrs = args.as_param_ptrs();
362        assert!(ptrs.is_empty());
363    }
364
365    #[test]
366    fn single_arg_ptr_valid() {
367        let args = (42u32,);
368        let ptrs = args.as_param_ptrs();
369        assert_eq!(ptrs.len(), 1);
370        // Verify the pointer actually points to the value.
371        let val_ptr = ptrs[0] as *const u32;
372        assert_eq!(unsafe { *val_ptr }, 42u32);
373    }
374
375    #[test]
376    fn two_args_ptr_valid() {
377        let args = (10u32, 20u64);
378        let ptrs = args.as_param_ptrs();
379        assert_eq!(ptrs.len(), 2);
380        assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 10u32);
381        assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 20u64);
382    }
383
384    #[test]
385    fn four_args_ptr_valid() {
386        let args = (1u32, 2u64, 3.0f32, 4.0f64);
387        let ptrs = args.as_param_ptrs();
388        assert_eq!(ptrs.len(), 4);
389        assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 1u32);
390        assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 2u64);
391        assert!((unsafe { *(ptrs[2] as *const f32) } - 3.0f32).abs() < f32::EPSILON);
392        assert!((unsafe { *(ptrs[3] as *const f64) } - 4.0f64).abs() < f64::EPSILON);
393    }
394
395    #[test]
396    fn twelve_args_count() {
397        let args = (
398            1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32,
399        );
400        let ptrs = args.as_param_ptrs();
401        assert_eq!(ptrs.len(), 12);
402        for (i, ptr) in ptrs.iter().enumerate() {
403            let val = unsafe { *(*ptr as *const u32) };
404            assert_eq!(val, (i as u32) + 1);
405        }
406    }
407
408    // ---------------------------------------------------------------------------
409    // Quality gate tests (CPU-only, E2E PTX chain parameter verification)
410    // ---------------------------------------------------------------------------
411
412    #[test]
413    fn launch_params_grid_calculation_e2e() {
414        // Given n = 1_048_576 (1M elements) and block_size = 256,
415        // grid_size_for must return exactly 4096 (ceiling division).
416        let n: u32 = 1_048_576;
417        let block_size: u32 = 256;
418        let grid = crate::grid::grid_size_for(n, block_size);
419        assert_eq!(
420            grid, 4096,
421            "grid_size_for(1M, 256) must be 4096, got {grid}"
422        );
423        // Also verify via arithmetic: 1_048_576 / 256 == 4096 exactly
424        assert_eq!(
425            n % block_size,
426            0,
427            "n must be exactly divisible by block_size"
428        );
429    }
430
431    #[test]
432    fn launch_params_stores_grid_and_block() {
433        // LaunchParams::new(4096, 256) must record grid==4096 and block==256.
434        let params = LaunchParams::new(4096u32, 256u32);
435        assert_eq!(
436            params.grid.x, 4096,
437            "grid.x must be 4096, got {}",
438            params.grid.x
439        );
440        assert_eq!(
441            params.block.x, 256,
442            "block.x must be 256, got {}",
443            params.block.x
444        );
445        assert_eq!(params.shared_mem_bytes, 0);
446        // Total threads: 4096 * 256 = 1_048_576
447        assert_eq!(params.total_threads(), 1_048_576);
448    }
449
450    #[test]
451    fn named_args_builder_chain() {
452        // ArgBuilder::new().add("a", &1u32).add("b", &2.0f32).build() must have length 2.
453        use crate::named_args::ArgBuilder;
454        let a: u32 = 1;
455        let b: f32 = 2.0;
456        let mut builder = ArgBuilder::new();
457        builder.add("a", &a).add("b", &b);
458        assert_eq!(
459            builder.arg_count(),
460            2,
461            "ArgBuilder with 2 pushes must have length 2"
462        );
463        let ptrs = builder.build();
464        assert_eq!(ptrs.len(), 2, "build() must return 2 pointers");
465    }
466}