mpc-macros 0.2.11

Arcium MPC Macros
Documentation
/*!
# MPC Macros

This crate provides procedural macros for the async-mpc library:

1. `define_task!` - Generates async task implementations with dependency management
2. `TaskGetters` - Derives getter functions for task enum variants

## define_task! Macro

The `define_task!` macro automates the creation of task wrappers, async execution logic,
and dependency management for MPC computations.

## TaskGetters Derive Macro

The `TaskGetters` derive macro automatically generates `try_get_*_task` functions for
each variant in a `TaskType` enum, eliminating repetitive boilerplate code.

For detailed documentation and examples, see the individual implementation modules.
*/

use proc_macro::TokenStream;

mod define_task;
mod gate_methods;
mod party_traits;
mod protocol_tags;
mod task_getters;

/// Generates async-mpc task implementations with automatic dependency management.
///
/// This macro takes a struct definition and a compute function, and generates:
/// - An internal unresolved struct to hold task dependencies
/// - A public type alias using the `TaskWrapper` type
/// - A Task trait implementation with async execution logic
/// - A `new` constructor for creating new task instances
/// - A standalone `compute` method for direct calls
///
/// # Examples
///
/// ```rust
/// define_task! {
///     pub struct FieldAddTask<F: FieldExtension> {
///         x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
///         y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
///     }
///
///     async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
///         Ok(x + &y)
///     }
/// }
/// ```
#[proc_macro]
pub fn define_task(input: TokenStream) -> TokenStream {
    define_task::define_task(input)
}

/// Derives `try_get_*_task` functions for enum variants containing task types.
///
/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
/// their index (encoded in the Label's first field) and return the appropriate task type.
///
/// # Examples
///
/// ```rust
/// #[derive(TaskGetters)]
/// pub enum TaskType<C: Curve> {
///     ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
///     BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
///     // ... more variants
/// }
/// ```
///
/// This generates functions like:
/// - `try_get_scalar_share_task`
/// - `try_get_base_field_share_task`
/// - `try_get_point_share_task`
///
/// Each function has the signature:
/// ```rust
/// pub fn try_get_<variant_name>_task<C: Curve>(
///     label: Label,
///     task_map: &[TaskType<C>],
/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
/// ```
///
/// ## Requirements
///
/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
/// - Each variant must have a corresponding `try_as_*` method (typically generated by
///   `EnumTryAsInner`)
#[proc_macro_derive(TaskGetters)]
pub fn derive_task_getters(input: TokenStream) -> TokenStream {
    task_getters::derive_task_getters(input)
}

/// Derives `Gate::map_labels` and `Gate::for_each_label`
#[proc_macro_derive(GateMethods)]
pub fn derive_gate_methods(input: TokenStream) -> TokenStream {
    gate_methods::derive_gate_methods(input)
}

/// Macro to ensure protocol tag uniqueness at compile time.
/// This macro generates a unique tag for a given protocol name,
/// and checks against a registry to prevent collisions.
/// If a collision is detected, it finds the next available tag.
#[proc_macro]
pub fn new_protocol(input: TokenStream) -> TokenStream {
    protocol_tags::new_protocol_info(input)
}

/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
#[proc_macro]
pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
    protocol_tags::dump_tags(input)
}
/// Procedural macro to automatically implement the `Party` trait for a struct.
///
/// # Requirements
/// - The struct must have either:
///   - A field marked with `#[session_id]` or a field of type `SessionId`, or
///   - At least one field annotated with `#[party]` (which must itself implement `Party`).
///
/// # Behavior
/// - `session_id()`: Returns a reference to the `SessionId` field, or delegates to the first
///   `#[party]` field.
/// - `protocol_name()`: Uses the `PROTOCOL_NAME` constant if a direct `SessionId` field is present,
///   otherwise formats as `"<TypeName> - <InnerProtocolName>"`.
/// - `refresh()`: Calls `refresh_from` or `refresh_with` on the session id, and `refresh()` on all
///   `#[party]` fields. If a `hasher` field is present, it is reset.
///
/// # Usage
/// ```rust
/// #[derive(Party)]
/// pub struct MyParty {
///     #[session_id]
///     session_id: SessionId,
///     #[party]
///     inner: InnerPartyType,
///     // ...
/// }
/// // or, with type/name matching fallback:
/// #[derive(Party)]
/// pub struct MyParty { session_id: SessionId, #[party] inner: InnerPartyType, ... }
/// ```
#[proc_macro_derive(Party, attributes(party, session_id, transcript, hasher))]
pub fn derive_party(input: TokenStream) -> TokenStream {
    party_traits::derive_party(input)
}

/// Procedural macro to automatically implement the `Probabilistic` trait for a struct.
///
/// # Requirements
/// - The struct must have a field marked with `#[rng]` or, if not present, a field named `rng` that
///   implements `CryptoRngCore`.
///
/// # Behavior
/// - Implements the `rng(&mut self)` method, returning a mutable reference to the selected field.
///
/// # Usage
/// ```rust
/// #[derive(Probabilistic)]
/// pub struct MyParty {
///     #[rng]
///     my_rng: MyRngType,
///     // ...
/// }
/// // or, fallback:
/// #[derive(Probabilistic)]
/// pub struct MyParty { rng: MyRngType, ... }
/// ```
#[proc_macro_derive(Probabilistic, attributes(rng))]
pub fn derive_probabilistic(input: TokenStream) -> TokenStream {
    party_traits::derive_probabilistic(input)
}

/// Procedural macro to automatically implement the `Scribe` trait for a struct.
///
/// # Requirements
/// - The struct must have a field marked with `#[transcript]` or, if not present, a field named
///   `transcript` that implements `Transcript`.
///
/// # Behavior
/// - Implements the `transcript(&mut self)` method, returning a mutable reference to the selected
///   field.
///
/// # Usage
/// ```rust
/// #[derive(Scribe)]
/// pub struct MyParty {
///     #[transcript]
///     my_transcript: MyTranscriptType,
///     // ...
/// }
/// // or, fallback:
/// #[derive(Scribe)]
/// pub struct MyParty { transcript: MyTranscriptType, ... }
/// ```
#[proc_macro_derive(Scribe, attributes(transcript))]
pub fn derive_scribe(input: TokenStream) -> TokenStream {
    party_traits::derive_scribe(input)
}

/// Procedural macro to automatically implement the `HasTweakableHasher` trait for a struct.
///
/// # Requirements
/// - The struct must have a field marked with `#[hasher]` or, if not present, a field named
///   `hasher` whose type implements `TweakableHasher`.
///
/// # Behavior
/// - Implements the `get_hasher(&mut self)` method, returning a mutable reference to the selected
///   field.
/// - Sets the associated type `Hasher` to the type of the selected field.
///
/// # Usage
/// ```rust
/// #[derive(HasTweakableHasher)]
/// pub struct MyParty {
///     #[hasher]
///     my_hasher: MyHasherType,
///     // ...
/// }
/// // or, fallback:
/// #[derive(HasTweakableHasher)]
/// pub struct MyParty { hasher: MyHasherType, ... }
/// ```
#[proc_macro_derive(HasTweakableHasher, attributes(hasher))]
pub fn derive_has_tweakable_hasher(input: TokenStream) -> TokenStream {
    party_traits::derive_has_tweakable_hasher(input)
}

/// Procedural macro to automatically implement the `Peer` trait for a struct.
///
/// # Requirements
/// - The struct must have a field marked with `#[peer_ctx]` or, if not present, a field named
///   `peer_ctx` whose type is PeerContext.
///
/// # Behavior
/// - Implements the `peer_context(&self)` method, returning a reference to the selected field.
///
/// # Usage
/// ```rust
/// #[derive(Peer)]
/// pub struct MyParty {
///     #[peer_ctx]
///     peer_ctx: PeerContext,
///     // ...
/// }
/// // or, fallback to type matching:
/// #[derive(Peer)]
/// pub struct MyParty { peer_ctx: PeerContext, ... }
/// ```
#[proc_macro_derive(Peer, attributes(peer))]
pub fn derive_peer(input: TokenStream) -> TokenStream {
    party_traits::derive_peer(input)
}