use crate::{
alias::{ReadStdbConnectErrorMessage, ReadStdbDisconnectedMessage},
channel_bridge::{channel_sender, register_channel},
connection::StdbConnection,
message::{StdbSubscriptionAppliedMessage, StdbSubscriptionErrorMessage},
set::StdbSet,
};
use bevy_app::{App, Plugin, PreUpdate};
use bevy_ecs::prelude::{IntoScheduleConfigs, Res, ResMut, Resource, resource_exists};
use crossbeam_channel::Sender;
use spacetimedb_sdk::{
__codegen::{__query_builder::Query, DbConnection, SpacetimeModule, SubscriptionBuilder},
DbContext, Result as StdbResult, SubscriptionHandle as StdbSubscriptionHandle,
};
use std::{collections::HashMap, hash::Hash, marker::PhantomData};
pub(crate) type SubscriptionsInitializer = dyn Fn(&mut App) + Send + Sync;
struct SubscriptionEntry<H> {
handle: Option<H>,
sql: String,
queued: bool,
}
#[derive(Resource)]
pub struct StdbSubscriptions<K, M>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
M: SpacetimeModule,
M::SubscriptionHandle: StdbSubscriptionHandle + Send + Sync + 'static,
{
entries: HashMap<K, SubscriptionEntry<M::SubscriptionHandle>>,
applied_sender: Sender<StdbSubscriptionAppliedMessage<K>>,
error_sender: Sender<StdbSubscriptionErrorMessage<K>>,
}
impl<K, M> StdbSubscriptions<K, M>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
M: SpacetimeModule,
M::SubscriptionHandle: StdbSubscriptionHandle + Send + Sync + 'static,
{
pub fn subscribe_query<T, Q>(&mut self, key: K, query: impl Fn(M::QueryBuilder) -> Q)
where
Q: Query<T>,
{
let res = query(M::QueryBuilder::default());
let sql = Query::into_sql(res);
self.subscribe_sql(key, sql);
}
pub fn subscribe_sql(&mut self, key: K, sql: impl Into<String>) {
let sql = sql.into();
if let Some(entry) = self.entries.get_mut(&key) {
if entry.sql == sql && (entry.queued || entry.handle.is_some()) {
return;
}
entry.sql = sql;
entry.queued = true;
return;
}
self.entries.insert(
key,
SubscriptionEntry {
handle: None,
sql,
queued: true,
},
);
}
pub fn unsubscribe(&mut self, key: &K) -> StdbResult<()> {
let Some(mut entry) = self.entries.remove(key) else {
return Ok(());
};
if let Some(handle) = entry.handle.take() {
handle.unsubscribe()?;
}
Ok(())
}
pub fn unsubscribe_all(&mut self) -> StdbResult<()> {
let mut first_err = None;
for (_, mut entry) in self.entries.drain() {
if let Some(handle) = entry.handle.take() {
let Err(err) = handle.unsubscribe() else {
continue;
};
if first_err.is_none() {
first_err = Some(err);
}
}
}
if let Some(err) = first_err {
Err(err)
} else {
Ok(())
}
}
pub fn sql_for(&self, key: &K) -> Option<&str> {
self.entries.get(key).map(|entry| entry.sql.as_str())
}
pub fn is_queued(&self, key: &K) -> bool {
self.entries.get(key).is_some_and(|entry| entry.queued)
}
pub fn is_active(&self, key: &K) -> bool {
self.entries
.get(key)
.and_then(|entry| entry.handle.as_ref())
.is_some_and(|handle| handle.is_active())
}
fn has_queued(&self) -> bool {
self.entries.values().any(|entry| entry.queued)
}
fn apply_queued<C>(&mut self, conn: &StdbConnection<C>)
where
C: DbConnection<Module = M>
+ DbContext<SubscriptionBuilder = SubscriptionBuilder<M>>
+ Send
+ Sync
+ 'static,
M: SpacetimeModule<DbConnection = C>,
{
for (key, entry) in self.entries.iter_mut().filter(|(_, entry)| entry.queued) {
let applied_key = key.clone();
let applied_sender = self.applied_sender.clone();
let error_key = key.clone();
let error_sender = self.error_sender.clone();
let handle = conn
.subscription_builder()
.on_applied(move |_ctx| {
let _ =
applied_sender.send(StdbSubscriptionAppliedMessage { key: applied_key });
})
.on_error(move |_ctx, err| {
let _ = error_sender.send(StdbSubscriptionErrorMessage {
key: error_key,
err,
});
})
.subscribe(entry.sql.as_str());
if let Some(old_handle) = entry.handle.replace(handle) {
let _ = old_handle.unsubscribe();
}
entry.queued = false;
}
}
}
pub(crate) struct SubscriptionsPlugin<K, C, M>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
C: DbConnection<Module = M>
+ DbContext<SubscriptionBuilder = SubscriptionBuilder<M>>
+ Send
+ Sync
+ 'static,
M: SpacetimeModule<DbConnection = C>,
M::SubscriptionHandle: StdbSubscriptionHandle + Send + Sync + 'static,
{
_marker: PhantomData<(K, C, M)>,
}
impl<K, C, M> Default for SubscriptionsPlugin<K, C, M>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
C: DbConnection<Module = M>
+ DbContext<SubscriptionBuilder = SubscriptionBuilder<M>>
+ Send
+ Sync
+ 'static,
M: SpacetimeModule<DbConnection = C>,
M::SubscriptionHandle: StdbSubscriptionHandle + Send + Sync + 'static,
{
fn default() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<K, C, M> Plugin for SubscriptionsPlugin<K, C, M>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
C: DbConnection<Module = M>
+ DbContext<SubscriptionBuilder = SubscriptionBuilder<M>>
+ Send
+ Sync
+ 'static,
M: SpacetimeModule<DbConnection = C> + 'static,
M::SubscriptionHandle: StdbSubscriptionHandle + Send + Sync + 'static,
{
fn build(&self, app: &mut App) {
register_channel::<StdbSubscriptionAppliedMessage<K>>(app);
register_channel::<StdbSubscriptionErrorMessage<K>>(app);
let world = app.world();
app.insert_resource(StdbSubscriptions::<K, M> {
entries: HashMap::default(),
applied_sender: channel_sender::<StdbSubscriptionAppliedMessage<K>>(world),
error_sender: channel_sender::<StdbSubscriptionErrorMessage<K>>(world),
});
app.add_systems(
PreUpdate,
(|mut disconnect_msgs: ReadStdbDisconnectedMessage,
mut error_msgs: ReadStdbConnectErrorMessage,
mut subs: ResMut<StdbSubscriptions<K, M>>| {
if disconnect_msgs.read().next().is_some() || error_msgs.read().next().is_some() {
for entry in subs.entries.values_mut() {
if entry.handle.take().is_some() {
entry.queued = true;
}
}
}
})
.in_set(StdbSet::Subscriptions),
);
app.add_systems(
PreUpdate,
(|conn: Res<StdbConnection<C>>, mut subs: ResMut<StdbSubscriptions<K, M>>| {
subs.apply_queued(&conn);
})
.in_set(StdbSet::Subscriptions)
.run_if(resource_exists::<StdbConnection<C>>)
.run_if(|subs: Res<StdbSubscriptions<K, M>>| subs.has_queued()),
);
}
}