use futures::future::BoxFuture;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
pub type PipeCallback<T> = Box<dyn (Fn(T) -> BoxFuture<'static, ()>) + Send + Sync>;
pub struct PipeManager {
senders: RwLock<HashMap<String, Box<dyn Any + Send + Sync>>>,
}
impl Default for PipeManager {
fn default() -> Self {
Self::new()
}
}
impl PipeManager {
pub fn new() -> Self {
Self {
senders: RwLock::new(HashMap::new()),
}
}
pub async fn register<T>(&self, name: &str, callback: PipeCallback<T>) -> Result<(), String>
where
T: Send + 'static,
{
{
let map = self.senders.read().await;
if map.contains_key(name) {
return Err(format!(
"Pipe registration failed: name '{}' is already in use",
name
));
}
}
let mut map = self.senders.write().await;
if map.contains_key(name) {
return Err(format!(
"Pipe registration failed: name '{}' conflict during race condition",
name
));
}
let (tx, mut rx) = mpsc::unbounded_channel::<T>();
let callback = Arc::new(callback);
tokio::spawn(async move {
while let Some(message) = rx.recv().await {
let cb = Arc::clone(&callback);
cb(message).await;
}
});
map.insert(name.to_string(), Box::new(tx));
Ok(())
}
pub async fn send<T>(&self, name: &str, message: T) -> Result<(), String>
where
T: Send + 'static,
{
let map = self.senders.read().await;
if let Some(any_tx) = map.get(name) {
if let Some(tx) = any_tx.downcast_ref::<mpsc::UnboundedSender<T>>() {
tx.send(message).map_err(|e| e.to_string())
} else {
Err("Type mismatch".into())
}
} else {
Err("Pipe not registered".into())
}
}
}