1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use tokio::sync::RwLock;
5use tokio::sync::mpsc;
6
7#[derive(Debug, Clone)]
33pub struct PingPongConfig {
34 ping_interval: Duration,
36 pong_timeout: Duration,
38}
39
40impl Default for PingPongConfig {
41 fn default() -> Self {
42 Self {
43 ping_interval: Duration::from_secs(30),
44 pong_timeout: Duration::from_secs(10),
45 }
46 }
47}
48
49impl PingPongConfig {
50 pub fn new(ping_interval: Duration, pong_timeout: Duration) -> Self {
71 Self {
72 ping_interval,
73 pong_timeout,
74 }
75 }
76
77 pub fn ping_interval(&self) -> Duration {
79 self.ping_interval
80 }
81
82 pub fn pong_timeout(&self) -> Duration {
84 self.pong_timeout
85 }
86
87 pub fn with_ping_interval(mut self, interval: Duration) -> Self {
100 self.ping_interval = interval;
101 self
102 }
103
104 pub fn with_pong_timeout(mut self, timeout: Duration) -> Self {
117 self.pong_timeout = timeout;
118 self
119 }
120}
121
122#[derive(Debug, Clone)]
149pub struct ConnectionConfig {
150 idle_timeout: Duration,
151 handshake_timeout: Duration,
152 cleanup_interval: Duration,
153 max_connections: Option<usize>,
155 ping_config: PingPongConfig,
157}
158
159impl Default for ConnectionConfig {
160 fn default() -> Self {
161 Self {
162 idle_timeout: Duration::from_secs(300), handshake_timeout: Duration::from_secs(10), cleanup_interval: Duration::from_secs(30), max_connections: None, ping_config: PingPongConfig::default(),
167 }
168 }
169}
170
171impl ConnectionConfig {
172 pub fn new() -> Self {
186 Self::default()
187 }
188
189 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
207 self.idle_timeout = timeout;
208 self
209 }
210
211 pub fn with_handshake_timeout(mut self, timeout: Duration) -> Self {
229 self.handshake_timeout = timeout;
230 self
231 }
232
233 pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
251 self.cleanup_interval = interval;
252 self
253 }
254
255 pub fn idle_timeout(&self) -> Duration {
257 self.idle_timeout
258 }
259
260 pub fn handshake_timeout(&self) -> Duration {
262 self.handshake_timeout
263 }
264
265 pub fn cleanup_interval(&self) -> Duration {
267 self.cleanup_interval
268 }
269
270 pub fn with_max_connections(mut self, max: Option<usize>) -> Self {
287 self.max_connections = max;
288 self
289 }
290
291 pub fn max_connections(&self) -> Option<usize> {
293 self.max_connections
294 }
295
296 pub fn with_ping_config(mut self, config: PingPongConfig) -> Self {
319 self.ping_config = config;
320 self
321 }
322
323 pub fn ping_config(&self) -> &PingPongConfig {
325 &self.ping_config
326 }
327
328 pub fn no_timeout() -> Self {
340 Self {
341 idle_timeout: Duration::MAX,
342 handshake_timeout: Duration::MAX,
343 cleanup_interval: Duration::from_secs(30),
344 max_connections: None,
345 ping_config: PingPongConfig::default(),
346 }
347 }
348
349 pub fn strict() -> Self {
367 Self {
368 idle_timeout: Duration::from_secs(30),
369 handshake_timeout: Duration::from_secs(5),
370 cleanup_interval: Duration::from_secs(10),
371 max_connections: None,
372 ping_config: PingPongConfig::new(Duration::from_secs(10), Duration::from_secs(5)),
373 }
374 }
375
376 pub fn permissive() -> Self {
394 Self {
395 idle_timeout: Duration::from_secs(3600),
396 handshake_timeout: Duration::from_secs(30),
397 cleanup_interval: Duration::from_secs(60),
398 max_connections: None,
399 ping_config: PingPongConfig::new(Duration::from_secs(60), Duration::from_secs(30)),
400 }
401 }
402}
403
404#[derive(Debug, thiserror::Error)]
406pub enum WebSocketError {
407 #[error("Connection error")]
409 Connection(String),
410 #[error("Send failed")]
412 Send(String),
413 #[error("Receive failed")]
415 Receive(String),
416 #[error("Protocol error")]
418 Protocol(String),
419 #[error("Internal error")]
421 Internal(String),
422 #[error("Connection timed out")]
424 Timeout(Duration),
425 #[error("Reconnection failed")]
427 ReconnectFailed(u32),
428 #[error("Invalid binary payload: {0}")]
430 BinaryPayload(String),
431 #[error("Heartbeat timeout: no pong received within {0:?}")]
433 HeartbeatTimeout(Duration),
434 #[error("Slow consumer: send timed out after {0:?}")]
436 SlowConsumer(Duration),
437}
438
439impl WebSocketError {
440 pub fn client_message(&self) -> &'static str {
445 match self {
446 Self::Connection(_) => "Connection error",
447 Self::Send(_) => "Failed to send message",
448 Self::Receive(_) => "Failed to receive message",
449 Self::Protocol(_) => "Protocol error",
450 Self::Internal(_) => "Internal server error",
451 Self::Timeout(_) => "Connection timed out",
452 Self::ReconnectFailed(_) => "Reconnection failed",
453 Self::BinaryPayload(_) => "Invalid message format",
454 Self::HeartbeatTimeout(_) => "Connection timed out",
455 Self::SlowConsumer(_) => "Server overloaded",
456 }
457 }
458
459 pub fn internal_detail(&self) -> String {
464 match self {
465 Self::Connection(msg) => format!("Connection error: {}", msg),
466 Self::Send(msg) => format!("Send error: {}", msg),
467 Self::Receive(msg) => format!("Receive error: {}", msg),
468 Self::Protocol(msg) => format!("Protocol error: {}", msg),
469 Self::Internal(msg) => format!("Internal error: {}", msg),
470 Self::Timeout(d) => format!("Connection timeout: idle for {:?}", d),
471 Self::ReconnectFailed(n) => format!("Reconnection failed after {} attempts", n),
472 Self::BinaryPayload(msg) => format!("Invalid binary payload: {}", msg),
473 Self::HeartbeatTimeout(d) => format!("Heartbeat timeout: no pong within {:?}", d),
474 Self::SlowConsumer(d) => format!("Slow consumer: send timed out after {:?}", d),
475 }
476 }
477}
478
479pub type WebSocketResult<T> = Result<T, WebSocketError>;
481
482#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
484#[serde(tag = "type")]
485pub enum Message {
486 Text {
488 data: String,
490 },
491 Binary {
493 data: Vec<u8>,
495 },
496 Ping,
498 Pong,
500 Close {
502 code: u16,
504 reason: String,
506 },
507}
508
509impl Message {
510 pub fn text(data: String) -> Self {
524 Self::Text { data }
525 }
526 pub fn binary(data: Vec<u8>) -> Self {
541 Self::Binary { data }
542 }
543 pub fn json<T: serde::Serialize>(data: &T) -> WebSocketResult<Self> {
572 let json =
573 serde_json::to_string(data).map_err(|e| WebSocketError::Protocol(e.to_string()))?;
574 Ok(Self::text(json))
575 }
576 pub fn parse_json<T: serde::de::DeserializeOwned>(&self) -> WebSocketResult<T> {
596 match self {
597 Message::Text { data } => {
598 serde_json::from_str(data).map_err(|e| WebSocketError::Protocol(e.to_string()))
599 }
600 _ => Err(WebSocketError::Protocol("Not a text message".to_string())),
601 }
602 }
603}
604
605pub struct WebSocketConnection {
607 id: String,
608 tx: mpsc::UnboundedSender<Message>,
609 closed: Arc<RwLock<bool>>,
610 subprotocol: Option<String>,
612 last_activity: Arc<RwLock<Instant>>,
614 config: ConnectionConfig,
616}
617
618impl WebSocketConnection {
619 pub fn new(id: String, tx: mpsc::UnboundedSender<Message>) -> Self {
634 Self {
635 id,
636 tx,
637 closed: Arc::new(RwLock::new(false)),
638 subprotocol: None,
639 last_activity: Arc::new(RwLock::new(Instant::now())),
640 config: ConnectionConfig::default(),
641 }
642 }
643
644 pub fn with_config(
662 id: String,
663 tx: mpsc::UnboundedSender<Message>,
664 config: ConnectionConfig,
665 ) -> Self {
666 Self {
667 id,
668 tx,
669 closed: Arc::new(RwLock::new(false)),
670 subprotocol: None,
671 last_activity: Arc::new(RwLock::new(Instant::now())),
672 config,
673 }
674 }
675
676 pub fn with_subprotocol(
694 id: String,
695 tx: mpsc::UnboundedSender<Message>,
696 subprotocol: Option<String>,
697 ) -> Self {
698 Self {
699 id,
700 tx,
701 closed: Arc::new(RwLock::new(false)),
702 subprotocol,
703 last_activity: Arc::new(RwLock::new(Instant::now())),
704 config: ConnectionConfig::default(),
705 }
706 }
707
708 pub fn subprotocol(&self) -> Option<&str> {
725 self.subprotocol.as_deref()
726 }
727
728 pub fn id(&self) -> &str {
741 &self.id
742 }
743
744 pub fn config(&self) -> &ConnectionConfig {
746 &self.config
747 }
748
749 pub async fn record_activity(&self) {
770 *self.last_activity.write().await = Instant::now();
771 }
772
773 pub async fn idle_duration(&self) -> Duration {
791 self.last_activity.read().await.elapsed()
792 }
793
794 pub async fn is_idle(&self) -> bool {
811 self.idle_duration().await > self.config.idle_timeout
812 }
813
814 pub async fn send(&self, message: Message) -> WebSocketResult<()> {
836 if *self.closed.read().await {
837 return Err(WebSocketError::Send("Connection closed".to_string()));
838 }
839
840 let result = self
841 .tx
842 .send(message)
843 .map_err(|e| WebSocketError::Send(e.to_string()));
844
845 if result.is_ok() {
846 self.record_activity().await;
847 }
848
849 result
850 }
851 pub async fn send_text(&self, text: String) -> WebSocketResult<()> {
873 self.send(Message::text(text)).await
874 }
875 pub async fn send_binary(&self, data: Vec<u8>) -> WebSocketResult<()> {
898 self.send(Message::binary(data)).await
899 }
900 pub async fn send_json<T: serde::Serialize>(&self, data: &T) -> WebSocketResult<()> {
930 let message = Message::json(data)?;
931 self.send(message).await
932 }
933 pub async fn close(&self) -> WebSocketResult<()> {
954 *self.closed.write().await = true;
956
957 self.tx
959 .send(Message::Close {
960 code: 1000,
961 reason: "Normal closure".to_string(),
962 })
963 .map_err(|e| WebSocketError::Send(e.to_string()))
964 }
965 pub async fn close_with_reason(&self, code: u16, reason: String) -> WebSocketResult<()> {
995 *self.closed.write().await = true;
997
998 self.tx
1000 .send(Message::Close { code, reason })
1001 .map_err(|e| WebSocketError::Send(e.to_string()))
1002 }
1003
1004 pub async fn force_close(&self) {
1024 *self.closed.write().await = true;
1025 }
1026
1027 pub async fn is_closed(&self) -> bool {
1043 *self.closed.read().await
1044 }
1045}
1046
1047pub struct ConnectionTimeoutMonitor {
1077 connections: Arc<RwLock<HashMap<String, Arc<WebSocketConnection>>>>,
1078 config: ConnectionConfig,
1079}
1080
1081impl ConnectionTimeoutMonitor {
1082 pub fn new(config: ConnectionConfig) -> Self {
1084 Self {
1085 connections: Arc::new(RwLock::new(HashMap::new())),
1086 config,
1087 }
1088 }
1089
1090 pub async fn register(
1095 &self,
1096 connection: Arc<WebSocketConnection>,
1097 ) -> Result<(), WebSocketError> {
1098 let mut connections = self.connections.write().await;
1099
1100 if let Some(max) = self.config.max_connections
1101 && connections.len() >= max
1102 {
1103 return Err(WebSocketError::Connection(format!(
1104 "maximum connection limit reached ({})",
1105 max
1106 )));
1107 }
1108
1109 connections.insert(connection.id().to_string(), connection);
1110 Ok(())
1111 }
1112
1113 pub async fn unregister(&self, connection_id: &str) {
1115 self.connections.write().await.remove(connection_id);
1116 }
1117
1118 pub async fn connection_count(&self) -> usize {
1120 self.connections.read().await.len()
1121 }
1122
1123 pub async fn check_idle_connections(&self) -> Vec<String> {
1127 let connections = self.connections.read().await;
1128 let mut timed_out = Vec::new();
1129
1130 for (id, conn) in connections.iter() {
1131 if conn.is_closed().await {
1132 timed_out.push(id.clone());
1133 continue;
1134 }
1135
1136 let idle_duration = conn.idle_duration().await;
1137 if idle_duration > self.config.idle_timeout {
1138 let reason = format!(
1139 "Idle timeout: connection idle for {}s (limit: {}s)",
1140 idle_duration.as_secs(),
1141 self.config.idle_timeout.as_secs()
1142 );
1143 let _ = conn.close_with_reason(1001, reason).await;
1145 timed_out.push(id.clone());
1146 }
1147 }
1148
1149 drop(connections);
1150
1151 if !timed_out.is_empty() {
1153 let mut connections = self.connections.write().await;
1154 for id in &timed_out {
1155 connections.remove(id);
1156 }
1157 }
1158
1159 timed_out
1160 }
1161
1162 pub async fn shutdown_all(&self) -> Vec<String> {
1170 let mut connections = self.connections.write().await;
1171 let mut shut_down = Vec::with_capacity(connections.len());
1172
1173 for (id, conn) in connections.drain() {
1174 if !conn.is_closed().await {
1175 let _ = conn
1176 .close_with_reason(1001, "Server shutting down".to_string())
1177 .await;
1178 }
1179 shut_down.push(id);
1180 }
1181
1182 shut_down
1183 }
1184
1185 pub fn start(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
1211 let monitor = Arc::clone(self);
1212 tokio::spawn(async move {
1213 let mut interval = tokio::time::interval(monitor.config.cleanup_interval);
1214 loop {
1215 interval.tick().await;
1216 monitor.check_idle_connections().await;
1217 }
1218 })
1219 }
1220}
1221
1222#[derive(Debug, Clone)]
1242pub struct HeartbeatConfig {
1243 ping_interval: Duration,
1245 pong_timeout: Duration,
1247}
1248
1249impl HeartbeatConfig {
1250 pub fn new(ping_interval: Duration, pong_timeout: Duration) -> Self {
1252 Self {
1253 ping_interval,
1254 pong_timeout,
1255 }
1256 }
1257
1258 pub fn ping_interval(&self) -> Duration {
1260 self.ping_interval
1261 }
1262
1263 pub fn pong_timeout(&self) -> Duration {
1265 self.pong_timeout
1266 }
1267}
1268
1269impl Default for HeartbeatConfig {
1270 fn default() -> Self {
1271 Self {
1272 ping_interval: Duration::from_secs(30),
1273 pong_timeout: Duration::from_secs(10),
1274 }
1275 }
1276}
1277
1278pub struct HeartbeatMonitor {
1303 connection: Arc<WebSocketConnection>,
1304 config: HeartbeatConfig,
1305 last_pong: Arc<RwLock<Instant>>,
1306 timed_out: Arc<RwLock<bool>>,
1307 pong_notify: Arc<tokio::sync::Notify>,
1308}
1309
1310impl HeartbeatMonitor {
1311 pub fn new(connection: Arc<WebSocketConnection>, config: HeartbeatConfig) -> Self {
1313 Self {
1314 connection,
1315 config,
1316 last_pong: Arc::new(RwLock::new(Instant::now())),
1317 timed_out: Arc::new(RwLock::new(false)),
1318 pong_notify: Arc::new(tokio::sync::Notify::new()),
1319 }
1320 }
1321
1322 pub async fn record_pong(&self) {
1327 *self.last_pong.write().await = Instant::now();
1328 self.pong_notify.notify_one();
1329 }
1330
1331 pub async fn time_since_last_pong(&self) -> Duration {
1333 self.last_pong.read().await.elapsed()
1334 }
1335
1336 pub async fn is_timed_out(&self) -> bool {
1338 *self.timed_out.read().await
1339 }
1340
1341 pub async fn check_heartbeat(&self) -> bool {
1346 let since_pong = self.time_since_last_pong().await;
1347
1348 if since_pong > self.config.pong_timeout {
1349 self.connection.force_close().await;
1350 *self.timed_out.write().await = true;
1351 return true;
1352 }
1353
1354 false
1355 }
1356
1357 pub async fn send_ping(&self) -> WebSocketResult<()> {
1362 self.connection.send(Message::Ping).await
1363 }
1364
1365 pub fn config(&self) -> &HeartbeatConfig {
1367 &self.config
1368 }
1369
1370 pub fn connection(&self) -> &Arc<WebSocketConnection> {
1372 &self.connection
1373 }
1374
1375 pub fn start(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
1382 let monitor = Arc::clone(self);
1383 tokio::spawn(async move {
1384 let mut interval = tokio::time::interval(monitor.config.ping_interval);
1385 loop {
1386 interval.tick().await;
1387
1388 if monitor.connection.is_closed().await {
1389 break;
1390 }
1391
1392 let _ = monitor.send_ping().await;
1394
1395 tokio::select! {
1399 () = tokio::time::sleep(monitor.config.pong_timeout) => {
1400 if monitor.check_heartbeat().await {
1402 break;
1403 }
1404 }
1405 () = monitor.pong_notify.notified() => {
1406 }
1408 }
1409 }
1410 })
1411 }
1412}
1413
1414#[cfg(test)]
1415mod tests {
1416 use super::*;
1417 use rstest::rstest;
1418
1419 #[rstest]
1420 fn test_message_text() {
1421 let text = "Hello".to_string();
1423
1424 let msg = Message::text(text);
1426
1427 match msg {
1429 Message::Text { data } => assert_eq!(data, "Hello"),
1430 _ => panic!("Expected text message"),
1431 }
1432 }
1433
1434 #[rstest]
1435 fn test_message_json() {
1436 #[derive(serde::Serialize)]
1438 struct TestData {
1439 value: i32,
1440 }
1441 let data = TestData { value: 42 };
1442
1443 let msg = Message::json(&data).unwrap();
1445
1446 match msg {
1448 Message::Text { data } => assert!(data.contains("42")),
1449 _ => panic!("Expected text message"),
1450 }
1451 }
1452
1453 #[rstest]
1454 #[tokio::test]
1455 async fn test_connection_send() {
1456 let (tx, mut rx) = mpsc::unbounded_channel();
1458 let conn = WebSocketConnection::new("test".to_string(), tx);
1459
1460 conn.send_text("Hello".to_string()).await.unwrap();
1462
1463 let received = rx.recv().await.unwrap();
1465 match received {
1466 Message::Text { data } => assert_eq!(data, "Hello"),
1467 _ => panic!("Expected text message"),
1468 }
1469 }
1470
1471 #[rstest]
1472 fn test_connection_config_default() {
1473 let config = ConnectionConfig::new();
1475
1476 assert_eq!(config.idle_timeout(), Duration::from_secs(300));
1478 assert_eq!(config.handshake_timeout(), Duration::from_secs(10));
1479 assert_eq!(config.cleanup_interval(), Duration::from_secs(30));
1480 }
1481
1482 #[rstest]
1483 fn test_connection_config_strict() {
1484 let config = ConnectionConfig::strict();
1486
1487 assert_eq!(config.idle_timeout(), Duration::from_secs(30));
1489 assert_eq!(config.handshake_timeout(), Duration::from_secs(5));
1490 assert_eq!(config.cleanup_interval(), Duration::from_secs(10));
1491 }
1492
1493 #[rstest]
1494 fn test_connection_config_permissive() {
1495 let config = ConnectionConfig::permissive();
1497
1498 assert_eq!(config.idle_timeout(), Duration::from_secs(3600));
1500 assert_eq!(config.handshake_timeout(), Duration::from_secs(30));
1501 assert_eq!(config.cleanup_interval(), Duration::from_secs(60));
1502 }
1503
1504 #[rstest]
1505 fn test_connection_config_no_timeout() {
1506 let config = ConnectionConfig::no_timeout();
1508
1509 assert_eq!(config.idle_timeout(), Duration::MAX);
1511 assert_eq!(config.handshake_timeout(), Duration::MAX);
1512 }
1513
1514 #[rstest]
1515 fn test_connection_config_builder() {
1516 let config = ConnectionConfig::new()
1518 .with_idle_timeout(Duration::from_secs(120))
1519 .with_handshake_timeout(Duration::from_secs(15))
1520 .with_cleanup_interval(Duration::from_secs(20));
1521
1522 assert_eq!(config.idle_timeout(), Duration::from_secs(120));
1524 assert_eq!(config.handshake_timeout(), Duration::from_secs(15));
1525 assert_eq!(config.cleanup_interval(), Duration::from_secs(20));
1526 }
1527
1528 #[rstest]
1529 #[tokio::test]
1530 async fn test_connection_with_config() {
1531 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_secs(60));
1533 let (tx, _rx) = mpsc::unbounded_channel();
1534
1535 let conn = WebSocketConnection::with_config("test".to_string(), tx, config);
1537
1538 assert_eq!(conn.config().idle_timeout(), Duration::from_secs(60));
1540 assert!(!conn.is_idle().await);
1541 }
1542
1543 #[rstest]
1544 #[tokio::test]
1545 async fn test_connection_record_activity_resets_idle() {
1546 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_millis(50));
1548 let (tx, _rx) = mpsc::unbounded_channel();
1549 let conn = WebSocketConnection::with_config("test".to_string(), tx, config);
1550
1551 tokio::time::sleep(Duration::from_millis(60)).await;
1553 assert!(conn.is_idle().await);
1554
1555 conn.record_activity().await;
1557
1558 assert!(!conn.is_idle().await);
1560 }
1561
1562 #[rstest]
1563 #[tokio::test]
1564 async fn test_connection_becomes_idle_after_timeout() {
1565 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_millis(50));
1567 let (tx, _rx) = mpsc::unbounded_channel();
1568 let conn = WebSocketConnection::with_config("test".to_string(), tx, config);
1569
1570 tokio::time::sleep(Duration::from_millis(60)).await;
1572
1573 assert!(conn.is_idle().await);
1575 assert!(conn.idle_duration().await >= Duration::from_millis(50));
1576 }
1577
1578 #[rstest]
1579 #[tokio::test]
1580 async fn test_send_resets_activity() {
1581 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_millis(100));
1583 let (tx, mut _rx) = mpsc::unbounded_channel();
1584 let conn = WebSocketConnection::with_config("test".to_string(), tx, config);
1585
1586 tokio::time::sleep(Duration::from_millis(50)).await;
1588 conn.send_text("ping".to_string()).await.unwrap();
1589
1590 assert!(conn.idle_duration().await < Duration::from_millis(30));
1592 assert!(!conn.is_idle().await);
1593 }
1594
1595 #[rstest]
1596 #[tokio::test]
1597 async fn test_close_with_reason() {
1598 let (tx, mut rx) = mpsc::unbounded_channel();
1600 let conn = WebSocketConnection::new("test".to_string(), tx);
1601
1602 conn.close_with_reason(1001, "Idle timeout".to_string())
1604 .await
1605 .unwrap();
1606
1607 assert!(conn.is_closed().await);
1609 let msg = rx.recv().await.unwrap();
1610 match msg {
1611 Message::Close { code, reason } => {
1612 assert_eq!(code, 1001);
1613 assert_eq!(reason, "Idle timeout");
1614 }
1615 _ => panic!("Expected close message"),
1616 }
1617 }
1618
1619 #[rstest]
1620 #[tokio::test]
1621 async fn test_timeout_monitor_register_and_count() {
1622 let config = ConnectionConfig::new();
1624 let monitor = ConnectionTimeoutMonitor::new(config);
1625 let (tx, _rx) = mpsc::unbounded_channel();
1626 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
1627
1628 monitor.register(conn).await.unwrap();
1630
1631 assert_eq!(monitor.connection_count().await, 1);
1633 }
1634
1635 #[rstest]
1636 #[tokio::test]
1637 async fn test_timeout_monitor_unregister() {
1638 let config = ConnectionConfig::new();
1640 let monitor = ConnectionTimeoutMonitor::new(config);
1641 let (tx, _rx) = mpsc::unbounded_channel();
1642 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
1643 monitor.register(conn).await.unwrap();
1644
1645 monitor.unregister("conn_1").await;
1647
1648 assert_eq!(monitor.connection_count().await, 0);
1650 }
1651
1652 #[rstest]
1653 #[tokio::test]
1654 async fn test_timeout_monitor_closes_idle_connections() {
1655 let config = ConnectionConfig::new().with_idle_timeout(Duration::from_millis(50));
1657 let monitor = ConnectionTimeoutMonitor::new(config);
1658
1659 let (tx1, mut rx1) = mpsc::unbounded_channel();
1660 let conn1 = Arc::new(WebSocketConnection::with_config(
1661 "idle_conn".to_string(),
1662 tx1,
1663 ConnectionConfig::new().with_idle_timeout(Duration::from_millis(50)),
1664 ));
1665
1666 let (tx2, _rx2) = mpsc::unbounded_channel();
1667 let conn2 = Arc::new(WebSocketConnection::with_config(
1668 "active_conn".to_string(),
1669 tx2,
1670 ConnectionConfig::new().with_idle_timeout(Duration::from_secs(300)),
1671 ));
1672
1673 monitor.register(conn1).await.unwrap();
1674 monitor.register(conn2.clone()).await.unwrap();
1675
1676 tokio::time::sleep(Duration::from_millis(60)).await;
1678 conn2.record_activity().await;
1680
1681 let timed_out = monitor.check_idle_connections().await;
1682
1683 assert_eq!(timed_out.len(), 1);
1685 assert_eq!(timed_out[0], "idle_conn");
1686 assert_eq!(monitor.connection_count().await, 1);
1687
1688 let msg = rx1.recv().await.unwrap();
1690 match msg {
1691 Message::Close { code, reason } => {
1692 assert_eq!(code, 1001);
1693 assert!(reason.contains("Idle timeout"));
1694 }
1695 _ => panic!("Expected close message for idle connection"),
1696 }
1697 }
1698
1699 #[rstest]
1700 #[tokio::test]
1701 async fn test_timeout_monitor_removes_already_closed_connections() {
1702 let config = ConnectionConfig::new();
1704 let monitor = ConnectionTimeoutMonitor::new(config);
1705 let (tx, _rx) = mpsc::unbounded_channel();
1706 let conn = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx));
1707 conn.close().await.unwrap();
1708 monitor.register(conn).await.unwrap();
1709
1710 let timed_out = monitor.check_idle_connections().await;
1712
1713 assert_eq!(timed_out.len(), 1);
1715 assert_eq!(timed_out[0], "conn_1");
1716 assert_eq!(monitor.connection_count().await, 0);
1717 }
1718
1719 #[rstest]
1720 #[tokio::test]
1721 async fn test_timeout_monitor_background_task() {
1722 let config = ConnectionConfig::new()
1724 .with_idle_timeout(Duration::from_millis(30))
1725 .with_cleanup_interval(Duration::from_millis(20));
1726 let monitor = Arc::new(ConnectionTimeoutMonitor::new(config));
1727
1728 let (tx, mut rx) = mpsc::unbounded_channel();
1729 let conn = Arc::new(WebSocketConnection::with_config(
1730 "bg_conn".to_string(),
1731 tx,
1732 ConnectionConfig::new().with_idle_timeout(Duration::from_millis(30)),
1733 ));
1734 monitor.register(conn).await.unwrap();
1735
1736 let handle = monitor.start();
1738
1739 tokio::time::sleep(Duration::from_millis(120)).await;
1741
1742 assert_eq!(monitor.connection_count().await, 0);
1744
1745 let msg = rx.recv().await.unwrap();
1747 assert!(matches!(msg, Message::Close { .. }));
1748
1749 handle.abort();
1751 }
1752
1753 #[rstest]
1754 fn test_ping_pong_config_default() {
1755 let config = PingPongConfig::default();
1757
1758 assert_eq!(config.ping_interval(), Duration::from_secs(30));
1760 assert_eq!(config.pong_timeout(), Duration::from_secs(10));
1761 }
1762
1763 #[rstest]
1764 fn test_ping_pong_config_custom() {
1765 let config = PingPongConfig::new(Duration::from_secs(15), Duration::from_secs(5));
1767
1768 assert_eq!(config.ping_interval(), Duration::from_secs(15));
1770 assert_eq!(config.pong_timeout(), Duration::from_secs(5));
1771 }
1772
1773 #[rstest]
1774 fn test_ping_pong_config_builder() {
1775 let config = PingPongConfig::default()
1777 .with_ping_interval(Duration::from_secs(60))
1778 .with_pong_timeout(Duration::from_secs(20));
1779
1780 assert_eq!(config.ping_interval(), Duration::from_secs(60));
1782 assert_eq!(config.pong_timeout(), Duration::from_secs(20));
1783 }
1784
1785 #[rstest]
1786 fn test_connection_config_has_default_ping_config() {
1787 let config = ConnectionConfig::new();
1789
1790 assert_eq!(
1792 config.ping_config().ping_interval(),
1793 Duration::from_secs(30)
1794 );
1795 assert_eq!(config.ping_config().pong_timeout(), Duration::from_secs(10));
1796 }
1797
1798 #[rstest]
1799 fn test_connection_config_with_custom_ping_config() {
1800 let ping_config = PingPongConfig::new(Duration::from_secs(15), Duration::from_secs(5));
1802
1803 let config = ConnectionConfig::new().with_ping_config(ping_config);
1805
1806 assert_eq!(
1808 config.ping_config().ping_interval(),
1809 Duration::from_secs(15)
1810 );
1811 assert_eq!(config.ping_config().pong_timeout(), Duration::from_secs(5));
1812 }
1813
1814 #[rstest]
1815 fn test_strict_config_has_aggressive_ping() {
1816 let config = ConnectionConfig::strict();
1818
1819 assert_eq!(
1821 config.ping_config().ping_interval(),
1822 Duration::from_secs(10)
1823 );
1824 assert_eq!(config.ping_config().pong_timeout(), Duration::from_secs(5));
1825 }
1826
1827 #[rstest]
1828 fn test_permissive_config_has_relaxed_ping() {
1829 let config = ConnectionConfig::permissive();
1831
1832 assert_eq!(
1834 config.ping_config().ping_interval(),
1835 Duration::from_secs(60)
1836 );
1837 assert_eq!(config.ping_config().pong_timeout(), Duration::from_secs(30));
1838 }
1839
1840 #[rstest]
1841 #[tokio::test]
1842 async fn test_timeout_monitor_rejects_when_max_connections_reached() {
1843 let config = ConnectionConfig::new().with_max_connections(Some(1));
1845 let monitor = ConnectionTimeoutMonitor::new(config);
1846
1847 let (tx1, _rx1) = mpsc::unbounded_channel();
1848 let conn1 = Arc::new(WebSocketConnection::new("conn_1".to_string(), tx1));
1849
1850 let (tx2, _rx2) = mpsc::unbounded_channel();
1851 let conn2 = Arc::new(WebSocketConnection::new("conn_2".to_string(), tx2));
1852
1853 monitor.register(conn1).await.unwrap();
1855 let result = monitor.register(conn2).await;
1856
1857 assert!(result.is_err());
1859 assert_eq!(monitor.connection_count().await, 1);
1860 }
1861
1862 #[rstest]
1863 #[tokio::test]
1864 async fn test_force_close_marks_connection_closed() {
1865 let (tx, _rx) = mpsc::unbounded_channel();
1867 let conn = WebSocketConnection::new("test".to_string(), tx);
1868
1869 conn.force_close().await;
1871
1872 assert!(conn.is_closed().await);
1874 }
1875
1876 #[rstest]
1877 #[tokio::test]
1878 async fn test_close_marks_closed_even_when_channel_dropped() {
1879 let (tx, rx) = mpsc::unbounded_channel();
1881 let conn = WebSocketConnection::new("test".to_string(), tx);
1882
1883 drop(rx);
1885
1886 let result = conn.close().await;
1888
1889 assert!(result.is_err()); assert!(conn.is_closed().await); }
1893
1894 #[rstest]
1895 #[tokio::test]
1896 async fn test_close_with_reason_marks_closed_even_when_channel_dropped() {
1897 let (tx, rx) = mpsc::unbounded_channel();
1899 let conn = WebSocketConnection::new("test".to_string(), tx);
1900
1901 drop(rx);
1903
1904 let result = conn
1906 .close_with_reason(1006, "Abnormal close".to_string())
1907 .await;
1908
1909 assert!(result.is_err());
1911 assert!(conn.is_closed().await);
1912 }
1913
1914 #[rstest]
1915 #[tokio::test]
1916 async fn test_send_after_force_close_returns_error() {
1917 let (tx, _rx) = mpsc::unbounded_channel();
1919 let conn = WebSocketConnection::new("test".to_string(), tx);
1920 conn.force_close().await;
1921
1922 let result = conn.send_text("should fail".to_string()).await;
1924
1925 assert!(result.is_err());
1927 assert!(matches!(result.unwrap_err(), WebSocketError::Send(_)));
1928 }
1929
1930 #[rstest]
1931 fn test_heartbeat_config_default() {
1932 let config = HeartbeatConfig::default();
1934
1935 assert_eq!(config.ping_interval(), Duration::from_secs(30));
1937 assert_eq!(config.pong_timeout(), Duration::from_secs(10));
1938 }
1939
1940 #[rstest]
1941 fn test_heartbeat_config_custom() {
1942 let config = HeartbeatConfig::new(Duration::from_secs(15), Duration::from_secs(5));
1944
1945 assert_eq!(config.ping_interval(), Duration::from_secs(15));
1947 assert_eq!(config.pong_timeout(), Duration::from_secs(5));
1948 }
1949
1950 #[rstest]
1951 #[tokio::test]
1952 async fn test_heartbeat_monitor_initial_state() {
1953 let (tx, _rx) = mpsc::unbounded_channel();
1955 let conn = Arc::new(WebSocketConnection::new("hb_test".to_string(), tx));
1956 let config = HeartbeatConfig::default();
1957
1958 let monitor = HeartbeatMonitor::new(conn, config);
1960
1961 assert!(!monitor.is_timed_out().await);
1963 assert!(monitor.time_since_last_pong().await < Duration::from_secs(1));
1964 }
1965
1966 #[rstest]
1967 #[tokio::test]
1968 async fn test_heartbeat_monitor_record_pong_resets_timer() {
1969 let (tx, _rx) = mpsc::unbounded_channel();
1971 let conn = Arc::new(WebSocketConnection::new("hb_pong".to_string(), tx));
1972 let config = HeartbeatConfig::new(Duration::from_millis(50), Duration::from_millis(30));
1973 let monitor = HeartbeatMonitor::new(conn, config);
1974
1975 tokio::time::sleep(Duration::from_millis(20)).await;
1977 monitor.record_pong().await;
1978
1979 assert!(monitor.time_since_last_pong().await < Duration::from_millis(10));
1981 }
1982
1983 #[rstest]
1984 #[tokio::test]
1985 async fn test_heartbeat_monitor_timeout_closes_connection() {
1986 let (tx, _rx) = mpsc::unbounded_channel();
1988 let conn = Arc::new(WebSocketConnection::new("hb_timeout".to_string(), tx));
1989 let config = HeartbeatConfig::new(Duration::from_millis(50), Duration::from_millis(30));
1990 let monitor = HeartbeatMonitor::new(conn.clone(), config);
1991
1992 tokio::time::sleep(Duration::from_millis(40)).await;
1994 let timed_out = monitor.check_heartbeat().await;
1995
1996 assert!(timed_out);
1998 assert!(monitor.is_timed_out().await);
1999 assert!(conn.is_closed().await);
2000 }
2001
2002 #[rstest]
2003 #[tokio::test]
2004 async fn test_heartbeat_monitor_no_timeout_when_pong_received() {
2005 let (tx, _rx) = mpsc::unbounded_channel();
2007 let conn = Arc::new(WebSocketConnection::new("hb_ok".to_string(), tx));
2008 let config = HeartbeatConfig::new(Duration::from_millis(100), Duration::from_millis(50));
2009 let monitor = HeartbeatMonitor::new(conn.clone(), config);
2010
2011 tokio::time::sleep(Duration::from_millis(20)).await;
2013 monitor.record_pong().await;
2014 let timed_out = monitor.check_heartbeat().await;
2015
2016 assert!(!timed_out);
2018 assert!(!monitor.is_timed_out().await);
2019 assert!(!conn.is_closed().await);
2020 }
2021
2022 #[rstest]
2023 #[tokio::test]
2024 async fn test_heartbeat_monitor_send_ping() {
2025 let (tx, mut rx) = mpsc::unbounded_channel();
2027 let conn = Arc::new(WebSocketConnection::new("hb_ping".to_string(), tx));
2028 let config = HeartbeatConfig::default();
2029 let monitor = HeartbeatMonitor::new(conn, config);
2030
2031 monitor.send_ping().await.unwrap();
2033
2034 let msg = rx.recv().await.unwrap();
2036 assert!(matches!(msg, Message::Ping));
2037 }
2038
2039 #[rstest]
2040 #[tokio::test]
2041 async fn test_heartbeat_monitor_early_pong_skips_full_sleep() {
2042 let (tx, _rx) = mpsc::unbounded_channel();
2044 let conn = Arc::new(WebSocketConnection::new("hb_early".to_string(), tx));
2045 let config = HeartbeatConfig {
2047 ping_interval: Duration::from_secs(60),
2048 pong_timeout: Duration::from_secs(10),
2049 };
2050 let monitor = Arc::new(HeartbeatMonitor::new(conn, config));
2051
2052 let monitor_clone = Arc::clone(&monitor);
2054 tokio::spawn(async move {
2055 tokio::time::sleep(Duration::from_millis(50)).await;
2056 monitor_clone.record_pong().await;
2057 });
2058
2059 let _ = monitor.send_ping().await;
2061 let start = Instant::now();
2062
2063 tokio::select! {
2064 () = tokio::time::sleep(monitor.config.pong_timeout) => {
2065 panic!("Should not reach full timeout");
2066 }
2067 () = monitor.pong_notify.notified() => {
2068 }
2070 }
2071
2072 let elapsed = start.elapsed();
2074 assert!(
2075 elapsed < Duration::from_secs(2),
2076 "Expected early wakeup but elapsed {:?}",
2077 elapsed
2078 );
2079 }
2080
2081 #[rstest]
2082 fn test_websocket_error_binary_payload_variant() {
2083 let err = WebSocketError::BinaryPayload("invalid data".to_string());
2085
2086 assert_eq!(err.to_string(), "Invalid binary payload: invalid data");
2088 }
2089
2090 #[rstest]
2091 fn test_websocket_error_heartbeat_timeout_variant() {
2092 let err = WebSocketError::HeartbeatTimeout(Duration::from_secs(10));
2094
2095 assert_eq!(
2097 err.to_string(),
2098 "Heartbeat timeout: no pong received within 10s"
2099 );
2100 }
2101
2102 #[rstest]
2103 fn test_websocket_error_slow_consumer_variant() {
2104 let err = WebSocketError::SlowConsumer(Duration::from_secs(5));
2106
2107 assert_eq!(err.to_string(), "Slow consumer: send timed out after 5s");
2109 }
2110}