ark-api-macros 0.11.0

Macros utilities for Ark API
Documentation
use proc_macro2::*;
use quote::format_ident;
use quote::quote;
use quote::ToTokens;
use std::sync::atomic::*;
use syn::parse::Parse;
use syn::Error;
use syn::ItemFn;
use syn::LitStr;
use syn::ReturnType;
use syn::Token;

static CNT: AtomicUsize = AtomicUsize::new(0);

pub fn try_ark_test(
    attr: proc_macro::TokenStream,
    body: proc_macro::TokenStream,
) -> Result<TokenStream, Error> {
    let attr = syn::parse::<ParsedAttr>(attr)?;
    let item = syn::parse::<ParsedItem>(body)?;

    let parsed = Parsed { attr, item };
    Ok(quote! { #parsed })
}

struct ParsedItem {
    item_fn: ItemFn,
}

impl Parse for ParsedItem {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let item_fn = input.parse::<ItemFn>()?;

        let sig = &item_fn.sig;

        if let Some(constness) = &sig.constness {
            return Err(Error::new_spanned(
                constness,
                "Test functions cannot be `const`",
            ));
        }

        if let Some(asyncness) = &sig.asyncness {
            return Err(Error::new_spanned(
                asyncness,
                "Test functions cannot be `async`",
            ));
        }

        if let Some(unsafety) = &sig.unsafety {
            return Err(Error::new_spanned(
                unsafety,
                "Test functions cannot be `unsafe`",
            ));
        }

        if let Some(abi) = &sig.abi {
            return Err(Error::new_spanned(
                abi,
                "Test functions cannot declare an ABI",
            ));
        }

        let generics = &sig.generics;

        if !generics.params.is_empty() {
            return Err(Error::new_spanned(
                generics,
                "Test functions cannot be generic",
            ));
        }

        if let Some(where_clause) = &generics.where_clause {
            return Err(Error::new_spanned(
                where_clause,
                "Test functions cannot be generic",
            ));
        }

        if !sig.inputs.is_empty() {
            return Err(Error::new_spanned(
                &sig.inputs,
                "Test functions cannot take inputs",
            ));
        }

        if let Some(variadic) = &sig.variadic {
            return Err(Error::new_spanned(
                variadic,
                "Test functions cannot be variadic",
            ));
        }

        if let ReturnType::Type(_, _) = &sig.output {
            return Err(Error::new_spanned(
                &sig.output,
                "Test functions cannot have outputs",
            ));
        }

        // we intentionally don't check `item_fn.vis` (visibility)
        // normal rust test functions are allowed to be `pub` so we might as well keep that

        Ok(Self { item_fn })
    }
}

struct ParsedAttr {
    ignored: Option<IgnoredAttr>,
    should_panic: Option<ShouldPanicAttr>,
}

impl ParsedAttr {
    fn suffix(&self) -> &'static str {
        match (self.should_panic.is_some(), self.ignored.is_some()) {
            (true, true) => "pi",
            (true, false) => "px",
            (false, true) => "xi",
            (false, false) => "xx",
        }
    }
}

enum IgnoredAttr {
    WithoutMessage,
    // this message isn't used currently, its allowed for documentation purposes
    WithMessage(LitStr),
}

enum ShouldPanicAttr {
    AnyMessage,
    WithMessage(LitStr),
}

#[allow(warnings)]
impl Parse for ParsedAttr {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let mut ignored = None::<IgnoredAttr>;
        let mut should_panic = None::<ShouldPanicAttr>;

        while !input.is_empty() {
            let ident = input.parse::<Ident>()?;

            match &*ident.to_string() {
                "ignore" => {
                    if input.peek(Token![=]) {
                        input.parse::<Token![=]>()?;
                        let msg = input.parse::<LitStr>()?;
                        ignored = Some(IgnoredAttr::WithMessage(msg));
                    } else {
                        ignored = Some(IgnoredAttr::WithoutMessage);
                    }
                }

                "should_panic" => {
                    if input.peek(Token![=]) {
                        input.parse::<Token![=]>()?;
                        let msg = input.parse::<LitStr>()?;
                        should_panic = Some(ShouldPanicAttr::WithMessage(msg));
                    } else {
                        should_panic = Some(ShouldPanicAttr::AnyMessage);
                    }
                }

                _ => {
                    return Err(Error::new_spanned(
                        ident,
                        "Unsupported argument to `#[ark_test]`",
                    ));
                }
            }

            let _ = input.parse::<Token![,]>();
        }

        Ok(Self {
            ignored,
            should_panic,
        })
    }
}

struct Parsed {
    attr: ParsedAttr,
    item: ParsedItem,
}

impl ToTokens for Parsed {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let item_fn = &self.item.item_fn;
        let ident = &item_fn.sig.ident;

        let extern_name = format_ident!(
            "__at_test__{}__{}_{}",
            ident,
            CNT.fetch_add(1, Ordering::SeqCst),
            self.attr.suffix(),
        );

        let panic_assertion =
            self.attr
                .should_panic
                .as_ref()
                .map(|should_panic| match should_panic {
                    ShouldPanicAttr::AnyMessage => quote! {
                        panic!(
                            "expected {} to panic",
                            concat!(module_path!(), "::", stringify!(#ident))
                        );
                    },
                    ShouldPanicAttr::WithMessage(msg) => quote! {
                        panic!(
                            "expected {} to panic with '{}'",
                            concat!(module_path!(), "::", stringify!(#ident)),
                            #msg,
                        );
                    },
                });

        tokens.extend(quote! {
            #[no_mangle]
            pub extern "C" fn #extern_name() {
                #ident();
                #panic_assertion
            }

            #item_fn
        });
    }
}

/// We don't actually care about adding docs to this function. We're only using the docs as a
/// makeshift compile test.
///
/// These should all pass
///
/// ```
/// #[ark_api_macros::ark_test]
/// fn test_1() {}
///
/// #[ark_api_macros::ark_test(ignore)]
/// fn test_2() {}
///
/// #[ark_api_macros::ark_test(should_panic)]
/// fn test_3() {}
///
/// #[ark_api_macros::ark_test(should_panic = "panic message")]
/// fn test_4() {}
///
/// #[ark_api_macros::ark_test(should_panic, ignore)]
/// fn test_5() {}
///
/// #[ark_api_macros::ark_test(should_panic = "panic message", ignore)]
/// fn test_6() {}
///
/// #[ark_api_macros::ark_test(ignore, should_panic)]
/// fn test_7() {}
///
/// #[ark_api_macros::ark_test(ignore, should_panic = "panic message")]
/// fn test_8() {}
///
/// #[ark_api_macros::ark_test(ignore, should_panic = "panic message",)]
/// fn test_9() {}
///
/// #[ark_api_macros::ark_test(ignore = "ignore message")]
/// fn test_10() {}
/// ```
///
/// And these should all fail
///
/// ```compile_fail
/// #[ark_api_macros::ark_test]
/// async fn test() {}
/// ```
///
/// ```compile_fail
/// #[ark_api_macros::ark_test(foo)]
/// fn test() {}
/// ```
#[allow(dead_code)]
fn poor_mans_compile_test() {}