round_based/mpc/party/
router.rs

1//! Routes incoming MPC messages between rounds
2//!
3//! Router is a building block, used in MpcParty to register rounds and route
4//! incoming messages between them
5
6use alloc::{boxed::Box, collections::BTreeMap};
7use core::{any::Any, convert::Infallible, mem};
8
9use phantom_type::PhantomType;
10use tracing::{error, trace_span, warn};
11
12use crate::{
13    Incoming, ProtocolMsg, RoundMsg,
14    round::{RoundInfo, RoundStore},
15};
16
17/// Routes received messages between protocol rounds
18pub struct RoundsRouter<M> {
19    rounds: BTreeMap<u16, Option<Box<dyn ProcessRoundMessage<Msg = M>>>>,
20}
21
22impl<M> RoundsRouter<M>
23where
24    M: ProtocolMsg + 'static,
25{
26    pub fn new() -> Self {
27        Self {
28            rounds: Default::default(),
29        }
30    }
31
32    /// Registers new round
33    ///
34    /// ## Panics
35    /// Panics if round `R` was already registered
36    pub fn add_round<R>(&mut self, message_store: R) -> Round<R>
37    where
38        R: RoundStore,
39        M: RoundMsg<R::Msg>,
40    {
41        let overridden_round = self.rounds.insert(
42            M::ROUND,
43            Some(Box::new(ProcessRoundMessageImpl::new(message_store))),
44        );
45        if overridden_round.is_some() {
46            panic!("round {} is overridden", M::ROUND);
47        }
48        Round {
49            _ph: PhantomType::new(),
50        }
51    }
52
53    pub fn received_msg(&mut self, incoming: Incoming<M>) -> Result<(), errors::UnregisteredRound> {
54        let msg_round_n = incoming.msg.round();
55        let span = trace_span!(
56            "Round::received_msg",
57            round = %msg_round_n,
58            sender = %incoming.sender,
59            ty = ?incoming.msg_type
60        );
61        let _guard = span.enter();
62
63        let message_round = match self.rounds.get_mut(&msg_round_n) {
64            Some(Some(round)) => round,
65            Some(None) => {
66                warn!("got message for the round that was already completed, ignoring it");
67                return Ok(());
68            }
69            None => {
70                return Err(errors::UnregisteredRound {
71                    n: msg_round_n,
72                    witness_provided: false,
73                });
74            }
75        };
76        if message_round.needs_more_messages().no() {
77            warn!("received message for the round that was already completed, ignoring it");
78            return Ok(());
79        }
80        message_round.process_message(incoming);
81        Ok(())
82    }
83
84    #[allow(clippy::type_complexity)]
85    pub fn complete_round<R>(
86        &mut self,
87        round: Round<R>,
88    ) -> Result<Result<R::Output, errors::CompleteRoundError<R::Error, Infallible>>, Round<R>>
89    where
90        R: RoundInfo,
91        M: RoundMsg<R::Msg>,
92    {
93        let message_round = match self.rounds.get_mut(&M::ROUND) {
94            Some(Some(round)) => round,
95            Some(None) => {
96                return Ok(Err(
97                    errors::Bug::RoundGoneButWitnessExists { n: M::ROUND }.into()
98                ));
99            }
100            None => {
101                return Ok(Err(errors::UnregisteredRound {
102                    n: M::ROUND,
103                    witness_provided: true,
104                }
105                .into()));
106            }
107        };
108        if message_round.needs_more_messages().yes() {
109            return Err(round);
110        }
111        Ok(Self::retrieve_round_output::<R>(message_round))
112    }
113
114    fn retrieve_round_output<R>(
115        round: &mut Box<dyn ProcessRoundMessage<Msg = M>>,
116    ) -> Result<R::Output, errors::CompleteRoundError<R::Error, Infallible>>
117    where
118        R: RoundInfo,
119    {
120        match round.take_output() {
121            Ok(Ok(any)) => Ok(*any
122                .downcast::<R::Output>()
123                .or(Err(errors::Bug::MismatchedOutputType))?),
124            Ok(Err(any)) => Err(*any
125                .downcast::<errors::CompleteRoundError<R::Error, Infallible>>()
126                .or(Err(errors::Bug::MismatchedErrorType))?),
127            Err(err) => Err(errors::Bug::TakeRoundResult(err).into()),
128        }
129    }
130}
131
132/// A witness that round has been registered in the router
133///
134/// Can be used later to claim messages received in this round
135pub struct Round<S> {
136    _ph: PhantomType<S>,
137}
138
139impl<S> core::fmt::Debug for Round<S> {
140    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
141        f.debug_struct("Round").finish_non_exhaustive()
142    }
143}
144
145trait ProcessRoundMessage {
146    type Msg;
147
148    /// Processes round message
149    ///
150    /// Before calling this method you must ensure that `.needs_more_messages()` returns `Yes`,
151    /// otherwise calling this method is unexpected.
152    fn process_message(&mut self, msg: Incoming<Self::Msg>);
153
154    /// Indicated whether the store needs more messages
155    ///
156    /// If it returns `Yes`, then you need to collect more messages to complete round. If it's `No`
157    /// then you need to take the round output by calling `.take_output()`.
158    fn needs_more_messages(&self) -> NeedsMoreMessages;
159
160    /// Tries to obtain round output
161    ///
162    /// Can be called once `process_message()` returned `NeedMoreMessages::No`.
163    ///
164    /// Returns:
165    /// * `Ok(Ok(any))` — round is successfully completed, `any` needs to be downcasted to `MessageStore::Output`
166    /// * `Ok(Err(any))` — round has terminated with an error, `any` needs to be downcasted to `CompleteRoundError<MessageStore::Error>`
167    /// * `Err(err)` — couldn't retrieve the output, see [`TakeOutputError`]
168    #[allow(clippy::type_complexity)]
169    fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError>;
170}
171
172#[derive(Debug, thiserror::Error)]
173enum TakeOutputError {
174    #[error("output is already taken")]
175    AlreadyTaken,
176    #[error("output is not ready yet, more messages are needed")]
177    NotReady,
178}
179
180enum ProcessRoundMessageImpl<S: RoundStore, M: ProtocolMsg + RoundMsg<S::Msg>> {
181    InProgress { store: S, _ph: PhantomType<fn(M)> },
182    Completed(Result<S::Output, errors::CompleteRoundError<S::Error, Infallible>>),
183    Gone,
184}
185
186impl<S: RoundStore, M: ProtocolMsg + RoundMsg<S::Msg>> ProcessRoundMessageImpl<S, M> {
187    pub fn new(store: S) -> Self {
188        if store.wants_more() {
189            Self::InProgress {
190                store,
191                _ph: Default::default(),
192            }
193        } else {
194            Self::Completed(
195                store
196                    .output()
197                    .map_err(|_| errors::ImproperRoundStore::StoreDidntOutput.into()),
198            )
199        }
200    }
201}
202
203impl<S, M> ProcessRoundMessageImpl<S, M>
204where
205    S: RoundStore,
206    M: ProtocolMsg + RoundMsg<S::Msg>,
207{
208    fn _process_message(
209        store: &mut S,
210        msg: Incoming<M>,
211    ) -> Result<(), errors::CompleteRoundError<S::Error, Infallible>> {
212        let msg = msg.try_map(M::from_protocol_msg).map_err(|msg| {
213            errors::Bug::MessageFromAnotherRound {
214                actual_number: msg.round(),
215                expected_round: M::ROUND,
216            }
217        })?;
218
219        store
220            .add_message(msg)
221            .map_err(errors::CompleteRoundError::ProcessMsg)?;
222        Ok(())
223    }
224}
225
226impl<S, M> ProcessRoundMessage for ProcessRoundMessageImpl<S, M>
227where
228    S: RoundStore,
229    M: ProtocolMsg + RoundMsg<S::Msg>,
230{
231    type Msg = M;
232
233    fn process_message(&mut self, msg: Incoming<Self::Msg>) {
234        let store = match self {
235            Self::InProgress { store, .. } => store,
236            _ => {
237                return;
238            }
239        };
240
241        match Self::_process_message(store, msg) {
242            Ok(()) => {
243                if store.wants_more() {
244                    return;
245                }
246
247                let store = match mem::replace(self, Self::Gone) {
248                    Self::InProgress { store, .. } => store,
249                    _ => {
250                        *self = Self::Completed(Err(errors::Bug::IncoherentState {
251                            expected: "InProgress",
252                            justification:
253                                "we checked at beginning of the function that `state` is InProgress",
254                        }.into()));
255                        return;
256                    }
257                };
258
259                match store.output() {
260                    Ok(output) => *self = Self::Completed(Ok(output)),
261                    Err(_err) => {
262                        *self = Self::Completed(Err(
263                            errors::ImproperRoundStore::StoreDidntOutput.into()
264                        ))
265                    }
266                }
267            }
268            Err(err) => {
269                *self = Self::Completed(Err(err));
270            }
271        }
272    }
273
274    fn needs_more_messages(&self) -> NeedsMoreMessages {
275        match self {
276            Self::InProgress { .. } => NeedsMoreMessages::Yes,
277            _ => NeedsMoreMessages::No,
278        }
279    }
280
281    fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError> {
282        match self {
283            Self::InProgress { .. } => return Err(TakeOutputError::NotReady),
284            Self::Gone => return Err(TakeOutputError::AlreadyTaken),
285            _ => (),
286        }
287        match mem::replace(self, Self::Gone) {
288            Self::Completed(Ok(output)) => Ok(Ok(Box::new(output))),
289            Self::Completed(Err(err)) => Ok(Err(Box::new(err))),
290            _ => unreachable!("it's checked to be completed"),
291        }
292    }
293}
294
295enum NeedsMoreMessages {
296    Yes,
297    No,
298}
299
300#[allow(dead_code)]
301impl NeedsMoreMessages {
302    pub fn yes(&self) -> bool {
303        matches!(self, Self::Yes)
304    }
305    pub fn no(&self) -> bool {
306        matches!(self, Self::No)
307    }
308}
309
310/// When something goes wrong
311pub mod errors {
312    pub use crate::mpc::party::CompleteRoundError;
313
314    use super::TakeOutputError;
315
316    #[derive(Debug, thiserror::Error)]
317    #[error("received a message for unregistered round")]
318    pub(in crate::mpc) struct UnregisteredRound {
319        pub n: u16,
320        pub(super) witness_provided: bool,
321    }
322
323    /// Router error
324    ///
325    /// Refer to [`CompleteRoundError::Router`] docs
326    #[derive(Debug, thiserror::Error)]
327    #[error(transparent)]
328    pub struct RouterError(Reason);
329
330    #[derive(Debug, thiserror::Error)]
331    pub(super) enum Reason {
332        /// Router API has been misused
333        ///
334        /// For instance, this error is returned when protocol implementation does not register
335        /// certain round of the protocol, but then a message from this round is received. In
336        /// this case, router doesn't have anywhere to route the message to, so an [`ApiMisuse`]
337        /// error is returned.
338        #[error("api misuse")]
339        ApiMisuse(#[source] ApiMisuse),
340        /// Improper [`RoundStore`](crate::round::RoundStore) implementation
341        ///
342        /// For instance, this error is returned when round store indicates that it doesn't need
343        /// any more messages ([`RoundStore::wants_more`](crate::round::RoundStore::wants_more)
344        /// returns `false`), but then it didn't output anything ([`RoundStore::output`](crate::round::RoundStore::output)
345        /// returns `Err(_)`)
346        #[error("improper round store")]
347        ImproperRoundStore(#[source] ImproperRoundStore),
348        /// Indicates that there's a bug in the router implementation
349        #[error("bug (please, open an issue)")]
350        Bug(#[source] Bug),
351    }
352
353    #[derive(Debug, thiserror::Error)]
354    pub(super) enum ApiMisuse {
355        #[error(transparent)]
356        UnregisteredRound(#[from] UnregisteredRound),
357    }
358
359    #[derive(Debug, thiserror::Error)]
360    pub(super) enum ImproperRoundStore {
361        /// Store indicated that it received enough messages but didn't output
362        ///
363        /// I.e. [`store.wants_more()`] returned `false`, but `store.output()` returned `Err(_)`.
364        #[error("store didn't output")]
365        StoreDidntOutput,
366    }
367
368    #[derive(Debug, thiserror::Error)]
369    pub(super) enum Bug {
370        #[error("round is gone, but witness exists")]
371        RoundGoneButWitnessExists { n: u16 },
372        #[error(
373            "message originates from another round: we process messages from round \
374            {expected_round}, got message from round {actual_number}"
375        )]
376        MessageFromAnotherRound {
377            expected_round: u16,
378            actual_number: u16,
379        },
380        #[error("state is incoherent, it's expected to be {expected}: {justification}")]
381        IncoherentState {
382            expected: &'static str,
383            justification: &'static str,
384        },
385        #[error("take round result")]
386        TakeRoundResult(#[source] TakeOutputError),
387        #[error("mismatched output type")]
388        MismatchedOutputType,
389        #[error("mismatched error type")]
390        MismatchedErrorType,
391    }
392
393    macro_rules! impl_round_complete_from {
394        ($(|$err:ident: $err_ty:ty| $err_fn:expr),+$(,)?) => {$(
395            impl<E, IoErr> From<$err_ty> for CompleteRoundError<E, IoErr> {
396                fn from($err: $err_ty) -> Self {
397                    $err_fn
398                }
399            }
400        )+};
401    }
402
403    impl_round_complete_from! {
404        |err: ApiMisuse| CompleteRoundError::Router(RouterError(Reason::ApiMisuse(err))),
405        |err: ImproperRoundStore| CompleteRoundError::Router(RouterError(Reason::ImproperRoundStore(err))),
406        |err: Bug| CompleteRoundError::Router(RouterError(Reason::Bug(err))),
407        |err: UnregisteredRound| ApiMisuse::UnregisteredRound(err).into(),
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    struct Store;
414
415    #[derive(crate::ProtocolMsg)]
416    #[protocol_msg(root = crate)]
417    enum FakeProtocolMsg {
418        R1(Msg1),
419    }
420    struct Msg1;
421
422    impl crate::round::RoundInfo for Store {
423        type Msg = Msg1;
424        type Output = ();
425        type Error = core::convert::Infallible;
426    }
427    impl crate::round::RoundStore for Store {
428        fn add_message(&mut self, _msg: crate::Incoming<Self::Msg>) -> Result<(), Self::Error> {
429            Ok(())
430        }
431        fn wants_more(&self) -> bool {
432            false
433        }
434        fn output(self) -> Result<Self::Output, Self> {
435            Ok(())
436        }
437    }
438
439    #[test]
440    fn complete_round_that_expects_no_messages() {
441        let mut rounds = super::RoundsRouter::<FakeProtocolMsg>::new();
442        let round1 = rounds.add_round(Store);
443
444        rounds.complete_round(round1).unwrap().unwrap();
445    }
446}