mpc_macros/
lib.rs

1/*!
2# Async-MPC Macros
3
4This crate provides procedural macros for the async-mpc library.
5*/
6
7mod define_task;
8mod gate_methods;
9mod op_variants;
10mod party_traits;
11mod protocol_tags;
12mod public_fn;
13mod task_getters;
14
15use proc_macro::TokenStream;
16
17/// Generates async-mpc task implementations with automatic dependency management.
18///
19/// This macro takes a struct definition and a compute function, and generates:
20/// - An internal unresolved struct to hold task dependencies
21/// - A public type alias using the `TaskWrapper` type
22/// - A Task trait implementation with async execution logic
23/// - A `new` constructor for creating new task instances
24/// - A standalone `compute` method for direct calls
25///
26/// # Examples
27///
28/// ```rust
29/// define_task! {
30///     pub struct FieldAddTask<F: FieldExtension> {
31///         x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
32///         y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
33///     }
34///
35///     async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
36///         Ok(x + &y)
37///     }
38/// }
39/// ```
40#[proc_macro]
41pub fn define_task(input: TokenStream) -> TokenStream {
42    define_task::define_task(input)
43}
44
45/// Derives `try_get_*_task` functions for enum variants containing task types.
46///
47/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
48/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
49/// their index (encoded in the Label's first field) and return the appropriate task type.
50///
51/// # Examples
52///
53/// ```rust
54/// #[derive(TaskGetters)]
55/// pub enum TaskType<C: Curve> {
56///     ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
57///     BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
58///     // ... more variants
59/// }
60/// ```
61///
62/// This generates functions like:
63/// - `try_get_scalar_share_task`
64/// - `try_get_base_field_share_task`
65/// - `try_get_point_share_task`
66///
67/// Each function has the signature:
68/// ```rust
69/// pub fn try_get_<variant_name>_task<C: Curve>(
70///     label: Label,
71///     task_map: &[TaskType<C>],
72/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
73/// ```
74///
75/// ## Requirements
76///
77/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
78/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
79/// - Each variant must have a corresponding `try_as_*` method (typically generated by
80///   `EnumTryAsInner`)
81#[proc_macro_derive(TaskGetters)]
82pub fn derive_task_getters(input: TokenStream) -> TokenStream {
83    task_getters::derive_task_getters(input)
84}
85
86/// Derives `Gate::map_labels` and `Gate::for_each_label`
87#[proc_macro_derive(GateMethods)]
88pub fn derive_gate_methods(input: TokenStream) -> TokenStream {
89    gate_methods::derive_gate_methods(input)
90}
91
92/// Macro to ensure protocol tag uniqueness at compile time.
93/// This macro generates a unique tag for a given protocol name,
94/// and checks against a registry to prevent collisions.
95/// If a collision is detected, it finds the next available tag.
96#[proc_macro]
97pub fn new_protocol(input: TokenStream) -> TokenStream {
98    protocol_tags::new_protocol_info(input)
99}
100
101/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
102#[proc_macro]
103pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
104    protocol_tags::dump_tags(input)
105}
106/// Procedural macro to automatically implement the `Party` trait for a struct.
107///
108/// # Requirements
109/// - The struct must have either:
110///   - A field marked with `#[session_id]` or a field of type `SessionId`, or
111///   - At least one field annotated with `#[party]` (which must itself implement `Party`).
112///
113/// # Behavior
114/// - `session_id()`: Returns a reference to the `SessionId` field, or delegates to the first
115///   `#[party]` field.
116/// - `protocol_name()`: Uses the `PROTOCOL_INFO` constant if a direct `SessionId` field is present,
117///   otherwise formats as `"<TypeName> - <InnerProtocolName>"`.
118/// - `refresh()`: Calls `refresh_from` or `refresh_with` on the session id, and `refresh()` on all
119///   `#[party]` fields. If a `hasher` field is present, it is reset.
120///
121/// # Usage
122/// ```rust
123/// #[derive(Party)]
124/// pub struct MyParty {
125///     #[session_id]
126///     session_id: SessionId,
127///     #[party]
128///     inner: InnerPartyType,
129///     // ...
130/// }
131/// // or, with type/name matching fallback:
132/// #[derive(Party)]
133/// pub struct MyParty { session_id: SessionId, #[party] inner: InnerPartyType, ... }
134/// ```
135#[proc_macro_derive(Party, attributes(party, session_id, transcript, hasher))]
136pub fn derive_party(input: TokenStream) -> TokenStream {
137    party_traits::derive_party(input)
138}
139
140/// Procedural macro to automatically implement the `Probabilistic` trait for a struct.
141///
142/// # Requirements
143/// - The struct must have a field marked with `#[rng]` or, if not present, a field named `rng` that
144///   implements `CryptoRngCore`.
145///
146/// # Behavior
147/// - Implements the `rng(&mut self)` method, returning a mutable reference to the selected field.
148///
149/// # Usage
150/// ```rust
151/// #[derive(Probabilistic)]
152/// pub struct MyParty {
153///     #[rng]
154///     my_rng: MyRngType,
155///     // ...
156/// }
157/// // or, fallback:
158/// #[derive(Probabilistic)]
159/// pub struct MyParty { rng: MyRngType, ... }
160/// ```
161#[proc_macro_derive(Probabilistic, attributes(rng))]
162pub fn derive_probabilistic(input: TokenStream) -> TokenStream {
163    party_traits::derive_probabilistic(input)
164}
165
166/// Procedural macro to automatically implement the `Scribe` trait for a struct.
167///
168/// # Requirements
169/// - The struct must have a field marked with `#[transcript]` or, if not present, a field named
170///   `transcript` that implements `Transcript`.
171///
172/// # Behavior
173/// - Implements the `transcript(&mut self)` method, returning a mutable reference to the selected
174///   field.
175///
176/// # Usage
177/// ```rust
178/// #[derive(Scribe)]
179/// pub struct MyParty {
180///     #[transcript]
181///     my_transcript: MyTranscriptType,
182///     // ...
183/// }
184/// // or, fallback:
185/// #[derive(Scribe)]
186/// pub struct MyParty { transcript: MyTranscriptType, ... }
187/// ```
188#[proc_macro_derive(Scribe, attributes(transcript))]
189pub fn derive_scribe(input: TokenStream) -> TokenStream {
190    party_traits::derive_scribe(input)
191}
192
193/// Procedural macro to automatically implement the `HasTweakableHasher` trait for a struct.
194///
195/// # Requirements
196/// - The struct must have a field marked with `#[hasher]` or, if not present, a field named
197///   `hasher` whose type implements `TweakableHasher`.
198///
199/// # Behavior
200/// - Implements the `get_hasher(&mut self)` method, returning a mutable reference to the selected
201///   field.
202/// - Sets the associated type `Hasher` to the type of the selected field.
203///
204/// # Usage
205/// ```rust
206/// #[derive(HasTweakableHasher)]
207/// pub struct MyParty {
208///     #[hasher]
209///     my_hasher: MyHasherType,
210///     // ...
211/// }
212/// // or, fallback:
213/// #[derive(HasTweakableHasher)]
214/// pub struct MyParty { hasher: MyHasherType, ... }
215/// ```
216#[proc_macro_derive(HasTweakableHasher, attributes(hasher))]
217pub fn derive_has_tweakable_hasher(input: TokenStream) -> TokenStream {
218    party_traits::derive_has_tweakable_hasher(input)
219}
220
221/// Procedural macro to automatically implement the `Peer` trait for a struct.
222///
223/// # Requirements
224/// - The struct must have a field marked with `#[peer_ctx]` or, if not present, a field named
225///   `peer_ctx` whose type is PeerContext.
226///
227/// # Behavior
228/// - Implements the `peer_context(&self)` method, returning a reference to the selected field.
229///
230/// # Usage
231/// ```rust
232/// #[derive(Peer)]
233/// pub struct MyParty {
234///     #[peer_ctx]
235///     peer_ctx: PeerContext,
236///     // ...
237/// }
238/// // or, fallback to type matching:
239/// #[derive(Peer)]
240/// pub struct MyParty { peer_ctx: PeerContext, ... }
241/// ```
242#[proc_macro_derive(Peer, attributes(peer))]
243pub fn derive_peer(input: TokenStream) -> TokenStream {
244    party_traits::derive_peer(input)
245}
246
247/// Procedural macro to expose a private member function for given cfg tags.
248///
249/// This macro creates a public wrapper function with an underscore prefix that
250/// calls the original private function. The wrapper is only compiled when the
251/// specified cfg condition is met (e.g., feature flags or test mode).
252///
253/// # Usage
254/// ```rust
255/// use macros::public;
256///
257/// struct MyStruct;
258///
259/// impl MyStruct {
260///     #[public(feature = "dev")]
261///     fn private_method(&mut self, x: i32) -> i32 {
262///         x * 2
263///     }
264/// }
265///
266/// // Generates:
267/// // #[cfg(feature = "dev")]
268/// // pub fn _private_method(&mut self, x: i32) -> i32 {
269/// //     self.private_method(x)
270/// // }
271/// ```
272#[proc_macro_attribute]
273pub fn public(attr: TokenStream, item: TokenStream) -> TokenStream {
274    public_fn::public_fn(attr, item)
275}
276
277/// Procedural macro to generate variants for arithmetic operations.
278///
279/// Supports binary operations (like `Add`), binary assign operations (like
280/// `AddAssign`), and unary operations (like `Neg`).
281///
282/// For **binary operations** (impl with `-> Self::Output`), it requires `Self op &RHS`;
283/// for **unary operations** it assumes `op Self`.
284///
285/// There are three supported variants:
286/// - `owned`: All operands owned - implemented by referencing rhs.
287/// - `borrowed`: All operands borrowed - implemented by cloning Self.
288/// - `flipped`: `&Self op RHS` - implemented by cloning lhs and referencing rhs (binary ops only).
289/// - `flipped_commutative`:  `&Self op RHS` - implemented as `rhs.op(lhs)`. To be used only in
290///   commutative ops (e.g., Add & Mul, but not Sub | Div).
291///
292/// # Usage
293/// ```rust
294/// #[derive(Clone, Debug)]
295/// struct Point {
296///     x: i32,
297///     y: i32,
298/// }
299///
300/// // Binary operation: Point + &Point (base impl)
301/// #[macros::arithmetic_variants(owned, borrowed, flipped)]
302/// impl std::ops::Add<&Point> for Point {
303///     type Output = Point;
304///     fn add(self, other: &Point) -> Point {
305///         Point {
306///             x: self.x + other.x,
307///             y: self.y + other.y,
308///         }
309///     }
310/// }
311/// // Generates: Point + Point, &Point + &Point, &Point + Point
312///
313/// // Binary assign operation: Point += &Point (base impl)
314/// #[macros::arithmetic_variants(owned)]
315/// impl std::ops::AddAssign<&Point> for Point {
316///     fn add_assign(&mut self, other: &Point) {
317///         self.x += other.x;
318///         self.y += other.y;
319///     }
320/// }
321/// // Generates: Point += Point
322///
323/// // Unary operation: -&Point (base impl)
324/// #[macros::arithmetic_variants(owned)]
325/// impl std::ops::Neg for &Point {
326///     type Output = Point;
327///     fn neg(self) -> Point {
328///         Point {
329///             x: -self.x,
330///             y: -self.y,
331///         }
332///     }
333/// }
334/// // Generates: -Point
335/// ```
336#[proc_macro_attribute]
337pub fn op_variants(attr: TokenStream, item: TokenStream) -> TokenStream {
338    op_variants::op_variants(attr, item)
339}