round_based/sim/async_env.rs
1//! Fully async simulation
2//!
3//! Simulation provided in a [parent module](super) should be used in most cases. It works
4//! by converting all parties (defined as async functions) into [state machines](crate::state_machine),
5//! which has certain limitations. In particular, the protocol cannot await on any futures that
6//! aren't provided by [`MpcParty`], for instance, awaiting on the timer will cause a simulation error.
7//!
8//! We suggest to avoid awaiting on the futures that aren't provided by `MpcParty` in the MPC protocol
9//! implementation as it likely makes it runtime-dependent. However, if you do ultimately need to
10//! do that, then you can't use regular simulation for the tests.
11//!
12//! This module provides fully async simulation built for tokio runtime, so the protocol can await
13//! on any futures supported by the tokio.
14//!
15//! ## Limitations
16//! To implement simulated [network](Network), we used [`tokio::sync::broadcast`] channels, which
17//! have internal buffer of stored messages, and once simulated network receives more messages than
18//! internal buffer can fit, some of the parties will not receive some of the messages, which will
19//! lead to execution error.
20//!
21//! By default, internal buffer is preallocated to fit 500 messages, which should be more than
22//! sufficient for simulating protocols with small amount of parties (say, < 10).
23//!
24//! If you need to preallocate bigger buffer, use [`Network::with_capacity`].
25//!
26//! ## Example
27//! Entry point to the simulation are [`run`] and [`run_with_setup`] functions
28//!
29//! ```rust,no_run
30//! # #[tokio::main(flavor = "current_thread")]
31//! # async fn main() {
32//! use round_based::{Mpc, PartyIndex};
33//!
34//! # type Result<T, E = ()> = std::result::Result<T, E>;
35//! # type Randomness = [u8; 32];
36//! # type Msg = ();
37//! // Any MPC protocol you want to test
38//! pub async fn protocol_of_random_generation<M>(
39//! party: M,
40//! i: PartyIndex,
41//! n: u16
42//! ) -> Result<Randomness>
43//! where
44//! M: Mpc<ProtocolMessage = Msg>
45//! {
46//! // ...
47//! # todo!()
48//! }
49//!
50//! let n = 3;
51//!
52//! let output = round_based::sim::async_env::run(
53//! n,
54//! |i, party| protocol_of_random_generation(party, i, n),
55//! )
56//! .await
57//! // unwrap `Result`s
58//! .expect_ok()
59//! // check that all parties produced the same response
60//! .expect_eq();
61//!
62//! println!("Output randomness: {}", hex::encode(output));
63//! # }
64//! ```
65use alloc::sync::Arc;
66use core::{
67 future::Future,
68 pin::Pin,
69 sync::atomic::AtomicU64,
70 task::ready,
71 task::{Context, Poll},
72};
73
74use futures_util::{Sink, Stream};
75use tokio::sync::broadcast;
76use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
77
78use crate::delivery::{Delivery, Incoming, Outgoing};
79use crate::{MessageDestination, MessageType, MpcParty, MsgId, PartyIndex};
80
81use super::SimResult;
82
83const DEFAULT_CAPACITY: usize = 500;
84
85/// Simulated async network
86pub struct Network<M> {
87 channel: broadcast::Sender<Outgoing<Incoming<M>>>,
88 next_party_idx: PartyIndex,
89 next_msg_id: Arc<NextMessageId>,
90}
91
92impl<M> Network<M>
93where
94 M: Clone + Send + Unpin + 'static,
95{
96 /// Instantiates a new simulation
97 pub fn new() -> Self {
98 Self::with_capacity(500)
99 }
100
101 /// Instantiates a new simulation with given capacity
102 ///
103 /// `Simulation` stores internally all sent messages. Capacity limits size of the internal buffer.
104 /// Because of that you might run into error if you choose too small capacity. Choose capacity
105 /// that can fit all the messages sent by all the parties during entire protocol lifetime.
106 ///
107 /// Default capacity is 500 (i.e. if you call `Simulation::new()`)
108 pub fn with_capacity(capacity: usize) -> Self {
109 Self {
110 channel: broadcast::channel(capacity).0,
111 next_party_idx: 0,
112 next_msg_id: Default::default(),
113 }
114 }
115
116 /// Adds new party to the network
117 pub fn add_party(&mut self) -> MpcParty<M, MockedDelivery<M>> {
118 MpcParty::connected(self.connect_new_party())
119 }
120
121 /// Connects new party to the network
122 ///
123 /// Similar to [`.add_party()`](Self::add_party) but returns `MockedDelivery<M>` instead of
124 /// `MpcParty<M, MockedDelivery<M>>`
125 pub fn connect_new_party(&mut self) -> MockedDelivery<M> {
126 let local_party_idx = self.next_party_idx;
127 self.next_party_idx += 1;
128
129 MockedDelivery {
130 incoming: MockedIncoming {
131 local_party_idx,
132 receiver: BroadcastStream::new(self.channel.subscribe()),
133 },
134 outgoing: MockedOutgoing {
135 local_party_idx,
136 sender: self.channel.clone(),
137 next_msg_id: self.next_msg_id.clone(),
138 },
139 }
140 }
141}
142
143impl<M> Default for Network<M>
144where
145 M: Clone + Send + Unpin + 'static,
146{
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152/// Mocked networking
153pub struct MockedDelivery<M> {
154 incoming: MockedIncoming<M>,
155 outgoing: MockedOutgoing<M>,
156}
157
158impl<M> Delivery<M> for MockedDelivery<M>
159where
160 M: Clone + Send + Unpin + 'static,
161{
162 type Send = MockedOutgoing<M>;
163 type Receive = MockedIncoming<M>;
164 type SendError = broadcast::error::SendError<()>;
165 type ReceiveError = BroadcastStreamRecvError;
166
167 fn split(self) -> (Self::Receive, Self::Send) {
168 (self.incoming, self.outgoing)
169 }
170}
171
172/// Incoming channel of mocked network
173pub struct MockedIncoming<M> {
174 local_party_idx: PartyIndex,
175 receiver: BroadcastStream<Outgoing<Incoming<M>>>,
176}
177
178impl<M> Stream for MockedIncoming<M>
179where
180 M: Clone + Send + 'static,
181{
182 type Item = Result<Incoming<M>, BroadcastStreamRecvError>;
183
184 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185 loop {
186 let msg = match ready!(Pin::new(&mut self.receiver).poll_next(cx)) {
187 Some(Ok(m)) => m,
188 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
189 None => return Poll::Ready(None),
190 };
191 if msg.recipient.is_p2p()
192 && msg.recipient != MessageDestination::OneParty(self.local_party_idx)
193 {
194 continue;
195 }
196 return Poll::Ready(Some(Ok(msg.msg)));
197 }
198 }
199}
200
201/// Outgoing channel of mocked network
202pub struct MockedOutgoing<M> {
203 local_party_idx: PartyIndex,
204 sender: broadcast::Sender<Outgoing<Incoming<M>>>,
205 next_msg_id: Arc<NextMessageId>,
206}
207
208impl<M> Sink<Outgoing<M>> for MockedOutgoing<M> {
209 type Error = broadcast::error::SendError<()>;
210
211 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212 Poll::Ready(Ok(()))
213 }
214
215 fn start_send(self: Pin<&mut Self>, msg: Outgoing<M>) -> Result<(), Self::Error> {
216 let msg_type = match msg.recipient {
217 MessageDestination::AllParties => MessageType::Broadcast,
218 MessageDestination::OneParty(_) => MessageType::P2P,
219 };
220 self.sender
221 .send(msg.map(|m| Incoming {
222 id: self.next_msg_id.next(),
223 sender: self.local_party_idx,
224 msg_type,
225 msg: m,
226 }))
227 .map_err(|_| broadcast::error::SendError(()))?;
228 Ok(())
229 }
230
231 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
232 Poll::Ready(Ok(()))
233 }
234
235 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
236 Poll::Ready(Ok(()))
237 }
238}
239
240#[derive(Default)]
241struct NextMessageId(AtomicU64);
242
243impl NextMessageId {
244 pub fn next(&self) -> MsgId {
245 self.0.fetch_add(1, core::sync::atomic::Ordering::Relaxed)
246 }
247}
248
249/// Simulates execution of the protocol
250///
251/// Takes amount of participants, and a function that carries out the protocol for
252/// one party. The function takes as input: index of the party, and [`MpcParty`]
253/// that can be used to communicate with others.
254///
255/// ## Example
256/// ```rust,no_run
257/// # #[tokio::main(flavor = "current_thread")]
258/// # async fn main() {
259/// use round_based::{Mpc, PartyIndex};
260///
261/// # type Result<T, E = ()> = std::result::Result<T, E>;
262/// # type Randomness = [u8; 32];
263/// # type Msg = ();
264/// // Any MPC protocol you want to test
265/// pub async fn protocol_of_random_generation<M>(
266/// party: M,
267/// i: PartyIndex,
268/// n: u16
269/// ) -> Result<Randomness>
270/// where
271/// M: Mpc<ProtocolMessage = Msg>
272/// {
273/// // ...
274/// # todo!()
275/// }
276///
277/// let n = 3;
278///
279/// let output = round_based::sim::async_env::run(
280/// n,
281/// |i, party| protocol_of_random_generation(party, i, n),
282/// )
283/// .await
284/// // unwrap `Result`s
285/// .expect_ok()
286/// // check that all parties produced the same response
287/// .expect_eq();
288///
289/// println!("Output randomness: {}", hex::encode(output));
290/// # }
291/// ```
292pub async fn run<M, F>(
293 n: u16,
294 party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>) -> F,
295) -> SimResult<F::Output>
296where
297 M: Clone + Send + Unpin + 'static,
298 F: Future,
299{
300 run_with_capacity(DEFAULT_CAPACITY, n, party_start).await
301}
302
303/// Simulates execution of the protocol
304///
305/// Same as [`run`] but also takes a capacity of internal buffer to be used
306/// within simulated network. Size of internal buffer should fit total amount of the
307/// messages sent by all participants during the whole protocol execution.
308pub async fn run_with_capacity<M, F>(
309 capacity: usize,
310 n: u16,
311 mut party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>) -> F,
312) -> SimResult<F::Output>
313where
314 M: Clone + Send + Unpin + 'static,
315 F: Future,
316{
317 run_with_capacity_and_setup(
318 capacity,
319 core::iter::repeat(()).take(n.into()),
320 |i, party, ()| party_start(i, party),
321 )
322 .await
323}
324
325/// Simulates execution of the protocol
326///
327/// Similar to [`run`], but allows some setup to be provided to the protocol execution
328/// function.
329///
330/// Simulation will have as many parties as `setups` iterator yields
331///
332/// ## Example
333/// ```rust,no_run
334/// # #[tokio::main(flavor = "current_thread")]
335/// # async fn main() {
336/// use round_based::{Mpc, PartyIndex};
337///
338/// # type Result<T, E = ()> = std::result::Result<T, E>;
339/// # type Randomness = [u8; 32];
340/// # type Msg = ();
341/// // Any MPC protocol you want to test
342/// pub async fn protocol_of_random_generation<M>(
343/// rng: impl rand::RngCore,
344/// party: M,
345/// i: PartyIndex,
346/// n: u16
347/// ) -> Result<Randomness>
348/// where
349/// M: Mpc<ProtocolMessage = Msg>
350/// {
351/// // ...
352/// # todo!()
353/// }
354///
355/// let mut rng = rand_dev::DevRng::new();
356/// let n = 3;
357/// let output = round_based::sim::async_env::run_with_setup(
358/// core::iter::repeat_with(|| rng.fork()).take(n.into()),
359/// |i, party, rng| protocol_of_random_generation(rng, party, i, n),
360/// )
361/// .await
362/// // unwrap `Result`s
363/// .expect_ok()
364/// // check that all parties produced the same response
365/// .expect_eq();
366///
367/// println!("Output randomness: {}", hex::encode(output));
368/// # }
369/// ```
370pub async fn run_with_setup<S, M, F>(
371 setups: impl IntoIterator<Item = S>,
372 party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>, S) -> F,
373) -> SimResult<F::Output>
374where
375 M: Clone + Send + Unpin + 'static,
376 F: Future,
377{
378 run_with_capacity_and_setup::<S, M, F>(DEFAULT_CAPACITY, setups, party_start).await
379}
380
381/// Simulates execution of the protocol
382///
383/// Same as [`run_with_setup`] but also takes a capacity of internal buffer to be used
384/// within simulated network. Size of internal buffer should fit total amount of the
385/// messages sent by all participants during the whole protocol execution.
386pub async fn run_with_capacity_and_setup<S, M, F>(
387 capacity: usize,
388 setups: impl IntoIterator<Item = S>,
389 mut party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>, S) -> F,
390) -> SimResult<F::Output>
391where
392 M: Clone + Send + Unpin + 'static,
393 F: Future,
394{
395 let mut network = Network::<M>::with_capacity(capacity);
396
397 let mut output = alloc::vec![];
398 for (setup, i) in setups.into_iter().zip(0u16..) {
399 output.push({
400 let party = network.add_party();
401 party_start(i, party, setup)
402 });
403 }
404
405 let result = futures_util::future::join_all(output).await;
406 SimResult(result)
407}