use std::{convert::Infallible, fs, ops::Deref, path::Path, thread, time::Duration};
use serde::{Deserialize, Serialize};
use sled::{
IVec,
transaction::{TransactionError as SledTransactionError, UnabortableTransactionError},
};
pub use sled::{Db, Transactional, Tree, transaction::ConflictableTransactionResult};
use thiserror::Error;
use crate::trace;
pub mod writer;
mod sled_ops;
pub(crate) use sled_ops::*;
mod database_entry;
pub use database_entry::*;
mod transaction_args;
pub use transaction_args::*;
#[derive(Debug, Error)]
pub enum DatabaseError {
#[error("IoError: {0}")]
Io(#[from] std::io::Error),
#[error("Sled error: {0}")]
Sled(#[from] sled::Error),
#[error("The database failed to be obtained after {0} retries")]
TooManyOpenRetries(usize),
#[error("A key was not found in the database when it was expected")]
MissingEntry,
#[error("A key was found in the database when it was not expected")]
AlreadyInDatabase,
#[error("A record was out of date. v{0} was expected but v{1} was found")]
OutdatedVesion(u32, u32),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Tried to write or overwrite a read-only value")]
ReadOnly,
}
#[derive(Debug, Error)]
pub enum TransactionError {
#[error("{0}")]
Database(#[from] DatabaseError),
#[error("Sled error: {0}")]
Sled(#[from] sled::Error),
#[error(
"An internal compare and swap failed during a transaction. This variant will never be returned from a function"
)]
CompareAndSwapError,
#[error("An internal compare and swap was retried too many times")]
TooManyRetries,
}
#[derive(Debug, Error)]
pub enum CustomTransactionError<E> {
#[error("{0}")]
Transaction(TransactionError),
#[error("Transaction closure failed: {0}")]
Closure(E),
}
impl From<SledTransactionError<TransactionError>> for TransactionError {
fn from(value: SledTransactionError<TransactionError>) -> Self {
match value {
SledTransactionError::Abort(err) => err,
SledTransactionError::Storage(error) => Self::Sled(error),
}
}
}
impl From<UnabortableTransactionError> for TransactionError {
fn from(value: UnabortableTransactionError) -> Self {
match value {
UnabortableTransactionError::Conflict => TransactionError::CompareAndSwapError,
UnabortableTransactionError::Storage(error) => TransactionError::Sled(error),
}
}
}
impl<E> From<TransactionError> for CustomTransactionError<E> {
fn from(value: TransactionError) -> Self {
CustomTransactionError::Transaction(value)
}
}
impl From<CustomTransactionError<Infallible>> for TransactionError {
fn from(value: CustomTransactionError<Infallible>) -> Self {
match value {
CustomTransactionError::Transaction(transaction_error) => transaction_error,
CustomTransactionError::Closure(_) => unreachable!(),
}
}
}
pub trait Database: Deref<Target = Db> {
const RETRY_MAX_ATTEMPTS: Option<usize>;
const RETRY_DURATION: Duration;
fn new(db: Db) -> Self;
fn path() -> &'static Path;
fn pre_open(_db: &Self) -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
fn open() -> Result<Self, DatabaseError>
where
Self: Sized,
{
let path = Self::path();
for _ in 0..Self::RETRY_MAX_ATTEMPTS.unwrap_or(usize::MAX) {
match sled::open(path) {
Ok(db) => {
let db = Self::new(db);
Self::pre_open(&db).unwrap();
return Ok(db);
}
Err(sled::Error::Io(err))
if matches!(err.kind(), std::io::ErrorKind::Other)
&& err.to_string().contains("could not acquire lock on ") =>
{
trace!(
"Database '{path}' is in use. Waiting",
path = path.display()
);
thread::sleep(Self::RETRY_DURATION);
continue;
}
Err(err) => return Err(err.into()),
}
}
Err(DatabaseError::TooManyOpenRetries(
Self::RETRY_MAX_ATTEMPTS.unwrap_or(usize::MAX),
))
}
fn delete() -> Result<(), DatabaseError>
where
Self: Sized,
{
Self::open()?.flush()?;
if let Err(err) = fs::remove_dir_all(Self::path())
&& !matches!(err.kind(), std::io::ErrorKind::NotFound)
{
return Err(err.into());
};
Ok(())
}
fn flush() -> Result<(), DatabaseError>
where
Self: Sized,
{
Self::open()?.flush()?;
Ok(())
}
}
pub(crate) fn deserialize_from_ivec<T: DatabaseEntry>(raw: IVec) -> Result<T, DatabaseError> {
match ciborium::from_reader(&raw[..]) {
Ok(t) => Ok(t),
Err(_) => {
#[derive(Deserialize)]
struct Version {
version: u32,
}
let v: Version = ciborium::from_reader(&raw[..])
.expect("Record was out of date and has no version struct");
Err(DatabaseError::OutdatedVesion(T::VERSION_NUMBER, v.version))
}
}
}
pub(crate) fn serialize_to_ivec<T: Serialize>(item: &T) -> IVec {
let mut buf = Vec::new();
ciborium::into_writer(item, &mut buf)
.expect("Ciborium failed to serialize. This cannot happen unless a serializer failed");
IVec::from(buf)
}
pub trait EntryId: Deref<Target = [u8; 32]> + Copy + Eq {
type IdDb: Database;
type Entry: DatabaseEntry<EntryDb = Self::IdDb>;
fn as_bytes(&self) -> &[u8; 32] {
&*self
}
}