mpc_stark/network/
quic.rs1use async_trait::async_trait;
4use futures::{Future, Sink, Stream};
5use quinn::{Endpoint, RecvStream, SendStream};
6use std::{
7 net::SocketAddr,
8 pin::Pin,
9 task::{Context, Poll},
10};
11use tracing::log;
12
13use crate::{
14 error::{MpcNetworkError, SetupError},
15 PARTY0,
16};
17
18use super::{config, stream_buffer::BufferWithCursor, MpcNetwork, NetworkOutbound, PartyId};
19
20const BYTES_PER_U64: usize = 8;
26
27const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";
29const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
31const ERR_SEND_BUFFER_FULL: &str = "send buffer full";
33
34pub struct QuicTwoPartyNet {
40 party_id: PartyId,
42 connected: bool,
44 local_addr: SocketAddr,
46 peer_addr: SocketAddr,
48 buffered_message_length: Option<u64>,
55 buffered_inbound: Option<BufferWithCursor>,
61 buffered_outbound: Option<BufferWithCursor>,
63 send_stream: Option<SendStream>,
65 recv_stream: Option<RecvStream>,
67}
68
69#[allow(clippy::redundant_closure)] impl<'a> QuicTwoPartyNet {
71 pub fn new(party_id: PartyId, local_addr: SocketAddr, peer_addr: SocketAddr) -> Self {
73 Self {
75 party_id,
76 local_addr,
77 peer_addr,
78 connected: false,
79 buffered_message_length: None,
80 buffered_inbound: None,
81 buffered_outbound: None,
82 send_stream: None,
83 recv_stream: None,
84 }
85 }
86
87 fn local_party0(&self) -> bool {
89 self.party_id() == PARTY0
90 }
91
92 fn assert_connected(&self) -> Result<(), MpcNetworkError> {
94 if self.connected {
95 Ok(())
96 } else {
97 Err(MpcNetworkError::NetworkUninitialized)
98 }
99 }
100
101 pub async fn connect(&mut self) -> Result<(), MpcNetworkError> {
103 let (client_config, server_config) =
105 config::build_configs().map_err(|err| MpcNetworkError::ConnectionSetupError(err))?;
106
107 let mut local_endpoint = Endpoint::server(server_config, self.local_addr).map_err(|e| {
109 log::error!("error setting up quinn server: {e:?}");
110 MpcNetworkError::ConnectionSetupError(SetupError::ServerSetupError)
111 })?;
112 local_endpoint.set_default_client_config(client_config);
113
114 let connection = {
116 if self.local_party0() {
117 local_endpoint
118 .connect(self.peer_addr, config::SERVER_NAME)
119 .map_err(|err| {
120 log::error!("error setting up quic endpoint connection: {err}");
121 MpcNetworkError::ConnectionSetupError(SetupError::ConnectError(err))
122 })?
123 .await
124 .map_err(|err| {
125 log::error!("error connecting to the remote quic endpoint: {err}");
126 MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
127 })?
128 } else {
129 local_endpoint
130 .accept()
131 .await
132 .ok_or_else(|| {
133 log::error!("no incoming connection while awaiting quic endpoint");
134 MpcNetworkError::ConnectionSetupError(SetupError::NoIncomingConnection)
135 })?
136 .await
137 .map_err(|err| {
138 log::error!("error while establishing remote connection as listener");
139 MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
140 })?
141 }
142 };
143
144 let (send, recv) = {
146 if self.local_party0() {
147 connection.open_bi().await.map_err(|err| {
148 log::error!("error opening bidirectional stream: {err}");
149 MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
150 })?
151 } else {
152 connection.accept_bi().await.map_err(|err| {
153 log::error!("error accepting bidirectional stream: {err}");
154 MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
155 })?
156 }
157 };
158
159 self.connected = true;
161 self.send_stream = Some(send);
162 self.recv_stream = Some(recv);
163
164 Ok(())
165 }
166
167 async fn write_bytes(&mut self) -> Result<(), MpcNetworkError> {
169 if self.buffered_outbound.is_none() {
171 return Ok(());
172 }
173
174 let buf = self.buffered_outbound.as_mut().unwrap();
176 while !buf.is_depleted() {
177 let bytes_written = self
178 .send_stream
179 .as_mut()
180 .unwrap()
181 .write(buf.get_remaining())
182 .await
183 .map_err(|e| MpcNetworkError::SendError(e.to_string()))?;
184
185 buf.advance_cursor(bytes_written);
186 }
187
188 self.buffered_outbound = None;
189 Ok(())
190 }
191
192 async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
194 if self.buffered_inbound.is_none() {
196 self.buffered_inbound = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
197 }
198
199 let read_buffer = self.buffered_inbound.as_mut().unwrap();
201 while !read_buffer.is_depleted() {
202 let bytes_read = self
203 .recv_stream
204 .as_mut()
205 .unwrap()
206 .read(read_buffer.get_remaining())
207 .await
208 .map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
209 .ok_or(MpcNetworkError::RecvError(
210 ERR_STREAM_FINISHED_EARLY.to_string(),
211 ))?;
212
213 read_buffer.advance_cursor(bytes_read);
214 }
215
216 Ok(self.buffered_inbound.take().unwrap().into_vec())
218 }
219
220 async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
222 let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
223 Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
224 |_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
225 )?))
226 }
227
228 async fn receive_message(&mut self) -> Result<NetworkOutbound, MpcNetworkError> {
230 if self.buffered_message_length.is_none() {
232 self.buffered_message_length = Some(self.read_message_length().await?);
233 }
234
235 let len = self.buffered_message_length.unwrap();
237 let bytes = self.read_bytes(len as usize).await?;
238
239 self.buffered_message_length = None;
241
242 serde_json::from_slice(&bytes)
244 .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))
245 }
246}
247
248#[async_trait]
249impl MpcNetwork for QuicTwoPartyNet {
250 fn party_id(&self) -> PartyId {
251 self.party_id
252 }
253
254 async fn close(&mut self) -> Result<(), MpcNetworkError> {
255 self.assert_connected()?;
256
257 self.send_stream
258 .as_mut()
259 .unwrap()
260 .finish()
261 .await
262 .map_err(|_| MpcNetworkError::ConnectionTeardownError)
263 }
264}
265
266impl Stream for QuicTwoPartyNet {
267 type Item = Result<NetworkOutbound, MpcNetworkError>;
268
269 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270 Box::pin(self.receive_message()).as_mut().poll(cx).map(Some)
271 }
272}
273
274impl Sink<NetworkOutbound> for QuicTwoPartyNet {
275 type Error = MpcNetworkError;
276
277 fn start_send(mut self: Pin<&mut Self>, msg: NetworkOutbound) -> Result<(), Self::Error> {
278 if !self.connected {
279 return Err(MpcNetworkError::NetworkUninitialized);
280 }
281
282 if self.buffered_outbound.is_some() {
284 return Err(MpcNetworkError::SendError(ERR_SEND_BUFFER_FULL.to_string()));
285 }
286
287 let bytes = serde_json::to_vec(&msg)
289 .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))?;
290 let mut payload = (bytes.len() as u64).to_le_bytes().to_vec();
291 payload.extend_from_slice(&bytes);
292
293 self.buffered_outbound = Some(BufferWithCursor::new(payload));
294 Ok(())
295 }
296
297 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
298 Box::pin(self.write_bytes()).as_mut().poll(cx)
300 }
301
302 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
303 self.poll_flush(cx)
305 }
306
307 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
308 self.poll_flush(cx)
310 }
311}