topic_stream/
topic_stream.rs

1use async_broadcast::{Receiver, SendError, Sender};
2use dashmap::DashMap;
3use futures::future::select_all;
4use std::{collections::HashSet, hash::Hash, sync::Arc};
5
6/// A topic-based publish-subscribe stream that allows multiple subscribers
7/// to listen to messages associated with specific topics.
8///
9/// # Type Parameters
10/// - T: The type representing a topic. Must be hashable, comparable, and clonable.
11/// - M: The message type that will be published and received. Must be clonable.
12#[derive(Debug, Clone)]
13pub struct TopicStream<T: Eq + Hash + Clone, M: Clone> {
14    /// Stores the active subscribers for each topic.
15    subscribers: Arc<DashMap<T, Sender<M>>>,
16    /// The maximum number of messages each topic can hold in its buffer.
17    buffer_size: usize,
18}
19
20impl<T: Eq + Hash + Clone, M: Clone> TopicStream<T, M> {
21    /// Creates a new TopicStream instance with the specified buffer size.
22    ///
23    /// # Arguments
24    /// - buffer_size: The maximum number of messages each topic can hold in its buffer.
25    ///
26    /// # Returns
27    /// A new TopicStream instance.
28    pub fn new(buffer_size: usize) -> Self {
29        Self {
30            subscribers: Arc::new(DashMap::new()),
31            buffer_size,
32        }
33    }
34
35    /// Subscribes to a list of topics and returns a MultiTopicReceiver
36    /// that can receive messages from them.
37    ///
38    /// # Arguments
39    /// - topics: A slice of topics to subscribe to.
40    ///
41    /// # Returns
42    /// A MultiTopicReceiver that listens to the specified topics.
43    pub fn subscribe(&self, topics: &[T]) -> MultiTopicReceiver<T, M> {
44        let mut receiver = MultiTopicReceiver::new(Arc::clone(&self.subscribers), self.buffer_size);
45        receiver.subscribe(topics);
46
47        receiver
48    }
49
50    /// Publishes a message to a specific topic. If the topic has no subscribers,
51    /// the message is ignored.
52    ///
53    /// # Arguments
54    /// - topic: The topic to publish the message to.
55    /// - message: The message to send.
56    ///
57    /// # Returns
58    /// - Ok(()): If the message was successfully sent or there were no subscribers.
59    /// - Err(SendError<M>): If there was an error sending the message.
60    pub async fn publish(&self, topic: &T, message: M) -> Result<(), SendError<M>> {
61        if let Some(sender) = self.subscribers.get(topic) {
62            sender.broadcast(message).await?;
63        };
64
65        Ok(())
66    }
67}
68
69/// A multi-topic receiver that listens to messages from multiple topics.
70///
71/// # Type Parameters
72/// - T: The type representing a topic.
73/// - M: The message type being received.
74#[derive(Debug)]
75pub struct MultiTopicReceiver<T: Eq + Hash + Clone, M: Clone> {
76    /// A reference to the associated TopicStream.
77    subscribers: Arc<DashMap<T, Sender<M>>>,
78    /// The list of active message receivers for the subscribed topics.
79    receivers: Vec<Receiver<M>>,
80    /// Tracks the topics this receiver is currently subscribed to.
81    subscribed_topics: HashSet<T>,
82    /// The message buffer for each topic's broadcast channel.
83    buffer_size: usize,
84}
85
86impl<T: Eq + Hash + Clone, M: Clone> MultiTopicReceiver<T, M> {
87    /// Creates a new MultiTopicReceiver for the given TopicStream.
88    ///
89    /// # Arguments
90    /// - subscribers: A reference to the DashMap containing the active subscribers.
91    /// - buffer_size: The size of the message buffer for each topic's broadcast channel.
92    ///
93    /// # Returns
94    /// A new MultiTopicReceiver instance.
95    pub fn new(subscribers: Arc<DashMap<T, Sender<M>>>, buffer_size: usize) -> Self {
96        Self {
97            subscribers,
98            receivers: Vec::new(),
99            subscribed_topics: HashSet::new(),
100            buffer_size,
101        }
102    }
103
104    /// Subscribes to the given list of topics. If already subscribed to a topic,
105    /// it is ignored.
106    ///
107    /// # Arguments
108    /// - topics: A slice of topics to subscribe to.
109    pub fn subscribe(&mut self, topics: &[T]) {
110        self.receivers.extend(
111            topics
112                .iter()
113                .filter(|topic| self.subscribed_topics.insert((*topic).clone()))
114                .map(|topic| {
115                    let topic = topic.clone();
116                    let (sender, _receiver) = async_broadcast::broadcast(self.buffer_size);
117
118                    self.subscribers
119                        .entry(topic)
120                        .or_insert_with(|| sender)
121                        .new_receiver()
122                }),
123        );
124    }
125
126    /// An Option<M> containing the received message, or None if all receivers are closed.
127    pub async fn recv(&mut self) -> Option<M> {
128        self.receivers.retain(|r| !r.is_closed());
129
130        if self.receivers.is_empty() {
131            return None;
132        }
133
134        let futures = self
135            .receivers
136            .iter_mut()
137            .map(|receiver| Box::pin(receiver.recv()))
138            .collect::<Vec<_>>();
139
140        let (result, _index, _remaining) = select_all(futures).await;
141
142        result.ok() // If a message is received, return it; otherwise, return None.
143    }
144}
145
146impl<T: Eq + Hash + Clone, M: Clone> Drop for MultiTopicReceiver<T, M> {
147    fn drop(&mut self) {
148        let mut to_remove = Vec::new();
149
150        for topic in &self.subscribed_topics {
151            if let Some(sender) = self.subscribers.get(topic) {
152                if sender.receiver_count() <= 1 {
153                    to_remove.push(topic.clone());
154                }
155            }
156        }
157
158        to_remove.into_iter().for_each(|topic| {
159            self.subscribers.remove(&topic);
160        });
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use std::hash::Hash;
168
169    #[derive(Debug, Clone, Hash, Eq, PartialEq)]
170    struct Topic(String);
171
172    #[derive(Debug, Clone, Eq, PartialEq)]
173    struct Message(String);
174
175    #[tokio::test]
176    async fn test_subscribe_and_publish_single_subscriber() {
177        let publisher = TopicStream::<Topic, Message>::new(2);
178        let topic = Topic("test_topic".to_string());
179
180        // Subscriber subscribes to the topic
181        let mut receiver = publisher.subscribe(&[topic.clone()]);
182
183        // Publisher sends a message to the topic
184        let message = Message("Hello, Subscriber!".to_string());
185        publisher.publish(&topic, message.clone()).await.unwrap();
186
187        // Subscriber receives the message
188        let received_message = receiver.recv().await.unwrap();
189        assert_eq!(received_message, message);
190    }
191
192    #[tokio::test]
193    async fn test_subscribe_multiple_subscribers() {
194        let publisher = TopicStream::<Topic, Message>::new(2);
195        let topic = Topic("test_topic".to_string());
196
197        // Subscriber 1 subscribes to the topic
198        let mut receiver1 = publisher.subscribe(&[topic.clone()]);
199        // Subscriber 2 subscribes to the topic
200        let mut receiver2 = publisher.subscribe(&[topic.clone()]);
201
202        // Publisher sends a message to the topic
203        let message = Message("Hello, Subscribers!".to_string());
204        publisher.publish(&topic, message.clone()).await.unwrap();
205
206        // Subscriber 1 receives the message
207        let received_message1 = receiver1.recv().await.unwrap();
208        assert_eq!(received_message1, message);
209
210        // Subscriber 2 receives the message
211        let received_message2 = receiver2.recv().await.unwrap();
212        assert_eq!(received_message2, message);
213    }
214
215    #[tokio::test]
216    async fn test_publish_to_unsubscribed_topic() {
217        let publisher = TopicStream::<Topic, Message>::new(2);
218        let topic = Topic("test_topic".to_string());
219
220        // Subscriber subscribes to a non-existent topic
221        let mut receiver = publisher.subscribe(&[Topic("invalid_topic".to_string())]);
222
223        // Publisher sends a message to the topic with no subscribers
224        let message = Message("Hello, World!".to_string());
225        publisher.publish(&topic, message.clone()).await.unwrap();
226
227        // No subscribers, so nothing to receive
228        // Here we assume that nothing crashes or any side effects occur.
229        // Test should pass as no message should be received
230
231        // Use a timeout to ensure the test completes
232        let timeout = tokio::time::sleep(tokio::time::Duration::from_secs(1));
233        tokio::select! {
234            _ = timeout => {
235                // Timeout reached, test completes
236            }
237            _ = receiver.recv() => {
238                panic!("Unexpected message received after timeout");
239            }
240        }
241    }
242
243    #[tokio::test]
244    async fn test_multiple_messages_for_single_subscriber() {
245        let publisher = TopicStream::<Topic, Message>::new(2);
246        let topic = Topic("test_topic".to_string());
247
248        // Subscriber subscribes to the topic
249        let mut receiver = publisher.subscribe(&[topic.clone()]);
250
251        // Publisher sends multiple messages
252        let message1 = Message("Message 1".to_string());
253        let message2 = Message("Message 2".to_string());
254        publisher.publish(&topic, message1.clone()).await.unwrap();
255        publisher.publish(&topic, message2.clone()).await.unwrap();
256
257        // Subscriber receives the first message
258        let received_message1 = receiver.recv().await.unwrap();
259        assert_eq!(received_message1, message1);
260
261        // Subscriber receives the second message
262        let received_message2 = receiver.recv().await.unwrap();
263        assert_eq!(received_message2, message2);
264    }
265
266    #[tokio::test]
267    async fn test_multiple_publishers() {
268        let publisher = TopicStream::<Topic, Message>::new(2);
269        let topic = Topic("test_topic".to_string());
270
271        // Subscriber subscribes to the topic
272        let mut receiver = publisher.subscribe(&[topic.clone()]);
273
274        // Publisher 1 sends a message
275        let message1 = Message("Message from Publisher 1".to_string());
276        publisher.publish(&topic, message1.clone()).await.unwrap();
277
278        // Publisher 2 sends a message
279        let message2 = Message("Message from Publisher 2".to_string());
280        publisher.publish(&topic, message2.clone()).await.unwrap();
281
282        // Subscriber receives the first message
283        let received_message1 = receiver.recv().await.unwrap();
284        assert_eq!(received_message1, message1);
285
286        // Subscriber receives the second message
287        let received_message2 = receiver.recv().await.unwrap();
288        assert_eq!(received_message2, message2);
289    }
290
291    #[tokio::test]
292    async fn test_subscribe_to_different_topics() {
293        let publisher = TopicStream::<Topic, Message>::new(2);
294        let topic1 = Topic("test_topic_1".to_string());
295        let topic2 = Topic("test_topic_2".to_string());
296
297        // Subscriber subscribes to topic 1
298        let mut receiver1 = publisher.subscribe(&[topic1.clone()]);
299
300        // Publisher sends a message to topic 1
301        let message1 = Message("Hello, Topic 1".to_string());
302        publisher.publish(&topic1, message1.clone()).await.unwrap();
303
304        // Subscriber 1 receives the message for topic 1
305        let received_message1 = receiver1.recv().await.unwrap();
306        assert_eq!(received_message1, message1);
307
308        // Subscriber subscribes to topic 2
309        let mut receiver2 = publisher.subscribe(&[topic2.clone()]);
310
311        // Publisher sends a message to topic 2
312        let message2 = Message("Hello, Topic 2".to_string());
313        publisher.publish(&topic2, message2.clone()).await.unwrap();
314
315        // Subscriber 2 receives the message for topic 2
316        let received_message2 = receiver2.recv().await.unwrap();
317        assert_eq!(received_message2, message2);
318    }
319
320    #[tokio::test]
321    async fn test_single_receiver_multiple_topics() {
322        let publisher = TopicStream::<Topic, Message>::new(2);
323
324        // Define multiple topics
325        let topic1 = Topic("test_topic_1".to_string());
326        let topic2 = Topic("test_topic_2".to_string());
327        let topic3 = Topic("test_topic_3".to_string());
328
329        // Subscriber subscribes to multiple topics
330        let mut receiver = publisher.subscribe(&[topic1.clone(), topic2.clone(), topic3.clone()]);
331
332        // Publisher sends messages to each topic
333        let message1 = Message("Message for Topic 1".to_string());
334        let message2 = Message("Message for Topic 2".to_string());
335        let message3 = Message("Message for Topic 3".to_string());
336
337        publisher.publish(&topic1, message1.clone()).await.unwrap();
338        publisher.publish(&topic2, message2.clone()).await.unwrap();
339        publisher.publish(&topic3, message3.clone()).await.unwrap();
340
341        // Subscriber should receive the messages in the order they were published
342        let received_message1 = receiver.recv().await.unwrap();
343        assert_eq!(received_message1, message1);
344
345        let received_message2 = receiver.recv().await.unwrap();
346        assert_eq!(received_message2, message2);
347
348        let received_message3 = receiver.recv().await.unwrap();
349        assert_eq!(received_message3, message3);
350    }
351}