round_based/sim/mod.rs
1//! Multiparty protocol simulation
2//!
3//! Simulator is an essential developer tool for testing the multiparty protocol locally.
4//! It covers most of the boilerplate of emulating MPC protocol execution.
5//!
6//! The entry point is either [`run`] or [`run_with_setup`] functions. They take a protocol
7//! defined as an async function, provide simulated networking, carry out the simulation,
8//! and return the result.
9//!
10//! If you need more control over execution, you can use [`Simulation`]. For instance, it allows
11//! creating a simulation that has parties defined by different functions, which is helpful, for
12//! instance, in simulation in presence of an adversary (e.g. one set of parties can be defined
13//! with a regular function/protocol implementation, when the other set of parties may be defined
14//! by other function which emulates adversary behavior).
15//!
16//! ## Limitations
17//! [`Simulation`] works by converting each party (defined as an async function) into the
18//! [state machine](crate::state_machine). That should work without problems in most cases, providing
19//! better UX, without requiring an async runtime (simulation is entirely sync).
20//!
21//! However, a protocol wrapped into a state machine cannot poll any futures except provided within
22//! [`MpcParty`](crate::MpcParty) (so it can only await on sending/receiving messages and yielding).
23//! For instance, if the protocol implementation makes use of tokio timers, it will result into an
24//! execution error.
25//!
26//! In general, we do not recommend awaiting on the futures that aren't provided by `MpcParty` in
27//! the MPC protocol implementation, to keep the protocol implementation runtime-agnostic.
28//!
29//! If you do really need to make use of unsupported futures, you can use [`async_env`] instead,
30//! which provides a simulation on tokio runtime, but has its own limitations.
31//!
32//! ## Example
33//! ```rust,no_run
34//! use round_based::{Mpc, PartyIndex};
35//!
36//! # type Result<T, E = ()> = std::result::Result<T, E>;
37//! # type Randomness = [u8; 32];
38//! # #[derive(round_based::ProtocolMsg, Clone)]
39//! # enum Msg {}
40//! // Any MPC protocol you want to test
41//! pub async fn protocol_of_random_generation<M>(
42//! party: M,
43//! i: PartyIndex,
44//! n: u16
45//! ) -> Result<Randomness>
46//! where
47//! M: Mpc<Msg = Msg>
48//! {
49//! // ...
50//! # todo!()
51//! }
52//!
53//! let n = 3;
54//!
55//! let output = round_based::sim::run(
56//! n,
57//! |i, party| protocol_of_random_generation(party, i, n),
58//! )
59//! .unwrap()
60//! // unwrap `Result`s
61//! .expect_ok()
62//! // check that all parties produced the same response
63//! .expect_eq();
64//!
65//! println!("Output randomness: {}", hex::encode(output));
66//! ```
67
68use alloc::{boxed::Box, collections::VecDeque, string::ToString, vec::Vec};
69use core::future::Future;
70
71use crate::{
72 Incoming, MessageDestination, MessageType, Outgoing, ProtocolMsg, state_machine::ProceedResult,
73};
74
75#[cfg(feature = "sim-async")]
76pub mod async_env;
77
78/// Result of the simulation
79pub struct SimResult<T>(pub Vec<T>);
80
81impl<T, E> SimResult<Result<T, E>>
82where
83 E: core::fmt::Debug,
84{
85 /// Unwraps `Result<T, E>` produced by each party
86 ///
87 /// Panics if at least one of the parties returned `Err(_)`. In this case,
88 /// a verbose error message will shown specifying which of the parties returned
89 /// an error.
90 pub fn expect_ok(self) -> SimResult<T> {
91 let mut oks = Vec::with_capacity(self.0.len());
92 let mut errs = Vec::with_capacity(self.0.len());
93
94 for (res, i) in self.0.into_iter().zip(0u16..) {
95 match res {
96 Ok(res) => oks.push(res),
97 Err(res) => errs.push((i, res)),
98 }
99 }
100
101 if !errs.is_empty() {
102 let mut msg = alloc::format!(
103 "Simulation output didn't match expectations.\n\
104 Expected: all parties succeed\n\
105 Actual : {success} parties succeeded, {failed} parties returned an error\n\
106 Failures:\n",
107 success = oks.len(),
108 failed = errs.len(),
109 );
110
111 for (i, err) in errs {
112 msg += &alloc::format!("- Party {i}: {err:?}\n");
113 }
114
115 panic!("{msg}");
116 }
117
118 SimResult(oks)
119 }
120}
121
122impl<T> SimResult<T>
123where
124 T: PartialEq + core::fmt::Debug,
125{
126 /// Checks that outputs of all parties are equally the same
127 ///
128 /// Returns the output on success (all the outputs are checked to be the same), otherwise
129 /// panics with a verbose error message.
130 ///
131 /// Panics if simulation contained zero parties.
132 pub fn expect_eq(mut self) -> T {
133 let Some(first) = self.0.first() else {
134 panic!("simulation contained zero parties");
135 };
136
137 if !self.0[1..].iter().all(|i| i == first) {
138 let mut msg = alloc::string::String::from(
139 "Simulation output didn't match expectations.\n\
140 Expected: all parties return the same output\n\
141 Actual : some of the parties returned a different output\n\
142 Outputs :\n",
143 );
144
145 let mut clusters: Vec<(&T, Vec<usize>)> = Vec::new();
146 for (i, value) in self.0.iter().enumerate() {
147 match clusters
148 .iter_mut()
149 .find(|(cluster_value, _)| *cluster_value == value)
150 .map(|(_, indexes)| indexes)
151 {
152 Some(indexes) => indexes.push(i),
153 None => clusters.push((value, alloc::vec![i])),
154 }
155 }
156
157 for (value, parties) in &clusters {
158 if parties.len() == 1 {
159 msg += "- Party ";
160 } else {
161 msg += "- Parties "
162 }
163
164 for (i, is_first) in parties
165 .iter()
166 .zip(core::iter::once(true).chain(core::iter::repeat(false)))
167 {
168 if !is_first {
169 msg += ", "
170 }
171 msg += &i.to_string();
172 }
173
174 msg += &alloc::format!(": {value:?}\n");
175 }
176
177 panic!("{msg}")
178 }
179
180 self.0
181 .pop()
182 .expect("we checked that the list contains at least one element")
183 }
184}
185
186impl<T> SimResult<T> {
187 /// Deconstructs the simulation result returning inner list of results
188 pub fn into_vec(self) -> Vec<T> {
189 self.0
190 }
191}
192
193impl<T> IntoIterator for SimResult<T> {
194 type Item = T;
195 type IntoIter = alloc::vec::IntoIter<T>;
196 fn into_iter(self) -> Self::IntoIter {
197 self.0.into_iter()
198 }
199}
200
201impl<T> core::ops::Deref for SimResult<T> {
202 type Target = [T];
203 fn deref(&self) -> &Self::Target {
204 &self.0
205 }
206}
207
208impl<T> From<Vec<T>> for SimResult<T> {
209 fn from(list: Vec<T>) -> Self {
210 Self(list)
211 }
212}
213
214impl<T> From<SimResult<T>> for Vec<T> {
215 fn from(res: SimResult<T>) -> Self {
216 res.0
217 }
218}
219
220/// Simulates MPC protocol with parties defined as [state machines](crate::state_machine)
221pub struct Simulation<'a, O, M> {
222 parties: Vec<Party<'a, O, M>>,
223}
224
225enum Party<'a, O, M> {
226 Active {
227 party: Box<dyn crate::state_machine::StateMachine<Output = O, Msg = M> + 'a>,
228 wants_one_more_msg: bool,
229 },
230 Finished(O),
231}
232
233impl<'a, O, M> Simulation<'a, O, M>
234where
235 M: ProtocolMsg + Clone + 'static,
236{
237 /// Creates empty simulation containing no parties
238 ///
239 /// New parties can be added via [`.add_party()`](Self::add_party)
240 pub fn empty() -> Self {
241 Self {
242 parties: Vec::new(),
243 }
244 }
245
246 /// Constructs empty simulation containing no parties, with allocated memory that can fit up to `n` parties without re-allocations
247 pub fn with_capacity(n: u16) -> Self {
248 Self {
249 parties: Vec::with_capacity(n.into()),
250 }
251 }
252
253 /// Constructs a simulation with `n` parties from async function that defines the protocol
254 ///
255 /// Each party has index `0 <= i < n` and instantiated via provided `init` function
256 ///
257 /// Async function will be converted into a [state machine](crate::state_machine). Because of that,
258 /// it cannot await on any futures that aren't provided by `MpcParty` (that is given as an argument
259 /// to this function).
260 pub fn from_async_fn<F>(
261 n: u16,
262 mut init: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
263 ) -> Self
264 where
265 F: core::future::Future<Output = O> + 'a,
266 {
267 let mut sim = Self::with_capacity(n);
268 for i in 0..n {
269 sim.add_async_party(|party| init(i, party))
270 }
271 sim
272 }
273
274 /// Construct a simulation with `n` parties from `init` function that constructs state machine for each party
275 ///
276 /// Each party has index `0 <= i < n` and instantiated via provided `init` function
277 pub fn from_fn<S>(n: u16, mut init: impl FnMut(u16) -> S) -> Self
278 where
279 S: crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
280 {
281 let mut sim = Self::with_capacity(n);
282 for i in 0..n {
283 sim.add_party(init(i));
284 }
285 sim
286 }
287
288 /// Adds new party into the protocol
289 ///
290 /// New party will be assigned index `i = n - 1` where `n` is amount of parties in the
291 /// simulation after this party was added.
292 pub fn add_party(
293 &mut self,
294 party: impl crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
295 ) {
296 self.parties.push(Party::Active {
297 party: Box::new(party),
298 wants_one_more_msg: false,
299 })
300 }
301
302 /// Adds new party, defined as an async function, into the protocol
303 ///
304 /// New party will be assigned index `i = n - 1` where `n` is amount of parties in the
305 /// simulation after this party was added.
306 ///
307 /// Async function will be converted into a [state machine](crate::state_machine). Because of that,
308 /// it cannot await on any futures that aren't provided by `MpcParty` (that is given as an argument
309 /// to this function).
310 pub fn add_async_party<F>(&mut self, party: impl FnOnce(crate::state_machine::MpcParty<M>) -> F)
311 where
312 F: core::future::Future<Output = O> + 'a,
313 {
314 self.parties.push(Party::Active {
315 party: Box::new(crate::state_machine::wrap_protocol(party)),
316 wants_one_more_msg: false,
317 })
318 }
319
320 /// Returns amount of parties in the simulation
321 pub fn parties_amount(&self) -> usize {
322 self.parties.len()
323 }
324
325 /// Carries out the simulation
326 pub fn run(mut self) -> Result<SimResult<O>, SimError> {
327 let mut messages_queue = MessagesQueue::new(self.parties.len());
328 let mut parties_left = self.parties.len();
329
330 while parties_left > 0 {
331 'next_party: for (i, party_state) in (0..).zip(&mut self.parties) {
332 'this_party: loop {
333 let Party::Active {
334 party,
335 wants_one_more_msg,
336 } = party_state
337 else {
338 continue 'next_party;
339 };
340
341 if *wants_one_more_msg {
342 if let Some(message) = messages_queue.recv_next_msg(i) {
343 party
344 .received_msg(message)
345 .map_err(|_| Reason::SaveIncomingMsg)?;
346 *wants_one_more_msg = false;
347 } else {
348 continue 'next_party;
349 }
350 }
351
352 match party.proceed() {
353 ProceedResult::SendMsg(msg) => {
354 messages_queue.send_message(i, msg)?;
355 continue 'this_party;
356 }
357 ProceedResult::NeedsOneMoreMessage => {
358 *wants_one_more_msg = true;
359 continue 'this_party;
360 }
361 ProceedResult::Output(out) => {
362 *party_state = Party::Finished(out);
363 parties_left -= 1;
364 continue 'next_party;
365 }
366 ProceedResult::Yielded => {
367 continue 'this_party;
368 }
369 ProceedResult::Error(err) => {
370 return Err(Reason::ExecutionError(err).into());
371 }
372 }
373 }
374 }
375 }
376
377 Ok(SimResult(
378 self.parties
379 .into_iter()
380 .map(|party| match party {
381 Party::Active { .. } => {
382 unreachable!("there must be no active parties when `parties_left == 0`")
383 }
384 Party::Finished(out) => out,
385 })
386 .collect(),
387 ))
388 }
389}
390
391/// Error indicating that simulation failed
392#[derive(Debug, thiserror::Error)]
393#[error(transparent)]
394pub struct SimError(#[from] Reason);
395
396#[derive(Debug, thiserror::Error)]
397enum Reason {
398 #[error("save incoming message")]
399 SaveIncomingMsg,
400 #[error("execution error")]
401 ExecutionError(#[source] crate::state_machine::ExecutionError),
402 #[error("party #{sender} tried to send a message to non existing party #{recipient}")]
403 UnknownRecipient { sender: u16, recipient: u16 },
404}
405
406struct MessagesQueue<M> {
407 queue: Vec<VecDeque<Incoming<M>>>,
408 next_id: u64,
409}
410
411impl<M: Clone> MessagesQueue<M> {
412 fn new(n: usize) -> Self {
413 Self {
414 queue: alloc::vec![VecDeque::new(); n],
415 next_id: 0,
416 }
417 }
418
419 fn send_message(&mut self, sender: u16, msg: Outgoing<M>) -> Result<(), SimError> {
420 match msg.recipient {
421 MessageDestination::AllParties { reliable } => {
422 let mut msg_ids = self.next_id..;
423 for (destination, msg_id) in (0..)
424 .zip(&mut self.queue)
425 .filter(|(recipient_index, _)| *recipient_index != sender)
426 .map(|(_, msg)| msg)
427 .zip(msg_ids.by_ref())
428 {
429 destination.push_back(Incoming {
430 id: msg_id,
431 sender,
432 msg_type: MessageType::Broadcast { reliable },
433 msg: msg.msg.clone(),
434 })
435 }
436 self.next_id = msg_ids.next().unwrap();
437 }
438 MessageDestination::OneParty(destination) => {
439 let next_id = self.next_id;
440 self.next_id += 1;
441
442 self.queue
443 .get_mut(usize::from(destination))
444 .ok_or(Reason::UnknownRecipient {
445 sender,
446 recipient: destination,
447 })?
448 .push_back(Incoming {
449 id: next_id,
450 sender,
451 msg_type: MessageType::P2P,
452 msg: msg.msg,
453 })
454 }
455 }
456
457 Ok(())
458 }
459
460 fn recv_next_msg(&mut self, recipient: u16) -> Option<Incoming<M>> {
461 self.queue[usize::from(recipient)].pop_front()
462 }
463}
464
465/// Simulates execution of the protocol
466///
467/// Takes amount of participants, and a function that carries out the protocol for
468/// one party. The function takes as input: index of the party, and [`MpcParty`](crate::MpcParty)
469/// that can be used to communicate with others.
470///
471/// ## Example
472/// ```rust,no_run
473/// use round_based::{Mpc, PartyIndex};
474///
475/// # type Result<T, E = ()> = std::result::Result<T, E>;
476/// # type Randomness = [u8; 32];
477/// # #[derive(round_based::ProtocolMsg, Clone)]
478/// # enum Msg {}
479/// // Any MPC protocol you want to test
480/// pub async fn protocol_of_random_generation<M>(
481/// party: M,
482/// i: PartyIndex,
483/// n: u16
484/// ) -> Result<Randomness>
485/// where
486/// M: Mpc<Msg = Msg>
487/// {
488/// // ...
489/// # todo!()
490/// }
491///
492/// let n = 3;
493///
494/// let output = round_based::sim::run(
495/// n,
496/// |i, party| protocol_of_random_generation(party, i, n),
497/// )
498/// .unwrap()
499/// // unwrap `Result`s
500/// .expect_ok()
501/// // check that all parties produced the same response
502/// .expect_eq();
503///
504/// println!("Output randomness: {}", hex::encode(output));
505/// ```
506pub fn run<M, F>(
507 n: u16,
508 mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
509) -> Result<SimResult<F::Output>, SimError>
510where
511 M: ProtocolMsg + Clone + 'static,
512 F: Future,
513{
514 run_with_setup(core::iter::repeat_n((), n.into()), |i, party, ()| {
515 party_start(i, party)
516 })
517}
518
519/// Simulates execution of the protocol
520///
521/// Similar to [`run`], but allows some setup to be provided to the protocol execution
522/// function.
523///
524/// Simulation will have as many parties as `setups` iterator yields
525///
526/// ## Example
527/// ```rust,no_run
528/// use round_based::{Mpc, PartyIndex};
529///
530/// # type Result<T, E = ()> = std::result::Result<T, E>;
531/// # type Randomness = [u8; 32];
532/// # #[derive(round_based::ProtocolMsg, Clone)]
533/// # enum Msg {}
534/// // Any MPC protocol you want to test
535/// pub async fn protocol_of_random_generation<M>(
536/// rng: impl rand::RngCore,
537/// party: M,
538/// i: PartyIndex,
539/// n: u16
540/// ) -> Result<Randomness>
541/// where
542/// M: Mpc<Msg = Msg>
543/// {
544/// // ...
545/// # todo!()
546/// }
547///
548/// let mut rng = rand_dev::DevRng::new();
549/// let n = 3;
550/// let output = round_based::sim::run_with_setup(
551/// core::iter::repeat_with(|| rng.fork()).take(n.into()),
552/// |i, party, rng| protocol_of_random_generation(rng, party, i, n),
553/// )
554/// .unwrap()
555/// // unwrap `Result`s
556/// .expect_ok()
557/// // check that all parties produced the same response
558/// .expect_eq();
559///
560/// println!("Output randomness: {}", hex::encode(output));
561/// ```
562pub fn run_with_setup<S, M, F>(
563 setups: impl IntoIterator<Item = S>,
564 mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>, S) -> F,
565) -> Result<SimResult<F::Output>, SimError>
566where
567 M: ProtocolMsg + Clone + 'static,
568 F: Future,
569{
570 let mut sim = Simulation::empty();
571
572 for (setup, i) in setups.into_iter().zip(0u16..) {
573 let party = crate::state_machine::wrap_protocol(|party| party_start(i, party, setup));
574 sim.add_party(party);
575 }
576
577 sim.run()
578}
579
580#[cfg(test)]
581mod tests {
582 mod expect_eq {
583 use crate::sim::SimResult;
584
585 #[test]
586 fn all_eq() {
587 let res = SimResult::from(alloc::vec!["same string", "same string", "same string"])
588 .expect_eq();
589 assert_eq!(res, "same string")
590 }
591
592 #[test]
593 #[should_panic]
594 fn empty_res() {
595 SimResult::from(alloc::vec![]).expect_eq()
596 }
597
598 #[test]
599 #[should_panic]
600 fn not_eq() {
601 SimResult::from(alloc::vec![
602 "one result",
603 "one result",
604 "another result",
605 "one result",
606 "and something else",
607 ])
608 .expect_eq();
609 }
610 }
611
612 mod expect_ok {
613 use crate::sim::SimResult;
614
615 #[test]
616 fn all_ok() {
617 let res = SimResult::<Result<i32, core::convert::Infallible>>::from(alloc::vec![
618 Ok(0),
619 Ok(1),
620 Ok(2)
621 ])
622 .expect_ok()
623 .into_vec();
624
625 assert_eq!(res, [0, 1, 2]);
626 }
627
628 #[test]
629 #[should_panic]
630 fn not_ok() {
631 SimResult::from(alloc::vec![
632 Ok(0),
633 Err("i couldn't do what you asked :("),
634 Ok(2),
635 Ok(3),
636 Err("sorry I was pooping, what did you want?")
637 ])
638 .expect_ok();
639 }
640 }
641}