#[cfg(any(feature = "state_store", feature = "azure_device_registry"))]
#[allow(dead_code)]
pub mod dispatcher {
use std::{collections::HashMap, fmt::Debug, hash::Hash, sync::Mutex};
use thiserror::Error;
use tokio::sync::mpsc::{
UnboundedReceiver, UnboundedSender, error::SendError, unbounded_channel,
};
pub type Receiver<T> = UnboundedReceiver<T>;
#[derive(Error, Debug)]
pub enum RegisterError {
#[error("receiver with id {0} already registered")]
AlreadyRegistered(String),
}
#[derive(PartialEq, Eq, Clone, Error, Debug)]
#[error("{kind}")]
pub struct DispatchError<T> {
pub data: T,
pub kind: DispatchErrorKind,
}
#[derive(Debug, Eq, PartialEq, Clone, Error)]
pub enum DispatchErrorKind {
#[error("Failed to send message")]
SendError,
#[error("Receiver with ID '{0}' not found")]
NotFound(String), }
impl<T> From<SendError<T>> for DispatchError<T> {
fn from(err: SendError<T>) -> Self {
Self {
data: err.0,
kind: DispatchErrorKind::SendError,
}
}
}
#[derive(Default)]
pub struct Dispatcher<T, H>
where
H: Eq + Hash + Debug + Clone,
{
tx_map: Mutex<HashMap<H, UnboundedSender<T>>>,
}
impl<T, H> Dispatcher<T, H>
where
H: Eq + Hash + Debug + Clone,
{
pub fn new() -> Self {
Self {
tx_map: Mutex::new(HashMap::new()),
}
}
pub fn register_receiver(&self, receiver_id: H) -> Result<Receiver<T>, RegisterError> {
let mut tx_map = self.tx_map.lock().unwrap();
if tx_map.get(&receiver_id).is_some() {
return Err(RegisterError::AlreadyRegistered(format!("{receiver_id:?}")));
}
let (tx, rx) = unbounded_channel();
tx_map.insert(receiver_id, tx);
Ok(rx)
}
pub fn unregister_receiver(&self, receiver_id: &H) -> bool {
self.tx_map.lock().unwrap().remove(receiver_id).is_some()
}
pub fn unregister_all(&self) -> usize {
self.tx_map.lock().unwrap().drain().count()
}
pub fn dispatch(&self, receiver_id: &H, message: T) -> Result<(), DispatchError<T>> {
if let Some(tx) = self.tx_map.lock().unwrap().get(receiver_id) {
Ok(tx.send(message)?)
} else {
Err(DispatchError {
data: message,
kind: DispatchErrorKind::NotFound(format!("{receiver_id:?}")),
})
}
}
pub fn get_all_receiver_ids(&self) -> Vec<H> {
let tx_map = self.tx_map.lock().unwrap();
tx_map.keys().cloned().collect()
}
}
}