1use std::collections::BTreeMap;
2
3use proc_macro::TokenStream;
4use syn::{Item, LitStr, Token, parse::Parser, punctuated::Punctuated};
5
6#[cfg(test)]
7macro_rules! as_string_slice {
8 ($($path:literal),* $(,)?) => {
9 &[$($path),*]
10 };
11}
12
13macro_rules! as_include_str_slice {
14 ($($path:literal),* $(,)?) => {
15 &[$(include_str!($path)),*]
16 };
17}
18
19macro_rules! feature_source_paths {
20 ($macro:ident) => {
21 $macro!("sources/x86_64/features.rs", "sources/aarch64/features.rs",)
22 };
23}
24
25macro_rules! family_source_paths {
26 ($macro:ident) => {
27 $macro!(
28 "sources/x86_64/families/generic.rs",
29 "sources/x86_64/families/amd.rs",
30 "sources/x86_64/families/intel.rs",
31 "sources/aarch64/families/ampere.rs",
32 "sources/aarch64/families/apple.rs",
33 "sources/aarch64/families/arm.rs",
34 "sources/aarch64/families/fujitsu.rs",
35 "sources/aarch64/families/generic.rs",
36 "sources/aarch64/families/hi_silicon.rs",
37 "sources/aarch64/families/marvell.rs",
38 "sources/aarch64/families/nvidia.rs",
39 "sources/aarch64/families/qualcomm.rs",
40 "sources/aarch64/families/samsung.rs",
41 )
42 };
43}
44
45#[cfg(test)]
46const FEATURE_SOURCE_PATHS: &[&str] = feature_source_paths!(as_string_slice);
47#[cfg(test)]
48const FAMILY_SOURCE_PATHS: &[&str] = family_source_paths!(as_string_slice);
49
50const FEATURE_SOURCES: &[&str] = feature_source_paths!(as_include_str_slice);
51const FAMILY_SOURCES: &[&str] = family_source_paths!(as_include_str_slice);
52
53fn parse_string_literals(tokens: proc_macro2::TokenStream) -> syn::Result<Vec<String>> {
54 let parser = Punctuated::<LitStr, Token![,]>::parse_terminated;
55 parser
56 .parse2(tokens)
57 .map(|items| items.into_iter().map(|literal| literal.value()).collect())
58}
59
60fn collect_macro_literals(source: &str, macro_name: &str) -> syn::Result<Vec<Vec<String>>> {
61 let file = syn::parse_file(source)?;
62 let mut invocations = Vec::new();
63
64 for item in file.items {
65 if let Item::Macro(item_macro) = item
66 && item_macro.mac.path.is_ident(macro_name)
67 {
68 invocations.push(parse_string_literals(item_macro.mac.tokens)?);
69 }
70 }
71
72 Ok(invocations)
73}
74
75fn feature_aliases() -> syn::Result<BTreeMap<String, String>> {
76 let mut aliases = BTreeMap::new();
77
78 for source in FEATURE_SOURCES {
79 for macro_name in ["x86_64_feature", "aarch64_feature"] {
80 for literals in collect_macro_literals(source, macro_name)? {
81 match literals.as_slice() {
82 [target_feature] => {
83 aliases.insert(target_feature.clone(), target_feature.clone());
84 }
85 [target_feature, function_suffix] => {
86 aliases.insert(function_suffix.clone(), target_feature.clone());
87 }
88 _ => {}
89 }
90 }
91 }
92 }
93
94 Ok(aliases)
95}
96
97fn family_map() -> syn::Result<BTreeMap<String, Vec<String>>> {
98 let aliases = feature_aliases()?;
99 let mut families = BTreeMap::new();
100
101 for source in FAMILY_SOURCES {
102 for literals in collect_macro_literals(source, "declare_is_compatible")? {
103 let Some((family, feature_suffixes)) = literals.split_first() else {
104 continue;
105 };
106
107 let features = feature_suffixes
108 .iter()
109 .map(|feature| {
110 aliases
111 .get(feature)
112 .cloned()
113 .unwrap_or_else(|| feature.clone())
114 })
115 .collect();
116
117 families.insert(family.clone(), features);
118 }
119 }
120
121 Ok(families)
122}
123
124struct TargetFamilyAttr {
125 family: LitStr,
126}
127
128impl syn::parse::Parse for TargetFamilyAttr {
129 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
130 if input.peek(Token![=]) {
131 input.parse::<Token![=]>()?;
132
133 return Ok(Self {
134 family: input.parse()?,
135 });
136 }
137
138 if input.peek(LitStr) {
139 return Ok(Self {
140 family: input.parse()?,
141 });
142 }
143
144 if input.peek(syn::Ident) {
145 let name = input.parse::<syn::Ident>()?;
146 if name != "name" {
147 return Err(syn::Error::new_spanned(
148 name,
149 "expected `name = \"family\"` or a string literal",
150 ));
151 }
152 input.parse::<Token![=]>()?;
153
154 return Ok(Self {
155 family: input.parse()?,
156 });
157 }
158
159 Err(input.error("expected `\"family\"` or `name = \"family\"`"))
160 }
161}
162
163#[proc_macro_attribute]
164pub fn target_family(attr: TokenStream, item: TokenStream) -> TokenStream {
165 expand_target_family(attr.into(), item.into())
166 .unwrap_or_else(syn::Error::into_compile_error)
167 .into()
168}
169
170fn expand_target_family(
171 attr: proc_macro2::TokenStream,
172 item: proc_macro2::TokenStream,
173) -> syn::Result<proc_macro2::TokenStream> {
174 let TargetFamilyAttr { family } = syn::parse2(attr)?;
175 let item_fn: syn::ItemFn = syn::parse2(item)?;
176 let family_name = family.value();
177 let families = family_map()?;
178 let Some(features) = families.get(&family_name) else {
179 return Err(syn::Error::new(
180 family.span(),
181 format!("unknown target family `{family_name}`"),
182 ));
183 };
184
185 if item_fn.sig.unsafety.is_none() {
186 return Err(syn::Error::new_spanned(
187 item_fn.sig.fn_token,
188 "target_family expands to #[target_feature], which must be applied to an unsafe fn",
189 ));
190 }
191
192 let feature_list = LitStr::new(&features.join(","), proc_macro2::Span::call_site());
193
194 Ok(quote::quote! {
195 #[target_feature(enable = #feature_list)]
196 #item_fn
197 })
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use std::{
204 collections::BTreeSet,
205 fs,
206 path::{Path, PathBuf},
207 };
208
209 fn collect_rs_files(root: &Path) -> std::io::Result<BTreeSet<String>> {
210 let mut paths = BTreeSet::new();
211 let mut pending = vec![root.to_path_buf()];
212
213 while let Some(dir) = pending.pop() {
214 for entry in fs::read_dir(&dir)? {
215 let entry = entry?;
216 let path = entry.path();
217
218 if entry.file_type()?.is_dir() {
219 pending.push(path);
220 continue;
221 }
222
223 if path.extension().is_some_and(|extension| extension == "rs") {
224 let relative = path.strip_prefix(root).expect("path is under root");
225 paths.insert(relative.to_string_lossy().replace('\\', "/"));
226 }
227 }
228 }
229
230 Ok(paths)
231 }
232
233 #[test]
234 fn resolves_x86_64_v3_to_canonical_target_features() {
235 let families = family_map().expect("family map parses");
236 let features = families.get("x86_64_v3").expect("x86_64_v3 exists");
237
238 assert!(features.contains(&"avx".to_owned()));
239 assert!(features.contains(&"avx2".to_owned()));
240 assert!(features.contains(&"bmi1".to_owned()));
241 assert!(features.contains(&"bmi2".to_owned()));
242 assert!(features.contains(&"sse4.1".to_owned()));
243 assert!(features.contains(&"sse4.2".to_owned()));
244 assert!(!features.contains(&"sse4_1".to_owned()));
245 assert!(!features.contains(&"sse4_2".to_owned()));
246 }
247
248 #[test]
249 fn unknown_family_is_absent() {
250 let families = family_map().expect("family map parses");
251
252 assert!(!families.contains_key("not_a_real_family"));
253 }
254
255 #[test]
256 fn expands_requested_family_to_target_feature_attribute() {
257 let expanded = expand_target_family(
258 quote::quote!("x86_64_v3"),
259 quote::quote!(
260 unsafe fn accelerated() {}
261 ),
262 )
263 .expect("expansion succeeds")
264 .to_string();
265
266 assert!(expanded.contains("target_feature"));
267 assert!(expanded.contains("enable"));
268 assert!(expanded.contains("avx,avx2"));
269 assert!(expanded.contains("sse4.1"));
270 }
271
272 #[test]
273 fn rejects_safe_functions_with_clear_error() {
274 let err = expand_target_family(
275 quote::quote!("x86_64_v3"),
276 quote::quote!(
277 fn accelerated() {}
278 ),
279 )
280 .expect_err("safe functions are rejected");
281
282 assert!(err.to_string().contains("unsafe fn"));
283 }
284
285 #[test]
286 fn rejects_unknown_families_with_clear_error() {
287 let err = expand_target_family(
288 quote::quote!("not_a_real_family"),
289 quote::quote!(
290 unsafe fn accelerated() {}
291 ),
292 )
293 .expect_err("unknown families are rejected");
294
295 assert!(err.to_string().contains("unknown target family"));
296 }
297
298 #[test]
299 fn rejects_trailing_attribute_tokens() {
300 let err = expand_target_family(
301 quote::quote!("x86_64_v3", "x86_64_v4"),
302 quote::quote!(
303 unsafe fn accelerated() {}
304 ),
305 )
306 .expect_err("trailing tokens are rejected");
307
308 assert!(err.to_string().contains("unexpected token"));
309 }
310
311 #[test]
312 fn all_snapshot_files_are_listed_in_source_constants() {
313 let snapshot_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("src/sources");
314 let listed_paths: BTreeSet<_> = FEATURE_SOURCE_PATHS
315 .iter()
316 .chain(FAMILY_SOURCE_PATHS.iter())
317 .map(|path| {
318 path.strip_prefix("sources/")
319 .expect("all source constants stay under sources/")
320 .to_owned()
321 })
322 .collect();
323
324 assert_eq!(
325 listed_paths,
326 collect_rs_files(&snapshot_root).expect("snapshot tree is readable"),
327 );
328 }
329}