use crate::{
Event, ReducerEvent, Status,
__codegen::{InternalError, Reducer},
callbacks::{
CallbackId, DbCallbacks, ProcedureCallback, ProcedureCallbacks, ReducerCallback, ReducerCallbacks, RowCallback,
UpdateCallback,
},
client_cache::{ClientCache, TableHandle},
spacetime_module::{AbstractEventContext, AppliedDiff, DbConnection, DbUpdate, InModule, SpacetimeModule},
subscription::{PendingUnsubscribeResult, SubscriptionHandleImpl, SubscriptionManager},
websocket::{WsConnection, WsParams},
};
use bytes::Bytes;
use futures::StreamExt;
#[cfg(feature = "browser")]
use futures::{pin_mut, FutureExt};
use futures_channel::mpsc;
use http::Uri;
use spacetimedb_client_api_messages::websocket::{self as ws, common::QuerySetId};
use spacetimedb_lib::{bsatn, ser::Serialize, ConnectionId, Identity, Timestamp};
use spacetimedb_sats::Deserialize;
#[cfg(not(feature = "browser"))]
use std::fs::OpenOptions;
use std::{
fs::File,
io::Write,
path::PathBuf,
sync::{atomic::AtomicU32, Arc, Mutex as StdMutex, OnceLock},
};
#[cfg(not(feature = "browser"))]
use tokio::{
runtime::{self, Runtime},
sync::Mutex as TokioMutex,
};
pub(crate) type SharedCell<T> = Arc<StdMutex<T>>;
#[cfg(not(feature = "browser"))]
type SharedAsyncCell<T> = Arc<TokioMutex<T>>;
#[cfg(feature = "browser")]
type SharedAsyncCell<T> = SharedCell<T>;
pub struct DbContextImpl<M: SpacetimeModule> {
#[cfg(not(feature = "browser"))]
runtime: runtime::Handle,
pub(crate) inner: SharedCell<DbContextImplInner<M>>,
pub(crate) send_chan: SharedCell<Option<mpsc::UnboundedSender<ws::v2::ClientMessage>>>,
cache: SharedCell<ClientCache<M>>,
recv: SharedAsyncCell<mpsc::UnboundedReceiver<ParsedMessage<M>>>,
pub(crate) pending_mutations_send: mpsc::UnboundedSender<PendingMutation<M>>,
pending_mutations_recv: SharedAsyncCell<mpsc::UnboundedReceiver<PendingMutation<M>>>,
identity: SharedCell<Option<Identity>>,
connection_id: SharedCell<Option<ConnectionId>>,
pub(crate) extra_logging: Option<SharedCell<File>>,
}
impl<M: SpacetimeModule> Clone for DbContextImpl<M> {
fn clone(&self) -> Self {
Self {
#[cfg(not(feature = "browser"))]
runtime: self.runtime.clone(),
inner: Arc::clone(&self.inner),
send_chan: Arc::clone(&self.send_chan),
cache: Arc::clone(&self.cache),
recv: Arc::clone(&self.recv),
pending_mutations_send: self.pending_mutations_send.clone(),
pending_mutations_recv: Arc::clone(&self.pending_mutations_recv),
identity: Arc::clone(&self.identity),
connection_id: Arc::clone(&self.connection_id),
extra_logging: Option::<Arc<_>>::clone(&self.extra_logging),
}
}
}
impl<M: SpacetimeModule> DbContextImpl<M> {
pub(crate) fn debug_log(&self, body: impl FnOnce(&mut File) -> std::result::Result<(), std::io::Error>) {
debug_log(&self.extra_logging, body);
}
fn process_message(&self, msg: ParsedMessage<M>) -> crate::Result<()> {
self.debug_log(|out| writeln!(out, "`process_message`: {msg:?}"));
match msg {
ParsedMessage::Error(e) => {
let disconnect_ctx = self.make_event_ctx(Some(e.clone()));
self.invoke_disconnected(&disconnect_ctx);
Err(e)
}
ParsedMessage::IdentityToken(identity, token, conn_id) => {
{
let mut ident_store = self.identity.lock().unwrap();
if let Some(prev_identity) = *ident_store {
assert_eq!(prev_identity, identity);
}
*ident_store = Some(identity);
}
{
let mut conn_id_store = self.connection_id.lock().unwrap();
if let Some(prev_conn_id) = *conn_id_store {
assert_eq!(prev_conn_id, conn_id);
}
*conn_id_store = Some(conn_id);
}
let mut inner = self.inner.lock().unwrap();
if let Some(on_connect) = inner.on_connect.take() {
let ctx = <M::DbConnection as DbConnection>::new(self.clone());
on_connect(&ctx, identity, &token);
}
Ok(())
}
ParsedMessage::TransactionUpdate(update) => {
self.apply_update(update, |_| Event::Transaction);
Ok(())
}
ParsedMessage::ReducerResult {
request_id,
timestamp,
result: Ok(Ok(update)),
} => {
let (reducer, callback) = {
let mut inner = self.inner.lock().unwrap();
inner.reducer_callbacks.pop_call_info(request_id).ok_or_else(|| {
InternalError::new(format!("Reducer result for unknown request_id {request_id}"))
})?
};
let reducer_event = ReducerEvent {
reducer,
timestamp,
status: Status::Committed,
};
self.apply_update(update, |_| Event::Reducer(reducer_event.clone()));
let reducer_event_ctx = self.make_event_ctx(reducer_event);
callback(&reducer_event_ctx, Ok(Ok(())));
Ok(())
}
ParsedMessage::ReducerResult {
request_id,
timestamp,
result,
} => {
let (status, result) = match result {
Ok(Ok(_)) => {
unreachable!("This pattern handled by an earlier branch in the match on the `ParsedMessage`")
}
Ok(Err(message)) => (Status::Err(message.clone()), Ok(Err(message))),
Err(internal_error) => (Status::Panic(internal_error.clone()), Err(internal_error)),
};
let (reducer, callback) = {
let mut inner = self.inner.lock().unwrap();
inner.reducer_callbacks.pop_call_info(request_id).ok_or_else(|| {
InternalError::new(format!("Reducer result for unknown request_id {request_id}"))
})?
};
let reducer_event = ReducerEvent {
reducer,
timestamp,
status,
};
let reducer_event_ctx = self.make_event_ctx(reducer_event);
callback(&reducer_event_ctx, result);
Ok(())
}
ParsedMessage::SubscribeApplied {
query_set_id,
initial_update,
} => {
self.apply_update(initial_update, |inner| {
let sub_event_ctx = self.make_event_ctx(());
inner.subscriptions.subscription_applied(&sub_event_ctx, query_set_id);
Event::SubscribeApplied
});
Ok(())
}
ParsedMessage::UnsubscribeApplied {
query_set_id,
initial_update,
} => {
self.apply_update(initial_update, |inner| {
let sub_event_ctx = self.make_event_ctx(());
inner.subscriptions.unsubscribe_applied(&sub_event_ctx, query_set_id);
Event::UnsubscribeApplied
});
Ok(())
}
ParsedMessage::SubscriptionError { query_set_id, error } => {
let error = crate::Error::SubscriptionError { error };
let ctx = self.make_event_ctx(Some(error));
let mut inner = self.inner.lock().unwrap();
inner.subscriptions.subscription_error(&ctx, query_set_id);
Ok(())
}
ParsedMessage::ProcedureResult { request_id, result } => {
let ctx = self.make_event_ctx(());
self.inner
.lock()
.unwrap()
.procedure_callbacks
.resolve(&ctx, request_id, result);
Ok(())
}
}
}
fn apply_update(
&self,
update: M::DbUpdate,
get_event: impl FnOnce(&mut DbContextImplInner<M>) -> Event<M::Reducer>,
) {
let applied_diff = {
let mut cache = self.cache.lock().unwrap();
update.apply_to_client_cache(&mut *cache)
};
let mut inner = self.inner.lock().unwrap();
let event = get_event(&mut inner);
let row_event_ctx = self.make_event_ctx(event);
applied_diff.invoke_row_callbacks(&row_event_ctx, &mut inner.db_callbacks);
}
fn invoke_disconnected(&self, ctx: &M::ErrorContext) {
let mut inner = self.inner.lock().unwrap();
*self.send_chan.lock().unwrap() = None;
if let Some(disconnect_callback) = inner.on_disconnect.take() {
disconnect_callback(ctx, ctx.event().clone());
}
inner.subscriptions.on_disconnect(ctx);
}
fn make_event_ctx<E, Ctx: AbstractEventContext<Module = M, Event = E>>(&self, event: E) -> Ctx {
let imp = self.clone();
Ctx::new(imp, event)
}
fn apply_pending_mutations(&self) -> crate::Result<()> {
while let Ok(Some(pending_mutation)) = get_lock_sync(&self.pending_mutations_recv).try_next() {
self.apply_mutation(pending_mutation)?;
}
Ok(())
}
fn apply_mutation(&self, mutation: PendingMutation<M>) -> crate::Result<()> {
self.debug_log(|out| writeln!(out, "`apply_mutation`: {mutation:?}"));
match mutation {
PendingMutation::Subscribe { query_set_id, handle } => {
let mut inner = self.inner.lock().unwrap();
inner.subscriptions.register_subscription(query_set_id, handle.clone());
if let Some(msg) = handle.start() {
self.send_chan
.lock()
.unwrap()
.as_mut()
.ok_or(crate::Error::Disconnected)?
.unbounded_send(ws::v2::ClientMessage::Subscribe(msg))
.expect("Unable to send subscribe message: WS sender loop has dropped its recv channel");
}
}
PendingMutation::Unsubscribe { query_set_id } => {
let mut inner = self.inner.lock().unwrap();
match inner.subscriptions.handle_pending_unsubscribe(query_set_id) {
PendingUnsubscribeResult::DoNothing =>
{
return Ok(())
}
PendingUnsubscribeResult::RunCallback(callback) => {
callback(&self.make_event_ctx(()));
}
PendingUnsubscribeResult::SendUnsubscribe(m) => {
self.send_chan
.lock()
.unwrap()
.as_mut()
.ok_or(crate::Error::Disconnected)?
.unbounded_send(ws::v2::ClientMessage::Unsubscribe(m))
.expect("Unable to send unsubscribe message: WS sender loop has dropped its recv channel");
}
}
}
PendingMutation::InvokeReducerWithCallback { reducer, callback } => {
let request_id = next_request_id();
let reducer_name = reducer.reducer_name();
let args = reducer
.args_bsatn()
.map_err(|e| InternalError::new("Failed to BSATN-serialize reducer arguments").with_cause(e))?;
self.inner
.lock()
.unwrap()
.reducer_callbacks
.store_call_info(request_id, reducer, callback);
let flags = ws::v2::CallReducerFlags::Default;
let msg = ws::v2::ClientMessage::CallReducer(ws::v2::CallReducer {
reducer: reducer_name.into(),
args: args.into(),
request_id,
flags,
});
self.send_chan
.lock()
.unwrap()
.as_mut()
.ok_or(crate::Error::Disconnected)?
.unbounded_send(msg)
.expect("Unable to send reducer call message: WS sender loop has dropped its recv channel");
}
PendingMutation::InvokeProcedureWithCallback {
procedure,
args,
callback,
} => {
let request_id = next_request_id();
self.inner
.lock()
.unwrap()
.procedure_callbacks
.insert(request_id, callback);
let msg = ws::v2::ClientMessage::CallProcedure(ws::v2::CallProcedure {
procedure: procedure.into(),
args: args.into(),
request_id,
flags: ws::v2::CallProcedureFlags::Default,
});
self.send_chan
.lock()
.unwrap()
.as_mut()
.ok_or(crate::Error::Disconnected)?
.unbounded_send(msg)
.expect("Unable to send procedure call message: WS sender loop has dropped its recv channel");
}
PendingMutation::Disconnect => {
*self.send_chan.lock().unwrap() = None;
}
PendingMutation::AddInsertCallback {
table,
callback_id,
callback,
} => {
self.inner
.lock()
.unwrap()
.db_callbacks
.get_table_callbacks(table)
.register_on_insert(callback_id, callback);
}
PendingMutation::AddDeleteCallback {
table,
callback_id,
callback,
} => {
self.inner
.lock()
.unwrap()
.db_callbacks
.get_table_callbacks(table)
.register_on_delete(callback_id, callback);
}
PendingMutation::AddUpdateCallback {
table,
callback_id,
callback,
} => {
self.inner
.lock()
.unwrap()
.db_callbacks
.get_table_callbacks(table)
.register_on_update(callback_id, callback);
}
PendingMutation::RemoveInsertCallback { table, callback_id } => {
self.inner
.lock()
.unwrap()
.db_callbacks
.get_table_callbacks(table)
.remove_on_insert(callback_id);
}
PendingMutation::RemoveDeleteCallback { table, callback_id } => {
self.inner
.lock()
.unwrap()
.db_callbacks
.get_table_callbacks(table)
.remove_on_delete(callback_id);
}
PendingMutation::RemoveUpdateCallback { table, callback_id } => {
self.inner
.lock()
.unwrap()
.db_callbacks
.get_table_callbacks(table)
.remove_on_update(callback_id);
}
};
Ok(())
}
pub fn advance_one_message(&self) -> crate::Result<bool> {
self.apply_pending_mutations()?;
let res = match get_lock_sync(&self.recv).try_next() {
Ok(None) => {
let disconnect_ctx = self.make_event_ctx(None);
self.invoke_disconnected(&disconnect_ctx);
Err(crate::Error::Disconnected)
}
Err(_) => Ok(false),
Ok(Some(msg)) => self.process_message(msg).map(|_| true),
};
self.apply_pending_mutations()?;
res
}
async fn get_message(&self) -> Message<M> {
#![allow(clippy::await_holding_lock)]
let mut pending_mutations = get_lock_async(&self.pending_mutations_recv).await;
let mut recv = get_lock_async(&self.recv).await;
if let Ok(pending_mutation) = pending_mutations.try_next() {
return Message::Local(pending_mutation.unwrap());
}
#[cfg(not(feature = "browser"))]
tokio::select! {
pending_mutation = pending_mutations.next() => Message::Local(pending_mutation.unwrap()),
incoming_message = recv.next() => Message::Ws(incoming_message),
}
#[cfg(feature = "browser")]
{
let (pending_fut, recv_fut) = (pending_mutations.next().fuse(), recv.next().fuse());
pin_mut!(pending_fut, recv_fut);
futures::select! {
pending_mutation = pending_fut => Message::Local(pending_mutation.unwrap()),
incoming_message = recv_fut => Message::Ws(incoming_message),
}
}
}
#[cfg(not(feature = "browser"))]
pub fn advance_one_message_blocking(&self) -> crate::Result<()> {
match self.runtime.block_on(self.get_message()) {
Message::Local(pending) => self.apply_mutation(pending),
Message::Ws(None) => {
let disconnect_ctx = self.make_event_ctx(None);
self.invoke_disconnected(&disconnect_ctx);
Err(crate::Error::Disconnected)
}
Message::Ws(Some(msg)) => self.process_message(msg),
}
}
pub async fn advance_one_message_async(&self) -> crate::Result<()> {
match self.get_message().await {
Message::Local(pending) => self.apply_mutation(pending),
Message::Ws(None) => {
let disconnect_ctx = self.make_event_ctx(None);
self.invoke_disconnected(&disconnect_ctx);
Err(crate::Error::Disconnected)
}
Message::Ws(Some(msg)) => self.process_message(msg),
}
}
pub fn frame_tick(&self) -> crate::Result<()> {
while self.advance_one_message()? {}
Ok(())
}
#[cfg(not(feature = "browser"))]
pub fn run_threaded(&self) -> std::thread::JoinHandle<()> {
let this = self.clone();
std::thread::spawn(move || loop {
match this.advance_one_message_blocking() {
Ok(()) => (),
Err(e) if error_is_normal_disconnect(&e) => return,
Err(e) => panic!("{e:?}"),
}
})
}
#[cfg(feature = "browser")]
pub fn run_background_task(&self) {
let this = self.clone();
wasm_bindgen_futures::spawn_local(async move {
loop {
match this.advance_one_message_async().await {
Ok(()) => (),
Err(e) if error_is_normal_disconnect(&e) => return,
Err(e) => panic!("{e:?}"),
}
}
})
}
pub async fn run_async(&self) -> crate::Result<()> {
let this = self.clone();
loop {
match this.advance_one_message_async().await {
Ok(()) => (),
Err(e) if error_is_normal_disconnect(&e) => return Ok(()),
Err(e) => return Err(e),
}
}
}
pub fn is_active(&self) -> bool {
self.send_chan.lock().unwrap().is_some()
}
pub fn disconnect(&self) -> crate::Result<()> {
if !self.is_active() {
return Err(crate::Error::Disconnected);
}
self.pending_mutations_send
.unbounded_send(PendingMutation::Disconnect)
.unwrap();
Ok(())
}
fn queue_mutation(&self, mutation: PendingMutation<M>) {
self.pending_mutations_send.unbounded_send(mutation).unwrap();
}
pub fn get_table<Row: InModule<Module = M> + Send + Sync + 'static>(
&self,
table_name: &'static str,
) -> TableHandle<Row> {
let client_cache = Arc::clone(&self.cache);
let pending_mutations = self.pending_mutations_send.clone();
TableHandle {
client_cache,
pending_mutations,
table_name,
}
}
pub fn invoke_reducer_with_callback<Args>(
&self,
reducer: Args,
callback: impl FnOnce(&<M as SpacetimeModule>::ReducerEventContext, Result<Result<(), String>, InternalError>)
+ Send
+ 'static,
) -> crate::Result<()>
where
<M as SpacetimeModule>::Reducer: From<Args>,
{
self.queue_mutation(PendingMutation::InvokeReducerWithCallback {
reducer: reducer.into(),
callback: Box::new(callback),
});
Ok(())
}
pub fn try_identity(&self) -> Option<Identity> {
*self.identity.lock().unwrap()
}
pub fn connection_id(&self) -> ConnectionId {
self.try_connection_id().unwrap()
}
pub fn try_connection_id(&self) -> Option<ConnectionId> {
*self.connection_id.lock().unwrap()
}
pub fn invoke_procedure_with_callback<
Args: Serialize + InModule<Module = M>,
RetVal: for<'a> Deserialize<'a> + 'static,
>(
&self,
procedure_name: &'static str,
args: Args,
callback: impl FnOnce(&<M as SpacetimeModule>::ProcedureEventContext, Result<RetVal, InternalError>)
+ Send
+ 'static,
) {
self.queue_mutation(PendingMutation::InvokeProcedureWithCallback {
procedure: procedure_name,
args: bsatn::to_vec(&args).expect("Failed to BSATN serialize procedure args"),
callback: Box::new(move |ctx, ret| {
callback(
ctx,
ret.map(|ret| {
bsatn::from_slice::<RetVal>(&ret[..])
.expect("Failed to BSATN deserialize procedure return value")
}),
)
}),
});
}
}
type OnConnectCallback<M> = Box<dyn FnOnce(&<M as SpacetimeModule>::DbConnection, Identity, &str) + Send + 'static>;
type OnConnectErrorCallback<M> = Box<dyn FnOnce(&<M as SpacetimeModule>::ErrorContext, crate::Error) + Send + 'static>;
type OnDisconnectCallback<M> =
Box<dyn FnOnce(&<M as SpacetimeModule>::ErrorContext, Option<crate::Error>) + Send + 'static>;
pub(crate) struct DbContextImplInner<M: SpacetimeModule> {
#[allow(unused)]
#[cfg(not(feature = "browser"))]
runtime: Option<Runtime>,
db_callbacks: DbCallbacks<M>,
reducer_callbacks: ReducerCallbacks<M>,
pub(crate) subscriptions: SubscriptionManager<M>,
on_connect: Option<OnConnectCallback<M>>,
#[allow(unused)]
on_connect_error: Option<OnConnectErrorCallback<M>>,
on_disconnect: Option<OnDisconnectCallback<M>>,
procedure_callbacks: ProcedureCallbacks<M>,
}
pub struct DbConnectionBuilder<M: SpacetimeModule> {
uri: Option<Uri>,
database_name: Option<String>,
token: Option<String>,
on_connect: Option<OnConnectCallback<M>>,
on_connect_error: Option<OnConnectErrorCallback<M>>,
on_disconnect: Option<OnDisconnectCallback<M>>,
additional_logging_path: Option<PathBuf>,
params: WsParams,
}
static CONNECTION_ID: OnceLock<ConnectionId> = OnceLock::new();
fn get_connection_id_override() -> Option<ConnectionId> {
CONNECTION_ID.get().copied()
}
#[doc(hidden)]
pub fn set_connection_id(id: ConnectionId) -> crate::Result<()> {
let stored = *CONNECTION_ID.get_or_init(|| id);
if stored != id {
return Err(InternalError::new(
"Call to set_connection_id after CONNECTION_ID was initialized to a different value ",
)
.into());
}
Ok(())
}
pub(crate) fn debug_log(
extra_logging: &Option<SharedCell<File>>,
body: impl FnOnce(&mut File) -> std::result::Result<(), std::io::Error>,
) {
if let Some(file) = extra_logging {
body(&mut file.lock().expect("`extra_logging` file Mutex is poisoned")).expect("Writing debug log failed")
}
}
impl<M: SpacetimeModule> DbConnectionBuilder<M> {
#[doc(hidden)]
pub fn new() -> Self {
Self {
uri: None,
database_name: None,
token: None,
on_connect: None,
on_connect_error: None,
on_disconnect: None,
additional_logging_path: None,
params: <_>::default(),
}
}
#[must_use = "
You must explicitly advance the connection by calling any one of:
- `DbConnection::frame_tick`.
- `DbConnection::run_threaded`.
- `DbConnection::run_background_task`.
- `DbConnection::run_async`.
- `DbConnection::advance_one_message`.
- `DbConnection::advance_one_message_blocking`.
- `DbConnection::advance_one_message_async`.
Which of these methods you should call depends on the specific needs of your application,
but you must call one of them, or else the connection will never progress.
"]
#[cfg(not(feature = "browser"))]
pub fn build(self) -> crate::Result<M::DbConnection> {
let imp = self.build_impl()?;
Ok(<M::DbConnection as DbConnection>::new(imp))
}
#[cfg(feature = "browser")]
pub async fn build(self) -> crate::Result<M::DbConnection> {
let imp = self.build_impl().await?;
Ok(<M::DbConnection as DbConnection>::new(imp))
}
#[cfg(not(feature = "browser"))]
fn build_impl(self) -> crate::Result<DbContextImpl<M>> {
let extra_logging = self
.additional_logging_path
.map(|path| {
OpenOptions::new().append(true).create(true).open(&path).map_err(|e| {
InternalError::new(format!("Failed to open file '{path:?}' for additional logging")).with_cause(e)
})
})
.transpose()?
.map(|file| Arc::new(StdMutex::new(file)));
let (runtime, handle) = enter_or_create_runtime()?;
let connection_id_override = get_connection_id_override();
let ws_connection = tokio::task::block_in_place(|| {
handle.block_on(WsConnection::connect(
self.uri.unwrap(),
self.database_name.as_ref().unwrap(),
self.token.as_deref(),
connection_id_override,
self.params,
))
})
.map_err(|source| crate::Error::FailedToConnect {
source: InternalError::new("Failed to initiate WebSocket connection").with_cause(source),
})?;
let (_websocket_loop_handle, raw_msg_recv, raw_msg_send) =
ws_connection.spawn_message_loop(&handle, extra_logging.clone());
let (_parse_loop_handle, parsed_recv_chan) =
spawn_parse_loop::<M>(raw_msg_recv, &handle, extra_logging.clone());
let parsed_recv_chan = Arc::new(TokioMutex::new(parsed_recv_chan));
let (pending_mutations_send, pending_mutations_recv) = mpsc::unbounded();
let pending_mutations_recv = Arc::new(TokioMutex::new(pending_mutations_recv));
let inner_ctx = build_db_ctx_inner(runtime, self.on_connect, self.on_connect_error, self.on_disconnect);
Ok(build_db_ctx(
handle,
inner_ctx,
raw_msg_send,
parsed_recv_chan,
pending_mutations_send,
pending_mutations_recv,
connection_id_override,
extra_logging,
))
}
#[cfg(feature = "browser")]
async fn build_impl(self) -> crate::Result<DbContextImpl<M>> {
let extra_logging = None;
let connection_id_override = get_connection_id_override();
let ws_connection = WsConnection::connect(
self.uri.clone().unwrap(),
self.database_name.as_ref().unwrap(),
self.token.as_deref(),
connection_id_override,
self.params,
)
.await
.map_err(|source| crate::Error::FailedToConnect {
source: InternalError::new("Failed to initiate WebSocket connection").with_cause(source),
})?;
let (raw_msg_recv, raw_msg_send) = ws_connection.spawn_message_loop();
let parsed_recv_chan = spawn_parse_loop::<M>(raw_msg_recv, extra_logging.clone());
let parsed_recv_chan = Arc::new(StdMutex::new(parsed_recv_chan));
let (pending_mutations_send, pending_mutations_recv) = mpsc::unbounded();
let pending_mutations_recv = Arc::new(StdMutex::new(pending_mutations_recv));
let inner_ctx = build_db_ctx_inner(self.on_connect, self.on_connect_error, self.on_disconnect);
Ok(build_db_ctx(
inner_ctx,
raw_msg_send,
parsed_recv_chan,
pending_mutations_send,
pending_mutations_recv,
connection_id_override,
extra_logging,
))
}
pub fn with_uri<E: std::fmt::Debug>(mut self, uri: impl TryInto<Uri, Error = E>) -> Self {
let uri = uri.try_into().expect("Unable to parse supplied URI");
self.uri = Some(uri);
self
}
pub fn with_database_name(mut self, name_or_identity: impl Into<String>) -> Self {
self.database_name = Some(name_or_identity.into());
self
}
pub fn with_token(mut self, token: Option<impl Into<String>>) -> Self {
self.token = token.map(|token| token.into());
self
}
pub fn with_compression(mut self, compression: ws::common::Compression) -> Self {
self.params.compression = compression;
self
}
pub fn with_confirmed_reads(mut self, confirmed: bool) -> Self {
self.params.confirmed = Some(confirmed);
self
}
pub fn with_debug_to_file(mut self, path: impl Into<PathBuf>) -> Self {
self.additional_logging_path = Some(path.into());
self
}
pub fn on_connect(mut self, callback: impl FnOnce(&M::DbConnection, Identity, &str) + Send + 'static) -> Self {
if self.on_connect.is_some() {
panic!(
"DbConnectionBuilder can only register a single `on_connect` callback.
Instead of registering multiple `on_connect` callbacks, register a single callback which does multiple operations."
);
}
self.on_connect = Some(Box::new(callback));
self
}
pub fn on_connect_error(mut self, callback: impl FnOnce(&M::ErrorContext, crate::Error) + Send + 'static) -> Self {
if self.on_connect_error.is_some() {
panic!(
"DbConnectionBuilder can only register a single `on_connect_error` callback.
Instead of registering multiple `on_connect_error` callbacks, register a single callback which does multiple operations."
);
}
self.on_connect_error = Some(Box::new(callback));
self
}
pub fn on_disconnect(
mut self,
callback: impl FnOnce(&M::ErrorContext, Option<crate::Error>) + Send + 'static,
) -> Self {
if self.on_disconnect.is_some() {
panic!(
"DbConnectionBuilder can only register a single `on_disconnect` callback.
Instead of registering multiple `on_disconnect` callbacks, register a single callback which does multiple operations."
);
}
self.on_disconnect = Some(Box::new(callback));
self
}
}
fn build_db_ctx_inner<M: SpacetimeModule>(
#[cfg(not(feature = "browser"))] runtime: Option<Runtime>,
on_connect_cb: Option<OnConnectCallback<M>>,
on_connect_error_cb: Option<OnConnectErrorCallback<M>>,
on_disconnect_cb: Option<OnDisconnectCallback<M>>,
) -> Arc<StdMutex<DbContextImplInner<M>>> {
Arc::new(StdMutex::new(DbContextImplInner {
#[cfg(not(feature = "browser"))]
runtime,
db_callbacks: DbCallbacks::default(),
reducer_callbacks: ReducerCallbacks::default(),
subscriptions: SubscriptionManager::default(),
on_connect: on_connect_cb,
on_connect_error: on_connect_error_cb,
on_disconnect: on_disconnect_cb,
procedure_callbacks: ProcedureCallbacks::default(),
}))
}
#[allow(clippy::too_many_arguments)]
fn build_db_ctx<M: SpacetimeModule>(
#[cfg(not(feature = "browser"))] runtime_handle: runtime::Handle,
inner_ctx: Arc<StdMutex<DbContextImplInner<M>>>,
raw_msg_send: mpsc::UnboundedSender<ws::v2::ClientMessage>,
parsed_msg_recv: SharedAsyncCell<mpsc::UnboundedReceiver<ParsedMessage<M>>>,
pending_mutations_send: mpsc::UnboundedSender<PendingMutation<M>>,
pending_mutations_recv: SharedAsyncCell<mpsc::UnboundedReceiver<PendingMutation<M>>>,
connection_id: Option<ConnectionId>,
extra_logging: Option<SharedCell<File>>,
) -> DbContextImpl<M> {
let mut cache = ClientCache::new(extra_logging.clone());
M::register_tables(&mut cache);
let cache = Arc::new(StdMutex::new(cache));
DbContextImpl {
#[cfg(not(feature = "browser"))]
runtime: runtime_handle,
inner: inner_ctx,
send_chan: Arc::new(StdMutex::new(Some(raw_msg_send))),
cache,
recv: parsed_msg_recv,
pending_mutations_send,
pending_mutations_recv,
identity: Arc::new(StdMutex::new(None)),
connection_id: Arc::new(StdMutex::new(connection_id)),
extra_logging,
}
}
#[cfg(not(feature = "browser"))]
fn enter_or_create_runtime() -> crate::Result<(Option<Runtime>, runtime::Handle)> {
match runtime::Handle::try_current() {
Err(e) if e.is_missing_context() => {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(1)
.thread_name("spacetimedb-background-connection")
.build()
.map_err(|source| InternalError::new("Failed to create Tokio runtime").with_cause(source))?;
let handle = rt.handle().clone();
Ok((Some(rt), handle))
}
Ok(handle) => Ok((None, handle)),
Err(source) => Err(
InternalError::new("Unexpected error when getting current Tokio runtime")
.with_cause(source)
.into(),
),
}
}
#[cfg(not(feature = "browser"))]
fn get_lock_sync<T>(mutex: &TokioMutex<T>) -> tokio::sync::MutexGuard<'_, T> {
mutex.blocking_lock()
}
#[cfg(feature = "browser")]
fn get_lock_sync<T>(mutex: &StdMutex<T>) -> std::sync::MutexGuard<'_, T> {
mutex.lock().unwrap()
}
#[cfg(not(feature = "browser"))]
async fn get_lock_async<T>(mutex: &TokioMutex<T>) -> tokio::sync::MutexGuard<'_, T> {
mutex.lock().await
}
#[cfg(feature = "browser")]
pub async fn get_lock_async<T>(mutex: &StdMutex<T>) -> std::sync::MutexGuard<'_, T> {
mutex.lock().unwrap()
}
#[derive(Debug)]
enum ParsedMessage<M: SpacetimeModule> {
TransactionUpdate(M::DbUpdate),
IdentityToken(Identity, Box<str>, ConnectionId),
SubscribeApplied {
query_set_id: QuerySetId,
initial_update: M::DbUpdate,
},
UnsubscribeApplied {
query_set_id: QuerySetId,
initial_update: M::DbUpdate,
},
SubscriptionError {
query_set_id: QuerySetId,
error: String,
},
Error(crate::Error),
ReducerResult {
request_id: u32,
timestamp: Timestamp,
result: Result<Result<M::DbUpdate, String>, InternalError>,
},
ProcedureResult {
request_id: u32,
result: Result<Bytes, InternalError>,
},
}
#[cfg(not(feature = "browser"))]
fn spawn_parse_loop<M: SpacetimeModule>(
raw_message_recv: mpsc::UnboundedReceiver<ws::v2::ServerMessage>,
handle: &runtime::Handle,
extra_logging: Option<SharedCell<File>>,
) -> (tokio::task::JoinHandle<()>, mpsc::UnboundedReceiver<ParsedMessage<M>>) {
let (parsed_message_send, parsed_message_recv) = mpsc::unbounded();
let handle = handle.spawn(parse_loop(raw_message_recv, parsed_message_send, extra_logging));
(handle, parsed_message_recv)
}
#[cfg(feature = "browser")]
fn spawn_parse_loop<M: SpacetimeModule>(
raw_message_recv: mpsc::UnboundedReceiver<ws::v2::ServerMessage>,
extra_logging: Option<SharedCell<File>>,
) -> mpsc::UnboundedReceiver<ParsedMessage<M>> {
let (parsed_message_send, parsed_message_recv) = mpsc::unbounded();
wasm_bindgen_futures::spawn_local(parse_loop(raw_message_recv, parsed_message_send, extra_logging));
parsed_message_recv
}
async fn parse_loop<M: SpacetimeModule>(
mut recv: mpsc::UnboundedReceiver<ws::v2::ServerMessage>,
send: mpsc::UnboundedSender<ParsedMessage<M>>,
extra_logging: Option<SharedCell<File>>,
) {
while let Some(msg) = recv.next().await {
debug_log(&extra_logging, |file| {
writeln!(file, "`parse_loop`: Got raw message: {msg:?}")
});
let parsed = match msg {
ws::v2::ServerMessage::TransactionUpdate(transaction_update) => {
match M::DbUpdate::parse_update(transaction_update) {
Err(e) => ParsedMessage::Error(
InternalError::failed_parse("TransactionUpdate", "TransactionUpdate")
.with_cause(e)
.into(),
),
Ok(db_update) => ParsedMessage::TransactionUpdate(db_update),
}
}
ws::v2::ServerMessage::ReducerResult(ws::v2::ReducerResult {
request_id,
result,
timestamp,
}) => {
match result {
ws::v2::ReducerOutcome::OkEmpty => ParsedMessage::ReducerResult {
request_id,
timestamp,
result: Ok(Ok(M::DbUpdate::default())),
},
ws::v2::ReducerOutcome::Ok(ws::v2::ReducerOk {
ret_value,
transaction_update,
}) => {
assert!(
ret_value.is_empty(),
"Reducer return value should be unit, i.e. 0 bytes, but got {ret_value:?}"
);
match M::DbUpdate::parse_update(transaction_update) {
Ok(db_update) => ParsedMessage::ReducerResult {
request_id,
timestamp,
result: Ok(Ok(db_update)),
},
Err(e) => ParsedMessage::Error(
InternalError::failed_parse("TransactionUpdate", "ReducerResult")
.with_cause(e)
.into(),
),
}
}
ws::v2::ReducerOutcome::Err(error_return) => match bsatn::from_slice::<String>(&error_return) {
Ok(error_message) => ParsedMessage::ReducerResult {
request_id,
timestamp,
result: Ok(Err(error_message)),
},
Err(e) => ParsedMessage::Error(
InternalError::failed_parse("String", "ReducerResult")
.with_cause(e)
.into(),
),
},
ws::v2::ReducerOutcome::InternalError(error_message) => ParsedMessage::ReducerResult {
request_id,
timestamp,
result: Err(InternalError::new(error_message)),
},
}
}
ws::v2::ServerMessage::InitialConnection(ws::v2::InitialConnection {
identity,
token,
connection_id,
}) => ParsedMessage::IdentityToken(identity, token, connection_id),
ws::v2::ServerMessage::OneOffQueryResult(_) => {
unreachable!("The Rust SDK does not implement one-off queries")
}
ws::v2::ServerMessage::SubscribeApplied(subscribe_applied) => {
let db_update = subscribe_applied.rows;
let query_set_id = subscribe_applied.query_set_id;
match M::DbUpdate::parse_initial_rows(db_update) {
Err(e) => ParsedMessage::Error(
InternalError::failed_parse("DbUpdate", "SubscribeApplied")
.with_cause(e)
.into(),
),
Ok(initial_update) => ParsedMessage::SubscribeApplied {
query_set_id,
initial_update,
},
}
}
ws::v2::ServerMessage::UnsubscribeApplied(ws::v2::UnsubscribeApplied {
query_set_id,
rows: db_update,
..
}) => {
let Some(db_update) = db_update else {
unreachable!("The Rust SDK always requests rows to delete when unsubscribing")
};
match M::DbUpdate::parse_unsubscribe_rows(db_update) {
Err(e) => ParsedMessage::Error(
InternalError::failed_parse("DbUpdate", "UnsubscribeApplied")
.with_cause(e)
.into(),
),
Ok(initial_update) => ParsedMessage::UnsubscribeApplied {
query_set_id,
initial_update,
},
}
}
ws::v2::ServerMessage::SubscriptionError(e) => ParsedMessage::SubscriptionError {
query_set_id: e.query_set_id,
error: e.error.to_string(),
},
ws::v2::ServerMessage::ProcedureResult(procedure_result) => ParsedMessage::ProcedureResult {
request_id: procedure_result.request_id,
result: match procedure_result.status {
ws::v2::ProcedureStatus::InternalError(msg) => Err(InternalError::new(msg)),
ws::v2::ProcedureStatus::Returned(val) => Ok(val),
},
},
};
debug_log(&extra_logging, |file| {
writeln!(file, "`parse_loop`: Parsed as: {parsed:?}")
});
send.unbounded_send(parsed)
.expect("Failed to send ParsedMessage to main thread");
}
}
pub(crate) enum PendingMutation<M: SpacetimeModule> {
Unsubscribe {
query_set_id: QuerySetId,
},
Subscribe {
query_set_id: QuerySetId,
handle: SubscriptionHandleImpl<M>,
},
AddInsertCallback {
table: &'static str,
callback_id: CallbackId,
callback: RowCallback<M>,
},
RemoveInsertCallback {
table: &'static str,
callback_id: CallbackId,
},
AddDeleteCallback {
table: &'static str,
callback_id: CallbackId,
callback: RowCallback<M>,
},
RemoveDeleteCallback {
table: &'static str,
callback_id: CallbackId,
},
AddUpdateCallback {
table: &'static str,
callback_id: CallbackId,
callback: UpdateCallback<M>,
},
RemoveUpdateCallback {
table: &'static str,
callback_id: CallbackId,
},
Disconnect,
InvokeReducerWithCallback {
reducer: M::Reducer,
callback: ReducerCallback<M>,
},
InvokeProcedureWithCallback {
procedure: &'static str,
args: Vec<u8>,
callback: ProcedureCallback<M>,
},
}
impl<M: SpacetimeModule> std::fmt::Debug for PendingMutation<M> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
PendingMutation::Unsubscribe { query_set_id } => f
.debug_struct("PendingMutation::Unsubscribe")
.field("query_set_id", query_set_id)
.finish(),
PendingMutation::Subscribe { query_set_id, .. } => f
.debug_struct("PendingMutation::Subscribe")
.field("query_set_id", query_set_id)
.finish_non_exhaustive(),
PendingMutation::AddInsertCallback { table, callback_id, .. } => f
.debug_struct("PendingMutation::AddInsertCallback")
.field("table", table)
.field("callback_id", callback_id)
.finish_non_exhaustive(),
PendingMutation::RemoveInsertCallback { table, callback_id } => f
.debug_struct("PendingMutation::RemoveInsertCallback")
.field("table", table)
.field("callback_id", callback_id)
.finish(),
PendingMutation::AddDeleteCallback { table, callback_id, .. } => f
.debug_struct("PendingMutation::AddDeleteCallback")
.field("table", table)
.field("callback_id", callback_id)
.finish_non_exhaustive(),
PendingMutation::RemoveDeleteCallback { table, callback_id } => f
.debug_struct("PendingMutation::RemoveDeleteCallback")
.field("table", table)
.field("callback_id", callback_id)
.finish(),
PendingMutation::AddUpdateCallback { table, callback_id, .. } => f
.debug_struct("PendingMutation::AddUpdateCallback")
.field("table", table)
.field("callback_id", callback_id)
.finish_non_exhaustive(),
PendingMutation::RemoveUpdateCallback { table, callback_id } => f
.debug_struct("PendingMutation::RemoveUpdateCallback")
.field("table", table)
.field("callback_id", callback_id)
.finish(),
PendingMutation::Disconnect => write!(f, "PendingMutation::Disconnect"),
PendingMutation::InvokeReducerWithCallback { reducer, .. } => f
.debug_struct("PendingMutation::InvokeReducerWithCallback")
.field("reducer", reducer)
.finish_non_exhaustive(),
PendingMutation::InvokeProcedureWithCallback { procedure, args, .. } => f
.debug_struct("PendingMutation::InvokeProcedureWithCallback")
.field("procedure", procedure)
.field("args", args)
.finish_non_exhaustive(),
}
}
}
enum Message<M: SpacetimeModule> {
Ws(Option<ParsedMessage<M>>),
Local(PendingMutation<M>),
}
fn error_is_normal_disconnect(e: &crate::Error) -> bool {
matches!(e, crate::Error::Disconnected)
}
static NEXT_REQUEST_ID: AtomicU32 = AtomicU32::new(1);
pub(crate) fn next_request_id() -> u32 {
NEXT_REQUEST_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
static NEXT_QUERY_SET_ID: AtomicU32 = AtomicU32::new(1);
pub(crate) fn next_query_set_id() -> QuerySetId {
QuerySetId {
id: NEXT_QUERY_SET_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
}
}