1use super::connection::WebSocketConnection;
4use super::types::{ConnectionId, ConnectionState, WebSocketMessage, WebSocketResult};
5use super::channel::{ChannelManager, ChannelId};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{debug, info};
10
11#[derive(Debug, Clone)]
13pub enum ConnectionEvent {
14 Connected(ConnectionId),
16 Disconnected(ConnectionId, ConnectionState),
18 Broadcast(WebSocketMessage),
20 MessageSent(ConnectionId, WebSocketMessage),
22}
23
24pub struct ConnectionRegistry {
26 connections: Arc<RwLock<HashMap<ConnectionId, Arc<WebSocketConnection>>>>,
28 channel_manager: Arc<ChannelManager>,
30 event_handlers: Arc<RwLock<Vec<Box<dyn Fn(ConnectionEvent) + Send + Sync>>>>,
32}
33
34impl ConnectionRegistry {
35 pub fn new() -> Self {
37 Self {
38 connections: Arc::new(RwLock::new(HashMap::new())),
39 channel_manager: Arc::new(ChannelManager::new()),
40 event_handlers: Arc::new(RwLock::new(Vec::new())),
41 }
42 }
43
44 pub fn with_channel_manager(channel_manager: Arc<ChannelManager>) -> Self {
46 Self {
47 connections: Arc::new(RwLock::new(HashMap::new())),
48 channel_manager,
49 event_handlers: Arc::new(RwLock::new(Vec::new())),
50 }
51 }
52
53 pub fn channel_manager(&self) -> &Arc<ChannelManager> {
55 &self.channel_manager
56 }
57
58 pub async fn add_connection(&self, connection: WebSocketConnection) -> ConnectionId {
60 let id = connection.id;
61 let arc_connection = Arc::new(connection);
62
63 {
64 let mut connections = self.connections.write().await;
65 connections.insert(id, arc_connection);
66 }
67
68 info!("Added connection to registry: {}", id);
69 self.emit_event(ConnectionEvent::Connected(id)).await;
70
71 id
72 }
73
74 pub async fn remove_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
76 let connection = {
77 let mut connections = self.connections.write().await;
78 connections.remove(&id)
79 };
80
81 if let Some(conn) = &connection {
82 let state = conn.state().await;
83
84 self.channel_manager.leave_all_channels(id).await;
86
87 info!("Removed connection from registry: {} (state: {:?})", id, state);
88 self.emit_event(ConnectionEvent::Disconnected(id, state)).await;
89 }
90
91 connection
92 }
93
94 pub async fn get_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
96 let connections = self.connections.read().await;
97 connections.get(&id).cloned()
98 }
99
100 pub async fn get_all_connections(&self) -> Vec<Arc<WebSocketConnection>> {
102 let connections = self.connections.read().await;
103 connections.values().cloned().collect()
104 }
105
106 pub async fn get_connection_ids(&self) -> Vec<ConnectionId> {
108 let connections = self.connections.read().await;
109 connections.keys().copied().collect()
110 }
111
112 pub async fn connection_count(&self) -> usize {
114 let connections = self.connections.read().await;
115 connections.len()
116 }
117
118 pub async fn send_to_connection(
120 &self,
121 id: ConnectionId,
122 message: WebSocketMessage,
123 ) -> WebSocketResult<()> {
124 let connection = self.get_connection(id).await
125 .ok_or(WebSocketError::ConnectionNotFound(id))?;
126
127 let result = connection.send(message.clone()).await;
128
129 if result.is_ok() {
130 self.emit_event(ConnectionEvent::MessageSent(id, message)).await;
131 }
132
133 result
134 }
135
136 pub async fn send_text_to_connection<T: Into<String>>(
138 &self,
139 id: ConnectionId,
140 text: T,
141 ) -> WebSocketResult<()> {
142 self.send_to_connection(id, WebSocketMessage::text(text)).await
143 }
144
145 pub async fn send_binary_to_connection<T: Into<Vec<u8>>>(
147 &self,
148 id: ConnectionId,
149 data: T,
150 ) -> WebSocketResult<()> {
151 self.send_to_connection(id, WebSocketMessage::binary(data)).await
152 }
153
154 pub async fn broadcast(&self, message: WebSocketMessage) -> BroadcastResult {
156 let connections = self.get_all_connections().await;
157 let mut results = BroadcastResult::new();
158
159 for connection in connections {
160 if connection.is_active().await {
161 match connection.send(message.clone()).await {
162 Ok(_) => results.success_count += 1,
163 Err(e) => {
164 results.failed_connections.push((connection.id, e));
165 }
166 }
167 } else {
168 results.inactive_connections.push(connection.id);
169 }
170 }
171
172 self.emit_event(ConnectionEvent::Broadcast(message)).await;
173 results
174 }
175
176 pub async fn broadcast_text<T: Into<String>>(&self, text: T) -> BroadcastResult {
178 self.broadcast(WebSocketMessage::text(text)).await
179 }
180
181 pub async fn broadcast_binary<T: Into<Vec<u8>>>(&self, data: T) -> BroadcastResult {
183 self.broadcast(WebSocketMessage::binary(data)).await
184 }
185
186 pub async fn send_to_channel(
188 &self,
189 channel_id: ChannelId,
190 sender_id: ConnectionId,
191 message: WebSocketMessage,
192 ) -> WebSocketResult<BroadcastResult> {
193 let member_ids = self.channel_manager
195 .send_to_channel(channel_id, sender_id, message.clone())
196 .await?;
197
198 let mut results = BroadcastResult::new();
200
201 for member_id in member_ids {
202 if let Some(connection) = self.get_connection(member_id).await {
203 if connection.is_active().await {
204 match connection.send(message.clone()).await {
205 Ok(_) => results.success_count += 1,
206 Err(e) => {
207 results.failed_connections.push((member_id, e));
208 }
209 }
210 } else {
211 results.inactive_connections.push(member_id);
212 }
213 } else {
214 let _ = self.channel_manager.leave_channel(channel_id, member_id).await;
216 }
217 }
218
219 Ok(results)
220 }
221
222 pub async fn send_text_to_channel<T: Into<String>>(
224 &self,
225 channel_id: ChannelId,
226 sender_id: ConnectionId,
227 text: T,
228 ) -> WebSocketResult<BroadcastResult> {
229 self.send_to_channel(channel_id, sender_id, WebSocketMessage::text(text)).await
230 }
231
232 pub async fn send_binary_to_channel<T: Into<Vec<u8>>>(
234 &self,
235 channel_id: ChannelId,
236 sender_id: ConnectionId,
237 data: T,
238 ) -> WebSocketResult<BroadcastResult> {
239 self.send_to_channel(channel_id, sender_id, WebSocketMessage::binary(data)).await
240 }
241
242 pub async fn close_connection(&self, id: ConnectionId) -> WebSocketResult<()> {
244 let connection = self.get_connection(id).await
245 .ok_or(WebSocketError::ConnectionNotFound(id))?;
246
247 connection.close().await?;
248 self.remove_connection(id).await;
249
250 Ok(())
251 }
252
253 pub async fn close_all_connections(&self) -> CloseAllResult {
255 let connections = self.get_all_connections().await;
256 let mut results = CloseAllResult::new();
257 let mut to_remove = Vec::new();
258
259 for connection in connections {
260 match connection.close().await {
261 Ok(_) => {
262 to_remove.push(connection.id);
263 results.closed_count += 1;
264 }
265 Err(e) => {
266 results.failed_connections.push((connection.id, e));
267 }
268 }
269 }
270
271 if !to_remove.is_empty() {
273 let mut connections = self.connections.write().await;
274 for id in to_remove {
275 if let Some(conn) = connections.remove(&id) {
276 let state = conn.state().await;
277 info!("Removed connection from registry: {} (state: {:?})", id, state);
278 }
281 }
282 }
283
284 results
285 }
286
287 pub async fn cleanup_inactive_connections(&self) -> usize {
289 let connections = self.get_all_connections().await;
290 let mut to_remove = Vec::new();
291
292 for connection in connections {
294 if connection.is_closed().await {
295 to_remove.push((connection.id, connection));
296 }
297 }
298
299 let cleaned_up = to_remove.len();
300
301 if !to_remove.is_empty() {
303 let mut registry_connections = self.connections.write().await;
304 for (id, _connection) in to_remove {
305 if registry_connections.remove(&id).is_some() {
306 debug!("Cleaned up inactive connection: {}", id);
307 }
310 }
311 }
312
313 if cleaned_up > 0 {
314 info!("Cleaned up {} inactive connections", cleaned_up);
315 }
316
317 cleaned_up
318 }
319
320 pub async fn stats(&self) -> RegistryStats {
322 let connections = self.get_all_connections().await;
323 let mut stats = RegistryStats::default();
324
325 stats.total_connections = connections.len();
326
327 for connection in connections {
328 match connection.state().await {
329 ConnectionState::Connected => stats.active_connections += 1,
330 ConnectionState::Connecting => stats.connecting_connections += 1,
331 ConnectionState::Closing => stats.closing_connections += 1,
332 ConnectionState::Closed => stats.closed_connections += 1,
333 ConnectionState::Failed(_) => stats.failed_connections += 1,
334 }
335
336 let conn_stats = connection.stats().await;
337 stats.total_messages_sent += conn_stats.messages_sent;
338 stats.total_messages_received += conn_stats.messages_received;
339 stats.total_bytes_sent += conn_stats.bytes_sent;
340 stats.total_bytes_received += conn_stats.bytes_received;
341 }
342
343 stats
344 }
345
346 pub async fn add_event_handler<F>(&self, handler: F)
348 where
349 F: Fn(ConnectionEvent) + Send + Sync + 'static,
350 {
351 let mut handlers = self.event_handlers.write().await;
352 handlers.push(Box::new(handler));
353 }
354
355 async fn emit_event(&self, event: ConnectionEvent) {
357 let handlers = self.event_handlers.read().await;
358 for handler in handlers.iter() {
359 handler(event.clone());
360 }
361 }
362}
363
364impl Default for ConnectionRegistry {
365 fn default() -> Self {
366 Self::new()
367 }
368}
369
370#[derive(Debug)]
372pub struct BroadcastResult {
373 pub success_count: usize,
374 pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
375 pub inactive_connections: Vec<ConnectionId>,
376}
377
378impl BroadcastResult {
379 fn new() -> Self {
380 Self {
381 success_count: 0,
382 failed_connections: Vec::new(),
383 inactive_connections: Vec::new(),
384 }
385 }
386
387 pub fn total_attempted(&self) -> usize {
388 self.success_count + self.failed_connections.len() + self.inactive_connections.len()
389 }
390
391 pub fn has_failures(&self) -> bool {
392 !self.failed_connections.is_empty()
393 }
394}
395
396#[derive(Debug)]
398pub struct CloseAllResult {
399 pub closed_count: usize,
400 pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
401}
402
403impl CloseAllResult {
404 fn new() -> Self {
405 Self {
406 closed_count: 0,
407 failed_connections: Vec::new(),
408 }
409 }
410}
411
412#[derive(Debug, Default)]
414pub struct RegistryStats {
415 pub total_connections: usize,
416 pub active_connections: usize,
417 pub connecting_connections: usize,
418 pub closing_connections: usize,
419 pub closed_connections: usize,
420 pub failed_connections: usize,
421 pub total_messages_sent: u64,
422 pub total_messages_received: u64,
423 pub total_bytes_sent: u64,
424 pub total_bytes_received: u64,
425}
426
427use super::types::WebSocketError;