use crate::{PartyId, Result, SessionId};
use serde::{Serialize, de::DeserializeOwned};
pub use async_trait::async_trait;
pub mod memory;
pub use memory::MemoryRelay;
#[async_trait]
pub trait Relay: Send + Sync {
async fn broadcast<T: Serialize + Send + Sync>(
&self,
session_id: &SessionId,
round: u32,
message: &T,
) -> Result<()>;
async fn send_direct<T: Serialize + Send + Sync>(
&self,
session_id: &SessionId,
round: u32,
to: PartyId,
message: &T,
) -> Result<()>;
async fn collect_broadcasts<T: DeserializeOwned + Send>(
&self,
session_id: &SessionId,
round: u32,
count: usize,
) -> Result<Vec<T>>;
async fn collect_direct<T: DeserializeOwned + Send>(
&self,
session_id: &SessionId,
round: u32,
my_id: PartyId,
count: usize,
) -> Result<Vec<T>>;
}
#[async_trait]
pub trait RelayExt: Relay {
async fn broadcast_with_timeout<T: Serialize + Send + Sync>(
&self,
session_id: &SessionId,
round: u32,
message: &T,
timeout: std::time::Duration,
) -> Result<()>;
async fn collect_broadcasts_with_timeout<T: DeserializeOwned + Send>(
&self,
session_id: &SessionId,
round: u32,
count: usize,
timeout: std::time::Duration,
) -> Result<Vec<T>>;
}
#[async_trait]
impl<R: Relay + ?Sized> RelayExt for R {
async fn broadcast_with_timeout<T: Serialize + Send + Sync>(
&self,
session_id: &SessionId,
round: u32,
message: &T,
timeout: std::time::Duration,
) -> Result<()> {
tokio::time::timeout(timeout, self.broadcast(session_id, round, message))
.await
.map_err(|_| crate::Error::Timeout("broadcast".to_string()))?
}
async fn collect_broadcasts_with_timeout<T: DeserializeOwned + Send>(
&self,
session_id: &SessionId,
round: u32,
count: usize,
timeout: std::time::Duration,
) -> Result<Vec<T>> {
tokio::time::timeout(timeout, self.collect_broadcasts(session_id, round, count))
.await
.map_err(|_| crate::Error::Timeout("collect_broadcasts".to_string()))?
}
}