round_based/state_machine/
mod.rs

1//! Wraps the protocol defined as async function and provides sync API to execute it
2//!
3//! In `round_based` framework, MPC protocols are defined as async function. However, sometimes it
4//! may not be possible/desirable to have async runtime which drives the futures until completion.
5//! For such use-cases, we provide [`wrap_protocol`] function that wraps an MPC protocol defined as
6//! async function and returns the [`StateMachine`] that exposes sync API to carry out the protocol.
7//!
8//! ## Example
9//! ```rust,no_run
10//! # fn main() -> anyhow::Result<()> {
11//! use anyhow::{Result, Error, Context as _};
12//!
13//! # type Randomness = [u8; 32];
14//! # #[derive(round_based::ProtocolMsg, Clone)]
15//! # enum Msg {}
16//! // Any MPC protocol
17//! pub async fn protocol_of_random_generation<M>(
18//!     party: M,
19//!     i: u16,
20//!     n: u16
21//! ) -> Result<Randomness>
22//! where
23//!     M: round_based::Mpc<Msg = Msg>
24//! {
25//!     // ...
26//! # todo!()
27//! }
28//!
29//! // `state` implements `round_based::state_machine::StateMachine` trait.
30//! // Its methods can be used to advance protocol until completion.
31//! let mut state = round_based::state_machine::wrap_protocol(
32//!     |party| protocol_of_random_generation(party, 0, 3)
33//! );
34//!
35//! fn send(msg: round_based::Outgoing<Msg>) -> Result<()> {
36//!     // sends outgoing message...
37//! # unimplemented!()
38//! }
39//! fn recv() -> Result<round_based::Incoming<Msg>> {
40//!     // receives incoming message...
41//! # unimplemented!()
42//! }
43//!
44//! use round_based::state_machine::{StateMachine as _, ProceedResult};
45//! let output = loop {
46//!     match state.proceed() {
47//!         ProceedResult::SendMsg(msg) => {
48//!             send(msg)?
49//!         }
50//!         ProceedResult::NeedsOneMoreMessage => {
51//!             let msg = recv()?;
52//!             state.received_msg(msg)
53//!                 .map_err(|_| anyhow::format_err!("state machine rejected received message"))?;
54//!         }
55//!         ProceedResult::Yielded => {},
56//!         ProceedResult::Output(out) => break Ok(out),
57//!         ProceedResult::Error(err) => break Err(err),
58//!     }
59//! };
60//! # Ok(()) }
61//! ```
62
63mod delivery;
64mod noop_waker;
65mod runtime;
66mod shared_state;
67
68use core::{future::Future, task::Poll};
69
70use crate::ProtocolMsg;
71
72pub use self::{
73    delivery::{Delivery, DeliveryErr},
74    runtime::{Runtime, YieldNow},
75};
76
77/// Provides interface to execute the protocol
78pub trait StateMachine {
79    /// Output of the protocol
80    type Output;
81    /// Message of the protocol
82    type Msg;
83
84    /// Resumes protocol execution
85    ///
86    /// Returns [`ProceedResult`] which will indicate, for instance, if the protocol wants to send
87    /// or receive a message, or if it's finished.
88    ///
89    /// Calling `proceed` after protocol has finished (after it returned [`ProceedResult::Output`])
90    /// returns an error.
91    fn proceed(&mut self) -> ProceedResult<Self::Output, Self::Msg>;
92    /// Saves received message to be picked up by the state machine on the next [`proceed`](Self::proceed) invocation
93    ///
94    /// This method should only be called if state machine returned [`ProceedResult::NeedsOneMoreMessage`] on previous
95    /// invocation of [`proceed`](Self::proceed) method. Calling this method when state machine did not request it
96    /// may return error.
97    ///
98    /// Calling this method must be followed up by calling [`proceed`](Self::proceed). Do not invoke this method
99    /// more than once in a row, even if you have available messages received from other parties. Instead, you
100    /// should call this method, then call `proceed`, and only if it returned [`ProceedResult::NeedsOneMoreMessage`]
101    /// you can call `received_msg` again.
102    fn received_msg(
103        &mut self,
104        msg: crate::Incoming<Self::Msg>,
105    ) -> Result<(), crate::Incoming<Self::Msg>>;
106}
107
108/// Tells why protocol execution stopped
109#[must_use = "ProceedResult must be used to correctly carry out the state machine"]
110pub enum ProceedResult<O, M> {
111    /// Protocol needs provided message to be sent
112    SendMsg(crate::Outgoing<M>),
113    /// Protocol needs one more message to be received
114    ///
115    /// After the state machine requested one more message, the next call to the state machine must
116    /// be [`StateMachine::received_msg`].
117    NeedsOneMoreMessage,
118    /// Protocol is finished
119    Output(O),
120    /// Protocol yielded the execution
121    ///
122    /// Protocol may yield at any point by calling `AsyncRuntime::yield_now`. Main motivation
123    /// for yielding is to break a long computation into smaller parts, so proceeding state
124    /// machine doesn't take too long.
125    ///
126    /// When protocol yields, you can resume the execution by calling [`proceed`](StateMachine::proceed)
127    /// immediately.
128    Yielded,
129    /// State machine failed to carry out the protocol
130    ///
131    /// Error likely means that either state machine is misused (e.g. when [`proceed`](StateMachine::proceed)
132    /// is called after protocol is finished) or protocol implementation is not supported by state machine
133    /// executor (e.g. it polls unknown future).
134    Error(ExecutionError),
135}
136
137impl<O, M> core::fmt::Debug for ProceedResult<O, M> {
138    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
139        match self {
140            ProceedResult::SendMsg(_) => f.write_str("SendMsg"),
141            ProceedResult::NeedsOneMoreMessage => f.write_str("NeedsOneMoreMessage"),
142            ProceedResult::Output(_) => f.write_str("Output"),
143            ProceedResult::Yielded => f.write_str("Yielded"),
144            ProceedResult::Error(_) => f.write_str("Error"),
145        }
146    }
147}
148
149/// Error type which indicates that state machine failed to carry out the protocol
150#[derive(Debug, thiserror::Error)]
151#[error(transparent)]
152pub struct ExecutionError(Reason);
153
154#[derive(Debug, thiserror::Error)]
155enum Reason {
156    #[error("resuming state machine when protocol is already finished")]
157    Exhausted,
158    #[error("protocol polls unknown (unsupported) future")]
159    PollingUnknownFuture,
160}
161
162impl<O, M> From<Reason> for ProceedResult<O, M> {
163    fn from(err: Reason) -> Self {
164        ProceedResult::Error(ExecutionError(err))
165    }
166}
167impl From<Reason> for ExecutionError {
168    fn from(err: Reason) -> Self {
169        ExecutionError(err)
170    }
171}
172
173struct StateMachineImpl<O, M, F: Future<Output = O>> {
174    shared_state: shared_state::SharedStateRef<M>,
175    exhausted: bool,
176    future: core::pin::Pin<alloc::boxed::Box<F>>,
177}
178
179impl<O, M, F> StateMachine for StateMachineImpl<O, M, F>
180where
181    F: Future<Output = O>,
182{
183    type Output = O;
184    type Msg = M;
185
186    fn proceed(&mut self) -> ProceedResult<Self::Output, Self::Msg> {
187        if self.exhausted {
188            return Reason::Exhausted.into();
189        }
190        let future = self.future.as_mut();
191        let waker = noop_waker::noop_waker();
192        let mut cx = core::task::Context::from_waker(&waker);
193        match future.poll(&mut cx) {
194            Poll::Ready(output) => {
195                self.exhausted = true;
196                ProceedResult::Output(output)
197            }
198            Poll::Pending => {
199                // underlying future may `await` only on either:
200                // 1. Flushing outgoing message
201                // 2. Waiting for incoming message
202                // 3. Yielding
203
204                // Check if it's flushing outgoing message:
205                if let Some(outgoing_msg) = self.shared_state.executor_takes_outgoing_msg() {
206                    return ProceedResult::SendMsg(outgoing_msg);
207                }
208
209                // Check if it's waiting for a new message
210                if self.shared_state.protocol_wants_more_messages() {
211                    return ProceedResult::NeedsOneMoreMessage;
212                }
213
214                // Check if protocol yielded
215                if self.shared_state.executor_reads_and_resets_yielded_flag() {
216                    return ProceedResult::Yielded;
217                }
218
219                // If none of above conditions are met, then protocol is polling
220                // a future which we do not recognize
221                Reason::PollingUnknownFuture.into()
222            }
223        }
224    }
225
226    fn received_msg(&mut self, msg: crate::Incoming<Self::Msg>) -> Result<(), crate::Incoming<M>> {
227        self.shared_state.executor_received_msg(msg)
228    }
229}
230
231/// MpcParty instantiated with state machine implementation of delivery and async runtime
232pub type MpcParty<M> = crate::MpcParty<M, Delivery<M>, Runtime<M>>;
233
234/// Wraps the protocol and provides sync API to execute it
235///
236/// Protocol is an async function that takes [`MpcParty`] as input. `MpcParty` contains
237/// channels (of incoming and outgoing messages) that protocol is expected to use, and
238/// a [`Runtime`]. Protocol is only allowed to `.await` on futures provided in `MpcParty`,
239/// such as polling next message from provided steam of incoming messages. If protocol
240/// polls an unknown future, executor won't know what to do with that, the protocol will
241/// be aborted and error returned.
242pub fn wrap_protocol<'a, M, F>(
243    protocol: impl FnOnce(MpcParty<M>) -> F,
244) -> impl StateMachine<Output = F::Output, Msg = M> + 'a
245where
246    F: Future + 'a,
247    M: ProtocolMsg + 'static,
248{
249    let shared_state = shared_state::SharedStateRef::new();
250    let delivery = Delivery::new(shared_state.clone());
251    let runtime = Runtime::new(shared_state.clone());
252
253    let future = protocol(crate::mpc::connected(delivery).with_runtime(runtime));
254    let future = alloc::boxed::Box::pin(future);
255
256    StateMachineImpl {
257        shared_state,
258        exhausted: false,
259        future,
260    }
261}