mpc_stark/network/
quic.rs

1//! Defines the central implementation of an `MpcNetwork` over the QUIC transport
2
3use async_trait::async_trait;
4use futures::{Future, Sink, Stream};
5use quinn::{Endpoint, RecvStream, SendStream};
6use std::{
7    net::SocketAddr,
8    pin::Pin,
9    task::{Context, Poll},
10};
11use tracing::log;
12
13use crate::{
14    error::{MpcNetworkError, SetupError},
15    PARTY0,
16};
17
18use super::{config, stream_buffer::BufferWithCursor, MpcNetwork, NetworkOutbound, PartyId};
19
20// -------------
21// | Constants |
22// -------------
23
24/// The number of bytes in a u64
25const BYTES_PER_U64: usize = 8;
26
27/// Error thrown when a stream finishes early
28const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";
29/// Error message emitted when reading a message length from the stream fails
30const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
31/// Error message emitted when the the send `Sink` is not ready
32const ERR_SEND_BUFFER_FULL: &str = "send buffer full";
33
34// -----------------------
35// | Quic Implementation |
36// -----------------------
37
38/// Implements an MpcNetwork on top of QUIC
39pub struct QuicTwoPartyNet {
40    /// The index of the local party in the participants
41    party_id: PartyId,
42    /// Whether the network has been bootstrapped yet
43    connected: bool,
44    /// The address of the local peer
45    local_addr: SocketAddr,
46    /// Addresses of the counterparties in the MPC
47    peer_addr: SocketAddr,
48    /// A buffered message length read from the stream
49    ///
50    /// In the case that the whole message is not available yet, reads may block
51    /// and the `read_message` future may be cancelled by the executor.
52    /// We buffer the message length to avoid re-reading the message length incorrectly from
53    /// the stream
54    buffered_message_length: Option<u64>,
55    /// A buffered partial message read from the stream
56    ///
57    /// This buffer exists to provide cancellation safety to a `read` future as the underlying `quinn`
58    /// stream is not cancellation safe, i.e. if a `ReadBuf` future is dropped, the buffer is dropped with
59    /// it and the partially read data is skipped
60    buffered_inbound: Option<BufferWithCursor>,
61    /// A buffered partial message written to the stream
62    buffered_outbound: Option<BufferWithCursor>,
63    /// The send side of the bidirectional stream
64    send_stream: Option<SendStream>,
65    /// The receive side of the bidirectional stream
66    recv_stream: Option<RecvStream>,
67}
68
69#[allow(clippy::redundant_closure)] // For readability of error handling
70impl<'a> QuicTwoPartyNet {
71    /// Create a new network, do not connect the network yet
72    pub fn new(party_id: PartyId, local_addr: SocketAddr, peer_addr: SocketAddr) -> Self {
73        // Construct the QUIC net
74        Self {
75            party_id,
76            local_addr,
77            peer_addr,
78            connected: false,
79            buffered_message_length: None,
80            buffered_inbound: None,
81            buffered_outbound: None,
82            send_stream: None,
83            recv_stream: None,
84        }
85    }
86
87    /// Returns true if the local party is party 0
88    fn local_party0(&self) -> bool {
89        self.party_id() == PARTY0
90    }
91
92    /// Returns an error if the network is not connected
93    fn assert_connected(&self) -> Result<(), MpcNetworkError> {
94        if self.connected {
95            Ok(())
96        } else {
97            Err(MpcNetworkError::NetworkUninitialized)
98        }
99    }
100
101    /// Establishes connections to the peer
102    pub async fn connect(&mut self) -> Result<(), MpcNetworkError> {
103        // Build the client and server configs
104        let (client_config, server_config) =
105            config::build_configs().map_err(|err| MpcNetworkError::ConnectionSetupError(err))?;
106
107        // Create a quinn server
108        let mut local_endpoint = Endpoint::server(server_config, self.local_addr).map_err(|e| {
109            log::error!("error setting up quinn server: {e:?}");
110            MpcNetworkError::ConnectionSetupError(SetupError::ServerSetupError)
111        })?;
112        local_endpoint.set_default_client_config(client_config);
113
114        // The king dials the peer who awaits connection
115        let connection = {
116            if self.local_party0() {
117                local_endpoint
118                    .connect(self.peer_addr, config::SERVER_NAME)
119                    .map_err(|err| {
120                        log::error!("error setting up quic endpoint connection: {err}");
121                        MpcNetworkError::ConnectionSetupError(SetupError::ConnectError(err))
122                    })?
123                    .await
124                    .map_err(|err| {
125                        log::error!("error connecting to the remote quic endpoint: {err}");
126                        MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
127                    })?
128            } else {
129                local_endpoint
130                    .accept()
131                    .await
132                    .ok_or_else(|| {
133                        log::error!("no incoming connection while awaiting quic endpoint");
134                        MpcNetworkError::ConnectionSetupError(SetupError::NoIncomingConnection)
135                    })?
136                    .await
137                    .map_err(|err| {
138                        log::error!("error while establishing remote connection as listener");
139                        MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
140                    })?
141            }
142        };
143
144        // King opens a bidirectional stream on top of the connection
145        let (send, recv) = {
146            if self.local_party0() {
147                connection.open_bi().await.map_err(|err| {
148                    log::error!("error opening bidirectional stream: {err}");
149                    MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
150                })?
151            } else {
152                connection.accept_bi().await.map_err(|err| {
153                    log::error!("error accepting bidirectional stream: {err}");
154                    MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
155                })?
156            }
157        };
158
159        // Update MpcNet state
160        self.connected = true;
161        self.send_stream = Some(send);
162        self.recv_stream = Some(recv);
163
164        Ok(())
165    }
166
167    /// Write the current buffer to the stream
168    async fn write_bytes(&mut self) -> Result<(), MpcNetworkError> {
169        // If no pending writes are available, return
170        if self.buffered_outbound.is_none() {
171            return Ok(());
172        }
173
174        // While the outbound buffer has elements remaining, write them
175        let buf = self.buffered_outbound.as_mut().unwrap();
176        while !buf.is_depleted() {
177            let bytes_written = self
178                .send_stream
179                .as_mut()
180                .unwrap()
181                .write(buf.get_remaining())
182                .await
183                .map_err(|e| MpcNetworkError::SendError(e.to_string()))?;
184
185            buf.advance_cursor(bytes_written);
186        }
187
188        self.buffered_outbound = None;
189        Ok(())
190    }
191
192    /// Read exactly `n` bytes from the stream
193    async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
194        // Allocate a buffer for the next message if one does not already exist
195        if self.buffered_inbound.is_none() {
196            self.buffered_inbound = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
197        }
198
199        // Read until the buffer is full
200        let read_buffer = self.buffered_inbound.as_mut().unwrap();
201        while !read_buffer.is_depleted() {
202            let bytes_read = self
203                .recv_stream
204                .as_mut()
205                .unwrap()
206                .read(read_buffer.get_remaining())
207                .await
208                .map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
209                .ok_or(MpcNetworkError::RecvError(
210                    ERR_STREAM_FINISHED_EARLY.to_string(),
211                ))?;
212
213            read_buffer.advance_cursor(bytes_read);
214        }
215
216        // Take ownership of the buffer, and reset the buffered message to `None`
217        Ok(self.buffered_inbound.take().unwrap().into_vec())
218    }
219
220    /// Read a message length from the stream
221    async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
222        let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
223        Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
224            |_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
225        )?))
226    }
227
228    /// Receive a message from the peer
229    async fn receive_message(&mut self) -> Result<NetworkOutbound, MpcNetworkError> {
230        // Read the message length from the buffer if available
231        if self.buffered_message_length.is_none() {
232            self.buffered_message_length = Some(self.read_message_length().await?);
233        }
234
235        // Read the data from the stream
236        let len = self.buffered_message_length.unwrap();
237        let bytes = self.read_bytes(len as usize).await?;
238
239        // Reset the message length buffer after the data has been pulled from the stream
240        self.buffered_message_length = None;
241
242        // Deserialize the message
243        serde_json::from_slice(&bytes)
244            .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))
245    }
246}
247
248#[async_trait]
249impl MpcNetwork for QuicTwoPartyNet {
250    fn party_id(&self) -> PartyId {
251        self.party_id
252    }
253
254    async fn close(&mut self) -> Result<(), MpcNetworkError> {
255        self.assert_connected()?;
256
257        self.send_stream
258            .as_mut()
259            .unwrap()
260            .finish()
261            .await
262            .map_err(|_| MpcNetworkError::ConnectionTeardownError)
263    }
264}
265
266impl Stream for QuicTwoPartyNet {
267    type Item = Result<NetworkOutbound, MpcNetworkError>;
268
269    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270        Box::pin(self.receive_message()).as_mut().poll(cx).map(Some)
271    }
272}
273
274impl Sink<NetworkOutbound> for QuicTwoPartyNet {
275    type Error = MpcNetworkError;
276
277    fn start_send(mut self: Pin<&mut Self>, msg: NetworkOutbound) -> Result<(), Self::Error> {
278        if !self.connected {
279            return Err(MpcNetworkError::NetworkUninitialized);
280        }
281
282        // Must call `poll_flush` before calling `start_send` again
283        if self.buffered_outbound.is_some() {
284            return Err(MpcNetworkError::SendError(ERR_SEND_BUFFER_FULL.to_string()));
285        }
286
287        // Serialize the message and buffer it for writing
288        let bytes = serde_json::to_vec(&msg)
289            .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))?;
290        let mut payload = (bytes.len() as u64).to_le_bytes().to_vec();
291        payload.extend_from_slice(&bytes);
292
293        self.buffered_outbound = Some(BufferWithCursor::new(payload));
294        Ok(())
295    }
296
297    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
298        // Poll the write future
299        Box::pin(self.write_bytes()).as_mut().poll(cx)
300    }
301
302    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303        // The network is always ready to send
304        self.poll_flush(cx)
305    }
306
307    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
308        // The network is always ready to close
309        self.poll_flush(cx)
310    }
311}