use futures::future::BoxFuture;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, broadcast};
pub type SpreadCallback<T> = Box<dyn (Fn(T) -> BoxFuture<'static, ()>) + Send + Sync>;
pub struct SpreadManager {
hubs: RwLock<HashMap<String, Box<dyn Any + Send + Sync>>>,
}
impl Default for SpreadManager {
fn default() -> Self {
Self::new()
}
}
impl SpreadManager {
pub fn new() -> Self {
Self {
hubs: RwLock::new(HashMap::new()),
}
}
pub async fn subscribe<T>(&self, name: &str, callback: SpreadCallback<T>) -> Result<(), String>
where
T: Clone + Send + Sync + 'static,
{
let mut rx = {
let mut map = self.hubs.write().await;
let tx = if let Some(any) = map.get(name) {
any.downcast_ref::<broadcast::Sender<T>>()
.ok_or_else(|| format!("Spread '{}' type mismatch", name))?
.clone()
} else {
let (tx, _) = broadcast::channel::<T>(1024);
map.insert(name.to_string(), Box::new(tx.clone()));
tx
};
tx.subscribe()
};
let callback = Arc::new(callback);
tokio::spawn(async move {
while let Ok(message) = rx.recv().await {
let cb = Arc::clone(&callback);
cb(message).await;
}
});
Ok(())
}
pub async fn publish<T>(&self, name: &str, message: T) -> Result<(), String>
where
T: Clone + Send + Sync + 'static,
{
let map = self.hubs.read().await;
if let Some(any) = map.get(name) {
if let Some(tx) = any.downcast_ref::<broadcast::Sender<T>>() {
let _ = tx.send(message);
Ok(())
} else {
Err(format!("Spread '{}' type mismatch", name))
}
} else {
Ok(())
}
}
}