1use std::collections::HashMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6
7use tokio::sync::RwLock;
8
9use ormdb_proto::{ChangeEvent, ChangeType, Subscription};
10
11use super::subscription::{SubscriptionEntry, SubscriptionFilter};
12use crate::error::Error;
13
14pub struct PubSubManager {
20 subscriptions: RwLock<HashMap<u64, SubscriptionEntry>>,
22 entity_index: RwLock<HashMap<String, Vec<u64>>>,
24 next_subscription_id: AtomicU64,
26 event_queue: RwLock<Vec<ChangeEvent>>,
28}
29
30impl PubSubManager {
31 pub fn new() -> Self {
33 Self {
34 subscriptions: RwLock::new(HashMap::new()),
35 entity_index: RwLock::new(HashMap::new()),
36 next_subscription_id: AtomicU64::new(1),
37 event_queue: RwLock::new(Vec::new()),
38 }
39 }
40
41 pub async fn subscribe(
45 &self,
46 client_id: &str,
47 subscription: &Subscription,
48 ) -> Result<u64, Error> {
49 let subscription_id = self.next_subscription_id.fetch_add(1, Ordering::SeqCst);
50
51 let filter = SubscriptionFilter::from_subscription(subscription);
52 let entry = SubscriptionEntry::new(
53 subscription_id,
54 client_id,
55 &subscription.entity,
56 filter,
57 );
58
59 {
61 let mut subs = self.subscriptions.write().await;
62 subs.insert(subscription_id, entry);
63 }
64
65 {
67 let mut index = self.entity_index.write().await;
68 index
69 .entry(subscription.entity.clone())
70 .or_default()
71 .push(subscription_id);
72 }
73
74 tracing::debug!(
75 subscription_id,
76 client_id,
77 entity = %subscription.entity,
78 "subscription created"
79 );
80
81 Ok(subscription_id)
82 }
83
84 pub async fn unsubscribe(&self, subscription_id: u64) -> Result<(), Error> {
86 let entry = {
88 let mut subs = self.subscriptions.write().await;
89 subs.remove(&subscription_id)
90 };
91
92 let entry = match entry {
93 Some(e) => e,
94 None => {
95 return Err(Error::Database(format!(
96 "subscription {} not found",
97 subscription_id
98 )));
99 }
100 };
101
102 {
104 let mut index = self.entity_index.write().await;
105 if let Some(ids) = index.get_mut(&entry.entity) {
106 ids.retain(|&id| id != subscription_id);
107 if ids.is_empty() {
108 index.remove(&entry.entity);
109 }
110 }
111 }
112
113 tracing::debug!(
114 subscription_id,
115 client_id = %entry.client_id,
116 entity = %entry.entity,
117 events_sent = entry.events_sent,
118 "subscription removed"
119 );
120
121 Ok(())
122 }
123
124 pub async fn publish_event(
129 &self,
130 entity: &str,
131 entity_id: [u8; 16],
132 change_type: ChangeType,
133 changed_fields: Vec<String>,
134 schema_version: u64,
135 ) {
136 let subscription_ids = {
138 let index = self.entity_index.read().await;
139 match index.get(entity) {
140 Some(ids) => ids.clone(),
141 None => return, }
143 };
144
145 if subscription_ids.is_empty() {
146 return;
147 }
148
149 let mut events = Vec::new();
151 for subscription_id in subscription_ids {
152 let event = ChangeEvent {
153 subscription_id,
154 change_type,
155 entity: entity.to_string(),
156 entity_id,
157 changed_fields: changed_fields.clone(),
158 schema_version,
159 };
160 events.push(event);
161 }
162
163 {
165 let mut queue = self.event_queue.write().await;
166 queue.extend(events);
167 }
168
169 tracing::trace!(
170 entity,
171 change_type = ?change_type,
172 "published change event"
173 );
174 }
175
176 pub async fn drain_events(&self) -> Vec<ChangeEvent> {
180 let mut queue = self.event_queue.write().await;
181 std::mem::take(&mut *queue)
182 }
183
184 pub async fn subscription_count(&self) -> usize {
186 self.subscriptions.read().await.len()
187 }
188
189 pub async fn subscriptions_for_entity(&self, entity: &str) -> Vec<u64> {
191 let index = self.entity_index.read().await;
192 index.get(entity).cloned().unwrap_or_default()
193 }
194
195 pub async fn get_subscription(&self, subscription_id: u64) -> Option<SubscriptionEntry> {
197 let subs = self.subscriptions.read().await;
198 subs.get(&subscription_id).cloned()
199 }
200
201 pub async fn remove_client_subscriptions(&self, client_id: &str) {
203 let to_remove: Vec<u64> = {
204 let subs = self.subscriptions.read().await;
205 subs.iter()
206 .filter(|(_, entry)| entry.client_id == client_id)
207 .map(|(&id, _)| id)
208 .collect()
209 };
210
211 for subscription_id in to_remove {
212 let _ = self.unsubscribe(subscription_id).await;
213 }
214 }
215}
216
217impl Default for PubSubManager {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223pub type SharedPubSubManager = Arc<PubSubManager>;
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[tokio::test]
231 async fn test_subscribe_unsubscribe() {
232 let manager = PubSubManager::new();
233
234 let sub = Subscription::new("User");
235 let id = manager.subscribe("client-1", &sub).await.unwrap();
236
237 assert_eq!(manager.subscription_count().await, 1);
238 assert_eq!(manager.subscriptions_for_entity("User").await, vec![id]);
239
240 manager.unsubscribe(id).await.unwrap();
241
242 assert_eq!(manager.subscription_count().await, 0);
243 assert!(manager.subscriptions_for_entity("User").await.is_empty());
244 }
245
246 #[tokio::test]
247 async fn test_publish_event() {
248 let manager = PubSubManager::new();
249
250 let sub = Subscription::new("User");
252 let id = manager.subscribe("client-1", &sub).await.unwrap();
253
254 manager
256 .publish_event("User", [1u8; 16], ChangeType::Insert, vec!["name".to_string()], 1)
257 .await;
258
259 let events = manager.drain_events().await;
261 assert_eq!(events.len(), 1);
262 assert_eq!(events[0].subscription_id, id);
263 assert_eq!(events[0].change_type, ChangeType::Insert);
264 assert_eq!(events[0].entity, "User");
265
266 assert!(manager.drain_events().await.is_empty());
268 }
269
270 #[tokio::test]
271 async fn test_no_event_without_subscription() {
272 let manager = PubSubManager::new();
273
274 manager
276 .publish_event("User", [1u8; 16], ChangeType::Insert, vec![], 1)
277 .await;
278
279 assert!(manager.drain_events().await.is_empty());
281 }
282
283 #[tokio::test]
284 async fn test_multiple_subscriptions() {
285 let manager = PubSubManager::new();
286
287 let sub = Subscription::new("User");
289 let id1 = manager.subscribe("client-1", &sub).await.unwrap();
290 let id2 = manager.subscribe("client-2", &sub).await.unwrap();
291
292 assert_eq!(manager.subscription_count().await, 2);
293
294 manager
296 .publish_event("User", [1u8; 16], ChangeType::Update, vec![], 1)
297 .await;
298
299 let events = manager.drain_events().await;
301 assert_eq!(events.len(), 2);
302
303 let ids: Vec<u64> = events.iter().map(|e| e.subscription_id).collect();
304 assert!(ids.contains(&id1));
305 assert!(ids.contains(&id2));
306 }
307
308 #[tokio::test]
309 async fn test_remove_client_subscriptions() {
310 let manager = PubSubManager::new();
311
312 manager.subscribe("client-1", &Subscription::new("User")).await.unwrap();
314 manager.subscribe("client-1", &Subscription::new("Post")).await.unwrap();
315
316 let id3 = manager.subscribe("client-2", &Subscription::new("User")).await.unwrap();
318
319 assert_eq!(manager.subscription_count().await, 3);
320
321 manager.remove_client_subscriptions("client-1").await;
323
324 assert_eq!(manager.subscription_count().await, 1);
325 assert!(manager.get_subscription(id3).await.is_some());
326 }
327}