lunar-lib 0.10.0

Common utilities for lunar applications
Documentation
use std::{borrow::Borrow, fmt};

use serde::{Deserialize, Serialize};
use sled::IVec;

use crate::database::{
    CompareAndSwapTransaction, CompareAndSwapValue, CustomTransactionError, DatabaseEntry, Db,
    Deleteable, Entry, TransactionError, deserialize_from_ivec,
};

#[derive(Debug, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Id<T: DatabaseEntry>([u8; 32], std::marker::PhantomData<T>);

impl<T: DatabaseEntry> Clone for Id<T> {
    fn clone(&self) -> Self {
        *self
    }
}

impl<T: DatabaseEntry> Copy for Id<T> {}

impl<T: DatabaseEntry> PartialEq for Id<T> {
    fn eq(&self, other: &Self) -> bool {
        self.0 == other.0
    }
}

impl<T: DatabaseEntry> Eq for Id<T> {}

impl<T: DatabaseEntry> std::hash::Hash for Id<T> {
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
        self.0.hash(state);
    }
}

impl<T: DatabaseEntry> fmt::Display for Id<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        hex::encode(self.0).fmt(f)
    }
}

impl<T: DatabaseEntry> std::str::FromStr for Id<T> {
    type Err = hex::FromHexError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let mut id = [0u8; 32];
        hex::decode_to_slice(s, &mut id)?;
        Ok(Self(id, std::marker::PhantomData))
    }
}

impl<T: DatabaseEntry> AsRef<[u8]> for Id<T> {
    fn as_ref(&self) -> &[u8] {
        &self.0
    }
}

impl<T: DatabaseEntry> Id<T> {
    #[must_use]
    pub fn cast<U: DatabaseEntry>(self) -> Id<U> {
        Id::<U>(self.0, std::marker::PhantomData)
    }
}

impl<T: DatabaseEntry> From<[u8; 32]> for Id<T> {
    fn from(value: [u8; 32]) -> Self {
        Self(value, std::marker::PhantomData)
    }
}

impl<T: DatabaseEntry> TryFrom<IVec> for Id<T> {
    type Error = std::array::TryFromSliceError;

    fn try_from(value: IVec) -> Result<Self, Self::Error> {
        let v: [u8; 32] = (*value).try_into()?;
        Ok(Id::from(v))
    }
}

impl<T: DatabaseEntry> std::ops::Deref for Id<T> {
    type Target = [u8; 32];

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

// TX
impl<T: DatabaseEntry> Id<T> {
    /// Gets the latest version of the item in the transaction context, looking it up in the database if its uncached
    ///
    /// # Errors
    ///
    /// Errors if [`CompareAndSwapTransaction::tx_get()`] errors
    pub fn tx_get(
        &self,
        cas_tx: &CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<Option<Entry<T>>, TransactionError> {
        if let Some(get_request) = cas_tx.get_request::<T>()
            && let Some(cas_value) = get_request.get(self)
        {
            return Ok(cas_value.new.clone().map(|t| Entry::new(t, *self)));
        }

        cas_tx
            .db()
            .entry_tree::<T>()
            .get(self)?
            .map(|ivec| deserialize_from_ivec(ivec).map(|t| Entry::new(t, *self)))
            .transpose()
    }

    /// Checks the latest version of the item in the transaction context to see if an item exists
    ///
    /// # Errors
    ///
    /// Errors if [`CompareAndSwapTransaction::db_check()`] errors
    pub fn tx_check(
        &self,
        cas_tx: &CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<bool, TransactionError> {
        if let Some(get_request) = cas_tx.get_request::<T>()
            && let Some(cas_value) = get_request.get(self)
        {
            return Ok(cas_value.new.is_some());
        }

        self.db_check(cas_tx.db())
    }

    pub fn tx_fetch_and_update<E>(
        &self,
        f: impl FnOnce(
            Option<T>,
            &mut CompareAndSwapTransaction<T::DbInner>,
        ) -> Result<Option<T>, CustomTransactionError<E>>,
        cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<(), CustomTransactionError<E>> {
        let (current_state, mut cas_value) = {
            let request = cas_tx.get_or_new_request::<T>();
            if let Some(v) = request.take_out(self) {
                (v.new.clone(), v)
            } else {
                let old = request.tree().get(self).map_err(|err| {
                    CustomTransactionError::Transaction(TransactionError::Sled(err))
                })?;
                (
                    old.clone().map(deserialize_from_ivec).transpose()?,
                    CompareAndSwapValue::new(old, None),
                )
            }
        };

        cas_value.new = f(current_state, cas_tx)?;
        cas_tx
            .get_request_mut()
            .unwrap()
            .insert_raw(*self, cas_value);

        Ok(())
    }

    /// Deletes an entry from the database safely by unlinking references before its deletion
    pub fn tx_delete(
        &self,
        cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<(), TransactionError>
    where
        T: Deleteable,
    {
        self.tx_fetch_and_update(
            |old, cas_tx| {
                let old = old.expect("Deleted ref");
                T::unlink_references(old.to_entry(*self), cas_tx)?;
                Ok(None)
            },
            cas_tx,
        )
        .map_err(TransactionError::from)
    }
}

// DB
impl<T: DatabaseEntry> Id<T> {
    /// Checks if this id exists in the database
    ///
    /// # Errors
    ///
    /// Errors if [`sled`] fails to read the entry at the `id`
    pub fn db_check(&self, db: &Db<T::DbInner>) -> Result<bool, TransactionError> {
        Ok(db.entry_tree::<T>().contains_key(self)?)
    }

    /// Gets the entry of [`Self`] with the given `id` from the input `db`
    ///
    /// # Errors
    ///
    /// Errors if [`sled`] fails to retrieve the entry
    pub fn db_get(&self, db: &Db<T::DbInner>) -> Result<Option<Entry<T>>, TransactionError> {
        db.entry_tree::<T>()
            .get(*self)?
            .map(|ivec| deserialize_from_ivec(ivec).map(|t| Entry::new(t, *self)))
            .transpose()
    }

    /// Wrapper around [`Id::tx_delete`]
    pub fn db_delete(&self, db: &Db<T::DbInner>) -> Result<(), TransactionError>
    where
        T: Deleteable,
    {
        db.transaction(false, None, |cas_tx| {
            self.tx_delete(cas_tx).map_err(CustomTransactionError::from)
        })
        .map_err(TransactionError::from)
    }
}

pub trait IdExt<T>
where
    Self: IntoIterator + Sized,
    Self::Item: Borrow<Id<T>>,
    T: DatabaseEntry,
{
    /// Gets all entries of [`Self`] in the database with the matching `ids` from the input `db`
    ///
    /// # Errors
    ///
    /// Errors if [`sled`] fails to retrieve any entry
    fn db_get(self, db: &Db<T::DbInner>) -> Result<Vec<Entry<T>>, TransactionError> {
        self.into_iter()
            .map(|id| {
                id.borrow()
                    .db_get(db)?
                    .ok_or(TransactionError::MissingEntry)
            })
            .collect()
    }

    /// Gets all entries of [`Self`] in the database with the matching `ids` from the current transaction context
    ///
    /// # Errors
    ///
    /// Errors if [`Id::tx_get()`] errors or any entry is missing
    fn tx_get(
        self,
        cas_tx: &CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<Vec<Entry<T>>, TransactionError> {
        self.into_iter()
            .map(|id| {
                id.borrow()
                    .tx_get(cas_tx)?
                    .ok_or(TransactionError::MissingEntry)
            })
            .collect()
    }

    fn tx_fetch_and_update(
        self,
        mut f: impl FnMut(Option<T>) -> Option<T>,
        cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<(), TransactionError> {
        let request = cas_tx.get_or_new_request::<T>();

        for id in self {
            let id = *id.borrow();

            let (current_state, mut cas_value) = {
                if let Some(v) = request.take_out(&id) {
                    (v.new.clone(), v)
                } else {
                    let old = request.tree().get(id).map_err(|err| {
                        CustomTransactionError::Transaction(TransactionError::Sled(err))
                    })?;
                    (
                        old.clone().map(deserialize_from_ivec).transpose()?,
                        CompareAndSwapValue::new(old, None),
                    )
                }
            };

            cas_value.new = f(current_state);
            request.insert_raw(id, cas_value);
        }

        Ok(())
    }
}

impl<T, I> IdExt<T> for I
where
    Self: IntoIterator + Sized,
    Self::Item: Borrow<Id<T>>,
    T: DatabaseEntry,
{
}