use std::collections::BTreeMap;
use proc_macro::TokenStream;
use syn::{Item, LitStr, Token, parse::Parser, punctuated::Punctuated};
#[cfg(test)]
macro_rules! as_string_slice {
($($path:literal),* $(,)?) => {
&[$($path),*]
};
}
macro_rules! as_include_str_slice {
($($path:literal),* $(,)?) => {
&[$(include_str!($path)),*]
};
}
macro_rules! feature_source_paths {
($macro:ident) => {
$macro!("sources/x86_64/features.rs", "sources/aarch64/features.rs",)
};
}
macro_rules! family_source_paths {
($macro:ident) => {
$macro!(
"sources/x86_64/families/generic.rs",
"sources/x86_64/families/amd.rs",
"sources/x86_64/families/intel.rs",
"sources/aarch64/families/ampere.rs",
"sources/aarch64/families/apple.rs",
"sources/aarch64/families/arm.rs",
"sources/aarch64/families/fujitsu.rs",
"sources/aarch64/families/generic.rs",
"sources/aarch64/families/hi_silicon.rs",
"sources/aarch64/families/marvell.rs",
"sources/aarch64/families/nvidia.rs",
"sources/aarch64/families/qualcomm.rs",
"sources/aarch64/families/samsung.rs",
)
};
}
#[cfg(test)]
const FEATURE_SOURCE_PATHS: &[&str] = feature_source_paths!(as_string_slice);
#[cfg(test)]
const FAMILY_SOURCE_PATHS: &[&str] = family_source_paths!(as_string_slice);
const FEATURE_SOURCES: &[&str] = feature_source_paths!(as_include_str_slice);
const FAMILY_SOURCES: &[&str] = family_source_paths!(as_include_str_slice);
fn parse_string_literals(tokens: proc_macro2::TokenStream) -> syn::Result<Vec<String>> {
let parser = Punctuated::<LitStr, Token![,]>::parse_terminated;
parser
.parse2(tokens)
.map(|items| items.into_iter().map(|literal| literal.value()).collect())
}
fn collect_macro_literals(source: &str, macro_name: &str) -> syn::Result<Vec<Vec<String>>> {
let file = syn::parse_file(source)?;
let mut invocations = Vec::new();
for item in file.items {
if let Item::Macro(item_macro) = item
&& item_macro.mac.path.is_ident(macro_name)
{
invocations.push(parse_string_literals(item_macro.mac.tokens)?);
}
}
Ok(invocations)
}
fn feature_aliases() -> syn::Result<BTreeMap<String, String>> {
let mut aliases = BTreeMap::new();
for source in FEATURE_SOURCES {
for macro_name in ["x86_64_feature", "aarch64_feature"] {
for literals in collect_macro_literals(source, macro_name)? {
match literals.as_slice() {
[target_feature] => {
aliases.insert(target_feature.clone(), target_feature.clone());
}
[target_feature, function_suffix] => {
aliases.insert(function_suffix.clone(), target_feature.clone());
}
_ => {}
}
}
}
}
Ok(aliases)
}
fn family_map() -> syn::Result<BTreeMap<String, Vec<String>>> {
let aliases = feature_aliases()?;
let mut families = BTreeMap::new();
for source in FAMILY_SOURCES {
for literals in collect_macro_literals(source, "declare_is_compatible")? {
let Some((family, feature_suffixes)) = literals.split_first() else {
continue;
};
let features = feature_suffixes
.iter()
.map(|feature| {
aliases
.get(feature)
.cloned()
.unwrap_or_else(|| feature.clone())
})
.collect();
families.insert(family.clone(), features);
}
}
Ok(families)
}
struct TargetFamilyAttr {
family: LitStr,
}
impl syn::parse::Parse for TargetFamilyAttr {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
if input.peek(Token![=]) {
input.parse::<Token![=]>()?;
return Ok(Self {
family: input.parse()?,
});
}
if input.peek(LitStr) {
return Ok(Self {
family: input.parse()?,
});
}
if input.peek(syn::Ident) {
let name = input.parse::<syn::Ident>()?;
if name != "name" {
return Err(syn::Error::new_spanned(
name,
"expected `name = \"family\"` or a string literal",
));
}
input.parse::<Token![=]>()?;
return Ok(Self {
family: input.parse()?,
});
}
Err(input.error("expected `\"family\"` or `name = \"family\"`"))
}
}
#[proc_macro_attribute]
pub fn target_family(attr: TokenStream, item: TokenStream) -> TokenStream {
expand_target_family(attr.into(), item.into())
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
fn expand_target_family(
attr: proc_macro2::TokenStream,
item: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
let TargetFamilyAttr { family } = syn::parse2(attr)?;
let item_fn: syn::ItemFn = syn::parse2(item)?;
let family_name = family.value();
let families = family_map()?;
let Some(features) = families.get(&family_name) else {
return Err(syn::Error::new(
family.span(),
format!("unknown target family `{family_name}`"),
));
};
if item_fn.sig.unsafety.is_none() {
return Err(syn::Error::new_spanned(
item_fn.sig.fn_token,
"target_family expands to #[target_feature], which must be applied to an unsafe fn",
));
}
let feature_list = LitStr::new(&features.join(","), proc_macro2::Span::call_site());
Ok(quote::quote! {
#[target_feature(enable = #feature_list)]
#item_fn
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::{
collections::BTreeSet,
fs,
path::{Path, PathBuf},
};
fn collect_rs_files(root: &Path) -> std::io::Result<BTreeSet<String>> {
let mut paths = BTreeSet::new();
let mut pending = vec![root.to_path_buf()];
while let Some(dir) = pending.pop() {
for entry in fs::read_dir(&dir)? {
let entry = entry?;
let path = entry.path();
if entry.file_type()?.is_dir() {
pending.push(path);
continue;
}
if path.extension().is_some_and(|extension| extension == "rs") {
let relative = path.strip_prefix(root).expect("path is under root");
paths.insert(relative.to_string_lossy().replace('\\', "/"));
}
}
}
Ok(paths)
}
#[test]
fn resolves_x86_64_v3_to_canonical_target_features() {
let families = family_map().expect("family map parses");
let features = families.get("x86_64_v3").expect("x86_64_v3 exists");
assert!(features.contains(&"avx".to_owned()));
assert!(features.contains(&"avx2".to_owned()));
assert!(features.contains(&"bmi1".to_owned()));
assert!(features.contains(&"bmi2".to_owned()));
assert!(features.contains(&"sse4.1".to_owned()));
assert!(features.contains(&"sse4.2".to_owned()));
assert!(!features.contains(&"sse4_1".to_owned()));
assert!(!features.contains(&"sse4_2".to_owned()));
}
#[test]
fn unknown_family_is_absent() {
let families = family_map().expect("family map parses");
assert!(!families.contains_key("not_a_real_family"));
}
#[test]
fn expands_requested_family_to_target_feature_attribute() {
let expanded = expand_target_family(
quote::quote!("x86_64_v3"),
quote::quote!(
unsafe fn accelerated() {}
),
)
.expect("expansion succeeds")
.to_string();
assert!(expanded.contains("target_feature"));
assert!(expanded.contains("enable"));
assert!(expanded.contains("avx,avx2"));
assert!(expanded.contains("sse4.1"));
}
#[test]
fn rejects_safe_functions_with_clear_error() {
let err = expand_target_family(
quote::quote!("x86_64_v3"),
quote::quote!(
fn accelerated() {}
),
)
.expect_err("safe functions are rejected");
assert!(err.to_string().contains("unsafe fn"));
}
#[test]
fn rejects_unknown_families_with_clear_error() {
let err = expand_target_family(
quote::quote!("not_a_real_family"),
quote::quote!(
unsafe fn accelerated() {}
),
)
.expect_err("unknown families are rejected");
assert!(err.to_string().contains("unknown target family"));
}
#[test]
fn rejects_trailing_attribute_tokens() {
let err = expand_target_family(
quote::quote!("x86_64_v3", "x86_64_v4"),
quote::quote!(
unsafe fn accelerated() {}
),
)
.expect_err("trailing tokens are rejected");
assert!(err.to_string().contains("unexpected token"));
}
#[test]
fn all_snapshot_files_are_listed_in_source_constants() {
let snapshot_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("src/sources");
let listed_paths: BTreeSet<_> = FEATURE_SOURCE_PATHS
.iter()
.chain(FAMILY_SOURCE_PATHS.iter())
.map(|path| {
path.strip_prefix("sources/")
.expect("all source constants stay under sources/")
.to_owned()
})
.collect();
assert_eq!(
listed_paths,
collect_rs_files(&snapshot_root).expect("snapshot tree is readable"),
);
}
}