mod cert_verifier;
mod config;
pub mod dummy_network;
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use curve25519_dalek::{
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
};
use quinn::{Endpoint, RecvStream, SendStream};
use std::{convert::TryInto, net::SocketAddr};
use tracing::log;
use crate::error::{MpcNetworkError, SetupError};
pub type PartyId = u64;
const BYTES_PER_POINT: usize = 32;
const BYTES_PER_SCALAR: usize = 32;
const BYTES_PER_U64: usize = 8;
fn scalars_to_bytes(scalars: &[Scalar]) -> Bytes {
let mut payload = BytesMut::new();
scalars.iter().for_each(|scalar| {
let bytes = scalar.to_bytes();
payload.extend_from_slice(&bytes);
});
payload.freeze()
}
fn bytes_to_scalars(bytes: &[u8]) -> Result<Vec<Scalar>, MpcNetworkError> {
bytes
.chunks(BYTES_PER_SCALAR)
.map(|bytes_chunk| {
Scalar::from_canonical_bytes(
bytes_chunk
.try_into()
.expect("unexpected number of bytes per chunk"),
)
.ok_or(MpcNetworkError::SerializationError)
})
.collect::<Result<Vec<Scalar>, MpcNetworkError>>()
}
fn points_to_bytes(points: &[RistrettoPoint]) -> Bytes {
let mut payload = BytesMut::new();
points.iter().for_each(|point| {
let bytes = point.compress().to_bytes();
payload.extend_from_slice(&bytes);
});
payload.freeze()
}
fn bytes_to_points(bytes: &[u8]) -> Result<Vec<RistrettoPoint>, MpcNetworkError> {
bytes
.chunks(BYTES_PER_POINT)
.map(|bytes_chunk| {
CompressedRistretto(
bytes_chunk
.try_into()
.expect("unexpected number of bytes per chunk"),
)
.decompress()
.ok_or(MpcNetworkError::SerializationError)
})
.collect::<Result<Vec<RistrettoPoint>, MpcNetworkError>>()
}
#[async_trait]
pub trait MpcNetwork {
fn party_id(&self) -> u64;
fn am_king(&self) -> bool {
self.party_id() == 0
}
async fn send_bytes(&mut self, bytes: &[u8]) -> Result<(), MpcNetworkError>;
async fn receive_bytes(&mut self) -> Result<Vec<u8>, MpcNetworkError>;
async fn send_scalars(&mut self, scalars: &[Scalar]) -> Result<(), MpcNetworkError>;
async fn send_single_scalar(&mut self, scalar: Scalar) -> Result<(), MpcNetworkError> {
self.send_scalars(&[scalar]).await
}
async fn receive_scalars(
&mut self,
num_expected: usize,
) -> Result<Vec<Scalar>, MpcNetworkError>;
async fn receive_single_scalar(&mut self) -> Result<Scalar, MpcNetworkError> {
Ok(self.receive_scalars(1).await?[0])
}
async fn broadcast_scalars(
&mut self,
scalars: &[Scalar],
) -> Result<Vec<Scalar>, MpcNetworkError>;
async fn broadcast_single_scalar(&mut self, scalar: Scalar) -> Result<Scalar, MpcNetworkError> {
Ok(self.broadcast_scalars(&[scalar]).await?[0])
}
async fn send_points(&mut self, points: &[RistrettoPoint]) -> Result<(), MpcNetworkError>;
async fn send_single_point(&mut self, point: RistrettoPoint) -> Result<(), MpcNetworkError> {
Ok(self.send_points(&[point]).await?)
}
async fn receive_points(
&mut self,
num_expected: usize,
) -> Result<Vec<RistrettoPoint>, MpcNetworkError>;
async fn receive_single_point(&mut self) -> Result<RistrettoPoint, MpcNetworkError> {
Ok(self.receive_points(1).await?[0])
}
async fn broadcast_points(
&mut self,
points: &[RistrettoPoint],
) -> Result<Vec<RistrettoPoint>, MpcNetworkError>;
async fn broadcast_single_point(
&mut self,
point: RistrettoPoint,
) -> Result<RistrettoPoint, MpcNetworkError> {
Ok(self.broadcast_points(&[point]).await?[0])
}
async fn close(&mut self) -> Result<(), MpcNetworkError>;
}
#[derive(Clone, Debug)]
pub enum ReadWriteOrder {
ReadFirst,
WriteFirst,
}
#[derive(Debug)]
pub struct QuicTwoPartyNet {
party_id: PartyId,
connected: bool,
local_addr: SocketAddr,
peer_addr: SocketAddr,
send_stream: Option<SendStream>,
recv_stream: Option<RecvStream>,
}
#[allow(clippy::redundant_closure)] impl<'a> QuicTwoPartyNet {
pub fn new(party_id: PartyId, local_addr: SocketAddr, peer_addr: SocketAddr) -> Self {
Self {
party_id,
local_addr,
peer_addr,
connected: false,
send_stream: None,
recv_stream: None,
}
}
fn read_order(&self) -> ReadWriteOrder {
if self.am_king() {
ReadWriteOrder::WriteFirst
} else {
ReadWriteOrder::ReadFirst
}
}
fn assert_connected(&self) -> Result<(), MpcNetworkError> {
if self.connected {
Ok(())
} else {
Err(MpcNetworkError::NetworkUninitialized)
}
}
pub async fn connect(&mut self) -> Result<(), MpcNetworkError> {
let (client_config, server_config) =
config::build_configs().map_err(|err| MpcNetworkError::ConnectionSetupError(err))?;
let mut local_endpoint = Endpoint::server(server_config, self.local_addr).map_err(|e| {
log::error!("error setting up quinn server: {e:?}");
MpcNetworkError::ConnectionSetupError(SetupError::ServerSetupError)
})?;
local_endpoint.set_default_client_config(client_config);
let connection = {
if self.am_king() {
local_endpoint
.connect(self.peer_addr, config::SERVER_NAME)
.map_err(|err| {
log::error!("error setting up quic endpoint connection: {err}");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectError(err))
})?
.await
.map_err(|err| {
log::error!("error connecting to the remote quic endpoint: {err}");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
})?
} else {
local_endpoint
.accept()
.await
.ok_or_else(|| {
log::error!("no incoming connection while awaiting quic endpoint");
MpcNetworkError::ConnectionSetupError(SetupError::NoIncomingConnection)
})?
.await
.map_err(|err| {
log::error!("error while establishing remote connection as listener");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
})?
}
};
let (send, recv) = {
if self.am_king() {
connection.open_bi().await.map_err(|err| {
log::error!("error opening bidirectional stream: {err}");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
})?
} else {
connection.accept_bi().await.map_err(|err| {
log::error!("error accepting bidirectional stream: {err}");
MpcNetworkError::ConnectionSetupError(SetupError::ConnectionError(err))
})?
}
};
self.connected = true;
self.send_stream = Some(send);
self.recv_stream = Some(recv);
Ok(())
}
async fn write_bytes(&mut self, payload: &[u8]) -> Result<(), MpcNetworkError> {
self.send_stream
.as_mut()
.unwrap()
.write_all(payload)
.await
.map_err(|_| MpcNetworkError::SendError)
}
async fn read_bytes(&mut self, num_bytes: usize) -> Result<Vec<u8>, MpcNetworkError> {
let mut read_buffer = vec![0u8; num_bytes];
self.recv_stream
.as_mut()
.unwrap()
.read_exact(&mut read_buffer)
.await
.map_err(|_| MpcNetworkError::RecvError)?;
Ok(read_buffer.to_vec())
}
async fn write_then_read_bytes(
&mut self,
order: ReadWriteOrder,
payload: &[u8],
) -> Result<Vec<u8>, MpcNetworkError> {
let payload_length = payload.len();
Ok(match order {
ReadWriteOrder::ReadFirst => {
let bytes_read = self.read_bytes(payload_length).await?;
self.write_bytes(payload).await?;
bytes_read
}
ReadWriteOrder::WriteFirst => {
self.write_bytes(payload).await?;
self.read_bytes(payload_length).await?
}
})
}
}
#[async_trait]
impl MpcNetwork for QuicTwoPartyNet {
fn party_id(&self) -> u64 {
self.party_id
}
async fn send_bytes(&mut self, bytes: &[u8]) -> Result<(), MpcNetworkError> {
self.assert_connected()?;
let length = (bytes.len() as u64).to_le_bytes();
self.write_bytes(&length).await?;
self.write_bytes(bytes).await
}
async fn receive_bytes(&mut self) -> Result<Vec<u8>, MpcNetworkError> {
self.assert_connected()?;
let length = u64::from_le_bytes(self.read_bytes(BYTES_PER_U64).await?.try_into().unwrap());
self.read_bytes(length as usize).await
}
async fn send_scalars(&mut self, scalars: &[Scalar]) -> Result<(), MpcNetworkError> {
self.assert_connected()?;
let payload = scalars_to_bytes(scalars);
self.write_bytes(&payload).await?;
Ok(())
}
async fn receive_scalars(
&mut self,
num_scalars: usize,
) -> Result<Vec<Scalar>, MpcNetworkError> {
self.assert_connected()?;
let bytes_read = self.read_bytes(num_scalars * BYTES_PER_SCALAR).await?;
bytes_to_scalars(&bytes_read)
}
async fn broadcast_scalars(
&mut self,
scalars: &[Scalar],
) -> Result<Vec<Scalar>, MpcNetworkError> {
self.assert_connected()?;
let payload = scalars_to_bytes(scalars);
let read_buffer = self
.write_then_read_bytes(self.read_order(), &payload)
.await?;
bytes_to_scalars(&read_buffer)
}
async fn send_points(&mut self, points: &[RistrettoPoint]) -> Result<(), MpcNetworkError> {
let payload = points_to_bytes(points);
self.write_bytes(&payload).await
}
async fn receive_points(
&mut self,
num_points: usize,
) -> Result<Vec<RistrettoPoint>, MpcNetworkError> {
let read_buffer = self.read_bytes(BYTES_PER_POINT * num_points).await?;
bytes_to_points(&read_buffer)
}
async fn broadcast_points(
&mut self,
points: &[RistrettoPoint],
) -> Result<Vec<RistrettoPoint>, MpcNetworkError> {
self.assert_connected()?;
let payload = points_to_bytes(points);
let read_buffer = self
.write_then_read_bytes(self.read_order(), &payload)
.await?;
bytes_to_points(&read_buffer)
}
async fn close(&mut self) -> Result<(), MpcNetworkError> {
self.assert_connected()?;
self.send_stream
.as_mut()
.unwrap()
.finish()
.await
.map_err(|_| MpcNetworkError::ConnectionTeardownError)
}
}
#[cfg(test)]
mod test {
use std::net::SocketAddr;
use curve25519_dalek::ristretto::RistrettoPoint;
use rand_core::OsRng;
use tokio;
use super::{MpcNetwork, QuicTwoPartyNet};
#[tokio::test]
async fn test_errors() {
let socket_addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
let mut net = QuicTwoPartyNet::new(0, socket_addr, socket_addr);
assert!(net.broadcast_points(&[]).await.is_err());
let mut rng = OsRng {};
assert!(net
.broadcast_single_point(RistrettoPoint::random(&mut rng))
.await
.is_err())
}
}