mod common;
use futures_util::{SinkExt, StreamExt};
use std::time::Duration;
use tokio_tungstenite::tungstenite::Message;
#[tokio::test]
#[ignore] async fn test_ws_connect_receives_ready() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
let (_sink, _stream, ready) = common::ws_connect(port).await;
assert_eq!(ready["type"], "ready");
assert_eq!(ready["version"], "1.0");
assert_eq!(ready["sample_rate"], 48000);
assert!(
ready["model"].as_str().unwrap().contains("zipformer"),
"model field should contain 'zipformer', got: {:?}",
ready["model"]
);
let rates = ready["supported_rates"]
.as_array()
.expect("supported_rates should be an array");
assert!(
rates.len() >= 5,
"supported_rates should have >=5 entries, got {}",
rates.len()
);
assert!(
rates.contains(&serde_json::json!(8000)),
"supported_rates should contain 8000"
);
assert!(
rates.contains(&serde_json::json!(48000)),
"supported_rates should contain 48000"
);
}
#[tokio::test]
#[ignore] async fn test_ws_audio_produces_final() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
let (mut sink, mut stream, _ready) = common::ws_connect(port).await;
let audio = common::pcm16_from_wav(&common::test_wav_path(0));
for chunk in audio.chunks(9600) {
sink.send(Message::Binary(chunk.to_vec().into()))
.await
.unwrap();
}
sink.send(Message::Text(
serde_json::to_string(&serde_json::json!({"type": "stop"}))
.unwrap()
.into(),
))
.await
.unwrap();
loop {
let msg = tokio::time::timeout(Duration::from_secs(30), stream.next())
.await
.expect("timeout waiting for Final")
.expect("stream ended")
.expect("ws error");
let text = msg.into_text().expect("expected text message");
let v: serde_json::Value = serde_json::from_str(&text).expect("expected JSON");
match v["type"].as_str().unwrap_or("") {
"partial" => continue,
"final" => {
let text_str = v["text"]
.as_str()
.expect("Final message should have a text field");
assert!(
!text_str.trim().is_empty(),
"Expected non-empty Vietnamese transcription, got: {text_str}"
);
break;
}
other => panic!("Unexpected message type: {other}, full: {text}"),
}
}
}
#[tokio::test]
#[ignore] async fn test_ws_stop_without_audio() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
let (mut sink, mut stream, _ready) = common::ws_connect(port).await;
sink.send(Message::Text(
serde_json::to_string(&serde_json::json!({"type": "stop"}))
.unwrap()
.into(),
))
.await
.unwrap();
let msg = tokio::time::timeout(Duration::from_secs(10), stream.next())
.await
.expect("timeout waiting for Final")
.expect("stream ended")
.expect("ws error");
let v = common::assert_msg_type(msg, "final");
assert_eq!(
v["text"].as_str().unwrap_or(""),
"",
"Expected empty text for stop-without-audio"
);
}
#[tokio::test]
#[ignore] async fn test_ws_configure_valid_sample_rate() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
let (mut sink, mut stream, _ready) = common::ws_connect(port).await;
sink.send(Message::Text(
serde_json::to_string(&serde_json::json!({"type": "configure", "sample_rate": 16000}))
.unwrap()
.into(),
))
.await
.unwrap();
let audio = common::pcm16_from_wav(&common::test_wav_path(0));
let one_sec_bytes = 16000usize * 2; let chunk = &audio[..audio.len().min(one_sec_bytes)];
sink.send(Message::Binary(chunk.to_vec().into()))
.await
.unwrap();
sink.send(Message::Text(
serde_json::to_string(&serde_json::json!({"type": "stop"}))
.unwrap()
.into(),
))
.await
.unwrap();
loop {
let msg = tokio::time::timeout(Duration::from_secs(20), stream.next())
.await
.expect("timeout waiting for Final")
.expect("stream ended")
.expect("ws error");
let text = msg.into_text().expect("expected text message");
let v: serde_json::Value = serde_json::from_str(&text).expect("expected JSON");
match v["type"].as_str().unwrap_or("") {
"partial" => continue,
"final" => break,
other => panic!("Unexpected message type: {other} (expected final, not error)"),
}
}
}
#[tokio::test]
#[ignore] async fn test_ws_configure_invalid_sample_rate() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
let (mut sink, mut stream, _ready) = common::ws_connect(port).await;
sink.send(Message::Text(
serde_json::to_string(&serde_json::json!({"type": "configure", "sample_rate": 7000}))
.unwrap()
.into(),
))
.await
.unwrap();
let msg = tokio::time::timeout(Duration::from_secs(5), stream.next())
.await
.expect("timeout waiting for Error")
.expect("stream ended")
.expect("ws error");
let v = common::assert_msg_type(msg, "error");
assert_eq!(
v["code"], "invalid_sample_rate",
"Expected code=invalid_sample_rate, got: {:?}",
v["code"]
);
}
#[tokio::test]
#[ignore] async fn test_ws_configure_after_audio() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
let (mut sink, mut stream, _ready) = common::ws_connect(port).await;
let silence = common::generate_pcm16_silence(0.1, 48000);
sink.send(Message::Binary(silence.into())).await.unwrap();
sink.send(Message::Text(
serde_json::to_string(&serde_json::json!({"type": "configure", "sample_rate": 16000}))
.unwrap()
.into(),
))
.await
.unwrap();
let msg = tokio::time::timeout(Duration::from_secs(5), stream.next())
.await
.expect("timeout waiting for Error")
.expect("stream ended")
.expect("ws error");
let v = common::assert_msg_type(msg, "error");
assert_eq!(
v["code"], "configure_too_late",
"Expected code=configure_too_late, got: {:?}",
v["code"]
);
}
#[tokio::test]
#[ignore] async fn test_ws_malformed_json() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
let (mut sink, mut stream, _ready) = common::ws_connect(port).await;
sink.send(Message::Text("not json at all {{".to_string().into()))
.await
.unwrap();
sink.send(Message::Text(
serde_json::to_string(&serde_json::json!({"type": "stop"}))
.unwrap()
.into(),
))
.await
.unwrap();
loop {
let msg = tokio::time::timeout(Duration::from_secs(10), stream.next())
.await
.expect("timeout — connection may have been closed by malformed JSON")
.expect("stream ended unexpectedly after malformed JSON")
.expect("ws error");
let text = msg.into_text().expect("expected text message");
let v: serde_json::Value = serde_json::from_str(&text).expect("expected JSON");
match v["type"].as_str().unwrap_or("") {
"partial" => continue,
"final" => break,
other => panic!("Unexpected message type after malformed JSON: {other}"),
}
}
}
#[tokio::test]
#[ignore] async fn test_ws_client_disconnect_midstream() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
{
let (mut sink, _stream, _ready) = common::ws_connect(port).await;
let silence = common::generate_pcm16_silence(0.5, 48000);
sink.send(Message::Binary(silence.into())).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(200)).await;
let (_sink2, _stream2, ready2) = common::ws_connect(port).await;
assert_eq!(
ready2["type"], "ready",
"Server should still be healthy after abrupt client disconnect"
);
}
#[tokio::test]
#[ignore] async fn test_ws_concurrent_4_clients() {
let model_dir = common::model_dir();
let (port, _shutdown) = common::start_server(&model_dir).await;
let url = format!("ws://127.0.0.1:{port}/ws");
let mut handles = Vec::new();
for i in 0..4usize {
let url = url.clone();
handles.push(tokio::spawn(async move {
let (ws, _) = tokio_tungstenite::connect_async(&url)
.await
.unwrap_or_else(|e| panic!("Client {i} failed to connect: {e}"));
let (mut sink, mut stream) = ws.split();
let msg = tokio::time::timeout(Duration::from_secs(10), stream.next())
.await
.expect("timeout waiting for Ready")
.expect("stream ended")
.expect("ws error");
let text = msg.into_text().unwrap();
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(v["type"], "ready", "Client {i} did not receive Ready");
sink.send(Message::Text(
serde_json::to_string(&serde_json::json!({"type": "stop"}))
.unwrap()
.into(),
))
.await
.unwrap();
let msg = tokio::time::timeout(Duration::from_secs(10), stream.next())
.await
.expect("timeout waiting for Final")
.expect("stream ended")
.expect("ws error");
let text = msg.into_text().unwrap();
let v: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(
v["type"], "final",
"Client {i} did not receive Final after Stop"
);
i
}));
}
for handle in handles {
let client_id = tokio::time::timeout(Duration::from_secs(30), handle)
.await
.expect("client task timed out")
.expect("client task panicked");
assert!(client_id < 4);
}
}