Skip to main content

aurora_db/pubsub/
mod.rs

1pub mod events;
2
3pub mod channel;
4pub mod listener;
5
6pub use channel::ChangeChannel;
7pub use events::{ChangeEvent, ChangeType, EventFilter};
8pub use listener::ChangeListener;
9
10use crate::error::Result;
11use dashmap::DashMap;
12use std::sync::Arc;
13use tokio::sync::broadcast;
14
15pub struct PubSubSystem {
16    channels: Arc<DashMap<String, broadcast::Sender<ChangeEvent>>>,
17    global_channel: broadcast::Sender<ChangeEvent>,
18    buffer_size: usize,
19}
20
21impl PubSubSystem {
22    pub fn new(buffer_size: usize) -> Self {
23        let (global_tx, _) = broadcast::channel(buffer_size);
24
25        Self {
26            channels: Arc::new(DashMap::new()),
27            global_channel: global_tx,
28            buffer_size,
29        }
30    }
31
32    pub fn publish(&self, event: ChangeEvent) -> Result<()> {
33        if let Some(tx) = self.channels.get(&event.collection) {
34            let _ = tx.send(event.clone());
35        }
36
37        let _ = self.global_channel.send(event);
38
39        Ok(())
40    }
41
42    pub fn listen(&self, collection: impl Into<String>) -> ChangeListener {
43        let collection = collection.into();
44
45        if !self.channels.contains_key(&collection) {
46            self.channels
47                .retain(|_, sender| sender.receiver_count() > 0);
48        }
49
50        let tx = self
51            .channels
52            .entry(collection.clone())
53            .or_insert_with(|| broadcast::channel(self.buffer_size).0)
54            .clone();
55
56        ChangeListener::new(collection, tx.subscribe())
57    }
58
59    pub fn listen_all(&self) -> ChangeListener {
60        ChangeListener::new("*".to_string(), self.global_channel.subscribe())
61    }
62
63    pub fn listener_count(&self, collection: &str) -> usize {
64        self.channels
65            .get(collection)
66            .map(|tx| tx.receiver_count())
67            .unwrap_or(0)
68    }
69
70    pub fn total_listeners(&self) -> usize {
71        self.channels
72            .iter()
73            .map(|entry| entry.value().receiver_count())
74            .sum::<usize>()
75            + self.global_channel.receiver_count()
76    }
77}
78
79impl Clone for PubSubSystem {
80    fn clone(&self) -> Self {
81        Self {
82            channels: Arc::clone(&self.channels),
83            global_channel: self.global_channel.clone(),
84            buffer_size: self.buffer_size,
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::types::{Document, Value};
93    use std::collections::HashMap;
94
95    #[tokio::test]
96    async fn test_pubsub_basic() {
97        let pubsub = PubSubSystem::new(100);
98
99        let mut listener = pubsub.listen("users");
100
101        let mut data = HashMap::new();
102        data.insert("id".to_string(), Value::String("123".into()));
103        data.insert("name".to_string(), Value::String("Alice".into()));
104
105        let event = ChangeEvent {
106            collection: "users".to_string(),
107            change_type: ChangeType::Insert,
108            id: "123".to_string(),
109            document: Some(Document {
110                id: "123".to_string(),
111                data,
112            }),
113            old_document: None,
114        };
115
116        pubsub.publish(event.clone()).unwrap();
117
118        let received = listener.recv().await.unwrap();
119        assert_eq!(received.collection, "users");
120        assert_eq!(received.id, "123");
121        assert!(matches!(received.change_type, ChangeType::Insert));
122    }
123
124    #[tokio::test]
125    async fn test_pubsub_multiple_listeners() {
126        let pubsub = PubSubSystem::new(100);
127
128        let mut listener1 = pubsub.listen("users");
129        let mut listener2 = pubsub.listen("users");
130
131        assert_eq!(pubsub.listener_count("users"), 2);
132
133        let event = ChangeEvent {
134            collection: "users".to_string(),
135            change_type: ChangeType::Insert,
136            id: "123".to_string(),
137            document: None,
138            old_document: None,
139        };
140
141        pubsub.publish(event).unwrap();
142
143        assert!(listener1.recv().await.is_ok());
144        assert!(listener2.recv().await.is_ok());
145    }
146
147    #[tokio::test]
148    async fn test_pubsub_global_listener() {
149        let pubsub = PubSubSystem::new(100);
150
151        let mut global_listener = pubsub.listen_all();
152
153        pubsub
154            .publish(ChangeEvent {
155                collection: "users".to_string(),
156                change_type: ChangeType::Insert,
157                id: "1".to_string(),
158                document: None,
159                old_document: None,
160            })
161            .unwrap();
162
163        pubsub
164            .publish(ChangeEvent {
165                collection: "posts".to_string(),
166                change_type: ChangeType::Insert,
167                id: "2".to_string(),
168                document: None,
169                old_document: None,
170            })
171            .unwrap();
172
173        let event1 = global_listener.recv().await.unwrap();
174        let event2 = global_listener.recv().await.unwrap();
175
176        assert_eq!(event1.collection, "users");
177        assert_eq!(event2.collection, "posts");
178    }
179}