mpc_macros/
lib.rs

1/*!
2# Async-MPC Macros
3
4This crate provides procedural macros for the async-mpc library.
5*/
6
7mod define_task;
8mod gate_methods;
9mod party_traits;
10mod protocol_tags;
11mod public_fn;
12mod task_getters;
13
14use proc_macro::TokenStream;
15
16/// Generates async-mpc task implementations with automatic dependency management.
17///
18/// This macro takes a struct definition and a compute function, and generates:
19/// - An internal unresolved struct to hold task dependencies
20/// - A public type alias using the `TaskWrapper` type
21/// - A Task trait implementation with async execution logic
22/// - A `new` constructor for creating new task instances
23/// - A standalone `compute` method for direct calls
24///
25/// # Examples
26///
27/// ```rust
28/// define_task! {
29///     pub struct FieldAddTask<F: FieldExtension> {
30///         x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
31///         y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
32///     }
33///
34///     async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
35///         Ok(x + &y)
36///     }
37/// }
38/// ```
39#[proc_macro]
40pub fn define_task(input: TokenStream) -> TokenStream {
41    define_task::define_task(input)
42}
43
44/// Derives `try_get_*_task` functions for enum variants containing task types.
45///
46/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
47/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
48/// their index (encoded in the Label's first field) and return the appropriate task type.
49///
50/// # Examples
51///
52/// ```rust
53/// #[derive(TaskGetters)]
54/// pub enum TaskType<C: Curve> {
55///     ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
56///     BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
57///     // ... more variants
58/// }
59/// ```
60///
61/// This generates functions like:
62/// - `try_get_scalar_share_task`
63/// - `try_get_base_field_share_task`
64/// - `try_get_point_share_task`
65///
66/// Each function has the signature:
67/// ```rust
68/// pub fn try_get_<variant_name>_task<C: Curve>(
69///     label: Label,
70///     task_map: &[TaskType<C>],
71/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
72/// ```
73///
74/// ## Requirements
75///
76/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
77/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
78/// - Each variant must have a corresponding `try_as_*` method (typically generated by
79///   `EnumTryAsInner`)
80#[proc_macro_derive(TaskGetters)]
81pub fn derive_task_getters(input: TokenStream) -> TokenStream {
82    task_getters::derive_task_getters(input)
83}
84
85/// Derives `Gate::map_labels` and `Gate::for_each_label`
86#[proc_macro_derive(GateMethods)]
87pub fn derive_gate_methods(input: TokenStream) -> TokenStream {
88    gate_methods::derive_gate_methods(input)
89}
90
91/// Macro to ensure protocol tag uniqueness at compile time.
92/// This macro generates a unique tag for a given protocol name,
93/// and checks against a registry to prevent collisions.
94/// If a collision is detected, it finds the next available tag.
95#[proc_macro]
96pub fn new_protocol(input: TokenStream) -> TokenStream {
97    protocol_tags::new_protocol_info(input)
98}
99
100/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
101#[proc_macro]
102pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
103    protocol_tags::dump_tags(input)
104}
105/// Procedural macro to automatically implement the `Party` trait for a struct.
106///
107/// # Requirements
108/// - The struct must have either:
109///   - A field marked with `#[session_id]` or a field of type `SessionId`, or
110///   - At least one field annotated with `#[party]` (which must itself implement `Party`).
111///
112/// # Behavior
113/// - `session_id()`: Returns a reference to the `SessionId` field, or delegates to the first
114///   `#[party]` field.
115/// - `protocol_name()`: Uses the `PROTOCOL_INFO` constant if a direct `SessionId` field is present,
116///   otherwise formats as `"<TypeName> - <InnerProtocolName>"`.
117/// - `refresh()`: Calls `refresh_from` or `refresh_with` on the session id, and `refresh()` on all
118///   `#[party]` fields. If a `hasher` field is present, it is reset.
119///
120/// # Usage
121/// ```rust
122/// #[derive(Party)]
123/// pub struct MyParty {
124///     #[session_id]
125///     session_id: SessionId,
126///     #[party]
127///     inner: InnerPartyType,
128///     // ...
129/// }
130/// // or, with type/name matching fallback:
131/// #[derive(Party)]
132/// pub struct MyParty { session_id: SessionId, #[party] inner: InnerPartyType, ... }
133/// ```
134#[proc_macro_derive(Party, attributes(party, session_id, transcript, hasher))]
135pub fn derive_party(input: TokenStream) -> TokenStream {
136    party_traits::derive_party(input)
137}
138
139/// Procedural macro to automatically implement the `Probabilistic` trait for a struct.
140///
141/// # Requirements
142/// - The struct must have a field marked with `#[rng]` or, if not present, a field named `rng` that
143///   implements `CryptoRngCore`.
144///
145/// # Behavior
146/// - Implements the `rng(&mut self)` method, returning a mutable reference to the selected field.
147///
148/// # Usage
149/// ```rust
150/// #[derive(Probabilistic)]
151/// pub struct MyParty {
152///     #[rng]
153///     my_rng: MyRngType,
154///     // ...
155/// }
156/// // or, fallback:
157/// #[derive(Probabilistic)]
158/// pub struct MyParty { rng: MyRngType, ... }
159/// ```
160#[proc_macro_derive(Probabilistic, attributes(rng))]
161pub fn derive_probabilistic(input: TokenStream) -> TokenStream {
162    party_traits::derive_probabilistic(input)
163}
164
165/// Procedural macro to automatically implement the `Scribe` trait for a struct.
166///
167/// # Requirements
168/// - The struct must have a field marked with `#[transcript]` or, if not present, a field named
169///   `transcript` that implements `Transcript`.
170///
171/// # Behavior
172/// - Implements the `transcript(&mut self)` method, returning a mutable reference to the selected
173///   field.
174///
175/// # Usage
176/// ```rust
177/// #[derive(Scribe)]
178/// pub struct MyParty {
179///     #[transcript]
180///     my_transcript: MyTranscriptType,
181///     // ...
182/// }
183/// // or, fallback:
184/// #[derive(Scribe)]
185/// pub struct MyParty { transcript: MyTranscriptType, ... }
186/// ```
187#[proc_macro_derive(Scribe, attributes(transcript))]
188pub fn derive_scribe(input: TokenStream) -> TokenStream {
189    party_traits::derive_scribe(input)
190}
191
192/// Procedural macro to automatically implement the `HasTweakableHasher` trait for a struct.
193///
194/// # Requirements
195/// - The struct must have a field marked with `#[hasher]` or, if not present, a field named
196///   `hasher` whose type implements `TweakableHasher`.
197///
198/// # Behavior
199/// - Implements the `get_hasher(&mut self)` method, returning a mutable reference to the selected
200///   field.
201/// - Sets the associated type `Hasher` to the type of the selected field.
202///
203/// # Usage
204/// ```rust
205/// #[derive(HasTweakableHasher)]
206/// pub struct MyParty {
207///     #[hasher]
208///     my_hasher: MyHasherType,
209///     // ...
210/// }
211/// // or, fallback:
212/// #[derive(HasTweakableHasher)]
213/// pub struct MyParty { hasher: MyHasherType, ... }
214/// ```
215#[proc_macro_derive(HasTweakableHasher, attributes(hasher))]
216pub fn derive_has_tweakable_hasher(input: TokenStream) -> TokenStream {
217    party_traits::derive_has_tweakable_hasher(input)
218}
219
220/// Procedural macro to automatically implement the `Peer` trait for a struct.
221///
222/// # Requirements
223/// - The struct must have a field marked with `#[peer_ctx]` or, if not present, a field named
224///   `peer_ctx` whose type is PeerContext.
225///
226/// # Behavior
227/// - Implements the `peer_context(&self)` method, returning a reference to the selected field.
228///
229/// # Usage
230/// ```rust
231/// #[derive(Peer)]
232/// pub struct MyParty {
233///     #[peer_ctx]
234///     peer_ctx: PeerContext,
235///     // ...
236/// }
237/// // or, fallback to type matching:
238/// #[derive(Peer)]
239/// pub struct MyParty { peer_ctx: PeerContext, ... }
240/// ```
241#[proc_macro_derive(Peer, attributes(peer))]
242pub fn derive_peer(input: TokenStream) -> TokenStream {
243    party_traits::derive_peer(input)
244}
245
246/// Procedural macro to expose a private member function for given cfg tags.
247///
248/// This macro creates a public wrapper function with an underscore prefix that
249/// calls the original private function. The wrapper is only compiled when the
250/// specified cfg condition is met (e.g., feature flags or test mode).
251///
252/// # Usage
253/// ```rust
254/// use macros::public;
255///
256/// struct MyStruct;
257///
258/// impl MyStruct {
259///     #[public(feature = "dev")]
260///     fn private_method(&mut self, x: i32) -> i32 {
261///         x * 2
262///     }
263/// }
264///
265/// // Generates:
266/// // #[cfg(feature = "dev")]
267/// // pub fn _private_method(&mut self, x: i32) -> i32 {
268/// //     self.private_method(x)
269/// // }
270/// ```
271#[proc_macro_attribute]
272pub fn public(attr: TokenStream, item: TokenStream) -> TokenStream {
273    public_fn::public_fn(attr, item)
274}