round_based/mpc/party/
mod.rs

1//! Provides [`MpcParty`], default engine for MPC protocol execution that implements [`Mpc`] and [`MpcExecution`] traits
2
3use futures_util::{Sink, SinkExt, Stream, StreamExt};
4
5use crate::{
6    Incoming, Outgoing,
7    round::{RoundInfo, RoundStore},
8};
9
10use super::{Mpc, MpcExecution, ProtocolMsg, RoundMsg};
11
12mod router;
13pub mod runtime;
14
15pub use self::router::{Round, errors::RouterError};
16#[doc(no_inline)]
17pub use self::runtime::AsyncRuntime;
18
19/// MPC engine, carries out the protocol
20///
21/// Can be constructed via [`MpcParty::connected`] or [`MpcParty::connected_halves`], which wraps
22/// a channel of incoming and outgoing messages, and implements additional logic on top of this
23/// to facilitate the MPC protocol execution, such as routing incoming messages between round
24/// stores.
25///
26/// Implements [`Mpc`] and [`MpcExecution`].
27pub struct MpcParty<M, D, R = runtime::DefaultRuntime, const SETUP_COMPLETE: bool = false> {
28    router: router::RoundsRouter<M>,
29    io: D,
30    runtime: R,
31}
32
33impl<M, D, E> MpcParty<M, D>
34where
35    M: ProtocolMsg + 'static,
36    D: Stream<Item = Result<Incoming<M>, E>> + Unpin,
37    D: Sink<Outgoing<M>, Error = E> + Unpin,
38{
39    /// Constructs [`MpcParty`]
40    pub fn connected(delivery: D) -> Self {
41        Self {
42            router: router::RoundsRouter::new(),
43            io: delivery,
44            runtime: runtime::DefaultRuntime::default(),
45        }
46    }
47}
48
49impl<M, In, Out, E> MpcParty<M, Halves<In, Out>>
50where
51    M: ProtocolMsg + 'static,
52    In: Stream<Item = Result<Incoming<M>, E>> + Unpin,
53    Out: Sink<Outgoing<M>, Error = E> + Unpin,
54{
55    /// Constructs [`MpcParty`]
56    pub fn connected_halves(incomings: In, outgoings: Out) -> Self {
57        Self::connected(Halves::new(incomings, outgoings))
58    }
59}
60
61impl<M, D, X> MpcParty<M, D, X> {
62    /// Changes which async runtime to use
63    pub fn with_runtime<R>(self, runtime: R) -> MpcParty<M, D, R> {
64        MpcParty {
65            router: self.router,
66            io: self.io,
67            runtime,
68        }
69    }
70}
71
72impl<M, D, E, AsyncR> Mpc for MpcParty<M, D, AsyncR>
73where
74    M: ProtocolMsg + 'static,
75    D: Stream<Item = Result<Incoming<M>, E>> + Unpin,
76    D: Sink<Outgoing<M>, Error = E> + Unpin,
77    AsyncR: runtime::AsyncRuntime,
78{
79    type Msg = M;
80
81    type Exec = MpcParty<M, D, AsyncR, true>;
82
83    type SendErr = E;
84
85    fn add_round<R>(&mut self, round: R) -> <Self::Exec as MpcExecution>::Round<R>
86    where
87        R: RoundStore,
88        Self::Msg: RoundMsg<R::Msg>,
89    {
90        self.router.add_round(round)
91    }
92
93    fn finish_setup(self) -> Self::Exec {
94        MpcParty {
95            router: self.router,
96            io: self.io,
97            runtime: self.runtime,
98        }
99    }
100}
101
102impl<M, D, IoErr, AsyncR> MpcExecution for MpcParty<M, D, AsyncR, true>
103where
104    M: ProtocolMsg + 'static,
105    D: Stream<Item = Result<Incoming<M>, IoErr>> + Unpin,
106    D: Sink<Outgoing<M>, Error = IoErr> + Unpin,
107    AsyncR: runtime::AsyncRuntime,
108{
109    type Round<R: RoundInfo> = router::Round<R>;
110    type Msg = M;
111    type CompleteRoundErr<E> = CompleteRoundError<E, IoErr>;
112    type SendErr = IoErr;
113    type SendMany = SendMany<M, D, AsyncR>;
114
115    async fn complete<R>(
116        &mut self,
117        mut round: Self::Round<R>,
118    ) -> Result<R::Output, Self::CompleteRoundErr<R::Error>>
119    where
120        R: RoundInfo,
121        Self::Msg: RoundMsg<R::Msg>,
122    {
123        // Check if round is already completed
124        round = match self.router.complete_round(round) {
125            Ok(output) => return output.map_err(|e| e.map_io_err(|e| match e {})),
126            Err(w) => w,
127        };
128
129        // Round is not completed - we need more messages
130        loop {
131            let incoming = self
132                .io
133                .next()
134                .await
135                .ok_or(CompleteRoundError::UnexpectedEof)?
136                .map_err(CompleteRoundError::Io)?;
137            self.router.received_msg(incoming)?;
138
139            // Check if round was just completed
140            round = match self.router.complete_round(round) {
141                Ok(output) => return output.map_err(|e| e.map_io_err(|e| match e {})),
142                Err(w) => w,
143            };
144        }
145    }
146
147    async fn send(&mut self, msg: Outgoing<Self::Msg>) -> Result<(), Self::SendErr> {
148        self.io.send(msg).await
149    }
150
151    fn send_many(self) -> Self::SendMany {
152        SendMany { party: self }
153    }
154
155    async fn yield_now(&self) {
156        self.runtime.yield_now().await
157    }
158}
159
160/// Returned by [`MpcParty::send_many()`]
161pub struct SendMany<M, D, R> {
162    party: MpcParty<M, D, R, true>,
163}
164
165impl<M, D, E, AsyncR> super::SendMany for SendMany<M, D, AsyncR>
166where
167    M: ProtocolMsg + 'static,
168    D: Stream<Item = Result<Incoming<M>, E>> + Unpin,
169    D: Sink<Outgoing<M>, Error = E> + Unpin,
170    AsyncR: runtime::AsyncRuntime,
171{
172    type Exec = MpcParty<M, D, AsyncR, true>;
173    type Msg = <MpcParty<M, D, AsyncR> as Mpc>::Msg;
174    type SendErr = <MpcParty<M, D, AsyncR> as Mpc>::SendErr;
175
176    async fn send(&mut self, msg: Outgoing<Self::Msg>) -> Result<(), Self::SendErr> {
177        self.party.io.feed(msg).await
178    }
179
180    async fn flush(mut self) -> Result<Self::Exec, Self::SendErr> {
181        self.party.io.flush().await?;
182        Ok(self.party)
183    }
184}
185
186pin_project_lite::pin_project! {
187    /// Merges a stream and a sink into one structure that implements both [`Stream`] and [`Sink`]
188    pub struct Halves<In, Out> {
189        #[pin]
190        incomings: In,
191        #[pin]
192        outgoings: Out,
193    }
194}
195
196impl<In, Out> Halves<In, Out> {
197    /// Constructs `Halves`
198    pub fn new(incomings: In, outgoings: Out) -> Self {
199        Self {
200            incomings,
201            outgoings,
202        }
203    }
204
205    /// Deconstructs back into halves
206    pub fn into_inner(self) -> (In, Out) {
207        (self.incomings, self.outgoings)
208    }
209}
210
211impl<In, Out, M, E> Stream for Halves<In, Out>
212where
213    In: Stream<Item = Result<M, E>>,
214{
215    type Item = Result<M, E>;
216
217    fn poll_next(
218        self: core::pin::Pin<&mut Self>,
219        cx: &mut core::task::Context<'_>,
220    ) -> core::task::Poll<Option<Self::Item>> {
221        let this = self.project();
222        this.incomings.poll_next(cx)
223    }
224}
225
226impl<In, Out, M, E> Sink<M> for Halves<In, Out>
227where
228    Out: Sink<M, Error = E>,
229{
230    type Error = E;
231
232    fn poll_ready(
233        self: core::pin::Pin<&mut Self>,
234        cx: &mut core::task::Context<'_>,
235    ) -> core::task::Poll<Result<(), Self::Error>> {
236        let this = self.project();
237        this.outgoings.poll_ready(cx)
238    }
239    fn start_send(self: core::pin::Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
240        let this = self.project();
241        this.outgoings.start_send(item)
242    }
243    fn poll_flush(
244        self: core::pin::Pin<&mut Self>,
245        cx: &mut core::task::Context<'_>,
246    ) -> core::task::Poll<Result<(), Self::Error>> {
247        let this = self.project();
248        this.outgoings.poll_flush(cx)
249    }
250    fn poll_close(
251        self: core::pin::Pin<&mut Self>,
252        cx: &mut core::task::Context<'_>,
253    ) -> core::task::Poll<Result<(), Self::Error>> {
254        let this = self.project();
255        this.outgoings.poll_close(cx)
256    }
257}
258
259/// Error returned by [`MpcParty::complete`]
260///
261/// May indicate malicious behavior (e.g. adversary sent a message that aborts protocol execution)
262/// or some misconfiguration of the protocol network (e.g. received a message from the round that
263/// was not registered via [`Mpc::add_round`]).
264#[derive(Debug, thiserror::Error)]
265pub enum CompleteRoundError<ProcessErr, IoErr> {
266    /// [`RoundStore`] returned an error
267    ///
268    /// Refer to this rounds store documentation to understand why it could fail
269    #[error(transparent)]
270    ProcessMsg(ProcessErr),
271
272    /// Router error
273    ///
274    /// Indicates that for some reason router was not able to process a message. This can be the case of:
275    /// - Router API misuse \
276    ///   E.g. when received a message from the round that was not registered in the router
277    /// - Improper [`RoundStore`] implementation \
278    ///   Indicates that round store is not properly implemented and contains a flaw. \
279    ///   For instance, this error is returned when round store indicates that it doesn't need
280    ///   any more messages ([`RoundStore::wants_more`]
281    ///   returns `false`), but then it didn't output anything ([`RoundStore::output`]
282    ///   returns `Err(_)`)
283    /// - Bug in the router
284    ///
285    /// This error is always related to some implementation flaw or bug: either in the code that uses
286    /// the router, or in the round store implementation, or in the router itself. When implementation
287    /// is correct, this error never appears. Thus, it should not be possible for the adversary to "make
288    /// this error happen."
289    Router(router::errors::RouterError),
290
291    /// Receiving the next message resulted into I/O error
292    Io(IoErr),
293    /// Channel of incoming messages was closed before protocol completion
294    UnexpectedEof,
295}
296
297impl<ProcessErr, IoErr> CompleteRoundError<ProcessErr, IoErr> {
298    /// Maps I/O error
299    pub fn map_io_err<E>(self, f: impl FnOnce(IoErr) -> E) -> CompleteRoundError<ProcessErr, E> {
300        match self {
301            CompleteRoundError::ProcessMsg(e) => CompleteRoundError::ProcessMsg(e),
302            CompleteRoundError::Router(e) => CompleteRoundError::Router(e),
303            CompleteRoundError::Io(e) => CompleteRoundError::Io(f(e)),
304            CompleteRoundError::UnexpectedEof => CompleteRoundError::UnexpectedEof,
305        }
306    }
307    /// Maps [`CompleteRoundError::ProcessMsg`]
308    pub fn map_process_err<E>(
309        self,
310        f: impl FnOnce(ProcessErr) -> E,
311    ) -> CompleteRoundError<E, IoErr> {
312        match self {
313            CompleteRoundError::ProcessMsg(e) => CompleteRoundError::ProcessMsg(f(e)),
314            CompleteRoundError::Router(e) => CompleteRoundError::Router(e),
315            CompleteRoundError::Io(e) => CompleteRoundError::Io(e),
316            CompleteRoundError::UnexpectedEof => CompleteRoundError::UnexpectedEof,
317        }
318    }
319}