round_based/sim/
async_env.rs

1//! Fully async simulation
2//!
3//! Simulation provided in a [parent module](super) should be used in most cases. It works
4//! by converting all parties (defined as async functions) into [state machines](crate::state_machine),
5//! which has certain limitations. In particular, the protocol cannot await on any futures that
6//! aren't provided by [`MpcParty`], for instance, awaiting on the timer will cause a simulation error.
7//!
8//! We suggest to avoid awaiting on the futures that aren't provided by `MpcParty` in the MPC protocol
9//! implementation as it likely makes it runtime-dependent. However, if you do ultimately need to
10//! do that, then you can't use regular simulation for the tests.
11//!
12//! This module provides fully async simulation built for tokio runtime, so the protocol can await
13//! on any futures supported by the tokio.
14//!
15//! ## Limitations
16//! To implement simulated [network](Network), we used [`tokio::sync::broadcast`] channels, which
17//! have internal buffer of stored messages, and once simulated network receives more messages than
18//! internal buffer can fit, some of the parties will not receive some of the messages, which will
19//! lead to execution error.
20//!
21//! By default, internal buffer is preallocated to fit 500 messages, which should be more than
22//! sufficient for simulating protocols with small amount of parties (say, < 10).
23//!
24//! If you need to preallocate bigger buffer, use [`Network::with_capacity`].
25//!
26//! ## Example
27//! Entry point to the simulation are [`run`] and [`run_with_setup`] functions
28//!
29//! ```rust,no_run
30//! # #[tokio::main(flavor = "current_thread")]
31//! # async fn main() {
32//! use round_based::{Mpc, PartyIndex};
33//!
34//! # type Result<T, E = ()> = std::result::Result<T, E>;
35//! # type Randomness = [u8; 32];
36//! # #[derive(round_based::ProtocolMsg, Clone)]
37//! # enum Msg {}
38//! // Any MPC protocol you want to test
39//! pub async fn protocol_of_random_generation<M>(
40//!     party: M,
41//!     i: PartyIndex,
42//!     n: u16
43//! ) -> Result<Randomness>
44//! where
45//!     M: Mpc<Msg = Msg>
46//! {
47//!     // ...
48//! # todo!()
49//! }
50//!
51//! let n = 3;
52//!
53//! let output = round_based::sim::async_env::run(
54//!     n,
55//!     |i, party| protocol_of_random_generation(party, i, n),
56//! )
57//! .await
58//! // unwrap `Result`s
59//! .expect_ok()
60//! // check that all parties produced the same response
61//! .expect_eq();
62//!
63//! println!("Output randomness: {}", hex::encode(output));
64//! # }  
65//! ```
66use alloc::sync::Arc;
67use core::{
68    future::Future,
69    pin::Pin,
70    sync::atomic::AtomicU64,
71    task::ready,
72    task::{Context, Poll},
73};
74
75use futures_util::{Sink, Stream};
76use tokio::sync::broadcast;
77use tokio_stream::wrappers::{BroadcastStream, errors::BroadcastStreamRecvError};
78
79use crate::{MessageDestination, MessageType, MpcParty, MsgId, PartyIndex};
80use crate::{
81    ProtocolMsg,
82    delivery::{Incoming, Outgoing},
83};
84
85use super::SimResult;
86
87const DEFAULT_CAPACITY: usize = 500;
88
89/// Simulated async network
90pub struct Network<M> {
91    channel: broadcast::Sender<Outgoing<Incoming<M>>>,
92    next_party_idx: PartyIndex,
93    next_msg_id: Arc<NextMessageId>,
94}
95
96impl<M> Network<M>
97where
98    M: ProtocolMsg + Clone + Send + Unpin + 'static,
99{
100    /// Instantiates a new simulation
101    pub fn new() -> Self {
102        Self::with_capacity(500)
103    }
104
105    /// Instantiates a new simulation with given capacity
106    ///
107    /// `Simulation` stores internally all sent messages. Capacity limits size of the internal buffer.
108    /// Because of that you might run into error if you choose too small capacity. Choose capacity
109    /// that can fit all the messages sent by all the parties during entire protocol lifetime.
110    ///
111    /// Default capacity is 500 (i.e. if you call `Simulation::new()`)
112    pub fn with_capacity(capacity: usize) -> Self {
113        Self {
114            channel: broadcast::channel(capacity).0,
115            next_party_idx: 0,
116            next_msg_id: Default::default(),
117        }
118    }
119
120    /// Adds new party to the network
121    pub fn add_party(&mut self) -> MpcParty<M, MockedDelivery<M>> {
122        MpcParty::connected(self.connect_new_party())
123    }
124
125    /// Connects new party to the network
126    ///
127    /// Similar to [`.add_party()`](Self::add_party) but returns `MockedDelivery<M>` instead of
128    /// `MpcParty<M, MockedDelivery<M>>`
129    pub fn connect_new_party(&mut self) -> MockedDelivery<M> {
130        let local_party_idx = self.next_party_idx;
131        self.next_party_idx += 1;
132
133        MockedDelivery::new(
134            MockedIncoming {
135                local_party_idx,
136                receiver: BroadcastStream::new(self.channel.subscribe()),
137            },
138            MockedOutgoing {
139                local_party_idx,
140                sender: self.channel.clone(),
141                next_msg_id: self.next_msg_id.clone(),
142            },
143        )
144    }
145}
146
147impl<M> Default for Network<M>
148where
149    M: ProtocolMsg + Clone + Send + Unpin + 'static,
150{
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156/// Mocked networking
157pub type MockedDelivery<M> = crate::mpc::Halves<MockedIncoming<M>, MockedOutgoing<M>>;
158
159/// Delivery error
160#[derive(Debug, thiserror::Error)]
161pub enum MockedDeliveryError {
162    /// Error occurred when sending a message
163    #[error(transparent)]
164    Recv(BroadcastStreamRecvError),
165    /// Error occurred when receiving a message
166    #[error(transparent)]
167    Send(broadcast::error::SendError<()>),
168}
169
170/// Incoming channel of mocked network
171pub struct MockedIncoming<M> {
172    local_party_idx: PartyIndex,
173    receiver: BroadcastStream<Outgoing<Incoming<M>>>,
174}
175
176impl<M> Stream for MockedIncoming<M>
177where
178    M: Clone + Send + 'static,
179{
180    type Item = Result<Incoming<M>, MockedDeliveryError>;
181
182    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183        loop {
184            let msg = match ready!(Pin::new(&mut self.receiver).poll_next(cx)) {
185                Some(Ok(m)) => m,
186                Some(Err(e)) => return Poll::Ready(Some(Err(MockedDeliveryError::Recv(e)))),
187                None => return Poll::Ready(None),
188            };
189            if msg.recipient.is_p2p()
190                && msg.recipient != MessageDestination::OneParty(self.local_party_idx)
191            {
192                continue;
193            }
194            return Poll::Ready(Some(Ok(msg.msg)));
195        }
196    }
197}
198
199/// Outgoing channel of mocked network
200pub struct MockedOutgoing<M> {
201    local_party_idx: PartyIndex,
202    sender: broadcast::Sender<Outgoing<Incoming<M>>>,
203    next_msg_id: Arc<NextMessageId>,
204}
205
206impl<M> Sink<Outgoing<M>> for MockedOutgoing<M> {
207    type Error = MockedDeliveryError;
208
209    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
210        Poll::Ready(Ok(()))
211    }
212
213    fn start_send(self: Pin<&mut Self>, msg: Outgoing<M>) -> Result<(), Self::Error> {
214        let msg_type = match msg.recipient {
215            MessageDestination::AllParties { reliable } => MessageType::Broadcast { reliable },
216            MessageDestination::OneParty(_) => MessageType::P2P,
217        };
218        self.sender
219            .send(msg.map(|m| Incoming {
220                id: self.next_msg_id.next(),
221                sender: self.local_party_idx,
222                msg_type,
223                msg: m,
224            }))
225            .map_err(|_| MockedDeliveryError::Send(broadcast::error::SendError(())))?;
226        Ok(())
227    }
228
229    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
230        Poll::Ready(Ok(()))
231    }
232
233    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
234        Poll::Ready(Ok(()))
235    }
236}
237
238#[derive(Default)]
239struct NextMessageId(AtomicU64);
240
241impl NextMessageId {
242    pub fn next(&self) -> MsgId {
243        self.0.fetch_add(1, core::sync::atomic::Ordering::Relaxed)
244    }
245}
246
247/// Simulates execution of the protocol
248///
249/// Takes amount of participants, and a function that carries out the protocol for
250/// one party. The function takes as input: index of the party, and [`MpcParty`]
251/// that can be used to communicate with others.
252///
253/// ## Example
254/// ```rust,no_run
255/// # #[tokio::main(flavor = "current_thread")]
256/// # async fn main() {
257/// use round_based::{Mpc, PartyIndex};
258///
259/// # type Result<T, E = ()> = std::result::Result<T, E>;
260/// # type Randomness = [u8; 32];
261/// # #[derive(round_based::ProtocolMsg, Clone)]
262/// # enum Msg {}
263/// // Any MPC protocol you want to test
264/// pub async fn protocol_of_random_generation<M>(
265///     party: M,
266///     i: PartyIndex,
267///     n: u16
268/// ) -> Result<Randomness>
269/// where
270///     M: Mpc<Msg = Msg>
271/// {
272///     // ...
273/// # todo!()
274/// }
275///
276/// let n = 3;
277///
278/// let output = round_based::sim::async_env::run(
279///     n,
280///     |i, party| protocol_of_random_generation(party, i, n),
281/// )
282/// .await
283/// // unwrap `Result`s
284/// .expect_ok()
285/// // check that all parties produced the same response
286/// .expect_eq();
287///
288/// println!("Output randomness: {}", hex::encode(output));
289/// # }  
290/// ```
291pub async fn run<M, F>(
292    n: u16,
293    party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>) -> F,
294) -> SimResult<F::Output>
295where
296    M: ProtocolMsg + Clone + Send + Unpin + 'static,
297    F: Future,
298{
299    run_with_capacity(DEFAULT_CAPACITY, n, party_start).await
300}
301
302/// Simulates execution of the protocol
303///
304/// Same as [`run`] but also takes a capacity of internal buffer to be used
305/// within simulated network. Size of internal buffer should fit total amount of the
306/// messages sent by all participants during the whole protocol execution.
307pub async fn run_with_capacity<M, F>(
308    capacity: usize,
309    n: u16,
310    mut party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>) -> F,
311) -> SimResult<F::Output>
312where
313    M: ProtocolMsg + Clone + Send + Unpin + 'static,
314    F: Future,
315{
316    run_with_capacity_and_setup(
317        capacity,
318        core::iter::repeat_n((), n.into()),
319        |i, party, ()| party_start(i, party),
320    )
321    .await
322}
323
324/// Simulates execution of the protocol
325///
326/// Similar to [`run`], but allows some setup to be provided to the protocol execution
327/// function.
328///
329/// Simulation will have as many parties as `setups` iterator yields
330///
331/// ## Example
332/// ```rust,no_run
333/// # #[tokio::main(flavor = "current_thread")]
334/// # async fn main() {
335/// use round_based::{Mpc, PartyIndex};
336///
337/// # type Result<T, E = ()> = std::result::Result<T, E>;
338/// # type Randomness = [u8; 32];
339/// # #[derive(round_based::ProtocolMsg, Clone)]
340/// # enum Msg {}
341/// // Any MPC protocol you want to test
342/// pub async fn protocol_of_random_generation<M>(
343///     rng: impl rand::RngCore,
344///     party: M,
345///     i: PartyIndex,
346///     n: u16
347/// ) -> Result<Randomness>
348/// where
349///     M: Mpc<Msg = Msg>
350/// {
351///     // ...
352/// # todo!()
353/// }
354///
355/// let mut rng = rand_dev::DevRng::new();
356/// let n = 3;
357/// let output = round_based::sim::async_env::run_with_setup(
358///     core::iter::repeat_with(|| rng.fork()).take(n.into()),
359///     |i, party, rng| protocol_of_random_generation(rng, party, i, n),
360/// )
361/// .await
362/// // unwrap `Result`s
363/// .expect_ok()
364/// // check that all parties produced the same response
365/// .expect_eq();
366///
367/// println!("Output randomness: {}", hex::encode(output));
368/// # }  
369/// ```
370pub async fn run_with_setup<S, M, F>(
371    setups: impl IntoIterator<Item = S>,
372    party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>, S) -> F,
373) -> SimResult<F::Output>
374where
375    M: ProtocolMsg + Clone + Send + Unpin + 'static,
376    F: Future,
377{
378    run_with_capacity_and_setup::<S, M, F>(DEFAULT_CAPACITY, setups, party_start).await
379}
380
381/// Simulates execution of the protocol
382///
383/// Same as [`run_with_setup`] but also takes a capacity of internal buffer to be used
384/// within simulated network. Size of internal buffer should fit total amount of the
385/// messages sent by all participants during the whole protocol execution.
386pub async fn run_with_capacity_and_setup<S, M, F>(
387    capacity: usize,
388    setups: impl IntoIterator<Item = S>,
389    mut party_start: impl FnMut(u16, MpcParty<M, MockedDelivery<M>>, S) -> F,
390) -> SimResult<F::Output>
391where
392    M: ProtocolMsg + Clone + Send + Unpin + 'static,
393    F: Future,
394{
395    let mut network = Network::<M>::with_capacity(capacity);
396
397    let mut output = alloc::vec![];
398    for (setup, i) in setups.into_iter().zip(0u16..) {
399        output.push({
400            let party = network.add_party();
401            party_start(i, party, setup)
402        });
403    }
404
405    let result = futures_util::future::join_all(output).await;
406    SimResult(result)
407}