use std::{
any::{Any, TypeId},
borrow::Borrow,
collections::{HashMap, hash_map},
rc::Rc,
};
use serde::Deserialize;
use sled::{
IVec, Transactional, Tree,
transaction::{ConflictableTransactionError, TransactionalTree},
};
use crate::{
database::{
CustomTransactionError, Database, DatabaseEntry, DatabaseError, DbKey, EntryId,
TransactionError, deserialize_from_ivec, serialize_to_ivec, sled_get_raw,
},
trace,
};
#[derive(Debug, Clone)]
pub struct CompareAndSwapValue<Entry: DatabaseEntry> {
pub old: Option<IVec>,
pub new: Option<Entry>,
}
impl<Entry: DatabaseEntry> CompareAndSwapValue<Entry> {
#[must_use]
pub fn new(old: Option<IVec>, new: Option<Entry>) -> Self {
Self { old, new }
}
}
#[derive(Debug)]
pub struct TreeCompareAndSwap<Entry: DatabaseEntry> {
tree: Tree,
swaps: HashMap<DbKey, CompareAndSwapValue<Entry>>,
}
impl<T: DatabaseEntry> TreeCompareAndSwap<T> {
fn new(db: &T::EntryDb) -> Self {
Self {
tree: T::tree(db),
swaps: HashMap::new(),
}
}
#[must_use]
pub fn tree(&self) -> &Tree {
&self.tree
}
}
trait GenericCompareAndSwap: Any + std::fmt::Debug {
fn tree(&self) -> &Tree;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
unsafe fn merge_from(
&mut self,
other: Box<dyn GenericCompareAndSwap>,
) -> Result<(), TransactionError>;
fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError>;
}
impl<Entry: DatabaseEntry> GenericCompareAndSwap for TreeCompareAndSwap<Entry> {
fn tree(&self) -> &Tree {
&self.tree
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
unsafe fn merge_from(
&mut self,
other: Box<dyn GenericCompareAndSwap>,
) -> Result<(), TransactionError> {
let other =
unsafe { *Box::from_raw(Box::into_raw(other) as *mut TreeCompareAndSwap<Entry>) };
for (other_key, other_value) in other.swaps {
match self.swaps.entry(other_key) {
hash_map::Entry::Occupied(entry) => {
if entry.get().old != other_value.old {
return Err(TransactionError::CompareAndSwapError);
}
}
hash_map::Entry::Vacant(entry) => {
entry.insert(other_value);
}
}
}
Ok(())
}
fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError> {
#[derive(Deserialize)]
struct ReadOnly {
read_only: Option<bool>,
}
for (k, v) in &self.swaps {
let db_old = tx_tree.get(k)?;
if db_old.as_deref().is_some_and(|v| {
let read_only: ReadOnly = ciborium::from_reader(v).unwrap();
read_only.read_only.unwrap_or(false)
}) {
return Err(DatabaseError::ReadOnly.into());
};
if db_old == v.old {
if let Some(new) = &v.new {
tx_tree.insert(k, serialize_to_ivec(&new))?;
} else {
tx_tree.remove(k)?;
}
} else {
trace!(
"CAS mismatch on key {:?}: Expected {:?}. Found {:?}",
k, v.old, db_old
);
return Err(TransactionError::CompareAndSwapError);
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct CompareAndSwapTransaction<CasDb: Database> {
swaps: HashMap<TypeId, Box<dyn GenericCompareAndSwap>>,
database: Rc<CasDb>,
}
impl<CasDb: Database> CompareAndSwapTransaction<CasDb> {
#[must_use]
pub(crate) fn new() -> Result<Self, DatabaseError> {
Ok(Self {
swaps: HashMap::new(),
database: Rc::new(CasDb::open()?),
})
}
#[must_use]
pub(crate) fn with_db(database: Rc<CasDb>) -> Self {
Self {
swaps: HashMap::new(),
database,
}
}
pub fn merge(&mut self, other: Self) -> Result<(), TransactionError> {
for (type_id, other_tree) in other.swaps {
match self.swaps.entry(type_id) {
hash_map::Entry::Occupied(mut entry) => {
unsafe {
entry.get_mut().merge_from(other_tree)?;
}
}
hash_map::Entry::Vacant(entry) => {
entry.insert(other_tree);
}
}
}
Ok(())
}
pub fn tx_get<Id>(&self, id: Id) -> Result<Option<Id::Entry>, TransactionError>
where
Id: EntryId<IdDb = CasDb>,
{
if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
let cas_tree = boxed
.as_any()
.downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
.unwrap();
if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
return Ok(get.new.clone());
}
}
let tree = Id::Entry::tree(&*self.database);
let raw = sled_get_raw(&tree, id.as_bytes())?;
raw.map(deserialize_from_ivec)
.transpose()
.map_err(TransactionError::Database)
}
pub fn tx_check<Id>(&self, id: Id) -> Result<bool, TransactionError>
where
Id: EntryId<IdDb = CasDb>,
{
if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
let cas_tree = boxed
.as_any()
.downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
.unwrap();
if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
return Ok(get.new.is_some());
}
}
let tree = Id::Entry::tree(&*self.database);
Ok(tree.contains_key(id.as_bytes())?)
}
pub fn tx_get_batch<Entry, I>(&self, items: I) -> Result<Vec<Entry>, TransactionError>
where
Entry: DatabaseEntry<EntryDb = CasDb>,
I: IntoIterator<Item: Borrow<Entry::Id>>,
{
items
.into_iter()
.map(|id| {
self.tx_get(*id.borrow())?
.ok_or(DatabaseError::MissingEntry.into())
})
.collect()
}
pub fn tx_remove<Id>(&mut self, key: Id) -> Result<(), TransactionError>
where
Id: EntryId<IdDb = CasDb>,
{
let db = self.database.clone();
let request = self.get_or_new_request::<Id::Entry>();
let key = *key.as_bytes();
if let Some(get_mut) = request.swaps.get_mut(&key) {
get_mut.new = None;
} else {
let old = sled_get_raw(&Id::Entry::tree(&*db), &key)?;
request
.swaps
.insert(key, CompareAndSwapValue { old, new: None });
}
Ok(())
}
pub fn tx_upsert<Entry: DatabaseEntry<EntryDb = CasDb>>(
&mut self,
key: Entry::Id,
mut new: Option<Entry>,
) -> Result<(), TransactionError> {
let db = self.database.clone();
if let Some(new) = &mut new {
new.pre_upsert(self)?;
}
let request = self.get_or_new_request::<Entry>();
let key = *key.as_bytes();
if let Some(get_mut) = request.swaps.get_mut(&key) {
get_mut.new = new;
} else {
let old = sled_get_raw(&Entry::tree(&*db), &key)?;
request.swaps.insert(key, CompareAndSwapValue { old, new });
}
Ok(())
}
pub fn tx_insert<Entry: DatabaseEntry<EntryDb = CasDb>>(
&mut self,
item: Entry,
) -> Result<(), TransactionError> {
if self.tx_get(item.id())?.is_some() {
return Err(DatabaseError::AlreadyInDatabase.into());
}
self.tx_upsert(item.id(), Some(item))?;
Ok(())
}
pub fn get_or_new_request<Entry>(&mut self) -> &mut TreeCompareAndSwap<Entry>
where
Entry: DatabaseEntry<EntryDb = CasDb>,
{
self.swaps
.entry(TypeId::of::<Entry>())
.or_insert_with(|| Box::new(TreeCompareAndSwap::<Entry>::new(&*self.database)))
.as_any_mut()
.downcast_mut::<TreeCompareAndSwap<Entry>>()
.unwrap()
}
}
pub fn apply_cas_tx<CasDb: Database>(
tx: CompareAndSwapTransaction<CasDb>,
flush: bool,
) -> Result<(), TransactionError> {
let (trees, swaps): (Vec<&Tree>, Vec<&Box<dyn GenericCompareAndSwap>>) =
tx.swaps.values().map(|cas| (cas.tree(), cas)).unzip();
trees.transaction(|tx_trees| {
for (tree, cas) in tx_trees.iter().zip(swaps.iter()) {
cas.apply(tree)
.map_err(ConflictableTransactionError::Abort)?;
if flush {
tree.flush();
}
}
Ok(())
})?;
Ok(())
}
pub fn db_transaction<CasDb: Database, F, T, E>(
mut f: F,
db: Option<CasDb>,
flush: bool,
) -> Result<T, CustomTransactionError<E>>
where
F: FnMut(&mut CompareAndSwapTransaction<CasDb>) -> Result<T, CustomTransactionError<E>>,
{
let db = db.map(Rc::new);
loop {
let mut cas_tx = if let Some(db) = db.clone() {
CompareAndSwapTransaction::with_db(db)
} else {
CompareAndSwapTransaction::new().map_err(TransactionError::Database)?
};
let t = f(&mut cas_tx)?;
match apply_cas_tx(cas_tx, flush) {
Ok(()) => return Ok(t),
Err(TransactionError::CompareAndSwapError) => {
trace!("Transaction (Not sync) ran into a CAS error and is retrying.");
}
Err(err) => return Err(err.into()),
}
}
}