p_test/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{
6    parenthesized,
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    punctuated::Punctuated,
10    Expr, Ident, ItemFn, LitBool, LitStr, Result, Token,
11};
12
13/// Represents the name of a test case, which can be either an identifier or a string literal.
14/// If the name is not provided, it will be `None`.
15#[derive(PartialEq)]
16enum Name {
17    Some(Ident),
18    None,
19}
20
21impl Parse for Name {
22    fn parse(input: ParseStream) -> Result<Self> {
23        if input.peek(Ident) {
24            let name = Name::Some(input.parse()?);
25            let _ = input.parse::<Token![,]>()?;
26            Ok(name)
27        } else if input.peek(LitStr) {
28            let name = input.parse::<LitStr>()?;
29            let _ = input.parse::<Token![,]>()?;
30            if name.value().is_empty() {
31                Ok(Name::None)
32            } else {
33                Ok(Name::Some(Ident::new(&slugify(&name.value()), name.span())))
34            }
35        } else {
36            Ok(Name::None)
37        }
38    }
39}
40
41/// Input for `p_test` attribute, consists of test name (optional),
42/// and a list of test cases. The test name will be used as a module name
43/// for the test. When the name is omitted, the test function name
44/// will be used instead.
45struct Input {
46    use_args_for_case_name: bool,
47    test_cases: Vec<TestCase>,
48}
49
50impl Parse for Input {
51    fn parse(input: ParseStream) -> Result<Self> {
52        let use_args_for_case_name =
53            if input.peek(Ident) && input.peek2(Token![=]) && input.peek3(LitBool) {
54                let option = input.parse::<Ident>()?;
55                if option != "use_args_for_case_name" {
56                    return Err(syn::Error::new(
57                        option.span(),
58                        "Expected 'use_args_for_case_name' option",
59                    ));
60                }
61                let _ = input.parse::<Token![=]>()?;
62                let use_args_for_case_name = input.parse::<LitBool>()?.value;
63                let _ = input.parse::<Token![,]>()?;
64                use_args_for_case_name
65            } else {
66                false
67            };
68        let test_cases = Punctuated::<TestCase, Token![,]>::parse_terminated(input)?
69            .into_iter()
70            .collect();
71        Ok(Input {
72            use_args_for_case_name,
73            test_cases,
74        })
75    }
76}
77
78/// Represent test case, consists of case name,
79/// and a list of arguments for the test function, (case_name, args...).
80/// One of the args can be used as an expected value.
81/// If the case name is omitted, the case name will be generated.
82struct TestCase {
83    name: Name,
84    args: Vec<Expr>,
85}
86
87impl Parse for TestCase {
88    fn parse(input: ParseStream) -> Result<Self> {
89        let name = input.parse::<Name>()?;
90
91        let content;
92        let _ = parenthesized!(content in input);
93
94        let args: Vec<Expr> = Punctuated::<Expr, Token![,]>::parse_terminated(&content)?
95            .into_iter()
96            .collect();
97        Ok(TestCase { name, args })
98    }
99}
100
101fn case_name_with_counter(name: Name, counter: i32, n_all: usize) -> Ident {
102    match name {
103        Name::Some(name) => name,
104        Name::None => {
105            let name = if n_all < 10 {
106                format!("case_{counter}")
107            } else if n_all < 100 {
108                format!("case_{counter:02}")
109            } else if n_all < 1000 {
110                format!("case_{counter:03}")
111            } else {
112                format!("case_{counter}")
113            };
114            Ident::new(&name, proc_macro::Span::call_site().into())
115        }
116    }
117}
118
119fn case_name_with_args(args: &[Expr]) -> Ident {
120    let name = args
121        .iter()
122        .map(|e| slugify(&e.to_token_stream().to_string()))
123        .collect::<Vec<_>>()
124        .join("_");
125
126    if name.is_empty() {
127        Ident::new("case", proc_macro::Span::call_site().into())
128    } else {
129        Ident::new(&name, proc_macro::Span::call_site().into())
130    }
131}
132
133fn slugify(name: &str) -> String {
134    let mut s: String = name
135        .to_ascii_lowercase()
136        .chars()
137        .map(|c| if c.is_alphanumeric() { c } else { '_' })
138        .collect();
139
140    if s.starts_with(|c: char| c.is_numeric()) {
141        s.insert(0, '_');
142    }
143
144    s
145}
146
147/// The attribute that annotates a function with arguments for parameterized test.
148#[proc_macro_attribute]
149pub fn p_test(attr: TokenStream, item: TokenStream) -> TokenStream {
150    let attr_input = parse_macro_input!(attr as Input);
151
152    let item = parse_macro_input!(item as ItemFn);
153    let p_test_fn_sig = &item.sig;
154    let p_test_fn_name = &item.sig.ident;
155    let p_test_fn_block = &item.block;
156
157    let mut output = quote! {
158        #p_test_fn_sig {
159            #p_test_fn_block
160        }
161    };
162
163    let mut test_functions = quote! {};
164
165    let mut counter = 0;
166    let n_all = attr_input.test_cases.len();
167    for TestCase { name, args } in attr_input.test_cases {
168        counter += 1;
169        let name = if name == Name::None && attr_input.use_args_for_case_name && !args.is_empty() {
170            case_name_with_args(&args)
171        } else {
172            case_name_with_counter(name, counter, n_all)
173        };
174
175        let mut arg_list = quote! {};
176        for e in args {
177            arg_list.extend(quote! { #e, });
178        }
179        test_functions.extend(quote! {
180            #[test]
181            fn #name() {
182                #p_test_fn_name(#arg_list);
183            }
184        })
185    }
186
187    output.extend(quote! {
188        #[cfg(test)]
189        mod #p_test_fn_name {
190            use super::*;
191            #test_functions
192        }
193    });
194
195    output.into()
196}