use std::{
any::{Any, TypeId},
collections::HashMap,
};
use sled::{
IVec, Transactional, Tree,
transaction::{ConflictableTransactionError, TransactionalTree},
};
use crate::database::{Database, DatabaseEntry, Db, Id, TransactionError};
pub struct CompareAndSwapValue<Entry: DatabaseEntry> {
pub old: Option<IVec>,
pub new: Option<Entry>,
}
impl<Entry: DatabaseEntry> CompareAndSwapValue<Entry> {
#[must_use]
pub fn new(old: Option<IVec>, new: Option<Entry>) -> Self {
Self { old, new }
}
}
pub struct TreeCompareAndSwap<Entry: DatabaseEntry> {
tree: Tree,
swaps: HashMap<Id<Entry>, CompareAndSwapValue<Entry>>,
}
impl<T: DatabaseEntry> TreeCompareAndSwap<T> {
#[must_use]
fn new(db: &Db<T::DbInner>) -> Self {
Self {
tree: db.entry_tree::<T>(),
swaps: HashMap::new(),
}
}
#[must_use]
pub fn tree(&self) -> &Tree {
&self.tree
}
#[must_use]
pub fn get(&self, id: &Id<T>) -> Option<&CompareAndSwapValue<T>> {
self.swaps.get(id)
}
#[must_use]
pub fn get_mut(&mut self, id: &Id<T>) -> Option<&mut CompareAndSwapValue<T>> {
self.swaps.get_mut(id)
}
#[must_use]
pub fn take_out(&mut self, id: &Id<T>) -> Option<CompareAndSwapValue<T>> {
self.swaps.remove(id)
}
pub fn insert(&mut self, id: Id<T>, old: Option<IVec>, new: Option<T>) {
self.swaps.insert(id, CompareAndSwapValue { old, new });
}
pub fn insert_raw(&mut self, id: Id<T>, cas_value: CompareAndSwapValue<T>) {
self.swaps.insert(id, cas_value);
}
}
trait GenericCompareAndSwap: Any {
fn tree(&self) -> Tree;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError>;
}
impl<Entry: DatabaseEntry> GenericCompareAndSwap for TreeCompareAndSwap<Entry> {
fn tree(&self) -> Tree {
self.tree.clone()
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError> {
for (k, v) in &self.swaps {
let db_old = tx_tree.get(k)?;
if db_old == v.old {
if let Some(new) = &v.new {
let buffer = Vec::from(Entry::VERSION_NUMBER.to_be_bytes());
let buffer: IVec = cbor4ii::serde::to_vec(buffer, new)
.expect("Cbor4ii failed to serialize. This cannot happen unless a serializer failed")
.into();
tx_tree.insert(&**k, buffer)?;
} else {
tx_tree.remove(&**k)?;
}
} else {
return Err(TransactionError::CompareAndSwapError);
}
}
Ok(())
}
}
pub struct CompareAndSwapTransaction<CasDb: Database> {
swaps: HashMap<TypeId, Box<dyn GenericCompareAndSwap>>,
database: Db<CasDb>,
}
impl<CasDb: Database> CompareAndSwapTransaction<CasDb> {
#[must_use]
pub(super) fn with_db(database: Db<CasDb>) -> Self {
Self {
swaps: HashMap::new(),
database,
}
}
pub(super) fn db(&self) -> &Db<CasDb> {
&self.database
}
pub fn get_request<Entry>(&self) -> Option<&TreeCompareAndSwap<Entry>>
where
Entry: DatabaseEntry<DbInner = CasDb>,
{
self.swaps.get(&TypeId::of::<Entry>()).map(|boxed| {
boxed
.as_any()
.downcast_ref::<TreeCompareAndSwap<Entry>>()
.unwrap()
})
}
pub fn get_request_mut<Entry>(&mut self) -> Option<&mut TreeCompareAndSwap<Entry>>
where
Entry: DatabaseEntry<DbInner = CasDb>,
{
self.swaps.get_mut(&TypeId::of::<Entry>()).map(|boxed| {
boxed
.as_any_mut()
.downcast_mut::<TreeCompareAndSwap<Entry>>()
.unwrap()
})
}
pub fn get_or_new_request<Entry>(&mut self) -> &mut TreeCompareAndSwap<Entry>
where
Entry: DatabaseEntry<DbInner = CasDb>,
{
self.swaps
.entry(TypeId::of::<Entry>())
.or_insert_with(|| Box::new(TreeCompareAndSwap::<Entry>::new(&self.database)))
.as_any_mut()
.downcast_mut::<TreeCompareAndSwap<Entry>>()
.unwrap()
}
pub(super) fn apply(self, flush: bool) -> Result<(), TransactionError> {
if self.swaps.is_empty() {
return Ok(());
}
let (trees, swaps): (Vec<Tree>, Vec<Box<dyn GenericCompareAndSwap>>) = self
.swaps
.into_values()
.map(|cas| (cas.tree(), cas))
.unzip();
trees.transaction(|tx_trees| {
for (tree, cas) in tx_trees.iter().zip(swaps.iter()) {
cas.apply(tree)
.map_err(ConflictableTransactionError::Abort)?;
if flush {
tree.flush();
}
}
Ok(())
})?;
Ok(())
}
}