mpc_macros/
lib.rs

1/*!
2# MPC Macros
3
4This crate provides procedural macros for the async-mpc library:
5
61. `define_task!` - Generates async task implementations with dependency management
72. `TaskGetters` - Derives getter functions for task enum variants
8
9## define_task! Macro
10
11The `define_task!` macro automates the creation of task wrappers, async execution logic,
12and dependency management for MPC computations.
13
14## TaskGetters Derive Macro
15
16The `TaskGetters` derive macro automatically generates `try_get_*_task` functions for
17each variant in a `TaskType` enum, eliminating repetitive boilerplate code.
18
19For detailed documentation and examples, see the individual implementation modules.
20*/
21
22use proc_macro::TokenStream;
23
24mod define_task;
25mod gate_methods;
26mod protocol_tags;
27mod task_getters;
28
29/// Generates async-mpc task implementations with automatic dependency management.
30///
31/// This macro takes a struct definition and a compute function, and generates:
32/// - An internal unresolved struct to hold task dependencies
33/// - A public type alias using the `TaskWrapper` type
34/// - A Task trait implementation with async execution logic
35/// - A `new` constructor for creating new task instances
36/// - A standalone `compute` method for direct calls
37///
38/// # Examples
39///
40/// ```rust
41/// define_task! {
42///     pub struct FieldAddTask<F: FieldExtension> {
43///         x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
44///         y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
45///     }
46///
47///     async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
48///         Ok(x + &y)
49///     }
50/// }
51/// ```
52#[proc_macro]
53pub fn define_task(input: TokenStream) -> TokenStream {
54    define_task::define_task(input)
55}
56
57/// Derives `try_get_*_task` functions for enum variants containing task types.
58///
59/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
60/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
61/// their index (encoded in the Label's first field) and return the appropriate task type.
62///
63/// # Examples
64///
65/// ```rust
66/// #[derive(TaskGetters)]
67/// pub enum TaskType<C: Curve> {
68///     ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
69///     BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
70///     // ... more variants
71/// }
72/// ```
73///
74/// This generates functions like:
75/// - `try_get_scalar_share_task`
76/// - `try_get_base_field_share_task`
77/// - `try_get_point_share_task`
78///
79/// Each function has the signature:
80/// ```rust
81/// pub fn try_get_<variant_name>_task<C: Curve>(
82///     label: Label,
83///     task_map: &[TaskType<C>],
84/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
85/// ```
86///
87/// ## Requirements
88///
89/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
90/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
91/// - Each variant must have a corresponding `try_as_*` method (typically generated by
92///   `EnumTryAsInner`)
93#[proc_macro_derive(TaskGetters)]
94pub fn derive_task_getters(input: TokenStream) -> TokenStream {
95    task_getters::derive_task_getters(input)
96}
97
98/// Derives `Gate::map_labels` and `Gate::for_each_label`
99#[proc_macro_derive(GateMethods)]
100pub fn derive_gate_methods(input: TokenStream) -> TokenStream {
101    gate_methods::derive_gate_methods(input)
102}
103
104/// Macro to ensure protocol tag uniqueness at compile time.
105/// This macro generates a unique tag for a given protocol name,
106/// and checks against a registry to prevent collisions.
107/// If a collision is detected, it finds the next available tag.
108#[proc_macro]
109pub fn new_protocol(input: TokenStream) -> TokenStream {
110    protocol_tags::new_protocol_info(input)
111}
112
113/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
114#[proc_macro]
115pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
116    protocol_tags::dump_tags(input)
117}