dir_test_macros/
lib.rs

1use std::{
2    collections::HashSet,
3    ffi::OsStr,
4    path::{Path, PathBuf},
5};
6
7use proc_macro2::Span;
8use quote::quote;
9use syn::Token;
10
11type Error = syn::Error;
12type Result<T> = syn::Result<T>;
13
14#[proc_macro_attribute]
15pub fn dir_test(
16    attrs: proc_macro::TokenStream,
17    item: proc_macro::TokenStream,
18) -> proc_macro::TokenStream {
19    let attr = syn::parse_macro_input!(attrs as DirTestArg);
20    let func = syn::parse_macro_input!(item as syn::ItemFn);
21
22    match TestBuilder::new(attr, func).build() {
23        Ok((tests, func)) => quote! {
24            #func
25            #tests
26        }
27        .into(),
28        Err(e) => e.to_compile_error().into(),
29    }
30}
31
32struct TestBuilder {
33    dir_test_arg: DirTestArg,
34    func: syn::ItemFn,
35    test_attrs: Vec<syn::Attribute>,
36}
37
38impl TestBuilder {
39    fn new(dir_test_arg: DirTestArg, func: syn::ItemFn) -> Self {
40        Self {
41            dir_test_arg,
42            func,
43            test_attrs: vec![],
44        }
45    }
46
47    fn build(mut self) -> Result<(proc_macro2::TokenStream, syn::ItemFn)> {
48        self.extract_test_attrs()?;
49
50        let mut pattern = self.dir_test_arg.resolve_dir()?;
51
52        pattern.push(
53            self.dir_test_arg
54                .glob
55                .clone()
56                .map_or_else(|| "*".to_string(), |g| g.value()),
57        );
58
59        let paths = glob::glob(&pattern.to_string_lossy()).map_err(|e| {
60            Error::new_spanned(
61                self.dir_test_arg.glob.clone().unwrap(),
62                format!("failed to resolve glob pattern: {e}"),
63            )
64        })?;
65
66        let mut tests = vec![];
67        for entry in paths.filter_map(|p| p.ok()) {
68            if !entry.is_file() {
69                continue;
70            }
71
72            tests.push(self.build_test(&entry)?);
73        }
74
75        Ok((
76            quote! {
77                #(#tests)*
78            },
79            self.func,
80        ))
81    }
82
83    fn build_test(&self, file_path: &Path) -> Result<proc_macro2::TokenStream> {
84        let test_func = &self.func.sig.ident;
85        let test_name = self.test_name(test_func.to_string(), file_path)?;
86        let file_path_str = file_path.to_string_lossy();
87        let return_ty = &self.func.sig.output;
88        let test_attrs = &self.test_attrs;
89
90        let loader = match self.dir_test_arg.loader {
91            Some(ref loader) => quote! {#loader},
92            None => quote! {::core::include_str!},
93        };
94
95        Ok(quote! {
96            #(#test_attrs)*
97            #[test]
98            fn #test_name() #return_ty {
99                #test_func(::dir_test::Fixture::new(#loader(#file_path_str), #file_path_str))
100            }
101        })
102    }
103
104    fn test_name(&self, test_func_name: String, fixture_path: &Path) -> Result<syn::Ident> {
105        assert!(fixture_path.is_file());
106
107        let dir_path = self.dir_test_arg.resolve_dir()?;
108        let rel_path = fixture_path.strip_prefix(dir_path).unwrap();
109        assert!(rel_path.is_relative());
110
111        let mut test_name = test_func_name;
112        test_name.push_str("__");
113
114        let components: Vec<_> = rel_path.iter().collect();
115
116        for component in &components[0..components.len() - 1] {
117            let component = component
118                .to_string_lossy()
119                .replace(|c: char| c.is_ascii_punctuation(), "_");
120            test_name.push_str(&component);
121            test_name.push('_');
122        }
123
124        test_name.push_str(
125            &rel_path
126                .file_stem()
127                .unwrap()
128                .to_string_lossy()
129                .replace(|c: char| c.is_ascii_punctuation(), "_"),
130        );
131
132        if let Some(postfix) = &self.dir_test_arg.postfix {
133            test_name.push('_');
134            test_name.push_str(&postfix.value());
135        }
136
137        Ok(make_ident(&test_name))
138    }
139
140    /// Extracts `#[dir_test_attr(...)]` from function attributes.
141    fn extract_test_attrs(&mut self) -> Result<()> {
142        let mut err = Ok(());
143        self.func.attrs.retain(|attr| {
144            if attr.path().is_ident("dir_test_attr") {
145                err = err
146                    .clone()
147                    .and(attr.parse_args_with(|input: syn::parse::ParseStream| {
148                        self.test_attrs
149                            .extend(input.call(syn::Attribute::parse_outer)?);
150                        if !input.is_empty() {
151                            Err(Error::new(
152                                input.span(),
153                                "unexpected token after `dir_test_attr`",
154                            ))
155                        } else {
156                            Ok(())
157                        }
158                    }));
159
160                false
161            } else {
162                true
163            }
164        });
165
166        err
167    }
168}
169
170#[derive(Default)]
171struct DirTestArg {
172    dir: Option<syn::LitStr>,
173    glob: Option<syn::LitStr>,
174    postfix: Option<syn::LitStr>,
175    loader: Option<syn::Path>,
176}
177
178impl DirTestArg {
179    fn resolve_dir(&self) -> Result<PathBuf> {
180        let Some(dir) = &self.dir else {
181            return Err(Error::new(Span::call_site(), "`dir` is required"));
182        };
183
184        let resolved = self.resolve_path(Path::new(&dir.value()))?;
185
186        if !resolved.is_absolute() {
187            return Err(Error::new_spanned(
188                dir.clone(),
189                format!("`{}` is not an absolute path", resolved.display()),
190            ));
191        } else if !resolved.exists() {
192            return Err(Error::new_spanned(
193                dir.clone(),
194                format!("`{}` does not exist", resolved.display()),
195            ));
196        } else if !resolved.is_dir() {
197            return Err(Error::new_spanned(
198                dir.clone(),
199                format!("`{}` is not a directory", resolved.display()),
200            ));
201        }
202
203        Ok(resolved)
204    }
205
206    fn resolve_path(&self, path: &Path) -> Result<PathBuf> {
207        let mut resolved = PathBuf::new();
208        for component in path {
209            resolved.push(self.resolve_component(component)?);
210        }
211        Ok(resolved)
212    }
213
214    fn resolve_component(&self, component: &OsStr) -> Result<PathBuf> {
215        if component.to_string_lossy().starts_with('$') {
216            let env_var = &component.to_string_lossy()[1..];
217            let env_var_value = std::env::var(env_var).map_err(|e| {
218                Error::new_spanned(
219                    self.dir.clone().unwrap(),
220                    format!("failed to resolve env var `{env_var}`: {e}"),
221                )
222            })?;
223            let resolved = self.resolve_path(Path::new(&env_var_value))?;
224            Ok(resolved)
225        } else {
226            Ok(Path::new(&component).into())
227        }
228    }
229}
230
231impl syn::parse::Parse for DirTestArg {
232    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
233        let mut dir_test_attr = DirTestArg::default();
234        let mut visited_args: HashSet<String> = HashSet::new();
235
236        while !input.is_empty() {
237            let arg = input.parse::<syn::Ident>()?;
238            if visited_args.contains(&arg.to_string()) {
239                return Err(Error::new_spanned(
240                    arg.clone(),
241                    format!("duplicated arg `{arg}`"),
242                ));
243            }
244
245            match arg.to_string().as_str() {
246                "dir" => {
247                    input.parse::<Token![:]>()?;
248                    dir_test_attr.dir = Some(input.parse()?);
249                }
250
251                "glob" => {
252                    input.parse::<Token![:]>()?;
253                    dir_test_attr.glob = Some(input.parse()?);
254                }
255
256                "postfix" => {
257                    input.parse::<Token![:]>()?;
258                    dir_test_attr.postfix = Some(input.parse()?);
259                }
260
261                "loader" => {
262                    input.parse::<Token![:]>()?;
263                    dir_test_attr.loader = Some(input.parse()?);
264                }
265
266                _ => {
267                    return Err(Error::new_spanned(
268                        arg.clone(),
269                        format!("unknown arg `{arg}`"),
270                    ))
271                }
272            };
273
274            visited_args.insert(arg.to_string());
275            input.parse::<syn::Token![,]>().ok();
276        }
277
278        Ok(dir_test_attr)
279    }
280}
281
282fn make_ident(name: &str) -> syn::Ident {
283    if is_keyword(name) {
284        syn::Ident::new_raw(name, Span::call_site())
285    } else {
286        syn::Ident::new(name, Span::call_site())
287    }
288}
289
290fn is_keyword(name: &str) -> bool {
291    matches!(
292        name,
293        "as" | "break"
294            | "const"
295            | "continue"
296            | "crate"
297            | "else"
298            | "enum "
299            | "extern"
300            | "false"
301            | "fn"
302            | "for"
303            | "if"
304            | "impl"
305            | "in"
306            | "let"
307            | "loop"
308            | "match"
309            | "mod"
310            | "move"
311            | "mut"
312            | "pub"
313            | "ref"
314            | "return"
315            | "self"
316            | "Self"
317            | "static"
318            | "struct"
319            | "super"
320            | "trait"
321            | "true"
322            | "type"
323            | "unsafe"
324            | "use"
325            | "where"
326            | "while"
327            | "async"
328            | "await"
329            | "dyn"
330            | "abstract"
331            | "become"
332            | "box"
333            | "do"
334            | "final"
335            | "macro"
336            | "override"
337            | "priv"
338            | "typeof"
339            | "unsized"
340            | "virtual"
341            | "yield"
342            | "try"
343    )
344}