tanu_derive/
lib.rs

1//! # Tanu Derive
2//!
3//! Procedural macros for the tanu WebAPI testing framework.
4//!
5//! This crate provides the `#[tanu::test]` and `#[tanu::main]` procedural macros
6//! that enable the core functionality of tanu's test discovery and execution system.
7//!
8//! ## Macros
9//!
10//! - `#[tanu::test]` - Marks async functions as tanu test cases
11//! - `#[tanu::test(param)]` - Creates parameterized test cases  
12//! - `#[tanu::main]` - Generates the main function for test discovery
13//!
14//! These macros are automatically re-exported by the main `tanu` crate,
15//! so users typically don't need to import this crate directly.
16
17extern crate proc_macro;
18
19use proc_macro::TokenStream;
20use quote::{quote, ToTokens};
21use syn::{
22    parse::Parse, parse_macro_input, punctuated::Punctuated, Expr, ExprCall, ExprLit, ExprPath,
23    ItemFn, Lit, LitStr, ReturnType, Signature, Token, Type,
24};
25
26/// Represents arguments in the test attribute #[test(a, b; c)].
27struct Input {
28    /// Test arguments specified in the test attribute.
29    args: Punctuated<Expr, Token![,]>,
30    /// Test name specified in the test attribute.
31    name: Option<LitStr>,
32}
33
34impl Parse for Input {
35    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
36        if input.is_empty() {
37            Ok(Input {
38                args: Default::default(),
39                name: None,
40            })
41        } else {
42            let args: Punctuated<Expr, Token![,]> =
43                Punctuated::parse_separated_nonempty_with(input, Expr::parse)?;
44
45            let name = if input.parse::<Token![;]>().is_ok() {
46                input.parse::<LitStr>().ok()
47            } else {
48                None
49            };
50
51            Ok(Input { args, name })
52        }
53    }
54}
55
56/// - If a test name argument is provided (e.g., `#[test(a; xxx)]`), use it as the function name.
57/// - Otherwise, generate a function name by concatenating the test parameters with `_`.
58fn generate_test_name(org_func_name: &str, input: &Input) -> String {
59    let func_name = org_func_name.to_string();
60
61    if input.args.is_empty() {
62        return func_name.to_string();
63    }
64
65    let stringified_args = match &input.name {
66        Some(name_argument) => name_argument.value(),
67        _ => input
68            .args
69            .iter()
70            .filter_map(|expr| match expr {
71                Expr::Lit(ExprLit { lit, .. }) => match lit {
72                    Lit::Str(lit_str) => Some(lit_str.value()),
73                    other_literal => Some(quote!(#other_literal).to_string()),
74                },
75                expr @ Expr::Path(_) | expr @ Expr::Call(_) => extract_and_stringify_option(expr),
76                other_expr => Some(quote!(#other_expr).to_string()),
77            })
78            .map(|s| {
79                s.replace("+=", "_add_")
80                    .replace("+", "_add_")
81                    .replace("-=", "_sub_")
82                    .replace("-", "_sub_")
83                    .replace("/=", "_div_")
84                    .replace("/", "_div_")
85                    .replace("*=", "_mul_")
86                    .replace("*", "_mul_")
87                    .replace("%=", "_mod_")
88                    .replace("%", "_mod_")
89                    .replace("==", "_eq_")
90                    .replace("!=", "_nq_")
91                    .replace("&&", "_and_")
92                    .replace("||", "_or_")
93                    .replace("!", "not_")
94                    .replace("&=", "_and_")
95                    .replace("&", "_and_")
96                    .replace("|=", "_or_")
97                    .replace("|", "_or_")
98                    .replace("^=", "_xor_")
99                    .replace("^", "_xor_")
100                    .replace("<<=", "_lshift_")
101                    .replace("<<", "_lshift_")
102                    .replace("<=", "_le_")
103                    .replace("<", "_lt_")
104                    .replace(">>=", "_rshift_")
105                    .replace(">>", "_rshift_")
106                    .replace(">=", "_ge_")
107                    .replace(">", "_gt_")
108                    .replace("&mut ", "")
109                    .replace("*mut ", "")
110                    .replace("&", "")
111                    .replace("*", "")
112                    .replace(" :: ", "_")
113                    .replace("\\", "")
114                    .replace("/", "")
115                    .replace("\"", "")
116                    .replace("(", "")
117                    .replace(")", "")
118                    .replace("{", "")
119                    .replace("}", "")
120                    .replace("[", "")
121                    .replace("]", "")
122                    .replace(" ", "")
123                    .replace(",", "_")
124                    .replace(".", "_")
125                    .to_lowercase()
126            })
127            .collect::<Vec<_>>()
128            .join("_"),
129    };
130
131    format!("{func_name}::{stringified_args}")
132}
133
134#[derive(Debug, Eq, PartialEq)]
135enum ErrorCrate {
136    Eyre,
137    AnythingElse,
138}
139
140/// Inspects the current function's signature to determine which error crate is being used.
141///
142/// This function analyzes the return type of the function to detect whether it is using
143/// `eyre::Result` or another error result type. It then enables conditional handling based
144/// on the error crate in use (e.g., wrapping non-`eyre::Result` types in an `eyre::Result`).
145///
146/// **Limitation:**
147/// Due to the inherent limitations of proc macros, this function can only detect error types
148/// when `eyre` is referenced using its fully qualified path (for example, `eyre::Result`).
149///
150/// For further details and discussion on this limitation, see:
151/// https://users.rust-lang.org/t/in-a-proc-macro-attribute-procedural-macro-how-to-get-the-full-typepath-of-some-type/107713/2
152fn inspect_error_crate(sig: &Signature) -> ErrorCrate {
153    match &sig.output {
154        ReturnType::Default => panic!("return type needs to be other than ()"),
155        ReturnType::Type(_, ty) => {
156            let Type::Path(type_path) = ty.as_ref() else {
157                panic!("failed to get return type path");
158            };
159
160            let path = &type_path.path;
161            match (path.segments.first(), path.segments.last()) {
162                (Some(first), Some(last)) => {
163                    if first.ident == "eyre" && last.ident == "Result" {
164                        ErrorCrate::Eyre
165                    } else {
166                        ErrorCrate::AnythingElse
167                    }
168                }
169                _ => {
170                    panic!("unexpected return type");
171                }
172            }
173        }
174    }
175}
176
177#[allow(dead_code)]
178/// Returns the name of the variant of the given expression.
179fn get_expr_variant_name(expr: &Expr) -> &'static str {
180    match expr {
181        Expr::Array(_) => "Array",
182        Expr::Assign(_) => "Assign",
183        Expr::Async(_) => "Async",
184        Expr::Await(_) => "Await",
185        Expr::Binary(_) => "Binary",
186        Expr::Block(_) => "Block",
187        Expr::Break(_) => "Break",
188        Expr::Call(_) => "Call",
189        Expr::Cast(_) => "Cast",
190        Expr::Closure(_) => "Closure",
191        Expr::Continue(_) => "Continue",
192        Expr::Field(_) => "Field",
193        Expr::ForLoop(_) => "ForLoop",
194        Expr::Group(_) => "Group",
195        Expr::If(_) => "If",
196        Expr::Index(_) => "Index",
197        Expr::Let(_) => "Let",
198        Expr::Lit(_) => "Lit",
199        Expr::Loop(_) => "Loop",
200        Expr::Macro(_) => "Macro",
201        Expr::Match(_) => "Match",
202        Expr::MethodCall(_) => "MethodCall",
203        Expr::Paren(_) => "Paren",
204        Expr::Path(_) => "Path",
205        Expr::Range(_) => "Range",
206        Expr::Reference(_) => "Reference",
207        Expr::Repeat(_) => "Repeat",
208        Expr::Return(_) => "Return",
209        Expr::Struct(_) => "Struct",
210        Expr::Try(_) => "Try",
211        Expr::TryBlock(_) => "TryBlock",
212        Expr::Tuple(_) => "Tuple",
213        Expr::Unary(_) => "Unary",
214        Expr::Unsafe(_) => "Unsafe",
215        Expr::Verbatim(_) => "Verbatim",
216        Expr::While(_) => "While",
217        Expr::Yield(_) => "Yield",
218        _ => "Unknown",
219    }
220}
221
222fn extract_and_stringify_option(expr: &Expr) -> Option<String> {
223    match expr {
224        Expr::Call(ExprCall { func, args, .. }) => {
225            if let Expr::Path(ExprPath { path, .. }) = &**func {
226                let segment = path.segments.last()?;
227                if segment.ident == "Some" {
228                    match args.first()? {
229                        Expr::Lit(ExprLit { lit, .. }) => match lit {
230                            Lit::Str(lit_str) => {
231                                return Some(lit_str.value());
232                            }
233                            other_type_of_literal => {
234                                return Some(other_type_of_literal.to_token_stream().to_string());
235                            }
236                        },
237                        first_arg => {
238                            return Some(quote!(#first_arg).to_string());
239                        }
240                    }
241                }
242            }
243        }
244        Expr::Path(ExprPath { path, .. }) => {
245            if path.get_ident()? == "None" {
246                return Some("None".into());
247            }
248        }
249        _ => {}
250    }
251
252    None
253}
254
255/// Marks an async function as a tanu test case.
256///
257/// This attribute registers the function with tanu's test discovery system,
258/// making it available for execution via the test runner.
259///
260/// # Basic Usage
261///
262/// ```rust,ignore
263/// #[tanu::test]
264/// async fn my_test() -> eyre::Result<()> {
265///     // Test implementation
266///     Ok(())
267/// }
268/// ```
269///
270/// # Parameterized Tests
271///
272/// The macro supports parameterized testing by accepting arguments:
273///
274/// ```rust,ignore
275/// #[tanu::test(200)]
276/// #[tanu::test(404)]
277/// #[tanu::test(500)]
278/// async fn test_status_codes(status: u16) -> eyre::Result<()> {
279///     // Test with different status codes
280///     Ok(())
281/// }
282/// ```
283///
284/// # Requirements
285///
286/// - Function must be `async`
287/// - Function must return a `Result<T, E>` type
288/// - Supported Result types: `eyre::Result`, `anyhow::Result`, `std::result::Result`
289///
290/// # Error Handling
291///
292/// The macro automatically handles different Result types and integrates
293/// with tanu's error reporting system for enhanced error messages and backtraces.
294#[proc_macro_attribute]
295pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
296    let input_args = parse_macro_input!(args as Input);
297    let input_fn = parse_macro_input!(input as ItemFn);
298
299    let func_name_inner = &input_fn.sig.ident;
300    let test_name_str = generate_test_name(&func_name_inner.to_string(), &input_args);
301
302    let args = input_args.args.to_token_stream();
303
304    // tanu internally relies on the `eyre` and `color-eyre` crates for error handling.
305    // since `tanu::Runner` expects test functions to return an `eyre::Result`, the macro
306    // generates two types of code.
307    //
308    // - If a test function explicitly returns `eyre::Result`, the macro will generate
309    //   a function that also returns `eyre::Result` without modification.
310    //
311    // - If the test function returns another result type (e.g., `anyhow::Result`),
312    //   the macro will automatically wrap the return value in an `eyre::Result`.
313    let error_crate = inspect_error_crate(&input_fn.sig);
314    let output = if error_crate == ErrorCrate::Eyre {
315        quote! {
316            #input_fn
317
318            // Submit test to inventory for discovery
319            ::tanu::inventory::submit! {
320                ::tanu::TestRegistration {
321                    module: module_path!(),
322                    name: #test_name_str,
323                    test_fn: || {
324                        Box::pin(async move {
325                            #func_name_inner(#args).await
326                        })
327                    },
328                }
329            }
330        }
331    } else {
332        quote! {
333            #input_fn
334
335            // Submit test to inventory for discovery
336            ::tanu::inventory::submit! {
337                ::tanu::TestRegistration {
338                    module: module_path!(),
339                    name: #test_name_str,
340                    test_fn: || {
341                        Box::pin(async move {
342                            #func_name_inner(#args).await.map_err(|e| ::tanu::eyre::eyre!(Box::new(e)))
343                        })
344                    },
345                }
346            }
347        }
348    };
349
350    output.into()
351}
352
353/// Generates the test discovery and registration code for tanu.
354///
355/// This attribute should be applied to your main function alongside `#[tokio::main]`.
356/// It automatically discovers all functions marked with `#[tanu::test]` and registers
357/// them with the test runner.
358///
359/// # Usage
360///
361/// ```rust,ignore
362/// #[tanu::main]
363/// #[tokio::main]
364/// async fn main() -> eyre::Result<()> {
365///     let runner = run();
366///     let app = tanu::App::new();
367///     app.run(runner).await?;
368///     Ok(())
369/// }
370/// ```
371///
372/// # What It Does
373///
374/// The macro performs compile-time test discovery by:
375/// 1. Scanning the codebase for `#[tanu::test]` annotated functions
376/// 2. Generating a `run()` function that returns a configured `Runner`
377/// 3. Registering all discovered tests with the runner
378/// 4. Setting up proper module organization and test metadata
379///
380/// # Requirements
381///
382/// - Must be used with `#[tokio::main]` for async support
383/// - The main function should return a `Result` type
384/// - All test functions must be marked with `#[tanu::test]`
385///
386/// # Generated Code
387///
388/// The macro generates a `run()` function that you can use to obtain
389/// a pre-configured test runner with all your tests registered.
390#[proc_macro_attribute]
391pub fn main(_args: TokenStream, input: TokenStream) -> TokenStream {
392    let main_fn = parse_macro_input!(input as ItemFn);
393
394    let output = quote! {
395        fn run() -> tanu::Runner {
396            let mut runner = tanu::Runner::new();
397
398            // Use inventory to discover all registered tests
399            for test in ::tanu::inventory::iter::<::tanu::TestRegistration> {
400                runner.add_test(
401                    test.name,
402                    test.module,
403                    std::sync::Arc::new(test.test_fn)
404                );
405            }
406
407            runner
408        }
409
410        #main_fn
411    };
412
413    output.into()
414}
415
416#[cfg(test)]
417mod test {
418    use crate::Input;
419
420    use super::{ErrorCrate, Expr};
421    use test_case::test_case;
422
423    #[test_case("fn foo() -> eyre::Result" => ErrorCrate::Eyre; "eyre")]
424    #[test_case("fn foo() -> anyhow::Result" => ErrorCrate::AnythingElse; "anyhow")]
425    #[test_case("fn foo() -> miette::Result" => ErrorCrate::AnythingElse; "miette")]
426    #[test_case("fn foo() -> Result" => ErrorCrate::AnythingElse; "std_result")]
427    fn inspect_error_crate(s: &str) -> ErrorCrate {
428        let sig: syn::Signature = syn::parse_str(s).expect("failed to parse function signature");
429        super::inspect_error_crate(&sig)
430    }
431
432    #[test_case("Some(1)" => Some("1".into()); "Some with int")]
433    #[test_case("Some(\"test\")" => Some("test".into()); "Some with string")]
434    #[test_case("Some(true)" => Some("true".into()); "Some with boolean")]
435    #[test_case("Some(1.0)" => Some("1.0".into()); "Some with float")]
436    #[test_case("Some(StatusCode::OK)" => Some("StatusCode :: OK".into()); "Some third party type")]
437    #[test_case("Some(\"foo\".to_string())" => Some("\"foo\" . to_string ()".into()); "Some expression")]
438    #[test_case("None" => Some("None".into()); "None")]
439    fn extract_and_stringify_option(s: &str) -> Option<String> {
440        let expr: Expr = syn::parse_str(s).expect("failed to parse expression");
441        super::extract_and_stringify_option(&expr)
442    }
443
444    #[allow(clippy::erasing_op)]
445    #[test_case("a, b; \"test_name\"" => "foo::test_name"; "with test name")]
446    #[test_case("1+1" => "foo::1_add_1"; "with add expression")]
447    #[test_case("1+=1" => "foo::1_add_1"; "with add assignment expression")]
448    #[test_case("1-1" => "foo::1_sub_1"; "with sub expression")]
449    #[test_case("1-=1" => "foo::1_sub_1"; "with sub assignment expression")]
450    #[test_case("1/1" => "foo::1_div_1"; "with div expression")]
451    #[test_case("1/=1" => "foo::1_div_1"; "with div assignment expression")]
452    #[test_case("1*1" => "foo::1_mul_1"; "with mul expression")]
453    #[test_case("1*=1" => "foo::1_mul_1"; "with mul assignment expression")]
454    #[test_case("1%1" => "foo::1_mod_1"; "with mod expression")]
455    #[test_case("1%=1" => "foo::1_mod_1"; "with mod assignment expression")]
456    #[test_case("1==1" => "foo::1_eq_1"; "with eq expression")]
457    #[test_case("1!=1" => "foo::1_nq_1"; "with neq expression")]
458    #[test_case("1<1" => "foo::1_lt_1"; "with lt expression")]
459    #[test_case("1>1" => "foo::1_gt_1"; "with gt expression")]
460    #[test_case("1<=1" => "foo::1_le_1"; "with le expression")]
461    #[test_case("1>=1" => "foo::1_ge_1"; "with ge expression")]
462    #[test_case("true&&false" => "foo::true_and_false"; "with and expression")]
463    #[test_case("true||false" => "foo::true_or_false"; "with or expression")]
464    #[test_case("!true" => "foo::not_true"; "with not expression")]
465    #[test_case("1&1" => "foo::1_and_1"; "with bitwise and expression")]
466    #[test_case("1&=1" => "foo::1_and_1"; "with bitwise and assignment expression")]
467    #[test_case("1|1" => "foo::1_or_1"; "with bitwise or expression")]
468    #[test_case("1|=1" => "foo::1_or_1"; "with bitwise or assignment expression")]
469    #[test_case("1^1" => "foo::1_xor_1"; "with xor expression")]
470    #[test_case("1^=1" => "foo::1_xor_1"; "with xor assignment expression")]
471    #[test_case("1<<1" => "foo::1_lshift_1"; "with left shift expression")]
472    #[test_case("1<<=1" => "foo::1_lshift_1"; "with left shift assignment expression")]
473    #[test_case("1>>1" => "foo::1_rshift_1"; "with right shift expression")]
474    #[test_case("1>>=1" => "foo::1_rshift_1"; "with right shift assignment expression")]
475    #[test_case("\"bar\".to_string()" => "foo::bar_to_string"; "to_string")]
476    #[test_case("1+1*2" => "foo::1_add_1_mul_2"; "with add and mul expression")]
477    #[test_case("1*(2+3)" => "foo::1_mul_2_add_3"; "with mul and add expression")]
478    #[test_case("1+2-3" => "foo::1_add_2_sub_3"; "with add and sub expression")]
479    #[test_case("1/2*3" => "foo::1_div_2_mul_3"; "with div and mul expression")]
480    #[test_case("1%2+3" => "foo::1_mod_2_add_3"; "with mod and add expression")]
481    #[test_case("1==2&&3!=4" => "foo::1_eq_2_and_3_nq_4"; "with eq and and expression")]
482    #[test_case("true||false&&true" => "foo::true_or_false_and_true"; "with or and and expression")]
483    #[test_case("!(1+2)" => "foo::not_1_add_2"; "with not and add expression")]
484    #[test_case("1&2|3^4" => "foo::1_and_2_or_3_xor_4"; "with bitwise and, or, xor expression")]
485    #[test_case("1<<2>>3" => "foo::1_lshift_2_rshift_3"; "with left shift and right shift expression")]
486    #[test_case("Some(1+2)" => "foo::1_add_2"; "with Some and add expression")]
487    #[test_case("None" => "foo::none"; "with None")]
488    #[test_case("[1, 2]" => "foo::1_2"; "with array")]
489    #[test_case("vec![1, 2]" => "foo::vecnot_1_2"; "with macro")] // TODO should parse macro so that it won't have "not"
490    #[test_case("\"foo\".to_string().len()" => "foo::foo_to_string_len"; "with function call chain")]
491    #[test_case("0.5+0.3" => "foo::0_5_add_0_3"; "with floating point add")] // TODO should be foo::05_add_03
492    #[test_case("-10" => "foo::_sub_10"; "with negative number")] // TODO should be neg_10
493    #[test_case("1.0e10" => "foo::1_0e10"; "with scientific notation")] // TODO should be foo::10e10
494    #[test_case("0xff" => "foo::0xff"; "with hex literal")]
495    #[test_case("0o777" => "foo::0o777"; "with octal literal")]
496    #[test_case("0b1010" => "foo::0b1010"; "with binary literal")]
497    #[test_case("\"hello\" + \"world\"" => "foo::hello_add_world"; "with string concatenation")]
498    #[test_case("format!(\"{}{}\", 1, 2)" => "foo::formatnot__1_2"; "with format macro")] // TODO should be format_1_2
499    #[test_case("r#\"raw string\"#" => "foo::rawstring"; "with raw string")]
500    //#[test_case("\n\t\r" => "foo::n_t_r"; "with escape sequences")] // TODO this does not work yet
501    #[test_case("(1, \"hello\", true)" => "foo::1_hello_true"; "with mixed tuple")]
502    //#[test_case("HashSet::from([1, 2, 3])" => "foo::hashsetfrom_1_2_3"; "with collection construction")] // TODO should be 1_2_3
503    //#[test_case("add(1, 2)" => "foo::add1_2"; "with function call")] // This does not work
504    //#[test_case("HashSet::from([1, 2, 3])" => "foo::hashsetfrom_1_2_3"; "with collection construction")] // TODO should be 1_2_3
505    #[test_case("vec![1..5]" => "foo::vecnot_1__5"; "with range in macro")]
506    //#[test_case("add(1, 2)" => "foo::add1_2"; "with function call")] // This does not work
507    #[test_case("x.map(|v| v+1)" => "foo::x_map_or_v_or_v_add_1"; "with closure")]
508    #[test_case("a.into()" => "foo::a_into"; "with method call no args")]
509    // should be a_parse_i32_unwrap
510    #[test_case("a.parse::<i32>().unwrap()" => "foo::a_parse__lt_i32_gt__unwrap"; "with turbofish syntax")]
511    // #[test_case("if x { 1 } else { 2 }" => "foo::if_x_1_else_2"; "with if expression")]
512    // #[test_case("match x { Some(v) => v, None => 0 }" => "foo::match_x_somev_v_none_0"; "with match expression")]
513    //#[test_case("Box::new(1)" => "foo::boxnew_1"; "with box allocation")]
514    //#[test_case("Rc::new(vec![1, 2])" => "foo::rcnew_vecnot_1_2"; "with reference counting")]
515    //#[test_case("<Vec<i32> as IntoIterator>::into_iter" => "foo::veci32_as_intoiterator_into_iter"; "with type casting")]
516    // TODO should be 1_10
517    #[test_case("1..10" => "foo::1__10"; "with range")]
518    //#[test_case("1..=10" => "foo::1_10"; "with inclusive range")]
519    //#[test_case("..10" => "foo::_10"; "with range to")]
520    //#[test_case("10.." => "foo::10_"; "with range from")]
521    fn generate_test_name(args: &str) -> String {
522        let input_args: Input = syn::parse_str(args).expect("failed to parse input args");
523        super::generate_test_name("foo", &input_args)
524    }
525}