pest_test_gen/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_error::{abort, abort_call_site, proc_macro_error};
4use quote::{format_ident, quote, ToTokens};
5use std::{borrow::Cow, path::PathBuf};
6use syn::{
7    parse_macro_input, AttributeArgs, Ident, Item, ItemMod, Lit, Meta, MetaList, MetaNameValue,
8    NestedMeta, Path, PathArguments, PathSegment,
9};
10use walkdir::WalkDir;
11
12struct Args {
13    parser_path: Path,
14    rule_path: Path,
15    rule_ident: Ident,
16    skip_rules: Vec<Ident>,
17    dir: PathBuf,
18    subdir: Option<PathBuf>,
19    ext: String,
20    recursive: bool,
21    strict: bool,
22    no_eoi: bool,
23    lazy_static: bool,
24}
25
26impl Args {
27    fn from(attr_args: Vec<NestedMeta>) -> Self {
28        let mut attr_args_iter = attr_args.into_iter();
29
30        // process required attrs
31        let parser_path = match attr_args_iter.next() {
32            Some(NestedMeta::Meta(Meta::Path(path))) => path,
33            Some(other) => abort!(other, "Invalid parser type"),
34            None => abort_call_site!("Missing required argument <parser>"),
35        };
36        let rule_path = match attr_args_iter.next() {
37            Some(NestedMeta::Meta(Meta::Path(path))) => path,
38            Some(other) => abort!(other, "Invalid rule"),
39            None => abort_call_site!("Missing required argument <rule>"),
40        };
41        let rule_ident = match attr_args_iter.next() {
42            Some(NestedMeta::Lit(Lit::Str(s))) => Ident::new(s.value().as_ref(), Span::call_site()),
43            Some(other) => abort!(other, "Invalid rule name"),
44            None => abort_call_site!("Missing required argument <root rule name>"),
45        };
46
47        let mut args = Args {
48            parser_path,
49            rule_path,
50            rule_ident,
51            skip_rules: Vec::new(),
52            dir: pest_test::default_test_dir(),
53            subdir: None,
54            ext: String::from("txt"),
55            recursive: false,
56            strict: true,
57            no_eoi: false,
58            lazy_static: false,
59        };
60
61        // process optional attrs
62        for arg in attr_args_iter {
63            match arg {
64                NestedMeta::Meta(Meta::NameValue(MetaNameValue {
65                    path,
66                    eq_token: _,
67                    lit,
68                })) => {
69                    let attr_name = path
70                        .get_ident()
71                        .unwrap_or_else(|| abort!(path, "Invalid argument to pest_test_gen macro"))
72                        .to_string();
73                    match attr_name.as_str() {
74                        "dir" => {
75                            let mut path = match lit {
76                                Lit::Str(s) => PathBuf::from(s.value()),
77                                _ => abort!(lit, "Invalid argument to 'dir' attribute"),
78                            };
79                            if path.is_relative() {
80                                path = pest_test::cargo_manifest_dir().join(path)
81                            }
82                            args.dir = path
83                        }
84                        "subdir" => {
85                            args.subdir = match lit {
86                                Lit::Str(s) => Some(PathBuf::from(s.value())),
87                                _ => abort!(lit, "Invalid argument to 'subdir' attribute"),
88                            }
89                        }
90                        "ext" => {
91                            args.ext = match lit {
92                                Lit::Str(s) => s.value(),
93                                _ => abort!(lit, "Invalid argument to 'ext' attribute"),
94                            }
95                        }
96                        "recursive" => {
97                            args.recursive = match lit {
98                                Lit::Bool(b) => b.value,
99                                _ => abort!(lit, "Invalid argument to 'recursive' attribute"),
100                            }
101                        }
102                        "strict" => {
103                            args.strict = match lit {
104                                Lit::Bool(b) => b.value,
105                                _ => abort!(lit, "Invalid argument to 'strict' attribute"),
106                            }
107                        }
108                        "no_eoi" => {
109                            args.no_eoi = match lit {
110                                Lit::Bool(b) => b.value,
111                                _ => abort!(lit, "Invalid argument to 'no_eoi' attribute"),
112                            }
113                        }
114                        "lazy_static" => {
115                            args.lazy_static = match lit {
116                                Lit::Bool(b) => b.value,
117                                _ => abort!(lit, "Invalid argument to 'lazy_static' attribute"),
118                            }
119                        }
120                        _ => abort!(path, "Invalid argument to pest_test_gen macro"),
121                    }
122                }
123                NestedMeta::Meta(Meta::List(MetaList {
124                    path,
125                    paren_token: _,
126                    nested,
127                })) => {
128                    let attr_name = path
129                        .get_ident()
130                        .unwrap_or_else(|| abort!(path, "Invalid argument to pest_test_gen macro"))
131                        .to_string();
132                    if attr_name == "skip_rule" {
133                        for rule_meta in nested {
134                            match rule_meta {
135                                NestedMeta::Lit(Lit::Str(s)) => {
136                                    let rule_name = s.value();
137                                    args.skip_rules
138                                        .push(Ident::new(rule_name.as_ref(), Span::call_site()));
139                                    // if EOI is added manually, don't add it again automatically
140                                    if rule_name == "EOI" {
141                                        args.no_eoi = true;
142                                    }
143                                }
144                                _ => abort!(rule_meta, "Invalid skip_rule item"),
145                            }
146                        }
147                    } else {
148                        abort!(path, "Invalid argument to pest_test_gen macro");
149                    }
150                }
151                _ => abort!(arg, "Invalid argument to pest_test_gen macro"),
152            }
153        }
154
155        args
156    }
157
158    fn iter_tests(&self) -> impl Iterator<Item = String> + '_ {
159        let dir = self
160            .subdir
161            .as_ref()
162            .map(|subdir| Cow::Owned(self.dir.join(subdir)))
163            .unwrap_or(Cow::Borrowed(&self.dir));
164        let mut walker = WalkDir::new(dir.as_ref());
165        if !self.recursive {
166            walker = walker.max_depth(1);
167        }
168        walker
169            .into_iter()
170            .filter_map(|entry| entry.ok())
171            .filter(|entry| {
172                let path = entry.path();
173                if path.is_dir() {
174                    false
175                } else if self.ext.is_empty() {
176                    path.extension().is_none()
177                } else {
178                    entry.path().extension() == Some(self.ext.as_ref())
179                }
180            })
181            .map(move |entry| {
182                entry
183                    .path()
184                    .strip_prefix(dir.as_ref())
185                    .expect("Error getting relative path of {:?}")
186                    .with_extension("")
187                    .as_os_str()
188                    .to_str()
189                    .unwrap()
190                    .to_owned()
191            })
192    }
193}
194
195fn rule_variant(rule_path: &Path, variant_ident: Ident) -> Path {
196    let mut path = rule_path.clone();
197    path.segments.push(PathSegment {
198        ident: variant_ident,
199        arguments: PathArguments::None,
200    });
201    path
202}
203
204fn add_tests(module: &mut ItemMod, args: &Args) {
205    let (_, content) = module.content.get_or_insert_with(Default::default);
206
207    let test_dir = args.dir.as_os_str().to_str().unwrap().to_owned();
208    let test_ext = args.ext.clone();
209    let parser_path = &args.parser_path;
210    let rule_path = &args.rule_path;
211    let rule_ident = &args.rule_ident;
212    let mut skip_rules: Vec<Path> = args
213        .skip_rules
214        .iter()
215        .map(|ident| rule_variant(rule_path, ident.clone()))
216        .collect();
217    if !args.no_eoi {
218        skip_rules.push(rule_variant(
219            rule_path,
220            Ident::new("EOI", Span::call_site()),
221        ));
222    }
223
224    if args.lazy_static {
225        let lazy_static_tokens = quote! {
226            lazy_static::lazy_static! {
227                static ref COLORIZE: bool = {
228                    option_env!("CARGO_TERM_COLOR").unwrap_or("always") != "never"
229                };
230                static ref TESTER: pest_test::PestTester<#rule_path, #parser_path> =
231                    pest_test::PestTester::new(
232                        #test_dir,
233                        #test_ext,
234                        #rule_path::#rule_ident,
235                        std::collections::HashSet::from([#(#skip_rules),*])
236                    );
237            }
238        };
239        let item: Item = match syn::parse2(lazy_static_tokens) {
240            Ok(item) => item,
241            Err(err) => abort_call_site!(format!("Error generating lazy_static block: {:?}", err)),
242        };
243        content.push(item);
244    }
245
246    for test_name in args.iter_tests() {
247        let fn_name = format_ident!("test_{}", test_name);
248        let fn_tokens = if args.lazy_static {
249            quote! {
250                #[test]
251                fn #fn_name() -> Result<(), pest_test::TestError<#rule_path>> {
252                    let res = (*TESTER).evaluate_strict(#test_name);
253                    if let Err(pest_test::TestError::Diff { ref diff }) = res {
254                        diff.print_test_result(*COLORIZE).unwrap();
255                    }
256                    res
257                }
258            }
259        } else {
260            quote! {
261                #[test]
262                fn #fn_name() -> Result<(), pest_test::TestError<#rule_path>> {
263                    let tester: pest_test::PestTester<#rule_path, #parser_path> = pest_test::PestTester::new(
264                        #test_dir,
265                        #test_ext,
266                        #rule_path::#rule_ident,
267                        std::collections::HashSet::from([#(#skip_rules),*])
268                    );
269                    let res = tester.evaluate_strict(#test_name);
270                    if let Err(pest_test::TestError::Diff { ref diff }) = res {
271                        let colorize = option_env!("CARGO_TERM_COLOR").unwrap_or("always") != "never";
272                        diff.print_test_result(colorize).unwrap();
273                    }
274                    res
275                }
276            }
277        };
278        let item: Item = match syn::parse2(fn_tokens) {
279            Ok(item) => item,
280            Err(err) => {
281                abort_call_site!(format!("Error generating test fn {}: {:?}", test_name, err))
282            }
283        };
284        content.push(item);
285    }
286}
287
288/// When added to a test module, adds test functions for pest-test test cases. Must come before
289/// the `#[cfg(test)]` attribute. If you specify `lazy_static = true` then a singleton `PestTester`
290/// is created and used by all the generated test functions (dependency on `lazy_static` is
291/// required), otherwise a separate instance is created for each test.
292///
293/// # Arguments:
294/// * **parser_type**: (required) the full path to the struct you defined that derives `pest::Parser`,
295///   e.g. `mycrate::parser::MyParser`.
296/// * **rule_type**: (required) the full path to the `Rule` enum, e.g. `mycrate::parser::Rule`.
297/// * **rule_name**: (required) the name of the `Rule` variant from which to start parsing.
298/// * skip_rules: (optional) a list of names of rules to skip when parsing; by default `Rule::EOI` is
299///   skipped unless `no_eoi = true`.
300/// * no_eoi: (optional) there is no `Rule::EOI` - don't automatically add it to `skip_rules`.
301/// * dir: (optional) the root directory where pest test cases are found; defaults to 'tests/pest'.
302/// * subdir: (optional) the subdirectory under `tests/pest` in which to look for test cases;
303///   defaults to "".
304/// * ext: (optional) the file extension of pest test cases; defaults to "txt".
305/// * recursive: (optional) whether to search for tests cases recursively under `{dir}/{subdir}`;
306///   defaults to `false`.
307/// * strict: (optional) whether to enforce that terminal node values must match between the
308///   expected and actual parse trees; defaults to `true`.
309/// * lazy_static: (optional) whether to create a singleton `PestTester` - requires dependency on
310///   `lazy_static`; defaults to `false`.
311///
312/// # Example:
313/// ```
314///
315/// use pest_test_gen;
316///
317/// #[pest_tests(
318///     mycrate::parser::MyParser,
319///     mycrate::parser::Rule,
320///     "root_rule",
321///     skip_rules("comment"),
322///     subdir = "foo",
323///     recursive = true,
324///     lazy_static = true
325/// )]
326/// #[cfg(test)]
327/// mod parser_tests {}
328///
329/// ```
330
331#[proc_macro_attribute]
332#[proc_macro_error]
333pub fn pest_tests(attr: TokenStream, item: TokenStream) -> TokenStream {
334    let args = Args::from(parse_macro_input!(attr as AttributeArgs));
335    let mut module = match parse_macro_input!(item as Item) {
336        Item::Mod(module) => module,
337        other => abort!(
338            other,
339            "The pest_test_gen macro may only be used as an attribute on a module"
340        ),
341    };
342    add_tests(&mut module, &args);
343    module.to_token_stream().into()
344}