oxigdal_websocket/server/
manager.rs1use crate::error::{Error, Result};
4use crate::protocol::message::Message;
5use crate::server::connection::{Connection, ConnectionId};
6use dashmap::DashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use tokio::sync::broadcast;
10
11pub struct ConnectionManager {
13 connections: Arc<DashMap<ConnectionId, Arc<Connection>>>,
15 event_tx: broadcast::Sender<ConnectionEvent>,
17 total_connections: Arc<AtomicU64>,
19 max_connections: usize,
21}
22
23#[derive(Debug, Clone)]
25pub enum ConnectionEvent {
26 Connected(ConnectionId),
28 Disconnected(ConnectionId),
30 MessageReceived(ConnectionId, Message),
32 Error(ConnectionId, String),
34}
35
36impl ConnectionManager {
37 pub fn new(max_connections: usize) -> Self {
39 let (event_tx, _) = broadcast::channel(1000);
40
41 Self {
42 connections: Arc::new(DashMap::new()),
43 event_tx,
44 total_connections: Arc::new(AtomicU64::new(0)),
45 max_connections,
46 }
47 }
48
49 pub fn add(&self, connection: Arc<Connection>) -> Result<()> {
51 if self.connections.len() >= self.max_connections {
53 return Err(Error::ResourceExhausted(format!(
54 "Maximum connections ({}) reached",
55 self.max_connections
56 )));
57 }
58
59 let id = connection.id();
60 self.connections.insert(id, connection);
61 self.total_connections.fetch_add(1, Ordering::Relaxed);
62
63 let _ = self.event_tx.send(ConnectionEvent::Connected(id));
65
66 tracing::info!("Connection {} added", id);
67 Ok(())
68 }
69
70 pub fn remove(&self, id: &ConnectionId) -> Option<Arc<Connection>> {
72 let conn = self.connections.remove(id).map(|(_, v)| v);
73
74 if conn.is_some() {
75 let _ = self.event_tx.send(ConnectionEvent::Disconnected(*id));
76 tracing::info!("Connection {} removed", id);
77 }
78
79 conn
80 }
81
82 pub fn get(&self, id: &ConnectionId) -> Option<Arc<Connection>> {
84 self.connections.get(id).map(|r| r.value().clone())
85 }
86
87 pub fn all(&self) -> Vec<Arc<Connection>> {
89 self.connections.iter().map(|r| r.value().clone()).collect()
90 }
91
92 pub fn count(&self) -> usize {
94 self.connections.len()
95 }
96
97 pub fn total_connections(&self) -> u64 {
99 self.total_connections.load(Ordering::Relaxed)
100 }
101
102 pub async fn broadcast(&self, message: Message) -> Result<usize> {
104 let connections = self.all();
105 let mut sent = 0;
106
107 for conn in connections {
108 if let Err(e) = conn.send(message.clone()).await {
109 tracing::error!("Failed to broadcast to {}: {}", conn.id(), e);
110 } else {
111 sent += 1;
112 }
113 }
114
115 Ok(sent)
116 }
117
118 pub async fn broadcast_to(&self, ids: &[ConnectionId], message: Message) -> Result<usize> {
120 let mut sent = 0;
121
122 for id in ids {
123 if let Some(conn) = self.get(id) {
124 if let Err(e) = conn.send(message.clone()).await {
125 tracing::error!("Failed to send to {}: {}", id, e);
126 } else {
127 sent += 1;
128 }
129 }
130 }
131
132 Ok(sent)
133 }
134
135 pub async fn broadcast_filtered<F>(&self, message: Message, filter: F) -> Result<usize>
137 where
138 F: Fn(&Arc<Connection>) -> bool,
139 {
140 let connections: Vec<_> = self.all().into_iter().filter(|c| filter(c)).collect();
141
142 let mut sent = 0;
143 for conn in connections {
144 if let Err(e) = conn.send(message.clone()).await {
145 tracing::error!("Failed to broadcast to {}: {}", conn.id(), e);
146 } else {
147 sent += 1;
148 }
149 }
150
151 Ok(sent)
152 }
153
154 pub async fn close_all(&self) -> Result<()> {
156 let connections = self.all();
157
158 for conn in connections {
159 if let Err(e) = conn.close().await {
160 tracing::error!("Failed to close connection {}: {}", conn.id(), e);
161 }
162 }
163
164 self.connections.clear();
165 Ok(())
166 }
167
168 pub async fn close_idle(&self, timeout_secs: u64) -> Result<usize> {
170 let mut closed = 0;
171 let to_close: Vec<_> = self
172 .all()
173 .into_iter()
174 .filter(|c| c.is_idle(timeout_secs))
175 .collect();
176
177 for conn in to_close {
178 let id = conn.id();
179 if let Err(e) = conn.close().await {
180 tracing::error!("Failed to close idle connection {}: {}", id, e);
181 } else {
182 self.remove(&id);
183 closed += 1;
184 }
185 }
186
187 Ok(closed)
188 }
189
190 pub async fn get_by_room(&self, room: &str) -> Vec<Arc<Connection>> {
192 let mut result = Vec::new();
193
194 for conn in self.all() {
195 let metadata = conn.metadata().await;
196 if metadata.rooms.contains(room) {
197 result.push(conn);
198 }
199 }
200
201 result
202 }
203
204 pub async fn get_by_topic(&self, topic: &str) -> Vec<Arc<Connection>> {
206 let mut result = Vec::new();
207
208 for conn in self.all() {
209 let metadata = conn.metadata().await;
210 if metadata.subscriptions.contains(topic) {
211 result.push(conn);
212 }
213 }
214
215 result
216 }
217
218 pub fn subscribe(&self) -> broadcast::Receiver<ConnectionEvent> {
220 self.event_tx.subscribe()
221 }
222
223 pub fn stats(&self) -> ConnectionManagerStats {
225 let connections = self.all();
226
227 let mut total_messages_sent = 0u64;
228 let mut total_messages_received = 0u64;
229 let mut total_bytes_sent = 0u64;
230 let mut total_bytes_received = 0u64;
231 let mut total_errors = 0u64;
232
233 for conn in &connections {
234 let stats = conn.stats();
235 total_messages_sent += stats.messages_sent;
236 total_messages_received += stats.messages_received;
237 total_bytes_sent += stats.bytes_sent;
238 total_bytes_received += stats.bytes_received;
239 total_errors += stats.errors;
240 }
241
242 ConnectionManagerStats {
243 active_connections: connections.len(),
244 total_connections: self.total_connections(),
245 messages_sent: total_messages_sent,
246 messages_received: total_messages_received,
247 bytes_sent: total_bytes_sent,
248 bytes_received: total_bytes_received,
249 errors: total_errors,
250 }
251 }
252}
253
254#[derive(Debug, Clone)]
256pub struct ConnectionManagerStats {
257 pub active_connections: usize,
259 pub total_connections: u64,
261 pub messages_sent: u64,
263 pub messages_received: u64,
265 pub bytes_sent: u64,
267 pub bytes_received: u64,
269 pub errors: u64,
271}
272
273pub use crate::server::connection::ConnectionStats;
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_connection_manager_new() {
282 let manager = ConnectionManager::new(1000);
283 assert_eq!(manager.count(), 0);
284 assert_eq!(manager.total_connections(), 0);
285 }
286
287 #[test]
288 fn test_connection_manager_stats() {
289 let manager = ConnectionManager::new(1000);
290 let stats = manager.stats();
291
292 assert_eq!(stats.active_connections, 0);
293 assert_eq!(stats.total_connections, 0);
294 assert_eq!(stats.messages_sent, 0);
295 }
296
297 #[tokio::test]
298 async fn test_connection_manager_events() {
299 let manager = ConnectionManager::new(1000);
300 let mut rx = manager.subscribe();
301
302 assert!(rx.try_recv().is_err());
304 }
305}