use futures::{SinkExt, StreamExt};
use serde_json::{json, Value};
use std::time::Duration;
use tokio_tungstenite::tungstenite::Message;
#[path = "harness/src/daemon.rs"]
mod daemon;
use daemon::DaemonFixture;
async fn daemon() -> DaemonFixture {
DaemonFixture::start("ws-test").await
}
fn client_with_auth(d: &DaemonFixture) -> reqwest::Client {
d.authed_client(Duration::from_secs(10))
}
async fn ws_connect(
d: &DaemonFixture,
path: &str,
) -> tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
let (ws, _) = tokio_tungstenite::connect_async(&d.ws_url(path))
.await
.expect("WS connect failed");
ws
}
async fn ws_recv_text(
ws: &mut tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
timeout_secs: u64,
) -> Option<String> {
let deadline = Duration::from_secs(timeout_secs);
match tokio::time::timeout(deadline, ws.next()).await {
Ok(Some(Ok(Message::Text(t)))) => Some(t.to_string()),
_ => None,
}
}
async fn ws_send(
ws: &mut tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
msg: &str,
) {
ws.send(Message::Text(msg.into()))
.await
.expect("WS send failed");
}
async fn ws_subscribe_topic(
ws: &mut tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
topic: &str,
) {
ws_send(
ws,
&format!(r#"{{"type":"subscribe","topics":["{topic}"]}}"#),
)
.await;
for _ in 0..5 {
let msg = ws_recv_text(ws, 5)
.await
.expect("should receive subscription confirmation");
let frame: Value = serde_json::from_str(&msg).expect("parse subscription confirmation");
match frame["type"].as_str() {
Some("subscribed") => return,
Some("pong") | Some("connected") => continue,
_ => panic!("unexpected subscription frame: {msg}"),
}
}
panic!("did not receive subscribed frame after subscribe command");
}
#[tokio::test]
#[ignore]
async fn ws_connection_lifecycle() {
let d = daemon().await;
let mut ws = ws_connect(&d, "/ws").await;
let msg = ws_recv_text(&mut ws, 5)
.await
.expect("should receive connected msg");
let frame: Value = serde_json::from_str(&msg).expect("parse JSON");
assert_eq!(
frame["type"], "connected",
"first message should be 'connected'"
);
assert!(frame["session_id"].is_string(), "should have session_id");
ws.close(None).await.expect("close");
}
#[tokio::test]
#[ignore]
async fn ws_ping_pong() {
let d = daemon().await;
let mut ws = ws_connect(&d, "/ws").await;
let _ = ws_recv_text(&mut ws, 5).await;
ws_send(&mut ws, r#"{"type":"ping"}"#).await;
let msg = ws_recv_text(&mut ws, 5).await.expect("should receive pong");
let frame: Value = serde_json::from_str(&msg).expect("parse JSON");
assert_eq!(frame["type"], "pong");
ws.close(None).await.expect("close");
}
#[tokio::test]
#[ignore]
async fn ws_subscribe_publish_receive() {
let d = daemon().await;
let client = client_with_auth(&d);
let mut ws = ws_connect(&d, "/ws").await;
let _ = ws_recv_text(&mut ws, 5).await;
let topic = format!("ws-test-{}", rand::random::<u32>());
ws_subscribe_topic(&mut ws, &topic).await;
let payload =
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, b"hello-ws-test");
let resp = client
.post(d.url("/publish"))
.json(&json!({"topic": topic, "payload": payload}))
.send()
.await
.expect("publish");
assert_eq!(resp.status(), 200);
let recv_msg = ws_recv_text(&mut ws, 10).await;
assert!(
recv_msg.is_some(),
"should receive published message via WS"
);
ws.close(None).await.expect("close");
}
#[tokio::test]
#[ignore]
async fn ws_session_tracking() {
let d = daemon().await;
let client = client_with_auth(&d);
let initial: Value = client
.get(d.url("/ws/sessions"))
.send()
.await
.expect("sessions")
.json()
.await
.expect("parse");
let initial_count = initial
.as_array()
.or_else(|| initial["sessions"].as_array())
.map_or(0, |a| a.len());
let mut ws = ws_connect(&d, "/ws").await;
let _ = ws_recv_text(&mut ws, 5).await;
tokio::time::sleep(Duration::from_millis(500)).await;
let after: Value = client
.get(d.url("/ws/sessions"))
.send()
.await
.expect("sessions")
.json()
.await
.expect("parse");
let after_count = after
.as_array()
.or_else(|| after["sessions"].as_array())
.map_or(0, |a| a.len());
assert!(
after_count > initial_count,
"session count should increase: before={initial_count}, after={after_count}"
);
ws.close(None).await.expect("close");
tokio::time::sleep(Duration::from_millis(1000)).await;
let final_resp: Value = client
.get(d.url("/ws/sessions"))
.send()
.await
.expect("sessions")
.json()
.await
.expect("parse");
let final_count = final_resp
.as_array()
.or_else(|| final_resp["sessions"].as_array())
.map_or(0, |a| a.len());
assert!(
final_count <= initial_count,
"session count should decrease after close: initial={initial_count}, final={final_count}"
);
}
#[tokio::test]
#[ignore]
async fn ws_requires_auth() {
let d = daemon().await;
let url = format!("ws://{}/ws", d.api_addr());
let result = tokio_tungstenite::connect_async(&url).await;
match result {
Err(_) => {} Ok((mut ws, _)) => {
let msg = ws_recv_text(&mut ws, 5).await;
if let Some(text) = msg {
let frame: Value = serde_json::from_str(&text).unwrap_or_default();
if frame["type"] == "error" {
} else {
}
}
ws.close(None).await.ok();
}
}
}
#[tokio::test]
#[ignore]
async fn ws_direct_endpoint() {
let d = daemon().await;
let mut ws = ws_connect(&d, "/ws/direct").await;
let msg = ws_recv_text(&mut ws, 5)
.await
.expect("should receive connected msg");
let frame: Value = serde_json::from_str(&msg).expect("parse JSON");
assert_eq!(
frame["type"], "connected",
"/ws/direct should send connected message"
);
ws.close(None).await.expect("close");
}
#[tokio::test]
#[ignore]
async fn ws_concurrent_subscribers() {
let d = daemon().await;
let client = client_with_auth(&d);
let topic = format!("concurrent-test-{}", rand::random::<u32>());
let n_clients = 5;
let n_messages = 3;
let mut clients = Vec::new();
for _ in 0..n_clients {
let mut ws = ws_connect(&d, "/ws").await;
let _ = ws_recv_text(&mut ws, 5).await; ws_subscribe_topic(&mut ws, &topic).await;
clients.push(ws);
}
tokio::time::sleep(Duration::from_millis(500)).await;
for i in 0..n_messages {
let payload = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
format!("msg-{i}").as_bytes(),
);
client
.post(d.url("/publish"))
.json(&json!({"topic": &topic, "payload": payload}))
.send()
.await
.expect("publish");
}
for (idx, ws) in clients.iter_mut().enumerate() {
let mut received = 0;
for _ in 0..n_messages {
if ws_recv_text(ws, 5).await.is_some() {
received += 1;
}
}
assert!(
received >= 1,
"Client {idx} should receive at least 1 message, got {received}"
);
}
for mut ws in clients {
ws.close(None).await.ok();
}
}
#[tokio::test]
#[ignore]
async fn ws_message_ordering() {
let d = daemon().await;
let client = client_with_auth(&d);
let topic = format!("order-test-{}", rand::random::<u32>());
let mut ws = ws_connect(&d, "/ws").await;
let _ = ws_recv_text(&mut ws, 5).await; ws_subscribe_topic(&mut ws, &topic).await;
tokio::time::sleep(Duration::from_millis(300)).await;
let n = 10;
for i in 0..n {
let payload = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
format!("seq-{i:04}").as_bytes(),
);
client
.post(d.url("/publish"))
.json(&json!({"topic": &topic, "payload": payload}))
.send()
.await
.expect("publish");
}
let mut received = Vec::new();
for _ in 0..n {
if let Some(msg) = ws_recv_text(&mut ws, 10).await {
received.push(msg);
}
}
for window in received.windows(2) {
assert!(
!window[0].is_empty() && !window[1].is_empty(),
"messages should not be empty"
);
}
assert!(
received.len() >= n / 2,
"should receive at least half of {} messages, got {}",
n,
received.len()
);
ws.close(None).await.expect("close");
}