use std::{marker::PhantomData, sync::Arc};
use lmdb_zero::traits::AsLmdbBytes;
use serde::{Serialize, de::DeserializeOwned};
use crate::{
key_val_store::{
KeyValStoreError,
key_val_store::{IterationResult, KeyValueStore},
},
lmdb_store::LMDBDatabase,
};
pub struct LMDBWrapper<K, V> {
inner: Arc<LMDBDatabase>,
_k: PhantomData<K>,
_v: PhantomData<V>,
}
impl<K, V> LMDBWrapper<K, V> {
pub fn new(db: Arc<LMDBDatabase>) -> LMDBWrapper<K, V> {
LMDBWrapper {
inner: db,
_k: PhantomData,
_v: PhantomData,
}
}
pub fn inner(&self) -> Arc<LMDBDatabase> {
Arc::clone(&self.inner)
}
}
impl<K, V> KeyValueStore<K, V> for LMDBWrapper<K, V>
where
K: AsLmdbBytes + DeserializeOwned,
V: Serialize + DeserializeOwned,
{
fn insert(&self, key: K, value: V) -> Result<(), KeyValStoreError> {
self.inner.insert::<K, V>(&key, &value).map_err(Into::into)
}
fn get(&self, key: &K) -> Result<Option<V>, KeyValStoreError>
where for<'t> V: serde::de::DeserializeOwned {
self.inner.get::<K, V>(key).map_err(Into::into)
}
fn get_many(&self, keys: &[K]) -> Result<Vec<V>, KeyValStoreError>
where for<'t> V: serde::de::DeserializeOwned {
self.inner
.with_read_transaction(|access| {
keys.iter()
.filter_map(|k| match access.get::<K, V>(k) {
Ok(Some(v)) => Some(Ok(v)),
Ok(None) => None,
Err(e) => Some(Err(e)),
})
.collect::<Result<Vec<_>, _>>()
})?
.map_err(Into::into)
}
fn size(&self) -> Result<usize, KeyValStoreError> {
self.inner.len().map_err(Into::into)
}
fn for_each<F>(&self, f: F) -> Result<(), KeyValStoreError>
where F: FnMut(Result<(K, V), KeyValStoreError>) -> IterationResult {
self.inner.for_each::<K, V, F>(f).map_err(Into::into)
}
fn exists(&self, key: &K) -> Result<bool, KeyValStoreError> {
self.inner.contains_key::<K>(key).map_err(Into::into)
}
fn delete(&self, key: &K) -> Result<(), KeyValStoreError> {
self.inner.remove::<K>(key).map_err(Into::into)
}
}
#[cfg(test)]
mod test {
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use super::*;
use crate::lmdb_store::{LMDBBuilder, LMDBConfig, LMDBError, LMDBStore};
fn get_path(name: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("tests/data");
path.push(name);
path.to_str().unwrap().to_string()
}
fn init_datastore(name: &str) -> Result<LMDBStore, LMDBError> {
let path = get_path(name);
std::fs::create_dir_all(&path).unwrap_or_default();
LMDBBuilder::new()
.set_path(&path)
.set_env_config(LMDBConfig::default())
.set_max_number_of_databases(2)
.add_database(name, lmdb_zero::db::CREATE)
.build()
}
fn clean_up_datastore(name: &str) {
std::fs::remove_dir_all(get_path(name)).unwrap();
}
#[test]
fn test_lmdb_kvstore() {
let database_name = "test_lmdb_kvstore"; {
let datastore = init_datastore(database_name).unwrap();
let db = datastore.get_handle(database_name).unwrap();
let db = LMDBWrapper::new(Arc::new(db));
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct Foo {
value: String,
}
let key1 = 1;
let key2 = 2;
let key3 = 3;
let key4 = 4;
let val1 = Foo {
value: "one".to_string(),
};
let val2 = Foo {
value: "two".to_string(),
};
let val3 = Foo {
value: "three".to_string(),
};
db.insert(1, val1.clone()).unwrap();
db.insert(2, val2.clone()).unwrap();
db.insert(3, val3.clone()).unwrap();
assert_eq!(db.get(&1).unwrap().unwrap(), val1);
assert_eq!(db.get(&2).unwrap().unwrap(), val2);
assert_eq!(db.get(&3).unwrap().unwrap(), val3);
assert!(db.get(&4).unwrap().is_none());
assert_eq!(db.size().unwrap(), 3);
assert!(db.exists(&key1).unwrap());
assert!(db.exists(&key2).unwrap());
assert!(db.exists(&key3).unwrap());
assert!(!db.exists(&key4).unwrap());
db.delete(&key2).unwrap();
assert_eq!(db.get(&key1).unwrap().unwrap(), val1);
assert!(db.get(&key2).unwrap().is_none());
assert_eq!(db.get(&key3).unwrap().unwrap(), val3);
assert!(db.get(&key4).unwrap().is_none());
assert_eq!(db.size().unwrap(), 2);
assert!(db.exists(&key1).unwrap());
assert!(!db.exists(&key2).unwrap());
assert!(db.exists(&key3).unwrap());
assert!(!db.exists(&key4).unwrap());
let mut key1_found = false;
let mut key3_found = false;
let _res = db.for_each(|pair| {
let (key, val) = pair.unwrap();
if key == key1 {
key1_found = true;
assert_eq!(val, val1);
} else if key == key3 {
key3_found = true;
assert_eq!(val, val3);
} else {
panic!("Should not be here")
}
IterationResult::Continue
});
assert!(key1_found);
assert!(key3_found);
}
clean_up_datastore(database_name); }
}