use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use ark_ec::CurveGroup;
use crossbeam::queue::SegQueue;
use futures::stream::SplitSink;
use futures::SinkExt;
use futures::{stream::SplitStream, StreamExt};
use kanal::AsyncReceiver as KanalReceiver;
use tokio::sync::broadcast::Receiver as BroadcastReceiver;
use tracing::log;
use crate::error::MpcNetworkError;
use crate::network::{MpcNetwork, NetworkOutbound};
use super::executor::ExecutorMessage;
use super::result::OpResult;
const ERR_STREAM_FINISHED_EARLY: &str = "stream finished early";
#[derive(Debug, Default)]
pub struct NetworkStats {
pub bytes_sent: AtomicUsize,
pub bytes_received: AtomicUsize,
pub messages_sent: AtomicUsize,
pub messages_received: AtomicUsize,
}
#[allow(unused)]
impl NetworkStats {
pub fn increment_bytes_sent(&self, bytes: usize) {
self.bytes_sent
.fetch_add(bytes, std::sync::atomic::Ordering::SeqCst);
}
pub fn increment_bytes_received(&self, bytes: usize) {
self.bytes_received
.fetch_add(bytes, std::sync::atomic::Ordering::SeqCst);
}
pub fn increment_messages_sent(&self) {
self.messages_sent
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
pub fn increment_messages_received(&self) {
self.messages_received
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
pub(crate) struct NetworkSender<C: CurveGroup, N: MpcNetwork<C>> {
outbound: KanalReceiver<NetworkOutbound<C>>,
result_queue: Arc<SegQueue<ExecutorMessage<C>>>,
network: N,
shutdown: BroadcastReceiver<()>,
}
impl<C: CurveGroup, N: MpcNetwork<C> + 'static> NetworkSender<C, N> {
pub fn new(
outbound: KanalReceiver<NetworkOutbound<C>>,
result_queue: Arc<SegQueue<ExecutorMessage<C>>>,
network: N,
shutdown: BroadcastReceiver<()>,
) -> Self {
NetworkSender {
outbound,
result_queue,
network,
shutdown,
}
}
pub async fn run(self) {
let NetworkSender {
outbound,
result_queue,
network,
mut shutdown,
} = self;
let stats = Arc::new(NetworkStats::default());
let (send, recv): (SplitSink<N, NetworkOutbound<C>>, SplitStream<N>) = network.split();
let read_loop_fut = tokio::spawn(Self::read_loop(recv, result_queue, stats.clone()));
let write_loop_fut = tokio::spawn(Self::write_loop(outbound, send, stats.clone()));
tokio::select! {
err = read_loop_fut => {
log::error!("error in `NetworkSender::read_loop`: {err:?}");
},
err = write_loop_fut => {
log::error!("error in `NetworkSender::write_loop`: {err:?}")
},
_ = shutdown.recv() => {
log::info!("received shutdown signal")
},
}
#[cfg(feature = "stats")]
println!("Network stats: {:#?}", stats);
}
async fn read_loop(
mut network_stream: SplitStream<N>,
result_queue: Arc<SegQueue<ExecutorMessage<C>>>,
#[allow(unused)] stats: Arc<NetworkStats>,
) -> MpcNetworkError {
while let Some(Ok(msg)) = network_stream.next().await {
#[cfg(feature = "stats")]
{
let n_bytes = serde_json::to_vec(&msg).unwrap().len();
stats.increment_bytes_received(n_bytes);
stats.increment_messages_received();
}
result_queue.push(ExecutorMessage::Result(OpResult {
id: msg.result_id,
value: msg.payload.into(),
}));
}
MpcNetworkError::RecvError(ERR_STREAM_FINISHED_EARLY.to_string())
}
async fn write_loop(
outbound_stream: KanalReceiver<NetworkOutbound<C>>,
mut network: SplitSink<N, NetworkOutbound<C>>,
#[allow(unused)] stats: Arc<NetworkStats>,
) -> MpcNetworkError {
while let Ok(msg) = outbound_stream.recv().await {
#[cfg(feature = "stats")]
{
let n_bytes = serde_json::to_vec(&msg).unwrap().len();
stats.increment_bytes_sent(n_bytes);
stats.increment_messages_sent();
}
if let Err(e) = network.send(msg).await {
log::error!("error sending outbound: {e:?}");
return e;
}
}
MpcNetworkError::RecvError(ERR_STREAM_FINISHED_EARLY.to_string())
}
}