hyperstack_server/websocket/
client_manager.rs

1use super::subscription::Subscription;
2use anyhow::Result;
3use bytes::Bytes;
4use futures_util::stream::SplitSink;
5use futures_util::SinkExt;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use tokio::net::TcpStream;
10use tokio::sync::{mpsc, RwLock};
11use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
12use tracing::{debug, info, warn};
13use uuid::Uuid;
14
15pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
16
17/// Information about a connected client
18#[derive(Debug, Clone)]
19pub struct ClientInfo {
20    pub id: Uuid,
21    pub subscription: Option<Subscription>,
22    pub last_seen: SystemTime,
23    pub sender: mpsc::Sender<Message>,
24}
25
26impl ClientInfo {
27    pub fn new(id: Uuid, sender: mpsc::Sender<Message>) -> Self {
28        Self {
29            id,
30            subscription: None,
31            last_seen: SystemTime::now(),
32            sender,
33        }
34    }
35
36    pub fn update_last_seen(&mut self) {
37        self.last_seen = SystemTime::now();
38    }
39
40    pub fn is_stale(&self, timeout: Duration) -> bool {
41        self.last_seen.elapsed().unwrap_or(Duration::MAX) > timeout
42    }
43}
44
45/// Manages all connected WebSocket clients
46#[derive(Clone)]
47pub struct ClientManager {
48    clients: Arc<RwLock<HashMap<Uuid, ClientInfo>>>,
49    client_timeout: Duration,
50    message_queue_size: usize,
51}
52
53impl ClientManager {
54    pub fn new() -> Self {
55        Self {
56            clients: Arc::new(RwLock::new(HashMap::new())),
57            client_timeout: Duration::from_secs(300),
58            message_queue_size: 1000,
59        }
60    }
61
62    pub fn with_timeout(mut self, timeout: Duration) -> Self {
63        self.client_timeout = timeout;
64        self
65    }
66
67    pub fn with_message_queue_size(mut self, queue_size: usize) -> Self {
68        self.message_queue_size = queue_size;
69        self
70    }
71
72    /// Add a new client connection
73    pub async fn add_client(&self, client_id: Uuid, mut ws_sender: WebSocketSender) -> Result<()> {
74        let (client_tx, mut client_rx) = mpsc::channel::<Message>(self.message_queue_size);
75
76        let client_info = ClientInfo::new(client_id, client_tx);
77
78        // Spawn task to handle WebSocket sending for this client
79        let clients_ref = self.clients.clone();
80        tokio::spawn(async move {
81            // The client info struct is given client_tx. The below listens to this channel
82            // This gives us clean decoupling and handles backpressure without blocking
83            // We also get natural cleanup of failed clients without putting that complexity into
84            // the ClientInfo struct
85            while let Some(message) = client_rx.recv().await {
86                if let Err(e) = ws_sender.send(message).await {
87                    warn!("Failed to send message to client {}: {}", client_id, e);
88
89                    // Remove failed client
90                    clients_ref.write().await.remove(&client_id);
91                    break;
92                }
93            }
94
95            debug!("WebSocket sender for client {} stopped", client_id);
96        });
97
98        // Register client
99        self.clients.write().await.insert(client_id, client_info);
100        info!("Client {} registered", client_id);
101
102        Ok(())
103    }
104
105    pub async fn remove_client(&self, client_id: Uuid) {
106        if self.clients.write().await.remove(&client_id).is_some() {
107            info!("Client {} removed", client_id);
108        }
109    }
110
111    pub async fn client_count(&self) -> usize {
112        self.clients.read().await.len()
113    }
114
115    pub async fn send_to_client(&self, client_id: Uuid, data: Arc<Bytes>) -> Result<()> {
116        let clients = self.clients.read().await;
117        if let Some(client) = clients.get(&client_id) {
118            let msg = Message::Binary((*data).clone());
119            client.sender.send(msg).await?;
120        }
121        Ok(())
122    }
123
124    pub async fn update_subscription(&self, client_id: Uuid, subscription: Subscription) -> bool {
125        if let Some(client_info) = self.clients.write().await.get_mut(&client_id) {
126            client_info.subscription = Some(subscription);
127            client_info.update_last_seen();
128            debug!("Updated subscription for client {}", client_id);
129            true
130        } else {
131            warn!(
132                "Failed to update subscription for unknown client {}",
133                client_id
134            );
135            false
136        }
137    }
138
139    pub async fn update_client_last_seen(&self, client_id: Uuid) {
140        if let Some(client_info) = self.clients.write().await.get_mut(&client_id) {
141            client_info.update_last_seen();
142        }
143    }
144
145    pub async fn get_subscription(&self, client_id: Uuid) -> Option<Subscription> {
146        let clients = self.clients.read().await;
147        clients.get(&client_id).and_then(|c| c.subscription.clone())
148    }
149
150    pub async fn cleanup_stale_clients(&self) -> usize {
151        let mut clients = self.clients.write().await;
152        let mut stale_clients = Vec::new();
153
154        for (client_id, client_info) in clients.iter() {
155            if client_info.is_stale(self.client_timeout) {
156                stale_clients.push(*client_id);
157            }
158        }
159
160        let removed_count = stale_clients.len();
161        for client_id in stale_clients {
162            clients.remove(&client_id);
163            info!("Removed stale client {}", client_id);
164        }
165
166        removed_count
167    }
168
169    pub async fn start_cleanup_task(&self) {
170        let client_manager = self.clone();
171
172        tokio::spawn(async move {
173            let mut interval = tokio::time::interval(Duration::from_secs(30));
174
175            loop {
176                interval.tick().await;
177                let removed = client_manager.cleanup_stale_clients().await;
178                if removed > 0 {
179                    info!("Cleaned up {} stale clients", removed);
180                }
181            }
182        });
183    }
184}
185
186impl Default for ClientManager {
187    fn default() -> Self {
188        Self::new()
189    }
190}