use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;
use parking_lot::Mutex;
use crate::{CacheEntry, CacheTier, Error, SizeError};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CacheOp<K, V> {
Get(K),
Insert {
key: K,
entry: CacheEntry<V>,
},
Invalidate(K),
Clear,
}
type FailPredicate<K, V> = Box<dyn Fn(&CacheOp<K, V>) -> bool + Send + Sync>;
pub struct MockCache<K, V> {
data: Arc<Mutex<HashMap<K, CacheEntry<V>>>>,
operations: Arc<Mutex<Vec<CacheOp<K, V>>>>,
fail_when: Arc<Mutex<Option<FailPredicate<K, V>>>>,
}
impl<K, V> std::fmt::Debug for MockCache<K, V>
where
K: std::fmt::Debug,
V: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockCache")
.field("data", &self.data)
.field("operations", &self.operations)
.field("fail_when", &self.fail_when.lock().is_some())
.finish()
}
}
impl<K, V> Clone for MockCache<K, V> {
fn clone(&self) -> Self {
Self {
data: Arc::clone(&self.data),
operations: Arc::clone(&self.operations),
fail_when: Arc::clone(&self.fail_when),
}
}
}
impl<K, V> Default for MockCache<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K, V> MockCache<K, V> {
#[must_use]
pub fn new() -> Self {
Self {
data: Arc::new(Mutex::new(HashMap::new())),
operations: Arc::new(Mutex::new(Vec::new())),
fail_when: Arc::new(Mutex::new(None)),
}
}
}
impl<K, V> MockCache<K, V>
where
K: Eq + Hash,
{
#[must_use]
pub fn with_data(data: HashMap<K, CacheEntry<V>>) -> Self {
Self {
data: Arc::new(Mutex::new(data)),
operations: Arc::new(Mutex::new(Vec::new())),
fail_when: Arc::new(Mutex::new(None)),
}
}
#[must_use]
pub fn entry_count(&self) -> usize {
self.data.lock().len()
}
#[must_use]
pub fn contains_key(&self, key: &K) -> bool {
self.data.lock().contains_key(key)
}
}
impl<K, V> MockCache<K, V>
where
K: Clone,
V: Clone,
{
pub fn fail_when<F>(&self, predicate: F)
where
F: Fn(&CacheOp<K, V>) -> bool + Send + Sync + 'static,
{
*self.fail_when.lock() = Some(Box::new(predicate));
}
pub fn clear_failures(&self) {
*self.fail_when.lock() = None;
}
#[must_use]
pub fn operations(&self) -> Vec<CacheOp<K, V>> {
self.operations.lock().clone()
}
pub fn clear_operations(&self) {
self.operations.lock().clear();
}
fn record(&self, op: CacheOp<K, V>) {
self.operations.lock().push(op);
}
fn should_fail(&self, op: &CacheOp<K, V>) -> bool {
self.fail_when.lock().as_ref().is_some_and(|predicate| predicate(op))
}
}
impl<K, V> CacheTier<K, V> for MockCache<K, V>
where
K: Clone + Eq + Hash + Send + Sync,
V: Clone + Send + Sync,
{
async fn get(&self, key: &K) -> Result<Option<CacheEntry<V>>, Error> {
let op = CacheOp::Get(key.clone());
if self.should_fail(&op) {
self.record(op);
return Err(Error::from_message("mock: get failed"));
}
self.record(op);
Ok(self.data.lock().get(key).cloned())
}
async fn insert(&self, key: K, entry: CacheEntry<V>) -> Result<(), Error> {
let op = CacheOp::Insert {
key: key.clone(),
entry: entry.clone(),
};
if self.should_fail(&op) {
self.record(op);
return Err(Error::from_message("mock: insert failed"));
}
self.record(op);
self.data.lock().insert(key, entry);
Ok(())
}
async fn invalidate(&self, key: &K) -> Result<(), Error> {
let op = CacheOp::Invalidate(key.clone());
if self.should_fail(&op) {
self.record(op);
return Err(Error::from_message("mock: invalidate failed"));
}
self.record(op);
self.data.lock().remove(key);
Ok(())
}
async fn clear(&self) -> Result<(), Error> {
let op = CacheOp::Clear;
if self.should_fail(&op) {
self.record(op);
return Err(Error::from_message("mock: clear failed"));
}
self.record(op);
self.data.lock().clear();
Ok(())
}
async fn len(&self) -> Result<u64, SizeError> {
Ok(self.data.lock().len() as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn insert_failure() {
let cache = MockCache::<String, i32>::new();
cache.fail_when(|op| matches!(op, CacheOp::Insert { .. }));
let result = cache.insert("key".to_string(), CacheEntry::new(42)).await;
result.unwrap_err();
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn invalidate_failure() {
let cache = MockCache::<String, i32>::new();
cache.fail_when(|op| matches!(op, CacheOp::Invalidate(_)));
let result = cache.invalidate(&"key".to_string()).await;
result.unwrap_err();
}
#[cfg_attr(miri, ignore)]
#[tokio::test]
async fn clear_failure() {
let cache = MockCache::<String, i32>::new();
cache.fail_when(|op| matches!(op, CacheOp::Clear));
let result = cache.clear().await;
result.unwrap_err();
}
}