helios_subscriptions/channels/
ws_manager.rs1use dashmap::DashMap;
8use tokio::sync::mpsc;
9use tracing::{debug, warn};
10use uuid::Uuid;
11
12pub type WsClientSender = mpsc::UnboundedSender<serde_json::Value>;
14
15pub struct WebSocketManager {
17 clients: DashMap<(String, String), Vec<(String, WsClientSender)>>,
20}
21
22impl WebSocketManager {
23 pub fn new() -> Self {
25 Self {
26 clients: DashMap::new(),
27 }
28 }
29
30 pub fn register_client(
35 &self,
36 tenant_id: &str,
37 subscription_id: &str,
38 ) -> (String, mpsc::UnboundedReceiver<serde_json::Value>) {
39 let (tx, rx) = mpsc::unbounded_channel();
40 let client_id = Uuid::new_v4().to_string();
41
42 self.clients
43 .entry((tenant_id.to_string(), subscription_id.to_string()))
44 .or_default()
45 .push((client_id.clone(), tx));
46
47 debug!(
48 tenant_id,
49 subscription_id,
50 client_id = %client_id,
51 "WebSocket client registered"
52 );
53
54 (client_id, rx)
55 }
56
57 pub fn remove_client(&self, tenant_id: &str, subscription_id: &str, client_id: &str) {
59 let key = (tenant_id.to_string(), subscription_id.to_string());
60
61 if let Some(mut entry) = self.clients.get_mut(&key) {
62 entry.retain(|(id, _)| id != client_id);
63 let is_empty = entry.is_empty();
64 drop(entry);
65
66 if is_empty {
68 self.clients.remove(&key);
69 }
70 }
71
72 debug!(
73 tenant_id,
74 subscription_id, client_id, "WebSocket client removed"
75 );
76 }
77
78 pub fn send_to_subscription(
83 &self,
84 tenant_id: &str,
85 subscription_id: &str,
86 notification: &serde_json::Value,
87 ) -> usize {
88 let key = (tenant_id.to_string(), subscription_id.to_string());
89
90 let Some(mut entry) = self.clients.get_mut(&key) else {
91 return 0;
92 };
93
94 let mut delivered = 0;
95 let initial_count = entry.len();
96
97 entry.retain(|(client_id, sender)| {
99 if sender.send(notification.clone()).is_ok() {
100 delivered += 1;
101 true
102 } else {
103 debug!(client_id = %client_id, "Pruning disconnected WebSocket client");
104 false
105 }
106 });
107
108 let pruned = initial_count - entry.len();
109 if pruned > 0 {
110 warn!(
111 tenant_id,
112 subscription_id, pruned, "Pruned disconnected WebSocket clients"
113 );
114 }
115
116 if entry.is_empty() {
118 drop(entry);
119 self.clients.remove(&key);
120 }
121
122 delivered
123 }
124
125 pub fn client_count(&self, tenant_id: &str, subscription_id: &str) -> usize {
127 let key = (tenant_id.to_string(), subscription_id.to_string());
128 self.clients.get(&key).map(|e| e.len()).unwrap_or(0)
129 }
130
131 pub fn remove_all_clients(&self, tenant_id: &str, subscription_id: &str) {
136 let key = (tenant_id.to_string(), subscription_id.to_string());
137 if let Some((_, clients)) = self.clients.remove(&key) {
138 debug!(
139 tenant_id,
140 subscription_id,
141 count = clients.len(),
142 "Removed all WebSocket clients for subscription"
143 );
144 }
145 }
146}
147
148impl Default for WebSocketManager {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use serde_json::json;
158
159 #[tokio::test]
160 async fn test_register_and_receive() {
161 let mgr = WebSocketManager::new();
162 let (_client_id, mut rx) = mgr.register_client("t1", "sub-1");
163
164 let notification = json!({"resourceType": "Bundle", "type": "history"});
165 let count = mgr.send_to_subscription("t1", "sub-1", ¬ification);
166
167 assert_eq!(count, 1);
168 let received = rx.recv().await.unwrap();
169 assert_eq!(received, notification);
170 }
171
172 #[tokio::test]
173 async fn test_multi_client_broadcast() {
174 let mgr = WebSocketManager::new();
175 let (_id1, mut rx1) = mgr.register_client("t1", "sub-1");
176 let (_id2, mut rx2) = mgr.register_client("t1", "sub-1");
177
178 let notification = json!({"type": "event-notification"});
179 let count = mgr.send_to_subscription("t1", "sub-1", ¬ification);
180
181 assert_eq!(count, 2);
182 assert_eq!(rx1.recv().await.unwrap(), notification);
183 assert_eq!(rx2.recv().await.unwrap(), notification);
184 }
185
186 #[tokio::test]
187 async fn test_disconnected_client_pruned() {
188 let mgr = WebSocketManager::new();
189 let (_id1, rx1) = mgr.register_client("t1", "sub-1");
190 let (_id2, mut rx2) = mgr.register_client("t1", "sub-1");
191
192 drop(rx1);
194
195 let notification = json!({"type": "event-notification"});
196 let count = mgr.send_to_subscription("t1", "sub-1", ¬ification);
197
198 assert_eq!(count, 1);
199 assert_eq!(rx2.recv().await.unwrap(), notification);
200 assert_eq!(mgr.client_count("t1", "sub-1"), 1);
201 }
202
203 #[test]
204 fn test_remove_client() {
205 let mgr = WebSocketManager::new();
206 let (id1, _rx1) = mgr.register_client("t1", "sub-1");
207 let (_id2, _rx2) = mgr.register_client("t1", "sub-1");
208
209 assert_eq!(mgr.client_count("t1", "sub-1"), 2);
210
211 mgr.remove_client("t1", "sub-1", &id1);
212 assert_eq!(mgr.client_count("t1", "sub-1"), 1);
213 }
214
215 #[test]
216 fn test_remove_all_clients() {
217 let mgr = WebSocketManager::new();
218 let (_id1, _rx1) = mgr.register_client("t1", "sub-1");
219 let (_id2, _rx2) = mgr.register_client("t1", "sub-1");
220
221 mgr.remove_all_clients("t1", "sub-1");
222 assert_eq!(mgr.client_count("t1", "sub-1"), 0);
223 }
224
225 #[tokio::test]
226 async fn test_remove_all_closes_receivers() {
227 let mgr = WebSocketManager::new();
228 let (_id1, mut rx1) = mgr.register_client("t1", "sub-1");
229
230 mgr.remove_all_clients("t1", "sub-1");
231
232 assert!(rx1.recv().await.is_none());
234 }
235
236 #[test]
237 fn test_send_to_nonexistent_subscription() {
238 let mgr = WebSocketManager::new();
239 let notification = json!({"type": "event-notification"});
240 let count = mgr.send_to_subscription("t1", "sub-1", ¬ification);
241 assert_eq!(count, 0);
242 }
243
244 #[test]
245 fn test_client_count_empty() {
246 let mgr = WebSocketManager::new();
247 assert_eq!(mgr.client_count("t1", "sub-1"), 0);
248 }
249}