Skip to main content

mpc_macros/
lib.rs

1/*!
2# Async-MPC Macros
3
4This crate provides procedural macros for the async-mpc library.
5*/
6
7mod define_task;
8mod op_variants;
9mod party_traits;
10mod protocol_tags;
11mod protocol_trait;
12mod public_fn;
13mod task_getters;
14mod task_transform;
15
16use proc_macro::TokenStream;
17
18/// Generates async-mpc task implementations with automatic dependency management.
19///
20/// This macro takes a struct definition and a compute function, and generates:
21/// - An internal unresolved struct to hold task dependencies
22/// - A public type alias using the `TaskWrapper` type
23/// - A Task trait implementation with async execution logic
24/// - A `new` constructor for creating new task instances
25/// - A standalone `compute` method for direct calls
26///
27/// # Examples
28///
29/// ```rust
30/// define_task! {
31///     pub struct FieldAddTask<F: FieldExtension> {
32///         x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
33///         y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
34///     }
35///
36///     async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
37///         Ok(x + &y)
38///     }
39/// }
40/// ```
41#[proc_macro]
42pub fn define_task(input: TokenStream) -> TokenStream {
43    define_task::define_task(input)
44}
45
46/// Derives `try_get_*_task` functions for enum variants containing task types.
47///
48/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
49/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
50/// their index (encoded in the Label's first field) and return the appropriate task type.
51///
52/// # Examples
53///
54/// ```rust
55/// #[derive(TaskGetters)]
56/// pub enum TaskType<C: Curve> {
57///     ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
58///     BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
59///     // ... more variants
60/// }
61/// ```
62///
63/// This generates functions like:
64/// - `try_get_scalar_share_task`
65/// - `try_get_base_field_share_task`
66/// - `try_get_point_share_task`
67///
68/// Each function has the signature:
69/// ```rust
70/// pub fn try_get_<variant_name>_task<C: Curve>(
71///     label: Label,
72///     task_map: &[TaskType<C>],
73/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
74/// ```
75///
76/// ## Requirements
77///
78/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
79/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
80/// - Each variant must have a corresponding `try_as_*` method (typically generated by
81///   `EnumTryAsInner`)
82#[proc_macro_derive(TaskGetters)]
83pub fn derive_task_getters(input: TokenStream) -> TokenStream {
84    task_getters::derive_task_getters(input)
85}
86
87/// Macro to ensure protocol tag uniqueness at compile time.
88/// This macro generates a unique tag for a given protocol name,
89/// and checks against a registry to prevent collisions.
90/// If a collision is detected, it finds the next available tag.
91#[proc_macro]
92pub fn new_protocol(input: TokenStream) -> TokenStream {
93    protocol_tags::new_protocol_info(input)
94}
95
96/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
97#[proc_macro]
98pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
99    protocol_tags::dump_tags(input)
100}
101
102/// Attribute macro to add debug logging to impl blocks of Party structs.
103///
104/// Injects `log::debug!` at the start of each function, logging:
105/// `<{StructName} - {PROTOCOL_INFO.name()}> {fn_name} with {session_id} over {network_info}`
106///
107/// Network detection: `IoSink`/`IoStream` → channel info, `MultipartyInterface` → peer info.
108///
109/// Requires `self.refresh()` call if impl has any `&self`/`&mut self` functions.
110#[proc_macro_attribute]
111pub fn protocol_trait(attr: TokenStream, item: TokenStream) -> TokenStream {
112    protocol_trait::protocol_trait_impl(attr, item)
113}
114
115/// Procedural macro to automatically implement the `Party` trait for a struct.
116///
117/// # Requirements
118/// - The struct must have either:
119///   - A field marked with `#[session_id]` or a field of type `SessionId`, or
120///   - At least one field annotated with `#[party]` (which must itself implement `Party`).
121///
122/// # Behavior
123/// - `session_id()`: Returns a reference to the `SessionId` field, or delegates to the first
124///   `#[party]` field.
125/// - `protocol_name()`: Uses the `PROTOCOL_INFO` constant if a direct `SessionId` field is present,
126///   otherwise formats as `"<TypeName> - <InnerProtocolName>"`.
127/// - `refresh()`: Calls `refresh_from` or `refresh_with` on the session id, and `refresh()` on all
128///   `#[party]` fields. If a `hasher` field is present, it is reset.
129///
130/// # Usage
131/// ```rust
132/// #[derive(Party)]
133/// pub struct MyParty {
134///     #[session_id]
135///     session_id: SessionId,
136///     #[party]
137///     inner: InnerPartyType,
138///     // ...
139/// }
140/// // or, with type/name matching fallback:
141/// #[derive(Party)]
142/// pub struct MyParty { session_id: SessionId, #[party] inner: InnerPartyType, ... }
143/// ```
144#[proc_macro_derive(Party, attributes(party, session_id, transcript, hasher))]
145pub fn derive_party(input: TokenStream) -> TokenStream {
146    party_traits::derive_party(input)
147}
148
149/// Procedural macro to automatically implement the `Probabilistic` trait for a struct.
150///
151/// # Requirements
152/// - The struct must have a field marked with `#[rng]` or, if not present, a field named `rng` that
153///   implements `CryptoRngCore`.
154///
155/// # Behavior
156/// - Implements the `rng(&mut self)` method, returning a mutable reference to the selected field.
157///
158/// # Usage
159/// ```rust
160/// #[derive(Probabilistic)]
161/// pub struct MyParty {
162///     #[rng]
163///     my_rng: MyRngType,
164///     // ...
165/// }
166/// // or, fallback:
167/// #[derive(Probabilistic)]
168/// pub struct MyParty { rng: MyRngType, ... }
169/// ```
170#[proc_macro_derive(Probabilistic, attributes(rng))]
171pub fn derive_probabilistic(input: TokenStream) -> TokenStream {
172    party_traits::derive_probabilistic(input)
173}
174
175/// Procedural macro to automatically implement the `Scribe` trait for a struct.
176///
177/// # Requirements
178/// - The struct must have a field marked with `#[transcript]` or, if not present, a field named
179///   `transcript` that implements `Transcript`.
180///
181/// # Behavior
182/// - Implements the `transcript(&mut self)` method, returning a mutable reference to the selected
183///   field.
184///
185/// # Usage
186/// ```rust
187/// #[derive(Scribe)]
188/// pub struct MyParty {
189///     #[transcript]
190///     my_transcript: MyTranscriptType,
191///     // ...
192/// }
193/// // or, fallback:
194/// #[derive(Scribe)]
195/// pub struct MyParty { transcript: MyTranscriptType, ... }
196/// ```
197#[proc_macro_derive(Scribe, attributes(transcript))]
198pub fn derive_scribe(input: TokenStream) -> TokenStream {
199    party_traits::derive_scribe(input)
200}
201
202/// Procedural macro to automatically implement the `Peer` trait for a struct.
203///
204/// # Requirements
205/// - The struct must have a field marked with `#[peer_ctx]` or, if not present, a field named
206///   `peer_ctx` whose type is PeerContext.
207///
208/// # Behavior
209/// - Implements the `peer_context(&self)` method, returning a reference to the selected field.
210///
211/// # Usage
212/// ```rust
213/// #[derive(Peer)]
214/// pub struct MyParty {
215///     #[peer_ctx]
216///     peer_ctx: PeerContext,
217///     // ...
218/// }
219/// // or, fallback to type matching:
220/// #[derive(Peer)]
221/// pub struct MyParty { peer_ctx: PeerContext, ... }
222/// ```
223#[proc_macro_derive(Peer, attributes(peer_ctx))]
224pub fn derive_peer(input: TokenStream) -> TokenStream {
225    party_traits::derive_peer(input)
226}
227
228/// Procedural macro to expose a private member function for given cfg tags.
229///
230/// This macro creates a public wrapper function with an underscore prefix that
231/// calls the original private function. The wrapper is only compiled when the
232/// specified cfg condition is met (e.g., feature flags or test mode).
233///
234/// # Usage
235/// ```rust
236/// use macros::public;
237///
238/// struct MyStruct;
239///
240/// impl MyStruct {
241///     #[public(feature = "dev")]
242///     fn private_method(&mut self, x: i32) -> i32 {
243///         x * 2
244///     }
245/// }
246///
247/// // Generates:
248/// // #[cfg(feature = "dev")]
249/// // pub fn _private_method(&mut self, x: i32) -> i32 {
250/// //     self.private_method(x)
251/// // }
252/// ```
253#[proc_macro_attribute]
254pub fn public(attr: TokenStream, item: TokenStream) -> TokenStream {
255    public_fn::public_fn(attr, item)
256}
257
258/// Procedural macro to generate variants for arithmetic operations.
259///
260/// Supports binary operations (like `Add`), binary assign operations (like
261/// `AddAssign`), and unary operations (like `Neg`).
262///
263/// For **binary operations** (impl with `-> Self::Output`), it requires `Self op &RHS`;
264/// for **unary operations** it assumes `op Self`.
265///
266/// There are three supported variants:
267/// - `owned`: All operands owned - implemented by referencing rhs.
268/// - `borrowed`: All operands borrowed - implemented by cloning Self.
269/// - `flipped`: `&Self op RHS` - implemented by cloning lhs and referencing rhs (binary ops only).
270/// - `flipped_commutative`:  `&Self op RHS` - implemented as `rhs.op(lhs)`. To be used only in
271///   commutative ops (e.g., Add & Mul, but not Sub | Div).
272///
273/// # Usage
274/// ```rust
275/// #[derive(Clone, Debug)]
276/// struct Point {
277///     x: i32,
278///     y: i32,
279/// }
280///
281/// // Binary operation: Point + &Point (base impl)
282/// #[macros::op_variants(owned, borrowed, flipped)]
283/// impl std::ops::Add<&Point> for Point {
284///     type Output = Point;
285///     fn add(self, other: &Point) -> Point {
286///         Point {
287///             x: self.x + other.x,
288///             y: self.y + other.y,
289///         }
290///     }
291/// }
292/// // Generates: Point + Point, &Point + &Point, &Point + Point
293///
294/// // Binary assign operation: Point += &Point (base impl)
295/// #[macros::op_variants(owned)]
296/// impl std::ops::AddAssign<&Point> for Point {
297///     fn add_assign(&mut self, other: &Point) {
298///         self.x += other.x;
299///         self.y += other.y;
300///     }
301/// }
302/// // Generates: Point += Point
303///
304/// // Unary operation: -&Point (base impl)
305/// #[macros::op_variants(owned)]
306/// impl std::ops::Neg for &Point {
307///     type Output = Point;
308///     fn neg(self) -> Point {
309///         Point {
310///             x: -self.x,
311///             y: -self.y,
312///         }
313///     }
314/// }
315/// // Generates: -Point
316/// ```
317#[proc_macro_attribute]
318pub fn op_variants(attr: TokenStream, item: TokenStream) -> TokenStream {
319    op_variants::op_variants(attr, item)
320}
321
322/// Attribute macro to transform a sequential function into a multi-round task state machine.
323///
324/// This macro allows writing multi-round MPC tasks as sequential functions, using `open!`
325/// macro calls to mark round boundaries. The macro transforms the function into:
326/// - Round state structs for each round
327/// - A task enum with Round variants + Resolving
328/// - Constructor from function parameters
329/// - State transition methods
330/// - Round logic methods
331/// - Task trait implementation
332///
333/// # Arguments
334///
335/// The macro takes the task name as its argument: `#[task(TaskName)]`
336///
337/// # Example
338///
339/// ```ignore
340/// #[task(MulTask)]
341/// fn mul<F: FieldExtension>(
342///     x: FieldShare<F>,
343///     y: FieldShare<F>,
344///     triple: Triple<F>,
345/// ) -> FieldShare<F>
346/// where
347///     FieldShare<F>: Into<Share>,
348///     Secret: TryInto<SubfieldElement<F>>,
349/// {
350///     let Triple { a, b, c: _ } = &triple;
351///     let epsilon = &x - a;
352///     let delta = &y - b;
353///
354///     // open_field! marks a round boundary - variables used after become state
355///     let (delta, epsilon): (SubfieldElement<F>, SubfieldElement<F>) = open_field!(delta, epsilon);
356///
357///     let Triple { a, b: _, c } = &triple;
358///     c + &y * epsilon + a * delta
359/// }
360/// ```
361///
362/// # Communication Macros
363///
364/// - `open_field!(share)` - Opens a single field share, returning the reconstructed value
365/// - `open_field!(share1, share2)` - Opens multiple field shares, returning a tuple
366/// - `open_binary!(share1, share2)` - Opens binary (Gf2_128) shares; emits `Share::Binary` directly
367///   with no `Into<Share<F>>` or `Secret<F>: Into<T>` bounds
368///
369/// Type annotations on the `let` binding tell the macro what types to extract from `Secret`.
370#[proc_macro_attribute]
371pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream {
372    task_transform::task_transform(attr, item)
373}