use crate::{
channel_bridge::ChannelBridgePlugin,
connection::{ConnectionDriver, ReconnectPlugin, StdbConnectionPlugin, StdbReconnectOptions},
message::RowEvent,
set::StdbSet,
subscription::{SubscriptionsInitializer, SubscriptionsPlugin},
table::{
EventTableBinder, StdbTablePlugin, TableBindCallback, TableBinder,
TableRegistrationCallback, TableWithoutPkBinder, ViewBinder, register_event_table,
register_table, register_table_without_pk, register_view,
},
};
use bevy_app::{App, Plugin, PreStartup, PreUpdate};
use bevy_ecs::prelude::IntoScheduleConfigs;
use spacetimedb_sdk::{
__codegen::{DbConnection, InModule, SpacetimeModule, SubscriptionBuilder},
Compression, DbContext, SubscriptionHandle,
};
use std::{hash::Hash, sync::Arc};
pub struct StdbPlugin<
C: DbConnection<Module = M> + DbContext + Send + Sync,
M: SpacetimeModule<DbConnection = C>,
> {
database_name: Option<String>,
uri: Option<String>,
token: Option<String>,
compression: Option<Compression>,
eager_connection: bool,
driver: Option<ConnectionDriver<C>>,
reconnect_options: Option<StdbReconnectOptions>,
subscriptions_initializer: Option<Arc<SubscriptionsInitializer>>,
table_registrations: Vec<Arc<TableRegistrationCallback>>,
table_bindings: Vec<Arc<TableBindCallback<C>>>,
}
impl<C: DbConnection<Module = M> + DbContext + Send + Sync, M: SpacetimeModule<DbConnection = C>>
Default for StdbPlugin<C, M>
{
fn default() -> Self {
Self {
database_name: None,
uri: None,
token: None,
compression: None,
eager_connection: false,
driver: None,
reconnect_options: None,
subscriptions_initializer: None,
table_registrations: Vec::new(),
table_bindings: Vec::new(),
}
}
}
impl<C: DbConnection<Module = M> + DbContext + Send + Sync, M: SpacetimeModule<DbConnection = C>>
StdbPlugin<C, M>
{
pub fn with_eager_connection(mut self) -> Self {
self.eager_connection = true;
self
}
pub fn with_frame_driver(mut self, frame_tick: fn(&C) -> spacetimedb_sdk::Result<()>) -> Self {
assert!(
self.driver.is_none(),
"only one connection driver may be configured"
);
self.driver = Some(ConnectionDriver::FrameTick(frame_tick));
self
}
pub fn with_background_driver<R>(mut self, background_driver: fn(&C) -> R) -> Self
where
R: 'static,
{
assert!(
self.driver.is_none(),
"only one connection driver may be configured"
);
self.driver = Some(ConnectionDriver::Background(Arc::new(move |conn: &C| {
let _ = background_driver(conn);
})));
self
}
pub fn with_database_name(mut self, name: impl Into<String>) -> Self {
assert!(
self.database_name.is_none(),
"`with_database_name()` may only be called once"
);
self.database_name = Some(name.into());
self
}
pub fn with_uri(mut self, uri: impl Into<String>) -> Self {
assert!(self.uri.is_none(), "`with_uri()` may only be called once");
self.uri = Some(uri.into());
self
}
pub fn with_token(mut self, token: impl Into<String>) -> Self {
assert!(
self.token.is_none(),
"`with_token()` may only be called once"
);
self.token = Some(token.into());
self
}
pub fn with_compression(mut self, compression: Compression) -> Self {
assert!(
self.compression.is_none(),
"`with_compression()` may only be called once"
);
self.compression = Some(compression);
self
}
pub fn add_table<TRow>(
mut self,
bind: impl for<'db> Fn(TableBinder<'_, TRow>, &'db C::DbView) + Send + Sync + 'static,
) -> Self
where
TRow: Send + Sync + Clone + InModule + 'static,
RowEvent<TRow>: Send + Sync,
{
self.table_registrations
.push(Arc::new(register_table::<TRow>));
self.table_bindings.push(Arc::new(move |world, db| {
let reg = TableBinder::<TRow>::new(world);
bind(reg, db);
}));
self
}
pub fn add_table_without_pk<TRow>(
mut self,
bind: impl for<'db> Fn(TableWithoutPkBinder<'_, TRow>, &'db C::DbView) + Send + Sync + 'static,
) -> Self
where
TRow: Send + Sync + Clone + InModule + 'static,
RowEvent<TRow>: Send + Sync,
{
self.table_registrations
.push(Arc::new(register_table_without_pk::<TRow>));
self.table_bindings.push(Arc::new(move |world, db| {
let reg = TableWithoutPkBinder::<TRow>::new(world);
bind(reg, db);
}));
self
}
pub fn add_view<TRow>(
mut self,
bind: impl for<'db> Fn(ViewBinder<'_, TRow>, &'db C::DbView) + Send + Sync + 'static,
) -> Self
where
TRow: Send + Sync + Clone + InModule + 'static,
RowEvent<TRow>: Send + Sync,
{
self.table_registrations
.push(Arc::new(register_view::<TRow>));
self.table_bindings.push(Arc::new(move |world, db| {
let reg = ViewBinder::<TRow>::new(world);
bind(reg, db);
}));
self
}
pub fn add_event_table<TRow>(
mut self,
bind: impl for<'db> Fn(EventTableBinder<'_, TRow>, &'db C::DbView) + Send + Sync + 'static,
) -> Self
where
TRow: Send + Sync + Clone + InModule + 'static,
RowEvent<TRow>: Send + Sync,
{
self.table_registrations
.push(Arc::new(register_event_table::<TRow>));
self.table_bindings.push(Arc::new(move |world, db| {
let reg = EventTableBinder::<TRow>::new(world);
bind(reg, db);
}));
self
}
pub fn with_subscriptions<K>(mut self) -> Self
where
K: Eq + Hash + Clone + Send + Sync + 'static,
M::SubscriptionHandle: SubscriptionHandle + Send + Sync + 'static,
C: DbConnection<Module = M>
+ DbContext<SubscriptionBuilder = SubscriptionBuilder<M>>
+ Send
+ Sync
+ 'static,
{
assert!(
self.subscriptions_initializer.is_none(),
"`with_subscriptions()` may only be called once"
);
self.subscriptions_initializer = Some(Arc::new(|app: &mut App| {
app.add_plugins(SubscriptionsPlugin::<K, C, M>::default());
}));
self
}
pub fn with_reconnect(mut self, reconnect_config: StdbReconnectOptions) -> Self {
assert!(
self.reconnect_options.is_none(),
"`with_reconnect()` may only be called once"
);
self.reconnect_options = Some(reconnect_config);
self
}
}
impl<
C: DbConnection<Module = M> + DbContext + Send + Sync + 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
> Plugin for StdbPlugin<C, M>
{
fn build(&self, app: &mut App) {
app.add_plugins(ChannelBridgePlugin);
app.configure_sets(PreStartup, StdbSet::Connection);
app.configure_sets(
PreUpdate,
(
StdbSet::Flush,
StdbSet::StateSync,
StdbSet::Connection,
StdbSet::Subscriptions,
)
.chain(),
);
if let Some(reconnect_options) = self.reconnect_options.clone() {
app.add_plugins(ReconnectPlugin::<C, M>::new(reconnect_options));
}
if let Some(init) = self.subscriptions_initializer.clone() {
init(app);
}
app.add_plugins(StdbConnectionPlugin::<C, M> {
database_name: self
.database_name
.clone()
.expect("No database name set. Use with_database_name()"),
uri: self.uri.clone().expect("No uri set. Use with_uri()"),
token: self.token.clone(),
eager_connection: self.eager_connection,
driver: self.driver.clone().or_else(|| {
panic!(
"No connection driver set. Use with_background_driver() or with_frame_driver()"
)
}),
compression: self.compression.unwrap_or_default(),
});
app.add_plugins(StdbTablePlugin::<C, M>::new(
self.table_bindings.clone(),
self.table_registrations.clone(),
));
}
}