use std::{collections::HashSet, time::Duration};
use bytes::Bytes;
use rand::Rng;
use tokio::{sync::mpsc, task::JoinSet};
use tokio_stream::StreamExt;
use tracing::info;
use msg_socket::{PubSocket, SubSocket};
use msg_transport::{Address, Transport, quic::Quic, tcp::Tcp};
const TOPIC: &str = "test";
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn pubsub_channel() {
let _ = tracing_subscriber::fmt::try_init();
let result = pubsub_channel_transport(build_tcp, "127.0.0.1:9879".parse().unwrap()).await;
assert!(result.is_ok());
let result = pubsub_channel_transport(build_quic, "127.0.0.1:9879".parse().unwrap()).await;
assert!(result.is_ok());
}
async fn pubsub_channel_transport<F, T, A>(
new_transport: F,
addr: A,
) -> Result<(), Box<dyn std::error::Error>>
where
F: Fn() -> T,
T: Transport<A>,
A: Address,
{
let mut publisher = PubSocket::new(new_transport());
let mut subscriber = SubSocket::new(new_transport());
subscriber.connect_inner(addr.clone()).await?;
subscriber.subscribe(TOPIC).await?;
inject_delay(400).await;
publisher.try_bind(vec![addr]).await?;
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_millis(500)).await;
publisher.publish(TOPIC, Bytes::from("WORLD")).await.unwrap();
}
});
let msg = subscriber.next().await.unwrap();
info!("Received message: {:?}", msg);
assert_eq!(TOPIC, msg.topic());
assert_eq!("WORLD", msg.payload());
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn pubsub_fan_out() {
let _ = tracing_subscriber::fmt::try_init();
let result = pubsub_fan_out_transport(build_tcp, 10, "127.0.0.1:9880".parse().unwrap()).await;
assert!(result.is_ok());
let result = pubsub_fan_out_transport(build_quic, 10, "127.0.0.1:9880".parse().unwrap()).await;
assert!(result.is_ok());
}
async fn pubsub_fan_out_transport<
F: Fn() -> T + Send + Copy + 'static,
T: Transport<A>,
A: Address,
>(
new_transport: F,
subscibers: usize,
addr: A,
) -> Result<(), Box<dyn std::error::Error>> {
let mut publisher = PubSocket::new(new_transport());
let mut sub_tasks = JoinSet::new();
for i in 0..subscibers {
let cloned = addr.clone();
sub_tasks.spawn(async move {
let mut subscriber = SubSocket::new(new_transport());
inject_delay((100 * (i + 1)) as u64).await;
subscriber.connect_inner(cloned).await.unwrap();
inject_delay((1000 / (i + 1)) as u64).await;
subscriber.subscribe(TOPIC).await.unwrap();
let msg = subscriber.next().await.unwrap();
info!("Received message: {:?}", msg);
assert_eq!(TOPIC, msg.topic());
assert_eq!("WORLD", msg.payload());
});
}
inject_delay(400).await;
publisher.try_bind(vec![addr]).await?;
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_millis(500)).await;
publisher.publish(TOPIC, Bytes::from("WORLD")).await.unwrap();
}
});
for _ in 0..subscibers {
sub_tasks.join_next().await.unwrap().unwrap();
}
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore]
async fn pubsub_fan_in() {
let _ = tracing_subscriber::fmt::try_init();
let result = pubsub_fan_in_transport(build_tcp, 20, "127.0.0.1:9881".parse().unwrap()).await;
assert!(result.is_ok());
let result = pubsub_fan_in_transport(build_quic, 20, "127.0.0.1:9881".parse().unwrap()).await;
assert!(result.is_ok());
}
async fn pubsub_fan_in_transport<
F: Fn() -> T + Send + Copy + 'static,
T: Transport<A>,
A: Address,
>(
new_transport: F,
publishers: usize,
addr: A,
) -> Result<(), Box<dyn std::error::Error>> {
let mut sub_tasks = JoinSet::new();
let (tx, mut rx) = mpsc::channel(publishers);
for i in 0..publishers {
let tx = tx.clone();
let addr = addr.clone();
sub_tasks.spawn(async move {
let mut publisher = PubSocket::new(new_transport());
inject_delay((100 * (i + 1)) as u64).await;
publisher.try_bind(vec![addr]).await.unwrap();
let local_addr = publisher.local_addr().unwrap().clone();
tx.send(local_addr).await.unwrap();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_millis(500)).await;
publisher.publish(TOPIC, Bytes::from("WORLD")).await.unwrap();
}
});
});
}
drop(tx);
let mut subscriber = SubSocket::new(new_transport());
let mut addrs = HashSet::with_capacity(publishers);
while let Some(addr) = rx.recv().await {
addrs.insert(addr);
}
for addr in addrs.clone() {
inject_delay(500).await;
subscriber.connect_inner(addr.clone()).await.unwrap();
subscriber.subscribe(TOPIC).await.unwrap();
}
loop {
if addrs.is_empty() {
break;
}
let msg = subscriber.next().await.unwrap();
info!("Received message: {:?}", msg);
assert_eq!(TOPIC, msg.topic());
assert_eq!("WORLD", msg.payload());
addrs.remove(msg.source());
}
for _ in 0..publishers {
sub_tasks.join_next().await.unwrap().unwrap();
}
Ok(())
}
fn build_tcp() -> Tcp {
Tcp::default()
}
fn build_quic() -> Quic {
Quic::default()
}
fn random_delay(upper_ms: u64) -> Duration {
let mut rng = rand::rng();
let delay_ms = rng.random_range(0..upper_ms);
Duration::from_millis(delay_ms)
}
async fn inject_delay(upper_ms: u64) {
tokio::time::sleep(random_delay(upper_ms)).await;
}