mpc_macros/lib.rs
1/*!
2# Async-MPC Macros
3
4This crate provides procedural macros for the async-mpc library.
5*/
6
7mod define_task;
8mod gate_methods;
9mod party_traits;
10mod protocol_tags;
11mod public_fn;
12mod task_getters;
13
14use proc_macro::TokenStream;
15
16/// Generates async-mpc task implementations with automatic dependency management.
17///
18/// This macro takes a struct definition and a compute function, and generates:
19/// - An internal unresolved struct to hold task dependencies
20/// - A public type alias using the `TaskWrapper` type
21/// - A Task trait implementation with async execution logic
22/// - A `new` constructor for creating new task instances
23/// - A standalone `compute` method for direct calls
24///
25/// # Examples
26///
27/// ```rust
28/// define_task! {
29/// pub struct FieldAddTask<F: FieldExtension> {
30/// x: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
31/// y: Arc<dyn Task<Output = Arc<FieldShare<F>>>>,
32/// }
33///
34/// async fn compute(x: FieldShare<F>, y: FieldShare<F>) -> Result<FieldShare<F>, AbortError> {
35/// Ok(x + &y)
36/// }
37/// }
38/// ```
39#[proc_macro]
40pub fn define_task(input: TokenStream) -> TokenStream {
41 define_task::define_task(input)
42}
43
44/// Derives `try_get_*_task` functions for enum variants containing task types.
45///
46/// This macro automatically generates getter functions for each variant in a `TaskType` enum,
47/// eliminating the need for repetitive boilerplate code. These functions retrieve tasks based on
48/// their index (encoded in the Label's first field) and return the appropriate task type.
49///
50/// # Examples
51///
52/// ```rust
53/// #[derive(TaskGetters)]
54/// pub enum TaskType<C: Curve> {
55/// ScalarShareTask(Arc<dyn Task<Output = Arc<ScalarShare<C>>>>),
56/// BaseFieldShareTask(Arc<dyn Task<Output = Arc<BaseFieldShare<C>>>>),
57/// // ... more variants
58/// }
59/// ```
60///
61/// This generates functions like:
62/// - `try_get_scalar_share_task`
63/// - `try_get_base_field_share_task`
64/// - `try_get_point_share_task`
65///
66/// Each function has the signature:
67/// ```rust
68/// pub fn try_get_<variant_name>_task<C: Curve>(
69/// label: Label,
70/// task_map: &[TaskType<C>],
71/// ) -> Result<Arc<dyn Task<Output = Arc<T>>>, ProtocolError>
72/// ```
73///
74/// ## Requirements
75///
76/// - The enum must have variants with the pattern `*Task(Arc<dyn Task<Output = Arc<T>>>)`
77/// - The macro extracts the inner type `T` from `Arc<dyn Task<Output = Arc<T>>>`
78/// - Each variant must have a corresponding `try_as_*` method (typically generated by
79/// `EnumTryAsInner`)
80#[proc_macro_derive(TaskGetters)]
81pub fn derive_task_getters(input: TokenStream) -> TokenStream {
82 task_getters::derive_task_getters(input)
83}
84
85/// Derives `Gate::map_labels` and `Gate::for_each_label`
86#[proc_macro_derive(GateMethods)]
87pub fn derive_gate_methods(input: TokenStream) -> TokenStream {
88 gate_methods::derive_gate_methods(input)
89}
90
91/// Macro to ensure protocol tag uniqueness at compile time.
92/// This macro generates a unique tag for a given protocol name,
93/// and checks against a registry to prevent collisions.
94/// If a collision is detected, it finds the next available tag.
95#[proc_macro]
96pub fn new_protocol(input: TokenStream) -> TokenStream {
97 protocol_tags::new_protocol_info(input)
98}
99
100/// Macro to dump the current state of the protocol tag registry at compile time (for debugging)
101#[proc_macro]
102pub fn dump_protocol_tags(input: TokenStream) -> TokenStream {
103 protocol_tags::dump_tags(input)
104}
105/// Procedural macro to automatically implement the `Party` trait for a struct.
106///
107/// # Requirements
108/// - The struct must have either:
109/// - A field marked with `#[session_id]` or a field of type `SessionId`, or
110/// - At least one field annotated with `#[party]` (which must itself implement `Party`).
111///
112/// # Behavior
113/// - `session_id()`: Returns a reference to the `SessionId` field, or delegates to the first
114/// `#[party]` field.
115/// - `protocol_name()`: Uses the `PROTOCOL_INFO` constant if a direct `SessionId` field is present,
116/// otherwise formats as `"<TypeName> - <InnerProtocolName>"`.
117/// - `refresh()`: Calls `refresh_from` or `refresh_with` on the session id, and `refresh()` on all
118/// `#[party]` fields. If a `hasher` field is present, it is reset.
119///
120/// # Usage
121/// ```rust
122/// #[derive(Party)]
123/// pub struct MyParty {
124/// #[session_id]
125/// session_id: SessionId,
126/// #[party]
127/// inner: InnerPartyType,
128/// // ...
129/// }
130/// // or, with type/name matching fallback:
131/// #[derive(Party)]
132/// pub struct MyParty { session_id: SessionId, #[party] inner: InnerPartyType, ... }
133/// ```
134#[proc_macro_derive(Party, attributes(party, session_id, transcript, hasher))]
135pub fn derive_party(input: TokenStream) -> TokenStream {
136 party_traits::derive_party(input)
137}
138
139/// Procedural macro to automatically implement the `Probabilistic` trait for a struct.
140///
141/// # Requirements
142/// - The struct must have a field marked with `#[rng]` or, if not present, a field named `rng` that
143/// implements `CryptoRngCore`.
144///
145/// # Behavior
146/// - Implements the `rng(&mut self)` method, returning a mutable reference to the selected field.
147///
148/// # Usage
149/// ```rust
150/// #[derive(Probabilistic)]
151/// pub struct MyParty {
152/// #[rng]
153/// my_rng: MyRngType,
154/// // ...
155/// }
156/// // or, fallback:
157/// #[derive(Probabilistic)]
158/// pub struct MyParty { rng: MyRngType, ... }
159/// ```
160#[proc_macro_derive(Probabilistic, attributes(rng))]
161pub fn derive_probabilistic(input: TokenStream) -> TokenStream {
162 party_traits::derive_probabilistic(input)
163}
164
165/// Procedural macro to automatically implement the `Scribe` trait for a struct.
166///
167/// # Requirements
168/// - The struct must have a field marked with `#[transcript]` or, if not present, a field named
169/// `transcript` that implements `Transcript`.
170///
171/// # Behavior
172/// - Implements the `transcript(&mut self)` method, returning a mutable reference to the selected
173/// field.
174///
175/// # Usage
176/// ```rust
177/// #[derive(Scribe)]
178/// pub struct MyParty {
179/// #[transcript]
180/// my_transcript: MyTranscriptType,
181/// // ...
182/// }
183/// // or, fallback:
184/// #[derive(Scribe)]
185/// pub struct MyParty { transcript: MyTranscriptType, ... }
186/// ```
187#[proc_macro_derive(Scribe, attributes(transcript))]
188pub fn derive_scribe(input: TokenStream) -> TokenStream {
189 party_traits::derive_scribe(input)
190}
191
192/// Procedural macro to automatically implement the `HasTweakableHasher` trait for a struct.
193///
194/// # Requirements
195/// - The struct must have a field marked with `#[hasher]` or, if not present, a field named
196/// `hasher` whose type implements `TweakableHasher`.
197///
198/// # Behavior
199/// - Implements the `get_hasher(&mut self)` method, returning a mutable reference to the selected
200/// field.
201/// - Sets the associated type `Hasher` to the type of the selected field.
202///
203/// # Usage
204/// ```rust
205/// #[derive(HasTweakableHasher)]
206/// pub struct MyParty {
207/// #[hasher]
208/// my_hasher: MyHasherType,
209/// // ...
210/// }
211/// // or, fallback:
212/// #[derive(HasTweakableHasher)]
213/// pub struct MyParty { hasher: MyHasherType, ... }
214/// ```
215#[proc_macro_derive(HasTweakableHasher, attributes(hasher))]
216pub fn derive_has_tweakable_hasher(input: TokenStream) -> TokenStream {
217 party_traits::derive_has_tweakable_hasher(input)
218}
219
220/// Procedural macro to automatically implement the `Peer` trait for a struct.
221///
222/// # Requirements
223/// - The struct must have a field marked with `#[peer_ctx]` or, if not present, a field named
224/// `peer_ctx` whose type is PeerContext.
225///
226/// # Behavior
227/// - Implements the `peer_context(&self)` method, returning a reference to the selected field.
228///
229/// # Usage
230/// ```rust
231/// #[derive(Peer)]
232/// pub struct MyParty {
233/// #[peer_ctx]
234/// peer_ctx: PeerContext,
235/// // ...
236/// }
237/// // or, fallback to type matching:
238/// #[derive(Peer)]
239/// pub struct MyParty { peer_ctx: PeerContext, ... }
240/// ```
241#[proc_macro_derive(Peer, attributes(peer))]
242pub fn derive_peer(input: TokenStream) -> TokenStream {
243 party_traits::derive_peer(input)
244}
245
246/// Procedural macro to expose a private member function for given cfg tags.
247///
248/// This macro creates a public wrapper function with an underscore prefix that
249/// calls the original private function. The wrapper is only compiled when the
250/// specified cfg condition is met (e.g., feature flags or test mode).
251///
252/// # Usage
253/// ```rust
254/// use macros::public;
255///
256/// struct MyStruct;
257///
258/// impl MyStruct {
259/// #[public(feature = "dev")]
260/// fn private_method(&mut self, x: i32) -> i32 {
261/// x * 2
262/// }
263/// }
264///
265/// // Generates:
266/// // #[cfg(feature = "dev")]
267/// // pub fn _private_method(&mut self, x: i32) -> i32 {
268/// // self.private_method(x)
269/// // }
270/// ```
271#[proc_macro_attribute]
272pub fn public(attr: TokenStream, item: TokenStream) -> TokenStream {
273 public_fn::public_fn(attr, item)
274}