array_parameterized_test/
lib.rs

1use proc_macro2::{Group, TokenStream};
2use quote::{ToTokens, format_ident, quote};
3use syn::{Expr, Ident, ItemConst, ItemFn, Token, parse, parse_macro_input};
4
5/// Use as an attribute to annotate a parameterized test function that accepts a single parameter. The attribute must be given the name of a const array that is annotated with the `test_parameter` attribute. That const array will be used to supply parameter values to the test function.
6#[proc_macro_attribute]
7pub fn parameterized_test(
8    input: proc_macro::TokenStream,
9    annotated_item: proc_macro::TokenStream,
10) -> proc_macro::TokenStream {
11    let macro_name = parse_macro_input!(input as Ident);
12    let test = parse_macro_input!(annotated_item as ItemFn);
13
14    let inner_func_name = test.sig.ident.clone();
15
16    quote! {
17        mod #inner_func_name {
18            use super::*;
19
20            #test
21
22            #macro_name!{#inner_func_name}
23        }
24    }
25    .into()
26}
27
28/// Use as an attribute to annotate a const array that will be used to supply the input values for a parameterized test.
29///
30/// # Panics
31///
32/// Panics if the annotated item is not a const array expression. Will also panic if the array contains a path expression that is somehow empty.
33#[expect(clippy::expect_used, clippy::panic)]
34#[proc_macro_attribute]
35pub fn test_parameter(
36    _input: proc_macro::TokenStream,
37    annotated_item: proc_macro::TokenStream,
38) -> proc_macro::TokenStream {
39    let cloned_item_tokens = annotated_item.clone();
40    let item = parse_macro_input!(cloned_item_tokens as ItemConst);
41
42    let Expr::Array(array) = item.expr.as_ref() else {
43        panic!("Expected expression to be an array");
44    };
45
46    let values: Vec<_> = array
47        .elems
48        .iter()
49        .map(|n| {
50            if let Expr::Path(path) = n {
51                path.path
52                    .segments
53                    .last()
54                    .expect("path expressions in the const array to have at least one segment")
55                    .ident
56                    .to_token_stream()
57            } else {
58                n.to_token_stream()
59            }
60        })
61        .collect();
62
63    let annotated_item = TokenStream::from(annotated_item);
64
65    let const_item_name = item.ident;
66    let macro_name = format_ident!("{}_macro", &const_item_name);
67
68    let macro_output = quote! {
69        macro_rules! #macro_name {
70            ( $inner_test_name:ident ) => {
71                array_parameterized_test::generate_tests!{
72                    $inner_test_name,
73                    #const_item_name,
74                    [#(#values),*]
75                }
76            };
77        }
78
79        #[allow(unused_imports)]
80        pub(crate) use #macro_name as #const_item_name;
81
82        #annotated_item
83    };
84
85    macro_output.into()
86}
87
88struct GenerateTestsInput {
89    inner_test_name: Ident,
90    const_item_name: Ident,
91    const_item_values: Group,
92}
93
94impl syn::parse::Parse for GenerateTestsInput {
95    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
96        let inner_test_name = input.parse()?;
97        let _: Token![,] = input.parse()?;
98        let const_item_name = input.parse()?;
99        let _: Token![,] = input.parse()?;
100        let const_item_values = input.parse()?;
101
102        Ok(GenerateTestsInput {
103            inner_test_name,
104            const_item_name,
105            const_item_values,
106        })
107    }
108}
109
110/// A macro used to generate multiple test functions given a parameterized test function name, a const array identifier and the array's values.
111#[proc_macro]
112pub fn generate_tests(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
113    let GenerateTestsInput {
114        inner_test_name,
115        const_item_name,
116        const_item_values,
117    } = parse_macro_input!(input as GenerateTestsInput);
118
119    let tokens: proc_macro2::TokenStream = const_item_values
120        .stream()
121        .into_iter()
122        .step_by(2)
123        .enumerate()
124        .flat_map(|(i, value)| {
125            let suffix = value
126                .to_string()
127                .escape_default()
128                .collect::<String>()
129                .replace(|c: char| !c.is_ascii_alphanumeric(), "_");
130
131            let test_name = format_ident!("_{suffix}");
132
133            quote! {
134                #[test]
135                #[allow(non_snake_case)]
136                fn #test_name() {
137                    #inner_test_name(#const_item_name[#i]);
138                }
139            }
140        })
141        .collect();
142
143    tokens.into()
144}