#![deny(unsafe_code)]
#![allow(dead_code)]
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{parse_macro_input, ItemFn, LitStr, Token};
const X86_64_V2: &str = "x86_64+sse+sse2+sse3+ssse3+sse4.1+sse4.2+popcnt+cmpxchg16b";
const X86_64_V3: &str =
"x86_64+sse+sse2+sse3+ssse3+sse4.1+sse4.2+popcnt+cmpxchg16b+avx+avx2+fma+bmi1+bmi2+f16c+lzcnt+movbe";
const X86_64_V4: &str =
"x86_64+sse+sse2+sse3+ssse3+sse4.1+sse4.2+popcnt+cmpxchg16b+avx+avx2+fma+bmi1+bmi2+f16c+lzcnt+movbe+avx512f+avx512bw+avx512cd+avx512dq+avx512vl";
const X86_64_V4_MODERN: &str =
"x86_64+sse+sse2+sse3+ssse3+sse4.1+sse4.2+popcnt+cmpxchg16b+avx+avx2+fma+bmi1+bmi2+f16c+lzcnt+movbe+avx512f+avx512bw+avx512cd+avx512dq+avx512vl+avx512vpopcntdq+avx512ifma+avx512vbmi+avx512vbmi2+avx512bitalg+avx512vnni+vpclmulqdq+gfni+vaes";
const ARM64_V2: &str = "aarch64+neon+crc+rdm+dotprod+fp16+aes+sha2";
const ARM64_V3: &str = "aarch64+neon+crc+rdm+dotprod+fp16+aes+sha2+fhm+fcma+sha3+i8mm+bf16";
fn resolve_target(s: &str) -> Option<&str> {
match s {
"x86-64-v2" => Some(X86_64_V2),
"x86-64-v3" => Some(X86_64_V3),
"x86-64-v4" => Some(X86_64_V4),
"x86-64-v4-modern" | "x86-64-v4x" => Some(X86_64_V4_MODERN),
"arm64" | "arm64-v2" => Some(ARM64_V2),
"arm64-v3" => Some(ARM64_V3),
"wasm32-simd128" => None,
s if s.contains('+') && !s.starts_with("wasm32") => Some(s),
_ => None,
}
}
fn is_x86_target(s: &str) -> bool {
s.starts_with("x86_64+") || s.starts_with("x86+") || s.starts_with("x86-64-")
}
fn is_aarch64_target(s: &str) -> bool {
s.starts_with("aarch64+") || s.starts_with("aarch64-")
}
struct MultiversedArgs {
targets: Vec<String>,
}
impl Parse for MultiversedArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut targets = Vec::new();
while !input.is_empty() {
let lit: LitStr = input.parse()?;
targets.push(lit.value());
if input.peek(Token![,]) {
let _: Token![,] = input.parse()?;
}
}
Ok(MultiversedArgs { targets })
}
}
#[allow(clippy::vec_init_then_push)]
fn default_x86_targets() -> Vec<&'static str> {
let mut targets = Vec::new();
#[cfg(any(feature = "x86-64-v4x", feature = "x86-64-v4-modern"))]
targets.push(X86_64_V4_MODERN);
#[cfg(feature = "x86-64-v4")]
targets.push(X86_64_V4);
#[cfg(feature = "x86-64-v3")]
targets.push(X86_64_V3);
#[cfg(feature = "x86-64-v2")]
targets.push(X86_64_V2);
targets
}
#[allow(clippy::vec_init_then_push)]
fn default_aarch64_targets() -> Vec<&'static str> {
let mut targets = Vec::new();
#[cfg(feature = "arm64-v3")]
targets.push(ARM64_V3);
#[cfg(feature = "arm64-v2")]
targets.push(ARM64_V2);
targets
}
#[proc_macro_attribute]
pub fn multiversed(attr: TokenStream, item: TokenStream) -> TokenStream {
let func = parse_macro_input!(item as ItemFn);
#[cfg(feature = "force-disable")]
{
let _ = attr; #[allow(clippy::needless_return)]
return quote! { #func }.into();
}
#[cfg(not(feature = "force-disable"))]
{
let args = parse_macro_input!(attr as MultiversedArgs);
multiversed_impl(args, func)
}
}
fn multiversed_impl(args: MultiversedArgs, func: ItemFn) -> TokenStream {
let (x86_targets, aarch64_targets) = if args.targets.is_empty() {
let x86: Vec<String> = default_x86_targets()
.into_iter()
.map(String::from)
.collect();
let aarch64: Vec<String> = default_aarch64_targets()
.into_iter()
.map(String::from)
.collect();
(x86, aarch64)
} else {
let mut resolved: Vec<String> = Vec::new();
for s in &args.targets {
if let Some(target) = resolve_target(s) {
let target = target.to_string();
if !resolved.contains(&target) {
resolved.push(target);
}
}
}
let x86: Vec<String> = resolved
.iter()
.filter(|s| is_x86_target(s))
.cloned()
.collect();
let aarch64: Vec<String> = resolved
.iter()
.filter(|s| is_aarch64_target(s))
.cloned()
.collect();
(x86, aarch64)
};
let x86_attr = if x86_targets.is_empty() {
quote! {}
} else {
quote! {
#[cfg_attr(
any(target_arch = "x86", target_arch = "x86_64"),
multiversion::multiversion(targets(#(#x86_targets),*))
)]
}
};
let aarch64_attr = if aarch64_targets.is_empty() {
quote! {}
} else {
quote! {
#[cfg_attr(
target_arch = "aarch64",
multiversion::multiversion(targets(#(#aarch64_targets),*))
)]
}
};
quote! {
#x86_attr
#aarch64_attr
#func
}
.into()
}