use std::marker::PhantomData;
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use tracing::debug;
use crate::{
errorhandling::MalstromFatal,
types::{OperatorId, WorkerId},
};
pub trait Distributable: Serialize + DeserializeOwned {}
impl<T> Distributable for T where T: Serialize + DeserializeOwned {}
pub trait OperatorOperatorComm {
fn operator_to_operator(
&self,
to_worker: WorkerId,
operator: OperatorId,
) -> Result<Box<dyn BiStreamTransport>, CommunicationBackendError>;
}
pub trait CoordinatorWorkerComm {
fn coordinator_to_worker(
&self,
worker_id: WorkerId,
) -> Result<Box<dyn BiStreamTransport>, CommunicationBackendError>;
}
pub trait WorkerCoordinatorComm {
fn worker_to_coordinator(
&self,
) -> Result<Box<dyn BiStreamTransport>, CommunicationBackendError>;
}
#[async_trait]
pub trait BiStreamTransport: Send + Sync {
fn send(&self, msg: Vec<u8>) -> Result<(), TransportError>;
fn recv(&self) -> Result<Option<Vec<u8>>, TransportError>;
async fn recv_async(&self) -> Result<Vec<u8>, TransportError>;
fn recv_all<'a>(&'a self) -> Box<dyn Iterator<Item = Result<Vec<u8>, TransportError>> + 'a> {
Box::new(std::iter::from_fn(|| self.recv().transpose()))
}
}
pub struct CommunicationClient<TSend, TRecv> {
transport: Box<dyn BiStreamTransport>,
message_type: PhantomData<(TSend, TRecv)>,
}
pub type BiCommunicationClient<T> = CommunicationClient<T, T>;
impl<T> CommunicationClient<T, T>
where
T: Distributable,
{
pub(crate) fn new(
to_worker: WorkerId,
operator: OperatorId,
backend: &dyn OperatorOperatorComm,
) -> Result<Self, CommunicationBackendError> {
debug!(
message = "Creating operator-operator communication client",
?to_worker,
?operator
);
let transport = backend.operator_to_operator(to_worker, operator)?;
Ok(Self {
transport,
message_type: PhantomData,
})
}
}
impl<TSend, TRecv> CommunicationClient<TSend, TRecv> {
pub(crate) fn coordinator_to_worker(
worker_id: WorkerId,
backend: &dyn CoordinatorWorkerComm,
) -> Result<Self, CommunicationBackendError> {
let transport = backend.coordinator_to_worker(worker_id)?;
Ok(Self {
transport,
message_type: PhantomData,
})
}
pub(crate) fn worker_to_coordinator(
backend: &dyn WorkerCoordinatorComm,
) -> Result<Self, CommunicationBackendError> {
let transport = backend.worker_to_coordinator()?;
Ok(Self {
transport,
message_type: PhantomData,
})
}
}
impl<TSend, TRecv> CommunicationClient<TSend, TRecv>
where
TSend: Distributable,
{
pub fn send(&self, msg: TSend) {
self.transport.send(Self::encode(msg)).malstrom_fatal()
}
pub(crate) fn encode(msg: TSend) -> Vec<u8> {
rmp_serde::encode::to_vec(&msg).malstrom_fatal()
}
}
impl<TSend, TRecv> CommunicationClient<TSend, TRecv>
where
TRecv: Distributable,
{
pub fn recv(&self) -> Option<TRecv> {
let encoded = self.transport.recv().malstrom_fatal()?;
Some(Self::decode(&encoded))
}
pub async fn recv_async(&self) -> TRecv {
let encoded = self.transport.recv_async().await.malstrom_fatal();
Self::decode(&encoded)
}
pub(crate) fn decode(msg: &[u8]) -> TRecv {
rmp_serde::decode::from_slice(msg)
.map_err(|e| DecodeError::Serde(e, std::any::type_name::<TRecv>()))
.malstrom_fatal()
}
}
#[derive(Debug, Error)]
enum DecodeError {
#[error("Expected to decode to type {1}")]
Serde(#[source] rmp_serde::decode::Error, &'static str),
}
pub fn broadcast<'a, TSend: Distributable + Clone + 'a, TRecv: 'a>(
clients: impl Iterator<Item = &'a CommunicationClient<TSend, TRecv>>,
msg: TSend,
) {
for c in clients {
c.send(msg.clone());
}
}
#[derive(thiserror::Error, Debug)]
pub enum CommunicationBackendError {
#[error("Error building Client: {0:?}")]
ClientBuildError(Box<dyn std::error::Error + Send + Sync>),
}
#[derive(thiserror::Error, Debug)]
pub enum TransportError {
#[error("Error sending message: {0}")]
SendError(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error("Error receiving message: {0}")]
RecvError(Box<dyn std::error::Error + Send + Sync>),
}