oxicuda_launch/kernel.rs
1//! Type-safe GPU kernel management and argument passing.
2//!
3//! This module provides the [`Kernel`] struct for launching GPU kernels
4//! and the [`KernelArgs`] trait for type-safe argument passing to CUDA
5//! kernel functions.
6//!
7//! # Architecture
8//!
9//! A [`Kernel`] wraps a [`Function`] handle and holds an `Arc<Module>`
10//! to ensure the PTX module remains loaded for the kernel's lifetime.
11//! Arguments are passed via the [`KernelArgs`] trait, which converts
12//! typed Rust values into the `*mut c_void` array that `cuLaunchKernel`
13//! expects.
14//!
15//! # Tuple arguments
16//!
17//! The [`KernelArgs`] trait is implemented for tuples of `Copy` types
18//! up to 24 elements. Each element must be `Copy` because kernel
19//! arguments are passed by value to the GPU.
20//!
21//! # Example
22//!
23//! ```rust,no_run
24//! # use std::sync::Arc;
25//! # use oxicuda_driver::{Module, Stream, Context, Device};
26//! # use oxicuda_launch::{Kernel, LaunchParams, Dim3};
27//! # fn main() -> oxicuda_driver::CudaResult<()> {
28//! # oxicuda_driver::init()?;
29//! # let dev = Device::get(0)?;
30//! # let ctx = Arc::new(Context::new(&dev)?);
31//! # let ptx = "";
32//! let module = Arc::new(Module::from_ptx(ptx)?);
33//! let kernel = Kernel::from_module(module, "vector_add")?;
34//!
35//! let stream = Stream::new(&ctx)?;
36//! let params = LaunchParams::new(4u32, 256u32);
37//!
38//! // Launch with typed arguments: (a_ptr, b_ptr, c_ptr, n)
39//! let args = (0u64, 0u64, 0u64, 1024u32);
40//! kernel.launch(¶ms, &stream, &args)?;
41//! # Ok(())
42//! # }
43//! ```
44
45use std::ffi::c_void;
46use std::sync::Arc;
47
48use oxicuda_driver::error::CudaResult;
49use oxicuda_driver::loader::try_driver;
50use oxicuda_driver::module::{Function, Module};
51use oxicuda_driver::stream::Stream;
52
53use crate::params::LaunchParams;
54use crate::trace::KernelSpanGuard;
55
56// ---------------------------------------------------------------------------
57// KernelArgs trait
58// ---------------------------------------------------------------------------
59
60/// Trait for types that can be passed as kernel arguments.
61///
62/// Kernel arguments must be convertible to an array of void pointers
63/// that `cuLaunchKernel` accepts. Each pointer points to the argument
64/// value on the host; the CUDA driver copies the values to the GPU
65/// before the kernel executes.
66///
67/// # Safety
68///
69/// Implementors must ensure that:
70/// - `as_param_ptrs` returns valid pointers to the argument values.
71/// - The pointed-to values remain valid for the duration of the kernel launch
72/// (i.e., until `cuLaunchKernel` returns).
73/// - The argument types and sizes match what the kernel expects.
74pub unsafe trait KernelArgs {
75 /// Convert arguments to an array of void pointers for `cuLaunchKernel`.
76 ///
77 /// Each element in the returned `Vec` is a pointer to one kernel argument.
78 /// The CUDA driver reads the value through each pointer and copies it
79 /// to the GPU.
80 fn as_param_ptrs(&self) -> Vec<*mut c_void>;
81}
82
83// ---------------------------------------------------------------------------
84// KernelArgs — unit type (no arguments)
85// ---------------------------------------------------------------------------
86
87/// Implementation for kernels that take no arguments.
88///
89/// # Safety
90///
91/// Returns an empty pointer array, which is valid for zero-argument kernels.
92unsafe impl KernelArgs for () {
93 #[inline]
94 fn as_param_ptrs(&self) -> Vec<*mut c_void> {
95 Vec::new()
96 }
97}
98
99// ---------------------------------------------------------------------------
100// KernelArgs — tuple implementations via macro
101// ---------------------------------------------------------------------------
102
103/// Generates [`KernelArgs`] implementations for tuples of `Copy` types.
104///
105/// Each tuple element is converted to a `*mut c_void` by taking
106/// a reference to the element and casting through `*const T`.
107macro_rules! impl_kernel_args_tuple {
108 ($($idx:tt: $T:ident),+) => {
109 /// # Safety
110 ///
111 /// The pointers returned point into `self`, which must remain
112 /// valid (i.e., not moved or dropped) until `cuLaunchKernel` returns.
113 unsafe impl<$($T: Copy),+> KernelArgs for ($($T,)+) {
114 #[inline]
115 fn as_param_ptrs(&self) -> Vec<*mut c_void> {
116 vec![
117 $(&self.$idx as *const $T as *mut c_void,)+
118 ]
119 }
120 }
121 };
122}
123
124impl_kernel_args_tuple!(0: A);
125impl_kernel_args_tuple!(0: A, 1: B);
126impl_kernel_args_tuple!(0: A, 1: B, 2: C);
127impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D);
128impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E);
129impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
130impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G);
131impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H);
132impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I);
133impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J);
134impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K);
135impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L);
136impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M);
137impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N);
138impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O);
139impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P);
140impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q);
141impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R);
142impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S);
143impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T);
144impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U);
145impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V);
146impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W);
147impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X);
148impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y);
149impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z);
150impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z, 26: AA);
151impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z, 26: AA, 27: BB);
152
153// ---------------------------------------------------------------------------
154// Kernel struct
155// ---------------------------------------------------------------------------
156
157/// A launchable GPU kernel with module lifetime management.
158///
159/// Holds an `Arc<Module>` to ensure the PTX module remains loaded
160/// as long as any `Kernel` references it. This is important because
161/// [`Function`] handles become invalid once their parent module is
162/// unloaded.
163///
164/// # Creating a kernel
165///
166/// ```rust,no_run
167/// # use std::sync::Arc;
168/// # use oxicuda_driver::Module;
169/// # use oxicuda_launch::Kernel;
170/// # fn main() -> oxicuda_driver::CudaResult<()> {
171/// # let ptx = "";
172/// let module = Arc::new(Module::from_ptx(ptx)?);
173/// let kernel = Kernel::from_module(module, "my_kernel")?;
174/// println!("loaded kernel: {}", kernel.name());
175/// # Ok(())
176/// # }
177/// ```
178///
179/// # Launching
180///
181/// ```rust,no_run
182/// # use std::sync::Arc;
183/// # use oxicuda_driver::{Module, Stream, Context, Device};
184/// # use oxicuda_launch::{Kernel, LaunchParams};
185/// # fn main() -> oxicuda_driver::CudaResult<()> {
186/// # oxicuda_driver::init()?;
187/// # let dev = Device::get(0)?;
188/// # let ctx = Arc::new(Context::new(&dev)?);
189/// # let ptx = "";
190/// # let module = Arc::new(Module::from_ptx(ptx)?);
191/// # let kernel = Kernel::from_module(module, "my_kernel")?;
192/// let stream = Stream::new(&ctx)?;
193/// let params = LaunchParams::new(4u32, 256u32);
194/// kernel.launch(¶ms, &stream, &(42u32, 1024u32))?;
195/// # Ok(())
196/// # }
197/// ```
198pub struct Kernel {
199 /// The underlying CUDA function handle.
200 function: Function,
201 /// Keeps the parent module alive as long as this kernel exists.
202 _module: Arc<Module>,
203 /// The kernel function name (for debugging and diagnostics).
204 name: String,
205}
206
207impl Kernel {
208 /// Creates a new `Kernel` from a module and function name.
209 ///
210 /// Looks up the named function in the module. The `Arc<Module>` ensures
211 /// the module is not unloaded while this kernel exists.
212 ///
213 /// # Errors
214 ///
215 /// Returns [`CudaError::NotFound`](oxicuda_driver::CudaError::NotFound) if no
216 /// function with the given name exists in the module, or another
217 /// [`CudaError`](oxicuda_driver::CudaError) on driver failure.
218 pub fn from_module(module: Arc<Module>, name: &str) -> CudaResult<Self> {
219 let function = module.get_function(name)?;
220 Ok(Self {
221 function,
222 _module: module,
223 name: name.to_owned(),
224 })
225 }
226
227 /// Launches the kernel with the given parameters and arguments on a stream.
228 ///
229 /// This is the primary entry point for kernel execution. It calls
230 /// `cuLaunchKernel` with the specified grid/block dimensions, shared
231 /// memory, stream, and kernel arguments.
232 ///
233 /// The launch is asynchronous — it returns immediately and the kernel
234 /// executes on the GPU. Use [`Stream::synchronize`] to wait for completion.
235 ///
236 /// # Type safety
237 ///
238 /// The `args` parameter accepts any type implementing [`KernelArgs`],
239 /// including tuples of `Copy` types up to 24 elements. The caller is
240 /// responsible for ensuring the argument types match the kernel signature.
241 ///
242 /// # Errors
243 ///
244 /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the launch fails
245 /// (e.g., invalid dimensions, insufficient resources, driver error).
246 pub fn launch<A: KernelArgs>(
247 &self,
248 params: &LaunchParams,
249 stream: &Stream,
250 args: &A,
251 ) -> CudaResult<()> {
252 // Emit a tracing span for this kernel launch (no-op when the
253 // `tracing` feature is disabled).
254 let _span = KernelSpanGuard::enter(
255 &self.name,
256 (params.grid.x, params.grid.y, params.grid.z),
257 (params.block.x, params.block.y, params.block.z),
258 );
259
260 let driver = try_driver()?;
261 let mut param_ptrs = args.as_param_ptrs();
262 oxicuda_driver::error::check(unsafe {
263 (driver.cu_launch_kernel)(
264 self.function.raw(),
265 params.grid.x,
266 params.grid.y,
267 params.grid.z,
268 params.block.x,
269 params.block.y,
270 params.block.z,
271 params.shared_mem_bytes,
272 stream.raw(),
273 param_ptrs.as_mut_ptr(),
274 std::ptr::null_mut(),
275 )
276 })
277 }
278
279 /// Returns the kernel function name.
280 #[inline]
281 pub fn name(&self) -> &str {
282 &self.name
283 }
284
285 /// Returns a reference to the underlying [`Function`] handle.
286 ///
287 /// This can be used for occupancy queries and other function-level
288 /// operations provided by `oxicuda-driver`.
289 #[inline]
290 pub fn function(&self) -> &Function {
291 &self.function
292 }
293
294 /// Returns the maximum number of active blocks per streaming multiprocessor
295 /// for a given block size and dynamic shared memory.
296 ///
297 /// Delegates to [`Function::max_active_blocks_per_sm`].
298 ///
299 /// # Parameters
300 ///
301 /// * `block_size` — number of threads per block.
302 /// * `dynamic_smem` — dynamic shared memory per block in bytes.
303 ///
304 /// # Errors
305 ///
306 /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the query fails.
307 pub fn max_active_blocks_per_sm(
308 &self,
309 block_size: i32,
310 dynamic_smem: usize,
311 ) -> CudaResult<i32> {
312 self.function
313 .max_active_blocks_per_sm(block_size, dynamic_smem)
314 }
315
316 /// Returns the optimal block size for this kernel and the minimum
317 /// grid size to achieve maximum occupancy.
318 ///
319 /// Delegates to [`Function::optimal_block_size`].
320 ///
321 /// Returns `(min_grid_size, optimal_block_size)`.
322 ///
323 /// # Parameters
324 ///
325 /// * `dynamic_smem` — dynamic shared memory per block in bytes.
326 ///
327 /// # Errors
328 ///
329 /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the query fails.
330 pub fn optimal_block_size(&self, dynamic_smem: usize) -> CudaResult<(i32, i32)> {
331 self.function.optimal_block_size(dynamic_smem)
332 }
333}
334
335impl std::fmt::Debug for Kernel {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.debug_struct("Kernel")
338 .field("name", &self.name)
339 .field("function", &self.function)
340 .finish_non_exhaustive()
341 }
342}
343
344impl std::fmt::Display for Kernel {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 write!(f, "Kernel({})", self.name)
347 }
348}
349
350// ---------------------------------------------------------------------------
351// Tests
352// ---------------------------------------------------------------------------
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn unit_args_empty() {
360 let args = ();
361 let ptrs = args.as_param_ptrs();
362 assert!(ptrs.is_empty());
363 }
364
365 #[test]
366 fn single_arg_ptr_valid() {
367 let args = (42u32,);
368 let ptrs = args.as_param_ptrs();
369 assert_eq!(ptrs.len(), 1);
370 // Verify the pointer actually points to the value.
371 let val_ptr = ptrs[0] as *const u32;
372 assert_eq!(unsafe { *val_ptr }, 42u32);
373 }
374
375 #[test]
376 fn two_args_ptr_valid() {
377 let args = (10u32, 20u64);
378 let ptrs = args.as_param_ptrs();
379 assert_eq!(ptrs.len(), 2);
380 assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 10u32);
381 assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 20u64);
382 }
383
384 #[test]
385 fn four_args_ptr_valid() {
386 let args = (1u32, 2u64, 3.0f32, 4.0f64);
387 let ptrs = args.as_param_ptrs();
388 assert_eq!(ptrs.len(), 4);
389 assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 1u32);
390 assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 2u64);
391 assert!((unsafe { *(ptrs[2] as *const f32) } - 3.0f32).abs() < f32::EPSILON);
392 assert!((unsafe { *(ptrs[3] as *const f64) } - 4.0f64).abs() < f64::EPSILON);
393 }
394
395 #[test]
396 fn twelve_args_count() {
397 let args = (
398 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32,
399 );
400 let ptrs = args.as_param_ptrs();
401 assert_eq!(ptrs.len(), 12);
402 for (i, ptr) in ptrs.iter().enumerate() {
403 let val = unsafe { *(*ptr as *const u32) };
404 assert_eq!(val, (i as u32) + 1);
405 }
406 }
407
408 // ---------------------------------------------------------------------------
409 // Quality gate tests (CPU-only, E2E PTX chain parameter verification)
410 // ---------------------------------------------------------------------------
411
412 #[test]
413 fn launch_params_grid_calculation_e2e() {
414 // Given n = 1_048_576 (1M elements) and block_size = 256,
415 // grid_size_for must return exactly 4096 (ceiling division).
416 let n: u32 = 1_048_576;
417 let block_size: u32 = 256;
418 let grid = crate::grid::grid_size_for(n, block_size);
419 assert_eq!(
420 grid, 4096,
421 "grid_size_for(1M, 256) must be 4096, got {grid}"
422 );
423 // Also verify via arithmetic: 1_048_576 / 256 == 4096 exactly
424 assert_eq!(
425 n % block_size,
426 0,
427 "n must be exactly divisible by block_size"
428 );
429 }
430
431 #[test]
432 fn launch_params_stores_grid_and_block() {
433 // LaunchParams::new(4096, 256) must record grid==4096 and block==256.
434 let params = LaunchParams::new(4096u32, 256u32);
435 assert_eq!(
436 params.grid.x, 4096,
437 "grid.x must be 4096, got {}",
438 params.grid.x
439 );
440 assert_eq!(
441 params.block.x, 256,
442 "block.x must be 256, got {}",
443 params.block.x
444 );
445 assert_eq!(params.shared_mem_bytes, 0);
446 // Total threads: 4096 * 256 = 1_048_576
447 assert_eq!(params.total_threads(), 1_048_576);
448 }
449
450 #[test]
451 fn named_args_builder_chain() {
452 // ArgBuilder::new().add("a", &1u32).add("b", &2.0f32).build() must have length 2.
453 use crate::named_args::ArgBuilder;
454 let a: u32 = 1;
455 let b: f32 = 2.0;
456 let mut builder = ArgBuilder::new();
457 builder.add("a", &a).add("b", &b);
458 assert_eq!(
459 builder.arg_count(),
460 2,
461 "ArgBuilder with 2 pushes must have length 2"
462 );
463 let ptrs = builder.build();
464 assert_eq!(ptrs.len(), 2, "build() must return 2 pointers");
465 }
466}