mpc-macros 0.4.1

Arcium MPC Macros
Documentation
//! `#[protocol_trait]` attribute macro for adding debug logging to impl blocks.

use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{
    parse_macro_input,
    spanned::Spanned,
    visit::Visit,
    Expr,
    ExprMethodCall,
    FnArg,
    ImplItem,
    ImplItemFn,
    ItemImpl,
    Pat,
    Type,
    TypeImplTrait,
    TypeParamBound,
};

/// Identifies the network type from a function parameter's type.
enum NetworkType {
    Sink,
    Stream,
    MultipartyInterface(String), // Contains the parameter name
}

fn determine_network_info(network_types: &[NetworkType]) -> TokenStream2 {
    if let Some(net_param) = network_types.iter().find_map(|t| match t {
        NetworkType::MultipartyInterface(name) => Some(name.clone()),
        _ => None,
    }) {
        let net_ident = syn::Ident::new(&net_param, proc_macro2::Span::call_site());
        return quote! {
            format!("network interface {:?} as peer {:?}",
                network::context::Peer::context_id(&#net_ident),
                network::context::Peer::local_peer_id(&#net_ident))
        };
    }
    let has_sink = network_types.iter().any(|t| matches!(t, NetworkType::Sink));
    let has_stream = network_types
        .iter()
        .any(|t| matches!(t, NetworkType::Stream));
    match (has_sink, has_stream) {
        (true, true) => quote! { "sink/stream channels".to_string() },
        (true, false) => quote! { "a sink channel".to_string() },
        (false, true) => quote! { "a stream channel".to_string() },
        (false, false) => quote! { "local execution".to_string() },
    }
}

fn bound_contains_trait(bound: &TypeParamBound, trait_name: &str) -> bool {
    matches!(bound, TypeParamBound::Trait(tb) if tb.path.segments.iter().any(|s| s.ident == trait_name))
}

fn extract_network_types(func: &ImplItemFn) -> Vec<NetworkType> {
    func.sig
        .inputs
        .iter()
        .filter_map(|arg| {
            let FnArg::Typed(pat_type) = arg else {
                return None;
            };
            let Pat::Ident(pat_ident) = &*pat_type.pat else {
                return None;
            };
            let Type::ImplTrait(TypeImplTrait { bounds, .. }) = &*pat_type.ty else {
                return None;
            };
            let param_name = pat_ident.ident.to_string();
            Some(bounds.iter().filter_map(move |bound| {
                if bound_contains_trait(bound, "MultipartyInterface") {
                    Some(NetworkType::MultipartyInterface(param_name.clone()))
                } else if bound_contains_trait(bound, "IoSink") {
                    Some(NetworkType::Sink)
                } else if bound_contains_trait(bound, "IoStream") {
                    Some(NetworkType::Stream)
                } else {
                    None
                }
            }))
        })
        .flatten()
        .collect()
}

fn has_ref_receiver(func: &ImplItemFn) -> bool {
    func.sig
        .inputs
        .iter()
        .any(|arg| matches!(arg, FnArg::Receiver(r) if r.reference.is_some()))
}

fn generate_log_statement(
    struct_name: &syn::Ident,
    trait_name: Option<&str>,
    func: &ImplItemFn,
) -> TokenStream2 {
    let func_name = func.sig.ident.to_string();
    let network_info = determine_network_info(&extract_network_types(func));
    let struct_name_str = struct_name.to_string();
    let func_display = trait_name.map_or(func_name.clone(), |t| format!("{t}.{func_name}"));
    // Use `self` for &self/&mut self methods, `&self` for self-by-value methods
    let self_ref = if has_ref_receiver(func) {
        quote! { self }
    } else {
        quote! { &self }
    };
    quote! {
        log::debug!("<{} - {}> {} with {:?} over {}",
            #struct_name_str, PROTOCOL_INFO.name(), #func_display,
            crate::types::party::Party::session_id(#self_ref), #network_info);
    }
}

fn inject_logging(
    mut func: ImplItemFn,
    struct_name: &syn::Ident,
    trait_name: Option<&str>,
) -> ImplItemFn {
    let log_stmt = generate_log_statement(struct_name, trait_name, &func);
    let stmts = &func.block.stmts;
    func.block = syn::parse_quote! {{ #log_stmt #(#stmts)* }};
    func
}

fn extract_trait_name(impl_block: &ItemImpl) -> Option<String> {
    impl_block.trait_.as_ref().map(|(_, path, _)| {
        path.segments.last().map_or_else(
            || path.to_token_stream().to_string(),
            |s| s.ident.to_string(),
        )
    })
}

struct RefreshVisitor(bool);
impl<'ast> Visit<'ast> for RefreshVisitor {
    fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
        if node.method == "refresh" {
            if let Expr::Path(p) = &*node.receiver {
                if p.path.is_ident("self") {
                    self.0 = true;
                }
            }
        }
        syn::visit::visit_expr_method_call(self, node);
    }
}

fn has_ref_self_fn(impl_block: &ItemImpl) -> bool {
    impl_block.items.iter().any(|item| {
        matches!(item, ImplItem::Fn(f) if f.sig.inputs.iter().any(
        |a| matches!(a, FnArg::Receiver(r) if r.reference.is_some())))
    })
}

fn contains_refresh_call(impl_block: &ItemImpl) -> bool {
    let mut v = RefreshVisitor(false);
    impl_block.items.iter().for_each(|item| {
        if let ImplItem::Fn(f) = item {
            v.visit_impl_item_fn(f);
        }
    });
    v.0
}

pub fn protocol_trait_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
    let _ = attr;
    let mut impl_block = parse_macro_input!(item as ItemImpl);

    if has_ref_self_fn(&impl_block) && !contains_refresh_call(&impl_block) {
        return syn::Error::new(
            impl_block.span(),
            "Protocol impl block must contain at least one `self.refresh()` call",
        )
        .to_compile_error()
        .into();
    }

    let trait_name = extract_trait_name(&impl_block);
    let struct_name = match &*impl_block.self_ty {
        Type::Path(tp) => tp
            .path
            .segments
            .last()
            .map(|s| s.ident.clone())
            .unwrap_or_else(|| syn::Ident::new("Unknown", proc_macro2::Span::call_site())),
        _ => syn::Ident::new("Unknown", proc_macro2::Span::call_site()),
    };

    impl_block.items = impl_block
        .items
        .into_iter()
        .map(|item| match item {
            ImplItem::Fn(f) => ImplItem::Fn(inject_logging(f, &struct_name, trait_name.as_deref())),
            other => other,
        })
        .collect();

    quote! { #impl_block }.into()
}