lunar-lib 0.11.0

Common utilities for lunar applications
Documentation
use std::{
    any::{Any, TypeId},
    collections::HashMap,
};

use sled::{
    IVec, Transactional, Tree,
    transaction::{ConflictableTransactionError, TransactionalTree},
};

use crate::{
    database::{Database, DatabaseEntry, Db, TransactionError},
    id::Id,
};

/// Holds a compare and swap value, where `old` is the expected value of the item in the database, and `new` is the data we want to overwrite it with
pub struct CompareAndSwapValue<Entry: DatabaseEntry> {
    pub old: Option<IVec>,
    pub new: Option<Entry>,
}

impl<Entry: DatabaseEntry> CompareAndSwapValue<Entry> {
    #[must_use]
    pub fn new(old: Option<IVec>, new: Option<Entry>) -> Self {
        Self { old, new }
    }
}

/// Holds a group of compare and swap values for the given tree
pub struct TreeCompareAndSwap<Entry: DatabaseEntry> {
    tree: Tree,
    swaps: HashMap<Id<Entry>, CompareAndSwapValue<Entry>>,
}

impl<T: DatabaseEntry> TreeCompareAndSwap<T> {
    #[must_use]
    fn new(db: &Db<T::DbInner>) -> Self {
        Self {
            tree: db.entry_tree::<T>(),
            swaps: HashMap::new(),
        }
    }

    #[must_use]
    pub fn tree(&self) -> &Tree {
        &self.tree
    }

    #[must_use]
    pub fn get(&self, id: &Id<T>) -> Option<&CompareAndSwapValue<T>> {
        self.swaps.get(id)
    }

    #[must_use]
    pub fn get_mut(&mut self, id: &Id<T>) -> Option<&mut CompareAndSwapValue<T>> {
        self.swaps.get_mut(id)
    }

    #[must_use]
    pub fn take_out(&mut self, id: &Id<T>) -> Option<CompareAndSwapValue<T>> {
        self.swaps.remove(id)
    }

    pub fn insert(&mut self, id: Id<T>, old: Option<IVec>, new: Option<T>) {
        self.swaps.insert(id, CompareAndSwapValue { old, new });
    }

    pub fn insert_raw(&mut self, id: Id<T>, cas_value: CompareAndSwapValue<T>) {
        self.swaps.insert(id, cas_value);
    }
}

/// Generic wrapper over [`TreeCompareAndSwap<T>`]
trait GenericCompareAndSwap: Any {
    fn tree(&self) -> Tree;
    fn as_any(&self) -> &dyn Any;
    fn as_any_mut(&mut self) -> &mut dyn Any;
    fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError>;
}

impl<Entry: DatabaseEntry> GenericCompareAndSwap for TreeCompareAndSwap<Entry> {
    fn tree(&self) -> Tree {
        self.tree.clone()
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn as_any_mut(&mut self) -> &mut dyn Any {
        self
    }

    /// Applies all changes [`Self`] holds to the `tx_tree`
    fn apply(&self, tx_tree: &TransactionalTree) -> Result<(), TransactionError> {
        for (k, v) in &self.swaps {
            let db_old = tx_tree.get(k)?;

            if db_old == v.old {
                if let Some(new) = &v.new {
                    let buffer = Vec::from(Entry::VERSION_NUMBER.to_be_bytes());
                    let buffer: IVec = cbor4ii::serde::to_vec(buffer, new)
                        .expect("Cbor4ii failed to serialize. This cannot happen unless a serializer failed")
                        .into();

                    tx_tree.insert(&**k, buffer)?;
                } else {
                    tx_tree.remove(&**k)?;
                }
            } else {
                return Err(TransactionError::CompareAndSwapError);
            }
        }
        Ok(())
    }
}

/// Defines an entire compare-and-swap transaction
pub struct CompareAndSwapTransaction<CasDb: Database> {
    swaps: HashMap<TypeId, Box<dyn GenericCompareAndSwap>>,
    database: Db<CasDb>,
}

impl<CasDb: Database> CompareAndSwapTransaction<CasDb> {
    #[must_use]
    pub(super) fn with_db(database: Db<CasDb>) -> Self {
        Self {
            swaps: HashMap::new(),
            database,
        }
    }

    pub(super) fn db(&self) -> &Db<CasDb> {
        &self.database
    }

    /// Returns an immutable reference to an entries compare and swap tree
    pub fn get_request<Entry>(&self) -> Option<&TreeCompareAndSwap<Entry>>
    where
        Entry: DatabaseEntry<DbInner = CasDb>,
    {
        self.swaps.get(&TypeId::of::<Entry>()).map(|boxed| {
            boxed
                .as_any()
                .downcast_ref::<TreeCompareAndSwap<Entry>>()
                .unwrap()
        })
    }

    /// Returns a mutable reference to an entries compare and swap tree
    pub fn get_request_mut<Entry>(&mut self) -> Option<&mut TreeCompareAndSwap<Entry>>
    where
        Entry: DatabaseEntry<DbInner = CasDb>,
    {
        self.swaps.get_mut(&TypeId::of::<Entry>()).map(|boxed| {
            boxed
                .as_any_mut()
                .downcast_mut::<TreeCompareAndSwap<Entry>>()
                .unwrap()
        })
    }

    /// Returns a mutable reference to an entries compare and swap tree, opening a new one if one has not already been opened
    pub fn get_or_new_request<Entry>(&mut self) -> &mut TreeCompareAndSwap<Entry>
    where
        Entry: DatabaseEntry<DbInner = CasDb>,
    {
        self.swaps
            .entry(TypeId::of::<Entry>())
            .or_insert_with(|| Box::new(TreeCompareAndSwap::<Entry>::new(&self.database)))
            .as_any_mut()
            .downcast_mut::<TreeCompareAndSwap<Entry>>()
            .unwrap()
    }

    /// Applies a [`CompareAndSwapTransaction`] atomically to the database
    ///
    /// # Errors
    ///
    /// This function will error if [`sled`] fails get, insert, or remove a key OR abort with a [`CompareAndSwapError`] if the current value does not match the expected value
    pub(super) fn apply(self, flush: bool) -> Result<(), TransactionError> {
        if self.swaps.is_empty() {
            return Ok(());
        }

        let (trees, swaps): (Vec<Tree>, Vec<Box<dyn GenericCompareAndSwap>>) = self
            .swaps
            .into_values()
            .map(|cas| (cas.tree(), cas))
            .unzip();

        trees.transaction(|tx_trees| {
            for (tree, cas) in tx_trees.iter().zip(swaps.iter()) {
                cas.apply(tree)
                    .map_err(ConflictableTransactionError::Abort)?;
                if flush {
                    tree.flush();
                }
            }
            Ok(())
        })?;
        Ok(())
    }
}