use sled::IVec;
use std::borrow::Borrow;
use crate::{
database::{
CompareAndSwapTransaction, CompareAndSwapValue, CustomTransactionError, DatabaseEntry, Db,
Deleteable, Entry, TransactionError, deserialize_from_ivec,
},
id::Id,
};
impl<T: DatabaseEntry> TryFrom<IVec> for Id<T> {
type Error = std::array::TryFromSliceError;
fn try_from(value: IVec) -> Result<Self, Self::Error> {
let v: [u8; 32] = (*value).try_into()?;
Ok(Id::from(v))
}
}
pub trait DbIdExt<T: DatabaseEntry>: Borrow<Id<T>> + Sized + Copy {
fn tx_get(
self,
cas_tx: &CompareAndSwapTransaction<T::DbInner>,
) -> Result<Option<Entry<T>>, TransactionError> {
if let Some(get_request) = cas_tx.get_request::<T>()
&& let Some(cas_value) = get_request.get(self.borrow())
{
return Ok(cas_value.new.clone().map(|t| Entry::new(t, *self.borrow())));
}
cas_tx
.db()
.entry_tree::<T>()
.get(self.borrow())?
.map(|ivec| deserialize_from_ivec(ivec).map(|t| Entry::new(t, *self.borrow())))
.transpose()
}
fn tx_check(
self,
cas_tx: &CompareAndSwapTransaction<T::DbInner>,
) -> Result<bool, TransactionError> {
if let Some(get_request) = cas_tx.get_request::<T>()
&& let Some(cas_value) = get_request.get(self.borrow())
{
return Ok(cas_value.new.is_some());
}
self.db_check(cas_tx.db())
}
fn tx_fetch_and_update<E>(
self,
f: impl FnOnce(
Option<T>,
&mut CompareAndSwapTransaction<T::DbInner>,
) -> Result<Option<T>, CustomTransactionError<E>>,
cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
) -> Result<(), CustomTransactionError<E>> {
let (current_state, mut cas_value) = {
let request = cas_tx.get_or_new_request::<T>();
if let Some(v) = request.take_out(self.borrow()) {
(v.new.clone(), v)
} else {
let old = request.tree().get(self.borrow()).map_err(|err| {
CustomTransactionError::Transaction(TransactionError::Sled(err))
})?;
(
old.clone().map(deserialize_from_ivec).transpose()?,
CompareAndSwapValue::new(old, None),
)
}
};
cas_value.new = f(current_state, cas_tx)?;
cas_tx
.get_request_mut()
.unwrap()
.insert_raw(*self.borrow(), cas_value);
Ok(())
}
fn tx_delete(
self,
cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
) -> Result<(), TransactionError>
where
T: Deleteable,
{
let id = *self.borrow();
self.tx_fetch_and_update(
|old, cas_tx| {
let old = old.expect("Deleted ref");
T::unlink_references(old.to_entry(id), cas_tx)?;
Ok(None)
},
cas_tx,
)
.map_err(TransactionError::from)
}
fn db_check(self, db: &Db<T::DbInner>) -> Result<bool, TransactionError> {
Ok(db.entry_tree::<T>().contains_key(self.borrow())?)
}
fn db_get(self, db: &Db<T::DbInner>) -> Result<Option<Entry<T>>, TransactionError> {
db.entry_tree::<T>()
.get(self.borrow())?
.map(|ivec| deserialize_from_ivec(ivec).map(|t| Entry::new(t, *self.borrow())))
.transpose()
}
fn db_delete(self, db: &Db<T::DbInner>) -> Result<(), TransactionError>
where
T: Deleteable,
{
db.transaction(false, None, |cas_tx| {
self.tx_delete(cas_tx).map_err(CustomTransactionError::from)
})
.map_err(TransactionError::from)
}
}
impl<T: DatabaseEntry, T1: Borrow<Id<T>> + Copy> DbIdExt<T> for T1 {}
pub trait DbIdIterExt<T>
where
Self: IntoIterator + Sized,
Self::Item: Borrow<Id<T>>,
T: DatabaseEntry,
{
fn db_get(self, db: &Db<T::DbInner>) -> Result<Vec<Entry<T>>, TransactionError> {
self.into_iter()
.map(|id| {
id.borrow()
.db_get(db)?
.ok_or(TransactionError::MissingEntry)
})
.collect()
}
fn tx_get(
self,
cas_tx: &CompareAndSwapTransaction<T::DbInner>,
) -> Result<Vec<Entry<T>>, TransactionError> {
self.into_iter()
.map(|id| {
id.borrow()
.tx_get(cas_tx)?
.ok_or(TransactionError::MissingEntry)
})
.collect()
}
fn tx_fetch_and_update(
self,
mut f: impl FnMut(Option<T>) -> Option<T>,
cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
) -> Result<(), TransactionError> {
let request = cas_tx.get_or_new_request::<T>();
for id in self {
let id = *id.borrow();
let (current_state, mut cas_value) = {
if let Some(v) = request.take_out(&id) {
(v.new.clone(), v)
} else {
let old = request.tree().get(id).map_err(|err| {
CustomTransactionError::Transaction(TransactionError::Sled(err))
})?;
(
old.clone().map(deserialize_from_ivec).transpose()?,
CompareAndSwapValue::new(old, None),
)
}
};
cas_value.new = f(current_state);
request.insert_raw(id, cas_value);
}
Ok(())
}
}
impl<T, I> DbIdIterExt<T> for I
where
Self: IntoIterator + Sized,
Self::Item: Borrow<Id<T>>,
T: DatabaseEntry,
{
}