rtest_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use rtest_util::TestArguments;
4use std::sync::{Arc, Mutex, OnceLock};
5
6struct TestInfo {
7    fn_name: String,
8    args:    TestArguments,
9    input:   String,
10}
11
12static FUNCTION_NAMES: OnceLock<Arc<Mutex<Vec<TestInfo>>>> = OnceLock::new();
13
14#[proc_macro_derive(Resource)]
15pub fn resource(input: TokenStream) -> TokenStream {
16    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
17    let name = &ast.ident;
18    let name_str = name.to_string();
19
20    let code = quote!(
21        impl rtest::Resource for #name {
22            type Context = rtest::Context;
23
24            fn from_context(context: &Self::Context) -> Option<Self> {
25                context.extract::<#name>()
26            }
27
28            fn into_context(context: &Self::Context, resource: #name) {
29                context.inject(resource)
30            }
31
32            fn get_resource_id() -> rtest::ResourceId {
33                #name_str.to_string()
34            }
35        }
36    );
37    code.into()
38}
39
40#[proc_macro_attribute]
41pub fn rtest(attr: TokenStream, orig_input: TokenStream) -> TokenStream {
42    use darling::FromMeta;
43
44    let input = orig_input.clone();
45    let fn_ast = syn::parse_macro_input!(input as syn::ItemFn);
46
47    let attr_args = match darling::ast::NestedMeta::parse_meta_list(attr.into()) {
48        Ok(v) => v,
49        Err(e) => {
50            return TokenStream::from(darling::Error::from(e).write_errors());
51        },
52    };
53
54    let args = match TestArguments::from_list(&attr_args) {
55        Ok(v) => v,
56        Err(e) => {
57            return TokenStream::from(e.write_errors());
58        },
59    };
60
61    {
62        let mutex = FUNCTION_NAMES.get_or_init(|| Arc::new(Mutex::new(Vec::new())));
63        let mut list = mutex.lock().unwrap();
64        list.push(TestInfo {
65            fn_name: fn_ast.sig.ident.to_string(),
66            args,
67            input: orig_input.to_string(),
68        })
69    }
70    orig_input
71}
72
73#[proc_macro]
74/// input must be the RunConfig
75pub fn run(input: TokenStream) -> TokenStream {
76    use std::str::FromStr;
77    let list: Vec<_> = {
78        let mutex = FUNCTION_NAMES.get_or_init(|| Arc::new(Mutex::new(Vec::new())));
79        mutex.lock().unwrap().drain(..).collect()
80    };
81
82    let input_struct = syn::parse_macro_input!(input as syn::Path);
83
84    let mut commands = quote!();
85    for (i, l) in list.into_iter().enumerate() {
86        let name = l.fn_name.to_string();
87        let input = TokenStream::from_str(&l.input).unwrap();
88        let ast = syn::parse_macro_input!(input as syn::ItemFn);
89        let fn_ident = ast.sig.ident;
90        let relative_fn_ident = l
91            .args
92            .module
93            .as_ref()
94            .and_then(|module| module.split_once("::"))
95            .map(|module| format!("crate::{}::{}", module.1, &fn_ident.to_string()))
96            .unwrap_or(fn_ident.to_string());
97        let relative_fn_ident = TokenStream::from_str(&relative_fn_ident).unwrap();
98        let relative_fn_ident = syn::parse_macro_input!(relative_fn_ident as syn::Path);
99        let testargs = l.args;
100
101        let f = quote! {
102            let c = #input_struct.context.clone();
103            let handler_params = rtest::HandlerParams::from(&#input_struct);
104            let (reference, inputs, outputs) = rtest::describe_handler(& #relative_fn_ident);
105            test_repo.add(#i, #name, Box::new(move || {
106                rtest::call_handler(c.clone(), &mut #relative_fn_ident, &handler_params)
107            }), #testargs, reference, inputs, outputs);
108        };
109        commands = quote!(
110            #commands
111            #f
112        );
113    }
114
115    let code = quote!(
116        {
117            use rtest::{Runner, Printer, TestArguments, Persister};
118            let #input_struct: rtest::RunConfig<_> = #input_struct;
119
120            let mut test_repo = rtest::TestRepo::default();
121            #commands
122            rtest::main_run(test_repo, #input_struct)
123        }
124    );
125
126    code.into()
127}