1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::RwLock;
5use tokio::sync::mpsc;
6
7#[derive(Debug, Clone)]
33pub struct PingPongConfig {
34 ping_interval: Duration,
36 pong_timeout: Duration,
38}
39
40impl Default for PingPongConfig {
41 fn default() -> Self {
42 Self {
43 ping_interval: Duration::from_secs(30),
44 pong_timeout: Duration::from_secs(10),
45 }
46 }
47}
48
49impl PingPongConfig {
50 pub fn new(ping_interval: Duration, pong_timeout: Duration) -> Self {
71 Self {
72 ping_interval,
73 pong_timeout,
74 }
75 }
76
77 pub fn ping_interval(&self) -> Duration {
79 self.ping_interval
80 }
81
82 pub fn pong_timeout(&self) -> Duration {
84 self.pong_timeout
85 }
86
87 pub fn with_ping_interval(mut self, interval: Duration) -> Self {
100 self.ping_interval = interval;
101 self
102 }
103
104 pub fn with_pong_timeout(mut self, timeout: Duration) -> Self {
117 self.pong_timeout = timeout;
118 self
119 }
120}
121
122#[derive(Debug, Clone)]
149pub struct ConnectionConfig {
150 idle_timeout: Duration,
151 handshake_timeout: Duration,
152 cleanup_interval: Duration,
153 max_connections: Option<usize>,
155 ping_config: PingPongConfig,
157}
158
159impl Default for ConnectionConfig {
160 fn default() -> Self {
161 Self {
162 idle_timeout: Duration::from_secs(300), handshake_timeout: Duration::from_secs(10), cleanup_interval: Duration::from_secs(30), max_connections: None, ping_config: PingPongConfig::default(),
167 }
168 }
169}
170
171impl ConnectionConfig {
172 pub fn new() -> Self {
186 Self::default()
187 }
188
189 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
207 self.idle_timeout = timeout;
208 self
209 }
210
211 pub fn with_handshake_timeout(mut self, timeout: Duration) -> Self {
229 self.handshake_timeout = timeout;
230 self
231 }
232
233 pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
251 self.cleanup_interval = interval;
252 self
253 }
254
255 pub fn idle_timeout(&self) -> Duration {
257 self.idle_timeout
258 }
259
260 pub fn handshake_timeout(&self) -> Duration {
262 self.handshake_timeout
263 }
264
265 pub fn cleanup_interval(&self) -> Duration {
267 self.cleanup_interval
268 }
269
270 pub fn with_max_connections(mut self, max: Option<usize>) -> Self {
287 self.max_connections = max;
288 self
289 }
290
291 pub fn max_connections(&self) -> Option<usize> {
293 self.max_connections
294 }
295
296 pub fn with_ping_config(mut self, config: PingPongConfig) -> Self {
319 self.ping_config = config;
320 self
321 }
322
323 pub fn ping_config(&self) -> &PingPongConfig {
325 &self.ping_config
326 }
327
328 pub fn no_timeout() -> Self {
340 Self {
341 idle_timeout: Duration::MAX,
342 handshake_timeout: Duration::MAX,
343 cleanup_interval: Duration::from_secs(30),
344 max_connections: None,
345 ping_config: PingPongConfig::default(),
346 }
347 }
348
349 pub fn strict() -> Self {
367 Self {
368 idle_timeout: Duration::from_secs(30),
369 handshake_timeout: Duration::from_secs(5),
370 cleanup_interval: Duration::from_secs(10),
371 max_connections: None,
372 ping_config: PingPongConfig::new(Duration::from_secs(10), Duration::from_secs(5)),
373 }
374 }
375
376 pub fn permissive() -> Self {
394 Self {
395 idle_timeout: Duration::from_secs(3600),
396 handshake_timeout: Duration::from_secs(30),
397 cleanup_interval: Duration::from_secs(60),
398 max_connections: None,
399 ping_config: PingPongConfig::new(Duration::from_secs(60), Duration::from_secs(30)),
400 }
401 }
402}
403
404#[derive(Debug, thiserror::Error)]
405pub enum WebSocketError {
406 #[error("Connection error")]
407 Connection(String),
408 #[error("Send failed")]
409 Send(String),
410 #[error("Receive failed")]
411 Receive(String),
412 #[error("Protocol error")]
413 Protocol(String),
414 #[error("Internal error")]
415 Internal(String),
416 #[error("Connection timed out")]
417 Timeout(Duration),
418 #[error("Reconnection failed")]
419 ReconnectFailed(u32),
420 #[error("Invalid binary payload: {0}")]
421 BinaryPayload(String),
422 #[error("Heartbeat timeout: no pong received within {0:?}")]
423 HeartbeatTimeout(Duration),
424 #[error("Slow consumer: send timed out after {0:?}")]
425 SlowConsumer(Duration),
426}
427
428impl WebSocketError {
429 pub fn client_message(&self) -> &'static str {
434 match self {
435 Self::Connection(_) => "Connection error",
436 Self::Send(_) => "Failed to send message",
437 Self::Receive(_) => "Failed to receive message",
438 Self::Protocol(_) => "Protocol error",
439 Self::Internal(_) => "Internal server error",
440 Self::Timeout(_) => "Connection timed out",
441 Self::ReconnectFailed(_) => "Reconnection failed",
442 Self::BinaryPayload(_) => "Invalid message format",
443 Self::HeartbeatTimeout(_) => "Connection timed out",
444 Self::SlowConsumer(_) => "Server overloaded",
445 }
446 }
447
448 pub fn internal_detail(&self) -> String {
453 match self {
454 Self::Connection(msg) => format!("Connection error: {}", msg),
455 Self::Send(msg) => format!("Send error: {}", msg),
456 Self::Receive(msg) => format!("Receive error: {}", msg),
457 Self::Protocol(msg) => format!("Protocol error: {}", msg),
458 Self::Internal(msg) => format!("Internal error: {}", msg),
459 Self::Timeout(d) => format!("Connection timeout: idle for {:?}", d),
460 Self::ReconnectFailed(n) => format!("Reconnection failed after {} attempts", n),
461 Self::BinaryPayload(msg) => format!("Invalid binary payload: {}", msg),
462 Self::HeartbeatTimeout(d) => format!("Heartbeat timeout: no pong within {:?}", d),
463 Self::SlowConsumer(d) => format!("Slow consumer: send timed out after {:?}", d),
464 }
465 }
466}
467
468pub type WebSocketResult<T> = Result<T, WebSocketError>;
469
470#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
472#[serde(tag = "type")]
473pub enum Message {
474 Text { data: String },
475 Binary { data: Vec<u8> },
476 Ping,
477 Pong,
478 Close { code: u16, reason: String },
479}
480
481impl Message {
482 pub fn text(data: String) -> Self {
496 Self::Text { data }
497 }
498 pub fn binary(data: Vec<u8>) -> Self {
513 Self::Binary { data }
514 }
515 pub fn json<T: serde::Serialize>(data: &T) -> WebSocketResult<Self> {
544 let json =
545 serde_json::to_string(data).map_err(|e| WebSocketError::Protocol(e.to_string()))?;
546 Ok(Self::text(json))
547 }
548 pub fn parse_json<T: serde::de::DeserializeOwned>(&self) -> WebSocketResult<T> {
568 match self {
569 Message::Text { data } => {
570 serde_json::from_str(data).map_err(|e| WebSocketError::Protocol(e.to_string()))
571 }
572 _ => Err(WebSocketError::Protocol("Not a text message".to_string())),
573 }
574 }
575}
576
577pub struct WebSocketConnection {
579 id: String,
580 tx: mpsc::UnboundedSender<Message>,
581 closed: Arc<RwLock<bool>>,
582 subprotocol: Option<String>,
584 last_activity: Arc<RwLock<Instant>>,
586 config: ConnectionConfig,
588}
589
590impl WebSocketConnection {
591 pub fn new(id: String, tx: mpsc::UnboundedSender<Message>) -> Self {
606 Self {
607 id,
608 tx,
609 closed: Arc::new(RwLock::new(false)),
610 subprotocol: None,
611 last_activity: Arc::new(RwLock::new(Instant::now())),
612 config: ConnectionConfig::default(),
613 }
614 }
615
616 pub fn with_config(
634 id: String,
635 tx: mpsc::UnboundedSender<Message>,
636 config: ConnectionConfig,
637 ) -> Self {
638 Self {
639 id,
640 tx,
641 closed: Arc::new(RwLock::new(false)),
642 subprotocol: None,
643 last_activity: Arc::new(RwLock::new(Instant::now())),
644 config,
645 }
646 }
647
648 pub fn with_subprotocol(
666 id: String,
667 tx: mpsc::UnboundedSender<Message>,
668 subprotocol: Option<String>,
669 ) -> Self {
670 Self {
671 id,
672 tx,
673 closed: Arc::new(RwLock::new(false)),
674 subprotocol,
675 last_activity: Arc::new(RwLock::new(Instant::now())),
676 config: ConnectionConfig::default(),
677 }
678 }
679
680 pub fn subprotocol(&self) -> Option<&str> {
697 self.subprotocol.as_deref()
698 }
699
700 pub fn id(&self) -> &str {
713 &self.id
714 }
715
716 pub fn config(&self) -> &ConnectionConfig {
718 &self.config
719 }
720
721 pub async fn record_activity(&self) {
742 *self.last_activity.write().await = Instant::now();
743 }
744
745 pub async fn idle_duration(&self) -> Duration {
763 self.last_activity.read().await.elapsed()
764 }
765
766 pub async fn is_idle(&self) -> bool {
783 self.idle_duration().await > self.config.idle_timeout
784 }
785
786 pub async fn send(&self, message: Message) -> WebSocketResult<()> {
808 if *self.closed.read().await {
809 return Err(WebSocketError::Send("Connection closed".to_string()));
810 }
811
812 let result = self
813 .tx
814 .send(message)
815 .map_err(|e| WebSocketError::Send(e.to_string()));
816
817 if result.is_ok() {
818 self.record_activity().await;
819 }
820
821 result
822 }
823 pub async fn send_text(&self, text: String) -> WebSocketResult<()> {
845 self.send(Message::text(text)).await
846 }
847 pub async fn send_binary(&self, data: Vec<u8>) -> WebSocketResult<()> {
870 self.send(Message::binary(data)).await
871 }
872 pub async fn send_json<T: serde::Serialize>(&self, data: &T) -> WebSocketResult<()> {
902 let message = Message::json(data)?;
903 self.send(message).await
904 }
905 pub async fn close(&self) -> WebSocketResult<()> {
926 *self.closed.write().await = true;
928
929 self.tx
931 .send(Message::Close {
932 code: 1000,
933 reason: "Normal closure".to_string(),
934 })
935 .map_err(|e| WebSocketError::Send(e.to_string()))
936 }
937 pub async fn close_with_reason(&self, code: u16, reason: String) -> WebSocketResult<()> {
967 *self.closed.write().await = true;
969
970 self.tx
972 .send(Message::Close { code, reason })
973 .map_err(|e| WebSocketError::Send(e.to_string()))
974 }
975
976 pub async fn force_close(&self) {
996 *self.closed.write().await = true;
997 }
998
999 pub async fn is_closed(&self) -> bool {
1015 *self.closed.read().await
1016 }
1017}
1018
1019pub struct ConnectionTimeoutMonitor {
1049 connections: Arc<RwLock<HashMap<String, Arc<WebSocketConnection>>>>,
1050 config: ConnectionConfig,
1051}
1052
1053impl ConnectionTimeoutMonitor {
1054 pub fn new(config: ConnectionConfig) -> Self {
1056 Self {
1057 connections: Arc::new(RwLock::new(HashMap::new())),
1058 config,
1059 }
1060 }
1061
1062 pub async fn register(
1067 &self,
1068 connection: Arc<WebSocketConnection>,
1069 ) -> Result<(), WebSocketError> {
1070 let mut connections = self.connections.write().await;
1071
1072 if let Some(max) = self.config.max_connections
1073 && connections.len() >= max
1074 {
1075 return Err(WebSocketError::Connection(format!(
1076 "maximum connection limit reached ({})",
1077 max
1078 )));
1079 }
1080
1081 connections.insert(connection.id().to_string(), connection);
1082 Ok(())
1083 }
1084
1085 pub async fn unregister(&self, connection_id: &str) {
1087 self.connections.write().await.remove(connection_id);
1088 }
1089
1090 pub async fn connection_count(&self) -> usize {
1092 self.connections.read().await.len()
1093 }
1094
1095 pub async fn check_idle_connections(&self) -> Vec<String> {
1099 let connections = self.connections.read().await;
1100 let mut timed_out = Vec::new();
1101
1102 for (id, conn) in connections.iter() {
1103 if conn.is_closed().await {
1104 timed_out.push(id.clone());
1105 continue;
1106 }
1107
1108 let idle_duration = conn.idle_duration().await;
1109 if idle_duration > self.config.idle_timeout {
1110 let reason = format!(
1111 "Idle timeout: connection idle for {}s (limit: {}s)",
1112 idle_duration.as_secs(),
1113 self.config.idle_timeout.as_secs()
1114 );
1115 let _ = conn.close_with_reason(1001, reason).await;
1117 timed_out.push(id.clone());
1118 }
1119 }
1120
1121 drop(connections);
1122
1123 if !timed_out.is_empty() {
1125 let mut connections = self.connections.write().await;
1126 for id in &timed_out {
1127 connections.remove(id);
1128 }
1129 }
1130
1131 timed_out
1132 }
1133
1134 pub async fn shutdown_all(&self) -> Vec<String> {
1142 let mut connections = self.connections.write().await;
1143 let mut shut_down = Vec::with_capacity(connections.len());
1144
1145 for (id, conn) in connections.drain() {
1146 if !conn.is_closed().await {
1147 let _ = conn
1148 .close_with_reason(1001, "Server shutting down".to_string())
1149 .await;
1150 }
1151 shut_down.push(id);
1152 }
1153
1154 shut_down
1155 }
1156
1157 pub fn start(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
1183 let monitor = Arc::clone(self);
1184 tokio::spawn(async move {
1185 let mut interval = tokio::time::interval(monitor.config.cleanup_interval);
1186 loop {
1187 interval.tick().await;
1188 monitor.check_idle_connections().await;
1189 }
1190 })
1191 }
1192}
1193
1194#[derive(Debug, Clone)]
1214pub struct HeartbeatConfig {
1215 ping_interval: Duration,
1217 pong_timeout: Duration,
1219}
1220
1221impl HeartbeatConfig {
1222 pub fn new(ping_interval: Duration, pong_timeout: Duration) -> Self {
1224 Self {
1225 ping_interval,
1226 pong_timeout,
1227 }
1228 }
1229
1230 pub fn ping_interval(&self) -> Duration {
1232 self.ping_interval
1233 }
1234
1235 pub fn pong_timeout(&self) -> Duration {
1237 self.pong_timeout
1238 }
1239}
1240
1241impl Default for HeartbeatConfig {
1242 fn default() -> Self {
1243 Self {
1244 ping_interval: Duration::from_secs(30),
1245 pong_timeout: Duration::from_secs(10),
1246 }
1247 }
1248}
1249
1250pub struct HeartbeatMonitor {
1275 connection: Arc<WebSocketConnection>,
1276 config: HeartbeatConfig,
1277 last_pong: Arc<RwLock<Instant>>,
1278 timed_out: Arc<RwLock<bool>>,
1279}
1280
1281impl HeartbeatMonitor {
1282 pub fn new(connection: Arc<WebSocketConnection>, config: HeartbeatConfig) -> Self {
1284 Self {
1285 connection,
1286 config,
1287 last_pong: Arc::new(RwLock::new(Instant::now())),
1288 timed_out: Arc::new(RwLock::new(false)),
1289 }
1290 }
1291
1292 pub async fn record_pong(&self) {
1294 *self.last_pong.write().await = Instant::now();
1295 }
1296
1297 pub async fn time_since_last_pong(&self) -> Duration {
1299 self.last_pong.read().await.elapsed()
1300 }
1301
1302 pub async fn is_timed_out(&self) -> bool {
1304 *self.timed_out.read().await
1305 }
1306
1307 pub async fn check_heartbeat(&self) -> bool {
1312 let since_pong = self.time_since_last_pong().await;
1313
1314 if since_pong > self.config.pong_timeout {
1315 self.connection.force_close().await;
1316 *self.timed_out.write().await = true;
1317 return true;
1318 }
1319
1320 false
1321 }
1322
1323 pub async fn send_ping(&self) -> WebSocketResult<()> {
1328 self.connection.send(Message::Ping).await
1329 }
1330
1331 pub fn config(&self) -> &HeartbeatConfig {
1333 &self.config
1334 }
1335
1336 pub fn connection(&self) -> &Arc<WebSocketConnection> {
1338 &self.connection
1339 }
1340
1341 pub fn start(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
1348 let monitor = Arc::clone(self);
1349 tokio::spawn(async move {
1350 let mut interval = tokio::time::interval(monitor.config.ping_interval);
1351 loop {
1352 interval.tick().await;
1353
1354 if monitor.connection.is_closed().await {
1355 break;
1356 }
1357
1358 let _ = monitor.send_ping().await;
1360
1361 tokio::time::sleep(monitor.config.pong_timeout).await;
1363
1364 if monitor.check_heartbeat().await {
1365 break;
1366 }
1367 }
1368 })
1369 }
1370}
1371
1372#[cfg(test)]
1373mod tests {
1374 use super::*;
1375 use rstest::rstest;
1376
1377 #[rstest]
1378 fn test_message_text() {
1379 let text = "Hello".to_string();
1381
1382 let msg = Message::text(text);
1384
1385 match msg {
1387 Message::Text { data } => assert_eq!(data, "Hello"),
1388 _ => panic!("Expected text message"),
1389 }
1390 }
1391
1392 #[rstest]
1393 fn test_message_json() {
1394 #[derive(serde::Serialize)]
1396 struct TestData {
1397 value: i32,
1398 }
1399 let data = TestData { value: 42 };
1400
1401 let msg = Message::json(&data).unwrap();
1403
1404 match msg {
1406 Message::Text { data } => assert!(data.contains("42")),
1407 _ => panic!("Expected text message"),
1408 }
1409 }
1410
1411 #[rstest]
1412 #[tokio::test]
1413 async fn test_connection_send() {
1414 let (tx, mut rx) = mpsc::unbounded_channel();
1416 let conn = WebSocketConnection::new("test".to_string(), tx);
1417
1418 conn.send_text("Hello".to_string()).await.unwrap();
1420
1421 let received = rx.recv().await.unwrap();
1423 match received {
1424 Message::Text { data } => assert_eq!(data, "Hello"),
1425 _ => panic!("Expected text message"),
1426 }
1427 }
1428
1429 #[rstest]
1430 fn test_connection_config_default() {
1431 let config = ConnectionConfig::new();
1433
1434 assert_eq!(config.idle_timeout(), Duration::from_secs(300));
1436 assert_eq!(config.handshake_timeout(), Duration::from_secs(10));
1437 assert_eq!(config.cleanup_interval(), Duration::from_secs(30));
1438 }
1439
1440 #[rstest]
1441 fn test_connection_config_strict() {
1442 let config = ConnectionConfig::strict();
1444
1445 assert_eq!(config.idle_timeout(), Duration::from_secs(30));
1447 assert_eq!(config.handshake_timeout(), Duration::from_secs(5));
1448 assert_eq!(config.cleanup_interval(), Duration::from_secs(10));
1449 }
1450
1451 #[rstest]
1452 fn test_connection_config_permissive() {
1453 let config = ConnectionConfig::permissive();
1455
1456 assert_eq!(config.idle_timeout(), Duration::from_secs(3600));
1458 assert_eq!(config.handshake_timeout(), Duration::from_secs(30));
1459 assert_eq!(config.cleanup_interval(), Duration::from_secs(60));
1460 }
1461
1462 #[rstest]
1463 fn test_connection_config_no_timeout() {
1464 let config = ConnectionConfig::no_timeout();
1466
1467 assert_eq!(config.idle_timeout(), Duration::MAX);
1469 assert_eq!(config.handshake_timeout(), Duration::MAX);
1470 }
1471
1472 #[rstest]
1473 fn test_connection_config_builder() {
1474 let config = ConnectionConfig::new()
1476 .with_idle_timeout(Duration::from_secs(120))
1477 .with_handshake_timeout(Duration::from_secs(15))
1478 .with_cleanup_interval(Duration::from_secs(20));
1479
1480 assert_eq!(config.idle_timeout(), Duration::from_secs(120));
1482 assert_eq!(config.handshake_timeout(), Duration::from_secs(15));
1483 assert_eq!(config.cleanup_interval(), Duration::from_secs(20));
1484 }
1485
1486 #[rstest]
1487 #[tokio::test]
1488 async fn test_connection_with_config() {
1489 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_secs(60));
1491 let (tx, _rx) = mpsc::unbounded_channel();
1492
1493 let conn = WebSocketConnection::with_config("test".to_string(), tx, config);
1495
1496 assert_eq!(conn.config().idle_timeout(), Duration::from_secs(60));
1498 assert!(!conn.is_idle().await);
1499 }
1500
1501 #[rstest]
1502 #[tokio::test]
1503 async fn test_connection_record_activity_resets_idle() {
1504 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_millis(50));
1506 let (tx, _rx) = mpsc::unbounded_channel();
1507 let conn = WebSocketConnection::with_config("test".to_string(), tx, config);
1508
1509 tokio::time::sleep(Duration::from_millis(60)).await;
1511 assert!(conn.is_idle().await);
1512
1513 conn.record_activity().await;
1515
1516 assert!(!conn.is_idle().await);
1518 }
1519
1520 #[rstest]
1521 #[tokio::test]
1522 async fn test_connection_becomes_idle_after_timeout() {
1523 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_millis(50));
1525 let (tx, _rx) = mpsc::unbounded_channel();
1526 let conn = WebSocketConnection::with_config("test".to_string(), tx, config);
1527
1528 tokio::time::sleep(Duration::from_millis(60)).await;
1530
1531 assert!(conn.is_idle().await);
1533 assert!(conn.idle_duration().await >= Duration::from_millis(50));
1534 }
1535
1536 #[rstest]
1537 #[tokio::test]
1538 async fn test_send_resets_activity() {
1539 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_millis(100));
1541 let (tx, mut _rx) = mpsc::unbounded_channel();
1542 let conn = WebSocketConnection::with_config("test".to_string(), tx, config);
1543
1544 tokio::time::sleep(Duration::from_millis(50)).await;
1546 conn.send_text("ping".to_string()).await.unwrap();
1547
1548 assert!(conn.idle_duration().await < Duration::from_millis(30));
1550 assert!(!conn.is_idle().await);
1551 }
1552
1553 #[rstest]
1554 #[tokio::test]
1555 async fn test_close_with_reason() {
1556 let (tx, mut rx) = mpsc::unbounded_channel();
1558 let conn = WebSocketConnection::new("test".to_string(), tx);
1559
1560 conn.close_with_reason(1001, "Idle timeout".to_string())
1562 .await
1563 .unwrap();
1564
1565 assert!(conn.is_closed().await);
1567 let msg = rx.recv().await.unwrap();
1568 match msg {
1569 Message::Close { code, reason } => {
1570 assert_eq!(code, 1001);
1571 assert_eq!(reason, "Idle timeout");
1572 }
1573 _ => panic!("Expected close message"),
1574 }
1575 }
1576
1577 #[rstest]
1578 #[tokio::test]
1579 async fn test_timeout_monitor_register_and_count() {
1580 let config = ConnectionConfig::new();
1582 let monitor = ConnectionTimeoutMonitor::new(config);
1583 let (tx, _rx) = mpsc::unbounded_channel();
1584 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
1585
1586 monitor.register(conn).await.unwrap();
1588
1589 assert_eq!(monitor.connection_count().await, 1);
1591 }
1592
1593 #[rstest]
1594 #[tokio::test]
1595 async fn test_timeout_monitor_unregister() {
1596 let config = ConnectionConfig::new();
1598 let monitor = ConnectionTimeoutMonitor::new(config);
1599 let (tx, _rx) = mpsc::unbounded_channel();
1600 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
1601 monitor.register(conn).await.unwrap();
1602
1603 monitor.unregister("conn_1").await;
1605
1606 assert_eq!(monitor.connection_count().await, 0);
1608 }
1609
1610 #[rstest]
1611 #[tokio::test]
1612 async fn test_timeout_monitor_closes_idle_connections() {
1613 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_millis(50));
1615 let monitor = ConnectionTimeoutMonitor::new(config);
1616
1617 let (tx1, mut rx1) = mpsc::unbounded_channel();
1618 let conn1 = Arc::new(WebSocketConnection::with_config(
1619 "idle_conn".to_string(),
1620 tx1,
1621 ConnectionConfig::new().with_idle_timeout(Duration::from_millis(50)),
1622 ));
1623
1624 let (tx2, _rx2) = mpsc::unbounded_channel();
1625 let conn2 = Arc::new(WebSocketConnection::with_config(
1626 "active_conn".to_string(),
1627 tx2,
1628 ConnectionConfig::new().with_idle_timeout(Duration::from_secs(300)),
1629 ));
1630
1631 monitor.register(conn1).await.unwrap();
1632 monitor.register(conn2.clone()).await.unwrap();
1633
1634 tokio::time::sleep(Duration::from_millis(60)).await;
1636 conn2.record_activity().await;
1638
1639 let timed_out = monitor.check_idle_connections().await;
1640
1641 assert_eq!(timed_out.len(), 1);
1643 assert_eq!(timed_out[0], "idle_conn");
1644 assert_eq!(monitor.connection_count().await, 1);
1645
1646 let msg = rx1.recv().await.unwrap();
1648 match msg {
1649 Message::Close { code, reason } => {
1650 assert_eq!(code, 1001);
1651 assert!(reason.contains("Idle timeout"));
1652 }
1653 _ => panic!("Expected close message for idle connection"),
1654 }
1655 }
1656
1657 #[rstest]
1658 #[tokio::test]
1659 async fn test_timeout_monitor_removes_already_closed_connections() {
1660 let config = ConnectionConfig::new();
1662 let monitor = ConnectionTimeoutMonitor::new(config);
1663 let (tx, _rx) = mpsc::unbounded_channel();
1664 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
1665 conn.close().await.unwrap();
1666 monitor.register(conn).await.unwrap();
1667
1668 let timed_out = monitor.check_idle_connections().await;
1670
1671 assert_eq!(timed_out.len(), 1);
1673 assert_eq!(timed_out[0], "conn_1");
1674 assert_eq!(monitor.connection_count().await, 0);
1675 }
1676
1677 #[rstest]
1678 #[tokio::test]
1679 async fn test_timeout_monitor_background_task() {
1680 let config = ConnectionConfig::new()
1682 .with_idle_timeout(Duration::from_millis(30))
1683 .with_cleanup_interval(Duration::from_millis(20));
1684 let monitor = Arc::new(ConnectionTimeoutMonitor::new(config));
1685
1686 let (tx, mut rx) = mpsc::unbounded_channel();
1687 let conn = Arc::new(WebSocketConnection::with_config(
1688 "bg_conn".to_string(),
1689 tx,
1690 ConnectionConfig::new().with_idle_timeout(Duration::from_millis(30)),
1691 ));
1692 monitor.register(conn).await.unwrap();
1693
1694 let handle = monitor.start();
1696
1697 tokio::time::sleep(Duration::from_millis(120)).await;
1699
1700 assert_eq!(monitor.connection_count().await, 0);
1702
1703 let msg = rx.recv().await.unwrap();
1705 assert!(matches!(msg, Message::Close { .. }));
1706
1707 handle.abort();
1709 }
1710
1711 #[rstest]
1712 fn test_ping_pong_config_default() {
1713 let config = PingPongConfig::default();
1715
1716 assert_eq!(config.ping_interval(), Duration::from_secs(30));
1718 assert_eq!(config.pong_timeout(), Duration::from_secs(10));
1719 }
1720
1721 #[rstest]
1722 fn test_ping_pong_config_custom() {
1723 let config = PingPongConfig::new(Duration::from_secs(15), Duration::from_secs(5));
1725
1726 assert_eq!(config.ping_interval(), Duration::from_secs(15));
1728 assert_eq!(config.pong_timeout(), Duration::from_secs(5));
1729 }
1730
1731 #[rstest]
1732 fn test_ping_pong_config_builder() {
1733 let config = PingPongConfig::default()
1735 .with_ping_interval(Duration::from_secs(60))
1736 .with_pong_timeout(Duration::from_secs(20));
1737
1738 assert_eq!(config.ping_interval(), Duration::from_secs(60));
1740 assert_eq!(config.pong_timeout(), Duration::from_secs(20));
1741 }
1742
1743 #[rstest]
1744 fn test_connection_config_has_default_ping_config() {
1745 let config = ConnectionConfig::new();
1747
1748 assert_eq!(
1750 config.ping_config().ping_interval(),
1751 Duration::from_secs(30)
1752 );
1753 assert_eq!(config.ping_config().pong_timeout(), Duration::from_secs(10));
1754 }
1755
1756 #[rstest]
1757 fn test_connection_config_with_custom_ping_config() {
1758 let ping_config = PingPongConfig::new(Duration::from_secs(15), Duration::from_secs(5));
1760
1761 let config = ConnectionConfig::new().with_ping_config(ping_config);
1763
1764 assert_eq!(
1766 config.ping_config().ping_interval(),
1767 Duration::from_secs(15)
1768 );
1769 assert_eq!(config.ping_config().pong_timeout(), Duration::from_secs(5));
1770 }
1771
1772 #[rstest]
1773 fn test_strict_config_has_aggressive_ping() {
1774 let config = ConnectionConfig::strict();
1776
1777 assert_eq!(
1779 config.ping_config().ping_interval(),
1780 Duration::from_secs(10)
1781 );
1782 assert_eq!(config.ping_config().pong_timeout(), Duration::from_secs(5));
1783 }
1784
1785 #[rstest]
1786 fn test_permissive_config_has_relaxed_ping() {
1787 let config = ConnectionConfig::permissive();
1789
1790 assert_eq!(
1792 config.ping_config().ping_interval(),
1793 Duration::from_secs(60)
1794 );
1795 assert_eq!(config.ping_config().pong_timeout(), Duration::from_secs(30));
1796 }
1797
1798 #[rstest]
1799 #[tokio::test]
1800 async fn test_timeout_monitor_rejects_when_max_connections_reached() {
1801 let config = ConnectionConfig::new().with_max_connections(Some(1));
1803 let monitor = ConnectionTimeoutMonitor::new(config);
1804
1805 let (tx1, _rx1) = mpsc::unbounded_channel();
1806 let conn1 = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx1));
1807
1808 let (tx2, _rx2) = mpsc::unbounded_channel();
1809 let conn2 = Arc::new(WebSocketConnection::new("conn_2".to_string(), tx2));
1810
1811 monitor.register(conn1).await.unwrap();
1813 let result = monitor.register(conn2).await;
1814
1815 assert!(result.is_err());
1817 assert_eq!(monitor.connection_count().await, 1);
1818 }
1819
1820 #[rstest]
1821 #[tokio::test]
1822 async fn test_force_close_marks_connection_closed() {
1823 let (tx, _rx) = mpsc::unbounded_channel();
1825 let conn = WebSocketConnection::new("test".to_string(), tx);
1826
1827 conn.force_close().await;
1829
1830 assert!(conn.is_closed().await);
1832 }
1833
1834 #[rstest]
1835 #[tokio::test]
1836 async fn test_close_marks_closed_even_when_channel_dropped() {
1837 let (tx, rx) = mpsc::unbounded_channel();
1839 let conn = WebSocketConnection::new("test".to_string(), tx);
1840
1841 drop(rx);
1843
1844 let result = conn.close().await;
1846
1847 assert!(result.is_err()); assert!(conn.is_closed().await); }
1851
1852 #[rstest]
1853 #[tokio::test]
1854 async fn test_close_with_reason_marks_closed_even_when_channel_dropped() {
1855 let (tx, rx) = mpsc::unbounded_channel();
1857 let conn = WebSocketConnection::new("test".to_string(), tx);
1858
1859 drop(rx);
1861
1862 let result = conn
1864 .close_with_reason(1006, "Abnormal close".to_string())
1865 .await;
1866
1867 assert!(result.is_err());
1869 assert!(conn.is_closed().await);
1870 }
1871
1872 #[rstest]
1873 #[tokio::test]
1874 async fn test_send_after_force_close_returns_error() {
1875 let (tx, _rx) = mpsc::unbounded_channel();
1877 let conn = WebSocketConnection::new("test".to_string(), tx);
1878 conn.force_close().await;
1879
1880 let result = conn.send_text("should fail".to_string()).await;
1882
1883 assert!(result.is_err());
1885 assert!(matches!(result.unwrap_err(), WebSocketError::Send(_)));
1886 }
1887
1888 #[rstest]
1889 fn test_heartbeat_config_default() {
1890 let config = HeartbeatConfig::default();
1892
1893 assert_eq!(config.ping_interval(), Duration::from_secs(30));
1895 assert_eq!(config.pong_timeout(), Duration::from_secs(10));
1896 }
1897
1898 #[rstest]
1899 fn test_heartbeat_config_custom() {
1900 let config = HeartbeatConfig::new(Duration::from_secs(15), Duration::from_secs(5));
1902
1903 assert_eq!(config.ping_interval(), Duration::from_secs(15));
1905 assert_eq!(config.pong_timeout(), Duration::from_secs(5));
1906 }
1907
1908 #[rstest]
1909 #[tokio::test]
1910 async fn test_heartbeat_monitor_initial_state() {
1911 let (tx, _rx) = mpsc::unbounded_channel();
1913 let conn = Arc::new(WebSocketConnection::new("hb_test".to_string(), tx));
1914 let config = HeartbeatConfig::default();
1915
1916 let monitor = HeartbeatMonitor::new(conn, config);
1918
1919 assert!(!monitor.is_timed_out().await);
1921 assert!(monitor.time_since_last_pong().await < Duration::from_secs(1));
1922 }
1923
1924 #[rstest]
1925 #[tokio::test]
1926 async fn test_heartbeat_monitor_record_pong_resets_timer() {
1927 let (tx, _rx) = mpsc::unbounded_channel();
1929 let conn = Arc::new(WebSocketConnection::new("hb_pong".to_string(), tx));
1930 let config = HeartbeatConfig::new(Duration::from_millis(50), Duration::from_millis(30));
1931 let monitor = HeartbeatMonitor::new(conn, config);
1932
1933 tokio::time::sleep(Duration::from_millis(20)).await;
1935 monitor.record_pong().await;
1936
1937 assert!(monitor.time_since_last_pong().await < Duration::from_millis(10));
1939 }
1940
1941 #[rstest]
1942 #[tokio::test]
1943 async fn test_heartbeat_monitor_timeout_closes_connection() {
1944 let (tx, _rx) = mpsc::unbounded_channel();
1946 let conn = Arc::new(WebSocketConnection::new("hb_timeout".to_string(), tx));
1947 let config = HeartbeatConfig::new(Duration::from_millis(50), Duration::from_millis(30));
1948 let monitor = HeartbeatMonitor::new(conn.clone(), config);
1949
1950 tokio::time::sleep(Duration::from_millis(40)).await;
1952 let timed_out = monitor.check_heartbeat().await;
1953
1954 assert!(timed_out);
1956 assert!(monitor.is_timed_out().await);
1957 assert!(conn.is_closed().await);
1958 }
1959
1960 #[rstest]
1961 #[tokio::test]
1962 async fn test_heartbeat_monitor_no_timeout_when_pong_received() {
1963 let (tx, _rx) = mpsc::unbounded_channel();
1965 let conn = Arc::new(WebSocketConnection::new("hb_ok".to_string(), tx));
1966 let config = HeartbeatConfig::new(Duration::from_millis(100), Duration::from_millis(50));
1967 let monitor = HeartbeatMonitor::new(conn.clone(), config);
1968
1969 tokio::time::sleep(Duration::from_millis(20)).await;
1971 monitor.record_pong().await;
1972 let timed_out = monitor.check_heartbeat().await;
1973
1974 assert!(!timed_out);
1976 assert!(!monitor.is_timed_out().await);
1977 assert!(!conn.is_closed().await);
1978 }
1979
1980 #[rstest]
1981 #[tokio::test]
1982 async fn test_heartbeat_monitor_send_ping() {
1983 let (tx, mut rx) = mpsc::unbounded_channel();
1985 let conn = Arc::new(WebSocketConnection::new("hb_ping".to_string(), tx));
1986 let config = HeartbeatConfig::default();
1987 let monitor = HeartbeatMonitor::new(conn, config);
1988
1989 monitor.send_ping().await.unwrap();
1991
1992 let msg = rx.recv().await.unwrap();
1994 assert!(matches!(msg, Message::Ping));
1995 }
1996
1997 #[rstest]
1998 fn test_websocket_error_binary_payload_variant() {
1999 let err = WebSocketError::BinaryPayload("invalid data".to_string());
2001
2002 assert_eq!(err.to_string(), "Invalid binary payload: invalid data");
2004 }
2005
2006 #[rstest]
2007 fn test_websocket_error_heartbeat_timeout_variant() {
2008 let err = WebSocketError::HeartbeatTimeout(Duration::from_secs(10));
2010
2011 assert_eq!(
2013 err.to_string(),
2014 "Heartbeat timeout: no pong received within 10s"
2015 );
2016 }
2017
2018 #[rstest]
2019 fn test_websocket_error_slow_consumer_variant() {
2020 let err = WebSocketError::SlowConsumer(Duration::from_secs(5));
2022
2023 assert_eq!(err.to_string(), "Slow consumer: send timed out after 5s");
2025 }
2026}