Skip to main content

oxicuda_launch/
macros.rs

1//! Convenience macros for kernel launching.
2//!
3//! The [`launch!`](crate::launch) macro provides a concise syntax for launching GPU kernels
4//! without manually constructing [`LaunchParams`](crate::LaunchParams).
5//!
6//! The [`named_args!`](macro@crate::named_args) macro provides a zero-overhead way to pass named
7//! kernel arguments — names are stripped at compile time, producing the same
8//! tuple as positional arguments with no runtime overhead.
9//!
10//! The [`launch_named!`](crate::launch_named) macro combines named-argument syntax with the
11//! `launch!` convenience macro.
12
13/// Launch a GPU kernel with a concise syntax.
14///
15/// This macro constructs [`LaunchParams`](crate::LaunchParams) from the
16/// provided grid and block dimensions, then calls [`Kernel::launch`](crate::Kernel::launch).
17///
18/// # Syntax
19///
20/// ```text
21/// launch!(kernel, grid(G), block(B), shared(S), stream, args)?;
22/// launch!(kernel, grid(G), block(B), stream, args)?;  // shared_mem = 0
23/// ```
24///
25/// Where:
26/// - `kernel` — a [`Kernel`](crate::Kernel) instance.
27/// - `G` — grid dimensions (anything convertible to [`Dim3`](crate::Dim3)).
28/// - `B` — block dimensions (anything convertible to [`Dim3`](crate::Dim3)).
29/// - `S` — dynamic shared memory in bytes (`u32`).
30/// - `stream` — a reference to a [`Stream`](oxicuda_driver::Stream).
31/// - `args` — a reference to a tuple implementing [`KernelArgs`](crate::KernelArgs).
32///
33/// # Returns
34///
35/// `CudaResult<()>` — use `?` to propagate errors.
36///
37/// # Examples
38///
39/// ```rust,no_run
40/// # use oxicuda_launch::*;
41/// # fn main() -> oxicuda_driver::CudaResult<()> {
42/// # let kernel: Kernel = todo!();
43/// # let stream: oxicuda_driver::Stream = todo!();
44/// let n: u32 = 1024;
45/// let a_ptr: u64 = 0;
46/// let b_ptr: u64 = 0;
47/// let c_ptr: u64 = 0;
48///
49/// // With explicit shared memory
50/// launch!(kernel, grid(4u32), block(256u32), shared(0), &stream, &(a_ptr, b_ptr, c_ptr, n))?;
51///
52/// // Without shared memory (defaults to 0)
53/// launch!(kernel, grid(4u32), block(256u32), &stream, &(a_ptr, b_ptr, c_ptr, n))?;
54/// # Ok(())
55/// # }
56/// ```
57#[macro_export]
58macro_rules! launch {
59    ($kernel:expr, grid($g:expr), block($b:expr), shared($s:expr), $stream:expr, $args:expr) => {{
60        let params = $crate::LaunchParams::new($g, $b).with_shared_mem($s);
61        $kernel.launch(&params, $stream, $args)
62    }};
63    ($kernel:expr, grid($g:expr), block($b:expr), $stream:expr, $args:expr) => {{
64        let params = $crate::LaunchParams::new($g, $b);
65        $kernel.launch(&params, $stream, $args)
66    }};
67}
68
69/// Build a kernel argument tuple from named fields.
70///
71/// Names are stripped at compile time; the result is identical to a plain
72/// positional tuple with zero runtime or size overhead.
73///
74/// # Syntax
75///
76/// ```text
77/// named_args!(name1: value1, name2: value2, ...)
78/// ```
79///
80/// # Returns
81///
82/// A tuple `(value1, value2, ...)` with the same types as the values.
83///
84/// # Examples
85///
86/// ```rust
87/// use oxicuda_launch::named_args;
88///
89/// let n: u32 = 1024;
90/// let alpha: f32 = 2.0;
91///
92/// // Named form — more readable, identical output to positional.
93/// let named = named_args!(n: n, alpha: alpha);
94/// // Positional form — for comparison.
95/// let positional = (n, alpha);
96///
97/// assert_eq!(named, positional);
98/// ```
99#[macro_export]
100macro_rules! named_args {
101    // Base case: nothing → unit tuple.
102    () => { () };
103
104    // One or more `name: value` pairs — strip all names, keep values as tuple.
105    ($($name:ident : $val:expr),+ $(,)?) => {
106        ($($val,)*)
107    };
108}
109
110/// Launch a GPU kernel with named argument syntax.
111///
112/// This macro strips the argument names at compile time and delegates to
113/// [`launch!`](crate::launch), so there is zero runtime overhead versus the positional form.
114///
115/// # Syntax
116///
117/// ```text
118/// launch_named!(kernel, grid(G), block(B), shared(S), stream, {
119///     name1: value1,
120///     name2: value2,
121/// })?;
122///
123/// // Without explicit shared memory (defaults to 0):
124/// launch_named!(kernel, grid(G), block(B), stream, {
125///     name1: value1,
126///     name2: value2,
127/// })?;
128/// ```
129///
130/// # Examples
131///
132/// ```rust,no_run
133/// # use oxicuda_launch::*;
134/// # fn main() -> oxicuda_driver::CudaResult<()> {
135/// # let kernel: Kernel = todo!();
136/// # let stream: oxicuda_driver::Stream = todo!();
137/// let n: u32 = 1024;
138/// let a_ptr: u64 = 0;
139/// let b_ptr: u64 = 0;
140/// let c_ptr: u64 = 0;
141///
142/// launch_named!(kernel, grid(4u32), block(256u32), &stream, {
143///     n: n,
144///     a: a_ptr,
145///     b: b_ptr,
146///     c: c_ptr,
147/// })?;
148/// # Ok(())
149/// # }
150/// ```
151#[macro_export]
152macro_rules! launch_named {
153    // With explicit shared memory.
154    ($kernel:expr, grid($g:expr), block($b:expr), shared($s:expr), $stream:expr, {
155        $($name:ident : $val:expr),+ $(,)?
156    }) => {{
157        let args = $crate::named_args!($($name: $val),+);
158        $crate::launch!($kernel, grid($g), block($b), shared($s), $stream, &args)
159    }};
160
161    // Without shared memory (defaults to 0).
162    ($kernel:expr, grid($g:expr), block($b:expr), $stream:expr, {
163        $($name:ident : $val:expr),+ $(,)?
164    }) => {{
165        let args = $crate::named_args!($($name: $val),+);
166        $crate::launch!($kernel, grid($g), block($b), $stream, &args)
167    }};
168}
169
170// ---------------------------------------------------------------------------
171// Tests for named_args! and launch_named! macro (no GPU required)
172// ---------------------------------------------------------------------------
173
174#[cfg(test)]
175mod tests {
176    use std::mem::size_of;
177
178    #[test]
179    fn test_named_args_produces_correct_tuple_two_fields() {
180        let n = 1024u32;
181        let alpha = 2.0f32;
182
183        let pos = (n, alpha);
184        let named = named_args!(n: n, alpha: alpha);
185
186        assert_eq!(pos, named);
187    }
188
189    #[test]
190    fn test_named_args_single_field() {
191        let x = 42u64;
192        let named = named_args!(x: x);
193        // A single-element named_args! produces a 1-tuple: (x,)
194        assert_eq!(named.0, 42u64);
195    }
196
197    #[test]
198    fn test_named_args_three_fields_order_preserved() {
199        let a = 1u32;
200        let b = 2u64;
201        let c = 3.0f32;
202
203        let named = named_args!(a: a, b: b, c: c);
204        assert_eq!(named.0, 1u32);
205        assert_eq!(named.1, 2u64);
206        assert!((named.2 - 3.0f32).abs() < f32::EPSILON);
207    }
208
209    #[test]
210    fn test_named_args_no_extra_size_vs_positional() {
211        // The macro produces the same tuple type — sizes must match.
212        assert_eq!(
213            size_of::<(u32, f32)>(),
214            size_of::<(u32, f32)>(),
215            "named_args! tuple must be the same size as positional tuple"
216        );
217        // Verify via type inference: named_args! produces the same type as
218        // its positional equivalent.
219        let n = 1024u32;
220        let alpha = 2.0f32;
221        let named = named_args!(n: n, alpha: alpha);
222        // If this compiles, the types are identical.
223        let _: (u32, f32) = named;
224    }
225
226    #[test]
227    fn test_named_args_trailing_comma_allowed() {
228        let x = 7u32;
229        let y = 8u64;
230        // Trailing comma must be accepted without compile error.
231        let named = named_args!(x: x, y: y,);
232        assert_eq!(named.0, 7u32);
233        assert_eq!(named.1, 8u64);
234    }
235
236    #[test]
237    fn test_named_args_expressions_evaluated() {
238        // Values in named_args! can be arbitrary expressions.
239        let named = named_args!(result: 2u32 + 3u32, factor: 1.5f32 * 2.0f32);
240        assert_eq!(named.0, 5u32);
241        assert!((named.1 - 3.0f32).abs() < f32::EPSILON);
242    }
243
244    #[test]
245    fn test_named_args_four_fields() {
246        let n = 1024u32;
247        let a: u64 = 0x1000;
248        let b: u64 = 0x2000;
249        let c: u64 = 0x3000;
250
251        let named = named_args!(n: n, a: a, b: b, c: c);
252        assert_eq!(named.0, 1024u32);
253        assert_eq!(named.1, 0x1000u64);
254        assert_eq!(named.2, 0x2000u64);
255        assert_eq!(named.3, 0x3000u64);
256    }
257}