use ark_ec::CurveGroup;
use async_trait::async_trait;
use futures::{Future, Sink, Stream};
use quinn::{Endpoint, RecvStream, SendStream};
use std::{
marker::PhantomData,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use tracing::log;
use crate::{
error::{MpcNetworkError, SetupError},
PARTY0,
};
use super::{config, stream_buffer::BufferWithCursor, MpcNetwork, NetworkOutbound, PartyId};
const BYTES_PER_U64: usize = 8;
const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";
const ERR_READ_MESSAGE_LENGTH: &str = "error reading message length from stream";
const ERR_SEND_BUFFER_FULL: &str = "send buffer full";
pub struct QuicTwoPartyNet<C: CurveGroup> {
party_id: PartyId,
connected: bool,
local_addr: SocketAddr,
peer_addr: SocketAddr,
buffered_message_length: Option<u64>,
buffered_inbound: Option<BufferWithCursor>,
buffered_outbound: Option<BufferWithCursor>,
send_stream: Option<SendStream>,
recv_stream: Option<RecvStream>,
_phantom: PhantomData<C>,
}
#[allow(clippy::redundant_closure)] impl<'a, C: CurveGroup> QuicTwoPartyNet<C> {
pub fn new(party_id: PartyId, local_addr: SocketAddr, peer_addr: SocketAddr) -> Self {
Self {
party_id,
local_addr,
peer_addr,
connected: false,
buffered_message_length: None,
buffered_inbound: None,
buffered_outbound: None,
send_stream: None,
recv_stream: None,
_phantom: PhantomData,
}
}
fn local_party0(&self) -> bool {
self.party_id == PARTY0
}
fn assert_connected(&self) -> Result<(), MpcNetworkError> {
if self.connected {
Ok(())
} else {
Err(MpcNetworkError::NetworkUninitialized)
}
}
pub async fn connect(&mut self) -> Result<(), MpcNetworkError> {
let (client_config, server_config) =
config::build_configs().map_err(|err| MpcNetworkError::ConnectionSetupError(err))?;
let mut local_endpoint = Endpoint::server(server_config, self.local_addr).map_err(|e| {
log::error!("error setting up quinn server: {e:?}");
MpcNetworkError::ConnectionSetupError(SetupError::ServerSetupError)
})?;
local_endpoint.set_default_client_config(client_config);
let connection = {
if self.local_party0() {
local_endpoint
.connect(self.peer_addr, config::SERVER_NAME)
.map_err(|err| {
log::error!("error setting up quic endpoint connection: {err}");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectError(err))
})?
.await
.map_err(|err| {
log::error!("error connecting to the remote quic endpoint: {err}");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
})?
} else {
local_endpoint
.accept()
.await
.ok_or_else(|| {
log::error!("no incoming connection while awaiting quic endpoint");
MpcNetworkError::ConnectionSetupError(SetupError::NoIncomingConnection)
})?
.await
.map_err(|err| {
log::error!("error while establishing remote connection as listener");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
})?
}
};
let (send, recv) = {
if self.local_party0() {
connection.open_bi().await.map_err(|err| {
log::error!("error opening bidirectional stream: {err}");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
})?
} else {
connection.accept_bi().await.map_err(|err| {
log::error!("error accepting bidirectional stream: {err}");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
})?
}
};
self.connected = true;
self.send_stream = Some(send);
self.recv_stream = Some(recv);
Ok(())
}
async fn write_bytes(&mut self) -> Result<(), MpcNetworkError> {
if self.buffered_outbound.is_none() {
return Ok(());
}
let buf = self.buffered_outbound.as_mut().unwrap();
while !buf.is_depleted() {
let bytes_written = self
.send_stream
.as_mut()
.unwrap()
.write(buf.get_remaining())
.await
.map_err(|e| MpcNetworkError::SendError(e.to_string()))?;
buf.advance_cursor(bytes_written);
}
self.buffered_outbound = None;
Ok(())
}
async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
if self.buffered_inbound.is_none() {
self.buffered_inbound = Some(BufferWithCursor::new(vec![0u8; num_bytes]));
}
let read_buffer = self.buffered_inbound.as_mut().unwrap();
while !read_buffer.is_depleted() {
let bytes_read = self
.recv_stream
.as_mut()
.unwrap()
.read(read_buffer.get_remaining())
.await
.map_err(|e| MpcNetworkError::RecvError(e.to_string()))?
.ok_or(MpcNetworkError::RecvError(
ERR_STREAM_FINISHED_EARLY.to_string(),
))?;
read_buffer.advance_cursor(bytes_read);
}
Ok(self.buffered_inbound.take().unwrap().into_vec())
}
async fn read_message_length(&mut self) -> Result<u64, MpcNetworkError> {
let read_buffer = self.read_bytes(BYTES_PER_U64).await?;
Ok(u64::from_le_bytes(read_buffer.try_into().map_err(
|_| MpcNetworkError::SerializationError(ERR_READ_MESSAGE_LENGTH.to_string()),
)?))
}
async fn receive_message(&mut self) -> Result<NetworkOutbound<C>, MpcNetworkError> {
if self.buffered_message_length.is_none() {
self.buffered_message_length = Some(self.read_message_length().await?);
}
let len = self.buffered_message_length.unwrap();
let bytes = self.read_bytes(len as usize).await?;
self.buffered_message_length = None;
serde_json::from_slice(&bytes)
.map_err(|err| MpcNetworkError::SerializationError(err.to_string()))
}
}
#[async_trait]
impl<C: CurveGroup> MpcNetwork<C> for QuicTwoPartyNet<C>
where
C: Unpin,
{
fn party_id(&self) -> PartyId {
self.party_id
}
async fn close(&mut self) -> Result<(), MpcNetworkError> {
self.assert_connected()?;
self.send_stream
.as_mut()
.unwrap()
.finish()
.await
.map_err(|_| MpcNetworkError::ConnectionTeardownError)
}
}
impl<C: CurveGroup> Stream for QuicTwoPartyNet<C>
where
C: Unpin,
{
type Item = Result<NetworkOutbound<C>, MpcNetworkError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Box::pin(self.get_mut().receive_message())
.as_mut()
.poll(cx)
.map(Some)
}
}
impl<C: CurveGroup> Sink<NetworkOutbound<C>> for QuicTwoPartyNet<C>
where
C: Unpin,
{
type Error = MpcNetworkError;
fn start_send(self: Pin<&mut Self>, msg: NetworkOutbound<C>) -> Result<(), Self::Error> {
if !self.connected {
return Err(MpcNetworkError::NetworkUninitialized);
}
if self.buffered_outbound.is_some() {
return Err(MpcNetworkError::SendError(ERR_SEND_BUFFER_FULL.to_string()));
}
let bytes = serde_json::to_vec(&msg)
.map_err(|err| MpcNetworkError::SerializationError(err.to_string()))?;
let mut payload = (bytes.len() as u64).to_le_bytes().to_vec();
payload.extend_from_slice(&bytes);
self.get_mut().buffered_outbound = Some(BufferWithCursor::new(payload));
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Box::pin(self.write_bytes()).as_mut().poll(cx)
}
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
}