use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StoreError {
StoreFull,
KeyNotFound,
WriteFailure(String),
}
impl fmt::Display for StoreError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StoreError::StoreFull => write!(f, "backing store is full"),
StoreError::KeyNotFound => write!(f, "key not found in backing store"),
StoreError::WriteFailure(msg) => write!(f, "write failure: {msg}"),
}
}
}
impl std::error::Error for StoreError {}
pub trait BackingStore<K, V> {
fn store(&mut self, key: &K, value: &V) -> Result<(), StoreError>;
fn load(&self, key: &K) -> Option<V>;
fn remove(&mut self, key: &K) -> bool;
}
#[derive(Debug, Clone, Default)]
pub struct InMemoryStore<K, V>
where
K: Eq + std::hash::Hash + Clone,
V: Clone,
{
data: HashMap<K, V>,
max_entries: usize,
write_count: u64,
read_count: u64,
}
impl<K, V> InMemoryStore<K, V>
where
K: Eq + std::hash::Hash + Clone,
V: Clone,
{
pub fn new() -> Self {
Self {
data: HashMap::new(),
max_entries: 0,
write_count: 0,
read_count: 0,
}
}
pub fn with_capacity(max_entries: usize) -> Self {
Self {
data: HashMap::with_capacity(max_entries),
max_entries,
write_count: 0,
read_count: 0,
}
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn write_count(&self) -> u64 {
self.write_count
}
pub fn read_count(&self) -> u64 {
self.read_count
}
}
impl<K, V> BackingStore<K, V> for InMemoryStore<K, V>
where
K: Eq + std::hash::Hash + Clone,
V: Clone,
{
fn store(&mut self, key: &K, value: &V) -> Result<(), StoreError> {
self.write_count += 1;
if self.max_entries > 0
&& self.data.len() >= self.max_entries
&& !self.data.contains_key(key)
{
return Err(StoreError::StoreFull);
}
self.data.insert(key.clone(), value.clone());
Ok(())
}
fn load(&self, key: &K) -> Option<V> {
self.data.get(key).cloned()
}
fn remove(&mut self, key: &K) -> bool {
self.data.remove(key).is_some()
}
}
#[derive(Debug, Clone, Default)]
pub struct WriteThroughStats {
pub cache_hits: u64,
pub backing_store_hits: u64,
pub misses: u64,
pub writes: u64,
pub write_errors: u64,
}
pub struct WriteThroughCache<K, V, S>
where
K: Eq + std::hash::Hash + Clone,
V: Clone,
S: BackingStore<K, V>,
{
capacity: usize,
cache: HashMap<K, V>,
insertion_order: Vec<K>,
store: S,
stats: WriteThroughStats,
_marker: std::marker::PhantomData<(K, V)>,
}
impl<K, V, S> WriteThroughCache<K, V, S>
where
K: Eq + std::hash::Hash + Clone,
V: Clone,
S: BackingStore<K, V>,
{
pub fn new(capacity: usize, store: S) -> Self {
assert!(capacity > 0, "WriteThroughCache capacity must be non-zero");
Self {
capacity,
cache: HashMap::with_capacity(capacity.min(1024)),
insertion_order: Vec::with_capacity(capacity.min(1024)),
store,
stats: WriteThroughStats::default(),
_marker: std::marker::PhantomData,
}
}
pub fn put(&mut self, key: K, value: V) -> Result<(), StoreError> {
if let Err(e) = self.store.store(&key, &value) {
self.stats.write_errors += 1;
return Err(e);
}
if !self.cache.contains_key(&key) && self.cache.len() >= self.capacity {
self.evict_oldest();
} else if self.cache.contains_key(&key) {
self.insertion_order.retain(|k| k != &key);
}
self.cache.insert(key.clone(), value);
self.insertion_order.push(key);
self.stats.writes += 1;
Ok(())
}
pub fn get(&mut self, key: &K) -> Option<&V> {
if self.cache.contains_key(key) {
self.stats.cache_hits += 1;
return self.cache.get(key);
}
match self.store.load(key) {
Some(value) => {
self.stats.backing_store_hits += 1;
if self.cache.len() >= self.capacity {
self.evict_oldest();
}
self.cache.insert(key.clone(), value);
self.insertion_order.push(key.clone());
self.cache.get(key)
}
None => {
self.stats.misses += 1;
None
}
}
}
pub fn invalidate(&mut self, key: &K) -> bool {
let in_cache = self.cache.remove(key).is_some();
if in_cache {
self.insertion_order.retain(|k| k != key);
}
let in_store = self.store.remove(key);
in_cache || in_store
}
pub fn stats(&self) -> WriteThroughStats {
self.stats.clone()
}
pub fn cache_len(&self) -> usize {
self.cache.len()
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn backing_store(&self) -> &S {
&self.store
}
fn evict_oldest(&mut self) {
if self.insertion_order.is_empty() {
return;
}
let oldest = self.insertion_order.remove(0);
self.cache.remove(&oldest);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cache(
cap: usize,
) -> WriteThroughCache<String, Vec<u8>, InMemoryStore<String, Vec<u8>>> {
let store = InMemoryStore::new();
WriteThroughCache::new(cap, store)
}
#[test]
fn test_put_and_get_from_cache() {
let mut cache = make_cache(10);
cache.put("key1".to_string(), vec![1, 2, 3]).expect("put");
let v = cache.get(&"key1".to_string());
assert_eq!(v, Some(&vec![1u8, 2, 3]));
assert_eq!(cache.stats().cache_hits, 1);
}
#[test]
fn test_put_persists_to_backing_store() {
let mut cache = make_cache(10);
cache.put("k".to_string(), vec![9]).expect("put");
let loaded = cache.backing_store().load(&"k".to_string());
assert_eq!(loaded, Some(vec![9u8]));
}
#[test]
fn test_get_fallback_to_backing_store() {
let store = InMemoryStore::new();
let mut cache: WriteThroughCache<String, u32, InMemoryStore<String, u32>> =
WriteThroughCache::new(10, store);
cache.put("target".to_string(), 42u32).expect("put");
let store2 = InMemoryStore::new();
let mut cache2: WriteThroughCache<String, u32, InMemoryStore<String, u32>> =
WriteThroughCache::new(2, store2);
cache2.put("target".to_string(), 42).expect("put");
cache2.put("a".to_string(), 1).expect("put");
cache2.put("b".to_string(), 2).expect("put");
let v = cache2.get(&"target".to_string());
assert_eq!(v, Some(&42));
assert_eq!(cache2.stats().backing_store_hits, 1);
}
#[test]
fn test_invalidate_removes_from_both() {
let mut cache = make_cache(10);
cache.put("x".to_string(), vec![0]).expect("put");
let removed = cache.invalidate(&"x".to_string());
assert!(removed);
assert_eq!(cache.get(&"x".to_string()), None);
assert_eq!(cache.backing_store().load(&"x".to_string()), None);
}
#[test]
fn test_invalidate_absent_key_returns_false() {
let mut cache = make_cache(10);
assert!(!cache.invalidate(&"ghost".to_string()));
}
#[test]
fn test_stats_writes_and_errors() {
let store = InMemoryStore::<String, u32>::with_capacity(1);
let mut cache: WriteThroughCache<String, u32, InMemoryStore<String, u32>> =
WriteThroughCache::new(10, store);
cache.put("first".to_string(), 1).expect("first put ok");
let result = cache.put("second".to_string(), 2);
assert_eq!(result, Err(StoreError::StoreFull));
let s = cache.stats();
assert_eq!(s.writes, 1);
assert_eq!(s.write_errors, 1);
}
#[test]
fn test_miss_when_not_in_either() {
let mut cache = make_cache(10);
assert_eq!(cache.get(&"nonexistent".to_string()), None);
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_capacity_eviction() {
let mut cache = make_cache(3);
cache.put("a".to_string(), vec![1]).expect("put");
cache.put("b".to_string(), vec![2]).expect("put");
cache.put("c".to_string(), vec![3]).expect("put");
cache.put("d".to_string(), vec![4]).expect("put");
assert_eq!(cache.stats().cache_hits, 0);
let v = cache.get(&"a".to_string());
assert_eq!(v, Some(&vec![1u8]));
assert_eq!(cache.stats().backing_store_hits, 1);
}
#[test]
fn test_overwrite_existing_key() {
let mut cache = make_cache(5);
cache.put("k".to_string(), vec![1]).expect("put");
cache.put("k".to_string(), vec![2]).expect("put"); assert_eq!(cache.get(&"k".to_string()), Some(&vec![2u8]));
assert_eq!(cache.cache_len(), 1);
}
#[test]
fn test_write_failure_does_not_pollute_cache() {
let store = InMemoryStore::<String, i32>::with_capacity(1);
let mut cache: WriteThroughCache<String, i32, InMemoryStore<String, i32>> =
WriteThroughCache::new(10, store);
cache.put("key1".to_string(), 1).expect("first ok");
let result = cache.put("key2".to_string(), 2);
assert!(result.is_err(), "second put should fail");
assert_eq!(cache.get(&"key2".to_string()), None);
}
}