use crate::{client_cache::TableAppliedDiff, error::InternalError, spacetime_module::SpacetimeModule};
use bytes::Bytes;
use spacetimedb_data_structures::map::HashMap;
use std::{
any::Any,
sync::atomic::{AtomicUsize, Ordering},
};
#[doc(hidden)]
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub struct CallbackId {
id: usize,
}
impl CallbackId {
pub(crate) fn get_next() -> Self {
static NEXT: AtomicUsize = AtomicUsize::new(0);
CallbackId {
id: NEXT.fetch_add(1, Ordering::Relaxed),
}
}
}
pub struct DbCallbacks<M: SpacetimeModule> {
table_callbacks: HashMap<&'static str, TableCallbacks<M>>,
}
impl<M: SpacetimeModule> Default for DbCallbacks<M> {
fn default() -> Self {
Self {
table_callbacks: HashMap::default(),
}
}
}
impl<M: SpacetimeModule> DbCallbacks<M> {
pub(crate) fn get_table_callbacks(&mut self, table_name: &'static str) -> &mut TableCallbacks<M> {
self.table_callbacks.entry(table_name).or_default()
}
pub fn invoke_table_row_callbacks<Row: Any>(
&mut self,
table_name: &'static str,
applied_diff: &TableAppliedDiff<Row>,
event: &M::EventContext,
) {
if applied_diff.is_empty() {
return;
}
let table_callbacks = self.get_table_callbacks(table_name);
for row in applied_diff.inserts() {
table_callbacks.invoke_on_insert(event, row);
}
for row in applied_diff.deletes() {
table_callbacks.invoke_on_delete(event, row);
}
for (del, ins) in applied_diff.updates() {
table_callbacks.invoke_on_update(event, del, ins);
}
}
}
pub(crate) type RowCallback<M> = Box<dyn FnMut(&<M as SpacetimeModule>::EventContext, &dyn Any) + Send + 'static>;
type InsertCallbackMap<M> = HashMap<CallbackId, RowCallback<M>>;
type DeleteCallbackMap<M> = HashMap<CallbackId, RowCallback<M>>;
pub(crate) type UpdateCallback<M> =
Box<dyn FnMut(&<M as SpacetimeModule>::EventContext, &dyn Any, &dyn Any) + Send + 'static>;
type UpdateCallbackMap<M> = HashMap<CallbackId, UpdateCallback<M>>;
pub(crate) struct TableCallbacks<M: SpacetimeModule> {
on_insert: InsertCallbackMap<M>,
on_delete: DeleteCallbackMap<M>,
on_update: UpdateCallbackMap<M>,
}
impl<M: SpacetimeModule> Default for TableCallbacks<M> {
fn default() -> Self {
Self {
on_insert: Default::default(),
on_delete: Default::default(),
on_update: Default::default(),
}
}
}
impl<M: SpacetimeModule> TableCallbacks<M> {
pub(crate) fn register_on_insert(&mut self, callback_id: CallbackId, callback: RowCallback<M>) {
self.on_insert.insert(callback_id, callback);
}
pub(crate) fn register_on_delete(&mut self, callback_id: CallbackId, callback: RowCallback<M>) {
self.on_delete.insert(callback_id, callback);
}
pub(crate) fn register_on_update(&mut self, callback_id: CallbackId, callback: UpdateCallback<M>) {
self.on_update.insert(callback_id, callback);
}
pub(crate) fn remove_on_insert(&mut self, callback_id: CallbackId) {
let _ = self
.on_insert
.remove(&callback_id)
.expect("Attempt to remove non-existent insert callback");
}
pub(crate) fn remove_on_delete(&mut self, callback_id: CallbackId) {
let _ = self
.on_delete
.remove(&callback_id)
.expect("Attempt to remove non-existent delete callback");
}
pub(crate) fn remove_on_update(&mut self, callback_id: CallbackId) {
let _ = self
.on_update
.remove(&callback_id)
.expect("Attempt to remove non-existent update callback");
}
fn invoke_on_insert(&mut self, ctx: &M::EventContext, row: &dyn Any) {
for callback in self.on_insert.values_mut() {
callback(ctx, row);
}
}
fn invoke_on_delete(&mut self, ctx: &M::EventContext, row: &dyn Any) {
for callback in self.on_delete.values_mut() {
callback(ctx, row);
}
}
fn invoke_on_update(&mut self, ctx: &M::EventContext, old: &dyn Any, new: &dyn Any) {
for callback in self.on_update.values_mut() {
callback(ctx, old, new);
}
}
}
pub(crate) type ReducerCallback<M> = Box<
dyn FnOnce(&<M as SpacetimeModule>::ReducerEventContext, Result<Result<(), String>, InternalError>)
+ Send
+ 'static,
>;
pub(crate) struct ReducerCallbacks<M: SpacetimeModule> {
callbacks: HashMap<u32, (M::Reducer, ReducerCallback<M>)>,
}
impl<M: SpacetimeModule> Default for ReducerCallbacks<M> {
fn default() -> Self {
Self {
callbacks: Default::default(),
}
}
}
impl<M: SpacetimeModule> ReducerCallbacks<M> {
pub(crate) fn pop_call_info(&mut self, request_id: u32) -> Option<(M::Reducer, ReducerCallback<M>)> {
self.callbacks.remove(&request_id)
}
pub(crate) fn store_call_info(&mut self, request_id: u32, args: M::Reducer, callback: ReducerCallback<M>) {
if self.callbacks.insert(request_id, (args, callback)).is_some() {
panic!("Re-used `request_id` {request_id} for multiple in-flight reducer requests.");
}
}
}
pub(crate) type ProcedureCallback<M> =
Box<dyn FnOnce(&<M as SpacetimeModule>::ProcedureEventContext, Result<Bytes, InternalError>) + Send + 'static>;
pub struct ProcedureCallbacks<M: SpacetimeModule> {
request_id_to_callback: HashMap<u32, ProcedureCallback<M>>,
}
impl<M: SpacetimeModule> Default for ProcedureCallbacks<M> {
fn default() -> Self {
Self {
request_id_to_callback: Default::default(),
}
}
}
impl<M: SpacetimeModule> ProcedureCallbacks<M> {
pub(crate) fn insert(&mut self, request_id: u32, callback: ProcedureCallback<M>) {
if self.request_id_to_callback.insert(request_id, callback).is_some() {
unreachable!("Request IDs are drawn from a global monotonic atomic counter and so are unique");
};
}
pub(crate) fn resolve(
&mut self,
ctx: &<M as SpacetimeModule>::ProcedureEventContext,
request_id: u32,
result: Result<Bytes, InternalError>,
) {
let callback = self
.request_id_to_callback
.remove(&request_id)
.expect("Attempting to resolve a non-existent procedure callback");
callback(ctx, result)
}
}