use async_broadcast::{Receiver, SendError, Sender};
use dashmap::DashMap;
use futures::future::select_all;
use std::{collections::HashSet, hash::Hash, sync::Arc};
#[derive(Debug, Clone)]
pub struct TopicStream<T: Eq + Hash + Clone, M: Clone> {
subscribers: Arc<DashMap<T, Sender<M>>>,
buffer_size: usize,
}
impl<T: Eq + Hash + Clone, M: Clone> TopicStream<T, M> {
pub fn new(buffer_size: usize) -> Self {
Self {
subscribers: Arc::new(DashMap::new()),
buffer_size,
}
}
pub fn subscribe(&self, topics: &[T]) -> MultiTopicReceiver<T, M> {
let mut receiver = MultiTopicReceiver::new(Arc::clone(&self.subscribers), self.buffer_size);
receiver.subscribe(topics);
receiver
}
pub async fn publish(&self, topic: &T, message: M) -> Result<(), SendError<M>> {
if let Some(sender) = self.subscribers.get(topic) {
sender.broadcast(message).await?;
};
Ok(())
}
}
#[derive(Debug)]
pub struct MultiTopicReceiver<T: Eq + Hash + Clone, M: Clone> {
subscribers: Arc<DashMap<T, Sender<M>>>,
receivers: Vec<Receiver<M>>,
subscribed_topics: HashSet<T>,
buffer_size: usize,
}
impl<T: Eq + Hash + Clone, M: Clone> MultiTopicReceiver<T, M> {
pub fn new(subscribers: Arc<DashMap<T, Sender<M>>>, buffer_size: usize) -> Self {
Self {
subscribers,
receivers: Vec::new(),
subscribed_topics: HashSet::new(),
buffer_size,
}
}
pub fn subscribe(&mut self, topics: &[T]) {
self.receivers.extend(
topics
.iter()
.filter(|topic| self.subscribed_topics.insert((*topic).clone()))
.map(|topic| {
let topic = topic.clone();
let (sender, _receiver) = async_broadcast::broadcast(self.buffer_size);
self.subscribers
.entry(topic)
.or_insert_with(|| sender)
.new_receiver()
}),
);
}
pub async fn recv(&mut self) -> Option<M> {
self.receivers.retain(|r| !r.is_closed());
if self.receivers.is_empty() {
return None;
}
let futures = self
.receivers
.iter_mut()
.map(|receiver| Box::pin(receiver.recv()))
.collect::<Vec<_>>();
let (result, _index, _remaining) = select_all(futures).await;
result.ok() }
}
impl<T: Eq + Hash + Clone, M: Clone> Drop for MultiTopicReceiver<T, M> {
fn drop(&mut self) {
let mut to_remove = Vec::new();
for topic in &self.subscribed_topics {
if let Some(sender) = self.subscribers.get(topic) {
if sender.receiver_count() <= 1 {
to_remove.push(topic.clone());
}
}
}
to_remove.into_iter().for_each(|topic| {
self.subscribers.remove(&topic);
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::hash::Hash;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
struct Topic(String);
#[derive(Debug, Clone, Eq, PartialEq)]
struct Message(String);
#[tokio::test]
async fn test_subscribe_and_publish_single_subscriber() {
let publisher = TopicStream::<Topic, Message>::new(2);
let topic = Topic("test_topic".to_string());
let mut receiver = publisher.subscribe(&[topic.clone()]);
let message = Message("Hello, Subscriber!".to_string());
publisher.publish(&topic, message.clone()).await.unwrap();
let received_message = receiver.recv().await.unwrap();
assert_eq!(received_message, message);
}
#[tokio::test]
async fn test_subscribe_multiple_subscribers() {
let publisher = TopicStream::<Topic, Message>::new(2);
let topic = Topic("test_topic".to_string());
let mut receiver1 = publisher.subscribe(&[topic.clone()]);
let mut receiver2 = publisher.subscribe(&[topic.clone()]);
let message = Message("Hello, Subscribers!".to_string());
publisher.publish(&topic, message.clone()).await.unwrap();
let received_message1 = receiver1.recv().await.unwrap();
assert_eq!(received_message1, message);
let received_message2 = receiver2.recv().await.unwrap();
assert_eq!(received_message2, message);
}
#[tokio::test]
async fn test_publish_to_unsubscribed_topic() {
let publisher = TopicStream::<Topic, Message>::new(2);
let topic = Topic("test_topic".to_string());
let mut receiver = publisher.subscribe(&[Topic("invalid_topic".to_string())]);
let message = Message("Hello, World!".to_string());
publisher.publish(&topic, message.clone()).await.unwrap();
let timeout = tokio::time::sleep(tokio::time::Duration::from_secs(1));
tokio::select! {
_ = timeout => {
}
_ = receiver.recv() => {
panic!("Unexpected message received after timeout");
}
}
}
#[tokio::test]
async fn test_multiple_messages_for_single_subscriber() {
let publisher = TopicStream::<Topic, Message>::new(2);
let topic = Topic("test_topic".to_string());
let mut receiver = publisher.subscribe(&[topic.clone()]);
let message1 = Message("Message 1".to_string());
let message2 = Message("Message 2".to_string());
publisher.publish(&topic, message1.clone()).await.unwrap();
publisher.publish(&topic, message2.clone()).await.unwrap();
let received_message1 = receiver.recv().await.unwrap();
assert_eq!(received_message1, message1);
let received_message2 = receiver.recv().await.unwrap();
assert_eq!(received_message2, message2);
}
#[tokio::test]
async fn test_multiple_publishers() {
let publisher = TopicStream::<Topic, Message>::new(2);
let topic = Topic("test_topic".to_string());
let mut receiver = publisher.subscribe(&[topic.clone()]);
let message1 = Message("Message from Publisher 1".to_string());
publisher.publish(&topic, message1.clone()).await.unwrap();
let message2 = Message("Message from Publisher 2".to_string());
publisher.publish(&topic, message2.clone()).await.unwrap();
let received_message1 = receiver.recv().await.unwrap();
assert_eq!(received_message1, message1);
let received_message2 = receiver.recv().await.unwrap();
assert_eq!(received_message2, message2);
}
#[tokio::test]
async fn test_subscribe_to_different_topics() {
let publisher = TopicStream::<Topic, Message>::new(2);
let topic1 = Topic("test_topic_1".to_string());
let topic2 = Topic("test_topic_2".to_string());
let mut receiver1 = publisher.subscribe(&[topic1.clone()]);
let message1 = Message("Hello, Topic 1".to_string());
publisher.publish(&topic1, message1.clone()).await.unwrap();
let received_message1 = receiver1.recv().await.unwrap();
assert_eq!(received_message1, message1);
let mut receiver2 = publisher.subscribe(&[topic2.clone()]);
let message2 = Message("Hello, Topic 2".to_string());
publisher.publish(&topic2, message2.clone()).await.unwrap();
let received_message2 = receiver2.recv().await.unwrap();
assert_eq!(received_message2, message2);
}
#[tokio::test]
async fn test_single_receiver_multiple_topics() {
let publisher = TopicStream::<Topic, Message>::new(2);
let topic1 = Topic("test_topic_1".to_string());
let topic2 = Topic("test_topic_2".to_string());
let topic3 = Topic("test_topic_3".to_string());
let mut receiver = publisher.subscribe(&[topic1.clone(), topic2.clone(), topic3.clone()]);
let message1 = Message("Message for Topic 1".to_string());
let message2 = Message("Message for Topic 2".to_string());
let message3 = Message("Message for Topic 3".to_string());
publisher.publish(&topic1, message1.clone()).await.unwrap();
publisher.publish(&topic2, message2.clone()).await.unwrap();
publisher.publish(&topic3, message3.clone()).await.unwrap();
let received_message1 = receiver.recv().await.unwrap();
assert_eq!(received_message1, message1);
let received_message2 = receiver.recv().await.unwrap();
assert_eq!(received_message2, message2);
let received_message3 = receiver.recv().await.unwrap();
assert_eq!(received_message3, message3);
}
}