round_based/sim/
mod.rs

1//! Multiparty protocol simulation
2//!
3//! Simulator is an essential developer tool for testing the multiparty protocol locally.
4//! It covers most of the boilerplate of emulating MPC protocol execution.
5//!
6//! The entry point is either [`run`] or [`run_with_setup`] functions. They take a protocol
7//! defined as an async function, provide simulated networking, carry out the simulation,
8//! and return the result.
9//!
10//! If you need more control over execution, you can use [`Simulation`]. For instance, it allows
11//! creating a simulation that has parties defined by different functions, which is helpful, for
12//! instance, in simulation in presence of an adversary (e.g. one set of parties can be defined
13//! with a regular function/protocol implementation, when the other set of parties may be defined
14//! by other function which emulates adversary behavior).
15//!
16//! ## Limitations
17//! [`Simulation`] works by converting each party (defined as an async function) into the
18//! [state machine](crate::state_machine). That should work without problems in most cases, providing
19//! better UX, without requiring an async runtime (simulation is entirely sync).
20//!
21//! However, a protocol wrapped into a state machine cannot poll any futures except provided within
22//! [`MpcParty`](crate::MpcParty) (so it can only await on sending/receiving messages and yielding).
23//! For instance, if the protocol implementation makes use of tokio timers, it will result into an
24//! execution error.
25//!
26//! In general, we do not recommend awaiting on the futures that aren't provided by `MpcParty` in
27//! the MPC protocol implementation, to keep the protocol implementation runtime-agnostic.
28//!
29//! If you do really need to make use of unsupported futures, you can use [`async_env`] instead,
30//! which provides a simulation on tokio runtime, but has its own limitations.
31//!
32//! ## Example
33//! ```rust,no_run
34//! use round_based::{Mpc, PartyIndex};
35//!
36//! # type Result<T, E = ()> = std::result::Result<T, E>;
37//! # type Randomness = [u8; 32];
38//! # type Msg = ();
39//! // Any MPC protocol you want to test
40//! pub async fn protocol_of_random_generation<M>(
41//!     party: M,
42//!     i: PartyIndex,
43//!     n: u16
44//! ) -> Result<Randomness>
45//! where
46//!     M: Mpc<ProtocolMessage = Msg>
47//! {
48//!     // ...
49//! # todo!()
50//! }
51//!
52//! let n = 3;
53//!
54//! let output = round_based::sim::run(
55//!     n,
56//!     |i, party| protocol_of_random_generation(party, i, n),
57//! )
58//! .unwrap()
59//! // unwrap `Result`s
60//! .expect_ok()
61//! // check that all parties produced the same response
62//! .expect_eq();
63//!
64//! println!("Output randomness: {}", hex::encode(output));
65//! ```
66
67use alloc::{boxed::Box, collections::VecDeque, string::ToString, vec::Vec};
68use core::future::Future;
69
70use crate::{state_machine::ProceedResult, Incoming, MessageDestination, MessageType, Outgoing};
71
72#[cfg(feature = "sim-async")]
73pub mod async_env;
74
75/// Result of the simulation
76pub struct SimResult<T>(pub Vec<T>);
77
78impl<T, E> SimResult<Result<T, E>>
79where
80    E: core::fmt::Debug,
81{
82    /// Unwraps `Result<T, E>` produced by each party
83    ///
84    /// Panics if at least one of the parties returned `Err(_)`. In this case,
85    /// a verbose error message will shown specifying which of the parties returned
86    /// an error.
87    pub fn expect_ok(self) -> SimResult<T> {
88        let mut oks = Vec::with_capacity(self.0.len());
89        let mut errs = Vec::with_capacity(self.0.len());
90
91        for (res, i) in self.0.into_iter().zip(0u16..) {
92            match res {
93                Ok(res) => oks.push(res),
94                Err(res) => errs.push((i, res)),
95            }
96        }
97
98        if !errs.is_empty() {
99            let mut msg = alloc::format!(
100                "Simulation output didn't match expectations.\n\
101                Expected: all parties succeed\n\
102                Actual  : {success} parties succeeded, {failed} parties returned an error\n\
103                Failures:\n",
104                success = oks.len(),
105                failed = errs.len(),
106            );
107
108            for (i, err) in errs {
109                msg += &alloc::format!("- Party {i}: {err:?}\n");
110            }
111
112            panic!("{msg}");
113        }
114
115        SimResult(oks)
116    }
117}
118
119impl<T> SimResult<T>
120where
121    T: PartialEq + core::fmt::Debug,
122{
123    /// Checks that outputs of all parties are equally the same
124    ///
125    /// Returns the output on success (all the outputs are checked to be the same), otherwise
126    /// panics with a verbose error message.
127    ///
128    /// Panics if simulation contained zero parties.
129    pub fn expect_eq(mut self) -> T {
130        let Some(first) = self.0.first() else {
131            panic!("simulation contained zero parties");
132        };
133
134        if !self.0[1..].iter().all(|i| i == first) {
135            let mut msg = alloc::string::String::from(
136                "Simulation output didn't match expectations.\n\
137                Expected: all parties return the same output\n\
138                Actual  : some of the parties returned a different output\n\
139                Outputs :\n",
140            );
141
142            let mut clusters: Vec<(&T, Vec<usize>)> = Vec::new();
143            for (i, value) in self.0.iter().enumerate() {
144                match clusters
145                    .iter_mut()
146                    .find(|(cluster_value, _)| *cluster_value == value)
147                    .map(|(_, indexes)| indexes)
148                {
149                    Some(indexes) => indexes.push(i),
150                    None => clusters.push((value, alloc::vec![i])),
151                }
152            }
153
154            for (value, parties) in &clusters {
155                if parties.len() == 1 {
156                    msg += "- Party ";
157                } else {
158                    msg += "- Parties "
159                }
160
161                for (i, is_first) in parties
162                    .iter()
163                    .zip(core::iter::once(true).chain(core::iter::repeat(false)))
164                {
165                    if !is_first {
166                        msg += ", "
167                    }
168                    msg += &i.to_string();
169                }
170
171                msg += &alloc::format!(": {value:?}\n");
172            }
173
174            panic!("{msg}")
175        }
176
177        self.0
178            .pop()
179            .expect("we checked that the list contains at least one element")
180    }
181}
182
183impl<T> SimResult<T> {
184    /// Deconstructs the simulation result returning inner list of results
185    pub fn into_vec(self) -> Vec<T> {
186        self.0
187    }
188}
189
190impl<T> IntoIterator for SimResult<T> {
191    type Item = T;
192    type IntoIter = alloc::vec::IntoIter<T>;
193    fn into_iter(self) -> Self::IntoIter {
194        self.0.into_iter()
195    }
196}
197
198impl<T> core::ops::Deref for SimResult<T> {
199    type Target = [T];
200    fn deref(&self) -> &Self::Target {
201        &self.0
202    }
203}
204
205impl<T> From<Vec<T>> for SimResult<T> {
206    fn from(list: Vec<T>) -> Self {
207        Self(list)
208    }
209}
210
211impl<T> From<SimResult<T>> for Vec<T> {
212    fn from(res: SimResult<T>) -> Self {
213        res.0
214    }
215}
216
217/// Simulates MPC protocol with parties defined as [state machines](crate::state_machine)
218pub struct Simulation<'a, O, M> {
219    parties: Vec<Party<'a, O, M>>,
220}
221
222enum Party<'a, O, M> {
223    Active {
224        party: Box<dyn crate::state_machine::StateMachine<Output = O, Msg = M> + 'a>,
225        wants_one_more_msg: bool,
226    },
227    Finished(O),
228}
229
230impl<'a, O, M> Simulation<'a, O, M>
231where
232    M: Clone + 'static,
233{
234    /// Creates empty simulation containing no parties
235    ///
236    /// New parties can be added via [`.add_party()`](Self::add_party)
237    pub fn empty() -> Self {
238        Self {
239            parties: Vec::new(),
240        }
241    }
242
243    /// Constructs empty simulation containing no parties, with allocated memory that can fit up to `n` parties without re-allocations
244    pub fn with_capacity(n: u16) -> Self {
245        Self {
246            parties: Vec::with_capacity(n.into()),
247        }
248    }
249
250    /// Constructs a simulation with `n` parties from async function that defines the protocol
251    ///
252    /// Each party has index `0 <= i < n` and instantiated via provided `init` function
253    ///
254    /// Async function will be converted into a [state machine](crate::state_machine). Because of that,
255    /// it cannot await on any futures that aren't provided by `MpcParty` (that is given as an argument
256    /// to this function).
257    pub fn from_async_fn<F>(
258        n: u16,
259        mut init: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
260    ) -> Self
261    where
262        F: core::future::Future<Output = O> + 'a,
263    {
264        let mut sim = Self::with_capacity(n);
265        for i in 0..n {
266            sim.add_async_party(|party| init(i, party))
267        }
268        sim
269    }
270
271    /// Construct a simulation with `n` parties from `init` function that constructs state machine for each party
272    ///
273    /// Each party has index `0 <= i < n` and instantiated via provided `init` function
274    pub fn from_fn<S>(n: u16, mut init: impl FnMut(u16) -> S) -> Self
275    where
276        S: crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
277    {
278        let mut sim = Self::with_capacity(n);
279        for i in 0..n {
280            sim.add_party(init(i));
281        }
282        sim
283    }
284
285    /// Adds new party into the protocol
286    ///
287    /// New party will be assigned index `i = n - 1` where `n` is amount of parties in the
288    /// simulation after this party was added.
289    pub fn add_party(
290        &mut self,
291        party: impl crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
292    ) {
293        self.parties.push(Party::Active {
294            party: Box::new(party),
295            wants_one_more_msg: false,
296        })
297    }
298
299    /// Adds new party, defined as an async function, into the protocol
300    ///
301    /// New party will be assigned index `i = n - 1` where `n` is amount of parties in the
302    /// simulation after this party was added.
303    ///
304    /// Async function will be converted into a [state machine](crate::state_machine). Because of that,
305    /// it cannot await on any futures that aren't provided by `MpcParty` (that is given as an argument
306    /// to this function).
307    pub fn add_async_party<F>(&mut self, party: impl FnOnce(crate::state_machine::MpcParty<M>) -> F)
308    where
309        F: core::future::Future<Output = O> + 'a,
310    {
311        self.parties.push(Party::Active {
312            party: Box::new(crate::state_machine::wrap_protocol(party)),
313            wants_one_more_msg: false,
314        })
315    }
316
317    /// Returns amount of parties in the simulation
318    pub fn parties_amount(&self) -> usize {
319        self.parties.len()
320    }
321
322    /// Carries out the simulation
323    pub fn run(mut self) -> Result<SimResult<O>, SimError> {
324        let mut messages_queue = MessagesQueue::new(self.parties.len());
325        let mut parties_left = self.parties.len();
326
327        while parties_left > 0 {
328            'next_party: for (i, party_state) in (0..).zip(&mut self.parties) {
329                'this_party: loop {
330                    let Party::Active {
331                        party,
332                        wants_one_more_msg,
333                    } = party_state
334                    else {
335                        continue 'next_party;
336                    };
337
338                    if *wants_one_more_msg {
339                        if let Some(message) = messages_queue.recv_next_msg(i) {
340                            party
341                                .received_msg(message)
342                                .map_err(|_| Reason::SaveIncomingMsg)?;
343                            *wants_one_more_msg = false;
344                        } else {
345                            continue 'next_party;
346                        }
347                    }
348
349                    match party.proceed() {
350                        ProceedResult::SendMsg(msg) => {
351                            messages_queue.send_message(i, msg)?;
352                            continue 'this_party;
353                        }
354                        ProceedResult::NeedsOneMoreMessage => {
355                            *wants_one_more_msg = true;
356                            continue 'this_party;
357                        }
358                        ProceedResult::Output(out) => {
359                            *party_state = Party::Finished(out);
360                            parties_left -= 1;
361                            continue 'next_party;
362                        }
363                        ProceedResult::Yielded => {
364                            continue 'this_party;
365                        }
366                        ProceedResult::Error(err) => {
367                            return Err(Reason::ExecutionError(err).into());
368                        }
369                    }
370                }
371            }
372        }
373
374        Ok(SimResult(
375            self.parties
376                .into_iter()
377                .map(|party| match party {
378                    Party::Active { .. } => {
379                        unreachable!("there must be no active parties when `parties_left == 0`")
380                    }
381                    Party::Finished(out) => out,
382                })
383                .collect(),
384        ))
385    }
386}
387
388/// Error indicating that simulation failed
389#[derive(Debug, thiserror::Error)]
390#[error(transparent)]
391pub struct SimError(#[from] Reason);
392
393#[derive(Debug, thiserror::Error)]
394enum Reason {
395    #[error("save incoming message")]
396    SaveIncomingMsg,
397    #[error("execution error")]
398    ExecutionError(#[source] crate::state_machine::ExecutionError),
399    #[error("party #{sender} tried to send a message to non existing party #{recipient}")]
400    UnknownRecipient { sender: u16, recipient: u16 },
401}
402
403struct MessagesQueue<M> {
404    queue: Vec<VecDeque<Incoming<M>>>,
405    next_id: u64,
406}
407
408impl<M: Clone> MessagesQueue<M> {
409    fn new(n: usize) -> Self {
410        Self {
411            queue: alloc::vec![VecDeque::new(); n],
412            next_id: 0,
413        }
414    }
415
416    fn send_message(&mut self, sender: u16, msg: Outgoing<M>) -> Result<(), SimError> {
417        match msg.recipient {
418            MessageDestination::AllParties => {
419                let mut msg_ids = self.next_id..;
420                for (destination, msg_id) in (0..)
421                    .zip(&mut self.queue)
422                    .filter(|(recipient_index, _)| *recipient_index != sender)
423                    .map(|(_, msg)| msg)
424                    .zip(msg_ids.by_ref())
425                {
426                    destination.push_back(Incoming {
427                        id: msg_id,
428                        sender,
429                        msg_type: MessageType::Broadcast,
430                        msg: msg.msg.clone(),
431                    })
432                }
433                self.next_id = msg_ids.next().unwrap();
434            }
435            MessageDestination::OneParty(destination) => {
436                let next_id = self.next_id;
437                self.next_id += 1;
438
439                self.queue
440                    .get_mut(usize::from(destination))
441                    .ok_or(Reason::UnknownRecipient {
442                        sender,
443                        recipient: destination,
444                    })?
445                    .push_back(Incoming {
446                        id: next_id,
447                        sender,
448                        msg_type: MessageType::P2P,
449                        msg: msg.msg,
450                    })
451            }
452        }
453
454        Ok(())
455    }
456
457    fn recv_next_msg(&mut self, recipient: u16) -> Option<Incoming<M>> {
458        self.queue[usize::from(recipient)].pop_front()
459    }
460}
461
462/// Simulates execution of the protocol
463///
464/// Takes amount of participants, and a function that carries out the protocol for
465/// one party. The function takes as input: index of the party, and [`MpcParty`](crate::MpcParty)
466/// that can be used to communicate with others.
467///
468/// ## Example
469/// ```rust,no_run
470/// use round_based::{Mpc, PartyIndex};
471///
472/// # type Result<T, E = ()> = std::result::Result<T, E>;
473/// # type Randomness = [u8; 32];
474/// # type Msg = ();
475/// // Any MPC protocol you want to test
476/// pub async fn protocol_of_random_generation<M>(
477///     party: M,
478///     i: PartyIndex,
479///     n: u16
480/// ) -> Result<Randomness>
481/// where
482///     M: Mpc<ProtocolMessage = Msg>
483/// {
484///     // ...
485/// # todo!()
486/// }
487///
488/// let n = 3;
489///
490/// let output = round_based::sim::run(
491///     n,
492///     |i, party| protocol_of_random_generation(party, i, n),
493/// )
494/// .unwrap()
495/// // unwrap `Result`s
496/// .expect_ok()
497/// // check that all parties produced the same response
498/// .expect_eq();
499///
500/// println!("Output randomness: {}", hex::encode(output));
501/// ```
502pub fn run<M, F>(
503    n: u16,
504    mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
505) -> Result<SimResult<F::Output>, SimError>
506where
507    M: Clone + 'static,
508    F: Future,
509{
510    run_with_setup(core::iter::repeat(()).take(n.into()), |i, party, ()| {
511        party_start(i, party)
512    })
513}
514
515/// Simulates execution of the protocol
516///
517/// Similar to [`run`], but allows some setup to be provided to the protocol execution
518/// function.
519///
520/// Simulation will have as many parties as `setups` iterator yields
521///
522/// ## Example
523/// ```rust,no_run
524/// use round_based::{Mpc, PartyIndex};
525///
526/// # type Result<T, E = ()> = std::result::Result<T, E>;
527/// # type Randomness = [u8; 32];
528/// # type Msg = ();
529/// // Any MPC protocol you want to test
530/// pub async fn protocol_of_random_generation<M>(
531///     rng: impl rand::RngCore,
532///     party: M,
533///     i: PartyIndex,
534///     n: u16
535/// ) -> Result<Randomness>
536/// where
537///     M: Mpc<ProtocolMessage = Msg>
538/// {
539///     // ...
540/// # todo!()
541/// }
542///
543/// let mut rng = rand_dev::DevRng::new();
544/// let n = 3;
545/// let output = round_based::sim::run_with_setup(
546///     core::iter::repeat_with(|| rng.fork()).take(n.into()),
547///     |i, party, rng| protocol_of_random_generation(rng, party, i, n),
548/// )
549/// .unwrap()
550/// // unwrap `Result`s
551/// .expect_ok()
552/// // check that all parties produced the same response
553/// .expect_eq();
554///
555/// println!("Output randomness: {}", hex::encode(output));
556/// ```
557pub fn run_with_setup<S, M, F>(
558    setups: impl IntoIterator<Item = S>,
559    mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>, S) -> F,
560) -> Result<SimResult<F::Output>, SimError>
561where
562    M: Clone + 'static,
563    F: Future,
564{
565    let mut sim = Simulation::empty();
566
567    for (setup, i) in setups.into_iter().zip(0u16..) {
568        let party = crate::state_machine::wrap_protocol(|party| party_start(i, party, setup));
569        sim.add_party(party);
570    }
571
572    sim.run()
573}
574
575#[cfg(test)]
576mod tests {
577    mod expect_eq {
578        use crate::sim::SimResult;
579
580        #[test]
581        fn all_eq() {
582            let res = SimResult::from(alloc::vec!["same string", "same string", "same string"])
583                .expect_eq();
584            assert_eq!(res, "same string")
585        }
586
587        #[test]
588        #[should_panic]
589        fn empty_res() {
590            SimResult::from(alloc::vec![]).expect_eq()
591        }
592
593        #[test]
594        #[should_panic]
595        fn not_eq() {
596            SimResult::from(alloc::vec![
597                "one result",
598                "one result",
599                "another result",
600                "one result",
601                "and something else",
602            ])
603            .expect_eq();
604        }
605    }
606
607    mod expect_ok {
608        use crate::sim::SimResult;
609
610        #[test]
611        fn all_ok() {
612            let res = SimResult::<Result<i32, core::convert::Infallible>>::from(alloc::vec![
613                Ok(0),
614                Ok(1),
615                Ok(2)
616            ])
617            .expect_ok()
618            .into_vec();
619
620            assert_eq!(res, [0, 1, 2]);
621        }
622
623        #[test]
624        #[should_panic]
625        fn not_ok() {
626            SimResult::from(alloc::vec![
627                Ok(0),
628                Err("i couldn't do what you asked :("),
629                Ok(2),
630                Ok(3),
631                Err("sorry I was pooping, what did you want?")
632            ])
633            .expect_ok();
634        }
635    }
636}