1use alloc::{boxed::Box, collections::BTreeMap};
51use core::{any::Any, convert::Infallible, mem};
52
53use futures_util::{Stream, StreamExt};
54use phantom_type::PhantomType;
55use tracing::{debug, error, trace, trace_span, warn, Span};
56
57use crate::Incoming;
58
59#[doc(inline)]
60pub use self::errors::CompleteRoundError;
61pub use self::store::*;
62
63pub mod simple_store;
64mod store;
65
66pub struct RoundsRouter<M, S = ()> {
70 incomings: S,
71 rounds: BTreeMap<u16, Option<Box<dyn ProcessRoundMessage<Msg = M> + Send>>>,
72}
73
74impl<M: ProtocolMessage + 'static> RoundsRouter<M> {
75 pub fn builder() -> RoundsRouterBuilder<M> {
77 RoundsRouterBuilder::new()
78 }
79}
80
81impl<M, S, E> RoundsRouter<M, S>
82where
83 M: ProtocolMessage,
84 S: Stream<Item = Result<Incoming<M>, E>> + Unpin,
85 E: core::error::Error,
86{
87 #[inline(always)]
92 pub async fn complete<R>(
93 &mut self,
94 round: Round<R>,
95 ) -> Result<R::Output, CompleteRoundError<R::Error, E>>
96 where
97 R: MessagesStore,
98 M: RoundMessage<R::Msg>,
99 {
100 let round_number = <M as RoundMessage<R::Msg>>::ROUND;
101 let span = trace_span!("Round", n = round_number);
102 debug!(parent: &span, "pending round to complete");
103
104 match self.complete_with_span(&span, round).await {
105 Ok(output) => {
106 trace!(parent: &span, "round successfully completed");
107 Ok(output)
108 }
109 Err(err) => {
110 error!(parent: &span, %err, "round terminated with error");
111 Err(err)
112 }
113 }
114 }
115
116 async fn complete_with_span<R>(
117 &mut self,
118 span: &Span,
119 _round: Round<R>,
120 ) -> Result<R::Output, CompleteRoundError<R::Error, E>>
121 where
122 R: MessagesStore,
123 M: RoundMessage<R::Msg>,
124 {
125 let pending_round = <M as RoundMessage<R::Msg>>::ROUND;
126 if let Some(output) = self.retrieve_round_output_if_its_completed::<R>() {
127 return output;
128 }
129
130 loop {
131 let incoming = match self.incomings.next().await {
132 Some(Ok(msg)) => msg,
133 Some(Err(err)) => return Err(errors::IoError::Io(err).into()),
134 None => return Err(errors::IoError::UnexpectedEof.into()),
135 };
136 let message_round_n = incoming.msg.round();
137
138 let message_round = match self.rounds.get_mut(&message_round_n) {
139 Some(Some(round)) => round,
140 Some(None) => {
141 warn!(
142 parent: span,
143 n = message_round_n,
144 "got message for the round that was already completed, ignoring it"
145 );
146 continue;
147 }
148 None => {
149 return Err(
150 errors::RoundsMisuse::UnregisteredRound { n: message_round_n }.into(),
151 )
152 }
153 };
154 if message_round.needs_more_messages().no() {
155 warn!(
156 parent: span,
157 n = message_round_n,
158 "received message for the round that was already completed, ignoring it"
159 );
160 continue;
161 }
162 message_round.process_message(incoming);
163
164 if pending_round == message_round_n {
165 if let Some(output) = self.retrieve_round_output_if_its_completed::<R>() {
166 return output;
167 }
168 }
169 }
170 }
171
172 #[allow(clippy::type_complexity)]
173 fn retrieve_round_output_if_its_completed<R>(
174 &mut self,
175 ) -> Option<Result<R::Output, CompleteRoundError<R::Error, E>>>
176 where
177 R: MessagesStore,
178 M: RoundMessage<R::Msg>,
179 {
180 let round_number = <M as RoundMessage<R::Msg>>::ROUND;
181 let round_slot = match self
182 .rounds
183 .get_mut(&round_number)
184 .ok_or(errors::RoundsMisuse::UnregisteredRound { n: round_number })
185 {
186 Ok(slot) => slot,
187 Err(err) => return Some(Err(err.into())),
188 };
189 let round = match round_slot
190 .as_mut()
191 .ok_or(errors::RoundsMisuse::RoundAlreadyCompleted)
192 {
193 Ok(round) => round,
194 Err(err) => return Some(Err(err.into())),
195 };
196 if round.needs_more_messages().no() {
197 Some(Self::retrieve_round_output::<R>(round_slot))
198 } else {
199 None
200 }
201 }
202
203 fn retrieve_round_output<R>(
204 slot: &mut Option<Box<dyn ProcessRoundMessage<Msg = M> + Send>>,
205 ) -> Result<R::Output, CompleteRoundError<R::Error, E>>
206 where
207 R: MessagesStore,
208 M: RoundMessage<R::Msg>,
209 {
210 let mut round = slot.take().ok_or(errors::RoundsMisuse::UnregisteredRound {
211 n: <M as RoundMessage<R::Msg>>::ROUND,
212 })?;
213 match round.take_output() {
214 Ok(Ok(any)) => Ok(*any
215 .downcast::<R::Output>()
216 .or(Err(CompleteRoundError::from(
217 errors::Bug::MismatchedOutputType,
218 )))?),
219 Ok(Err(any)) => Err(any
220 .downcast::<CompleteRoundError<R::Error, Infallible>>()
221 .or(Err(CompleteRoundError::from(
222 errors::Bug::MismatchedErrorType,
223 )))?
224 .map_io_err(|e| match e {})),
225 Err(err) => Err(errors::Bug::TakeRoundResult(err).into()),
226 }
227 }
228}
229
230pub struct RoundsRouterBuilder<M> {
232 rounds: BTreeMap<u16, Option<Box<dyn ProcessRoundMessage<Msg = M> + Send>>>,
233}
234
235impl<M> Default for RoundsRouterBuilder<M>
236where
237 M: ProtocolMessage + 'static,
238{
239 fn default() -> Self {
240 Self::new()
241 }
242}
243
244impl<M> RoundsRouterBuilder<M>
245where
246 M: ProtocolMessage + 'static,
247{
248 pub fn new() -> Self {
252 Self {
253 rounds: BTreeMap::new(),
254 }
255 }
256
257 pub fn add_round<R>(&mut self, message_store: R) -> Round<R>
262 where
263 R: MessagesStore + Send + 'static,
264 R::Output: Send,
265 R::Error: Send,
266 M: RoundMessage<R::Msg>,
267 {
268 let overridden_round = self.rounds.insert(
269 M::ROUND,
270 Some(Box::new(ProcessRoundMessageImpl::new(message_store))),
271 );
272 if overridden_round.is_some() {
273 panic!("round {} is overridden", M::ROUND);
274 }
275 Round {
276 _ph: PhantomType::new(),
277 }
278 }
279
280 pub fn listen<S, E>(self, incomings: S) -> RoundsRouter<M, S>
284 where
285 S: Stream<Item = Result<Incoming<M>, E>>,
286 {
287 RoundsRouter {
288 incomings,
289 rounds: self.rounds,
290 }
291 }
292}
293
294pub struct Round<S: MessagesStore> {
299 _ph: PhantomType<S>,
300}
301
302trait ProcessRoundMessage {
303 type Msg;
304
305 fn process_message(&mut self, msg: Incoming<Self::Msg>);
310
311 fn needs_more_messages(&self) -> NeedsMoreMessages;
316
317 #[allow(clippy::type_complexity)]
326 fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError>;
327}
328
329#[derive(Debug, thiserror::Error)]
330enum TakeOutputError {
331 #[error("output is already taken")]
332 AlreadyTaken,
333 #[error("output is not ready yet, more messages are needed")]
334 NotReady,
335}
336
337enum ProcessRoundMessageImpl<S: MessagesStore, M: ProtocolMessage + RoundMessage<S::Msg>> {
338 InProgress { store: S, _ph: PhantomType<fn(M)> },
339 Completed(Result<S::Output, CompleteRoundError<S::Error, Infallible>>),
340 Gone,
341}
342
343impl<S: MessagesStore, M: ProtocolMessage + RoundMessage<S::Msg>> ProcessRoundMessageImpl<S, M> {
344 pub fn new(store: S) -> Self {
345 if store.wants_more() {
346 Self::InProgress {
347 store,
348 _ph: Default::default(),
349 }
350 } else {
351 Self::Completed(
352 store
353 .output()
354 .map_err(|_| errors::ImproperStoreImpl::StoreDidntOutput.into()),
355 )
356 }
357 }
358}
359
360impl<S, M> ProcessRoundMessageImpl<S, M>
361where
362 S: MessagesStore,
363 M: ProtocolMessage + RoundMessage<S::Msg>,
364{
365 fn _process_message(
366 store: &mut S,
367 msg: Incoming<M>,
368 ) -> Result<(), CompleteRoundError<S::Error, Infallible>> {
369 let msg = msg.try_map(M::from_protocol_message).map_err(|msg| {
370 errors::Bug::MessageFromAnotherRound {
371 actual_number: msg.round(),
372 expected_round: M::ROUND,
373 }
374 })?;
375
376 store
377 .add_message(msg)
378 .map_err(CompleteRoundError::ProcessMessage)?;
379 Ok(())
380 }
381}
382
383impl<S, M> ProcessRoundMessage for ProcessRoundMessageImpl<S, M>
384where
385 S: MessagesStore,
386 M: ProtocolMessage + RoundMessage<S::Msg>,
387{
388 type Msg = M;
389
390 fn process_message(&mut self, msg: Incoming<Self::Msg>) {
391 let store = match self {
392 Self::InProgress { store, .. } => store,
393 _ => {
394 return;
395 }
396 };
397
398 match Self::_process_message(store, msg) {
399 Ok(()) => {
400 if store.wants_more() {
401 return;
402 }
403
404 let store = match mem::replace(self, Self::Gone) {
405 Self::InProgress { store, .. } => store,
406 _ => {
407 *self = Self::Completed(Err(errors::Bug::IncoherentState {
408 expected: "InProgress",
409 justification:
410 "we checked at beginning of the function that `state` is InProgress",
411 }
412 .into()));
413 return;
414 }
415 };
416
417 match store.output() {
418 Ok(output) => *self = Self::Completed(Ok(output)),
419 Err(_err) => {
420 *self =
421 Self::Completed(Err(errors::ImproperStoreImpl::StoreDidntOutput.into()))
422 }
423 }
424 }
425 Err(err) => {
426 *self = Self::Completed(Err(err));
427 }
428 }
429 }
430
431 fn needs_more_messages(&self) -> NeedsMoreMessages {
432 match self {
433 Self::InProgress { .. } => NeedsMoreMessages::Yes,
434 _ => NeedsMoreMessages::No,
435 }
436 }
437
438 fn take_output(&mut self) -> Result<Result<Box<dyn Any>, Box<dyn Any>>, TakeOutputError> {
439 match self {
440 Self::InProgress { .. } => return Err(TakeOutputError::NotReady),
441 Self::Gone => return Err(TakeOutputError::AlreadyTaken),
442 _ => (),
443 }
444 match mem::replace(self, Self::Gone) {
445 Self::Completed(Ok(output)) => Ok(Ok(Box::new(output))),
446 Self::Completed(Err(err)) => Ok(Err(Box::new(err))),
447 _ => unreachable!("it's checked to be completed"),
448 }
449 }
450}
451
452enum NeedsMoreMessages {
453 Yes,
454 No,
455}
456
457#[allow(dead_code)]
458impl NeedsMoreMessages {
459 pub fn yes(&self) -> bool {
460 matches!(self, Self::Yes)
461 }
462 pub fn no(&self) -> bool {
463 matches!(self, Self::No)
464 }
465}
466
467pub mod errors {
469 use super::TakeOutputError;
470
471 #[derive(Debug, thiserror::Error)]
473 pub enum CompleteRoundError<ProcessErr, IoErr> {
474 #[error("failed to process the message")]
476 ProcessMessage(#[source] ProcessErr),
477 #[error("receive next message")]
479 Io(#[from] IoError<IoErr>),
480 #[error("implementation error")]
485 Other(#[source] OtherError),
486 }
487
488 #[derive(Debug, thiserror::Error)]
490 pub enum IoError<E> {
491 #[error("i/o error")]
493 Io(#[source] E),
494 #[error("unexpected eof")]
496 UnexpectedEof,
497 }
498
499 #[derive(Debug, thiserror::Error)]
504 #[error(transparent)]
505 pub struct OtherError(OtherReason);
506
507 #[derive(Debug, thiserror::Error)]
508 pub(super) enum OtherReason {
509 #[error("improper `MessagesStore` implementation")]
510 ImproperStoreImpl(#[source] ImproperStoreImpl),
511 #[error("`Rounds` API misuse")]
512 RoundsMisuse(#[source] RoundsMisuse),
513 #[error("bug in `Rounds` (please, open a issue)")]
514 Bug(#[source] Bug),
515 }
516
517 #[derive(Debug, thiserror::Error)]
518 pub(super) enum ImproperStoreImpl {
519 #[error("store didn't output")]
523 StoreDidntOutput,
524 }
525
526 #[derive(Debug, thiserror::Error)]
527 pub(super) enum RoundsMisuse {
528 #[error("round is already completed")]
529 RoundAlreadyCompleted,
530 #[error("round {n} is not registered")]
531 UnregisteredRound { n: u16 },
532 }
533
534 #[derive(Debug, thiserror::Error)]
535 pub(super) enum Bug {
536 #[error(
537 "message originates from another round: we process messages from round \
538 {expected_round}, got message from round {actual_number}"
539 )]
540 MessageFromAnotherRound {
541 expected_round: u16,
542 actual_number: u16,
543 },
544 #[error("state is incoherent, it's expected to be {expected}: {justification}")]
545 IncoherentState {
546 expected: &'static str,
547 justification: &'static str,
548 },
549 #[error("mismatched output type")]
550 MismatchedOutputType,
551 #[error("mismatched error type")]
552 MismatchedErrorType,
553 #[error("take round result")]
554 TakeRoundResult(#[source] TakeOutputError),
555 }
556
557 impl<ProcessErr, IoErr> CompleteRoundError<ProcessErr, IoErr> {
558 pub(super) fn map_io_err<E, F>(self, f: F) -> CompleteRoundError<ProcessErr, E>
559 where
560 F: FnOnce(IoErr) -> E,
561 {
562 match self {
563 CompleteRoundError::Io(err) => CompleteRoundError::Io(err.map_err(f)),
564 CompleteRoundError::ProcessMessage(err) => CompleteRoundError::ProcessMessage(err),
565 CompleteRoundError::Other(err) => CompleteRoundError::Other(err),
566 }
567 }
568 }
569
570 impl<E> IoError<E> {
571 pub(super) fn map_err<B, F>(self, f: F) -> IoError<B>
572 where
573 F: FnOnce(E) -> B,
574 {
575 match self {
576 IoError::Io(e) => IoError::Io(f(e)),
577 IoError::UnexpectedEof => IoError::UnexpectedEof,
578 }
579 }
580 }
581
582 macro_rules! impl_from_other_error {
583 ($($err:ident),+,) => {$(
584 impl<E1, E2> From<$err> for CompleteRoundError<E1, E2> {
585 fn from(err: $err) -> Self {
586 Self::Other(OtherError(OtherReason::$err(err)))
587 }
588 }
589 )+};
590 }
591
592 impl_from_other_error! {
593 ImproperStoreImpl,
594 RoundsMisuse,
595 Bug,
596 }
597}
598
599#[cfg(test)]
600mod tests {
601 struct Store;
602
603 #[derive(crate::ProtocolMessage)]
604 #[protocol_message(root = crate)]
605 enum FakeProtocolMsg {
606 R1(Msg1),
607 }
608 struct Msg1;
609
610 impl super::MessagesStore for Store {
611 type Msg = Msg1;
612 type Output = ();
613 type Error = core::convert::Infallible;
614
615 fn add_message(&mut self, _msg: crate::Incoming<Self::Msg>) -> Result<(), Self::Error> {
616 Ok(())
617 }
618 fn wants_more(&self) -> bool {
619 false
620 }
621 fn output(self) -> Result<Self::Output, Self> {
622 Ok(())
623 }
624 }
625
626 #[tokio::test]
627 async fn complete_round_that_expects_no_messages() {
628 let incomings = futures::stream::pending::<
629 Result<crate::Incoming<FakeProtocolMsg>, core::convert::Infallible>,
630 >();
631
632 let mut rounds = super::RoundsRouter::builder();
633 let round1 = rounds.add_round(Store);
634 let mut rounds = rounds.listen(incomings);
635
636 rounds.complete(round1).await.unwrap();
637 }
638}