flowtest_macro/
lib.rs

1use std::mem;
2
3use attribute::Input;
4use quote::{format_ident, quote, ToTokens};
5use syn::{
6	parse_macro_input,
7	punctuated::Punctuated,
8	GenericArgument,
9	ItemFn,
10	PathArguments,
11	ReturnType,
12	Type,
13	TypePath,
14	TypeTuple,
15};
16
17use crate::attribute::AttributeOptions;
18
19mod attribute;
20
21#[proc_macro_attribute]
22pub fn flowtest(
23	attr: proc_macro::TokenStream,
24	item: proc_macro::TokenStream,
25) -> proc_macro::TokenStream {
26	let mut attr = parse_macro_input!(attr as AttributeOptions);
27
28	let ItemFn {
29		attrs,
30		vis,
31		mut sig,
32		block,
33	} = parse_macro_input!(item as ItemFn);
34
35	if let Some(asyncness) = sig.asyncness {
36		return syn::Error::new_spanned(asyncness, "`async` is not allowed on flowtest tests. If you are using an async test attribute make sure it is above the flowtest attribute.")
37			.into_compile_error()
38			.into();
39	}
40
41	let executor = attr
42		.executor
43		.take()
44		.map(ToTokens::into_token_stream)
45		.unwrap_or_else(|| {
46			quote! {
47				::flowtest::standard_executor::StandardExecutor
48			}
49		});
50
51	let continuation = format_ident!("__flowtest_test_{}", sig.ident);
52
53	let mut out_ty = Type::Tuple(TypeTuple {
54		paren_token: Default::default(),
55		elems: Punctuated::new(),
56	});
57
58	let mut is_result = false;
59
60	if let ReturnType::Type(_, ty) = &mut sig.output {
61		'ret: {
62			if attr.result_override != Some(false) {
63				if let Type::Path(TypePath { qself: _, path }) = &mut **ty {
64					if let Some(ty) = path.segments.last_mut() {
65						// hopefully this should catch result types used in std and libraries like anyhow
66						if attr.result_override == Some(true) || ty.ident.to_string() == "Result" {
67							if let PathArguments::AngleBracketed(args) = &mut ty.arguments {
68								if let Some(GenericArgument::Type(ty)) = args.args.first_mut() {
69									mem::swap(ty, &mut out_ty);
70									is_result = true;
71									break 'ret
72								}
73							}
74						}
75					}
76				}
77			}
78
79			mem::swap(&mut **ty, &mut out_ty);
80		}
81
82		if attr.result_override == Some(true) && !is_result {
83			return syn::Error::new_spanned(ty, "unable to parse return type as result (required due to `-> result` in flowtest attribute)")
84				.into_compile_error()
85				.into()
86		}
87	} else if attr.result_override == Some(true) {
88		return syn::Error::new_spanned(
89			sig.ident,
90			"function must return a result (required due to `-> result` in flowtest attribute)",
91		)
92		.into_compile_error()
93		.into()
94	}
95
96	let exec_fn = match is_result {
97		true => format_ident!("exec_result"),
98		false => format_ident!("exec_noresult"),
99	};
100
101	let dependencies = attr.inputs.iter().map(|input| {
102		let Input { from, pat } = input;
103
104		let from_continuation = format_ident!("__flowtest_test_{}", from);
105
106		quote! {
107			let #pat = match ::flowtest::Executor::wait(&mut __flowtest_executor, &#from_continuation) {
108				Ok(v) => v,
109				Err(::flowtest::__private::TestFailedError) => ::std::panic!(
110					concat!("flowtest dependency `", stringify!(#from), "` failed")
111				),
112			};
113		}
114	});
115
116	let new_fn = quote! {
117		#[allow(non_upper_case_globals)]
118		static #continuation: <#executor as ::flowtest::Executor>::Continuation<::std::result::Result<#out_ty, ::flowtest::__private::TestFailedError>> =
119			<<#executor as ::flowtest::Executor>::Continuation<::std::result::Result<#out_ty, ::flowtest::__private::TestFailedError>>
120				as ::flowtest::Continuation>::INITIAL;
121
122		#(#attrs)* #vis #sig {
123			let mut __flowtest_executor = <#executor as ::flowtest::Executor>::init();
124
125			#(#dependencies)*
126
127			::flowtest::__private::#exec_fn(__flowtest_executor, &#continuation, move || #block)
128		}
129	};
130
131	new_fn.into()
132}