1use futures_util::{Sink, SinkExt, Stream, StreamExt};
4
5use crate::{
6 Incoming, Outgoing,
7 round::{RoundInfo, RoundStore},
8};
9
10use super::{Mpc, MpcExecution, ProtocolMsg, RoundMsg};
11
12mod router;
13pub mod runtime;
14
15pub use self::router::{Round, errors::RouterError};
16#[doc(no_inline)]
17pub use self::runtime::AsyncRuntime;
18
19pub struct MpcParty<M, D, R = runtime::DefaultRuntime, const SETUP_COMPLETE: bool = false> {
28 router: router::RoundsRouter<M>,
29 io: D,
30 runtime: R,
31}
32
33impl<M, D, E> MpcParty<M, D>
34where
35 M: ProtocolMsg + 'static,
36 D: Stream<Item = Result<Incoming<M>, E>> + Unpin,
37 D: Sink<Outgoing<M>, Error = E> + Unpin,
38{
39 pub fn connected(delivery: D) -> Self {
41 Self {
42 router: router::RoundsRouter::new(),
43 io: delivery,
44 runtime: runtime::DefaultRuntime::default(),
45 }
46 }
47}
48
49impl<M, In, Out, E> MpcParty<M, Halves<In, Out>>
50where
51 M: ProtocolMsg + 'static,
52 In: Stream<Item = Result<Incoming<M>, E>> + Unpin,
53 Out: Sink<Outgoing<M>, Error = E> + Unpin,
54{
55 pub fn connected_halves(incomings: In, outgoings: Out) -> Self {
57 Self::connected(Halves::new(incomings, outgoings))
58 }
59}
60
61impl<M, D, X> MpcParty<M, D, X> {
62 pub fn with_runtime<R>(self, runtime: R) -> MpcParty<M, D, R> {
64 MpcParty {
65 router: self.router,
66 io: self.io,
67 runtime,
68 }
69 }
70}
71
72impl<M, D, E, AsyncR> Mpc for MpcParty<M, D, AsyncR>
73where
74 M: ProtocolMsg + 'static,
75 D: Stream<Item = Result<Incoming<M>, E>> + Unpin,
76 D: Sink<Outgoing<M>, Error = E> + Unpin,
77 AsyncR: runtime::AsyncRuntime,
78{
79 type Msg = M;
80
81 type Exec = MpcParty<M, D, AsyncR, true>;
82
83 type SendErr = E;
84
85 fn add_round<R>(&mut self, round: R) -> <Self::Exec as MpcExecution>::Round<R>
86 where
87 R: RoundStore,
88 Self::Msg: RoundMsg<R::Msg>,
89 {
90 self.router.add_round(round)
91 }
92
93 fn finish_setup(self) -> Self::Exec {
94 MpcParty {
95 router: self.router,
96 io: self.io,
97 runtime: self.runtime,
98 }
99 }
100}
101
102impl<M, D, IoErr, AsyncR> MpcExecution for MpcParty<M, D, AsyncR, true>
103where
104 M: ProtocolMsg + 'static,
105 D: Stream<Item = Result<Incoming<M>, IoErr>> + Unpin,
106 D: Sink<Outgoing<M>, Error = IoErr> + Unpin,
107 AsyncR: runtime::AsyncRuntime,
108{
109 type Round<R: RoundInfo> = router::Round<R>;
110 type Msg = M;
111 type CompleteRoundErr<E> = CompleteRoundError<E, IoErr>;
112 type SendErr = IoErr;
113 type SendMany = SendMany<M, D, AsyncR>;
114
115 async fn complete<R>(
116 &mut self,
117 mut round: Self::Round<R>,
118 ) -> Result<R::Output, Self::CompleteRoundErr<R::Error>>
119 where
120 R: RoundInfo,
121 Self::Msg: RoundMsg<R::Msg>,
122 {
123 round = match self.router.complete_round(round) {
125 Ok(output) => return output.map_err(|e| e.map_io_err(|e| match e {})),
126 Err(w) => w,
127 };
128
129 loop {
131 let incoming = self
132 .io
133 .next()
134 .await
135 .ok_or(CompleteRoundError::UnexpectedEof)?
136 .map_err(CompleteRoundError::Io)?;
137 self.router.received_msg(incoming)?;
138
139 round = match self.router.complete_round(round) {
141 Ok(output) => return output.map_err(|e| e.map_io_err(|e| match e {})),
142 Err(w) => w,
143 };
144 }
145 }
146
147 async fn send(&mut self, msg: Outgoing<Self::Msg>) -> Result<(), Self::SendErr> {
148 self.io.send(msg).await
149 }
150
151 fn send_many(self) -> Self::SendMany {
152 SendMany { party: self }
153 }
154
155 async fn yield_now(&self) {
156 self.runtime.yield_now().await
157 }
158}
159
160pub struct SendMany<M, D, R> {
162 party: MpcParty<M, D, R, true>,
163}
164
165impl<M, D, E, AsyncR> super::SendMany for SendMany<M, D, AsyncR>
166where
167 M: ProtocolMsg + 'static,
168 D: Stream<Item = Result<Incoming<M>, E>> + Unpin,
169 D: Sink<Outgoing<M>, Error = E> + Unpin,
170 AsyncR: runtime::AsyncRuntime,
171{
172 type Exec = MpcParty<M, D, AsyncR, true>;
173 type Msg = <MpcParty<M, D, AsyncR> as Mpc>::Msg;
174 type SendErr = <MpcParty<M, D, AsyncR> as Mpc>::SendErr;
175
176 async fn send(&mut self, msg: Outgoing<Self::Msg>) -> Result<(), Self::SendErr> {
177 self.party.io.feed(msg).await
178 }
179
180 async fn flush(mut self) -> Result<Self::Exec, Self::SendErr> {
181 self.party.io.flush().await?;
182 Ok(self.party)
183 }
184}
185
186pin_project_lite::pin_project! {
187 pub struct Halves<In, Out> {
189 #[pin]
190 incomings: In,
191 #[pin]
192 outgoings: Out,
193 }
194}
195
196impl<In, Out> Halves<In, Out> {
197 pub fn new(incomings: In, outgoings: Out) -> Self {
199 Self {
200 incomings,
201 outgoings,
202 }
203 }
204
205 pub fn into_inner(self) -> (In, Out) {
207 (self.incomings, self.outgoings)
208 }
209}
210
211impl<In, Out, M, E> Stream for Halves<In, Out>
212where
213 In: Stream<Item = Result<M, E>>,
214{
215 type Item = Result<M, E>;
216
217 fn poll_next(
218 self: core::pin::Pin<&mut Self>,
219 cx: &mut core::task::Context<'_>,
220 ) -> core::task::Poll<Option<Self::Item>> {
221 let this = self.project();
222 this.incomings.poll_next(cx)
223 }
224}
225
226impl<In, Out, M, E> Sink<M> for Halves<In, Out>
227where
228 Out: Sink<M, Error = E>,
229{
230 type Error = E;
231
232 fn poll_ready(
233 self: core::pin::Pin<&mut Self>,
234 cx: &mut core::task::Context<'_>,
235 ) -> core::task::Poll<Result<(), Self::Error>> {
236 let this = self.project();
237 this.outgoings.poll_ready(cx)
238 }
239 fn start_send(self: core::pin::Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
240 let this = self.project();
241 this.outgoings.start_send(item)
242 }
243 fn poll_flush(
244 self: core::pin::Pin<&mut Self>,
245 cx: &mut core::task::Context<'_>,
246 ) -> core::task::Poll<Result<(), Self::Error>> {
247 let this = self.project();
248 this.outgoings.poll_flush(cx)
249 }
250 fn poll_close(
251 self: core::pin::Pin<&mut Self>,
252 cx: &mut core::task::Context<'_>,
253 ) -> core::task::Poll<Result<(), Self::Error>> {
254 let this = self.project();
255 this.outgoings.poll_close(cx)
256 }
257}
258
259#[derive(Debug, thiserror::Error)]
265pub enum CompleteRoundError<ProcessErr, IoErr> {
266 #[error(transparent)]
270 ProcessMsg(ProcessErr),
271
272 Router(router::errors::RouterError),
290
291 Io(IoErr),
293 UnexpectedEof,
295}
296
297impl<ProcessErr, IoErr> CompleteRoundError<ProcessErr, IoErr> {
298 pub fn map_io_err<E>(self, f: impl FnOnce(IoErr) -> E) -> CompleteRoundError<ProcessErr, E> {
300 match self {
301 CompleteRoundError::ProcessMsg(e) => CompleteRoundError::ProcessMsg(e),
302 CompleteRoundError::Router(e) => CompleteRoundError::Router(e),
303 CompleteRoundError::Io(e) => CompleteRoundError::Io(f(e)),
304 CompleteRoundError::UnexpectedEof => CompleteRoundError::UnexpectedEof,
305 }
306 }
307 pub fn map_process_err<E>(
309 self,
310 f: impl FnOnce(ProcessErr) -> E,
311 ) -> CompleteRoundError<E, IoErr> {
312 match self {
313 CompleteRoundError::ProcessMsg(e) => CompleteRoundError::ProcessMsg(f(e)),
314 CompleteRoundError::Router(e) => CompleteRoundError::Router(e),
315 CompleteRoundError::Io(e) => CompleteRoundError::Io(e),
316 CompleteRoundError::UnexpectedEof => CompleteRoundError::UnexpectedEof,
317 }
318 }
319}