oxicuda_launch/kernel.rs
1//! Type-safe GPU kernel management and argument passing.
2//!
3//! This module provides the [`Kernel`] struct for launching GPU kernels
4//! and the [`KernelArgs`] trait for type-safe argument passing to CUDA
5//! kernel functions.
6//!
7//! # Architecture
8//!
9//! A [`Kernel`] wraps a [`Function`] handle and holds an `Arc<Module>`
10//! to ensure the PTX module remains loaded for the kernel's lifetime.
11//! Arguments are passed via the [`KernelArgs`] trait, which converts
12//! typed Rust values into the `*mut c_void` array that `cuLaunchKernel`
13//! expects.
14//!
15//! # Tuple arguments
16//!
17//! The [`KernelArgs`] trait is implemented for tuples of `Copy` types
18//! up to 24 elements. Each element must be `Copy` because kernel
19//! arguments are passed by value to the GPU.
20//!
21//! # Example
22//!
23//! ```rust,no_run
24//! # use std::sync::Arc;
25//! # use oxicuda_driver::{Module, Stream, Context, Device};
26//! # use oxicuda_launch::{Kernel, LaunchParams, Dim3};
27//! # fn main() -> oxicuda_driver::CudaResult<()> {
28//! # oxicuda_driver::init()?;
29//! # let dev = Device::get(0)?;
30//! # let ctx = Arc::new(Context::new(&dev)?);
31//! # let ptx = "";
32//! let module = Arc::new(Module::from_ptx(ptx)?);
33//! let kernel = Kernel::from_module(module, "vector_add")?;
34//!
35//! let stream = Stream::new(&ctx)?;
36//! let params = LaunchParams::new(4u32, 256u32);
37//!
38//! // Launch with typed arguments: (a_ptr, b_ptr, c_ptr, n)
39//! let args = (0u64, 0u64, 0u64, 1024u32);
40//! kernel.launch(¶ms, &stream, &args)?;
41//! # Ok(())
42//! # }
43//! ```
44
45use std::ffi::c_void;
46use std::sync::Arc;
47
48use oxicuda_driver::error::CudaResult;
49use oxicuda_driver::loader::try_driver;
50use oxicuda_driver::module::{Function, Module};
51use oxicuda_driver::stream::Stream;
52
53use crate::params::LaunchParams;
54use crate::trace::KernelSpanGuard;
55
56// ---------------------------------------------------------------------------
57// KernelArgs trait
58// ---------------------------------------------------------------------------
59
60/// Trait for types that can be passed as kernel arguments.
61///
62/// Kernel arguments must be convertible to an array of void pointers
63/// that `cuLaunchKernel` accepts. Each pointer points to the argument
64/// value on the host; the CUDA driver copies the values to the GPU
65/// before the kernel executes.
66///
67/// # Safety
68///
69/// Implementors must ensure that:
70/// - `as_param_ptrs` returns valid pointers to the argument values.
71/// - The pointed-to values remain valid for the duration of the kernel launch
72/// (i.e., until `cuLaunchKernel` returns).
73/// - The argument types and sizes match what the kernel expects.
74pub unsafe trait KernelArgs {
75 /// Convert arguments to an array of void pointers for `cuLaunchKernel`.
76 ///
77 /// Each element in the returned `Vec` is a pointer to one kernel argument.
78 /// The CUDA driver reads the value through each pointer and copies it
79 /// to the GPU.
80 fn as_param_ptrs(&self) -> Vec<*mut c_void>;
81}
82
83// ---------------------------------------------------------------------------
84// KernelArgs — unit type (no arguments)
85// ---------------------------------------------------------------------------
86
87/// Implementation for kernels that take no arguments.
88///
89/// # Safety
90///
91/// Returns an empty pointer array, which is valid for zero-argument kernels.
92unsafe impl KernelArgs for () {
93 #[inline]
94 fn as_param_ptrs(&self) -> Vec<*mut c_void> {
95 Vec::new()
96 }
97}
98
99// ---------------------------------------------------------------------------
100// KernelArgs — tuple implementations via macro
101// ---------------------------------------------------------------------------
102
103/// Generates [`KernelArgs`] implementations for tuples of `Copy` types.
104///
105/// Each tuple element is converted to a `*mut c_void` by taking
106/// a reference to the element and casting through `*const T`.
107macro_rules! impl_kernel_args_tuple {
108 ($($idx:tt: $T:ident),+) => {
109 /// # Safety
110 ///
111 /// The pointers returned point into `self`, which must remain
112 /// valid (i.e., not moved or dropped) until `cuLaunchKernel` returns.
113 unsafe impl<$($T: Copy),+> KernelArgs for ($($T,)+) {
114 #[inline]
115 fn as_param_ptrs(&self) -> Vec<*mut c_void> {
116 vec![
117 $(&self.$idx as *const $T as *mut c_void,)+
118 ]
119 }
120 }
121 };
122}
123
124impl_kernel_args_tuple!(0: A);
125impl_kernel_args_tuple!(0: A, 1: B);
126impl_kernel_args_tuple!(0: A, 1: B, 2: C);
127impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D);
128impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E);
129impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
130impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G);
131impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H);
132impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I);
133impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J);
134impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K);
135impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L);
136impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M);
137impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N);
138impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O);
139impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P);
140impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q);
141impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R);
142impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S);
143impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T);
144impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U);
145impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V);
146impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W);
147impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X);
148
149// ---------------------------------------------------------------------------
150// Kernel struct
151// ---------------------------------------------------------------------------
152
153/// A launchable GPU kernel with module lifetime management.
154///
155/// Holds an `Arc<Module>` to ensure the PTX module remains loaded
156/// as long as any `Kernel` references it. This is important because
157/// [`Function`] handles become invalid once their parent module is
158/// unloaded.
159///
160/// # Creating a kernel
161///
162/// ```rust,no_run
163/// # use std::sync::Arc;
164/// # use oxicuda_driver::Module;
165/// # use oxicuda_launch::Kernel;
166/// # fn main() -> oxicuda_driver::CudaResult<()> {
167/// # let ptx = "";
168/// let module = Arc::new(Module::from_ptx(ptx)?);
169/// let kernel = Kernel::from_module(module, "my_kernel")?;
170/// println!("loaded kernel: {}", kernel.name());
171/// # Ok(())
172/// # }
173/// ```
174///
175/// # Launching
176///
177/// ```rust,no_run
178/// # use std::sync::Arc;
179/// # use oxicuda_driver::{Module, Stream, Context, Device};
180/// # use oxicuda_launch::{Kernel, LaunchParams};
181/// # fn main() -> oxicuda_driver::CudaResult<()> {
182/// # oxicuda_driver::init()?;
183/// # let dev = Device::get(0)?;
184/// # let ctx = Arc::new(Context::new(&dev)?);
185/// # let ptx = "";
186/// # let module = Arc::new(Module::from_ptx(ptx)?);
187/// # let kernel = Kernel::from_module(module, "my_kernel")?;
188/// let stream = Stream::new(&ctx)?;
189/// let params = LaunchParams::new(4u32, 256u32);
190/// kernel.launch(¶ms, &stream, &(42u32, 1024u32))?;
191/// # Ok(())
192/// # }
193/// ```
194pub struct Kernel {
195 /// The underlying CUDA function handle.
196 function: Function,
197 /// Keeps the parent module alive as long as this kernel exists.
198 _module: Arc<Module>,
199 /// The kernel function name (for debugging and diagnostics).
200 name: String,
201}
202
203impl Kernel {
204 /// Creates a new `Kernel` from a module and function name.
205 ///
206 /// Looks up the named function in the module. The `Arc<Module>` ensures
207 /// the module is not unloaded while this kernel exists.
208 ///
209 /// # Errors
210 ///
211 /// Returns [`CudaError::NotFound`](oxicuda_driver::CudaError::NotFound) if no
212 /// function with the given name exists in the module, or another
213 /// [`CudaError`](oxicuda_driver::CudaError) on driver failure.
214 pub fn from_module(module: Arc<Module>, name: &str) -> CudaResult<Self> {
215 let function = module.get_function(name)?;
216 Ok(Self {
217 function,
218 _module: module,
219 name: name.to_owned(),
220 })
221 }
222
223 /// Launches the kernel with the given parameters and arguments on a stream.
224 ///
225 /// This is the primary entry point for kernel execution. It calls
226 /// `cuLaunchKernel` with the specified grid/block dimensions, shared
227 /// memory, stream, and kernel arguments.
228 ///
229 /// The launch is asynchronous — it returns immediately and the kernel
230 /// executes on the GPU. Use [`Stream::synchronize`] to wait for completion.
231 ///
232 /// # Type safety
233 ///
234 /// The `args` parameter accepts any type implementing [`KernelArgs`],
235 /// including tuples of `Copy` types up to 24 elements. The caller is
236 /// responsible for ensuring the argument types match the kernel signature.
237 ///
238 /// # Errors
239 ///
240 /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the launch fails
241 /// (e.g., invalid dimensions, insufficient resources, driver error).
242 pub fn launch<A: KernelArgs>(
243 &self,
244 params: &LaunchParams,
245 stream: &Stream,
246 args: &A,
247 ) -> CudaResult<()> {
248 // Emit a tracing span for this kernel launch (no-op when the
249 // `tracing` feature is disabled).
250 let _span = KernelSpanGuard::enter(
251 &self.name,
252 (params.grid.x, params.grid.y, params.grid.z),
253 (params.block.x, params.block.y, params.block.z),
254 );
255
256 let driver = try_driver()?;
257 let mut param_ptrs = args.as_param_ptrs();
258 oxicuda_driver::error::check(unsafe {
259 (driver.cu_launch_kernel)(
260 self.function.raw(),
261 params.grid.x,
262 params.grid.y,
263 params.grid.z,
264 params.block.x,
265 params.block.y,
266 params.block.z,
267 params.shared_mem_bytes,
268 stream.raw(),
269 param_ptrs.as_mut_ptr(),
270 std::ptr::null_mut(),
271 )
272 })
273 }
274
275 /// Returns the kernel function name.
276 #[inline]
277 pub fn name(&self) -> &str {
278 &self.name
279 }
280
281 /// Returns a reference to the underlying [`Function`] handle.
282 ///
283 /// This can be used for occupancy queries and other function-level
284 /// operations provided by `oxicuda-driver`.
285 #[inline]
286 pub fn function(&self) -> &Function {
287 &self.function
288 }
289
290 /// Returns the maximum number of active blocks per streaming multiprocessor
291 /// for a given block size and dynamic shared memory.
292 ///
293 /// Delegates to [`Function::max_active_blocks_per_sm`].
294 ///
295 /// # Parameters
296 ///
297 /// * `block_size` — number of threads per block.
298 /// * `dynamic_smem` — dynamic shared memory per block in bytes.
299 ///
300 /// # Errors
301 ///
302 /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the query fails.
303 pub fn max_active_blocks_per_sm(
304 &self,
305 block_size: i32,
306 dynamic_smem: usize,
307 ) -> CudaResult<i32> {
308 self.function
309 .max_active_blocks_per_sm(block_size, dynamic_smem)
310 }
311
312 /// Returns the optimal block size for this kernel and the minimum
313 /// grid size to achieve maximum occupancy.
314 ///
315 /// Delegates to [`Function::optimal_block_size`].
316 ///
317 /// Returns `(min_grid_size, optimal_block_size)`.
318 ///
319 /// # Parameters
320 ///
321 /// * `dynamic_smem` — dynamic shared memory per block in bytes.
322 ///
323 /// # Errors
324 ///
325 /// Returns a [`CudaError`](oxicuda_driver::CudaError) if the query fails.
326 pub fn optimal_block_size(&self, dynamic_smem: usize) -> CudaResult<(i32, i32)> {
327 self.function.optimal_block_size(dynamic_smem)
328 }
329}
330
331impl std::fmt::Debug for Kernel {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 f.debug_struct("Kernel")
334 .field("name", &self.name)
335 .field("function", &self.function)
336 .finish_non_exhaustive()
337 }
338}
339
340impl std::fmt::Display for Kernel {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 write!(f, "Kernel({})", self.name)
343 }
344}
345
346// ---------------------------------------------------------------------------
347// Tests
348// ---------------------------------------------------------------------------
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn unit_args_empty() {
356 let args = ();
357 let ptrs = args.as_param_ptrs();
358 assert!(ptrs.is_empty());
359 }
360
361 #[test]
362 fn single_arg_ptr_valid() {
363 let args = (42u32,);
364 let ptrs = args.as_param_ptrs();
365 assert_eq!(ptrs.len(), 1);
366 // Verify the pointer actually points to the value.
367 let val_ptr = ptrs[0] as *const u32;
368 assert_eq!(unsafe { *val_ptr }, 42u32);
369 }
370
371 #[test]
372 fn two_args_ptr_valid() {
373 let args = (10u32, 20u64);
374 let ptrs = args.as_param_ptrs();
375 assert_eq!(ptrs.len(), 2);
376 assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 10u32);
377 assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 20u64);
378 }
379
380 #[test]
381 fn four_args_ptr_valid() {
382 let args = (1u32, 2u64, 3.0f32, 4.0f64);
383 let ptrs = args.as_param_ptrs();
384 assert_eq!(ptrs.len(), 4);
385 assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 1u32);
386 assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 2u64);
387 assert!((unsafe { *(ptrs[2] as *const f32) } - 3.0f32).abs() < f32::EPSILON);
388 assert!((unsafe { *(ptrs[3] as *const f64) } - 4.0f64).abs() < f64::EPSILON);
389 }
390
391 #[test]
392 fn twelve_args_count() {
393 let args = (
394 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32,
395 );
396 let ptrs = args.as_param_ptrs();
397 assert_eq!(ptrs.len(), 12);
398 for (i, ptr) in ptrs.iter().enumerate() {
399 let val = unsafe { *(*ptr as *const u32) };
400 assert_eq!(val, (i as u32) + 1);
401 }
402 }
403
404 // ---------------------------------------------------------------------------
405 // Quality gate tests (CPU-only, E2E PTX chain parameter verification)
406 // ---------------------------------------------------------------------------
407
408 #[test]
409 fn launch_params_grid_calculation_e2e() {
410 // Given n = 1_048_576 (1M elements) and block_size = 256,
411 // grid_size_for must return exactly 4096 (ceiling division).
412 let n: u32 = 1_048_576;
413 let block_size: u32 = 256;
414 let grid = crate::grid::grid_size_for(n, block_size);
415 assert_eq!(
416 grid, 4096,
417 "grid_size_for(1M, 256) must be 4096, got {grid}"
418 );
419 // Also verify via arithmetic: 1_048_576 / 256 == 4096 exactly
420 assert_eq!(
421 n % block_size,
422 0,
423 "n must be exactly divisible by block_size"
424 );
425 }
426
427 #[test]
428 fn launch_params_stores_grid_and_block() {
429 // LaunchParams::new(4096, 256) must record grid==4096 and block==256.
430 let params = LaunchParams::new(4096u32, 256u32);
431 assert_eq!(
432 params.grid.x, 4096,
433 "grid.x must be 4096, got {}",
434 params.grid.x
435 );
436 assert_eq!(
437 params.block.x, 256,
438 "block.x must be 256, got {}",
439 params.block.x
440 );
441 assert_eq!(params.shared_mem_bytes, 0);
442 // Total threads: 4096 * 256 = 1_048_576
443 assert_eq!(params.total_threads(), 1_048_576);
444 }
445
446 #[test]
447 fn named_args_builder_chain() {
448 // ArgBuilder::new().add("a", &1u32).add("b", &2.0f32).build() must have length 2.
449 use crate::named_args::ArgBuilder;
450 let a: u32 = 1;
451 let b: f32 = 2.0;
452 let mut builder = ArgBuilder::new();
453 builder.add("a", &a).add("b", &b);
454 assert_eq!(
455 builder.arg_count(),
456 2,
457 "ArgBuilder with 2 pushes must have length 2"
458 );
459 let ptrs = builder.build();
460 assert_eq!(ptrs.len(), 2, "build() must return 2 pointers");
461 }
462}