use std::collections::HashMap;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use crate::node::NodeId;
use super::{CommunicationError, ControlMessage};
pub struct ControlMessageHandler {
tx: UnboundedSender<ControlMessage>,
rx: UnboundedReceiver<ControlMessage>,
channels_to_control_senders: HashMap<NodeId, UnboundedSender<ControlMessage>>,
channels_to_control_receivers: HashMap<NodeId, UnboundedSender<ControlMessage>>,
channels_to_data_senders: HashMap<NodeId, UnboundedSender<ControlMessage>>,
channels_to_data_receivers: HashMap<NodeId, UnboundedSender<ControlMessage>>,
channels_to_nodes: HashMap<NodeId, UnboundedSender<ControlMessage>>,
}
#[allow(dead_code)]
impl ControlMessageHandler {
pub fn new() -> Self {
let (tx, rx) = mpsc::unbounded_channel();
Self {
tx,
rx,
channels_to_control_senders: HashMap::new(),
channels_to_control_receivers: HashMap::new(),
channels_to_data_senders: HashMap::new(),
channels_to_data_receivers: HashMap::new(),
channels_to_nodes: HashMap::new(),
}
}
pub fn add_channel_to_control_sender(
&mut self,
node_id: NodeId,
tx: UnboundedSender<ControlMessage>,
) {
if self
.channels_to_control_senders
.insert(node_id, tx)
.is_some()
{
tracing::error!(
"ControlMessageHandler: overwrote channel to control sender for node {}",
node_id
);
}
}
pub fn send_to_control_sender(
&mut self,
node_id: NodeId,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
match self.channels_to_control_senders.get_mut(&node_id) {
Some(tx) => tx.send(msg).map_err(CommunicationError::from),
None => Err(CommunicationError::Disconnected),
}
}
pub fn broadcast_to_control_senders(
&mut self,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
for tx in self.channels_to_control_senders.values_mut() {
tx.send(msg.clone()).map_err(CommunicationError::from)?;
}
Ok(())
}
pub fn add_channel_to_control_receiver(
&mut self,
node_id: NodeId,
tx: UnboundedSender<ControlMessage>,
) {
if self
.channels_to_control_receivers
.insert(node_id, tx)
.is_some()
{
tracing::error!(
"ControlMessageHandler: overwrote channel to control receiver for node {}",
node_id
);
}
}
pub fn send_to_control_receiver(
&mut self,
node_id: NodeId,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
match self.channels_to_control_receivers.get_mut(&node_id) {
Some(tx) => tx.send(msg).map_err(CommunicationError::from),
None => Err(CommunicationError::Disconnected),
}
}
pub fn broadcast_to_control_receivers(
&mut self,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
for tx in self.channels_to_control_receivers.values_mut() {
tx.send(msg.clone()).map_err(CommunicationError::from)?;
}
Ok(())
}
pub fn add_channel_to_data_sender(
&mut self,
node_id: NodeId,
tx: UnboundedSender<ControlMessage>,
) {
if self.channels_to_data_senders.insert(node_id, tx).is_some() {
tracing::error!(
"ControlMessageHandler: overwrote channel to data sender for node {}",
node_id
);
}
}
pub fn send_to_data_sender(
&mut self,
node_id: NodeId,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
match self.channels_to_data_senders.get_mut(&node_id) {
Some(tx) => tx.send(msg).map_err(CommunicationError::from),
None => Err(CommunicationError::Disconnected),
}
}
pub fn broadcast_to_data_senders(
&mut self,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
for tx in self.channels_to_data_senders.values_mut() {
tx.send(msg.clone()).map_err(CommunicationError::from)?;
}
Ok(())
}
pub fn add_channel_to_data_receiver(
&mut self,
node_id: NodeId,
tx: UnboundedSender<ControlMessage>,
) {
if self
.channels_to_data_receivers
.insert(node_id, tx)
.is_some()
{
tracing::error!(
"ControlMessageHandler: overwrote channel to data receiver for node {}",
node_id
);
}
}
pub fn send_to_data_receiver(
&mut self,
node_id: NodeId,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
match self.channels_to_data_receivers.get_mut(&node_id) {
Some(tx) => tx.send(msg).map_err(CommunicationError::from),
None => Err(CommunicationError::Disconnected),
}
}
pub fn broadcast_to_data_receivers(
&mut self,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
for tx in self.channels_to_data_receivers.values_mut() {
tx.send(msg.clone()).map_err(CommunicationError::from)?;
}
Ok(())
}
pub fn add_channel_to_node(&mut self, node_id: NodeId, tx: UnboundedSender<ControlMessage>) {
self.channels_to_nodes.insert(node_id, tx);
}
pub fn send_to_node(
&mut self,
node_id: NodeId,
msg: ControlMessage,
) -> Result<(), CommunicationError> {
match self.channels_to_nodes.get_mut(&node_id) {
Some(tx) => tx.send(msg).map_err(CommunicationError::from),
None => Err(CommunicationError::Disconnected),
}
}
pub fn broadcast_to_nodes(&mut self, msg: ControlMessage) -> Result<(), CommunicationError> {
for tx in self.channels_to_nodes.values_mut() {
tx.send(msg.clone()).map_err(CommunicationError::from)?;
}
Ok(())
}
pub fn get_channel_to_handler(&self) -> UnboundedSender<ControlMessage> {
self.tx.clone()
}
pub async fn read(&mut self) -> Result<ControlMessage, CommunicationError> {
self.rx.recv().await.ok_or(CommunicationError::Disconnected)
}
pub async fn read_all_operators_initialized_on_node_msg(
&mut self,
) -> Result<NodeId, CommunicationError> {
let mut read_msgs = Vec::new();
let mut result = None;
while result.is_none() {
match self.read().await {
Ok(ControlMessage::AllOperatorsInitializedOnNode(node_id)) => {
result = Some(Ok(node_id))
}
Ok(msg) => read_msgs.push(msg),
Err(e) => result = Some(Err(e)),
};
}
for msg in read_msgs {
self.tx.send(msg).map_err(CommunicationError::from)?;
}
result.unwrap()
}
pub async fn read_sender_or_receiver_initialized(
&mut self,
) -> Result<ControlMessage, CommunicationError> {
let mut read_msgs = Vec::new();
let mut result = None;
while result.is_none() {
match self.read().await {
Ok(control_msg) => match control_msg.clone() {
ControlMessage::ControlSenderInitialized(_)
| ControlMessage::ControlReceiverInitialized(_)
| ControlMessage::DataSenderInitialized(_)
| ControlMessage::DataReceiverInitialized(_) => result = Some(Ok(control_msg)),
_ => read_msgs.push(control_msg),
},
Err(e) => result = Some(Err(e)),
};
}
for msg in read_msgs {
self.tx.send(msg).map_err(CommunicationError::from)?;
}
result.unwrap()
}
}