cpudetect-macros 0.2.0

Procedural macros for cpudetect
Documentation
use std::collections::BTreeMap;

use proc_macro::TokenStream;
use syn::{Item, LitStr, Token, parse::Parser, punctuated::Punctuated};

#[cfg(test)]
macro_rules! as_string_slice {
    ($($path:literal),* $(,)?) => {
        &[$($path),*]
    };
}

macro_rules! as_include_str_slice {
    ($($path:literal),* $(,)?) => {
        &[$(include_str!($path)),*]
    };
}

macro_rules! feature_source_paths {
    ($macro:ident) => {
        $macro!("sources/x86_64/features.rs", "sources/aarch64/features.rs",)
    };
}

macro_rules! family_source_paths {
    ($macro:ident) => {
        $macro!(
            "sources/x86_64/families/generic.rs",
            "sources/x86_64/families/amd.rs",
            "sources/x86_64/families/intel.rs",
            "sources/aarch64/families/ampere.rs",
            "sources/aarch64/families/apple.rs",
            "sources/aarch64/families/arm.rs",
            "sources/aarch64/families/fujitsu.rs",
            "sources/aarch64/families/generic.rs",
            "sources/aarch64/families/hi_silicon.rs",
            "sources/aarch64/families/marvell.rs",
            "sources/aarch64/families/nvidia.rs",
            "sources/aarch64/families/qualcomm.rs",
            "sources/aarch64/families/samsung.rs",
        )
    };
}

#[cfg(test)]
const FEATURE_SOURCE_PATHS: &[&str] = feature_source_paths!(as_string_slice);
#[cfg(test)]
const FAMILY_SOURCE_PATHS: &[&str] = family_source_paths!(as_string_slice);

const FEATURE_SOURCES: &[&str] = feature_source_paths!(as_include_str_slice);
const FAMILY_SOURCES: &[&str] = family_source_paths!(as_include_str_slice);

fn parse_string_literals(tokens: proc_macro2::TokenStream) -> syn::Result<Vec<String>> {
    let parser = Punctuated::<LitStr, Token![,]>::parse_terminated;
    parser
        .parse2(tokens)
        .map(|items| items.into_iter().map(|literal| literal.value()).collect())
}

fn collect_macro_literals(source: &str, macro_name: &str) -> syn::Result<Vec<Vec<String>>> {
    let file = syn::parse_file(source)?;
    let mut invocations = Vec::new();

    for item in file.items {
        if let Item::Macro(item_macro) = item
            && item_macro.mac.path.is_ident(macro_name)
        {
            invocations.push(parse_string_literals(item_macro.mac.tokens)?);
        }
    }

    Ok(invocations)
}

fn feature_aliases() -> syn::Result<BTreeMap<String, String>> {
    let mut aliases = BTreeMap::new();

    for source in FEATURE_SOURCES {
        for macro_name in ["x86_64_feature", "aarch64_feature"] {
            for literals in collect_macro_literals(source, macro_name)? {
                match literals.as_slice() {
                    [target_feature] => {
                        aliases.insert(target_feature.clone(), target_feature.clone());
                    }
                    [target_feature, function_suffix] => {
                        aliases.insert(function_suffix.clone(), target_feature.clone());
                    }
                    _ => {}
                }
            }
        }
    }

    Ok(aliases)
}

fn family_map() -> syn::Result<BTreeMap<String, Vec<String>>> {
    let aliases = feature_aliases()?;
    let mut families = BTreeMap::new();

    for source in FAMILY_SOURCES {
        for literals in collect_macro_literals(source, "declare_is_compatible")? {
            let Some((family, feature_suffixes)) = literals.split_first() else {
                continue;
            };

            let features = feature_suffixes
                .iter()
                .map(|feature| {
                    aliases
                        .get(feature)
                        .cloned()
                        .unwrap_or_else(|| feature.clone())
                })
                .collect();

            families.insert(family.clone(), features);
        }
    }

    Ok(families)
}

struct TargetFamilyAttr {
    family: LitStr,
}

impl syn::parse::Parse for TargetFamilyAttr {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        if input.peek(Token![=]) {
            input.parse::<Token![=]>()?;

            return Ok(Self {
                family: input.parse()?,
            });
        }

        if input.peek(LitStr) {
            return Ok(Self {
                family: input.parse()?,
            });
        }

        if input.peek(syn::Ident) {
            let name = input.parse::<syn::Ident>()?;
            if name != "name" {
                return Err(syn::Error::new_spanned(
                    name,
                    "expected `name = \"family\"` or a string literal",
                ));
            }
            input.parse::<Token![=]>()?;

            return Ok(Self {
                family: input.parse()?,
            });
        }

        Err(input.error("expected `\"family\"` or `name = \"family\"`"))
    }
}

#[proc_macro_attribute]
pub fn target_family(attr: TokenStream, item: TokenStream) -> TokenStream {
    expand_target_family(attr.into(), item.into())
        .unwrap_or_else(syn::Error::into_compile_error)
        .into()
}

fn expand_target_family(
    attr: proc_macro2::TokenStream,
    item: proc_macro2::TokenStream,
) -> syn::Result<proc_macro2::TokenStream> {
    let TargetFamilyAttr { family } = syn::parse2(attr)?;
    let item_fn: syn::ItemFn = syn::parse2(item)?;
    let family_name = family.value();
    let families = family_map()?;
    let Some(features) = families.get(&family_name) else {
        return Err(syn::Error::new(
            family.span(),
            format!("unknown target family `{family_name}`"),
        ));
    };

    if item_fn.sig.unsafety.is_none() {
        return Err(syn::Error::new_spanned(
            item_fn.sig.fn_token,
            "target_family expands to #[target_feature], which must be applied to an unsafe fn",
        ));
    }

    let feature_list = LitStr::new(&features.join(","), proc_macro2::Span::call_site());

    Ok(quote::quote! {
        #[target_feature(enable = #feature_list)]
        #item_fn
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::{
        collections::BTreeSet,
        fs,
        path::{Path, PathBuf},
    };

    fn collect_rs_files(root: &Path) -> std::io::Result<BTreeSet<String>> {
        let mut paths = BTreeSet::new();
        let mut pending = vec![root.to_path_buf()];

        while let Some(dir) = pending.pop() {
            for entry in fs::read_dir(&dir)? {
                let entry = entry?;
                let path = entry.path();

                if entry.file_type()?.is_dir() {
                    pending.push(path);
                    continue;
                }

                if path.extension().is_some_and(|extension| extension == "rs") {
                    let relative = path.strip_prefix(root).expect("path is under root");
                    paths.insert(relative.to_string_lossy().replace('\\', "/"));
                }
            }
        }

        Ok(paths)
    }

    #[test]
    fn resolves_x86_64_v3_to_canonical_target_features() {
        let families = family_map().expect("family map parses");
        let features = families.get("x86_64_v3").expect("x86_64_v3 exists");

        assert!(features.contains(&"avx".to_owned()));
        assert!(features.contains(&"avx2".to_owned()));
        assert!(features.contains(&"bmi1".to_owned()));
        assert!(features.contains(&"bmi2".to_owned()));
        assert!(features.contains(&"sse4.1".to_owned()));
        assert!(features.contains(&"sse4.2".to_owned()));
        assert!(!features.contains(&"sse4_1".to_owned()));
        assert!(!features.contains(&"sse4_2".to_owned()));
    }

    #[test]
    fn unknown_family_is_absent() {
        let families = family_map().expect("family map parses");

        assert!(!families.contains_key("not_a_real_family"));
    }

    #[test]
    fn expands_requested_family_to_target_feature_attribute() {
        let expanded = expand_target_family(
            quote::quote!("x86_64_v3"),
            quote::quote!(
                unsafe fn accelerated() {}
            ),
        )
        .expect("expansion succeeds")
        .to_string();

        assert!(expanded.contains("target_feature"));
        assert!(expanded.contains("enable"));
        assert!(expanded.contains("avx,avx2"));
        assert!(expanded.contains("sse4.1"));
    }

    #[test]
    fn rejects_safe_functions_with_clear_error() {
        let err = expand_target_family(
            quote::quote!("x86_64_v3"),
            quote::quote!(
                fn accelerated() {}
            ),
        )
        .expect_err("safe functions are rejected");

        assert!(err.to_string().contains("unsafe fn"));
    }

    #[test]
    fn rejects_unknown_families_with_clear_error() {
        let err = expand_target_family(
            quote::quote!("not_a_real_family"),
            quote::quote!(
                unsafe fn accelerated() {}
            ),
        )
        .expect_err("unknown families are rejected");

        assert!(err.to_string().contains("unknown target family"));
    }

    #[test]
    fn rejects_trailing_attribute_tokens() {
        let err = expand_target_family(
            quote::quote!("x86_64_v3", "x86_64_v4"),
            quote::quote!(
                unsafe fn accelerated() {}
            ),
        )
        .expect_err("trailing tokens are rejected");

        assert!(err.to_string().contains("unexpected token"));
    }

    #[test]
    fn all_snapshot_files_are_listed_in_source_constants() {
        let snapshot_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("src/sources");
        let listed_paths: BTreeSet<_> = FEATURE_SOURCE_PATHS
            .iter()
            .chain(FAMILY_SOURCE_PATHS.iter())
            .map(|path| {
                path.strip_prefix("sources/")
                    .expect("all source constants stay under sources/")
                    .to_owned()
            })
            .collect();

        assert_eq!(
            listed_paths,
            collect_rs_files(&snapshot_root).expect("snapshot tree is readable"),
        );
    }
}