use crate::callbacks::CallbackId;
use crate::db_connection::{debug_log, PendingMutation, SharedCell};
use crate::spacetime_module::{InModule, SpacetimeModule, TableUpdate, WithBsatn};
use anymap3::Map;
use bytes::Bytes;
use core::any::type_name;
use core::hash::Hash;
use futures_channel::mpsc;
use spacetimedb_data_structures::map::{hash_map::Entry, HashCollectionExt, HashMap};
use std::any::Any;
use std::fmt::Debug;
use std::fs::File;
use std::io::Write;
use std::marker::PhantomData;
use std::sync::Arc;
pub struct TableCache<Row> {
pub(crate) entries: HashMap<Bytes, RowEntry<Row>>,
pub(crate) unique_indices: HashMap<&'static str, Box<dyn UniqueIndexDyn<Row = Row>>>,
extra_logging: Option<SharedCell<File>>,
}
pub(crate) struct RowEntry<Row> {
row: Row,
ref_count: u32,
}
impl<Row> TableCache<Row> {
fn new(extra_logging: Option<SharedCell<File>>) -> Self {
Self {
entries: Default::default(),
unique_indices: Default::default(),
extra_logging,
}
}
}
type RowEventMap<'r, Row> = HashMap<&'r [u8], &'r Row>;
pub struct TableAppliedDiff<'r, Row> {
deletes: RowEventMap<'r, Row>,
inserts: RowEventMap<'r, Row>,
update_deletes: Vec<&'r Row>,
update_inserts: Vec<&'r Row>,
}
impl<Row> Default for TableAppliedDiff<'_, Row> {
fn default() -> Self {
Self {
deletes: <_>::default(),
inserts: <_>::default(),
update_deletes: <_>::default(),
update_inserts: <_>::default(),
}
}
}
impl<'r, Row> TableAppliedDiff<'r, Row> {
pub(crate) fn from_event_inserts(inserts: &'r [WithBsatn<Row>]) -> Self {
let insert_map = inserts.iter().map(|wb| (wb.bsatn.as_ref(), &wb.row)).collect();
Self {
inserts: insert_map,
deletes: Default::default(),
update_deletes: Vec::new(),
update_inserts: Vec::new(),
}
}
pub fn with_updates_by_pk<Pk: Eq + Hash>(mut self, derive_pk: impl Fn(&Row) -> &Pk) -> Self {
self.derive_updates(derive_pk);
self
}
fn derive_updates<Pk: Eq + Hash>(&mut self, derive_pk: impl Fn(&Row) -> &Pk) {
if self.deletes.is_empty() {
return;
}
let mut delete_pks = HashMap::with_capacity(self.deletes.len());
for (&bsatn, &row) in self.deletes.iter() {
let pk = derive_pk(row);
delete_pks.insert(pk, (bsatn, row));
}
self.update_inserts = self
.inserts
.extract_if(|_, ins_row| {
let pk = derive_pk(ins_row);
let Some((del_bsatn, del_row)) = delete_pks.get(pk) else {
return false;
};
self.update_deletes.push(del_row);
let _deleted = self.deletes.remove(del_bsatn);
debug_assert!(_deleted.is_some());
true
})
.map(|(_, ins_row)| ins_row)
.collect::<Vec<_>>();
}
pub(super) fn is_empty(&self) -> bool {
self.deletes.is_empty()
&& self.inserts.is_empty()
&& self.update_deletes.is_empty()
&& self.update_inserts.is_empty()
}
pub(super) fn deletes(&self) -> impl '_ + Iterator<Item = &'r Row> {
self.deletes.values().copied()
}
pub(super) fn inserts(&self) -> impl '_ + Iterator<Item = &'r Row> {
self.inserts.values().copied()
}
pub(super) fn updates(&self) -> impl '_ + Iterator<Item = (&'r Row, &'r Row)> {
self.update_deletes
.iter()
.copied()
.zip(self.update_inserts.iter().copied())
}
}
impl<Row> TableCache<Row> {
fn debug_log(&self, body: impl FnOnce(&mut File) -> std::result::Result<(), std::io::Error>) {
debug_log(&self.extra_logging, body);
}
}
impl<Row: Clone + Debug + Send + Sync + 'static> TableCache<Row> {
fn handle_delete<'r>(
&mut self,
inserts: &mut RowEventMap<'_, Row>,
deletes: &mut RowEventMap<'r, Row>,
delete: &'r WithBsatn<Row>,
) {
let Entry::Occupied(mut entry) = self.entries.entry(delete.bsatn.clone()) else {
self.debug_log(|file| {
writeln!(file, "`handle_delete` for table with row type {}: a delete update should correspond to an existing row in the table cache, but the row {delete:?} was not present", std::any::type_name::<Row>())?;
writeln!(file, "table contents:")?;
for (bsatn, RowEntry { row, ref_count }) in self.entries.iter() {
writeln!(file, "\t{bsatn:?}\n\t\t{row:?}\n\t\tref_count {ref_count}")?;
}
Ok(())
});
unreachable!("a delete update should correspond to an existing row in the table cache, but the row {delete:?} was not present");
};
let ref_count = &mut entry.get_mut().ref_count;
*ref_count -= 1;
if *ref_count == 0 {
entry.remove();
deletes.insert(&delete.bsatn, &delete.row);
inserts.remove(&*delete.bsatn);
}
}
fn handle_insert<'r>(&mut self, inserts: &mut RowEventMap<'r, Row>, insert: &'r WithBsatn<Row>) {
let entry = self.entries.entry(insert.bsatn.clone());
let entry = entry.or_insert_with(|| {
inserts.insert(&insert.bsatn, &insert.row);
RowEntry {
row: insert.row.clone(),
ref_count: 0,
}
});
entry.ref_count += 1;
}
fn apply_diff<'r>(&mut self, diff: &'r TableUpdate<Row>) -> TableAppliedDiff<'r, Row> {
let mut insert_events = <_>::default();
for insert in &diff.inserts {
self.handle_insert(&mut insert_events, insert);
}
let mut delete_events = <_>::default();
for delete in &diff.deletes {
self.handle_delete(&mut insert_events, &mut delete_events, delete);
}
for index in self.unique_indices.values_mut() {
for row in delete_events.values() {
index.remove_row(row);
}
for &row in insert_events.values() {
index.add_row(row.clone());
}
}
TableAppliedDiff {
deletes: delete_events,
inserts: insert_events,
update_deletes: Vec::new(),
update_inserts: Vec::new(),
}
}
fn find_by_unique_index<'this>(
&'this self,
unique_index_name: &'static str,
key: &'_ dyn std::any::Any,
) -> Option<&'this Row> {
let index = self
.unique_indices
.get(unique_index_name)
.unwrap_or_else(|| panic!("No such unique index: {unique_index_name}"));
index.find_row(key)
}
pub fn add_unique_constraint<Col>(&mut self, unique_index_name: &'static str, get_unique_col: fn(&Row) -> &Col)
where
Col: Any + Clone + std::hash::Hash + Eq + Send + Sync + std::fmt::Debug + 'static,
{
assert!(self.entries.is_empty(), "Cannot add a unique constraint to a populated table; constraints should only be added during initialization, before subscribing to any rows.");
if self
.unique_indices
.insert(
unique_index_name,
Box::new(UniqueIndexImpl {
get_unique_col,
rows: Default::default(),
}),
)
.is_some()
{
panic!("Duplicate unique constraint name {unique_index_name}");
}
}
}
pub struct ClientCache<M: SpacetimeModule + ?Sized> {
tables: Map<dyn Any + Send + Sync>,
extra_logging: Option<SharedCell<File>>,
_module: PhantomData<M>,
}
impl<M: SpacetimeModule> ClientCache<M> {
pub(crate) fn new(extra_logging: Option<SharedCell<File>>) -> Self {
Self {
tables: Map::new(),
extra_logging,
_module: PhantomData,
}
}
pub(crate) fn get_table<Row: InModule<Module = M> + Send + Sync + 'static>(
&self,
table_name: &'static str,
) -> Option<&TableCache<Row>> {
self.tables
.get::<HashMap<&'static str, TableCache<Row>>>()
.and_then(|tables_of_row_type| tables_of_row_type.get(table_name))
}
pub fn get_or_make_table<Row: InModule<Module = M> + Send + Sync + 'static>(
&mut self,
table_name: &'static str,
) -> &mut TableCache<Row> {
self.tables
.entry::<HashMap<&'static str, TableCache<Row>>>()
.or_default()
.entry(table_name)
.or_insert_with(|| TableCache::new(self.extra_logging.clone()))
}
pub fn apply_diff_to_table<'r, Row: InModule<Module = M> + Clone + Debug + Send + Sync + 'static>(
&mut self,
table_name: &'static str,
diff: &'r TableUpdate<Row>,
) -> TableAppliedDiff<'r, Row> {
if diff.is_empty() {
return <_>::default();
}
let table = self.get_or_make_table::<Row>(table_name);
table.apply_diff(diff)
}
}
pub struct TableHandle<Row: InModule> {
pub(crate) client_cache: SharedCell<ClientCache<Row::Module>>,
pub(crate) pending_mutations: mpsc::UnboundedSender<PendingMutation<Row::Module>>,
pub(crate) table_name: &'static str,
}
impl<Row: InModule> Clone for TableHandle<Row> {
fn clone(&self) -> Self {
Self {
client_cache: Arc::clone(&self.client_cache),
pending_mutations: self.pending_mutations.clone(),
table_name: self.table_name,
}
}
}
impl<Row: InModule + Send + Sync + Clone + 'static> TableHandle<Row> {
fn with_table_cache<Res>(&self, get: impl FnOnce(&TableCache<Row>) -> Res) -> Res {
let client_cache = self.client_cache.lock().unwrap();
client_cache
.get_table::<Row>(self.table_name)
.map(get)
.unwrap_or_else(|| panic!("No such table: {}", self.table_name))
}
pub fn count(&self) -> u64 {
self.with_table_cache(|table| table.entries.len() as u64)
}
pub fn iter(&self) -> impl Iterator<Item = Row> + use<Row> {
self.with_table_cache(|table| table.entries.values().map(|e| e.row.clone()).collect::<Vec<_>>())
.into_iter()
}
fn queue_mutation(&self, mutation: PendingMutation<Row::Module>) {
self.pending_mutations.unbounded_send(mutation).unwrap();
}
pub fn on_insert(
&self,
mut callback: impl FnMut(&<Row::Module as SpacetimeModule>::EventContext, &Row) + Send + 'static,
) -> CallbackId {
let callback_id = CallbackId::get_next();
self.queue_mutation(PendingMutation::AddInsertCallback {
table: self.table_name,
callback: Box::new(move |ctx, row| {
let row = row.downcast_ref::<Row>().unwrap();
callback(ctx, row);
}),
callback_id,
});
callback_id
}
pub fn remove_on_insert(&self, callback: CallbackId) {
self.queue_mutation(PendingMutation::RemoveInsertCallback {
table: self.table_name,
callback_id: callback,
});
}
pub fn on_delete(
&self,
mut callback: impl FnMut(&<Row::Module as SpacetimeModule>::EventContext, &Row) + Send + 'static,
) -> CallbackId {
let callback_id = CallbackId::get_next();
self.queue_mutation(PendingMutation::AddDeleteCallback {
table: self.table_name,
callback: Box::new(move |ctx, row| {
let row = row.downcast_ref::<Row>().unwrap();
callback(ctx, row);
}),
callback_id,
});
callback_id
}
pub fn remove_on_delete(&self, callback: CallbackId) {
self.queue_mutation(PendingMutation::RemoveDeleteCallback {
table: self.table_name,
callback_id: callback,
});
}
pub fn on_update(
&self,
mut callback: impl FnMut(&<Row::Module as SpacetimeModule>::EventContext, &Row, &Row) + Send + 'static,
) -> CallbackId {
let callback_id = CallbackId::get_next();
self.queue_mutation(PendingMutation::AddUpdateCallback {
table: self.table_name,
callback: Box::new(move |ctx, old, new| {
let old = old.downcast_ref::<Row>().unwrap();
let new = new.downcast_ref::<Row>().unwrap();
callback(ctx, old, new);
}),
callback_id,
});
callback_id
}
pub fn remove_on_update(&self, callback: CallbackId) {
self.queue_mutation(PendingMutation::RemoveUpdateCallback {
table: self.table_name,
callback_id: callback,
});
}
pub fn get_unique_constraint<Col>(&self, constraint_name: &'static str) -> UniqueConstraintHandle<Row, Col> {
UniqueConstraintHandle {
table_handle: self.clone(),
unique_index_name: constraint_name,
_phantom: PhantomData,
}
}
}
pub struct UniqueConstraintHandle<Row: InModule, Col> {
table_handle: TableHandle<Row>,
unique_index_name: &'static str,
_phantom: PhantomData<HashMap<Col, Row>>,
}
impl<
Row: Clone + Debug + InModule + Send + Sync + 'static,
Col: std::any::Any + Eq + std::hash::Hash + Clone + Send + Sync + std::fmt::Debug + 'static,
> UniqueConstraintHandle<Row, Col>
{
pub fn find(&self, col_val: &Col) -> Option<Row> {
self.table_handle
.with_table_cache(|table| table.find_by_unique_index(self.unique_index_name, col_val).cloned())
}
}
pub trait UniqueIndexDyn: Send + Sync + 'static {
type Row: Clone + Send + Sync + 'static;
fn add_row(&mut self, row: Self::Row);
fn remove_row(&mut self, row: &Self::Row);
fn find_row<'this>(&'this self, key: &'_ dyn std::any::Any) -> Option<&'this Self::Row>;
}
pub struct UniqueIndexImpl<Row, Col> {
rows: HashMap<Col, Row>,
get_unique_col: fn(&Row) -> &Col,
}
impl<Row, Col> UniqueIndexDyn for UniqueIndexImpl<Row, Col>
where
Row: Clone + Send + Sync + 'static,
Col: Any + Clone + std::hash::Hash + Eq + Send + Sync + std::fmt::Debug + 'static,
{
type Row = Row;
fn add_row(&mut self, row: Self::Row) {
let col = (self.get_unique_col)(&row).clone();
if let Some(prev_row) = self.rows.insert(col, row) {
panic!(
"Duplicated entry in unique index at key {:?}, for type {}",
(self.get_unique_col)(&prev_row),
type_name::<Row>()
);
}
}
fn remove_row(&mut self, row: &Self::Row) {
let col = (self.get_unique_col)(row);
self.rows
.remove(col)
.expect("UniqueIndexDyn::remove_row for non-present row");
}
fn find_row<'this>(&'this self, key: &'_ dyn std::any::Any) -> Option<&'this Self::Row> {
let col = key
.downcast_ref::<Col>()
.expect("UniqueIndexDyn::find_row with key of incorrect type");
self.rows.get(col)
}
}