use std::convert::TryFrom;
use std::time::{Duration, Instant};
use std::sync::Arc;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::http::StatusCode;
use tokio_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFrame};
use tokio_tungstenite::tungstenite::{Error as TError, Message};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tracing::{debug, warn};
use url::Url;
use crate::ws::types::{WorkerInbound, WorkerOutbound};
pub const SUBPROTOCOL: &str = "studio-worker-v1";
const TRACE_TARGET: &str = "studio_worker::ws::client";
const API_PREFIX: &str = "/graphics/api";
const CONNECT_TIMEOUT: Duration = Duration::from_secs(15);
pub type WsResult<T> = Result<T, WsClientError>;
#[derive(Debug, thiserror::Error)]
pub enum WsClientError {
#[error("auth failed: {reason}")]
AuthFailed { reason: String },
#[error("connection closed by server")]
ConnectionClosed,
#[error("ws transport error: {0}")]
Transport(String),
#[error("protocol error: {0}")]
Protocol(String),
}
impl From<TError> for WsClientError {
fn from(value: TError) -> Self {
match value {
TError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => {
WsClientError::AuthFailed {
reason: "401 on websocket upgrade".to_string(),
}
}
TError::ConnectionClosed | TError::AlreadyClosed => WsClientError::ConnectionClosed,
other => WsClientError::Transport(other.to_string()),
}
}
}
fn build_connect_url(base_url: &str, worker_id: &str) -> WsResult<Url> {
let mut url = Url::parse(base_url)
.map_err(|e| WsClientError::Transport(format!("invalid base url: {e}")))?;
let new_scheme = match url.scheme() {
"http" => Some("ws"),
"https" => Some("wss"),
"ws" | "wss" => None, other => {
return Err(WsClientError::Transport(format!(
"unsupported scheme: {other}"
)))
}
};
if let Some(scheme) = new_scheme {
url.set_scheme(scheme)
.map_err(|_| WsClientError::Transport("set_scheme failed".to_string()))?;
}
let trimmed_path = url.path().trim_end_matches('/');
let prefixed = if trimmed_path.ends_with(API_PREFIX) {
trimmed_path.to_string()
} else {
format!("{trimmed_path}{API_PREFIX}")
};
let new_path = format!("{prefixed}/workers/{worker_id}/connect");
url.set_path(&new_path);
Ok(url)
}
pub async fn connect(base_url: &str, worker_id: &str, auth_token: &str) -> WsResult<WsClient> {
let started = Instant::now();
let result = connect_inner(base_url, worker_id, auth_token, CONNECT_TIMEOUT).await;
let elapsed_ms = started.elapsed().as_millis() as u64;
match &result {
Ok(_) => debug!(
target: TRACE_TARGET,
op = "connect",
worker_id,
elapsed_ms,
"websocket established"
),
Err(e) => warn!(
target: TRACE_TARGET,
op = "connect",
worker_id,
elapsed_ms,
error = %e,
"websocket connect failed"
),
}
result
}
async fn connect_inner(
base_url: &str,
worker_id: &str,
auth_token: &str,
connect_timeout: Duration,
) -> WsResult<WsClient> {
let url = build_connect_url(base_url, worker_id)?;
debug!(
target: TRACE_TARGET,
op = "connect",
worker_id,
url = %url,
"opening websocket"
);
let mut request = url
.as_str()
.into_client_request()
.map_err(WsClientError::from)?;
let headers = request.headers_mut();
headers.insert(
"Authorization",
HeaderValue::try_from(format!("Bearer {auth_token}"))
.map_err(|e| WsClientError::Transport(format!("invalid auth header: {e}")))?,
);
headers.insert(
"Sec-WebSocket-Protocol",
HeaderValue::from_static(SUBPROTOCOL),
);
let (stream, _response) = match tokio::time::timeout(
connect_timeout,
tokio_tungstenite::connect_async(request),
)
.await
{
Ok(result) => result?,
Err(_elapsed) => {
return Err(WsClientError::Transport(format!(
"connect timed out after {connect_timeout:?}"
)))
}
};
let (sink, source) = stream.split();
Ok(WsClient {
sink,
source,
closed: false,
})
}
type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
type WsSource = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
#[allow(missing_debug_implementations)]
pub struct WsClient {
sink: WsSink,
source: WsSource,
closed: bool,
}
impl WsClient {
pub fn split(self) -> (WsSender, WsReceiver) {
let sink = Arc::new(Mutex::new(self.sink));
(
WsSender { sink },
WsReceiver {
source: self.source,
closed: false,
},
)
}
}
#[derive(Clone)]
#[allow(missing_debug_implementations)]
pub struct WsSender {
sink: Arc<Mutex<WsSink>>,
}
impl WsSender {
pub async fn send(&self, frame: &WorkerInbound) -> WsResult<()> {
let text = serde_json::to_string(frame).map_err(|e| {
let err = WsClientError::Protocol(e.to_string());
log_send_error(frame, &err);
err
})?;
let mut guard = self.sink.lock().await;
guard.send(Message::Text(text.into())).await.map_err(|e| {
let err = WsClientError::from(e);
log_send_error(frame, &err);
err
})
}
pub async fn close(&self, code: u16, reason: &str) -> WsResult<()> {
debug!(target: TRACE_TARGET, op = "close", code, reason, "closing websocket");
let frame = CloseFrame {
code: CloseCode::from(code),
reason: reason.to_owned().into(),
};
let mut guard = self.sink.lock().await;
if tokio::time::timeout(
Duration::from_secs(5),
guard.send(Message::Close(Some(frame))),
)
.await
.is_err()
{
warn!(target: TRACE_TARGET, op = "close", code, "timed out sending close frame");
}
Ok(())
}
}
#[allow(missing_debug_implementations)]
pub struct WsReceiver {
source: WsSource,
closed: bool,
}
impl WsReceiver {
pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
if self.closed {
return Ok(None);
}
while let Some(item) = self.source.next().await {
match classify_incoming(item) {
RecvStep::Yield(frame) => return Ok(Some(frame)),
RecvStep::Skip => continue,
RecvStep::Fail(e) => return Err(e),
RecvStep::Closed(e) => {
self.closed = true;
return Err(e);
}
}
}
self.closed = true;
debug!(target: TRACE_TARGET, op = "recv", "stream ended (no close frame)");
Ok(None)
}
}
impl std::fmt::Debug for WsClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsClient")
.field("closed", &self.closed)
.finish()
}
}
impl WsClient {
pub async fn send(&mut self, frame: &WorkerInbound) -> WsResult<()> {
let text = serde_json::to_string(frame).map_err(|e| {
let err = WsClientError::Protocol(e.to_string());
log_send_error(frame, &err);
err
})?;
self.sink
.send(Message::Text(text.into()))
.await
.map_err(|e| {
let err = WsClientError::from(e);
log_send_error(frame, &err);
err
})
}
pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
if self.closed {
return Ok(None);
}
while let Some(item) = self.source.next().await {
match classify_incoming(item) {
RecvStep::Yield(frame) => return Ok(Some(frame)),
RecvStep::Skip => continue,
RecvStep::Fail(e) => return Err(e),
RecvStep::Closed(e) => {
self.closed = true;
return Err(e);
}
}
}
self.closed = true;
debug!(target: TRACE_TARGET, op = "recv", "stream ended (no close frame)");
Ok(None)
}
pub async fn close(&mut self, code: u16, reason: &str) -> WsResult<()> {
if self.closed {
return Ok(());
}
self.closed = true;
debug!(target: TRACE_TARGET, op = "close", code, reason, "closing websocket");
let frame = CloseFrame {
code: CloseCode::from(code),
reason: reason.to_owned().into(),
};
if tokio::time::timeout(
Duration::from_secs(5),
self.sink.send(Message::Close(Some(frame))),
)
.await
.is_err()
{
warn!(target: TRACE_TARGET, op = "close", code, "timed out sending close frame");
}
Ok(())
}
}
fn frame_label(frame: &WorkerInbound) -> &'static str {
match frame {
WorkerInbound::Hello(_) => "hello",
WorkerInbound::Heartbeat { .. } => "heartbeat",
WorkerInbound::Accept { .. } => "accept",
WorkerInbound::Reject { .. } => "reject",
WorkerInbound::CompleteJson { .. } => "completeJson",
WorkerInbound::Fail { .. } => "fail",
WorkerInbound::LogBatch { .. } => "logBatch",
WorkerInbound::ReadyForMore => "readyForMore",
}
}
fn log_send_error(frame: &WorkerInbound, err: &WsClientError) {
warn!(
target: TRACE_TARGET,
op = "send",
frame = frame_label(frame),
error = %err,
"failed to send frame"
);
}
enum RecvStep {
Yield(WorkerOutbound),
Skip,
Fail(WsClientError),
Closed(WsClientError),
}
fn classify_incoming(item: Result<Message, TError>) -> RecvStep {
match item {
Ok(Message::Text(text)) => match serde_json::from_str::<WorkerOutbound>(&text) {
Ok(frame) => RecvStep::Yield(frame),
Err(e) => {
warn!(
target: TRACE_TARGET,
op = "recv",
error = %e,
"dropping unparseable text frame"
);
RecvStep::Fail(WsClientError::Protocol(e.to_string()))
}
},
Ok(Message::Binary(_)) => {
warn!(
target: TRACE_TARGET,
op = "recv",
"rejecting unexpected binary frame"
);
RecvStep::Fail(WsClientError::Protocol(
"unexpected binary frame".to_string(),
))
}
Ok(Message::Close(frame)) => {
let err = close_frame_to_error(frame);
match &err {
WsClientError::AuthFailed { reason } => warn!(
target: TRACE_TARGET,
op = "recv",
reason = %reason,
"server closed connection: auth failed"
),
_ => debug!(
target: TRACE_TARGET,
op = "recv",
"server closed connection"
),
}
RecvStep::Closed(err)
}
Ok(_) => RecvStep::Skip,
Err(e) => {
let mapped = WsClientError::from(e);
match &mapped {
WsClientError::ConnectionClosed => debug!(
target: TRACE_TARGET,
op = "recv",
"connection closed by peer"
),
other => warn!(
target: TRACE_TARGET,
op = "recv",
error = %other,
"transport error while reading frame"
),
}
RecvStep::Fail(mapped)
}
}
}
fn close_frame_to_error(frame: Option<CloseFrame>) -> WsClientError {
if let Some(frame) = frame {
let code: u16 = frame.code.into();
if code == 4001 {
return WsClientError::AuthFailed {
reason: format!("server closed 4001: {}", frame.reason),
};
}
}
WsClientError::ConnectionClosed
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_connect_url_http_to_ws() {
let url = build_connect_url("http://api.example/graphics/api", "w-1").unwrap();
assert_eq!(url.scheme(), "ws");
assert!(url.path().ends_with("/workers/w-1/connect"));
}
#[test]
fn build_connect_url_https_to_wss() {
let url = build_connect_url("https://api.example/graphics/api/", "w-2").unwrap();
assert_eq!(url.scheme(), "wss");
assert_eq!(url.path(), "/graphics/api/workers/w-2/connect");
}
#[test]
fn build_connect_url_appends_graphics_api_prefix_when_missing() {
let url = build_connect_url("http://localhost:9790", "w-3").unwrap();
assert_eq!(url.scheme(), "ws");
assert_eq!(url.path(), "/graphics/api/workers/w-3/connect");
}
#[test]
fn build_connect_url_preserves_existing_ws_scheme() {
let url = build_connect_url("ws://localhost:9790/x", "w").unwrap();
assert_eq!(url.scheme(), "ws");
}
#[test]
fn build_connect_url_rejects_unknown_scheme() {
let err = build_connect_url("ftp://nope/", "w").unwrap_err();
assert!(matches!(err, WsClientError::Transport(_)));
}
#[test]
fn build_connect_url_rejects_invalid_url() {
let err = build_connect_url("not a url", "w").unwrap_err();
assert!(matches!(err, WsClientError::Transport(_)));
}
#[test]
fn close_frame_4001_maps_to_auth_failed() {
let frame = CloseFrame {
code: CloseCode::Library(4001),
reason: "bad token".into(),
};
let err = close_frame_to_error(Some(frame));
assert!(matches!(err, WsClientError::AuthFailed { .. }));
}
#[test]
fn close_frame_other_codes_map_to_connection_closed() {
let frame = CloseFrame {
code: CloseCode::Normal,
reason: "bye".into(),
};
let err = close_frame_to_error(Some(frame));
assert!(matches!(err, WsClientError::ConnectionClosed));
}
#[test]
fn close_frame_missing_maps_to_connection_closed() {
let err = close_frame_to_error(None);
assert!(matches!(err, WsClientError::ConnectionClosed));
}
#[test]
fn transport_error_round_trips_through_from_impl() {
let inner = TError::AlreadyClosed;
let mapped: WsClientError = inner.into();
assert!(matches!(mapped, WsClientError::ConnectionClosed));
}
use crate::test_support::capture;
#[test]
fn classify_rejects_binary_frame_with_warn() {
let logs = capture(|| {
let step = classify_incoming(Ok(Message::Binary(vec![1, 2, 3].into())));
assert!(matches!(step, RecvStep::Fail(WsClientError::Protocol(_))));
});
assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
assert!(
logs.contains("studio_worker::ws::client"),
"expected target, got: {logs}"
);
assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
assert!(logs.contains("binary"), "expected reason: {logs}");
}
#[test]
fn classify_warns_on_unparseable_text_frame() {
let logs = capture(|| {
let step = classify_incoming(Ok(Message::Text("not json".into())));
assert!(matches!(step, RecvStep::Fail(WsClientError::Protocol(_))));
});
assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
assert!(logs.contains("op=\"recv\""), "expected op field: {logs}");
}
#[test]
fn classify_warns_on_4001_close_frame() {
let logs = capture(|| {
let frame = CloseFrame {
code: CloseCode::Library(4001),
reason: "invalid auth token".into(),
};
let step = classify_incoming(Ok(Message::Close(Some(frame))));
assert!(matches!(
step,
RecvStep::Closed(WsClientError::AuthFailed { .. })
));
});
assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
assert!(logs.contains("auth failed"), "expected reason: {logs}");
}
#[test]
fn classify_debug_logs_on_normal_close_frame() {
let logs = capture(|| {
let frame = CloseFrame {
code: CloseCode::Normal,
reason: "bye".into(),
};
let step = classify_incoming(Ok(Message::Close(Some(frame))));
assert!(matches!(
step,
RecvStep::Closed(WsClientError::ConnectionClosed)
));
});
assert!(logs.contains("DEBUG"), "expected DEBUG, got: {logs}");
assert!(!logs.contains("WARN"), "normal close must not warn: {logs}");
assert!(logs.contains("server closed"), "expected message: {logs}");
}
#[test]
fn classify_yields_valid_frame_without_warning() {
let logs = capture(|| {
let json = serde_json::json!({ "type": "heartbeatAck" }).to_string();
let step = classify_incoming(Ok(Message::Text(json.into())));
assert!(matches!(
step,
RecvStep::Yield(WorkerOutbound::HeartbeatAck)
));
});
assert!(
!logs.contains("WARN"),
"a valid frame should not warn: {logs}"
);
}
#[test]
fn classify_skips_control_frames() {
assert!(matches!(
classify_incoming(Ok(Message::Ping(Vec::new().into()))),
RecvStep::Skip
));
assert!(matches!(
classify_incoming(Ok(Message::Pong(Vec::new().into()))),
RecvStep::Skip
));
}
#[test]
fn frame_label_names_every_inbound_variant() {
use crate::types::WorkerCapabilities;
let caps = WorkerCapabilities {
machine_name: String::new(),
username: String::new(),
agent_version: String::new(),
engine: String::new(),
vram_total_gb: 0.0,
vram_threshold_gb: 0.0,
auto_enabled: false,
auto_start: false,
supported_models: vec![],
task_kinds: vec![],
supported_models_per_kind: Default::default(),
};
assert_eq!(
frame_label(&WorkerInbound::Hello(crate::ws::types::HelloFrame {
auth_token: String::new(),
capabilities: caps.clone(),
})),
"hello"
);
assert_eq!(
frame_label(&WorkerInbound::Heartbeat {
capabilities: caps,
current_job_id: None,
}),
"heartbeat"
);
assert_eq!(
frame_label(&WorkerInbound::Accept { job_id: "j".into() }),
"accept"
);
assert_eq!(
frame_label(&WorkerInbound::Reject {
job_id: "j".into(),
reason: "r".into(),
}),
"reject"
);
assert_eq!(
frame_label(&WorkerInbound::CompleteJson {
job_id: "j".into(),
result: serde_json::Value::Null,
prompt: None,
}),
"completeJson"
);
assert_eq!(
frame_label(&WorkerInbound::Fail {
job_id: "j".into(),
error: "e".into(),
retryable: true,
}),
"fail"
);
assert_eq!(
frame_label(&WorkerInbound::LogBatch { entries: vec![] }),
"logBatch"
);
assert_eq!(frame_label(&WorkerInbound::ReadyForMore), "readyForMore");
}
#[test]
fn send_error_logs_warn_with_frame_label() {
let logs = capture(|| {
log_send_error(
&WorkerInbound::Accept {
job_id: "j-1".into(),
},
&WsClientError::ConnectionClosed,
);
});
assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
assert!(logs.contains("op=\"send\""), "expected op field: {logs}");
assert!(
logs.contains("frame=\"accept\""),
"expected frame label: {logs}"
);
}
#[tokio::test]
async fn connect_times_out_against_a_stalling_upgrade() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let _accepted = listener.accept().await; tokio::time::sleep(Duration::from_secs(30)).await;
});
let url = format!("http://{addr}/graphics/api");
let started = Instant::now();
let result = connect_inner(&url, "w", "tok", Duration::from_millis(150)).await;
assert!(
matches!(result, Err(WsClientError::Transport(_))),
"expected a transport timeout, got {result:?}"
);
assert!(
started.elapsed() < Duration::from_secs(2),
"connect must time out promptly, took {:?}",
started.elapsed()
);
}
#[test]
fn connect_failure_logs_warn_breadcrumb() {
let logs = capture(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let result = rt.block_on(connect("http://127.0.0.1:1/graphics/api", "w-err", "tok"));
assert!(result.is_err(), "connect to a dead port should fail");
});
assert!(logs.contains("WARN"), "expected WARN, got: {logs}");
assert!(logs.contains("op=\"connect\""), "expected op field: {logs}");
assert!(
logs.contains("websocket connect failed"),
"expected message: {logs}"
);
assert!(
logs.contains("worker_id=\"w-err\""),
"expected worker_id field: {logs}"
);
}
}