#[cfg(feature = "websocket")]
#[cfg(test)]
mod websocket_tests {
use crate::error::TidewayError;
use crate::websocket::{Connection, ConnectionManager, Message};
use std::sync::Arc;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_connection_manager_register_unregister() {
let manager = ConnectionManager::new();
let (tx, _rx) = mpsc::channel::<Message>(16);
let conn = Arc::new(tokio::sync::RwLock::new(Connection::new(
"test-conn-1".to_string(),
tx,
)));
assert!(manager.register(conn.clone()).await.is_ok());
assert_eq!(manager.connection_count(), 1);
manager.unregister("test-conn-1").await;
assert_eq!(manager.connection_count(), 0);
}
#[tokio::test]
async fn test_room_operations() {
let manager = Arc::new(ConnectionManager::new());
let (tx1, _rx1) = mpsc::channel::<Message>(16);
let conn1 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-1".to_string(),
tx1,
)));
let (tx2, _rx2) = mpsc::channel::<Message>(16);
let conn2 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-2".to_string(),
tx2,
)));
assert!(manager.register(conn1.clone()).await.is_ok());
assert!(manager.register(conn2.clone()).await.is_ok());
manager.add_to_room("conn-1", "room-1");
manager.add_to_room("conn-2", "room-1");
let members = manager.room_members("room-1");
assert_eq!(members.len(), 2);
assert!(members.contains(&"conn-1".to_string()));
assert!(members.contains(&"conn-2".to_string()));
manager.remove_from_room("conn-1", "room-1");
let members = manager.room_members("room-1");
assert_eq!(members.len(), 1);
assert_eq!(members[0], "conn-2");
}
#[tokio::test]
async fn test_broadcast_to_room() {
let manager = Arc::new(ConnectionManager::new());
let (tx1, mut rx1) = mpsc::channel::<Message>(16);
let conn1 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-1".to_string(),
tx1,
)));
let (tx2, mut rx2) = mpsc::channel::<Message>(16);
let conn2 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-2".to_string(),
tx2,
)));
assert!(manager.register(conn1.clone()).await.is_ok());
assert!(manager.register(conn2.clone()).await.is_ok());
manager.add_to_room("conn-1", "room-1");
manager.add_to_room("conn-2", "room-1");
manager
.broadcast_text_to_room("room-1", "Hello room!")
.await
.unwrap();
let msg1 = rx1.recv().await.unwrap();
let msg2 = rx2.recv().await.unwrap();
assert!(matches!(msg1, Message::Text(ref t) if t == "Hello room!"));
assert!(matches!(msg2, Message::Text(ref t) if t == "Hello room!"));
}
#[tokio::test]
async fn test_broadcast_does_not_unregister_on_lock_contention() {
use tokio::time::{Duration, timeout};
let manager = Arc::new(ConnectionManager::new());
let (tx, mut rx) = mpsc::channel::<Message>(16);
let conn = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-1".to_string(),
tx,
)));
assert!(manager.register(conn.clone()).await.is_ok());
assert_eq!(manager.connection_count(), 1);
let (lock_acquired_tx, lock_acquired_rx) = tokio::sync::oneshot::channel();
let conn_for_lock = conn.clone();
let hold_lock = tokio::spawn(async move {
let _guard = conn_for_lock.write().await;
let _ = lock_acquired_tx.send(());
tokio::time::sleep(Duration::from_millis(50)).await;
});
let _ = lock_acquired_rx.await;
assert!(manager.broadcast_text("hello").await.is_ok());
hold_lock.await.unwrap();
let received = timeout(Duration::from_millis(250), rx.recv())
.await
.expect("message should be delivered")
.expect("channel should still be open");
assert!(matches!(received, Message::Text(ref t) if t == "hello"));
assert_eq!(manager.connection_count(), 1);
}
#[tokio::test]
async fn test_unregister_cleans_stale_room_membership_after_lock_contention() {
let manager = Arc::new(ConnectionManager::new());
let (tx, _rx) = mpsc::channel::<Message>(16);
let conn = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-1".to_string(),
tx,
)));
assert!(manager.register(conn.clone()).await.is_ok());
let conn_write_guard = conn.write().await;
manager.add_to_room("conn-1", "room-1");
drop(conn_write_guard);
assert_eq!(manager.room_members("room-1").len(), 1);
manager.unregister("conn-1").await;
assert!(manager.room_members("room-1").is_empty());
assert_eq!(manager.room_count(), 0);
}
#[tokio::test]
async fn test_connection_limit_enforced() {
let manager = ConnectionManager::with_max_connections(2);
let (tx1, _rx1) = mpsc::channel::<Message>(16);
let conn1 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-1".to_string(),
tx1,
)));
let (tx2, _rx2) = mpsc::channel::<Message>(16);
let conn2 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-2".to_string(),
tx2,
)));
let (tx3, _rx3) = mpsc::channel::<Message>(16);
let conn3 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-3".to_string(),
tx3,
)));
assert!(manager.register(conn1).await.is_ok());
assert!(manager.register(conn2).await.is_ok());
assert_eq!(manager.connection_count(), 2);
let result = manager.register(conn3).await;
assert!(result.is_err());
assert_eq!(manager.connection_count(), 2);
if let Err(TidewayError::ServiceUnavailable(msg)) = result {
assert!(msg.contains("connection limit"));
} else {
panic!("Expected ServiceUnavailable error");
}
}
#[tokio::test]
async fn test_connection_limit_after_unregister() {
let manager = ConnectionManager::with_max_connections(1);
let (tx1, _rx1) = mpsc::channel::<Message>(16);
let conn1 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-1".to_string(),
tx1,
)));
let (tx2, _rx2) = mpsc::channel::<Message>(16);
let conn2 = Arc::new(tokio::sync::RwLock::new(Connection::new(
"conn-2".to_string(),
tx2,
)));
assert!(manager.register(conn1).await.is_ok());
assert_eq!(manager.connection_count(), 1);
assert!(manager.register(conn2.clone()).await.is_err());
manager.unregister("conn-1").await;
assert_eq!(manager.connection_count(), 0);
assert!(manager.register(conn2).await.is_ok());
assert_eq!(manager.connection_count(), 1);
}
#[tokio::test]
async fn test_concurrent_registration_respects_limit() {
use std::sync::atomic::{AtomicUsize, Ordering};
let manager = Arc::new(ConnectionManager::with_max_connections(5));
let success_count = Arc::new(AtomicUsize::new(0));
let failure_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for i in 0..20 {
let manager = manager.clone();
let success_count = success_count.clone();
let failure_count = failure_count.clone();
handles.push(tokio::spawn(async move {
let (tx, _rx) = mpsc::channel::<Message>(16);
let conn = Arc::new(tokio::sync::RwLock::new(Connection::new(
format!("conn-{}", i),
tx,
)));
match manager.register(conn).await {
Ok(_) => {
success_count.fetch_add(1, Ordering::Relaxed);
}
Err(_) => {
failure_count.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for handle in handles {
handle.await.unwrap();
}
let successes = success_count.load(Ordering::Relaxed);
let failures = failure_count.load(Ordering::Relaxed);
assert_eq!(successes, 5, "Exactly 5 connections should have succeeded");
assert_eq!(failures, 15, "15 connections should have been rejected");
assert_eq!(
manager.connection_count(),
5,
"Connection count should be exactly 5"
);
}
}