round_based/state_machine/
mod.rs

1//! Wraps the protocol defined as async function and provides sync API to execute it
2//!
3//! In `round_based` framework, MPC protocols are defined as async function. However, sometimes it
4//! may not be possible/desirable to have async runtime which drives the futures until completion.
5//! For such use-cases, we provide [`wrap_protocol`] function that wraps an MPC protocol defined as
6//! async function and returns the [`StateMachine`] that exposes sync API to carry out the protocol.
7//!
8//! ## Example
9//! ```rust,no_run
10//! # fn main() -> anyhow::Result<()> {
11//! use round_based::{Mpc, PartyIndex};
12//! use anyhow::{Result, Error, Context as _};
13//!
14//! # type Randomness = [u8; 32];
15//! # type Msg = ();
16//! // Any MPC protocol
17//! pub async fn protocol_of_random_generation<M>(
18//!     party: M,
19//!     i: PartyIndex,
20//!     n: u16
21//! ) -> Result<Randomness>
22//! where
23//!     M: Mpc<ProtocolMessage = Msg>
24//! {
25//!     // ...
26//! # todo!()
27//! }
28//!
29//! // `state` implements `round_based::state_machine::StateMachine` trait.
30//! // Its methods can be used to advance protocol until completion.
31//! let mut state = round_based::state_machine::wrap_protocol(
32//!     |party| protocol_of_random_generation(party, 0, 3)
33//! );
34//!
35//! fn send(msg: round_based::Outgoing<Msg>) -> Result<()> {
36//!     // sends outgoing message...
37//! # unimplemented!()
38//! }
39//! fn recv() -> Result<round_based::Incoming<Msg>> {
40//!     // receives incoming message...
41//! # unimplemented!()
42//! }
43//!
44//! use round_based::state_machine::{StateMachine as _, ProceedResult};
45//! let output = loop {
46//!     match state.proceed() {
47//!         ProceedResult::SendMsg(msg) => {
48//!             send(msg)?
49//!         }
50//!         ProceedResult::NeedsOneMoreMessage => {
51//!             let msg = recv()?;
52//!             state.received_msg(msg)
53//!                 .map_err(|_| anyhow::format_err!("state machine rejected received message"))?;
54//!         }
55//!         ProceedResult::Yielded => {},
56//!         ProceedResult::Output(out) => break Ok(out),
57//!         ProceedResult::Error(err) => break Err(err),
58//!     }
59//! };
60//! # Ok(()) }
61//! ```
62
63mod delivery;
64mod noop_waker;
65mod runtime;
66mod shared_state;
67
68use core::{future::Future, task::Poll};
69
70pub use self::{
71    delivery::{Incomings, Outgoings, SendErr},
72    runtime::{Runtime, YieldNow},
73};
74
75/// Provides interface to execute the protocol
76pub trait StateMachine {
77    /// Output of the protocol
78    type Output;
79    /// Message of the protocol
80    type Msg;
81
82    /// Resumes protocol execution
83    ///
84    /// Returns [`ProceedResult`] which will indicate, for instance, if the protocol wants to send
85    /// or receive a message, or if it's finished.
86    ///
87    /// Calling `proceed` after protocol has finished (after it returned [`ProceedResult::Output`])
88    /// returns an error.
89    fn proceed(&mut self) -> ProceedResult<Self::Output, Self::Msg>;
90    /// Saves received message to be picked up by the state machine on the next [`proceed`](Self::proceed) invocation
91    ///
92    /// This method should only be called if state machine returned [`ProceedResult::NeedsOneMoreMessage`] on previous
93    /// invocation of [`proceed`](Self::proceed) method. Calling this method when state machine did not request it
94    /// may return error.
95    ///
96    /// Calling this method must be followed up by calling [`proceed`](Self::proceed). Do not invoke this method
97    /// more than once in a row, even if you have available messages received from other parties. Instead, you
98    /// should call this method, then call `proceed`, and only if it returned [`ProceedResult::NeedsOneMoreMessage`]
99    /// you can call `received_msg` again.
100    fn received_msg(
101        &mut self,
102        msg: crate::Incoming<Self::Msg>,
103    ) -> Result<(), crate::Incoming<Self::Msg>>;
104}
105
106/// Tells why protocol execution stopped
107#[must_use = "ProceedResult must be used to correctly carry out the state machine"]
108pub enum ProceedResult<O, M> {
109    /// Protocol needs provided message to be sent
110    SendMsg(crate::Outgoing<M>),
111    /// Protocol needs one more message to be received
112    ///
113    /// After the state machine requested one more message, the next call to the state machine must
114    /// be [`StateMachine::received_msg`].
115    NeedsOneMoreMessage,
116    /// Protocol is finished
117    Output(O),
118    /// Protocol yielded the execution
119    ///
120    /// Protocol may yield at any point by calling `AsyncRuntime::yield_now`. Main motivation
121    /// for yielding is to break a long computation into smaller parts, so proceeding state
122    /// machine doesn't take too long.
123    ///
124    /// When protocol yields, you can resume the execution by calling [`proceed`](StateMachine::proceed)
125    /// immediately.
126    Yielded,
127    /// State machine failed to carry out the protocol
128    ///
129    /// Error likely means that either state machine is misused (e.g. when [`proceed`](StateMachine::proceed)
130    /// is called after protocol is finished) or protocol implementation is not supported by state machine
131    /// executor (e.g. it polls unknown future).
132    Error(ExecutionError),
133}
134
135impl<O, M> core::fmt::Debug for ProceedResult<O, M> {
136    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
137        match self {
138            ProceedResult::SendMsg(_) => f.write_str("SendMsg"),
139            ProceedResult::NeedsOneMoreMessage => f.write_str("NeedsOneMoreMessage"),
140            ProceedResult::Output(_) => f.write_str("Output"),
141            ProceedResult::Yielded => f.write_str("Yielded"),
142            ProceedResult::Error(_) => f.write_str("Error"),
143        }
144    }
145}
146
147/// Error type which indicates that state machine failed to carry out the protocol
148#[derive(Debug, thiserror::Error)]
149#[error(transparent)]
150pub struct ExecutionError(Reason);
151
152#[derive(Debug, thiserror::Error)]
153enum Reason {
154    #[error("resuming state machine when protocol is already finished")]
155    Exhausted,
156    #[error("protocol polls unknown (unsupported) future")]
157    PollingUnknownFuture,
158}
159
160impl<O, M> From<Reason> for ProceedResult<O, M> {
161    fn from(err: Reason) -> Self {
162        ProceedResult::Error(ExecutionError(err))
163    }
164}
165impl From<Reason> for ExecutionError {
166    fn from(err: Reason) -> Self {
167        ExecutionError(err)
168    }
169}
170
171struct StateMachineImpl<O, M, F: Future<Output = O>> {
172    shared_state: shared_state::SharedStateRef<M>,
173    exhausted: bool,
174    future: core::pin::Pin<alloc::boxed::Box<F>>,
175}
176
177impl<O, M, F> StateMachine for StateMachineImpl<O, M, F>
178where
179    F: Future<Output = O>,
180{
181    type Output = O;
182    type Msg = M;
183
184    fn proceed(&mut self) -> ProceedResult<Self::Output, Self::Msg> {
185        if self.exhausted {
186            return Reason::Exhausted.into();
187        }
188        let future = self.future.as_mut();
189        let waker = noop_waker::noop_waker();
190        let mut cx = core::task::Context::from_waker(&waker);
191        match future.poll(&mut cx) {
192            Poll::Ready(output) => {
193                self.exhausted = true;
194                ProceedResult::Output(output)
195            }
196            Poll::Pending => {
197                // underlying future may `await` only on either:
198                // 1. Flushing outgoing message
199                // 2. Waiting for incoming message
200                // 3. Yielding
201
202                // Check if it's flushing outgoing message:
203                if let Some(outgoing_msg) = self.shared_state.executor_takes_outgoing_msg() {
204                    return ProceedResult::SendMsg(outgoing_msg);
205                }
206
207                // Check if it's waiting for a new message
208                if self.shared_state.protocol_wants_more_messages() {
209                    return ProceedResult::NeedsOneMoreMessage;
210                }
211
212                // Check if protocol yielded
213                if self.shared_state.executor_reads_and_resets_yielded_flag() {
214                    return ProceedResult::Yielded;
215                }
216
217                // If none of above conditions are met, then protocol is polling
218                // a future which we do not recognize
219                Reason::PollingUnknownFuture.into()
220            }
221        }
222    }
223
224    fn received_msg(&mut self, msg: crate::Incoming<Self::Msg>) -> Result<(), crate::Incoming<M>> {
225        self.shared_state.executor_received_msg(msg)
226    }
227}
228
229/// Delivery implementation used in the state machine
230pub type Delivery<M> = (Incomings<M>, Outgoings<M>);
231
232/// MpcParty instantiated with state machine implementation of delivery and async runtime
233pub type MpcParty<M> = crate::MpcParty<M, Delivery<M>, Runtime<M>>;
234
235/// Wraps the protocol and provides sync API to execute it
236///
237/// Protocol is an async function that takes [`MpcParty`] as input. `MpcParty` contains
238/// channels (of incoming and outgoing messages) that protocol is expected to use, and
239/// a [`Runtime`]. Protocol is only allowed to `.await` on futures provided in `MpcParty`,
240/// such as polling next message from provided steam of incoming messages. If protocol
241/// polls an unknown future, executor won't know what to do with that, the protocol will
242/// be aborted and error returned.
243pub fn wrap_protocol<'a, M, F>(
244    protocol: impl FnOnce(MpcParty<M>) -> F,
245) -> impl StateMachine<Output = F::Output, Msg = M> + 'a
246where
247    F: Future + 'a,
248    M: 'static,
249{
250    let shared_state = shared_state::SharedStateRef::new();
251    let incomings = Incomings::new(shared_state.clone());
252    let outgoings = Outgoings::new(shared_state.clone());
253    let delivery = (incomings, outgoings);
254    let runtime = Runtime::new(shared_state.clone());
255
256    let future = protocol(crate::MpcParty::connected(delivery).set_runtime(runtime));
257    let future = alloc::boxed::Box::pin(future);
258
259    StateMachineImpl {
260        shared_state,
261        exhausted: false,
262        future,
263    }
264}