#[cfg(not(feature = "browser"))]
use bytes::Bytes;
#[cfg(not(feature = "browser"))]
use futures::TryStreamExt;
use futures::{SinkExt, StreamExt as _};
use futures_channel::mpsc;
use http::uri::{InvalidUri, Scheme, Uri};
use spacetimedb_client_api_messages::websocket as ws;
use spacetimedb_lib::{bsatn, ConnectionId};
#[cfg(not(feature = "browser"))]
use std::fs::File;
#[cfg(not(feature = "browser"))]
use std::io::Write;
#[cfg(not(feature = "browser"))]
use std::mem;
use std::sync::Arc;
#[cfg(not(feature = "browser"))]
use std::sync::Mutex;
#[cfg(not(feature = "browser"))]
use std::time::Duration;
use thiserror::Error;
#[cfg(not(feature = "browser"))]
use tokio::{net::TcpStream, runtime, task::JoinHandle, time::Instant};
#[cfg(not(feature = "browser"))]
use tokio_tungstenite::{
connect_async_with_config,
tungstenite::client::IntoClientRequest,
tungstenite::protocol::{Message as WebSocketMessage, WebSocketConfig},
MaybeTlsStream, WebSocketStream,
};
#[cfg(feature = "browser")]
use tokio_tungstenite_wasm::{Message as WebSocketMessage, WebSocketStream};
use crate::compression::decompress_server_message;
#[cfg(not(feature = "browser"))]
use crate::db_connection::debug_log;
use crate::metrics::CLIENT_METRICS;
#[cfg(not(feature = "browser"))]
type TokioTungsteniteError = tokio_tungstenite::tungstenite::Error;
#[cfg(feature = "browser")]
type TokioTungsteniteError = tokio_tungstenite_wasm::Error;
#[derive(Error, Debug, Clone)]
pub enum UriError {
#[error("Unknown URI scheme {scheme}, expected http, https, ws or wss")]
UnknownUriScheme { scheme: String },
#[error("Expected a URI without a query part, but found {query}")]
UnexpectedQuery { query: String },
#[error(transparent)]
InvalidUri {
source: Arc<http::uri::InvalidUri>,
},
#[error(transparent)]
InvalidUriParts {
source: Arc<http::uri::InvalidUriParts>,
},
}
#[derive(Error, Debug, Clone)]
pub enum WsError {
#[error(transparent)]
UriError(#[from] UriError),
#[error("Error in WebSocket connection with {uri}: {source}")]
Tungstenite {
uri: Uri,
#[source]
source: Arc<TokioTungsteniteError>,
},
#[error("Received empty raw message, but valid messages always start with a one-byte compression flag")]
EmptyMessage,
#[error("Failed to deserialize WebSocket message: {source}")]
DeserializeMessage {
#[source]
source: bsatn::DecodeError,
},
#[error("Failed to decompress WebSocket message with {scheme}: {source}")]
Decompress {
scheme: &'static str,
#[source]
source: Arc<std::io::Error>,
},
#[error("Unrecognized compression scheme: {scheme:#x}")]
UnknownCompressionScheme { scheme: u8 },
#[cfg(feature = "browser")]
#[error("Token verification error: {0}")]
TokenVerification(String),
}
pub(crate) struct WsConnection {
db_name: Box<str>,
#[cfg(not(feature = "browser"))]
sock: WebSocketStream<MaybeTlsStream<TcpStream>>,
#[cfg(feature = "browser")]
sock: WebSocketStream,
}
fn parse_scheme(scheme: Option<Scheme>) -> Result<Scheme, UriError> {
Ok(match scheme {
Some(s) => match s.as_str() {
"ws" | "wss" => s,
"http" => "ws".parse().unwrap(),
"https" => "wss".parse().unwrap(),
unknown_scheme => {
return Err(UriError::UnknownUriScheme {
scheme: unknown_scheme.into(),
})
}
},
None => "ws".parse().unwrap(),
})
}
#[derive(Clone, Copy, Default)]
pub(crate) struct WsParams {
pub compression: ws::common::Compression,
pub confirmed: Option<bool>,
}
#[cfg(not(feature = "browser"))]
fn make_uri(host: Uri, db_name: &str, connection_id: Option<ConnectionId>, params: WsParams) -> Result<Uri, UriError> {
make_uri_impl(host, db_name, connection_id, params, None)
}
#[cfg(feature = "browser")]
fn make_uri(
host: Uri,
db_name: &str,
connection_id: Option<ConnectionId>,
params: WsParams,
token: Option<&str>,
) -> Result<Uri, UriError> {
make_uri_impl(host, db_name, connection_id, params, token)
}
fn make_uri_impl(
host: Uri,
db_name: &str,
connection_id: Option<ConnectionId>,
params: WsParams,
token: Option<&str>,
) -> Result<Uri, UriError> {
let mut parts = host.into_parts();
let scheme = parse_scheme(parts.scheme.take())?;
parts.scheme = Some(scheme);
let mut path = if let Some(path_and_query) = parts.path_and_query {
if let Some(query) = path_and_query.query() {
return Err(UriError::UnexpectedQuery { query: query.into() });
}
path_and_query.path().to_string()
} else {
"/".to_string()
};
if !path.ends_with('/') {
path.push('/');
}
path.push_str("v1/database/");
path.push_str(db_name);
path.push_str("/subscribe");
match params.compression {
ws::common::Compression::None => path.push_str("?compression=None"),
ws::common::Compression::Gzip => path.push_str("?compression=Gzip"),
ws::common::Compression::Brotli => path.push_str("?compression=Brotli"),
};
if let Some(cid) = connection_id {
path.push_str("&connection_id=");
path.push_str(&cid.to_hex());
}
if let Some(confirmed) = params.confirmed {
path.push_str("&confirmed=");
path.push_str(if confirmed { "true" } else { "false" });
}
if let Some(token) = token {
path.push_str(&format!("&token={token}"));
}
parts.path_and_query = Some(path.parse().map_err(|source: InvalidUri| UriError::InvalidUri {
source: Arc::new(source),
})?);
Uri::from_parts(parts).map_err(|source| UriError::InvalidUriParts {
source: Arc::new(source),
})
}
#[cfg(not(feature = "browser"))]
fn make_request(
host: Uri,
db_name: &str,
token: Option<&str>,
connection_id: Option<ConnectionId>,
params: WsParams,
) -> Result<http::Request<()>, WsError> {
let uri = make_uri(host, db_name, connection_id, params)?;
let mut req = IntoClientRequest::into_client_request(uri.clone()).map_err(|source| WsError::Tungstenite {
uri,
source: Arc::new(source),
})?;
request_insert_protocol_header(&mut req);
request_insert_auth_header(&mut req, token);
Ok(req)
}
#[cfg(not(feature = "browser"))]
fn request_insert_protocol_header(req: &mut http::Request<()>) {
req.headers_mut().insert(
http::header::SEC_WEBSOCKET_PROTOCOL,
const { http::HeaderValue::from_static(ws::v2::BIN_PROTOCOL) },
);
}
#[cfg(not(feature = "browser"))]
fn request_insert_auth_header(req: &mut http::Request<()>, token: Option<&str>) {
if let Some(token) = token {
let auth = ["Bearer ", token].concat().try_into().unwrap();
req.headers_mut().insert(http::header::AUTHORIZATION, auth);
}
}
#[cfg(feature = "browser")]
async fn fetch_ws_token(host: &Uri, auth_token: &str) -> Result<String, WsError> {
use gloo_net::http::{Method, RequestBuilder};
use js_sys::{Reflect, JSON};
use wasm_bindgen::{JsCast, JsValue};
let url = format!("{host}v1/identity/websocket-token");
let gloo_to_ws_err = |e: gloo_net::Error| match e {
gloo_net::Error::JsError(js_err) => WsError::TokenVerification(js_err.message),
gloo_net::Error::SerdeError(e) => WsError::TokenVerification(e.to_string()),
gloo_net::Error::GlooError(msg) => WsError::TokenVerification(msg),
};
let js_to_ws_err = |e: JsValue| {
if let Some(err) = e.dyn_ref::<js_sys::Error>() {
WsError::TokenVerification(err.message().into())
} else if let Some(s) = e.as_string() {
WsError::TokenVerification(s)
} else {
WsError::TokenVerification(format!("{e:?}"))
}
};
let res = RequestBuilder::new(&url)
.method(Method::POST)
.header("Authorization", &format!("Bearer {auth_token}"))
.send()
.await
.map_err(gloo_to_ws_err)?;
if !res.ok() {
return Err(WsError::TokenVerification(format!(
"HTTP error: {} {}",
res.status(),
res.status_text()
)));
}
let body = res.text().await.map_err(gloo_to_ws_err)?;
let json = JSON::parse(&body).map_err(js_to_ws_err)?;
let token_js = Reflect::get(&json, &JsValue::from_str("token")).map_err(js_to_ws_err)?;
token_js
.as_string()
.ok_or_else(|| WsError::TokenVerification("`token` parsing failed".into()))
}
#[cfg(not(feature = "browser"))]
macro_rules! maybe_log_error {
($extra_logging:expr, $cause:expr, $res:expr) => {
if let Err(e) = $res {
let cause = $cause;
debug_log($extra_logging, |file| writeln!(file, "{}: {:?}", cause, e));
log::warn!("{}: {:?}", cause, e);
}
};
}
impl WsConnection {
#[cfg(not(feature = "browser"))]
pub(crate) async fn connect(
host: Uri,
db_name: &str,
token: Option<&str>,
connection_id: Option<ConnectionId>,
params: WsParams,
) -> Result<Self, WsError> {
let req = make_request(host, db_name, token, connection_id, params)?;
let uri = req.uri().clone();
let (sock, _): (WebSocketStream<MaybeTlsStream<TcpStream>>, _) = connect_async_with_config(
req,
Some(WebSocketConfig::default().max_frame_size(None).max_message_size(None)),
false,
)
.await
.map_err(|source| WsError::Tungstenite {
uri,
source: Arc::new(source),
})?;
Ok(WsConnection {
db_name: db_name.into(),
sock,
})
}
#[cfg(feature = "browser")]
pub(crate) async fn connect(
host: Uri,
db_name: &str,
token: Option<&str>,
connection_id: Option<ConnectionId>,
params: WsParams,
) -> Result<Self, WsError> {
let token = if let Some(auth_token) = token {
Some(fetch_ws_token(&host, auth_token).await?)
} else {
None
};
let uri = make_uri(host, db_name, connection_id, params, token.as_deref())?;
let sock = tokio_tungstenite_wasm::connect_with_protocols(&uri.to_string(), &[ws::v2::BIN_PROTOCOL])
.await
.map_err(|source| WsError::Tungstenite {
uri,
source: Arc::new(source),
})?;
Ok(WsConnection {
db_name: db_name.into(),
sock,
})
}
pub(crate) fn parse_response(bytes: &[u8]) -> Result<ws::v2::ServerMessage, WsError> {
let bytes = &*decompress_server_message(bytes)?;
bsatn::from_slice(bytes).map_err(|source| WsError::DeserializeMessage { source })
}
pub(crate) fn encode_message(msg: ws::v2::ClientMessage) -> WebSocketMessage {
WebSocketMessage::Binary(bsatn::to_vec(&msg).unwrap().into())
}
#[cfg(not(feature = "browser"))]
async fn message_loop(
mut self,
incoming_messages: mpsc::UnboundedSender<ws::v2::ServerMessage>,
outgoing_messages: mpsc::UnboundedReceiver<ws::v2::ClientMessage>,
extra_logging: Option<Arc<Mutex<File>>>,
) {
let websocket_received = CLIENT_METRICS.websocket_received.with_label_values(&self.db_name);
let websocket_received_msg_size = CLIENT_METRICS
.websocket_received_msg_size
.with_label_values(&self.db_name);
let record_metrics = |msg_size: usize| {
websocket_received.inc();
websocket_received_msg_size.observe(msg_size as f64);
};
const IDLE_TIMEOUT: Duration = Duration::from_secs(30);
let mut idle_timeout_interval = tokio::time::interval_at(Instant::now() + IDLE_TIMEOUT, IDLE_TIMEOUT);
let mut idle = true;
let mut want_pong = false;
let mut outgoing_messages = Some(outgoing_messages);
loop {
tokio::select! {
incoming = self.sock.try_next() => match incoming {
Err(tokio_tungstenite::tungstenite::error::Error::ConnectionClosed) | Ok(None) => {
log::info!("Connection closed");
break;
},
Err(e) => {
maybe_log_error!(
&extra_logging,
"Error reading message from read WebSocket stream",
Result::<(), _>::Err(e)
);
break;
},
Ok(Some(WebSocketMessage::Binary(bytes))) => {
idle = false;
record_metrics(bytes.len());
match Self::parse_response(&bytes) {
Err(e) => maybe_log_error!(
&extra_logging,
"Error decoding WebSocketMessage::Binary payload",
Result::<(), _>::Err(e)
),
Ok(msg) => maybe_log_error!(
&extra_logging,
"Error sending decoded message to incoming_messages queue",
incoming_messages.unbounded_send(msg)
),
}
}
Ok(Some(WebSocketMessage::Ping(payload))) => {
log::trace!("received ping");
idle = false;
record_metrics(payload.len());
},
Ok(Some(WebSocketMessage::Pong(payload))) => {
log::trace!("received pong");
idle = false;
want_pong = false;
record_metrics(payload.len());
},
Ok(Some(other)) => {
debug_log(&extra_logging, |file| writeln!(file, "Unexpeccted WebSocket message {other:?}"));
log::warn!("Unexpected WebSocket message {other:?}");
idle = false;
record_metrics(other.len());
},
},
_ = idle_timeout_interval.tick() => {
if mem::replace(&mut idle, true) {
if want_pong {
debug_log(&extra_logging, |file| writeln!(file, "Connection timed out"));
log::warn!("Connection timed out");
break;
}
log::trace!("sending client ping");
let ping = WebSocketMessage::Ping(Bytes::new());
if let Err(e) = self.sock.send(ping).await {
debug_log(&extra_logging, |file| writeln!(file, "Error sending ping: {e:?}"));
log::warn!("Error sending ping: {e:?}");
break;
}
want_pong = true;
}
},
Some(outgoing) = async { Some(outgoing_messages.as_mut()?.next().await) } => match outgoing {
Some(outgoing) => {
let msg = Self::encode_message(outgoing);
if let Err(e) = self.sock.send(msg).await {
debug_log(&extra_logging, |file| writeln!(file, "Error sending outgoing message: {e:?}"));
log::warn!("Error sending outgoing message: {e:?}");
break;
}
}
None => {
maybe_log_error!(&extra_logging, "Error sending close frame", SinkExt::close(&mut self.sock).await);
outgoing_messages = None;
}
},
}
}
}
#[cfg(not(feature = "browser"))]
pub(crate) fn spawn_message_loop(
self,
runtime: &runtime::Handle,
extra_logging: Option<Arc<Mutex<File>>>,
) -> (
JoinHandle<()>,
mpsc::UnboundedReceiver<ws::v2::ServerMessage>,
mpsc::UnboundedSender<ws::v2::ClientMessage>,
) {
let (outgoing_send, outgoing_recv) = mpsc::unbounded();
let (incoming_send, incoming_recv) = mpsc::unbounded();
let handle = runtime.spawn(self.message_loop(incoming_send, outgoing_recv, extra_logging));
(handle, incoming_recv, outgoing_send)
}
#[cfg(feature = "browser")]
pub(crate) fn spawn_message_loop(
self,
) -> (
mpsc::UnboundedReceiver<ws::v2::ServerMessage>,
mpsc::UnboundedSender<ws::v2::ClientMessage>,
) {
let websocket_received = CLIENT_METRICS.websocket_received.with_label_values(&self.db_name);
let websocket_received_msg_size = CLIENT_METRICS
.websocket_received_msg_size
.with_label_values(&self.db_name);
let record_metrics = move |msg_size: usize| {
websocket_received.inc();
websocket_received_msg_size.observe(msg_size as f64);
};
let (outgoing_tx, outgoing_rx) = mpsc::unbounded::<ws::v2::ClientMessage>();
let (incoming_tx, incoming_rx) = mpsc::unbounded::<ws::v2::ServerMessage>();
let (mut ws_writer, ws_reader) = self.sock.split();
wasm_bindgen_futures::spawn_local(async move {
let mut incoming = ws_reader.fuse();
let mut outgoing = outgoing_rx.fuse();
loop {
futures::select! {
inbound = incoming.next() => match inbound {
Some(Err(tokio_tungstenite_wasm::Error::ConnectionClosed)) | None => {
gloo_console::log!("Connection closed");
break;
},
Some(Ok(WebSocketMessage::Binary(bytes))) => {
record_metrics(bytes.len());
match Self::parse_response(&bytes) {
Ok(msg) => if let Err(_e) = incoming_tx.unbounded_send(msg) {
gloo_console::warn!("Incoming receiver dropped.");
break;
},
Err(e) => {
gloo_console::warn!(
"Error decoding WebSocketMessage::Binay payload: ",
format!("{:?}", e)
);
},
}
},
Some(Ok(WebSocketMessage::Close(r))) => {
let reason: String = if let Some(r) = r {
format!("{}:{:?}", r, r.code)
} else {String::default()};
gloo_console::warn!("Connection Closed.", reason);
let _ = ws_writer.close().await;
break;
},
Some(Err(e)) => {
gloo_console::warn!(
"Error reading message from read WebSocket stream: ",
format!("{:?}",e)
);
break;
},
Some(Ok(other)) => {
record_metrics(other.len());
gloo_console::warn!("Unexpected WebSocket message: ", format!("{:?}",other));
}
},
outbound = outgoing.next() => if let Some(client_msg) = outbound {
let raw = Self::encode_message(client_msg);
if let Err(e) = ws_writer.send(raw).await {
gloo_console::warn!("Error sending outgoing message:", format!("{:?}",e));
break;
}
} else {
if let Err(e) = ws_writer.close().await {
gloo_console::warn!("Error sending close frame:", format!("{:?}", e));
}
break;
},
}
}
});
(incoming_rx, outgoing_tx)
}
}