1extern crate proc_macro;
18
19use proc_macro::TokenStream;
20use quote::{quote, ToTokens};
21use syn::{
22 parse::Parse, parse_macro_input, punctuated::Punctuated, spanned::Spanned, Expr, ExprCall,
23 ExprLit, ExprPath, Item, ItemFn, ItemMod, Lit, LitStr, ReturnType, Signature, Token, Type,
24};
25
26struct Input {
28 args: Punctuated<Expr, Token![,]>,
30 name: Option<LitStr>,
32 serial_group: Option<String>,
34 ordered: bool,
36}
37
38impl Parse for Input {
39 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
40 if input.is_empty() {
41 return Ok(Input {
42 args: Default::default(),
43 name: None,
44 serial_group: None,
45 ordered: false,
46 });
47 }
48
49 let mut serial_group: Option<String> = None;
50 let mut ordered = false;
51 let mut test_args: Punctuated<Expr, Token![,]> = Punctuated::new();
52
53 loop {
55 if input.peek(Token![;]) || input.is_empty() {
56 break;
57 }
58
59 if input.peek(syn::Ident) {
61 let fork = input.fork();
62 if let Ok(ident) = fork.parse::<syn::Ident>() {
63 if ident == "serial" {
64 input.parse::<syn::Ident>()?;
66
67 let group = if input.peek(Token![=]) {
69 input.parse::<Token![=]>()?;
70 let lit: LitStr = input.parse()?;
71 Some(lit.value())
72 } else {
73 Some(String::new()) };
75
76 serial_group = group;
77
78 if input.peek(Token![,]) {
80 input.parse::<Token![,]>()?;
81 }
82 continue;
83 } else if ident == "ordered" {
84 input.parse::<syn::Ident>()?;
86 ordered = true;
87
88 if input.peek(Token![,]) {
90 input.parse::<Token![,]>()?;
91 }
92 continue;
93 }
94 }
95 }
96
97 let expr = input.parse::<Expr>()?;
99 test_args.push(expr);
100
101 if input.peek(Token![,]) && !input.peek2(Token![;]) {
103 input.parse::<Token![,]>()?;
104 } else if !input.peek(Token![;]) && !input.is_empty() {
105 break;
106 }
107 }
108
109 let name = if input.parse::<Token![;]>().is_ok() {
111 input.parse::<LitStr>().ok()
112 } else {
113 None
114 };
115
116 Ok(Input {
117 args: test_args,
118 name,
119 serial_group,
120 ordered,
121 })
122 }
123}
124
125fn generate_test_name(org_func_name: &str, input: &Input) -> String {
128 let func_name = org_func_name.to_string();
129
130 if input.args.is_empty() {
131 return func_name.to_string();
132 }
133
134 let stringified_args = match &input.name {
135 Some(name_argument) => name_argument.value(),
136 _ => input
137 .args
138 .iter()
139 .filter_map(|expr| match expr {
140 Expr::Lit(ExprLit { lit, .. }) => match lit {
141 Lit::Str(lit_str) => Some(lit_str.value()),
142 other_literal => Some(quote!(#other_literal).to_string()),
143 },
144 expr @ Expr::Path(_) | expr @ Expr::Call(_) => extract_and_stringify_option(expr),
145 other_expr => Some(quote!(#other_expr).to_string()),
146 })
147 .map(|s| {
148 s.replace("+=", "_add_")
149 .replace("+", "_add_")
150 .replace("-=", "_sub_")
151 .replace("-", "_sub_")
152 .replace("/=", "_div_")
153 .replace("/", "_div_")
154 .replace("*=", "_mul_")
155 .replace("*", "_mul_")
156 .replace("%=", "_mod_")
157 .replace("%", "_mod_")
158 .replace("==", "_eq_")
159 .replace("!=", "_nq_")
160 .replace("&&", "_and_")
161 .replace("||", "_or_")
162 .replace("!", "not_")
163 .replace("&=", "_and_")
164 .replace("&", "_and_")
165 .replace("|=", "_or_")
166 .replace("|", "_or_")
167 .replace("^=", "_xor_")
168 .replace("^", "_xor_")
169 .replace("<<=", "_lshift_")
170 .replace("<<", "_lshift_")
171 .replace("<=", "_le_")
172 .replace("<", "_lt_")
173 .replace(">>=", "_rshift_")
174 .replace(">>", "_rshift_")
175 .replace(">=", "_ge_")
176 .replace(">", "_gt_")
177 .replace("&mut ", "")
178 .replace("*mut ", "")
179 .replace("&", "")
180 .replace("*", "")
181 .replace(" :: ", "_")
182 .replace("\\", "")
183 .replace("/", "")
184 .replace("\"", "")
185 .replace("(", "")
186 .replace(")", "")
187 .replace("{", "")
188 .replace("}", "")
189 .replace("[", "")
190 .replace("]", "")
191 .replace(" ", "")
192 .replace(",", "_")
193 .replace(".", "_")
194 .to_lowercase()
195 })
196 .collect::<Vec<_>>()
197 .join("_"),
198 };
199
200 format!("{func_name}::{stringified_args}")
201}
202
203#[derive(Debug, Eq, PartialEq)]
204enum ErrorCrate {
205 Eyre,
206 AnythingElse,
207}
208
209fn inspect_error_crate(sig: &Signature) -> ErrorCrate {
222 match &sig.output {
223 ReturnType::Default => panic!("return type needs to be other than ()"),
224 ReturnType::Type(_, ty) => {
225 let Type::Path(type_path) = ty.as_ref() else {
226 panic!("failed to get return type path");
227 };
228
229 let path = &type_path.path;
230 match (path.segments.first(), path.segments.last()) {
231 (Some(first), Some(last)) => {
232 if first.ident == "eyre" && last.ident == "Result" {
233 ErrorCrate::Eyre
234 } else {
235 ErrorCrate::AnythingElse
236 }
237 }
238 _ => {
239 panic!("unexpected return type");
240 }
241 }
242 }
243 }
244}
245
246#[allow(dead_code)]
247fn get_expr_variant_name(expr: &Expr) -> &'static str {
249 match expr {
250 Expr::Array(_) => "Array",
251 Expr::Assign(_) => "Assign",
252 Expr::Async(_) => "Async",
253 Expr::Await(_) => "Await",
254 Expr::Binary(_) => "Binary",
255 Expr::Block(_) => "Block",
256 Expr::Break(_) => "Break",
257 Expr::Call(_) => "Call",
258 Expr::Cast(_) => "Cast",
259 Expr::Closure(_) => "Closure",
260 Expr::Continue(_) => "Continue",
261 Expr::Field(_) => "Field",
262 Expr::ForLoop(_) => "ForLoop",
263 Expr::Group(_) => "Group",
264 Expr::If(_) => "If",
265 Expr::Index(_) => "Index",
266 Expr::Let(_) => "Let",
267 Expr::Lit(_) => "Lit",
268 Expr::Loop(_) => "Loop",
269 Expr::Macro(_) => "Macro",
270 Expr::Match(_) => "Match",
271 Expr::MethodCall(_) => "MethodCall",
272 Expr::Paren(_) => "Paren",
273 Expr::Path(_) => "Path",
274 Expr::Range(_) => "Range",
275 Expr::Reference(_) => "Reference",
276 Expr::Repeat(_) => "Repeat",
277 Expr::Return(_) => "Return",
278 Expr::Struct(_) => "Struct",
279 Expr::Try(_) => "Try",
280 Expr::TryBlock(_) => "TryBlock",
281 Expr::Tuple(_) => "Tuple",
282 Expr::Unary(_) => "Unary",
283 Expr::Unsafe(_) => "Unsafe",
284 Expr::Verbatim(_) => "Verbatim",
285 Expr::While(_) => "While",
286 Expr::Yield(_) => "Yield",
287 _ => "Unknown",
288 }
289}
290
291fn extract_and_stringify_option(expr: &Expr) -> Option<String> {
292 match expr {
293 Expr::Call(ExprCall { func, args, .. }) => {
294 if let Expr::Path(ExprPath { path, .. }) = &**func {
295 let segment = path.segments.last()?;
296 if segment.ident == "Some" {
297 match args.first()? {
298 Expr::Lit(ExprLit { lit, .. }) => match lit {
299 Lit::Str(lit_str) => {
300 return Some(lit_str.value());
301 }
302 other_type_of_literal => {
303 return Some(other_type_of_literal.to_token_stream().to_string());
304 }
305 },
306 first_arg => {
307 return Some(quote!(#first_arg).to_string());
308 }
309 }
310 }
311 }
312 }
313 Expr::Path(ExprPath { path, .. }) if path.get_ident()? == "None" => {
314 return Some("None".into());
315 }
316 _ => {}
317 }
318
319 None
320}
321
322fn handle_ordered_module(mut module: ItemMod) -> TokenStream {
325 if let Some((_, items)) = &mut module.content {
327 for item in items.iter_mut() {
328 if let Item::Fn(func) = item {
329 let has_tanu_test = func.attrs.iter().any(|attr| {
331 if let Some(segment) = attr.path().segments.first() {
332 segment.ident == "tanu"
333 } else {
334 false
335 }
336 });
337
338 if has_tanu_test {
339 for attr in func.attrs.iter_mut() {
341 if let Some(segment) = attr.path().segments.first() {
342 if segment.ident == "tanu" {
343 let attr_span = attr.span();
345 let tokens = attr.meta.require_list().ok().map(|list| {
347 let tokens = &list.tokens;
348 tokens.clone()
349 });
350
351 let new_tokens = if let Some(existing) = tokens {
353 quote::quote_spanned! { attr_span => ordered, #existing }
354 } else {
355 quote::quote_spanned! { attr_span => ordered }
356 };
357
358 *attr = syn::parse_quote_spanned! { attr_span =>
360 #[tanu::test(#new_tokens)]
361 };
362 }
363 }
364 }
365 }
366 }
367 }
368 }
369
370 quote! { #module }.into()
371}
372
373#[proc_macro_attribute]
413pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
414 let input_args = parse_macro_input!(args as Input);
415
416 if let Ok(module) = syn::parse::<ItemMod>(input.clone()) {
418 if input_args.ordered {
419 return handle_ordered_module(module);
420 }
421 return syn::Error::new_spanned(
423 module,
424 "#[tanu::test] on modules requires 'ordered' parameter. Use #[tanu::test(ordered)]",
425 )
426 .to_compile_error()
427 .into();
428 }
429
430 let input_fn = parse_macro_input!(input as ItemFn);
432
433 let func_name_inner = &input_fn.sig.ident;
434 let test_name_str = generate_test_name(&func_name_inner.to_string(), &input_args);
435
436 let args = input_args.args.to_token_stream();
437
438 let serial_group_tokens = if input_args.ordered {
441 quote! { Some(module_path!()) }
442 } else {
443 match &input_args.serial_group {
444 None => quote! { None },
445 Some(s) if s.is_empty() => quote! { Some("") },
446 Some(s) => quote! { Some(#s) },
447 }
448 };
449
450 let ordered = input_args.ordered;
451
452 let error_crate = inspect_error_crate(&input_fn.sig);
462 let output = if error_crate == ErrorCrate::Eyre {
463 quote! {
464 #input_fn
465
466 ::tanu::inventory::submit! {
468 ::tanu::TestRegistration {
469 module: module_path!(),
470 name: #test_name_str,
471 serial_group: #serial_group_tokens,
472 line: line!(),
473 ordered: #ordered,
474 test_fn: || {
475 Box::pin(async move {
476 #func_name_inner(#args).await
477 })
478 },
479 }
480 }
481 }
482 } else {
483 quote! {
484 #input_fn
485
486 ::tanu::inventory::submit! {
488 ::tanu::TestRegistration {
489 module: module_path!(),
490 name: #test_name_str,
491 serial_group: #serial_group_tokens,
492 line: line!(),
493 ordered: #ordered,
494 test_fn: || {
495 Box::pin(async move {
496 #func_name_inner(#args).await.map_err(|e| ::tanu::eyre::eyre!(Box::new(e)))
497 })
498 },
499 }
500 }
501 }
502 };
503
504 output.into()
505}
506
507#[proc_macro_attribute]
545pub fn main(_args: TokenStream, input: TokenStream) -> TokenStream {
546 let main_fn = parse_macro_input!(input as ItemFn);
547
548 let output = quote! {
549 fn run() -> tanu::Runner {
550 let mut runner = tanu::Runner::new();
551
552 for test in ::tanu::inventory::iter::<::tanu::TestRegistration> {
554 runner.add_test(
555 test.name,
556 test.module,
557 test.serial_group,
558 test.line,
559 test.ordered,
560 std::sync::Arc::new(test.test_fn)
561 );
562 }
563
564 runner
565 }
566
567 #main_fn
568 };
569
570 output.into()
571}
572
573#[cfg(test)]
574mod test {
575 use crate::Input;
576
577 use super::{ErrorCrate, Expr};
578 use test_case::test_case;
579
580 #[test_case("fn foo() -> eyre::Result" => ErrorCrate::Eyre; "eyre")]
581 #[test_case("fn foo() -> anyhow::Result" => ErrorCrate::AnythingElse; "anyhow")]
582 #[test_case("fn foo() -> miette::Result" => ErrorCrate::AnythingElse; "miette")]
583 #[test_case("fn foo() -> Result" => ErrorCrate::AnythingElse; "std_result")]
584 fn inspect_error_crate(s: &str) -> ErrorCrate {
585 let sig: syn::Signature = syn::parse_str(s).expect("failed to parse function signature");
586 super::inspect_error_crate(&sig)
587 }
588
589 #[test_case("Some(1)" => Some("1".into()); "Some with int")]
590 #[test_case("Some(\"test\")" => Some("test".into()); "Some with string")]
591 #[test_case("Some(true)" => Some("true".into()); "Some with boolean")]
592 #[test_case("Some(1.0)" => Some("1.0".into()); "Some with float")]
593 #[test_case("Some(StatusCode::OK)" => Some("StatusCode :: OK".into()); "Some third party type")]
594 #[test_case("Some(\"foo\".to_string())" => Some("\"foo\" . to_string ()".into()); "Some expression")]
595 #[test_case("None" => Some("None".into()); "None")]
596 fn extract_and_stringify_option(s: &str) -> Option<String> {
597 let expr: Expr = syn::parse_str(s).expect("failed to parse expression");
598 super::extract_and_stringify_option(&expr)
599 }
600
601 #[allow(clippy::erasing_op)]
602 #[test_case("a, b; \"test_name\"" => "foo::test_name"; "with test name")]
603 #[test_case("1+1" => "foo::1_add_1"; "with add expression")]
604 #[test_case("1+=1" => "foo::1_add_1"; "with add assignment expression")]
605 #[test_case("1-1" => "foo::1_sub_1"; "with sub expression")]
606 #[test_case("1-=1" => "foo::1_sub_1"; "with sub assignment expression")]
607 #[test_case("1/1" => "foo::1_div_1"; "with div expression")]
608 #[test_case("1/=1" => "foo::1_div_1"; "with div assignment expression")]
609 #[test_case("1*1" => "foo::1_mul_1"; "with mul expression")]
610 #[test_case("1*=1" => "foo::1_mul_1"; "with mul assignment expression")]
611 #[test_case("1%1" => "foo::1_mod_1"; "with mod expression")]
612 #[test_case("1%=1" => "foo::1_mod_1"; "with mod assignment expression")]
613 #[test_case("1==1" => "foo::1_eq_1"; "with eq expression")]
614 #[test_case("1!=1" => "foo::1_nq_1"; "with neq expression")]
615 #[test_case("1<1" => "foo::1_lt_1"; "with lt expression")]
616 #[test_case("1>1" => "foo::1_gt_1"; "with gt expression")]
617 #[test_case("1<=1" => "foo::1_le_1"; "with le expression")]
618 #[test_case("1>=1" => "foo::1_ge_1"; "with ge expression")]
619 #[test_case("true&&false" => "foo::true_and_false"; "with and expression")]
620 #[test_case("true||false" => "foo::true_or_false"; "with or expression")]
621 #[test_case("!true" => "foo::not_true"; "with not expression")]
622 #[test_case("1&1" => "foo::1_and_1"; "with bitwise and expression")]
623 #[test_case("1&=1" => "foo::1_and_1"; "with bitwise and assignment expression")]
624 #[test_case("1|1" => "foo::1_or_1"; "with bitwise or expression")]
625 #[test_case("1|=1" => "foo::1_or_1"; "with bitwise or assignment expression")]
626 #[test_case("1^1" => "foo::1_xor_1"; "with xor expression")]
627 #[test_case("1^=1" => "foo::1_xor_1"; "with xor assignment expression")]
628 #[test_case("1<<1" => "foo::1_lshift_1"; "with left shift expression")]
629 #[test_case("1<<=1" => "foo::1_lshift_1"; "with left shift assignment expression")]
630 #[test_case("1>>1" => "foo::1_rshift_1"; "with right shift expression")]
631 #[test_case("1>>=1" => "foo::1_rshift_1"; "with right shift assignment expression")]
632 #[test_case("\"bar\".to_string()" => "foo::bar_to_string"; "to_string")]
633 #[test_case("1+1*2" => "foo::1_add_1_mul_2"; "with add and mul expression")]
634 #[test_case("1*(2+3)" => "foo::1_mul_2_add_3"; "with mul and add expression")]
635 #[test_case("1+2-3" => "foo::1_add_2_sub_3"; "with add and sub expression")]
636 #[test_case("1/2*3" => "foo::1_div_2_mul_3"; "with div and mul expression")]
637 #[test_case("1%2+3" => "foo::1_mod_2_add_3"; "with mod and add expression")]
638 #[test_case("1==2&&3!=4" => "foo::1_eq_2_and_3_nq_4"; "with eq and and expression")]
639 #[test_case("true||false&&true" => "foo::true_or_false_and_true"; "with or and and expression")]
640 #[test_case("!(1+2)" => "foo::not_1_add_2"; "with not and add expression")]
641 #[test_case("1&2|3^4" => "foo::1_and_2_or_3_xor_4"; "with bitwise and, or, xor expression")]
642 #[test_case("1<<2>>3" => "foo::1_lshift_2_rshift_3"; "with left shift and right shift expression")]
643 #[test_case("Some(1+2)" => "foo::1_add_2"; "with Some and add expression")]
644 #[test_case("None" => "foo::none"; "with None")]
645 #[test_case("[1, 2]" => "foo::1_2"; "with array")]
646 #[test_case("vec![1, 2]" => "foo::vecnot_1_2"; "with macro")] #[test_case("\"foo\".to_string().len()" => "foo::foo_to_string_len"; "with function call chain")]
648 #[test_case("0.5+0.3" => "foo::0_5_add_0_3"; "with floating point add")] #[test_case("-10" => "foo::_sub_10"; "with negative number")] #[test_case("1.0e10" => "foo::1_0e10"; "with scientific notation")] #[test_case("0xff" => "foo::0xff"; "with hex literal")]
652 #[test_case("0o777" => "foo::0o777"; "with octal literal")]
653 #[test_case("0b1010" => "foo::0b1010"; "with binary literal")]
654 #[test_case("\"hello\" + \"world\"" => "foo::hello_add_world"; "with string concatenation")]
655 #[test_case("format!(\"{}{}\", 1, 2)" => "foo::formatnot__1_2"; "with format macro")] #[test_case("r#\"raw string\"#" => "foo::rawstring"; "with raw string")]
657 #[test_case("(1, \"hello\", true)" => "foo::1_hello_true"; "with mixed tuple")]
659 #[test_case("vec![1..5]" => "foo::vecnot_1__5"; "with range in macro")]
663 #[test_case("x.map(|v| v+1)" => "foo::x_map_or_v_or_v_add_1"; "with closure")]
665 #[test_case("a.into()" => "foo::a_into"; "with method call no args")]
666 #[test_case("a.parse::<i32>().unwrap()" => "foo::a_parse__lt_i32_gt__unwrap"; "with turbofish syntax")]
668 #[test_case("1..10" => "foo::1__10"; "with range")]
675 fn generate_test_name(args: &str) -> String {
679 let input_args: Input = syn::parse_str(args).expect("failed to parse input args");
680 super::generate_test_name("foo", &input_args)
681 }
682}