Skip to main content

cpudetect_macros/
lib.rs

1use std::collections::BTreeMap;
2
3use proc_macro::TokenStream;
4use syn::{Item, LitStr, Token, parse::Parser, punctuated::Punctuated};
5
6#[cfg(test)]
7macro_rules! as_string_slice {
8    ($($path:literal),* $(,)?) => {
9        &[$($path),*]
10    };
11}
12
13macro_rules! as_include_str_slice {
14    ($($path:literal),* $(,)?) => {
15        &[$(include_str!($path)),*]
16    };
17}
18
19macro_rules! feature_source_paths {
20    ($macro:ident) => {
21        $macro!("sources/x86_64/features.rs", "sources/aarch64/features.rs",)
22    };
23}
24
25macro_rules! family_source_paths {
26    ($macro:ident) => {
27        $macro!(
28            "sources/x86_64/families/generic.rs",
29            "sources/x86_64/families/amd.rs",
30            "sources/x86_64/families/intel.rs",
31            "sources/aarch64/families/ampere.rs",
32            "sources/aarch64/families/apple.rs",
33            "sources/aarch64/families/arm.rs",
34            "sources/aarch64/families/fujitsu.rs",
35            "sources/aarch64/families/generic.rs",
36            "sources/aarch64/families/hi_silicon.rs",
37            "sources/aarch64/families/marvell.rs",
38            "sources/aarch64/families/nvidia.rs",
39            "sources/aarch64/families/qualcomm.rs",
40            "sources/aarch64/families/samsung.rs",
41        )
42    };
43}
44
45#[cfg(test)]
46const FEATURE_SOURCE_PATHS: &[&str] = feature_source_paths!(as_string_slice);
47#[cfg(test)]
48const FAMILY_SOURCE_PATHS: &[&str] = family_source_paths!(as_string_slice);
49
50const FEATURE_SOURCES: &[&str] = feature_source_paths!(as_include_str_slice);
51const FAMILY_SOURCES: &[&str] = family_source_paths!(as_include_str_slice);
52
53fn parse_string_literals(tokens: proc_macro2::TokenStream) -> syn::Result<Vec<String>> {
54    let parser = Punctuated::<LitStr, Token![,]>::parse_terminated;
55    parser
56        .parse2(tokens)
57        .map(|items| items.into_iter().map(|literal| literal.value()).collect())
58}
59
60fn collect_macro_literals(source: &str, macro_name: &str) -> syn::Result<Vec<Vec<String>>> {
61    let file = syn::parse_file(source)?;
62    let mut invocations = Vec::new();
63
64    for item in file.items {
65        if let Item::Macro(item_macro) = item
66            && item_macro.mac.path.is_ident(macro_name)
67        {
68            invocations.push(parse_string_literals(item_macro.mac.tokens)?);
69        }
70    }
71
72    Ok(invocations)
73}
74
75fn feature_aliases() -> syn::Result<BTreeMap<String, String>> {
76    let mut aliases = BTreeMap::new();
77
78    for source in FEATURE_SOURCES {
79        for macro_name in ["x86_64_feature", "aarch64_feature"] {
80            for literals in collect_macro_literals(source, macro_name)? {
81                match literals.as_slice() {
82                    [target_feature] => {
83                        aliases.insert(target_feature.clone(), target_feature.clone());
84                    }
85                    [target_feature, function_suffix] => {
86                        aliases.insert(function_suffix.clone(), target_feature.clone());
87                    }
88                    _ => {}
89                }
90            }
91        }
92    }
93
94    Ok(aliases)
95}
96
97fn family_map() -> syn::Result<BTreeMap<String, Vec<String>>> {
98    let aliases = feature_aliases()?;
99    let mut families = BTreeMap::new();
100
101    for source in FAMILY_SOURCES {
102        for literals in collect_macro_literals(source, "declare_is_compatible")? {
103            let Some((family, feature_suffixes)) = literals.split_first() else {
104                continue;
105            };
106
107            let features = feature_suffixes
108                .iter()
109                .map(|feature| {
110                    aliases
111                        .get(feature)
112                        .cloned()
113                        .unwrap_or_else(|| feature.clone())
114                })
115                .collect();
116
117            families.insert(family.clone(), features);
118        }
119    }
120
121    Ok(families)
122}
123
124struct TargetFamilyAttr {
125    family: LitStr,
126}
127
128impl syn::parse::Parse for TargetFamilyAttr {
129    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
130        if input.peek(Token![=]) {
131            input.parse::<Token![=]>()?;
132
133            return Ok(Self {
134                family: input.parse()?,
135            });
136        }
137
138        if input.peek(LitStr) {
139            return Ok(Self {
140                family: input.parse()?,
141            });
142        }
143
144        if input.peek(syn::Ident) {
145            let name = input.parse::<syn::Ident>()?;
146            if name != "name" {
147                return Err(syn::Error::new_spanned(
148                    name,
149                    "expected `name = \"family\"` or a string literal",
150                ));
151            }
152            input.parse::<Token![=]>()?;
153
154            return Ok(Self {
155                family: input.parse()?,
156            });
157        }
158
159        Err(input.error("expected `\"family\"` or `name = \"family\"`"))
160    }
161}
162
163#[proc_macro_attribute]
164pub fn target_family(attr: TokenStream, item: TokenStream) -> TokenStream {
165    expand_target_family(attr.into(), item.into())
166        .unwrap_or_else(syn::Error::into_compile_error)
167        .into()
168}
169
170fn expand_target_family(
171    attr: proc_macro2::TokenStream,
172    item: proc_macro2::TokenStream,
173) -> syn::Result<proc_macro2::TokenStream> {
174    let TargetFamilyAttr { family } = syn::parse2(attr)?;
175    let item_fn: syn::ItemFn = syn::parse2(item)?;
176    let family_name = family.value();
177    let families = family_map()?;
178    let Some(features) = families.get(&family_name) else {
179        return Err(syn::Error::new(
180            family.span(),
181            format!("unknown target family `{family_name}`"),
182        ));
183    };
184
185    if item_fn.sig.unsafety.is_none() {
186        return Err(syn::Error::new_spanned(
187            item_fn.sig.fn_token,
188            "target_family expands to #[target_feature], which must be applied to an unsafe fn",
189        ));
190    }
191
192    let feature_list = LitStr::new(&features.join(","), proc_macro2::Span::call_site());
193
194    Ok(quote::quote! {
195        #[target_feature(enable = #feature_list)]
196        #item_fn
197    })
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use std::{
204        collections::BTreeSet,
205        fs,
206        path::{Path, PathBuf},
207    };
208
209    fn collect_rs_files(root: &Path) -> std::io::Result<BTreeSet<String>> {
210        let mut paths = BTreeSet::new();
211        let mut pending = vec![root.to_path_buf()];
212
213        while let Some(dir) = pending.pop() {
214            for entry in fs::read_dir(&dir)? {
215                let entry = entry?;
216                let path = entry.path();
217
218                if entry.file_type()?.is_dir() {
219                    pending.push(path);
220                    continue;
221                }
222
223                if path.extension().is_some_and(|extension| extension == "rs") {
224                    let relative = path.strip_prefix(root).expect("path is under root");
225                    paths.insert(relative.to_string_lossy().replace('\\', "/"));
226                }
227            }
228        }
229
230        Ok(paths)
231    }
232
233    #[test]
234    fn resolves_x86_64_v3_to_canonical_target_features() {
235        let families = family_map().expect("family map parses");
236        let features = families.get("x86_64_v3").expect("x86_64_v3 exists");
237
238        assert!(features.contains(&"avx".to_owned()));
239        assert!(features.contains(&"avx2".to_owned()));
240        assert!(features.contains(&"bmi1".to_owned()));
241        assert!(features.contains(&"bmi2".to_owned()));
242        assert!(features.contains(&"sse4.1".to_owned()));
243        assert!(features.contains(&"sse4.2".to_owned()));
244        assert!(!features.contains(&"sse4_1".to_owned()));
245        assert!(!features.contains(&"sse4_2".to_owned()));
246    }
247
248    #[test]
249    fn unknown_family_is_absent() {
250        let families = family_map().expect("family map parses");
251
252        assert!(!families.contains_key("not_a_real_family"));
253    }
254
255    #[test]
256    fn expands_requested_family_to_target_feature_attribute() {
257        let expanded = expand_target_family(
258            quote::quote!("x86_64_v3"),
259            quote::quote!(
260                unsafe fn accelerated() {}
261            ),
262        )
263        .expect("expansion succeeds")
264        .to_string();
265
266        assert!(expanded.contains("target_feature"));
267        assert!(expanded.contains("enable"));
268        assert!(expanded.contains("avx,avx2"));
269        assert!(expanded.contains("sse4.1"));
270    }
271
272    #[test]
273    fn rejects_safe_functions_with_clear_error() {
274        let err = expand_target_family(
275            quote::quote!("x86_64_v3"),
276            quote::quote!(
277                fn accelerated() {}
278            ),
279        )
280        .expect_err("safe functions are rejected");
281
282        assert!(err.to_string().contains("unsafe fn"));
283    }
284
285    #[test]
286    fn rejects_unknown_families_with_clear_error() {
287        let err = expand_target_family(
288            quote::quote!("not_a_real_family"),
289            quote::quote!(
290                unsafe fn accelerated() {}
291            ),
292        )
293        .expect_err("unknown families are rejected");
294
295        assert!(err.to_string().contains("unknown target family"));
296    }
297
298    #[test]
299    fn rejects_trailing_attribute_tokens() {
300        let err = expand_target_family(
301            quote::quote!("x86_64_v3", "x86_64_v4"),
302            quote::quote!(
303                unsafe fn accelerated() {}
304            ),
305        )
306        .expect_err("trailing tokens are rejected");
307
308        assert!(err.to_string().contains("unexpected token"));
309    }
310
311    #[test]
312    fn all_snapshot_files_are_listed_in_source_constants() {
313        let snapshot_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("src/sources");
314        let listed_paths: BTreeSet<_> = FEATURE_SOURCE_PATHS
315            .iter()
316            .chain(FAMILY_SOURCE_PATHS.iter())
317            .map(|path| {
318                path.strip_prefix("sources/")
319                    .expect("all source constants stay under sources/")
320                    .to_owned()
321            })
322            .collect();
323
324        assert_eq!(
325            listed_paths,
326            collect_rs_files(&snapshot_root).expect("snapshot tree is readable"),
327        );
328    }
329}