use std::{
collections::{HashMap, hash_map},
convert::Infallible,
fs, io,
path::PathBuf,
sync::{Arc, RwLock},
};
use sled::{
IVec, Tree,
transaction::{TransactionError as SledTransactionError, UnabortableTransactionError},
};
use thiserror::Error;
mod entry;
pub use entry::*;
mod id;
pub use id::*;
mod extensions;
pub use extensions::*;
mod transactions;
pub use transactions::*;
mod index;
pub use index::*;
pub mod caching;
pub mod writer;
#[derive(Debug, Error)]
pub enum DatabaseError {
#[error("IoError: {0}")]
Io(#[from] std::io::Error),
#[error("Sled error: {0}")]
Sled(#[from] sled::Error),
#[error("{0}")]
Transaction(#[from] TransactionError),
#[error("The database failed to be obtained after {0} retries")]
TooManyOpenRetries(usize),
}
#[derive(Debug, Error)]
pub enum TransactionError {
#[error("Sled error: {0}")]
Sled(#[from] sled::Error),
#[error(
"An internal compare and swap failed during a transaction. This variant will never be returned from a function"
)]
CompareAndSwapError,
#[error("An internal compare and swap was retried too many times")]
TooManyRetries,
#[error("A key was not found in the database when it was expected")]
MissingEntry,
#[error("A key was found in the database when it was not expected")]
AlreadyInDatabase,
#[error("A record was out of date. v{0} was expected but v{1} was found")]
OutdatedVesion(u32, u32),
}
#[derive(Debug, Error)]
pub enum CustomTransactionError<E> {
#[error("{0}")]
Transaction(TransactionError),
#[error("Transaction closure failed: {0}")]
Closure(E),
}
impl From<SledTransactionError<TransactionError>> for TransactionError {
fn from(value: SledTransactionError<TransactionError>) -> Self {
match value {
SledTransactionError::Abort(err) => err,
SledTransactionError::Storage(error) => Self::Sled(error),
}
}
}
impl From<UnabortableTransactionError> for TransactionError {
fn from(value: UnabortableTransactionError) -> Self {
match value {
UnabortableTransactionError::Conflict => TransactionError::CompareAndSwapError,
UnabortableTransactionError::Storage(error) => TransactionError::Sled(error),
}
}
}
impl<E> From<TransactionError> for CustomTransactionError<E> {
fn from(value: TransactionError) -> Self {
CustomTransactionError::Transaction(value)
}
}
impl From<CustomTransactionError<Infallible>> for TransactionError {
fn from(value: CustomTransactionError<Infallible>) -> Self {
match value {
CustomTransactionError::Transaction(transaction_error) => transaction_error,
CustomTransactionError::Closure(_) => unreachable!(),
}
}
}
pub struct Db<Inner: Database> {
inner: sled::Db,
trees: Arc<RwLock<HashMap<String, Tree>>>,
_marker: std::marker::PhantomData<Inner>,
}
impl<I: Database> Clone for Db<I> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
trees: self.trees.clone(),
_marker: std::marker::PhantomData,
}
}
}
impl<I: Database> std::ops::Deref for Db<I> {
type Target = sled::Db;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<I: Database> Db<I> {
pub fn open() -> Result<Self, DatabaseError> {
const { assert!(std::mem::size_of::<I>() == 0) }
let db = if let Some(path) = &I::path() {
sled::Config::new()
.mode(sled::Mode::HighThroughput)
.path(path)
.open()?
} else {
sled::Config::new()
.temporary(true)
.mode(sled::Mode::HighThroughput)
.open()?
};
Ok(Db {
inner: db,
trees: Arc::new(RwLock::new(HashMap::new())),
_marker: std::marker::PhantomData,
})
}
pub fn open_temp() -> Result<Self, DatabaseError> {
const { assert!(std::mem::size_of::<I>() == 0) }
let db = sled::Config::new()
.temporary(true)
.mode(sled::Mode::HighThroughput)
.open()?;
Ok(Db {
inner: db,
trees: Arc::new(RwLock::new(HashMap::new())),
_marker: std::marker::PhantomData,
})
}
pub fn tree(&self, name: String) -> Tree {
{
let read = self.trees.read().unwrap();
if let Some(tree) = read.get(&name) {
return tree.clone();
}
}
let mut write = self.trees.write().unwrap();
match write.entry(name.clone()) {
hash_map::Entry::Occupied(e) => e.get().clone(),
hash_map::Entry::Vacant(v) => {
let tree = self.open_tree(name).unwrap();
v.insert(tree.clone());
tree
}
}
}
pub fn index<'a, T: DatabaseEntry<DbInner = I>>(&'a self, name: &'static str) -> Index<'a, T> {
Index {
db: self,
tree: self.tree(format!("index_{name}")),
name,
}
}
pub fn delete(self) -> io::Result<()>
where
Self: Sized,
{
if let Some(path) = I::path()
&& let Err(err) = fs::remove_dir_all(path)
{
return Err(err);
}
Ok(())
}
pub fn transaction<T, E, F>(
&self,
flush: bool,
max_retries: Option<usize>,
f: F,
) -> Result<T, CustomTransactionError<E>>
where
F: Fn(&mut CompareAndSwapTransaction<I>) -> Result<T, CustomTransactionError<E>>,
{
for _ in 0..max_retries.unwrap_or(usize::MAX) {
let mut cas_tx = CompareAndSwapTransaction::with_db(self.clone());
let t = f(&mut cas_tx)?;
match cas_tx.apply(flush) {
Ok(()) => return Ok(t),
Err(TransactionError::CompareAndSwapError) => {
continue;
}
Err(err) => return Err(CustomTransactionError::Transaction(err)),
}
}
Err(CustomTransactionError::Transaction(
TransactionError::TooManyRetries,
))
}
}
impl<I: Database> Db<I> {
pub fn entry_tree<T: DatabaseEntry<DbInner = I>>(&self) -> Tree {
self.tree(format!("entry_{}", T::TREE_NAME))
}
pub fn iter_entries<T: DatabaseEntry<DbInner = I>>(
&self,
) -> impl Iterator<Item = Result<Entry<T>, TransactionError>> {
self.entry_tree::<T>().iter().map(Entry::from_sled_batch)
}
pub fn entry_count<T: DatabaseEntry<DbInner = I>>(&self) -> usize {
self.entry_tree::<T>().len()
}
}
pub trait Database: Sized + 'static + Send + Sync {
fn path() -> Option<PathBuf>;
}
#[macro_export]
macro_rules! define_db {
( $name:ident { $($body:tt)* } ) => {
pub enum $name {}
impl $crate::database::Database for $name {
$($body)*
}
};
}
pub(crate) fn deserialize_from_ivec<T: DatabaseEntry>(raw: IVec) -> Result<T, TransactionError> {
let version = u32::from_be_bytes(raw[0..4].try_into().expect("Corrupted database entry"));
if version != T::VERSION_NUMBER {
return Err(TransactionError::OutdatedVesion(T::VERSION_NUMBER, version));
}
match cbor4ii::serde::from_slice(&raw[4..]) {
Ok(t) => Ok(t),
Err(err) => {
panic!(
"An item in the database failed to deserialize, but the version number has not changed.\nError:\n{err}\nThis could be due to a deserialization bug, database corruption, or from an internal data change without updating the version"
)
}
}
}
#[cfg(test)]
mod tests;