use async_trait::async_trait;
use serde::Deserialize;
use std::{any::Any, collections::HashMap, sync::Arc};
use tokio::sync::{mpsc, Mutex};
use crate::{
communication::{Pusher, PusherT, RecvEndpoint, SendEndpoint},
dataflow::{
graph::{Job, JobGraph},
stream::StreamId,
Data, Message, ReadStream, WriteStream,
},
node::NodeId,
scheduler::endpoints_manager::{ChannelsToReceivers, ChannelsToSenders},
OperatorId,
};
#[async_trait]
pub trait StreamEndpointsT: Send {
fn as_any(&mut self) -> &mut dyn Any;
fn add_inter_thread_channel(&mut self);
async fn add_inter_node_send_endpoint(
&mut self,
other_node_id: NodeId,
channels_to_senders: Arc<Mutex<ChannelsToSenders>>,
) -> Result<(), String>;
fn add_inter_node_recv_endpoint(
&mut self,
receiver_pushers: &mut HashMap<StreamId, Box<dyn PusherT>>,
) -> Result<(), String>;
}
pub struct StreamEndpoints<D>
where
for<'a> D: Data + Deserialize<'a>,
{
stream_id: StreamId,
recv_endpoints: Vec<RecvEndpoint<Arc<Message<D>>>>,
send_endpoints: Vec<SendEndpoint<Arc<Message<D>>>>,
}
impl<D> StreamEndpoints<D>
where
for<'a> D: Data + Deserialize<'a>,
{
pub fn new(stream_id: StreamId) -> Self {
Self {
stream_id,
recv_endpoints: Vec::new(),
send_endpoints: Vec::new(),
}
}
fn take_recv_endpoint(&mut self) -> Result<RecvEndpoint<Arc<Message<D>>>, &'static str> {
match self.recv_endpoints.pop() {
Some(recv_endpoint) => Ok(recv_endpoint),
None => Err("No more recv endpoints available"),
}
}
fn get_send_endpoints(&mut self) -> Result<Vec<SendEndpoint<Arc<Message<D>>>>, &'static str> {
let mut result: Vec<SendEndpoint<Arc<Message<D>>>> = Vec::new();
result.append(&mut self.send_endpoints);
Ok(result)
}
fn add_send_endpoint(&mut self, endpoint: SendEndpoint<Arc<Message<D>>>) {
self.send_endpoints.push(endpoint);
}
fn add_recv_endpoint(&mut self, endpoint: RecvEndpoint<Arc<Message<D>>>) {
self.recv_endpoints.push(endpoint);
}
}
#[async_trait]
impl<D> StreamEndpointsT for StreamEndpoints<D>
where
for<'a> D: Data + Deserialize<'a>,
{
fn as_any(&mut self) -> &mut dyn Any {
self
}
fn add_inter_thread_channel(&mut self) {
let (tx, rx) = mpsc::unbounded_channel();
self.add_send_endpoint(SendEndpoint::InterThread(tx));
self.add_recv_endpoint(RecvEndpoint::InterThread(rx));
}
async fn add_inter_node_send_endpoint(
&mut self,
other_node_id: NodeId,
channels_to_senders: Arc<Mutex<ChannelsToSenders>>,
) -> Result<(), String> {
let channels_to_senders = channels_to_senders.lock().await;
if let Some(tx) = channels_to_senders.clone_channel(other_node_id) {
self.add_send_endpoint(SendEndpoint::InterProcess(self.stream_id, tx));
Ok(())
} else {
Err(format!("Unable to clone channel to node {}", other_node_id))
}
}
fn add_inter_node_recv_endpoint(
&mut self,
receiver_pushers: &mut HashMap<StreamId, Box<dyn PusherT>>,
) -> Result<(), String> {
let pusher: &mut Box<dyn PusherT> = receiver_pushers
.entry(self.stream_id)
.or_insert_with(|| Box::new(Pusher::<Arc<Message<D>>>::new()));
if let Some(pusher) = pusher.as_any().downcast_mut::<Pusher<Arc<Message<D>>>>() {
let (tx, rx) = mpsc::unbounded_channel();
pusher.add_endpoint(SendEndpoint::InterThread(tx));
self.add_recv_endpoint(RecvEndpoint::InterThread(rx));
Ok(())
} else {
Err(format!(
"Error casting pusher when adding inter node recv endpoint for stream {}",
self.stream_id
))
}
}
}
pub(crate) struct ChannelManager {
node_id: NodeId,
stream_entries: HashMap<StreamId, Box<dyn StreamEndpointsT>>,
}
#[allow(dead_code)]
impl ChannelManager {
#[allow(clippy::needless_collect)]
pub async fn new(
job_graph: &JobGraph,
node_id: NodeId,
channels_to_receivers: Arc<Mutex<ChannelsToReceivers>>,
channels_to_senders: Arc<Mutex<ChannelsToSenders>>,
) -> Self {
let mut channel_manager = Self {
node_id,
stream_entries: HashMap::new(),
};
let mut receiver_pushers: HashMap<StreamId, Box<dyn PusherT>> = HashMap::new();
let local_operator_ids: Vec<OperatorId> = job_graph
.operators()
.into_iter()
.filter(|o| o.config.node_id == node_id)
.map(|o| o.id)
.collect();
let operators: HashMap<_, _> = job_graph
.operators()
.into_iter()
.map(|o| (o.config.id, o))
.collect();
for (stream, source, destinations) in job_graph.get_streams() {
let contains_source = match source {
Job::Operator(operator_id) => local_operator_ids.contains(&operator_id),
Job::Driver => node_id == 0,
};
if contains_source {
let stream_endpoint_t = channel_manager
.stream_entries
.entry(stream.id())
.or_insert_with(|| stream.to_stream_endpoints_t());
for destination in destinations {
let destination_node_id = match destination {
Job::Operator(operator_id) => {
operators.get(&operator_id).unwrap().config.node_id
}
Job::Driver => 0,
};
if destination_node_id == node_id {
stream_endpoint_t.add_inter_thread_channel();
} else {
stream_endpoint_t
.add_inter_node_send_endpoint(
destination_node_id,
channels_to_senders.clone(),
)
.await
.unwrap();
}
}
} else {
let contains_destination = destinations.iter().any(|destination| {
let destination_node_id = match destination {
Job::Operator(operator_id) => {
operators.get(operator_id).unwrap().config.node_id
}
Job::Driver => 0,
};
node_id == destination_node_id
});
if contains_destination {
let stream_endpoint_t = channel_manager
.stream_entries
.entry(stream.id())
.or_insert_with(|| stream.to_stream_endpoints_t());
stream_endpoint_t
.add_inter_node_recv_endpoint(&mut receiver_pushers)
.unwrap();
}
}
}
for (k, v) in receiver_pushers.into_iter() {
channels_to_receivers.lock().await.send(k, v);
}
channel_manager
}
pub fn node_id(&self) -> NodeId {
self.node_id
}
pub fn take_recv_endpoint<D>(
&mut self,
stream_id: StreamId,
) -> Result<RecvEndpoint<Arc<Message<D>>>, String>
where
for<'a> D: Data + Deserialize<'a>,
{
if let Some(stream_entry_t) = self.stream_entries.get_mut(&stream_id) {
if let Some(stream_entry) = stream_entry_t.as_any().downcast_mut::<StreamEndpoints<D>>()
{
match stream_entry.take_recv_endpoint() {
Ok(recv_endpoint) => Ok(recv_endpoint),
Err(msg) => Err(format!(
"Could not get recv endpoint with id {}: {}",
stream_id, msg
)),
}
} else {
Err(format!(
"Type mismatch for recv endpoint with ID {}",
stream_id
))
}
} else {
Err(format!("No recv endpoints found with ID {}", stream_id))
}
}
pub fn get_send_endpoints<D>(
&mut self,
stream_id: StreamId,
) -> Result<Vec<SendEndpoint<Arc<Message<D>>>>, String>
where
for<'a> D: Data + Deserialize<'a>,
{
if let Some(stream_entry_t) = self.stream_entries.get_mut(&stream_id) {
if let Some(stream_entry) = stream_entry_t.as_any().downcast_mut::<StreamEndpoints<D>>()
{
match stream_entry.get_send_endpoints() {
Ok(send_endpoints) => Ok(send_endpoints),
Err(msg) => Err(format!(
"Could not get recv endpoint with id {}: {}",
stream_id, msg
)),
}
} else {
Err(format!(
"Type mismatch for recv endpoint with ID {}",
stream_id
))
}
} else {
Err(format!("No recv endpoints found with ID {}", stream_id))
}
}
pub fn take_read_stream<D>(&mut self, stream_id: StreamId) -> Result<ReadStream<D>, String>
where
D: Data + for<'a> Deserialize<'a>,
{
self.take_recv_endpoint(stream_id)
.map(|endpoint| ReadStream::new(stream_id, &stream_id.to_string(), Some(endpoint)))
}
pub fn write_stream<D>(&mut self, stream_id: StreamId) -> Result<WriteStream<D>, String>
where
D: Data + for<'a> Deserialize<'a>,
{
self.get_send_endpoints(stream_id)
.map(|endpoints| WriteStream::from_endpoints(endpoints, stream_id))
}
}