use std::fmt::{self, Debug};
use std::future::Future;
use futures::future::{Either, FutureExt};
use futures::sink::Sink;
use futures::stream::{self, FusedStream, Stream, StreamExt};
use futures::SinkExt;
use tokio::time::{self, timeout_at};
use crate::{IsCritical, Msg, StateMachine};
use watcher::{BlindWatcher, ProtocolWatcher, When};
pub mod watcher;
pub struct AsyncProtocol<SM, I, O, W = BlindWatcher> {
state: Option<SM>,
incoming: I,
outcoming: O,
deadline: Option<time::Instant>,
current_round: Option<u16>,
watcher: W,
}
impl<SM, I, O> AsyncProtocol<SM, I, O, BlindWatcher> {
pub fn new(state: SM, incoming: I, outcoming: O) -> Self {
Self {
state: Some(state),
incoming,
outcoming,
deadline: None,
current_round: None,
watcher: BlindWatcher,
}
}
}
impl<SM, I, O, W> AsyncProtocol<SM, I, O, W> {
pub fn set_watcher<WR>(self, watcher: WR) -> AsyncProtocol<SM, I, O, WR> {
AsyncProtocol {
state: self.state,
incoming: self.incoming,
outcoming: self.outcoming,
deadline: self.deadline,
current_round: self.current_round,
watcher,
}
}
}
impl<SM, I, O, IErr, W> AsyncProtocol<SM, I, O, W>
where
SM: StateMachine,
SM::Err: Send,
SM: Send + 'static,
I: Stream<Item = Result<Msg<SM::MessageBody>, IErr>> + FusedStream + Unpin,
O: Sink<Msg<SM::MessageBody>> + Unpin,
W: ProtocolWatcher<SM>,
{
pub async fn run(&mut self) -> Result<SM::Output, Error<SM::Err, IErr, O::Error>> {
if self.current_round.is_some() {
return Err(Error::Exhausted);
}
self.refresh_timer()?;
self.proceed_if_needed().await?;
self.send_outcoming().await?;
self.refresh_timer()?;
if let Some(result) = self.finish_if_possible() {
return result;
}
loop {
self.handle_incoming().await?;
self.send_outcoming().await?;
self.refresh_timer()?;
self.proceed_if_needed().await?;
self.send_outcoming().await?;
self.refresh_timer()?;
if let Some(result) = self.finish_if_possible() {
return result;
}
}
}
async fn handle_incoming(&mut self) -> Result<(), Error<SM::Err, IErr, O::Error>> {
let state = self.state.as_mut().ok_or(InternalError::MissingState)?;
match Self::enforce_timeout(self.deadline, self.incoming.next()).await {
Ok(Some(Ok(msg))) => match state.handle_incoming(msg) {
Ok(()) => (),
Err(err) if err.is_critical() => return Err(Error::HandleIncoming(err)),
Err(err) => self
.watcher
.caught_non_critical_error(When::HandleIncoming, err),
},
Ok(Some(Err(err))) => return Err(Error::Recv(err)),
Ok(None) => return Err(Error::RecvEof),
Err(_) => {
let err = state.round_timeout_reached();
return Err(Error::HandleIncomingTimeout(err));
}
}
Ok(())
}
async fn proceed_if_needed(&mut self) -> Result<(), Error<SM::Err, IErr, O::Error>> {
let mut state = self.state.take().ok_or(InternalError::MissingState)?;
if state.wants_to_proceed() {
let (result, s) = tokio::task::spawn_blocking(move || (state.proceed(), state))
.await
.map_err(Error::ProceedPanicked)?;
state = s;
match result {
Ok(()) => (),
Err(err) if err.is_critical() => return Err(Error::Proceed(err)),
Err(err) => self.watcher.caught_non_critical_error(When::Proceed, err),
}
}
self.state = Some(state);
Ok(())
}
async fn send_outcoming(&mut self) -> Result<(), Error<SM::Err, IErr, O::Error>> {
let state = self.state.as_mut().ok_or(InternalError::MissingState)?;
if !state.message_queue().is_empty() {
let mut msgs = stream::iter(state.message_queue().drain(..).map(Ok));
self.outcoming
.send_all(&mut msgs)
.await
.map_err(Error::Send)?;
}
Ok(())
}
fn finish_if_possible(&mut self) -> Option<Result<SM::Output, Error<SM::Err, IErr, O::Error>>> {
let state = match self.state.as_mut() {
Some(s) => s,
None => return Some(Err(InternalError::MissingState.into())),
};
if !state.is_finished() {
None
} else {
match state.pick_output() {
Some(Ok(result)) => Some(Ok(result)),
Some(Err(err)) => Some(Err(Error::Finish(err))),
None => Some(Err(
BadStateMachineReason::ProtocolFinishedButNoResult.into()
)),
}
}
}
fn refresh_timer(&mut self) -> Result<(), Error<SM::Err, IErr, O::Error>> {
let state = self.state.as_mut().ok_or(InternalError::MissingState)?;
let round_n = state.current_round();
if self.current_round != Some(round_n) {
self.current_round = Some(round_n);
self.deadline = match state.round_timeout() {
Some(timeout) => Some(time::Instant::now() + timeout),
None => None,
}
}
Ok(())
}
fn enforce_timeout<F>(
deadline: Option<time::Instant>,
f: F,
) -> impl Future<Output = Result<F::Output, time::error::Elapsed>>
where
F: Future,
{
match deadline {
Some(deadline) => Either::Right(timeout_at(deadline, f)),
None => Either::Left(f.map(Ok)),
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Error<E, RE, SE> {
Recv(RE),
RecvEof,
Send(SE),
HandleIncoming(E),
HandleIncomingTimeout(E),
ProceedPanicked(tokio::task::JoinError),
Proceed(E),
Finish(E),
Exhausted,
BadStateMachine(BadStateMachineReason),
InternalError(InternalError),
}
impl<E, RE, SE> From<BadStateMachineReason> for Error<E, RE, SE> {
fn from(reason: BadStateMachineReason) -> Self {
Error::BadStateMachine(reason)
}
}
impl<E, RE, SE> From<InternalError> for Error<E, RE, SE> {
fn from(err: InternalError) -> Self {
Error::InternalError(err)
}
}
impl<E, RE, SE> fmt::Display for Error<E, RE, SE>
where
E: fmt::Display,
RE: fmt::Display,
SE: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Recv(err) => {
write!(f, "receive next message: {}", err)
}
Self::RecvEof => {
write!(f, "receive next message: unexpected eof")
}
Self::Send(err) => {
write!(f, "send a message: {}", err)
}
Self::HandleIncoming(err) => {
write!(f, "handle received message: {}", err)
}
Self::HandleIncomingTimeout(err) => {
write!(f, "round timeout reached: {}", err)
}
Self::ProceedPanicked(err) => {
write!(f, "proceed round panicked: {}", err)
}
Self::Proceed(err) => {
write!(f, "round proceed error: {}", err)
}
Self::Finish(err) => {
write!(f, "couldn't finish protocol: {}", err)
}
Self::Exhausted => {
write!(f, "async runtime is exhausted")
}
Self::BadStateMachine(err) => {
write!(f, "buggy state machine implementation: {}", err)
}
Self::InternalError(err) => {
write!(f, "internal error: {:?}", err)
}
}
}
}
impl<E, RE, SE> std::error::Error for Error<E, RE, SE>
where
E: std::error::Error + 'static,
RE: std::error::Error + 'static,
SE: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Recv(err) => Some(err),
Self::Send(err) => Some(err),
Self::HandleIncoming(err) => Some(err),
Self::HandleIncomingTimeout(err) => Some(err),
Self::ProceedPanicked(err) => Some(err),
Self::Proceed(err) => Some(err),
Self::Finish(err) => Some(err),
Self::RecvEof => None,
Self::Exhausted => None,
Self::BadStateMachine(_) => None,
Self::InternalError(_) => None,
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum BadStateMachineReason {
ProtocolFinishedButNoResult,
}
impl fmt::Display for BadStateMachineReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ProtocolFinishedButNoResult => write!(
f,
"couldn't obtain protocol output although it is completed"
),
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum InternalError {
MissingState,
}