Skip to main content

aerosocket_server/
manager.rs

1//! Connection manager for WebSocket server
2//!
3//! This module provides connection management, monitoring, and cleanup functionality.
4
5use crate::config::ServerConfig;
6use crate::connection::{Connection, ConnectionHandle};
7use aerosocket_core::Result;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{mpsc, Mutex};
12use tokio::time::interval;
13
14/// Connection manager statistics
15#[derive(Debug, Clone)]
16pub struct ManagerStats {
17    /// Total number of active connections
18    pub active_connections: usize,
19    /// Total number of connections since server start
20    pub total_connections: u64,
21    /// Number of connections closed due to timeout
22    pub timeout_closures: u64,
23    /// Number of connections closed due to errors
24    pub error_closures: u64,
25    /// Number of connections closed normally
26    pub normal_closures: u64,
27    /// Current memory usage in bytes
28    pub memory_usage: u64,
29    /// Peak number of concurrent connections
30    pub peak_connections: usize,
31}
32
33impl Default for ManagerStats {
34    fn default() -> Self {
35        Self {
36            active_connections: 0,
37            total_connections: 0,
38            timeout_closures: 0,
39            error_closures: 0,
40            normal_closures: 0,
41            memory_usage: 0,
42            peak_connections: 0,
43        }
44    }
45}
46
47/// Connection manager
48#[derive(Debug)]
49pub struct ConnectionManager {
50    /// Server configuration
51    config: ServerConfig,
52    /// Active connections by ID
53    connections: Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
54    /// Connection statistics
55    stats: Arc<Mutex<ManagerStats>>,
56    /// Next connection ID
57    next_id: Arc<Mutex<u64>>,
58    /// Cleanup interval
59    cleanup_interval: Duration,
60    /// Sender for cleanup notifications
61    cleanup_tx: mpsc::Sender<u64>,
62    /// Receiver for cleanup notifications
63    cleanup_rx: Arc<Mutex<mpsc::Receiver<u64>>>,
64}
65
66impl ConnectionManager {
67    /// Create a new connection manager
68    pub fn new(config: ServerConfig) -> Self {
69        let (cleanup_tx, cleanup_rx) = mpsc::channel(1000);
70
71        Self {
72            cleanup_interval: Duration::from_secs(30), // Default cleanup interval
73            config,
74            connections: Arc::new(Mutex::new(HashMap::new())),
75            stats: Arc::new(Mutex::new(ManagerStats::default())),
76            next_id: Arc::new(Mutex::new(1)),
77            cleanup_tx,
78            cleanup_rx: Arc::new(Mutex::new(cleanup_rx)),
79        }
80    }
81
82    /// Set the cleanup interval
83    pub fn set_cleanup_interval(&mut self, interval: Duration) {
84        self.cleanup_interval = interval;
85    }
86
87    /// Add a new connection and return its assigned ID.
88    pub async fn add_connection(&self, connection: Connection) -> u64 {
89        let mut next_id = self.next_id.lock().await;
90        let id = *next_id;
91        *next_id += 1;
92
93        let handle = ConnectionHandle::new(id, connection);
94
95        let mut connections = self.connections.lock().await;
96        connections.insert(id, handle);
97
98        // Update statistics
99        let mut stats = self.stats.lock().await;
100        stats.active_connections = connections.len();
101        stats.total_connections += 1;
102        stats.peak_connections = stats.peak_connections.max(stats.active_connections);
103
104        id
105    }
106
107    /// Remove a connection
108    pub async fn remove_connection(&self, id: u64, reason: CloseReason) {
109        let mut connections = self.connections.lock().await;
110        if connections.remove(&id).is_some() {
111            // Update statistics
112            let mut stats = self.stats.lock().await;
113            stats.active_connections = connections.len();
114
115            match reason {
116                CloseReason::Timeout => stats.timeout_closures += 1,
117                CloseReason::Error => stats.error_closures += 1,
118                CloseReason::Normal => stats.normal_closures += 1,
119            }
120        }
121    }
122
123    /// Get connection by ID
124    pub async fn get_connection(&self, id: u64) -> Option<ConnectionHandle> {
125        let connections = self.connections.lock().await;
126        connections.get(&id).cloned()
127    }
128
129    /// Get all active connections
130    pub async fn get_all_connections(&self) -> Vec<ConnectionHandle> {
131        let connections = self.connections.lock().await;
132        connections.values().cloned().collect()
133    }
134
135    /// Get current connection count
136    pub async fn connection_count(&self) -> usize {
137        let connections = self.connections.lock().await;
138        connections.len()
139    }
140
141    /// Get connection manager statistics
142    pub async fn get_stats(&self) -> ManagerStats {
143        let stats = self.stats.lock().await;
144        ManagerStats {
145            active_connections: stats.active_connections,
146            total_connections: stats.total_connections,
147            timeout_closures: stats.timeout_closures,
148            error_closures: stats.error_closures,
149            normal_closures: stats.normal_closures,
150            memory_usage: stats.memory_usage,
151            peak_connections: stats.peak_connections,
152        }
153    }
154
155    /// Start the cleanup task
156    pub async fn start_cleanup_task(&self) {
157        let connections = self.connections.clone();
158        let stats = self.stats.clone();
159        let cleanup_rx = self.cleanup_rx.clone();
160        let cleanup_interval = self.cleanup_interval;
161        let idle_timeout = self.config.idle_timeout;
162
163        tokio::spawn(async move {
164            let mut cleanup_interval_timer = interval(cleanup_interval);
165            let mut cleanup_receiver = cleanup_rx.lock().await;
166
167            loop {
168                tokio::select! {
169                    _ = cleanup_interval_timer.tick() => {
170                        // Periodic cleanup
171                        Self::cleanup_idle_connections(&connections, &stats, idle_timeout).await;
172                    }
173                    Some(id) = cleanup_receiver.recv() => {
174                        // Immediate cleanup for specific connection
175                        Self::remove_connection_internal(&connections, &stats, id, CloseReason::Timeout).await;
176                    }
177                }
178            }
179        });
180    }
181
182    /// Cleanup idle connections
183    async fn cleanup_idle_connections(
184        connections: &Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
185        stats: &Arc<Mutex<ManagerStats>>,
186        _idle_timeout: Duration,
187    ) {
188        let mut connections_map = connections.lock().await;
189        let mut to_remove = Vec::new();
190
191        for (id, handle) in connections_map.iter() {
192            if let Ok(connection) = handle.try_lock().await {
193                if connection.is_timed_out() {
194                    to_remove.push(*id);
195                }
196            }
197        }
198
199        for id in to_remove {
200            connections_map.remove(&id);
201            let mut stats = stats.lock().await;
202            stats.active_connections = connections_map.len();
203            stats.timeout_closures += 1;
204        }
205    }
206
207    /// Internal connection removal
208    async fn remove_connection_internal(
209        connections: &Arc<Mutex<HashMap<u64, ConnectionHandle>>>,
210        stats: &Arc<Mutex<ManagerStats>>,
211        id: u64,
212        reason: CloseReason,
213    ) {
214        let mut connections_map = connections.lock().await;
215        if connections_map.remove(&id).is_some() {
216            let mut stats = stats.lock().await;
217            stats.active_connections = connections_map.len();
218
219            match reason {
220                CloseReason::Timeout => stats.timeout_closures += 1,
221                CloseReason::Error => stats.error_closures += 1,
222                CloseReason::Normal => stats.normal_closures += 1,
223            }
224        }
225    }
226
227    /// Monitor connection health
228    pub async fn monitor_connections(&self) -> Result<Vec<ConnectionHealth>> {
229        let connections = self.connections.lock().await;
230        let mut health_reports = Vec::new();
231
232        for (id, handle) in connections.iter() {
233            if let Ok(connection) = handle.try_lock().await {
234                let health = ConnectionHealth {
235                    id: *id,
236                    remote_addr: connection.remote_addr(),
237                    state: connection.state(),
238                    uptime: connection.metadata().established_at.elapsed(),
239                    last_activity: connection.metadata().last_activity_at.elapsed(),
240                    messages_sent: connection.metadata().messages_sent,
241                    messages_received: connection.metadata().messages_received,
242                    bytes_sent: connection.metadata().bytes_sent,
243                    bytes_received: connection.metadata().bytes_received,
244                    time_until_timeout: connection.time_until_timeout(),
245                };
246                health_reports.push(health);
247            }
248        }
249
250        Ok(health_reports)
251    }
252
253    /// Close all connections
254    pub async fn close_all_connections(&self) {        let connections = self.connections.lock().await;
255        let handles: Vec<_> = connections.values().cloned().collect();
256        let connection_count = connections.len();
257        drop(connections);
258
259        for handle in handles {
260            if let Ok(mut connection) = handle.try_lock().await {
261                let _ = connection.close(Some(1000), Some("Server shutdown")).await;
262            }
263        }
264
265        // Clear all connections and update stats
266        let mut connections_map = self.connections.lock().await;
267        connections_map.clear();
268
269        // Update statistics
270        let mut stats = self.stats.lock().await;
271        stats.active_connections = 0;
272        stats.normal_closures += connection_count as u64;
273    }
274
275    /// Broadcast binary to every active connection.
276    pub async fn broadcast_binary_to_all(&self, data: &[u8]) -> aerosocket_core::Result<()> {
277        let data = data.to_vec();
278        let handles = self.get_all_connections().await;
279        for handle in handles {
280            if let Ok(mut conn) = handle.try_lock().await {
281                let _ = conn.send_binary(data.clone()).await;
282            }
283        }
284        Ok(())
285    }
286
287    /// Broadcast text to every active connection.
288    pub async fn broadcast_text_to_all(&self, text: &str) -> aerosocket_core::Result<()> {
289        let handles = self.get_all_connections().await;
290        for handle in handles {
291            if let Ok(mut conn) = handle.try_lock().await {
292                let _ = conn.send_text(text).await;
293            }
294        }
295        Ok(())
296    }
297
298    /// Broadcast binary to all connections except the one with `except_id`.
299    pub async fn broadcast_binary_except(
300        &self,
301        data: &[u8],
302        except_id: u64,
303    ) -> aerosocket_core::Result<()> {
304        let data = data.to_vec();
305        let handles = self.get_all_connections().await;
306        for handle in handles {
307            if handle.id() != except_id {
308                if let Ok(mut conn) = handle.try_lock().await {
309                    let _ = conn.send_binary(data.clone()).await;
310                }
311            }
312        }
313        Ok(())
314    }
315
316    /// Broadcast text to all connections except the one with `except_id`.
317    pub async fn broadcast_text_except(
318        &self,
319        text: &str,
320        except_id: u64,
321    ) -> aerosocket_core::Result<()> {
322        let handles = self.get_all_connections().await;
323        for handle in handles {
324            if handle.id() != except_id {
325                if let Ok(mut conn) = handle.try_lock().await {
326                    let _ = conn.send_text(text).await;
327                }
328            }
329        }
330        Ok(())
331    }
332}
333
334impl Drop for ConnectionManager {
335    fn drop(&mut self) {
336        // Guard with try_current() so Drop does not panic when called outside a
337        // Tokio runtime (e.g. in synchronous unit tests).
338        if tokio::runtime::Handle::try_current().is_ok() {
339            let connections = self.connections.clone();
340            tokio::spawn(async move {
341                let manager = ConnectionManager {
342                    config: ServerConfig::default(),
343                    connections,
344                    stats: Arc::new(Mutex::new(ManagerStats::default())),
345                    next_id: Arc::new(Mutex::new(0)),
346                    cleanup_interval: Duration::ZERO,
347                    cleanup_tx: mpsc::channel(1).0,
348                    cleanup_rx: Arc::new(Mutex::new(mpsc::channel(1).1)),
349                };
350                manager.close_all_connections().await;
351            });
352        }
353    }
354}
355
356/// Connection close reason
357#[derive(Debug, Clone, Copy, PartialEq, Eq)]
358pub enum CloseReason {
359    /// Connection closed due to timeout
360    Timeout,
361    /// Connection closed due to error
362    Error,
363    /// Connection closed normally
364    Normal,
365}
366
367/// Connection health information
368#[derive(Debug, Clone)]
369pub struct ConnectionHealth {
370    /// Connection ID
371    pub id: u64,
372    /// Remote address
373    pub remote_addr: std::net::SocketAddr,
374    /// Connection state
375    pub state: crate::connection::ConnectionState,
376    /// How long the connection has been active
377    pub uptime: Duration,
378    /// Time since last activity
379    pub last_activity: Duration,
380    /// Number of messages sent
381    pub messages_sent: u64,
382    /// Number of messages received
383    pub messages_received: u64,
384    /// Number of bytes sent
385    pub bytes_sent: u64,
386    /// Number of bytes received
387    pub bytes_received: u64,
388    /// Time until connection times out
389    pub time_until_timeout: Option<Duration>,
390}