mpc_macros/lib.rs
1/*!
2# Async-MPC Macros
3
4This crate provides procedural macros for the async-mpc library.
5*/
6
7mod define_task;
8mod op_variants;
9mod party_traits;
10mod protocol_tags;
11mod protocol_trait;
12mod public_fn;
13mod task_getters;
14
15use proc_macro::TokenStream;
16
17/// Generates async-mpc task implementations with automatic dependency management.
18///
19/// This macro takes a struct definition and a compute function, and generates:
20/// - An internal unresolved struct to hold task dependencies
21/// - A public type alias using the `TaskWrapper` type
22/// - A Task trait implementation with async execution logic
23/// - A `new` constructor for creating new task instances
24/// - A standalone `compute` method for direct calls
25///
26/// # Examples
27///
28/// ```rust
29/// define_task! {
30/// pub struct FieldAddTask<F: FieldExtension> {
31/// x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
32/// y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
33/// }
34///
35/// async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
36/// Ok(x + &y)
37/// }
38/// }
39/// ```
40#[proc_macro]
41pub fn define_task(input: TokenStream) -> TokenStream {
42 define_task::define_task(input)
43}
44
45/// Derives `try_get_*_task` functions for enum variants containing task types.
46///
47/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
48/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
49/// their index (encoded in the Label's first field) and return the appropriate task type.
50///
51/// # Examples
52///
53/// ```rust
54/// #[derive(TaskGetters)]
55/// pub enum TaskType<C: Curve> {
56/// ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
57/// BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
58/// // ... more variants
59/// }
60/// ```
61///
62/// This generates functions like:
63/// - `try_get_scalar_share_task`
64/// - `try_get_base_field_share_task`
65/// - `try_get_point_share_task`
66///
67/// Each function has the signature:
68/// ```rust
69/// pub fn try_get_<variant_name>_task<C: Curve>(
70/// label: Label,
71/// task_map: &[TaskType<C>],
72/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
73/// ```
74///
75/// ## Requirements
76///
77/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
78/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
79/// - Each variant must have a corresponding `try_as_*` method (typically generated by
80/// `EnumTryAsInner`)
81#[proc_macro_derive(TaskGetters)]
82pub fn derive_task_getters(input: TokenStream) -> TokenStream {
83 task_getters::derive_task_getters(input)
84}
85
86/// Macro to ensure protocol tag uniqueness at compile time.
87/// This macro generates a unique tag for a given protocol name,
88/// and checks against a registry to prevent collisions.
89/// If a collision is detected, it finds the next available tag.
90#[proc_macro]
91pub fn new_protocol(input: TokenStream) -> TokenStream {
92 protocol_tags::new_protocol_info(input)
93}
94
95/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
96#[proc_macro]
97pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
98 protocol_tags::dump_tags(input)
99}
100
101/// Attribute macro to add debug logging to impl blocks of Party structs.
102///
103/// Injects `log::debug!` at the start of each function, logging:
104/// `<{StructName} - {PROTOCOL_INFO.name()}> {fn_name} with {session_id} over {network_info}`
105///
106/// Network detection: `IoSink`/`IoStream` → channel info, `MultipartyInterface` → peer info.
107///
108/// Requires `self.refresh()` call if impl has any `&self`/`&mut self` functions.
109#[proc_macro_attribute]
110pub fn protocol_trait(attr: TokenStream, item: TokenStream) -> TokenStream {
111 protocol_trait::protocol_trait_impl(attr, item)
112}
113
114/// Procedural macro to automatically implement the `Party` trait for a struct.
115///
116/// # Requirements
117/// - The struct must have either:
118/// - A field marked with `#[session_id]` or a field of type `SessionId`, or
119/// - At least one field annotated with `#[party]` (which must itself implement `Party`).
120///
121/// # Behavior
122/// - `session_id()`: Returns a reference to the `SessionId` field, or delegates to the first
123/// `#[party]` field.
124/// - `protocol_name()`: Uses the `PROTOCOL_INFO` constant if a direct `SessionId` field is present,
125/// otherwise formats as `"<TypeName> - <InnerProtocolName>"`.
126/// - `refresh()`: Calls `refresh_from` or `refresh_with` on the session id, and `refresh()` on all
127/// `#[party]` fields. If a `hasher` field is present, it is reset.
128///
129/// # Usage
130/// ```rust
131/// #[derive(Party)]
132/// pub struct MyParty {
133/// #[session_id]
134/// session_id: SessionId,
135/// #[party]
136/// inner: InnerPartyType,
137/// // ...
138/// }
139/// // or, with type/name matching fallback:
140/// #[derive(Party)]
141/// pub struct MyParty { session_id: SessionId, #[party] inner: InnerPartyType, ... }
142/// ```
143#[proc_macro_derive(Party, attributes(party, session_id, transcript, hasher))]
144pub fn derive_party(input: TokenStream) -> TokenStream {
145 party_traits::derive_party(input)
146}
147
148/// Procedural macro to automatically implement the `Probabilistic` trait for a struct.
149///
150/// # Requirements
151/// - The struct must have a field marked with `#[rng]` or, if not present, a field named `rng` that
152/// implements `CryptoRngCore`.
153///
154/// # Behavior
155/// - Implements the `rng(&mut self)` method, returning a mutable reference to the selected field.
156///
157/// # Usage
158/// ```rust
159/// #[derive(Probabilistic)]
160/// pub struct MyParty {
161/// #[rng]
162/// my_rng: MyRngType,
163/// // ...
164/// }
165/// // or, fallback:
166/// #[derive(Probabilistic)]
167/// pub struct MyParty { rng: MyRngType, ... }
168/// ```
169#[proc_macro_derive(Probabilistic, attributes(rng))]
170pub fn derive_probabilistic(input: TokenStream) -> TokenStream {
171 party_traits::derive_probabilistic(input)
172}
173
174/// Procedural macro to automatically implement the `Scribe` trait for a struct.
175///
176/// # Requirements
177/// - The struct must have a field marked with `#[transcript]` or, if not present, a field named
178/// `transcript` that implements `Transcript`.
179///
180/// # Behavior
181/// - Implements the `transcript(&mut self)` method, returning a mutable reference to the selected
182/// field.
183///
184/// # Usage
185/// ```rust
186/// #[derive(Scribe)]
187/// pub struct MyParty {
188/// #[transcript]
189/// my_transcript: MyTranscriptType,
190/// // ...
191/// }
192/// // or, fallback:
193/// #[derive(Scribe)]
194/// pub struct MyParty { transcript: MyTranscriptType, ... }
195/// ```
196#[proc_macro_derive(Scribe, attributes(transcript))]
197pub fn derive_scribe(input: TokenStream) -> TokenStream {
198 party_traits::derive_scribe(input)
199}
200
201/// Procedural macro to automatically implement the `Peer` trait for a struct.
202///
203/// # Requirements
204/// - The struct must have a field marked with `#[peer_ctx]` or, if not present, a field named
205/// `peer_ctx` whose type is PeerContext.
206///
207/// # Behavior
208/// - Implements the `peer_context(&self)` method, returning a reference to the selected field.
209///
210/// # Usage
211/// ```rust
212/// #[derive(Peer)]
213/// pub struct MyParty {
214/// #[peer_ctx]
215/// peer_ctx: PeerContext,
216/// // ...
217/// }
218/// // or, fallback to type matching:
219/// #[derive(Peer)]
220/// pub struct MyParty { peer_ctx: PeerContext, ... }
221/// ```
222#[proc_macro_derive(Peer, attributes(peer_ctx))]
223pub fn derive_peer(input: TokenStream) -> TokenStream {
224 party_traits::derive_peer(input)
225}
226
227/// Procedural macro to expose a private member function for given cfg tags.
228///
229/// This macro creates a public wrapper function with an underscore prefix that
230/// calls the original private function. The wrapper is only compiled when the
231/// specified cfg condition is met (e.g., feature flags or test mode).
232///
233/// # Usage
234/// ```rust
235/// use macros::public;
236///
237/// struct MyStruct;
238///
239/// impl MyStruct {
240/// #[public(feature = "dev")]
241/// fn private_method(&mut self, x: i32) -> i32 {
242/// x * 2
243/// }
244/// }
245///
246/// // Generates:
247/// // #[cfg(feature = "dev")]
248/// // pub fn _private_method(&mut self, x: i32) -> i32 {
249/// // self.private_method(x)
250/// // }
251/// ```
252#[proc_macro_attribute]
253pub fn public(attr: TokenStream, item: TokenStream) -> TokenStream {
254 public_fn::public_fn(attr, item)
255}
256
257/// Procedural macro to generate variants for arithmetic operations.
258///
259/// Supports binary operations (like `Add`), binary assign operations (like
260/// `AddAssign`), and unary operations (like `Neg`).
261///
262/// For **binary operations** (impl with `-> Self::Output`), it requires `Self op &RHS`;
263/// for **unary operations** it assumes `op Self`.
264///
265/// There are three supported variants:
266/// - `owned`: All operands owned - implemented by referencing rhs.
267/// - `borrowed`: All operands borrowed - implemented by cloning Self.
268/// - `flipped`: `&Self op RHS` - implemented by cloning lhs and referencing rhs (binary ops only).
269/// - `flipped_commutative`: `&Self op RHS` - implemented as `rhs.op(lhs)`. To be used only in
270/// commutative ops (e.g., Add & Mul, but not Sub | Div).
271///
272/// # Usage
273/// ```rust
274/// #[derive(Clone, Debug)]
275/// struct Point {
276/// x: i32,
277/// y: i32,
278/// }
279///
280/// // Binary operation: Point + &Point (base impl)
281/// #[macros::op_variants(owned, borrowed, flipped)]
282/// impl std::ops::Add<&Point> for Point {
283/// type Output = Point;
284/// fn add(self, other: &Point) -> Point {
285/// Point {
286/// x: self.x + other.x,
287/// y: self.y + other.y,
288/// }
289/// }
290/// }
291/// // Generates: Point + Point, &Point + &Point, &Point + Point
292///
293/// // Binary assign operation: Point += &Point (base impl)
294/// #[macros::op_variants(owned)]
295/// impl std::ops::AddAssign<&Point> for Point {
296/// fn add_assign(&mut self, other: &Point) {
297/// self.x += other.x;
298/// self.y += other.y;
299/// }
300/// }
301/// // Generates: Point += Point
302///
303/// // Unary operation: -&Point (base impl)
304/// #[macros::op_variants(owned)]
305/// impl std::ops::Neg for &Point {
306/// type Output = Point;
307/// fn neg(self) -> Point {
308/// Point {
309/// x: -self.x,
310/// y: -self.y,
311/// }
312/// }
313/// }
314/// // Generates: -Point
315/// ```
316#[proc_macro_attribute]
317pub fn op_variants(attr: TokenStream, item: TokenStream) -> TokenStream {
318 op_variants::op_variants(attr, item)
319}