datatest_derive/
lib.rs

1#![recursion_limit = "128"]
2#![deny(unused_must_use)]
3extern crate proc_macro;
4
5use proc_macro2::{Span, TokenStream};
6use quote::quote;
7use std::collections::HashMap;
8use syn::parse::{Parse, ParseStream, Result as ParseResult};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::token::Comma;
12use syn::{braced, parse_macro_input, FnArg, Ident, ItemFn, Pat, PatIdent, PatType, Token, Type};
13
14type Error = syn::parse::Error;
15
16struct TemplateArg {
17    ident: syn::Ident,
18    is_pattern: bool,
19    ignore_fn: Option<syn::Path>,
20    value: syn::LitStr,
21}
22
23impl Parse for TemplateArg {
24    fn parse(input: ParseStream) -> ParseResult<Self> {
25        let mut ignore_fn = None;
26        let ident = input.parse::<syn::Ident>()?;
27
28        let is_pattern = if input.peek(syn::token::In) {
29            let _in = input.parse::<syn::token::In>()?;
30            true
31        } else {
32            let _eq = input.parse::<syn::token::Eq>()?;
33            false
34        };
35        let value = input.parse::<syn::LitStr>()?;
36        if is_pattern && input.peek(syn::token::If) {
37            let _if = input.parse::<syn::token::If>()?;
38            let _not = input.parse::<syn::token::Not>()?;
39            ignore_fn = Some(input.parse::<syn::Path>()?);
40        }
41        Ok(Self {
42            ident,
43            is_pattern,
44            ignore_fn,
45            value,
46        })
47    }
48}
49
50/// Parse `#[file_test(...)]` attribute arguments
51/// The syntax is the following:
52///
53/// ```ignore
54/// #[files("<root>", {
55///   <arg_name> in "<regexp>",
56///   <arg_name> in "<template>",
57/// }]
58/// ```
59struct FilesTestArgs {
60    root: String,
61    args: HashMap<Ident, TemplateArg>,
62}
63
64/// See `syn` crate documentation / sources for more examples.
65impl Parse for FilesTestArgs {
66    fn parse(input: ParseStream) -> ParseResult<Self> {
67        let root = input.parse::<syn::LitStr>()?;
68        let _comma = input.parse::<syn::token::Comma>()?;
69        let content;
70        let _brace_token = braced!(content in input);
71
72        let args: Punctuated<TemplateArg, Comma> =
73            content.parse_terminated(TemplateArg::parse, Token![,])?;
74        let args = args
75            .into_pairs()
76            .map(|p| {
77                let value = p.into_value();
78                (value.ident.clone(), value)
79            })
80            .collect();
81
82        Ok(Self {
83            root: root.value(),
84            args,
85        })
86    }
87}
88
89enum Registration {
90    /// Register test cases via "global" constructors (https://crates.io/crates/ctor)
91    Ctor,
92    /// Register test cases via `#[test_case]` attribute (requires `custom_test_frameworks` unstable
93    /// feature).
94    Nightly,
95}
96
97/// Wrapper that turns on behavior that works on stable Rust.
98#[proc_macro_attribute]
99pub fn files_ctor_registration(
100    args: proc_macro::TokenStream,
101    func: proc_macro::TokenStream,
102) -> proc_macro::TokenStream {
103    guarded_test_attribute(
104        args,
105        func,
106        Ident::new("files_ctor_internal", Span::call_site()),
107    )
108}
109
110/// Wrapper that turns on behavior that works only on nightly Rust.
111#[proc_macro_attribute]
112pub fn files_test_case_registration(
113    args: proc_macro::TokenStream,
114    func: proc_macro::TokenStream,
115) -> proc_macro::TokenStream {
116    guarded_test_attribute(
117        args,
118        func,
119        Ident::new("files_test_case_internal", Span::call_site()),
120    )
121}
122
123#[proc_macro_attribute]
124pub fn files_ctor_internal(
125    args: proc_macro::TokenStream,
126    func: proc_macro::TokenStream,
127) -> proc_macro::TokenStream {
128    files_internal(args, func, Registration::Ctor)
129}
130
131#[proc_macro_attribute]
132pub fn files_test_case_internal(
133    args: proc_macro::TokenStream,
134    func: proc_macro::TokenStream,
135) -> proc_macro::TokenStream {
136    files_internal(args, func, Registration::Nightly)
137}
138
139/// Proc macro handling `#[files(...)]` syntax. This attribute defines rules for deriving
140/// test function arguments from file paths. There are two types of rules:
141/// 1. Pattern rule, `<arg_name> in "<regexp>"`
142/// 2. Template rule, `<arg_name> = "regexp"`
143///
144/// There must be only one pattern rule defined in the attribute. It defines a regular expression
145/// to run against all files found in the test directory.
146///
147/// Template rule defines rules how the name of the matched file is transformed to get related files.
148///
149/// This macro is responsible for generating a test descriptor (`datatest::FilesTestDesc`) based on the
150/// `#[files(..)]` attribute attached to the test function.
151///
152/// There are four fields specific for these type of tests we need to fill in:
153///
154/// 1. `root`, which is the root directory to scan for the tests (relative to the root of the crate
155/// with tests)
156/// 2. `params`, slice of strings, each string is either a template or pattern assigned to the
157/// function argument
158/// 3. `pattern`, an index of the "pattern" argument (since exactly one is required, it is just an
159/// index in the `params` array).
160/// 4. `testfn`, test function trampoline.
161///
162/// Few words about trampoline function. Each test function could have a unique signature, depending
163/// on which types it needs and which files it requires as an input. However, our test framework
164/// should be capable of running these test functions via some standardized interface. This interface
165/// is `fn(&[PathBuf])`. Each slice element matches test function argument (so length of this slice
166/// is the same as amount of arguments test function has).
167///
168/// In addition to that, this trampoline function is also responsible for mapping `&PathBuf`
169/// references into argument types. There is some trait magic involved to make code work for both
170/// cases when function takes argument as a slice (`&str`, `&[u8]`) and for cases when function takes
171/// argument as owned (`String`, `Vec<u8>`).
172///
173/// The difficulty here is that for owned arguments we can create value and just pass it down to the
174/// function. However, for arguments taking slices, we need to store value somewhere on the stack
175/// and pass a reference.
176///
177/// I could have made this proc macro to handle these cases explicitly and generate a different
178/// code, but I decided to not add a complexity of type analysis to the proc macro and use traits
179/// instead. See `datatest::TakeArg` and `datatest::DeriveArg` to see how this mechanism works.
180fn files_internal(
181    args: proc_macro::TokenStream,
182    func: proc_macro::TokenStream,
183    channel: Registration,
184) -> proc_macro::TokenStream {
185    let mut func_item: ItemFn = parse_macro_input!(func as ItemFn);
186    let args: FilesTestArgs = parse_macro_input!(args as FilesTestArgs);
187    let info = handle_common_attrs(&mut func_item, false);
188    let func_ident = &func_item.sig.ident;
189    let func_name_str = func_ident.to_string();
190    let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
191    let trampoline_func_ident = Ident::new(
192        &format!("__TEST_TRAMPOLINE_{}", func_ident),
193        func_ident.span(),
194    );
195    let ignore = info.ignore;
196    let root = args.root;
197    let mut pattern_idx = None;
198    let mut params: Vec<String> = Vec::new();
199    let mut invoke_args: Vec<TokenStream> = Vec::new();
200    let mut ignore_fn = None;
201
202    // Match function arguments with our parsed list of mappings
203    // We do the following in this loop:
204    // 1. For each argument we collect the corresponding template defined for that argument
205    // 2. For each argument we collect piece of code to create argument from the `&[PathBuf]` slice
206    // given to us by the test runner.
207    // 3. Capture the index of the argument corresponding to the "pattern" mapping
208    for (mut idx, arg) in func_item.sig.inputs.iter().enumerate() {
209        match match_arg(arg) {
210            Some((pat_ident, ty)) => {
211                if info.bench {
212                    if idx == 0 {
213                        // FIXME: verify is Bencher!
214                        invoke_args.push(quote!(#pat_ident));
215                        continue;
216                    } else {
217                        idx -= 1;
218                    }
219                }
220
221                if let Some(arg) = args.args.get(&pat_ident.ident) {
222                    if arg.is_pattern {
223                        if pattern_idx.is_some() {
224                            return Error::new(arg.ident.span(), "two patterns are not allowed!")
225                                .to_compile_error()
226                                .into();
227                        }
228                        pattern_idx = Some(idx);
229                        ignore_fn = arg.ignore_fn.clone();
230                    }
231
232                    params.push(arg.value.value());
233                    invoke_args.push(quote! {
234                        ::datatest::__internal::TakeArg::take(&mut <#ty as ::datatest::__internal::DeriveArg>::derive(&paths_arg[#idx]))
235                    })
236                } else {
237                    return Error::new(pat_ident.span(), "mapping is not defined for the argument")
238                        .to_compile_error()
239                        .into();
240                }
241            }
242            None => {
243                return Error::new(
244                    arg.span(),
245                    "unexpected argument; only simple argument types are allowed (`&str`, `String`, `&[u8]`, `Vec<u8>`, `&Path`, etc)",
246                ).to_compile_error().into();
247            }
248        }
249    }
250
251    let ignore_func_ref = if let Some(ignore_fn) = ignore_fn {
252        quote!(Some(#ignore_fn))
253    } else {
254        quote!(None)
255    };
256
257    if pattern_idx.is_none() {
258        return Error::new(
259            Span::call_site(),
260            "must have exactly one pattern mapping defined via `pattern in r#\"<regular expression>\"`",
261        )
262            .to_compile_error()
263            .into();
264    }
265
266    let (kind, bencher_param) = if info.bench {
267        (
268            quote!(BenchFn),
269            quote!(bencher: &mut ::datatest::__internal::Bencher,),
270        )
271    } else {
272        (quote!(TestFn), quote!())
273    };
274
275    let registration = test_registration(channel, &desc_ident);
276    let output = quote! {
277        #registration
278        #[automatically_derived]
279        #[allow(non_upper_case_globals)]
280        static #desc_ident: ::datatest::__internal::FilesTestDesc = ::datatest::__internal::FilesTestDesc {
281            name: concat!(module_path!(), "::", #func_name_str),
282            ignore: #ignore,
283            root: #root,
284            params: &[#(#params),*],
285            pattern: #pattern_idx,
286            ignorefn: #ignore_func_ref,
287            testfn: ::datatest::__internal::FilesTestFn::#kind(#trampoline_func_ident),
288            source_file: file!(),
289        };
290
291        #[automatically_derived]
292        #[allow(non_snake_case)]
293        fn #trampoline_func_ident(#bencher_param paths_arg: &[::std::path::PathBuf]) {
294            let result = #func_ident(#(#invoke_args),*);
295            ::datatest::__internal::assert_test_result(result);
296        }
297
298        #func_item
299    };
300    output.into()
301}
302
303fn match_arg(arg: &FnArg) -> Option<(&PatIdent, &Type)> {
304    if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
305        if let Pat::Ident(pat_ident) = pat.as_ref() {
306            return Some((pat_ident, ty));
307        }
308    }
309    None
310}
311
312enum ShouldPanic {
313    No,
314    Yes,
315    YesWithMessage(String),
316}
317
318struct FuncInfo {
319    ignore: bool,
320    bench: bool,
321    should_panic: ShouldPanic,
322}
323
324/// Only allows certain attributes (`#[should_panic]`, for example) when used against a "regular"
325/// test `#[test]`.
326fn handle_common_attrs(func: &mut ItemFn, regular_test: bool) -> FuncInfo {
327    // Remove #[test] attribute as we don't want standard test framework to handle it!
328    // We allow #[test] to be used to improve IDE experience (namely, IntelliJ Rust), which would
329    // only allow you to run test if it is marked with `#[test]`
330    let test_pos = func
331        .attrs
332        .iter()
333        .position(|attr| attr.path().is_ident("test"));
334    if let Some(pos) = test_pos {
335        func.attrs.remove(pos);
336    }
337
338    // Same for #[bench]
339    let bench_pos = func
340        .attrs
341        .iter()
342        .position(|attr| attr.path().is_ident("bench"));
343    if let Some(pos) = bench_pos {
344        func.attrs.remove(pos);
345    }
346
347    // Allow tests to be marked as `#[ignore]`.
348    let ignore_pos = func
349        .attrs
350        .iter()
351        .position(|attr| attr.path().is_ident("ignore"));
352    if let Some(pos) = ignore_pos {
353        func.attrs.remove(pos);
354    }
355
356    let mut should_panic = ShouldPanic::No;
357    if regular_test {
358        // Regular tests support (on stable channel): allow `#[should_panic]`
359        let should_panic_pos = func
360            .attrs
361            .iter()
362            .position(|attr| attr.path().is_ident("should_panic"));
363        if let Some(pos) = should_panic_pos {
364            let attr = &func.attrs[pos];
365            should_panic = parse_should_panic(attr);
366            func.attrs.remove(pos);
367        }
368    }
369
370    FuncInfo {
371        ignore: ignore_pos.is_some(),
372        bench: bench_pos.is_some(),
373        should_panic,
374    }
375}
376
377#[allow(clippy::collapsible_match)]
378fn parse_should_panic(attr: &syn::Attribute) -> ShouldPanic {
379    let mut message: Option<String> = None;
380    _ = attr.parse_nested_meta(|meta| {
381        if meta.path.is_ident("expected") {
382            if let Ok(v) = meta.value() {
383                let mut value = v.to_string();
384                if value.starts_with("\"") {
385                    value.remove(0);
386                }
387                if value.ends_with("\"") {
388                    value.pop();
389                }
390                message = Some(value);
391            }
392        }
393        Ok(())
394    });
395    match message {
396        Some(message) => ShouldPanic::YesWithMessage(message),
397        None => ShouldPanic::Yes,
398    }
399}
400
401/// Parse `#[data(...)]` attribute arguments. It's either a function returning
402/// `Vec<datatest::DataTestCaseDesc<T>>` (where `T` is a test case type) or string literal, which
403/// is interpreted as `datatest::yaml("<path>")`
404#[allow(clippy::large_enum_variant)]
405enum DataTestArgs {
406    Literal(syn::LitStr),
407    Expression(syn::Expr),
408}
409
410/// See `syn` crate documentation / sources for more examples.
411impl Parse for DataTestArgs {
412    fn parse(input: ParseStream) -> ParseResult<Self> {
413        let lookahead = input.lookahead1();
414        if lookahead.peek(syn::LitStr) {
415            input.parse::<syn::LitStr>().map(DataTestArgs::Literal)
416        } else {
417            input.parse::<syn::Expr>().map(DataTestArgs::Expression)
418        }
419    }
420}
421
422/// Wrapper that turns on behavior that works on stable Rust.
423#[proc_macro_attribute]
424pub fn data_ctor_registration(
425    args: proc_macro::TokenStream,
426    func: proc_macro::TokenStream,
427) -> proc_macro::TokenStream {
428    guarded_test_attribute(
429        args,
430        func,
431        Ident::new("data_ctor_internal", Span::call_site()),
432    )
433}
434
435/// Wrapper that turns on behavior that works only on nightly Rust.
436#[proc_macro_attribute]
437pub fn data_test_case_registration(
438    args: proc_macro::TokenStream,
439    func: proc_macro::TokenStream,
440) -> proc_macro::TokenStream {
441    guarded_test_attribute(
442        args,
443        func,
444        Ident::new("data_test_case_internal", Span::call_site()),
445    )
446}
447
448#[proc_macro_attribute]
449pub fn data_ctor_internal(
450    args: proc_macro::TokenStream,
451    func: proc_macro::TokenStream,
452) -> proc_macro::TokenStream {
453    data_internal(args, func, Registration::Ctor)
454}
455
456#[proc_macro_attribute]
457pub fn data_test_case_internal(
458    args: proc_macro::TokenStream,
459    func: proc_macro::TokenStream,
460) -> proc_macro::TokenStream {
461    data_internal(args, func, Registration::Nightly)
462}
463
464fn data_internal(
465    args: proc_macro::TokenStream,
466    func: proc_macro::TokenStream,
467    channel: Registration,
468) -> proc_macro::TokenStream {
469    let mut func_item = parse_macro_input!(func as ItemFn);
470    let cases: DataTestArgs = parse_macro_input!(args as DataTestArgs);
471    let info = handle_common_attrs(&mut func_item, false);
472    let cases = match cases {
473        DataTestArgs::Literal(path) => quote!(datatest::yaml(#path)),
474        DataTestArgs::Expression(expr) => quote!(#expr),
475    };
476    let func_ident = &func_item.sig.ident;
477
478    let func_name_str = func_ident.to_string();
479    let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
480    let describe_func_ident = Ident::new(
481        &format!("__TEST_DESCRIBE_{}", func_ident),
482        func_ident.span(),
483    );
484    let trampoline_func_ident = Ident::new(
485        &format!("__TEST_TRAMPOLINE_{}", func_ident),
486        func_ident.span(),
487    );
488
489    let ignore = info.ignore;
490    // FIXME: check file exists!
491    let mut args = func_item.sig.inputs.iter();
492
493    if info.bench {
494        // Skip Bencher argument
495        // FIXME: verify it is &mut Bencher
496        args.next();
497    }
498
499    let arg = args.next();
500    let ty = match arg {
501        Some(FnArg::Typed(PatType { ty, .. })) => Some(ty.as_ref()),
502        _ => None,
503    };
504    let (ref_token, ty) = match ty {
505        Some(syn::Type::Reference(type_ref)) => (quote!(&), Some(type_ref.elem.as_ref())),
506        _ => (TokenStream::new(), ty),
507    };
508
509    let (case_ctor, bencher_param, bencher_arg) = if info.bench {
510        (
511            quote!(::datatest::__internal::DataTestFn::BenchFn(Box::new(::datatest::__internal::DataBenchFn(#trampoline_func_ident, case)))),
512            quote!(bencher: &mut ::datatest::__internal::Bencher,),
513            quote!(bencher,),
514        )
515    } else {
516        (
517            quote!(::datatest::__internal::DataTestFn::TestFn(Box::new(move || #trampoline_func_ident(case)))),
518            quote!(),
519            quote!(),
520        )
521    };
522
523    let registration = test_registration(channel, &desc_ident);
524    let output = quote! {
525        #registration
526        #[automatically_derived]
527        #[allow(non_upper_case_globals)]
528        static #desc_ident: ::datatest::__internal::DataTestDesc = ::datatest::__internal::DataTestDesc {
529            name: concat!(module_path!(), "::", #func_name_str),
530            ignore: #ignore,
531            describefn: #describe_func_ident,
532            source_file: file!(),
533        };
534
535        #[automatically_derived]
536        #[allow(non_snake_case)]
537        fn #trampoline_func_ident(#bencher_param arg: #ty) {
538            let result = #func_ident(#bencher_arg #ref_token arg);
539            ::datatest::__internal::assert_test_result(result);
540        }
541
542        #[automatically_derived]
543        #[allow(non_snake_case)]
544        fn #describe_func_ident() -> Vec<::datatest::DataTestCaseDesc<::datatest::__internal::DataTestFn>> {
545            let result = #cases
546                .into_iter()
547                .map(|input| {
548                    let case = input.case;
549                    ::datatest::DataTestCaseDesc {
550                        case: #case_ctor,
551                        name: input.name,
552                        location: input.location,
553                    }
554                })
555                .collect::<Vec<_>>();
556            assert!(!result.is_empty(), "no test cases were found!");
557            result
558        }
559
560        #func_item
561    };
562    output.into()
563}
564
565fn test_registration(channel: Registration, desc_ident: &syn::Ident) -> TokenStream {
566    match channel {
567        // On nightly, we rely on `custom_test_frameworks` feature
568        Registration::Nightly => quote!(#[test_case]),
569        // On stable, we use `ctor` crate to build a registry of all our tests
570        Registration::Ctor => {
571            let registration_fn =
572                syn::Ident::new(&format!("{}__REGISTRATION", desc_ident), desc_ident.span());
573            let check_fn = syn::Ident::new(&format!("{}__CHECK", desc_ident), desc_ident.span());
574            let tokens = quote! {
575                #[automatically_derived]
576                #[allow(non_snake_case)]
577                #[datatest::__internal::ctor]
578                fn #registration_fn() {
579                    use ::datatest::__internal::RegistrationNode;
580                    static mut REGISTRATION: RegistrationNode = RegistrationNode {
581                        descriptor: &#desc_ident,
582                        next: None,
583                    };
584                    // This runs only once during initialization, so should be safe
585                    ::datatest::__internal::register(unsafe { &mut REGISTRATION });
586                }
587
588                // Make sure we our registry was actually scanned!
589                // This would detect scenario where none of the ways are used to plug datatest
590                // test runner (either by replacing the whole harness or by overriding test runner).
591                // So, for every test we have registered, we make sure this test actually gets
592                // executed.
593                #[automatically_derived]
594                #[allow(non_snake_case)]
595                mod #check_fn {
596                    #[datatest::__internal::dtor]
597                    fn check_fn() {
598                        ::datatest::__internal::check_test_runner();
599                    }
600                }
601            };
602            tokens
603        }
604    }
605}
606
607/// Replacement for the `#[test]` attribute that uses ctor-based test registration so it can be
608/// used when the whole test harness is replaced.
609#[proc_macro_attribute]
610pub fn test_ctor_registration(
611    _args: proc_macro::TokenStream,
612    func: proc_macro::TokenStream,
613) -> proc_macro::TokenStream {
614    let mut func_item = parse_macro_input!(func as ItemFn);
615    let info = handle_common_attrs(&mut func_item, true);
616    let func_ident = &func_item.sig.ident;
617    let func_name_str = func_ident.to_string();
618    let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
619
620    let ignore = info.ignore;
621    let should_panic = match info.should_panic {
622        ShouldPanic::No => quote!(::datatest::__internal::RegularShouldPanic::No),
623        ShouldPanic::Yes => quote!(::datatest::__internal::RegularShouldPanic::Yes),
624        ShouldPanic::YesWithMessage(v) => {
625            quote!(::datatest::__internal::RegularShouldPanic::YesWithMessage(#v))
626        }
627    };
628    let registration = test_registration(Registration::Ctor, &desc_ident);
629    let output = quote! {
630        #registration
631        #[automatically_derived]
632        #[allow(non_upper_case_globals)]
633        static #desc_ident: ::datatest::__internal::RegularTestDesc = ::datatest::__internal::RegularTestDesc {
634            name: concat!(module_path!(), "::", #func_name_str),
635            ignore: #ignore,
636            testfn: || {
637                let result = #func_ident();
638                ::datatest::__internal::assert_test_result(result);
639            },
640            should_panic: #should_panic,
641            source_file: file!(),
642        };
643
644        #func_item
645    };
646
647    output.into()
648}
649
650fn guarded_test_attribute(
651    args: proc_macro::TokenStream,
652    item: proc_macro::TokenStream,
653    implementation: Ident,
654) -> proc_macro::TokenStream {
655    let args: TokenStream = args.into();
656    let header = quote! {
657        #[cfg(test)]
658        #[::datatest::__internal::#implementation(#args)]
659    };
660    let mut out: proc_macro::TokenStream = header.into();
661    out.extend(item);
662    out
663}