use crate::spacetime_module::AbstractEventContext;
use crate::{
db_connection::{next_query_set_id, next_request_id, DbContextImpl, PendingMutation},
spacetime_module::{SpacetimeModule, SubscriptionHandle},
};
use futures_channel::mpsc;
use spacetimedb_client_api_messages::websocket::{self as ws, common::QuerySetId};
use spacetimedb_data_structures::map::HashMap;
use spacetimedb_query_builder::Query;
use std::sync::{Arc, Mutex};
pub struct SubscriptionManager<M: SpacetimeModule> {
subscriptions: HashMap<QuerySetId, SubscriptionHandleImpl<M>>,
}
impl<M: SpacetimeModule> Default for SubscriptionManager<M> {
fn default() -> Self {
Self {
subscriptions: HashMap::default(),
}
}
}
pub(crate) type OnAppliedCallback<M> =
Box<dyn FnOnce(&<M as SpacetimeModule>::SubscriptionEventContext) + Send + 'static>;
pub(crate) type OnErrorCallback<M> =
Box<dyn FnOnce(&<M as SpacetimeModule>::ErrorContext, crate::Error) + Send + 'static>;
pub type OnEndedCallback<M> = Box<dyn FnOnce(&<M as SpacetimeModule>::SubscriptionEventContext) + Send + 'static>;
pub(crate) enum PendingUnsubscribeResult<M: SpacetimeModule> {
SendUnsubscribe(ws::v2::Unsubscribe),
RunCallback(OnEndedCallback<M>),
DoNothing,
}
impl<M: SpacetimeModule> SubscriptionManager<M> {
pub(crate) fn on_disconnect(&mut self, _ctx: &M::ErrorContext) {
}
pub(crate) fn register_subscription(&mut self, query_set_id: QuerySetId, handle: SubscriptionHandleImpl<M>) {
self.subscriptions
.try_insert(query_set_id, handle.clone())
.unwrap_or_else(|_| unreachable!("Duplicate subscription id {query_set_id:?}"));
}
pub(crate) fn subscription_applied(&mut self, ctx: &M::SubscriptionEventContext, query_set_id: QuerySetId) {
let Some(sub) = self.subscriptions.get_mut(&query_set_id) else {
return;
};
if let Some(callback) = sub.on_applied() {
callback(ctx)
}
}
pub(crate) fn handle_pending_unsubscribe(&mut self, query_set_id: QuerySetId) -> PendingUnsubscribeResult<M> {
let Some(sub) = self.subscriptions.get(&query_set_id) else {
return PendingUnsubscribeResult::DoNothing;
};
let mut sub = sub.clone();
if sub.is_cancelled() {
self.subscriptions.remove(&query_set_id);
match sub.on_ended() {
Some(callback) => {
return PendingUnsubscribeResult::RunCallback(callback);
}
_ => {
return PendingUnsubscribeResult::DoNothing;
}
}
}
if sub.is_ended() {
self.subscriptions.remove(&query_set_id);
return PendingUnsubscribeResult::DoNothing;
}
PendingUnsubscribeResult::SendUnsubscribe(ws::v2::Unsubscribe {
query_set_id,
request_id: next_request_id(),
flags: ws::v2::UnsubscribeFlags::SendDroppedRows,
})
}
pub(crate) fn unsubscribe_applied(&mut self, ctx: &M::SubscriptionEventContext, query_set_id: QuerySetId) {
let Some(mut sub) = self.subscriptions.remove(&query_set_id) else {
log::debug!("Unsubscribe applied called for missing query {query_set_id:?}");
return;
};
if let Some(callback) = sub.on_ended() {
callback(ctx)
}
}
pub(crate) fn subscription_error(&mut self, ctx: &M::ErrorContext, query_set_id: QuerySetId) {
let Some(mut sub) = self.subscriptions.remove(&query_set_id) else {
log::warn!("Unsubscribe applied called for missing query {query_set_id:?}");
return;
};
if let Some(callback) = sub.on_error() {
callback(ctx, ctx.event().clone().unwrap());
}
}
}
pub struct SubscriptionBuilder<M: SpacetimeModule> {
on_applied: Option<OnAppliedCallback<M>>,
on_error: Option<OnErrorCallback<M>>,
conn: DbContextImpl<M>,
}
impl<M: SpacetimeModule> SubscriptionBuilder<M> {
#[doc(hidden)]
pub fn new(imp: &DbContextImpl<M>) -> Self {
Self {
on_applied: None,
on_error: None,
conn: imp.clone(),
}
}
pub fn on_applied(mut self, callback: impl FnOnce(&M::SubscriptionEventContext) + Send + 'static) -> Self {
self.on_applied = Some(Box::new(callback));
self
}
pub fn on_error(mut self, callback: impl FnOnce(&M::ErrorContext, crate::Error) + Send + 'static) -> Self {
self.on_error = Some(Box::new(callback));
self
}
pub fn subscribe<Queries: IntoQueries>(self, query_sql: Queries) -> M::SubscriptionHandle {
let query_set_id = next_query_set_id();
let handle = SubscriptionHandleImpl::new(SubscriptionState::new(
query_set_id,
query_sql.into_queries(),
self.conn.pending_mutations_send.clone(),
self.on_applied,
self.on_error,
));
self.conn
.pending_mutations_send
.unbounded_send(PendingMutation::Subscribe {
query_set_id,
handle: handle.clone(),
})
.unwrap();
M::SubscriptionHandle::new(handle)
}
pub fn subscribe_to_all_tables(self) -> M::SubscriptionHandle {
let all_subs = M::ALL_TABLE_NAMES
.iter()
.map(|table_name| format!("SELECT * FROM {table_name}"))
.collect::<Vec<_>>();
log::info!("Subscribing to queries: {all_subs:#?}");
self.subscribe(all_subs)
}
pub fn add_query<T, Q: Query<T>>(self, build: impl Fn(M::QueryBuilder) -> Q) -> TypedSubscriptionBuilder<M> {
let query = build(M::QueryBuilder::default());
TypedSubscriptionBuilder {
builder: self,
queries: vec![query.into_sql()],
}
}
}
pub struct TypedSubscriptionBuilder<M: SpacetimeModule> {
builder: SubscriptionBuilder<M>,
queries: Vec<String>,
}
impl<M: SpacetimeModule> TypedSubscriptionBuilder<M> {
pub fn add_query<T, Q: Query<T>>(mut self, build: impl Fn(M::QueryBuilder) -> Q) -> Self {
let query = build(M::QueryBuilder::default());
self.queries.push(query.into_sql());
self
}
pub fn subscribe(self) -> M::SubscriptionHandle {
self.builder.subscribe(self.queries)
}
}
pub trait IntoQueryString {
fn into_query_string(self) -> Box<str>;
}
macro_rules! impl_into_query_string_via_into {
($ty:ty $(, $tys:ty)* $(,)?) => {
impl IntoQueryString for $ty {
fn into_query_string(self) -> Box<str> {
self.into()
}
}
$(impl_into_query_string_via_into!($tys);)*
};
}
impl_into_query_string_via_into! {
&str, String, Box<str>,
}
pub trait IntoQueries {
fn into_queries(self) -> Box<[Box<str>]>;
}
impl<T: IntoQueryString> IntoQueries for T {
fn into_queries(self) -> Box<[Box<str>]> {
Box::new([self.into_query_string()])
}
}
impl<T: IntoQueryString, const N: usize> IntoQueries for [T; N] {
fn into_queries(self) -> Box<[Box<str>]> {
self.into_iter().map(IntoQueryString::into_query_string).collect()
}
}
impl<T: IntoQueryString + Clone> IntoQueries for &[T] {
fn into_queries(self) -> Box<[Box<str>]> {
self.iter().cloned().map(IntoQueryString::into_query_string).collect()
}
}
impl<T: IntoQueryString> IntoQueries for Vec<T> {
fn into_queries(self) -> Box<[Box<str>]> {
self.into_iter().map(IntoQueryString::into_query_string).collect()
}
}
impl IntoQueries for Box<[Box<str>]> {
fn into_queries(self) -> Box<[Box<str>]> {
self
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
enum SubscriptionServerState {
Pending, Sent, Applied, Ended, Error, }
pub(crate) struct SubscriptionState<M: SpacetimeModule> {
query_set_id: QuerySetId,
query_sql: Box<[Box<str>]>,
unsubscribe_called: bool,
status: SubscriptionServerState,
on_applied: Option<OnAppliedCallback<M>>,
on_error: Option<OnErrorCallback<M>>,
on_ended: Option<OnEndedCallback<M>>,
pending_mutation_sender: mpsc::UnboundedSender<PendingMutation<M>>,
}
impl<M: SpacetimeModule> SubscriptionState<M> {
pub(crate) fn new(
query_set_id: QuerySetId,
query_sql: Box<[Box<str>]>,
pending_mutation_sender: mpsc::UnboundedSender<PendingMutation<M>>,
on_applied: Option<OnAppliedCallback<M>>,
on_error: Option<OnErrorCallback<M>>,
) -> Self {
Self {
query_set_id,
query_sql,
unsubscribe_called: false,
status: SubscriptionServerState::Pending,
on_applied,
on_error,
on_ended: None,
pending_mutation_sender,
}
}
pub(crate) fn start(&mut self) -> Option<ws::v2::Subscribe> {
if self.unsubscribe_called {
return None;
}
if self.status != SubscriptionServerState::Pending {
unreachable!("Subscription already started");
}
self.status = SubscriptionServerState::Sent;
Some(ws::v2::Subscribe {
query_set_id: self.query_set_id,
query_strings: self.query_sql.clone(),
request_id: next_request_id(),
})
}
pub fn unsubscribe_then(&mut self, on_end: Option<OnEndedCallback<M>>) -> crate::Result<()> {
if self.is_ended() {
return Err(crate::Error::AlreadyEnded);
}
if self.unsubscribe_called {
return Err(crate::Error::AlreadyUnsubscribed);
}
self.unsubscribe_called = true;
self.on_ended = on_end;
self.pending_mutation_sender
.unbounded_send(PendingMutation::Unsubscribe {
query_set_id: self.query_set_id,
})
.unwrap();
Ok(())
}
pub fn is_cancelled(&self) -> bool {
self.status == SubscriptionServerState::Pending && self.unsubscribe_called
}
pub fn is_ended(&self) -> bool {
matches!(
self.status,
SubscriptionServerState::Ended | SubscriptionServerState::Error
)
}
pub fn is_active(&self) -> bool {
match self.status {
SubscriptionServerState::Applied => !self.unsubscribe_called,
_ => false,
}
}
pub fn on_applied(&mut self) -> Option<OnAppliedCallback<M>> {
if self.status != SubscriptionServerState::Sent {
log::debug!(
"on_applied called for query {:?} with status: {:?}",
self.query_set_id,
self.status
);
return None;
}
log::debug!("on_applied called for query {:?}", self.query_set_id);
self.status = SubscriptionServerState::Applied;
self.on_applied.take()
}
pub fn on_ended(&mut self) -> Option<OnAppliedCallback<M>> {
if self.is_ended() {
return None;
}
self.status = SubscriptionServerState::Ended;
self.on_ended.take()
}
pub fn on_error(&mut self) -> Option<OnErrorCallback<M>> {
if self.is_ended() {
return None;
}
self.status = SubscriptionServerState::Error;
self.on_error.take()
}
}
#[doc(hidden)]
pub struct SubscriptionHandleImpl<M: SpacetimeModule> {
pub(crate) inner: Arc<Mutex<SubscriptionState<M>>>,
}
impl<M: SpacetimeModule> Clone for SubscriptionHandleImpl<M> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<M: SpacetimeModule> SubscriptionHandleImpl<M> {
pub(crate) fn new(inner: SubscriptionState<M>) -> Self {
Self {
inner: Arc::new(Mutex::new(inner)),
}
}
pub(crate) fn start(&self) -> Option<ws::v2::Subscribe> {
let mut inner = self.inner.lock().unwrap();
inner.start()
}
pub(crate) fn is_cancelled(&self) -> bool {
self.inner.lock().unwrap().is_cancelled()
}
pub fn is_ended(&self) -> bool {
self.inner.lock().unwrap().is_ended()
}
pub fn is_active(&self) -> bool {
self.inner.lock().unwrap().is_active()
}
pub fn unsubscribe_then(self, on_end: Option<OnEndedCallback<M>>) -> crate::Result<()> {
let mut inner = self.inner.lock().unwrap();
inner.unsubscribe_then(on_end)
}
pub(crate) fn on_applied(&mut self) -> Option<OnAppliedCallback<M>> {
let mut inner = self.inner.lock().unwrap();
inner.on_applied()
}
pub(crate) fn on_ended(&mut self) -> Option<OnEndedCallback<M>> {
let mut inner = self.inner.lock().unwrap();
inner.on_ended()
}
pub(crate) fn on_error(&mut self) -> Option<OnErrorCallback<M>> {
let mut inner = self.inner.lock().unwrap();
inner.on_error()
}
}
#[cfg(test)]
mod test {
use super::*;
#[allow(unused)]
fn into_queries_box_str(query: Box<str>) {
let _ = query.clone().into_query_string();
let _ = <Box<str> as IntoQueryString>::into_query_string(query.clone());
let _ = query.clone().into_queries();
let _ = <[Box<str>; 1] as IntoQueries>::into_queries([query.clone()]);
let _ = [query.clone()].into_queries();
let slice: &[Box<str>] = std::slice::from_ref(&query);
let _ = <&[Box<str>] as IntoQueries>::into_queries(slice);
let _ = slice.into_queries();
let _ = <Vec<Box<str>> as IntoQueries>::into_queries(vec![query.clone()]);
let _ = vec![query.clone()].into_queries();
}
#[allow(unused)]
fn into_queries_string(query: String) {
let _ = query.clone().into_query_string();
let _ = <String as IntoQueryString>::into_query_string(query.clone());
let _ = query.clone().into_queries();
let _ = <[String; 1] as IntoQueries>::into_queries([query.clone()]);
let _ = [query.clone()].into_queries();
let slice: &[String] = std::slice::from_ref(&query);
let _ = <&[String] as IntoQueries>::into_queries(slice);
let _ = slice.into_queries();
let _ = <Vec<String> as IntoQueries>::into_queries(vec![query.clone()]);
let _ = vec![query.clone()].into_queries();
}
#[allow(unused)]
fn into_queries_str(query: &str) {
let _ = query.into_query_string();
let _ = <&str as IntoQueryString>::into_query_string(query);
let _ = query.into_queries();
let _ = <[&str; 1] as IntoQueries>::into_queries([query]);
let _ = [query].into_queries();
let slice: &[&str] = &[query];
let _ = <&[&str] as IntoQueries>::into_queries(slice);
let _ = slice.into_queries();
let _ = <Vec<&str> as IntoQueries>::into_queries(vec![query]);
let _ = vec![query].into_queries();
}
}