Skip to main content

aurora_db/pubsub/
mod.rs

1pub mod events;
2
3pub mod channel;
4pub mod listener;
5
6pub use channel::ChangeChannel;
7pub use events::{ChangeEvent, ChangeType, EventFilter};
8pub use listener::ChangeListener;
9
10use crate::error::Result;
11use dashmap::DashMap;
12use std::sync::Arc;
13use tokio::sync::broadcast;
14
15pub struct PubSubSystem {
16    channels: Arc<DashMap<String, broadcast::Sender<ChangeEvent>>>,
17    global_channel: broadcast::Sender<ChangeEvent>,
18    buffer_size: usize,
19}
20
21impl PubSubSystem {
22    pub fn new(buffer_size: usize) -> Self {
23        let (global_tx, _) = broadcast::channel(buffer_size);
24
25        Self {
26            channels: Arc::new(DashMap::new()),
27            global_channel: global_tx,
28            buffer_size,
29        }
30    }
31
32    pub fn publish(&self, event: ChangeEvent) -> Result<()> {
33        if let Some(tx) = self.channels.get(&event.collection) {
34            let _ = tx.send(event.clone());
35        }
36
37        let _ = self.global_channel.send(event);
38
39        Ok(())
40    }
41
42    pub fn listen(&self, collection: impl Into<String>) -> ChangeListener {
43        let collection = collection.into();
44
45        if !self.channels.contains_key(&collection) {
46            self.channels
47                .retain(|_, sender| sender.receiver_count() > 0);
48        }
49
50        let tx = self
51            .channels
52            .entry(collection.clone())
53            .or_insert_with(|| broadcast::channel(self.buffer_size).0)
54            .clone();
55
56        ChangeListener::new(collection, tx.subscribe())
57    }
58
59    pub fn listen_all(&self) -> ChangeListener {
60        ChangeListener::new("*".to_string(), self.global_channel.subscribe())
61    }
62
63    pub fn listener_count(&self, collection: &str) -> usize {
64        self.channels
65            .get(collection)
66            .map(|tx| tx.receiver_count())
67            .unwrap_or(0)
68    }
69
70    pub fn total_listeners(&self) -> usize {
71        self.channels
72            .iter()
73            .map(|entry| entry.value().receiver_count())
74            .sum::<usize>()
75            + self.global_channel.receiver_count()
76    }
77}
78
79impl Clone for PubSubSystem {
80    fn clone(&self) -> Self {
81        Self {
82            channels: Arc::clone(&self.channels),
83            global_channel: self.global_channel.clone(),
84            buffer_size: self.buffer_size,
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::types::{Document, Value};
93    use std::collections::HashMap;
94
95    #[tokio::test]
96    async fn test_pubsub_basic() {
97        let pubsub = PubSubSystem::new(100);
98
99        let mut listener = pubsub.listen("users");
100
101        let mut data = HashMap::new();
102        data.insert("id".to_string(), Value::String("123".into()));
103        data.insert("name".to_string(), Value::String("Alice".into()));
104
105        let event = ChangeEvent::insert(
106            "users",
107            "123",
108            Document {
109                _sid: "123".to_string(),
110                data,
111            },
112        );
113
114        pubsub.publish(event.clone()).unwrap();
115
116        let received = listener.recv().await.unwrap();
117        assert_eq!(received.collection, "users");
118        assert_eq!(received._sid, "123");
119        assert!(matches!(received.change_type, ChangeType::Insert));
120    }
121
122    #[tokio::test]
123    async fn test_pubsub_multiple_listeners() {
124        let pubsub = PubSubSystem::new(100);
125
126        let mut listener1 = pubsub.listen("users");
127        let mut listener2 = pubsub.listen("users");
128
129        assert_eq!(pubsub.listener_count("users"), 2);
130
131        let event = ChangeEvent::delete("users", "123");
132
133        pubsub.publish(event).unwrap();
134
135        assert!(listener1.recv().await.is_ok());
136        assert!(listener2.recv().await.is_ok());
137    }
138
139    #[tokio::test]
140    async fn test_pubsub_global_listener() {
141        let pubsub = PubSubSystem::new(100);
142
143        let mut global_listener = pubsub.listen_all();
144
145        pubsub.publish(ChangeEvent::delete("users", "1")).unwrap();
146        pubsub.publish(ChangeEvent::delete("posts", "2")).unwrap();
147
148        let event1 = global_listener.recv().await.unwrap();
149        let event2 = global_listener.recv().await.unwrap();
150
151        assert_eq!(event1.collection, "users");
152        assert_eq!(event2.collection, "posts");
153    }
154}