elif_http/websocket/
registry.rs

1//! Connection registry for managing WebSocket connections
2
3use super::channel::{ChannelId, ChannelManager};
4use super::connection::WebSocketConnection;
5use super::types::{ConnectionId, ConnectionState, WebSocketMessage, WebSocketResult};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{debug, info};
10
11/// Events that can occur in the connection registry
12#[derive(Debug, Clone)]
13pub enum ConnectionEvent {
14    /// New connection was added
15    Connected(ConnectionId),
16    /// Connection was removed
17    Disconnected(ConnectionId, ConnectionState),
18    /// Message was broadcast to all connections
19    Broadcast(WebSocketMessage),
20    /// Message was sent to specific connection
21    MessageSent(ConnectionId, WebSocketMessage),
22}
23
24/// High-performance connection registry using Arc<RwLock<>> for concurrent access
25pub struct ConnectionRegistry {
26    /// Active connections
27    connections: Arc<RwLock<HashMap<ConnectionId, Arc<WebSocketConnection>>>>,
28    /// Channel manager for channel-based messaging
29    channel_manager: Arc<ChannelManager>,
30    /// Event subscribers (for future extensibility)
31    event_handlers: Arc<RwLock<Vec<Box<dyn Fn(ConnectionEvent) + Send + Sync>>>>,
32}
33
34impl ConnectionRegistry {
35    /// Create a new connection registry
36    pub fn new() -> Self {
37        Self {
38            connections: Arc::new(RwLock::new(HashMap::new())),
39            channel_manager: Arc::new(ChannelManager::new()),
40            event_handlers: Arc::new(RwLock::new(Vec::new())),
41        }
42    }
43
44    /// Create a new connection registry with existing channel manager
45    pub fn with_channel_manager(channel_manager: Arc<ChannelManager>) -> Self {
46        Self {
47            connections: Arc::new(RwLock::new(HashMap::new())),
48            channel_manager,
49            event_handlers: Arc::new(RwLock::new(Vec::new())),
50        }
51    }
52
53    /// Get the channel manager
54    pub fn channel_manager(&self) -> &Arc<ChannelManager> {
55        &self.channel_manager
56    }
57
58    /// Add a connection to the registry
59    pub async fn add_connection(&self, connection: WebSocketConnection) -> ConnectionId {
60        let id = connection.id;
61        let arc_connection = Arc::new(connection);
62
63        {
64            let mut connections = self.connections.write().await;
65            connections.insert(id, arc_connection);
66        }
67
68        info!("Added connection to registry: {}", id);
69        self.emit_event(ConnectionEvent::Connected(id)).await;
70
71        id
72    }
73
74    /// Remove a connection from the registry
75    pub async fn remove_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
76        let connection = {
77            let mut connections = self.connections.write().await;
78            connections.remove(&id)
79        };
80
81        if let Some(conn) = &connection {
82            let state = conn.state().await;
83
84            // Clean up channel memberships
85            self.channel_manager.leave_all_channels(id).await;
86
87            info!(
88                "Removed connection from registry: {} (state: {:?})",
89                id, state
90            );
91            self.emit_event(ConnectionEvent::Disconnected(id, state))
92                .await;
93        }
94
95        connection
96    }
97
98    /// Get a connection by ID
99    pub async fn get_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
100        let connections = self.connections.read().await;
101        connections.get(&id).cloned()
102    }
103
104    /// Get all active connections
105    pub async fn get_all_connections(&self) -> Vec<Arc<WebSocketConnection>> {
106        let connections = self.connections.read().await;
107        connections.values().cloned().collect()
108    }
109
110    /// Get all connection IDs
111    pub async fn get_connection_ids(&self) -> Vec<ConnectionId> {
112        let connections = self.connections.read().await;
113        connections.keys().copied().collect()
114    }
115
116    /// Get the number of active connections
117    pub async fn connection_count(&self) -> usize {
118        let connections = self.connections.read().await;
119        connections.len()
120    }
121
122    /// Send a message to a specific connection
123    pub async fn send_to_connection(
124        &self,
125        id: ConnectionId,
126        message: WebSocketMessage,
127    ) -> WebSocketResult<()> {
128        let connection = self
129            .get_connection(id)
130            .await
131            .ok_or(WebSocketError::ConnectionNotFound(id))?;
132
133        let result = connection.send(message.clone()).await;
134
135        if result.is_ok() {
136            self.emit_event(ConnectionEvent::MessageSent(id, message))
137                .await;
138        }
139
140        result
141    }
142
143    /// Send a text message to a specific connection
144    pub async fn send_text_to_connection<T: Into<String>>(
145        &self,
146        id: ConnectionId,
147        text: T,
148    ) -> WebSocketResult<()> {
149        self.send_to_connection(id, WebSocketMessage::text(text))
150            .await
151    }
152
153    /// Send a binary message to a specific connection
154    pub async fn send_binary_to_connection<T: Into<Vec<u8>>>(
155        &self,
156        id: ConnectionId,
157        data: T,
158    ) -> WebSocketResult<()> {
159        self.send_to_connection(id, WebSocketMessage::binary(data))
160            .await
161    }
162
163    /// Broadcast a message to all active connections
164    pub async fn broadcast(&self, message: WebSocketMessage) -> BroadcastResult {
165        let connections = self.get_all_connections().await;
166        let mut results = BroadcastResult::new();
167
168        for connection in connections {
169            if connection.is_active().await {
170                match connection.send(message.clone()).await {
171                    Ok(_) => results.success_count += 1,
172                    Err(e) => {
173                        results.failed_connections.push((connection.id, e));
174                    }
175                }
176            } else {
177                results.inactive_connections.push(connection.id);
178            }
179        }
180
181        self.emit_event(ConnectionEvent::Broadcast(message)).await;
182        results
183    }
184
185    /// Broadcast a text message to all active connections
186    pub async fn broadcast_text<T: Into<String>>(&self, text: T) -> BroadcastResult {
187        self.broadcast(WebSocketMessage::text(text)).await
188    }
189
190    /// Broadcast a binary message to all active connections
191    pub async fn broadcast_binary<T: Into<Vec<u8>>>(&self, data: T) -> BroadcastResult {
192        self.broadcast(WebSocketMessage::binary(data)).await
193    }
194
195    /// Send a message to a specific channel
196    pub async fn send_to_channel(
197        &self,
198        channel_id: ChannelId,
199        sender_id: ConnectionId,
200        message: WebSocketMessage,
201    ) -> WebSocketResult<BroadcastResult> {
202        // Get the member IDs from the channel manager
203        let member_ids = self
204            .channel_manager
205            .send_to_channel(channel_id, sender_id, message.clone())
206            .await?;
207
208        // Broadcast to all channel members
209        let mut results = BroadcastResult::new();
210
211        for member_id in member_ids {
212            if let Some(connection) = self.get_connection(member_id).await {
213                if connection.is_active().await {
214                    match connection.send(message.clone()).await {
215                        Ok(_) => results.success_count += 1,
216                        Err(e) => {
217                            results.failed_connections.push((member_id, e));
218                        }
219                    }
220                } else {
221                    results.inactive_connections.push(member_id);
222                }
223            } else {
224                // Connection not in registry but still in channel - clean up
225                let _ = self
226                    .channel_manager
227                    .leave_channel(channel_id, member_id)
228                    .await;
229            }
230        }
231
232        Ok(results)
233    }
234
235    /// Send a text message to a specific channel
236    pub async fn send_text_to_channel<T: Into<String>>(
237        &self,
238        channel_id: ChannelId,
239        sender_id: ConnectionId,
240        text: T,
241    ) -> WebSocketResult<BroadcastResult> {
242        self.send_to_channel(channel_id, sender_id, WebSocketMessage::text(text))
243            .await
244    }
245
246    /// Send a binary message to a specific channel
247    pub async fn send_binary_to_channel<T: Into<Vec<u8>>>(
248        &self,
249        channel_id: ChannelId,
250        sender_id: ConnectionId,
251        data: T,
252    ) -> WebSocketResult<BroadcastResult> {
253        self.send_to_channel(channel_id, sender_id, WebSocketMessage::binary(data))
254            .await
255    }
256
257    /// Close a specific connection
258    pub async fn close_connection(&self, id: ConnectionId) -> WebSocketResult<()> {
259        let connection = self
260            .get_connection(id)
261            .await
262            .ok_or(WebSocketError::ConnectionNotFound(id))?;
263
264        connection.close().await?;
265        self.remove_connection(id).await;
266
267        Ok(())
268    }
269
270    /// Close all connections
271    pub async fn close_all_connections(&self) -> CloseAllResult {
272        let connections = self.get_all_connections().await;
273        let mut results = CloseAllResult::new();
274        let mut to_remove = Vec::new();
275
276        for connection in connections {
277            match connection.close().await {
278                Ok(_) => {
279                    to_remove.push(connection.id);
280                    results.closed_count += 1;
281                }
282                Err(e) => {
283                    results.failed_connections.push((connection.id, e));
284                }
285            }
286        }
287
288        // Batch removal: remove all closed connections under a single write lock
289        if !to_remove.is_empty() {
290            let mut connections = self.connections.write().await;
291            for id in to_remove {
292                if let Some(conn) = connections.remove(&id) {
293                    let state = conn.state().await;
294                    info!(
295                        "Removed connection from registry: {} (state: {:?})",
296                        id, state
297                    );
298                    // Note: We can't emit events here while holding the write lock
299                    // to avoid potential deadlocks. Consider restructuring if events are critical.
300                }
301            }
302        }
303
304        results
305    }
306
307    /// Clean up inactive connections
308    pub async fn cleanup_inactive_connections(&self) -> usize {
309        let connections = self.get_all_connections().await;
310        let mut to_remove = Vec::new();
311
312        // First pass: identify inactive connections
313        for connection in connections {
314            if connection.is_closed().await {
315                to_remove.push((connection.id, connection));
316            }
317        }
318
319        let cleaned_up = to_remove.len();
320
321        // Batch removal: remove all inactive connections under a single write lock
322        if !to_remove.is_empty() {
323            let mut registry_connections = self.connections.write().await;
324            for (id, _connection) in to_remove {
325                if registry_connections.remove(&id).is_some() {
326                    debug!("Cleaned up inactive connection: {}", id);
327                    // Note: We can't emit Disconnected events here while holding the write lock
328                    // to avoid potential deadlocks. Consider restructuring if events are critical.
329                }
330            }
331        }
332
333        if cleaned_up > 0 {
334            info!("Cleaned up {} inactive connections", cleaned_up);
335        }
336
337        cleaned_up
338    }
339
340    /// Get registry statistics
341    pub async fn stats(&self) -> RegistryStats {
342        let connections = self.get_all_connections().await;
343        let mut stats = RegistryStats::default();
344
345        stats.total_connections = connections.len();
346
347        for connection in connections {
348            match connection.state().await {
349                ConnectionState::Connected => stats.active_connections += 1,
350                ConnectionState::Connecting => stats.connecting_connections += 1,
351                ConnectionState::Closing => stats.closing_connections += 1,
352                ConnectionState::Closed => stats.closed_connections += 1,
353                ConnectionState::Failed(_) => stats.failed_connections += 1,
354            }
355
356            let conn_stats = connection.stats().await;
357            stats.total_messages_sent += conn_stats.messages_sent;
358            stats.total_messages_received += conn_stats.messages_received;
359            stats.total_bytes_sent += conn_stats.bytes_sent;
360            stats.total_bytes_received += conn_stats.bytes_received;
361        }
362
363        stats
364    }
365
366    /// Add an event handler (for future extensibility)
367    pub async fn add_event_handler<F>(&self, handler: F)
368    where
369        F: Fn(ConnectionEvent) + Send + Sync + 'static,
370    {
371        let mut handlers = self.event_handlers.write().await;
372        handlers.push(Box::new(handler));
373    }
374
375    /// Emit an event to all handlers
376    async fn emit_event(&self, event: ConnectionEvent) {
377        let handlers = self.event_handlers.read().await;
378        for handler in handlers.iter() {
379            handler(event.clone());
380        }
381    }
382}
383
384impl Default for ConnectionRegistry {
385    fn default() -> Self {
386        Self::new()
387    }
388}
389
390/// Result of broadcasting a message to multiple connections
391#[derive(Debug)]
392pub struct BroadcastResult {
393    pub success_count: usize,
394    pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
395    pub inactive_connections: Vec<ConnectionId>,
396}
397
398impl BroadcastResult {
399    fn new() -> Self {
400        Self {
401            success_count: 0,
402            failed_connections: Vec::new(),
403            inactive_connections: Vec::new(),
404        }
405    }
406
407    pub fn total_attempted(&self) -> usize {
408        self.success_count + self.failed_connections.len() + self.inactive_connections.len()
409    }
410
411    pub fn has_failures(&self) -> bool {
412        !self.failed_connections.is_empty()
413    }
414}
415
416/// Result of closing all connections
417#[derive(Debug)]
418pub struct CloseAllResult {
419    pub closed_count: usize,
420    pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
421}
422
423impl CloseAllResult {
424    fn new() -> Self {
425        Self {
426            closed_count: 0,
427            failed_connections: Vec::new(),
428        }
429    }
430}
431
432/// Registry statistics
433#[derive(Debug, Default)]
434pub struct RegistryStats {
435    pub total_connections: usize,
436    pub active_connections: usize,
437    pub connecting_connections: usize,
438    pub closing_connections: usize,
439    pub closed_connections: usize,
440    pub failed_connections: usize,
441    pub total_messages_sent: u64,
442    pub total_messages_received: u64,
443    pub total_bytes_sent: u64,
444    pub total_bytes_received: u64,
445}
446
447// Re-export WebSocketError for convenience
448use super::types::WebSocketError;