lunatic_test/
lib.rs

1#[allow(unused_extern_crates)]
2extern crate proc_macro;
3
4use proc_macro::TokenStream;
5use quote::{quote, quote_spanned};
6use syn::spanned::Spanned;
7
8/// Marks function to be executed by the lunatic runtime as a unit test. This is
9/// a drop-in replacement for the standard `#[test]` attribute macro.
10#[proc_macro_attribute]
11pub fn test(_args: TokenStream, item: TokenStream) -> TokenStream {
12    let input = syn::parse_macro_input!(item as syn::ItemFn);
13    let original_input = input.clone();
14    let attributes = &input.attrs;
15    let span = input.span();
16
17    // Check if #[should_panic] attribute is present.
18    let mut should_panic = None;
19    let mut ignore = "";
20    for attribute in attributes.iter() {
21        if let Some(ident) = attribute.path.get_ident() {
22            if ident == "ignore" {
23                ignore = "#ignore_";
24            }
25            if ident == "should_panic" {
26                // Common error message
27                let error = syn::Error::new_spanned(
28                    &attribute.tokens,
29                    "argument must be of the form: `expected = \"error message\"`",
30                )
31                .to_compile_error()
32                .into();
33
34                let attribute_args = match attribute.parse_meta() {
35                    Ok(args) => args,
36                    Err(_) => return error,
37                };
38
39                let attribute_args = match attribute_args {
40                    syn::Meta::List(attribute_args) => attribute_args,
41                    syn::Meta::Path(_) => {
42                        // Match any panic if no expected value is provided
43                        should_panic = Some("".to_string());
44                        continue;
45                    }
46                    _ => return error,
47                };
48
49                // should_panic can have at most one argument
50                if attribute_args.nested.len() > 1 {
51                    return error;
52                }
53                // The first argument can only be in the format 'expected = "partial matcher"'
54                if let Some(argument) = attribute_args.nested.iter().next() {
55                    match argument {
56                        syn::NestedMeta::Meta(syn::Meta::NameValue(name_value)) => {
57                            if let Some(ident) = name_value.path.get_ident() {
58                                if ident != "expected" {
59                                    return error;
60                                }
61                                match &name_value.lit {
62                                    syn::Lit::Str(lit) => {
63                                        // Mark function as should_panic
64                                        should_panic = Some(lit.value())
65                                    }
66                                    _ => return error,
67                                }
68                            } else {
69                                return error;
70                            }
71                        }
72                        _ => return error,
73                    }
74                };
75            }
76        }
77    }
78
79    let mut export_name = format!("#lunatic_test_{}", ignore);
80    if let Some(ref panic_str) = should_panic {
81        // Escape # in panic_str
82        let panic_str = panic_str.replace('#', "\\#");
83        export_name = format!("{}#panic_{}#", export_name, panic_str,);
84    }
85    let function_name = input.sig.ident.to_string();
86
87    let name = input.sig.ident;
88    let arguments = input.sig.inputs;
89    let output = input.sig.output;
90    let block = input.block;
91
92    // `#[should_panic]` can't be combined with `Result`.
93    match output {
94        syn::ReturnType::Type(_, _) => {
95            if should_panic.is_some() {
96                return quote_spanned! {
97                    span => compile_error!("functions using `#[should_panic]` must return `()`");
98                }
99                .into();
100            }
101        }
102        syn::ReturnType::Default => (),
103    }
104
105    let mailbox = if !arguments.is_empty() {
106        quote! { lunatic::Mailbox::new() }
107    } else {
108        quote! {}
109    };
110
111    let wasm32_test = quote! {
112        fn #name() {
113            fn __with_mailbox(#arguments) #output {
114                #block
115            }
116            let result = unsafe { __with_mailbox(#mailbox) };
117            lunatic::test::assert_test_result(result);
118        }
119    };
120
121    quote! {
122        // If not compiling for wasm32, fall back to #[test]
123        #[cfg_attr(not(target_arch = "wasm32"), ::core::prelude::v1::test)]
124        #[cfg(not(target_arch = "wasm32"))]
125        #original_input
126
127        #[cfg_attr(
128            target_arch = "wasm32",
129            export_name = concat!(#export_name, module_path!(), "::", #function_name)
130        )]
131        #[cfg(target_arch = "wasm32")]
132        #wasm32_test
133    }
134    .into()
135}