oxicuda_launch/named_args.rs
1//! Named kernel arguments for enhanced debuggability and type safety.
2//!
3//! This module provides [`NamedKernelArgs`], a trait that extends
4//! [`KernelArgs`] with human-readable argument names, and [`ArgBuilder`],
5//! a builder that constructs argument pointer arrays with associated
6//! names for logging and debugging.
7//!
8//! # Motivation
9//!
10//! Standard kernel arguments are positional tuples with no names. When
11//! debugging kernel launches, it is helpful to know which argument
12//! corresponds to which kernel parameter. `NamedKernelArgs` bridges
13//! this gap by associating names with arguments.
14//!
15//! # Example
16//!
17//! ```rust
18//! use oxicuda_launch::named_args::ArgBuilder;
19//!
20//! let a_ptr: u64 = 0x1000;
21//! let n: u32 = 1024;
22//! let mut builder = ArgBuilder::new();
23//! builder.add("a_ptr", &a_ptr).add("n", &n);
24//! assert_eq!(builder.names(), &["a_ptr", "n"]);
25//! let ptrs = builder.build();
26//! assert_eq!(ptrs.len(), 2);
27//! ```
28
29use std::ffi::c_void;
30
31use crate::kernel::KernelArgs;
32
33// ---------------------------------------------------------------------------
34// NamedKernelArgs trait
35// ---------------------------------------------------------------------------
36
37/// Extension of [`KernelArgs`] that provides argument metadata.
38///
39/// Types implementing this trait can report the names and count of
40/// their kernel arguments, which is useful for debugging, logging,
41/// and validation.
42///
43/// # Safety
44///
45/// Implementors must uphold the same invariants as [`KernelArgs`].
46/// The names returned by `arg_names` must correspond one-to-one with
47/// the pointers returned by `as_param_ptrs`.
48pub unsafe trait NamedKernelArgs: KernelArgs {
49 /// Returns the names of all kernel arguments in order.
50 fn arg_names() -> &'static [&'static str];
51
52 /// Returns the number of kernel arguments.
53 fn arg_count() -> usize {
54 Self::arg_names().len()
55 }
56}
57
58// ---------------------------------------------------------------------------
59// ArgEntry
60// ---------------------------------------------------------------------------
61
62/// An entry in the argument builder, holding a pointer and its name.
63#[derive(Debug)]
64struct ArgEntry {
65 /// Pointer to the argument value.
66 ptr: *mut c_void,
67 /// Human-readable name for the argument.
68 name: String,
69}
70
71// ---------------------------------------------------------------------------
72// ArgBuilder
73// ---------------------------------------------------------------------------
74
75/// A builder for constructing named kernel argument arrays.
76///
77/// Collects typed argument values along with their names, then
78/// produces the `Vec<*mut c_void>` array needed by `cuLaunchKernel`.
79///
80/// # Example
81///
82/// ```rust
83/// use oxicuda_launch::named_args::ArgBuilder;
84///
85/// let x: f32 = 3.14;
86/// let n: u32 = 512;
87/// let mut builder = ArgBuilder::new();
88/// builder.add("x", &x).add("n", &n);
89/// assert_eq!(builder.arg_count(), 2);
90/// let ptrs = builder.build();
91/// assert_eq!(ptrs.len(), 2);
92/// ```
93pub struct ArgBuilder {
94 /// Collected argument entries.
95 args: Vec<ArgEntry>,
96}
97
98impl ArgBuilder {
99 /// Creates a new empty argument builder.
100 #[inline]
101 pub fn new() -> Self {
102 Self { args: Vec::new() }
103 }
104
105 /// Creates a new argument builder with the given initial capacity.
106 #[inline]
107 pub fn with_capacity(capacity: usize) -> Self {
108 Self {
109 args: Vec::with_capacity(capacity),
110 }
111 }
112
113 /// Adds a named argument to the builder.
114 ///
115 /// The pointer to `val` is stored. The caller must ensure that `val`
116 /// remains valid (not moved or dropped) until the kernel launch
117 /// using the built pointer array completes.
118 ///
119 /// Returns `&mut Self` for method chaining.
120 pub fn add<T: Copy>(&mut self, name: &str, val: &T) -> &mut Self {
121 self.args.push(ArgEntry {
122 ptr: val as *const T as *mut c_void,
123 name: name.to_owned(),
124 });
125 self
126 }
127
128 /// Builds the argument pointer array for `cuLaunchKernel`.
129 ///
130 /// Returns the raw pointer array. The names are consumed; use
131 /// [`names`](Self::names) before calling `build` if you need them.
132 pub fn build(self) -> Vec<*mut c_void> {
133 self.args.into_iter().map(|entry| entry.ptr).collect()
134 }
135
136 /// Returns the names of all added arguments in order.
137 pub fn names(&self) -> Vec<&str> {
138 self.args.iter().map(|entry| entry.name.as_str()).collect()
139 }
140
141 /// Returns the number of arguments added so far.
142 #[inline]
143 pub fn arg_count(&self) -> usize {
144 self.args.len()
145 }
146
147 /// Returns `true` if no arguments have been added.
148 #[inline]
149 pub fn is_empty(&self) -> bool {
150 self.args.is_empty()
151 }
152
153 /// Returns a human-readable summary of the arguments.
154 pub fn summary(&self) -> String {
155 let parts: Vec<String> = self
156 .args
157 .iter()
158 .map(|entry| format!("{}={:p}", entry.name, entry.ptr))
159 .collect();
160 format!("ArgBuilder[{}]", parts.join(", "))
161 }
162}
163
164impl Default for ArgBuilder {
165 #[inline]
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl std::fmt::Debug for ArgBuilder {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("ArgBuilder")
174 .field("count", &self.args.len())
175 .field("names", &self.names())
176 .finish()
177 }
178}
179
180// Implement NamedKernelArgs for () (no arguments).
181//
182// SAFETY: Returns an empty name array, consistent with the empty
183// pointer array from KernelArgs for ().
184unsafe impl NamedKernelArgs for () {
185 fn arg_names() -> &'static [&'static str] {
186 &[]
187 }
188}
189
190// ---------------------------------------------------------------------------
191// FixedNamedArgs — stack-allocated named argument array
192// ---------------------------------------------------------------------------
193
194/// A single named argument entry: a static string name paired with a raw
195/// const pointer to the argument value (already on the stack in the caller).
196///
197/// The pointer is `*const c_void` so no allocation occurs; it is the
198/// caller's responsibility to ensure the pointed-to value outlives the
199/// kernel launch.
200#[derive(Debug, Clone, Copy)]
201pub struct NamedArgEntry {
202 /// The human-readable name of this kernel argument.
203 pub name: &'static str,
204 /// Raw const pointer to the argument value.
205 pub ptr: *const c_void,
206}
207
208// SAFETY: NamedArgEntry only holds a raw pointer to a value that lives
209// at least as long as the kernel launch (caller guarantees this).
210// The pointer is never dereferenced inside this module.
211unsafe impl Send for NamedArgEntry {}
212unsafe impl Sync for NamedArgEntry {}
213
214/// A const-generic, stack-allocated array of named kernel arguments.
215///
216/// `FixedNamedArgs<N>` stores exactly `N` [`NamedArgEntry`] values on
217/// the stack — zero heap allocation, zero indirection overhead compared
218/// to a plain positional tuple.
219///
220/// # Invariants
221///
222/// Every pointed-to value must remain valid (not moved or dropped) until
223/// after the kernel launch that consumes this struct completes.
224///
225/// # Example
226///
227/// ```rust
228/// use oxicuda_launch::named_args::{FixedNamedArgs, NamedArgEntry};
229///
230/// let n: u32 = 1024;
231/// let alpha: f32 = 2.0;
232///
233/// let args = FixedNamedArgs::new([
234/// NamedArgEntry { name: "n", ptr: &n as *const u32 as *const std::ffi::c_void },
235/// NamedArgEntry { name: "alpha", ptr: &alpha as *const f32 as *const std::ffi::c_void },
236/// ]);
237///
238/// assert_eq!(args.len(), 2);
239/// assert_eq!(args.names()[0], "n");
240/// assert_eq!(args.names()[1], "alpha");
241/// ```
242pub struct FixedNamedArgs<const N: usize> {
243 /// The argument entries, stored inline on the stack.
244 entries: [NamedArgEntry; N],
245}
246
247impl<const N: usize> FixedNamedArgs<N> {
248 /// Creates a new `FixedNamedArgs` from an array of [`NamedArgEntry`] values.
249 #[inline]
250 pub const fn new(entries: [NamedArgEntry; N]) -> Self {
251 Self { entries }
252 }
253
254 /// Returns the number of arguments.
255 #[inline]
256 pub const fn len(&self) -> usize {
257 N
258 }
259
260 /// Returns `true` if there are no arguments.
261 #[inline]
262 pub const fn is_empty(&self) -> bool {
263 N == 0
264 }
265
266 /// Returns the argument names in declaration order.
267 pub fn names(&self) -> [&'static str; N] {
268 let mut out = [""; N];
269 for (i, entry) in self.entries.iter().enumerate() {
270 out[i] = entry.name;
271 }
272 out
273 }
274
275 /// Returns a mutable array of `*mut c_void` pointers suitable for
276 /// passing directly to `cuLaunchKernel` as the `kernelParams` array.
277 ///
278 /// # Safety
279 ///
280 /// The returned pointers are valid only as long as the original values
281 /// that were passed when constructing the entries remain in scope.
282 pub fn as_ptr_array(&self) -> [*mut c_void; N] {
283 let mut out = [std::ptr::null_mut::<c_void>(); N];
284 for (i, entry) in self.entries.iter().enumerate() {
285 // Cast const → mut: cuLaunchKernel's ABI takes `void**` but
286 // never mutates the argument values.
287 out[i] = entry.ptr as *mut c_void;
288 }
289 out
290 }
291
292 /// Returns an immutable slice over the entries for inspection.
293 #[inline]
294 pub fn entries(&self) -> &[NamedArgEntry; N] {
295 &self.entries
296 }
297}
298
299impl<const N: usize> std::fmt::Debug for FixedNamedArgs<N> {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 let mut ds = f.debug_struct("FixedNamedArgs");
302 ds.field("len", &N);
303 for entry in &self.entries {
304 ds.field(entry.name, &entry.ptr);
305 }
306 ds.finish()
307 }
308}
309
310// ---------------------------------------------------------------------------
311// Tests
312// ---------------------------------------------------------------------------
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn arg_builder_new_empty() {
320 let builder = ArgBuilder::new();
321 assert!(builder.is_empty());
322 assert_eq!(builder.arg_count(), 0);
323 }
324
325 #[test]
326 fn arg_builder_add_and_names() {
327 let x: u32 = 42;
328 let y: f64 = 2.78;
329 let mut builder = ArgBuilder::new();
330 builder.add("x", &x).add("y", &y);
331 assert_eq!(builder.arg_count(), 2);
332 assert_eq!(builder.names(), vec!["x", "y"]);
333 }
334
335 #[test]
336 fn arg_builder_build_pointer_count() {
337 let a: u64 = 0x1000;
338 let b: u32 = 512;
339 let c: f32 = 1.0;
340 let mut builder = ArgBuilder::new();
341 builder.add("a", &a).add("b", &b).add("c", &c);
342 let ptrs = builder.build();
343 assert_eq!(ptrs.len(), 3);
344 }
345
346 #[test]
347 fn arg_builder_build_pointers_valid() {
348 let val: u32 = 99;
349 let mut builder = ArgBuilder::new();
350 builder.add("val", &val);
351 let ptrs = builder.build();
352 assert_eq!(ptrs.len(), 1);
353 let read_back = unsafe { *(ptrs[0] as *const u32) };
354 assert_eq!(read_back, 99);
355 }
356
357 #[test]
358 fn arg_builder_summary() {
359 let x: u32 = 10;
360 let mut builder = ArgBuilder::new();
361 builder.add("x", &x);
362 let s = builder.summary();
363 assert!(s.starts_with("ArgBuilder["));
364 assert!(s.contains("x="));
365 }
366
367 #[test]
368 fn arg_builder_debug() {
369 let builder = ArgBuilder::new();
370 let dbg = format!("{builder:?}");
371 assert!(dbg.contains("ArgBuilder"));
372 assert!(dbg.contains("count"));
373 }
374
375 #[test]
376 fn arg_builder_default() {
377 let builder = ArgBuilder::default();
378 assert!(builder.is_empty());
379 }
380
381 #[test]
382 fn arg_builder_with_capacity() {
383 let builder = ArgBuilder::with_capacity(8);
384 assert!(builder.is_empty());
385 assert_eq!(builder.arg_count(), 0);
386 }
387
388 #[test]
389 fn named_kernel_args_trait_exists() {
390 // Verify the trait compiles with the safety requirement.
391 fn assert_named<T: NamedKernelArgs>() {
392 let _names = T::arg_names();
393 let _count = T::arg_count();
394 }
395 // We cannot call assert_named without an implementor, but
396 // the existence of the function is enough.
397 let _ = assert_named::<()> as *const ();
398 }
399
400 // -----------------------------------------------------------------------
401 // FixedNamedArgs tests
402 // -----------------------------------------------------------------------
403
404 #[test]
405 fn fixed_named_args_zero_size() {
406 let args: FixedNamedArgs<0> = FixedNamedArgs::new([]);
407 assert!(args.is_empty());
408 assert_eq!(args.len(), 0);
409 let ptrs = args.as_ptr_array();
410 assert_eq!(ptrs.len(), 0);
411 }
412
413 #[test]
414 fn fixed_named_args_single_u32() {
415 let n: u32 = 1024;
416 let args = FixedNamedArgs::new([NamedArgEntry {
417 name: "n",
418 ptr: &n as *const u32 as *const c_void,
419 }]);
420 assert_eq!(args.len(), 1);
421 assert!(!args.is_empty());
422 assert_eq!(args.names(), ["n"]);
423
424 let ptrs = args.as_ptr_array();
425 assert_eq!(ptrs.len(), 1);
426 // Verify the pointer round-trips to the original value.
427 let read_back = unsafe { *(ptrs[0] as *const u32) };
428 assert_eq!(read_back, 1024);
429 }
430
431 #[test]
432 fn fixed_named_args_two_entries_order_preserved() {
433 let n: u32 = 512;
434 let alpha: f32 = std::f32::consts::PI;
435 let args = FixedNamedArgs::new([
436 NamedArgEntry {
437 name: "n",
438 ptr: &n as *const u32 as *const c_void,
439 },
440 NamedArgEntry {
441 name: "alpha",
442 ptr: &alpha as *const f32 as *const c_void,
443 },
444 ]);
445
446 assert_eq!(args.len(), 2);
447 let names = args.names();
448 assert_eq!(names[0], "n");
449 assert_eq!(names[1], "alpha");
450
451 let ptrs = args.as_ptr_array();
452 let n_back = unsafe { *(ptrs[0] as *const u32) };
453 let a_back = unsafe { *(ptrs[1] as *const f32) };
454 assert_eq!(n_back, 512);
455 assert!((a_back - std::f32::consts::PI).abs() < f32::EPSILON);
456 }
457
458 #[test]
459 fn fixed_named_args_no_size_overhead_vs_entry_array() {
460 use std::mem::size_of;
461 // FixedNamedArgs<N> must have the same size as [NamedArgEntry; N].
462 assert_eq!(
463 size_of::<FixedNamedArgs<4>>(),
464 size_of::<[NamedArgEntry; 4]>(),
465 "FixedNamedArgs<4> must not add any size overhead"
466 );
467 assert_eq!(
468 size_of::<FixedNamedArgs<1>>(),
469 size_of::<[NamedArgEntry; 1]>(),
470 "FixedNamedArgs<1> must not add any size overhead"
471 );
472 }
473
474 #[test]
475 fn fixed_named_args_ptr_array_length_matches_n() {
476 let a: u64 = 0x1000_0000;
477 let b: u32 = 256;
478 let c: f64 = 1.0;
479 let args = FixedNamedArgs::new([
480 NamedArgEntry {
481 name: "a",
482 ptr: &a as *const u64 as *const c_void,
483 },
484 NamedArgEntry {
485 name: "b",
486 ptr: &b as *const u32 as *const c_void,
487 },
488 NamedArgEntry {
489 name: "c",
490 ptr: &c as *const f64 as *const c_void,
491 },
492 ]);
493 let ptrs = args.as_ptr_array();
494 assert_eq!(ptrs.len(), 3);
495 }
496
497 #[test]
498 fn fixed_named_args_debug_contains_len() {
499 let n: u32 = 42;
500 let args = FixedNamedArgs::new([NamedArgEntry {
501 name: "n",
502 ptr: &n as *const u32 as *const c_void,
503 }]);
504 let dbg = format!("{args:?}");
505 assert!(dbg.contains("FixedNamedArgs"), "Debug output: {dbg}");
506 assert!(
507 dbg.contains('1') || dbg.contains("len"),
508 "Debug output: {dbg}"
509 );
510 }
511
512 #[test]
513 fn fixed_named_args_entries_accessor() {
514 let x: f32 = 7.0;
515 let args = FixedNamedArgs::new([NamedArgEntry {
516 name: "x",
517 ptr: &x as *const f32 as *const c_void,
518 }]);
519 let entries = args.entries();
520 assert_eq!(entries.len(), 1);
521 assert_eq!(entries[0].name, "x");
522 }
523
524 // ---------------------------------------------------------------------------
525 // Quality gate tests (CPU-only)
526 // ---------------------------------------------------------------------------
527
528 #[test]
529 fn named_kernel_args_empty() {
530 // NamedKernelArgs for () has no args: arg_names is empty, arg_count is 0.
531 let names = <() as NamedKernelArgs>::arg_names();
532 assert!(names.is_empty(), "() must have no arg names");
533 let count = <() as NamedKernelArgs>::arg_count();
534 assert_eq!(count, 0, "() must have arg_count == 0");
535 }
536
537 #[test]
538 fn named_kernel_args_add_and_count() {
539 // After adding 3 named args to ArgBuilder, arg_count() == 3.
540 let a: u32 = 1;
541 let b: f64 = 2.0;
542 let c: u64 = 3;
543 let mut builder = ArgBuilder::new();
544 builder.add("a", &a).add("b", &b).add("c", &c);
545 assert_eq!(
546 builder.arg_count(),
547 3,
548 "ArgBuilder with 3 args must report arg_count == 3"
549 );
550 // Names must be in insertion order
551 let names = builder.names();
552 assert_eq!(names[0], "a");
553 assert_eq!(names[1], "b");
554 assert_eq!(names[2], "c");
555 }
556}