round_based/sim/
async_env.rs

1//! Fully async simulation
2//!
3//! Simulation provided in a [parent module](super) should be used in most cases. It works
4//! by converting all parties (defined as async functions) into [state machines](crate::state_machine),
5//! which has certain limitations. In particular, the protocol cannot await on any futures that
6//! aren't provided by [`MpcParty`], for instance, awaiting on the timer will cause a simulation error.
7//!
8//! We suggest to avoid awaiting on the futures that aren't provided by `MpcParty` in the MPC protocol
9//! implementation as it likely makes it runtime-dependent. However, if you do ultimately need to
10//! do that, then you can't use regular simulation for the tests.
11//!
12//! This module provides fully async simulation built for tokio runtime, so the protocol can await
13//! on any futures supported by the tokio.
14//!
15//! ## Limitations
16//! To implement simulated [network](Network), we used [`tokio::sync::broadcast`] channels, which
17//! have internal buffer of stored messages, and once simulated network receives more messages than
18//! internal buffer can fit, some of the parties will not receive some of the messages, which will
19//! lead to execution error.
20//!
21//! By default, internal buffer is preallocated to fit 500 messages, which should be more than
22//! sufficient for simulating protocols with small amount of parties (say, < 10).
23//!
24//! If you need to preallocate bigger buffer, use [`Network::with_capacity`].
25//!
26//! ## Example
27//! Entry point to the simulation are [`run`] and [`run_with_setup`] functions
28//!
29//! ```rust,no_run
30//! # #[tokio::main(flavor = "current_thread")]
31//! # async fn main() {
32//! use round_based::{Mpc, PartyIndex};
33//!
34//! # type Result<T, E = ()> = std::result::Result<T, E>;
35//! # type Randomness = [u8; 32];
36//! # type Msg = ();
37//! // Any MPC protocol you want to test
38//! pub async fn protocol_of_random_generation<M>(
39//!     party: M,
40//!     i: PartyIndex,
41//!     n: u16
42//! ) -> Result<Randomness>
43//! where
44//!     M: Mpc<ProtocolMessage = Msg>
45//! {
46//!     // ...
47//! # todo!()
48//! }
49//!
50//! let n = 3;
51//!
52//! let output = round_based::sim::async_env::run(
53//!     n,
54//!     |i, party| protocol_of_random_generation(party, i, n),
55//! )
56//! .await
57//! // unwrap `Result`s
58//! .expect_ok()
59//! // check that all parties produced the same response
60//! .expect_eq();
61//!
62//! println!("Output randomness: {}", hex::encode(output));
63//! # }  
64//! ```
65use alloc::sync::Arc;
66use core::{
67    future::Future,
68    pin::Pin,
69    sync::atomic::AtomicU64,
70    task::ready,
71    task::{Context, Poll},
72};
73
74use futures_util::{Sink, Stream};
75use tokio::sync::broadcast;
76use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream};
77
78use crate::delivery::{Delivery, Incoming, Outgoing};
79use crate::{MessageDestination, MessageType, MpcParty, MsgId, PartyIndex};
80
81use super::SimResult;
82
83const DEFAULT_CAPACITY: usize = 500;
84
85/// Simulated async network
86pub struct Network<M> {
87    channel: broadcast::Sender<Outgoing<Incoming<M>>>,
88    next_party_idx: PartyIndex,
89    next_msg_id: Arc<NextMessageId>,
90}
91
92impl<M> Network<M>
93where
94    M: Clone + Send + Unpin + 'static,
95{
96    /// Instantiates a new simulation
97    pub fn new() -> Self {
98        Self::with_capacity(500)
99    }
100
101    /// Instantiates a new simulation with given capacity
102    ///
103    /// `Simulation` stores internally all sent messages. Capacity limits size of the internal buffer.
104    /// Because of that you might run into error if you choose too small capacity. Choose capacity
105    /// that can fit all the messages sent by all the parties during entire protocol lifetime.
106    ///
107    /// Default capacity is 500 (i.e. if you call `Simulation::new()`)
108    pub fn with_capacity(capacity: usize) -> Self {
109        Self {
110            channel: broadcast::channel(capacity).0,
111            next_party_idx: 0,
112            next_msg_id: Default::default(),
113        }
114    }
115
116    /// Adds new party to the network
117    pub fn add_party(&mut self) -> MpcParty<M, MockedDelivery<M>> {
118        MpcParty::connected(self.connect_new_party())
119    }
120
121    /// Connects new party to the network
122    ///
123    /// Similar to [`.add_party()`](Self::add_party) but returns `MockedDelivery<M>` instead of
124    /// `MpcParty<M, MockedDelivery<M>>`
125    pub fn connect_new_party(&mut self) -> MockedDelivery<M> {
126        let local_party_idx = self.next_party_idx;
127        self.next_party_idx += 1;
128
129        MockedDelivery {
130            incoming: MockedIncoming {
131                local_party_idx,
132                receiver: BroadcastStream::new(self.channel.subscribe()),
133            },
134            outgoing: MockedOutgoing {
135                local_party_idx,
136                sender: self.channel.clone(),
137                next_msg_id: self.next_msg_id.clone(),
138            },
139        }
140    }
141}
142
143impl<M> Default for Network<M>
144where
145    M: Clone + Send + Unpin + 'static,
146{
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152/// Mocked networking
153pub struct MockedDelivery<M> {
154    incoming: MockedIncoming<M>,
155    outgoing: MockedOutgoing<M>,
156}
157
158impl<M> Delivery<M> for MockedDelivery<M>
159where
160    M: Clone + Send + Unpin + 'static,
161{
162    type Send = MockedOutgoing<M>;
163    type Receive = MockedIncoming<M>;
164    type SendError = broadcast::error::SendError<()>;
165    type ReceiveError = BroadcastStreamRecvError;
166
167    fn split(self) -> (Self::Receive, Self::Send) {
168        (self.incoming, self.outgoing)
169    }
170}
171
172/// Incoming channel of mocked network
173pub struct MockedIncoming<M> {
174    local_party_idx: PartyIndex,
175    receiver: BroadcastStream<Outgoing<Incoming<M>>>,
176}
177
178impl<M> Stream for MockedIncoming<M>
179where
180    M: Clone + Send + 'static,
181{
182    type Item = Result<Incoming<M>, BroadcastStreamRecvError>;
183
184    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185        loop {
186            let msg = match ready!(Pin::new(&mut self.receiver).poll_next(cx)) {
187                Some(Ok(m)) => m,
188                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
189                None => return Poll::Ready(None),
190            };
191            if msg.recipient.is_p2p()
192                && msg.recipient != MessageDestination::OneParty(self.local_party_idx)
193            {
194                continue;
195            }
196            return Poll::Ready(Some(Ok(msg.msg)));
197        }
198    }
199}
200
201/// Outgoing channel of mocked network
202pub struct MockedOutgoing<M> {
203    local_party_idx: PartyIndex,
204    sender: broadcast::Sender<Outgoing<Incoming<M>>>,
205    next_msg_id: Arc<NextMessageId>,
206}
207
208impl<M> Sink<Outgoing<M>> for MockedOutgoing<M> {
209    type Error = broadcast::error::SendError<()>;
210
211    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212        Poll::Ready(Ok(()))
213    }
214
215    fn start_send(self: Pin<&mut Self>, msg: Outgoing<M>) -> Result<(), Self::Error> {
216        let msg_type = match msg.recipient {
217            MessageDestination::AllParties => MessageType::Broadcast,
218            MessageDestination::OneParty(_) => MessageType::P2P,
219        };
220        self.sender
221            .send(msg.map(|m| Incoming {
222                id: self.next_msg_id.next(),
223                sender: self.local_party_idx,
224                msg_type,
225                msg: m,
226            }))
227            .map_err(|_| broadcast::error::SendError(()))?;
228        Ok(())
229    }
230
231    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
232        Poll::Ready(Ok(()))
233    }
234
235    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
236        Poll::Ready(Ok(()))
237    }
238}
239
240#[derive(Default)]
241struct NextMessageId(AtomicU64);
242
243impl NextMessageId {
244    pub fn next(&self) -> MsgId {
245        self.0.fetch_add(1, core::sync::atomic::Ordering::Relaxed)
246    }
247}
248
249/// Simulates execution of the protocol
250///
251/// Takes amount of participants, and a function that carries out the protocol for
252/// one party. The function takes as input: index of the party, and [`MpcParty`]
253/// that can be used to communicate with others.
254///
255/// ## Example
256/// ```rust,no_run
257/// # #[tokio::main(flavor = "current_thread")]
258/// # async fn main() {
259/// use round_based::{Mpc, PartyIndex};
260///
261/// # type Result<T, E = ()> = std::result::Result<T, E>;
262/// # type Randomness = [u8; 32];
263/// # type Msg = ();
264/// // Any MPC protocol you want to test
265/// pub async fn protocol_of_random_generation<M>(
266///     party: M,
267///     i: PartyIndex,
268///     n: u16
269/// ) -> Result<Randomness>
270/// where
271///     M: Mpc<ProtocolMessage = Msg>
272/// {
273///     // ...
274/// # todo!()
275/// }
276///
277/// let n = 3;
278///
279/// let output = round_based::sim::async_env::run(
280///     n,
281///     |i, party| protocol_of_random_generation(party, i, n),
282/// )
283/// .await
284/// // unwrap `Result`s
285/// .expect_ok()
286/// // check that all parties produced the same response
287/// .expect_eq();
288///
289/// println!("Output randomness: {}", hex::encode(output));
290/// # }  
291/// ```
292pub async fn run<M, F>(
293    n: u16,
294    party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>) -> F,
295) -> SimResult<F::Output>
296where
297    M: Clone + Send + Unpin + 'static,
298    F: Future,
299{
300    run_with_capacity(DEFAULT_CAPACITY, n, party_start).await
301}
302
303/// Simulates execution of the protocol
304///
305/// Same as [`run`] but also takes a capacity of internal buffer to be used
306/// within simulated network. Size of internal buffer should fit total amount of the
307/// messages sent by all participants during the whole protocol execution.
308pub async fn run_with_capacity<M, F>(
309    capacity: usize,
310    n: u16,
311    mut party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>) -> F,
312) -> SimResult<F::Output>
313where
314    M: Clone + Send + Unpin + 'static,
315    F: Future,
316{
317    run_with_capacity_and_setup(
318        capacity,
319        core::iter::repeat(()).take(n.into()),
320        |i, party, ()| party_start(i, party),
321    )
322    .await
323}
324
325/// Simulates execution of the protocol
326///
327/// Similar to [`run`], but allows some setup to be provided to the protocol execution
328/// function.
329///
330/// Simulation will have as many parties as `setups` iterator yields
331///
332/// ## Example
333/// ```rust,no_run
334/// # #[tokio::main(flavor = "current_thread")]
335/// # async fn main() {
336/// use round_based::{Mpc, PartyIndex};
337///
338/// # type Result<T, E = ()> = std::result::Result<T, E>;
339/// # type Randomness = [u8; 32];
340/// # type Msg = ();
341/// // Any MPC protocol you want to test
342/// pub async fn protocol_of_random_generation<M>(
343///     rng: impl rand::RngCore,
344///     party: M,
345///     i: PartyIndex,
346///     n: u16
347/// ) -> Result<Randomness>
348/// where
349///     M: Mpc<ProtocolMessage = Msg>
350/// {
351///     // ...
352/// # todo!()
353/// }
354///
355/// let mut rng = rand_dev::DevRng::new();
356/// let n = 3;
357/// let output = round_based::sim::async_env::run_with_setup(
358///     core::iter::repeat_with(|| rng.fork()).take(n.into()),
359///     |i, party, rng| protocol_of_random_generation(rng, party, i, n),
360/// )
361/// .await
362/// // unwrap `Result`s
363/// .expect_ok()
364/// // check that all parties produced the same response
365/// .expect_eq();
366///
367/// println!("Output randomness: {}", hex::encode(output));
368/// # }  
369/// ```
370pub async fn run_with_setup<S, M, F>(
371    setups: impl IntoIterator<Item = S>,
372    party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>, S) -> F,
373) -> SimResult<F::Output>
374where
375    M: Clone + Send + Unpin + 'static,
376    F: Future,
377{
378    run_with_capacity_and_setup::<S, M, F>(DEFAULT_CAPACITY, setups, party_start).await
379}
380
381/// Simulates execution of the protocol
382///
383/// Same as [`run_with_setup`] but also takes a capacity of internal buffer to be used
384/// within simulated network. Size of internal buffer should fit total amount of the
385/// messages sent by all participants during the whole protocol execution.
386pub async fn run_with_capacity_and_setup<S, M, F>(
387    capacity: usize,
388    setups: impl IntoIterator<Item = S>,
389    mut party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>, S) -> F,
390) -> SimResult<F::Output>
391where
392    M: Clone + Send + Unpin + 'static,
393    F: Future,
394{
395    let mut network = Network::<M>::with_capacity(capacity);
396
397    let mut output = alloc::vec![];
398    for (setup, i) in setups.into_iter().zip(0u16..) {
399        output.push({
400            let party = network.add_party();
401            party_start(i, party, setup)
402        });
403    }
404
405    let result = futures_util::future::join_all(output).await;
406    SimResult(result)
407}