Skip to main content

oxicuda_launch/
named_args.rs

1//! Named kernel arguments for enhanced debuggability and type safety.
2//!
3//! This module provides [`NamedKernelArgs`], a trait that extends
4//! [`KernelArgs`] with human-readable argument names, and [`ArgBuilder`],
5//! a builder that constructs argument pointer arrays with associated
6//! names for logging and debugging.
7//!
8//! # Motivation
9//!
10//! Standard kernel arguments are positional tuples with no names. When
11//! debugging kernel launches, it is helpful to know which argument
12//! corresponds to which kernel parameter. `NamedKernelArgs` bridges
13//! this gap by associating names with arguments.
14//!
15//! # Example
16//!
17//! ```rust
18//! use oxicuda_launch::named_args::ArgBuilder;
19//!
20//! let a_ptr: u64 = 0x1000;
21//! let n: u32 = 1024;
22//! let mut builder = ArgBuilder::new();
23//! builder.add("a_ptr", &a_ptr).add("n", &n);
24//! assert_eq!(builder.names(), &["a_ptr", "n"]);
25//! let ptrs = builder.build();
26//! assert_eq!(ptrs.len(), 2);
27//! ```
28
29use std::ffi::c_void;
30
31use crate::kernel::KernelArgs;
32
33// ---------------------------------------------------------------------------
34// NamedKernelArgs trait
35// ---------------------------------------------------------------------------
36
37/// Extension of [`KernelArgs`] that provides argument metadata.
38///
39/// Types implementing this trait can report the names and count of
40/// their kernel arguments, which is useful for debugging, logging,
41/// and validation.
42///
43/// # Safety
44///
45/// Implementors must uphold the same invariants as [`KernelArgs`].
46/// The names returned by `arg_names` must correspond one-to-one with
47/// the pointers returned by `as_param_ptrs`.
48pub unsafe trait NamedKernelArgs: KernelArgs {
49    /// Returns the names of all kernel arguments in order.
50    fn arg_names() -> &'static [&'static str];
51
52    /// Returns the number of kernel arguments.
53    fn arg_count() -> usize {
54        Self::arg_names().len()
55    }
56}
57
58// ---------------------------------------------------------------------------
59// ArgEntry
60// ---------------------------------------------------------------------------
61
62/// An entry in the argument builder, holding a pointer and its name.
63#[derive(Debug)]
64struct ArgEntry {
65    /// Pointer to the argument value.
66    ptr: *mut c_void,
67    /// Human-readable name for the argument.
68    name: String,
69}
70
71// ---------------------------------------------------------------------------
72// ArgBuilder
73// ---------------------------------------------------------------------------
74
75/// A builder for constructing named kernel argument arrays.
76///
77/// Collects typed argument values along with their names, then
78/// produces the `Vec<*mut c_void>` array needed by `cuLaunchKernel`.
79///
80/// # Example
81///
82/// ```rust
83/// use oxicuda_launch::named_args::ArgBuilder;
84///
85/// let x: f32 = 3.14;
86/// let n: u32 = 512;
87/// let mut builder = ArgBuilder::new();
88/// builder.add("x", &x).add("n", &n);
89/// assert_eq!(builder.arg_count(), 2);
90/// let ptrs = builder.build();
91/// assert_eq!(ptrs.len(), 2);
92/// ```
93pub struct ArgBuilder {
94    /// Collected argument entries.
95    args: Vec<ArgEntry>,
96}
97
98impl ArgBuilder {
99    /// Creates a new empty argument builder.
100    #[inline]
101    pub fn new() -> Self {
102        Self { args: Vec::new() }
103    }
104
105    /// Creates a new argument builder with the given initial capacity.
106    #[inline]
107    pub fn with_capacity(capacity: usize) -> Self {
108        Self {
109            args: Vec::with_capacity(capacity),
110        }
111    }
112
113    /// Adds a named argument to the builder.
114    ///
115    /// The pointer to `val` is stored. The caller must ensure that `val`
116    /// remains valid (not moved or dropped) until the kernel launch
117    /// using the built pointer array completes.
118    ///
119    /// Returns `&mut Self` for method chaining.
120    pub fn add<T: Copy>(&mut self, name: &str, val: &T) -> &mut Self {
121        self.args.push(ArgEntry {
122            ptr: val as *const T as *mut c_void,
123            name: name.to_owned(),
124        });
125        self
126    }
127
128    /// Builds the argument pointer array for `cuLaunchKernel`.
129    ///
130    /// Returns the raw pointer array. The names are consumed; use
131    /// [`names`](Self::names) before calling `build` if you need them.
132    pub fn build(self) -> Vec<*mut c_void> {
133        self.args.into_iter().map(|entry| entry.ptr).collect()
134    }
135
136    /// Returns the names of all added arguments in order.
137    pub fn names(&self) -> Vec<&str> {
138        self.args.iter().map(|entry| entry.name.as_str()).collect()
139    }
140
141    /// Returns the number of arguments added so far.
142    #[inline]
143    pub fn arg_count(&self) -> usize {
144        self.args.len()
145    }
146
147    /// Returns `true` if no arguments have been added.
148    #[inline]
149    pub fn is_empty(&self) -> bool {
150        self.args.is_empty()
151    }
152
153    /// Returns a human-readable summary of the arguments.
154    pub fn summary(&self) -> String {
155        let parts: Vec<String> = self
156            .args
157            .iter()
158            .map(|entry| format!("{}={:p}", entry.name, entry.ptr))
159            .collect();
160        format!("ArgBuilder[{}]", parts.join(", "))
161    }
162}
163
164impl Default for ArgBuilder {
165    #[inline]
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl std::fmt::Debug for ArgBuilder {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        f.debug_struct("ArgBuilder")
174            .field("count", &self.args.len())
175            .field("names", &self.names())
176            .finish()
177    }
178}
179
180// Implement NamedKernelArgs for () (no arguments).
181//
182// SAFETY: Returns an empty name array, consistent with the empty
183// pointer array from KernelArgs for ().
184unsafe impl NamedKernelArgs for () {
185    fn arg_names() -> &'static [&'static str] {
186        &[]
187    }
188}
189
190// ---------------------------------------------------------------------------
191// FixedNamedArgs — stack-allocated named argument array
192// ---------------------------------------------------------------------------
193
194/// A single named argument entry: a static string name paired with a raw
195/// const pointer to the argument value (already on the stack in the caller).
196///
197/// The pointer is `*const c_void` so no allocation occurs; it is the
198/// caller's responsibility to ensure the pointed-to value outlives the
199/// kernel launch.
200#[derive(Debug, Clone, Copy)]
201pub struct NamedArgEntry {
202    /// The human-readable name of this kernel argument.
203    pub name: &'static str,
204    /// Raw const pointer to the argument value.
205    pub ptr: *const c_void,
206}
207
208// SAFETY: NamedArgEntry only holds a raw pointer to a value that lives
209// at least as long as the kernel launch (caller guarantees this).
210// The pointer is never dereferenced inside this module.
211unsafe impl Send for NamedArgEntry {}
212unsafe impl Sync for NamedArgEntry {}
213
214/// A const-generic, stack-allocated array of named kernel arguments.
215///
216/// `FixedNamedArgs<N>` stores exactly `N` [`NamedArgEntry`] values on
217/// the stack — zero heap allocation, zero indirection overhead compared
218/// to a plain positional tuple.
219///
220/// # Invariants
221///
222/// Every pointed-to value must remain valid (not moved or dropped) until
223/// after the kernel launch that consumes this struct completes.
224///
225/// # Example
226///
227/// ```rust
228/// use oxicuda_launch::named_args::{FixedNamedArgs, NamedArgEntry};
229///
230/// let n: u32 = 1024;
231/// let alpha: f32 = 2.0;
232///
233/// let args = FixedNamedArgs::new([
234///     NamedArgEntry { name: "n",     ptr: &n     as *const u32 as *const std::ffi::c_void },
235///     NamedArgEntry { name: "alpha", ptr: &alpha as *const f32  as *const std::ffi::c_void },
236/// ]);
237///
238/// assert_eq!(args.len(), 2);
239/// assert_eq!(args.names()[0], "n");
240/// assert_eq!(args.names()[1], "alpha");
241/// ```
242pub struct FixedNamedArgs<const N: usize> {
243    /// The argument entries, stored inline on the stack.
244    entries: [NamedArgEntry; N],
245}
246
247impl<const N: usize> FixedNamedArgs<N> {
248    /// Creates a new `FixedNamedArgs` from an array of [`NamedArgEntry`] values.
249    #[inline]
250    pub const fn new(entries: [NamedArgEntry; N]) -> Self {
251        Self { entries }
252    }
253
254    /// Returns the number of arguments.
255    #[inline]
256    pub const fn len(&self) -> usize {
257        N
258    }
259
260    /// Returns `true` if there are no arguments.
261    #[inline]
262    pub const fn is_empty(&self) -> bool {
263        N == 0
264    }
265
266    /// Returns the argument names in declaration order.
267    pub fn names(&self) -> [&'static str; N] {
268        let mut out = [""; N];
269        for (i, entry) in self.entries.iter().enumerate() {
270            out[i] = entry.name;
271        }
272        out
273    }
274
275    /// Returns a mutable array of `*mut c_void` pointers suitable for
276    /// passing directly to `cuLaunchKernel` as the `kernelParams` array.
277    ///
278    /// # Safety
279    ///
280    /// The returned pointers are valid only as long as the original values
281    /// that were passed when constructing the entries remain in scope.
282    pub fn as_ptr_array(&self) -> [*mut c_void; N] {
283        let mut out = [std::ptr::null_mut::<c_void>(); N];
284        for (i, entry) in self.entries.iter().enumerate() {
285            // Cast const → mut: cuLaunchKernel's ABI takes `void**` but
286            // never mutates the argument values.
287            out[i] = entry.ptr as *mut c_void;
288        }
289        out
290    }
291
292    /// Returns an immutable slice over the entries for inspection.
293    #[inline]
294    pub fn entries(&self) -> &[NamedArgEntry; N] {
295        &self.entries
296    }
297}
298
299impl<const N: usize> std::fmt::Debug for FixedNamedArgs<N> {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        let mut ds = f.debug_struct("FixedNamedArgs");
302        ds.field("len", &N);
303        for entry in &self.entries {
304            ds.field(entry.name, &entry.ptr);
305        }
306        ds.finish()
307    }
308}
309
310// ---------------------------------------------------------------------------
311// Tests
312// ---------------------------------------------------------------------------
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn arg_builder_new_empty() {
320        let builder = ArgBuilder::new();
321        assert!(builder.is_empty());
322        assert_eq!(builder.arg_count(), 0);
323    }
324
325    #[test]
326    fn arg_builder_add_and_names() {
327        let x: u32 = 42;
328        let y: f64 = 2.78;
329        let mut builder = ArgBuilder::new();
330        builder.add("x", &x).add("y", &y);
331        assert_eq!(builder.arg_count(), 2);
332        assert_eq!(builder.names(), vec!["x", "y"]);
333    }
334
335    #[test]
336    fn arg_builder_build_pointer_count() {
337        let a: u64 = 0x1000;
338        let b: u32 = 512;
339        let c: f32 = 1.0;
340        let mut builder = ArgBuilder::new();
341        builder.add("a", &a).add("b", &b).add("c", &c);
342        let ptrs = builder.build();
343        assert_eq!(ptrs.len(), 3);
344    }
345
346    #[test]
347    fn arg_builder_build_pointers_valid() {
348        let val: u32 = 99;
349        let mut builder = ArgBuilder::new();
350        builder.add("val", &val);
351        let ptrs = builder.build();
352        assert_eq!(ptrs.len(), 1);
353        let read_back = unsafe { *(ptrs[0] as *const u32) };
354        assert_eq!(read_back, 99);
355    }
356
357    #[test]
358    fn arg_builder_summary() {
359        let x: u32 = 10;
360        let mut builder = ArgBuilder::new();
361        builder.add("x", &x);
362        let s = builder.summary();
363        assert!(s.starts_with("ArgBuilder["));
364        assert!(s.contains("x="));
365    }
366
367    #[test]
368    fn arg_builder_debug() {
369        let builder = ArgBuilder::new();
370        let dbg = format!("{builder:?}");
371        assert!(dbg.contains("ArgBuilder"));
372        assert!(dbg.contains("count"));
373    }
374
375    #[test]
376    fn arg_builder_default() {
377        let builder = ArgBuilder::default();
378        assert!(builder.is_empty());
379    }
380
381    #[test]
382    fn arg_builder_with_capacity() {
383        let builder = ArgBuilder::with_capacity(8);
384        assert!(builder.is_empty());
385        assert_eq!(builder.arg_count(), 0);
386    }
387
388    #[test]
389    fn named_kernel_args_trait_exists() {
390        // Verify the trait compiles with the safety requirement.
391        fn assert_named<T: NamedKernelArgs>() {
392            let _names = T::arg_names();
393            let _count = T::arg_count();
394        }
395        // We cannot call assert_named without an implementor, but
396        // the existence of the function is enough.
397        let _ = assert_named::<()> as *const ();
398    }
399
400    // -----------------------------------------------------------------------
401    // FixedNamedArgs tests
402    // -----------------------------------------------------------------------
403
404    #[test]
405    fn fixed_named_args_zero_size() {
406        let args: FixedNamedArgs<0> = FixedNamedArgs::new([]);
407        assert!(args.is_empty());
408        assert_eq!(args.len(), 0);
409        let ptrs = args.as_ptr_array();
410        assert_eq!(ptrs.len(), 0);
411    }
412
413    #[test]
414    fn fixed_named_args_single_u32() {
415        let n: u32 = 1024;
416        let args = FixedNamedArgs::new([NamedArgEntry {
417            name: "n",
418            ptr: &n as *const u32 as *const c_void,
419        }]);
420        assert_eq!(args.len(), 1);
421        assert!(!args.is_empty());
422        assert_eq!(args.names(), ["n"]);
423
424        let ptrs = args.as_ptr_array();
425        assert_eq!(ptrs.len(), 1);
426        // Verify the pointer round-trips to the original value.
427        let read_back = unsafe { *(ptrs[0] as *const u32) };
428        assert_eq!(read_back, 1024);
429    }
430
431    #[test]
432    fn fixed_named_args_two_entries_order_preserved() {
433        let n: u32 = 512;
434        let alpha: f32 = std::f32::consts::PI;
435        let args = FixedNamedArgs::new([
436            NamedArgEntry {
437                name: "n",
438                ptr: &n as *const u32 as *const c_void,
439            },
440            NamedArgEntry {
441                name: "alpha",
442                ptr: &alpha as *const f32 as *const c_void,
443            },
444        ]);
445
446        assert_eq!(args.len(), 2);
447        let names = args.names();
448        assert_eq!(names[0], "n");
449        assert_eq!(names[1], "alpha");
450
451        let ptrs = args.as_ptr_array();
452        let n_back = unsafe { *(ptrs[0] as *const u32) };
453        let a_back = unsafe { *(ptrs[1] as *const f32) };
454        assert_eq!(n_back, 512);
455        assert!((a_back - std::f32::consts::PI).abs() < f32::EPSILON);
456    }
457
458    #[test]
459    fn fixed_named_args_no_size_overhead_vs_entry_array() {
460        use std::mem::size_of;
461        // FixedNamedArgs<N> must have the same size as [NamedArgEntry; N].
462        assert_eq!(
463            size_of::<FixedNamedArgs<4>>(),
464            size_of::<[NamedArgEntry; 4]>(),
465            "FixedNamedArgs<4> must not add any size overhead"
466        );
467        assert_eq!(
468            size_of::<FixedNamedArgs<1>>(),
469            size_of::<[NamedArgEntry; 1]>(),
470            "FixedNamedArgs<1> must not add any size overhead"
471        );
472    }
473
474    #[test]
475    fn fixed_named_args_ptr_array_length_matches_n() {
476        let a: u64 = 0x1000_0000;
477        let b: u32 = 256;
478        let c: f64 = 1.0;
479        let args = FixedNamedArgs::new([
480            NamedArgEntry {
481                name: "a",
482                ptr: &a as *const u64 as *const c_void,
483            },
484            NamedArgEntry {
485                name: "b",
486                ptr: &b as *const u32 as *const c_void,
487            },
488            NamedArgEntry {
489                name: "c",
490                ptr: &c as *const f64 as *const c_void,
491            },
492        ]);
493        let ptrs = args.as_ptr_array();
494        assert_eq!(ptrs.len(), 3);
495    }
496
497    #[test]
498    fn fixed_named_args_debug_contains_len() {
499        let n: u32 = 42;
500        let args = FixedNamedArgs::new([NamedArgEntry {
501            name: "n",
502            ptr: &n as *const u32 as *const c_void,
503        }]);
504        let dbg = format!("{args:?}");
505        assert!(dbg.contains("FixedNamedArgs"), "Debug output: {dbg}");
506        assert!(
507            dbg.contains('1') || dbg.contains("len"),
508            "Debug output: {dbg}"
509        );
510    }
511
512    #[test]
513    fn fixed_named_args_entries_accessor() {
514        let x: f32 = 7.0;
515        let args = FixedNamedArgs::new([NamedArgEntry {
516            name: "x",
517            ptr: &x as *const f32 as *const c_void,
518        }]);
519        let entries = args.entries();
520        assert_eq!(entries.len(), 1);
521        assert_eq!(entries[0].name, "x");
522    }
523
524    // ---------------------------------------------------------------------------
525    // Quality gate tests (CPU-only)
526    // ---------------------------------------------------------------------------
527
528    #[test]
529    fn named_kernel_args_empty() {
530        // NamedKernelArgs for () has no args: arg_names is empty, arg_count is 0.
531        let names = <() as NamedKernelArgs>::arg_names();
532        assert!(names.is_empty(), "() must have no arg names");
533        let count = <() as NamedKernelArgs>::arg_count();
534        assert_eq!(count, 0, "() must have arg_count == 0");
535    }
536
537    #[test]
538    fn named_kernel_args_add_and_count() {
539        // After adding 3 named args to ArgBuilder, arg_count() == 3.
540        let a: u32 = 1;
541        let b: f64 = 2.0;
542        let c: u64 = 3;
543        let mut builder = ArgBuilder::new();
544        builder.add("a", &a).add("b", &b).add("c", &c);
545        assert_eq!(
546            builder.arg_count(),
547            3,
548            "ArgBuilder with 3 args must report arg_count == 3"
549        );
550        // Names must be in insertion order
551        let names = builder.names();
552        assert_eq!(names[0], "a");
553        assert_eq!(names[1], "b");
554        assert_eq!(names[2], "c");
555    }
556}