elif_http/websocket/
registry.rs

1//! Connection registry for managing WebSocket connections
2
3use super::connection::WebSocketConnection;
4use super::types::{ConnectionId, ConnectionState, WebSocketMessage, WebSocketResult};
5use super::channel::{ChannelManager, ChannelId};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{debug, info};
10
11/// Events that can occur in the connection registry
12#[derive(Debug, Clone)]
13pub enum ConnectionEvent {
14    /// New connection was added
15    Connected(ConnectionId),
16    /// Connection was removed
17    Disconnected(ConnectionId, ConnectionState),
18    /// Message was broadcast to all connections
19    Broadcast(WebSocketMessage),
20    /// Message was sent to specific connection
21    MessageSent(ConnectionId, WebSocketMessage),
22}
23
24/// High-performance connection registry using Arc<RwLock<>> for concurrent access
25pub struct ConnectionRegistry {
26    /// Active connections
27    connections: Arc<RwLock<HashMap<ConnectionId, Arc<WebSocketConnection>>>>,
28    /// Channel manager for channel-based messaging
29    channel_manager: Arc<ChannelManager>,
30    /// Event subscribers (for future extensibility)
31    event_handlers: Arc<RwLock<Vec<Box<dyn Fn(ConnectionEvent) + Send + Sync>>>>,
32}
33
34impl ConnectionRegistry {
35    /// Create a new connection registry
36    pub fn new() -> Self {
37        Self {
38            connections: Arc::new(RwLock::new(HashMap::new())),
39            channel_manager: Arc::new(ChannelManager::new()),
40            event_handlers: Arc::new(RwLock::new(Vec::new())),
41        }
42    }
43
44    /// Create a new connection registry with existing channel manager
45    pub fn with_channel_manager(channel_manager: Arc<ChannelManager>) -> Self {
46        Self {
47            connections: Arc::new(RwLock::new(HashMap::new())),
48            channel_manager,
49            event_handlers: Arc::new(RwLock::new(Vec::new())),
50        }
51    }
52
53    /// Get the channel manager
54    pub fn channel_manager(&self) -> &Arc<ChannelManager> {
55        &self.channel_manager
56    }
57
58    /// Add a connection to the registry
59    pub async fn add_connection(&self, connection: WebSocketConnection) -> ConnectionId {
60        let id = connection.id;
61        let arc_connection = Arc::new(connection);
62        
63        {
64            let mut connections = self.connections.write().await;
65            connections.insert(id, arc_connection);
66        }
67        
68        info!("Added connection to registry: {}", id);
69        self.emit_event(ConnectionEvent::Connected(id)).await;
70        
71        id
72    }
73
74    /// Remove a connection from the registry
75    pub async fn remove_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
76        let connection = {
77            let mut connections = self.connections.write().await;
78            connections.remove(&id)
79        };
80
81        if let Some(conn) = &connection {
82            let state = conn.state().await;
83            
84            // Clean up channel memberships
85            self.channel_manager.leave_all_channels(id).await;
86            
87            info!("Removed connection from registry: {} (state: {:?})", id, state);
88            self.emit_event(ConnectionEvent::Disconnected(id, state)).await;
89        }
90
91        connection
92    }
93
94    /// Get a connection by ID
95    pub async fn get_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
96        let connections = self.connections.read().await;
97        connections.get(&id).cloned()
98    }
99
100    /// Get all active connections
101    pub async fn get_all_connections(&self) -> Vec<Arc<WebSocketConnection>> {
102        let connections = self.connections.read().await;
103        connections.values().cloned().collect()
104    }
105
106    /// Get all connection IDs
107    pub async fn get_connection_ids(&self) -> Vec<ConnectionId> {
108        let connections = self.connections.read().await;
109        connections.keys().copied().collect()
110    }
111
112    /// Get the number of active connections
113    pub async fn connection_count(&self) -> usize {
114        let connections = self.connections.read().await;
115        connections.len()
116    }
117
118    /// Send a message to a specific connection
119    pub async fn send_to_connection(
120        &self,
121        id: ConnectionId,
122        message: WebSocketMessage,
123    ) -> WebSocketResult<()> {
124        let connection = self.get_connection(id).await
125            .ok_or(WebSocketError::ConnectionNotFound(id))?;
126
127        let result = connection.send(message.clone()).await;
128        
129        if result.is_ok() {
130            self.emit_event(ConnectionEvent::MessageSent(id, message)).await;
131        }
132        
133        result
134    }
135
136    /// Send a text message to a specific connection
137    pub async fn send_text_to_connection<T: Into<String>>(
138        &self,
139        id: ConnectionId,
140        text: T,
141    ) -> WebSocketResult<()> {
142        self.send_to_connection(id, WebSocketMessage::text(text)).await
143    }
144
145    /// Send a binary message to a specific connection
146    pub async fn send_binary_to_connection<T: Into<Vec<u8>>>(
147        &self,
148        id: ConnectionId,
149        data: T,
150    ) -> WebSocketResult<()> {
151        self.send_to_connection(id, WebSocketMessage::binary(data)).await
152    }
153
154    /// Broadcast a message to all active connections
155    pub async fn broadcast(&self, message: WebSocketMessage) -> BroadcastResult {
156        let connections = self.get_all_connections().await;
157        let mut results = BroadcastResult::new();
158
159        for connection in connections {
160            if connection.is_active().await {
161                match connection.send(message.clone()).await {
162                    Ok(_) => results.success_count += 1,
163                    Err(e) => {
164                        results.failed_connections.push((connection.id, e));
165                    }
166                }
167            } else {
168                results.inactive_connections.push(connection.id);
169            }
170        }
171
172        self.emit_event(ConnectionEvent::Broadcast(message)).await;
173        results
174    }
175
176    /// Broadcast a text message to all active connections
177    pub async fn broadcast_text<T: Into<String>>(&self, text: T) -> BroadcastResult {
178        self.broadcast(WebSocketMessage::text(text)).await
179    }
180
181    /// Broadcast a binary message to all active connections
182    pub async fn broadcast_binary<T: Into<Vec<u8>>>(&self, data: T) -> BroadcastResult {
183        self.broadcast(WebSocketMessage::binary(data)).await
184    }
185
186    /// Send a message to a specific channel
187    pub async fn send_to_channel(
188        &self,
189        channel_id: ChannelId,
190        sender_id: ConnectionId,
191        message: WebSocketMessage,
192    ) -> WebSocketResult<BroadcastResult> {
193        // Get the member IDs from the channel manager
194        let member_ids = self.channel_manager
195            .send_to_channel(channel_id, sender_id, message.clone())
196            .await?;
197
198        // Broadcast to all channel members
199        let mut results = BroadcastResult::new();
200        
201        for member_id in member_ids {
202            if let Some(connection) = self.get_connection(member_id).await {
203                if connection.is_active().await {
204                    match connection.send(message.clone()).await {
205                        Ok(_) => results.success_count += 1,
206                        Err(e) => {
207                            results.failed_connections.push((member_id, e));
208                        }
209                    }
210                } else {
211                    results.inactive_connections.push(member_id);
212                }
213            } else {
214                // Connection not in registry but still in channel - clean up
215                let _ = self.channel_manager.leave_channel(channel_id, member_id).await;
216            }
217        }
218
219        Ok(results)
220    }
221
222    /// Send a text message to a specific channel
223    pub async fn send_text_to_channel<T: Into<String>>(
224        &self,
225        channel_id: ChannelId,
226        sender_id: ConnectionId,
227        text: T,
228    ) -> WebSocketResult<BroadcastResult> {
229        self.send_to_channel(channel_id, sender_id, WebSocketMessage::text(text)).await
230    }
231
232    /// Send a binary message to a specific channel
233    pub async fn send_binary_to_channel<T: Into<Vec<u8>>>(
234        &self,
235        channel_id: ChannelId,
236        sender_id: ConnectionId,
237        data: T,
238    ) -> WebSocketResult<BroadcastResult> {
239        self.send_to_channel(channel_id, sender_id, WebSocketMessage::binary(data)).await
240    }
241
242    /// Close a specific connection
243    pub async fn close_connection(&self, id: ConnectionId) -> WebSocketResult<()> {
244        let connection = self.get_connection(id).await
245            .ok_or(WebSocketError::ConnectionNotFound(id))?;
246
247        connection.close().await?;
248        self.remove_connection(id).await;
249        
250        Ok(())
251    }
252
253    /// Close all connections
254    pub async fn close_all_connections(&self) -> CloseAllResult {
255        let connections = self.get_all_connections().await;
256        let mut results = CloseAllResult::new();
257        let mut to_remove = Vec::new();
258
259        for connection in connections {
260            match connection.close().await {
261                Ok(_) => {
262                    to_remove.push(connection.id);
263                    results.closed_count += 1;
264                }
265                Err(e) => {
266                    results.failed_connections.push((connection.id, e));
267                }
268            }
269        }
270
271        // Batch removal: remove all closed connections under a single write lock
272        if !to_remove.is_empty() {
273            let mut connections = self.connections.write().await;
274            for id in to_remove {
275                if let Some(conn) = connections.remove(&id) {
276                    let state = conn.state().await;
277                    info!("Removed connection from registry: {} (state: {:?})", id, state);
278                    // Note: We can't emit events here while holding the write lock
279                    // to avoid potential deadlocks. Consider restructuring if events are critical.
280                }
281            }
282        }
283
284        results
285    }
286
287    /// Clean up inactive connections
288    pub async fn cleanup_inactive_connections(&self) -> usize {
289        let connections = self.get_all_connections().await;
290        let mut to_remove = Vec::new();
291
292        // First pass: identify inactive connections
293        for connection in connections {
294            if connection.is_closed().await {
295                to_remove.push((connection.id, connection));
296            }
297        }
298
299        let cleaned_up = to_remove.len();
300
301        // Batch removal: remove all inactive connections under a single write lock
302        if !to_remove.is_empty() {
303            let mut registry_connections = self.connections.write().await;
304            for (id, _connection) in to_remove {
305                if registry_connections.remove(&id).is_some() {
306                    debug!("Cleaned up inactive connection: {}", id);
307                    // Note: We can't emit Disconnected events here while holding the write lock
308                    // to avoid potential deadlocks. Consider restructuring if events are critical.
309                }
310            }
311        }
312
313        if cleaned_up > 0 {
314            info!("Cleaned up {} inactive connections", cleaned_up);
315        }
316
317        cleaned_up
318    }
319
320    /// Get registry statistics
321    pub async fn stats(&self) -> RegistryStats {
322        let connections = self.get_all_connections().await;
323        let mut stats = RegistryStats::default();
324        
325        stats.total_connections = connections.len();
326        
327        for connection in connections {
328            match connection.state().await {
329                ConnectionState::Connected => stats.active_connections += 1,
330                ConnectionState::Connecting => stats.connecting_connections += 1,
331                ConnectionState::Closing => stats.closing_connections += 1,
332                ConnectionState::Closed => stats.closed_connections += 1,
333                ConnectionState::Failed(_) => stats.failed_connections += 1,
334            }
335            
336            let conn_stats = connection.stats().await;
337            stats.total_messages_sent += conn_stats.messages_sent;
338            stats.total_messages_received += conn_stats.messages_received;
339            stats.total_bytes_sent += conn_stats.bytes_sent;
340            stats.total_bytes_received += conn_stats.bytes_received;
341        }
342
343        stats
344    }
345
346    /// Add an event handler (for future extensibility)
347    pub async fn add_event_handler<F>(&self, handler: F)
348    where
349        F: Fn(ConnectionEvent) + Send + Sync + 'static,
350    {
351        let mut handlers = self.event_handlers.write().await;
352        handlers.push(Box::new(handler));
353    }
354
355    /// Emit an event to all handlers
356    async fn emit_event(&self, event: ConnectionEvent) {
357        let handlers = self.event_handlers.read().await;
358        for handler in handlers.iter() {
359            handler(event.clone());
360        }
361    }
362}
363
364impl Default for ConnectionRegistry {
365    fn default() -> Self {
366        Self::new()
367    }
368}
369
370/// Result of broadcasting a message to multiple connections
371#[derive(Debug)]
372pub struct BroadcastResult {
373    pub success_count: usize,
374    pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
375    pub inactive_connections: Vec<ConnectionId>,
376}
377
378impl BroadcastResult {
379    fn new() -> Self {
380        Self {
381            success_count: 0,
382            failed_connections: Vec::new(),
383            inactive_connections: Vec::new(),
384        }
385    }
386
387    pub fn total_attempted(&self) -> usize {
388        self.success_count + self.failed_connections.len() + self.inactive_connections.len()
389    }
390
391    pub fn has_failures(&self) -> bool {
392        !self.failed_connections.is_empty()
393    }
394}
395
396/// Result of closing all connections
397#[derive(Debug)]
398pub struct CloseAllResult {
399    pub closed_count: usize,
400    pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
401}
402
403impl CloseAllResult {
404    fn new() -> Self {
405        Self {
406            closed_count: 0,
407            failed_connections: Vec::new(),
408        }
409    }
410}
411
412/// Registry statistics
413#[derive(Debug, Default)]
414pub struct RegistryStats {
415    pub total_connections: usize,
416    pub active_connections: usize,
417    pub connecting_connections: usize,
418    pub closing_connections: usize,
419    pub closed_connections: usize,
420    pub failed_connections: usize,
421    pub total_messages_sent: u64,
422    pub total_messages_received: u64,
423    pub total_bytes_sent: u64,
424    pub total_bytes_received: u64,
425}
426
427// Re-export WebSocketError for convenience
428use super::types::WebSocketError;