hyperstack_server/websocket/
client_manager.rs1use super::subscription::Subscription;
2use anyhow::Result;
3use bytes::Bytes;
4use futures_util::stream::SplitSink;
5use futures_util::SinkExt;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use tokio::net::TcpStream;
10use tokio::sync::{mpsc, RwLock};
11use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
12use tracing::{debug, info, warn};
13use uuid::Uuid;
14
15pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
16
17#[derive(Debug, Clone)]
19pub struct ClientInfo {
20 pub id: Uuid,
21 pub subscription: Option<Subscription>,
22 pub last_seen: SystemTime,
23 pub sender: mpsc::Sender<Message>,
24}
25
26impl ClientInfo {
27 pub fn new(id: Uuid, sender: mpsc::Sender<Message>) -> Self {
28 Self {
29 id,
30 subscription: None,
31 last_seen: SystemTime::now(),
32 sender,
33 }
34 }
35
36 pub fn update_last_seen(&mut self) {
37 self.last_seen = SystemTime::now();
38 }
39
40 pub fn is_stale(&self, timeout: Duration) -> bool {
41 self.last_seen.elapsed().unwrap_or(Duration::MAX) > timeout
42 }
43}
44
45#[derive(Clone)]
47pub struct ClientManager {
48 clients: Arc<RwLock<HashMap<Uuid, ClientInfo>>>,
49 client_timeout: Duration,
50 message_queue_size: usize,
51}
52
53impl ClientManager {
54 pub fn new() -> Self {
55 Self {
56 clients: Arc::new(RwLock::new(HashMap::new())),
57 client_timeout: Duration::from_secs(300),
58 message_queue_size: 1000,
59 }
60 }
61
62 pub fn with_timeout(mut self, timeout: Duration) -> Self {
63 self.client_timeout = timeout;
64 self
65 }
66
67 pub fn with_message_queue_size(mut self, queue_size: usize) -> Self {
68 self.message_queue_size = queue_size;
69 self
70 }
71
72 pub async fn add_client(&self, client_id: Uuid, mut ws_sender: WebSocketSender) -> Result<()> {
74 let (client_tx, mut client_rx) = mpsc::channel::<Message>(self.message_queue_size);
75
76 let client_info = ClientInfo::new(client_id, client_tx);
77
78 let clients_ref = self.clients.clone();
80 tokio::spawn(async move {
81 while let Some(message) = client_rx.recv().await {
86 if let Err(e) = ws_sender.send(message).await {
87 warn!("Failed to send message to client {}: {}", client_id, e);
88
89 clients_ref.write().await.remove(&client_id);
91 break;
92 }
93 }
94
95 debug!("WebSocket sender for client {} stopped", client_id);
96 });
97
98 self.clients.write().await.insert(client_id, client_info);
100 info!("Client {} registered", client_id);
101
102 Ok(())
103 }
104
105 pub async fn remove_client(&self, client_id: Uuid) {
106 if self.clients.write().await.remove(&client_id).is_some() {
107 info!("Client {} removed", client_id);
108 }
109 }
110
111 pub async fn client_count(&self) -> usize {
112 self.clients.read().await.len()
113 }
114
115 pub async fn send_to_client(&self, client_id: Uuid, data: Arc<Bytes>) -> Result<()> {
116 let clients = self.clients.read().await;
117 if let Some(client) = clients.get(&client_id) {
118 let msg = Message::Binary((*data).clone());
119 client.sender.send(msg).await?;
120 }
121 Ok(())
122 }
123
124 pub async fn update_subscription(&self, client_id: Uuid, subscription: Subscription) -> bool {
125 if let Some(client_info) = self.clients.write().await.get_mut(&client_id) {
126 client_info.subscription = Some(subscription);
127 client_info.update_last_seen();
128 debug!("Updated subscription for client {}", client_id);
129 true
130 } else {
131 warn!(
132 "Failed to update subscription for unknown client {}",
133 client_id
134 );
135 false
136 }
137 }
138
139 pub async fn update_client_last_seen(&self, client_id: Uuid) {
140 if let Some(client_info) = self.clients.write().await.get_mut(&client_id) {
141 client_info.update_last_seen();
142 }
143 }
144
145 pub async fn get_subscription(&self, client_id: Uuid) -> Option<Subscription> {
146 let clients = self.clients.read().await;
147 clients.get(&client_id).and_then(|c| c.subscription.clone())
148 }
149
150 pub async fn cleanup_stale_clients(&self) -> usize {
151 let mut clients = self.clients.write().await;
152 let mut stale_clients = Vec::new();
153
154 for (client_id, client_info) in clients.iter() {
155 if client_info.is_stale(self.client_timeout) {
156 stale_clients.push(*client_id);
157 }
158 }
159
160 let removed_count = stale_clients.len();
161 for client_id in stale_clients {
162 clients.remove(&client_id);
163 info!("Removed stale client {}", client_id);
164 }
165
166 removed_count
167 }
168
169 pub async fn start_cleanup_task(&self) {
170 let client_manager = self.clone();
171
172 tokio::spawn(async move {
173 let mut interval = tokio::time::interval(Duration::from_secs(30));
174
175 loop {
176 interval.tick().await;
177 let removed = client_manager.cleanup_stale_clients().await;
178 if removed > 0 {
179 info!("Cleaned up {} stale clients", removed);
180 }
181 }
182 });
183 }
184}
185
186impl Default for ClientManager {
187 fn default() -> Self {
188 Self::new()
189 }
190}