mod delivery;
mod noop_waker;
mod runtime;
mod shared_state;
use core::{future::Future, task::Poll};
pub use self::{
delivery::{Incomings, Outgoings, SendErr},
runtime::{Runtime, YieldNow},
};
pub trait StateMachine {
type Output;
type Msg;
fn proceed(&mut self) -> ProceedResult<Self::Output, Self::Msg>;
fn received_msg(
&mut self,
msg: crate::Incoming<Self::Msg>,
) -> Result<(), crate::Incoming<Self::Msg>>;
}
#[must_use = "ProceedResult must be used to correctly carry out the state machine"]
pub enum ProceedResult<O, M> {
SendMsg(crate::Outgoing<M>),
NeedsOneMoreMessage,
Output(O),
Yielded,
Error(ExecutionError),
}
impl<O, M> core::fmt::Debug for ProceedResult<O, M> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ProceedResult::SendMsg(_) => f.write_str("SendMsg"),
ProceedResult::NeedsOneMoreMessage => f.write_str("NeedsOneMoreMessage"),
ProceedResult::Output(_) => f.write_str("Output"),
ProceedResult::Yielded => f.write_str("Yielded"),
ProceedResult::Error(_) => f.write_str("Error"),
}
}
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ExecutionError(Reason);
#[derive(Debug, thiserror::Error)]
enum Reason {
#[error("resuming state machine when protocol is already finished")]
Exhausted,
#[error("protocol polls unknown (unsupported) future")]
PollingUnknownFuture,
}
impl<O, M> From<Reason> for ProceedResult<O, M> {
fn from(err: Reason) -> Self {
ProceedResult::Error(ExecutionError(err))
}
}
impl From<Reason> for ExecutionError {
fn from(err: Reason) -> Self {
ExecutionError(err)
}
}
struct StateMachineImpl<O, M, F: Future<Output = O>> {
shared_state: shared_state::SharedStateRef<M>,
exhausted: bool,
future: core::pin::Pin<alloc::boxed::Box<F>>,
}
impl<O, M, F> StateMachine for StateMachineImpl<O, M, F>
where
F: Future<Output = O>,
{
type Output = O;
type Msg = M;
fn proceed(&mut self) -> ProceedResult<Self::Output, Self::Msg> {
if self.exhausted {
return Reason::Exhausted.into();
}
let future = self.future.as_mut();
let waker = noop_waker::noop_waker();
let mut cx = core::task::Context::from_waker(&waker);
match future.poll(&mut cx) {
Poll::Ready(output) => {
self.exhausted = true;
ProceedResult::Output(output)
}
Poll::Pending => {
if let Some(outgoing_msg) = self.shared_state.executor_takes_outgoing_msg() {
return ProceedResult::SendMsg(outgoing_msg);
}
if self.shared_state.protocol_wants_more_messages() {
return ProceedResult::NeedsOneMoreMessage;
}
if self.shared_state.executor_reads_and_resets_yielded_flag() {
return ProceedResult::Yielded;
}
Reason::PollingUnknownFuture.into()
}
}
}
fn received_msg(&mut self, msg: crate::Incoming<Self::Msg>) -> Result<(), crate::Incoming<M>> {
self.shared_state.executor_received_msg(msg)
}
}
pub type Delivery<M> = (Incomings<M>, Outgoings<M>);
pub type MpcParty<M> = crate::MpcParty<M, Delivery<M>, Runtime<M>>;
pub fn wrap_protocol<'a, M, F>(
protocol: impl FnOnce(MpcParty<M>) -> F,
) -> impl StateMachine<Output = F::Output, Msg = M> + 'a
where
F: Future + 'a,
M: 'static,
{
let shared_state = shared_state::SharedStateRef::new();
let incomings = Incomings::new(shared_state.clone());
let outgoings = Outgoings::new(shared_state.clone());
let delivery = (incomings, outgoings);
let runtime = Runtime::new(shared_state.clone());
let future = protocol(crate::MpcParty::connected(delivery).set_runtime(runtime));
let future = alloc::boxed::Box::pin(future);
StateMachineImpl {
shared_state,
exhausted: false,
future,
}
}