lunar-lib 0.11.0

Common utilities for lunar applications
Documentation
use sled::IVec;
use std::borrow::Borrow;

use crate::{
    database::{
        CompareAndSwapTransaction, CompareAndSwapValue, CustomTransactionError, DatabaseEntry, Db,
        Deleteable, Entry, TransactionError, deserialize_from_ivec,
    },
    id::Id,
};

impl<T: DatabaseEntry> TryFrom<IVec> for Id<T> {
    type Error = std::array::TryFromSliceError;

    fn try_from(value: IVec) -> Result<Self, Self::Error> {
        let v: [u8; 32] = (*value).try_into()?;
        Ok(Id::from(v))
    }
}

pub trait DbIdExt<T: DatabaseEntry>: Borrow<Id<T>> + Sized + Copy {
    // ###############
    // TRANSACTION EXT
    // ###############

    /// Gets the latest version of the item in the transaction context, looking it up in the database if its uncached
    ///
    /// # Errors
    ///
    /// Errors if [`CompareAndSwapTransaction::tx_get()`] errors
    fn tx_get(
        self,
        cas_tx: &CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<Option<Entry<T>>, TransactionError> {
        if let Some(get_request) = cas_tx.get_request::<T>()
            && let Some(cas_value) = get_request.get(self.borrow())
        {
            return Ok(cas_value.new.clone().map(|t| Entry::new(t, *self.borrow())));
        }

        cas_tx
            .db()
            .entry_tree::<T>()
            .get(self.borrow())?
            .map(|ivec| deserialize_from_ivec(ivec).map(|t| Entry::new(t, *self.borrow())))
            .transpose()
    }

    /// Checks the latest version of the item in the transaction context to see if an item exists
    ///
    /// # Errors
    ///
    /// Errors if [`CompareAndSwapTransaction::db_check()`] errors
    fn tx_check(
        self,
        cas_tx: &CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<bool, TransactionError> {
        if let Some(get_request) = cas_tx.get_request::<T>()
            && let Some(cas_value) = get_request.get(self.borrow())
        {
            return Ok(cas_value.new.is_some());
        }

        self.db_check(cas_tx.db())
    }

    fn tx_fetch_and_update<E>(
        self,
        f: impl FnOnce(
            Option<T>,
            &mut CompareAndSwapTransaction<T::DbInner>,
        ) -> Result<Option<T>, CustomTransactionError<E>>,
        cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<(), CustomTransactionError<E>> {
        let (current_state, mut cas_value) = {
            let request = cas_tx.get_or_new_request::<T>();
            if let Some(v) = request.take_out(self.borrow()) {
                (v.new.clone(), v)
            } else {
                let old = request.tree().get(self.borrow()).map_err(|err| {
                    CustomTransactionError::Transaction(TransactionError::Sled(err))
                })?;
                (
                    old.clone().map(deserialize_from_ivec).transpose()?,
                    CompareAndSwapValue::new(old, None),
                )
            }
        };

        cas_value.new = f(current_state, cas_tx)?;
        cas_tx
            .get_request_mut()
            .unwrap()
            .insert_raw(*self.borrow(), cas_value);

        Ok(())
    }

    /// Deletes an entry from the database safely by unlinking references before its deletion
    fn tx_delete(
        self,
        cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<(), TransactionError>
    where
        T: Deleteable,
    {
        let id = *self.borrow();
        self.tx_fetch_and_update(
            |old, cas_tx| {
                let old = old.expect("Deleted ref");
                T::unlink_references(old.to_entry(id), cas_tx)?;
                Ok(None)
            },
            cas_tx,
        )
        .map_err(TransactionError::from)
    }

    // ###############
    // DATABASE EXT
    // ###############
    /// Checks if this id exists in the database
    ///
    /// # Errors
    ///
    /// Errors if [`sled`] fails to read the entry at the `id`
    fn db_check(self, db: &Db<T::DbInner>) -> Result<bool, TransactionError> {
        Ok(db.entry_tree::<T>().contains_key(self.borrow())?)
    }

    /// Gets the entry of [`Self`] with the given `id` from the input `db`
    ///
    /// # Errors
    ///
    /// Errors if [`sled`] fails to retrieve the entry
    fn db_get(self, db: &Db<T::DbInner>) -> Result<Option<Entry<T>>, TransactionError> {
        db.entry_tree::<T>()
            .get(self.borrow())?
            .map(|ivec| deserialize_from_ivec(ivec).map(|t| Entry::new(t, *self.borrow())))
            .transpose()
    }

    /// Wrapper around [`Id::tx_delete`]
    fn db_delete(self, db: &Db<T::DbInner>) -> Result<(), TransactionError>
    where
        T: Deleteable,
    {
        db.transaction(false, None, |cas_tx| {
            self.tx_delete(cas_tx).map_err(CustomTransactionError::from)
        })
        .map_err(TransactionError::from)
    }
}

impl<T: DatabaseEntry, T1: Borrow<Id<T>> + Copy> DbIdExt<T> for T1 {}

pub trait DbIdIterExt<T>
where
    Self: IntoIterator + Sized,
    Self::Item: Borrow<Id<T>>,
    T: DatabaseEntry,
{
    /// Gets all entries of [`Self`] in the database with the matching `ids` from the input `db`
    ///
    /// # Errors
    ///
    /// Errors if [`sled`] fails to retrieve any entry
    fn db_get(self, db: &Db<T::DbInner>) -> Result<Vec<Entry<T>>, TransactionError> {
        self.into_iter()
            .map(|id| {
                id.borrow()
                    .db_get(db)?
                    .ok_or(TransactionError::MissingEntry)
            })
            .collect()
    }

    /// Gets all entries of [`Self`] in the database with the matching `ids` from the current transaction context
    ///
    /// # Errors
    ///
    /// Errors if [`Id::tx_get()`] errors or any entry is missing
    fn tx_get(
        self,
        cas_tx: &CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<Vec<Entry<T>>, TransactionError> {
        self.into_iter()
            .map(|id| {
                id.borrow()
                    .tx_get(cas_tx)?
                    .ok_or(TransactionError::MissingEntry)
            })
            .collect()
    }

    fn tx_fetch_and_update(
        self,
        mut f: impl FnMut(Option<T>) -> Option<T>,
        cas_tx: &mut CompareAndSwapTransaction<T::DbInner>,
    ) -> Result<(), TransactionError> {
        let request = cas_tx.get_or_new_request::<T>();

        for id in self {
            let id = *id.borrow();

            let (current_state, mut cas_value) = {
                if let Some(v) = request.take_out(&id) {
                    (v.new.clone(), v)
                } else {
                    let old = request.tree().get(id).map_err(|err| {
                        CustomTransactionError::Transaction(TransactionError::Sled(err))
                    })?;
                    (
                        old.clone().map(deserialize_from_ivec).transpose()?,
                        CompareAndSwapValue::new(old, None),
                    )
                }
            };

            cas_value.new = f(current_state);
            request.insert_raw(id, cas_value);
        }

        Ok(())
    }
}

impl<T, I> DbIdIterExt<T> for I
where
    Self: IntoIterator + Sized,
    Self::Item: Borrow<Id<T>>,
    T: DatabaseEntry,
{
}