#[macro_export]
macro_rules! launch {
($kernel:expr, grid($g:expr), block($b:expr), shared($s:expr), $stream:expr, $args:expr) => {{
let params = $crate::LaunchParams::new($g, $b).with_shared_mem($s);
$kernel.launch(¶ms, $stream, $args)
}};
($kernel:expr, grid($g:expr), block($b:expr), $stream:expr, $args:expr) => {{
let params = $crate::LaunchParams::new($g, $b);
$kernel.launch(¶ms, $stream, $args)
}};
}
#[macro_export]
macro_rules! named_args {
() => { () };
($($name:ident : $val:expr),+ $(,)?) => {
($($val,)*)
};
}
#[macro_export]
macro_rules! launch_named {
($kernel:expr, grid($g:expr), block($b:expr), shared($s:expr), $stream:expr, {
$($name:ident : $val:expr),+ $(,)?
}) => {{
let args = $crate::named_args!($($name: $val),+);
$crate::launch!($kernel, grid($g), block($b), shared($s), $stream, &args)
}};
($kernel:expr, grid($g:expr), block($b:expr), $stream:expr, {
$($name:ident : $val:expr),+ $(,)?
}) => {{
let args = $crate::named_args!($($name: $val),+);
$crate::launch!($kernel, grid($g), block($b), $stream, &args)
}};
}
#[cfg(test)]
mod tests {
use std::mem::size_of;
#[test]
fn test_named_args_produces_correct_tuple_two_fields() {
let n = 1024u32;
let alpha = 2.0f32;
let pos = (n, alpha);
let named = named_args!(n: n, alpha: alpha);
assert_eq!(pos, named);
}
#[test]
fn test_named_args_single_field() {
let x = 42u64;
let named = named_args!(x: x);
assert_eq!(named.0, 42u64);
}
#[test]
fn test_named_args_three_fields_order_preserved() {
let a = 1u32;
let b = 2u64;
let c = 3.0f32;
let named = named_args!(a: a, b: b, c: c);
assert_eq!(named.0, 1u32);
assert_eq!(named.1, 2u64);
assert!((named.2 - 3.0f32).abs() < f32::EPSILON);
}
#[test]
fn test_named_args_no_extra_size_vs_positional() {
assert_eq!(
size_of::<(u32, f32)>(),
size_of::<(u32, f32)>(),
"named_args! tuple must be the same size as positional tuple"
);
let n = 1024u32;
let alpha = 2.0f32;
let named = named_args!(n: n, alpha: alpha);
let _: (u32, f32) = named;
}
#[test]
fn test_named_args_trailing_comma_allowed() {
let x = 7u32;
let y = 8u64;
let named = named_args!(x: x, y: y,);
assert_eq!(named.0, 7u32);
assert_eq!(named.1, 8u64);
}
#[test]
fn test_named_args_expressions_evaluated() {
let named = named_args!(result: 2u32 + 3u32, factor: 1.5f32 * 2.0f32);
assert_eq!(named.0, 5u32);
assert!((named.1 - 3.0f32).abs() < f32::EPSILON);
}
#[test]
fn test_named_args_four_fields() {
let n = 1024u32;
let a: u64 = 0x1000;
let b: u64 = 0x2000;
let c: u64 = 0x3000;
let named = named_args!(n: n, a: a, b: b, c: c);
assert_eq!(named.0, 1024u32);
assert_eq!(named.1, 0x1000u64);
assert_eq!(named.2, 0x2000u64);
assert_eq!(named.3, 0x3000u64);
}
}