mpc_macros/lib.rs
1/*!
2# Async-MPC Macros
3
4This crate provides procedural macros for the async-mpc library.
5*/
6
7mod define_task;
8mod op_variants;
9mod party_traits;
10mod protocol_tags;
11mod protocol_trait;
12mod public_fn;
13mod task_getters;
14mod task_transform;
15
16use proc_macro::TokenStream;
17
18/// Generates async-mpc task implementations with automatic dependency management.
19///
20/// This macro takes a struct definition and a compute function, and generates:
21/// - An internal unresolved struct to hold task dependencies
22/// - A public type alias using the `TaskWrapper` type
23/// - A Task trait implementation with async execution logic
24/// - A `new` constructor for creating new task instances
25/// - A standalone `compute` method for direct calls
26///
27/// # Examples
28///
29/// ```rust
30/// define_task! {
31/// pub struct FieldAddTask<F: FieldExtension> {
32/// x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
33/// y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
34/// }
35///
36/// async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
37/// Ok(x + &y)
38/// }
39/// }
40/// ```
41#[proc_macro]
42pub fn define_task(input: TokenStream) -> TokenStream {
43 define_task::define_task(input)
44}
45
46/// Derives `try_get_*_task` functions for enum variants containing task types.
47///
48/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
49/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
50/// their index (encoded in the Label's first field) and return the appropriate task type.
51///
52/// # Examples
53///
54/// ```rust
55/// #[derive(TaskGetters)]
56/// pub enum TaskType<C: Curve> {
57/// ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
58/// BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
59/// // ... more variants
60/// }
61/// ```
62///
63/// This generates functions like:
64/// - `try_get_scalar_share_task`
65/// - `try_get_base_field_share_task`
66/// - `try_get_point_share_task`
67///
68/// Each function has the signature:
69/// ```rust
70/// pub fn try_get_<variant_name>_task<C: Curve>(
71/// label: Label,
72/// task_map: &[TaskType<C>],
73/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
74/// ```
75///
76/// ## Requirements
77///
78/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
79/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
80/// - Each variant must have a corresponding `try_as_*` method (typically generated by
81/// `EnumTryAsInner`)
82#[proc_macro_derive(TaskGetters)]
83pub fn derive_task_getters(input: TokenStream) -> TokenStream {
84 task_getters::derive_task_getters(input)
85}
86
87/// Macro to ensure protocol tag uniqueness at compile time.
88/// This macro generates a unique tag for a given protocol name,
89/// and checks against a registry to prevent collisions.
90/// If a collision is detected, it finds the next available tag.
91#[proc_macro]
92pub fn new_protocol(input: TokenStream) -> TokenStream {
93 protocol_tags::new_protocol_info(input)
94}
95
96/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
97#[proc_macro]
98pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
99 protocol_tags::dump_tags(input)
100}
101
102/// Attribute macro to add debug logging to impl blocks of Party structs.
103///
104/// Injects `log::debug!` at the start of each function, logging:
105/// `<{StructName} - {PROTOCOL_INFO.name()}> {fn_name} with {session_id} over {network_info}`
106///
107/// Network detection: `IoSink`/`IoStream` → channel info, `MultipartyInterface` → peer info.
108///
109/// Requires `self.refresh()` call if impl has any `&self`/`&mut self` functions.
110#[proc_macro_attribute]
111pub fn protocol_trait(attr: TokenStream, item: TokenStream) -> TokenStream {
112 protocol_trait::protocol_trait_impl(attr, item)
113}
114
115/// Procedural macro to automatically implement the `Party` trait for a struct.
116///
117/// # Requirements
118/// - The struct must have either:
119/// - A field marked with `#[session_id]` or a field of type `SessionId`, or
120/// - At least one field annotated with `#[party]` (which must itself implement `Party`).
121///
122/// # Behavior
123/// - `session_id()`: Returns a reference to the `SessionId` field, or delegates to the first
124/// `#[party]` field.
125/// - `protocol_name()`: Uses the `PROTOCOL_INFO` constant if a direct `SessionId` field is present,
126/// otherwise formats as `"<TypeName> - <InnerProtocolName>"`.
127/// - `refresh()`: Calls `refresh_from` or `refresh_with` on the session id, and `refresh()` on all
128/// `#[party]` fields. If a `hasher` field is present, it is reset.
129///
130/// # Usage
131/// ```rust
132/// #[derive(Party)]
133/// pub struct MyParty {
134/// #[session_id]
135/// session_id: SessionId,
136/// #[party]
137/// inner: InnerPartyType,
138/// // ...
139/// }
140/// // or, with type/name matching fallback:
141/// #[derive(Party)]
142/// pub struct MyParty { session_id: SessionId, #[party] inner: InnerPartyType, ... }
143/// ```
144#[proc_macro_derive(Party, attributes(party, session_id, transcript, hasher))]
145pub fn derive_party(input: TokenStream) -> TokenStream {
146 party_traits::derive_party(input)
147}
148
149/// Procedural macro to automatically implement the `Probabilistic` trait for a struct.
150///
151/// # Requirements
152/// - The struct must have a field marked with `#[rng]` or, if not present, a field named `rng` that
153/// implements `CryptoRngCore`.
154///
155/// # Behavior
156/// - Implements the `rng(&mut self)` method, returning a mutable reference to the selected field.
157///
158/// # Usage
159/// ```rust
160/// #[derive(Probabilistic)]
161/// pub struct MyParty {
162/// #[rng]
163/// my_rng: MyRngType,
164/// // ...
165/// }
166/// // or, fallback:
167/// #[derive(Probabilistic)]
168/// pub struct MyParty { rng: MyRngType, ... }
169/// ```
170#[proc_macro_derive(Probabilistic, attributes(rng))]
171pub fn derive_probabilistic(input: TokenStream) -> TokenStream {
172 party_traits::derive_probabilistic(input)
173}
174
175/// Procedural macro to automatically implement the `Scribe` trait for a struct.
176///
177/// # Requirements
178/// - The struct must have a field marked with `#[transcript]` or, if not present, a field named
179/// `transcript` that implements `Transcript`.
180///
181/// # Behavior
182/// - Implements the `transcript(&mut self)` method, returning a mutable reference to the selected
183/// field.
184///
185/// # Usage
186/// ```rust
187/// #[derive(Scribe)]
188/// pub struct MyParty {
189/// #[transcript]
190/// my_transcript: MyTranscriptType,
191/// // ...
192/// }
193/// // or, fallback:
194/// #[derive(Scribe)]
195/// pub struct MyParty { transcript: MyTranscriptType, ... }
196/// ```
197#[proc_macro_derive(Scribe, attributes(transcript))]
198pub fn derive_scribe(input: TokenStream) -> TokenStream {
199 party_traits::derive_scribe(input)
200}
201
202/// Procedural macro to automatically implement the `Peer` trait for a struct.
203///
204/// # Requirements
205/// - The struct must have a field marked with `#[peer_ctx]` or, if not present, a field named
206/// `peer_ctx` whose type is PeerContext.
207///
208/// # Behavior
209/// - Implements the `peer_context(&self)` method, returning a reference to the selected field.
210///
211/// # Usage
212/// ```rust
213/// #[derive(Peer)]
214/// pub struct MyParty {
215/// #[peer_ctx]
216/// peer_ctx: PeerContext,
217/// // ...
218/// }
219/// // or, fallback to type matching:
220/// #[derive(Peer)]
221/// pub struct MyParty { peer_ctx: PeerContext, ... }
222/// ```
223#[proc_macro_derive(Peer, attributes(peer_ctx))]
224pub fn derive_peer(input: TokenStream) -> TokenStream {
225 party_traits::derive_peer(input)
226}
227
228/// Procedural macro to expose a private member function for given cfg tags.
229///
230/// This macro creates a public wrapper function with an underscore prefix that
231/// calls the original private function. The wrapper is only compiled when the
232/// specified cfg condition is met (e.g., feature flags or test mode).
233///
234/// # Usage
235/// ```rust
236/// use macros::public;
237///
238/// struct MyStruct;
239///
240/// impl MyStruct {
241/// #[public(feature = "dev")]
242/// fn private_method(&mut self, x: i32) -> i32 {
243/// x * 2
244/// }
245/// }
246///
247/// // Generates:
248/// // #[cfg(feature = "dev")]
249/// // pub fn _private_method(&mut self, x: i32) -> i32 {
250/// // self.private_method(x)
251/// // }
252/// ```
253#[proc_macro_attribute]
254pub fn public(attr: TokenStream, item: TokenStream) -> TokenStream {
255 public_fn::public_fn(attr, item)
256}
257
258/// Procedural macro to generate variants for arithmetic operations.
259///
260/// Supports binary operations (like `Add`), binary assign operations (like
261/// `AddAssign`), and unary operations (like `Neg`).
262///
263/// For **binary operations** (impl with `-> Self::Output`), it requires `Self op &RHS`;
264/// for **unary operations** it assumes `op Self`.
265///
266/// There are three supported variants:
267/// - `owned`: All operands owned - implemented by referencing rhs.
268/// - `borrowed`: All operands borrowed - implemented by cloning Self.
269/// - `flipped`: `&Self op RHS` - implemented by cloning lhs and referencing rhs (binary ops only).
270/// - `flipped_commutative`: `&Self op RHS` - implemented as `rhs.op(lhs)`. To be used only in
271/// commutative ops (e.g., Add & Mul, but not Sub | Div).
272///
273/// # Usage
274/// ```rust
275/// #[derive(Clone, Debug)]
276/// struct Point {
277/// x: i32,
278/// y: i32,
279/// }
280///
281/// // Binary operation: Point + &Point (base impl)
282/// #[macros::op_variants(owned, borrowed, flipped)]
283/// impl std::ops::Add<&Point> for Point {
284/// type Output = Point;
285/// fn add(self, other: &Point) -> Point {
286/// Point {
287/// x: self.x + other.x,
288/// y: self.y + other.y,
289/// }
290/// }
291/// }
292/// // Generates: Point + Point, &Point + &Point, &Point + Point
293///
294/// // Binary assign operation: Point += &Point (base impl)
295/// #[macros::op_variants(owned)]
296/// impl std::ops::AddAssign<&Point> for Point {
297/// fn add_assign(&mut self, other: &Point) {
298/// self.x += other.x;
299/// self.y += other.y;
300/// }
301/// }
302/// // Generates: Point += Point
303///
304/// // Unary operation: -&Point (base impl)
305/// #[macros::op_variants(owned)]
306/// impl std::ops::Neg for &Point {
307/// type Output = Point;
308/// fn neg(self) -> Point {
309/// Point {
310/// x: -self.x,
311/// y: -self.y,
312/// }
313/// }
314/// }
315/// // Generates: -Point
316/// ```
317#[proc_macro_attribute]
318pub fn op_variants(attr: TokenStream, item: TokenStream) -> TokenStream {
319 op_variants::op_variants(attr, item)
320}
321
322/// Attribute macro to transform a sequential function into a multi-round task state machine.
323///
324/// This macro allows writing multi-round MPC tasks as sequential functions, using `open!`
325/// macro calls to mark round boundaries. The macro transforms the function into:
326/// - Round state structs for each round
327/// - A task enum with Round variants + Resolving
328/// - Constructor from function parameters
329/// - State transition methods
330/// - Round logic methods
331/// - Task trait implementation
332///
333/// # Arguments
334///
335/// The macro takes the task name as its argument: `#[task(TaskName)]`
336///
337/// # Example
338///
339/// ```ignore
340/// #[task(MulTask)]
341/// fn mul<F: FieldExtension>(
342/// x: FieldShare<F>,
343/// y: FieldShare<F>,
344/// triple: Triple<F>,
345/// ) -> FieldShare<F>
346/// where
347/// FieldShare<F>: Into<Share>,
348/// Secret: TryInto<SubfieldElement<F>>,
349/// {
350/// let Triple { a, b, c: _ } = &triple;
351/// let epsilon = &x - a;
352/// let delta = &y - b;
353///
354/// // open_field! marks a round boundary - variables used after become state
355/// let (delta, epsilon): (SubfieldElement<F>, SubfieldElement<F>) = open_field!(delta, epsilon);
356///
357/// let Triple { a, b: _, c } = &triple;
358/// c + &y * epsilon + a * delta
359/// }
360/// ```
361///
362/// # Communication Macros
363///
364/// - `open_field!(share)` - Opens a single field share, returning the reconstructed value
365/// - `open_field!(share1, share2)` - Opens multiple field shares, returning a tuple
366/// - `open_binary!(share1, share2)` - Opens binary (Gf2_128) shares; emits `Share::Binary` directly
367/// with no `Into<Share<F>>` or `Secret<F>: Into<T>` bounds
368///
369/// Type annotations on the `let` binding tell the macro what types to extract from `Secret`.
370#[proc_macro_attribute]
371pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream {
372 task_transform::task_transform(attr, item)
373}