Skip to main content

oxigdal_websocket/server/
manager.rs

1//! Connection manager for WebSocket server
2
3use crate::error::{Error, Result};
4use crate::protocol::message::Message;
5use crate::server::connection::{Connection, ConnectionId};
6use dashmap::DashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use tokio::sync::broadcast;
10
11/// Connection manager
12pub struct ConnectionManager {
13    /// Active connections
14    connections: Arc<DashMap<ConnectionId, Arc<Connection>>>,
15    /// Connection event broadcaster
16    event_tx: broadcast::Sender<ConnectionEvent>,
17    /// Total connections counter
18    total_connections: Arc<AtomicU64>,
19    /// Maximum connections allowed
20    max_connections: usize,
21}
22
23/// Connection event
24#[derive(Debug, Clone)]
25pub enum ConnectionEvent {
26    /// Connection opened
27    Connected(ConnectionId),
28    /// Connection closed
29    Disconnected(ConnectionId),
30    /// Message received
31    MessageReceived(ConnectionId, Message),
32    /// Error occurred
33    Error(ConnectionId, String),
34}
35
36impl ConnectionManager {
37    /// Create a new connection manager
38    pub fn new(max_connections: usize) -> Self {
39        let (event_tx, _) = broadcast::channel(1000);
40
41        Self {
42            connections: Arc::new(DashMap::new()),
43            event_tx,
44            total_connections: Arc::new(AtomicU64::new(0)),
45            max_connections,
46        }
47    }
48
49    /// Add a connection
50    pub fn add(&self, connection: Arc<Connection>) -> Result<()> {
51        // Check if we've reached the limit
52        if self.connections.len() >= self.max_connections {
53            return Err(Error::ResourceExhausted(format!(
54                "Maximum connections ({}) reached",
55                self.max_connections
56            )));
57        }
58
59        let id = connection.id();
60        self.connections.insert(id, connection);
61        self.total_connections.fetch_add(1, Ordering::Relaxed);
62
63        // Broadcast connection event
64        let _ = self.event_tx.send(ConnectionEvent::Connected(id));
65
66        tracing::info!("Connection {} added", id);
67        Ok(())
68    }
69
70    /// Remove a connection
71    pub fn remove(&self, id: &ConnectionId) -> Option<Arc<Connection>> {
72        let conn = self.connections.remove(id).map(|(_, v)| v);
73
74        if conn.is_some() {
75            let _ = self.event_tx.send(ConnectionEvent::Disconnected(*id));
76            tracing::info!("Connection {} removed", id);
77        }
78
79        conn
80    }
81
82    /// Get a connection by ID
83    pub fn get(&self, id: &ConnectionId) -> Option<Arc<Connection>> {
84        self.connections.get(id).map(|r| r.value().clone())
85    }
86
87    /// Get all connections
88    pub fn all(&self) -> Vec<Arc<Connection>> {
89        self.connections.iter().map(|r| r.value().clone()).collect()
90    }
91
92    /// Get connection count
93    pub fn count(&self) -> usize {
94        self.connections.len()
95    }
96
97    /// Get total connections served
98    pub fn total_connections(&self) -> u64 {
99        self.total_connections.load(Ordering::Relaxed)
100    }
101
102    /// Broadcast a message to all connections
103    pub async fn broadcast(&self, message: Message) -> Result<usize> {
104        let connections = self.all();
105        let mut sent = 0;
106
107        for conn in connections {
108            if let Err(e) = conn.send(message.clone()).await {
109                tracing::error!("Failed to broadcast to {}: {}", conn.id(), e);
110            } else {
111                sent += 1;
112            }
113        }
114
115        Ok(sent)
116    }
117
118    /// Broadcast a message to specific connections
119    pub async fn broadcast_to(&self, ids: &[ConnectionId], message: Message) -> Result<usize> {
120        let mut sent = 0;
121
122        for id in ids {
123            if let Some(conn) = self.get(id) {
124                if let Err(e) = conn.send(message.clone()).await {
125                    tracing::error!("Failed to send to {}: {}", id, e);
126                } else {
127                    sent += 1;
128                }
129            }
130        }
131
132        Ok(sent)
133    }
134
135    /// Broadcast to connections matching a filter
136    pub async fn broadcast_filtered<F>(&self, message: Message, filter: F) -> Result<usize>
137    where
138        F: Fn(&Arc<Connection>) -> bool,
139    {
140        let connections: Vec<_> = self.all().into_iter().filter(|c| filter(c)).collect();
141
142        let mut sent = 0;
143        for conn in connections {
144            if let Err(e) = conn.send(message.clone()).await {
145                tracing::error!("Failed to broadcast to {}: {}", conn.id(), e);
146            } else {
147                sent += 1;
148            }
149        }
150
151        Ok(sent)
152    }
153
154    /// Close all connections
155    pub async fn close_all(&self) -> Result<()> {
156        let connections = self.all();
157
158        for conn in connections {
159            if let Err(e) = conn.close().await {
160                tracing::error!("Failed to close connection {}: {}", conn.id(), e);
161            }
162        }
163
164        self.connections.clear();
165        Ok(())
166    }
167
168    /// Close idle connections
169    pub async fn close_idle(&self, timeout_secs: u64) -> Result<usize> {
170        let mut closed = 0;
171        let to_close: Vec<_> = self
172            .all()
173            .into_iter()
174            .filter(|c| c.is_idle(timeout_secs))
175            .collect();
176
177        for conn in to_close {
178            let id = conn.id();
179            if let Err(e) = conn.close().await {
180                tracing::error!("Failed to close idle connection {}: {}", id, e);
181            } else {
182                self.remove(&id);
183                closed += 1;
184            }
185        }
186
187        Ok(closed)
188    }
189
190    /// Get connections by room
191    pub async fn get_by_room(&self, room: &str) -> Vec<Arc<Connection>> {
192        let mut result = Vec::new();
193
194        for conn in self.all() {
195            let metadata = conn.metadata().await;
196            if metadata.rooms.contains(room) {
197                result.push(conn);
198            }
199        }
200
201        result
202    }
203
204    /// Get connections by topic
205    pub async fn get_by_topic(&self, topic: &str) -> Vec<Arc<Connection>> {
206        let mut result = Vec::new();
207
208        for conn in self.all() {
209            let metadata = conn.metadata().await;
210            if metadata.subscriptions.contains(topic) {
211                result.push(conn);
212            }
213        }
214
215        result
216    }
217
218    /// Subscribe to connection events
219    pub fn subscribe(&self) -> broadcast::Receiver<ConnectionEvent> {
220        self.event_tx.subscribe()
221    }
222
223    /// Get manager statistics
224    pub fn stats(&self) -> ConnectionManagerStats {
225        let connections = self.all();
226
227        let mut total_messages_sent = 0u64;
228        let mut total_messages_received = 0u64;
229        let mut total_bytes_sent = 0u64;
230        let mut total_bytes_received = 0u64;
231        let mut total_errors = 0u64;
232
233        for conn in &connections {
234            let stats = conn.stats();
235            total_messages_sent += stats.messages_sent;
236            total_messages_received += stats.messages_received;
237            total_bytes_sent += stats.bytes_sent;
238            total_bytes_received += stats.bytes_received;
239            total_errors += stats.errors;
240        }
241
242        ConnectionManagerStats {
243            active_connections: connections.len(),
244            total_connections: self.total_connections(),
245            messages_sent: total_messages_sent,
246            messages_received: total_messages_received,
247            bytes_sent: total_bytes_sent,
248            bytes_received: total_bytes_received,
249            errors: total_errors,
250        }
251    }
252}
253
254/// Connection manager statistics
255#[derive(Debug, Clone)]
256pub struct ConnectionManagerStats {
257    /// Active connections
258    pub active_connections: usize,
259    /// Total connections served
260    pub total_connections: u64,
261    /// Total messages sent
262    pub messages_sent: u64,
263    /// Total messages received
264    pub messages_received: u64,
265    /// Total bytes sent
266    pub bytes_sent: u64,
267    /// Total bytes received
268    pub bytes_received: u64,
269    /// Total errors
270    pub errors: u64,
271}
272
273/// Connection statistics (re-export for convenience)
274pub use crate::server::connection::ConnectionStats;
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_connection_manager_new() {
282        let manager = ConnectionManager::new(1000);
283        assert_eq!(manager.count(), 0);
284        assert_eq!(manager.total_connections(), 0);
285    }
286
287    #[test]
288    fn test_connection_manager_stats() {
289        let manager = ConnectionManager::new(1000);
290        let stats = manager.stats();
291
292        assert_eq!(stats.active_connections, 0);
293        assert_eq!(stats.total_connections, 0);
294        assert_eq!(stats.messages_sent, 0);
295    }
296
297    #[tokio::test]
298    async fn test_connection_manager_events() {
299        let manager = ConnectionManager::new(1000);
300        let mut rx = manager.subscribe();
301
302        // Should be able to subscribe
303        assert!(rx.try_recv().is_err());
304    }
305}