use std::fs;
use std::sync::Arc;
use lmdb_zero as lmdb;
use lmdb_zero::traits::CreateCursor;
use lmdb_zero::LmdbResultExt;
use crate::grin_core::global;
use crate::grin_core::ser::{self, DeserializationMode, ProtocolVersion};
use crate::util::RwLock;
pub const ALLOC_CHUNK_SIZE_DEFAULT: usize = 134_217_728; pub const ALLOC_CHUNK_SIZE_DEFAULT_TEST: usize = 1_048_576; const RESIZE_PERCENT: f32 = 0.9;
const RESIZE_MIN_TARGET_PERCENT: f32 = 0.65;
#[derive(Clone, Eq, PartialEq, Debug, thiserror::Error)]
pub enum Error {
#[error("DB Not Found Error: {0}")]
NotFoundErr(String),
#[error("LMDB error: {0}")]
LmdbErr(lmdb::error::Error),
#[error("Serialization Error: {0}")]
SerErr(ser::Error),
#[error("File handling Error: {0}")]
FileErr(String),
#[error("Other Error: {0}")]
OtherErr(String),
}
impl From<lmdb::error::Error> for Error {
fn from(e: lmdb::error::Error) -> Error {
Error::LmdbErr(e)
}
}
impl From<ser::Error> for Error {
fn from(e: ser::Error) -> Error {
Error::SerErr(e)
}
}
pub fn option_to_not_found<T, F>(res: Result<Option<T>, Error>, field_name: F) -> Result<T, Error>
where
F: Fn() -> String,
{
match res {
Ok(None) => Err(Error::NotFoundErr(field_name())),
Ok(Some(o)) => Ok(o),
Err(e) => Err(e),
}
}
const DEFAULT_DB_VERSION: ProtocolVersion = ProtocolVersion(3);
pub struct Store {
env: Arc<lmdb::Environment>,
db: Arc<RwLock<Option<Arc<lmdb::Database<'static>>>>>,
name: String,
version: ProtocolVersion,
alloc_chunk_size: usize,
}
impl Store {
pub fn new(
root_path: &str,
env_name: Option<&str>,
db_name: Option<&str>,
max_readers: Option<u32>,
) -> Result<Store, Error> {
let name = match env_name {
Some(n) => n.to_owned(),
None => "lmdb".to_owned(),
};
let db_name = match db_name {
Some(n) => n.to_owned(),
None => "lmdb".to_owned(),
};
let full_path = [root_path.to_owned(), name].join("/");
fs::create_dir_all(&full_path).map_err(|e| {
Error::FileErr(format!(
"Unable to create directory 'db_root' to store chain_data: {:?}",
e
))
})?;
let mut env_builder = lmdb::EnvBuilder::new()?;
env_builder.set_maxdbs(8)?;
if let Some(max_readers) = max_readers {
env_builder.set_maxreaders(max_readers)?;
}
let alloc_chunk_size = match global::is_production_mode() {
true => ALLOC_CHUNK_SIZE_DEFAULT,
false => ALLOC_CHUNK_SIZE_DEFAULT_TEST,
};
let env = unsafe { env_builder.open(&full_path, lmdb::open::NOTLS, 0o600)? };
debug!("DB Mapsize for {} is {}", full_path, env.info()?.mapsize);
let res = Store {
env: Arc::new(env),
db: Arc::new(RwLock::new(None)),
name: db_name,
version: DEFAULT_DB_VERSION,
alloc_chunk_size,
};
{
let mut w = res.db.write();
*w = Some(Arc::new(lmdb::Database::open(
res.env.clone(),
Some(&res.name),
&lmdb::DatabaseOptions::new(lmdb::db::CREATE),
)?));
}
Ok(res)
}
pub fn with_version(&self, version: ProtocolVersion) -> Store {
let alloc_chunk_size = match global::is_production_mode() {
true => ALLOC_CHUNK_SIZE_DEFAULT,
false => ALLOC_CHUNK_SIZE_DEFAULT_TEST,
};
Store {
env: self.env.clone(),
db: self.db.clone(),
name: self.name.clone(),
version,
alloc_chunk_size,
}
}
pub fn protocol_version(&self) -> ProtocolVersion {
self.version
}
pub fn open(&self) -> Result<(), Error> {
let mut w = self.db.write();
*w = Some(Arc::new(lmdb::Database::open(
self.env.clone(),
Some(&self.name),
&lmdb::DatabaseOptions::new(lmdb::db::CREATE),
)?));
Ok(())
}
pub fn needs_resize(&self) -> Result<bool, Error> {
let env_info = self.env.info()?;
let stat = self.env.stat()?;
let size_used = stat.psize as usize * env_info.last_pgno;
trace!("DB map size: {}", env_info.mapsize);
trace!("Space used: {}", size_used);
trace!("Space remaining: {}", env_info.mapsize - size_used);
let resize_percent = RESIZE_PERCENT;
trace!(
"Percent used: {:.*} Percent threshold: {:.*}",
4,
size_used as f64 / env_info.mapsize as f64,
4,
resize_percent
);
if size_used as f32 / env_info.mapsize as f32 > resize_percent
|| env_info.mapsize < self.alloc_chunk_size
{
trace!("Resize threshold met (percent-based)");
Ok(true)
} else {
trace!("Resize threshold not met (percent-based)");
Ok(false)
}
}
pub fn do_resize(&self) -> Result<(), Error> {
let env_info = self.env.info()?;
let stat = self.env.stat()?;
let size_used = stat.psize as usize * env_info.last_pgno;
let new_mapsize = if env_info.mapsize < self.alloc_chunk_size {
self.alloc_chunk_size
} else {
let mut tot = env_info.mapsize;
while size_used as f32 / tot as f32 > RESIZE_MIN_TARGET_PERCENT {
tot += self.alloc_chunk_size;
}
tot
};
let mut w = self.db.write();
*w = None;
unsafe {
self.env.set_mapsize(new_mapsize)?;
}
*w = Some(Arc::new(lmdb::Database::open(
self.env.clone(),
Some(&self.name),
&lmdb::DatabaseOptions::new(lmdb::db::CREATE),
)?));
info!(
"Resized database from {} to {}",
env_info.mapsize, new_mapsize
);
Ok(())
}
pub fn get_with<F, T>(
&self,
key: &[u8],
access: &lmdb::ConstAccessor<'_>,
db: &lmdb::Database<'_>,
deserialize: F,
) -> Result<Option<T>, Error>
where
F: Fn(&[u8], &[u8]) -> Result<T, Error>,
{
let res: Option<&[u8]> = access.get(db, key).to_opt()?;
match res {
None => Ok(None),
Some(res) => deserialize(key, res).map(Some),
}
}
pub fn get_ser<T: ser::Readable>(
&self,
key: &[u8],
deser_mode: Option<DeserializationMode>,
) -> Result<Option<T>, Error> {
let lock = self.db.read();
let db = lock
.as_ref()
.ok_or_else(|| Error::NotFoundErr("chain db is None".to_string()))?;
let txn = lmdb::ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let d = match deser_mode {
Some(d) => d,
_ => DeserializationMode::default(),
};
self.get_with(key, &access, &db, |_, mut data| {
ser::deserialize(&mut data, self.protocol_version(), d).map_err(From::from)
})
}
pub fn exists(&self, key: &[u8]) -> Result<bool, Error> {
let lock = self.db.read();
let db = lock
.as_ref()
.ok_or_else(|| Error::NotFoundErr("chain db is None".to_string()))?;
let txn = lmdb::ReadTransaction::new(self.env.clone())?;
let access = txn.access();
let res: Option<&lmdb::Ignore> = access.get(db, key).to_opt()?;
Ok(res.is_some())
}
pub fn iter<F, T>(&self, prefix: &[u8], deserialize: F) -> Result<PrefixIterator<F, T>, Error>
where
F: Fn(&[u8], &[u8]) -> Result<T, Error>,
{
let lock = self.db.read();
let db = lock
.as_ref()
.ok_or_else(|| Error::NotFoundErr("chain db is None".to_string()))?;
let tx = Arc::new(lmdb::ReadTransaction::new(self.env.clone())?);
let cursor = Arc::new(tx.cursor(db.clone())?);
Ok(PrefixIterator::new(tx, cursor, prefix, deserialize))
}
pub fn batch(&self) -> Result<Batch<'_>, Error> {
if self.needs_resize()? {
self.do_resize()?;
}
let tx = lmdb::WriteTransaction::new(self.env.clone())?;
Ok(Batch { store: self, tx })
}
}
pub struct Batch<'a> {
store: &'a Store,
tx: lmdb::WriteTransaction<'a>,
}
impl<'a> Batch<'a> {
pub fn put(&self, key: &[u8], value: &[u8]) -> Result<(), Error> {
let lock = self.store.db.read();
let db = lock
.as_ref()
.ok_or_else(|| Error::NotFoundErr("chain db is None".to_string()))?;
self.tx
.access()
.put(db, key, value, lmdb::put::Flags::empty())?;
Ok(())
}
pub fn put_ser<W: ser::Writeable>(&self, key: &[u8], value: &W) -> Result<(), Error> {
self.put_ser_with_version(key, value, self.store.protocol_version())
}
pub fn protocol_version(&self) -> ProtocolVersion {
self.store.protocol_version()
}
pub fn put_ser_with_version<W: ser::Writeable>(
&self,
key: &[u8],
value: &W,
version: ProtocolVersion,
) -> Result<(), Error> {
let ser_value = ser::ser_vec(value, version);
match ser_value {
Ok(data) => self.put(key, &data),
Err(err) => Err(err.into()),
}
}
pub fn get_with<F, T>(&self, key: &[u8], deserialize: F) -> Result<Option<T>, Error>
where
F: Fn(&[u8], &[u8]) -> Result<T, Error>,
{
let access = self.tx.access();
let lock = self.store.db.read();
let db = lock
.as_ref()
.ok_or_else(|| Error::NotFoundErr("chain db is None".to_string()))?;
self.store.get_with(key, &access, &db, deserialize)
}
pub fn exists(&self, key: &[u8]) -> Result<bool, Error> {
let access = self.tx.access();
let lock = self.store.db.read();
let db = lock
.as_ref()
.ok_or_else(|| Error::NotFoundErr("chain db is None".to_string()))?;
let res: Option<&lmdb::Ignore> = access.get(db, key).to_opt()?;
Ok(res.is_some())
}
pub fn iter<F, T>(&self, prefix: &[u8], deserialize: F) -> Result<PrefixIterator<F, T>, Error>
where
F: Fn(&[u8], &[u8]) -> Result<T, Error>,
{
self.store.iter(prefix, deserialize)
}
pub fn get_ser<T: ser::Readable>(
&self,
key: &[u8],
deser_mode: Option<DeserializationMode>,
) -> Result<Option<T>, Error> {
let d = match deser_mode {
Some(d) => d,
_ => DeserializationMode::default(),
};
self.get_with(key, |_, mut data| {
match ser::deserialize(&mut data, self.protocol_version(), d) {
Ok(res) => Ok(res),
Err(e) => Err(From::from(e)),
}
})
}
pub fn delete(&self, key: &[u8]) -> Result<(), Error> {
let lock = self.store.db.read();
let db = lock
.as_ref()
.ok_or_else(|| Error::NotFoundErr("chain db is None".to_string()))?;
self.tx.access().del_key(db, key)?;
Ok(())
}
pub fn commit(self) -> Result<(), Error> {
self.tx.commit()?;
Ok(())
}
pub fn child(&mut self) -> Result<Batch<'_>, Error> {
Ok(Batch {
store: self.store,
tx: self.tx.child_tx()?,
})
}
}
pub struct PrefixIterator<F, T>
where
F: Fn(&[u8], &[u8]) -> Result<T, Error>,
{
tx: Arc<lmdb::ReadTransaction<'static>>,
cursor: Arc<lmdb::Cursor<'static, 'static>>,
seek: bool,
prefix: Vec<u8>,
deserialize: F,
}
impl<F, T> Iterator for PrefixIterator<F, T>
where
F: Fn(&[u8], &[u8]) -> Result<T, Error>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
let access = self.tx.access();
let cursor = Arc::get_mut(&mut self.cursor).expect("failed to get cursor");
let kv: Result<(&[u8], &[u8]), _> = if self.seek {
cursor.next(&access)
} else {
self.seek = true;
cursor.seek_range_k(&access, &self.prefix[..])
};
kv.ok()
.filter(|(k, _)| k.starts_with(self.prefix.as_slice()))
.map(|(k, v)| match (self.deserialize)(k, v) {
Ok(v) => Some(v),
Err(_) => None,
})
.flatten()
}
}
impl<F, T> PrefixIterator<F, T>
where
F: Fn(&[u8], &[u8]) -> Result<T, Error>,
{
pub fn new(
tx: Arc<lmdb::ReadTransaction<'static>>,
cursor: Arc<lmdb::Cursor<'static, 'static>>,
prefix: &[u8],
deserialize: F,
) -> PrefixIterator<F, T> {
PrefixIterator {
tx,
cursor,
seek: false,
prefix: prefix.to_vec(),
deserialize,
}
}
}