use std::{
any::{Any, TypeId},
collections::HashMap,
};
use sled::{
IVec, Transactional, Tree,
transaction::{ConflictableTransactionError, TransactionalTree},
};
use crate::{
database::{
Createable, CustomTransactionError, Database, DatabaseEntry, DbHandle, DbKey, Deleteable,
EntryId, Mergeable, TransactionError, deserialize_from_ivec, serialize_to_ivec,
sled_get_raw,
},
trace,
};
#[derive(Debug, Clone)]
pub struct CompareAndSwapValue<Entry: DatabaseEntry> {
pub old: Option<IVec>,
pub new: Option<Entry>,
}
impl<Entry: DatabaseEntry> CompareAndSwapValue<Entry> {
#[must_use]
pub fn new(old: Option<IVec>, new: Option<Entry>) -> Self {
Self { old, new }
}
}
#[derive(Debug)]
pub struct TreeCompareAndSwap<Entry: DatabaseEntry> {
tree: Tree,
swaps: HashMap<DbKey, CompareAndSwapValue<Entry>>,
}
impl<T: DatabaseEntry> TreeCompareAndSwap<T> {
fn new(db: &T::Db) -> Self {
Self {
tree: T::__tree(db),
swaps: HashMap::new(),
}
}
#[must_use]
pub fn tree(&self) -> &Tree {
&self.tree
}
}
trait GenericCompareAndSwap: Any + std::fmt::Debug {
fn tree(&self) -> &Tree;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError>;
}
impl<Entry: DatabaseEntry> GenericCompareAndSwap for TreeCompareAndSwap<Entry> {
fn tree(&self) -> &Tree {
&self.tree
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError> {
for (k, v) in &self.swaps {
let db_old = tx_tree.get(k)?;
if db_old.as_deref().is_some_and(|v| v[4] != 0) {
return Err(TransactionError::ReadOnly);
}
if db_old == v.old {
if let Some(new) = &v.new {
tx_tree.insert(k, serialize_to_ivec(new))?;
} else {
tx_tree.remove(k)?;
}
} else {
trace!(
"CAS mismatch on key {:?}: Expected {:?}. Found {:?}",
k, v.old, db_old
);
return Err(TransactionError::CompareAndSwapError);
}
}
Ok(())
}
}
pub struct CompareAndSwapTransaction<CasDb: Database> {
swaps: HashMap<TypeId, Box<dyn GenericCompareAndSwap>>,
database: DbHandle<CasDb>,
}
impl<CasDb: Database> CompareAndSwapTransaction<CasDb> {
#[must_use]
pub(crate) fn with_db(database: DbHandle<CasDb>) -> Self {
Self {
swaps: HashMap::new(),
database,
}
}
pub fn get_or_new_request<Entry>(&mut self) -> &mut TreeCompareAndSwap<Entry>
where
Entry: DatabaseEntry<Db = CasDb>,
{
self.swaps
.entry(TypeId::of::<Entry>())
.or_insert_with(|| Box::new(TreeCompareAndSwap::<Entry>::new(&*self.database)))
.as_any_mut()
.downcast_mut::<TreeCompareAndSwap<Entry>>()
.unwrap()
}
pub fn tx_get<Id>(&self, id: Id) -> Result<Option<Id::Entry>, TransactionError>
where
Id: EntryId<IdDb = CasDb>,
{
if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
let cas_tree = boxed
.as_any()
.downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
.unwrap();
if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
return Ok(get.new.clone());
}
}
let tree = Id::Entry::__tree(&*self.database);
let raw = sled_get_raw(&tree, id.as_bytes())?;
raw.map(deserialize_from_ivec).transpose()
}
pub fn tx_check<Id>(&self, id: Id) -> Result<bool, TransactionError>
where
Id: EntryId<IdDb = CasDb>,
{
if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
let cas_tree = boxed
.as_any()
.downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
.unwrap();
if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
return Ok(get.new.is_some());
}
}
let tree = Id::Entry::__tree(&*self.database);
Ok(tree.contains_key(id.as_bytes())?)
}
pub fn tx_get_batch<Entry, I>(&self, items: I) -> Result<Vec<Entry>, TransactionError>
where
Entry: DatabaseEntry<Db = CasDb>,
I: IntoIterator<Item = Entry::Id>,
{
items
.into_iter()
.map(|id| self.tx_get(id)?.ok_or(TransactionError::MissingEntry))
.collect()
}
pub fn tx_remove<Id>(&mut self, key: Id) -> Result<(), TransactionError>
where
Id: EntryId<IdDb = CasDb>,
{
let db = self.database.clone();
let request = self.get_or_new_request::<Id::Entry>();
let key = *key.as_bytes();
if let Some(get_mut) = request.swaps.get_mut(&key) {
get_mut.new = None;
} else {
let old = sled_get_raw(&Id::Entry::__tree(&*db), &key)?;
request
.swaps
.insert(key, CompareAndSwapValue { old, new: None });
}
Ok(())
}
pub fn tx_upsert<Entry: DatabaseEntry<Db = CasDb>>(
&mut self,
mut item: Entry,
) -> Result<(), TransactionError> {
let db = self.database.clone();
item.pre_upsert(self)?;
let request = self.get_or_new_request::<Entry>();
let key = *item.id().as_bytes();
if let Some(get_mut) = request.swaps.get_mut(&key) {
get_mut.new = Some(item);
} else {
let old = sled_get_raw(&Entry::__tree(&*db), &key)?;
request.swaps.insert(
key,
CompareAndSwapValue {
old,
new: Some(item),
},
);
}
Ok(())
}
pub fn tx_insert<Entry: DatabaseEntry<Db = CasDb>>(
&mut self,
item: Entry,
) -> Result<(), TransactionError> {
if self.tx_get(item.id())?.is_some() {
return Err(TransactionError::AlreadyInDatabase);
}
self.tx_upsert(item)?;
Ok(())
}
pub fn tx_merge<Entry>(&mut self, item: Entry, from: Entry) -> Result<Entry, TransactionError>
where
Entry: DatabaseEntry<Db = CasDb> + Mergeable,
{
item.tx_merge(from, self)
}
pub fn tx_patch<Entry>(&mut self, item: Entry) -> Result<Entry, TransactionError>
where
Entry: DatabaseEntry<Db = CasDb> + Mergeable,
{
if let Some(current) = self.tx_get(item.id())? {
current.tx_merge(item, self)
} else {
self.tx_upsert(item.clone())?;
Ok(item)
}
}
pub fn tx_merge_batch<Entry>(
&mut self,
item: Entry,
from: impl IntoIterator<Item = Entry>,
) -> Result<Entry, TransactionError>
where
Entry: DatabaseEntry<Db = CasDb> + Mergeable,
{
item.tx_merge_batch(from, self)
}
pub fn tx_delete<Entry>(&mut self, item: Entry) -> Result<(), TransactionError>
where
Entry: DatabaseEntry<Db = CasDb> + Deleteable,
{
item.tx_delete(self)
}
pub fn tx_delete_id<Id>(&mut self, id: Id) -> Result<(), TransactionError>
where
Id: EntryId<IdDb = CasDb, Entry: Deleteable>,
{
if let Some(item) = self.tx_get(id)? {
item.tx_delete(self)
} else {
Ok(())
}
}
pub fn tx_create<Entry>(
&mut self,
args: <Entry as Createable>::CreateArgs,
) -> Result<Entry, CustomTransactionError<<Entry as Createable>::Err>>
where
Entry: DatabaseEntry<Db = CasDb> + Createable,
{
Entry::tx_create(self, args)
}
}
pub(crate) fn apply_cas_tx<CasDb: Database>(
tx: CompareAndSwapTransaction<CasDb>,
flush: bool,
) -> Result<(), TransactionError> {
if tx.swaps.is_empty() {
return Ok(());
}
let (trees, swaps): (Vec<&Tree>, Vec<&Box<dyn GenericCompareAndSwap>>) =
tx.swaps.values().map(|cas| (cas.tree(), cas)).unzip();
trees.transaction(|tx_trees| {
for (tree, cas) in tx_trees.iter().zip(swaps.iter()) {
cas.apply(tree)
.map_err(ConflictableTransactionError::Abort)?;
if flush {
tree.flush();
}
}
Ok(())
})?;
Ok(())
}
pub fn db_transaction<CasDb: Database, F, T, E>(
mut f: F,
db: DbHandle<CasDb>,
flush: bool,
) -> Result<T, CustomTransactionError<E>>
where
F: FnMut(&mut CompareAndSwapTransaction<CasDb>) -> Result<T, CustomTransactionError<E>>,
{
loop {
let mut cas_tx = CompareAndSwapTransaction::with_db(db.clone());
let t = f(&mut cas_tx)?;
match apply_cas_tx(cas_tx, flush) {
Ok(()) => return Ok(t),
Err(TransactionError::CompareAndSwapError) => {
trace!("Transaction (Not sync) ran into a CAS error and is retrying.");
}
Err(err) => return Err(CustomTransactionError::Transaction(err)),
}
}
}