use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{Mutex, mpsc, oneshot};
use super::async_ibus_node::IBusMessage;
use crate::subscription::filter::SubscriptionFilter;
pub const BROADCAST_ID : &str = "@BROADCAST";
#[derive(Debug, Error)]
pub enum AsyncNodeRegistryError {
#[error("Node id {0} is already registered")]
NodeAlreadyExists(String),
#[error("{0} is an invalid node id.")]
InvalidNodeId(String),
#[error("The node {0} is not registered")]
UnknownNode(String),
#[error("Error receiving response from Registry Instance: {0}.")]
ReceiveError(tokio::sync::oneshot::error::RecvError),
#[error("Error sending response to node {0} Registry Instance: {1}.")]
SendError(String, String),
#[error("Error subscribing to node {0} with Error: {1}.")]
SubscriptionError(String, String),
#[error("Error unsubscribing to ID {0} with Error: {1}.")]
UnsubscriptionError(usize, String),
#[error("An unknown error occurred")]
Unknown,
}
enum AsyncNodeRegistryControlMessage<S, R> {
RegisterNode {
node_id: String,
sender: mpsc::Sender<IBusMessage<S, R>>,
respond_to: oneshot::Sender<Result<(), AsyncNodeRegistryError>>,
},
SendMessage {
node_id: String,
payload: IBusMessage<S, R>,
respond_to: oneshot::Sender<Result<(), AsyncNodeRegistryError>>,
},
RequestNodeChannel {
node_id: String,
respond_to:
oneshot::Sender<Result<mpsc::Sender<IBusMessage<S, R>>, AsyncNodeRegistryError>>,
},
Subscribe {
source_node_id: String,
target_node_id: String,
filter: Box<dyn SubscriptionFilter<R> + Send + Sync>,
respond_to: oneshot::Sender<Result<usize, AsyncNodeRegistryError>>,
},
Unsubscribe {
subscription_id: usize,
target_node_id: String,
respond_to: oneshot::Sender<Result<usize, AsyncNodeRegistryError>>,
},
Shutdown,
}
struct SubscriptionItem<R> {
id: usize,
target_node_id: String,
filter: Box<dyn SubscriptionFilter<R> + Send + Sync>,
}
struct AsyncNodeRegistry<S, R>
where
R: Clone + Send + Sync + 'static,
{
nodes: HashMap<String, mpsc::Sender<IBusMessage<S, R>>>,
subscriptions: Arc<Mutex<HashMap<String, Vec<SubscriptionItem<R>>>>>,
last_subscription_id: usize,
}
impl<S, R> AsyncNodeRegistry<S, R>
where
R: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
return Self {
nodes: HashMap::new(),
subscriptions: Arc::new(Mutex::new(HashMap::new())),
last_subscription_id: 0,
};
}
async fn handle_message(&mut self, msg: AsyncNodeRegistryControlMessage<S, R>) {
match msg {
AsyncNodeRegistryControlMessage::RegisterNode {
node_id,
sender,
respond_to,
} => {
if node_id == BROADCAST_ID {
let _ = respond_to.send(Err(AsyncNodeRegistryError::InvalidNodeId(node_id)));
return;
}
if !self.nodes.contains_key(node_id.as_str()) {
self.nodes.insert(node_id, sender);
let _ = respond_to.send(Ok(()));
} else {
let _ =
respond_to.send(Err(AsyncNodeRegistryError::NodeAlreadyExists(node_id)));
}
}
AsyncNodeRegistryControlMessage::SendMessage {
node_id,
payload,
respond_to,
} => match payload {
IBusMessage::Broadcast { payload } => {
if node_id == BROADCAST_ID {
for (_, node_sender) in &self.nodes {
let _ = node_sender.send(IBusMessage::Broadcast { payload: payload.clone() }).await;
}
}
else {
self.check_subscriptions(node_id.as_str(), &payload).await;
}
let _ = respond_to.send(Ok(()));
}
_ => match self.nodes.get(node_id.as_str()) {
Some(res) => match res.send(payload).await {
Ok(_) => {
let _ = respond_to.send(Ok(()));
}
Err(err) => {
let _ = respond_to.send(Err(AsyncNodeRegistryError::SendError(
node_id,
format!("{}", err),
)));
}
},
_ => {
let _ = respond_to.send(Err(AsyncNodeRegistryError::UnknownNode(node_id)));
}
},
},
AsyncNodeRegistryControlMessage::RequestNodeChannel {
node_id,
respond_to,
} => match self.nodes.get(node_id.as_str()) {
Some(res) => {
let _ = respond_to.send(Ok(res.clone()));
}
_ => {
let _ = respond_to.send(Err(AsyncNodeRegistryError::UnknownNode(node_id)));
}
},
AsyncNodeRegistryControlMessage::Subscribe {
source_node_id,
target_node_id,
filter,
respond_to,
} => {
let mut subscriptions = self.subscriptions.lock().await;
self.last_subscription_id += 1;
let item = SubscriptionItem {
id: self.last_subscription_id,
target_node_id,
filter: filter,
};
if let Some(subscribers) = subscriptions.get_mut(&source_node_id) {
subscribers.push(item);
} else {
subscriptions.insert(source_node_id, vec![item]);
}
let _ = respond_to.send(Ok(self.last_subscription_id));
}
AsyncNodeRegistryControlMessage::Unsubscribe {
subscription_id,
target_node_id,
respond_to,
} => {
let mut subscriptions = self.subscriptions.lock().await;
let mut found = false;
for (_, subscribers) in subscriptions.iter_mut() {
let mut index = 0;
for subscriber in subscribers.iter() {
if subscriber.id == subscription_id
&& subscriber.target_node_id == target_node_id
{
subscribers.remove(index);
found = true;
break;
}
index += 1;
}
}
if found {
let _ = respond_to.send(Ok(subscription_id));
} else {
let _ = respond_to.send(Err(AsyncNodeRegistryError::UnsubscriptionError(
subscription_id,
"Subscription not found".to_string(),
)));
}
}
_ => {}
}
}
async fn check_subscriptions(&self, node_id: &str, payload: &R) {
let mut removals = Vec::new();
{
let subscriptions = self.subscriptions.lock().await;
if let Some(subscribers) = subscriptions.get(node_id) {
for subscriber in subscribers.iter() {
if subscriber.filter.matches(&payload) {
match self.nodes.get(subscriber.target_node_id.as_str()) {
Some(res) => match res
.send(IBusMessage::Broadcast {
payload: payload.clone(),
})
.await
{
Ok(_) => {
}
Err(err) => {
log::error!(
"Error sending broadcast message to node {}: {}",
subscriber.target_node_id,
err
);
removals.push(subscriber.id);
}
},
_ => {
log::error!(
"Error sending broadcast message to node {}: Node not found",
subscriber.target_node_id
);
removals.push(subscriber.id);
}
}
}
}
}
}
if removals.len() > 0 {
for id in removals {
self.remove_subscription(id).await;
}
}
}
async fn remove_subscription(&self, id: usize) {
let mut subscriptions = self.subscriptions.lock().await;
for (_, subscribers) in subscriptions.iter_mut() {
let mut index = 0;
for subscriber in subscribers.iter() {
if subscriber.id == id {
subscribers.remove(index);
break;
}
index += 1;
}
}
}
}
async fn run_my_registry<S, R>(
mut registry: AsyncNodeRegistry<S, R>,
mut receiver: mpsc::Receiver<AsyncNodeRegistryControlMessage<S, R>>,
) where
R: Clone + Send + Sync + 'static,
{
loop {
tokio::select! {
msg = receiver.recv() => {
if let Some(m) = msg {
match m {
AsyncNodeRegistryControlMessage::Shutdown => {
break;
},
_ => {
let _ = registry.handle_message(m).await;
}
}
}
}
}
}
}
#[derive(Clone)]
pub struct AsyncNodeRegistryHandle<S, R> {
sender: mpsc::Sender<AsyncNodeRegistryControlMessage<S, R>>,
}
impl<S, R> AsyncNodeRegistryHandle<S, R>
where
S: Clone + Send + Sync + 'static,
R: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
let (sender, receiver) = mpsc::channel(32);
let actor = AsyncNodeRegistry::new();
tokio::spawn(run_my_registry(actor, receiver));
Self { sender: sender }
}
pub async fn register_node(
&self,
node_id: &str,
sender: mpsc::Sender<IBusMessage<S, R>>,
) -> Result<(), AsyncNodeRegistryError> {
let (send, recv) = oneshot::channel();
let msg = AsyncNodeRegistryControlMessage::RegisterNode {
node_id: node_id.to_string(),
sender: sender,
respond_to: send,
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(res) => {
return res;
}
Err(err) => {
return Err(AsyncNodeRegistryError::ReceiveError(err));
}
}
}
pub async fn send_message(
&self,
target_node_id: &str,
payload: IBusMessage<S, R>,
) -> Result<(), AsyncNodeRegistryError> {
let (send, recv) = oneshot::channel();
let msg = AsyncNodeRegistryControlMessage::SendMessage {
node_id: target_node_id.to_string(),
payload: payload,
respond_to: send,
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(res) => {
return res;
}
Err(err) => {
return Err(AsyncNodeRegistryError::ReceiveError(err));
}
}
}
pub async fn request(&self, target_node_id: &str, payload: S) -> Result<R, AsyncNodeRegistryError> {
let (res_send, res_recv) = oneshot::channel();
let ibus_msg :IBusMessage<S,R> = IBusMessage::Request {
payload: payload,
respond_to: res_send,
};
let (send, recv) = oneshot::channel();
let msg = AsyncNodeRegistryControlMessage::SendMessage {
node_id: target_node_id.to_string(),
payload: ibus_msg,
respond_to: send,
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(_) => {
match res_recv.await {
Ok(res) => {
return Ok(res);
}
Err(err) => {
return Err(AsyncNodeRegistryError::ReceiveError(err));
}
}
}
Err(err) => {
return Err(AsyncNodeRegistryError::ReceiveError(err));
}
}
}
pub async fn broadcast(
&self,
target_node_id: &str,
payload: R
) -> Result<(), AsyncNodeRegistryError> {
let (send, recv) = oneshot::channel();
let local_payload : IBusMessage<S,R> = IBusMessage::Broadcast { payload: payload };
let msg = AsyncNodeRegistryControlMessage::SendMessage {
node_id: target_node_id.to_string(),
payload: local_payload,
respond_to: send,
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(res) => {
return res;
}
Err(err) => {
return Err(AsyncNodeRegistryError::ReceiveError(err));
}
}
}
pub async fn subscribe(
&self,
source_node_id: &str,
target_node_id: &str,
filter: Box<dyn SubscriptionFilter<R> + Send + Sync>,
) -> Result<usize, AsyncNodeRegistryError> {
let (send, recv) = oneshot::channel();
let msg = AsyncNodeRegistryControlMessage::Subscribe {
source_node_id: source_node_id.to_string(),
target_node_id: target_node_id.to_string(),
filter: filter,
respond_to: send,
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(res) => match res {
Ok(id) => {
return Ok(id);
}
Err(err) => {
return Err(err);
}
},
Err(err) => {
return Err(AsyncNodeRegistryError::ReceiveError(err));
}
}
}
pub async fn unsubscribe(
&self,
subscription_id: usize,
target_node_id: &str,
filter: Box<dyn SubscriptionFilter<R> + Send + Sync>,
) -> Result<usize, AsyncNodeRegistryError> {
let (send, recv) = oneshot::channel();
let msg = AsyncNodeRegistryControlMessage::Unsubscribe {
subscription_id: subscription_id,
target_node_id: target_node_id.to_string(),
respond_to: send,
};
let _ = self.sender.send(msg).await;
match recv.await {
Ok(res) => match res {
Ok(id) => {
return Ok(id);
}
Err(err) => {
return Err(err);
}
},
Err(err) => {
return Err(AsyncNodeRegistryError::ReceiveError(err));
}
}
}
}