use sled::{
transaction::{ConflictableTransactionError, TransactionError},
IVec,
};
use thiserror::Error;
use crate::{
storage::{HeightStorage, InvoiceStorage, OutputId, OutputKeyStorage, OutputPubKey, Storage},
Invoice, InvoiceId, SubIndex,
};
pub struct Sled {
invoices: sled::Tree,
output_keys: sled::Tree,
height: sled::Tree,
}
impl Sled {
pub fn new(
path: &str,
invoice_tree: &str,
output_key_tree: &str,
height_tree: &str,
) -> Result<Sled, SledStorageError> {
let db = sled::Config::default()
.path(path)
.flush_every_ms(None)
.open()
.map_err(DatabaseError::from)?;
let invoices = db.open_tree(invoice_tree).map_err(DatabaseError::from)?;
let output_keys = db.open_tree(output_key_tree).map_err(DatabaseError::from)?;
let height = db.open_tree(height_tree).map_err(DatabaseError::from)?;
invoices.set_merge_operator(Sled::update_merge);
Ok(Sled {
invoices,
output_keys,
height,
})
}
fn update_merge(_key: &[u8], old_value: Option<&[u8]>, new_value: &[u8]) -> Option<Vec<u8>> {
if old_value.is_some() {
Some(new_value.to_vec())
} else {
None
}
}
}
impl InvoiceStorage for Sled {
type Error = SledStorageError;
fn insert(&mut self, invoice: Invoice) -> Result<(), SledStorageError> {
let invoice_id = invoice.id();
let key = bincode::encode_to_vec(invoice_id, bincode::config::standard())?;
let value = bincode::encode_to_vec(invoice, bincode::config::standard())?;
match self
.invoices
.compare_and_swap(key, None::<IVec>, Some(value))
.map_err(DatabaseError::from)?
{
Ok(()) => Ok(()),
Err(_) => Err(SledStorageError::DuplicateInvoiceId),
}
}
fn remove(&mut self, invoice_id: InvoiceId) -> Result<Option<Invoice>, SledStorageError> {
let key = bincode::encode_to_vec(invoice_id, bincode::config::standard())?;
let old = self.invoices.remove(key).transpose();
old.map(|ivec_or_err| {
Ok(bincode::decode_from_slice(
&ivec_or_err.map_err(DatabaseError::from)?,
bincode::config::standard(),
)?
.0)
})
.transpose()
}
fn update(&mut self, invoice: Invoice) -> Result<Option<Invoice>, SledStorageError> {
let key = bincode::encode_to_vec(invoice.id(), bincode::config::standard())?;
let new_ivec = bincode::encode_to_vec(invoice, bincode::config::standard())?;
let maybe_old = self
.invoices
.fetch_and_update(key, move |old| {
if old.is_some() {
Some(new_ivec.clone())
} else {
None
}
})
.map_err(DatabaseError::from)?;
match maybe_old {
Some(ivec) => Ok(Some(
bincode::decode_from_slice(&ivec, bincode::config::standard())?.0,
)),
None => Ok(None),
}
}
fn get(&self, invoice_id: InvoiceId) -> Result<Option<Invoice>, SledStorageError> {
let key = bincode::encode_to_vec(invoice_id, bincode::config::standard())?;
let current = self.invoices.get(key).transpose();
current
.map(|ivec_or_err| {
Ok(bincode::decode_from_slice(
&ivec_or_err.map_err(DatabaseError::from)?,
bincode::config::standard(),
)?
.0)
})
.transpose()
}
fn get_ids(&self) -> Result<Vec<InvoiceId>, SledStorageError> {
let current = self
.invoices
.iter()
.keys()
.collect::<Result<Vec<IVec>, sled::Error>>()
.map_err(DatabaseError::from)?;
current
.iter()
.map(|ivec| Ok(bincode::decode_from_slice(ivec, bincode::config::standard())?.0))
.collect::<Result<Vec<InvoiceId>, SledStorageError>>()
}
fn contains_sub_index(&self, sub_index: SubIndex) -> Result<bool, SledStorageError> {
let key = bincode::encode_to_vec(sub_index, bincode::config::standard())?;
Ok(self.invoices.scan_prefix(key).next().is_some())
}
fn try_for_each<F>(&self, mut f: F) -> Result<(), Self::Error>
where
F: FnMut(Result<Invoice, Self::Error>) -> Result<(), Self::Error>,
{
self.invoices.iter().try_for_each(move |row| {
let invoice_or_err = match row {
Ok((_id, ivec)) => bincode::decode_from_slice(&ivec, bincode::config::standard())
.map(|v| v.0)
.map_err(SledStorageError::Deserialize),
Err(e) => Err(SledStorageError::Database(e.into())),
};
f(invoice_or_err)
})
}
fn is_empty(&self) -> Result<bool, SledStorageError> {
Ok(self.invoices.is_empty())
}
}
impl OutputKeyStorage for Sled {
type Error = SledStorageError;
fn insert(&mut self, key: OutputPubKey, output_id: OutputId) -> Result<(), Self::Error> {
let result = self.output_keys.transaction(move |tx| {
let value =
bincode::encode_to_vec(output_id, bincode::config::standard()).map_err(|e| {
ConflictableTransactionError::Abort(Box::new(SledStorageError::Serialize(e)))
})?;
match tx.insert(&key.0, value) {
Ok(None) => Ok(()),
Ok(Some(_)) => Err(ConflictableTransactionError::Abort(Box::new(
SledStorageError::DuplicateOutputKey,
))),
Err(e) => Err(e)?,
}
});
Ok(result.map_err(DatabaseError::from)?)
}
fn get(&self, key: OutputPubKey) -> Result<Option<OutputId>, Self::Error> {
let current = self.output_keys.get(key).transpose();
current
.map(|ivec_or_err| {
Ok(bincode::decode_from_slice(
&ivec_or_err.map_err(DatabaseError::from)?,
bincode::config::standard(),
)?
.0)
})
.transpose()
}
}
impl HeightStorage for Sled {
type Error = SledStorageError;
fn upsert(&mut self, height: u64) -> Result<Option<u64>, Self::Error> {
let encoded_height = bincode::encode_to_vec(height, bincode::config::standard())?;
let maybe_ivec = self
.height
.insert("height", encoded_height)
.map_err(DatabaseError::from)?;
let old_height = maybe_ivec
.map(|ivec| bincode::decode_from_slice(&ivec, bincode::config::standard()))
.transpose()?
.map(|(h, _)| h);
Ok(old_height)
}
fn get(&self) -> Result<Option<u64>, Self::Error> {
let maybe_ivec = self.height.get("height").map_err(DatabaseError::from)?;
let height = maybe_ivec
.map(|ivec| bincode::decode_from_slice(&ivec, bincode::config::standard()))
.transpose()?
.map(|(h, _)| h);
Ok(height)
}
}
impl Storage for Sled {
type Error = SledStorageError;
fn flush(&self) -> Result<(), SledStorageError> {
self.invoices.flush().map_err(DatabaseError::from)?;
self.output_keys.flush().map_err(DatabaseError::from)?;
self.height.flush().map_err(DatabaseError::from)?;
Ok(())
}
}
#[derive(Error, Debug)]
pub enum SledStorageError {
#[error("database error: {0}")]
Database(#[from] DatabaseError),
#[error("duplicate invoice ID")]
DuplicateInvoiceId,
#[error("duplicate output public key")]
DuplicateOutputKey,
#[error("serialization error: {0}")]
Serialize(#[from] bincode::error::EncodeError),
#[error("deserialization error: {0}")]
Deserialize(#[from] bincode::error::DecodeError),
}
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("internal error: {0}")]
General(#[from] sled::Error),
#[error("transaction error: {0}")]
Transaction(#[from] TransactionError<Box<SledStorageError>>),
}