ark_mpc/network/
quic.rs

1//! Defines the central implementation of an `MpcNetwork` over the QUIC
2//! transport
3
4use ark_ec::CurveGroup;
5use async_trait::async_trait;
6use futures::{Future, Sink, Stream};
7use quinn::{Endpoint, RecvStream, SendStream};
8use std::{
9    marker::PhantomData,
10    net::SocketAddr,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tracing::log;
15
16use crate::{
17    error::{MpcNetworkError, SetupError},
18    PARTY0,
19};
20
21use super::{config, stream_buffer::BufferWithCursor, MpcNetwork, NetworkOutbound, PartyId};
22
23// -------------
24// | Constants |
25// -------------
26
27/// The number of bytes in a u64
28const BYTES_PER_U64: usize = 8;
29
30/// Error thrown when a stream finishes early
31const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";
32/// Error message emitted when reading a message length from the stream fails
33const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
34/// Error message emitted when the the send `Sink` is not ready
35const ERR_SEND_BUFFER_FULL: &str = "send buffer full";
36
37// -----------------------
38// | Quic Implementation |
39// -----------------------
40
41/// Implements an MpcNetwork on top of QUIC
42pub struct QuicTwoPartyNet<C: CurveGroup> {
43    /// The index of the local party in the participants
44    party_id: PartyId,
45    /// Whether the network has been bootstrapped yet
46    connected: bool,
47    /// The address of the local peer
48    local_addr: SocketAddr,
49    /// Addresses of the counterparties in the MPC
50    peer_addr: SocketAddr,
51    /// A buffered message length read from the stream
52    ///
53    /// In the case that the whole message is not available yet, reads may block
54    /// and the `read_message` future may be cancelled by the executor.
55    /// We buffer the message length to avoid re-reading the message length
56    /// incorrectly from the stream
57    buffered_message_length: Option<u64>,
58    /// A buffered partial message read from the stream
59    ///
60    /// This buffer exists to provide cancellation safety to a `read` future as
61    /// the underlying `quinn` stream is not cancellation safe, i.e. if a
62    /// `ReadBuf` future is dropped, the buffer is dropped with it and the
63    /// partially read data is skipped
64    buffered_inbound: Option<BufferWithCursor>,
65    /// A buffered partial message written to the stream
66    buffered_outbound: Option<BufferWithCursor>,
67    /// The send side of the bidirectional stream
68    send_stream: Option<SendStream>,
69    /// The receive side of the bidirectional stream
70    recv_stream: Option<RecvStream>,
71    /// The phantom on the curve group
72    _phantom: PhantomData<C>,
73}
74
75#[allow(clippy::redundant_closure)] // For readability of error handling
76impl<'a, C: CurveGroup> QuicTwoPartyNet<C> {
77    /// Create a new network, do not connect the network yet
78    pub fn new(party_id: PartyId, local_addr: SocketAddr, peer_addr: SocketAddr) -> Self {
79        // Construct the QUIC net
80        Self {
81            party_id,
82            local_addr,
83            peer_addr,
84            connected: false,
85            buffered_message_length: None,
86            buffered_inbound: None,
87            buffered_outbound: None,
88            send_stream: None,
89            recv_stream: None,
90            _phantom: PhantomData,
91        }
92    }
93
94    /// Returns true if the local party is party 0
95    fn local_party0(&self) -> bool {
96        self.party_id == PARTY0
97    }
98
99    /// Returns an error if the network is not connected
100    fn assert_connected(&self) -> Result<(), MpcNetworkError> {
101        if self.connected {
102            Ok(())
103        } else {
104            Err(MpcNetworkError::NetworkUninitialized)
105        }
106    }
107
108    /// Establishes connections to the peer
109    pub async fn connect(&mut self) -> Result<(), MpcNetworkError> {
110        // Build the client and server configs
111        let (client_config, server_config) =
112            config::build_configs().map_err(|err| MpcNetworkError::ConnectionSetupError(err))?;
113
114        // Create a quinn server
115        let mut local_endpoint = Endpoint::server(server_config, self.local_addr).map_err(|e| {
116            log::error!("error setting up quinn server: {e:?}");
117            MpcNetworkError::ConnectionSetupError(SetupError::ServerSetupError)
118        })?;
119        local_endpoint.set_default_client_config(client_config);
120
121        // The king dials the peer who awaits connection
122        let connection = {
123            if self.local_party0() {
124                local_endpoint
125                    .connect(self.peer_addr, config::SERVER_NAME)
126                    .map_err(|err| {
127                        log::error!("error setting up quic endpoint connection: {err}");
128                        MpcNetworkError::ConnectionSetupError(SetupError::ConnectError(err))
129                    })?
130                    .await
131                    .map_err(|err| {
132                        log::error!("error connecting to the remote quic endpoint: {err}");
133                        MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
134                    })?
135            } else {
136                local_endpoint
137                    .accept()
138                    .await
139                    .ok_or_else(|| {
140                        log::error!("no incoming connection while awaiting quic endpoint");
141                        MpcNetworkError::ConnectionSetupError(SetupError::NoIncomingConnection)
142                    })?
143                    .await
144                    .map_err(|err| {
145                        log::error!("error while establishing remote connection as listener");
146                        MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
147                    })?
148            }
149        };
150
151        // King opens a bidirectional stream on top of the connection
152        let (send, recv) = {
153            if self.local_party0() {
154                connection.open_bi().await.map_err(|err| {
155                    log::error!("error opening bidirectional stream: {err}");
156                    MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
157                })?
158            } else {
159                connection.accept_bi().await.map_err(|err| {
160                    log::error!("error accepting bidirectional stream: {err}");
161                    MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
162                })?
163            }
164        };
165
166        // Update MpcNet state
167        self.connected = true;
168        self.send_stream = Some(send);
169        self.recv_stream = Some(recv);
170
171        Ok(())
172    }
173
174    /// Write the current buffer to the stream
175    async fn write_bytes(&mut self) -> Result<(), MpcNetworkError> {
176        // If no pending writes are available, return
177        if self.buffered_outbound.is_none() {
178            return Ok(());
179        }
180
181        // While the outbound buffer has elements remaining, write them
182        let buf = self.buffered_outbound.as_mut().unwrap();
183        while !buf.is_depleted() {
184            let bytes_written = self
185                .send_stream
186                .as_mut()
187                .unwrap()
188                .write(buf.get_remaining())
189                .await
190                .map_err(|e| MpcNetworkError::SendError(e.to_string()))?;
191
192            buf.advance_cursor(bytes_written);
193        }
194
195        self.buffered_outbound = None;
196        Ok(())
197    }
198
199    /// Read exactly `n` bytes from the stream
200    async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
201        // Allocate a buffer for the next message if one does not already exist
202        if self.buffered_inbound.is_none() {
203            self.buffered_inbound = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
204        }
205
206        // Read until the buffer is full
207        let read_buffer = self.buffered_inbound.as_mut().unwrap();
208        while !read_buffer.is_depleted() {
209            let bytes_read = self
210                .recv_stream
211                .as_mut()
212                .unwrap()
213                .read(read_buffer.get_remaining())
214                .await
215                .map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
216                .ok_or(MpcNetworkError::RecvError(
217                    ERR_STREAM_FINISHED_EARLY.to_string(),
218                ))?;
219
220            read_buffer.advance_cursor(bytes_read);
221        }
222
223        // Take ownership of the buffer, and reset the buffered message to `None`
224        Ok(self.buffered_inbound.take().unwrap().into_vec())
225    }
226
227    /// Read a message length from the stream
228    async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
229        let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
230        Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
231            |_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
232        )?))
233    }
234
235    /// Receive a message from the peer
236    async fn receive_message(&mut self) -> Result<NetworkOutbound<C>, MpcNetworkError> {
237        // Read the message length from the buffer if available
238        if self.buffered_message_length.is_none() {
239            self.buffered_message_length = Some(self.read_message_length().await?);
240        }
241
242        // Read the data from the stream
243        let len = self.buffered_message_length.unwrap();
244        let bytes = self.read_bytes(len as usize).await?;
245
246        // Reset the message length buffer after the data has been pulled from the
247        // stream
248        self.buffered_message_length = None;
249
250        // Deserialize the message
251        serde_json::from_slice(&bytes)
252            .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))
253    }
254}
255
256#[async_trait]
257impl<C: CurveGroup> MpcNetwork<C> for QuicTwoPartyNet<C>
258where
259    C: Unpin,
260{
261    fn party_id(&self) -> PartyId {
262        self.party_id
263    }
264
265    async fn close(&mut self) -> Result<(), MpcNetworkError> {
266        self.assert_connected()?;
267
268        self.send_stream
269            .as_mut()
270            .unwrap()
271            .finish()
272            .await
273            .map_err(|_| MpcNetworkError::ConnectionTeardownError)
274    }
275}
276
277impl<C: CurveGroup> Stream for QuicTwoPartyNet<C>
278where
279    C: Unpin,
280{
281    type Item = Result<NetworkOutbound<C>, MpcNetworkError>;
282
283    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
284        Box::pin(self.get_mut().receive_message())
285            .as_mut()
286            .poll(cx)
287            .map(Some)
288    }
289}
290
291impl<C: CurveGroup> Sink<NetworkOutbound<C>> for QuicTwoPartyNet<C>
292where
293    C: Unpin,
294{
295    type Error = MpcNetworkError;
296
297    fn start_send(self: Pin<&mut Self>, msg: NetworkOutbound<C>) -> Result<(), Self::Error> {
298        if !self.connected {
299            return Err(MpcNetworkError::NetworkUninitialized);
300        }
301
302        // Must call `poll_flush` before calling `start_send` again
303        if self.buffered_outbound.is_some() {
304            return Err(MpcNetworkError::SendError(ERR_SEND_BUFFER_FULL.to_string()));
305        }
306
307        // Serialize the message and buffer it for writing
308        let bytes = serde_json::to_vec(&msg)
309            .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))?;
310        let mut payload = (bytes.len() as u64).to_le_bytes().to_vec();
311        payload.extend_from_slice(&bytes);
312
313        self.get_mut().buffered_outbound = Some(BufferWithCursor::new(payload));
314        Ok(())
315    }
316
317    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
318        // Poll the write future
319        Box::pin(self.write_bytes()).as_mut().poll(cx)
320    }
321
322    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
323        // The network is always ready to send
324        self.poll_flush(cx)
325    }
326
327    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
328        // The network is always ready to close
329        self.poll_flush(cx)
330    }
331}