use proc_macro2::*;
use quote::format_ident;
use quote::quote;
use quote::ToTokens;
use std::sync::atomic::*;
use syn::parse::Parse;
use syn::Error;
use syn::ItemFn;
use syn::LitStr;
use syn::ReturnType;
use syn::Token;
static CNT: AtomicUsize = AtomicUsize::new(0);
pub fn try_ark_test(
attr: proc_macro::TokenStream,
body: proc_macro::TokenStream,
) -> Result<TokenStream, Error> {
let attr = syn::parse::<ParsedAttr>(attr)?;
let item = syn::parse::<ParsedItem>(body)?;
let parsed = Parsed { attr, item };
Ok(quote! { #parsed })
}
struct ParsedItem {
item_fn: ItemFn,
}
impl Parse for ParsedItem {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let item_fn = input.parse::<ItemFn>()?;
let sig = &item_fn.sig;
if let Some(constness) = &sig.constness {
return Err(Error::new_spanned(
constness,
"Test functions cannot be `const`",
));
}
if let Some(asyncness) = &sig.asyncness {
return Err(Error::new_spanned(
asyncness,
"Test functions cannot be `async`",
));
}
if let Some(unsafety) = &sig.unsafety {
return Err(Error::new_spanned(
unsafety,
"Test functions cannot be `unsafe`",
));
}
if let Some(abi) = &sig.abi {
return Err(Error::new_spanned(
abi,
"Test functions cannot declare an ABI",
));
}
let generics = &sig.generics;
if !generics.params.is_empty() {
return Err(Error::new_spanned(
generics,
"Test functions cannot be generic",
));
}
if let Some(where_clause) = &generics.where_clause {
return Err(Error::new_spanned(
where_clause,
"Test functions cannot be generic",
));
}
if !sig.inputs.is_empty() {
return Err(Error::new_spanned(
&sig.inputs,
"Test functions cannot take inputs",
));
}
if let Some(variadic) = &sig.variadic {
return Err(Error::new_spanned(
variadic,
"Test functions cannot be variadic",
));
}
if let ReturnType::Type(_, _) = &sig.output {
return Err(Error::new_spanned(
&sig.output,
"Test functions cannot have outputs",
));
}
Ok(Self { item_fn })
}
}
struct ParsedAttr {
ignored: Option<IgnoredAttr>,
should_panic: Option<ShouldPanicAttr>,
}
impl ParsedAttr {
fn suffix(&self) -> &'static str {
match (self.should_panic.is_some(), self.ignored.is_some()) {
(true, true) => "pi",
(true, false) => "px",
(false, true) => "xi",
(false, false) => "xx",
}
}
}
enum IgnoredAttr {
WithoutMessage,
WithMessage(LitStr),
}
enum ShouldPanicAttr {
AnyMessage,
WithMessage(LitStr),
}
#[allow(warnings)]
impl Parse for ParsedAttr {
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let mut ignored = None::<IgnoredAttr>;
let mut should_panic = None::<ShouldPanicAttr>;
while !input.is_empty() {
let ident = input.parse::<Ident>()?;
match &*ident.to_string() {
"ignore" => {
if input.peek(Token![=]) {
input.parse::<Token![=]>()?;
let msg = input.parse::<LitStr>()?;
ignored = Some(IgnoredAttr::WithMessage(msg));
} else {
ignored = Some(IgnoredAttr::WithoutMessage);
}
}
"should_panic" => {
if input.peek(Token![=]) {
input.parse::<Token![=]>()?;
let msg = input.parse::<LitStr>()?;
should_panic = Some(ShouldPanicAttr::WithMessage(msg));
} else {
should_panic = Some(ShouldPanicAttr::AnyMessage);
}
}
_ => {
return Err(Error::new_spanned(
ident,
"Unsupported argument to `#[ark_test]`",
));
}
}
let _ = input.parse::<Token![,]>();
}
Ok(Self {
ignored,
should_panic,
})
}
}
struct Parsed {
attr: ParsedAttr,
item: ParsedItem,
}
impl ToTokens for Parsed {
fn to_tokens(&self, tokens: &mut TokenStream) {
let item_fn = &self.item.item_fn;
let ident = &item_fn.sig.ident;
let extern_name = format_ident!(
"__at_test__{}__{}_{}",
ident,
CNT.fetch_add(1, Ordering::SeqCst),
self.attr.suffix(),
);
let panic_assertion =
self.attr
.should_panic
.as_ref()
.map(|should_panic| match should_panic {
ShouldPanicAttr::AnyMessage => quote! {
panic!(
"expected {} to panic",
concat!(module_path!(), "::", stringify!(#ident))
);
},
ShouldPanicAttr::WithMessage(msg) => quote! {
panic!(
"expected {} to panic with '{}'",
concat!(module_path!(), "::", stringify!(#ident)),
#msg,
);
},
});
tokens.extend(quote! {
#[no_mangle]
pub extern "C" fn #extern_name() {
#ident();
#panic_assertion
}
#item_fn
});
}
}
#[allow(dead_code)]
fn poor_mans_compile_test() {}