use std::{collections::HashMap, sync::Arc};
use futures::{future, stream::SplitStream, FutureExt};
use futures_util::stream::StreamExt;
use tokio::{
net::TcpStream,
sync::{
mpsc::{self, UnboundedReceiver, UnboundedSender},
Mutex,
},
};
use tokio_util::codec::Framed;
use crate::{
communication::{
CommunicationError, ControlMessage, ControlMessageCodec, ControlMessageHandler,
InterProcessMessage, MessageCodec, PusherT,
},
dataflow::stream::StreamId,
node::NodeId,
scheduler::endpoints_manager::ChannelsToReceivers,
};
#[allow(dead_code)]
pub(crate) struct DataReceiver {
node_id: NodeId,
stream: SplitStream<Framed<TcpStream, MessageCodec>>,
rx: UnboundedReceiver<(StreamId, Box<dyn PusherT>)>,
stream_id_to_pusher: HashMap<StreamId, Box<dyn PusherT>>,
control_tx: UnboundedSender<ControlMessage>,
control_rx: UnboundedReceiver<ControlMessage>,
}
impl DataReceiver {
pub(crate) async fn new(
node_id: NodeId,
stream: SplitStream<Framed<TcpStream, MessageCodec>>,
channels_to_receivers: Arc<Mutex<ChannelsToReceivers>>,
control_handler: &mut ControlMessageHandler,
) -> Self {
let (tx, rx) = mpsc::unbounded_channel();
channels_to_receivers.lock().await.add_sender(tx);
let (control_tx, control_rx) = mpsc::unbounded_channel();
control_handler.add_channel_to_data_receiver(node_id, control_tx);
Self {
node_id,
stream,
rx,
stream_id_to_pusher: HashMap::new(),
control_tx: control_handler.get_channel_to_handler(),
control_rx,
}
}
pub(crate) async fn run(&mut self) -> Result<(), CommunicationError> {
self.control_tx
.send(ControlMessage::DataReceiverInitialized(self.node_id))
.map_err(CommunicationError::from)?;
while let Some(res) = self.stream.next().await {
match res {
Ok(msg) => {
self.update_pushers().await;
let (metadata, bytes) = match msg {
InterProcessMessage::Serialized { metadata, bytes } => (metadata, bytes),
InterProcessMessage::Deserialized {
metadata: _,
data: _,
} => unreachable!(),
};
match self.stream_id_to_pusher.get_mut(&metadata.stream_id) {
Some(pusher) => {
if let Err(e) = pusher.send_from_bytes(bytes) {
return Err(e);
}
}
None => panic!(
"Receiver does not have any pushers. \
Race condition during data-flow reconfiguration."
),
}
}
Err(e) => return Err(CommunicationError::from(e)),
}
}
Ok(())
}
async fn update_pushers(&mut self) {
while let Some(Some((stream_id, pusher))) = self.rx.recv().now_or_never() {
self.stream_id_to_pusher.insert(stream_id, pusher);
}
}
}
pub(crate) async fn run_receivers(
mut receivers: Vec<DataReceiver>,
) -> Result<(), CommunicationError> {
future::join_all(receivers.iter_mut().map(|receiver| receiver.run())).await;
Ok(())
}
#[allow(dead_code)]
pub(crate) struct ControlReceiver {
node_id: NodeId,
stream: SplitStream<Framed<TcpStream, ControlMessageCodec>>,
control_tx: UnboundedSender<ControlMessage>,
control_rx: UnboundedReceiver<ControlMessage>,
}
impl ControlReceiver {
pub(crate) fn new(
node_id: NodeId,
stream: SplitStream<Framed<TcpStream, ControlMessageCodec>>,
control_handler: &mut ControlMessageHandler,
) -> Self {
let (tx, control_rx) = tokio::sync::mpsc::unbounded_channel();
control_handler.add_channel_to_control_receiver(node_id, tx);
Self {
node_id,
stream,
control_tx: control_handler.get_channel_to_handler(),
control_rx,
}
}
pub(crate) async fn run(&mut self) -> Result<(), CommunicationError> {
self.control_tx
.send(ControlMessage::ControlReceiverInitialized(self.node_id))
.map_err(CommunicationError::from)?;
while let Some(res) = self.stream.next().await {
match res {
Ok(msg) => {
self.control_tx
.send(msg)
.map_err(CommunicationError::from)?;
}
Err(e) => return Err(CommunicationError::from(e)),
}
}
Ok(())
}
}
pub(crate) async fn run_control_receivers(
mut receivers: Vec<ControlReceiver>,
) -> Result<(), CommunicationError> {
future::join_all(receivers.iter_mut().map(|receiver| receiver.run())).await;
Ok(())
}