1extern crate proc_macro;
18
19use proc_macro::TokenStream;
20use quote::{quote, ToTokens};
21use syn::{
22 parse::Parse, parse_macro_input, punctuated::Punctuated, Expr, ExprCall, ExprLit, ExprPath,
23 ItemFn, Lit, LitStr, ReturnType, Signature, Token, Type,
24};
25
26struct Input {
28 args: Punctuated<Expr, Token![,]>,
30 name: Option<LitStr>,
32}
33
34impl Parse for Input {
35 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
36 if input.is_empty() {
37 Ok(Input {
38 args: Default::default(),
39 name: None,
40 })
41 } else {
42 let args: Punctuated<Expr, Token![,]> =
43 Punctuated::parse_separated_nonempty_with(input, Expr::parse)?;
44
45 let name = if input.parse::<Token![;]>().is_ok() {
46 input.parse::<LitStr>().ok()
47 } else {
48 None
49 };
50
51 Ok(Input { args, name })
52 }
53 }
54}
55
56fn generate_test_name(org_func_name: &str, input: &Input) -> String {
59 let func_name = org_func_name.to_string();
60
61 if input.args.is_empty() {
62 return func_name.to_string();
63 }
64
65 let stringified_args = match &input.name {
66 Some(name_argument) => name_argument.value(),
67 _ => input
68 .args
69 .iter()
70 .filter_map(|expr| match expr {
71 Expr::Lit(ExprLit { lit, .. }) => match lit {
72 Lit::Str(lit_str) => Some(lit_str.value()),
73 other_literal => Some(quote!(#other_literal).to_string()),
74 },
75 expr @ Expr::Path(_) | expr @ Expr::Call(_) => extract_and_stringify_option(expr),
76 other_expr => Some(quote!(#other_expr).to_string()),
77 })
78 .map(|s| {
79 s.replace("+=", "_add_")
80 .replace("+", "_add_")
81 .replace("-=", "_sub_")
82 .replace("-", "_sub_")
83 .replace("/=", "_div_")
84 .replace("/", "_div_")
85 .replace("*=", "_mul_")
86 .replace("*", "_mul_")
87 .replace("%=", "_mod_")
88 .replace("%", "_mod_")
89 .replace("==", "_eq_")
90 .replace("!=", "_nq_")
91 .replace("&&", "_and_")
92 .replace("||", "_or_")
93 .replace("!", "not_")
94 .replace("&=", "_and_")
95 .replace("&", "_and_")
96 .replace("|=", "_or_")
97 .replace("|", "_or_")
98 .replace("^=", "_xor_")
99 .replace("^", "_xor_")
100 .replace("<<=", "_lshift_")
101 .replace("<<", "_lshift_")
102 .replace("<=", "_le_")
103 .replace("<", "_lt_")
104 .replace(">>=", "_rshift_")
105 .replace(">>", "_rshift_")
106 .replace(">=", "_ge_")
107 .replace(">", "_gt_")
108 .replace("&mut ", "")
109 .replace("*mut ", "")
110 .replace("&", "")
111 .replace("*", "")
112 .replace(" :: ", "_")
113 .replace("\\", "")
114 .replace("/", "")
115 .replace("\"", "")
116 .replace("(", "")
117 .replace(")", "")
118 .replace("{", "")
119 .replace("}", "")
120 .replace("[", "")
121 .replace("]", "")
122 .replace(" ", "")
123 .replace(",", "_")
124 .replace(".", "_")
125 .to_lowercase()
126 })
127 .collect::<Vec<_>>()
128 .join("_"),
129 };
130
131 format!("{func_name}::{stringified_args}")
132}
133
134#[derive(Debug, Eq, PartialEq)]
135enum ErrorCrate {
136 Eyre,
137 AnythingElse,
138}
139
140fn inspect_error_crate(sig: &Signature) -> ErrorCrate {
153 match &sig.output {
154 ReturnType::Default => panic!("return type needs to be other than ()"),
155 ReturnType::Type(_, ty) => {
156 let Type::Path(type_path) = ty.as_ref() else {
157 panic!("failed to get return type path");
158 };
159
160 let path = &type_path.path;
161 match (path.segments.first(), path.segments.last()) {
162 (Some(first), Some(last)) => {
163 if first.ident == "eyre" && last.ident == "Result" {
164 ErrorCrate::Eyre
165 } else {
166 ErrorCrate::AnythingElse
167 }
168 }
169 _ => {
170 panic!("unexpected return type");
171 }
172 }
173 }
174 }
175}
176
177#[allow(dead_code)]
178fn get_expr_variant_name(expr: &Expr) -> &'static str {
180 match expr {
181 Expr::Array(_) => "Array",
182 Expr::Assign(_) => "Assign",
183 Expr::Async(_) => "Async",
184 Expr::Await(_) => "Await",
185 Expr::Binary(_) => "Binary",
186 Expr::Block(_) => "Block",
187 Expr::Break(_) => "Break",
188 Expr::Call(_) => "Call",
189 Expr::Cast(_) => "Cast",
190 Expr::Closure(_) => "Closure",
191 Expr::Continue(_) => "Continue",
192 Expr::Field(_) => "Field",
193 Expr::ForLoop(_) => "ForLoop",
194 Expr::Group(_) => "Group",
195 Expr::If(_) => "If",
196 Expr::Index(_) => "Index",
197 Expr::Let(_) => "Let",
198 Expr::Lit(_) => "Lit",
199 Expr::Loop(_) => "Loop",
200 Expr::Macro(_) => "Macro",
201 Expr::Match(_) => "Match",
202 Expr::MethodCall(_) => "MethodCall",
203 Expr::Paren(_) => "Paren",
204 Expr::Path(_) => "Path",
205 Expr::Range(_) => "Range",
206 Expr::Reference(_) => "Reference",
207 Expr::Repeat(_) => "Repeat",
208 Expr::Return(_) => "Return",
209 Expr::Struct(_) => "Struct",
210 Expr::Try(_) => "Try",
211 Expr::TryBlock(_) => "TryBlock",
212 Expr::Tuple(_) => "Tuple",
213 Expr::Unary(_) => "Unary",
214 Expr::Unsafe(_) => "Unsafe",
215 Expr::Verbatim(_) => "Verbatim",
216 Expr::While(_) => "While",
217 Expr::Yield(_) => "Yield",
218 _ => "Unknown",
219 }
220}
221
222fn extract_and_stringify_option(expr: &Expr) -> Option<String> {
223 match expr {
224 Expr::Call(ExprCall { func, args, .. }) => {
225 if let Expr::Path(ExprPath { path, .. }) = &**func {
226 let segment = path.segments.last()?;
227 if segment.ident == "Some" {
228 match args.first()? {
229 Expr::Lit(ExprLit { lit, .. }) => match lit {
230 Lit::Str(lit_str) => {
231 return Some(lit_str.value());
232 }
233 other_type_of_literal => {
234 return Some(other_type_of_literal.to_token_stream().to_string());
235 }
236 },
237 first_arg => {
238 return Some(quote!(#first_arg).to_string());
239 }
240 }
241 }
242 }
243 }
244 Expr::Path(ExprPath { path, .. }) => {
245 if path.get_ident()? == "None" {
246 return Some("None".into());
247 }
248 }
249 _ => {}
250 }
251
252 None
253}
254
255#[proc_macro_attribute]
295pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
296 let input_args = parse_macro_input!(args as Input);
297 let input_fn = parse_macro_input!(input as ItemFn);
298
299 let func_name_inner = &input_fn.sig.ident;
300 let test_name_str = generate_test_name(&func_name_inner.to_string(), &input_args);
301
302 let args = input_args.args.to_token_stream();
303
304 let error_crate = inspect_error_crate(&input_fn.sig);
314 let output = if error_crate == ErrorCrate::Eyre {
315 quote! {
316 #input_fn
317
318 ::tanu::inventory::submit! {
320 ::tanu::TestRegistration {
321 module: module_path!(),
322 name: #test_name_str,
323 test_fn: || {
324 Box::pin(async move {
325 #func_name_inner(#args).await
326 })
327 },
328 }
329 }
330 }
331 } else {
332 quote! {
333 #input_fn
334
335 ::tanu::inventory::submit! {
337 ::tanu::TestRegistration {
338 module: module_path!(),
339 name: #test_name_str,
340 test_fn: || {
341 Box::pin(async move {
342 #func_name_inner(#args).await.map_err(|e| ::tanu::eyre::eyre!(Box::new(e)))
343 })
344 },
345 }
346 }
347 }
348 };
349
350 output.into()
351}
352
353#[proc_macro_attribute]
391pub fn main(_args: TokenStream, input: TokenStream) -> TokenStream {
392 let main_fn = parse_macro_input!(input as ItemFn);
393
394 let output = quote! {
395 fn run() -> tanu::Runner {
396 let mut runner = tanu::Runner::new();
397
398 for test in ::tanu::inventory::iter::<::tanu::TestRegistration> {
400 runner.add_test(
401 test.name,
402 test.module,
403 std::sync::Arc::new(test.test_fn)
404 );
405 }
406
407 runner
408 }
409
410 #main_fn
411 };
412
413 output.into()
414}
415
416#[cfg(test)]
417mod test {
418 use crate::Input;
419
420 use super::{ErrorCrate, Expr};
421 use test_case::test_case;
422
423 #[test_case("fn foo() -> eyre::Result" => ErrorCrate::Eyre; "eyre")]
424 #[test_case("fn foo() -> anyhow::Result" => ErrorCrate::AnythingElse; "anyhow")]
425 #[test_case("fn foo() -> miette::Result" => ErrorCrate::AnythingElse; "miette")]
426 #[test_case("fn foo() -> Result" => ErrorCrate::AnythingElse; "std_result")]
427 fn inspect_error_crate(s: &str) -> ErrorCrate {
428 let sig: syn::Signature = syn::parse_str(s).expect("failed to parse function signature");
429 super::inspect_error_crate(&sig)
430 }
431
432 #[test_case("Some(1)" => Some("1".into()); "Some with int")]
433 #[test_case("Some(\"test\")" => Some("test".into()); "Some with string")]
434 #[test_case("Some(true)" => Some("true".into()); "Some with boolean")]
435 #[test_case("Some(1.0)" => Some("1.0".into()); "Some with float")]
436 #[test_case("Some(StatusCode::OK)" => Some("StatusCode :: OK".into()); "Some third party type")]
437 #[test_case("Some(\"foo\".to_string())" => Some("\"foo\" . to_string ()".into()); "Some expression")]
438 #[test_case("None" => Some("None".into()); "None")]
439 fn extract_and_stringify_option(s: &str) -> Option<String> {
440 let expr: Expr = syn::parse_str(s).expect("failed to parse expression");
441 super::extract_and_stringify_option(&expr)
442 }
443
444 #[allow(clippy::erasing_op)]
445 #[test_case("a, b; \"test_name\"" => "foo::test_name"; "with test name")]
446 #[test_case("1+1" => "foo::1_add_1"; "with add expression")]
447 #[test_case("1+=1" => "foo::1_add_1"; "with add assignment expression")]
448 #[test_case("1-1" => "foo::1_sub_1"; "with sub expression")]
449 #[test_case("1-=1" => "foo::1_sub_1"; "with sub assignment expression")]
450 #[test_case("1/1" => "foo::1_div_1"; "with div expression")]
451 #[test_case("1/=1" => "foo::1_div_1"; "with div assignment expression")]
452 #[test_case("1*1" => "foo::1_mul_1"; "with mul expression")]
453 #[test_case("1*=1" => "foo::1_mul_1"; "with mul assignment expression")]
454 #[test_case("1%1" => "foo::1_mod_1"; "with mod expression")]
455 #[test_case("1%=1" => "foo::1_mod_1"; "with mod assignment expression")]
456 #[test_case("1==1" => "foo::1_eq_1"; "with eq expression")]
457 #[test_case("1!=1" => "foo::1_nq_1"; "with neq expression")]
458 #[test_case("1<1" => "foo::1_lt_1"; "with lt expression")]
459 #[test_case("1>1" => "foo::1_gt_1"; "with gt expression")]
460 #[test_case("1<=1" => "foo::1_le_1"; "with le expression")]
461 #[test_case("1>=1" => "foo::1_ge_1"; "with ge expression")]
462 #[test_case("true&&false" => "foo::true_and_false"; "with and expression")]
463 #[test_case("true||false" => "foo::true_or_false"; "with or expression")]
464 #[test_case("!true" => "foo::not_true"; "with not expression")]
465 #[test_case("1&1" => "foo::1_and_1"; "with bitwise and expression")]
466 #[test_case("1&=1" => "foo::1_and_1"; "with bitwise and assignment expression")]
467 #[test_case("1|1" => "foo::1_or_1"; "with bitwise or expression")]
468 #[test_case("1|=1" => "foo::1_or_1"; "with bitwise or assignment expression")]
469 #[test_case("1^1" => "foo::1_xor_1"; "with xor expression")]
470 #[test_case("1^=1" => "foo::1_xor_1"; "with xor assignment expression")]
471 #[test_case("1<<1" => "foo::1_lshift_1"; "with left shift expression")]
472 #[test_case("1<<=1" => "foo::1_lshift_1"; "with left shift assignment expression")]
473 #[test_case("1>>1" => "foo::1_rshift_1"; "with right shift expression")]
474 #[test_case("1>>=1" => "foo::1_rshift_1"; "with right shift assignment expression")]
475 #[test_case("\"bar\".to_string()" => "foo::bar_to_string"; "to_string")]
476 #[test_case("1+1*2" => "foo::1_add_1_mul_2"; "with add and mul expression")]
477 #[test_case("1*(2+3)" => "foo::1_mul_2_add_3"; "with mul and add expression")]
478 #[test_case("1+2-3" => "foo::1_add_2_sub_3"; "with add and sub expression")]
479 #[test_case("1/2*3" => "foo::1_div_2_mul_3"; "with div and mul expression")]
480 #[test_case("1%2+3" => "foo::1_mod_2_add_3"; "with mod and add expression")]
481 #[test_case("1==2&&3!=4" => "foo::1_eq_2_and_3_nq_4"; "with eq and and expression")]
482 #[test_case("true||false&&true" => "foo::true_or_false_and_true"; "with or and and expression")]
483 #[test_case("!(1+2)" => "foo::not_1_add_2"; "with not and add expression")]
484 #[test_case("1&2|3^4" => "foo::1_and_2_or_3_xor_4"; "with bitwise and, or, xor expression")]
485 #[test_case("1<<2>>3" => "foo::1_lshift_2_rshift_3"; "with left shift and right shift expression")]
486 #[test_case("Some(1+2)" => "foo::1_add_2"; "with Some and add expression")]
487 #[test_case("None" => "foo::none"; "with None")]
488 #[test_case("[1, 2]" => "foo::1_2"; "with array")]
489 #[test_case("vec![1, 2]" => "foo::vecnot_1_2"; "with macro")] #[test_case("\"foo\".to_string().len()" => "foo::foo_to_string_len"; "with function call chain")]
491 #[test_case("0.5+0.3" => "foo::0_5_add_0_3"; "with floating point add")] #[test_case("-10" => "foo::_sub_10"; "with negative number")] #[test_case("1.0e10" => "foo::1_0e10"; "with scientific notation")] #[test_case("0xff" => "foo::0xff"; "with hex literal")]
495 #[test_case("0o777" => "foo::0o777"; "with octal literal")]
496 #[test_case("0b1010" => "foo::0b1010"; "with binary literal")]
497 #[test_case("\"hello\" + \"world\"" => "foo::hello_add_world"; "with string concatenation")]
498 #[test_case("format!(\"{}{}\", 1, 2)" => "foo::formatnot__1_2"; "with format macro")] #[test_case("r#\"raw string\"#" => "foo::rawstring"; "with raw string")]
500 #[test_case("(1, \"hello\", true)" => "foo::1_hello_true"; "with mixed tuple")]
502 #[test_case("vec![1..5]" => "foo::vecnot_1__5"; "with range in macro")]
506 #[test_case("x.map(|v| v+1)" => "foo::x_map_or_v_or_v_add_1"; "with closure")]
508 #[test_case("a.into()" => "foo::a_into"; "with method call no args")]
509 #[test_case("a.parse::<i32>().unwrap()" => "foo::a_parse__lt_i32_gt__unwrap"; "with turbofish syntax")]
511 #[test_case("1..10" => "foo::1__10"; "with range")]
518 fn generate_test_name(args: &str) -> String {
522 let input_args: Input = syn::parse_str(args).expect("failed to parse input args");
523 super::generate_test_name("foo", &input_args)
524 }
525}