round_based/state_machine/mod.rs
1//! Wraps the protocol defined as async function and provides sync API to execute it
2//!
3//! In `round_based` framework, MPC protocols are defined as async function. However, sometimes it
4//! may not be possible/desirable to have async runtime which drives the futures until completion.
5//! For such use-cases, we provide [`wrap_protocol`] function that wraps an MPC protocol defined as
6//! async function and returns the [`StateMachine`] that exposes sync API to carry out the protocol.
7//!
8//! ## Example
9//! ```rust,no_run
10//! # fn main() -> anyhow::Result<()> {
11//! use anyhow::{Result, Error, Context as _};
12//!
13//! # type Randomness = [u8; 32];
14//! # #[derive(round_based::ProtocolMsg, Clone)]
15//! # enum Msg {}
16//! // Any MPC protocol
17//! pub async fn protocol_of_random_generation<M>(
18//! party: M,
19//! i: u16,
20//! n: u16
21//! ) -> Result<Randomness>
22//! where
23//! M: round_based::Mpc<Msg = Msg>
24//! {
25//! // ...
26//! # todo!()
27//! }
28//!
29//! // `state` implements `round_based::state_machine::StateMachine` trait.
30//! // Its methods can be used to advance protocol until completion.
31//! let mut state = round_based::state_machine::wrap_protocol(
32//! |party| protocol_of_random_generation(party, 0, 3)
33//! );
34//!
35//! fn send(msg: round_based::Outgoing<Msg>) -> Result<()> {
36//! // sends outgoing message...
37//! # unimplemented!()
38//! }
39//! fn recv() -> Result<round_based::Incoming<Msg>> {
40//! // receives incoming message...
41//! # unimplemented!()
42//! }
43//!
44//! use round_based::state_machine::{StateMachine as _, ProceedResult};
45//! let output = loop {
46//! match state.proceed() {
47//! ProceedResult::SendMsg(msg) => {
48//! send(msg)?
49//! }
50//! ProceedResult::NeedsOneMoreMessage => {
51//! let msg = recv()?;
52//! state.received_msg(msg)
53//! .map_err(|_| anyhow::format_err!("state machine rejected received message"))?;
54//! }
55//! ProceedResult::Yielded => {},
56//! ProceedResult::Output(out) => break Ok(out),
57//! ProceedResult::Error(err) => break Err(err),
58//! }
59//! };
60//! # Ok(()) }
61//! ```
62
63mod delivery;
64mod noop_waker;
65mod runtime;
66mod shared_state;
67
68use core::{future::Future, task::Poll};
69
70use crate::ProtocolMsg;
71
72pub use self::{
73 delivery::{Delivery, DeliveryErr},
74 runtime::{Runtime, YieldNow},
75};
76
77/// Provides interface to execute the protocol
78pub trait StateMachine {
79 /// Output of the protocol
80 type Output;
81 /// Message of the protocol
82 type Msg;
83
84 /// Resumes protocol execution
85 ///
86 /// Returns [`ProceedResult`] which will indicate, for instance, if the protocol wants to send
87 /// or receive a message, or if it's finished.
88 ///
89 /// Calling `proceed` after protocol has finished (after it returned [`ProceedResult::Output`])
90 /// returns an error.
91 fn proceed(&mut self) -> ProceedResult<Self::Output, Self::Msg>;
92 /// Saves received message to be picked up by the state machine on the next [`proceed`](Self::proceed) invocation
93 ///
94 /// This method should only be called if state machine returned [`ProceedResult::NeedsOneMoreMessage`] on previous
95 /// invocation of [`proceed`](Self::proceed) method. Calling this method when state machine did not request it
96 /// may return error.
97 ///
98 /// Calling this method must be followed up by calling [`proceed`](Self::proceed). Do not invoke this method
99 /// more than once in a row, even if you have available messages received from other parties. Instead, you
100 /// should call this method, then call `proceed`, and only if it returned [`ProceedResult::NeedsOneMoreMessage`]
101 /// you can call `received_msg` again.
102 fn received_msg(
103 &mut self,
104 msg: crate::Incoming<Self::Msg>,
105 ) -> Result<(), crate::Incoming<Self::Msg>>;
106}
107
108/// Tells why protocol execution stopped
109#[must_use = "ProceedResult must be used to correctly carry out the state machine"]
110pub enum ProceedResult<O, M> {
111 /// Protocol needs provided message to be sent
112 SendMsg(crate::Outgoing<M>),
113 /// Protocol needs one more message to be received
114 ///
115 /// After the state machine requested one more message, the next call to the state machine must
116 /// be [`StateMachine::received_msg`].
117 NeedsOneMoreMessage,
118 /// Protocol is finished
119 Output(O),
120 /// Protocol yielded the execution
121 ///
122 /// Protocol may yield at any point by calling `AsyncRuntime::yield_now`. Main motivation
123 /// for yielding is to break a long computation into smaller parts, so proceeding state
124 /// machine doesn't take too long.
125 ///
126 /// When protocol yields, you can resume the execution by calling [`proceed`](StateMachine::proceed)
127 /// immediately.
128 Yielded,
129 /// State machine failed to carry out the protocol
130 ///
131 /// Error likely means that either state machine is misused (e.g. when [`proceed`](StateMachine::proceed)
132 /// is called after protocol is finished) or protocol implementation is not supported by state machine
133 /// executor (e.g. it polls unknown future).
134 Error(ExecutionError),
135}
136
137impl<O, M> core::fmt::Debug for ProceedResult<O, M> {
138 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
139 match self {
140 ProceedResult::SendMsg(_) => f.write_str("SendMsg"),
141 ProceedResult::NeedsOneMoreMessage => f.write_str("NeedsOneMoreMessage"),
142 ProceedResult::Output(_) => f.write_str("Output"),
143 ProceedResult::Yielded => f.write_str("Yielded"),
144 ProceedResult::Error(_) => f.write_str("Error"),
145 }
146 }
147}
148
149/// Error type which indicates that state machine failed to carry out the protocol
150#[derive(Debug, thiserror::Error)]
151#[error(transparent)]
152pub struct ExecutionError(Reason);
153
154#[derive(Debug, thiserror::Error)]
155enum Reason {
156 #[error("resuming state machine when protocol is already finished")]
157 Exhausted,
158 #[error("protocol polls unknown (unsupported) future")]
159 PollingUnknownFuture,
160}
161
162impl<O, M> From<Reason> for ProceedResult<O, M> {
163 fn from(err: Reason) -> Self {
164 ProceedResult::Error(ExecutionError(err))
165 }
166}
167impl From<Reason> for ExecutionError {
168 fn from(err: Reason) -> Self {
169 ExecutionError(err)
170 }
171}
172
173struct StateMachineImpl<O, M, F: Future<Output = O>> {
174 shared_state: shared_state::SharedStateRef<M>,
175 exhausted: bool,
176 future: core::pin::Pin<alloc::boxed::Box<F>>,
177}
178
179impl<O, M, F> StateMachine for StateMachineImpl<O, M, F>
180where
181 F: Future<Output = O>,
182{
183 type Output = O;
184 type Msg = M;
185
186 fn proceed(&mut self) -> ProceedResult<Self::Output, Self::Msg> {
187 if self.exhausted {
188 return Reason::Exhausted.into();
189 }
190 let future = self.future.as_mut();
191 let waker = noop_waker::noop_waker();
192 let mut cx = core::task::Context::from_waker(&waker);
193 match future.poll(&mut cx) {
194 Poll::Ready(output) => {
195 self.exhausted = true;
196 ProceedResult::Output(output)
197 }
198 Poll::Pending => {
199 // underlying future may `await` only on either:
200 // 1. Flushing outgoing message
201 // 2. Waiting for incoming message
202 // 3. Yielding
203
204 // Check if it's flushing outgoing message:
205 if let Some(outgoing_msg) = self.shared_state.executor_takes_outgoing_msg() {
206 return ProceedResult::SendMsg(outgoing_msg);
207 }
208
209 // Check if it's waiting for a new message
210 if self.shared_state.protocol_wants_more_messages() {
211 return ProceedResult::NeedsOneMoreMessage;
212 }
213
214 // Check if protocol yielded
215 if self.shared_state.executor_reads_and_resets_yielded_flag() {
216 return ProceedResult::Yielded;
217 }
218
219 // If none of above conditions are met, then protocol is polling
220 // a future which we do not recognize
221 Reason::PollingUnknownFuture.into()
222 }
223 }
224 }
225
226 fn received_msg(&mut self, msg: crate::Incoming<Self::Msg>) -> Result<(), crate::Incoming<M>> {
227 self.shared_state.executor_received_msg(msg)
228 }
229}
230
231/// MpcParty instantiated with state machine implementation of delivery and async runtime
232pub type MpcParty<M> = crate::MpcParty<M, Delivery<M>, Runtime<M>>;
233
234/// Wraps the protocol and provides sync API to execute it
235///
236/// Protocol is an async function that takes [`MpcParty`] as input. `MpcParty` contains
237/// channels (of incoming and outgoing messages) that protocol is expected to use, and
238/// a [`Runtime`]. Protocol is only allowed to `.await` on futures provided in `MpcParty`,
239/// such as polling next message from provided steam of incoming messages. If protocol
240/// polls an unknown future, executor won't know what to do with that, the protocol will
241/// be aborted and error returned.
242pub fn wrap_protocol<'a, M, F>(
243 protocol: impl FnOnce(MpcParty<M>) -> F,
244) -> impl StateMachine<Output = F::Output, Msg = M> + 'a
245where
246 F: Future + 'a,
247 M: ProtocolMsg + 'static,
248{
249 let shared_state = shared_state::SharedStateRef::new();
250 let delivery = Delivery::new(shared_state.clone());
251 let runtime = Runtime::new(shared_state.clone());
252
253 let future = protocol(crate::mpc::connected(delivery).with_runtime(runtime));
254 let future = alloc::boxed::Box::pin(future);
255
256 StateMachineImpl {
257 shared_state,
258 exhausted: false,
259 future,
260 }
261}