use std::{borrow::Borrow, fmt};
use serde::{Deserialize, Serialize};
use sled::IVec;
use crate::database::{
CompareAndSwapTransaction, CompareAndSwapValue, CustomTransactionError, DatabaseEntry, Db,
Deleteable, Entry, TransactionError, deserialize_from_ivec,
};
#[derive(Debug, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Id<T: DatabaseEntry>([u8; 32], std::marker::PhantomData<T>);
impl<T: DatabaseEntry> Clone for Id<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: DatabaseEntry> Copy for Id<T> {}
impl<T: DatabaseEntry> PartialEq for Id<T> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T: DatabaseEntry> Eq for Id<T> {}
impl<T: DatabaseEntry> std::hash::Hash for Id<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl<T: DatabaseEntry> fmt::Display for Id<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
hex::encode(self.0).fmt(f)
}
}
impl<T: DatabaseEntry> std::str::FromStr for Id<T> {
type Err = hex::FromHexError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut id = [0u8; 32];
hex::decode_to_slice(s, &mut id)?;
Ok(Self(id, std::marker::PhantomData))
}
}
impl<T: DatabaseEntry> AsRef<[u8]> for Id<T> {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl<T: DatabaseEntry> Id<T> {
#[must_use]
pub fn cast<U: DatabaseEntry>(self) -> Id<U> {
Id::<U>(self.0, std::marker::PhantomData)
}
}
impl<T: DatabaseEntry> From<[u8; 32]> for Id<T> {
fn from(value: [u8; 32]) -> Self {
Self(value, std::marker::PhantomData)
}
}
impl<T: DatabaseEntry> TryFrom<IVec> for Id<T> {
type Error = std::array::TryFromSliceError;
fn try_from(value: IVec) -> Result<Self, Self::Error> {
let v: [u8; 32] = (*value).try_into()?;
Ok(Id::from(v))
}
}
impl<T: DatabaseEntry> std::ops::Deref for Id<T> {
type Target = [u8; 32];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T: DatabaseEntry> Id<T> {
pub fn tx_get(
&self,
cas_tx: &CompareAndSwapTransaction<T::DbInner>,
) -> Result<Option<Entry<T>>, TransactionError> {
if let Some(get_request) = cas_tx.get_request::<T>()
&& let Some(cas_value) = get_request.get(self)
{
return Ok(cas_value.new.clone().map(|t| Entry::new(t, *self)));
}
cas_tx
.db()
.entry_tree::<T>()
.get(self)?
.map(|ivec| deserialize_from_ivec(ivec).map(|t| Entry::new(t, *self)))
.transpose()
}
pub fn tx_check(
&self,
cas_tx: &CompareAndSwapTransaction<T::DbInner>,
) -> Result<bool, TransactionError> {
if let Some(get_request) = cas_tx.get_request::<T>()
&& let Some(cas_value) = get_request.get(self)
{
return Ok(cas_value.new.is_some());
}
self.db_check(cas_tx.db())
}
pub fn tx_fetch_and_update<E>(
&self,
f: impl FnOnce(
Option<T>,
&mut CompareAndSwapTransaction<T::DbInner>,
) -> Result<Option<T>, CustomTransactionError<E>>,
cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
) -> Result<(), CustomTransactionError<E>> {
let (current_state, mut cas_value) = {
let request = cas_tx.get_or_new_request::<T>();
if let Some(v) = request.take_out(self) {
(v.new.clone(), v)
} else {
let old = request.tree().get(self).map_err(|err| {
CustomTransactionError::Transaction(TransactionError::Sled(err))
})?;
(
old.clone().map(deserialize_from_ivec).transpose()?,
CompareAndSwapValue::new(old, None),
)
}
};
cas_value.new = f(current_state, cas_tx)?;
cas_tx
.get_request_mut()
.unwrap()
.insert_raw(*self, cas_value);
Ok(())
}
pub fn tx_delete(
&self,
cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
) -> Result<(), TransactionError>
where
T: Deleteable,
{
self.tx_fetch_and_update(
|old, cas_tx| {
let old = old.expect("Deleted ref");
T::unlink_references(old.to_entry(*self), cas_tx)?;
Ok(None)
},
cas_tx,
)
.map_err(TransactionError::from)
}
}
impl<T: DatabaseEntry> Id<T> {
pub fn db_check(&self, db: &Db<T::DbInner>) -> Result<bool, TransactionError> {
Ok(db.entry_tree::<T>().contains_key(self)?)
}
pub fn db_get(&self, db: &Db<T::DbInner>) -> Result<Option<Entry<T>>, TransactionError> {
db.entry_tree::<T>()
.get(*self)?
.map(|ivec| deserialize_from_ivec(ivec).map(|t| Entry::new(t, *self)))
.transpose()
}
pub fn db_delete(&self, db: &Db<T::DbInner>) -> Result<(), TransactionError>
where
T: Deleteable,
{
db.transaction(false, None, |cas_tx| {
self.tx_delete(cas_tx).map_err(CustomTransactionError::from)
})
.map_err(TransactionError::from)
}
}
pub trait IdExt<T>
where
Self: IntoIterator + Sized,
Self::Item: Borrow<Id<T>>,
T: DatabaseEntry,
{
fn db_get(self, db: &Db<T::DbInner>) -> Result<Vec<Entry<T>>, TransactionError> {
self.into_iter()
.map(|id| {
id.borrow()
.db_get(db)?
.ok_or(TransactionError::MissingEntry)
})
.collect()
}
fn tx_get(
self,
cas_tx: &CompareAndSwapTransaction<T::DbInner>,
) -> Result<Vec<Entry<T>>, TransactionError> {
self.into_iter()
.map(|id| {
id.borrow()
.tx_get(cas_tx)?
.ok_or(TransactionError::MissingEntry)
})
.collect()
}
fn tx_fetch_and_update(
self,
mut f: impl FnMut(Option<T>) -> Option<T>,
cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
) -> Result<(), TransactionError> {
let request = cas_tx.get_or_new_request::<T>();
for id in self {
let id = *id.borrow();
let (current_state, mut cas_value) = {
if let Some(v) = request.take_out(&id) {
(v.new.clone(), v)
} else {
let old = request.tree().get(id).map_err(|err| {
CustomTransactionError::Transaction(TransactionError::Sled(err))
})?;
(
old.clone().map(deserialize_from_ivec).transpose()?,
CompareAndSwapValue::new(old, None),
)
}
};
cas_value.new = f(current_state);
request.insert_raw(id, cas_value);
}
Ok(())
}
}
impl<T, I> IdExt<T> for I
where
Self: IntoIterator + Sized,
Self::Item: Borrow<Id<T>>,
T: DatabaseEntry,
{
}