mpc_macros/lib.rs
1/*!
2# MPC Macros
3
4This crate provides procedural macros for the async-mpc library:
5
61. `define_task!` - Generates async task implementations with dependency management
72. `TaskGetters` - Derives getter functions for task enum variants
8
9## define_task! Macro
10
11The `define_task!` macro automates the creation of task wrappers, async execution logic,
12and dependency management for MPC computations.
13
14## TaskGetters Derive Macro
15
16The `TaskGetters` derive macro automatically generates `try_get_*_task` functions for
17each variant in a `TaskType` enum, eliminating repetitive boilerplate code.
18
19For detailed documentation and examples, see the individual implementation modules.
20*/
21
22use proc_macro::TokenStream;
23
24mod define_task;
25mod gate_methods;
26mod party_traits;
27mod protocol_tags;
28mod task_getters;
29
30/// Generates async-mpc task implementations with automatic dependency management.
31///
32/// This macro takes a struct definition and a compute function, and generates:
33/// - An internal unresolved struct to hold task dependencies
34/// - A public type alias using the `TaskWrapper` type
35/// - A Task trait implementation with async execution logic
36/// - A `new` constructor for creating new task instances
37/// - A standalone `compute` method for direct calls
38///
39/// # Examples
40///
41/// ```rust
42/// define_task! {
43/// pub struct FieldAddTask<F: FieldExtension> {
44/// x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
45/// y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
46/// }
47///
48/// async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
49/// Ok(x + &y)
50/// }
51/// }
52/// ```
53#[proc_macro]
54pub fn define_task(input: TokenStream) -> TokenStream {
55 define_task::define_task(input)
56}
57
58/// Derives `try_get_*_task` functions for enum variants containing task types.
59///
60/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
61/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
62/// their index (encoded in the Label's first field) and return the appropriate task type.
63///
64/// # Examples
65///
66/// ```rust
67/// #[derive(TaskGetters)]
68/// pub enum TaskType<C: Curve> {
69/// ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
70/// BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
71/// // ... more variants
72/// }
73/// ```
74///
75/// This generates functions like:
76/// - `try_get_scalar_share_task`
77/// - `try_get_base_field_share_task`
78/// - `try_get_point_share_task`
79///
80/// Each function has the signature:
81/// ```rust
82/// pub fn try_get_<variant_name>_task<C: Curve>(
83/// label: Label,
84/// task_map: &[TaskType<C>],
85/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
86/// ```
87///
88/// ## Requirements
89///
90/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
91/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
92/// - Each variant must have a corresponding `try_as_*` method (typically generated by
93/// `EnumTryAsInner`)
94#[proc_macro_derive(TaskGetters)]
95pub fn derive_task_getters(input: TokenStream) -> TokenStream {
96 task_getters::derive_task_getters(input)
97}
98
99/// Derives `Gate::map_labels` and `Gate::for_each_label`
100#[proc_macro_derive(GateMethods)]
101pub fn derive_gate_methods(input: TokenStream) -> TokenStream {
102 gate_methods::derive_gate_methods(input)
103}
104
105/// Macro to ensure protocol tag uniqueness at compile time.
106/// This macro generates a unique tag for a given protocol name,
107/// and checks against a registry to prevent collisions.
108/// If a collision is detected, it finds the next available tag.
109#[proc_macro]
110pub fn new_protocol(input: TokenStream) -> TokenStream {
111 protocol_tags::new_protocol_info(input)
112}
113
114/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
115#[proc_macro]
116pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
117 protocol_tags::dump_tags(input)
118}
119/// Procedural macro to automatically implement the `Party` trait for a struct.
120///
121/// # Requirements
122/// - The struct must have either:
123/// - A field marked with `#[session_id]` or a field of type `SessionId`, or
124/// - At least one field annotated with `#[party]` (which must itself implement `Party`).
125///
126/// # Behavior
127/// - `session_id()`: Returns a reference to the `SessionId` field, or delegates to the first
128/// `#[party]` field.
129/// - `protocol_name()`: Uses the `PROTOCOL_NAME` constant if a direct `SessionId` field is present,
130/// otherwise formats as `"<TypeName> - <InnerProtocolName>"`.
131/// - `refresh()`: Calls `refresh_from` or `refresh_with` on the session id, and `refresh()` on all
132/// `#[party]` fields. If a `hasher` field is present, it is reset.
133///
134/// # Usage
135/// ```rust
136/// #[derive(Party)]
137/// pub struct MyParty {
138/// #[session_id]
139/// session_id: SessionId,
140/// #[party]
141/// inner: InnerPartyType,
142/// // ...
143/// }
144/// // or, with type/name matching fallback:
145/// #[derive(Party)]
146/// pub struct MyParty { session_id: SessionId, #[party] inner: InnerPartyType, ... }
147/// ```
148#[proc_macro_derive(Party, attributes(party, session_id, transcript, hasher))]
149pub fn derive_party(input: TokenStream) -> TokenStream {
150 party_traits::derive_party(input)
151}
152
153/// Procedural macro to automatically implement the `Probabilistic` trait for a struct.
154///
155/// # Requirements
156/// - The struct must have a field marked with `#[rng]` or, if not present, a field named `rng` that
157/// implements `CryptoRngCore`.
158///
159/// # Behavior
160/// - Implements the `rng(&mut self)` method, returning a mutable reference to the selected field.
161///
162/// # Usage
163/// ```rust
164/// #[derive(Probabilistic)]
165/// pub struct MyParty {
166/// #[rng]
167/// my_rng: MyRngType,
168/// // ...
169/// }
170/// // or, fallback:
171/// #[derive(Probabilistic)]
172/// pub struct MyParty { rng: MyRngType, ... }
173/// ```
174#[proc_macro_derive(Probabilistic, attributes(rng))]
175pub fn derive_probabilistic(input: TokenStream) -> TokenStream {
176 party_traits::derive_probabilistic(input)
177}
178
179/// Procedural macro to automatically implement the `Scribe` trait for a struct.
180///
181/// # Requirements
182/// - The struct must have a field marked with `#[transcript]` or, if not present, a field named
183/// `transcript` that implements `Transcript`.
184///
185/// # Behavior
186/// - Implements the `transcript(&mut self)` method, returning a mutable reference to the selected
187/// field.
188///
189/// # Usage
190/// ```rust
191/// #[derive(Scribe)]
192/// pub struct MyParty {
193/// #[transcript]
194/// my_transcript: MyTranscriptType,
195/// // ...
196/// }
197/// // or, fallback:
198/// #[derive(Scribe)]
199/// pub struct MyParty { transcript: MyTranscriptType, ... }
200/// ```
201#[proc_macro_derive(Scribe, attributes(transcript))]
202pub fn derive_scribe(input: TokenStream) -> TokenStream {
203 party_traits::derive_scribe(input)
204}
205
206/// Procedural macro to automatically implement the `HasTweakableHasher` trait for a struct.
207///
208/// # Requirements
209/// - The struct must have a field marked with `#[hasher]` or, if not present, a field named
210/// `hasher` whose type implements `TweakableHasher`.
211///
212/// # Behavior
213/// - Implements the `get_hasher(&mut self)` method, returning a mutable reference to the selected
214/// field.
215/// - Sets the associated type `Hasher` to the type of the selected field.
216///
217/// # Usage
218/// ```rust
219/// #[derive(HasTweakableHasher)]
220/// pub struct MyParty {
221/// #[hasher]
222/// my_hasher: MyHasherType,
223/// // ...
224/// }
225/// // or, fallback:
226/// #[derive(HasTweakableHasher)]
227/// pub struct MyParty { hasher: MyHasherType, ... }
228/// ```
229#[proc_macro_derive(HasTweakableHasher, attributes(hasher))]
230pub fn derive_has_tweakable_hasher(input: TokenStream) -> TokenStream {
231 party_traits::derive_has_tweakable_hasher(input)
232}
233
234/// Procedural macro to automatically implement the `Peer` trait for a struct.
235///
236/// # Requirements
237/// - The struct must have a field marked with `#[peer_ctx]` or, if not present, a field named
238/// `peer_ctx` whose type is PeerContext.
239///
240/// # Behavior
241/// - Implements the `peer_context(&self)` method, returning a reference to the selected field.
242///
243/// # Usage
244/// ```rust
245/// #[derive(Peer)]
246/// pub struct MyParty {
247/// #[peer_ctx]
248/// peer_ctx: PeerContext,
249/// // ...
250/// }
251/// // or, fallback to type matching:
252/// #[derive(Peer)]
253/// pub struct MyParty { peer_ctx: PeerContext, ... }
254/// ```
255#[proc_macro_derive(Peer, attributes(peer))]
256pub fn derive_peer(input: TokenStream) -> TokenStream {
257 party_traits::derive_peer(input)
258}