lunar-lib 0.6.1

Common utilities for lunar applications
Documentation
use std::{
    any::{Any, TypeId},
    borrow::Borrow,
    collections::{HashMap, hash_map},
    rc::Rc,
};

use serde::Deserialize;
use sled::{
    IVec, Transactional, Tree,
    transaction::{ConflictableTransactionError, TransactionalTree},
};

use crate::{
    database::{
        CustomTransactionError, Database, DatabaseEntry, DatabaseError, DbKey, EntryId,
        TransactionError, deserialize_from_ivec, serialize_to_ivec, sled_get_raw,
    },
    trace,
};

#[derive(Debug, Clone)]
/// Holds a compare and swap value, where `old` is the expected value of the item in the database, and `new` is the data we want to overwrite it with
pub struct CompareAndSwapValue<Entry: DatabaseEntry> {
    pub old: Option<IVec>,
    pub new: Option<Entry>,
}

impl<Entry: DatabaseEntry> CompareAndSwapValue<Entry> {
    #[must_use]
    pub fn new(old: Option<IVec>, new: Option<Entry>) -> Self {
        Self { old, new }
    }
}

#[derive(Debug)]
/// Holds a group of compare and swap values for the given tree
pub struct TreeCompareAndSwap<Entry: DatabaseEntry> {
    tree: Tree,
    swaps: HashMap<DbKey, CompareAndSwapValue<Entry>>,
}

impl<T: DatabaseEntry> TreeCompareAndSwap<T> {
    fn new(db: &T::EntryDb) -> Self {
        Self {
            tree: T::tree(db),
            swaps: HashMap::new(),
        }
    }

    #[must_use]
    pub fn tree(&self) -> &Tree {
        &self.tree
    }
}

/// Generic wrapper over [`TreeCompareAndSwap<T>`]
trait GenericCompareAndSwap: Any + std::fmt::Debug {
    fn tree(&self) -> &Tree;
    fn as_any(&self) -> &dyn Any;
    fn as_any_mut(&mut self) -> &mut dyn Any;
    unsafe fn merge_from(
        &mut self,
        other: Box<dyn GenericCompareAndSwap>,
    ) -> Result<(), TransactionError>;
    fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError>;
}

impl<Entry: DatabaseEntry> GenericCompareAndSwap for TreeCompareAndSwap<Entry> {
    fn tree(&self) -> &Tree {
        &self.tree
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn as_any_mut(&mut self) -> &mut dyn Any {
        self
    }

    /// Merges two [`GenericCompareAndSwap`] structs of the same type
    ///
    /// # Safety
    ///
    /// The `other` input of this function must be the same type of `Self`. Using any other type is undefined behaivor
    unsafe fn merge_from(
        &mut self,
        other: Box<dyn GenericCompareAndSwap>,
    ) -> Result<(), TransactionError> {
        // SAFETY: See function doc comment
        let other =
            unsafe { *Box::from_raw(Box::into_raw(other) as *mut TreeCompareAndSwap<Entry>) };

        for (other_key, other_value) in other.swaps {
            match self.swaps.entry(other_key) {
                hash_map::Entry::Occupied(entry) => {
                    if entry.get().old != other_value.old {
                        return Err(TransactionError::CompareAndSwapError);
                    }
                }
                hash_map::Entry::Vacant(entry) => {
                    entry.insert(other_value);
                }
            }
        }

        Ok(())
    }

    /// Applies all changes [`Self`] holds to the `tx_tree`
    fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError> {
        #[derive(Deserialize)]
        struct ReadOnly {
            read_only: Option<bool>,
        }

        for (k, v) in &self.swaps {
            let db_old = tx_tree.get(k)?;

            if db_old.as_deref().is_some_and(|v| {
                let read_only: ReadOnly = ciborium::from_reader(v).unwrap();
                read_only.read_only.unwrap_or(false)
            }) {
                return Err(DatabaseError::ReadOnly.into());
            };

            if db_old == v.old {
                if let Some(new) = &v.new {
                    tx_tree.insert(k, serialize_to_ivec(&new))?;
                } else {
                    tx_tree.remove(k)?;
                }
            } else {
                trace!(
                    "CAS mismatch on key {:?}: Expected {:?}. Found {:?}",
                    k, v.old, db_old
                );
                return Err(TransactionError::CompareAndSwapError);
            }
        }
        Ok(())
    }
}

#[derive(Debug)]
/// Defines an entire compare-and-swap transaction
pub struct CompareAndSwapTransaction<CasDb: Database> {
    swaps: HashMap<TypeId, Box<dyn GenericCompareAndSwap>>,
    database: Rc<CasDb>,
}

impl<CasDb: Database> CompareAndSwapTransaction<CasDb> {
    #[must_use]
    /// Opens a compare and swap transaction with [`Db::open()`]
    ///
    /// # Errors
    ///
    /// Errors if the database cannot be opened
    pub(crate) fn new() -> Result<Self, DatabaseError> {
        Ok(Self {
            swaps: HashMap::new(),
            database: Rc::new(CasDb::open()?),
        })
    }

    #[must_use]
    pub(crate) fn with_db(database: Rc<CasDb>) -> Self {
        Self {
            swaps: HashMap::new(),
            database,
        }
    }

    pub fn merge(&mut self, other: Self) -> Result<(), TransactionError> {
        for (type_id, other_tree) in other.swaps {
            match self.swaps.entry(type_id) {
                hash_map::Entry::Occupied(mut entry) => {
                    // SAFETY: The TreeCompareAndSwap types of other and self must be the same type, as we get the value of self using the type ID of other
                    unsafe {
                        entry.get_mut().merge_from(other_tree)?;
                    }
                }
                hash_map::Entry::Vacant(entry) => {
                    entry.insert(other_tree);
                }
            }
        }
        Ok(())
    }

    /// Gets the latest version of the item in the transaction context, looking it up in the database if its unmodified
    ///
    /// # Errors
    ///
    /// Errors if [`DatabaseEntry::db_get()`] errors
    pub fn tx_get<Id>(&self, id: Id) -> Result<Option<Id::Entry>, TransactionError>
    where
        Id: EntryId<IdDb = CasDb>,
    {
        if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
            let cas_tree = boxed
                .as_any()
                .downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
                .unwrap();
            if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
                return Ok(get.new.clone());
            }
        }

        let tree = Id::Entry::tree(&*self.database);
        let raw = sled_get_raw(&tree, id.as_bytes())?;
        raw.map(deserialize_from_ivec)
            .transpose()
            .map_err(TransactionError::Database)
    }

    /// Checks if the item exists in the latest version of the transaction context, looking it up in the database if its unmodified
    ///
    /// # Errors
    ///
    /// Errors if sled fails to check the tree
    pub fn tx_check<Id>(&self, id: Id) -> Result<bool, TransactionError>
    where
        Id: EntryId<IdDb = CasDb>,
    {
        if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
            let cas_tree = boxed
                .as_any()
                .downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
                .unwrap();
            if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
                return Ok(get.new.is_some());
            }
        }

        let tree = Id::Entry::tree(&*self.database);
        Ok(tree.contains_key(id.as_bytes())?)
    }

    pub fn tx_get_batch<Entry, I>(&self, items: I) -> Result<Vec<Entry>, TransactionError>
    where
        Entry: DatabaseEntry<EntryDb = CasDb>,
        I: IntoIterator<Item: Borrow<Entry::Id>>,
    {
        items
            .into_iter()
            .map(|id| {
                self.tx_get(*id.borrow())?
                    .ok_or(DatabaseError::MissingEntry.into())
            })
            .collect()
    }

    pub fn tx_remove<Id>(&mut self, key: Id) -> Result<(), TransactionError>
    where
        Id: EntryId<IdDb = CasDb>,
    {
        let db = self.database.clone();
        let request = self.get_or_new_request::<Id::Entry>();

        let key = *key.as_bytes();
        if let Some(get_mut) = request.swaps.get_mut(&key) {
            get_mut.new = None;
        } else {
            let old = sled_get_raw(&Id::Entry::tree(&*db), &key)?;
            request
                .swaps
                .insert(key, CompareAndSwapValue { old, new: None });
        }
        Ok(())
    }

    pub fn tx_upsert<Entry: DatabaseEntry<EntryDb = CasDb>>(
        &mut self,
        key: Entry::Id,
        mut new: Option<Entry>,
    ) -> Result<(), TransactionError> {
        let db = self.database.clone();
        if let Some(new) = &mut new {
            new.pre_upsert(self)?;
        }

        let request = self.get_or_new_request::<Entry>();

        let key = *key.as_bytes();

        if let Some(get_mut) = request.swaps.get_mut(&key) {
            get_mut.new = new;
        } else {
            let old = sled_get_raw(&Entry::tree(&*db), &key)?;
            request.swaps.insert(key, CompareAndSwapValue { old, new });
        }
        Ok(())
    }

    pub fn tx_insert<Entry: DatabaseEntry<EntryDb = CasDb>>(
        &mut self,
        item: Entry,
    ) -> Result<(), TransactionError> {
        if self.tx_get(item.id())?.is_some() {
            return Err(DatabaseError::AlreadyInDatabase.into());
        }

        self.tx_upsert(item.id(), Some(item))?;
        Ok(())
    }

    pub fn get_or_new_request<Entry>(&mut self) -> &mut TreeCompareAndSwap<Entry>
    where
        Entry: DatabaseEntry<EntryDb = CasDb>,
    {
        self.swaps
            .entry(TypeId::of::<Entry>())
            .or_insert_with(|| Box::new(TreeCompareAndSwap::<Entry>::new(&*self.database)))
            .as_any_mut()
            .downcast_mut::<TreeCompareAndSwap<Entry>>()
            .unwrap()
    }
}

/// Applies a [`CompareAndSwapTransaction`] atomically to the database
///
/// # Errors
///
/// This function will error if `[sled]` fails get, insert, or remove a key OR abort with a [`CompareAndSwapError`] if the current value does not match the expected value
pub fn apply_cas_tx<CasDb: Database>(
    tx: CompareAndSwapTransaction<CasDb>,
    flush: bool,
) -> Result<(), TransactionError> {
    let (trees, swaps): (Vec<&Tree>, Vec<&Box<dyn GenericCompareAndSwap>>) =
        tx.swaps.values().map(|cas| (cas.tree(), cas)).unzip();

    trees.transaction(|tx_trees| {
        for (tree, cas) in tx_trees.iter().zip(swaps.iter()) {
            cas.apply(tree)
                .map_err(ConflictableTransactionError::Abort)?;
            if flush {
                tree.flush();
            }
        }
        Ok(())
    })?;
    Ok(())
}

/// Calls a closure that returns a [`CompareAndSwapTransaction`] and applies it atomically to the database
///
/// This function will be called again if a [`CompareAndSwapError`] occurs
///
/// # Errors
///
/// This function will error if `f()` returns an error, or if [`apply_cas_tx()`] fails with an error other than [`CompareAndSwapError`]
pub fn db_transaction<CasDb: Database, F, T, E>(
    mut f: F,
    db: Option<CasDb>,
    flush: bool,
) -> Result<T, CustomTransactionError<E>>
where
    F: FnMut(&mut CompareAndSwapTransaction<CasDb>) -> Result<T, CustomTransactionError<E>>,
{
    let db = db.map(Rc::new);

    loop {
        let mut cas_tx = if let Some(db) = db.clone() {
            CompareAndSwapTransaction::with_db(db)
        } else {
            CompareAndSwapTransaction::new().map_err(TransactionError::Database)?
        };

        let t = f(&mut cas_tx)?;

        match apply_cas_tx(cas_tx, flush) {
            Ok(()) => return Ok(t),
            Err(TransactionError::CompareAndSwapError) => {
                trace!("Transaction (Not sync) ran into a CAS error and is retrying.");
            }
            Err(err) => return Err(err.into()),
        }
    }
}