lunar-lib 0.10.0

Common utilities for lunar applications
Documentation
use std::{
    collections::{HashMap, hash_map},
    convert::Infallible,
    fs, io,
    path::PathBuf,
    sync::{Arc, RwLock},
};

use sled::{
    IVec, Tree,
    transaction::{TransactionError as SledTransactionError, UnabortableTransactionError},
};

use thiserror::Error;

mod entry;
pub use entry::*;

mod id;
pub use id::*;

mod extensions;
pub use extensions::*;

mod transactions;
pub use transactions::*;

mod index;
pub use index::*;

pub mod caching;

pub mod writer;

#[derive(Debug, Error)]
pub enum DatabaseError {
    #[error("IoError: {0}")]
    Io(#[from] std::io::Error),

    #[error("Sled error: {0}")]
    Sled(#[from] sled::Error),

    #[error("{0}")]
    Transaction(#[from] TransactionError),

    #[error("The database failed to be obtained after {0} retries")]
    TooManyOpenRetries(usize),
}

#[derive(Debug, Error)]
pub enum TransactionError {
    #[error("Sled error: {0}")]
    Sled(#[from] sled::Error),

    #[error(
        "An internal compare and swap failed during a transaction. This variant will never be returned from a function"
    )]
    CompareAndSwapError,

    #[error("An internal compare and swap was retried too many times")]
    TooManyRetries,

    #[error("A key was not found in the database when it was expected")]
    MissingEntry,

    #[error("A key was found in the database when it was not expected")]
    AlreadyInDatabase,

    #[error("A record was out of date. v{0} was expected but v{1} was found")]
    OutdatedVesion(u32, u32),
}

#[derive(Debug, Error)]
pub enum CustomTransactionError<E> {
    #[error("{0}")]
    Transaction(TransactionError),

    #[error("Transaction closure failed: {0}")]
    Closure(E),
}

impl From<SledTransactionError<TransactionError>> for TransactionError {
    fn from(value: SledTransactionError<TransactionError>) -> Self {
        match value {
            SledTransactionError::Abort(err) => err,
            SledTransactionError::Storage(error) => Self::Sled(error),
        }
    }
}

impl From<UnabortableTransactionError> for TransactionError {
    fn from(value: UnabortableTransactionError) -> Self {
        match value {
            UnabortableTransactionError::Conflict => TransactionError::CompareAndSwapError,
            UnabortableTransactionError::Storage(error) => TransactionError::Sled(error),
        }
    }
}

impl<E> From<TransactionError> for CustomTransactionError<E> {
    fn from(value: TransactionError) -> Self {
        CustomTransactionError::Transaction(value)
    }
}

impl From<CustomTransactionError<Infallible>> for TransactionError {
    fn from(value: CustomTransactionError<Infallible>) -> Self {
        match value {
            CustomTransactionError::Transaction(transaction_error) => transaction_error,
            CustomTransactionError::Closure(_) => unreachable!(),
        }
    }
}

pub struct Db<Inner: Database> {
    inner: sled::Db,
    trees: Arc<RwLock<HashMap<String, Tree>>>,
    _marker: std::marker::PhantomData<Inner>,
}

impl<I: Database> Clone for Db<I> {
    fn clone(&self) -> Self {
        Self {
            inner: self.inner.clone(),
            trees: self.trees.clone(),
            _marker: std::marker::PhantomData,
        }
    }
}

impl<I: Database> std::ops::Deref for Db<I> {
    type Target = sled::Db;

    fn deref(&self) -> &Self::Target {
        &self.inner
    }
}

impl<I: Database> Db<I> {
    pub fn open() -> Result<Self, DatabaseError> {
        const { assert!(std::mem::size_of::<I>() == 0) }

        let db = if let Some(path) = &I::path() {
            sled::Config::new()
                .mode(sled::Mode::HighThroughput)
                .path(path)
                .open()?
        } else {
            sled::Config::new()
                .temporary(true)
                .mode(sled::Mode::HighThroughput)
                .open()?
        };

        Ok(Db {
            inner: db,
            trees: Arc::new(RwLock::new(HashMap::new())),
            _marker: std::marker::PhantomData,
        })
    }

    pub fn open_temp() -> Result<Self, DatabaseError> {
        const { assert!(std::mem::size_of::<I>() == 0) }

        let db = sled::Config::new()
            .temporary(true)
            .mode(sled::Mode::HighThroughput)
            .open()?;

        Ok(Db {
            inner: db,
            trees: Arc::new(RwLock::new(HashMap::new())),
            _marker: std::marker::PhantomData,
        })
    }

    pub fn tree(&self, name: String) -> Tree {
        {
            let read = self.trees.read().unwrap();
            if let Some(tree) = read.get(&name) {
                return tree.clone();
            }
        }

        let mut write = self.trees.write().unwrap();
        match write.entry(name.clone()) {
            hash_map::Entry::Occupied(e) => e.get().clone(),
            hash_map::Entry::Vacant(v) => {
                let tree = self.open_tree(name).unwrap();
                v.insert(tree.clone());
                tree
            }
        }
    }

    pub fn index<'a, T: DatabaseEntry<DbInner = I>>(&'a self, name: &'static str) -> Index<'a, T> {
        Index {
            db: self,
            tree: self.tree(format!("index_{name}")),
            name,
        }
    }

    /// Attempts to delete the entire contents of the database by deleting the path its stored at
    pub fn delete(self) -> io::Result<()>
    where
        Self: Sized,
    {
        if let Some(path) = I::path()
            && let Err(err) = fs::remove_dir_all(path)
        {
            return Err(err);
        }

        Ok(())
    }

    /// Opens a new compare-and-swap transaction and runs the provided closure with mutable refence to it. The closure may be run multiple times if a CAS failure occurs
    ///
    /// Transactions are applied atomically
    ///
    /// # Errors
    ///
    /// This function will error if `f()` returns an error, or if [`apply_cas_tx()`] fails with an error other than [`CompareAndSwapError`]
    pub fn transaction<T, E, F>(
        &self,
        flush: bool,
        max_retries: Option<usize>,
        f: F,
    ) -> Result<T, CustomTransactionError<E>>
    where
        F: Fn(&mut CompareAndSwapTransaction<I>) -> Result<T, CustomTransactionError<E>>,
    {
        for _ in 0..max_retries.unwrap_or(usize::MAX) {
            let mut cas_tx = CompareAndSwapTransaction::with_db(self.clone());

            let t = f(&mut cas_tx)?;

            match cas_tx.apply(flush) {
                Ok(()) => return Ok(t),
                Err(TransactionError::CompareAndSwapError) => {
                    continue;
                }
                Err(err) => return Err(CustomTransactionError::Transaction(err)),
            }
        }

        Err(CustomTransactionError::Transaction(
            TransactionError::TooManyRetries,
        ))
    }
}

impl<I: Database> Db<I> {
    pub fn entry_tree<T: DatabaseEntry<DbInner = I>>(&self) -> Tree {
        self.tree(format!("entry_{}", T::TREE_NAME))
    }

    /// Iterate over all entries of [`T`] in from `self`
    ///
    /// # Errors
    ///
    /// Errors if [`sled`] fails to retrieve any entry
    pub fn iter_entries<T: DatabaseEntry<DbInner = I>>(
        &self,
    ) -> impl Iterator<Item = Result<Entry<T>, TransactionError>> {
        self.entry_tree::<T>().iter().map(Entry::from_sled_batch)
    }

    pub fn entry_count<T: DatabaseEntry<DbInner = I>>(&self) -> usize {
        self.entry_tree::<T>().len()
    }
}

pub trait Database: Sized + 'static + Send + Sync {
    /// Path this database is stored at. This should return none for a temporary database
    fn path() -> Option<PathBuf>;
}

#[macro_export]
macro_rules! define_db {
    ( $name:ident { $($body:tt)* } ) => {
        pub enum $name {}

        impl $crate::database::Database for $name {
            $($body)*
        }
    };
}

pub(crate) fn deserialize_from_ivec<T: DatabaseEntry>(raw: IVec) -> Result<T, TransactionError> {
    let version = u32::from_be_bytes(raw[0..4].try_into().expect("Corrupted database entry"));

    if version != T::VERSION_NUMBER {
        return Err(TransactionError::OutdatedVesion(T::VERSION_NUMBER, version));
    }

    match cbor4ii::serde::from_slice(&raw[4..]) {
        Ok(t) => Ok(t),
        Err(err) => {
            panic!(
                "An item in the database failed to deserialize, but the version number has not changed.\nError:\n{err}\nThis could be due to a deserialization bug, database corruption, or from an internal data change without updating the version"
            )
        }
    }
}

#[cfg(test)]
mod tests;