use std::collections::HashMap;
use sled::{
CompareAndSwapError, IVec, Transactional, Tree,
transaction::{ConflictableTransactionError, TransactionError},
};
use crate::database::DbKey;
#[derive(Debug)]
pub struct CompareAndSwapValue {
pub old: Option<IVec>,
pub new: Option<IVec>,
}
impl CompareAndSwapValue {
#[must_use]
pub fn new(old: Option<IVec>, new: Option<IVec>) -> Self {
Self { old, new }
}
}
#[derive(Debug)]
pub struct TreeCompareAndSwap {
tree: Tree,
pub swaps: HashMap<DbKey, CompareAndSwapValue>,
}
impl TreeCompareAndSwap {
fn new(tree: Tree) -> Self {
Self {
tree,
swaps: HashMap::new(),
}
}
#[must_use]
pub fn tree(&self) -> &Tree {
&self.tree
}
}
#[derive(Debug, Default)]
pub struct CompareAndSwapTransaction {
swaps: Vec<TreeCompareAndSwap>,
}
impl CompareAndSwapTransaction {
#[must_use]
pub fn new() -> Self {
Self { swaps: Vec::new() }
}
pub fn get_or_new_request(&mut self, tree: Tree) -> &mut TreeCompareAndSwap {
if let Some(i) = self.swaps.iter().position(|a| a.tree.name() == tree.name()) {
&mut self.swaps[i]
} else {
self.swaps.push(TreeCompareAndSwap::new(tree));
self.swaps.last_mut().unwrap()
}
}
#[must_use]
pub fn trees(&self) -> Vec<&Tree> {
self.swaps.iter().map(|a| &a.tree).collect()
}
#[must_use]
pub fn swaps(&self) -> Vec<&HashMap<[u8; 32], CompareAndSwapValue>> {
self.swaps.iter().map(|a| &a.swaps).collect()
}
}
pub fn apply_cas_tx(
tx: CompareAndSwapTransaction,
) -> Result<(), TransactionError<CompareAndSwapError>> {
tx.trees().transaction(|tx_trees| {
for (tree, maps) in tx_trees.iter().zip(tx.swaps()) {
for (k, v) in maps {
let ivec = tree.get(k)?;
if ivec == v.old {
if let Some(new) = &v.new {
tree.insert(k, new)?;
} else {
tree.remove(k)?;
}
} else {
return Err(ConflictableTransactionError::Abort(CompareAndSwapError {
current: ivec,
proposed: v.new.clone(),
}));
}
}
}
Ok(())
})
}
pub fn db_transaction<F, E>(mut f: F) -> Result<(), E>
where
F: FnMut(&mut CompareAndSwapTransaction) -> Result<(), E>,
E: From<TransactionError<CompareAndSwapError>>,
{
loop {
let mut cas_tx = CompareAndSwapTransaction::new();
f(&mut cas_tx)?;
match apply_cas_tx(cas_tx) {
Ok(()) => return Ok(()),
Err(TransactionError::Abort(CompareAndSwapError {
current: _,
proposed: _,
})) => {}
Err(err) => return Err(err.into()),
}
}
}