mpc-macros 0.2.15

Arcium MPC Macros
Documentation
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Attribute, ItemFn, Meta};

pub fn public_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
    let input_fn = parse_macro_input!(item as ItemFn);
    let cfg_content = parse_macro_input!(attr as Meta);

    let fn_name = &input_fn.sig.ident;
    let public_fn_name = syn::Ident::new(&format!("_{fn_name}"), fn_name.span());

    let generics = &input_fn.sig.generics;
    let inputs = &input_fn.sig.inputs;
    let output = &input_fn.sig.output;
    let where_clause = &input_fn.sig.generics.where_clause;
    let asyncness = &input_fn.sig.asyncness;

    // Extract parameter names for the call
    let param_names: Vec<_> = inputs
        .iter()
        .filter_map(|arg| {
            if let syn::FnArg::Typed(pat_type) = arg {
                if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
                    return Some(&pat_ident.ident);
                }
            }
            None
        })
        .collect();

    // Check if this is a method (has &self, &mut self, or self)
    let has_self = inputs
        .iter()
        .any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
    let call_expr = match has_self {
        true => quote! { self.#fn_name(#(#param_names),*) },
        false => quote! { #fn_name(#(#param_names),*) },
    };

    // Determine if we need to add `.await` to the call
    let await_token = match asyncness.is_some() {
        true => quote! { .await },
        false => quote! {},
    };

    // Filter out #[cfg(...)] and #[public(...)] attributes from the original function
    let wrapper_attrs: Vec<&Attribute> = input_fn
        .attrs
        .iter()
        .filter(|attr| {
            let ident = attr.path().get_ident().map(|id| id.to_string());
            !matches!(ident.as_deref(), Some("cfg") | Some("public"))
        })
        .collect();

    // Check if the original function already has #[inline], add it otherwise.
    let has_inline = input_fn.attrs.iter().any(|attr| {
        attr.path()
            .get_ident()
            .map(|id| id == "inline")
            .unwrap_or(false)
    });
    let inline_attr = match has_inline {
        true => quote! {},
        false => quote! { #[inline] },
    };

    let expanded = quote! {
        #input_fn

        #[cfg(#cfg_content)]
        #inline_attr
        #(#wrapper_attrs)*
        pub #asyncness fn #public_fn_name #generics(#inputs) #output #where_clause {
            #call_expr #await_token
        }
    };

    TokenStream::from(expanded)
}