round_based/rounds_router/
mod.rs

1//! Routes incoming MPC messages between rounds
2//!
3//! [`RoundsRouter`] is an essential building block of MPC protocol, it processes incoming messages, groups
4//! them by rounds, and provides convenient API for retrieving received messages at certain round.
5//!
6//! ## Example
7//!
8//! ```rust
9//! use round_based::{Mpc, MpcParty, ProtocolMessage, Delivery, PartyIndex};
10//! use round_based::rounds_router::{RoundsRouter, simple_store::{RoundInput, RoundMsgs}};
11//!
12//! #[derive(ProtocolMessage)]
13//! pub enum Msg {
14//!     Round1(Msg1),
15//!     Round2(Msg2),
16//! }
17//!
18//! pub struct Msg1 { /* ... */ }
19//! pub struct Msg2 { /* ... */ }
20//!
21//! pub async fn some_mpc_protocol<M>(party: M, i: PartyIndex, n: u16) -> Result<Output, Error>
22//! where
23//!     M: Mpc<ProtocolMessage = Msg>,
24//! {
25//!     let MpcParty{ delivery, .. } = party.into_party();
26//!
27//!     let (incomings, _outgoings) = delivery.split();
28//!
29//!     // Build `Rounds`
30//!     let mut rounds = RoundsRouter::builder();
31//!     let round1 = rounds.add_round(RoundInput::<Msg1>::broadcast(i, n));
32//!     let round2 = rounds.add_round(RoundInput::<Msg2>::p2p(i, n));
33//!     let mut rounds = rounds.listen(incomings);
34//!
35//!     // Receive messages from round 1
36//!     let msgs: RoundMsgs<Msg1> = rounds.complete(round1).await?;
37//!
38//!     // ... process received messages
39//!
40//!     // Receive messages from round 2
41//!     let msgs = rounds.complete(round2).await?;
42//!
43//!     // ...
44//!     # todo!()
45//! }
46//! # type Output = ();
47//! # type Error = Box<dyn std::error::Error>;
48//! ```
49
50use alloc::{boxed::Box, collections::BTreeMap};
51use core::{any::Any, convert::Infallible, mem};
52
53use futures_util::{Stream, StreamExt};
54use phantom_type::PhantomType;
55use tracing::{debug, error, trace, trace_span, warn, Span};
56
57use crate::Incoming;
58
59#[doc(inline)]
60pub use self::errors::CompleteRoundError;
61pub use self::store::*;
62
63pub mod simple_store;
64mod store;
65
66/// Routes received messages between protocol rounds
67///
68/// See [module level](self) documentation to learn more about it.
69pub struct RoundsRouter<M, S = ()> {
70    incomings: S,
71    rounds: BTreeMap<u16, Option<Box<dyn ProcessRoundMessage<Msg = M> + Send>>>,
72}
73
74impl<M: ProtocolMessage + 'static> RoundsRouter<M> {
75    /// Instantiates [`RoundsRouterBuilder`]
76    pub fn builder() -> RoundsRouterBuilder<M> {
77        RoundsRouterBuilder::new()
78    }
79}
80
81impl<M, S, E> RoundsRouter<M, S>
82where
83    M: ProtocolMessage,
84    S: Stream<Item = Result<Incoming<M>, E>> + Unpin,
85    E: core::error::Error,
86{
87    /// Completes specified round
88    ///
89    /// Waits until all messages at specified round are received. Returns received
90    /// messages if round is successfully completed, or error otherwise.
91    #[inline(always)]
92    pub async fn complete<R>(
93        &mut self,
94        round: Round<R>,
95    ) -> Result<R::Output, CompleteRoundError<R::Error, E>>
96    where
97        R: MessagesStore,
98        M: RoundMessage<R::Msg>,
99    {
100        let round_number = <M as RoundMessage<R::Msg>>::ROUND;
101        let span = trace_span!("Round", n = round_number);
102        debug!(parent: &span, "pending round to complete");
103
104        match self.complete_with_span(&span, round).await {
105            Ok(output) => {
106                trace!(parent: &span, "round successfully completed");
107                Ok(output)
108            }
109            Err(err) => {
110                error!(parent: &span, %err, "round terminated with error");
111                Err(err)
112            }
113        }
114    }
115
116    async fn complete_with_span<R>(
117        &mut self,
118        span: &Span,
119        _round: Round<R>,
120    ) -> Result<R::Output, CompleteRoundError<R::Error, E>>
121    where
122        R: MessagesStore,
123        M: RoundMessage<R::Msg>,
124    {
125        let pending_round = <M as RoundMessage<R::Msg>>::ROUND;
126        if let Some(output) = self.retrieve_round_output_if_its_completed::<R>() {
127            return output;
128        }
129
130        loop {
131            let incoming = match self.incomings.next().await {
132                Some(Ok(msg)) => msg,
133                Some(Err(err)) => return Err(errors::IoError::Io(err).into()),
134                None => return Err(errors::IoError::UnexpectedEof.into()),
135            };
136            let message_round_n = incoming.msg.round();
137
138            let message_round = match self.rounds.get_mut(&message_round_n) {
139                Some(Some(round)) => round,
140                Some(None) => {
141                    warn!(
142                        parent: span,
143                        n = message_round_n,
144                        "got message for the round that was already completed, ignoring it"
145                    );
146                    continue;
147                }
148                None => {
149                    return Err(
150                        errors::RoundsMisuse::UnregisteredRound { n: message_round_n }.into(),
151                    )
152                }
153            };
154            if message_round.needs_more_messages().no() {
155                warn!(
156                    parent: span,
157                    n = message_round_n,
158                    "received message for the round that was already completed, ignoring it"
159                );
160                continue;
161            }
162            message_round.process_message(incoming);
163
164            if pending_round == message_round_n {
165                if let Some(output) = self.retrieve_round_output_if_its_completed::<R>() {
166                    return output;
167                }
168            }
169        }
170    }
171
172    #[allow(clippy::type_complexity)]
173    fn retrieve_round_output_if_its_completed<R>(
174        &mut self,
175    ) -> Option<Result<R::Output, CompleteRoundError<R::Error, E>>>
176    where
177        R: MessagesStore,
178        M: RoundMessage<R::Msg>,
179    {
180        let round_number = <M as RoundMessage<R::Msg>>::ROUND;
181        let round_slot = match self
182            .rounds
183            .get_mut(&round_number)
184            .ok_or(errors::RoundsMisuse::UnregisteredRound { n: round_number })
185        {
186            Ok(slot) => slot,
187            Err(err) => return Some(Err(err.into())),
188        };
189        let round = match round_slot
190            .as_mut()
191            .ok_or(errors::RoundsMisuse::RoundAlreadyCompleted)
192        {
193            Ok(round) => round,
194            Err(err) => return Some(Err(err.into())),
195        };
196        if round.needs_more_messages().no() {
197            Some(Self::retrieve_round_output::<R>(round_slot))
198        } else {
199            None
200        }
201    }
202
203    fn retrieve_round_output<R>(
204        slot: &mut Option<Box<dyn ProcessRoundMessage<Msg = M> + Send>>,
205    ) -> Result<R::Output, CompleteRoundError<R::Error, E>>
206    where
207        R: MessagesStore,
208        M: RoundMessage<R::Msg>,
209    {
210        let mut round = slot.take().ok_or(errors::RoundsMisuse::UnregisteredRound {
211            n: <M as RoundMessage<R::Msg>>::ROUND,
212        })?;
213        match round.take_output() {
214            Ok(Ok(any)) => Ok(*any
215                .downcast::<R::Output>()
216                .or(Err(CompleteRoundError::from(
217                    errors::Bug::MismatchedOutputType,
218                )))?),
219            Ok(Err(any)) => Err(any
220                .downcast::<CompleteRoundError<R::Error, Infallible>>()
221                .or(Err(CompleteRoundError::from(
222                    errors::Bug::MismatchedErrorType,
223                )))?
224                .map_io_err(|e| match e {})),
225            Err(err) => Err(errors::Bug::TakeRoundResult(err).into()),
226        }
227    }
228}
229
230/// Builds [`RoundsRouter`]
231pub struct RoundsRouterBuilder<M> {
232    rounds: BTreeMap<u16, Option<Box<dyn ProcessRoundMessage<Msg = M> + Send>>>,
233}
234
235impl<M> Default for RoundsRouterBuilder<M>
236where
237    M: ProtocolMessage + 'static,
238{
239    fn default() -> Self {
240        Self::new()
241    }
242}
243
244impl<M> RoundsRouterBuilder<M>
245where
246    M: ProtocolMessage + 'static,
247{
248    /// Constructs [`RoundsRouterBuilder`]
249    ///
250    /// Alias to [`RoundsRouter::builder`]
251    pub fn new() -> Self {
252        Self {
253            rounds: BTreeMap::new(),
254        }
255    }
256
257    /// Registers new round
258    ///
259    /// ## Panics
260    /// Panics if round `R` was already registered
261    pub fn add_round<R>(&mut self, message_store: R) -> Round<R>
262    where
263        R: MessagesStore + Send + 'static,
264        R::Output: Send,
265        R::Error: Send,
266        M: RoundMessage<R::Msg>,
267    {
268        let overridden_round = self.rounds.insert(
269            M::ROUND,
270            Some(Box::new(ProcessRoundMessageImpl::new(message_store))),
271        );
272        if overridden_round.is_some() {
273            panic!("round {} is overridden", M::ROUND);
274        }
275        Round {
276            _ph: PhantomType::new(),
277        }
278    }
279
280    /// Builds [`RoundsRouter`]
281    ///
282    /// Takes a stream of incoming messages which will be routed between registered rounds
283    pub fn listen<S, E>(self, incomings: S) -> RoundsRouter<M, S>
284    where
285        S: Stream<Item = Result<Incoming<M>, E>>,
286    {
287        RoundsRouter {
288            incomings,
289            rounds: self.rounds,
290        }
291    }
292}
293
294/// A round of MPC protocol
295///
296/// `Round` can be used to retrieve messages received at this round by calling [`RoundsRouter::complete`]. See
297/// [module level](self) documentation to see usage.
298pub struct Round<S: MessagesStore> {
299    _ph: PhantomType<S>,
300}
301
302trait ProcessRoundMessage {
303    type Msg;
304
305    /// Processes round message
306    ///
307    /// Before calling this method you must ensure that `.needs_more_messages()` returns `Yes`,
308    /// otherwise calling this method is unexpected.
309    fn process_message(&mut self, msg: Incoming<Self::Msg>);
310
311    /// Indicated whether the store needs more messages
312    ///
313    /// If it returns `Yes`, then you need to collect more messages to complete round. If it's `No`
314    /// then you need to take the round output by calling `.take_output()`.
315    fn needs_more_messages(&self) -> NeedsMoreMessages;
316
317    /// Tries to obtain round output
318    ///
319    /// Can be called once `process_message()` returned `NeedMoreMessages::No`.
320    ///
321    /// Returns:
322    /// * `Ok(Ok(any))` — round is successfully completed, `any` needs to be downcasted to `MessageStore::Output`
323    /// * `Ok(Err(any))` — round has terminated with an error, `any` needs to be downcasted to `CompleteRoundError<MessageStore::Error>`
324    /// * `Err(err)` — couldn't retrieve the output, see [`TakeOutputError`]
325    #[allow(clippy::type_complexity)]
326    fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError>;
327}
328
329#[derive(Debug, thiserror::Error)]
330enum TakeOutputError {
331    #[error("output is already taken")]
332    AlreadyTaken,
333    #[error("output is not ready yet, more messages are needed")]
334    NotReady,
335}
336
337enum ProcessRoundMessageImpl<S: MessagesStore, M: ProtocolMessage + RoundMessage<S::Msg>> {
338    InProgress { store: S, _ph: PhantomType<fn(M)> },
339    Completed(Result<S::Output, CompleteRoundError<S::Error, Infallible>>),
340    Gone,
341}
342
343impl<S: MessagesStore, M: ProtocolMessage + RoundMessage<S::Msg>> ProcessRoundMessageImpl<S, M> {
344    pub fn new(store: S) -> Self {
345        if store.wants_more() {
346            Self::InProgress {
347                store,
348                _ph: Default::default(),
349            }
350        } else {
351            Self::Completed(
352                store
353                    .output()
354                    .map_err(|_| errors::ImproperStoreImpl::StoreDidntOutput.into()),
355            )
356        }
357    }
358}
359
360impl<S, M> ProcessRoundMessageImpl<S, M>
361where
362    S: MessagesStore,
363    M: ProtocolMessage + RoundMessage<S::Msg>,
364{
365    fn _process_message(
366        store: &mut S,
367        msg: Incoming<M>,
368    ) -> Result<(), CompleteRoundError<S::Error, Infallible>> {
369        let msg = msg.try_map(M::from_protocol_message).map_err(|msg| {
370            errors::Bug::MessageFromAnotherRound {
371                actual_number: msg.round(),
372                expected_round: M::ROUND,
373            }
374        })?;
375
376        store
377            .add_message(msg)
378            .map_err(CompleteRoundError::ProcessMessage)?;
379        Ok(())
380    }
381}
382
383impl<S, M> ProcessRoundMessage for ProcessRoundMessageImpl<S, M>
384where
385    S: MessagesStore,
386    M: ProtocolMessage + RoundMessage<S::Msg>,
387{
388    type Msg = M;
389
390    fn process_message(&mut self, msg: Incoming<Self::Msg>) {
391        let store = match self {
392            Self::InProgress { store, .. } => store,
393            _ => {
394                return;
395            }
396        };
397
398        match Self::_process_message(store, msg) {
399            Ok(()) => {
400                if store.wants_more() {
401                    return;
402                }
403
404                let store = match mem::replace(self, Self::Gone) {
405                    Self::InProgress { store, .. } => store,
406                    _ => {
407                        *self = Self::Completed(Err(errors::Bug::IncoherentState {
408                            expected: "InProgress",
409                            justification:
410                                "we checked at beginning of the function that `state` is InProgress",
411                        }
412                        .into()));
413                        return;
414                    }
415                };
416
417                match store.output() {
418                    Ok(output) => *self = Self::Completed(Ok(output)),
419                    Err(_err) => {
420                        *self =
421                            Self::Completed(Err(errors::ImproperStoreImpl::StoreDidntOutput.into()))
422                    }
423                }
424            }
425            Err(err) => {
426                *self = Self::Completed(Err(err));
427            }
428        }
429    }
430
431    fn needs_more_messages(&self) -> NeedsMoreMessages {
432        match self {
433            Self::InProgress { .. } => NeedsMoreMessages::Yes,
434            _ => NeedsMoreMessages::No,
435        }
436    }
437
438    fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError> {
439        match self {
440            Self::InProgress { .. } => return Err(TakeOutputError::NotReady),
441            Self::Gone => return Err(TakeOutputError::AlreadyTaken),
442            _ => (),
443        }
444        match mem::replace(self, Self::Gone) {
445            Self::Completed(Ok(output)) => Ok(Ok(Box::new(output))),
446            Self::Completed(Err(err)) => Ok(Err(Box::new(err))),
447            _ => unreachable!("it's checked to be completed"),
448        }
449    }
450}
451
452enum NeedsMoreMessages {
453    Yes,
454    No,
455}
456
457#[allow(dead_code)]
458impl NeedsMoreMessages {
459    pub fn yes(&self) -> bool {
460        matches!(self, Self::Yes)
461    }
462    pub fn no(&self) -> bool {
463        matches!(self, Self::No)
464    }
465}
466
467/// When something goes wrong
468pub mod errors {
469    use super::TakeOutputError;
470
471    /// Error indicating that `Rounds` failed to complete certain round
472    #[derive(Debug, thiserror::Error)]
473    pub enum CompleteRoundError<ProcessErr, IoErr> {
474        /// [`MessagesStore`](super::MessagesStore) failed to process this message
475        #[error("failed to process the message")]
476        ProcessMessage(#[source] ProcessErr),
477        /// Receiving next message resulted into i/o error
478        #[error("receive next message")]
479        Io(#[from] IoError<IoErr>),
480        /// Some implementation specific error
481        ///
482        /// Error may be result of improper `MessagesStore` implementation, API misuse, or bug
483        /// in `Rounds` implementation
484        #[error("implementation error")]
485        Other(#[source] OtherError),
486    }
487
488    /// Error indicating that receiving next message resulted into i/o error
489    #[derive(Debug, thiserror::Error)]
490    pub enum IoError<E> {
491        /// I/O error
492        #[error("i/o error")]
493        Io(#[source] E),
494        /// Encountered unexpected EOF
495        #[error("unexpected eof")]
496        UnexpectedEof,
497    }
498
499    /// Some implementation specific error
500    ///
501    /// Error may be result of improper `MessagesStore` implementation, API misuse, or bug
502    /// in `Rounds` implementation
503    #[derive(Debug, thiserror::Error)]
504    #[error(transparent)]
505    pub struct OtherError(OtherReason);
506
507    #[derive(Debug, thiserror::Error)]
508    pub(super) enum OtherReason {
509        #[error("improper `MessagesStore` implementation")]
510        ImproperStoreImpl(#[source] ImproperStoreImpl),
511        #[error("`Rounds` API misuse")]
512        RoundsMisuse(#[source] RoundsMisuse),
513        #[error("bug in `Rounds` (please, open a issue)")]
514        Bug(#[source] Bug),
515    }
516
517    #[derive(Debug, thiserror::Error)]
518    pub(super) enum ImproperStoreImpl {
519        /// Store indicated that it received enough messages but didn't output
520        ///
521        /// I.e. [`store.wants_more()`] returned `false`, but `store.output()` returned `Err(_)`.
522        #[error("store didn't output")]
523        StoreDidntOutput,
524    }
525
526    #[derive(Debug, thiserror::Error)]
527    pub(super) enum RoundsMisuse {
528        #[error("round is already completed")]
529        RoundAlreadyCompleted,
530        #[error("round {n} is not registered")]
531        UnregisteredRound { n: u16 },
532    }
533
534    #[derive(Debug, thiserror::Error)]
535    pub(super) enum Bug {
536        #[error(
537            "message originates from another round: we process messages from round \
538            {expected_round}, got message from round {actual_number}"
539        )]
540        MessageFromAnotherRound {
541            expected_round: u16,
542            actual_number: u16,
543        },
544        #[error("state is incoherent, it's expected to be {expected}: {justification}")]
545        IncoherentState {
546            expected: &'static str,
547            justification: &'static str,
548        },
549        #[error("mismatched output type")]
550        MismatchedOutputType,
551        #[error("mismatched error type")]
552        MismatchedErrorType,
553        #[error("take round result")]
554        TakeRoundResult(#[source] TakeOutputError),
555    }
556
557    impl<ProcessErr, IoErr> CompleteRoundError<ProcessErr, IoErr> {
558        pub(super) fn map_io_err<E, F>(self, f: F) -> CompleteRoundError<ProcessErr, E>
559        where
560            F: FnOnce(IoErr) -> E,
561        {
562            match self {
563                CompleteRoundError::Io(err) => CompleteRoundError::Io(err.map_err(f)),
564                CompleteRoundError::ProcessMessage(err) => CompleteRoundError::ProcessMessage(err),
565                CompleteRoundError::Other(err) => CompleteRoundError::Other(err),
566            }
567        }
568    }
569
570    impl<E> IoError<E> {
571        pub(super) fn map_err<B, F>(self, f: F) -> IoError<B>
572        where
573            F: FnOnce(E) -> B,
574        {
575            match self {
576                IoError::Io(e) => IoError::Io(f(e)),
577                IoError::UnexpectedEof => IoError::UnexpectedEof,
578            }
579        }
580    }
581
582    macro_rules! impl_from_other_error {
583        ($($err:ident),+,) => {$(
584            impl<E1, E2> From<$err> for CompleteRoundError<E1, E2> {
585                fn from(err: $err) -> Self {
586                    Self::Other(OtherError(OtherReason::$err(err)))
587                }
588            }
589        )+};
590    }
591
592    impl_from_other_error! {
593        ImproperStoreImpl,
594        RoundsMisuse,
595        Bug,
596    }
597}
598
599#[cfg(test)]
600mod tests {
601    struct Store;
602
603    #[derive(crate::ProtocolMessage)]
604    #[protocol_message(root = crate)]
605    enum FakeProtocolMsg {
606        R1(Msg1),
607    }
608    struct Msg1;
609
610    impl super::MessagesStore for Store {
611        type Msg = Msg1;
612        type Output = ();
613        type Error = core::convert::Infallible;
614
615        fn add_message(&mut self, _msg: crate::Incoming<Self::Msg>) -> Result<(), Self::Error> {
616            Ok(())
617        }
618        fn wants_more(&self) -> bool {
619            false
620        }
621        fn output(self) -> Result<Self::Output, Self> {
622            Ok(())
623        }
624    }
625
626    #[tokio::test]
627    async fn complete_round_that_expects_no_messages() {
628        let incomings = futures::stream::pending::<
629            Result<crate::Incoming<FakeProtocolMsg>, core::convert::Infallible>,
630        >();
631
632        let mut rounds = super::RoundsRouter::builder();
633        let round1 = rounds.add_round(Store);
634        let mut rounds = rounds.listen(incomings);
635
636        rounds.complete(round1).await.unwrap();
637    }
638}