1use std::mem;
2
3use attribute::Input;
4use quote::{format_ident, quote, ToTokens};
5use syn::{
6 parse_macro_input,
7 punctuated::Punctuated,
8 GenericArgument,
9 ItemFn,
10 PathArguments,
11 ReturnType,
12 Type,
13 TypePath,
14 TypeTuple,
15};
16
17use crate::attribute::AttributeOptions;
18
19mod attribute;
20
21#[proc_macro_attribute]
22pub fn flowtest(
23 attr: proc_macro::TokenStream,
24 item: proc_macro::TokenStream,
25) -> proc_macro::TokenStream {
26 let mut attr = parse_macro_input!(attr as AttributeOptions);
27
28 let ItemFn {
29 attrs,
30 vis,
31 mut sig,
32 block,
33 } = parse_macro_input!(item as ItemFn);
34
35 if let Some(asyncness) = sig.asyncness {
36 return syn::Error::new_spanned(asyncness, "`async` is not allowed on flowtest tests. If you are using an async test attribute make sure it is above the flowtest attribute.")
37 .into_compile_error()
38 .into();
39 }
40
41 let executor = attr
42 .executor
43 .take()
44 .map(ToTokens::into_token_stream)
45 .unwrap_or_else(|| {
46 quote! {
47 ::flowtest::standard_executor::StandardExecutor
48 }
49 });
50
51 let continuation = format_ident!("__flowtest_test_{}", sig.ident);
52
53 let mut out_ty = Type::Tuple(TypeTuple {
54 paren_token: Default::default(),
55 elems: Punctuated::new(),
56 });
57
58 let mut is_result = false;
59
60 if let ReturnType::Type(_, ty) = &mut sig.output {
61 'ret: {
62 if attr.result_override != Some(false) {
63 if let Type::Path(TypePath { qself: _, path }) = &mut **ty {
64 if let Some(ty) = path.segments.last_mut() {
65 if attr.result_override == Some(true) || ty.ident.to_string() == "Result" {
67 if let PathArguments::AngleBracketed(args) = &mut ty.arguments {
68 if let Some(GenericArgument::Type(ty)) = args.args.first_mut() {
69 mem::swap(ty, &mut out_ty);
70 is_result = true;
71 break 'ret
72 }
73 }
74 }
75 }
76 }
77 }
78
79 mem::swap(&mut **ty, &mut out_ty);
80 }
81
82 if attr.result_override == Some(true) && !is_result {
83 return syn::Error::new_spanned(ty, "unable to parse return type as result (required due to `-> result` in flowtest attribute)")
84 .into_compile_error()
85 .into()
86 }
87 } else if attr.result_override == Some(true) {
88 return syn::Error::new_spanned(
89 sig.ident,
90 "function must return a result (required due to `-> result` in flowtest attribute)",
91 )
92 .into_compile_error()
93 .into()
94 }
95
96 let exec_fn = match is_result {
97 true => format_ident!("exec_result"),
98 false => format_ident!("exec_noresult"),
99 };
100
101 let dependencies = attr.inputs.iter().map(|input| {
102 let Input { from, pat } = input;
103
104 let from_continuation = format_ident!("__flowtest_test_{}", from);
105
106 quote! {
107 let #pat = match ::flowtest::Executor::wait(&mut __flowtest_executor, &#from_continuation) {
108 Ok(v) => v,
109 Err(::flowtest::__private::TestFailedError) => ::std::panic!(
110 concat!("flowtest dependency `", stringify!(#from), "` failed")
111 ),
112 };
113 }
114 });
115
116 let new_fn = quote! {
117 #[allow(non_upper_case_globals)]
118 static #continuation: <#executor as ::flowtest::Executor>::Continuation<::std::result::Result<#out_ty, ::flowtest::__private::TestFailedError>> =
119 <<#executor as ::flowtest::Executor>::Continuation<::std::result::Result<#out_ty, ::flowtest::__private::TestFailedError>>
120 as ::flowtest::Continuation>::INITIAL;
121
122 #(#attrs)* #vis #sig {
123 let mut __flowtest_executor = <#executor as ::flowtest::Executor>::init();
124
125 #(#dependencies)*
126
127 ::flowtest::__private::#exec_fn(__flowtest_executor, &#continuation, move || #block)
128 }
129 };
130
131 new_fn.into()
132}