use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::Barrier;
use ultimo::websocket::{ChannelManager, Message};
#[cfg(test)]
mod concurrency_tests {
use super::*;
#[tokio::test]
async fn test_concurrent_subscriptions() {
let manager = Arc::new(ChannelManager::new());
let mut handles = vec![];
for i in 0..100 {
let manager: Arc<ChannelManager> = Arc::clone(&manager);
let handle = tokio::spawn(async move {
let conn_id = uuid::Uuid::new_v4();
let topic = format!("topic_{}", i % 10); let (tx, _rx) = mpsc::channel(1000);
manager.subscribe(conn_id, &topic, tx).await.unwrap();
let count = manager.subscriber_count(&topic).await;
assert!(count > 0);
conn_id
});
handles.push(handle);
}
let conn_ids: Vec<uuid::Uuid> = futures_util::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
for _conn_id in conn_ids {
}
}
#[tokio::test]
async fn test_concurrent_publish() {
let manager = Arc::new(ChannelManager::new());
let topic = "broadcast_topic";
let num_subscribers = 50;
let mut receivers = vec![];
for _ in 0..num_subscribers {
let (tx, rx) = mpsc::channel(1000);
let conn_id = uuid::Uuid::new_v4();
manager.subscribe(conn_id, topic, tx).await.unwrap();
receivers.push(rx);
}
let num_messages = 20;
let mut publish_handles = vec![];
for i in 0..num_messages {
let manager: Arc<ChannelManager> = Arc::clone(&manager);
let topic = topic.to_string();
let handle = tokio::spawn(async move {
let msg = Message::Text(format!("message_{}", i));
manager.publish(&topic, msg).await
});
publish_handles.push(handle);
}
let results = futures_util::future::join_all(publish_handles).await;
for result in results {
let count = result.unwrap().unwrap();
assert_eq!(count, num_subscribers);
}
for mut rx in receivers {
let mut count = 0;
while rx.try_recv().is_ok() {
count += 1;
}
assert_eq!(count, num_messages);
}
}
#[tokio::test]
async fn test_subscribe_unsubscribe_race() {
let manager = Arc::new(ChannelManager::new());
let topic = "race_topic";
let conn_id = uuid::Uuid::new_v4();
let mut handles = vec![];
for _ in 0..100 {
let manager: Arc<ChannelManager> = Arc::clone(&manager);
let topic = topic.to_string();
let handle = tokio::spawn(async move {
let (tx, _rx) = mpsc::channel(1000);
manager.subscribe(conn_id, &topic, tx).await.ok();
tokio::time::sleep(Duration::from_micros(10)).await;
manager.unsubscribe(conn_id, &topic).await.ok();
});
handles.push(handle);
}
futures_util::future::join_all(handles).await;
let count = manager.subscriber_count(topic).await;
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_concurrent_disconnect() {
let manager = Arc::new(ChannelManager::new());
let topic = "disconnect_topic";
let mut conn_ids = vec![];
for _ in 0..100 {
let conn_id = uuid::Uuid::new_v4();
let (tx, _rx) = mpsc::channel(1000);
manager.subscribe(conn_id, topic, tx).await.unwrap();
conn_ids.push(conn_id);
}
let count_before = manager.subscriber_count(topic).await;
assert_eq!(count_before, 100);
let mut handles = vec![];
for conn_id in conn_ids {
let manager: Arc<ChannelManager> = Arc::clone(&manager);
let handle = tokio::spawn(async move {
manager.disconnect(conn_id).await;
});
handles.push(handle);
}
futures_util::future::join_all(handles).await;
let count_after = manager.subscriber_count(topic).await;
assert_eq!(count_after, 0);
}
#[tokio::test]
async fn test_many_topics_subscription() {
let manager = Arc::new(ChannelManager::new());
let conn_id = uuid::Uuid::new_v4();
let (tx, _rx) = mpsc::channel(1000);
let num_topics = 1000;
for i in 0..num_topics {
let topic = format!("topic_{}", i);
manager
.subscribe(conn_id, &topic, tx.clone())
.await
.unwrap();
}
let topic_count = manager.topic_count().await;
assert_eq!(topic_count, num_topics);
}
#[tokio::test]
async fn test_barrier_synchronized_publish() {
let manager = Arc::new(ChannelManager::new());
let topic = "sync_topic";
let num_publishers = 10;
let barrier = Arc::new(Barrier::new(num_publishers));
let (tx, mut rx) = mpsc::channel(1000);
let conn_id = uuid::Uuid::new_v4();
manager.subscribe(conn_id, topic, tx).await.unwrap();
let mut handles = vec![];
for i in 0..num_publishers {
let manager: Arc<ChannelManager> = Arc::clone(&manager);
let barrier = Arc::clone(&barrier);
let topic = topic.to_string();
let handle = tokio::spawn(async move {
barrier.wait().await; let msg = Message::Text(format!("msg_{}", i));
manager.publish(&topic, msg).await
});
handles.push(handle);
}
futures_util::future::join_all(handles).await;
let mut count = 0;
while rx.try_recv().is_ok() {
count += 1;
}
assert_eq!(count, num_publishers);
}
#[tokio::test]
async fn test_publish_to_nonexistent_topic() {
let manager = Arc::new(ChannelManager::new());
let msg = Message::Text("hello".to_string());
let result = manager.publish("nonexistent", msg).await;
assert_eq!(result.unwrap(), 0);
}
#[tokio::test]
async fn test_connection_cleanup_on_channel_close() {
let manager = Arc::new(ChannelManager::new());
let topic = "cleanup_topic";
let conn_id = uuid::Uuid::new_v4();
let (tx, rx) = mpsc::channel(1000);
manager.subscribe(conn_id, topic, tx).await.unwrap();
drop(rx);
let msg = Message::Text("test".to_string());
let result = manager.publish(topic, msg).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_high_frequency_publish() {
let manager = Arc::new(ChannelManager::new());
let topic = "high_freq";
let (tx, mut rx) = mpsc::channel(15000);
let conn_id = uuid::Uuid::new_v4();
manager.subscribe(conn_id, topic, tx).await.unwrap();
let num_messages = 10000;
for i in 0..num_messages {
let msg = Message::Text(format!("msg_{}", i));
manager.publish(topic, msg).await.unwrap();
}
let mut count = 0;
while rx.try_recv().is_ok() {
count += 1;
}
assert_eq!(count, num_messages);
}
#[tokio::test]
async fn test_multiple_topics_single_connection() {
let manager = Arc::new(ChannelManager::new());
let (tx, mut rx) = mpsc::channel(1000);
let conn_id = uuid::Uuid::new_v4();
let topics = vec!["topic_a", "topic_b", "topic_c"];
for topic in &topics {
manager.subscribe(conn_id, topic, tx.clone()).await.unwrap();
}
for topic in &topics {
let msg = Message::Text(format!("msg_for_{}", topic));
manager.publish(topic, msg).await.unwrap();
}
let mut count = 0;
while rx.try_recv().is_ok() {
count += 1;
}
assert_eq!(count, topics.len());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_stress_many_concurrent_operations() {
let manager = Arc::new(ChannelManager::new());
let mut handles = vec![];
for i in 0..200 {
let manager: Arc<ChannelManager> = Arc::clone(&manager);
let handle = tokio::spawn(async move {
let conn_id = uuid::Uuid::new_v4();
let (tx, _rx) = mpsc::channel(1000);
let topic = format!("topic_{}", i % 20);
manager.subscribe(conn_id, &topic, tx).await.ok();
for j in 0..5 {
let msg = Message::Text(format!("msg_{}_{}", i, j));
manager.publish(&topic, msg).await.ok();
}
manager.unsubscribe(conn_id, &topic).await.ok();
manager.disconnect(conn_id).await;
});
handles.push(handle);
}
futures_util::future::join_all(handles).await;
let topic_count = manager.topic_count().await;
assert!(topic_count <= 20); }
}