mpc-macros 0.4.1-rc.0

Arcium MPC Macros
Documentation
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};

use crate::party_traits::{find_fields_in_struct, is_type};

const PEER_STR: &str = "peer_ctx";

pub fn derive_peer(input: TokenStream) -> TokenStream {
    // Parsing
    let input = parse_macro_input!(input as DeriveInput);
    let peer_ctx_fields = match find_fields_in_struct(&input, [PEER_STR.to_string()]) {
        Ok(mut results) => results.remove(PEER_STR).unwrap_or_default(),
        Err(e) => return e,
    };

    // Validation
    let (peer_ctx_ident, is_direct) = match peer_ctx_fields.as_slice() {
        [(ident, ty)] if is_type(ty, "PeerContext") => (ident, true),
        [(ident, _)] => (ident, false), // Delegate to field implementing Peer
        [] => return quote! { compile_error!("Found no field marked with #[peer_ctx] or a field named `peer_ctx`.") }.into(),
        _ => return quote! { compile_error!("Found multiple fields marked with #[peer_ctx] or named `peer_ctx`. Only one is allowed.") }.into(),
    };

    // Determine the crate path based on whether we're inside the network crate
    let network_crate = if std::env::var("CARGO_PKG_NAME")
        .map(|s| s == "network")
        .unwrap_or(false)
    {
        quote! { crate }
    } else {
        quote! { network }
    };

    // Implementation
    let struct_name = &input.ident;
    let generics = &input.generics;
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    let peer_context_impl = if is_direct {
        quote! { &self.#peer_ctx_ident }
    } else {
        quote! { self.#peer_ctx_ident.peer_context() }
    };

    quote! {
        impl #impl_generics #network_crate::context::Peer for #struct_name #ty_generics #where_clause {
            fn peer_context(&self) -> &#network_crate::context::PeerContext {
                #peer_context_impl
            }
        }
    }
    .into()
}