use std::{collections::HashMap, hash::Hash, sync::RwLock};
use crate::key_val_store::{
error::KeyValStoreError,
key_val_store::{IterationResult, KeyValueStore},
};
#[derive(Default)]
pub struct HashmapDatabase<K: Eq + Hash, V> {
db: RwLock<HashMap<K, V>>,
}
impl<K: Clone + Eq + Hash, V: Clone> HashmapDatabase<K, V> {
pub fn new() -> Self {
Self {
db: RwLock::new(HashMap::new()),
}
}
pub fn insert(&self, key: K, value: V) -> Result<(), KeyValStoreError> {
self.db
.write()
.map_err(|_| KeyValStoreError::PoisonedAccess)?
.insert(key, value);
Ok(())
}
pub fn get(&self, key: &K) -> Result<Option<V>, KeyValStoreError> {
match self.db.read().map_err(|_| KeyValStoreError::PoisonedAccess)?.get(key) {
Some(val) => Ok(Some(val.clone())),
None => Ok(None),
}
}
pub fn is_empty(&self) -> Result<bool, KeyValStoreError> {
Ok(self.db.read().map_err(|_| KeyValStoreError::PoisonedAccess)?.is_empty())
}
pub fn len(&self) -> Result<usize, KeyValStoreError> {
Ok(self.db.read().map_err(|_| KeyValStoreError::PoisonedAccess)?.len())
}
pub fn for_each<F>(&self, mut f: F) -> Result<(), KeyValStoreError>
where F: FnMut(Result<(K, V), KeyValStoreError>) -> IterationResult {
for (key, val) in self.db.read().map_err(|_| KeyValStoreError::PoisonedAccess)?.iter() {
match f(Ok((key.clone(), val.clone()))) {
IterationResult::Break => break,
IterationResult::Continue => {},
}
}
Ok(())
}
pub fn contains_key(&self, key: &K) -> Result<bool, KeyValStoreError> {
Ok(self
.db
.read()
.map_err(|_| KeyValStoreError::PoisonedAccess)?
.contains_key(key))
}
pub fn remove(&self, key: &K) -> Result<(), KeyValStoreError> {
match self
.db
.write()
.map_err(|_| KeyValStoreError::PoisonedAccess)?
.remove(key)
{
Some(_) => Ok(()),
None => Err(KeyValStoreError::KeyNotFound),
}
}
}
impl<K: Clone + Eq + Hash, V: Clone> KeyValueStore<K, V> for HashmapDatabase<K, V> {
fn insert(&self, key: K, value: V) -> Result<(), KeyValStoreError> {
self.insert(key, value)
}
fn get(&self, key: &K) -> Result<Option<V>, KeyValStoreError> {
self.get(key)
}
fn get_many(&self, keys: &[K]) -> Result<Vec<V>, KeyValStoreError> {
keys.iter()
.filter_map(|k| match self.get(k) {
Ok(Some(v)) => Some(Ok(v)),
Ok(None) => None,
Err(e) => Some(Err(e)),
})
.collect()
}
fn size(&self) -> Result<usize, KeyValStoreError> {
self.len()
}
fn for_each<F>(&self, f: F) -> Result<(), KeyValStoreError>
where F: FnMut(Result<(K, V), KeyValStoreError>) -> IterationResult {
self.for_each(f)
}
fn exists(&self, key: &K) -> Result<bool, KeyValStoreError> {
self.contains_key(key)
}
fn delete(&self, key: &K) -> Result<(), KeyValStoreError> {
self.remove(key)
}
}
#[cfg(test)]
mod test {
use serde::{Deserialize, Serialize};
use super::*;
#[test]
fn test_hmap_kvstore() {
let db = HashmapDatabase::new();
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct Foo {
value: String,
}
let val1 = Foo {
value: "one".to_string(),
};
let val2 = Foo {
value: "two".to_string(),
};
let val3 = Foo {
value: "three".to_string(),
};
db.insert(1, val1.clone()).unwrap();
db.insert(2, val2.clone()).unwrap();
db.insert(3, val3.clone()).unwrap();
assert_eq!(db.get(&1).unwrap().unwrap(), val1);
assert_eq!(db.get(&2).unwrap().unwrap(), val2);
assert_eq!(db.get(&3).unwrap().unwrap(), val3);
assert!(db.get(&4).unwrap().is_none());
assert_eq!(db.size().unwrap(), 3);
assert!(db.exists(&1).unwrap());
assert!(db.exists(&2).unwrap());
assert!(db.exists(&3).unwrap());
assert!(!db.exists(&4).unwrap());
db.remove(&2).unwrap();
assert_eq!(db.get(&1).unwrap().unwrap(), val1);
assert!(db.get(&2).unwrap().is_none());
assert_eq!(db.get(&3).unwrap().unwrap(), val3);
assert!(db.get(&4).unwrap().is_none());
assert_eq!(db.size().unwrap(), 2);
assert!(db.exists(&1).unwrap());
assert!(!db.exists(&2).unwrap());
assert!(db.exists(&3).unwrap());
assert!(!db.exists(&4).unwrap());
let mut key1_found = false;
let mut key3_found = false;
let _res = db.for_each(|pair| {
let (key, val) = pair.unwrap();
if key == 1 {
key1_found = true;
assert_eq!(val, val1);
} else if key == 3 {
key3_found = true;
assert_eq!(val, val3);
} else {
panic!("Should not be here")
}
IterationResult::Continue
});
assert!(key1_found);
assert!(key3_found);
}
}