dir_bench_macros/
lib.rs

1use quote::quote;
2use std::collections::HashSet;
3use std::ffi::OsStr;
4use std::path::{Path, PathBuf};
5
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use syn::Result;
9use syn::{Error, Token, parse::Parse};
10
11#[proc_macro_attribute]
12pub fn dir_bench(args: TokenStream, item: TokenStream) -> TokenStream {
13    let input = syn::parse_macro_input!(item as syn::ItemFn);
14    let args = syn::parse_macro_input!(args as DirBenchArgs);
15
16    match BenchBuilder::new(args, input).build() {
17        Ok((benchs, func)) => quote! {
18            #func
19            #benchs
20        }
21        .into(),
22        Err(e) => e.to_compile_error().into(),
23    }
24}
25
26struct BenchBuilder {
27    args: DirBenchArgs,
28    func: syn::ItemFn,
29    bench_attrs: Vec<syn::Attribute>,
30}
31
32impl BenchBuilder {
33    fn new(args: DirBenchArgs, func: syn::ItemFn) -> Self {
34        Self {
35            args,
36            func,
37            bench_attrs: vec![],
38        }
39    }
40
41    fn build(mut self) -> Result<(TokenStream2, syn::ItemFn)> {
42        self.extract_bench_args()?;
43
44        let mut pattern = self.args.resolve_dir()?;
45
46        pattern.push(
47            self.args
48                .glob
49                .clone()
50                .map_or_else(|| "*".to_owned(), |v| v.value()),
51        );
52
53        let paths = glob::glob(&pattern.to_string_lossy()).map_err(|e| {
54            Error::new_spanned(
55                self.args.glob.clone().unwrap(),
56                format!("failed to resolve glob pattern {e}"),
57            )
58        })?;
59
60        let bound = paths.size_hint();
61        let mut tests = Vec::with_capacity(bound.1.unwrap_or(bound.0));
62
63        for entry in paths.filter_map(|p| p.ok()) {
64            if !entry.is_file() {
65                continue;
66            }
67
68            tests.push(self.build_bench(&entry)?)
69        }
70
71        Ok((
72            quote! {
73                #(#tests)*
74            },
75            self.func,
76        ))
77    }
78
79    fn build_bench(&self, path: &Path) -> Result<TokenStream2> {
80        let bench_ident = &self.func.sig.ident;
81        let bench_name = self.bench_name(bench_ident.to_string(), path)?;
82        let bench_attrs = &self.bench_attrs;
83        let path = path.to_string_lossy();
84
85        let loader = match self.args.loader {
86            Some(ref loader) => quote! {#loader},
87            None => quote! { ::core::include_str! },
88        };
89
90        Ok(quote! {
91            #(#bench_attrs)*
92            #[bench]
93            fn #bench_name(b: &mut test::Bencher) {
94                #bench_ident(b,::dir_bench::Fixture::new(#loader(#path), #path));
95            }
96        })
97    }
98
99    fn bench_name(&self, test_func_name: String, fixture_path: &Path) -> Result<syn::Ident> {
100        assert!(fixture_path.is_file());
101
102        let dir_path = self.args.resolve_dir()?;
103        let rel_path = fixture_path.strip_prefix(dir_path).unwrap();
104
105        assert!(rel_path.is_relative());
106
107        let mut bench_name = test_func_name;
108        bench_name.push_str("__");
109
110        let components: Vec<_> = rel_path.iter().collect();
111
112        for component in &components[0..components.len() - 1] {
113            let component = component
114                .to_string_lossy()
115                .replace(|c: char| c.is_ascii_punctuation(), "_");
116            bench_name.push_str(&component);
117            bench_name.push('_');
118        }
119
120        bench_name.push_str(
121            &rel_path
122                .file_stem()
123                .unwrap()
124                .to_string_lossy()
125                .replace(|c: char| c.is_ascii_punctuation(), "_"),
126        );
127
128        if let Some(postfix) = &self.args.postfix {
129            bench_name.push('_');
130            bench_name.push_str(&postfix.value());
131        }
132
133        Ok(make_ident(&bench_name))
134    }
135
136    fn extract_bench_args(&mut self) -> Result<()> {
137        let mut err = Ok(());
138
139        self.func.attrs.retain(|attr| {
140            if attr.path().is_ident("dir_bench_attr") {
141                err = err
142                    .clone()
143                    .and(attr.parse_args_with(|input: syn::parse::ParseStream| {
144                        self.bench_attrs
145                            .extend(input.call(syn::Attribute::parse_outer)?);
146
147                        if !input.is_empty() {
148                            Err(Error::new(
149                                input.span(),
150                                "unexpected token after `dir_bench_attr`",
151                            ))
152                        } else {
153                            Ok(())
154                        }
155                    }));
156
157                false
158            } else {
159                true
160            }
161        });
162
163        err
164    }
165}
166
167#[derive(Default)]
168struct DirBenchArgs {
169    pub dir: Option<syn::LitStr>,
170    pub glob: Option<syn::LitStr>,
171    pub postfix: Option<syn::LitStr>,
172    pub loader: Option<syn::Path>,
173}
174
175impl DirBenchArgs {
176    fn resolve_dir(&self) -> Result<PathBuf> {
177        let Some(dir) = &self.dir else {
178            return Err(Error::new(Span::call_site(), "`dir` is required"));
179        };
180
181        let resolved = self.resolve_path(Path::new(&dir.value()))?;
182
183        if !resolved.is_absolute() {
184            return Err(Error::new_spanned(
185                dir.clone(),
186                format!("`{}` is not an absolute path", resolved.display()),
187            ));
188        } else if !resolved.exists() {
189            return Err(Error::new_spanned(
190                dir.clone(),
191                format!("`{}` does not exist", resolved.display()),
192            ));
193        } else if !resolved.is_dir() {
194            return Err(Error::new_spanned(
195                dir.clone(),
196                format!("`{}` is not a directory", resolved.display()),
197            ));
198        }
199
200        Ok(resolved)
201    }
202
203    fn resolve_path(&self, path: &Path) -> Result<PathBuf> {
204        let mut resolved = PathBuf::new();
205        for component in path {
206            resolved.push(self.resolve_component(component)?);
207        }
208        Ok(resolved)
209    }
210
211    fn resolve_component(&self, component: &OsStr) -> Result<PathBuf> {
212        if component.to_string_lossy().starts_with('$') {
213            let env_var = &component.to_string_lossy()[1..];
214            let env_var_value = std::env::var(env_var).map_err(|e| {
215                Error::new_spanned(
216                    self.dir.clone().unwrap(),
217                    format!("failed to resolve env var `{env_var}`: {e}"),
218                )
219            })?;
220            let resolved = self.resolve_path(Path::new(&env_var_value))?;
221            Ok(resolved)
222        } else {
223            Ok(Path::new(&component).into())
224        }
225    }
226}
227
228impl Parse for DirBenchArgs {
229    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
230        let mut args = DirBenchArgs::default();
231        let mut visited_args = HashSet::<String>::new();
232
233        while !input.is_empty() {
234            let arg = input.parse::<syn::Ident>()?;
235            if visited_args.contains(&arg.to_string()) {
236                return Err(Error::new_spanned(
237                    arg.clone(),
238                    format!("duplicated arg `{arg}`"),
239                ));
240            }
241
242            input.parse::<Token![:]>()?;
243
244            match arg.to_string().as_str() {
245                "dir" => {
246                    args.dir = Some(input.parse()?);
247                }
248                "glob" => {
249                    args.glob = Some(input.parse()?);
250                }
251                "postfix" => {
252                    args.postfix = Some(input.parse()?);
253                }
254                "loader" => {
255                    args.loader = Some(input.parse()?);
256                }
257                _ => {
258                    return Err(Error::new_spanned(
259                        arg.clone(),
260                        format!("unknown arg `{arg}`"),
261                    ));
262                }
263            }
264
265            visited_args.insert(arg.to_string());
266            input.parse::<syn::Token![,]>().ok();
267        }
268
269        Ok(args)
270    }
271}
272
273fn is_keyword(name: &str) -> bool {
274    matches!(
275        name,
276        "as" | "break"
277            | "const"
278            | "continue"
279            | "crate"
280            | "else"
281            | "enum "
282            | "extern"
283            | "false"
284            | "fn"
285            | "for"
286            | "if"
287            | "impl"
288            | "in"
289            | "let"
290            | "loop"
291            | "match"
292            | "mod"
293            | "move"
294            | "mut"
295            | "pub"
296            | "ref"
297            | "return"
298            | "self"
299            | "Self"
300            | "static"
301            | "struct"
302            | "super"
303            | "trait"
304            | "true"
305            | "type"
306            | "unsafe"
307            | "use"
308            | "where"
309            | "while"
310            | "async"
311            | "await"
312            | "dyn"
313            | "abstract"
314            | "become"
315            | "box"
316            | "do"
317            | "final"
318            | "macro"
319            | "override"
320            | "priv"
321            | "typeof"
322            | "unsized"
323            | "virtual"
324            | "yield"
325            | "try"
326    )
327}
328
329fn make_ident(name: &str) -> syn::Ident {
330    if is_keyword(name) {
331        syn::Ident::new_raw(name, Span::call_site())
332    } else {
333        syn::Ident::new(name, Span::call_site())
334    }
335}