p_test/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    parenthesized,
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    punctuated::Punctuated,
10    Expr, Ident, ItemFn, Result, Token,
11};
12
13/// Input for `p_test` attribute, consists of test name (optional),
14/// and a list of test cases. The test name will be used as a module name
15/// for the test. When the name is omitted, the test function name
16/// will be used instead.
17struct Input {
18    test_name: Option<Ident>,
19    test_cases: Vec<TestCase>,
20}
21
22impl Parse for Input {
23    fn parse(input: ParseStream) -> Result<Self> {
24        let test_name = if input.peek(Ident) {
25            let test_name = input.parse::<Ident>()?;
26            let _ = input.parse::<Token![,]>()?;
27            Some(test_name)
28        } else {
29            None
30        };
31        let test_cases = Punctuated::<TestCase, Token![,]>::parse_terminated(input)?
32            .into_iter()
33            .collect();
34        Ok(Input {
35            test_name,
36            test_cases,
37        })
38    }
39}
40
41/// Represent test case, consists of case name (optional), 
42/// and a list of arguments for the test function, (case_name, args...)
43/// One of the args can be used as an expected value.
44/// If the case name is omitted, the case name will be generated 
45/// in `case_{n}` format, where `n` is the case number.
46struct TestCase {
47    name: Option<Ident>,
48    args: Vec<Expr>,
49}
50
51impl Parse for TestCase {
52    fn parse(input: ParseStream) -> Result<Self> {
53        let content;
54        let _ = parenthesized!(content in input);
55        let name = if content.peek(Ident) {
56            let name = content.parse()?;
57            let _ = content.parse::<Token![,]>()?;
58            Some(name)
59        } else {
60            None
61        };
62            let args: Vec<Expr> = Punctuated::<Expr, Token![,]>::parse_terminated(&content)?
63                .into_iter()
64                .collect();
65            Ok(TestCase { name, args })
66    }
67}
68
69fn test_case_name(name: Option<Ident>, counter: i32, n_all: usize) -> Ident {
70    if let Some(name) = name {
71        name
72    } else {
73        let name = if n_all < 10 {
74            &format!("case_{counter}")
75        } else if n_all < 100 {
76            &format!("case_{counter:02}")
77        } else if n_all < 1000 {
78            &format!("case_{counter:03}")
79        } else {
80            &format!("case_{counter}")
81        };
82        Ident::new(name, proc_macro::Span::call_site().into())
83    }
84}
85
86/// The attribute that annotates a function with arguments for parameterized test.
87#[proc_macro_attribute]
88pub fn p_test(attr: TokenStream, item: TokenStream) -> TokenStream {
89    let attr_input = parse_macro_input!(attr as Input);
90
91    let item = parse_macro_input!(item as ItemFn);
92    let p_test_fn_sig = &item.sig;
93    let p_test_fn_name = &item.sig.ident;
94    let p_test_fn_block = &item.block;
95
96    let mut output = quote! {
97        #p_test_fn_sig {
98            #p_test_fn_block
99        }
100    };
101
102    let mut test_functions = quote! {};
103
104    let mut counter = 0;
105    let n_all = attr_input.test_cases.len();
106    for TestCase { name, args} in attr_input.test_cases {
107        counter += 1;
108        let name = test_case_name(name, counter, n_all);
109        let mut arg_list = quote! {};
110        for e in args {
111            arg_list.extend(quote! { #e, });
112        }
113        test_functions.extend(quote! {
114            #[test]
115            fn #name() {
116                #p_test_fn_name(#arg_list);
117            }
118        })
119    }
120
121    let test_name = attr_input.test_name.unwrap_or(p_test_fn_name.clone());
122    output.extend(quote! {
123        #[cfg(test)]
124        mod #test_name {
125            use super::*;
126            #test_functions
127        }
128    });
129
130    output.into()
131}