1use ark_ec::CurveGroup;
5use async_trait::async_trait;
6use futures::{Future, Sink, Stream};
7use quinn::{Endpoint, RecvStream, SendStream};
8use std::{
9 marker::PhantomData,
10 net::SocketAddr,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use tracing::log;
15
16use crate::{
17 error::{MpcNetworkError, SetupError},
18 PARTY0,
19};
20
21use super::{config, stream_buffer::BufferWithCursor, MpcNetwork, NetworkOutbound, PartyId};
22
23const BYTES_PER_U64: usize = 8;
29
30const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";
32const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
34const ERR_SEND_BUFFER_FULL: &str = "send buffer full";
36
37pub struct QuicTwoPartyNet<C: CurveGroup> {
43 party_id: PartyId,
45 connected: bool,
47 local_addr: SocketAddr,
49 peer_addr: SocketAddr,
51 buffered_message_length: Option<u64>,
58 buffered_inbound: Option<BufferWithCursor>,
65 buffered_outbound: Option<BufferWithCursor>,
67 send_stream: Option<SendStream>,
69 recv_stream: Option<RecvStream>,
71 _phantom: PhantomData<C>,
73}
74
75#[allow(clippy::redundant_closure)] impl<'a, C: CurveGroup> QuicTwoPartyNet<C> {
77 pub fn new(party_id: PartyId, local_addr: SocketAddr, peer_addr: SocketAddr) -> Self {
79 Self {
81 party_id,
82 local_addr,
83 peer_addr,
84 connected: false,
85 buffered_message_length: None,
86 buffered_inbound: None,
87 buffered_outbound: None,
88 send_stream: None,
89 recv_stream: None,
90 _phantom: PhantomData,
91 }
92 }
93
94 fn local_party0(&self) -> bool {
96 self.party_id == PARTY0
97 }
98
99 fn assert_connected(&self) -> Result<(), MpcNetworkError> {
101 if self.connected {
102 Ok(())
103 } else {
104 Err(MpcNetworkError::NetworkUninitialized)
105 }
106 }
107
108 pub async fn connect(&mut self) -> Result<(), MpcNetworkError> {
110 let (client_config, server_config) =
112 config::build_configs().map_err(|err| MpcNetworkError::ConnectionSetupError(err))?;
113
114 let mut local_endpoint = Endpoint::server(server_config, self.local_addr).map_err(|e| {
116 log::error!("error setting up quinn server: {e:?}");
117 MpcNetworkError::ConnectionSetupError(SetupError::ServerSetupError)
118 })?;
119 local_endpoint.set_default_client_config(client_config);
120
121 let connection = {
123 if self.local_party0() {
124 local_endpoint
125 .connect(self.peer_addr, config::SERVER_NAME)
126 .map_err(|err| {
127 log::error!("error setting up quic endpoint connection: {err}");
128 MpcNetworkError::ConnectionSetupError(SetupError::ConnectError(err))
129 })?
130 .await
131 .map_err(|err| {
132 log::error!("error connecting to the remote quic endpoint: {err}");
133 MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
134 })?
135 } else {
136 local_endpoint
137 .accept()
138 .await
139 .ok_or_else(|| {
140 log::error!("no incoming connection while awaiting quic endpoint");
141 MpcNetworkError::ConnectionSetupError(SetupError::NoIncomingConnection)
142 })?
143 .await
144 .map_err(|err| {
145 log::error!("error while establishing remote connection as listener");
146 MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
147 })?
148 }
149 };
150
151 let (send, recv) = {
153 if self.local_party0() {
154 connection.open_bi().await.map_err(|err| {
155 log::error!("error opening bidirectional stream: {err}");
156 MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
157 })?
158 } else {
159 connection.accept_bi().await.map_err(|err| {
160 log::error!("error accepting bidirectional stream: {err}");
161 MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
162 })?
163 }
164 };
165
166 self.connected = true;
168 self.send_stream = Some(send);
169 self.recv_stream = Some(recv);
170
171 Ok(())
172 }
173
174 async fn write_bytes(&mut self) -> Result<(), MpcNetworkError> {
176 if self.buffered_outbound.is_none() {
178 return Ok(());
179 }
180
181 let buf = self.buffered_outbound.as_mut().unwrap();
183 while !buf.is_depleted() {
184 let bytes_written = self
185 .send_stream
186 .as_mut()
187 .unwrap()
188 .write(buf.get_remaining())
189 .await
190 .map_err(|e| MpcNetworkError::SendError(e.to_string()))?;
191
192 buf.advance_cursor(bytes_written);
193 }
194
195 self.buffered_outbound = None;
196 Ok(())
197 }
198
199 async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
201 if self.buffered_inbound.is_none() {
203 self.buffered_inbound = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
204 }
205
206 let read_buffer = self.buffered_inbound.as_mut().unwrap();
208 while !read_buffer.is_depleted() {
209 let bytes_read = self
210 .recv_stream
211 .as_mut()
212 .unwrap()
213 .read(read_buffer.get_remaining())
214 .await
215 .map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
216 .ok_or(MpcNetworkError::RecvError(
217 ERR_STREAM_FINISHED_EARLY.to_string(),
218 ))?;
219
220 read_buffer.advance_cursor(bytes_read);
221 }
222
223 Ok(self.buffered_inbound.take().unwrap().into_vec())
225 }
226
227 async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
229 let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
230 Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
231 |_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
232 )?))
233 }
234
235 async fn receive_message(&mut self) -> Result<NetworkOutbound<C>, MpcNetworkError> {
237 if self.buffered_message_length.is_none() {
239 self.buffered_message_length = Some(self.read_message_length().await?);
240 }
241
242 let len = self.buffered_message_length.unwrap();
244 let bytes = self.read_bytes(len as usize).await?;
245
246 self.buffered_message_length = None;
249
250 serde_json::from_slice(&bytes)
252 .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))
253 }
254}
255
256#[async_trait]
257impl<C: CurveGroup> MpcNetwork<C> for QuicTwoPartyNet<C>
258where
259 C: Unpin,
260{
261 fn party_id(&self) -> PartyId {
262 self.party_id
263 }
264
265 async fn close(&mut self) -> Result<(), MpcNetworkError> {
266 self.assert_connected()?;
267
268 self.send_stream
269 .as_mut()
270 .unwrap()
271 .finish()
272 .await
273 .map_err(|_| MpcNetworkError::ConnectionTeardownError)
274 }
275}
276
277impl<C: CurveGroup> Stream for QuicTwoPartyNet<C>
278where
279 C: Unpin,
280{
281 type Item = Result<NetworkOutbound<C>, MpcNetworkError>;
282
283 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
284 Box::pin(self.get_mut().receive_message())
285 .as_mut()
286 .poll(cx)
287 .map(Some)
288 }
289}
290
291impl<C: CurveGroup> Sink<NetworkOutbound<C>> for QuicTwoPartyNet<C>
292where
293 C: Unpin,
294{
295 type Error = MpcNetworkError;
296
297 fn start_send(self: Pin<&mut Self>, msg: NetworkOutbound<C>) -> Result<(), Self::Error> {
298 if !self.connected {
299 return Err(MpcNetworkError::NetworkUninitialized);
300 }
301
302 if self.buffered_outbound.is_some() {
304 return Err(MpcNetworkError::SendError(ERR_SEND_BUFFER_FULL.to_string()));
305 }
306
307 let bytes = serde_json::to_vec(&msg)
309 .map_err(|err| MpcNetworkError::SerializationError(err.to_string()))?;
310 let mut payload = (bytes.len() as u64).to_le_bytes().to_vec();
311 payload.extend_from_slice(&bytes);
312
313 self.get_mut().buffered_outbound = Some(BufferWithCursor::new(payload));
314 Ok(())
315 }
316
317 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
318 Box::pin(self.write_bytes()).as_mut().poll(cx)
320 }
321
322 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
323 self.poll_flush(cx)
325 }
326
327 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
328 self.poll_flush(cx)
330 }
331}