1use glob::glob;
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use quote::quote;
11use std::{collections::HashSet, iter::FromIterator, ops::Sub};
12use syn::{
13 self,
14 parse::{Parse, ParseStream, Result},
15 parse_macro_input,
16 punctuated::Punctuated,
17 Ident, LitStr, Path, Token,
18};
19
20struct GlobPattern {
21 inverted: bool,
22 pattern: LitStr,
23}
24
25impl Parse for GlobPattern {
26 fn parse(input: ParseStream) -> Result<Self> {
27 let inverted = input.parse::<Token![!]>().is_ok();
28 let pattern = input.parse()?;
29 Ok(GlobPattern { inverted, pattern })
30 }
31}
32
33type GlobPatternList = Punctuated<GlobPattern, Token![,]>;
34
35struct FileTestsInput {
36 test_fn: Path,
37 globs: GlobPatternList,
38}
39
40impl Parse for FileTestsInput {
41 fn parse(input: ParseStream) -> Result<Self> {
42 let test_fn: Path = input.parse()?;
43 input.parse::<Token![=>]>()?;
44 let globs: GlobPatternList = input.parse_terminated(GlobPattern::parse)?;
45 Ok(FileTestsInput { test_fn, globs })
46 }
47}
48
49fn glob_all<'a>(patterns: impl Iterator<Item = &'a GlobPattern>) -> HashSet<std::path::PathBuf> {
50 patterns
51 .filter_map(|pattern| glob(pattern.pattern.value().as_str()).ok())
52 .flat_map(|paths| paths.filter_map(|path| path.ok()))
53 .collect()
54}
55
56#[proc_macro]
68pub fn file_tests(input: TokenStream) -> TokenStream {
69 let input = parse_macro_input!(input as FileTestsInput);
70
71 let glob_accepted = glob_all(input.globs.iter().filter(|pattern| !pattern.inverted));
72 let glob_rejected = glob_all(input.globs.iter().filter(|pattern| pattern.inverted));
73 let test_files = glob_accepted.sub(&glob_rejected);
74
75 let test_fn_name = input.test_fn.segments.last().unwrap().ident.to_string();
76
77 let fns_tokens = test_files.iter().enumerate().map(|(i, path)| {
78 let mut fn_name = path
79 .file_stem()
80 .map(|name| {
81 format!(
82 "test{}_{}_{}",
83 i,
84 test_fn_name,
85 name.to_str().expect("Invalid filename")
86 )
87 })
88 .expect("Invalid globbed path");
89 fn_name = fn_name
91 .chars()
92 .map(|ch| match ch {
93 'A'..='Z' | 'a'..='z' | '0'..='9' => ch,
94 _ => '_',
95 })
96 .collect();
97
98 let test_fn = &input.test_fn;
99 let abs_path = path.canonicalize().expect("Could not make absolute path");
100 let path_str = abs_path.to_str().expect("Invalid path");
101 let fn_ident = Ident::new(fn_name.as_str(), Span::call_site());
102
103 quote! {
104 #[test]
105 fn #fn_ident() {
106 let path = std::path::PathBuf::from(#path_str);
107 println!("Test file: {}", #path_str);
108 match std::fs::File::open(&path) {
109 Ok(file) => #test_fn(path, file),
110 Err(err) => panic!("Error loading test file: {}: {}", #path_str, err),
111 }
112 }
113 }
114 });
115
116 proc_macro2::TokenStream::from_iter(fns_tokens).into()
117}