mod reconnect;
use crate::{
alias::{ReadStdbConnectedMessage, ReadStdbDisconnectedMessage},
channel_bridge::{channel_sender, register_channel},
message::{StdbConnectErrorMessage, StdbConnectedMessage, StdbDisconnectedMessage},
set::StdbSet,
};
use bevy_app::{App, Plugin, PreUpdate};
use bevy_ecs::prelude::{Commands, IntoScheduleConfigs, Res, Resource, World, resource_exists};
use bevy_tasks::{IoTaskPool, Task, block_on, poll_once};
use crossbeam_channel::Sender;
pub(crate) use reconnect::ReconnectPlugin;
pub use reconnect::StdbReconnectOptions;
use spacetimedb_sdk::{
__codegen::{DbConnection, SpacetimeModule},
Compression, ConnectionId, DbConnectionBuilder, DbContext, Identity, Result,
};
use std::sync::Arc;
#[derive(Resource)]
pub(crate) struct PendingConnection<C: DbContext + Send + Sync + 'static>(
pub(crate) Task<Result<Arc<C>>>,
);
pub(crate) enum ConnectionDriver<C: DbContext + Send + Sync + 'static> {
FrameTick(fn(&C) -> Result<()>),
Background(Arc<dyn Fn(&C) + Send + Sync>),
}
impl<C> Clone for ConnectionDriver<C>
where
C: DbContext + Send + Sync + 'static,
{
fn clone(&self) -> Self {
match self {
Self::FrameTick(frame_tick) => Self::FrameTick(*frame_tick),
Self::Background(background_driver) => Self::Background(background_driver.clone()),
}
}
}
#[derive(Resource)]
pub(crate) struct StdbConnectionConfig<
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
> {
pub(crate) database_name: String,
pub(crate) uri: String,
pub(crate) token: Option<String>,
driver: Option<ConnectionDriver<C>>,
compression: Compression,
connected_tx: Sender<StdbConnectedMessage>,
disconnected_tx: Sender<StdbDisconnectedMessage>,
connect_error_tx: Sender<StdbConnectErrorMessage>,
}
impl<C, M> Clone for StdbConnectionConfig<C, M>
where
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
{
fn clone(&self) -> Self {
Self {
database_name: self.database_name.clone(),
uri: self.uri.clone(),
token: self.token.clone(),
driver: self.driver.clone(),
compression: self.compression,
connected_tx: self.connected_tx.clone(),
disconnected_tx: self.disconnected_tx.clone(),
connect_error_tx: self.connect_error_tx.clone(),
}
}
}
impl<C, M> StdbConnectionConfig<C, M>
where
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
{
fn connection_builder(&self) -> DbConnectionBuilder<M> {
let connected_tx = self.connected_tx.clone();
let disconnected_tx = self.disconnected_tx.clone();
let connect_error_tx = self.connect_error_tx.clone();
DbConnectionBuilder::<M>::new()
.with_database_name(self.database_name.clone())
.with_uri(self.uri.clone())
.with_token(self.token.clone())
.with_compression(self.compression)
.on_connect(move |_ctx, id, token| {
let _ = connected_tx.send(StdbConnectedMessage {
identity: id,
access_token: token.to_string(),
});
})
.on_disconnect(move |_ctx, err| {
let _ = disconnected_tx.send(StdbDisconnectedMessage { err });
})
.on_connect_error(move |_ctx, err| {
let _ = connect_error_tx.send(StdbConnectErrorMessage { err });
})
}
pub(crate) async fn build_connection(&self) -> Result<Arc<C>> {
#[cfg(not(feature = "browser"))]
return self.connection_builder().build().map(Arc::new);
#[cfg(feature = "browser")]
return self.connection_builder().build().await.map(Arc::new);
}
}
#[derive(Resource)]
pub struct StdbConnection<T: DbContext + 'static> {
conn: Arc<T>,
}
impl<T: DbContext> StdbConnection<T> {
fn new(conn: Arc<T>) -> Self {
Self { conn }
}
}
impl<T: DbContext> StdbConnection<T> {
pub fn db(&self) -> &T::DbView {
self.conn.db()
}
pub fn reducers(&self) -> &T::Reducers {
self.conn.reducers()
}
pub fn procedures(&self) -> &T::Procedures {
self.conn.procedures()
}
pub fn is_active(&self) -> bool {
self.conn.is_active()
}
pub fn disconnect(&self) -> Result<()> {
self.conn.disconnect()
}
pub fn subscription_builder(&self) -> T::SubscriptionBuilder {
self.conn.subscription_builder()
}
pub fn identity(&self) -> Identity {
self.conn.identity()
}
pub fn try_identity(&self) -> Option<Identity> {
self.conn.try_identity()
}
pub fn connection_id(&self) -> ConnectionId {
self.conn.connection_id()
}
pub fn try_connection_id(&self) -> Option<ConnectionId> {
self.conn.try_connection_id()
}
}
pub(crate) struct StdbConnectionPlugin<
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
> {
pub database_name: String,
pub uri: String,
pub token: Option<String>,
pub eager_connection: bool,
pub driver: Option<ConnectionDriver<C>>,
pub compression: Compression,
}
impl<
C: DbConnection<Module = M> + DbContext + Send + Sync + 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
> Plugin for StdbConnectionPlugin<C, M>
{
fn build(&self, app: &mut App) {
register_channel::<StdbConnectedMessage>(app);
register_channel::<StdbDisconnectedMessage>(app);
register_channel::<StdbConnectErrorMessage>(app);
let world = app.world();
app.insert_resource(StdbConnectionConfig::<C, M> {
database_name: self.database_name.clone(),
uri: self.uri.clone(),
token: self.token.clone(),
driver: self.driver.clone(),
compression: self.compression,
connected_tx: channel_sender::<StdbConnectedMessage>(world),
disconnected_tx: channel_sender::<StdbDisconnectedMessage>(world),
connect_error_tx: channel_sender::<StdbConnectErrorMessage>(world),
});
app.add_systems(
PreUpdate,
sync_connection_resource::<C>.in_set(StdbSet::StateSync),
);
app.add_systems(
PreUpdate,
poll_pending_connection::<C, M>
.run_if(resource_exists::<PendingConnection<C>>)
.in_set(StdbSet::Connection),
);
if matches!(self.driver, Some(ConnectionDriver::FrameTick(_))) {
app.add_systems(
PreUpdate,
(|conn: Res<StdbConnection<C>>, config: Res<StdbConnectionConfig<C, M>>| {
if let Some(ConnectionDriver::FrameTick(frame_tick)) = config.driver {
let _ = frame_tick(conn.conn.as_ref());
}
})
.in_set(StdbSet::Connection)
.run_if(resource_exists::<StdbConnection<C>>),
);
}
if self.eager_connection {
let config = app.world().resource::<StdbConnectionConfig<C, M>>().clone();
let task = IoTaskPool::get().spawn(async move { config.build_connection().await });
app.insert_resource(PendingConnection::<C>(task));
}
}
}
fn poll_pending_connection<
C: DbConnection<Module = M> + DbContext + Send + Sync + 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
>(
world: &mut World,
) {
let Some(pending_connection) = world.remove_resource::<PendingConnection<C>>() else {
return;
};
match pending_connection {
PendingConnection(mut task) => {
let Some(result) = block_on(poll_once(&mut task)) else {
world.insert_resource(PendingConnection::<C>(task));
return;
};
match result {
Ok(conn) => {
let driver = world
.get_resource::<StdbConnectionConfig<C, M>>()
.expect("StdbConnectionConfig should exist when activating a connection")
.driver
.clone();
if let Some(ConnectionDriver::Background(background_driver)) = driver {
background_driver(conn.as_ref());
}
if let Some(prev_conn) = world.get_resource::<StdbConnection<C>>() {
let _ = prev_conn.disconnect();
}
world.insert_resource(StdbConnection::new(conn));
}
Err(err) => {
world.write_message(StdbConnectErrorMessage { err });
}
}
}
}
}
fn sync_connection_resource<C: DbContext + Send + Sync + 'static>(
mut connected_msgs: ReadStdbConnectedMessage,
mut disconnected_msgs: ReadStdbDisconnectedMessage,
conn: Option<Res<StdbConnection<C>>>,
mut commands: Commands,
) {
if (connected_msgs.read().next().is_some() || disconnected_msgs.read().next().is_some())
&& conn.as_ref().is_some_and(|conn| !conn.is_active())
{
commands.remove_resource::<StdbConnection<C>>();
}
}