round_based/sim/
mod.rs

1//! Multiparty protocol simulation
2//!
3//! Simulator is an essential developer tool for testing the multiparty protocol locally.
4//! It covers most of the boilerplate of emulating MPC protocol execution.
5//!
6//! The entry point is either [`run`] or [`run_with_setup`] functions. They take a protocol
7//! defined as an async function, provide simulated networking, carry out the simulation,
8//! and return the result.
9//!
10//! If you need more control over execution, you can use [`Simulation`]. For instance, it allows
11//! creating a simulation that has parties defined by different functions, which is helpful, for
12//! instance, in simulation in presence of an adversary (e.g. one set of parties can be defined
13//! with a regular function/protocol implementation, when the other set of parties may be defined
14//! by other function which emulates adversary behavior).
15//!
16//! ## Limitations
17//! [`Simulation`] works by converting each party (defined as an async function) into the
18//! [state machine](crate::state_machine). That should work without problems in most cases, providing
19//! better UX, without requiring an async runtime (simulation is entirely sync).
20//!
21//! However, a protocol wrapped into a state machine cannot poll any futures except provided within
22//! [`MpcParty`](crate::MpcParty) (so it can only await on sending/receiving messages and yielding).
23//! For instance, if the protocol implementation makes use of tokio timers, it will result into an
24//! execution error.
25//!
26//! In general, we do not recommend awaiting on the futures that aren't provided by `MpcParty` in
27//! the MPC protocol implementation, to keep the protocol implementation runtime-agnostic.
28//!
29//! If you do really need to make use of unsupported futures, you can use [`async_env`] instead,
30//! which provides a simulation on tokio runtime, but has its own limitations.
31//!
32//! ## Example
33//! ```rust,no_run
34//! use round_based::{Mpc, PartyIndex};
35//!
36//! # type Result<T, E = ()> = std::result::Result<T, E>;
37//! # type Randomness = [u8; 32];
38//! # #[derive(round_based::ProtocolMsg, Clone)]
39//! # enum Msg {}
40//! // Any MPC protocol you want to test
41//! pub async fn protocol_of_random_generation<M>(
42//!     party: M,
43//!     i: PartyIndex,
44//!     n: u16
45//! ) -> Result<Randomness>
46//! where
47//!     M: Mpc<Msg = Msg>
48//! {
49//!     // ...
50//! # todo!()
51//! }
52//!
53//! let n = 3;
54//!
55//! let output = round_based::sim::run(
56//!     n,
57//!     |i, party| protocol_of_random_generation(party, i, n),
58//! )
59//! .unwrap()
60//! // unwrap `Result`s
61//! .expect_ok()
62//! // check that all parties produced the same response
63//! .expect_eq();
64//!
65//! println!("Output randomness: {}", hex::encode(output));
66//! ```
67
68use alloc::{boxed::Box, collections::VecDeque, string::ToString, vec::Vec};
69use core::future::Future;
70
71use crate::{
72    Incoming, MessageDestination, MessageType, Outgoing, ProtocolMsg, state_machine::ProceedResult,
73};
74
75#[cfg(feature = "sim-async")]
76pub mod async_env;
77
78/// Result of the simulation
79pub struct SimResult<T>(pub Vec<T>);
80
81impl<T, E> SimResult<Result<T, E>>
82where
83    E: core::fmt::Debug,
84{
85    /// Unwraps `Result<T, E>` produced by each party
86    ///
87    /// Panics if at least one of the parties returned `Err(_)`. In this case,
88    /// a verbose error message will shown specifying which of the parties returned
89    /// an error.
90    pub fn expect_ok(self) -> SimResult<T> {
91        let mut oks = Vec::with_capacity(self.0.len());
92        let mut errs = Vec::with_capacity(self.0.len());
93
94        for (res, i) in self.0.into_iter().zip(0u16..) {
95            match res {
96                Ok(res) => oks.push(res),
97                Err(res) => errs.push((i, res)),
98            }
99        }
100
101        if !errs.is_empty() {
102            let mut msg = alloc::format!(
103                "Simulation output didn't match expectations.\n\
104                Expected: all parties succeed\n\
105                Actual  : {success} parties succeeded, {failed} parties returned an error\n\
106                Failures:\n",
107                success = oks.len(),
108                failed = errs.len(),
109            );
110
111            for (i, err) in errs {
112                msg += &alloc::format!("- Party {i}: {err:?}\n");
113            }
114
115            panic!("{msg}");
116        }
117
118        SimResult(oks)
119    }
120}
121
122impl<T> SimResult<T>
123where
124    T: PartialEq + core::fmt::Debug,
125{
126    /// Checks that outputs of all parties are equally the same
127    ///
128    /// Returns the output on success (all the outputs are checked to be the same), otherwise
129    /// panics with a verbose error message.
130    ///
131    /// Panics if simulation contained zero parties.
132    pub fn expect_eq(mut self) -> T {
133        let Some(first) = self.0.first() else {
134            panic!("simulation contained zero parties");
135        };
136
137        if !self.0[1..].iter().all(|i| i == first) {
138            let mut msg = alloc::string::String::from(
139                "Simulation output didn't match expectations.\n\
140                Expected: all parties return the same output\n\
141                Actual  : some of the parties returned a different output\n\
142                Outputs :\n",
143            );
144
145            let mut clusters: Vec<(&T, Vec<usize>)> = Vec::new();
146            for (i, value) in self.0.iter().enumerate() {
147                match clusters
148                    .iter_mut()
149                    .find(|(cluster_value, _)| *cluster_value == value)
150                    .map(|(_, indexes)| indexes)
151                {
152                    Some(indexes) => indexes.push(i),
153                    None => clusters.push((value, alloc::vec![i])),
154                }
155            }
156
157            for (value, parties) in &clusters {
158                if parties.len() == 1 {
159                    msg += "- Party ";
160                } else {
161                    msg += "- Parties "
162                }
163
164                for (i, is_first) in parties
165                    .iter()
166                    .zip(core::iter::once(true).chain(core::iter::repeat(false)))
167                {
168                    if !is_first {
169                        msg += ", "
170                    }
171                    msg += &i.to_string();
172                }
173
174                msg += &alloc::format!(": {value:?}\n");
175            }
176
177            panic!("{msg}")
178        }
179
180        self.0
181            .pop()
182            .expect("we checked that the list contains at least one element")
183    }
184}
185
186impl<T> SimResult<T> {
187    /// Deconstructs the simulation result returning inner list of results
188    pub fn into_vec(self) -> Vec<T> {
189        self.0
190    }
191}
192
193impl<T> IntoIterator for SimResult<T> {
194    type Item = T;
195    type IntoIter = alloc::vec::IntoIter<T>;
196    fn into_iter(self) -> Self::IntoIter {
197        self.0.into_iter()
198    }
199}
200
201impl<T> core::ops::Deref for SimResult<T> {
202    type Target = [T];
203    fn deref(&self) -> &Self::Target {
204        &self.0
205    }
206}
207
208impl<T> From<Vec<T>> for SimResult<T> {
209    fn from(list: Vec<T>) -> Self {
210        Self(list)
211    }
212}
213
214impl<T> From<SimResult<T>> for Vec<T> {
215    fn from(res: SimResult<T>) -> Self {
216        res.0
217    }
218}
219
220/// Simulates MPC protocol with parties defined as [state machines](crate::state_machine)
221pub struct Simulation<'a, O, M> {
222    parties: Vec<Party<'a, O, M>>,
223}
224
225enum Party<'a, O, M> {
226    Active {
227        party: Box<dyn crate::state_machine::StateMachine<Output = O, Msg = M> + 'a>,
228        wants_one_more_msg: bool,
229    },
230    Finished(O),
231}
232
233impl<'a, O, M> Simulation<'a, O, M>
234where
235    M: ProtocolMsg + Clone + 'static,
236{
237    /// Creates empty simulation containing no parties
238    ///
239    /// New parties can be added via [`.add_party()`](Self::add_party)
240    pub fn empty() -> Self {
241        Self {
242            parties: Vec::new(),
243        }
244    }
245
246    /// Constructs empty simulation containing no parties, with allocated memory that can fit up to `n` parties without re-allocations
247    pub fn with_capacity(n: u16) -> Self {
248        Self {
249            parties: Vec::with_capacity(n.into()),
250        }
251    }
252
253    /// Constructs a simulation with `n` parties from async function that defines the protocol
254    ///
255    /// Each party has index `0 <= i < n` and instantiated via provided `init` function
256    ///
257    /// Async function will be converted into a [state machine](crate::state_machine). Because of that,
258    /// it cannot await on any futures that aren't provided by `MpcParty` (that is given as an argument
259    /// to this function).
260    pub fn from_async_fn<F>(
261        n: u16,
262        mut init: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
263    ) -> Self
264    where
265        F: core::future::Future<Output = O> + 'a,
266    {
267        let mut sim = Self::with_capacity(n);
268        for i in 0..n {
269            sim.add_async_party(|party| init(i, party))
270        }
271        sim
272    }
273
274    /// Construct a simulation with `n` parties from `init` function that constructs state machine for each party
275    ///
276    /// Each party has index `0 <= i < n` and instantiated via provided `init` function
277    pub fn from_fn<S>(n: u16, mut init: impl FnMut(u16) -> S) -> Self
278    where
279        S: crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
280    {
281        let mut sim = Self::with_capacity(n);
282        for i in 0..n {
283            sim.add_party(init(i));
284        }
285        sim
286    }
287
288    /// Adds new party into the protocol
289    ///
290    /// New party will be assigned index `i = n - 1` where `n` is amount of parties in the
291    /// simulation after this party was added.
292    pub fn add_party(
293        &mut self,
294        party: impl crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
295    ) {
296        self.parties.push(Party::Active {
297            party: Box::new(party),
298            wants_one_more_msg: false,
299        })
300    }
301
302    /// Adds new party, defined as an async function, into the protocol
303    ///
304    /// New party will be assigned index `i = n - 1` where `n` is amount of parties in the
305    /// simulation after this party was added.
306    ///
307    /// Async function will be converted into a [state machine](crate::state_machine). Because of that,
308    /// it cannot await on any futures that aren't provided by `MpcParty` (that is given as an argument
309    /// to this function).
310    pub fn add_async_party<F>(&mut self, party: impl FnOnce(crate::state_machine::MpcParty<M>) -> F)
311    where
312        F: core::future::Future<Output = O> + 'a,
313    {
314        self.parties.push(Party::Active {
315            party: Box::new(crate::state_machine::wrap_protocol(party)),
316            wants_one_more_msg: false,
317        })
318    }
319
320    /// Returns amount of parties in the simulation
321    pub fn parties_amount(&self) -> usize {
322        self.parties.len()
323    }
324
325    /// Carries out the simulation
326    pub fn run(mut self) -> Result<SimResult<O>, SimError> {
327        let mut messages_queue = MessagesQueue::new(self.parties.len());
328        let mut parties_left = self.parties.len();
329
330        while parties_left > 0 {
331            'next_party: for (i, party_state) in (0..).zip(&mut self.parties) {
332                'this_party: loop {
333                    let Party::Active {
334                        party,
335                        wants_one_more_msg,
336                    } = party_state
337                    else {
338                        continue 'next_party;
339                    };
340
341                    if *wants_one_more_msg {
342                        if let Some(message) = messages_queue.recv_next_msg(i) {
343                            party
344                                .received_msg(message)
345                                .map_err(|_| Reason::SaveIncomingMsg)?;
346                            *wants_one_more_msg = false;
347                        } else {
348                            continue 'next_party;
349                        }
350                    }
351
352                    match party.proceed() {
353                        ProceedResult::SendMsg(msg) => {
354                            messages_queue.send_message(i, msg)?;
355                            continue 'this_party;
356                        }
357                        ProceedResult::NeedsOneMoreMessage => {
358                            *wants_one_more_msg = true;
359                            continue 'this_party;
360                        }
361                        ProceedResult::Output(out) => {
362                            *party_state = Party::Finished(out);
363                            parties_left -= 1;
364                            continue 'next_party;
365                        }
366                        ProceedResult::Yielded => {
367                            continue 'this_party;
368                        }
369                        ProceedResult::Error(err) => {
370                            return Err(Reason::ExecutionError(err).into());
371                        }
372                    }
373                }
374            }
375        }
376
377        Ok(SimResult(
378            self.parties
379                .into_iter()
380                .map(|party| match party {
381                    Party::Active { .. } => {
382                        unreachable!("there must be no active parties when `parties_left == 0`")
383                    }
384                    Party::Finished(out) => out,
385                })
386                .collect(),
387        ))
388    }
389}
390
391/// Error indicating that simulation failed
392#[derive(Debug, thiserror::Error)]
393#[error(transparent)]
394pub struct SimError(#[from] Reason);
395
396#[derive(Debug, thiserror::Error)]
397enum Reason {
398    #[error("save incoming message")]
399    SaveIncomingMsg,
400    #[error("execution error")]
401    ExecutionError(#[source] crate::state_machine::ExecutionError),
402    #[error("party #{sender} tried to send a message to non existing party #{recipient}")]
403    UnknownRecipient { sender: u16, recipient: u16 },
404}
405
406struct MessagesQueue<M> {
407    queue: Vec<VecDeque<Incoming<M>>>,
408    next_id: u64,
409}
410
411impl<M: Clone> MessagesQueue<M> {
412    fn new(n: usize) -> Self {
413        Self {
414            queue: alloc::vec![VecDeque::new(); n],
415            next_id: 0,
416        }
417    }
418
419    fn send_message(&mut self, sender: u16, msg: Outgoing<M>) -> Result<(), SimError> {
420        match msg.recipient {
421            MessageDestination::AllParties { reliable } => {
422                let mut msg_ids = self.next_id..;
423                for (destination, msg_id) in (0..)
424                    .zip(&mut self.queue)
425                    .filter(|(recipient_index, _)| *recipient_index != sender)
426                    .map(|(_, msg)| msg)
427                    .zip(msg_ids.by_ref())
428                {
429                    destination.push_back(Incoming {
430                        id: msg_id,
431                        sender,
432                        msg_type: MessageType::Broadcast { reliable },
433                        msg: msg.msg.clone(),
434                    })
435                }
436                self.next_id = msg_ids.next().unwrap();
437            }
438            MessageDestination::OneParty(destination) => {
439                let next_id = self.next_id;
440                self.next_id += 1;
441
442                self.queue
443                    .get_mut(usize::from(destination))
444                    .ok_or(Reason::UnknownRecipient {
445                        sender,
446                        recipient: destination,
447                    })?
448                    .push_back(Incoming {
449                        id: next_id,
450                        sender,
451                        msg_type: MessageType::P2P,
452                        msg: msg.msg,
453                    })
454            }
455        }
456
457        Ok(())
458    }
459
460    fn recv_next_msg(&mut self, recipient: u16) -> Option<Incoming<M>> {
461        self.queue[usize::from(recipient)].pop_front()
462    }
463}
464
465/// Simulates execution of the protocol
466///
467/// Takes amount of participants, and a function that carries out the protocol for
468/// one party. The function takes as input: index of the party, and [`MpcParty`](crate::MpcParty)
469/// that can be used to communicate with others.
470///
471/// ## Example
472/// ```rust,no_run
473/// use round_based::{Mpc, PartyIndex};
474///
475/// # type Result<T, E = ()> = std::result::Result<T, E>;
476/// # type Randomness = [u8; 32];
477/// # #[derive(round_based::ProtocolMsg, Clone)]
478/// # enum Msg {}
479/// // Any MPC protocol you want to test
480/// pub async fn protocol_of_random_generation<M>(
481///     party: M,
482///     i: PartyIndex,
483///     n: u16
484/// ) -> Result<Randomness>
485/// where
486///     M: Mpc<Msg = Msg>
487/// {
488///     // ...
489/// # todo!()
490/// }
491///
492/// let n = 3;
493///
494/// let output = round_based::sim::run(
495///     n,
496///     |i, party| protocol_of_random_generation(party, i, n),
497/// )
498/// .unwrap()
499/// // unwrap `Result`s
500/// .expect_ok()
501/// // check that all parties produced the same response
502/// .expect_eq();
503///
504/// println!("Output randomness: {}", hex::encode(output));
505/// ```
506pub fn run<M, F>(
507    n: u16,
508    mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
509) -> Result<SimResult<F::Output>, SimError>
510where
511    M: ProtocolMsg + Clone + 'static,
512    F: Future,
513{
514    run_with_setup(core::iter::repeat_n((), n.into()), |i, party, ()| {
515        party_start(i, party)
516    })
517}
518
519/// Simulates execution of the protocol
520///
521/// Similar to [`run`], but allows some setup to be provided to the protocol execution
522/// function.
523///
524/// Simulation will have as many parties as `setups` iterator yields
525///
526/// ## Example
527/// ```rust,no_run
528/// use round_based::{Mpc, PartyIndex};
529///
530/// # type Result<T, E = ()> = std::result::Result<T, E>;
531/// # type Randomness = [u8; 32];
532/// # #[derive(round_based::ProtocolMsg, Clone)]
533/// # enum Msg {}
534/// // Any MPC protocol you want to test
535/// pub async fn protocol_of_random_generation<M>(
536///     rng: impl rand::RngCore,
537///     party: M,
538///     i: PartyIndex,
539///     n: u16
540/// ) -> Result<Randomness>
541/// where
542///     M: Mpc<Msg = Msg>
543/// {
544///     // ...
545/// # todo!()
546/// }
547///
548/// let mut rng = rand_dev::DevRng::new();
549/// let n = 3;
550/// let output = round_based::sim::run_with_setup(
551///     core::iter::repeat_with(|| rng.fork()).take(n.into()),
552///     |i, party, rng| protocol_of_random_generation(rng, party, i, n),
553/// )
554/// .unwrap()
555/// // unwrap `Result`s
556/// .expect_ok()
557/// // check that all parties produced the same response
558/// .expect_eq();
559///
560/// println!("Output randomness: {}", hex::encode(output));
561/// ```
562pub fn run_with_setup<S, M, F>(
563    setups: impl IntoIterator<Item = S>,
564    mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>, S) -> F,
565) -> Result<SimResult<F::Output>, SimError>
566where
567    M: ProtocolMsg + Clone + 'static,
568    F: Future,
569{
570    let mut sim = Simulation::empty();
571
572    for (setup, i) in setups.into_iter().zip(0u16..) {
573        let party = crate::state_machine::wrap_protocol(|party| party_start(i, party, setup));
574        sim.add_party(party);
575    }
576
577    sim.run()
578}
579
580#[cfg(test)]
581mod tests {
582    mod expect_eq {
583        use crate::sim::SimResult;
584
585        #[test]
586        fn all_eq() {
587            let res = SimResult::from(alloc::vec!["same string", "same string", "same string"])
588                .expect_eq();
589            assert_eq!(res, "same string")
590        }
591
592        #[test]
593        #[should_panic]
594        fn empty_res() {
595            SimResult::from(alloc::vec![]).expect_eq()
596        }
597
598        #[test]
599        #[should_panic]
600        fn not_eq() {
601            SimResult::from(alloc::vec![
602                "one result",
603                "one result",
604                "another result",
605                "one result",
606                "and something else",
607            ])
608            .expect_eq();
609        }
610    }
611
612    mod expect_ok {
613        use crate::sim::SimResult;
614
615        #[test]
616        fn all_ok() {
617            let res = SimResult::<Result<i32, core::convert::Infallible>>::from(alloc::vec![
618                Ok(0),
619                Ok(1),
620                Ok(2)
621            ])
622            .expect_ok()
623            .into_vec();
624
625            assert_eq!(res, [0, 1, 2]);
626        }
627
628        #[test]
629        #[should_panic]
630        fn not_ok() {
631            SimResult::from(alloc::vec![
632                Ok(0),
633                Err("i couldn't do what you asked :("),
634                Ok(2),
635                Ok(3),
636                Err("sorry I was pooping, what did you want?")
637            ])
638            .expect_ok();
639        }
640    }
641}