#[cfg(feature = "websocket")]
mod websocket_backpressure_tests {
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time;
use ultimo::websocket::test_helpers::*;
#[tokio::test]
async fn test_send_blocks_when_buffer_full() {
let (tx, mut rx) = mpsc::channel(2); let channel_manager = Arc::new(ChannelManager::new());
let conn_id = uuid::Uuid::new_v4();
let config = Arc::new(ultimo::websocket::WebSocketConfig {
max_write_queue_size: 2,
..Default::default()
});
let ws = create_websocket((), tx.clone(), channel_manager, conn_id, None, config);
assert!(ws.send("message1").await.is_ok());
assert!(ws.send("message2").await.is_ok());
let result = ws.send("message3").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::WouldBlock);
let _ = rx.recv().await;
assert!(ws.send("message4").await.is_ok());
}
#[tokio::test]
async fn test_buffer_capacity_tracking() {
let (tx, _rx) = mpsc::channel(10);
let channel_manager = Arc::new(ChannelManager::new());
let conn_id = uuid::Uuid::new_v4();
let config = Arc::new(ultimo::websocket::WebSocketConfig {
max_write_queue_size: 10,
..Default::default()
});
let ws = create_websocket((), tx, channel_manager, conn_id, None, config);
assert_eq!(ws.max_capacity(), 10);
assert!(ws.has_capacity());
for i in 0..5 {
ws.send(format!("message{}", i)).await.unwrap();
}
assert!(ws.has_capacity());
assert!(ws.capacity() > 0);
}
#[tokio::test]
async fn test_binary_send_respects_backpressure() {
let (tx, mut rx) = mpsc::channel(1);
let channel_manager = Arc::new(ChannelManager::new());
let conn_id = uuid::Uuid::new_v4();
let config = Arc::new(ultimo::websocket::WebSocketConfig {
max_write_queue_size: 1,
..Default::default()
});
let ws = create_websocket((), tx, channel_manager, conn_id, None, config);
assert!(ws.send_binary(vec![1, 2, 3, 4]).await.is_ok());
let result = ws.send_binary(vec![5, 6, 7, 8]).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::WouldBlock);
let _ = rx.recv().await;
assert!(ws.send_binary(vec![9, 10]).await.is_ok());
}
#[tokio::test]
async fn test_publish_skips_backpressured_connections() {
let manager = ChannelManager::new();
let (tx1, mut rx1) = mpsc::channel(10);
let conn1 = uuid::Uuid::new_v4();
manager.subscribe(conn1, "topic", tx1).await.unwrap();
let (tx2, _rx2) = mpsc::channel(1);
let conn2 = uuid::Uuid::new_v4();
tx2.try_send(Message::Text("blocking".to_string())).unwrap();
manager.subscribe(conn2, "topic", tx2).await.unwrap();
let sent = manager
.publish("topic", Message::Text("test".to_string()))
.await
.unwrap();
assert_eq!(sent, 1);
let msg = time::timeout(Duration::from_millis(100), rx1.recv())
.await
.unwrap()
.unwrap();
match msg {
Message::Text(t) => assert_eq!(t, "test"),
_ => panic!("Expected text message"),
}
}
#[tokio::test]
async fn test_broadcast_all_respects_backpressure() {
let manager = ChannelManager::new();
let (tx1, mut rx1) = mpsc::channel(10);
let conn1 = uuid::Uuid::new_v4();
manager
.subscribe(conn1, "topic1", tx1.clone())
.await
.unwrap();
let (tx2, _rx2) = mpsc::channel(1);
let conn2 = uuid::Uuid::new_v4();
tx2.try_send(Message::Text("blocking".to_string())).unwrap();
manager.subscribe(conn2, "topic2", tx2).await.unwrap();
let count = manager
.broadcast_all(Message::Text("broadcast".to_string()))
.await;
assert_eq!(count, 1);
let msg = rx1.recv().await.unwrap();
match msg {
Message::Text(t) => assert_eq!(t, "broadcast"),
_ => panic!("Expected text message"),
}
}
#[tokio::test]
async fn test_close_respects_backpressure() {
let (tx, mut rx) = mpsc::channel(1);
let channel_manager = Arc::new(ChannelManager::new());
let conn_id = uuid::Uuid::new_v4();
let config = Arc::new(ultimo::websocket::WebSocketConfig {
max_write_queue_size: 1,
..Default::default()
});
let ws = create_websocket((), tx, channel_manager, conn_id, None, config);
assert!(ws.send("message").await.is_ok());
let result = ws.close(Some(1000), Some("Normal closure")).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::WouldBlock);
let _ = rx.recv().await;
assert!(ws.close(Some(1000), Some("Normal closure")).await.is_ok());
}
#[tokio::test]
async fn test_is_writable_reflects_connection_state() {
let (tx, rx) = mpsc::channel(10);
let channel_manager = Arc::new(ChannelManager::new());
let conn_id = uuid::Uuid::new_v4();
let config = Arc::new(ultimo::websocket::WebSocketConfig::default());
let ws = create_websocket((), tx.clone(), channel_manager, conn_id, None, config);
assert!(ws.is_writable());
drop(tx);
drop(rx);
assert!(!ws.is_writable());
}
#[tokio::test]
async fn test_multiple_send_failures_with_full_buffer() {
let (tx, mut rx) = mpsc::channel(2);
let channel_manager = Arc::new(ChannelManager::new());
let conn_id = uuid::Uuid::new_v4();
let config = Arc::new(ultimo::websocket::WebSocketConfig {
max_write_queue_size: 2,
..Default::default()
});
let ws = create_websocket((), tx, channel_manager, conn_id, None, config);
ws.send("msg1").await.unwrap();
ws.send("msg2").await.unwrap();
for i in 0..5 {
let result = ws.send(format!("failed{}", i)).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::WouldBlock);
}
let _ = rx.recv().await;
let _ = rx.recv().await;
assert!(ws.send("success").await.is_ok());
}
#[tokio::test]
async fn test_config_max_write_queue_size() {
let config1 = ultimo::websocket::WebSocketConfig {
max_write_queue_size: 1,
..Default::default()
};
assert_eq!(config1.max_write_queue_size, 1);
let config2 = ultimo::websocket::WebSocketConfig {
max_write_queue_size: 10000,
..Default::default()
};
assert_eq!(config2.max_write_queue_size, 10000);
let config_default = ultimo::websocket::WebSocketConfig::default();
assert_eq!(config_default.max_write_queue_size, 1024);
}
#[tokio::test]
async fn test_send_json_respects_backpressure() {
use serde::Serialize;
#[derive(Serialize)]
struct TestData {
message: String,
count: u32,
}
let (tx, mut rx) = mpsc::channel(1);
let channel_manager = Arc::new(ChannelManager::new());
let conn_id = uuid::Uuid::new_v4();
let config = Arc::new(ultimo::websocket::WebSocketConfig {
max_write_queue_size: 1,
..Default::default()
});
let ws = create_websocket((), tx, channel_manager, conn_id, None, config);
let data1 = TestData {
message: "first".to_string(),
count: 1,
};
assert!(ws.send_json(&data1).await.is_ok());
let data2 = TestData {
message: "second".to_string(),
count: 2,
};
let result = ws.send_json(&data2).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::WouldBlock);
let _ = rx.recv().await;
let data3 = TestData {
message: "third".to_string(),
count: 3,
};
assert!(ws.send_json(&data3).await.is_ok());
}
}