round_based/sim/mod.rs
1//! Multiparty protocol simulation
2//!
3//! Simulator is an essential developer tool for testing the multiparty protocol locally.
4//! It covers most of the boilerplate of emulating MPC protocol execution.
5//!
6//! The entry point is either [`run`] or [`run_with_setup`] functions. They take a protocol
7//! defined as an async function, provide simulated networking, carry out the simulation,
8//! and return the result.
9//!
10//! If you need more control over execution, you can use [`Simulation`]. For instance, it allows
11//! creating a simulation that has parties defined by different functions, which is helpful, for
12//! instance, in simulation in presence of an adversary (e.g. one set of parties can be defined
13//! with a regular function/protocol implementation, when the other set of parties may be defined
14//! by other function which emulates adversary behavior).
15//!
16//! ## Limitations
17//! [`Simulation`] works by converting each party (defined as an async function) into the
18//! [state machine](crate::state_machine). That should work without problems in most cases, providing
19//! better UX, without requiring an async runtime (simulation is entirely sync).
20//!
21//! However, a protocol wrapped into a state machine cannot poll any futures except provided within
22//! [`MpcParty`](crate::MpcParty) (so it can only await on sending/receiving messages and yielding).
23//! For instance, if the protocol implementation makes use of tokio timers, it will result into an
24//! execution error.
25//!
26//! In general, we do not recommend awaiting on the futures that aren't provided by `MpcParty` in
27//! the MPC protocol implementation, to keep the protocol implementation runtime-agnostic.
28//!
29//! If you do really need to make use of unsupported futures, you can use [`async_env`] instead,
30//! which provides a simulation on tokio runtime, but has its own limitations.
31//!
32//! ## Example
33//! ```rust,no_run
34//! use round_based::{Mpc, PartyIndex};
35//!
36//! # type Result<T, E = ()> = std::result::Result<T, E>;
37//! # type Randomness = [u8; 32];
38//! # type Msg = ();
39//! // Any MPC protocol you want to test
40//! pub async fn protocol_of_random_generation<M>(
41//! party: M,
42//! i: PartyIndex,
43//! n: u16
44//! ) -> Result<Randomness>
45//! where
46//! M: Mpc<ProtocolMessage = Msg>
47//! {
48//! // ...
49//! # todo!()
50//! }
51//!
52//! let n = 3;
53//!
54//! let output = round_based::sim::run(
55//! n,
56//! |i, party| protocol_of_random_generation(party, i, n),
57//! )
58//! .unwrap()
59//! // unwrap `Result`s
60//! .expect_ok()
61//! // check that all parties produced the same response
62//! .expect_eq();
63//!
64//! println!("Output randomness: {}", hex::encode(output));
65//! ```
66
67use alloc::{boxed::Box, collections::VecDeque, string::ToString, vec::Vec};
68use core::future::Future;
69
70use crate::{state_machine::ProceedResult, Incoming, MessageDestination, MessageType, Outgoing};
71
72#[cfg(feature = "sim-async")]
73pub mod async_env;
74
75/// Result of the simulation
76pub struct SimResult<T>(pub Vec<T>);
77
78impl<T, E> SimResult<Result<T, E>>
79where
80 E: core::fmt::Debug,
81{
82 /// Unwraps `Result<T, E>` produced by each party
83 ///
84 /// Panics if at least one of the parties returned `Err(_)`. In this case,
85 /// a verbose error message will shown specifying which of the parties returned
86 /// an error.
87 pub fn expect_ok(self) -> SimResult<T> {
88 let mut oks = Vec::with_capacity(self.0.len());
89 let mut errs = Vec::with_capacity(self.0.len());
90
91 for (res, i) in self.0.into_iter().zip(0u16..) {
92 match res {
93 Ok(res) => oks.push(res),
94 Err(res) => errs.push((i, res)),
95 }
96 }
97
98 if !errs.is_empty() {
99 let mut msg = alloc::format!(
100 "Simulation output didn't match expectations.\n\
101 Expected: all parties succeed\n\
102 Actual : {success} parties succeeded, {failed} parties returned an error\n\
103 Failures:\n",
104 success = oks.len(),
105 failed = errs.len(),
106 );
107
108 for (i, err) in errs {
109 msg += &alloc::format!("- Party {i}: {err:?}\n");
110 }
111
112 panic!("{msg}");
113 }
114
115 SimResult(oks)
116 }
117}
118
119impl<T> SimResult<T>
120where
121 T: PartialEq + core::fmt::Debug,
122{
123 /// Checks that outputs of all parties are equally the same
124 ///
125 /// Returns the output on success (all the outputs are checked to be the same), otherwise
126 /// panics with a verbose error message.
127 ///
128 /// Panics if simulation contained zero parties.
129 pub fn expect_eq(mut self) -> T {
130 let Some(first) = self.0.first() else {
131 panic!("simulation contained zero parties");
132 };
133
134 if !self.0[1..].iter().all(|i| i == first) {
135 let mut msg = alloc::string::String::from(
136 "Simulation output didn't match expectations.\n\
137 Expected: all parties return the same output\n\
138 Actual : some of the parties returned a different output\n\
139 Outputs :\n",
140 );
141
142 let mut clusters: Vec<(&T, Vec<usize>)> = Vec::new();
143 for (i, value) in self.0.iter().enumerate() {
144 match clusters
145 .iter_mut()
146 .find(|(cluster_value, _)| *cluster_value == value)
147 .map(|(_, indexes)| indexes)
148 {
149 Some(indexes) => indexes.push(i),
150 None => clusters.push((value, alloc::vec![i])),
151 }
152 }
153
154 for (value, parties) in &clusters {
155 if parties.len() == 1 {
156 msg += "- Party ";
157 } else {
158 msg += "- Parties "
159 }
160
161 for (i, is_first) in parties
162 .iter()
163 .zip(core::iter::once(true).chain(core::iter::repeat(false)))
164 {
165 if !is_first {
166 msg += ", "
167 }
168 msg += &i.to_string();
169 }
170
171 msg += &alloc::format!(": {value:?}\n");
172 }
173
174 panic!("{msg}")
175 }
176
177 self.0
178 .pop()
179 .expect("we checked that the list contains at least one element")
180 }
181}
182
183impl<T> SimResult<T> {
184 /// Deconstructs the simulation result returning inner list of results
185 pub fn into_vec(self) -> Vec<T> {
186 self.0
187 }
188}
189
190impl<T> IntoIterator for SimResult<T> {
191 type Item = T;
192 type IntoIter = alloc::vec::IntoIter<T>;
193 fn into_iter(self) -> Self::IntoIter {
194 self.0.into_iter()
195 }
196}
197
198impl<T> core::ops::Deref for SimResult<T> {
199 type Target = [T];
200 fn deref(&self) -> &Self::Target {
201 &self.0
202 }
203}
204
205impl<T> From<Vec<T>> for SimResult<T> {
206 fn from(list: Vec<T>) -> Self {
207 Self(list)
208 }
209}
210
211impl<T> From<SimResult<T>> for Vec<T> {
212 fn from(res: SimResult<T>) -> Self {
213 res.0
214 }
215}
216
217/// Simulates MPC protocol with parties defined as [state machines](crate::state_machine)
218pub struct Simulation<'a, O, M> {
219 parties: Vec<Party<'a, O, M>>,
220}
221
222enum Party<'a, O, M> {
223 Active {
224 party: Box<dyn crate::state_machine::StateMachine<Output = O, Msg = M> + 'a>,
225 wants_one_more_msg: bool,
226 },
227 Finished(O),
228}
229
230impl<'a, O, M> Simulation<'a, O, M>
231where
232 M: Clone + 'static,
233{
234 /// Creates empty simulation containing no parties
235 ///
236 /// New parties can be added via [`.add_party()`](Self::add_party)
237 pub fn empty() -> Self {
238 Self {
239 parties: Vec::new(),
240 }
241 }
242
243 /// Constructs empty simulation containing no parties, with allocated memory that can fit up to `n` parties without re-allocations
244 pub fn with_capacity(n: u16) -> Self {
245 Self {
246 parties: Vec::with_capacity(n.into()),
247 }
248 }
249
250 /// Constructs a simulation with `n` parties from async function that defines the protocol
251 ///
252 /// Each party has index `0 <= i < n` and instantiated via provided `init` function
253 ///
254 /// Async function will be converted into a [state machine](crate::state_machine). Because of that,
255 /// it cannot await on any futures that aren't provided by `MpcParty` (that is given as an argument
256 /// to this function).
257 pub fn from_async_fn<F>(
258 n: u16,
259 mut init: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
260 ) -> Self
261 where
262 F: core::future::Future<Output = O> + 'a,
263 {
264 let mut sim = Self::with_capacity(n);
265 for i in 0..n {
266 sim.add_async_party(|party| init(i, party))
267 }
268 sim
269 }
270
271 /// Construct a simulation with `n` parties from `init` function that constructs state machine for each party
272 ///
273 /// Each party has index `0 <= i < n` and instantiated via provided `init` function
274 pub fn from_fn<S>(n: u16, mut init: impl FnMut(u16) -> S) -> Self
275 where
276 S: crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
277 {
278 let mut sim = Self::with_capacity(n);
279 for i in 0..n {
280 sim.add_party(init(i));
281 }
282 sim
283 }
284
285 /// Adds new party into the protocol
286 ///
287 /// New party will be assigned index `i = n - 1` where `n` is amount of parties in the
288 /// simulation after this party was added.
289 pub fn add_party(
290 &mut self,
291 party: impl crate::state_machine::StateMachine<Output = O, Msg = M> + 'a,
292 ) {
293 self.parties.push(Party::Active {
294 party: Box::new(party),
295 wants_one_more_msg: false,
296 })
297 }
298
299 /// Adds new party, defined as an async function, into the protocol
300 ///
301 /// New party will be assigned index `i = n - 1` where `n` is amount of parties in the
302 /// simulation after this party was added.
303 ///
304 /// Async function will be converted into a [state machine](crate::state_machine). Because of that,
305 /// it cannot await on any futures that aren't provided by `MpcParty` (that is given as an argument
306 /// to this function).
307 pub fn add_async_party<F>(&mut self, party: impl FnOnce(crate::state_machine::MpcParty<M>) -> F)
308 where
309 F: core::future::Future<Output = O> + 'a,
310 {
311 self.parties.push(Party::Active {
312 party: Box::new(crate::state_machine::wrap_protocol(party)),
313 wants_one_more_msg: false,
314 })
315 }
316
317 /// Returns amount of parties in the simulation
318 pub fn parties_amount(&self) -> usize {
319 self.parties.len()
320 }
321
322 /// Carries out the simulation
323 pub fn run(mut self) -> Result<SimResult<O>, SimError> {
324 let mut messages_queue = MessagesQueue::new(self.parties.len());
325 let mut parties_left = self.parties.len();
326
327 while parties_left > 0 {
328 'next_party: for (i, party_state) in (0..).zip(&mut self.parties) {
329 'this_party: loop {
330 let Party::Active {
331 party,
332 wants_one_more_msg,
333 } = party_state
334 else {
335 continue 'next_party;
336 };
337
338 if *wants_one_more_msg {
339 if let Some(message) = messages_queue.recv_next_msg(i) {
340 party
341 .received_msg(message)
342 .map_err(|_| Reason::SaveIncomingMsg)?;
343 *wants_one_more_msg = false;
344 } else {
345 continue 'next_party;
346 }
347 }
348
349 match party.proceed() {
350 ProceedResult::SendMsg(msg) => {
351 messages_queue.send_message(i, msg)?;
352 continue 'this_party;
353 }
354 ProceedResult::NeedsOneMoreMessage => {
355 *wants_one_more_msg = true;
356 continue 'this_party;
357 }
358 ProceedResult::Output(out) => {
359 *party_state = Party::Finished(out);
360 parties_left -= 1;
361 continue 'next_party;
362 }
363 ProceedResult::Yielded => {
364 continue 'this_party;
365 }
366 ProceedResult::Error(err) => {
367 return Err(Reason::ExecutionError(err).into());
368 }
369 }
370 }
371 }
372 }
373
374 Ok(SimResult(
375 self.parties
376 .into_iter()
377 .map(|party| match party {
378 Party::Active { .. } => {
379 unreachable!("there must be no active parties when `parties_left == 0`")
380 }
381 Party::Finished(out) => out,
382 })
383 .collect(),
384 ))
385 }
386}
387
388/// Error indicating that simulation failed
389#[derive(Debug, thiserror::Error)]
390#[error(transparent)]
391pub struct SimError(#[from] Reason);
392
393#[derive(Debug, thiserror::Error)]
394enum Reason {
395 #[error("save incoming message")]
396 SaveIncomingMsg,
397 #[error("execution error")]
398 ExecutionError(#[source] crate::state_machine::ExecutionError),
399 #[error("party #{sender} tried to send a message to non existing party #{recipient}")]
400 UnknownRecipient { sender: u16, recipient: u16 },
401}
402
403struct MessagesQueue<M> {
404 queue: Vec<VecDeque<Incoming<M>>>,
405 next_id: u64,
406}
407
408impl<M: Clone> MessagesQueue<M> {
409 fn new(n: usize) -> Self {
410 Self {
411 queue: alloc::vec![VecDeque::new(); n],
412 next_id: 0,
413 }
414 }
415
416 fn send_message(&mut self, sender: u16, msg: Outgoing<M>) -> Result<(), SimError> {
417 match msg.recipient {
418 MessageDestination::AllParties => {
419 let mut msg_ids = self.next_id..;
420 for (destination, msg_id) in (0..)
421 .zip(&mut self.queue)
422 .filter(|(recipient_index, _)| *recipient_index != sender)
423 .map(|(_, msg)| msg)
424 .zip(msg_ids.by_ref())
425 {
426 destination.push_back(Incoming {
427 id: msg_id,
428 sender,
429 msg_type: MessageType::Broadcast,
430 msg: msg.msg.clone(),
431 })
432 }
433 self.next_id = msg_ids.next().unwrap();
434 }
435 MessageDestination::OneParty(destination) => {
436 let next_id = self.next_id;
437 self.next_id += 1;
438
439 self.queue
440 .get_mut(usize::from(destination))
441 .ok_or(Reason::UnknownRecipient {
442 sender,
443 recipient: destination,
444 })?
445 .push_back(Incoming {
446 id: next_id,
447 sender,
448 msg_type: MessageType::P2P,
449 msg: msg.msg,
450 })
451 }
452 }
453
454 Ok(())
455 }
456
457 fn recv_next_msg(&mut self, recipient: u16) -> Option<Incoming<M>> {
458 self.queue[usize::from(recipient)].pop_front()
459 }
460}
461
462/// Simulates execution of the protocol
463///
464/// Takes amount of participants, and a function that carries out the protocol for
465/// one party. The function takes as input: index of the party, and [`MpcParty`](crate::MpcParty)
466/// that can be used to communicate with others.
467///
468/// ## Example
469/// ```rust,no_run
470/// use round_based::{Mpc, PartyIndex};
471///
472/// # type Result<T, E = ()> = std::result::Result<T, E>;
473/// # type Randomness = [u8; 32];
474/// # type Msg = ();
475/// // Any MPC protocol you want to test
476/// pub async fn protocol_of_random_generation<M>(
477/// party: M,
478/// i: PartyIndex,
479/// n: u16
480/// ) -> Result<Randomness>
481/// where
482/// M: Mpc<ProtocolMessage = Msg>
483/// {
484/// // ...
485/// # todo!()
486/// }
487///
488/// let n = 3;
489///
490/// let output = round_based::sim::run(
491/// n,
492/// |i, party| protocol_of_random_generation(party, i, n),
493/// )
494/// .unwrap()
495/// // unwrap `Result`s
496/// .expect_ok()
497/// // check that all parties produced the same response
498/// .expect_eq();
499///
500/// println!("Output randomness: {}", hex::encode(output));
501/// ```
502pub fn run<M, F>(
503 n: u16,
504 mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>) -> F,
505) -> Result<SimResult<F::Output>, SimError>
506where
507 M: Clone + 'static,
508 F: Future,
509{
510 run_with_setup(core::iter::repeat(()).take(n.into()), |i, party, ()| {
511 party_start(i, party)
512 })
513}
514
515/// Simulates execution of the protocol
516///
517/// Similar to [`run`], but allows some setup to be provided to the protocol execution
518/// function.
519///
520/// Simulation will have as many parties as `setups` iterator yields
521///
522/// ## Example
523/// ```rust,no_run
524/// use round_based::{Mpc, PartyIndex};
525///
526/// # type Result<T, E = ()> = std::result::Result<T, E>;
527/// # type Randomness = [u8; 32];
528/// # type Msg = ();
529/// // Any MPC protocol you want to test
530/// pub async fn protocol_of_random_generation<M>(
531/// rng: impl rand::RngCore,
532/// party: M,
533/// i: PartyIndex,
534/// n: u16
535/// ) -> Result<Randomness>
536/// where
537/// M: Mpc<ProtocolMessage = Msg>
538/// {
539/// // ...
540/// # todo!()
541/// }
542///
543/// let mut rng = rand_dev::DevRng::new();
544/// let n = 3;
545/// let output = round_based::sim::run_with_setup(
546/// core::iter::repeat_with(|| rng.fork()).take(n.into()),
547/// |i, party, rng| protocol_of_random_generation(rng, party, i, n),
548/// )
549/// .unwrap()
550/// // unwrap `Result`s
551/// .expect_ok()
552/// // check that all parties produced the same response
553/// .expect_eq();
554///
555/// println!("Output randomness: {}", hex::encode(output));
556/// ```
557pub fn run_with_setup<S, M, F>(
558 setups: impl IntoIterator<Item = S>,
559 mut party_start: impl FnMut(u16, crate::state_machine::MpcParty<M>, S) -> F,
560) -> Result<SimResult<F::Output>, SimError>
561where
562 M: Clone + 'static,
563 F: Future,
564{
565 let mut sim = Simulation::empty();
566
567 for (setup, i) in setups.into_iter().zip(0u16..) {
568 let party = crate::state_machine::wrap_protocol(|party| party_start(i, party, setup));
569 sim.add_party(party);
570 }
571
572 sim.run()
573}
574
575#[cfg(test)]
576mod tests {
577 mod expect_eq {
578 use crate::sim::SimResult;
579
580 #[test]
581 fn all_eq() {
582 let res = SimResult::from(alloc::vec!["same string", "same string", "same string"])
583 .expect_eq();
584 assert_eq!(res, "same string")
585 }
586
587 #[test]
588 #[should_panic]
589 fn empty_res() {
590 SimResult::from(alloc::vec![]).expect_eq()
591 }
592
593 #[test]
594 #[should_panic]
595 fn not_eq() {
596 SimResult::from(alloc::vec![
597 "one result",
598 "one result",
599 "another result",
600 "one result",
601 "and something else",
602 ])
603 .expect_eq();
604 }
605 }
606
607 mod expect_ok {
608 use crate::sim::SimResult;
609
610 #[test]
611 fn all_ok() {
612 let res = SimResult::<Result<i32, core::convert::Infallible>>::from(alloc::vec![
613 Ok(0),
614 Ok(1),
615 Ok(2)
616 ])
617 .expect_ok()
618 .into_vec();
619
620 assert_eq!(res, [0, 1, 2]);
621 }
622
623 #[test]
624 #[should_panic]
625 fn not_ok() {
626 SimResult::from(alloc::vec![
627 Ok(0),
628 Err("i couldn't do what you asked :("),
629 Ok(2),
630 Ok(3),
631 Err("sorry I was pooping, what did you want?")
632 ])
633 .expect_ok();
634 }
635 }
636}