hyperstack_server/websocket/
client_manager.rs1use super::subscription::Subscription;
2use crate::compression::CompressedPayload;
3use bytes::Bytes;
4use dashmap::DashMap;
5use futures_util::stream::SplitSink;
6use futures_util::SinkExt;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::net::TcpStream;
11use tokio::sync::{mpsc, RwLock};
12use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
13use tokio_util::sync::CancellationToken;
14use tracing::{debug, info, warn};
15use uuid::Uuid;
16
17pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum SendError {
22 ClientNotFound,
24 ClientBackpressured,
26 ClientDisconnected,
28}
29
30impl std::fmt::Display for SendError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 SendError::ClientNotFound => write!(f, "client not found"),
34 SendError::ClientBackpressured => write!(f, "client backpressured and disconnected"),
35 SendError::ClientDisconnected => write!(f, "client disconnected"),
36 }
37 }
38}
39
40impl std::error::Error for SendError {}
41
42#[derive(Debug)]
44pub struct ClientInfo {
45 pub id: Uuid,
46 pub subscription: Option<Subscription>,
47 pub last_seen: SystemTime,
48 pub sender: mpsc::Sender<Message>,
49 subscriptions: Arc<RwLock<HashMap<String, CancellationToken>>>,
50}
51
52impl ClientInfo {
53 pub fn new(id: Uuid, sender: mpsc::Sender<Message>) -> Self {
54 Self {
55 id,
56 subscription: None,
57 last_seen: SystemTime::now(),
58 sender,
59 subscriptions: Arc::new(RwLock::new(HashMap::new())),
60 }
61 }
62
63 pub fn update_last_seen(&mut self) {
64 self.last_seen = SystemTime::now();
65 }
66
67 pub fn is_stale(&self, timeout: Duration) -> bool {
68 self.last_seen.elapsed().unwrap_or(Duration::MAX) > timeout
69 }
70
71 pub async fn add_subscription(&self, sub_key: String, token: CancellationToken) -> bool {
72 let mut subs = self.subscriptions.write().await;
73 if let Some(old_token) = subs.insert(sub_key.clone(), token) {
74 old_token.cancel();
75 debug!("Replaced existing subscription: {}", sub_key);
76 false
77 } else {
78 true
79 }
80 }
81
82 pub async fn remove_subscription(&self, sub_key: &str) -> bool {
83 let mut subs = self.subscriptions.write().await;
84 if let Some(token) = subs.remove(sub_key) {
85 token.cancel();
86 debug!("Cancelled subscription: {}", sub_key);
87 true
88 } else {
89 debug!("Subscription not found for cancellation: {}", sub_key);
90 false
91 }
92 }
93
94 pub async fn cancel_all_subscriptions(&self) {
95 let subs = self.subscriptions.read().await;
96 for (sub_key, token) in subs.iter() {
97 token.cancel();
98 debug!("Cancelled subscription on disconnect: {}", sub_key);
99 }
100 }
101
102 pub async fn subscription_count(&self) -> usize {
103 self.subscriptions.read().await.len()
104 }
105}
106
107#[derive(Clone)]
115pub struct ClientManager {
116 clients: Arc<DashMap<Uuid, ClientInfo>>,
117 client_timeout: Duration,
118 message_queue_size: usize,
119}
120
121impl ClientManager {
122 pub fn new() -> Self {
123 Self {
124 clients: Arc::new(DashMap::new()),
125 client_timeout: Duration::from_secs(300),
126 message_queue_size: 512,
127 }
128 }
129
130 pub fn with_timeout(mut self, timeout: Duration) -> Self {
131 self.client_timeout = timeout;
132 self
133 }
134
135 pub fn with_message_queue_size(mut self, queue_size: usize) -> Self {
136 self.message_queue_size = queue_size;
137 self
138 }
139
140 pub fn add_client(&self, client_id: Uuid, mut ws_sender: WebSocketSender) {
146 let (client_tx, mut client_rx) = mpsc::channel::<Message>(self.message_queue_size);
147 let client_info = ClientInfo::new(client_id, client_tx);
148
149 let clients_ref = self.clients.clone();
150 tokio::spawn(async move {
151 while let Some(message) = client_rx.recv().await {
152 if let Err(e) = ws_sender.send(message).await {
153 warn!("Failed to send message to client {}: {}", client_id, e);
154 break;
155 }
156 }
157 clients_ref.remove(&client_id);
158 debug!("WebSocket sender task for client {} stopped", client_id);
159 });
160
161 self.clients.insert(client_id, client_info);
162 info!("Client {} registered", client_id);
163 }
164
165 pub fn remove_client(&self, client_id: Uuid) {
167 if self.clients.remove(&client_id).is_some() {
168 info!("Client {} removed", client_id);
169 }
170 }
171
172 pub fn client_count(&self) -> usize {
177 self.clients.len()
178 }
179
180 pub fn send_to_client(&self, client_id: Uuid, data: Arc<Bytes>) -> Result<(), SendError> {
189 let sender = {
190 let client = self
191 .clients
192 .get(&client_id)
193 .ok_or(SendError::ClientNotFound)?;
194 client.sender.clone()
195 };
196
197 let msg = Message::Binary((*data).clone());
198 match sender.try_send(msg) {
199 Ok(()) => Ok(()),
200 Err(mpsc::error::TrySendError::Full(_)) => {
201 warn!(
202 "Client {} backpressured (queue full), disconnecting",
203 client_id
204 );
205 self.clients.remove(&client_id);
206 Err(SendError::ClientBackpressured)
207 }
208 Err(mpsc::error::TrySendError::Closed(_)) => {
209 debug!("Client {} channel closed", client_id);
210 self.clients.remove(&client_id);
211 Err(SendError::ClientDisconnected)
212 }
213 }
214 }
215
216 pub async fn send_to_client_async(
225 &self,
226 client_id: Uuid,
227 data: Arc<Bytes>,
228 ) -> Result<(), SendError> {
229 let sender = {
230 let client = self
231 .clients
232 .get(&client_id)
233 .ok_or(SendError::ClientNotFound)?;
234 client.sender.clone()
235 };
236
237 let msg = Message::Binary((*data).clone());
238 sender
239 .send(msg)
240 .await
241 .map_err(|_| SendError::ClientDisconnected)
242 }
243
244 pub async fn send_compressed_async(
249 &self,
250 client_id: Uuid,
251 payload: CompressedPayload,
252 ) -> Result<(), SendError> {
253 let sender = {
254 let client = self
255 .clients
256 .get(&client_id)
257 .ok_or(SendError::ClientNotFound)?;
258 client.sender.clone()
259 };
260
261 let msg = match payload {
262 CompressedPayload::Compressed(bytes) => Message::Binary(bytes),
263 CompressedPayload::Uncompressed(bytes) => Message::Binary(bytes),
264 };
265 sender
266 .send(msg)
267 .await
268 .map_err(|_| SendError::ClientDisconnected)
269 }
270
271 pub fn update_subscription(&self, client_id: Uuid, subscription: Subscription) -> bool {
273 if let Some(mut client) = self.clients.get_mut(&client_id) {
274 client.subscription = Some(subscription);
275 client.update_last_seen();
276 debug!("Updated subscription for client {}", client_id);
277 true
278 } else {
279 warn!(
280 "Failed to update subscription for unknown client {}",
281 client_id
282 );
283 false
284 }
285 }
286
287 pub fn update_client_last_seen(&self, client_id: Uuid) {
289 if let Some(mut client) = self.clients.get_mut(&client_id) {
290 client.update_last_seen();
291 }
292 }
293
294 pub fn get_subscription(&self, client_id: Uuid) -> Option<Subscription> {
296 self.clients
297 .get(&client_id)
298 .and_then(|c| c.subscription.clone())
299 }
300
301 pub fn has_client(&self, client_id: Uuid) -> bool {
303 self.clients.contains_key(&client_id)
304 }
305
306 pub async fn add_client_subscription(
307 &self,
308 client_id: Uuid,
309 sub_key: String,
310 token: CancellationToken,
311 ) -> bool {
312 if let Some(client) = self.clients.get(&client_id) {
313 client.add_subscription(sub_key, token).await
314 } else {
315 false
316 }
317 }
318
319 pub async fn remove_client_subscription(&self, client_id: Uuid, sub_key: &str) -> bool {
320 if let Some(client) = self.clients.get(&client_id) {
321 client.remove_subscription(sub_key).await
322 } else {
323 false
324 }
325 }
326
327 pub async fn cancel_all_client_subscriptions(&self, client_id: Uuid) {
328 if let Some(client) = self.clients.get(&client_id) {
329 client.cancel_all_subscriptions().await;
330 }
331 }
332
333 pub fn cleanup_stale_clients(&self) -> usize {
335 let timeout = self.client_timeout;
336 let mut stale_clients = Vec::new();
337
338 for entry in self.clients.iter() {
339 if entry.value().is_stale(timeout) {
340 stale_clients.push(*entry.key());
341 }
342 }
343
344 let removed_count = stale_clients.len();
345 for client_id in stale_clients {
346 self.clients.remove(&client_id);
347 info!("Removed stale client {}", client_id);
348 }
349
350 removed_count
351 }
352
353 pub fn start_cleanup_task(&self) {
355 let client_manager = self.clone();
356
357 tokio::spawn(async move {
358 let mut interval = tokio::time::interval(Duration::from_secs(30));
359
360 loop {
361 interval.tick().await;
362 let removed = client_manager.cleanup_stale_clients();
363 if removed > 0 {
364 info!("Cleaned up {} stale clients", removed);
365 }
366 }
367 });
368 }
369}
370
371impl Default for ClientManager {
372 fn default() -> Self {
373 Self::new()
374 }
375}