use std::{
any::{Any, TypeId},
collections::HashMap,
convert::Infallible,
fs,
ops::Deref,
path::{Path, PathBuf},
sync::{Arc, LazyLock, OnceLock, RwLock, Weak},
thread,
time::Duration,
};
pub use sled::{Db as SledDb, Transactional, Tree, transaction::ConflictableTransactionResult};
use sled::{
IVec,
transaction::{TransactionError as SledTransactionError, UnabortableTransactionError},
};
use thiserror::Error;
pub mod writer;
mod sled_ops;
pub(crate) use sled_ops::*;
mod database_entry;
pub use database_entry::*;
mod transaction_args;
pub use transaction_args::*;
pub mod caching;
#[derive(Debug, Error)]
pub enum DatabaseError {
#[error("IoError: {0}")]
Io(#[from] std::io::Error),
#[error("pre_open() function failed with error: {0}")]
PreOpenError(Box<dyn std::error::Error + Send + Sync>),
#[error("Sled error: {0}")]
Sled(#[from] sled::Error),
#[error("{0}")]
Transaction(#[from] TransactionError),
#[error("The database failed to be obtained after {0} retries")]
TooManyOpenRetries(usize),
}
#[derive(Debug, Error)]
pub enum TransactionError {
#[error("Sled error: {0}")]
Sled(#[from] sled::Error),
#[error(
"An internal compare and swap failed during a transaction. This variant will never be returned from a function"
)]
CompareAndSwapError,
#[error("An internal compare and swap was retried too many times")]
TooManyRetries,
#[error("A key was not found in the database when it was expected")]
MissingEntry,
#[error("A key was found in the database when it was not expected")]
AlreadyInDatabase,
#[error("A record was out of date. v{0} was expected but v{1} was found")]
OutdatedVesion(u32, u32),
#[error("Tried to write or overwrite a read-only value")]
ReadOnly,
}
#[derive(Debug, Error)]
pub enum CustomTransactionError<E> {
#[error("{0}")]
Transaction(TransactionError),
#[error("Transaction closure failed: {0}")]
Closure(E),
}
impl From<SledTransactionError<TransactionError>> for TransactionError {
fn from(value: SledTransactionError<TransactionError>) -> Self {
match value {
SledTransactionError::Abort(err) => err,
SledTransactionError::Storage(error) => Self::Sled(error),
}
}
}
impl From<UnabortableTransactionError> for TransactionError {
fn from(value: UnabortableTransactionError) -> Self {
match value {
UnabortableTransactionError::Conflict => TransactionError::CompareAndSwapError,
UnabortableTransactionError::Storage(error) => TransactionError::Sled(error),
}
}
}
impl<E> From<TransactionError> for CustomTransactionError<E> {
fn from(value: TransactionError) -> Self {
CustomTransactionError::Transaction(value)
}
}
impl From<CustomTransactionError<Infallible>> for TransactionError {
fn from(value: CustomTransactionError<Infallible>) -> Self {
match value {
CustomTransactionError::Transaction(transaction_error) => transaction_error,
CustomTransactionError::Closure(_) => unreachable!(),
}
}
}
#[derive(Clone)]
pub struct DbHandle<Db: Database> {
inner: Arc<Db>,
}
impl<Db: Database> Deref for DbHandle<Db> {
type Target = Db;
fn deref(&self) -> &Self::Target {
&*self.inner
}
}
impl<Db: Database> DbHandle<Db> {
pub fn open() -> Result<Self, DatabaseError> {
let type_id = TypeId::of::<Self>();
{
let cache = DB_CACHE.read().unwrap();
if let Some(entry) = cache.get(&type_id)
&& let Some(upgrade) = entry.upgrade()
{
return Ok(DbHandle {
inner: upgrade.downcast::<Db>().unwrap(),
});
}
}
let mut cache = DB_CACHE.write().unwrap();
if let Some(entry) = cache.get(&type_id)
&& let Some(upgrade) = entry.upgrade()
{
return Ok(DbHandle {
inner: upgrade.downcast::<Db>().unwrap(),
});
}
let path = Db::__path();
for _ in 0..Db::RETRY_MAX_ATTEMPTS.unwrap_or(usize::MAX) {
match sled::open(path) {
Ok(db) => {
let db = Db::new(db);
Db::pre_open(&db).map_err(DatabaseError::PreOpenError)?;
let inner = Arc::new(db);
cache.insert(
type_id,
Arc::downgrade(&(inner.clone() as Arc<dyn Any + Send + Sync>)),
);
return Ok(DbHandle { inner });
}
Err(sled::Error::Io(err))
if matches!(err.kind(), std::io::ErrorKind::Other)
&& err.to_string().contains("could not acquire lock on ") =>
{
thread::sleep(Db::RETRY_DURATION);
continue;
}
Err(err) => return Err(err.into()),
}
}
Err(DatabaseError::TooManyOpenRetries(
Db::RETRY_MAX_ATTEMPTS.unwrap_or(usize::MAX),
))
}
pub fn delete(&self) -> Result<(), DatabaseError>
where
Self: Sized,
{
if let Err(err) = fs::remove_dir_all(Db::__path())
&& !matches!(err.kind(), std::io::ErrorKind::NotFound)
{
return Err(err.into());
};
Ok(())
}
pub fn flush(&self) -> Result<(), DatabaseError>
where
Self: Sized,
{
self.inner.db().flush()?;
Ok(())
}
}
pub trait Database: Clone + Sized + 'static + Send + Sync {
const RETRY_MAX_ATTEMPTS: Option<usize>;
const RETRY_DURATION: Duration;
fn new(db: SledDb) -> Self;
fn db(&self) -> &SledDb;
fn path() -> PathBuf;
fn __path() -> &'static Path {
static STABLE_PATH: OnceLock<PathBuf> = OnceLock::new();
STABLE_PATH.get_or_init(Self::path)
}
fn pre_open(_db: &Self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
type DatabaseCacheRegistry = HashMap<TypeId, Weak<dyn Any + Send + Sync>>;
static DB_CACHE: LazyLock<RwLock<DatabaseCacheRegistry>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
pub(crate) fn deserialize_from_ivec<T: DatabaseEntry>(raw: IVec) -> Result<T, TransactionError> {
match postcard::from_bytes(&raw[5..]) {
Ok(t) => Ok(t),
Err(err) => {
let version =
u32::from_be_bytes(raw[0..4].try_into().expect("Corrupted database entry"));
if version == T::VERSION_NUMBER {
panic!(
"An item in the database failed to deserialize, but the version number has not changed.\nError:\n{err}\nThis could be due to a deserialization bug, database corruption, or from an internal data change without updating the version"
)
} else {
Err(TransactionError::OutdatedVesion(T::VERSION_NUMBER, version))
}
}
}
}
pub(crate) fn serialize_to_ivec<T: DatabaseEntry>(item: &T) -> IVec {
let version: [u8; 4] = T::VERSION_NUMBER.to_be_bytes();
let mut buffer = Vec::from(version);
buffer.push(item.read_only() as u8);
let buffer = postcard::to_extend(item, buffer)
.expect("Postcard failed to serialize. This cannot happen unless a serializer failed");
IVec::from(buffer)
}
pub trait EntryId: Deref<Target = [u8; 32]> + Copy + Eq {
type IdDb: Database;
type Entry: DatabaseEntry<Db = Self::IdDb>;
fn as_bytes(&self) -> &[u8; 32] {
self
}
}