#[macro_export]
macro_rules! simd_dispatch {
(
$(#[$meta:meta])*
$vis:vis fn $name:ident($($arg:ident: $type:ty),* $(,)?) -> $ret:ty {
wasm_simd: $wasm:expr,
avx2: $avx2:expr,
neon: $neon:expr,
fallback: $fallback:expr $(,)?
}
) => {
$(#[$meta])*
$vis fn $name($($arg: $type),*) -> $ret {
cfg_if::cfg_if! {
if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] {
$wasm
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
$avx2
} else if #[cfg(target_arch = "aarch64")] {
$neon
} else {
$fallback
}
}
}
};
(
$(#[$meta:meta])*
$vis:vis fn $name:ident($($arg:ident: $type:ty),* $(,)?) -> $ret:ty {
wasm_simd: $wasm:expr,
avx2: $avx2:expr,
fallback: $fallback:expr $(,)?
}
) => {
$(#[$meta])*
$vis fn $name($($arg: $type),*) -> $ret {
cfg_if::cfg_if! {
if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] {
$wasm
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
$avx2
} else {
$fallback
}
}
}
};
(
$(#[$meta:meta])*
$vis:vis fn $name:ident($($arg:ident: $type:ty),* $(,)?) -> $ret:ty {
wasm_simd: $wasm:expr,
neon: $neon:expr,
fallback: $fallback:expr $(,)?
}
) => {
$(#[$meta])*
$vis fn $name($($arg: $type),*) -> $ret {
cfg_if::cfg_if! {
if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] {
$wasm
} else if #[cfg(target_arch = "aarch64")] {
$neon
} else {
$fallback
}
}
}
};
(
$(#[$meta:meta])*
$vis:vis fn $name:ident($($arg:ident: $type:ty),* $(,)?) -> $ret:ty {
avx2: $avx2:expr,
neon: $neon:expr,
fallback: $fallback:expr $(,)?
}
) => {
$(#[$meta])*
$vis fn $name($($arg: $type),*) -> $ret {
cfg_if::cfg_if! {
if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
$avx2
} else if #[cfg(target_arch = "aarch64")] {
$neon
} else {
$fallback
}
}
}
};
(
$(#[$meta:meta])*
$vis:vis fn $name:ident($($arg:ident: $type:ty),* $(,)?) -> $ret:ty {
wasm_simd: $wasm:expr,
fallback: $fallback:expr $(,)?
}
) => {
$(#[$meta])*
$vis fn $name($($arg: $type),*) -> $ret {
cfg_if::cfg_if! {
if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] {
$wasm
} else {
$fallback
}
}
}
};
(
$(#[$meta:meta])*
$vis:vis fn $name:ident($($arg:ident: $type:ty),* $(,)?) -> $ret:ty {
avx2: $avx2:expr,
fallback: $fallback:expr $(,)?
}
) => {
$(#[$meta])*
$vis fn $name($($arg: $type),*) -> $ret {
cfg_if::cfg_if! {
if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
$avx2
} else {
$fallback
}
}
}
};
(
$(#[$meta:meta])*
$vis:vis fn $name:ident($($arg:ident: $type:ty),* $(,)?) -> $ret:ty {
neon: $neon:expr,
fallback: $fallback:expr $(,)?
}
) => {
$(#[$meta])*
$vis fn $name($($arg: $type),*) -> $ret {
cfg_if::cfg_if! {
if #[cfg(target_arch = "aarch64")] {
$neon
} else {
$fallback
}
}
}
};
(
$(#[$meta:meta])*
$vis:vis fn $name:ident($($arg:ident: $type:ty),* $(,)?) -> $ret:ty {
fallback: $fallback:expr $(,)?
}
) => {
$(#[$meta])*
$vis fn $name($($arg: $type),*) -> $ret {
$fallback
}
};
}
#[cfg(test)]
mod tests {
fn scalar_add(a: u32, b: u32) -> u32 {
a + b
}
fn scalar_mul(a: f32, b: f32) -> f32 {
a * b
}
simd_dispatch! {
#[inline]
fn test_fallback_only(a: u32, b: u32) -> u32 {
fallback: scalar_add(a, b),
}
}
#[test]
fn test_fallback_only_works() {
assert_eq!(test_fallback_only(2, 3), 5);
}
simd_dispatch! {
#[inline]
#[must_use]
pub fn test_with_attrs(x: f32, y: f32) -> f32 {
fallback: scalar_mul(x, y),
}
}
#[test]
fn test_attributes_preserved() {
let result = test_with_attrs(2.0, 3.0);
assert!((result - 6.0).abs() < 1e-6);
}
simd_dispatch! {
fn test_multi_arg(a: u32, b: u32, c: u32) -> u32 {
fallback: a + b + c,
}
}
#[test]
fn test_multiple_arguments() {
assert_eq!(test_multi_arg(1, 2, 3), 6);
}
fn scalar_sum(slice: &[f32]) -> f32 {
slice.iter().sum()
}
simd_dispatch! {
fn test_slice_arg(data: &[f32]) -> f32 {
fallback: scalar_sum(data),
}
}
#[test]
fn test_slice_argument() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let result = test_slice_arg(&data);
assert!((result - 10.0).abs() < 1e-6);
}
fn scalar_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
simd_dispatch! {
#[inline]
fn test_distance(a: &[f32], b: &[f32]) -> f32 {
fallback: scalar_distance(a, b),
}
}
#[test]
fn test_distance_pattern() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
let dist = test_distance(&a, &b);
assert!((dist - 5.0).abs() < 1e-6);
}
}