mpc-macros 0.2.12

Arcium MPC Macros
Documentation
/*!
# Procedural Macros for Party Traits

This file provides one procedural macro per trait for the async-mpc framework. Each macro implements a single trait for a struct, based on the presence and type of its fields.

## Macros
- `Party`: Implements the `Party` trait.
- `Probabilistic`: Implements the `Probabilistic` trait.
- `Scribe`: Implements the `Scribe` trait.
- `HasTweakableHasher`: Implements the `HasTweakableHasher` trait.
- `Peer`: Implements the `Peer` trait.

See each macro's docstring for details and usage.
*/

pub mod hasher;
pub mod party;
pub mod peer;
pub mod probabilistic;
pub mod scribe;

pub use hasher::*;
pub use party::*;
pub use peer::*;
pub use probabilistic::*;
use proc_macro::TokenStream;
use quote::quote;
pub use scribe::*;
use syn::{Data, DeriveInput, Fields};

/// Helper to find fields by attribute or name.
/// For each name in `attrs_or_names`, returns a HashMap mapping the name to all the hits <(Ident,
/// Type)>.
fn find_fields_in_struct<I>(
    input: &DeriveInput,
    attrs_or_names: I,
) -> Result<std::collections::HashMap<String, Vec<(syn::Ident, syn::Type)>>, TokenStream>
where
    I: IntoIterator<Item = String>,
{
    // Validation
    let named_struct_fields = match &input.data {
        Data::Struct(data_struct) => match &data_struct.fields {
            Fields::Named(fields) => fields,
            _ => return Err(quote! { compile_error!("Party trait derive macro only supports structs with named fields.") }.into()),
        },
        _ => return Err(quote! { compile_error!("Party trait derive macro only supports structs.") }.into()),
    };

    // Build results
    let mut results = attrs_or_names
        .into_iter()
        .map(|name| (name, Vec::new()))
        .collect::<std::collections::HashMap<_, _>>();

    // Check each field against all attributes/names
    for field in named_struct_fields.named.iter() {
        for (attr_or_name, matches) in results.iter_mut() {
            if let Some(ident) = &field.ident {
                let has_attr = field.attrs.iter().any(|attr| {
                    attr.path()
                        .get_ident()
                        .map(|id| id == attr_or_name)
                        .unwrap_or(false)
                });
                if has_attr || ident == attr_or_name {
                    matches.push((ident.clone(), field.ty.clone()));
                }
            }
        }
    }
    Ok(results)
}

/// Helper function to check a field's type matches a given identifier using `syn`
fn is_type(field_ty: &syn::Type, ident_str: &str) -> bool {
    if let syn::Type::Path(type_path) = field_ty {
        if let Some(last_segment) = type_path.path.segments.last() {
            return last_segment.ident == ident_str;
        }
    }
    false
}