use std::convert::TryFrom;
use std::time::Duration;
use std::sync::Arc;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::http::StatusCode;
use tokio_tungstenite::tungstenite::protocol::{frame::coding::CloseCode, CloseFrame};
use tokio_tungstenite::tungstenite::{Error as TError, Message};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use url::Url;
use crate::ws::types::{WorkerInbound, WorkerOutbound};
pub const SUBPROTOCOL: &str = "studio-worker-v1";
const API_PREFIX: &str = "/graphics/api";
pub type WsResult<T> = Result<T, WsClientError>;
#[derive(Debug, thiserror::Error)]
pub enum WsClientError {
#[error("auth failed: {reason}")]
AuthFailed { reason: String },
#[error("connection closed by server")]
ConnectionClosed,
#[error("ws transport error: {0}")]
Transport(String),
#[error("protocol error: {0}")]
Protocol(String),
}
impl From<TError> for WsClientError {
fn from(value: TError) -> Self {
match value {
TError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => {
WsClientError::AuthFailed {
reason: "401 on websocket upgrade".to_string(),
}
}
TError::ConnectionClosed | TError::AlreadyClosed => WsClientError::ConnectionClosed,
other => WsClientError::Transport(other.to_string()),
}
}
}
fn build_connect_url(base_url: &str, worker_id: &str) -> WsResult<Url> {
let mut url = Url::parse(base_url)
.map_err(|e| WsClientError::Transport(format!("invalid base url: {e}")))?;
let new_scheme = match url.scheme() {
"http" => Some("ws"),
"https" => Some("wss"),
"ws" | "wss" => None, other => {
return Err(WsClientError::Transport(format!(
"unsupported scheme: {other}"
)))
}
};
if let Some(scheme) = new_scheme {
url.set_scheme(scheme)
.map_err(|_| WsClientError::Transport("set_scheme failed".to_string()))?;
}
let trimmed_path = url.path().trim_end_matches('/');
let prefixed = if trimmed_path.ends_with(API_PREFIX) {
trimmed_path.to_string()
} else {
format!("{trimmed_path}{API_PREFIX}")
};
let new_path = format!("{prefixed}/workers/{worker_id}/connect");
url.set_path(&new_path);
Ok(url)
}
pub async fn connect(base_url: &str, worker_id: &str, auth_token: &str) -> WsResult<WsClient> {
let url = build_connect_url(base_url, worker_id)?;
let mut request = url
.as_str()
.into_client_request()
.map_err(WsClientError::from)?;
let headers = request.headers_mut();
headers.insert(
"Authorization",
HeaderValue::try_from(format!("Bearer {auth_token}"))
.map_err(|e| WsClientError::Transport(format!("invalid auth header: {e}")))?,
);
headers.insert(
"Sec-WebSocket-Protocol",
HeaderValue::from_static(SUBPROTOCOL),
);
let (stream, _response) = tokio_tungstenite::connect_async(request).await?;
let (sink, source) = stream.split();
Ok(WsClient {
sink,
source,
closed: false,
})
}
type WsSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
type WsSource = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
#[allow(missing_debug_implementations)]
pub struct WsClient {
sink: WsSink,
source: WsSource,
closed: bool,
}
impl WsClient {
pub fn split(self) -> (WsSender, WsReceiver) {
let sink = Arc::new(Mutex::new(self.sink));
(
WsSender { sink },
WsReceiver {
source: self.source,
closed: false,
},
)
}
}
#[derive(Clone)]
#[allow(missing_debug_implementations)]
pub struct WsSender {
sink: Arc<Mutex<WsSink>>,
}
impl WsSender {
pub async fn send(&self, frame: &WorkerInbound) -> WsResult<()> {
let text =
serde_json::to_string(frame).map_err(|e| WsClientError::Protocol(e.to_string()))?;
let mut guard = self.sink.lock().await;
guard
.send(Message::Text(text.into()))
.await
.map_err(WsClientError::from)
}
pub async fn close(&self, code: u16, reason: &str) -> WsResult<()> {
let frame = CloseFrame {
code: CloseCode::from(code),
reason: reason.to_owned().into(),
};
let mut guard = self.sink.lock().await;
let _ = tokio::time::timeout(
Duration::from_secs(5),
guard.send(Message::Close(Some(frame))),
)
.await;
Ok(())
}
}
#[allow(missing_debug_implementations)]
pub struct WsReceiver {
source: WsSource,
closed: bool,
}
impl WsReceiver {
pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
if self.closed {
return Ok(None);
}
while let Some(item) = self.source.next().await {
match item {
Ok(Message::Text(text)) => {
let frame: WorkerOutbound = serde_json::from_str(&text)
.map_err(|e| WsClientError::Protocol(e.to_string()))?;
return Ok(Some(frame));
}
Ok(Message::Binary(_)) => {
return Err(WsClientError::Protocol(
"unexpected binary frame".to_string(),
));
}
Ok(Message::Close(frame)) => {
self.closed = true;
return Err(close_frame_to_error(frame));
}
Ok(_) => continue,
Err(e) => return Err(WsClientError::from(e)),
}
}
self.closed = true;
Ok(None)
}
}
impl std::fmt::Debug for WsClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WsClient")
.field("closed", &self.closed)
.finish()
}
}
impl WsClient {
pub async fn send(&mut self, frame: &WorkerInbound) -> WsResult<()> {
let text =
serde_json::to_string(frame).map_err(|e| WsClientError::Protocol(e.to_string()))?;
self.sink
.send(Message::Text(text.into()))
.await
.map_err(WsClientError::from)
}
pub async fn recv(&mut self) -> WsResult<Option<WorkerOutbound>> {
if self.closed {
return Ok(None);
}
while let Some(item) = self.source.next().await {
match item {
Ok(Message::Text(text)) => {
let frame: WorkerOutbound = serde_json::from_str(&text)
.map_err(|e| WsClientError::Protocol(e.to_string()))?;
return Ok(Some(frame));
}
Ok(Message::Binary(_)) => {
return Err(WsClientError::Protocol(
"unexpected binary frame".to_string(),
));
}
Ok(Message::Close(frame)) => {
self.closed = true;
return Err(close_frame_to_error(frame));
}
Ok(_) => continue, Err(e) => return Err(WsClientError::from(e)),
}
}
self.closed = true;
Ok(None)
}
pub async fn close(&mut self, code: u16, reason: &str) -> WsResult<()> {
if self.closed {
return Ok(());
}
self.closed = true;
let frame = CloseFrame {
code: CloseCode::from(code),
reason: reason.to_owned().into(),
};
let _ = tokio::time::timeout(
Duration::from_secs(5),
self.sink.send(Message::Close(Some(frame))),
)
.await;
Ok(())
}
}
fn close_frame_to_error(frame: Option<CloseFrame>) -> WsClientError {
if let Some(frame) = frame {
let code: u16 = frame.code.into();
if code == 4001 {
return WsClientError::AuthFailed {
reason: format!("server closed 4001: {}", frame.reason),
};
}
}
WsClientError::ConnectionClosed
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_connect_url_http_to_ws() {
let url = build_connect_url("http://api.example/graphics/api", "w-1").unwrap();
assert_eq!(url.scheme(), "ws");
assert!(url.path().ends_with("/workers/w-1/connect"));
}
#[test]
fn build_connect_url_https_to_wss() {
let url = build_connect_url("https://api.example/graphics/api/", "w-2").unwrap();
assert_eq!(url.scheme(), "wss");
assert_eq!(url.path(), "/graphics/api/workers/w-2/connect");
}
#[test]
fn build_connect_url_appends_graphics_api_prefix_when_missing() {
let url = build_connect_url("http://localhost:9790", "w-3").unwrap();
assert_eq!(url.scheme(), "ws");
assert_eq!(url.path(), "/graphics/api/workers/w-3/connect");
}
#[test]
fn build_connect_url_preserves_existing_ws_scheme() {
let url = build_connect_url("ws://localhost:9790/x", "w").unwrap();
assert_eq!(url.scheme(), "ws");
}
#[test]
fn build_connect_url_rejects_unknown_scheme() {
let err = build_connect_url("ftp://nope/", "w").unwrap_err();
assert!(matches!(err, WsClientError::Transport(_)));
}
#[test]
fn build_connect_url_rejects_invalid_url() {
let err = build_connect_url("not a url", "w").unwrap_err();
assert!(matches!(err, WsClientError::Transport(_)));
}
#[test]
fn close_frame_4001_maps_to_auth_failed() {
let frame = CloseFrame {
code: CloseCode::Library(4001),
reason: "bad token".into(),
};
let err = close_frame_to_error(Some(frame));
assert!(matches!(err, WsClientError::AuthFailed { .. }));
}
#[test]
fn close_frame_other_codes_map_to_connection_closed() {
let frame = CloseFrame {
code: CloseCode::Normal,
reason: "bye".into(),
};
let err = close_frame_to_error(Some(frame));
assert!(matches!(err, WsClientError::ConnectionClosed));
}
#[test]
fn close_frame_missing_maps_to_connection_closed() {
let err = close_frame_to_error(None);
assert!(matches!(err, WsClientError::ConnectionClosed));
}
#[test]
fn transport_error_round_trips_through_from_impl() {
let inner = TError::AlreadyClosed;
let mapped: WsClientError = inner.into();
assert!(matches!(mapped, WsClientError::ConnectionClosed));
}
}