use std::fmt;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::atomic::{AtomicU64, Ordering};
use futures::future::{try_join, try_join_all};
use serde::{Serialize, de::DeserializeOwned};
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::{
Mutex,
mpsc::{Receiver, Sender, channel},
};
#[cfg(not(target_arch = "wasm32"))]
use tracing::{Level, trace};
#[derive(Debug)]
pub struct Error {
pub phase: String,
pub reason: ErrorKind,
}
#[derive(Debug)]
pub enum ErrorKind {
RecvError(String),
SendError(String),
SerdeError(String),
InvalidLength,
}
pub trait Channel {
type SendError: fmt::Debug;
type RecvError: fmt::Debug;
#[allow(async_fn_in_trait)]
async fn send_bytes_to(
&self,
party: usize,
data: Vec<u8>,
phase: &str,
) -> Result<(), Self::SendError>;
#[allow(async_fn_in_trait)]
async fn recv_bytes_from(&self, party: usize, phase: &str) -> Result<Vec<u8>, Self::RecvError>;
}
pub(crate) async fn send_to<S: Serialize + std::fmt::Debug>(
channel: &impl Channel,
party: usize,
phase: &str,
msg: &[S],
) -> Result<(), Error> {
let data = bincode::serialize(&msg).map_err(|e| Error {
phase: format!("sending {phase}"),
reason: ErrorKind::SerdeError(format!("{e:?}")),
})?;
channel
.send_bytes_to(party, data, phase)
.await
.map_err(|e| Error {
phase: phase.to_string(),
reason: ErrorKind::SendError(format!("{e:?}")),
})?;
Ok(())
}
pub(crate) async fn recv_from<T: DeserializeOwned + std::fmt::Debug>(
channel: &impl Channel,
party: usize,
phase: &str,
) -> Result<Vec<T>, Error> {
let data = channel
.recv_bytes_from(party, phase)
.await
.map_err(|e| Error {
phase: phase.to_string(),
reason: ErrorKind::RecvError(format!("{e:?}")),
})?;
let msg: Vec<T> = bincode::deserialize(&data).map_err(|e| Error {
phase: format!("receiving {phase}"),
reason: ErrorKind::SerdeError(format!("{e:?}")),
})?;
Ok(msg)
}
pub(crate) async fn recv_vec_from<T: DeserializeOwned + std::fmt::Debug>(
channel: &impl Channel,
party: usize,
phase: &str,
len: usize,
) -> Result<Vec<T>, Error> {
let v: Vec<T> = recv_from(channel, party, phase).await?;
if v.len() == len {
Ok(v)
} else {
Err(Error {
phase: phase.to_string(),
reason: ErrorKind::InvalidLength,
})
}
}
pub(crate) async fn unverified_broadcast<T>(
channel: &impl Channel,
own_party: usize,
num_parties: usize,
phase: &str,
data: &[T],
) -> Result<Vec<Vec<T>>, Error>
where
T: Serialize + DeserializeOwned + std::fmt::Debug,
{
let expected_recv_len = data.len();
let send_fut = try_join_all(
(0..num_parties)
.filter(|p| *p != own_party)
.map(|p| send_to(channel, p, phase, data)),
);
let recv_fut = try_join_all((0..num_parties).map(async |p| {
if p != own_party {
recv_vec_from(channel, p, phase, expected_recv_len).await
} else {
Ok(vec![])
}
}));
let (_, responses) = try_join(send_fut, recv_fut).await?;
Ok(responses)
}
pub(crate) async fn scatter<T>(
channel: &impl Channel,
own_party: usize,
phase: &str,
data_per_party: &[Vec<T>],
) -> Result<Vec<Vec<T>>, Error>
where
T: Serialize + DeserializeOwned + std::fmt::Debug,
{
let num_parties = data_per_party.len();
let mut expected_recv_len = None;
for (p, data) in data_per_party.iter().enumerate() {
if p == own_party {
continue;
}
if expected_recv_len.is_none() && !data.is_empty() {
expected_recv_len = Some(data.len());
continue;
}
if let Some(len) = expected_recv_len
&& len != data.len()
{
return Err(Error {
phase: phase.to_string(),
reason: ErrorKind::InvalidLength,
});
}
}
let Some(expected_recv_len) = expected_recv_len else {
return Ok(vec![]);
};
let send_fut = try_join_all(
(0..num_parties)
.filter(|p| *p != own_party)
.map(|p| send_to(channel, p, phase, &data_per_party[p])),
);
let recv_fut = try_join_all((0..num_parties).map(async |p| {
if p != own_party {
recv_vec_from(channel, p, phase, expected_recv_len).await
} else {
Ok(vec![])
}
}));
let (_, responses) = try_join(send_fut, recv_fut).await?;
Ok(responses)
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug)]
#[allow(dead_code)]
#[doc(hidden)]
pub struct SimpleChannel {
s: Vec<Option<Sender<Vec<u8>>>>,
r: Vec<Option<Mutex<Receiver<Vec<u8>>>>>,
bytes_sent: AtomicU64,
}
#[cfg(not(target_arch = "wasm32"))]
impl SimpleChannel {
pub fn channels(parties: usize) -> Vec<Self> {
let buffer_capacity = 1024;
let mut channels = vec![];
for _ in 0..parties {
let mut s = vec![];
let mut r = vec![];
for _ in 0..parties {
s.push(None);
r.push(None);
}
let bytes_sent = AtomicU64::new(0);
channels.push(SimpleChannel { s, r, bytes_sent });
}
for a in 0..parties {
for b in 0..parties {
if a == b {
continue;
}
let (send_a_to_b, recv_a_to_b) = channel(buffer_capacity);
let (send_b_to_a, recv_b_to_a) = channel(buffer_capacity);
channels[a].s[b] = Some(send_a_to_b);
channels[b].s[a] = Some(send_b_to_a);
channels[a].r[b] = Some(Mutex::new(recv_b_to_a));
channels[b].r[a] = Some(Mutex::new(recv_a_to_b));
}
}
channels
}
pub fn bytes_sent(&self) -> u64 {
self.bytes_sent.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
#[cfg(not(target_arch = "wasm32"))]
#[doc(hidden)]
pub enum AsyncRecvError {
Closed,
TimeoutElapsed,
}
#[cfg(not(target_arch = "wasm32"))]
impl Channel for SimpleChannel {
type SendError = tokio::sync::mpsc::error::SendError<Vec<u8>>;
type RecvError = AsyncRecvError;
#[tracing::instrument(level = Level::TRACE, skip(self, msg))]
async fn send_bytes_to(
&self,
p: usize,
msg: Vec<u8>,
phase: &str,
) -> Result<(), tokio::sync::mpsc::error::SendError<Vec<u8>>> {
self.bytes_sent
.fetch_add(msg.len() as u64, Ordering::Relaxed);
let mb = msg.len() as f64 / 1024.0 / 1024.0;
trace!(size = mb, "Sending msg");
self.s[p]
.as_ref()
.unwrap_or_else(|| panic!("No sender for party {p}"))
.send(msg)
.await
}
#[tracing::instrument(level = Level::TRACE, skip(self), fields(phase = ?_phase))]
async fn recv_bytes_from(&self, p: usize, _phase: &str) -> Result<Vec<u8>, AsyncRecvError> {
let mut r = self.r[p]
.as_ref()
.unwrap_or_else(|| panic!("No receiver for party {p}"))
.lock()
.await;
let data = r.recv();
match tokio::time::timeout(std::time::Duration::from_secs(10 * 60), data).await {
Ok(Some(data)) => {
let mb = data.len() as f64 / 1024.0 / 1024.0;
trace!(size = mb, "Received data");
Ok(data)
}
Ok(None) => Err(AsyncRecvError::Closed),
Err(_) => Err(AsyncRecvError::TimeoutElapsed),
}
}
}