mpc-macros 0.2.10

Arcium MPC Macros
Documentation
/*!
# MPC Macros

This crate provides procedural macros for the async-mpc library:

1. `define_task!` - Generates async task implementations with dependency management
2. `TaskGetters` - Derives getter functions for task enum variants

## define_task! Macro

The `define_task!` macro automates the creation of task wrappers, async execution logic,
and dependency management for MPC computations.

## TaskGetters Derive Macro

The `TaskGetters` derive macro automatically generates `try_get_*_task` functions for
each variant in a `TaskType` enum, eliminating repetitive boilerplate code.

For detailed documentation and examples, see the individual implementation modules.
*/

use proc_macro::TokenStream;

mod define_task;
mod gate_methods;
mod protocol_tags;
mod task_getters;

/// Generates async-mpc task implementations with automatic dependency management.
///
/// This macro takes a struct definition and a compute function, and generates:
/// - An internal unresolved struct to hold task dependencies
/// - A public type alias using the `TaskWrapper` type
/// - A Task trait implementation with async execution logic
/// - A `new` constructor for creating new task instances
/// - A standalone `compute` method for direct calls
///
/// # Examples
///
/// ```rust
/// define_task! {
///     pub struct FieldAddTask<F: FieldExtension> {
///         x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
///         y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
///     }
///
///     async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
///         Ok(x + &y)
///     }
/// }
/// ```
#[proc_macro]
pub fn define_task(input: TokenStream) -> TokenStream {
    define_task::define_task(input)
}

/// Derives `try_get_*_task` functions for enum variants containing task types.
///
/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
/// their index (encoded in the Label's first field) and return the appropriate task type.
///
/// # Examples
///
/// ```rust
/// #[derive(TaskGetters)]
/// pub enum TaskType<C: Curve> {
///     ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
///     BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
///     // ... more variants
/// }
/// ```
///
/// This generates functions like:
/// - `try_get_scalar_share_task`
/// - `try_get_base_field_share_task`
/// - `try_get_point_share_task`
///
/// Each function has the signature:
/// ```rust
/// pub fn try_get_<variant_name>_task<C: Curve>(
///     label: Label,
///     task_map: &[TaskType<C>],
/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
/// ```
///
/// ## Requirements
///
/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
/// - Each variant must have a corresponding `try_as_*` method (typically generated by
///   `EnumTryAsInner`)
#[proc_macro_derive(TaskGetters)]
pub fn derive_task_getters(input: TokenStream) -> TokenStream {
    task_getters::derive_task_getters(input)
}

/// Derives `Gate::map_labels` and `Gate::for_each_label`
#[proc_macro_derive(GateMethods)]
pub fn derive_gate_methods(input: TokenStream) -> TokenStream {
    gate_methods::derive_gate_methods(input)
}

/// Macro to ensure protocol tag uniqueness at compile time.
/// This macro generates a unique tag for a given protocol name,
/// and checks against a registry to prevent collisions.
/// If a collision is detected, it finds the next available tag.
#[proc_macro]
pub fn new_protocol(input: TokenStream) -> TokenStream {
    protocol_tags::new_protocol_info(input)
}

/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
#[proc_macro]
pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
    protocol_tags::dump_tags(input)
}