mpc-macros 0.2.12

Arcium MPC Macros
Documentation
extern crate proc_macro;
use itertools::izip;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, GenericArgument, PathArguments, Type};

use crate::party_traits::{
    find_fields_in_struct,
    hasher::HASHER_STR,
    is_type,
    scribe::TRANSCRIPT_STR,
};

const SESSION_ID_STR: &str = "session_id";
const PARTY_STR: &str = "party";

/// Helper function to generate refresh code for iterable types.
/// Returns the appropriate code to call `.refresh()` on elements, handling nested iterables.
fn generate_refresh_for_iterable(ty: &Type) -> proc_macro2::TokenStream {
    match ty {
        Type::Path(type_path) => {
            // Iterate over Vec types.
            if let Some(last_segment) = type_path.path.segments.last() {
                if last_segment.ident == "Vec" {
                    if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
                        if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
                            let inner_refresh = generate_refresh_for_iterable(inner_type);
                            return quote! {
                                .iter_mut().for_each(|elem| #inner_refresh )
                            };
                        }
                    }
                }
            }
        }
        Type::Tuple(tuple) => {
            // Tuples need to be destructured to access each element
            if !tuple.elems.is_empty() {
                let tuple_size = tuple.elems.len();
                let tuple_vars: Vec<_> = (0..tuple_size)
                    .map(|i| quote::format_ident!("elem_{}", i))
                    .collect();

                let refresh_calls: Vec<_> = izip!(&tuple.elems, &tuple_vars)
                    .map(|(elem_ty, var)| {
                        let inner_refresh = generate_refresh_for_iterable(elem_ty);
                        quote! { #var #inner_refresh; }
                    })
                    .collect();

                return quote! {
                    {
                        let (#(#tuple_vars),*) = elem;
                        #(#refresh_calls)*
                    }
                };
            }
        }
        Type::Array(arr) => {
            // Arrays can be handled similarly to Vec
            let inner_refresh = generate_refresh_for_iterable(&arr.elem);
            return quote! {
                .iter_mut().for_each(|elem| { #inner_refresh })
            };
        }
        _ => {}
    }

    // Base case: just call .refresh()
    quote! { .refresh() }
}

pub fn derive_party(input: TokenStream) -> TokenStream {
    // PARSING
    let input = parse_macro_input!(input as DeriveInput);
    let struct_name = &input.ident;
    let generics = &input.generics;
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
    let field_results = match find_fields_in_struct(
        &input,
        [
            SESSION_ID_STR.to_string(),
            PARTY_STR.to_string(),
            TRANSCRIPT_STR.to_string(),
            HASHER_STR.to_string(),
        ],
    ) {
        Ok(results) => results,
        Err(e) => return e,
    };

    // VALIDATION
    // Validate session_id field. If there is one, ensure it's of type SessionId.
    let session_id_field = match &field_results[SESSION_ID_STR].as_slice() {
        [] => None,
        [(ident, ty)] if is_type(ty, "SessionId") => Some(ident.clone()),
        [(_, _)] => return quote! { compile_error!("Found a field marked with #[session_id] or named `session_id` but not of type `SessionId`.") }.into(),
        _ => return quote! { compile_error!("Found multiple fields marked with #[session_id] or named `session_id`. Only one is allowed.") }.into(),
    };

    // Validate transcript field
    let transcript_matches = &field_results[TRANSCRIPT_STR];
    if transcript_matches.len() > 1 {
        return quote! { compile_error!("Party derive macro found multiple fields marked with #[transcript] or named `transcript`. Only one is allowed.") }.into();
    }
    let transcript_field = transcript_matches.first().map(|(ident, _)| ident);

    // Validate hasher field
    let hasher_matches = &field_results[HASHER_STR];
    if hasher_matches.len() > 1 {
        return quote! { compile_error!("Party derive macro found multiple fields marked with #[hasher] or named `hasher`. Only one is allowed.") }.into();
    }
    let hasher_field = hasher_matches.first().map(|(ident, _)| ident.clone());

    // Extract party fields (can have multiple)
    let party_annotated_fields = &field_results[PARTY_STR];
    if session_id_field.is_none() && party_annotated_fields.is_empty() {
        return quote! {
            compile_error!("Party derive macro requires a field of type `SessionId` or a field annotated with #[party].");
        }.into();
    }

    // IMPLEMENTATIONS
    // Session ID
    let session_id_impl = match (session_id_field.as_ref(), party_annotated_fields.first()) {
        (Some(session_id), _) => quote! { &self.#session_id },
        (None, Some((party, _))) => quote! { self.#party.session_id() },
        (None, None) => quote! { unreachable!() },
    };

    // Protocol Name
    let protocol_name_impl = if session_id_field.is_some() {
        quote! { PROTOCOL_INFO.name().to_string() }
    } else {
        let (first_party_field, _) = party_annotated_fields.first().unwrap();
        let type_name = struct_name.to_string();
        quote! { format!("{} - {}", #type_name, self.#first_party_field.protocol_name()) }
    };

    // Refresh
    let refresh_session_id = if let Some(session_id) = session_id_field.as_ref() {
        if let Some(transcript) = transcript_field {
            quote! { self.#session_id.refresh_from(&mut self.#transcript); }
        } else {
            quote! {
                let protocol_name = crate::types::party::Party::protocol_name(&self);
                self.#session_id.refresh_with(&protocol_name);
            }
        }
    } else {
        quote! {}
    };

    let refresh_party_fields = party_annotated_fields.iter().map(|(ident, field_ty)| {
        let refresh_code = generate_refresh_for_iterable(field_ty);
        quote! {
            self.#ident #refresh_code;
        }
    });

    let refresh_hasher_fields = if let Some(hasher) = hasher_field.as_ref() {
        quote! { primitives::hashing::TweakableHasher::refresh(&mut self.#hasher, #session_id_impl); }
    } else {
        quote! {}
    };

    quote! {
        impl #impl_generics crate::types::party::Party for #struct_name #ty_generics #where_clause {
            fn session_id(&self) -> &primitives::types::SessionId {
                #session_id_impl
            }
            fn protocol_name(&self) -> String {
                #protocol_name_impl
            }
            fn refresh(&mut self) {
                #refresh_session_id
                #( #refresh_party_fields )*
                #refresh_hasher_fields
            }
        }
    }
    .into()
}