1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{
6 parenthesized,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10 Expr, Ident, ItemFn, LitBool, LitStr, Result, Token,
11};
12
13#[derive(PartialEq)]
16enum Name {
17 Some(Ident),
18 None,
19}
20
21impl Parse for Name {
22 fn parse(input: ParseStream) -> Result<Self> {
23 if input.peek(Ident) {
24 let name = Name::Some(input.parse()?);
25 let _ = input.parse::<Token![,]>()?;
26 Ok(name)
27 } else if input.peek(LitStr) {
28 let name = input.parse::<LitStr>()?;
29 let _ = input.parse::<Token![,]>()?;
30 if name.value().is_empty() {
31 Ok(Name::None)
32 } else {
33 Ok(Name::Some(Ident::new(&slugify(&name.value()), name.span())))
34 }
35 } else {
36 Ok(Name::None)
37 }
38 }
39}
40
41struct Input {
46 use_args_for_case_name: bool,
47 test_cases: Vec<TestCase>,
48}
49
50impl Parse for Input {
51 fn parse(input: ParseStream) -> Result<Self> {
52 let use_args_for_case_name =
53 if input.peek(Ident) && input.peek2(Token![=]) && input.peek3(LitBool) {
54 let option = input.parse::<Ident>()?;
55 if option != "use_args_for_case_name" {
56 return Err(syn::Error::new(
57 option.span(),
58 "Expected 'use_args_for_case_name' option",
59 ));
60 }
61 let _ = input.parse::<Token![=]>()?;
62 let use_args_for_case_name = input.parse::<LitBool>()?.value;
63 let _ = input.parse::<Token![,]>()?;
64 use_args_for_case_name
65 } else {
66 false
67 };
68 let test_cases = Punctuated::<TestCase, Token![,]>::parse_terminated(input)?
69 .into_iter()
70 .collect();
71 Ok(Input {
72 use_args_for_case_name,
73 test_cases,
74 })
75 }
76}
77
78struct TestCase {
83 name: Name,
84 args: Vec<Expr>,
85}
86
87impl Parse for TestCase {
88 fn parse(input: ParseStream) -> Result<Self> {
89 let name = input.parse::<Name>()?;
90
91 let content;
92 let _ = parenthesized!(content in input);
93
94 let args: Vec<Expr> = Punctuated::<Expr, Token![,]>::parse_terminated(&content)?
95 .into_iter()
96 .collect();
97 Ok(TestCase { name, args })
98 }
99}
100
101fn case_name_with_counter(name: Name, counter: i32, n_all: usize) -> Ident {
102 match name {
103 Name::Some(name) => name,
104 Name::None => {
105 let name = if n_all < 10 {
106 format!("case_{counter}")
107 } else if n_all < 100 {
108 format!("case_{counter:02}")
109 } else if n_all < 1000 {
110 format!("case_{counter:03}")
111 } else {
112 format!("case_{counter}")
113 };
114 Ident::new(&name, proc_macro::Span::call_site().into())
115 }
116 }
117}
118
119fn case_name_with_args(args: &[Expr]) -> Ident {
120 let name = args
121 .iter()
122 .map(|e| slugify(&e.to_token_stream().to_string()))
123 .collect::<Vec<_>>()
124 .join("_");
125
126 if name.is_empty() {
127 Ident::new("case", proc_macro::Span::call_site().into())
128 } else {
129 Ident::new(&name, proc_macro::Span::call_site().into())
130 }
131}
132
133fn slugify(name: &str) -> String {
134 let mut s: String = name
135 .to_ascii_lowercase()
136 .chars()
137 .map(|c| if c.is_alphanumeric() { c } else { '_' })
138 .collect();
139
140 if s.starts_with(|c: char| c.is_numeric()) {
141 s.insert(0, '_');
142 }
143
144 s
145}
146
147#[proc_macro_attribute]
149pub fn p_test(attr: TokenStream, item: TokenStream) -> TokenStream {
150 let attr_input = parse_macro_input!(attr as Input);
151
152 let item = parse_macro_input!(item as ItemFn);
153 let p_test_fn_sig = &item.sig;
154 let p_test_fn_name = &item.sig.ident;
155 let p_test_fn_block = &item.block;
156
157 let mut output = quote! {
158 #p_test_fn_sig {
159 #p_test_fn_block
160 }
161 };
162
163 let mut test_functions = quote! {};
164
165 let mut counter = 0;
166 let n_all = attr_input.test_cases.len();
167 for TestCase { name, args } in attr_input.test_cases {
168 counter += 1;
169 let name = if name == Name::None && attr_input.use_args_for_case_name && !args.is_empty() {
170 case_name_with_args(&args)
171 } else {
172 case_name_with_counter(name, counter, n_all)
173 };
174
175 let mut arg_list = quote! {};
176 for e in args {
177 arg_list.extend(quote! { #e, });
178 }
179 test_functions.extend(quote! {
180 #[test]
181 fn #name() {
182 #p_test_fn_name(#arg_list);
183 }
184 })
185 }
186
187 output.extend(quote! {
188 #[cfg(test)]
189 mod #p_test_fn_name {
190 use super::*;
191 #test_functions
192 }
193 });
194
195 output.into()
196}