Skip to main content

helios_subscriptions/channels/
ws_manager.rs

1//! WebSocket client connection manager.
2//!
3//! Tracks connected WebSocket clients per subscription. The subscriptions
4//! crate uses this to broadcast notifications, while the REST crate uses
5//! it to register new connections from the WebSocket upgrade handler.
6
7use dashmap::DashMap;
8use tokio::sync::mpsc;
9use tracing::{debug, warn};
10use uuid::Uuid;
11
12/// Sender half given to a connected WebSocket client.
13pub type WsClientSender = mpsc::UnboundedSender<serde_json::Value>;
14
15/// Manages connected WebSocket clients for all active subscriptions.
16pub struct WebSocketManager {
17    /// Connected clients keyed by (tenant_id, subscription_id).
18    /// Each entry is a vec of (client_id, sender).
19    clients: DashMap<(String, String), Vec<(String, WsClientSender)>>,
20}
21
22impl WebSocketManager {
23    /// Creates a new empty manager.
24    pub fn new() -> Self {
25        Self {
26            clients: DashMap::new(),
27        }
28    }
29
30    /// Registers a new WebSocket client for a subscription.
31    ///
32    /// Returns `(client_id, receiver)`. The caller (WebSocket handler) reads
33    /// notifications from the receiver and forwards them over the socket.
34    pub fn register_client(
35        &self,
36        tenant_id: &str,
37        subscription_id: &str,
38    ) -> (String, mpsc::UnboundedReceiver<serde_json::Value>) {
39        let (tx, rx) = mpsc::unbounded_channel();
40        let client_id = Uuid::new_v4().to_string();
41
42        self.clients
43            .entry((tenant_id.to_string(), subscription_id.to_string()))
44            .or_default()
45            .push((client_id.clone(), tx));
46
47        debug!(
48            tenant_id,
49            subscription_id,
50            client_id = %client_id,
51            "WebSocket client registered"
52        );
53
54        (client_id, rx)
55    }
56
57    /// Removes a specific client by ID (called on disconnect).
58    pub fn remove_client(&self, tenant_id: &str, subscription_id: &str, client_id: &str) {
59        let key = (tenant_id.to_string(), subscription_id.to_string());
60
61        if let Some(mut entry) = self.clients.get_mut(&key) {
62            entry.retain(|(id, _)| id != client_id);
63            let is_empty = entry.is_empty();
64            drop(entry);
65
66            // Clean up empty entries.
67            if is_empty {
68                self.clients.remove(&key);
69            }
70        }
71
72        debug!(
73            tenant_id,
74            subscription_id, client_id, "WebSocket client removed"
75        );
76    }
77
78    /// Broadcasts a notification to all connected clients for a subscription.
79    ///
80    /// Automatically prunes disconnected clients (closed channels).
81    /// Returns the number of clients that received the message.
82    pub fn send_to_subscription(
83        &self,
84        tenant_id: &str,
85        subscription_id: &str,
86        notification: &serde_json::Value,
87    ) -> usize {
88        let key = (tenant_id.to_string(), subscription_id.to_string());
89
90        let Some(mut entry) = self.clients.get_mut(&key) else {
91            return 0;
92        };
93
94        let mut delivered = 0;
95        let initial_count = entry.len();
96
97        // Send to all clients, removing those with closed channels.
98        entry.retain(|(client_id, sender)| {
99            if sender.send(notification.clone()).is_ok() {
100                delivered += 1;
101                true
102            } else {
103                debug!(client_id = %client_id, "Pruning disconnected WebSocket client");
104                false
105            }
106        });
107
108        let pruned = initial_count - entry.len();
109        if pruned > 0 {
110            warn!(
111                tenant_id,
112                subscription_id, pruned, "Pruned disconnected WebSocket clients"
113            );
114        }
115
116        // Clean up empty entries.
117        if entry.is_empty() {
118            drop(entry);
119            self.clients.remove(&key);
120        }
121
122        delivered
123    }
124
125    /// Returns the number of connected clients for a subscription.
126    pub fn client_count(&self, tenant_id: &str, subscription_id: &str) -> usize {
127        let key = (tenant_id.to_string(), subscription_id.to_string());
128        self.clients.get(&key).map(|e| e.len()).unwrap_or(0)
129    }
130
131    /// Removes all clients for a subscription (called on deregistration).
132    ///
133    /// Dropping the senders causes receivers to return `None`, which
134    /// triggers graceful close in the WebSocket handlers.
135    pub fn remove_all_clients(&self, tenant_id: &str, subscription_id: &str) {
136        let key = (tenant_id.to_string(), subscription_id.to_string());
137        if let Some((_, clients)) = self.clients.remove(&key) {
138            debug!(
139                tenant_id,
140                subscription_id,
141                count = clients.len(),
142                "Removed all WebSocket clients for subscription"
143            );
144        }
145    }
146}
147
148impl Default for WebSocketManager {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use serde_json::json;
158
159    #[tokio::test]
160    async fn test_register_and_receive() {
161        let mgr = WebSocketManager::new();
162        let (_client_id, mut rx) = mgr.register_client("t1", "sub-1");
163
164        let notification = json!({"resourceType": "Bundle", "type": "history"});
165        let count = mgr.send_to_subscription("t1", "sub-1", &notification);
166
167        assert_eq!(count, 1);
168        let received = rx.recv().await.unwrap();
169        assert_eq!(received, notification);
170    }
171
172    #[tokio::test]
173    async fn test_multi_client_broadcast() {
174        let mgr = WebSocketManager::new();
175        let (_id1, mut rx1) = mgr.register_client("t1", "sub-1");
176        let (_id2, mut rx2) = mgr.register_client("t1", "sub-1");
177
178        let notification = json!({"type": "event-notification"});
179        let count = mgr.send_to_subscription("t1", "sub-1", &notification);
180
181        assert_eq!(count, 2);
182        assert_eq!(rx1.recv().await.unwrap(), notification);
183        assert_eq!(rx2.recv().await.unwrap(), notification);
184    }
185
186    #[tokio::test]
187    async fn test_disconnected_client_pruned() {
188        let mgr = WebSocketManager::new();
189        let (_id1, rx1) = mgr.register_client("t1", "sub-1");
190        let (_id2, mut rx2) = mgr.register_client("t1", "sub-1");
191
192        // Drop rx1 to simulate disconnect.
193        drop(rx1);
194
195        let notification = json!({"type": "event-notification"});
196        let count = mgr.send_to_subscription("t1", "sub-1", &notification);
197
198        assert_eq!(count, 1);
199        assert_eq!(rx2.recv().await.unwrap(), notification);
200        assert_eq!(mgr.client_count("t1", "sub-1"), 1);
201    }
202
203    #[test]
204    fn test_remove_client() {
205        let mgr = WebSocketManager::new();
206        let (id1, _rx1) = mgr.register_client("t1", "sub-1");
207        let (_id2, _rx2) = mgr.register_client("t1", "sub-1");
208
209        assert_eq!(mgr.client_count("t1", "sub-1"), 2);
210
211        mgr.remove_client("t1", "sub-1", &id1);
212        assert_eq!(mgr.client_count("t1", "sub-1"), 1);
213    }
214
215    #[test]
216    fn test_remove_all_clients() {
217        let mgr = WebSocketManager::new();
218        let (_id1, _rx1) = mgr.register_client("t1", "sub-1");
219        let (_id2, _rx2) = mgr.register_client("t1", "sub-1");
220
221        mgr.remove_all_clients("t1", "sub-1");
222        assert_eq!(mgr.client_count("t1", "sub-1"), 0);
223    }
224
225    #[tokio::test]
226    async fn test_remove_all_closes_receivers() {
227        let mgr = WebSocketManager::new();
228        let (_id1, mut rx1) = mgr.register_client("t1", "sub-1");
229
230        mgr.remove_all_clients("t1", "sub-1");
231
232        // Receiver should return None (channel closed).
233        assert!(rx1.recv().await.is_none());
234    }
235
236    #[test]
237    fn test_send_to_nonexistent_subscription() {
238        let mgr = WebSocketManager::new();
239        let notification = json!({"type": "event-notification"});
240        let count = mgr.send_to_subscription("t1", "sub-1", &notification);
241        assert_eq!(count, 0);
242    }
243
244    #[test]
245    fn test_client_count_empty() {
246        let mgr = WebSocketManager::new();
247        assert_eq!(mgr.client_count("t1", "sub-1"), 0);
248    }
249}