1use quote::quote;
2use std::collections::HashSet;
3use std::ffi::OsStr;
4use std::path::{Path, PathBuf};
5
6use proc_macro::TokenStream;
7use proc_macro2::{Span, TokenStream as TokenStream2};
8use syn::Result;
9use syn::{Error, Token, parse::Parse};
10
11#[proc_macro_attribute]
12pub fn dir_bench(args: TokenStream, item: TokenStream) -> TokenStream {
13 let input = syn::parse_macro_input!(item as syn::ItemFn);
14 let args = syn::parse_macro_input!(args as DirBenchArgs);
15
16 match BenchBuilder::new(args, input).build() {
17 Ok((benchs, func)) => quote! {
18 #func
19 #benchs
20 }
21 .into(),
22 Err(e) => e.to_compile_error().into(),
23 }
24}
25
26struct BenchBuilder {
27 args: DirBenchArgs,
28 func: syn::ItemFn,
29 bench_attrs: Vec<syn::Attribute>,
30}
31
32impl BenchBuilder {
33 fn new(args: DirBenchArgs, func: syn::ItemFn) -> Self {
34 Self {
35 args,
36 func,
37 bench_attrs: vec![],
38 }
39 }
40
41 fn build(mut self) -> Result<(TokenStream2, syn::ItemFn)> {
42 self.extract_bench_args()?;
43
44 let mut pattern = self.args.resolve_dir()?;
45
46 pattern.push(
47 self.args
48 .glob
49 .clone()
50 .map_or_else(|| "*".to_owned(), |v| v.value()),
51 );
52
53 let paths = glob::glob(&pattern.to_string_lossy()).map_err(|e| {
54 Error::new_spanned(
55 self.args.glob.clone().unwrap(),
56 format!("failed to resolve glob pattern {e}"),
57 )
58 })?;
59
60 let bound = paths.size_hint();
61 let mut tests = Vec::with_capacity(bound.1.unwrap_or(bound.0));
62
63 for entry in paths.filter_map(|p| p.ok()) {
64 if !entry.is_file() {
65 continue;
66 }
67
68 tests.push(self.build_bench(&entry)?)
69 }
70
71 Ok((
72 quote! {
73 #(#tests)*
74 },
75 self.func,
76 ))
77 }
78
79 fn build_bench(&self, path: &Path) -> Result<TokenStream2> {
80 let bench_ident = &self.func.sig.ident;
81 let bench_name = self.bench_name(bench_ident.to_string(), path)?;
82 let bench_attrs = &self.bench_attrs;
83 let path = path.to_string_lossy();
84
85 let loader = match self.args.loader {
86 Some(ref loader) => quote! {#loader},
87 None => quote! { ::core::include_str! },
88 };
89
90 Ok(quote! {
91 #(#bench_attrs)*
92 #[bench]
93 fn #bench_name(b: &mut test::Bencher) {
94 #bench_ident(b,::dir_bench::Fixture::new(#loader(#path), #path));
95 }
96 })
97 }
98
99 fn bench_name(&self, test_func_name: String, fixture_path: &Path) -> Result<syn::Ident> {
100 assert!(fixture_path.is_file());
101
102 let dir_path = self.args.resolve_dir()?;
103 let rel_path = fixture_path.strip_prefix(dir_path).unwrap();
104
105 assert!(rel_path.is_relative());
106
107 let mut bench_name = test_func_name;
108 bench_name.push_str("__");
109
110 let components: Vec<_> = rel_path.iter().collect();
111
112 for component in &components[0..components.len() - 1] {
113 let component = component
114 .to_string_lossy()
115 .replace(|c: char| c.is_ascii_punctuation(), "_");
116 bench_name.push_str(&component);
117 bench_name.push('_');
118 }
119
120 bench_name.push_str(
121 &rel_path
122 .file_stem()
123 .unwrap()
124 .to_string_lossy()
125 .replace(|c: char| c.is_ascii_punctuation(), "_"),
126 );
127
128 if let Some(postfix) = &self.args.postfix {
129 bench_name.push('_');
130 bench_name.push_str(&postfix.value());
131 }
132
133 Ok(make_ident(&bench_name))
134 }
135
136 fn extract_bench_args(&mut self) -> Result<()> {
137 let mut err = Ok(());
138
139 self.func.attrs.retain(|attr| {
140 if attr.path().is_ident("dir_bench_attr") {
141 err = err
142 .clone()
143 .and(attr.parse_args_with(|input: syn::parse::ParseStream| {
144 self.bench_attrs
145 .extend(input.call(syn::Attribute::parse_outer)?);
146
147 if !input.is_empty() {
148 Err(Error::new(
149 input.span(),
150 "unexpected token after `dir_bench_attr`",
151 ))
152 } else {
153 Ok(())
154 }
155 }));
156
157 false
158 } else {
159 true
160 }
161 });
162
163 err
164 }
165}
166
167#[derive(Default)]
168struct DirBenchArgs {
169 pub dir: Option<syn::LitStr>,
170 pub glob: Option<syn::LitStr>,
171 pub postfix: Option<syn::LitStr>,
172 pub loader: Option<syn::Path>,
173}
174
175impl DirBenchArgs {
176 fn resolve_dir(&self) -> Result<PathBuf> {
177 let Some(dir) = &self.dir else {
178 return Err(Error::new(Span::call_site(), "`dir` is required"));
179 };
180
181 let resolved = self.resolve_path(Path::new(&dir.value()))?;
182
183 if !resolved.is_absolute() {
184 return Err(Error::new_spanned(
185 dir.clone(),
186 format!("`{}` is not an absolute path", resolved.display()),
187 ));
188 } else if !resolved.exists() {
189 return Err(Error::new_spanned(
190 dir.clone(),
191 format!("`{}` does not exist", resolved.display()),
192 ));
193 } else if !resolved.is_dir() {
194 return Err(Error::new_spanned(
195 dir.clone(),
196 format!("`{}` is not a directory", resolved.display()),
197 ));
198 }
199
200 Ok(resolved)
201 }
202
203 fn resolve_path(&self, path: &Path) -> Result<PathBuf> {
204 let mut resolved = PathBuf::new();
205 for component in path {
206 resolved.push(self.resolve_component(component)?);
207 }
208 Ok(resolved)
209 }
210
211 fn resolve_component(&self, component: &OsStr) -> Result<PathBuf> {
212 if component.to_string_lossy().starts_with('$') {
213 let env_var = &component.to_string_lossy()[1..];
214 let env_var_value = std::env::var(env_var).map_err(|e| {
215 Error::new_spanned(
216 self.dir.clone().unwrap(),
217 format!("failed to resolve env var `{env_var}`: {e}"),
218 )
219 })?;
220 let resolved = self.resolve_path(Path::new(&env_var_value))?;
221 Ok(resolved)
222 } else {
223 Ok(Path::new(&component).into())
224 }
225 }
226}
227
228impl Parse for DirBenchArgs {
229 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
230 let mut args = DirBenchArgs::default();
231 let mut visited_args = HashSet::<String>::new();
232
233 while !input.is_empty() {
234 let arg = input.parse::<syn::Ident>()?;
235 if visited_args.contains(&arg.to_string()) {
236 return Err(Error::new_spanned(
237 arg.clone(),
238 format!("duplicated arg `{arg}`"),
239 ));
240 }
241
242 input.parse::<Token![:]>()?;
243
244 match arg.to_string().as_str() {
245 "dir" => {
246 args.dir = Some(input.parse()?);
247 }
248 "glob" => {
249 args.glob = Some(input.parse()?);
250 }
251 "postfix" => {
252 args.postfix = Some(input.parse()?);
253 }
254 "loader" => {
255 args.loader = Some(input.parse()?);
256 }
257 _ => {
258 return Err(Error::new_spanned(
259 arg.clone(),
260 format!("unknown arg `{arg}`"),
261 ));
262 }
263 }
264
265 visited_args.insert(arg.to_string());
266 input.parse::<syn::Token![,]>().ok();
267 }
268
269 Ok(args)
270 }
271}
272
273fn is_keyword(name: &str) -> bool {
274 matches!(
275 name,
276 "as" | "break"
277 | "const"
278 | "continue"
279 | "crate"
280 | "else"
281 | "enum "
282 | "extern"
283 | "false"
284 | "fn"
285 | "for"
286 | "if"
287 | "impl"
288 | "in"
289 | "let"
290 | "loop"
291 | "match"
292 | "mod"
293 | "move"
294 | "mut"
295 | "pub"
296 | "ref"
297 | "return"
298 | "self"
299 | "Self"
300 | "static"
301 | "struct"
302 | "super"
303 | "trait"
304 | "true"
305 | "type"
306 | "unsafe"
307 | "use"
308 | "where"
309 | "while"
310 | "async"
311 | "await"
312 | "dyn"
313 | "abstract"
314 | "become"
315 | "box"
316 | "do"
317 | "final"
318 | "macro"
319 | "override"
320 | "priv"
321 | "typeof"
322 | "unsized"
323 | "virtual"
324 | "yield"
325 | "try"
326 )
327}
328
329fn make_ident(name: &str) -> syn::Ident {
330 if is_keyword(name) {
331 syn::Ident::new_raw(name, Span::call_site())
332 } else {
333 syn::Ident::new(name, Span::call_site())
334 }
335}