use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use ark_ec::CurveGroup;
use async_trait::async_trait;
use futures::{future::pending, Future, Sink, Stream};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use crate::{error::MpcNetworkError, PARTY0};
use super::{MpcNetwork, NetworkOutbound, PartyId};
#[derive(Default)]
pub struct NoRecvNetwork<C: CurveGroup>(PhantomData<C>);
#[async_trait]
impl<C: CurveGroup> MpcNetwork<C> for NoRecvNetwork<C> {
fn party_id(&self) -> PartyId {
PARTY0
}
async fn close(&mut self) -> Result<(), MpcNetworkError> {
Ok(())
}
}
impl<C: CurveGroup> Stream for NoRecvNetwork<C> {
type Item = Result<NetworkOutbound<C>, MpcNetworkError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Box::pin(pending()).as_mut().poll(cx)
}
}
impl<C: CurveGroup> Sink<NetworkOutbound<C>> for NoRecvNetwork<C> {
type Error = MpcNetworkError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, _item: NetworkOutbound<C>) -> Result<(), Self::Error> {
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
pub struct UnboundedDuplexStream<C: CurveGroup> {
send: UnboundedSender<NetworkOutbound<C>>,
recv: UnboundedReceiver<NetworkOutbound<C>>,
}
impl<C: CurveGroup> UnboundedDuplexStream<C> {
pub fn new_duplex_pair() -> (Self, Self) {
let (send1, recv1) = unbounded_channel();
let (send2, recv2) = unbounded_channel();
(
Self {
send: send1,
recv: recv2,
},
Self {
send: send2,
recv: recv1,
},
)
}
pub fn send(&mut self, msg: NetworkOutbound<C>) {
self.send.send(msg).unwrap();
}
pub async fn recv(&mut self) -> NetworkOutbound<C> {
self.recv.recv().await.unwrap()
}
}
pub struct MockNetwork<C: CurveGroup> {
party_id: PartyId,
mock_conn: UnboundedDuplexStream<C>,
}
impl<C: CurveGroup> MockNetwork<C> {
pub fn new(party_id: PartyId, stream: UnboundedDuplexStream<C>) -> Self {
Self {
party_id,
mock_conn: stream,
}
}
}
#[async_trait]
impl<C: CurveGroup> MpcNetwork<C> for MockNetwork<C> {
fn party_id(&self) -> PartyId {
self.party_id
}
async fn close(&mut self) -> Result<(), MpcNetworkError> {
Ok(())
}
}
impl<C: CurveGroup> Stream for MockNetwork<C> {
type Item = Result<NetworkOutbound<C>, MpcNetworkError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Box::pin(self.mock_conn.recv())
.as_mut()
.poll(cx)
.map(|value| Some(Ok(value)))
}
}
impl<C: CurveGroup> Sink<NetworkOutbound<C>> for MockNetwork<C> {
type Error = MpcNetworkError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(mut self: Pin<&mut Self>, item: NetworkOutbound<C>) -> Result<(), Self::Error> {
self.mock_conn.send(item);
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}