use std::collections::HashSet;
use oxistore_core::{KvStore, StoreError};
use crate::Cache;
pub struct WriteThroughCache<S, C> {
store: S,
cache: C,
}
impl<S, C> WriteThroughCache<S, C>
where
S: KvStore,
C: Cache<Vec<u8>, Vec<u8>>,
{
pub fn new(store: S, cache: C) -> Self {
WriteThroughCache { store, cache }
}
pub fn store(&self) -> &S {
&self.store
}
pub fn cache(&self) -> &C {
&self.cache
}
pub fn get(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>, StoreError> {
if let Some(v) = self.cache.get(&key.to_vec()) {
return Ok(Some(v.clone()));
}
match self.store.get(key)? {
Some(v) => {
self.cache.put(key.to_vec(), v.clone());
Ok(Some(v))
}
None => Ok(None),
}
}
pub fn put(&mut self, key: Vec<u8>, value: Vec<u8>) -> Result<(), StoreError> {
self.store.put(&key, &value)?;
self.cache.put(key, value);
Ok(())
}
pub fn remove(&mut self, key: &[u8]) -> Result<(), StoreError> {
self.cache.remove(&key.to_vec());
self.store.delete(key)?;
Ok(())
}
pub fn cache_len(&self) -> usize {
self.cache.len()
}
}
pub struct WriteBackCache<S, C> {
store: S,
cache: C,
dirty: HashSet<Vec<u8>>,
}
impl<S, C> WriteBackCache<S, C>
where
S: KvStore,
C: Cache<Vec<u8>, Vec<u8>>,
{
pub fn new(store: S, cache: C) -> Self {
WriteBackCache {
store,
cache,
dirty: HashSet::new(),
}
}
pub fn store(&self) -> &S {
&self.store
}
pub fn cache(&self) -> &C {
&self.cache
}
pub fn dirty_count(&self) -> usize {
self.dirty.len()
}
pub fn get(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>, StoreError> {
if let Some(v) = self.cache.get(&key.to_vec()) {
return Ok(Some(v.clone()));
}
self.store.get(key)
}
pub fn put(&mut self, key: Vec<u8>, value: Vec<u8>) -> Result<(), StoreError> {
self.flush_if_eviction_imminent(&key)?;
self.dirty.insert(key.clone());
self.cache.put(key, value);
Ok(())
}
pub fn remove(&mut self, key: &[u8]) -> Result<(), StoreError> {
let key_vec = key.to_vec();
self.cache.remove(&key_vec);
self.dirty.remove(&key_vec);
self.store.delete(key)?;
Ok(())
}
pub fn flush(&mut self) -> Result<(), StoreError> {
let dirty_keys: Vec<Vec<u8>> = self.dirty.iter().cloned().collect();
for key in dirty_keys {
if let Some(val) = self.cache.peek(&key) {
self.store.put(&key, val)?;
}
}
self.dirty.clear();
Ok(())
}
fn flush_if_eviction_imminent(&mut self, incoming_key: &[u8]) -> Result<(), StoreError> {
if self.cache.contains_key(&incoming_key.to_vec()) {
return Ok(());
}
if self.cache.len() < self.cache.cap() {
return Ok(());
}
let dirty_keys: Vec<Vec<u8>> = self.dirty.iter().cloned().collect();
for key in dirty_keys {
if let Some(val) = self.cache.peek(&key) {
self.store.put(&key, val)?;
}
}
Ok(())
}
}
pub struct CacheableKvStore<S, C> {
store: S,
cache: std::sync::Mutex<C>,
}
impl<S, C> CacheableKvStore<S, C> {
pub fn new(store: S, cache: C) -> Self {
CacheableKvStore {
store,
cache: std::sync::Mutex::new(cache),
}
}
}
impl<S, C> KvStore for CacheableKvStore<S, C>
where
S: KvStore,
C: Cache<Vec<u8>, Vec<u8>> + Send,
{
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, StoreError> {
let cached = {
let mut guard = self
.cache
.lock()
.map_err(|e| StoreError::Other(format!("cache lock poisoned: {e}")))?;
guard.get(&key.to_vec()).cloned()
};
if let Some(v) = cached {
return Ok(Some(v));
}
let from_store = self.store.get(key)?;
if let Some(ref v) = from_store {
let mut guard = self
.cache
.lock()
.map_err(|e| StoreError::Other(format!("cache lock poisoned: {e}")))?;
guard.put(key.to_vec(), v.clone());
}
Ok(from_store)
}
fn put(&self, key: &[u8], value: &[u8]) -> Result<(), StoreError> {
self.store.put(key, value)?;
let mut guard = self
.cache
.lock()
.map_err(|e| StoreError::Other(format!("cache lock poisoned: {e}")))?;
guard.remove(&key.to_vec());
Ok(())
}
fn delete(&self, key: &[u8]) -> Result<(), StoreError> {
self.store.delete(key)?;
let mut guard = self
.cache
.lock()
.map_err(|e| StoreError::Other(format!("cache lock poisoned: {e}")))?;
guard.remove(&key.to_vec());
Ok(())
}
fn range<'a>(
&'a self,
lo: &[u8],
hi: &[u8],
) -> Result<oxistore_core::RangeIter<'a>, StoreError> {
self.store.range(lo, hi)
}
fn iter<'a>(&'a self) -> Result<oxistore_core::RangeIter<'a>, StoreError> {
self.store.iter()
}
fn transaction(&self) -> Result<Box<dyn oxistore_core::KvTxn + '_>, StoreError> {
self.store.transaction()
}
fn snapshot(&self) -> Result<Box<dyn oxistore_core::KvSnapshot + '_>, StoreError> {
self.store.snapshot()
}
fn flush(&self) -> Result<(), StoreError> {
self.store.flush()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::LruCache;
use oxistore_core::{KvSnapshot, KvStore, KvTxn, RangeIter, StoreError};
use std::collections::HashMap;
use std::sync::Mutex;
#[derive(Default, Debug)]
struct MemStore(Mutex<HashMap<Vec<u8>, Vec<u8>>>);
impl KvStore for MemStore {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, StoreError> {
Ok(self.0.lock().expect("lock").get(key).cloned())
}
fn put(&self, key: &[u8], value: &[u8]) -> Result<(), StoreError> {
self.0
.lock()
.expect("lock")
.insert(key.to_vec(), value.to_vec());
Ok(())
}
fn delete(&self, key: &[u8]) -> Result<(), StoreError> {
self.0.lock().expect("lock").remove(key);
Ok(())
}
fn range<'a>(&'a self, lo: &[u8], hi: &[u8]) -> Result<RangeIter<'a>, StoreError> {
let guard = self.0.lock().expect("lock");
let lo = lo.to_vec();
let hi = hi.to_vec();
let pairs: Vec<_> = guard
.iter()
.filter(|(k, _)| **k >= lo && **k < hi)
.map(|(k, v)| Ok((k.clone(), v.clone())))
.collect();
drop(guard);
Ok(Box::new(pairs.into_iter()))
}
fn transaction(&self) -> Result<Box<dyn KvTxn + '_>, StoreError> {
Err(StoreError::Other("MemStore: no txn".to_string()))
}
fn snapshot(&self) -> Result<Box<dyn KvSnapshot + '_>, StoreError> {
Err(StoreError::Other("MemStore: no snapshot".to_string()))
}
fn iter<'a>(&'a self) -> Result<RangeIter<'a>, StoreError> {
let guard = self.0.lock().expect("lock");
let mut pairs: Vec<(Vec<u8>, Vec<u8>)> =
guard.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
drop(guard);
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
Ok(Box::new(pairs.into_iter().map(Ok)))
}
fn flush(&self) -> Result<(), StoreError> {
Ok(())
}
}
#[test]
fn write_through_put_flushes_to_store() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wt = WriteThroughCache::new(store, cache);
wt.put(b"key".to_vec(), b"value".to_vec())
.expect("put failed");
let from_store = wt.store().get(b"key").expect("get failed");
assert_eq!(from_store, Some(b"value".to_vec()));
}
#[test]
fn write_through_get_hits_cache() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wt = WriteThroughCache::new(store, cache);
wt.put(b"k".to_vec(), b"v".to_vec()).expect("put");
let v = wt.get(b"k").expect("get");
assert_eq!(v, Some(b"v".to_vec()));
}
#[test]
fn write_through_get_miss_populates_from_store() {
let store = MemStore::default();
store.put(b"existing", b"from_store").expect("store put");
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wt = WriteThroughCache::new(store, cache);
let v = wt.get(b"existing").expect("get");
assert_eq!(v, Some(b"from_store".to_vec()));
assert_eq!(wt.cache_len(), 1);
}
#[test]
fn write_through_remove_clears_store() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wt = WriteThroughCache::new(store, cache);
wt.put(b"rm_key".to_vec(), b"rm_val".to_vec()).expect("put");
wt.remove(b"rm_key").expect("remove");
let from_store = wt.store().get(b"rm_key").expect("store get");
assert!(from_store.is_none());
}
#[test]
fn write_through_get_miss_absent_in_store() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wt = WriteThroughCache::new(store, cache);
let v = wt.get(b"no_such_key").expect("get");
assert!(v.is_none());
}
#[test]
fn write_back_put_deferred() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wb = WriteBackCache::new(store, cache);
wb.put(b"lazy".to_vec(), b"write".to_vec()).expect("put");
let from_store = wb.store().get(b"lazy").expect("store get");
assert!(from_store.is_none());
assert_eq!(wb.dirty_count(), 1);
}
#[test]
fn write_back_flush_persists() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wb = WriteBackCache::new(store, cache);
wb.put(b"a".to_vec(), b"1".to_vec()).expect("put");
wb.put(b"b".to_vec(), b"2".to_vec()).expect("put");
wb.flush().expect("flush");
assert_eq!(wb.dirty_count(), 0);
assert_eq!(
wb.store().get(b"a").expect("store get a"),
Some(b"1".to_vec())
);
assert_eq!(
wb.store().get(b"b").expect("store get b"),
Some(b"2".to_vec())
);
}
#[test]
fn write_back_get_hits_cache() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wb = WriteBackCache::new(store, cache);
wb.put(b"key".to_vec(), b"val".to_vec()).expect("put");
let v = wb.get(b"key").expect("get");
assert_eq!(v, Some(b"val".to_vec()));
}
#[test]
fn write_back_get_misses_to_store() {
let store = MemStore::default();
store.put(b"persistent", b"data").expect("store put");
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wb = WriteBackCache::new(store, cache);
let v = wb.get(b"persistent").expect("get");
assert_eq!(v, Some(b"data".to_vec()));
}
#[test]
fn write_back_remove_deletes_from_store() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(8);
let mut wb = WriteBackCache::new(store, cache);
wb.put(b"del".to_vec(), b"gone".to_vec()).expect("put");
wb.flush().expect("flush");
wb.remove(b"del").expect("remove");
assert!(wb.store().get(b"del").expect("store get").is_none());
assert_eq!(wb.dirty_count(), 0);
}
#[test]
fn write_back_dirty_eviction_flushes() {
let store = MemStore::default();
let cache = LruCache::<Vec<u8>, Vec<u8>>::new(2);
let mut wb = WriteBackCache::new(store, cache);
wb.put(b"first".to_vec(), b"v1".to_vec()).expect("put 1");
wb.put(b"second".to_vec(), b"v2".to_vec()).expect("put 2");
wb.put(b"third".to_vec(), b"v3".to_vec()).expect("put 3");
wb.flush().expect("flush");
let v2 = wb.store().get(b"second").expect("store get second");
let v3 = wb.store().get(b"third").expect("store get third");
assert_eq!(v2, Some(b"v2".to_vec()));
assert_eq!(v3, Some(b"v3".to_vec()));
}
}