1use super::channel::{ChannelId, ChannelManager};
4use super::connection::WebSocketConnection;
5use super::types::{ConnectionId, ConnectionState, WebSocketMessage, WebSocketResult};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{debug, info};
10
11#[derive(Debug, Clone)]
13pub enum ConnectionEvent {
14 Connected(ConnectionId),
16 Disconnected(ConnectionId, ConnectionState),
18 Broadcast(WebSocketMessage),
20 MessageSent(ConnectionId, WebSocketMessage),
22}
23
24pub struct ConnectionRegistry {
26 connections: Arc<RwLock<HashMap<ConnectionId, Arc<WebSocketConnection>>>>,
28 channel_manager: Arc<ChannelManager>,
30 event_handlers: Arc<RwLock<Vec<Box<dyn Fn(ConnectionEvent) + Send + Sync>>>>,
32}
33
34impl ConnectionRegistry {
35 pub fn new() -> Self {
37 Self {
38 connections: Arc::new(RwLock::new(HashMap::new())),
39 channel_manager: Arc::new(ChannelManager::new()),
40 event_handlers: Arc::new(RwLock::new(Vec::new())),
41 }
42 }
43
44 pub fn with_channel_manager(channel_manager: Arc<ChannelManager>) -> Self {
46 Self {
47 connections: Arc::new(RwLock::new(HashMap::new())),
48 channel_manager,
49 event_handlers: Arc::new(RwLock::new(Vec::new())),
50 }
51 }
52
53 pub fn channel_manager(&self) -> &Arc<ChannelManager> {
55 &self.channel_manager
56 }
57
58 pub async fn add_connection(&self, connection: WebSocketConnection) -> ConnectionId {
60 let id = connection.id;
61 let arc_connection = Arc::new(connection);
62
63 {
64 let mut connections = self.connections.write().await;
65 connections.insert(id, arc_connection);
66 }
67
68 info!("Added connection to registry: {}", id);
69 self.emit_event(ConnectionEvent::Connected(id)).await;
70
71 id
72 }
73
74 pub async fn remove_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
76 let connection = {
77 let mut connections = self.connections.write().await;
78 connections.remove(&id)
79 };
80
81 if let Some(conn) = &connection {
82 let state = conn.state().await;
83
84 self.channel_manager.leave_all_channels(id).await;
86
87 info!(
88 "Removed connection from registry: {} (state: {:?})",
89 id, state
90 );
91 self.emit_event(ConnectionEvent::Disconnected(id, state))
92 .await;
93 }
94
95 connection
96 }
97
98 pub async fn get_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
100 let connections = self.connections.read().await;
101 connections.get(&id).cloned()
102 }
103
104 pub async fn get_all_connections(&self) -> Vec<Arc<WebSocketConnection>> {
106 let connections = self.connections.read().await;
107 connections.values().cloned().collect()
108 }
109
110 pub async fn get_connection_ids(&self) -> Vec<ConnectionId> {
112 let connections = self.connections.read().await;
113 connections.keys().copied().collect()
114 }
115
116 pub async fn connection_count(&self) -> usize {
118 let connections = self.connections.read().await;
119 connections.len()
120 }
121
122 pub async fn send_to_connection(
124 &self,
125 id: ConnectionId,
126 message: WebSocketMessage,
127 ) -> WebSocketResult<()> {
128 let connection = self
129 .get_connection(id)
130 .await
131 .ok_or(WebSocketError::ConnectionNotFound(id))?;
132
133 let result = connection.send(message.clone()).await;
134
135 if result.is_ok() {
136 self.emit_event(ConnectionEvent::MessageSent(id, message))
137 .await;
138 }
139
140 result
141 }
142
143 pub async fn send_text_to_connection<T: Into<String>>(
145 &self,
146 id: ConnectionId,
147 text: T,
148 ) -> WebSocketResult<()> {
149 self.send_to_connection(id, WebSocketMessage::text(text))
150 .await
151 }
152
153 pub async fn send_binary_to_connection<T: Into<Vec<u8>>>(
155 &self,
156 id: ConnectionId,
157 data: T,
158 ) -> WebSocketResult<()> {
159 self.send_to_connection(id, WebSocketMessage::binary(data))
160 .await
161 }
162
163 pub async fn broadcast(&self, message: WebSocketMessage) -> BroadcastResult {
165 let connections = self.get_all_connections().await;
166 let mut results = BroadcastResult::new();
167
168 for connection in connections {
169 if connection.is_active().await {
170 match connection.send(message.clone()).await {
171 Ok(_) => results.success_count += 1,
172 Err(e) => {
173 results.failed_connections.push((connection.id, e));
174 }
175 }
176 } else {
177 results.inactive_connections.push(connection.id);
178 }
179 }
180
181 self.emit_event(ConnectionEvent::Broadcast(message)).await;
182 results
183 }
184
185 pub async fn broadcast_text<T: Into<String>>(&self, text: T) -> BroadcastResult {
187 self.broadcast(WebSocketMessage::text(text)).await
188 }
189
190 pub async fn broadcast_binary<T: Into<Vec<u8>>>(&self, data: T) -> BroadcastResult {
192 self.broadcast(WebSocketMessage::binary(data)).await
193 }
194
195 pub async fn send_to_channel(
197 &self,
198 channel_id: ChannelId,
199 sender_id: ConnectionId,
200 message: WebSocketMessage,
201 ) -> WebSocketResult<BroadcastResult> {
202 let member_ids = self
204 .channel_manager
205 .send_to_channel(channel_id, sender_id, message.clone())
206 .await?;
207
208 let mut results = BroadcastResult::new();
210
211 for member_id in member_ids {
212 if let Some(connection) = self.get_connection(member_id).await {
213 if connection.is_active().await {
214 match connection.send(message.clone()).await {
215 Ok(_) => results.success_count += 1,
216 Err(e) => {
217 results.failed_connections.push((member_id, e));
218 }
219 }
220 } else {
221 results.inactive_connections.push(member_id);
222 }
223 } else {
224 let _ = self
226 .channel_manager
227 .leave_channel(channel_id, member_id)
228 .await;
229 }
230 }
231
232 Ok(results)
233 }
234
235 pub async fn send_text_to_channel<T: Into<String>>(
237 &self,
238 channel_id: ChannelId,
239 sender_id: ConnectionId,
240 text: T,
241 ) -> WebSocketResult<BroadcastResult> {
242 self.send_to_channel(channel_id, sender_id, WebSocketMessage::text(text))
243 .await
244 }
245
246 pub async fn send_binary_to_channel<T: Into<Vec<u8>>>(
248 &self,
249 channel_id: ChannelId,
250 sender_id: ConnectionId,
251 data: T,
252 ) -> WebSocketResult<BroadcastResult> {
253 self.send_to_channel(channel_id, sender_id, WebSocketMessage::binary(data))
254 .await
255 }
256
257 pub async fn close_connection(&self, id: ConnectionId) -> WebSocketResult<()> {
259 let connection = self
260 .get_connection(id)
261 .await
262 .ok_or(WebSocketError::ConnectionNotFound(id))?;
263
264 connection.close().await?;
265 self.remove_connection(id).await;
266
267 Ok(())
268 }
269
270 pub async fn close_all_connections(&self) -> CloseAllResult {
272 let connections = self.get_all_connections().await;
273 let mut results = CloseAllResult::new();
274 let mut to_remove = Vec::new();
275
276 for connection in connections {
277 match connection.close().await {
278 Ok(_) => {
279 to_remove.push(connection.id);
280 results.closed_count += 1;
281 }
282 Err(e) => {
283 results.failed_connections.push((connection.id, e));
284 }
285 }
286 }
287
288 if !to_remove.is_empty() {
290 let mut connections = self.connections.write().await;
291 for id in to_remove {
292 if let Some(conn) = connections.remove(&id) {
293 let state = conn.state().await;
294 info!(
295 "Removed connection from registry: {} (state: {:?})",
296 id, state
297 );
298 }
301 }
302 }
303
304 results
305 }
306
307 pub async fn cleanup_inactive_connections(&self) -> usize {
309 let connections = self.get_all_connections().await;
310 let mut to_remove = Vec::new();
311
312 for connection in connections {
314 if connection.is_closed().await {
315 to_remove.push((connection.id, connection));
316 }
317 }
318
319 let cleaned_up = to_remove.len();
320
321 if !to_remove.is_empty() {
323 let mut registry_connections = self.connections.write().await;
324 for (id, _connection) in to_remove {
325 if registry_connections.remove(&id).is_some() {
326 debug!("Cleaned up inactive connection: {}", id);
327 }
330 }
331 }
332
333 if cleaned_up > 0 {
334 info!("Cleaned up {} inactive connections", cleaned_up);
335 }
336
337 cleaned_up
338 }
339
340 pub async fn stats(&self) -> RegistryStats {
342 let connections = self.get_all_connections().await;
343 let mut stats = RegistryStats::default();
344
345 stats.total_connections = connections.len();
346
347 for connection in connections {
348 match connection.state().await {
349 ConnectionState::Connected => stats.active_connections += 1,
350 ConnectionState::Connecting => stats.connecting_connections += 1,
351 ConnectionState::Closing => stats.closing_connections += 1,
352 ConnectionState::Closed => stats.closed_connections += 1,
353 ConnectionState::Failed(_) => stats.failed_connections += 1,
354 }
355
356 let conn_stats = connection.stats().await;
357 stats.total_messages_sent += conn_stats.messages_sent;
358 stats.total_messages_received += conn_stats.messages_received;
359 stats.total_bytes_sent += conn_stats.bytes_sent;
360 stats.total_bytes_received += conn_stats.bytes_received;
361 }
362
363 stats
364 }
365
366 pub async fn add_event_handler<F>(&self, handler: F)
368 where
369 F: Fn(ConnectionEvent) + Send + Sync + 'static,
370 {
371 let mut handlers = self.event_handlers.write().await;
372 handlers.push(Box::new(handler));
373 }
374
375 async fn emit_event(&self, event: ConnectionEvent) {
377 let handlers = self.event_handlers.read().await;
378 for handler in handlers.iter() {
379 handler(event.clone());
380 }
381 }
382}
383
384impl Default for ConnectionRegistry {
385 fn default() -> Self {
386 Self::new()
387 }
388}
389
390#[derive(Debug)]
392pub struct BroadcastResult {
393 pub success_count: usize,
394 pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
395 pub inactive_connections: Vec<ConnectionId>,
396}
397
398impl BroadcastResult {
399 fn new() -> Self {
400 Self {
401 success_count: 0,
402 failed_connections: Vec::new(),
403 inactive_connections: Vec::new(),
404 }
405 }
406
407 pub fn total_attempted(&self) -> usize {
408 self.success_count + self.failed_connections.len() + self.inactive_connections.len()
409 }
410
411 pub fn has_failures(&self) -> bool {
412 !self.failed_connections.is_empty()
413 }
414}
415
416#[derive(Debug)]
418pub struct CloseAllResult {
419 pub closed_count: usize,
420 pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
421}
422
423impl CloseAllResult {
424 fn new() -> Self {
425 Self {
426 closed_count: 0,
427 failed_connections: Vec::new(),
428 }
429 }
430}
431
432#[derive(Debug, Default)]
434pub struct RegistryStats {
435 pub total_connections: usize,
436 pub active_connections: usize,
437 pub connecting_connections: usize,
438 pub closing_connections: usize,
439 pub closed_connections: usize,
440 pub failed_connections: usize,
441 pub total_messages_sent: u64,
442 pub total_messages_received: u64,
443 pub total_bytes_sent: u64,
444 pub total_bytes_received: u64,
445}
446
447use super::types::WebSocketError;