lunar-lib 0.8.0

Common utilities for lunar applications
Documentation
use std::{
    any::{Any, TypeId},
    collections::HashMap,
    convert::Infallible,
    fs,
    ops::Deref,
    path::{Path, PathBuf},
    sync::{Arc, LazyLock, OnceLock, RwLock, Weak},
    thread,
    time::Duration,
};

pub use sled::{Db as SledDb, Transactional, Tree, transaction::ConflictableTransactionResult};

use sled::{
    IVec,
    transaction::{TransactionError as SledTransactionError, UnabortableTransactionError},
};

use thiserror::Error;

/// Module related to the dedicated writer thread
pub mod writer;

mod sled_ops;
pub(crate) use sled_ops::*;

mod database_entry;
pub use database_entry::*;

mod transaction_args;
pub use transaction_args::*;

pub mod caching;

#[derive(Debug, Error)]
pub enum DatabaseError {
    #[error("IoError: {0}")]
    Io(#[from] std::io::Error),

    #[error("pre_open() function failed with error: {0}")]
    PreOpenError(Box<dyn std::error::Error + Send + Sync>),

    #[error("Sled error: {0}")]
    Sled(#[from] sled::Error),

    #[error("{0}")]
    Transaction(#[from] TransactionError),

    #[error("The database failed to be obtained after {0} retries")]
    TooManyOpenRetries(usize),
}

#[derive(Debug, Error)]
pub enum TransactionError {
    #[error("Sled error: {0}")]
    Sled(#[from] sled::Error),

    #[error(
        "An internal compare and swap failed during a transaction. This variant will never be returned from a function"
    )]
    CompareAndSwapError,

    #[error("An internal compare and swap was retried too many times")]
    TooManyRetries,

    #[error("A key was not found in the database when it was expected")]
    MissingEntry,

    #[error("A key was found in the database when it was not expected")]
    AlreadyInDatabase,

    #[error("A record was out of date. v{0} was expected but v{1} was found")]
    OutdatedVesion(u32, u32),

    #[error("Tried to write or overwrite a read-only value")]
    ReadOnly,
}

#[derive(Debug, Error)]
pub enum CustomTransactionError<E> {
    #[error("{0}")]
    Transaction(TransactionError),

    #[error("Transaction closure failed: {0}")]
    Closure(E),
}

impl From<SledTransactionError<TransactionError>> for TransactionError {
    fn from(value: SledTransactionError<TransactionError>) -> Self {
        match value {
            SledTransactionError::Abort(err) => err,
            SledTransactionError::Storage(error) => Self::Sled(error),
        }
    }
}

impl From<UnabortableTransactionError> for TransactionError {
    fn from(value: UnabortableTransactionError) -> Self {
        match value {
            UnabortableTransactionError::Conflict => TransactionError::CompareAndSwapError,
            UnabortableTransactionError::Storage(error) => TransactionError::Sled(error),
        }
    }
}

impl<E> From<TransactionError> for CustomTransactionError<E> {
    fn from(value: TransactionError) -> Self {
        CustomTransactionError::Transaction(value)
    }
}

impl From<CustomTransactionError<Infallible>> for TransactionError {
    fn from(value: CustomTransactionError<Infallible>) -> Self {
        match value {
            CustomTransactionError::Transaction(transaction_error) => transaction_error,
            CustomTransactionError::Closure(_) => unreachable!(),
        }
    }
}

#[derive(Clone)]
pub struct DbHandle<Db: Database> {
    inner: Arc<Db>,
}

impl<Db: Database> Deref for DbHandle<Db> {
    type Target = Db;

    fn deref(&self) -> &Self::Target {
        &*self.inner
    }
}

impl<Db: Database> DbHandle<Db> {
    pub fn open() -> Result<Self, DatabaseError> {
        let type_id = TypeId::of::<Self>();

        {
            let cache = DB_CACHE.read().unwrap();
            if let Some(entry) = cache.get(&type_id)
                && let Some(upgrade) = entry.upgrade()
            {
                return Ok(DbHandle {
                    inner: upgrade.downcast::<Db>().unwrap(),
                });
            }
        }

        let mut cache = DB_CACHE.write().unwrap();
        if let Some(entry) = cache.get(&type_id)
            && let Some(upgrade) = entry.upgrade()
        {
            return Ok(DbHandle {
                inner: upgrade.downcast::<Db>().unwrap(),
            });
        }

        let path = Db::__path();

        for _ in 0..Db::RETRY_MAX_ATTEMPTS.unwrap_or(usize::MAX) {
            match sled::open(path) {
                Ok(db) => {
                    let db = Db::new(db);
                    Db::pre_open(&db).map_err(DatabaseError::PreOpenError)?;

                    let inner = Arc::new(db);
                    cache.insert(
                        type_id,
                        Arc::downgrade(&(inner.clone() as Arc<dyn Any + Send + Sync>)),
                    );

                    return Ok(DbHandle { inner });
                }
                Err(sled::Error::Io(err))
                    if matches!(err.kind(), std::io::ErrorKind::Other)
                        && err.to_string().contains("could not acquire lock on ") =>
                {
                    thread::sleep(Db::RETRY_DURATION);
                    continue;
                }
                Err(err) => return Err(err.into()),
            }
        }
        Err(DatabaseError::TooManyOpenRetries(
            Db::RETRY_MAX_ATTEMPTS.unwrap_or(usize::MAX),
        ))
    }

    /// Attempts to delete the entire contents of the database by deleting the path its stored at
    ///
    /// # Warning
    ///
    /// This will attempt to aquire the database using [`Self::Open`]
    pub fn delete(&self) -> Result<(), DatabaseError>
    where
        Self: Sized,
    {
        if let Err(err) = fs::remove_dir_all(Db::__path())
            && !matches!(err.kind(), std::io::ErrorKind::NotFound)
        {
            return Err(err.into());
        };

        Ok(())
    }

    /// Attempts to flush the database
    ///
    /// # Warning
    ///
    /// This will attempt to aquire the database using [`Self::Open`]
    pub fn flush(&self) -> Result<(), DatabaseError>
    where
        Self: Sized,
    {
        self.inner.db().flush()?;
        Ok(())
    }
}

pub trait Database: Clone + Sized + 'static + Send + Sync {
    /// The maximum number of times [`Self::open()`] is allowed to retry before returning an error. Must be atleast `1`
    ///
    /// If [`None`], this function will retry essentially forever
    const RETRY_MAX_ATTEMPTS: Option<usize>;

    /// The amount of time between retries
    const RETRY_DURATION: Duration;

    /// Creates a new instance of self with the input db
    fn new(db: SledDb) -> Self;

    fn db(&self) -> &SledDb;

    /// Path this database is stored at
    fn path() -> PathBuf;

    fn __path() -> &'static Path {
        static STABLE_PATH: OnceLock<PathBuf> = OnceLock::new();
        STABLE_PATH.get_or_init(Self::path)
    }

    /// A function to be applied before the database is returned with [`Self::open()`]
    ///
    /// Also see [`Self::open()`]
    fn pre_open(_db: &Self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
        Ok(())
    }
}

type DatabaseCacheRegistry = HashMap<TypeId, Weak<dyn Any + Send + Sync>>;
static DB_CACHE: LazyLock<RwLock<DatabaseCacheRegistry>> =
    LazyLock::new(|| RwLock::new(HashMap::new()));

pub(crate) fn deserialize_from_ivec<T: DatabaseEntry>(raw: IVec) -> Result<T, TransactionError> {
    match postcard::from_bytes(&raw[5..]) {
        Ok(t) => Ok(t),
        Err(err) => {
            let version =
                u32::from_be_bytes(raw[0..4].try_into().expect("Corrupted database entry"));

            if version == T::VERSION_NUMBER {
                panic!(
                    "An item in the database failed to deserialize, but the version number has not changed.\nError:\n{err}\nThis could be due to a deserialization bug, database corruption, or from an internal data change without updating the version"
                )
            } else {
                Err(TransactionError::OutdatedVesion(T::VERSION_NUMBER, version))
            }
        }
    }
}

pub(crate) fn serialize_to_ivec<T: DatabaseEntry>(item: &T) -> IVec {
    let version: [u8; 4] = T::VERSION_NUMBER.to_be_bytes();

    let mut buffer = Vec::from(version);
    buffer.push(item.read_only() as u8);

    let buffer = postcard::to_extend(item, buffer)
        .expect("Postcard failed to serialize. This cannot happen unless a serializer failed");

    IVec::from(buffer)
}

pub trait EntryId: Deref<Target = [u8; 32]> + Copy + Eq {
    type IdDb: Database;
    type Entry: DatabaseEntry<Db = Self::IdDb>;

    fn as_bytes(&self) -> &[u8; 32] {
        self
    }
}