use async_trait::async_trait;
use std::pin::Pin;
use std::sync::Arc;
use std::collections::HashMap;
use tokio::sync::Mutex;
use super::subscription_handler::{
SubscriptionHandler, SubscriptionResponse, SubscriptionResponseTx,
};
use super::filter::{DefaultSubscriptionFilter, SubscriptionFilter};
pub type BroadcastCallback<R> = Box<
dyn Fn(String, R) -> Pin<Box<dyn Future<Output = Result<(), anyhow::Error>> + Send + Sync>>
+ Send
+ Sync,
>;
struct SubscriptionItem<R> {
id: usize,
filter: Box<dyn SubscriptionFilter<R> + Send + Sync>,
}
pub struct SubscriptionManager<R>
where
R: Clone + Send + Sync + 'static,
{
subscriptions: Arc<Mutex<HashMap<String, Vec<SubscriptionItem<R>>>>>,
id_count: usize,
broadcast_callback: BroadcastCallback<R>,
}
impl<R> SubscriptionManager<R>
where
R: Clone + Send + Sync + 'static,
{
pub fn new(broadcast_callback: BroadcastCallback<R>) -> Self {
Self {
subscriptions: Arc::new(Mutex::new(HashMap::new())),
id_count: 0,
broadcast_callback: broadcast_callback,
}
}
}
#[async_trait]
impl<R> SubscriptionHandler<R> for SubscriptionManager<R>
where
R: Clone + Send + Sync + 'static,
{
async fn subscribe(&mut self, topic: String, respond_to: SubscriptionResponseTx) {
log::info!(
"Subcription Manager subscribing to topic: {} with id {}",
topic,
self.id_count + 1
);
let mut subscriptions = self.subscriptions.lock().await;
subscriptions
.entry(topic.clone())
.or_insert_with(Vec::new)
.push(SubscriptionItem {
id: self.id_count + 1,
filter: Box::new(DefaultSubscriptionFilter),
});
let count = subscriptions.get(&topic).unwrap().len();
log::info!(
"Subcription Manager topic: {} has {} subscribers",
topic,
count
);
self.id_count += 1;
let _ = respond_to.send(SubscriptionResponse { id: self.id_count });
}
async fn subscribe_with_filter(
&mut self,
topic: String,
filter: Box<dyn SubscriptionFilter<R> + Send + Sync>,
respond_to: SubscriptionResponseTx,
) {
let mut subscriptions = self.subscriptions.lock().await;
subscriptions
.entry(topic)
.or_insert_with(Vec::new)
.push(SubscriptionItem {
id: self.id_count + 1,
filter,
});
self.id_count += 1;
let _ = respond_to.send(SubscriptionResponse { id: self.id_count });
}
async fn unsubscribe(&mut self, id: usize, respond_to: SubscriptionResponseTx) {
let mut subscriptions = self.subscriptions.lock().await;
for (_, subscribers) in subscriptions.iter_mut() {
subscribers.retain(|subscriber| subscriber.id != id);
}
let mut topics_to_remove = Vec::new();
for (topic, _) in subscriptions.iter() {
log::info!(
"Subcription Manager topic: {} has {} subscribers",
topic,
subscriptions.get(topic).unwrap().len()
);
if subscriptions.get(topic).map_or(true, |v| v.is_empty()) {
topics_to_remove.push(topic.clone());
}
}
for topic in topics_to_remove {
subscriptions.remove(&topic);
}
let _ = respond_to.send(SubscriptionResponse { id: id });
}
async fn broadcast(&self, topic: String, message: R) -> Result<(), anyhow::Error> {
log::debug!("Subcription Manager checking topic: {}", topic);
let subscriptions = self.subscriptions.lock().await;
for (k, v) in subscriptions.iter() {
log::debug!(
"Subcription Manager topic: {} has {} subscribers",
k,
v.len()
);
if k == topic.as_str() {
log::debug!("k {} == topic {}", k, topic);
} else {
log::debug!("k {} != topic {}", k, topic);
}
}
if let Some(subscribers) = subscriptions.get(&topic) {
for subscriber in subscribers {
log::debug!("Subcription Checking filter on: {}", topic);
if subscriber.filter.matches(&message) {
log::debug!("Subcription Manager broadcasting topic: {}", topic);
return (self.broadcast_callback)(topic.clone(), message.clone()).await;
}
}
} else {
log::debug!("Subcription Manager no subscribers for topic: {}", topic);
}
Ok(())
}
}