use crate::Result;
use async_trait::async_trait;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
pub trait Event: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static {
fn topic(&self) -> String;
}
#[async_trait]
pub trait Subscriber: Send + Sync + 'static {
fn topic_pattern(&self) -> &str;
async fn handle(&mut self, topic: &str, payload: Vec<u8>) -> Result<()>;
}
pub trait SyncSubscriber: Send + Sync + 'static {
fn topic_pattern(&self) -> &str;
fn handle_sync(&mut self, topic: &str, payload: Vec<u8>) -> Result<()>;
}
pub struct Subscription {
pub id: Uuid,
pub topic_pattern: String,
registry: Arc<SubscriptionRegistry>,
}
impl Subscription {
pub async fn unsubscribe(self) -> Result<()> {
self.registry.unsubscribe(&self.id).await
}
}
impl Drop for Subscription {
fn drop(&mut self) {
let id = self.id;
let registry = self.registry.clone();
tokio::spawn(async move {
let _ = registry.unsubscribe(&id).await;
});
}
}
pub struct SubscriptionRegistry {
subscribers: Arc<DashMap<Uuid, Box<dyn Subscriber>>>,
topic_subscriptions: Arc<RwLock<HashMap<String, Vec<Uuid>>>>,
}
impl SubscriptionRegistry {
pub fn new() -> Self {
Self {
subscribers: Arc::new(DashMap::new()),
topic_subscriptions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn subscribe<S: Subscriber>(&self, subscriber: S) -> Result<Subscription> {
let id = Uuid::new_v4();
let topic_pattern = subscriber.topic_pattern().to_string();
self.subscribers.insert(id, Box::new(subscriber));
let mut topics = self.topic_subscriptions.write().await;
topics
.entry(topic_pattern.clone())
.or_insert_with(Vec::new)
.push(id);
Ok(Subscription {
id,
topic_pattern,
registry: Arc::new(self.clone()),
})
}
pub async fn unsubscribe(&self, id: &Uuid) -> Result<()> {
if let Some((_, subscriber)) = self.subscribers.remove(id) {
let topic_pattern = subscriber.topic_pattern();
let mut topics = self.topic_subscriptions.write().await;
if let Some(subs) = topics.get_mut(topic_pattern) {
subs.retain(|sub_id| sub_id != id);
if subs.is_empty() {
topics.remove(topic_pattern);
}
}
}
Ok(())
}
pub async fn publish(&self, topic: &str, payload: Vec<u8>) -> Result<()> {
let topics = self.topic_subscriptions.read().await;
let mut matching_ids = Vec::new();
for (pattern, ids) in topics.iter() {
if topic_matches(topic, pattern) {
matching_ids.extend(ids.iter().copied());
}
}
drop(topics);
for id in matching_ids {
if let Some(mut subscriber) = self.subscribers.get_mut(&id) {
let _ = subscriber.handle(topic, payload.clone()).await;
}
}
Ok(())
}
}
impl Clone for SubscriptionRegistry {
fn clone(&self) -> Self {
Self {
subscribers: self.subscribers.clone(),
topic_subscriptions: self.topic_subscriptions.clone(),
}
}
}
impl Default for SubscriptionRegistry {
fn default() -> Self {
Self::new()
}
}
fn topic_matches(topic: &str, pattern: &str) -> bool {
if topic == pattern {
return true;
}
let topic_parts: Vec<&str> = topic.split('/').collect();
let pattern_parts: Vec<&str> = pattern.split('/').collect();
let mut t_idx = 0;
let mut p_idx = 0;
while p_idx < pattern_parts.len() && t_idx < topic_parts.len() {
match pattern_parts[p_idx] {
"#" => return true, "+" => {
t_idx += 1;
p_idx += 1;
}
part => {
if part != topic_parts[t_idx] {
return false;
}
t_idx += 1;
p_idx += 1;
}
}
}
t_idx == topic_parts.len() && p_idx == pattern_parts.len()
}
pub struct SyncSubscriberAdapter<S: SyncSubscriber> {
inner: S,
}
impl<S: SyncSubscriber> SyncSubscriberAdapter<S> {
pub fn new(subscriber: S) -> Self {
Self { inner: subscriber }
}
}
#[async_trait]
impl<S: SyncSubscriber> Subscriber for SyncSubscriberAdapter<S> {
fn topic_pattern(&self) -> &str {
self.inner.topic_pattern()
}
async fn handle(&mut self, _topic: &str, _payload: Vec<u8>) -> Result<()> {
Ok(())
}
}