hyperstack_server/websocket/
client_manager.rs

1use super::subscription::Subscription;
2use crate::compression::CompressedPayload;
3use bytes::Bytes;
4use dashmap::DashMap;
5use futures_util::stream::SplitSink;
6use futures_util::SinkExt;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::net::TcpStream;
11use tokio::sync::{mpsc, RwLock};
12use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
13use tokio_util::sync::CancellationToken;
14use tracing::{debug, info, warn};
15use uuid::Uuid;
16
17pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
18
19/// Error type for send operations
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SendError {
22    /// Client not found in registry
23    ClientNotFound,
24    /// Client's message queue is full - client was disconnected
25    ClientBackpressured,
26    /// Client's channel is closed - client was disconnected
27    ClientDisconnected,
28}
29
30impl std::fmt::Display for SendError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            SendError::ClientNotFound => write!(f, "client not found"),
34            SendError::ClientBackpressured => write!(f, "client backpressured and disconnected"),
35            SendError::ClientDisconnected => write!(f, "client disconnected"),
36        }
37    }
38}
39
40impl std::error::Error for SendError {}
41
42/// Information about a connected client
43#[derive(Debug)]
44pub struct ClientInfo {
45    pub id: Uuid,
46    pub subscription: Option<Subscription>,
47    pub last_seen: SystemTime,
48    pub sender: mpsc::Sender<Message>,
49    subscriptions: Arc<RwLock<HashMap<String, CancellationToken>>>,
50}
51
52impl ClientInfo {
53    pub fn new(id: Uuid, sender: mpsc::Sender<Message>) -> Self {
54        Self {
55            id,
56            subscription: None,
57            last_seen: SystemTime::now(),
58            sender,
59            subscriptions: Arc::new(RwLock::new(HashMap::new())),
60        }
61    }
62
63    pub fn update_last_seen(&mut self) {
64        self.last_seen = SystemTime::now();
65    }
66
67    pub fn is_stale(&self, timeout: Duration) -> bool {
68        self.last_seen.elapsed().unwrap_or(Duration::MAX) > timeout
69    }
70
71    pub async fn add_subscription(&self, sub_key: String, token: CancellationToken) -> bool {
72        let mut subs = self.subscriptions.write().await;
73        if let Some(old_token) = subs.insert(sub_key.clone(), token) {
74            old_token.cancel();
75            debug!("Replaced existing subscription: {}", sub_key);
76            false
77        } else {
78            true
79        }
80    }
81
82    pub async fn remove_subscription(&self, sub_key: &str) -> bool {
83        let mut subs = self.subscriptions.write().await;
84        if let Some(token) = subs.remove(sub_key) {
85            token.cancel();
86            debug!("Cancelled subscription: {}", sub_key);
87            true
88        } else {
89            debug!("Subscription not found for cancellation: {}", sub_key);
90            false
91        }
92    }
93
94    pub async fn cancel_all_subscriptions(&self) {
95        let subs = self.subscriptions.read().await;
96        for (sub_key, token) in subs.iter() {
97            token.cancel();
98            debug!("Cancelled subscription on disconnect: {}", sub_key);
99        }
100    }
101
102    pub async fn subscription_count(&self) -> usize {
103        self.subscriptions.read().await.len()
104    }
105}
106
107/// Manages all connected WebSocket clients using lock-free DashMap.
108///
109/// Key design decisions:
110/// - Uses DashMap for lock-free concurrent access to client registry
111/// - Uses try_send instead of send to never block on slow clients
112/// - Disconnects clients that are backpressured (queue full) to prevent cascade failures
113/// - All public methods are non-blocking or use fine-grained per-key locks
114#[derive(Clone)]
115pub struct ClientManager {
116    clients: Arc<DashMap<Uuid, ClientInfo>>,
117    client_timeout: Duration,
118    message_queue_size: usize,
119}
120
121impl ClientManager {
122    pub fn new() -> Self {
123        Self {
124            clients: Arc::new(DashMap::new()),
125            client_timeout: Duration::from_secs(300),
126            message_queue_size: 512,
127        }
128    }
129
130    pub fn with_timeout(mut self, timeout: Duration) -> Self {
131        self.client_timeout = timeout;
132        self
133    }
134
135    pub fn with_message_queue_size(mut self, queue_size: usize) -> Self {
136        self.message_queue_size = queue_size;
137        self
138    }
139
140    /// Add a new client connection.
141    ///
142    /// Spawns a dedicated sender task for this client that reads from its mpsc channel
143    /// and writes to the WebSocket. If the WebSocket write fails, the client is automatically
144    /// removed from the registry.
145    pub fn add_client(&self, client_id: Uuid, mut ws_sender: WebSocketSender) {
146        let (client_tx, mut client_rx) = mpsc::channel::<Message>(self.message_queue_size);
147        let client_info = ClientInfo::new(client_id, client_tx);
148
149        let clients_ref = self.clients.clone();
150        tokio::spawn(async move {
151            while let Some(message) = client_rx.recv().await {
152                if let Err(e) = ws_sender.send(message).await {
153                    warn!("Failed to send message to client {}: {}", client_id, e);
154                    break;
155                }
156            }
157            clients_ref.remove(&client_id);
158            debug!("WebSocket sender task for client {} stopped", client_id);
159        });
160
161        self.clients.insert(client_id, client_info);
162        info!("Client {} registered", client_id);
163    }
164
165    /// Remove a client from the registry.
166    pub fn remove_client(&self, client_id: Uuid) {
167        if self.clients.remove(&client_id).is_some() {
168            info!("Client {} removed", client_id);
169        }
170    }
171
172    /// Get the current number of connected clients.
173    ///
174    /// This is lock-free and returns an approximate count (may be slightly stale
175    /// under high concurrency, which is fine for max_clients checks).
176    pub fn client_count(&self) -> usize {
177        self.clients.len()
178    }
179
180    /// Send data to a specific client (non-blocking).
181    ///
182    /// This method NEVER blocks. If the client's queue is full, the client is
183    /// considered too slow and is disconnected to prevent cascade failures.
184    /// Use this for live streaming updates.
185    ///
186    /// For initial snapshots where you expect to send many messages at once,
187    /// use `send_to_client_async` instead which will wait for queue space.
188    pub fn send_to_client(&self, client_id: Uuid, data: Arc<Bytes>) -> Result<(), SendError> {
189        let sender = {
190            let client = self
191                .clients
192                .get(&client_id)
193                .ok_or(SendError::ClientNotFound)?;
194            client.sender.clone()
195        };
196
197        let msg = Message::Binary((*data).clone());
198        match sender.try_send(msg) {
199            Ok(()) => Ok(()),
200            Err(mpsc::error::TrySendError::Full(_)) => {
201                warn!(
202                    "Client {} backpressured (queue full), disconnecting",
203                    client_id
204                );
205                self.clients.remove(&client_id);
206                Err(SendError::ClientBackpressured)
207            }
208            Err(mpsc::error::TrySendError::Closed(_)) => {
209                debug!("Client {} channel closed", client_id);
210                self.clients.remove(&client_id);
211                Err(SendError::ClientDisconnected)
212            }
213        }
214    }
215
216    /// Send data to a specific client (async, waits for queue space).
217    ///
218    /// This method will wait if the client's queue is full, allowing the client
219    /// time to catch up. Use this for initial snapshots where you need to send
220    /// many messages at once.
221    ///
222    /// For live streaming updates, use `send_to_client` instead which will
223    /// disconnect slow clients rather than blocking.
224    pub async fn send_to_client_async(
225        &self,
226        client_id: Uuid,
227        data: Arc<Bytes>,
228    ) -> Result<(), SendError> {
229        let sender = {
230            let client = self
231                .clients
232                .get(&client_id)
233                .ok_or(SendError::ClientNotFound)?;
234            client.sender.clone()
235        };
236
237        let msg = Message::Binary((*data).clone());
238        sender
239            .send(msg)
240            .await
241            .map_err(|_| SendError::ClientDisconnected)
242    }
243
244    /// Send a potentially compressed payload to a client (async).
245    ///
246    /// Compressed payloads are sent as binary frames (raw gzip).
247    /// Uncompressed payloads are sent as text frames (JSON).
248    pub async fn send_compressed_async(
249        &self,
250        client_id: Uuid,
251        payload: CompressedPayload,
252    ) -> Result<(), SendError> {
253        let sender = {
254            let client = self
255                .clients
256                .get(&client_id)
257                .ok_or(SendError::ClientNotFound)?;
258            client.sender.clone()
259        };
260
261        let msg = match payload {
262            CompressedPayload::Compressed(bytes) => Message::Binary(bytes),
263            CompressedPayload::Uncompressed(bytes) => Message::Binary(bytes),
264        };
265        sender
266            .send(msg)
267            .await
268            .map_err(|_| SendError::ClientDisconnected)
269    }
270
271    /// Update the subscription for a client.
272    pub fn update_subscription(&self, client_id: Uuid, subscription: Subscription) -> bool {
273        if let Some(mut client) = self.clients.get_mut(&client_id) {
274            client.subscription = Some(subscription);
275            client.update_last_seen();
276            debug!("Updated subscription for client {}", client_id);
277            true
278        } else {
279            warn!(
280                "Failed to update subscription for unknown client {}",
281                client_id
282            );
283            false
284        }
285    }
286
287    /// Update the last_seen timestamp for a client.
288    pub fn update_client_last_seen(&self, client_id: Uuid) {
289        if let Some(mut client) = self.clients.get_mut(&client_id) {
290            client.update_last_seen();
291        }
292    }
293
294    /// Get the subscription for a client.
295    pub fn get_subscription(&self, client_id: Uuid) -> Option<Subscription> {
296        self.clients
297            .get(&client_id)
298            .and_then(|c| c.subscription.clone())
299    }
300
301    /// Check if a client exists.
302    pub fn has_client(&self, client_id: Uuid) -> bool {
303        self.clients.contains_key(&client_id)
304    }
305
306    pub async fn add_client_subscription(
307        &self,
308        client_id: Uuid,
309        sub_key: String,
310        token: CancellationToken,
311    ) -> bool {
312        if let Some(client) = self.clients.get(&client_id) {
313            client.add_subscription(sub_key, token).await
314        } else {
315            false
316        }
317    }
318
319    pub async fn remove_client_subscription(&self, client_id: Uuid, sub_key: &str) -> bool {
320        if let Some(client) = self.clients.get(&client_id) {
321            client.remove_subscription(sub_key).await
322        } else {
323            false
324        }
325    }
326
327    pub async fn cancel_all_client_subscriptions(&self, client_id: Uuid) {
328        if let Some(client) = self.clients.get(&client_id) {
329            client.cancel_all_subscriptions().await;
330        }
331    }
332
333    /// Remove stale clients that haven't been seen within the timeout period.
334    pub fn cleanup_stale_clients(&self) -> usize {
335        let timeout = self.client_timeout;
336        let mut stale_clients = Vec::new();
337
338        for entry in self.clients.iter() {
339            if entry.value().is_stale(timeout) {
340                stale_clients.push(*entry.key());
341            }
342        }
343
344        let removed_count = stale_clients.len();
345        for client_id in stale_clients {
346            self.clients.remove(&client_id);
347            info!("Removed stale client {}", client_id);
348        }
349
350        removed_count
351    }
352
353    /// Start a background task that periodically cleans up stale clients.
354    pub fn start_cleanup_task(&self) {
355        let client_manager = self.clone();
356
357        tokio::spawn(async move {
358            let mut interval = tokio::time::interval(Duration::from_secs(30));
359
360            loop {
361                interval.tick().await;
362                let removed = client_manager.cleanup_stale_clients();
363                if removed > 0 {
364                    info!("Cleaned up {} stale clients", removed);
365                }
366            }
367        });
368    }
369}
370
371impl Default for ClientManager {
372    fn default() -> Self {
373        Self::new()
374    }
375}