use std::fmt::Display;
use std::future::{poll_fn, Future};
use std::num::NonZeroUsize;
use std::panic;
use std::pin::{pin, Pin};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use async_stream::stream;
use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use axum::Extension;
use axum_extra::TypedHeader;
use bytes::Bytes;
use bytestring::ByteString;
use derive_more::From;
use futures::{pin_mut, Sink, SinkExt, Stream, StreamExt};
use http::{HeaderValue, StatusCode};
use prometheus::IntGauge;
use scopeguard::{defer, ScopeGuard};
use serde::Deserialize;
use spacetimedb::client::messages::{
serialize, IdentityTokenMessage, SerializableMessage, SerializeBuffer, SwitchedServerMessage, ToProtocol,
};
use spacetimedb::client::{
ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageExecutionError, MessageHandleError,
MeteredReceiver, Protocol,
};
use spacetimedb::host::module_host::ClientConnectedError;
use spacetimedb::host::NoSuchModule;
use spacetimedb::util::spawn_rayon;
use spacetimedb::worker_metrics::WORKER_METRICS;
use spacetimedb::Identity;
use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression};
use spacetimedb_datastore::execution_context::WorkloadType;
use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl};
use std::time::Instant;
use tokio::sync::{mpsc, watch};
use tokio::task::JoinHandle;
use tokio::time::error::Elapsed;
use tokio::time::{sleep_until, timeout};
use tokio_tungstenite::tungstenite::Utf8Bytes;
use crate::auth::SpacetimeAuth;
use crate::util::serde::humantime_duration;
use crate::util::websocket::{
CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade, WsError,
};
use crate::util::{NameOrIdentity, XForwardedFor};
use crate::{log_and_500, ControlStateDelegate, NodeDelegate};
#[allow(clippy::declare_interior_mutable_const)]
pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::TEXT_PROTOCOL);
#[allow(clippy::declare_interior_mutable_const)]
pub const BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::BIN_PROTOCOL);
pub trait HasWebSocketOptions {
fn websocket_options(&self) -> WebSocketOptions;
}
impl<T: HasWebSocketOptions> HasWebSocketOptions for Arc<T> {
fn websocket_options(&self) -> WebSocketOptions {
(**self).websocket_options()
}
}
#[derive(Deserialize)]
pub struct SubscribeParams {
pub name_or_identity: NameOrIdentity,
}
#[derive(Deserialize)]
pub struct SubscribeQueryParams {
pub connection_id: Option<ConnectionIdForUrl>,
#[serde(default)]
pub compression: Compression,
#[serde(default)]
pub light: bool,
}
pub fn generate_random_connection_id() -> ConnectionId {
ConnectionId::from_le_byte_array(rand::random())
}
pub async fn handle_websocket<S>(
State(ctx): State<S>,
Path(SubscribeParams { name_or_identity }): Path<SubscribeParams>,
Query(SubscribeQueryParams {
connection_id,
compression,
light,
}): Query<SubscribeQueryParams>,
forwarded_for: Option<TypedHeader<XForwardedFor>>,
Extension(auth): Extension<SpacetimeAuth>,
ws: WebSocketUpgrade,
) -> axum::response::Result<impl IntoResponse>
where
S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions,
{
if connection_id.is_some() {
log::debug!("The connection_id query parameter to the subscribe HTTP endpoint is internal and will be removed in a future version of SpacetimeDB.");
}
let connection_id = connection_id
.map(ConnectionId::from)
.unwrap_or_else(generate_random_connection_id);
if connection_id == ConnectionId::ZERO {
Err((
StatusCode::BAD_REQUEST,
"Invalid connection ID: the all-zeros ConnectionId is reserved.",
))?;
}
let db_identity = name_or_identity.resolve(&ctx).await?;
let (res, ws_upgrade, protocol) =
ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]);
let protocol = protocol.ok_or((StatusCode::BAD_REQUEST, "no valid protocol selected"))?;
let client_config = ClientConfig {
protocol,
compression,
tx_update_full: !light,
};
let database = ctx
.get_database_by_identity(&db_identity)
.unwrap()
.ok_or(StatusCode::NOT_FOUND)?;
let leader = ctx
.leader(database.id)
.await
.map_err(log_and_500)?
.ok_or(StatusCode::NOT_FOUND)?;
let identity_token = auth.creds.token().into();
let mut module_rx = leader.module_watcher().await.map_err(log_and_500)?;
let client_id = ClientActorId {
identity: auth.identity,
connection_id,
name: ctx.client_actor_index().next_client_name(),
};
let ws_config = WebSocketConfig::default()
.max_message_size(Some(0x2000000))
.max_frame_size(None)
.accept_unmasked_frames(false);
let ws_opts = ctx.websocket_options();
tokio::spawn(async move {
let ws = match ws_upgrade.upgrade(ws_config).await {
Ok(ws) => ws,
Err(err) => {
log::error!("websocket: WebSocket init error: {err}");
return;
}
};
let identity = client_id.identity;
let client_log_string = match forwarded_for {
Some(TypedHeader(XForwardedFor(ip))) => {
format!("ip {ip} with Identity {identity} and ConnectionId {connection_id}")
}
None => format!("unknown ip with Identity {identity} and ConnectionId {connection_id}"),
};
log::debug!("websocket: New client connected from {client_log_string}");
let connected = match ClientConnection::call_client_connected_maybe_reject(&mut module_rx, client_id).await {
Ok(connected) => {
log::debug!("websocket: client_connected returned Ok for {client_log_string}");
connected
}
Err(e @ (ClientConnectedError::Rejected(_) | ClientConnectedError::OutOfEnergy)) => {
log::info!(
"websocket: Rejecting connection for {client_log_string} due to error from client_connected reducer: {e}"
);
return;
}
Err(e @ (ClientConnectedError::DBError(_) | ClientConnectedError::ReducerCall(_))) => {
log::warn!("websocket: ModuleHost died while {client_log_string} was connecting: {e:#}");
return;
}
};
log::debug!(
"websocket: Database accepted connection from {client_log_string}; spawning ws_client_actor and ClientConnection"
);
let actor = |client, sendrx| ws_client_actor(ws_opts, client, ws, sendrx);
let client =
ClientConnection::spawn(client_id, client_config, leader.replica_id, module_rx, actor, connected).await;
let message = IdentityTokenMessage {
identity: auth.identity,
token: identity_token,
connection_id,
};
if let Err(e) = client.send_message(message) {
log::warn!("websocket: Error sending IdentityToken message to {client_log_string}: {e}");
}
});
Ok(res)
}
struct ActorState {
pub client_id: ClientActorId,
pub database: Identity,
config: WebSocketOptions,
closed: AtomicBool,
got_pong: AtomicBool,
}
impl ActorState {
pub fn new(database: Identity, client_id: ClientActorId, config: WebSocketOptions) -> Self {
Self {
database,
client_id,
config,
closed: AtomicBool::new(false),
got_pong: AtomicBool::new(true),
}
}
pub fn closed(&self) -> bool {
self.closed.load(Ordering::Relaxed)
}
pub fn close(&self) -> bool {
self.closed.swap(true, Ordering::Relaxed)
}
pub fn set_ponged(&self) {
self.got_pong.store(true, Ordering::Relaxed);
}
pub fn reset_ponged(&self) -> bool {
self.got_pong.swap(false, Ordering::Relaxed)
}
pub fn next_idle_deadline(&self) -> Instant {
Instant::now() + self.config.idle_timeout
}
}
#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct WebSocketOptions {
#[serde(with = "humantime_duration")]
#[serde(default = "WebSocketOptions::default_ping_interval")]
pub ping_interval: Duration,
#[serde(with = "humantime_duration")]
#[serde(default = "WebSocketOptions::default_idle_timeout")]
pub idle_timeout: Duration,
#[serde(with = "humantime_duration")]
#[serde(default = "WebSocketOptions::default_close_handshake_timeout")]
pub close_handshake_timeout: Duration,
#[serde(default = "WebSocketOptions::default_incoming_queue_length")]
pub incoming_queue_length: NonZeroUsize,
}
impl Default for WebSocketOptions {
fn default() -> Self {
Self::DEFAULT
}
}
impl WebSocketOptions {
const DEFAULT_PING_INTERVAL: Duration = Duration::from_secs(15);
const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_CLOSE_HANDSHAKE_TIMEOUT: Duration = Duration::from_millis(250);
const DEFAULT_INCOMING_QUEUE_LENGTH: NonZeroUsize = NonZeroUsize::new(2048).expect("2048 > 0, qed");
const DEFAULT: Self = Self {
ping_interval: Self::DEFAULT_PING_INTERVAL,
idle_timeout: Self::DEFAULT_IDLE_TIMEOUT,
close_handshake_timeout: Self::DEFAULT_CLOSE_HANDSHAKE_TIMEOUT,
incoming_queue_length: Self::DEFAULT_INCOMING_QUEUE_LENGTH,
};
const fn default_ping_interval() -> Duration {
Self::DEFAULT_PING_INTERVAL
}
const fn default_idle_timeout() -> Duration {
Self::DEFAULT_IDLE_TIMEOUT
}
const fn default_close_handshake_timeout() -> Duration {
Self::DEFAULT_CLOSE_HANDSHAKE_TIMEOUT
}
const fn default_incoming_queue_length() -> NonZeroUsize {
Self::DEFAULT_INCOMING_QUEUE_LENGTH
}
}
async fn ws_client_actor(
options: WebSocketOptions,
client: ClientConnection,
ws: WebSocketStream,
sendrx: MeteredReceiver<SerializableMessage>,
) {
let mut client = scopeguard::guard(client, |client| {
tokio::spawn(client.disconnect());
});
ws_client_actor_inner(&mut client, options, ws, sendrx).await;
ScopeGuard::into_inner(client).disconnect().await;
}
async fn ws_client_actor_inner(
client: &mut ClientConnection,
config: WebSocketOptions,
ws: WebSocketStream,
sendrx: MeteredReceiver<SerializableMessage>,
) {
let database = client.module.info().database_identity;
let client_id = client.id;
let client_closed_metric = WORKER_METRICS.ws_clients_closed_connection.with_label_values(&database);
let state = Arc::new(ActorState::new(database, client_id, config));
let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
let (ws_send, ws_recv) = ws.split();
let (idle_tx, idle_rx) = watch::channel(state.next_idle_deadline());
let idle_timer = ws_idle_timer(idle_rx);
let send_task = tokio::spawn(ws_send_loop(
state.clone(),
client.config,
ws_send,
sendrx,
unordered_rx,
));
let recv_task = tokio::spawn(ws_recv_task(
state.clone(),
idle_tx,
client_closed_metric,
{
let client = client.clone();
move |data, timer| {
let client = client.clone();
async move { client.handle_message(data, timer).await }
}
},
unordered_tx.clone(),
ws_recv,
));
let hotswap = {
let client = client.clone();
move || {
let mut client = client.clone();
async move { client.watch_module_host().await }
}
};
ws_main_loop(state, hotswap, idle_timer, send_task, recv_task, move |msg| {
let _ = unordered_tx.send(msg);
})
.await;
log::info!("Client connection ended: {client_id}");
}
async fn ws_main_loop<HotswapWatcher>(
state: Arc<ActorState>,
hotswap: impl Fn() -> HotswapWatcher,
idle_timer: impl Future<Output = ()>,
mut send_task: JoinHandle<()>,
mut recv_task: JoinHandle<()>,
unordered_tx: impl Fn(UnorderedWsMessage),
) where
HotswapWatcher: Future<Output = Result<(), NoSuchModule>>,
{
let abort_send = send_task.abort_handle();
let abort_recv = recv_task.abort_handle();
defer! {
abort_send.abort();
abort_recv.abort();
};
let mut ping_interval = tokio::time::interval(state.config.ping_interval);
let watch_hotswap = hotswap();
pin_mut!(watch_hotswap);
pin_mut!(idle_timer);
loop {
let closed = state.closed();
tokio::select! {
res = &mut send_task => {
if let Err(e) = res {
if e.is_panic() {
panic::resume_unwind(e.into_panic())
}
}
break;
},
res = &mut recv_task => {
if let Err(e) = res {
if e.is_panic() {
panic::resume_unwind(e.into_panic())
}
}
break;
},
_ = &mut idle_timer => {
log::warn!("Client {} timed out", state.client_id);
break;
},
res = &mut watch_hotswap, if !closed => {
if let Err(NoSuchModule) = res {
let close = CloseFrame {
code: CloseCode::Away,
reason: "module exited".into()
};
unordered_tx(close.into());
}
watch_hotswap.set(hotswap());
},
_ = ping_interval.tick(), if !closed => {
let was_ponged = state.reset_ponged();
if was_ponged {
unordered_tx(UnorderedWsMessage::Ping(Bytes::new()));
}
}
}
}
}
async fn ws_idle_timer(mut activity: watch::Receiver<Instant>) {
let mut deadline = *activity.borrow();
let sleep = sleep_until(deadline.into());
pin_mut!(sleep);
loop {
tokio::select! {
biased;
Ok(()) = activity.changed() => {
let new_deadline = *activity.borrow_and_update();
if new_deadline != deadline {
deadline = new_deadline;
sleep.as_mut().reset(deadline.into());
}
},
() = &mut sleep => {
break;
},
}
}
}
async fn ws_recv_task<MessageHandler>(
state: Arc<ActorState>,
idle_tx: watch::Sender<Instant>,
client_closed_metric: IntGauge,
message_handler: impl Fn(DataMessage, Instant) -> MessageHandler,
unordered_tx: mpsc::UnboundedSender<UnorderedWsMessage>,
ws: impl Stream<Item = Result<WsMessage, WsError>> + Unpin + Send + 'static,
) where
MessageHandler: Future<Output = Result<(), MessageHandleError>>,
{
let recv_queue = ws_recv_queue(state.clone(), unordered_tx.clone(), ws);
let recv_loop = pin!(ws_recv_loop(state.clone(), idle_tx, recv_queue));
let recv_handler = ws_client_message_handler(state.clone(), client_closed_metric, recv_loop);
pin_mut!(recv_handler);
while let Some((data, timer)) = recv_handler.next().await {
let result = message_handler(data, timer).await;
if let Err(e) = result {
if let MessageHandleError::Execution(err) = e {
log::error!("{err:#}");
if unordered_tx.send(err.into()).is_err() {
break;
}
continue;
}
log::debug!("Client caused error: {e}");
let close = CloseFrame {
code: CloseCode::Error,
reason: format!("{e:#}").into(),
};
if unordered_tx.send(close.into()).is_err() {
break;
};
}
}
}
fn ws_recv_loop(
state: Arc<ActorState>,
idle_tx: watch::Sender<Instant>,
mut ws: impl Stream<Item = Result<WsMessage, WsError>> + Unpin,
) -> impl Stream<Item = ClientMessage> {
async fn next_message(
state: &ActorState,
ws: &mut (impl Stream<Item = Result<WsMessage, WsError>> + Unpin),
) -> Option<Result<WsMessage, WsError>> {
if state.closed() {
log::trace!("drain websocket waiting for client close");
let res: Result<Option<Result<WsMessage, WsError>>, Elapsed> =
timeout(state.config.close_handshake_timeout, async {
while let Some(item) = ws.next().await {
match item {
Ok(message) => drop(message),
Err(e) => return Some(Err(e)),
}
}
None
})
.await;
match res {
Err(_elapsed) => {
log::warn!("timeout waiting for client close");
None
}
Ok(item) => item, }
} else {
log::trace!("await next client message without timeout");
ws.next().await
}
}
stream! {
loop {
let Some(res) = next_message(&state, &mut ws).await else {
log::trace!("recv stream exhausted");
break;
};
match res {
Ok(m) => {
idle_tx.send(state.next_idle_deadline()).ok();
if !state.closed() {
yield ClientMessage::from_message(m);
}
log::trace!("message received while already closed");
}
Err(e) => match e {
e @ (WsError::ConnectionClosed
| WsError::AlreadyClosed
| WsError::Io(_)
| WsError::Tls(_)
| WsError::Capacity(_)
| WsError::Protocol(_)
| WsError::WriteBufferFull(_)
| WsError::Utf8(_)
| WsError::AttackAttempt
| WsError::Url(_)
| WsError::Http(_)
| WsError::HttpFormat(_)) => {
log::warn!("Websocket receive error: {e}");
break;
}
},
}
}
}
}
fn ws_recv_queue(
state: Arc<ActorState>,
unordered_tx: mpsc::UnboundedSender<UnorderedWsMessage>,
mut ws: impl Stream<Item = Result<WsMessage, WsError>> + Unpin + Send + 'static,
) -> impl Stream<Item = Result<WsMessage, WsError>> {
const CLOSE: UnorderedWsMessage = UnorderedWsMessage::Close(CloseFrame {
code: CloseCode::Again,
reason: Utf8Bytes::from_static("too many requests"),
});
let on_message_after_close = move |client_id| {
log::warn!("client {client_id} sent message after close or error");
};
let (tx, rx) = mpsc::channel(state.config.incoming_queue_length.get());
let rx = MeteredReceiverStream {
inner: MeteredReceiver::with_gauge(
rx,
WORKER_METRICS
.total_incoming_queue_length
.with_label_values(&state.database),
),
};
tokio::spawn(async move {
while let Some(item) = ws.next().await {
if let Err(e) = tx.try_send(item) {
match e {
mpsc::error::TrySendError::Full(item) => {
if unordered_tx.send(CLOSE).is_err() {
state.close();
break;
}
if tx.send(item).await.is_err() {
on_message_after_close(state.client_id);
break;
}
}
mpsc::error::TrySendError::Closed(_item) => {
on_message_after_close(state.client_id);
break;
}
}
}
}
});
rx
}
struct MeteredReceiverStream<T> {
inner: MeteredReceiver<T>,
}
impl<T> Stream for MeteredReceiverStream<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_recv(cx)
}
}
fn ws_client_message_handler(
state: Arc<ActorState>,
client_closed_metric: IntGauge,
mut messages: impl Stream<Item = ClientMessage> + Unpin,
) -> impl Stream<Item = (DataMessage, Instant)> {
stream! {
while let Some(message) = messages.next().await {
match message {
ClientMessage::Message(message) => {
log::trace!("Received client message");
yield (message, Instant::now());
},
ClientMessage::Ping(_bytes) => {
log::trace!("Received ping from client {}", state.client_id);
},
ClientMessage::Pong(_bytes) => {
log::trace!("Received pong from client {}", state.client_id);
state.set_ponged();
},
ClientMessage::Close(close_frame) => {
log::trace!("Received Close frame from client {}: {:?}", state.client_id, close_frame);
let was_closed = state.close();
if !was_closed {
client_closed_metric.inc();
}
}
}
}
log::trace!("client message handler done");
}
}
#[derive(Debug, From)]
enum UnorderedWsMessage {
Close(CloseFrame),
Ping(Bytes),
Error(MessageExecutionError),
}
async fn ws_send_loop(
state: Arc<ActorState>,
config: ClientConfig,
mut ws: impl Sink<WsMessage, Error: Display> + Unpin,
mut messages: MeteredReceiver<SerializableMessage>,
mut unordered: mpsc::UnboundedReceiver<UnorderedWsMessage>,
) {
let mut messages_buf = Vec::with_capacity(32);
let mut serialize_buf = SerializeBuffer::new(config);
loop {
let closed = state.closed();
tokio::select! {
biased;
maybe_msg = unordered.recv() => {
let Some(msg) = maybe_msg else {
break;
};
if closed {
continue;
}
match msg {
UnorderedWsMessage::Close(close_frame) => {
log::trace!("sending close frame");
if let Err(e) = ws.send(WsMessage::Close(Some(close_frame))).await {
log::warn!("error sending close frame: {e:#}");
break;
}
state.close();
messages.close();
},
UnorderedWsMessage::Ping(bytes) => {
log::trace!("sending ping");
if let Err(e) = ws.feed(WsMessage::Ping(bytes)).await {
log::warn!("error sending ping: {e:#}");
break;
}
},
UnorderedWsMessage::Error(err) => {
log::trace!("sending error result");
let (msg_alloc, res) = send_message(
&state.database,
config,
serialize_buf,
None,
&mut ws,
err
).await;
serialize_buf = msg_alloc;
if let Err(e) = res {
log::warn!("websocket send error: {e}");
break;
}
},
}
},
n = messages.recv_many(&mut messages_buf, 32), if !closed => {
if n == 0 {
continue;
}
log::trace!("sending {n} outgoing messages");
for msg in messages_buf.drain(..n) {
let (msg_alloc, res) = send_message(
&state.database,
config,
serialize_buf,
msg.workload().zip(msg.num_rows()),
&mut ws,
msg
).await;
serialize_buf = msg_alloc;
if let Err(e) = res {
log::warn!("websocket send error: {e}");
return;
}
}
},
}
if let Err(e) = ws.flush().await {
log::warn!("error flushing websocket: {e}");
break;
}
}
}
async fn send_message<S: Sink<WsMessage> + Unpin>(
database_identity: &Identity,
config: ClientConfig,
serialize_buf: SerializeBuffer,
metrics_metadata: Option<(WorkloadType, usize)>,
ws: &mut S,
message: impl ToProtocol<Encoded = SwitchedServerMessage> + Send + 'static,
) -> (SerializeBuffer, Result<(), S::Error>) {
let (workload, num_rows) = metrics_metadata.unzip();
let serialize_and_compress = |serialize_buf, message, config| {
let start = Instant::now();
let (msg_alloc, msg_data) = serialize(serialize_buf, message, config);
(start.elapsed(), msg_alloc, msg_data)
};
let (timing, msg_alloc, msg_data) = if num_rows.is_some_and(|n| n > 1024) {
spawn_rayon(move || serialize_and_compress(serialize_buf, message, config)).await
} else {
serialize_and_compress(serialize_buf, message, config)
};
report_ws_sent_metrics(database_identity, workload, num_rows, timing, &msg_data);
let res = async {
ws.feed(datamsg_to_wsmsg(msg_data)).await?;
poll_fn(|cx| ws.poll_ready_unpin(cx)).await
}
.await;
let buf = msg_alloc.try_reclaim().unwrap_or_else(|| SerializeBuffer::new(config));
(buf, res)
}
#[derive(Debug)]
enum ClientMessage {
Message(DataMessage),
Ping(Bytes),
Pong(Bytes),
Close(Option<CloseFrame>),
}
impl ClientMessage {
fn from_message(msg: WsMessage) -> Self {
match msg {
WsMessage::Text(s) => Self::Message(DataMessage::Text(utf8bytes_to_bytestring(s))),
WsMessage::Binary(b) => Self::Message(DataMessage::Binary(b)),
WsMessage::Ping(b) => Self::Ping(b),
WsMessage::Pong(b) => Self::Pong(b),
WsMessage::Close(frame) => Self::Close(frame),
WsMessage::Frame(_) => unreachable!(),
}
}
}
fn report_ws_sent_metrics(
addr: &Identity,
workload: Option<WorkloadType>,
num_rows: Option<usize>,
serialize_duration: Duration,
msg_ws: &DataMessage,
) {
if let (Some(workload), Some(num_rows)) = (workload, num_rows) {
WORKER_METRICS
.websocket_sent_num_rows
.with_label_values(addr, &workload)
.observe(num_rows as f64);
WORKER_METRICS
.websocket_sent_msg_size
.with_label_values(addr, &workload)
.observe(msg_ws.len() as f64);
}
WORKER_METRICS
.websocket_serialize_secs
.with_label_values(addr)
.observe(serialize_duration.as_secs_f64());
}
fn datamsg_to_wsmsg(msg: DataMessage) -> WsMessage {
match msg {
DataMessage::Text(text) => WsMessage::Text(bytestring_to_utf8bytes(text)),
DataMessage::Binary(bin) => WsMessage::Binary(bin),
}
}
fn utf8bytes_to_bytestring(s: Utf8Bytes) -> ByteString {
unsafe { ByteString::from_bytes_unchecked(Bytes::from(s)) }
}
fn bytestring_to_utf8bytes(s: ByteString) -> Utf8Bytes {
unsafe { Utf8Bytes::from_bytes_unchecked(s.into_bytes()) }
}
#[cfg(test)]
mod tests {
use std::{
future::Future,
pin::Pin,
sync::atomic::AtomicUsize,
task::{Context, Poll},
};
use anyhow::anyhow;
use futures::{
future::{self, Either, FutureExt as _},
sink, stream,
};
use pretty_assertions::assert_matches;
use spacetimedb::client::ClientName;
use tokio::time::sleep;
use super::*;
fn dummy_client_id() -> ClientActorId {
ClientActorId {
identity: Identity::ZERO,
connection_id: ConnectionId::ZERO,
name: ClientName(0),
}
}
fn dummy_actor_state() -> ActorState {
dummy_actor_state_with_config(<_>::default())
}
fn dummy_actor_state_with_config(config: WebSocketOptions) -> ActorState {
ActorState::new(Identity::ZERO, dummy_client_id(), config)
}
#[tokio::test]
async fn idle_timer_extends_sleep() {
let timeout = Duration::from_millis(10);
let start = Instant::now();
let (tx, rx) = watch::channel(start + timeout);
tokio::join!(ws_idle_timer(rx), async {
for _ in 0..5 {
sleep(Duration::from_millis(1)).await;
tx.send(Instant::now() + timeout).unwrap();
}
});
let elapsed = start.elapsed();
let expected = timeout + Duration::from_millis(5);
assert!(
elapsed >= expected,
"{}ms elapsed, expected >= {}ms",
elapsed.as_millis(),
expected.as_millis(),
);
}
#[tokio::test]
async fn recv_loop_terminates_when_input_exhausted() {
let state = Arc::new(dummy_actor_state());
let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout);
let input = stream::iter(vec![Ok(WsMessage::Ping(Bytes::new()))]);
pin_mut!(input);
let recv_loop = ws_recv_loop(state, idle_tx, input);
pin_mut!(recv_loop);
assert_matches!(recv_loop.next().await, Some(ClientMessage::Ping(_)));
assert_matches!(recv_loop.next().await, None);
}
#[tokio::test]
async fn recv_loop_terminates_when_input_yields_err() {
let state = Arc::new(dummy_actor_state());
let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout);
let input = stream::iter(vec![
Ok(WsMessage::Ping(Bytes::new())),
Err(WsError::ConnectionClosed),
Ok(WsMessage::Pong(Bytes::new())),
]);
pin_mut!(input);
let recv_loop = ws_recv_loop(state, idle_tx, input);
pin_mut!(recv_loop);
assert_matches!(recv_loop.next().await, Some(ClientMessage::Ping(_)));
assert_matches!(recv_loop.next().await, None);
}
#[tokio::test]
async fn recv_loop_drains_remaining_messages_when_closed() {
let state = Arc::new(dummy_actor_state());
let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout);
let input = stream::iter(vec![
Ok(WsMessage::Ping(Bytes::new())),
Ok(WsMessage::Pong(Bytes::new())),
]);
pin_mut!(input);
{
let recv_loop = ws_recv_loop(state.clone(), idle_tx, &mut input);
pin_mut!(recv_loop);
state.close();
assert_matches!(recv_loop.next().await, None);
}
assert_matches!(input.next().await, None);
}
#[tokio::test]
async fn recv_loop_stops_at_error_while_draining() {
let state = Arc::new(dummy_actor_state());
let (idle_tx, _idle_rx) = watch::channel(Instant::now() + state.config.idle_timeout);
let input = stream::iter(vec![
Ok(WsMessage::Ping(Bytes::new())),
Err(WsError::ConnectionClosed),
Ok(WsMessage::Pong(Bytes::new())),
]);
pin_mut!(input);
{
let recv_loop = ws_recv_loop(state.clone(), idle_tx, &mut input);
pin_mut!(recv_loop);
state.close();
assert_matches!(recv_loop.next().await, None);
}
assert_matches!(input.next().await, Some(Ok(WsMessage::Pong(_))));
}
#[tokio::test]
async fn recv_loop_updates_idle_channel() {
let state = Arc::new(dummy_actor_state());
let idle_deadline = Instant::now() + state.config.idle_timeout;
let (idle_tx, mut idle_rx) = watch::channel(idle_deadline);
let input = stream::iter(vec![
Ok(WsMessage::Ping(Bytes::new())),
Ok(WsMessage::Pong(Bytes::new())),
]);
let recv_loop = ws_recv_loop(state, idle_tx, input);
pin_mut!(recv_loop);
let mut new_idle_deadline = *idle_rx.borrow();
while let Some(message) = recv_loop.next().await {
drop(message);
assert!(idle_rx.has_changed().unwrap());
new_idle_deadline = *idle_rx.borrow_and_update();
}
assert!(new_idle_deadline > idle_deadline);
}
#[tokio::test]
async fn client_message_handler_terminates_when_input_exhausted() {
let state = Arc::new(dummy_actor_state());
let metric = IntGauge::new("bleep", "unhelpful").unwrap();
let input = stream::iter(vec![
ClientMessage::Ping(Bytes::new()),
ClientMessage::Message(DataMessage::from("hello".to_owned())),
]);
let handler = ws_client_message_handler(state, metric, input);
pin_mut!(handler);
assert_matches!(
handler.next().await,
Some((DataMessage::Text(data), _instant)) if data == "hello"
);
assert_matches!(handler.next().await, None);
}
#[tokio::test]
async fn client_message_handler_updates_pong_and_closed_states_and_metric() {
let state = Arc::new(dummy_actor_state());
state.reset_ponged();
let metric = IntGauge::new("bleep", "unhelpful").unwrap();
let input = stream::iter(vec![ClientMessage::Pong(Bytes::new()), ClientMessage::Close(None)]);
let handler = ws_client_message_handler(state.clone(), metric.clone(), input);
handler.map(drop).for_each(future::ready).await;
assert!(state.closed());
assert!(state.reset_ponged());
assert_eq!(metric.get(), 1);
}
#[tokio::test]
async fn send_loop_terminates_when_unordered_closed() {
let state = Arc::new(dummy_actor_state());
let (messages_tx, messages_rx) = mpsc::channel(64);
let messages = MeteredReceiver::new(messages_rx);
let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
let send_loop = ws_send_loop(state, ClientConfig::for_test(), sink::drain(), messages, unordered_rx);
pin_mut!(send_loop);
assert!(is_pending(&mut send_loop).await);
drop(messages_tx);
assert!(is_pending(&mut send_loop).await);
drop(unordered_tx);
send_loop.await;
}
#[tokio::test]
async fn send_loop_close_message_closes_state_and_messages() {
let state = Arc::new(dummy_actor_state());
let (messages_tx, messages_rx) = mpsc::channel(64);
let messages = MeteredReceiver::new(messages_rx);
let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
let send_loop = ws_send_loop(
state.clone(),
ClientConfig::for_test(),
sink::drain(),
messages,
unordered_rx,
);
pin_mut!(send_loop);
unordered_tx
.send(UnorderedWsMessage::Close(CloseFrame {
code: CloseCode::Away,
reason: "done".into(),
}))
.unwrap();
assert!(is_pending(&mut send_loop).await);
assert!(state.closed());
assert!(messages_tx.is_closed());
}
#[tokio::test]
async fn send_loop_terminates_if_sink_cant_be_fed() {
let input = [
Either::Left(UnorderedWsMessage::Close(CloseFrame {
code: CloseCode::Away,
reason: "bah!".into(),
})),
Either::Left(UnorderedWsMessage::Ping(Bytes::new())),
Either::Left(UnorderedWsMessage::Error(MessageExecutionError {
reducer: None,
reducer_id: None,
caller_identity: Identity::ZERO,
caller_connection_id: None,
err: anyhow!("it did not work"),
})),
Either::Right(SerializableMessage::Identity(IdentityTokenMessage {
identity: Identity::ZERO,
token: "macaron".into(),
connection_id: ConnectionId::ZERO,
})),
];
for msg in input {
let state = Arc::new(dummy_actor_state());
let (messages_tx, messages_rx) = mpsc::channel(64);
let messages = MeteredReceiver::new(messages_rx);
let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
let send_loop = ws_send_loop(
state.clone(),
ClientConfig::for_test(),
UnfeedableSink,
messages,
unordered_rx,
);
pin_mut!(send_loop);
match msg {
Either::Left(unordered) => unordered_tx.send(unordered).unwrap(),
Either::Right(msg) => messages_tx.send(msg).await.unwrap(),
}
send_loop.await;
}
}
#[tokio::test]
async fn send_loop_terminates_if_sink_cant_be_flushed() {
let input = [
Either::Left(UnorderedWsMessage::Close(CloseFrame {
code: CloseCode::Away,
reason: "bah!".into(),
})),
Either::Left(UnorderedWsMessage::Ping(Bytes::new())),
Either::Left(UnorderedWsMessage::Error(MessageExecutionError {
reducer: None,
reducer_id: None,
caller_identity: Identity::ZERO,
caller_connection_id: None,
err: anyhow!("it did not work"),
})),
Either::Right(SerializableMessage::Identity(IdentityTokenMessage {
identity: Identity::ZERO,
token: "macaron".into(),
connection_id: ConnectionId::ZERO,
})),
];
for msg in input {
let state = Arc::new(dummy_actor_state());
let (messages_tx, messages_rx) = mpsc::channel(64);
let messages = MeteredReceiver::new(messages_rx);
let (unordered_tx, unordered_rx) = mpsc::unbounded_channel();
let send_loop = ws_send_loop(
state.clone(),
ClientConfig::for_test(),
UnflushableSink,
messages,
unordered_rx,
);
pin_mut!(send_loop);
match msg {
Either::Left(unordered) => unordered_tx.send(unordered).unwrap(),
Either::Right(msg) => messages_tx.send(msg).await.unwrap(),
}
send_loop.await;
}
}
#[tokio::test]
async fn main_loop_terminates_if_either_send_or_recv_terminates() {
let state = Arc::new(dummy_actor_state());
ws_main_loop(
state.clone(),
future::pending,
future::pending(),
tokio::spawn(sleep(Duration::from_millis(10))),
tokio::spawn(future::pending()),
drop,
)
.await;
ws_main_loop(
state,
future::pending,
future::pending(),
tokio::spawn(future::pending()),
tokio::spawn(sleep(Duration::from_millis(10))),
drop,
)
.await;
}
#[tokio::test]
async fn main_loop_terminates_on_idle_timeout() {
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
idle_timeout: Duration::from_millis(10),
..<_>::default()
}));
let (idle_tx, idle_rx) = watch::channel(state.next_idle_deadline());
let start = Instant::now();
let mut t = tokio::spawn({
let state = state.clone();
async move {
ws_main_loop(
state,
future::pending,
ws_idle_timer(idle_rx),
tokio::spawn(future::pending()),
tokio::spawn(future::pending()),
drop,
)
.await
}
});
let loop_start = Instant::now();
for _ in 0..5 {
sleep(Duration::from_millis(5)).await;
idle_tx.send(state.next_idle_deadline()).unwrap();
assert!(is_pending(&mut t).await);
}
let timeout = loop_start.elapsed() + Duration::from_millis(10);
t.await.unwrap();
let elapsed = start.elapsed();
assert!(elapsed >= timeout);
assert!(elapsed < timeout + Duration::from_millis(10));
}
#[tokio::test]
async fn main_loop_keepalive_keeps_alive() {
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
ping_interval: Duration::from_millis(5),
idle_timeout: Duration::from_millis(10),
..<_>::default()
}));
let (idle_tx, idle_rx) = watch::channel(state.next_idle_deadline());
let unordered_tx = {
let state = state.clone();
let pings = AtomicUsize::new(0);
move |m| {
if let UnorderedWsMessage::Ping(_) = m {
let n = pings.fetch_add(1, Ordering::Relaxed);
if n < 5 {
state.set_ponged();
idle_tx.send(state.next_idle_deadline()).ok();
}
}
}
};
let start = Instant::now();
let t = tokio::spawn({
let state = state.clone();
async move {
ws_main_loop(
state,
future::pending,
ws_idle_timer(idle_rx),
tokio::spawn(future::pending()),
tokio::spawn(future::pending()),
unordered_tx,
)
.await
}
});
let expected_timeout = (5 * state.config.ping_interval) + state.config.idle_timeout;
let res = timeout(expected_timeout, t).await;
let elapsed = start.elapsed();
assert_matches!(res, Ok(Ok(())));
assert!(elapsed >= expected_timeout - state.config.ping_interval);
}
#[tokio::test]
async fn main_loop_terminates_when_module_exits() {
let state = Arc::new(dummy_actor_state());
let (_idle_tx, idle_rx) = watch::channel(state.next_idle_deadline());
let unordered_tx = {
let state = state.clone();
move |m| {
if let UnorderedWsMessage::Close(_) = m {
state.close();
}
}
};
let start = Instant::now();
tokio::spawn(async move {
let hotswap = || async {
sleep(Duration::from_millis(5)).await;
Err(NoSuchModule)
};
ws_main_loop(
state.clone(),
hotswap,
ws_idle_timer(idle_rx),
tokio::spawn(async move {
loop {
if state.closed() {
break;
}
sleep(Duration::from_millis(1)).await
}
}),
tokio::spawn(future::pending()),
unordered_tx,
)
.await
})
.await
.unwrap();
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(5));
assert!(elapsed < Duration::from_millis(10));
}
#[tokio::test]
async fn recv_queue_sends_close_when_at_capacity() {
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
incoming_queue_length: 10.try_into().unwrap(),
..<_>::default()
}));
let (unordered_tx, mut unordered_rx) = mpsc::unbounded_channel();
let input = stream::iter((0..20).map(|i| Ok(WsMessage::text(format!("message {i}")))));
let received = ws_recv_queue(state, unordered_tx, input).collect::<Vec<_>>().await;
assert_matches!(unordered_rx.recv().await, Some(UnorderedWsMessage::Close(_)));
assert_eq!(received.len(), 20);
}
#[tokio::test]
async fn recv_queue_closes_state_if_sender_gone() {
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
incoming_queue_length: 10.try_into().unwrap(),
..<_>::default()
}));
let (unordered_tx, _) = mpsc::unbounded_channel();
let input = stream::iter((0..20).map(|i| Ok(WsMessage::text(format!("message {i}")))));
let received = ws_recv_queue(state.clone(), unordered_tx, input)
.collect::<Vec<_>>()
.await;
assert!(state.closed());
assert_eq!(received.len(), 10);
}
async fn is_pending(fut: &mut (impl Future + Unpin)) -> bool {
poll_fn(|cx| Poll::Ready(fut.poll_unpin(cx).is_pending())).await
}
struct UnfeedableSink;
impl<T> Sink<T> for UnfeedableSink {
type Error = &'static str;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> {
Err("don't feed the sink")
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
struct UnflushableSink;
impl<T> Sink<T> for UnflushableSink {
type Error = &'static str;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, _: T) -> Result<(), Self::Error> {
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Err("don't flush the sink"))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn options_toml_roundtrip() {
let options = WebSocketOptions::default();
let toml = toml::to_string(&options).unwrap();
assert_eq!(options, toml::from_str::<WebSocketOptions>(&toml).unwrap());
}
#[test]
fn options_from_partial_toml() {
let toml = r#"
ping-interval = "53s"
idle-timeout = "1m 3s"
"#;
let expected = WebSocketOptions {
ping_interval: Duration::from_secs(53),
idle_timeout: Duration::from_secs(63),
..<_>::default()
};
assert_eq!(expected, toml::from_str(toml).unwrap());
}
}