use std::collections::HashMap;
use std::time::Instant;
pub trait BackingStore {
type Key: Eq + std::hash::Hash + Clone;
type Value: Clone;
type Error: std::fmt::Debug;
fn write(&mut self, key: &Self::Key, value: &Self::Value) -> Result<(), Self::Error>;
fn read(&self, key: &Self::Key) -> Result<Option<Self::Value>, Self::Error>;
fn delete(&mut self, key: &Self::Key) -> Result<(), Self::Error>;
}
struct CacheEntry<V> {
value: V,
dirty: bool,
last_modified: Instant,
}
pub struct WriteBehindCache<S: BackingStore> {
entries: HashMap<S::Key, CacheEntry<S::Value>>,
order: Vec<S::Key>,
capacity: usize,
store: S,
dirty_count: usize,
total_flushes: u64,
total_entries_flushed: u64,
}
#[derive(Debug, Clone)]
pub struct WriteBehindStats {
pub entry_count: usize,
pub dirty_count: usize,
pub capacity: usize,
pub total_flushes: u64,
pub total_entries_flushed: u64,
}
#[derive(Debug)]
pub enum WriteBehindError<E: std::fmt::Debug> {
StoreError(E),
}
impl<E: std::fmt::Debug> std::fmt::Display for WriteBehindError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::StoreError(e) => write!(f, "backing store error: {e:?}"),
}
}
}
impl<S: BackingStore> WriteBehindCache<S> {
pub fn new(capacity: usize, store: S) -> Self {
Self {
entries: HashMap::new(),
order: Vec::new(),
capacity: capacity.max(1),
store,
dirty_count: 0,
total_flushes: 0,
total_entries_flushed: 0,
}
}
pub fn put(&mut self, key: S::Key, value: S::Value) -> Result<(), WriteBehindError<S::Error>> {
if self.entries.contains_key(&key) {
if let Some(entry) = self.entries.get_mut(&key) {
if !entry.dirty {
self.dirty_count += 1;
}
entry.value = value;
entry.dirty = true;
entry.last_modified = Instant::now();
}
return Ok(());
}
while self.entries.len() >= self.capacity {
self.evict_oldest()?;
}
self.order.push(key.clone());
self.entries.insert(
key,
CacheEntry {
value,
dirty: true,
last_modified: Instant::now(),
},
);
self.dirty_count += 1;
Ok(())
}
pub fn get(&mut self, key: &S::Key) -> Result<Option<&S::Value>, WriteBehindError<S::Error>> {
if self.entries.contains_key(key) {
return Ok(self.entries.get(key).map(|e| &e.value));
}
let value = self.store.read(key).map_err(WriteBehindError::StoreError)?;
if let Some(v) = value {
while self.entries.len() >= self.capacity {
self.evict_oldest().map_err(|e| match e {
WriteBehindError::StoreError(se) => WriteBehindError::StoreError(se),
})?;
}
self.order.push(key.clone());
self.entries.insert(
key.clone(),
CacheEntry {
value: v,
dirty: false,
last_modified: Instant::now(),
},
);
return Ok(self.entries.get(key).map(|e| &e.value));
}
Ok(None)
}
pub fn delete(&mut self, key: &S::Key) -> Result<bool, WriteBehindError<S::Error>> {
if let Some(entry) = self.entries.remove(key) {
self.order.retain(|k| k != key);
if entry.dirty {
self.dirty_count = self.dirty_count.saturating_sub(1);
}
self.store
.delete(key)
.map_err(WriteBehindError::StoreError)?;
return Ok(true);
}
Ok(false)
}
pub fn flush(&mut self) -> Result<usize, WriteBehindError<S::Error>> {
let dirty_keys: Vec<S::Key> = self
.entries
.iter()
.filter(|(_, e)| e.dirty)
.map(|(k, _)| k.clone())
.collect();
let count = dirty_keys.len();
for key in &dirty_keys {
if let Some(entry) = self.entries.get(key) {
self.store
.write(key, &entry.value)
.map_err(WriteBehindError::StoreError)?;
}
if let Some(entry) = self.entries.get_mut(key) {
entry.dirty = false;
}
}
self.dirty_count = 0;
self.total_flushes += 1;
self.total_entries_flushed += count as u64;
Ok(count)
}
pub fn flush_if_needed(
&mut self,
threshold: usize,
) -> Result<usize, WriteBehindError<S::Error>> {
if self.dirty_count >= threshold {
self.flush()
} else {
Ok(0)
}
}
pub fn dirty_count(&self) -> usize {
self.dirty_count
}
pub fn stats(&self) -> WriteBehindStats {
WriteBehindStats {
entry_count: self.entries.len(),
dirty_count: self.dirty_count,
capacity: self.capacity,
total_flushes: self.total_flushes,
total_entries_flushed: self.total_entries_flushed,
}
}
pub fn store(&self) -> &S {
&self.store
}
pub fn store_mut(&mut self) -> &mut S {
&mut self.store
}
pub fn is_dirty(&self, key: &S::Key) -> bool {
self.entries.get(key).map(|e| e.dirty).unwrap_or(false)
}
pub fn contains(&self, key: &S::Key) -> bool {
self.entries.contains_key(key)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn flush_older_than(
&mut self,
max_age: std::time::Duration,
) -> Result<usize, WriteBehindError<S::Error>> {
let now = Instant::now();
let old_dirty_keys: Vec<S::Key> = self
.entries
.iter()
.filter(|(_, e)| e.dirty && now.duration_since(e.last_modified) >= max_age)
.map(|(k, _)| k.clone())
.collect();
let count = old_dirty_keys.len();
for key in &old_dirty_keys {
if let Some(entry) = self.entries.get(key) {
self.store
.write(key, &entry.value)
.map_err(WriteBehindError::StoreError)?;
}
if let Some(entry) = self.entries.get_mut(key) {
entry.dirty = false;
}
}
self.dirty_count = self.dirty_count.saturating_sub(count);
if count > 0 {
self.total_flushes += 1;
self.total_entries_flushed += count as u64;
}
Ok(count)
}
pub fn dirty_keys(&self) -> Vec<S::Key> {
self.entries
.iter()
.filter(|(_, e)| e.dirty)
.map(|(k, _)| k.clone())
.collect()
}
pub fn mark_clean(&mut self, key: &S::Key) -> bool {
if let Some(entry) = self.entries.get_mut(key) {
if entry.dirty {
entry.dirty = false;
self.dirty_count = self.dirty_count.saturating_sub(1);
return true;
}
}
false
}
pub fn capacity(&self) -> usize {
self.capacity
}
fn evict_oldest(&mut self) -> Result<(), WriteBehindError<S::Error>> {
if self.order.is_empty() {
return Ok(());
}
let key = self.order.remove(0);
if let Some(entry) = self.entries.remove(&key) {
if entry.dirty {
self.store
.write(&key, &entry.value)
.map_err(WriteBehindError::StoreError)?;
self.dirty_count = self.dirty_count.saturating_sub(1);
self.total_entries_flushed += 1;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct MemStore {
data: Arc<Mutex<HashMap<String, String>>>,
}
impl MemStore {
fn new() -> Self {
Self {
data: Arc::new(Mutex::new(HashMap::new())),
}
}
fn snapshot(&self) -> HashMap<String, String> {
let guard = self.data.lock().unwrap_or_else(|p| p.into_inner());
guard.clone()
}
}
impl BackingStore for MemStore {
type Key = String;
type Value = String;
type Error = String;
fn write(&mut self, key: &String, value: &String) -> Result<(), String> {
let mut guard = self.data.lock().unwrap_or_else(|p| p.into_inner());
guard.insert(key.clone(), value.clone());
Ok(())
}
fn read(&self, key: &String) -> Result<Option<String>, String> {
let guard = self.data.lock().unwrap_or_else(|p| p.into_inner());
Ok(guard.get(key).cloned())
}
fn delete(&mut self, key: &String) -> Result<(), String> {
let mut guard = self.data.lock().unwrap_or_else(|p| p.into_inner());
guard.remove(key);
Ok(())
}
}
#[test]
fn test_put_and_get() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("k1".to_string(), "v1".to_string()).ok();
let val = cache.get(&"k1".to_string()).ok().flatten();
assert_eq!(val, Some(&"v1".to_string()));
}
#[test]
fn test_dirty_tracking() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
assert!(cache.is_dirty(&"a".to_string()));
assert_eq!(cache.dirty_count(), 1);
}
#[test]
fn test_flush_writes_to_store() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store.clone());
cache.put("x".to_string(), "42".to_string()).ok();
let flushed = cache.flush().ok();
assert_eq!(flushed, Some(1));
assert!(!cache.is_dirty(&"x".to_string()));
let snap = store.snapshot();
assert_eq!(snap.get("x"), Some(&"42".to_string()));
}
#[test]
fn test_flush_clears_dirty() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
cache.put("b".to_string(), "2".to_string()).ok();
cache.flush().ok();
assert_eq!(cache.dirty_count(), 0);
}
#[test]
fn test_flush_if_needed() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
let flushed = cache.flush_if_needed(5).ok();
assert_eq!(flushed, Some(0)); cache.put("b".to_string(), "2".to_string()).ok();
cache.put("c".to_string(), "3".to_string()).ok();
let flushed = cache.flush_if_needed(2).ok();
assert_eq!(flushed, Some(3)); }
#[test]
fn test_eviction_flushes_dirty() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(2, store.clone());
cache.put("a".to_string(), "1".to_string()).ok();
cache.put("b".to_string(), "2".to_string()).ok();
cache.put("c".to_string(), "3".to_string()).ok();
let snap = store.snapshot();
assert_eq!(snap.get("a"), Some(&"1".to_string()));
}
#[test]
fn test_delete() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store.clone());
cache.put("k".to_string(), "v".to_string()).ok();
cache.flush().ok();
let deleted = cache.delete(&"k".to_string()).ok();
assert_eq!(deleted, Some(true));
let snap = store.snapshot();
assert!(!snap.contains_key("k"));
}
#[test]
fn test_read_through() {
let store = MemStore::new();
{
let mut guard = store.data.lock().unwrap_or_else(|p| p.into_inner());
guard.insert("pre".to_string(), "existing".to_string());
}
let mut cache = WriteBehindCache::new(10, store);
let val = cache.get(&"pre".to_string()).ok().flatten();
assert_eq!(val, Some(&"existing".to_string()));
assert!(!cache.is_dirty(&"pre".to_string()));
}
#[test]
fn test_update_re_dirties() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
cache.flush().ok();
assert!(!cache.is_dirty(&"a".to_string()));
cache.put("a".to_string(), "2".to_string()).ok();
assert!(cache.is_dirty(&"a".to_string()));
}
#[test]
fn test_stats() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
cache.put("b".to_string(), "2".to_string()).ok();
cache.flush().ok();
let s = cache.stats();
assert_eq!(s.entry_count, 2);
assert_eq!(s.dirty_count, 0);
assert_eq!(s.total_flushes, 1);
assert_eq!(s.total_entries_flushed, 2);
}
#[test]
fn test_delete_absent() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
let deleted = cache.delete(&"ghost".to_string()).ok();
assert_eq!(deleted, Some(false));
}
#[test]
fn test_get_absent() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
let val = cache.get(&"nope".to_string()).ok().flatten();
assert!(val.is_none());
}
#[test]
fn test_contains() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("x".to_string(), "val".to_string()).ok();
assert!(cache.contains(&"x".to_string()));
assert!(!cache.contains(&"y".to_string()));
}
#[test]
fn test_len_and_is_empty() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
cache.put("a".to_string(), "1".to_string()).ok();
cache.put("b".to_string(), "2".to_string()).ok();
assert_eq!(cache.len(), 2);
assert!(!cache.is_empty());
}
#[test]
fn test_flush_older_than() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store.clone());
cache.put("old".to_string(), "old_val".to_string()).ok();
std::thread::sleep(std::time::Duration::from_millis(50));
cache.put("new".to_string(), "new_val".to_string()).ok();
let flushed = cache
.flush_older_than(std::time::Duration::from_millis(30))
.ok();
assert_eq!(flushed, Some(1));
assert!(!cache.is_dirty(&"old".to_string()));
assert!(cache.is_dirty(&"new".to_string()));
let snap = store.snapshot();
assert!(snap.contains_key("old"));
}
#[test]
fn test_flush_older_than_zero() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
cache.put("b".to_string(), "2".to_string()).ok();
let flushed = cache
.flush_older_than(std::time::Duration::from_millis(0))
.ok();
assert_eq!(flushed, Some(2));
assert_eq!(cache.dirty_count(), 0);
}
#[test]
fn test_dirty_keys() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
cache.put("b".to_string(), "2".to_string()).ok();
cache.put("c".to_string(), "3".to_string()).ok();
cache.flush().ok();
cache.put("b".to_string(), "updated".to_string()).ok();
let dirty = cache.dirty_keys();
assert_eq!(dirty.len(), 1);
assert_eq!(dirty[0], "b");
}
#[test]
fn test_mark_clean() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store.clone());
cache.put("x".to_string(), "val".to_string()).ok();
assert!(cache.is_dirty(&"x".to_string()));
assert!(cache.mark_clean(&"x".to_string()));
assert!(!cache.is_dirty(&"x".to_string()));
assert_eq!(cache.dirty_count(), 0);
let snap = store.snapshot();
assert!(!snap.contains_key("x"));
}
#[test]
fn test_mark_clean_already_clean() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
cache.flush().ok();
assert!(!cache.mark_clean(&"a".to_string()));
}
#[test]
fn test_mark_clean_absent() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
assert!(!cache.mark_clean(&"ghost".to_string()));
}
#[test]
fn test_capacity() {
let store = MemStore::new();
let cache: WriteBehindCache<MemStore> = WriteBehindCache::new(42, store);
assert_eq!(cache.capacity(), 42);
}
#[test]
fn test_multiple_flushes_stats() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(10, store);
cache.put("a".to_string(), "1".to_string()).ok();
cache.flush().ok();
cache.put("b".to_string(), "2".to_string()).ok();
cache.flush().ok();
let s = cache.stats();
assert_eq!(s.total_flushes, 2);
assert_eq!(s.total_entries_flushed, 2);
}
#[test]
fn test_eviction_cascade() {
let store = MemStore::new();
let mut cache = WriteBehindCache::new(3, store.clone());
for i in 0..5 {
cache.put(format!("k{i}"), format!("v{i}")).ok();
}
let snap = store.snapshot();
assert!(snap.contains_key("k0"), "evicted k0 should be in store");
assert!(snap.contains_key("k1"), "evicted k1 should be in store");
}
#[test]
fn test_read_through_is_clean() {
let store = MemStore::new();
{
let mut guard = store.data.lock().unwrap_or_else(|p| p.into_inner());
guard.insert("existing".to_string(), "value".to_string());
}
let mut cache = WriteBehindCache::new(10, store);
cache.get(&"existing".to_string()).ok();
assert!(!cache.is_dirty(&"existing".to_string()));
assert_eq!(cache.dirty_count(), 0);
}
#[test]
fn test_store_accessors() {
let store = MemStore::new();
let cache = WriteBehindCache::new(10, store);
let _store_ref = cache.store();
}
}