hyperstack_server/websocket/
client_manager.rs

1use super::subscription::Subscription;
2use bytes::Bytes;
3use dashmap::DashMap;
4use futures_util::stream::SplitSink;
5use futures_util::SinkExt;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use tokio::net::TcpStream;
10use tokio::sync::{mpsc, RwLock};
11use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
12use tokio_util::sync::CancellationToken;
13use tracing::{debug, info, warn};
14use uuid::Uuid;
15
16pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
17
18/// Error type for send operations
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SendError {
21    /// Client not found in registry
22    ClientNotFound,
23    /// Client's message queue is full - client was disconnected
24    ClientBackpressured,
25    /// Client's channel is closed - client was disconnected
26    ClientDisconnected,
27}
28
29impl std::fmt::Display for SendError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            SendError::ClientNotFound => write!(f, "client not found"),
33            SendError::ClientBackpressured => write!(f, "client backpressured and disconnected"),
34            SendError::ClientDisconnected => write!(f, "client disconnected"),
35        }
36    }
37}
38
39impl std::error::Error for SendError {}
40
41/// Information about a connected client
42#[derive(Debug)]
43pub struct ClientInfo {
44    pub id: Uuid,
45    pub subscription: Option<Subscription>,
46    pub last_seen: SystemTime,
47    pub sender: mpsc::Sender<Message>,
48    subscriptions: Arc<RwLock<HashMap<String, CancellationToken>>>,
49}
50
51impl ClientInfo {
52    pub fn new(id: Uuid, sender: mpsc::Sender<Message>) -> Self {
53        Self {
54            id,
55            subscription: None,
56            last_seen: SystemTime::now(),
57            sender,
58            subscriptions: Arc::new(RwLock::new(HashMap::new())),
59        }
60    }
61
62    pub fn update_last_seen(&mut self) {
63        self.last_seen = SystemTime::now();
64    }
65
66    pub fn is_stale(&self, timeout: Duration) -> bool {
67        self.last_seen.elapsed().unwrap_or(Duration::MAX) > timeout
68    }
69
70    pub async fn add_subscription(&self, sub_key: String, token: CancellationToken) {
71        let mut subs = self.subscriptions.write().await;
72        if let Some(old_token) = subs.insert(sub_key.clone(), token) {
73            old_token.cancel();
74            debug!("Replaced existing subscription: {}", sub_key);
75        }
76    }
77
78    pub async fn remove_subscription(&self, sub_key: &str) -> bool {
79        let mut subs = self.subscriptions.write().await;
80        if let Some(token) = subs.remove(sub_key) {
81            token.cancel();
82            debug!("Cancelled subscription: {}", sub_key);
83            true
84        } else {
85            debug!("Subscription not found for cancellation: {}", sub_key);
86            false
87        }
88    }
89
90    pub async fn cancel_all_subscriptions(&self) {
91        let subs = self.subscriptions.read().await;
92        for (sub_key, token) in subs.iter() {
93            token.cancel();
94            debug!("Cancelled subscription on disconnect: {}", sub_key);
95        }
96    }
97
98    pub async fn subscription_count(&self) -> usize {
99        self.subscriptions.read().await.len()
100    }
101}
102
103/// Manages all connected WebSocket clients using lock-free DashMap.
104///
105/// Key design decisions:
106/// - Uses DashMap for lock-free concurrent access to client registry
107/// - Uses try_send instead of send to never block on slow clients
108/// - Disconnects clients that are backpressured (queue full) to prevent cascade failures
109/// - All public methods are non-blocking or use fine-grained per-key locks
110#[derive(Clone)]
111pub struct ClientManager {
112    clients: Arc<DashMap<Uuid, ClientInfo>>,
113    client_timeout: Duration,
114    message_queue_size: usize,
115}
116
117impl ClientManager {
118    pub fn new() -> Self {
119        Self {
120            clients: Arc::new(DashMap::new()),
121            client_timeout: Duration::from_secs(300),
122            message_queue_size: 512,
123        }
124    }
125
126    pub fn with_timeout(mut self, timeout: Duration) -> Self {
127        self.client_timeout = timeout;
128        self
129    }
130
131    pub fn with_message_queue_size(mut self, queue_size: usize) -> Self {
132        self.message_queue_size = queue_size;
133        self
134    }
135
136    /// Add a new client connection.
137    ///
138    /// Spawns a dedicated sender task for this client that reads from its mpsc channel
139    /// and writes to the WebSocket. If the WebSocket write fails, the client is automatically
140    /// removed from the registry.
141    pub fn add_client(&self, client_id: Uuid, mut ws_sender: WebSocketSender) {
142        let (client_tx, mut client_rx) = mpsc::channel::<Message>(self.message_queue_size);
143        let client_info = ClientInfo::new(client_id, client_tx);
144
145        let clients_ref = self.clients.clone();
146        tokio::spawn(async move {
147            while let Some(message) = client_rx.recv().await {
148                if let Err(e) = ws_sender.send(message).await {
149                    warn!("Failed to send message to client {}: {}", client_id, e);
150                    break;
151                }
152            }
153            clients_ref.remove(&client_id);
154            debug!("WebSocket sender task for client {} stopped", client_id);
155        });
156
157        self.clients.insert(client_id, client_info);
158        info!("Client {} registered", client_id);
159    }
160
161    /// Remove a client from the registry.
162    pub fn remove_client(&self, client_id: Uuid) {
163        if self.clients.remove(&client_id).is_some() {
164            info!("Client {} removed", client_id);
165        }
166    }
167
168    /// Get the current number of connected clients.
169    ///
170    /// This is lock-free and returns an approximate count (may be slightly stale
171    /// under high concurrency, which is fine for max_clients checks).
172    pub fn client_count(&self) -> usize {
173        self.clients.len()
174    }
175
176    /// Send data to a specific client (non-blocking).
177    ///
178    /// This method NEVER blocks. If the client's queue is full, the client is
179    /// considered too slow and is disconnected to prevent cascade failures.
180    /// Use this for live streaming updates.
181    ///
182    /// For initial snapshots where you expect to send many messages at once,
183    /// use `send_to_client_async` instead which will wait for queue space.
184    pub fn send_to_client(&self, client_id: Uuid, data: Arc<Bytes>) -> Result<(), SendError> {
185        let sender = {
186            let client = self
187                .clients
188                .get(&client_id)
189                .ok_or(SendError::ClientNotFound)?;
190            client.sender.clone()
191        };
192
193        let msg = Message::Binary((*data).clone());
194        match sender.try_send(msg) {
195            Ok(()) => Ok(()),
196            Err(mpsc::error::TrySendError::Full(_)) => {
197                warn!(
198                    "Client {} backpressured (queue full), disconnecting",
199                    client_id
200                );
201                self.clients.remove(&client_id);
202                Err(SendError::ClientBackpressured)
203            }
204            Err(mpsc::error::TrySendError::Closed(_)) => {
205                debug!("Client {} channel closed", client_id);
206                self.clients.remove(&client_id);
207                Err(SendError::ClientDisconnected)
208            }
209        }
210    }
211
212    /// Send data to a specific client (async, waits for queue space).
213    ///
214    /// This method will wait if the client's queue is full, allowing the client
215    /// time to catch up. Use this for initial snapshots where you need to send
216    /// many messages at once.
217    ///
218    /// For live streaming updates, use `send_to_client` instead which will
219    /// disconnect slow clients rather than blocking.
220    pub async fn send_to_client_async(
221        &self,
222        client_id: Uuid,
223        data: Arc<Bytes>,
224    ) -> Result<(), SendError> {
225        let sender = {
226            let client = self
227                .clients
228                .get(&client_id)
229                .ok_or(SendError::ClientNotFound)?;
230            client.sender.clone()
231        };
232
233        let msg = Message::Binary((*data).clone());
234        sender
235            .send(msg)
236            .await
237            .map_err(|_| SendError::ClientDisconnected)
238    }
239
240    /// Update the subscription for a client.
241    pub fn update_subscription(&self, client_id: Uuid, subscription: Subscription) -> bool {
242        if let Some(mut client) = self.clients.get_mut(&client_id) {
243            client.subscription = Some(subscription);
244            client.update_last_seen();
245            debug!("Updated subscription for client {}", client_id);
246            true
247        } else {
248            warn!(
249                "Failed to update subscription for unknown client {}",
250                client_id
251            );
252            false
253        }
254    }
255
256    /// Update the last_seen timestamp for a client.
257    pub fn update_client_last_seen(&self, client_id: Uuid) {
258        if let Some(mut client) = self.clients.get_mut(&client_id) {
259            client.update_last_seen();
260        }
261    }
262
263    /// Get the subscription for a client.
264    pub fn get_subscription(&self, client_id: Uuid) -> Option<Subscription> {
265        self.clients
266            .get(&client_id)
267            .and_then(|c| c.subscription.clone())
268    }
269
270    /// Check if a client exists.
271    pub fn has_client(&self, client_id: Uuid) -> bool {
272        self.clients.contains_key(&client_id)
273    }
274
275    pub async fn add_client_subscription(
276        &self,
277        client_id: Uuid,
278        sub_key: String,
279        token: CancellationToken,
280    ) {
281        if let Some(client) = self.clients.get(&client_id) {
282            client.add_subscription(sub_key, token).await;
283        }
284    }
285
286    pub async fn remove_client_subscription(&self, client_id: Uuid, sub_key: &str) -> bool {
287        if let Some(client) = self.clients.get(&client_id) {
288            client.remove_subscription(sub_key).await
289        } else {
290            false
291        }
292    }
293
294    pub async fn cancel_all_client_subscriptions(&self, client_id: Uuid) {
295        if let Some(client) = self.clients.get(&client_id) {
296            client.cancel_all_subscriptions().await;
297        }
298    }
299
300    /// Remove stale clients that haven't been seen within the timeout period.
301    pub fn cleanup_stale_clients(&self) -> usize {
302        let timeout = self.client_timeout;
303        let mut stale_clients = Vec::new();
304
305        for entry in self.clients.iter() {
306            if entry.value().is_stale(timeout) {
307                stale_clients.push(*entry.key());
308            }
309        }
310
311        let removed_count = stale_clients.len();
312        for client_id in stale_clients {
313            self.clients.remove(&client_id);
314            info!("Removed stale client {}", client_id);
315        }
316
317        removed_count
318    }
319
320    /// Start a background task that periodically cleans up stale clients.
321    pub fn start_cleanup_task(&self) {
322        let client_manager = self.clone();
323
324        tokio::spawn(async move {
325            let mut interval = tokio::time::interval(Duration::from_secs(30));
326
327            loop {
328                interval.tick().await;
329                let removed = client_manager.cleanup_stale_clients();
330                if removed > 0 {
331                    info!("Cleaned up {} stale clients", removed);
332                }
333            }
334        });
335    }
336}
337
338impl Default for ClientManager {
339    fn default() -> Self {
340        Self::new()
341    }
342}