use quote::quote;
use syn::{parse_macro_input, Attribute, ItemFn, ReturnType};
#[proc_macro_attribute]
pub fn test(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut parsed_fn = parse_macro_input!(input as ItemFn);
let attrs = parsed_fn.attrs.drain(..).collect::<Vec<_>>();
let (mut sig, block) = (parsed_fn.sig, parsed_fn.block);
let (outer_return_type, trailer) =
if attrs.iter().any(|attr| attr.path().is_ident("should_panic")) {
(quote! { () }, quote! { .unwrap(); })
} else {
(
quote! { std::result::Result<(), googletest::internal::test_outcome::TestFailure> },
quote! {},
)
};
let output_type = match sig.output.clone() {
ReturnType::Type(_, output_type) => Some(output_type),
ReturnType::Default => None,
};
sig.output = ReturnType::Default;
let (maybe_closure, invocation) = if sig.asyncness.is_some() {
(
quote! {},
quote! {
async { #block }.await
},
)
} else {
(
quote! {
let test = move || #block;
},
quote! {
test()
},
)
};
let function = if let Some(output_type) = output_type {
quote! {
#(#attrs)*
#sig -> #outer_return_type {
#maybe_closure
use googletest::internal::test_outcome::TestOutcome;
TestOutcome::init_current_test_outcome();
let result: #output_type = #invocation;
TestOutcome::close_current_test_outcome(result)
#trailer
}
}
} else {
quote! {
#(#attrs)*
#sig -> #outer_return_type {
#maybe_closure
use googletest::internal::test_outcome::TestOutcome;
TestOutcome::init_current_test_outcome();
#invocation;
TestOutcome::close_current_test_outcome(googletest::Result::Ok(()))
#trailer
}
}
};
let output = if attrs.iter().any(is_test_attribute) {
function
} else {
quote! {
#[::core::prelude::v1::test]
#function
}
};
output.into()
}
fn is_test_attribute(attr: &Attribute) -> bool {
let first_segment = match attr.path().segments.first() {
Some(first_segment) => first_segment,
None => return false,
};
let last_segment = match attr.path().segments.last() {
Some(last_segment) => last_segment,
None => return false,
};
last_segment.ident == "test"
|| (first_segment.ident == "rstest"
&& last_segment.ident == "rstest"
&& attr.path().segments.len() <= 2)
}