round_based/echo_broadcast/
mod.rs

1//! Reliable broadcast for any protocol via echo messages
2//!
3//! Broadcast message is a message meant to be received by all participants of the protocol.
4//!
5//! We say that message is reliably broadcasted if, upon reception, it is guaranteed that all
6//! honest participants of the protocol has received the same message.
7//!
8//! One way to achieve the reliable broadcast is by adding an echo round: when we receive
9//! messages in a reliable broadcast round, we hash all messages, and we send the hash to all
10//! other participants. If party receives a the same hash from everyone else, we can be
11//! assured that messages in the round were reliably broadcasted.
12//!
13//! This module provides a mechanism that automatically add an echo round per each
14//! round of the protocol that requires a reliable broadcast.
15//!
16//! ## Example
17//!
18//! ```rust
19//! # #[derive(round_based::ProtocolMsg, Clone, udigest::Digestable)]
20//! # enum KeygenMsg {}
21//! # struct KeyShare;
22//! # struct Error;
23//! # type Result<T> = std::result::Result<T, Error>;
24//! # async fn doc() -> Result<()> {
25//! // protocol to be executed that **requires** reliable broadcast
26//! async fn keygen<M>(mpc: M, i: u16, n: u16) -> Result<KeyShare>
27//! where
28//!     M: round_based::Mpc<Msg = KeygenMsg>
29//! {
30//!     // ...
31//! # unimplemented!()
32//! }
33//! // The full message type, which corresponds to keygen msg + echo broadcast msg
34//! type Msg = round_based::echo_broadcast::Msg<sha2::Sha256, KeygenMsg>;
35//! // establishes network connection(s) to other parties, but
36//! // **does not** support reliable broadcast
37//! async fn connect() ->
38//!     impl futures::Stream<Item = Result<round_based::Incoming<Msg>>>
39//!         + futures::Sink<round_based::Outgoing<Msg>, Error = Error>
40//!         + Unpin
41//! {
42//!     // ...
43//! # round_based::_docs::fake_delivery()
44//! }
45//! let delivery = connect().await;
46//!
47//! # let (i, n) = (1, 3);
48//! // constructs an MPC engine as usual
49//! let mpc = round_based::mpc::connected(delivery);
50//! // wraps an engine to add reliable broadcast support
51//! let mpc = round_based::echo_broadcast::wrap(mpc, i, n);
52//!
53//! // execute the protocol
54//! let keyshare = keygen(mpc, i, n).await?;
55//! # Ok(()) }
56//! ```
57
58use core::marker::PhantomData;
59
60use alloc::collections::btree_map::BTreeMap;
61use digest::Digest;
62
63use crate::{
64    Mpc, MpcExecution, Outgoing, ProtocolMsg, RoundMsg,
65    round::{RoundInfo, RoundStore, RoundStoreExt},
66};
67
68mod error;
69mod store;
70
71pub use self::error::{CompleteRoundError, EchoError, Error};
72
73/// Message of the protocol with echo broadcast round(s)
74pub enum Msg<D: Digest, M> {
75    /// Message from echo broadcast sub-protocol
76    Echo {
77        /// Indicates for which round of main protocol this echo message is transmitted
78        ///
79        /// Note that this field is controlled by potential malicious party. If it sets it
80        /// to the round that doesn't exist, the protocol will likely be aborted with an error
81        /// that we received a message from unregistered round, which may appear as implementation
82        /// error (i.e. API misuse), but in fact it's a malicious abort.
83        round: u16,
84        /// Hash of all messages received in `round`
85        hash: digest::Output<D>,
86    },
87    /// Message from the main protocol
88    Main(M),
89}
90
91/// Sub-messages of [`Msg`]
92///
93/// Sub-messages implement [`RoundMsg`] trait for [`Msg`]
94mod sub_msg {
95    pub struct EchoMsg<D: digest::Digest, R> {
96        pub hash: digest::Output<D>,
97        pub _round: core::marker::PhantomData<R>,
98    }
99    #[derive(Debug, Clone)]
100    pub struct Main<M>(pub M);
101
102    impl<D: digest::Digest, R> Clone for EchoMsg<D, R> {
103        fn clone(&self) -> Self {
104            Self {
105                hash: self.hash.clone(),
106                _round: core::marker::PhantomData,
107            }
108        }
109    }
110}
111
112// `D` doesn't implement traits like `Clone`, `Eq`, etc. so we have to implement those traits by hand
113
114impl<D: Digest, M: Clone> Clone for Msg<D, M> {
115    fn clone(&self) -> Self {
116        match self {
117            Self::Echo { round, hash } => Self::Echo {
118                round: *round,
119                hash: hash.clone(),
120            },
121            Self::Main(msg) => Self::Main(msg.clone()),
122        }
123    }
124}
125
126impl<D: Digest, M: PartialEq> PartialEq for Msg<D, M> {
127    fn eq(&self, other: &Self) -> bool {
128        match self {
129            Self::Echo { round, hash } => {
130                matches!(other, Self::Echo { round: r2, hash: h2 } if round == r2 && hash == h2)
131            }
132            Self::Main(msg) => matches!(other, Self::Main(m2) if msg == m2),
133        }
134    }
135}
136
137impl<D: Digest, M: PartialEq> Eq for Msg<D, M> {}
138
139impl<D: Digest, M: core::fmt::Debug> core::fmt::Debug for Msg<D, M> {
140    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
141        match self {
142            Self::Echo { round, hash } => f
143                .debug_struct("Msg::Echo")
144                .field("round", round)
145                .field("hash", hash)
146                .finish(),
147            Self::Main(msg) => f.debug_tuple("Msg::Main").field(msg).finish(),
148        }
149    }
150}
151
152impl<D: Digest, M: ProtocolMsg> ProtocolMsg for Msg<D, M> {
153    fn round(&self) -> u16 {
154        match self {
155            Self::Echo { round, .. } => 2 * round + 1,
156            Self::Main(m) => 2 * m.round(),
157        }
158    }
159}
160
161impl<D: Digest, M: ProtocolMsg, R> RoundMsg<sub_msg::EchoMsg<D, R>> for Msg<D, M>
162where
163    M: RoundMsg<R>,
164{
165    const ROUND: u16 = 2 * M::ROUND + 1;
166    fn to_protocol_msg(round_msg: sub_msg::EchoMsg<D, R>) -> Self {
167        Self::Echo {
168            round: M::ROUND,
169            hash: round_msg.hash,
170        }
171    }
172    fn from_protocol_msg(protocol_msg: Self) -> Result<sub_msg::EchoMsg<D, R>, Self> {
173        match protocol_msg {
174            Self::Echo { round, hash } if round == M::ROUND => Ok(sub_msg::EchoMsg {
175                hash,
176                _round: PhantomData,
177            }),
178            _ => Err(protocol_msg),
179        }
180    }
181}
182
183impl<D: Digest, ProtoM, RoundM> RoundMsg<sub_msg::Main<RoundM>> for Msg<D, ProtoM>
184where
185    ProtoM: ProtocolMsg + RoundMsg<RoundM>,
186{
187    const ROUND: u16 = 2 * <ProtoM as RoundMsg<RoundM>>::ROUND;
188    fn to_protocol_msg(round_msg: sub_msg::Main<RoundM>) -> Self {
189        Self::Main(ProtoM::to_protocol_msg(round_msg.0))
190    }
191    fn from_protocol_msg(protocol_msg: Self) -> Result<sub_msg::Main<RoundM>, Self> {
192        if let Self::Main(msg) = protocol_msg {
193            ProtoM::from_protocol_msg(msg)
194                .map(sub_msg::Main)
195                .map_err(|m| Self::Main(m))
196        } else {
197            Err(protocol_msg)
198        }
199    }
200}
201
202/// Wraps an [`Mpc`] engine and provides echo broadcast capabilities
203pub fn wrap<D, M, MainMsg>(party: M, i: u16, n: u16) -> WithEchoBroadcast<D, M, MainMsg>
204where
205    D: Digest,
206    M: Mpc<Msg = Msg<D, MainMsg>>,
207    MainMsg: udigest::Digestable,
208{
209    WithEchoBroadcast {
210        party,
211        i,
212        n,
213        sent_reliable_msgs: Default::default(),
214        _ph: PhantomData,
215    }
216}
217
218/// [`Mpc`] engine with echo-broadcast capabilities
219pub struct WithEchoBroadcast<D: Digest, M, Msg> {
220    party: M,
221    i: u16,
222    n: u16,
223    sent_reliable_msgs: BTreeMap<u16, Option<Msg>>,
224    _ph: PhantomData<D>,
225}
226
227impl<D: Digest, M, Msg> WithEchoBroadcast<D, M, Msg> {
228    fn map_party<P>(self, f: impl FnOnce(M) -> P) -> WithEchoBroadcast<D, P, Msg> {
229        let party = f(self.party);
230        WithEchoBroadcast {
231            party,
232            i: self.i,
233            n: self.n,
234            sent_reliable_msgs: self.sent_reliable_msgs,
235            _ph: PhantomData,
236        }
237    }
238}
239
240impl<D, M, MainMsg> Mpc for WithEchoBroadcast<D, M, MainMsg>
241where
242    D: Digest + 'static,
243    M: Mpc<Msg = Msg<D, MainMsg>>,
244    MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static,
245{
246    type Msg = MainMsg;
247
248    type Exec = WithEchoBroadcast<D, M::Exec, MainMsg>;
249
250    type SendErr = error::Error<M::SendErr>;
251
252    fn add_round<R>(&mut self, round: R) -> <Self::Exec as MpcExecution>::Round<R>
253    where
254        R: RoundStore,
255        Self::Msg: RoundMsg<R::Msg>,
256    {
257        let reliable_broadcast_required = round
258            .read_prop::<crate::round::props::RequiresReliableBroadcast>()
259            .map(|x| x.0);
260        if reliable_broadcast_required == Some(true) {
261            let (main_round, echo_round) = store::new::<D, MainMsg, _>(self.i, self.n, round);
262            let main_round = self.party.add_round(store::WithMainMsg(main_round));
263            let echo_round = self.party.add_round(store::WithEchoError::from(echo_round));
264
265            self.sent_reliable_msgs.insert(Self::Msg::ROUND, None);
266
267            Round(Inner::WithReliabilityCheck {
268                main_round,
269                echo_round,
270            })
271        } else {
272            let round = self
273                .party
274                .add_round(store::WithError(store::WithMainMsg(round)));
275            Round(Inner::Unmodified(round))
276        }
277    }
278
279    fn finish_setup(self) -> Self::Exec {
280        self.map_party(|p| p.finish_setup())
281    }
282}
283
284impl<D, M, MainMsg> WithEchoBroadcast<D, M, MainMsg>
285where
286    D: Digest,
287    MainMsg: ProtocolMsg + Clone,
288{
289    fn on_send(&mut self, outgoing: &mut Outgoing<MainMsg>) -> Result<(), error::EchoError> {
290        if let Some(slot) = self.sent_reliable_msgs.get_mut(&outgoing.msg.round()) {
291            if !outgoing.recipient.is_reliable_broadcast() {
292                // it's reliable broadcast round, but message is not reliable broadcast
293                return Err(error::Reason::SentNonReliableMsgInReliableRound {
294                    dest: outgoing.recipient,
295                    round: outgoing.msg.round(),
296                }
297                .into());
298            }
299            // Message delivery layer doesn't need to know that protocol wants this message to be
300            // reliably broadcasted - echo broadcast takes care of it
301            outgoing.recipient = crate::MessageDestination::AllParties { reliable: false };
302            if slot.is_some() {
303                return Err(error::Reason::SendTwice.into());
304            }
305            *slot = Some(outgoing.msg.clone())
306        } else if outgoing.recipient.is_reliable_broadcast() {
307            // it's not a reliable broadcast round, but message is a reliable broadcast
308            return Err(error::Reason::SentReliableMsgInNonReliableRound {
309                round: outgoing.msg.round(),
310            }
311            .into());
312        }
313
314        Ok(())
315    }
316}
317
318impl<D, M, MainMsg> MpcExecution for WithEchoBroadcast<D, M, MainMsg>
319where
320    D: Digest + 'static,
321    M: MpcExecution<Msg = Msg<D, MainMsg>>,
322    MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static,
323{
324    type Round<R: RoundInfo> = Round<M, D, MainMsg, R>;
325    type Msg = MainMsg;
326    type CompleteRoundErr<E> =
327        error::CompleteRoundError<M::CompleteRoundErr<error::Error<E>>, M::SendErr>;
328    type SendErr = error::Error<M::SendErr>;
329    type SendMany = WithEchoBroadcast<D, M::SendMany, MainMsg>;
330
331    async fn complete<R>(
332        &mut self,
333        round: Self::Round<R>,
334    ) -> Result<R::Output, Self::CompleteRoundErr<R::Error>>
335    where
336        R: RoundInfo,
337        Self::Msg: RoundMsg<R::Msg>,
338    {
339        match round.0 {
340            Inner::Unmodified(round) => {
341                // regular round that doesn't need reliable broadcast
342                let output = self
343                    .party
344                    .complete(round)
345                    .await
346                    .map_err(error::CompleteRoundError::CompleteRound)?;
347                Ok(output)
348            }
349            Inner::WithReliabilityCheck {
350                main_round,
351                echo_round,
352            } => {
353                // receive all messages in the main round
354                let main_output = self
355                    .party
356                    .complete(main_round)
357                    .await
358                    .map_err(error::CompleteRoundError::CompleteRound)?;
359                // retrieve a msg that we sent in this round
360                let sent_msg =
361                    if let Some(Some(msg)) = self.sent_reliable_msgs.remove(&Self::Msg::ROUND) {
362                        let msg: R::Msg = Self::Msg::from_protocol_msg(msg)
363                            .map_err(|_| error::Reason::SentMsgFromProto)?;
364                        Some(msg)
365                    } else {
366                        None
367                    };
368                // calculate a hash and send it to all other parties
369                let (main_output, hash) = main_output.with_my_msg(sent_msg)?;
370                self.party
371                    .send_to_all(Msg::Echo {
372                        round: Self::Msg::ROUND,
373                        hash,
374                    })
375                    .await
376                    .map_err(error::CompleteRoundError::Send)?;
377                // receive echoes from other parties
378                let echoes = self
379                    .party
380                    .complete(echo_round)
381                    .await
382                    .map_err(error::CompleteRoundError::CompleteRound)?;
383                // check that everyone sent the same hash
384                let main_output = main_output.with_echo_output(echoes)?;
385
386                Ok(main_output)
387            }
388        }
389    }
390
391    async fn send(&mut self, mut outgoing: Outgoing<Self::Msg>) -> Result<(), Self::SendErr> {
392        self.on_send(&mut outgoing)?;
393
394        self.party
395            .send(outgoing.map(Msg::Main))
396            .await
397            .map_err(error::Error::Main)
398    }
399
400    fn send_many(self) -> Self::SendMany {
401        self.map_party(|p| p.send_many())
402    }
403
404    async fn yield_now(&self) {
405        self.party.yield_now().await
406    }
407}
408
409/// Round registration witness returned by [`WithEchoBroadcast::add_round()`]
410pub struct Round<M, D, ProtoMsg, R>(Inner<M, D, ProtoMsg, R>)
411where
412    M: MpcExecution,
413    D: Digest + 'static,
414    ProtoMsg: 'static,
415    R: RoundInfo;
416
417enum Inner<M, D, ProtoMsg, R>
418where
419    M: MpcExecution,
420    D: Digest + 'static,
421    ProtoMsg: 'static,
422    R: RoundInfo,
423{
424    /// Round that we do not modify (round that doesn't require reliable broadcast)
425    Unmodified(M::Round<store::WithError<store::WithMainMsg<R>>>),
426    WithReliabilityCheck {
427        main_round: M::Round<store::WithMainMsg<store::MainRound<D, ProtoMsg, R>>>,
428        echo_round: M::Round<store::WithEchoError<store::EchoRound<D, R>, R::Error>>,
429    },
430}
431
432impl<D, M, MainMsg> crate::mpc::SendMany for WithEchoBroadcast<D, M, MainMsg>
433where
434    D: Digest + 'static,
435    M: crate::mpc::SendMany<Msg = Msg<D, MainMsg>>,
436    MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static,
437{
438    type Exec = WithEchoBroadcast<D, M::Exec, MainMsg>;
439    type Msg = MainMsg;
440    type SendErr = error::Error<M::SendErr>;
441
442    async fn send(&mut self, mut outgoing: Outgoing<Self::Msg>) -> Result<(), Self::SendErr> {
443        self.on_send(&mut outgoing)?;
444        self.party
445            .send(outgoing.map(Msg::Main))
446            .await
447            .map_err(error::Error::Main)
448    }
449
450    async fn flush(self) -> Result<Self::Exec, Self::SendErr> {
451        let party = self.party.flush().await.map_err(error::Error::Main)?;
452        Ok(WithEchoBroadcast {
453            party,
454            i: self.i,
455            n: self.n,
456            sent_reliable_msgs: self.sent_reliable_msgs,
457            _ph: PhantomData,
458        })
459    }
460}