1pub mod events;
2
3pub mod channel;
4pub mod listener;
5
6pub use channel::ChangeChannel;
7pub use events::{ChangeEvent, ChangeType, EventFilter};
8pub use listener::ChangeListener;
9
10use crate::error::Result;
11use dashmap::DashMap;
12use std::sync::Arc;
13use tokio::sync::broadcast;
14
15pub struct PubSubSystem {
16 channels: Arc<DashMap<String, broadcast::Sender<ChangeEvent>>>,
17 global_channel: broadcast::Sender<ChangeEvent>,
18 buffer_size: usize,
19}
20
21impl PubSubSystem {
22 pub fn new(buffer_size: usize) -> Self {
23 let (global_tx, _) = broadcast::channel(buffer_size);
24
25 Self {
26 channels: Arc::new(DashMap::new()),
27 global_channel: global_tx,
28 buffer_size,
29 }
30 }
31
32 pub fn publish(&self, event: ChangeEvent) -> Result<()> {
33 if let Some(tx) = self.channels.get(&event.collection) {
34 let _ = tx.send(event.clone());
35 }
36
37 let _ = self.global_channel.send(event);
38
39 Ok(())
40 }
41
42 pub fn listen(&self, collection: impl Into<String>) -> ChangeListener {
43 let collection = collection.into();
44
45 if !self.channels.contains_key(&collection) {
46 self.channels
47 .retain(|_, sender| sender.receiver_count() > 0);
48 }
49
50 let tx = self
51 .channels
52 .entry(collection.clone())
53 .or_insert_with(|| broadcast::channel(self.buffer_size).0)
54 .clone();
55
56 ChangeListener::new(collection, tx.subscribe())
57 }
58
59 pub fn listen_all(&self) -> ChangeListener {
60 ChangeListener::new("*".to_string(), self.global_channel.subscribe())
61 }
62
63 pub fn listener_count(&self, collection: &str) -> usize {
64 self.channels
65 .get(collection)
66 .map(|tx| tx.receiver_count())
67 .unwrap_or(0)
68 }
69
70 pub fn total_listeners(&self) -> usize {
71 self.channels
72 .iter()
73 .map(|entry| entry.value().receiver_count())
74 .sum::<usize>()
75 + self.global_channel.receiver_count()
76 }
77}
78
79impl Clone for PubSubSystem {
80 fn clone(&self) -> Self {
81 Self {
82 channels: Arc::clone(&self.channels),
83 global_channel: self.global_channel.clone(),
84 buffer_size: self.buffer_size,
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::types::{Document, Value};
93 use std::collections::HashMap;
94
95 #[tokio::test]
96 async fn test_pubsub_basic() {
97 let pubsub = PubSubSystem::new(100);
98
99 let mut listener = pubsub.listen("users");
100
101 let mut data = HashMap::new();
102 data.insert("id".to_string(), Value::String("123".into()));
103 data.insert("name".to_string(), Value::String("Alice".into()));
104
105 let event = ChangeEvent::insert(
106 "users",
107 "123",
108 Document {
109 _sid: "123".to_string(),
110 data,
111 },
112 );
113
114 pubsub.publish(event.clone()).unwrap();
115
116 let received = listener.recv().await.unwrap();
117 assert_eq!(received.collection, "users");
118 assert_eq!(received._sid, "123");
119 assert!(matches!(received.change_type, ChangeType::Insert));
120 }
121
122 #[tokio::test]
123 async fn test_pubsub_multiple_listeners() {
124 let pubsub = PubSubSystem::new(100);
125
126 let mut listener1 = pubsub.listen("users");
127 let mut listener2 = pubsub.listen("users");
128
129 assert_eq!(pubsub.listener_count("users"), 2);
130
131 let event = ChangeEvent::delete("users", "123");
132
133 pubsub.publish(event).unwrap();
134
135 assert!(listener1.recv().await.is_ok());
136 assert!(listener2.recv().await.is_ok());
137 }
138
139 #[tokio::test]
140 async fn test_pubsub_global_listener() {
141 let pubsub = PubSubSystem::new(100);
142
143 let mut global_listener = pubsub.listen_all();
144
145 pubsub.publish(ChangeEvent::delete("users", "1")).unwrap();
146 pubsub.publish(ChangeEvent::delete("posts", "2")).unwrap();
147
148 let event1 = global_listener.recv().await.unwrap();
149 let event2 = global_listener.recv().await.unwrap();
150
151 assert_eq!(event1.collection, "users");
152 assert_eq!(event2.collection, "posts");
153 }
154}