1use std::{
2 collections::HashSet,
3 ffi::OsStr,
4 path::{Path, PathBuf},
5};
6
7use proc_macro2::Span;
8use quote::quote;
9use syn::Token;
10
11type Error = syn::Error;
12type Result<T> = syn::Result<T>;
13
14#[proc_macro_attribute]
15pub fn dir_test(
16 attrs: proc_macro::TokenStream,
17 item: proc_macro::TokenStream,
18) -> proc_macro::TokenStream {
19 let attr = syn::parse_macro_input!(attrs as DirTestArg);
20 let func = syn::parse_macro_input!(item as syn::ItemFn);
21
22 match TestBuilder::new(attr, func).build() {
23 Ok((tests, func)) => quote! {
24 #func
25 #tests
26 }
27 .into(),
28 Err(e) => e.to_compile_error().into(),
29 }
30}
31
32struct TestBuilder {
33 dir_test_arg: DirTestArg,
34 func: syn::ItemFn,
35 test_attrs: Vec<syn::Attribute>,
36}
37
38impl TestBuilder {
39 fn new(dir_test_arg: DirTestArg, func: syn::ItemFn) -> Self {
40 Self {
41 dir_test_arg,
42 func,
43 test_attrs: vec![],
44 }
45 }
46
47 fn build(mut self) -> Result<(proc_macro2::TokenStream, syn::ItemFn)> {
48 self.extract_test_attrs()?;
49
50 let mut pattern = self.dir_test_arg.resolve_dir()?;
51
52 pattern.push(
53 self.dir_test_arg
54 .glob
55 .clone()
56 .map_or_else(|| "*".to_string(), |g| g.value()),
57 );
58
59 let paths = glob::glob(&pattern.to_string_lossy()).map_err(|e| {
60 Error::new_spanned(
61 self.dir_test_arg.glob.clone().unwrap(),
62 format!("failed to resolve glob pattern: {e}"),
63 )
64 })?;
65
66 let mut tests = vec![];
67 for entry in paths.filter_map(|p| p.ok()) {
68 if !entry.is_file() {
69 continue;
70 }
71
72 tests.push(self.build_test(&entry)?);
73 }
74
75 Ok((
76 quote! {
77 #(#tests)*
78 },
79 self.func,
80 ))
81 }
82
83 fn build_test(&self, file_path: &Path) -> Result<proc_macro2::TokenStream> {
84 let test_func = &self.func.sig.ident;
85 let test_name = self.test_name(test_func.to_string(), file_path)?;
86 let file_path_str = file_path.to_string_lossy();
87 let return_ty = &self.func.sig.output;
88 let test_attrs = &self.test_attrs;
89
90 let loader = match self.dir_test_arg.loader {
91 Some(ref loader) => quote! {#loader},
92 None => quote! {::core::include_str!},
93 };
94
95 Ok(quote! {
96 #(#test_attrs)*
97 #[test]
98 fn #test_name() #return_ty {
99 #test_func(::dir_test::Fixture::new(#loader(#file_path_str), #file_path_str))
100 }
101 })
102 }
103
104 fn test_name(&self, test_func_name: String, fixture_path: &Path) -> Result<syn::Ident> {
105 assert!(fixture_path.is_file());
106
107 let dir_path = self.dir_test_arg.resolve_dir()?;
108 let rel_path = fixture_path.strip_prefix(dir_path).unwrap();
109 assert!(rel_path.is_relative());
110
111 let mut test_name = test_func_name;
112 test_name.push_str("__");
113
114 let components: Vec<_> = rel_path.iter().collect();
115
116 for component in &components[0..components.len() - 1] {
117 let component = component
118 .to_string_lossy()
119 .replace(|c: char| c.is_ascii_punctuation(), "_");
120 test_name.push_str(&component);
121 test_name.push('_');
122 }
123
124 test_name.push_str(
125 &rel_path
126 .file_stem()
127 .unwrap()
128 .to_string_lossy()
129 .replace(|c: char| c.is_ascii_punctuation(), "_"),
130 );
131
132 if let Some(postfix) = &self.dir_test_arg.postfix {
133 test_name.push('_');
134 test_name.push_str(&postfix.value());
135 }
136
137 Ok(make_ident(&test_name))
138 }
139
140 fn extract_test_attrs(&mut self) -> Result<()> {
142 let mut err = Ok(());
143 self.func.attrs.retain(|attr| {
144 if attr.path().is_ident("dir_test_attr") {
145 err = err
146 .clone()
147 .and(attr.parse_args_with(|input: syn::parse::ParseStream| {
148 self.test_attrs
149 .extend(input.call(syn::Attribute::parse_outer)?);
150 if !input.is_empty() {
151 Err(Error::new(
152 input.span(),
153 "unexpected token after `dir_test_attr`",
154 ))
155 } else {
156 Ok(())
157 }
158 }));
159
160 false
161 } else {
162 true
163 }
164 });
165
166 err
167 }
168}
169
170#[derive(Default)]
171struct DirTestArg {
172 dir: Option<syn::LitStr>,
173 glob: Option<syn::LitStr>,
174 postfix: Option<syn::LitStr>,
175 loader: Option<syn::Path>,
176}
177
178impl DirTestArg {
179 fn resolve_dir(&self) -> Result<PathBuf> {
180 let Some(dir) = &self.dir else {
181 return Err(Error::new(Span::call_site(), "`dir` is required"));
182 };
183
184 let resolved = self.resolve_path(Path::new(&dir.value()))?;
185
186 if !resolved.is_absolute() {
187 return Err(Error::new_spanned(
188 dir.clone(),
189 format!("`{}` is not an absolute path", resolved.display()),
190 ));
191 } else if !resolved.exists() {
192 return Err(Error::new_spanned(
193 dir.clone(),
194 format!("`{}` does not exist", resolved.display()),
195 ));
196 } else if !resolved.is_dir() {
197 return Err(Error::new_spanned(
198 dir.clone(),
199 format!("`{}` is not a directory", resolved.display()),
200 ));
201 }
202
203 Ok(resolved)
204 }
205
206 fn resolve_path(&self, path: &Path) -> Result<PathBuf> {
207 let mut resolved = PathBuf::new();
208 for component in path {
209 resolved.push(self.resolve_component(component)?);
210 }
211 Ok(resolved)
212 }
213
214 fn resolve_component(&self, component: &OsStr) -> Result<PathBuf> {
215 if component.to_string_lossy().starts_with('$') {
216 let env_var = &component.to_string_lossy()[1..];
217 let env_var_value = std::env::var(env_var).map_err(|e| {
218 Error::new_spanned(
219 self.dir.clone().unwrap(),
220 format!("failed to resolve env var `{env_var}`: {e}"),
221 )
222 })?;
223 let resolved = self.resolve_path(Path::new(&env_var_value))?;
224 Ok(resolved)
225 } else {
226 Ok(Path::new(&component).into())
227 }
228 }
229}
230
231impl syn::parse::Parse for DirTestArg {
232 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
233 let mut dir_test_attr = DirTestArg::default();
234 let mut visited_args: HashSet<String> = HashSet::new();
235
236 while !input.is_empty() {
237 let arg = input.parse::<syn::Ident>()?;
238 if visited_args.contains(&arg.to_string()) {
239 return Err(Error::new_spanned(
240 arg.clone(),
241 format!("duplicated arg `{arg}`"),
242 ));
243 }
244
245 match arg.to_string().as_str() {
246 "dir" => {
247 input.parse::<Token![:]>()?;
248 dir_test_attr.dir = Some(input.parse()?);
249 }
250
251 "glob" => {
252 input.parse::<Token![:]>()?;
253 dir_test_attr.glob = Some(input.parse()?);
254 }
255
256 "postfix" => {
257 input.parse::<Token![:]>()?;
258 dir_test_attr.postfix = Some(input.parse()?);
259 }
260
261 "loader" => {
262 input.parse::<Token![:]>()?;
263 dir_test_attr.loader = Some(input.parse()?);
264 }
265
266 _ => {
267 return Err(Error::new_spanned(
268 arg.clone(),
269 format!("unknown arg `{arg}`"),
270 ))
271 }
272 };
273
274 visited_args.insert(arg.to_string());
275 input.parse::<syn::Token![,]>().ok();
276 }
277
278 Ok(dir_test_attr)
279 }
280}
281
282fn make_ident(name: &str) -> syn::Ident {
283 if is_keyword(name) {
284 syn::Ident::new_raw(name, Span::call_site())
285 } else {
286 syn::Ident::new(name, Span::call_site())
287 }
288}
289
290fn is_keyword(name: &str) -> bool {
291 matches!(
292 name,
293 "as" | "break"
294 | "const"
295 | "continue"
296 | "crate"
297 | "else"
298 | "enum "
299 | "extern"
300 | "false"
301 | "fn"
302 | "for"
303 | "if"
304 | "impl"
305 | "in"
306 | "let"
307 | "loop"
308 | "match"
309 | "mod"
310 | "move"
311 | "mut"
312 | "pub"
313 | "ref"
314 | "return"
315 | "self"
316 | "Self"
317 | "static"
318 | "struct"
319 | "super"
320 | "trait"
321 | "true"
322 | "type"
323 | "unsafe"
324 | "use"
325 | "where"
326 | "while"
327 | "async"
328 | "await"
329 | "dyn"
330 | "abstract"
331 | "become"
332 | "box"
333 | "do"
334 | "final"
335 | "macro"
336 | "override"
337 | "priv"
338 | "typeof"
339 | "unsized"
340 | "virtual"
341 | "yield"
342 | "try"
343 )
344}