Skip to main content

oxicuda_launch/
kernel.rs

1//! Type-safe GPU kernel management and argument passing.
2//!
3//! This module provides the [`Kernel`] struct for launching GPU kernels
4//! and the [`KernelArgs`] trait for type-safe argument passing to CUDA
5//! kernel functions.
6//!
7//! # Architecture
8//!
9//! A [`Kernel`] wraps a [`Function`] handle and holds an `Arc<Module>`
10//! to ensure the PTX module remains loaded for the kernel's lifetime.
11//! Arguments are passed via the [`KernelArgs`] trait, which converts
12//! typed Rust values into the `*mut c_void` array that `cuLaunchKernel`
13//! expects.
14//!
15//! # Tuple arguments
16//!
17//! The [`KernelArgs`] trait is implemented for tuples of `Copy` types
18//! up to 24 elements. Each element must be `Copy` because kernel
19//! arguments are passed by value to the GPU.
20//!
21//! # Example
22//!
23//! ```rust,no_run
24//! # use std::sync::Arc;
25//! # use oxicuda_driver::{Module, Stream, Context, Device};
26//! # use oxicuda_launch::{Kernel, LaunchParams, Dim3};
27//! # fn main() -> oxicuda_driver::CudaResult<()> {
28//! # oxicuda_driver::init()?;
29//! # let dev = Device::get(0)?;
30//! # let ctx = Arc::new(Context::new(&dev)?);
31//! # let ptx = "";
32//! let module = Arc::new(Module::from_ptx(ptx)?);
33//! let kernel = Kernel::from_module(module, "vector_add")?;
34//!
35//! let stream = Stream::new(&ctx)?;
36//! let params = LaunchParams::new(4u32, 256u32);
37//!
38//! // Launch with typed arguments: (a_ptr, b_ptr, c_ptr, n)
39//! let args = (0u64, 0u64, 0u64, 1024u32);
40//! kernel.launch(&params, &stream, &args)?;
41//! # Ok(())
42//! # }
43//! ```
44
45use std::ffi::c_void;
46use std::sync::Arc;
47
48use oxicuda_driver::error::CudaResult;
49use oxicuda_driver::loader::try_driver;
50use oxicuda_driver::module::{Function, Module};
51use oxicuda_driver::stream::Stream;
52
53use crate::params::LaunchParams;
54use crate::trace::KernelSpanGuard;
55
56// ---------------------------------------------------------------------------
57// KernelArgs trait
58// ---------------------------------------------------------------------------
59
60/// Trait for types that can be passed as kernel arguments.
61///
62/// Kernel arguments must be convertible to an array of void pointers
63/// that `cuLaunchKernel` accepts. Each pointer points to the argument
64/// value on the host; the CUDA driver copies the values to the GPU
65/// before the kernel executes.
66///
67/// # Safety
68///
69/// Implementors must ensure that:
70/// - `as_param_ptrs` returns valid pointers to the argument values.
71/// - The pointed-to values remain valid for the duration of the kernel launch
72///   (i.e., until `cuLaunchKernel` returns).
73/// - The argument types and sizes match what the kernel expects.
74pub unsafe trait KernelArgs {
75    /// Convert arguments to an array of void pointers for `cuLaunchKernel`.
76    ///
77    /// Each element in the returned `Vec` is a pointer to one kernel argument.
78    /// The CUDA driver reads the value through each pointer and copies it
79    /// to the GPU.
80    fn as_param_ptrs(&self) -> Vec<*mut c_void>;
81}
82
83// ---------------------------------------------------------------------------
84// KernelArgs — unit type (no arguments)
85// ---------------------------------------------------------------------------
86
87/// Implementation for kernels that take no arguments.
88///
89/// # Safety
90///
91/// Returns an empty pointer array, which is valid for zero-argument kernels.
92unsafe impl KernelArgs for () {
93    #[inline]
94    fn as_param_ptrs(&self) -> Vec<*mut c_void> {
95        Vec::new()
96    }
97}
98
99// ---------------------------------------------------------------------------
100// KernelArgs — tuple implementations via macro
101// ---------------------------------------------------------------------------
102
103/// Generates [`KernelArgs`] implementations for tuples of `Copy` types.
104///
105/// Each tuple element is converted to a `*mut c_void` by taking
106/// a reference to the element and casting through `*const T`.
107macro_rules! impl_kernel_args_tuple {
108    ($($idx:tt: $T:ident),+) => {
109        /// # Safety
110        ///
111        /// The pointers returned point into `self`, which must remain
112        /// valid (i.e., not moved or dropped) until `cuLaunchKernel` returns.
113        unsafe impl<$($T: Copy),+> KernelArgs for ($($T,)+) {
114            #[inline]
115            fn as_param_ptrs(&self) -> Vec<*mut c_void> {
116                vec![
117                    $(&self.$idx as *const $T as *mut c_void,)+
118                ]
119            }
120        }
121    };
122}
123
124impl_kernel_args_tuple!(0: A);
125impl_kernel_args_tuple!(0: A, 1: B);
126impl_kernel_args_tuple!(0: A, 1: B, 2: C);
127impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D);
128impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E);
129impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
130impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G);
131impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H);
132impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I);
133impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J);
134impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K);
135impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L);
136impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M);
137impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N);
138impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O);
139impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P);
140impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q);
141impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R);
142impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S);
143impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T);
144impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U);
145impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V);
146impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W);
147impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X);
148
149// ---------------------------------------------------------------------------
150// Kernel struct
151// ---------------------------------------------------------------------------
152
153/// A launchable GPU kernel with module lifetime management.
154///
155/// Holds an `Arc<Module>` to ensure the PTX module remains loaded
156/// as long as any `Kernel` references it. This is important because
157/// [`Function`] handles become invalid once their parent module is
158/// unloaded.
159///
160/// # Creating a kernel
161///
162/// ```rust,no_run
163/// # use std::sync::Arc;
164/// # use oxicuda_driver::Module;
165/// # use oxicuda_launch::Kernel;
166/// # fn main() -> oxicuda_driver::CudaResult<()> {
167/// # let ptx = "";
168/// let module = Arc::new(Module::from_ptx(ptx)?);
169/// let kernel = Kernel::from_module(module, "my_kernel")?;
170/// println!("loaded kernel: {}", kernel.name());
171/// # Ok(())
172/// # }
173/// ```
174///
175/// # Launching
176///
177/// ```rust,no_run
178/// # use std::sync::Arc;
179/// # use oxicuda_driver::{Module, Stream, Context, Device};
180/// # use oxicuda_launch::{Kernel, LaunchParams};
181/// # fn main() -> oxicuda_driver::CudaResult<()> {
182/// # oxicuda_driver::init()?;
183/// # let dev = Device::get(0)?;
184/// # let ctx = Arc::new(Context::new(&dev)?);
185/// # let ptx = "";
186/// # let module = Arc::new(Module::from_ptx(ptx)?);
187/// # let kernel = Kernel::from_module(module, "my_kernel")?;
188/// let stream = Stream::new(&ctx)?;
189/// let params = LaunchParams::new(4u32, 256u32);
190/// kernel.launch(&params, &stream, &(42u32, 1024u32))?;
191/// # Ok(())
192/// # }
193/// ```
194pub struct Kernel {
195    /// The underlying CUDA function handle.
196    function: Function,
197    /// Keeps the parent module alive as long as this kernel exists.
198    _module: Arc<Module>,
199    /// The kernel function name (for debugging and diagnostics).
200    name: String,
201}
202
203impl Kernel {
204    /// Creates a new `Kernel` from a module and function name.
205    ///
206    /// Looks up the named function in the module. The `Arc<Module>` ensures
207    /// the module is not unloaded while this kernel exists.
208    ///
209    /// # Errors
210    ///
211    /// Returns [`CudaError::NotFound`](oxicuda_driver::CudaError::NotFound) if no
212    /// function with the given name exists in the module, or another
213    /// [`CudaError`](oxicuda_driver::CudaError) on driver failure.
214    pub fn from_module(module: Arc<Module>, name: &str) -> CudaResult<Self> {
215        let function = module.get_function(name)?;
216        Ok(Self {
217            function,
218            _module: module,
219            name: name.to_owned(),
220        })
221    }
222
223    /// Launches the kernel with the given parameters and arguments on a stream.
224    ///
225    /// This is the primary entry point for kernel execution. It calls
226    /// `cuLaunchKernel` with the specified grid/block dimensions, shared
227    /// memory, stream, and kernel arguments.
228    ///
229    /// The launch is asynchronous — it returns immediately and the kernel
230    /// executes on the GPU. Use [`Stream::synchronize`] to wait for completion.
231    ///
232    /// # Type safety
233    ///
234    /// The `args` parameter accepts any type implementing [`KernelArgs`],
235    /// including tuples of `Copy` types up to 24 elements. The caller is
236    /// responsible for ensuring the argument types match the kernel signature.
237    ///
238    /// # Errors
239    ///
240    /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the launch fails
241    /// (e.g., invalid dimensions, insufficient resources, driver error).
242    pub fn launch<A: KernelArgs>(
243        &self,
244        params: &LaunchParams,
245        stream: &Stream,
246        args: &A,
247    ) -> CudaResult<()> {
248        // Emit a tracing span for this kernel launch (no-op when the
249        // `tracing` feature is disabled).
250        let _span = KernelSpanGuard::enter(
251            &self.name,
252            (params.grid.x, params.grid.y, params.grid.z),
253            (params.block.x, params.block.y, params.block.z),
254        );
255
256        let driver = try_driver()?;
257        let mut param_ptrs = args.as_param_ptrs();
258        oxicuda_driver::error::check(unsafe {
259            (driver.cu_launch_kernel)(
260                self.function.raw(),
261                params.grid.x,
262                params.grid.y,
263                params.grid.z,
264                params.block.x,
265                params.block.y,
266                params.block.z,
267                params.shared_mem_bytes,
268                stream.raw(),
269                param_ptrs.as_mut_ptr(),
270                std::ptr::null_mut(),
271            )
272        })
273    }
274
275    /// Returns the kernel function name.
276    #[inline]
277    pub fn name(&self) -> &str {
278        &self.name
279    }
280
281    /// Returns a reference to the underlying [`Function`] handle.
282    ///
283    /// This can be used for occupancy queries and other function-level
284    /// operations provided by `oxicuda-driver`.
285    #[inline]
286    pub fn function(&self) -> &Function {
287        &self.function
288    }
289
290    /// Returns the maximum number of active blocks per streaming multiprocessor
291    /// for a given block size and dynamic shared memory.
292    ///
293    /// Delegates to [`Function::max_active_blocks_per_sm`].
294    ///
295    /// # Parameters
296    ///
297    /// * `block_size` — number of threads per block.
298    /// * `dynamic_smem` — dynamic shared memory per block in bytes.
299    ///
300    /// # Errors
301    ///
302    /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the query fails.
303    pub fn max_active_blocks_per_sm(
304        &self,
305        block_size: i32,
306        dynamic_smem: usize,
307    ) -> CudaResult<i32> {
308        self.function
309            .max_active_blocks_per_sm(block_size, dynamic_smem)
310    }
311
312    /// Returns the optimal block size for this kernel and the minimum
313    /// grid size to achieve maximum occupancy.
314    ///
315    /// Delegates to [`Function::optimal_block_size`].
316    ///
317    /// Returns `(min_grid_size, optimal_block_size)`.
318    ///
319    /// # Parameters
320    ///
321    /// * `dynamic_smem` — dynamic shared memory per block in bytes.
322    ///
323    /// # Errors
324    ///
325    /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the query fails.
326    pub fn optimal_block_size(&self, dynamic_smem: usize) -> CudaResult<(i32, i32)> {
327        self.function.optimal_block_size(dynamic_smem)
328    }
329}
330
331impl std::fmt::Debug for Kernel {
332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333        f.debug_struct("Kernel")
334            .field("name", &self.name)
335            .field("function", &self.function)
336            .finish_non_exhaustive()
337    }
338}
339
340impl std::fmt::Display for Kernel {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        write!(f, "Kernel({})", self.name)
343    }
344}
345
346// ---------------------------------------------------------------------------
347// Tests
348// ---------------------------------------------------------------------------
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn unit_args_empty() {
356        let args = ();
357        let ptrs = args.as_param_ptrs();
358        assert!(ptrs.is_empty());
359    }
360
361    #[test]
362    fn single_arg_ptr_valid() {
363        let args = (42u32,);
364        let ptrs = args.as_param_ptrs();
365        assert_eq!(ptrs.len(), 1);
366        // Verify the pointer actually points to the value.
367        let val_ptr = ptrs[0] as *const u32;
368        assert_eq!(unsafe { *val_ptr }, 42u32);
369    }
370
371    #[test]
372    fn two_args_ptr_valid() {
373        let args = (10u32, 20u64);
374        let ptrs = args.as_param_ptrs();
375        assert_eq!(ptrs.len(), 2);
376        assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 10u32);
377        assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 20u64);
378    }
379
380    #[test]
381    fn four_args_ptr_valid() {
382        let args = (1u32, 2u64, 3.0f32, 4.0f64);
383        let ptrs = args.as_param_ptrs();
384        assert_eq!(ptrs.len(), 4);
385        assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 1u32);
386        assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 2u64);
387        assert!((unsafe { *(ptrs[2] as *const f32) } - 3.0f32).abs() < f32::EPSILON);
388        assert!((unsafe { *(ptrs[3] as *const f64) } - 4.0f64).abs() < f64::EPSILON);
389    }
390
391    #[test]
392    fn twelve_args_count() {
393        let args = (
394            1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32,
395        );
396        let ptrs = args.as_param_ptrs();
397        assert_eq!(ptrs.len(), 12);
398        for (i, ptr) in ptrs.iter().enumerate() {
399            let val = unsafe { *(*ptr as *const u32) };
400            assert_eq!(val, (i as u32) + 1);
401        }
402    }
403
404    // ---------------------------------------------------------------------------
405    // Quality gate tests (CPU-only, E2E PTX chain parameter verification)
406    // ---------------------------------------------------------------------------
407
408    #[test]
409    fn launch_params_grid_calculation_e2e() {
410        // Given n = 1_048_576 (1M elements) and block_size = 256,
411        // grid_size_for must return exactly 4096 (ceiling division).
412        let n: u32 = 1_048_576;
413        let block_size: u32 = 256;
414        let grid = crate::grid::grid_size_for(n, block_size);
415        assert_eq!(
416            grid, 4096,
417            "grid_size_for(1M, 256) must be 4096, got {grid}"
418        );
419        // Also verify via arithmetic: 1_048_576 / 256 == 4096 exactly
420        assert_eq!(
421            n % block_size,
422            0,
423            "n must be exactly divisible by block_size"
424        );
425    }
426
427    #[test]
428    fn launch_params_stores_grid_and_block() {
429        // LaunchParams::new(4096, 256) must record grid==4096 and block==256.
430        let params = LaunchParams::new(4096u32, 256u32);
431        assert_eq!(
432            params.grid.x, 4096,
433            "grid.x must be 4096, got {}",
434            params.grid.x
435        );
436        assert_eq!(
437            params.block.x, 256,
438            "block.x must be 256, got {}",
439            params.block.x
440        );
441        assert_eq!(params.shared_mem_bytes, 0);
442        // Total threads: 4096 * 256 = 1_048_576
443        assert_eq!(params.total_threads(), 1_048_576);
444    }
445
446    #[test]
447    fn named_args_builder_chain() {
448        // ArgBuilder::new().add("a", &1u32).add("b", &2.0f32).build() must have length 2.
449        use crate::named_args::ArgBuilder;
450        let a: u32 = 1;
451        let b: f32 = 2.0;
452        let mut builder = ArgBuilder::new();
453        builder.add("a", &a).add("b", &b);
454        assert_eq!(
455            builder.arg_count(),
456            2,
457            "ArgBuilder with 2 pushes must have length 2"
458        );
459        let ptrs = builder.build();
460        assert_eq!(ptrs.len(), 2, "build() must return 2 pointers");
461    }
462}