1use core::marker::PhantomData;
59
60use alloc::collections::btree_map::BTreeMap;
61use digest::Digest;
62
63use crate::{
64 Mpc, MpcExecution, Outgoing, ProtocolMsg, RoundMsg,
65 round::{RoundInfo, RoundStore, RoundStoreExt},
66};
67
68mod error;
69mod store;
70
71pub use self::error::{CompleteRoundError, EchoError, Error};
72
73pub enum Msg<D: Digest, M> {
75 Echo {
77 round: u16,
84 hash: digest::Output<D>,
86 },
87 Main(M),
89}
90
91mod sub_msg {
95 pub struct EchoMsg<D: digest::Digest, R> {
96 pub hash: digest::Output<D>,
97 pub _round: core::marker::PhantomData<R>,
98 }
99 #[derive(Debug, Clone)]
100 pub struct Main<M>(pub M);
101
102 impl<D: digest::Digest, R> Clone for EchoMsg<D, R> {
103 fn clone(&self) -> Self {
104 Self {
105 hash: self.hash.clone(),
106 _round: core::marker::PhantomData,
107 }
108 }
109 }
110}
111
112impl<D: Digest, M: Clone> Clone for Msg<D, M> {
115 fn clone(&self) -> Self {
116 match self {
117 Self::Echo { round, hash } => Self::Echo {
118 round: *round,
119 hash: hash.clone(),
120 },
121 Self::Main(msg) => Self::Main(msg.clone()),
122 }
123 }
124}
125
126impl<D: Digest, M: PartialEq> PartialEq for Msg<D, M> {
127 fn eq(&self, other: &Self) -> bool {
128 match self {
129 Self::Echo { round, hash } => {
130 matches!(other, Self::Echo { round: r2, hash: h2 } if round == r2 && hash == h2)
131 }
132 Self::Main(msg) => matches!(other, Self::Main(m2) if msg == m2),
133 }
134 }
135}
136
137impl<D: Digest, M: PartialEq> Eq for Msg<D, M> {}
138
139impl<D: Digest, M: core::fmt::Debug> core::fmt::Debug for Msg<D, M> {
140 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
141 match self {
142 Self::Echo { round, hash } => f
143 .debug_struct("Msg::Echo")
144 .field("round", round)
145 .field("hash", hash)
146 .finish(),
147 Self::Main(msg) => f.debug_tuple("Msg::Main").field(msg).finish(),
148 }
149 }
150}
151
152impl<D: Digest, M: ProtocolMsg> ProtocolMsg for Msg<D, M> {
153 fn round(&self) -> u16 {
154 match self {
155 Self::Echo { round, .. } => 2 * round + 1,
156 Self::Main(m) => 2 * m.round(),
157 }
158 }
159}
160
161impl<D: Digest, M: ProtocolMsg, R> RoundMsg<sub_msg::EchoMsg<D, R>> for Msg<D, M>
162where
163 M: RoundMsg<R>,
164{
165 const ROUND: u16 = 2 * M::ROUND + 1;
166 fn to_protocol_msg(round_msg: sub_msg::EchoMsg<D, R>) -> Self {
167 Self::Echo {
168 round: M::ROUND,
169 hash: round_msg.hash,
170 }
171 }
172 fn from_protocol_msg(protocol_msg: Self) -> Result<sub_msg::EchoMsg<D, R>, Self> {
173 match protocol_msg {
174 Self::Echo { round, hash } if round == M::ROUND => Ok(sub_msg::EchoMsg {
175 hash,
176 _round: PhantomData,
177 }),
178 _ => Err(protocol_msg),
179 }
180 }
181}
182
183impl<D: Digest, ProtoM, RoundM> RoundMsg<sub_msg::Main<RoundM>> for Msg<D, ProtoM>
184where
185 ProtoM: ProtocolMsg + RoundMsg<RoundM>,
186{
187 const ROUND: u16 = 2 * <ProtoM as RoundMsg<RoundM>>::ROUND;
188 fn to_protocol_msg(round_msg: sub_msg::Main<RoundM>) -> Self {
189 Self::Main(ProtoM::to_protocol_msg(round_msg.0))
190 }
191 fn from_protocol_msg(protocol_msg: Self) -> Result<sub_msg::Main<RoundM>, Self> {
192 if let Self::Main(msg) = protocol_msg {
193 ProtoM::from_protocol_msg(msg)
194 .map(sub_msg::Main)
195 .map_err(|m| Self::Main(m))
196 } else {
197 Err(protocol_msg)
198 }
199 }
200}
201
202pub fn wrap<D, M, MainMsg>(party: M, i: u16, n: u16) -> WithEchoBroadcast<D, M, MainMsg>
204where
205 D: Digest,
206 M: Mpc<Msg = Msg<D, MainMsg>>,
207 MainMsg: udigest::Digestable,
208{
209 WithEchoBroadcast {
210 party,
211 i,
212 n,
213 sent_reliable_msgs: Default::default(),
214 _ph: PhantomData,
215 }
216}
217
218pub struct WithEchoBroadcast<D: Digest, M, Msg> {
220 party: M,
221 i: u16,
222 n: u16,
223 sent_reliable_msgs: BTreeMap<u16, Option<Msg>>,
224 _ph: PhantomData<D>,
225}
226
227impl<D: Digest, M, Msg> WithEchoBroadcast<D, M, Msg> {
228 fn map_party<P>(self, f: impl FnOnce(M) -> P) -> WithEchoBroadcast<D, P, Msg> {
229 let party = f(self.party);
230 WithEchoBroadcast {
231 party,
232 i: self.i,
233 n: self.n,
234 sent_reliable_msgs: self.sent_reliable_msgs,
235 _ph: PhantomData,
236 }
237 }
238}
239
240impl<D, M, MainMsg> Mpc for WithEchoBroadcast<D, M, MainMsg>
241where
242 D: Digest + 'static,
243 M: Mpc<Msg = Msg<D, MainMsg>>,
244 MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static,
245{
246 type Msg = MainMsg;
247
248 type Exec = WithEchoBroadcast<D, M::Exec, MainMsg>;
249
250 type SendErr = error::Error<M::SendErr>;
251
252 fn add_round<R>(&mut self, round: R) -> <Self::Exec as MpcExecution>::Round<R>
253 where
254 R: RoundStore,
255 Self::Msg: RoundMsg<R::Msg>,
256 {
257 let reliable_broadcast_required = round
258 .read_prop::<crate::round::props::RequiresReliableBroadcast>()
259 .map(|x| x.0);
260 if reliable_broadcast_required == Some(true) {
261 let (main_round, echo_round) = store::new::<D, MainMsg, _>(self.i, self.n, round);
262 let main_round = self.party.add_round(store::WithMainMsg(main_round));
263 let echo_round = self.party.add_round(store::WithEchoError::from(echo_round));
264
265 self.sent_reliable_msgs.insert(Self::Msg::ROUND, None);
266
267 Round(Inner::WithReliabilityCheck {
268 main_round,
269 echo_round,
270 })
271 } else {
272 let round = self
273 .party
274 .add_round(store::WithError(store::WithMainMsg(round)));
275 Round(Inner::Unmodified(round))
276 }
277 }
278
279 fn finish_setup(self) -> Self::Exec {
280 self.map_party(|p| p.finish_setup())
281 }
282}
283
284impl<D, M, MainMsg> WithEchoBroadcast<D, M, MainMsg>
285where
286 D: Digest,
287 MainMsg: ProtocolMsg + Clone,
288{
289 fn on_send(&mut self, outgoing: &mut Outgoing<MainMsg>) -> Result<(), error::EchoError> {
290 if let Some(slot) = self.sent_reliable_msgs.get_mut(&outgoing.msg.round()) {
291 if !outgoing.recipient.is_reliable_broadcast() {
292 return Err(error::Reason::SentNonReliableMsgInReliableRound {
294 dest: outgoing.recipient,
295 round: outgoing.msg.round(),
296 }
297 .into());
298 }
299 outgoing.recipient = crate::MessageDestination::AllParties { reliable: false };
302 if slot.is_some() {
303 return Err(error::Reason::SendTwice.into());
304 }
305 *slot = Some(outgoing.msg.clone())
306 } else if outgoing.recipient.is_reliable_broadcast() {
307 return Err(error::Reason::SentReliableMsgInNonReliableRound {
309 round: outgoing.msg.round(),
310 }
311 .into());
312 }
313
314 Ok(())
315 }
316}
317
318impl<D, M, MainMsg> MpcExecution for WithEchoBroadcast<D, M, MainMsg>
319where
320 D: Digest + 'static,
321 M: MpcExecution<Msg = Msg<D, MainMsg>>,
322 MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static,
323{
324 type Round<R: RoundInfo> = Round<M, D, MainMsg, R>;
325 type Msg = MainMsg;
326 type CompleteRoundErr<E> =
327 error::CompleteRoundError<M::CompleteRoundErr<error::Error<E>>, M::SendErr>;
328 type SendErr = error::Error<M::SendErr>;
329 type SendMany = WithEchoBroadcast<D, M::SendMany, MainMsg>;
330
331 async fn complete<R>(
332 &mut self,
333 round: Self::Round<R>,
334 ) -> Result<R::Output, Self::CompleteRoundErr<R::Error>>
335 where
336 R: RoundInfo,
337 Self::Msg: RoundMsg<R::Msg>,
338 {
339 match round.0 {
340 Inner::Unmodified(round) => {
341 let output = self
343 .party
344 .complete(round)
345 .await
346 .map_err(error::CompleteRoundError::CompleteRound)?;
347 Ok(output)
348 }
349 Inner::WithReliabilityCheck {
350 main_round,
351 echo_round,
352 } => {
353 let main_output = self
355 .party
356 .complete(main_round)
357 .await
358 .map_err(error::CompleteRoundError::CompleteRound)?;
359 let sent_msg =
361 if let Some(Some(msg)) = self.sent_reliable_msgs.remove(&Self::Msg::ROUND) {
362 let msg: R::Msg = Self::Msg::from_protocol_msg(msg)
363 .map_err(|_| error::Reason::SentMsgFromProto)?;
364 Some(msg)
365 } else {
366 None
367 };
368 let (main_output, hash) = main_output.with_my_msg(sent_msg)?;
370 self.party
371 .send_to_all(Msg::Echo {
372 round: Self::Msg::ROUND,
373 hash,
374 })
375 .await
376 .map_err(error::CompleteRoundError::Send)?;
377 let echoes = self
379 .party
380 .complete(echo_round)
381 .await
382 .map_err(error::CompleteRoundError::CompleteRound)?;
383 let main_output = main_output.with_echo_output(echoes)?;
385
386 Ok(main_output)
387 }
388 }
389 }
390
391 async fn send(&mut self, mut outgoing: Outgoing<Self::Msg>) -> Result<(), Self::SendErr> {
392 self.on_send(&mut outgoing)?;
393
394 self.party
395 .send(outgoing.map(Msg::Main))
396 .await
397 .map_err(error::Error::Main)
398 }
399
400 fn send_many(self) -> Self::SendMany {
401 self.map_party(|p| p.send_many())
402 }
403
404 async fn yield_now(&self) {
405 self.party.yield_now().await
406 }
407}
408
409pub struct Round<M, D, ProtoMsg, R>(Inner<M, D, ProtoMsg, R>)
411where
412 M: MpcExecution,
413 D: Digest + 'static,
414 ProtoMsg: 'static,
415 R: RoundInfo;
416
417enum Inner<M, D, ProtoMsg, R>
418where
419 M: MpcExecution,
420 D: Digest + 'static,
421 ProtoMsg: 'static,
422 R: RoundInfo,
423{
424 Unmodified(M::Round<store::WithError<store::WithMainMsg<R>>>),
426 WithReliabilityCheck {
427 main_round: M::Round<store::WithMainMsg<store::MainRound<D, ProtoMsg, R>>>,
428 echo_round: M::Round<store::WithEchoError<store::EchoRound<D, R>, R::Error>>,
429 },
430}
431
432impl<D, M, MainMsg> crate::mpc::SendMany for WithEchoBroadcast<D, M, MainMsg>
433where
434 D: Digest + 'static,
435 M: crate::mpc::SendMany<Msg = Msg<D, MainMsg>>,
436 MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static,
437{
438 type Exec = WithEchoBroadcast<D, M::Exec, MainMsg>;
439 type Msg = MainMsg;
440 type SendErr = error::Error<M::SendErr>;
441
442 async fn send(&mut self, mut outgoing: Outgoing<Self::Msg>) -> Result<(), Self::SendErr> {
443 self.on_send(&mut outgoing)?;
444 self.party
445 .send(outgoing.map(Msg::Main))
446 .await
447 .map_err(error::Error::Main)
448 }
449
450 async fn flush(self) -> Result<Self::Exec, Self::SendErr> {
451 let party = self.party.flush().await.map_err(error::Error::Main)?;
452 Ok(WithEchoBroadcast {
453 party,
454 i: self.i,
455 n: self.n,
456 sent_reliable_msgs: self.sent_reliable_msgs,
457 _ph: PhantomData,
458 })
459 }
460}