round_based/sim/async_env.rs
1//! Fully async simulation
2//!
3//! Simulation provided in a [parent module](super) should be used in most cases. It works
4//! by converting all parties (defined as async functions) into [state machines](crate::state_machine),
5//! which has certain limitations. In particular, the protocol cannot await on any futures that
6//! aren't provided by [`MpcParty`], for instance, awaiting on the timer will cause a simulation error.
7//!
8//! We suggest to avoid awaiting on the futures that aren't provided by `MpcParty` in the MPC protocol
9//! implementation as it likely makes it runtime-dependent. However, if you do ultimately need to
10//! do that, then you can't use regular simulation for the tests.
11//!
12//! This module provides fully async simulation built for tokio runtime, so the protocol can await
13//! on any futures supported by the tokio.
14//!
15//! ## Limitations
16//! To implement simulated [network](Network), we used [`tokio::sync::broadcast`] channels, which
17//! have internal buffer of stored messages, and once simulated network receives more messages than
18//! internal buffer can fit, some of the parties will not receive some of the messages, which will
19//! lead to execution error.
20//!
21//! By default, internal buffer is preallocated to fit 500 messages, which should be more than
22//! sufficient for simulating protocols with small amount of parties (say, < 10).
23//!
24//! If you need to preallocate bigger buffer, use [`Network::with_capacity`].
25//!
26//! ## Example
27//! Entry point to the simulation are [`run`] and [`run_with_setup`] functions
28//!
29//! ```rust,no_run
30//! # #[tokio::main(flavor = "current_thread")]
31//! # async fn main() {
32//! use round_based::{Mpc, PartyIndex};
33//!
34//! # type Result<T, E = ()> = std::result::Result<T, E>;
35//! # type Randomness = [u8; 32];
36//! # #[derive(round_based::ProtocolMsg, Clone)]
37//! # enum Msg {}
38//! // Any MPC protocol you want to test
39//! pub async fn protocol_of_random_generation<M>(
40//! party: M,
41//! i: PartyIndex,
42//! n: u16
43//! ) -> Result<Randomness>
44//! where
45//! M: Mpc<Msg = Msg>
46//! {
47//! // ...
48//! # todo!()
49//! }
50//!
51//! let n = 3;
52//!
53//! let output = round_based::sim::async_env::run(
54//! n,
55//! |i, party| protocol_of_random_generation(party, i, n),
56//! )
57//! .await
58//! // unwrap `Result`s
59//! .expect_ok()
60//! // check that all parties produced the same response
61//! .expect_eq();
62//!
63//! println!("Output randomness: {}", hex::encode(output));
64//! # }
65//! ```
66use alloc::sync::Arc;
67use core::{
68 future::Future,
69 pin::Pin,
70 sync::atomic::AtomicU64,
71 task::ready,
72 task::{Context, Poll},
73};
74
75use futures_util::{Sink, Stream};
76use tokio::sync::broadcast;
77use tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError};
78
79use crate::{MessageDestination, MessageType, MpcParty, MsgId, PartyIndex};
80use crate::{
81 ProtocolMsg,
82 delivery::{Incoming, Outgoing},
83};
84
85use super::SimResult;
86
87const DEFAULT_CAPACITY: usize = 500;
88
89/// Simulated async network
90pub struct Network<M> {
91 channel: broadcast::Sender<Outgoing<Incoming<M>>>,
92 next_party_idx: PartyIndex,
93 next_msg_id: Arc<NextMessageId>,
94}
95
96impl<M> Network<M>
97where
98 M: ProtocolMsg + Clone + Send + Unpin + 'static,
99{
100 /// Instantiates a new simulation
101 pub fn new() -> Self {
102 Self::with_capacity(500)
103 }
104
105 /// Instantiates a new simulation with given capacity
106 ///
107 /// `Simulation` stores internally all sent messages. Capacity limits size of the internal buffer.
108 /// Because of that you might run into error if you choose too small capacity. Choose capacity
109 /// that can fit all the messages sent by all the parties during entire protocol lifetime.
110 ///
111 /// Default capacity is 500 (i.e. if you call `Simulation::new()`)
112 pub fn with_capacity(capacity: usize) -> Self {
113 Self {
114 channel: broadcast::channel(capacity).0,
115 next_party_idx: 0,
116 next_msg_id: Default::default(),
117 }
118 }
119
120 /// Adds new party to the network
121 pub fn add_party(&mut self) -> MpcParty<M, MockedDelivery<M>> {
122 MpcParty::connected(self.connect_new_party())
123 }
124
125 /// Connects new party to the network
126 ///
127 /// Similar to [`.add_party()`](Self::add_party) but returns `MockedDelivery<M>` instead of
128 /// `MpcParty<M, MockedDelivery<M>>`
129 pub fn connect_new_party(&mut self) -> MockedDelivery<M> {
130 let local_party_idx = self.next_party_idx;
131 self.next_party_idx += 1;
132
133 MockedDelivery::new(
134 MockedIncoming {
135 local_party_idx,
136 receiver: BroadcastStream::new(self.channel.subscribe()),
137 },
138 MockedOutgoing {
139 local_party_idx,
140 sender: self.channel.clone(),
141 next_msg_id: self.next_msg_id.clone(),
142 },
143 )
144 }
145}
146
147impl<M> Default for Network<M>
148where
149 M: ProtocolMsg + Clone + Send + Unpin + 'static,
150{
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156/// Mocked networking
157pub type MockedDelivery<M> = crate::mpc::Halves<MockedIncoming<M>, MockedOutgoing<M>>;
158
159/// Delivery error
160#[derive(Debug, thiserror::Error)]
161pub enum MockedDeliveryError {
162 /// Error occurred when sending a message
163 #[error(transparent)]
164 Recv(BroadcastStreamRecvError),
165 /// Error occurred when receiving a message
166 #[error(transparent)]
167 Send(broadcast::error::SendError<()>),
168}
169
170/// Incoming channel of mocked network
171pub struct MockedIncoming<M> {
172 local_party_idx: PartyIndex,
173 receiver: BroadcastStream<Outgoing<Incoming<M>>>,
174}
175
176impl<M> Stream for MockedIncoming<M>
177where
178 M: Clone + Send + 'static,
179{
180 type Item = Result<Incoming<M>, MockedDeliveryError>;
181
182 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183 loop {
184 let msg = match ready!(Pin::new(&mut self.receiver).poll_next(cx)) {
185 Some(Ok(m)) => m,
186 Some(Err(e)) => return Poll::Ready(Some(Err(MockedDeliveryError::Recv(e)))),
187 None => return Poll::Ready(None),
188 };
189 if msg.recipient.is_p2p()
190 && msg.recipient != MessageDestination::OneParty(self.local_party_idx)
191 {
192 continue;
193 }
194 return Poll::Ready(Some(Ok(msg.msg)));
195 }
196 }
197}
198
199/// Outgoing channel of mocked network
200pub struct MockedOutgoing<M> {
201 local_party_idx: PartyIndex,
202 sender: broadcast::Sender<Outgoing<Incoming<M>>>,
203 next_msg_id: Arc<NextMessageId>,
204}
205
206impl<M> Sink<Outgoing<M>> for MockedOutgoing<M> {
207 type Error = MockedDeliveryError;
208
209 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
210 Poll::Ready(Ok(()))
211 }
212
213 fn start_send(self: Pin<&mut Self>, msg: Outgoing<M>) -> Result<(), Self::Error> {
214 let msg_type = match msg.recipient {
215 MessageDestination::AllParties { reliable } => MessageType::Broadcast { reliable },
216 MessageDestination::OneParty(_) => MessageType::P2P,
217 };
218 self.sender
219 .send(msg.map(|m| Incoming {
220 id: self.next_msg_id.next(),
221 sender: self.local_party_idx,
222 msg_type,
223 msg: m,
224 }))
225 .map_err(|_| MockedDeliveryError::Send(broadcast::error::SendError(())))?;
226 Ok(())
227 }
228
229 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
230 Poll::Ready(Ok(()))
231 }
232
233 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
234 Poll::Ready(Ok(()))
235 }
236}
237
238#[derive(Default)]
239struct NextMessageId(AtomicU64);
240
241impl NextMessageId {
242 pub fn next(&self) -> MsgId {
243 self.0.fetch_add(1, core::sync::atomic::Ordering::Relaxed)
244 }
245}
246
247/// Simulates execution of the protocol
248///
249/// Takes amount of participants, and a function that carries out the protocol for
250/// one party. The function takes as input: index of the party, and [`MpcParty`]
251/// that can be used to communicate with others.
252///
253/// ## Example
254/// ```rust,no_run
255/// # #[tokio::main(flavor = "current_thread")]
256/// # async fn main() {
257/// use round_based::{Mpc, PartyIndex};
258///
259/// # type Result<T, E = ()> = std::result::Result<T, E>;
260/// # type Randomness = [u8; 32];
261/// # #[derive(round_based::ProtocolMsg, Clone)]
262/// # enum Msg {}
263/// // Any MPC protocol you want to test
264/// pub async fn protocol_of_random_generation<M>(
265/// party: M,
266/// i: PartyIndex,
267/// n: u16
268/// ) -> Result<Randomness>
269/// where
270/// M: Mpc<Msg = Msg>
271/// {
272/// // ...
273/// # todo!()
274/// }
275///
276/// let n = 3;
277///
278/// let output = round_based::sim::async_env::run(
279/// n,
280/// |i, party| protocol_of_random_generation(party, i, n),
281/// )
282/// .await
283/// // unwrap `Result`s
284/// .expect_ok()
285/// // check that all parties produced the same response
286/// .expect_eq();
287///
288/// println!("Output randomness: {}", hex::encode(output));
289/// # }
290/// ```
291pub async fn run<M, F>(
292 n: u16,
293 party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>) -> F,
294) -> SimResult<F::Output>
295where
296 M: ProtocolMsg + Clone + Send + Unpin + 'static,
297 F: Future,
298{
299 run_with_capacity(DEFAULT_CAPACITY, n, party_start).await
300}
301
302/// Simulates execution of the protocol
303///
304/// Same as [`run`] but also takes a capacity of internal buffer to be used
305/// within simulated network. Size of internal buffer should fit total amount of the
306/// messages sent by all participants during the whole protocol execution.
307pub async fn run_with_capacity<M, F>(
308 capacity: usize,
309 n: u16,
310 mut party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>) -> F,
311) -> SimResult<F::Output>
312where
313 M: ProtocolMsg + Clone + Send + Unpin + 'static,
314 F: Future,
315{
316 run_with_capacity_and_setup(
317 capacity,
318 core::iter::repeat_n((), n.into()),
319 |i, party, ()| party_start(i, party),
320 )
321 .await
322}
323
324/// Simulates execution of the protocol
325///
326/// Similar to [`run`], but allows some setup to be provided to the protocol execution
327/// function.
328///
329/// Simulation will have as many parties as `setups` iterator yields
330///
331/// ## Example
332/// ```rust,no_run
333/// # #[tokio::main(flavor = "current_thread")]
334/// # async fn main() {
335/// use round_based::{Mpc, PartyIndex};
336///
337/// # type Result<T, E = ()> = std::result::Result<T, E>;
338/// # type Randomness = [u8; 32];
339/// # #[derive(round_based::ProtocolMsg, Clone)]
340/// # enum Msg {}
341/// // Any MPC protocol you want to test
342/// pub async fn protocol_of_random_generation<M>(
343/// rng: impl rand::RngCore,
344/// party: M,
345/// i: PartyIndex,
346/// n: u16
347/// ) -> Result<Randomness>
348/// where
349/// M: Mpc<Msg = Msg>
350/// {
351/// // ...
352/// # todo!()
353/// }
354///
355/// let mut rng = rand_dev::DevRng::new();
356/// let n = 3;
357/// let output = round_based::sim::async_env::run_with_setup(
358/// core::iter::repeat_with(|| rng.fork()).take(n.into()),
359/// |i, party, rng| protocol_of_random_generation(rng, party, i, n),
360/// )
361/// .await
362/// // unwrap `Result`s
363/// .expect_ok()
364/// // check that all parties produced the same response
365/// .expect_eq();
366///
367/// println!("Output randomness: {}", hex::encode(output));
368/// # }
369/// ```
370pub async fn run_with_setup<S, M, F>(
371 setups: impl IntoIterator<Item = S>,
372 party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>, S) -> F,
373) -> SimResult<F::Output>
374where
375 M: ProtocolMsg + Clone + Send + Unpin + 'static,
376 F: Future,
377{
378 run_with_capacity_and_setup::<S, M, F>(DEFAULT_CAPACITY, setups, party_start).await
379}
380
381/// Simulates execution of the protocol
382///
383/// Same as [`run_with_setup`] but also takes a capacity of internal buffer to be used
384/// within simulated network. Size of internal buffer should fit total amount of the
385/// messages sent by all participants during the whole protocol execution.
386pub async fn run_with_capacity_and_setup<S, M, F>(
387 capacity: usize,
388 setups: impl IntoIterator<Item = S>,
389 mut party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>, S) -> F,
390) -> SimResult<F::Output>
391where
392 M: ProtocolMsg + Clone + Send + Unpin + 'static,
393 F: Future,
394{
395 let mut network = Network::<M>::with_capacity(capacity);
396
397 let mut output = alloc::vec![];
398 for (setup, i) in setups.into_iter().zip(0u16..) {
399 output.push({
400 let party = network.add_party();
401 party_start(i, party, setup)
402 });
403 }
404
405 let result = futures_util::future::join_all(output).await;
406 SimResult(result)
407}