#![allow(clippy::result_large_err)]
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use futures_util::{SinkExt, StreamExt};
use parking_lot::Mutex;
use serde_json::json;
use studio_worker::config::{self, Config};
use studio_worker::runtime::WorkerObservers;
use studio_worker::types::LogEntry;
use studio_worker::ws::session::{spawn_ws_session, SessionSchedule};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::handshake::server::{ErrorResponse, Request, Response};
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
const TIMEOUT: Duration = Duration::from_secs(10);
fn echo_subprotocol(_req: &Request, mut resp: Response) -> Result<Response, ErrorResponse> {
resp.headers_mut().insert(
"sec-websocket-protocol",
HeaderValue::from_static("studio-worker-v1"),
);
Ok(resp)
}
async fn spawn_studio_ws() -> (SocketAddr, tokio::task::JoinHandle<Result<()>>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await?;
let mut ws = tokio_tungstenite::accept_hdr_async(stream, echo_subprotocol).await?;
let hello = ws
.next()
.await
.ok_or_else(|| anyhow::anyhow!("hello missing"))??
.into_text()
.map_err(|_| anyhow::anyhow!("hello not text"))?;
let hello_json: serde_json::Value = serde_json::from_str(&hello)?;
assert_eq!(hello_json["type"], "hello");
ws.send(Message::Text(
serde_json::to_string(
&json!({"type":"welcome","workerId":"w-test","serverTime":"now"}),
)?
.into(),
))
.await?;
let synthetic_source = json!({
"engine": "synthetic",
"files": [],
"cliDefaults": {
"cfgScale": 1.0,
"steps": 8,
"width": 1024,
"height": 1024
}
});
ws.send(Message::Text(
serde_json::to_string(&json!({
"type": "offer",
"claim": {
"jobId": "job-llm",
"gameId": "g",
"assetName": "g/dialogue/scribe",
"model": "synthetic",
"vramGbEstimate": 1.0,
"task": {
"kind": "llm",
"messages": [{"role": "user", "content": "hi"}],
"maxTokens": 4,
"temperature": 0.5
},
"modelSource": synthetic_source
}
}))?
.into(),
))
.await?;
let frames = collect_frames(&mut ws, &["accept", "completeJson"]).await?;
assert_eq!(frames["accept"]["jobId"], "job-llm");
assert_eq!(frames["completeJson"]["jobId"], "job-llm");
ws.send(Message::Text(
serde_json::to_string(&json!({
"type": "offer",
"claim": {
"jobId": "job-stt",
"gameId": "g",
"assetName": "g/dialogue/transcript",
"model": "synthetic",
"vramGbEstimate": 1.0,
"task": {
"kind": "audio_stt",
"inputUrl": "https://example/audio.wav",
"language": null
},
"modelSource": synthetic_source
}
}))?
.into(),
))
.await?;
let frames = collect_frames(&mut ws, &["accept", "completeJson"]).await?;
assert_eq!(frames["accept"]["jobId"], "job-stt");
assert_eq!(frames["completeJson"]["jobId"], "job-stt");
ws.close(None).await?;
Ok(())
});
(addr, handle)
}
async fn spawn_handshake_only_ws() -> (SocketAddr, tokio::task::JoinHandle<Result<()>>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await?;
let mut ws = tokio_tungstenite::accept_hdr_async(stream, echo_subprotocol).await?;
let hello = ws
.next()
.await
.ok_or_else(|| anyhow::anyhow!("hello missing"))??
.into_text()
.map_err(|_| anyhow::anyhow!("hello not text"))?;
let hello_json: serde_json::Value = serde_json::from_str(&hello)?;
assert_eq!(hello_json["type"], "hello");
ws.send(Message::Text(
serde_json::to_string(
&json!({"type":"welcome","workerId":"w-test","serverTime":"now"}),
)?
.into(),
))
.await?;
ws.close(None).await?;
Ok(())
});
(addr, handle)
}
async fn spawn_silent_after_welcome_ws() -> (SocketAddr, tokio::task::JoinHandle<Result<()>>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await?;
drop(listener);
let mut ws = tokio_tungstenite::accept_hdr_async(stream, echo_subprotocol).await?;
let _hello = ws
.next()
.await
.ok_or_else(|| anyhow::anyhow!("hello missing"))??;
ws.send(Message::Text(
serde_json::to_string(
&json!({"type":"welcome","workerId":"w-test","serverTime":"now"}),
)?
.into(),
))
.await?;
tokio::time::sleep(Duration::from_secs(30)).await;
drop(ws);
Ok(())
});
(addr, handle)
}
async fn collect_frames(
ws: &mut WebSocketStream<TcpStream>,
expected: &[&str],
) -> Result<HashMap<String, serde_json::Value>> {
let mut bucket: HashMap<String, serde_json::Value> = HashMap::new();
while bucket.len() < expected.len() {
let item = tokio::time::timeout(TIMEOUT, ws.next())
.await?
.ok_or_else(|| anyhow::anyhow!("stream ended early"))??;
if let Message::Text(t) = item {
let frame: serde_json::Value = serde_json::from_str(&t)?;
if let Some(kind) = frame["type"].as_str() {
if expected.contains(&kind) {
bucket.insert(kind.to_string(), frame);
}
}
}
}
Ok(bucket)
}
#[tokio::test]
async fn ws_session_walks_through_two_json_offers_and_then_disconnects() {
let (ws_addr, server_handle) = spawn_studio_ws().await;
let cfg = Config {
api_base_url: format!("http://{ws_addr}"),
worker_id: Some("w-test".into()),
auth_token: Some("tok-test".into()),
auto_update_enabled: false,
ws_reconnect_attempts: Some(1),
..Config::default()
};
let shared = config::shared(cfg);
let stop = Arc::new(AtomicBool::new(false));
let logs = Arc::new(Mutex::new(Vec::<LogEntry>::new()));
let busy = Arc::new(AtomicBool::new(false));
let paused = Arc::new(AtomicBool::new(false));
let session_handle = tokio::spawn({
let shared = shared.clone();
let stop = stop.clone();
let logs = logs.clone();
let busy = busy.clone();
let paused = paused.clone();
async move {
spawn_ws_session(
shared,
stop,
logs,
busy,
paused,
WorkerObservers::default(),
SessionSchedule::fast_for_tests(),
)
.await
}
});
tokio::time::timeout(TIMEOUT, server_handle)
.await
.expect("server timed out")
.expect("server task panicked")
.expect("server returned err");
tokio::time::sleep(Duration::from_millis(200)).await;
stop.store(true, std::sync::atomic::Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(5), session_handle)
.await
.expect("session loop timed out");
let _ = logs;
}
#[tokio::test]
async fn ws_session_logs_a_breadcrumb_when_json_result_is_sent() {
let (ws_addr, server_handle) = spawn_studio_ws().await;
let cfg = Config {
api_base_url: format!("http://{ws_addr}"),
worker_id: Some("w-test".into()),
auth_token: Some("tok-test".into()),
auto_update_enabled: false,
ws_reconnect_attempts: Some(1),
..Config::default()
};
let shared = config::shared(cfg);
let stop = Arc::new(AtomicBool::new(false));
let logs = Arc::new(Mutex::new(Vec::<LogEntry>::new()));
let busy = Arc::new(AtomicBool::new(false));
let paused = Arc::new(AtomicBool::new(false));
let observers = WorkerObservers::default();
let session_handle = tokio::spawn({
let shared = shared.clone();
let stop = stop.clone();
let logs = logs.clone();
let busy = busy.clone();
let paused = paused.clone();
let observers = observers.clone();
async move {
spawn_ws_session(
shared,
stop,
logs,
busy,
paused,
observers,
SessionSchedule::fast_for_tests(),
)
.await
}
});
tokio::time::timeout(TIMEOUT, server_handle)
.await
.expect("server timed out")
.expect("server task panicked")
.expect("server returned err");
tokio::time::sleep(Duration::from_millis(200)).await;
stop.store(true, std::sync::atomic::Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(5), session_handle)
.await
.expect("session loop timed out");
let json_completion = observers
.recent_logs
.lock()
.iter()
.filter(|e| e.message.contains("json result sent"))
.map(|e| (e.level.clone(), e.job_id.clone()))
.collect::<Vec<_>>();
assert_eq!(
json_completion.len(),
2,
"expected a completion breadcrumb per JSON job, got {json_completion:?}"
);
assert!(
json_completion.iter().all(|(level, _)| level == "info"),
"json completion breadcrumbs must be info-level: {json_completion:?}"
);
assert!(
json_completion
.iter()
.any(|(_, job_id)| job_id.as_deref() == Some("job-llm")),
"missing breadcrumb for the LLM job: {json_completion:?}"
);
assert!(
json_completion
.iter()
.any(|(_, job_id)| job_id.as_deref() == Some("job-stt")),
"missing breadcrumb for the STT job: {json_completion:?}"
);
}
#[tokio::test]
async fn ws_session_recovers_from_a_silent_half_open_connection() {
let (ws_addr, _server) = spawn_silent_after_welcome_ws().await;
let cfg = Config {
api_base_url: format!("http://{ws_addr}"),
worker_id: Some("w-test".into()),
auth_token: Some("tok-test".into()),
auto_update_enabled: false,
ws_reconnect_attempts: Some(1),
..Config::default()
};
let shared = config::shared(cfg);
let stop = Arc::new(AtomicBool::new(false));
let logs = Arc::new(Mutex::new(Vec::<LogEntry>::new()));
let busy = Arc::new(AtomicBool::new(false));
let paused = Arc::new(AtomicBool::new(false));
let observers = WorkerObservers::default();
let schedule = SessionSchedule {
read_idle_timeout: Duration::from_millis(300),
..SessionSchedule::fast_for_tests()
};
let session_handle = tokio::spawn({
let shared = shared.clone();
let stop = stop.clone();
let logs = logs.clone();
let busy = busy.clone();
let paused = paused.clone();
let observers = observers.clone();
async move { spawn_ws_session(shared, stop, logs, busy, paused, observers, schedule).await }
});
let outcome = tokio::time::timeout(Duration::from_secs(8), session_handle).await;
stop.store(true, std::sync::atomic::Ordering::SeqCst);
assert!(
outcome.is_ok(),
"session hung on a silent connection instead of detecting the read-idle-timeout"
);
assert!(
observers
.recent_logs
.lock()
.iter()
.any(|e| e.message.contains("reconnect attempt")),
"worker must log a reconnect attempt after the idle timeout"
);
}
#[tokio::test]
async fn ws_session_logs_advertised_capabilities_on_handshake() {
let (ws_addr, server_handle) = spawn_handshake_only_ws().await;
let cfg = Config {
api_base_url: format!("http://{ws_addr}"),
worker_id: Some("w-test".into()),
auth_token: Some("tok-test".into()),
auto_update_enabled: false,
ws_reconnect_attempts: Some(1),
..Config::default()
};
let shared = config::shared(cfg);
let stop = Arc::new(AtomicBool::new(false));
let logs = Arc::new(Mutex::new(Vec::<LogEntry>::new()));
let busy = Arc::new(AtomicBool::new(false));
let paused = Arc::new(AtomicBool::new(false));
let observers = WorkerObservers::default();
let session_handle = tokio::spawn({
let shared = shared.clone();
let stop = stop.clone();
let logs = logs.clone();
let busy = busy.clone();
let paused = paused.clone();
let observers = observers.clone();
async move {
spawn_ws_session(
shared,
stop,
logs,
busy,
paused,
observers,
SessionSchedule::fast_for_tests(),
)
.await
}
});
tokio::time::timeout(TIMEOUT, server_handle)
.await
.expect("server timed out")
.expect("server task panicked")
.expect("server returned err");
tokio::time::sleep(Duration::from_millis(200)).await;
stop.store(true, std::sync::atomic::Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(5), session_handle)
.await
.expect("session loop timed out");
let summary = observers
.recent_logs
.lock()
.iter()
.find(|e| e.message.contains("advertising engine="))
.map(|e| e.message.clone())
.expect("capability summary must be logged on the handshake");
assert!(summary.contains("kinds=["), "missing kinds: {summary}");
assert!(summary.contains("image"), "missing image kind: {summary}");
assert!(summary.contains("synthetic"), "missing model id: {summary}");
assert!(
summary.contains("auto_enabled=true"),
"unpaused worker must advertise auto_enabled=true: {summary}"
);
}