use std::fmt;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::atomic::{AtomicU64, Ordering};
use futures::future::{try_join, try_join_all};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::{
Mutex,
mpsc::{Receiver, Sender, channel},
};
#[cfg(not(target_arch = "wasm32"))]
use tracing::{Level, trace};
#[derive(Debug)]
pub struct Error {
pub phase: String,
pub reason: ErrorKind,
}
#[derive(Debug)]
pub enum ErrorKind {
RecvError(String),
SendError(String),
SerdeError(String),
InvalidLength,
}
#[derive(Debug, Serialize)]
struct SendChunk<'a, T> {
chunk: &'a [T],
remaining_chunks: usize,
}
#[derive(Debug, Deserialize)]
struct RecvChunk<T> {
chunk: Vec<T>,
remaining_chunks: usize,
}
#[derive(Debug, Clone)]
pub struct SendInfo {
phase: String,
current_msg: usize,
remaining_msgs: usize,
}
impl SendInfo {
pub fn phase(&self) -> &str {
&self.phase
}
pub fn sent(&self) -> usize {
self.current_msg + 1
}
pub fn remaining(&self) -> usize {
self.remaining_msgs
}
pub fn total(&self) -> usize {
self.sent() + self.remaining()
}
}
#[derive(Debug, Clone)]
pub struct RecvInfo {
phase: String,
current_msg: usize,
remaining_msgs: Option<usize>,
}
impl RecvInfo {
pub fn phase(&self) -> &str {
&self.phase
}
pub fn sent(&self) -> usize {
self.current_msg + 1
}
pub fn remaining(&self) -> Option<usize> {
self.remaining_msgs
}
pub fn total(&self) -> Option<usize> {
self.remaining().map(|remaining| self.sent() + remaining)
}
}
pub trait Channel {
type SendError: fmt::Debug;
type RecvError: fmt::Debug;
#[allow(async_fn_in_trait)]
async fn send_bytes_to(
&self,
party: usize,
chunk: Vec<u8>,
info: SendInfo,
) -> Result<(), Self::SendError>;
#[allow(async_fn_in_trait)]
async fn recv_bytes_from(
&self,
party: usize,
info: RecvInfo,
) -> Result<Vec<u8>, Self::RecvError>;
}
pub(crate) async fn send_to<S: Serialize + std::fmt::Debug>(
channel: &impl Channel,
party: usize,
phase: &str,
msg: &[S],
) -> Result<(), Error> {
let chunk_size = 5_000_000;
let mut chunks: Vec<_> = msg.chunks(chunk_size).collect();
if chunks.is_empty() {
chunks.push(&[]);
}
let length = chunks.len();
for (i, chunk) in chunks.into_iter().enumerate() {
let remaining_chunks = length - i - 1;
let chunk = SendChunk {
chunk,
remaining_chunks,
};
let chunk = bincode::serialize(&chunk).map_err(|e| Error {
phase: format!("sending {phase}"),
reason: ErrorKind::SerdeError(format!("{e:?}")),
})?;
let info = SendInfo {
phase: phase.to_string(),
current_msg: i,
remaining_msgs: remaining_chunks,
};
channel
.send_bytes_to(party, chunk, info)
.await
.map_err(|e| Error {
phase: phase.to_string(),
reason: ErrorKind::SendError(format!("{e:?}")),
})?;
}
Ok(())
}
pub(crate) async fn recv_from<T: DeserializeOwned + std::fmt::Debug>(
channel: &impl Channel,
party: usize,
phase: &str,
) -> Result<Vec<T>, Error> {
let mut msg = vec![];
let mut i = 0;
let mut remaining = None;
loop {
let info = RecvInfo {
phase: phase.to_string(),
current_msg: i,
remaining_msgs: remaining,
};
let chunk = channel
.recv_bytes_from(party, info)
.await
.map_err(|e| Error {
phase: phase.to_string(),
reason: ErrorKind::RecvError(format!("{e:?}")),
})?;
let RecvChunk {
chunk,
remaining_chunks,
}: RecvChunk<T> = bincode::deserialize(&chunk).map_err(|e| Error {
phase: format!("receiving {phase}"),
reason: ErrorKind::SerdeError(format!("{e:?}")),
})?;
msg.extend(chunk);
if remaining_chunks == 0 {
return Ok(msg);
}
remaining = Some(remaining_chunks);
i += 1;
}
}
pub(crate) async fn recv_vec_from<T: DeserializeOwned + std::fmt::Debug>(
channel: &impl Channel,
party: usize,
phase: &str,
len: usize,
) -> Result<Vec<T>, Error> {
let v: Vec<T> = recv_from(channel, party, phase).await?;
if v.len() == len {
Ok(v)
} else {
Err(Error {
phase: phase.to_string(),
reason: ErrorKind::InvalidLength,
})
}
}
pub(crate) async fn unverified_broadcast<T>(
channel: &impl Channel,
own_party: usize,
num_parties: usize,
phase: &str,
data: &[T],
) -> Result<Vec<Vec<T>>, Error>
where
T: Serialize + DeserializeOwned + std::fmt::Debug,
{
let expected_recv_len = data.len();
let send_fut = try_join_all(
(0..num_parties)
.filter(|p| *p != own_party)
.map(|p| send_to(channel, p, phase, data)),
);
let recv_fut = try_join_all((0..num_parties).map(async |p| {
if p != own_party {
recv_vec_from(channel, p, phase, expected_recv_len).await
} else {
Ok(vec![])
}
}));
let (_, responses) = try_join(send_fut, recv_fut).await?;
Ok(responses)
}
pub(crate) async fn scatter<T>(
channel: &impl Channel,
own_party: usize,
phase: &str,
data_per_party: &[Vec<T>],
) -> Result<Vec<Vec<T>>, Error>
where
T: Serialize + DeserializeOwned + std::fmt::Debug,
{
let num_parties = data_per_party.len();
let mut expected_recv_len = None;
for (p, data) in data_per_party.iter().enumerate() {
if p == own_party {
continue;
}
if expected_recv_len.is_none() && !data.is_empty() {
expected_recv_len = Some(data.len());
continue;
}
if let Some(len) = expected_recv_len
&& len != data.len()
{
return Err(Error {
phase: phase.to_string(),
reason: ErrorKind::InvalidLength,
});
}
}
let Some(expected_recv_len) = expected_recv_len else {
return Ok(vec![]);
};
let send_fut = try_join_all(
(0..num_parties)
.filter(|p| *p != own_party)
.map(|p| send_to(channel, p, phase, &data_per_party[p])),
);
let recv_fut = try_join_all((0..num_parties).map(async |p| {
if p != own_party {
recv_vec_from(channel, p, phase, expected_recv_len).await
} else {
Ok(vec![])
}
}));
let (_, responses) = try_join(send_fut, recv_fut).await?;
Ok(responses)
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug)]
#[allow(dead_code)]
#[doc(hidden)]
pub struct SimpleChannel {
s: Vec<Option<Sender<Vec<u8>>>>,
r: Vec<Option<Mutex<Receiver<Vec<u8>>>>>,
bytes_sent: AtomicU64,
}
#[cfg(not(target_arch = "wasm32"))]
impl SimpleChannel {
pub fn channels(parties: usize) -> Vec<Self> {
let buffer_capacity = 1024;
let mut channels = vec![];
for _ in 0..parties {
let mut s = vec![];
let mut r = vec![];
for _ in 0..parties {
s.push(None);
r.push(None);
}
let bytes_sent = AtomicU64::new(0);
channels.push(SimpleChannel { s, r, bytes_sent });
}
for a in 0..parties {
for b in 0..parties {
if a == b {
continue;
}
let (send_a_to_b, recv_a_to_b) = channel(buffer_capacity);
let (send_b_to_a, recv_b_to_a) = channel(buffer_capacity);
channels[a].s[b] = Some(send_a_to_b);
channels[b].s[a] = Some(send_b_to_a);
channels[a].r[b] = Some(Mutex::new(recv_b_to_a));
channels[b].r[a] = Some(Mutex::new(recv_a_to_b));
}
}
channels
}
pub fn bytes_sent(&self) -> u64 {
self.bytes_sent.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
#[cfg(not(target_arch = "wasm32"))]
#[doc(hidden)]
pub enum AsyncRecvError {
Closed,
TimeoutElapsed,
}
#[cfg(not(target_arch = "wasm32"))]
impl Channel for SimpleChannel {
type SendError = tokio::sync::mpsc::error::SendError<Vec<u8>>;
type RecvError = AsyncRecvError;
#[tracing::instrument(level = Level::TRACE, skip(self, msg))]
async fn send_bytes_to(
&self,
p: usize,
msg: Vec<u8>,
info: SendInfo,
) -> Result<(), tokio::sync::mpsc::error::SendError<Vec<u8>>> {
self.bytes_sent
.fetch_add(msg.len() as u64, Ordering::Relaxed);
let mb = msg.len() as f64 / 1024.0 / 1024.0;
let i = info.sent();
if i == 1 {
trace!(size = mb, "Sending msg");
} else {
trace!(size = mb, " (continued sending msg)");
}
self.s[p]
.as_ref()
.unwrap_or_else(|| panic!("No sender for party {p}"))
.send(msg)
.await
}
#[tracing::instrument(level = Level::TRACE, skip(self), fields(info = ?_info))]
async fn recv_bytes_from(&self, p: usize, _info: RecvInfo) -> Result<Vec<u8>, AsyncRecvError> {
let mut r = self.r[p]
.as_ref()
.unwrap_or_else(|| panic!("No receiver for party {p}"))
.lock()
.await;
let chunk = r.recv();
match tokio::time::timeout(std::time::Duration::from_secs(10 * 60), chunk).await {
Ok(Some(chunk)) => {
let mb = chunk.len() as f64 / 1024.0 / 1024.0;
trace!(size = mb, "Received chunk");
Ok(chunk)
}
Ok(None) => Err(AsyncRecvError::Closed),
Err(_) => Err(AsyncRecvError::TimeoutElapsed),
}
}
}