1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_error::{abort, abort_call_site, proc_macro_error};
4use quote::{format_ident, quote, ToTokens};
5use std::{borrow::Cow, path::PathBuf};
6use syn::{
7 parse_macro_input, AttributeArgs, Ident, Item, ItemMod, Lit, Meta, MetaList, MetaNameValue,
8 NestedMeta, Path, PathArguments, PathSegment,
9};
10use walkdir::WalkDir;
11
12struct Args {
13 parser_path: Path,
14 rule_path: Path,
15 rule_ident: Ident,
16 skip_rules: Vec<Ident>,
17 dir: PathBuf,
18 subdir: Option<PathBuf>,
19 ext: String,
20 recursive: bool,
21 strict: bool,
22 no_eoi: bool,
23 lazy_static: bool,
24}
25
26impl Args {
27 fn from(attr_args: Vec<NestedMeta>) -> Self {
28 let mut attr_args_iter = attr_args.into_iter();
29
30 let parser_path = match attr_args_iter.next() {
32 Some(NestedMeta::Meta(Meta::Path(path))) => path,
33 Some(other) => abort!(other, "Invalid parser type"),
34 None => abort_call_site!("Missing required argument <parser>"),
35 };
36 let rule_path = match attr_args_iter.next() {
37 Some(NestedMeta::Meta(Meta::Path(path))) => path,
38 Some(other) => abort!(other, "Invalid rule"),
39 None => abort_call_site!("Missing required argument <rule>"),
40 };
41 let rule_ident = match attr_args_iter.next() {
42 Some(NestedMeta::Lit(Lit::Str(s))) => Ident::new(s.value().as_ref(), Span::call_site()),
43 Some(other) => abort!(other, "Invalid rule name"),
44 None => abort_call_site!("Missing required argument <root rule name>"),
45 };
46
47 let mut args = Args {
48 parser_path,
49 rule_path,
50 rule_ident,
51 skip_rules: Vec::new(),
52 dir: pest_test::default_test_dir(),
53 subdir: None,
54 ext: String::from("txt"),
55 recursive: false,
56 strict: true,
57 no_eoi: false,
58 lazy_static: false,
59 };
60
61 for arg in attr_args_iter {
63 match arg {
64 NestedMeta::Meta(Meta::NameValue(MetaNameValue {
65 path,
66 eq_token: _,
67 lit,
68 })) => {
69 let attr_name = path
70 .get_ident()
71 .unwrap_or_else(|| abort!(path, "Invalid argument to pest_test_gen macro"))
72 .to_string();
73 match attr_name.as_str() {
74 "dir" => {
75 let mut path = match lit {
76 Lit::Str(s) => PathBuf::from(s.value()),
77 _ => abort!(lit, "Invalid argument to 'dir' attribute"),
78 };
79 if path.is_relative() {
80 path = pest_test::cargo_manifest_dir().join(path)
81 }
82 args.dir = path
83 }
84 "subdir" => {
85 args.subdir = match lit {
86 Lit::Str(s) => Some(PathBuf::from(s.value())),
87 _ => abort!(lit, "Invalid argument to 'subdir' attribute"),
88 }
89 }
90 "ext" => {
91 args.ext = match lit {
92 Lit::Str(s) => s.value(),
93 _ => abort!(lit, "Invalid argument to 'ext' attribute"),
94 }
95 }
96 "recursive" => {
97 args.recursive = match lit {
98 Lit::Bool(b) => b.value,
99 _ => abort!(lit, "Invalid argument to 'recursive' attribute"),
100 }
101 }
102 "strict" => {
103 args.strict = match lit {
104 Lit::Bool(b) => b.value,
105 _ => abort!(lit, "Invalid argument to 'strict' attribute"),
106 }
107 }
108 "no_eoi" => {
109 args.no_eoi = match lit {
110 Lit::Bool(b) => b.value,
111 _ => abort!(lit, "Invalid argument to 'no_eoi' attribute"),
112 }
113 }
114 "lazy_static" => {
115 args.lazy_static = match lit {
116 Lit::Bool(b) => b.value,
117 _ => abort!(lit, "Invalid argument to 'lazy_static' attribute"),
118 }
119 }
120 _ => abort!(path, "Invalid argument to pest_test_gen macro"),
121 }
122 }
123 NestedMeta::Meta(Meta::List(MetaList {
124 path,
125 paren_token: _,
126 nested,
127 })) => {
128 let attr_name = path
129 .get_ident()
130 .unwrap_or_else(|| abort!(path, "Invalid argument to pest_test_gen macro"))
131 .to_string();
132 if attr_name == "skip_rule" {
133 for rule_meta in nested {
134 match rule_meta {
135 NestedMeta::Lit(Lit::Str(s)) => {
136 let rule_name = s.value();
137 args.skip_rules
138 .push(Ident::new(rule_name.as_ref(), Span::call_site()));
139 if rule_name == "EOI" {
141 args.no_eoi = true;
142 }
143 }
144 _ => abort!(rule_meta, "Invalid skip_rule item"),
145 }
146 }
147 } else {
148 abort!(path, "Invalid argument to pest_test_gen macro");
149 }
150 }
151 _ => abort!(arg, "Invalid argument to pest_test_gen macro"),
152 }
153 }
154
155 args
156 }
157
158 fn iter_tests(&self) -> impl Iterator<Item = String> + '_ {
159 let dir = self
160 .subdir
161 .as_ref()
162 .map(|subdir| Cow::Owned(self.dir.join(subdir)))
163 .unwrap_or(Cow::Borrowed(&self.dir));
164 let mut walker = WalkDir::new(dir.as_ref());
165 if !self.recursive {
166 walker = walker.max_depth(1);
167 }
168 walker
169 .into_iter()
170 .filter_map(|entry| entry.ok())
171 .filter(|entry| {
172 let path = entry.path();
173 if path.is_dir() {
174 false
175 } else if self.ext.is_empty() {
176 path.extension().is_none()
177 } else {
178 entry.path().extension() == Some(self.ext.as_ref())
179 }
180 })
181 .map(move |entry| {
182 entry
183 .path()
184 .strip_prefix(dir.as_ref())
185 .expect("Error getting relative path of {:?}")
186 .with_extension("")
187 .as_os_str()
188 .to_str()
189 .unwrap()
190 .to_owned()
191 })
192 }
193}
194
195fn rule_variant(rule_path: &Path, variant_ident: Ident) -> Path {
196 let mut path = rule_path.clone();
197 path.segments.push(PathSegment {
198 ident: variant_ident,
199 arguments: PathArguments::None,
200 });
201 path
202}
203
204fn add_tests(module: &mut ItemMod, args: &Args) {
205 let (_, content) = module.content.get_or_insert_with(Default::default);
206
207 let test_dir = args.dir.as_os_str().to_str().unwrap().to_owned();
208 let test_ext = args.ext.clone();
209 let parser_path = &args.parser_path;
210 let rule_path = &args.rule_path;
211 let rule_ident = &args.rule_ident;
212 let mut skip_rules: Vec<Path> = args
213 .skip_rules
214 .iter()
215 .map(|ident| rule_variant(rule_path, ident.clone()))
216 .collect();
217 if !args.no_eoi {
218 skip_rules.push(rule_variant(
219 rule_path,
220 Ident::new("EOI", Span::call_site()),
221 ));
222 }
223
224 if args.lazy_static {
225 let lazy_static_tokens = quote! {
226 lazy_static::lazy_static! {
227 static ref COLORIZE: bool = {
228 option_env!("CARGO_TERM_COLOR").unwrap_or("always") != "never"
229 };
230 static ref TESTER: pest_test::PestTester<#rule_path, #parser_path> =
231 pest_test::PestTester::new(
232 #test_dir,
233 #test_ext,
234 #rule_path::#rule_ident,
235 std::collections::HashSet::from([#(#skip_rules),*])
236 );
237 }
238 };
239 let item: Item = match syn::parse2(lazy_static_tokens) {
240 Ok(item) => item,
241 Err(err) => abort_call_site!(format!("Error generating lazy_static block: {:?}", err)),
242 };
243 content.push(item);
244 }
245
246 for test_name in args.iter_tests() {
247 let fn_name = format_ident!("test_{}", test_name);
248 let fn_tokens = if args.lazy_static {
249 quote! {
250 #[test]
251 fn #fn_name() -> Result<(), pest_test::TestError<#rule_path>> {
252 let res = (*TESTER).evaluate_strict(#test_name);
253 if let Err(pest_test::TestError::Diff { ref diff }) = res {
254 diff.print_test_result(*COLORIZE).unwrap();
255 }
256 res
257 }
258 }
259 } else {
260 quote! {
261 #[test]
262 fn #fn_name() -> Result<(), pest_test::TestError<#rule_path>> {
263 let tester: pest_test::PestTester<#rule_path, #parser_path> = pest_test::PestTester::new(
264 #test_dir,
265 #test_ext,
266 #rule_path::#rule_ident,
267 std::collections::HashSet::from([#(#skip_rules),*])
268 );
269 let res = tester.evaluate_strict(#test_name);
270 if let Err(pest_test::TestError::Diff { ref diff }) = res {
271 let colorize = option_env!("CARGO_TERM_COLOR").unwrap_or("always") != "never";
272 diff.print_test_result(colorize).unwrap();
273 }
274 res
275 }
276 }
277 };
278 let item: Item = match syn::parse2(fn_tokens) {
279 Ok(item) => item,
280 Err(err) => {
281 abort_call_site!(format!("Error generating test fn {}: {:?}", test_name, err))
282 }
283 };
284 content.push(item);
285 }
286}
287
288#[proc_macro_attribute]
332#[proc_macro_error]
333pub fn pest_tests(attr: TokenStream, item: TokenStream) -> TokenStream {
334 let args = Args::from(parse_macro_input!(attr as AttributeArgs));
335 let mut module = match parse_macro_input!(item as Item) {
336 Item::Mod(module) => module,
337 other => abort!(
338 other,
339 "The pest_test_gen macro may only be used as an attribute on a module"
340 ),
341 };
342 add_tests(&mut module, &args);
343 module.to_token_stream().into()
344}