hyperstack_server/websocket/
client_manager.rs1use super::subscription::Subscription;
2use bytes::Bytes;
3use dashmap::DashMap;
4use futures_util::stream::SplitSink;
5use futures_util::SinkExt;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use tokio::net::TcpStream;
10use tokio::sync::{mpsc, RwLock};
11use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
12use tokio_util::sync::CancellationToken;
13use tracing::{debug, info, warn};
14use uuid::Uuid;
15
16pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SendError {
21 ClientNotFound,
23 ClientBackpressured,
25 ClientDisconnected,
27}
28
29impl std::fmt::Display for SendError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 SendError::ClientNotFound => write!(f, "client not found"),
33 SendError::ClientBackpressured => write!(f, "client backpressured and disconnected"),
34 SendError::ClientDisconnected => write!(f, "client disconnected"),
35 }
36 }
37}
38
39impl std::error::Error for SendError {}
40
41#[derive(Debug)]
43pub struct ClientInfo {
44 pub id: Uuid,
45 pub subscription: Option<Subscription>,
46 pub last_seen: SystemTime,
47 pub sender: mpsc::Sender<Message>,
48 subscriptions: Arc<RwLock<HashMap<String, CancellationToken>>>,
49}
50
51impl ClientInfo {
52 pub fn new(id: Uuid, sender: mpsc::Sender<Message>) -> Self {
53 Self {
54 id,
55 subscription: None,
56 last_seen: SystemTime::now(),
57 sender,
58 subscriptions: Arc::new(RwLock::new(HashMap::new())),
59 }
60 }
61
62 pub fn update_last_seen(&mut self) {
63 self.last_seen = SystemTime::now();
64 }
65
66 pub fn is_stale(&self, timeout: Duration) -> bool {
67 self.last_seen.elapsed().unwrap_or(Duration::MAX) > timeout
68 }
69
70 pub async fn add_subscription(&self, sub_key: String, token: CancellationToken) {
71 let mut subs = self.subscriptions.write().await;
72 if let Some(old_token) = subs.insert(sub_key.clone(), token) {
73 old_token.cancel();
74 debug!("Replaced existing subscription: {}", sub_key);
75 }
76 }
77
78 pub async fn remove_subscription(&self, sub_key: &str) -> bool {
79 let mut subs = self.subscriptions.write().await;
80 if let Some(token) = subs.remove(sub_key) {
81 token.cancel();
82 debug!("Cancelled subscription: {}", sub_key);
83 true
84 } else {
85 debug!("Subscription not found for cancellation: {}", sub_key);
86 false
87 }
88 }
89
90 pub async fn cancel_all_subscriptions(&self) {
91 let subs = self.subscriptions.read().await;
92 for (sub_key, token) in subs.iter() {
93 token.cancel();
94 debug!("Cancelled subscription on disconnect: {}", sub_key);
95 }
96 }
97
98 pub async fn subscription_count(&self) -> usize {
99 self.subscriptions.read().await.len()
100 }
101}
102
103#[derive(Clone)]
111pub struct ClientManager {
112 clients: Arc<DashMap<Uuid, ClientInfo>>,
113 client_timeout: Duration,
114 message_queue_size: usize,
115}
116
117impl ClientManager {
118 pub fn new() -> Self {
119 Self {
120 clients: Arc::new(DashMap::new()),
121 client_timeout: Duration::from_secs(300),
122 message_queue_size: 512,
123 }
124 }
125
126 pub fn with_timeout(mut self, timeout: Duration) -> Self {
127 self.client_timeout = timeout;
128 self
129 }
130
131 pub fn with_message_queue_size(mut self, queue_size: usize) -> Self {
132 self.message_queue_size = queue_size;
133 self
134 }
135
136 pub fn add_client(&self, client_id: Uuid, mut ws_sender: WebSocketSender) {
142 let (client_tx, mut client_rx) = mpsc::channel::<Message>(self.message_queue_size);
143 let client_info = ClientInfo::new(client_id, client_tx);
144
145 let clients_ref = self.clients.clone();
146 tokio::spawn(async move {
147 while let Some(message) = client_rx.recv().await {
148 if let Err(e) = ws_sender.send(message).await {
149 warn!("Failed to send message to client {}: {}", client_id, e);
150 break;
151 }
152 }
153 clients_ref.remove(&client_id);
154 debug!("WebSocket sender task for client {} stopped", client_id);
155 });
156
157 self.clients.insert(client_id, client_info);
158 info!("Client {} registered", client_id);
159 }
160
161 pub fn remove_client(&self, client_id: Uuid) {
163 if self.clients.remove(&client_id).is_some() {
164 info!("Client {} removed", client_id);
165 }
166 }
167
168 pub fn client_count(&self) -> usize {
173 self.clients.len()
174 }
175
176 pub fn send_to_client(&self, client_id: Uuid, data: Arc<Bytes>) -> Result<(), SendError> {
185 let sender = {
186 let client = self
187 .clients
188 .get(&client_id)
189 .ok_or(SendError::ClientNotFound)?;
190 client.sender.clone()
191 };
192
193 let msg = Message::Binary((*data).clone());
194 match sender.try_send(msg) {
195 Ok(()) => Ok(()),
196 Err(mpsc::error::TrySendError::Full(_)) => {
197 warn!(
198 "Client {} backpressured (queue full), disconnecting",
199 client_id
200 );
201 self.clients.remove(&client_id);
202 Err(SendError::ClientBackpressured)
203 }
204 Err(mpsc::error::TrySendError::Closed(_)) => {
205 debug!("Client {} channel closed", client_id);
206 self.clients.remove(&client_id);
207 Err(SendError::ClientDisconnected)
208 }
209 }
210 }
211
212 pub async fn send_to_client_async(
221 &self,
222 client_id: Uuid,
223 data: Arc<Bytes>,
224 ) -> Result<(), SendError> {
225 let sender = {
226 let client = self
227 .clients
228 .get(&client_id)
229 .ok_or(SendError::ClientNotFound)?;
230 client.sender.clone()
231 };
232
233 let msg = Message::Binary((*data).clone());
234 sender
235 .send(msg)
236 .await
237 .map_err(|_| SendError::ClientDisconnected)
238 }
239
240 pub fn update_subscription(&self, client_id: Uuid, subscription: Subscription) -> bool {
242 if let Some(mut client) = self.clients.get_mut(&client_id) {
243 client.subscription = Some(subscription);
244 client.update_last_seen();
245 debug!("Updated subscription for client {}", client_id);
246 true
247 } else {
248 warn!(
249 "Failed to update subscription for unknown client {}",
250 client_id
251 );
252 false
253 }
254 }
255
256 pub fn update_client_last_seen(&self, client_id: Uuid) {
258 if let Some(mut client) = self.clients.get_mut(&client_id) {
259 client.update_last_seen();
260 }
261 }
262
263 pub fn get_subscription(&self, client_id: Uuid) -> Option<Subscription> {
265 self.clients
266 .get(&client_id)
267 .and_then(|c| c.subscription.clone())
268 }
269
270 pub fn has_client(&self, client_id: Uuid) -> bool {
272 self.clients.contains_key(&client_id)
273 }
274
275 pub async fn add_client_subscription(
276 &self,
277 client_id: Uuid,
278 sub_key: String,
279 token: CancellationToken,
280 ) {
281 if let Some(client) = self.clients.get(&client_id) {
282 client.add_subscription(sub_key, token).await;
283 }
284 }
285
286 pub async fn remove_client_subscription(&self, client_id: Uuid, sub_key: &str) -> bool {
287 if let Some(client) = self.clients.get(&client_id) {
288 client.remove_subscription(sub_key).await
289 } else {
290 false
291 }
292 }
293
294 pub async fn cancel_all_client_subscriptions(&self, client_id: Uuid) {
295 if let Some(client) = self.clients.get(&client_id) {
296 client.cancel_all_subscriptions().await;
297 }
298 }
299
300 pub fn cleanup_stale_clients(&self) -> usize {
302 let timeout = self.client_timeout;
303 let mut stale_clients = Vec::new();
304
305 for entry in self.clients.iter() {
306 if entry.value().is_stale(timeout) {
307 stale_clients.push(*entry.key());
308 }
309 }
310
311 let removed_count = stale_clients.len();
312 for client_id in stale_clients {
313 self.clients.remove(&client_id);
314 info!("Removed stale client {}", client_id);
315 }
316
317 removed_count
318 }
319
320 pub fn start_cleanup_task(&self) {
322 let client_manager = self.clone();
323
324 tokio::spawn(async move {
325 let mut interval = tokio::time::interval(Duration::from_secs(30));
326
327 loop {
328 interval.tick().await;
329 let removed = client_manager.cleanup_stale_clients();
330 if removed > 0 {
331 info!("Cleaned up {} stale clients", removed);
332 }
333 }
334 });
335 }
336}
337
338impl Default for ClientManager {
339 fn default() -> Self {
340 Self::new()
341 }
342}