1use alloc::{boxed::Box, collections::BTreeMap};
7use core::{any::Any, convert::Infallible, mem};
8
9use phantom_type::PhantomType;
10use tracing::{error, trace_span, warn};
11
12use crate::{
13 Incoming, ProtocolMsg, RoundMsg,
14 round::{RoundInfo, RoundStore},
15};
16
17pub struct RoundsRouter<M> {
19 rounds: BTreeMap<u16, Option<Box<dyn ProcessRoundMessage<Msg = M>>>>,
20}
21
22impl<M> RoundsRouter<M>
23where
24 M: ProtocolMsg + 'static,
25{
26 pub fn new() -> Self {
27 Self {
28 rounds: Default::default(),
29 }
30 }
31
32 pub fn add_round<R>(&mut self, message_store: R) -> Round<R>
37 where
38 R: RoundStore,
39 M: RoundMsg<R::Msg>,
40 {
41 let overridden_round = self.rounds.insert(
42 M::ROUND,
43 Some(Box::new(ProcessRoundMessageImpl::new(message_store))),
44 );
45 if overridden_round.is_some() {
46 panic!("round {} is overridden", M::ROUND);
47 }
48 Round {
49 _ph: PhantomType::new(),
50 }
51 }
52
53 pub fn received_msg(&mut self, incoming: Incoming<M>) -> Result<(), errors::UnregisteredRound> {
54 let msg_round_n = incoming.msg.round();
55 let span = trace_span!(
56 "Round::received_msg",
57 round = %msg_round_n,
58 sender = %incoming.sender,
59 ty = ?incoming.msg_type
60 );
61 let _guard = span.enter();
62
63 let message_round = match self.rounds.get_mut(&msg_round_n) {
64 Some(Some(round)) => round,
65 Some(None) => {
66 warn!("got message for the round that was already completed, ignoring it");
67 return Ok(());
68 }
69 None => {
70 return Err(errors::UnregisteredRound {
71 n: msg_round_n,
72 witness_provided: false,
73 });
74 }
75 };
76 if message_round.needs_more_messages().no() {
77 warn!("received message for the round that was already completed, ignoring it");
78 return Ok(());
79 }
80 message_round.process_message(incoming);
81 Ok(())
82 }
83
84 #[allow(clippy::type_complexity)]
85 pub fn complete_round<R>(
86 &mut self,
87 round: Round<R>,
88 ) -> Result<Result<R::Output, errors::CompleteRoundError<R::Error, Infallible>>, Round<R>>
89 where
90 R: RoundInfo,
91 M: RoundMsg<R::Msg>,
92 {
93 let message_round = match self.rounds.get_mut(&M::ROUND) {
94 Some(Some(round)) => round,
95 Some(None) => {
96 return Ok(Err(
97 errors::Bug::RoundGoneButWitnessExists { n: M::ROUND }.into()
98 ));
99 }
100 None => {
101 return Ok(Err(errors::UnregisteredRound {
102 n: M::ROUND,
103 witness_provided: true,
104 }
105 .into()));
106 }
107 };
108 if message_round.needs_more_messages().yes() {
109 return Err(round);
110 }
111 Ok(Self::retrieve_round_output::<R>(message_round))
112 }
113
114 fn retrieve_round_output<R>(
115 round: &mut Box<dyn ProcessRoundMessage<Msg = M>>,
116 ) -> Result<R::Output, errors::CompleteRoundError<R::Error, Infallible>>
117 where
118 R: RoundInfo,
119 {
120 match round.take_output() {
121 Ok(Ok(any)) => Ok(*any
122 .downcast::<R::Output>()
123 .or(Err(errors::Bug::MismatchedOutputType))?),
124 Ok(Err(any)) => Err(*any
125 .downcast::<errors::CompleteRoundError<R::Error, Infallible>>()
126 .or(Err(errors::Bug::MismatchedErrorType))?),
127 Err(err) => Err(errors::Bug::TakeRoundResult(err).into()),
128 }
129 }
130}
131
132pub struct Round<S> {
136 _ph: PhantomType<S>,
137}
138
139impl<S> core::fmt::Debug for Round<S> {
140 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
141 f.debug_struct("Round").finish_non_exhaustive()
142 }
143}
144
145trait ProcessRoundMessage {
146 type Msg;
147
148 fn process_message(&mut self, msg: Incoming<Self::Msg>);
153
154 fn needs_more_messages(&self) -> NeedsMoreMessages;
159
160 #[allow(clippy::type_complexity)]
169 fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError>;
170}
171
172#[derive(Debug, thiserror::Error)]
173enum TakeOutputError {
174 #[error("output is already taken")]
175 AlreadyTaken,
176 #[error("output is not ready yet, more messages are needed")]
177 NotReady,
178}
179
180enum ProcessRoundMessageImpl<S: RoundStore, M: ProtocolMsg + RoundMsg<S::Msg>> {
181 InProgress { store: S, _ph: PhantomType<fn(M)> },
182 Completed(Result<S::Output, errors::CompleteRoundError<S::Error, Infallible>>),
183 Gone,
184}
185
186impl<S: RoundStore, M: ProtocolMsg + RoundMsg<S::Msg>> ProcessRoundMessageImpl<S, M> {
187 pub fn new(store: S) -> Self {
188 if store.wants_more() {
189 Self::InProgress {
190 store,
191 _ph: Default::default(),
192 }
193 } else {
194 Self::Completed(
195 store
196 .output()
197 .map_err(|_| errors::ImproperRoundStore::StoreDidntOutput.into()),
198 )
199 }
200 }
201}
202
203impl<S, M> ProcessRoundMessageImpl<S, M>
204where
205 S: RoundStore,
206 M: ProtocolMsg + RoundMsg<S::Msg>,
207{
208 fn _process_message(
209 store: &mut S,
210 msg: Incoming<M>,
211 ) -> Result<(), errors::CompleteRoundError<S::Error, Infallible>> {
212 let msg = msg.try_map(M::from_protocol_msg).map_err(|msg| {
213 errors::Bug::MessageFromAnotherRound {
214 actual_number: msg.round(),
215 expected_round: M::ROUND,
216 }
217 })?;
218
219 store
220 .add_message(msg)
221 .map_err(errors::CompleteRoundError::ProcessMsg)?;
222 Ok(())
223 }
224}
225
226impl<S, M> ProcessRoundMessage for ProcessRoundMessageImpl<S, M>
227where
228 S: RoundStore,
229 M: ProtocolMsg + RoundMsg<S::Msg>,
230{
231 type Msg = M;
232
233 fn process_message(&mut self, msg: Incoming<Self::Msg>) {
234 let store = match self {
235 Self::InProgress { store, .. } => store,
236 _ => {
237 return;
238 }
239 };
240
241 match Self::_process_message(store, msg) {
242 Ok(()) => {
243 if store.wants_more() {
244 return;
245 }
246
247 let store = match mem::replace(self, Self::Gone) {
248 Self::InProgress { store, .. } => store,
249 _ => {
250 *self = Self::Completed(Err(errors::Bug::IncoherentState {
251 expected: "InProgress",
252 justification:
253 "we checked at beginning of the function that `state` is InProgress",
254 }.into()));
255 return;
256 }
257 };
258
259 match store.output() {
260 Ok(output) => *self = Self::Completed(Ok(output)),
261 Err(_err) => {
262 *self = Self::Completed(Err(
263 errors::ImproperRoundStore::StoreDidntOutput.into()
264 ))
265 }
266 }
267 }
268 Err(err) => {
269 *self = Self::Completed(Err(err));
270 }
271 }
272 }
273
274 fn needs_more_messages(&self) -> NeedsMoreMessages {
275 match self {
276 Self::InProgress { .. } => NeedsMoreMessages::Yes,
277 _ => NeedsMoreMessages::No,
278 }
279 }
280
281 fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError> {
282 match self {
283 Self::InProgress { .. } => return Err(TakeOutputError::NotReady),
284 Self::Gone => return Err(TakeOutputError::AlreadyTaken),
285 _ => (),
286 }
287 match mem::replace(self, Self::Gone) {
288 Self::Completed(Ok(output)) => Ok(Ok(Box::new(output))),
289 Self::Completed(Err(err)) => Ok(Err(Box::new(err))),
290 _ => unreachable!("it's checked to be completed"),
291 }
292 }
293}
294
295enum NeedsMoreMessages {
296 Yes,
297 No,
298}
299
300#[allow(dead_code)]
301impl NeedsMoreMessages {
302 pub fn yes(&self) -> bool {
303 matches!(self, Self::Yes)
304 }
305 pub fn no(&self) -> bool {
306 matches!(self, Self::No)
307 }
308}
309
310pub mod errors {
312 pub use crate::mpc::party::CompleteRoundError;
313
314 use super::TakeOutputError;
315
316 #[derive(Debug, thiserror::Error)]
317 #[error("received a message for unregistered round")]
318 pub(in crate::mpc) struct UnregisteredRound {
319 pub n: u16,
320 pub(super) witness_provided: bool,
321 }
322
323 #[derive(Debug, thiserror::Error)]
327 #[error(transparent)]
328 pub struct RouterError(Reason);
329
330 #[derive(Debug, thiserror::Error)]
331 pub(super) enum Reason {
332 #[error("api misuse")]
339 ApiMisuse(#[source] ApiMisuse),
340 #[error("improper round store")]
347 ImproperRoundStore(#[source] ImproperRoundStore),
348 #[error("bug (please, open an issue)")]
350 Bug(#[source] Bug),
351 }
352
353 #[derive(Debug, thiserror::Error)]
354 pub(super) enum ApiMisuse {
355 #[error(transparent)]
356 UnregisteredRound(#[from] UnregisteredRound),
357 }
358
359 #[derive(Debug, thiserror::Error)]
360 pub(super) enum ImproperRoundStore {
361 #[error("store didn't output")]
365 StoreDidntOutput,
366 }
367
368 #[derive(Debug, thiserror::Error)]
369 pub(super) enum Bug {
370 #[error("round is gone, but witness exists")]
371 RoundGoneButWitnessExists { n: u16 },
372 #[error(
373 "message originates from another round: we process messages from round \
374 {expected_round}, got message from round {actual_number}"
375 )]
376 MessageFromAnotherRound {
377 expected_round: u16,
378 actual_number: u16,
379 },
380 #[error("state is incoherent, it's expected to be {expected}: {justification}")]
381 IncoherentState {
382 expected: &'static str,
383 justification: &'static str,
384 },
385 #[error("take round result")]
386 TakeRoundResult(#[source] TakeOutputError),
387 #[error("mismatched output type")]
388 MismatchedOutputType,
389 #[error("mismatched error type")]
390 MismatchedErrorType,
391 }
392
393 macro_rules! impl_round_complete_from {
394 ($(|$err:ident: $err_ty:ty| $err_fn:expr),+$(,)?) => {$(
395 impl<E, IoErr> From<$err_ty> for CompleteRoundError<E, IoErr> {
396 fn from($err: $err_ty) -> Self {
397 $err_fn
398 }
399 }
400 )+};
401 }
402
403 impl_round_complete_from! {
404 |err: ApiMisuse| CompleteRoundError::Router(RouterError(Reason::ApiMisuse(err))),
405 |err: ImproperRoundStore| CompleteRoundError::Router(RouterError(Reason::ImproperRoundStore(err))),
406 |err: Bug| CompleteRoundError::Router(RouterError(Reason::Bug(err))),
407 |err: UnregisteredRound| ApiMisuse::UnregisteredRound(err).into(),
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 struct Store;
414
415 #[derive(crate::ProtocolMsg)]
416 #[protocol_msg(root = crate)]
417 enum FakeProtocolMsg {
418 R1(Msg1),
419 }
420 struct Msg1;
421
422 impl crate::round::RoundInfo for Store {
423 type Msg = Msg1;
424 type Output = ();
425 type Error = core::convert::Infallible;
426 }
427 impl crate::round::RoundStore for Store {
428 fn add_message(&mut self, _msg: crate::Incoming<Self::Msg>) -> Result<(), Self::Error> {
429 Ok(())
430 }
431 fn wants_more(&self) -> bool {
432 false
433 }
434 fn output(self) -> Result<Self::Output, Self> {
435 Ok(())
436 }
437 }
438
439 #[test]
440 fn complete_round_that_expects_no_messages() {
441 let mut rounds = super::RoundsRouter::<FakeProtocolMsg>::new();
442 let round1 = rounds.add_round(Store);
443
444 rounds.complete_round(round1).unwrap().unwrap();
445 }
446}