use crate::{
alias::{
ReadStdbConnectedMessage, ReadStdbConnectionErrorMessage, ReadStdbDisconnectedMessage,
},
channel_bridge::{channel_sender, register_channel},
message::{
ConnectionBuildFinishedMessage, RequestStdbConnectionMessage, StdbConnectedMessage,
StdbConnectionErrorMessage, StdbDisconnectedMessage,
},
set::StdbSet,
table::TableBindCallback,
};
use bevy_app::{App, Plugin, PreUpdate};
use bevy_ecs::prelude::{
Commands, IntoScheduleConfigs, Messages, Res, ResMut, Resource, World, not,
};
use bevy_state::prelude::{AppExtStates, NextState, OnEnter, States, in_state};
use crossbeam_channel::Sender;
use spacetimedb_sdk::{
__codegen::{DbConnection, SpacetimeModule},
Compression, ConnectionId, DbConnectionBuilder, DbContext, Identity, Result,
};
use std::sync::Arc;
#[derive(States, Debug, Default, Clone, PartialEq, Eq, Hash)]
pub enum StdbConnectionState {
#[default]
Uninitialized,
Connecting,
Connected,
Disconnected,
Exhausted,
}
pub(crate) enum ConnectionDriver<C: DbContext + Send + Sync + 'static> {
FrameTick(fn(&C) -> Result<()>),
Background(Arc<dyn Fn(&C) + Send + Sync>),
}
impl<C> Clone for ConnectionDriver<C>
where
C: DbContext + Send + Sync + 'static,
{
fn clone(&self) -> Self {
match self {
Self::FrameTick(frame_tick) => Self::FrameTick(*frame_tick),
Self::Background(background_driver) => Self::Background(background_driver.clone()),
}
}
}
#[derive(Resource)]
pub(crate) struct StdbConnectionConfig<
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
> {
module_name: String,
uri: String,
token: Option<String>,
driver: Option<ConnectionDriver<C>>,
compression: Compression,
delayed_connection: bool,
table_bindings: Vec<Arc<TableBindCallback<C>>>,
connected_tx: Sender<StdbConnectedMessage>,
disconnected_tx: Sender<StdbDisconnectedMessage>,
error_tx: Sender<StdbConnectionErrorMessage>,
}
impl<C, M> Clone for StdbConnectionConfig<C, M>
where
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
{
fn clone(&self) -> Self {
Self {
module_name: self.module_name.clone(),
uri: self.uri.clone(),
token: self.token.clone(),
driver: self.driver.clone(),
compression: self.compression,
delayed_connection: self.delayed_connection,
table_bindings: self.table_bindings.clone(),
connected_tx: self.connected_tx.clone(),
disconnected_tx: self.disconnected_tx.clone(),
error_tx: self.error_tx.clone(),
}
}
}
impl<C, M> StdbConnectionConfig<C, M>
where
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
{
fn connection_builder(&self) -> DbConnectionBuilder<M> {
let connected_tx = self.connected_tx.clone();
let disconnected_tx = self.disconnected_tx.clone();
let error_tx = self.error_tx.clone();
DbConnectionBuilder::<M>::new()
.with_database_name(self.module_name.clone())
.with_uri(self.uri.clone())
.with_token(self.token.clone())
.with_compression(self.compression)
.on_connect(move |_ctx, id, token| {
let _ = connected_tx.send(StdbConnectedMessage {
identity: id,
access_token: token.to_string(),
});
})
.on_disconnect(move |_ctx, err| {
let _ = disconnected_tx.send(StdbDisconnectedMessage { err });
})
.on_connect_error(move |_ctx, err| {
let _ = error_tx.send(StdbConnectionErrorMessage { err });
})
}
#[cfg(not(feature = "browser"))]
pub(crate) fn build_connection(&self) -> Result<Arc<C>> {
self.connection_builder().build().map(Arc::new)
}
#[cfg(feature = "browser")]
pub(crate) async fn build_connection(&self) -> Result<Arc<C>> {
self.connection_builder().build().await.map(Arc::new)
}
}
#[derive(Resource)]
pub struct StdbConnection<T: DbContext + 'static> {
conn: Arc<T>,
}
impl<T: DbContext> StdbConnection<T> {
fn new(conn: Arc<T>) -> Self {
Self { conn }
}
}
impl<T: DbContext> StdbConnection<T> {
pub fn db(&self) -> &T::DbView {
self.conn.db()
}
pub fn reducers(&self) -> &T::Reducers {
self.conn.reducers()
}
pub fn procedures(&self) -> &T::Procedures {
self.conn.procedures()
}
pub fn is_active(&self) -> bool {
self.conn.is_active()
}
pub fn disconnect(&self) -> Result<()> {
self.conn.disconnect()
}
pub fn subscription_builder(&self) -> T::SubscriptionBuilder {
self.conn.subscription_builder()
}
pub fn identity(&self) -> Identity {
self.conn.identity()
}
pub fn try_identity(&self) -> Option<Identity> {
self.conn.try_identity()
}
pub fn connection_id(&self) -> ConnectionId {
self.conn.connection_id()
}
pub fn try_connection_id(&self) -> Option<ConnectionId> {
self.conn.try_connection_id()
}
}
pub(crate) struct StdbConnectionPlugin<
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
> {
pub module_name: String,
pub uri: String,
pub token: Option<String>,
pub driver: Option<ConnectionDriver<C>>,
pub compression: Compression,
pub delayed_connection: bool,
pub table_bindings: Vec<Arc<TableBindCallback<C>>>,
}
impl<
C: DbConnection<Module = M> + DbContext + Send + Sync + 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
> Plugin for StdbConnectionPlugin<C, M>
{
fn build(&self, app: &mut App) {
app.init_state::<StdbConnectionState>();
app.add_message::<RequestStdbConnectionMessage>();
register_channel::<StdbConnectedMessage>(app);
register_channel::<StdbDisconnectedMessage>(app);
register_channel::<StdbConnectionErrorMessage>(app);
#[cfg(feature = "browser")]
register_channel::<ConnectionBuildFinishedMessage<C>>(app);
#[cfg(not(feature = "browser"))]
app.add_message::<ConnectionBuildFinishedMessage<C>>();
let world = app.world();
let config = StdbConnectionConfig::<C, M> {
module_name: self.module_name.clone(),
uri: self.uri.clone(),
token: self.token.clone(),
driver: self.driver.clone(),
compression: self.compression,
delayed_connection: self.delayed_connection,
table_bindings: self.table_bindings.clone(),
connected_tx: channel_sender::<StdbConnectedMessage>(world),
disconnected_tx: channel_sender::<StdbDisconnectedMessage>(world),
error_tx: channel_sender::<StdbConnectionErrorMessage>(world),
};
app.insert_resource(config);
app.add_systems(
PreUpdate,
sync_connection_state::<C>.in_set(StdbSet::StateSync),
);
app.add_systems(
PreUpdate,
handle_connection_request::<C, M>
.in_set(StdbSet::Connection)
.run_if(not(in_state(StdbConnectionState::Connected)))
.run_if(not(in_state(StdbConnectionState::Connecting))),
);
app.add_systems(
PreUpdate,
finalize_pending_connection::<C, M>.in_set(StdbSet::Connection),
);
app.add_systems(
OnEnter(StdbConnectionState::Connected),
on_connected_bind::<C, M>,
);
if matches!(self.driver, Some(ConnectionDriver::FrameTick(_))) {
app.add_systems(
PreUpdate,
(|conn: Res<StdbConnection<C>>, config: Res<StdbConnectionConfig<C, M>>| {
let Some(ConnectionDriver::FrameTick(frame_tick)) = config.driver.as_ref() else {
panic!("frame tick system should only be added when the frame tick driver is configured");
};
let _ = frame_tick(conn.conn.as_ref());
})
.in_set(StdbSet::Connection)
.run_if(in_state(StdbConnectionState::Connected)),
);
}
if !self.delayed_connection {
app.world_mut()
.write_message_default::<RequestStdbConnectionMessage>();
}
}
}
fn handle_connection_request<
C: DbConnection<Module = M> + DbContext + Send + Sync + 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
>(
world: &mut World,
) {
if world.get_resource::<StdbConnection<C>>().is_some() {
return world
.resource_mut::<Messages<RequestStdbConnectionMessage>>()
.clear();
}
let Some(latest_request) = world
.resource_mut::<Messages<RequestStdbConnectionMessage>>()
.drain()
.last()
else {
return;
};
let connect_config = {
let mut config = world.resource_mut::<StdbConnectionConfig<C, M>>();
if let Some(token) = latest_request.token {
config.token = Some(token);
}
if let Some(uri) = latest_request.uri {
config.uri = uri;
}
if let Some(module_name) = latest_request.module_name {
config.module_name = module_name;
}
config.clone()
};
world
.resource_mut::<NextState<StdbConnectionState>>()
.set(StdbConnectionState::Connecting);
#[cfg(not(feature = "browser"))]
world.write_message(ConnectionBuildFinishedMessage {
result: connect_config.build_connection(),
});
#[cfg(feature = "browser")]
{
let sender = channel_sender::<ConnectionBuildFinishedMessage<C>>(world);
wasm_bindgen_futures::spawn_local(async move {
let _ = sender.send(ConnectionBuildFinishedMessage {
result: connect_config.build_connection().await,
});
});
}
}
fn finalize_pending_connection<
C: DbConnection<Module = M> + DbContext + Send + Sync + 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
>(
world: &mut World,
) {
let finished_msgs: Vec<ConnectionBuildFinishedMessage<C>> = {
let mut messages = world.resource_mut::<Messages<ConnectionBuildFinishedMessage<C>>>();
messages.drain().collect()
};
for msg in finished_msgs {
match msg.result {
Ok(conn) => {
let driver = world
.get_resource::<StdbConnectionConfig<C, M>>()
.expect("StdbConnectionConfig should exist when activating a connection")
.driver
.clone();
if let Some(ConnectionDriver::Background(background_driver)) = driver {
background_driver(conn.as_ref());
}
world.insert_resource(StdbConnection::new(conn));
}
Err(_) => {
world
.resource_mut::<NextState<StdbConnectionState>>()
.set(StdbConnectionState::Disconnected);
}
}
}
}
fn sync_connection_state<C: DbContext + Send + Sync + 'static>(
mut connected_msgs: ReadStdbConnectedMessage,
mut disconnected_msgs: ReadStdbDisconnectedMessage,
mut connection_error_msgs: ReadStdbConnectionErrorMessage,
mut next_state: ResMut<NextState<StdbConnectionState>>,
mut commands: Commands,
) {
if connected_msgs.read().count() > 0 {
next_state.set(StdbConnectionState::Connected);
}
if disconnected_msgs.read().count() > 0 {
commands.remove_resource::<StdbConnection<C>>();
next_state.set(StdbConnectionState::Disconnected);
}
if connection_error_msgs.read().count() > 0 {
commands.remove_resource::<StdbConnection<C>>();
next_state.set(StdbConnectionState::Disconnected);
}
}
fn on_connected_bind<
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
>(
world: &mut World,
) {
let config = world
.get_resource::<StdbConnectionConfig<C, M>>()
.expect("StdbConnectionConfig should exist before Connected bind phase");
let conn = world
.get_resource::<StdbConnection<C>>()
.expect("StdbConnection should exist before Connected bind phase");
let db = conn.db();
for bind in &config.table_bindings {
bind(&*world, db);
}
}