use std::time::Duration;
use thiserror::Error;
use crate::channel::{
self, Receiver, RecvError, RecvTimeoutError, SelectResult, Sender, TryRecvError,
};
use crate::network::{NetworkMessage, ReceiverEndpoint};
use crate::operator::ExchangeData;
use crate::profiler::{get_profiler, Profiler};
const CHANNEL_CAPACITY: usize = 16;
pub(crate) fn local_channel<T: ExchangeData>(
receiver_endpoint: ReceiverEndpoint,
) -> (NetworkSender<T>, NetworkReceiver<T>) {
let (sender, receiver) = channel::bounded(CHANNEL_CAPACITY);
(
NetworkSender {
receiver_endpoint,
sender: SenderInner::Local(sender),
},
NetworkReceiver {
receiver_endpoint,
receiver,
},
)
}
pub(crate) fn mux_sender<T: ExchangeData>(
receiver_endpoint: ReceiverEndpoint,
tx: Sender<(ReceiverEndpoint, NetworkMessage<T>)>,
) -> NetworkSender<T> {
NetworkSender {
receiver_endpoint,
sender: SenderInner::Mux(tx),
}
}
#[derive(Derivative)]
#[derivative(Debug)]
pub(crate) struct NetworkReceiver<In: Send + 'static> {
pub receiver_endpoint: ReceiverEndpoint,
#[derivative(Debug = "ignore")]
receiver: Receiver<NetworkMessage<In>>,
}
impl<In: Send + 'static> NetworkReceiver<In> {
#[inline]
fn profile_message<E>(
&self,
message: Result<NetworkMessage<In>, E>,
) -> Result<NetworkMessage<In>, E> {
message.map(|message| {
get_profiler().items_in(
message.sender,
self.receiver_endpoint.coord,
message.num_items(),
);
message
})
}
pub fn recv(&self) -> Result<NetworkMessage<In>, RecvError> {
self.profile_message(self.receiver.recv())
}
pub fn try_recv(&self) -> Result<NetworkMessage<In>, TryRecvError> {
self.profile_message(self.receiver.try_recv())
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<NetworkMessage<In>, RecvTimeoutError> {
self.profile_message(self.receiver.recv_timeout(timeout))
}
pub fn select<In2: ExchangeData>(
&self,
other: &NetworkReceiver<In2>,
) -> SelectResult<NetworkMessage<In>, NetworkMessage<In2>> {
self.receiver.select(&other.receiver)
}
pub fn select_timeout<In2: ExchangeData>(
&self,
other: &NetworkReceiver<In2>,
timeout: Duration,
) -> Result<SelectResult<NetworkMessage<In>, NetworkMessage<In2>>, RecvTimeoutError> {
self.receiver.select_timeout(&other.receiver, timeout)
}
}
#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub(crate) struct NetworkSender<Out: Send + 'static> {
pub receiver_endpoint: ReceiverEndpoint,
#[derivative(Debug = "ignore")]
sender: SenderInner<Out>,
}
#[derive(Clone)]
enum SenderInner<Out: Send + 'static> {
Mux(Sender<(ReceiverEndpoint, NetworkMessage<Out>)>),
Local(Sender<NetworkMessage<Out>>),
}
impl<Out: ExchangeData> NetworkSender<Out> {
pub fn send(&self, message: NetworkMessage<Out>) -> Result<(), NetworkSendError> {
get_profiler().items_out(
message.sender,
self.receiver_endpoint.coord,
message.num_items(),
);
match &self.sender {
SenderInner::Mux(tx) => tx
.send((self.receiver_endpoint, message))
.map_err(|_| NetworkSendError::Disconnected(self.receiver_endpoint)),
SenderInner::Local(tx) => tx
.send(message)
.map_err(|_| NetworkSendError::Disconnected(self.receiver_endpoint)),
}
}
pub fn clone_inner(&self) -> Sender<NetworkMessage<Out>> {
match &self.sender {
SenderInner::Mux(_) => panic!("Trying to clone mux channel. Not supported"),
SenderInner::Local(tx) => tx.clone(),
}
}
}
#[derive(Debug, Error)]
pub enum NetworkSendError {
#[error("channel disconnected")]
Disconnected(ReceiverEndpoint),
}