1pub mod events;
2
3pub mod channel;
4pub mod listener;
5
6pub use channel::ChangeChannel;
7pub use events::{ChangeEvent, ChangeType, EventFilter};
8pub use listener::ChangeListener;
9
10use crate::error::Result;
11use dashmap::DashMap;
12use std::sync::Arc;
13use tokio::sync::broadcast;
14
15pub struct PubSubSystem {
16 channels: Arc<DashMap<String, broadcast::Sender<ChangeEvent>>>,
17 global_channel: broadcast::Sender<ChangeEvent>,
18 buffer_size: usize,
19}
20
21impl PubSubSystem {
22 pub fn new(buffer_size: usize) -> Self {
23 let (global_tx, _) = broadcast::channel(buffer_size);
24
25 Self {
26 channels: Arc::new(DashMap::new()),
27 global_channel: global_tx,
28 buffer_size,
29 }
30 }
31
32 pub fn publish(&self, event: ChangeEvent) -> Result<()> {
33 if let Some(tx) = self.channels.get(&event.collection) {
34 let _ = tx.send(event.clone());
35 }
36
37 let _ = self.global_channel.send(event);
38
39 Ok(())
40 }
41
42 pub fn listen(&self, collection: impl Into<String>) -> ChangeListener {
43 let collection = collection.into();
44
45 if !self.channels.contains_key(&collection) {
46 self.channels
47 .retain(|_, sender| sender.receiver_count() > 0);
48 }
49
50 let tx = self
51 .channels
52 .entry(collection.clone())
53 .or_insert_with(|| broadcast::channel(self.buffer_size).0)
54 .clone();
55
56 ChangeListener::new(collection, tx.subscribe())
57 }
58
59 pub fn listen_all(&self) -> ChangeListener {
60 ChangeListener::new("*".to_string(), self.global_channel.subscribe())
61 }
62
63 pub fn listener_count(&self, collection: &str) -> usize {
64 self.channels
65 .get(collection)
66 .map(|tx| tx.receiver_count())
67 .unwrap_or(0)
68 }
69
70 pub fn total_listeners(&self) -> usize {
71 self.channels
72 .iter()
73 .map(|entry| entry.value().receiver_count())
74 .sum::<usize>()
75 + self.global_channel.receiver_count()
76 }
77}
78
79impl Clone for PubSubSystem {
80 fn clone(&self) -> Self {
81 Self {
82 channels: Arc::clone(&self.channels),
83 global_channel: self.global_channel.clone(),
84 buffer_size: self.buffer_size,
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::types::{Document, Value};
93 use std::collections::HashMap;
94
95 #[tokio::test]
96 async fn test_pubsub_basic() {
97 let pubsub = PubSubSystem::new(100);
98
99 let mut listener = pubsub.listen("users");
100
101 let mut data = HashMap::new();
102 data.insert("id".to_string(), Value::String("123".into()));
103 data.insert("name".to_string(), Value::String("Alice".into()));
104
105 let event = ChangeEvent {
106 collection: "users".to_string(),
107 change_type: ChangeType::Insert,
108 id: "123".to_string(),
109 document: Some(Document {
110 id: "123".to_string(),
111 data,
112 }),
113 old_document: None,
114 };
115
116 pubsub.publish(event.clone()).unwrap();
117
118 let received = listener.recv().await.unwrap();
119 assert_eq!(received.collection, "users");
120 assert_eq!(received.id, "123");
121 assert!(matches!(received.change_type, ChangeType::Insert));
122 }
123
124 #[tokio::test]
125 async fn test_pubsub_multiple_listeners() {
126 let pubsub = PubSubSystem::new(100);
127
128 let mut listener1 = pubsub.listen("users");
129 let mut listener2 = pubsub.listen("users");
130
131 assert_eq!(pubsub.listener_count("users"), 2);
132
133 let event = ChangeEvent {
134 collection: "users".to_string(),
135 change_type: ChangeType::Insert,
136 id: "123".to_string(),
137 document: None,
138 old_document: None,
139 };
140
141 pubsub.publish(event).unwrap();
142
143 assert!(listener1.recv().await.is_ok());
144 assert!(listener2.recv().await.is_ok());
145 }
146
147 #[tokio::test]
148 async fn test_pubsub_global_listener() {
149 let pubsub = PubSubSystem::new(100);
150
151 let mut global_listener = pubsub.listen_all();
152
153 pubsub
154 .publish(ChangeEvent {
155 collection: "users".to_string(),
156 change_type: ChangeType::Insert,
157 id: "1".to_string(),
158 document: None,
159 old_document: None,
160 })
161 .unwrap();
162
163 pubsub
164 .publish(ChangeEvent {
165 collection: "posts".to_string(),
166 change_type: ChangeType::Insert,
167 id: "2".to_string(),
168 document: None,
169 old_document: None,
170 })
171 .unwrap();
172
173 let event1 = global_listener.recv().await.unwrap();
174 let event2 = global_listener.recv().await.unwrap();
175
176 assert_eq!(event1.collection, "users");
177 assert_eq!(event2.collection, "posts");
178 }
179}