use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::Cache;
macro_rules! check_error_mode {
($error_mode:expr) => {{
let guard = $error_mode.lock().unwrap();
if let Some(ref msg) = *guard {
return Err(anyhow::anyhow!("{}", msg));
}
}};
}
#[derive(Clone, Debug)]
struct MockEntry {
value: Vec<u8>,
expires_at: Option<Instant>,
}
impl MockEntry {
fn is_expired(&self) -> bool {
self.expires_at
.map(|t| Instant::now() >= t)
.unwrap_or(false)
}
}
#[derive(Clone, Debug)]
pub struct MockCache {
data: Arc<Mutex<HashMap<Vec<u8>, MockEntry>>>,
error_mode: Arc<Mutex<Option<String>>>,
operation_count: Arc<Mutex<OperationCounts>>,
}
#[derive(Clone, Debug, Default)]
pub struct OperationCounts {
pub gets: usize,
pub sets: usize,
pub set_nx_px: usize,
pub deletes: usize,
}
impl MockCache {
pub fn new() -> Self {
Self {
data: Arc::new(Mutex::new(HashMap::new())),
error_mode: Arc::new(Mutex::new(None)),
operation_count: Arc::new(Mutex::new(OperationCounts::default())),
}
}
pub fn with_data<I, K, V>(entries: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let cache = Self::new();
{
let mut data = cache.data.lock().unwrap();
for (key, value) in entries {
data.insert(
key.as_ref().to_vec(),
MockEntry {
value: value.as_ref().to_vec(),
expires_at: None,
},
);
}
}
cache
}
pub fn enable_error_mode(&self, message: &str) {
let mut error_mode = self.error_mode.lock().unwrap();
*error_mode = Some(message.to_owned());
}
pub fn disable_error_mode(&self) {
let mut error_mode = self.error_mode.lock().unwrap();
*error_mode = None;
}
pub fn operation_counts(&self) -> OperationCounts {
self.operation_count.lock().unwrap().clone()
}
pub fn reset_counts(&self) {
let mut counts = self.operation_count.lock().unwrap();
*counts = OperationCounts::default();
}
pub fn len(&self) -> usize {
let data = self.data.lock().unwrap();
data.values().filter(|e| !e.is_expired()).count()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
let mut data = self.data.lock().unwrap();
data.clear();
}
pub fn force_expire(&self, key: &[u8]) {
let mut data = self.data.lock().unwrap();
if let Some(entry) = data.get_mut(&key.to_vec()) {
entry.expires_at = Some(Instant::now() - Duration::from_secs(1));
}
}
}
impl Default for MockCache {
fn default() -> Self {
Self::new()
}
}
impl Cache for MockCache {
fn set_nx_px(
&self,
key: &[u8],
value: &[u8],
ttl: Duration,
) -> impl Future<Output = anyhow::Result<bool>> + Send {
let key_vec = key.to_vec();
let value_vec = value.to_vec();
let data = Arc::clone(&self.data);
let error_mode = Arc::clone(&self.error_mode);
let operation_count = Arc::clone(&self.operation_count);
async move {
check_error_mode!(error_mode);
operation_count.lock().unwrap().set_nx_px += 1;
let mut data = data.lock().unwrap();
if let Some(entry) = data.get(&key_vec)
&& !entry.is_expired()
{
return Ok(false);
}
let expires_at = if ttl.is_zero() {
Some(Instant::now())
} else {
Some(Instant::now() + ttl)
};
data.insert(
key_vec,
MockEntry {
value: value_vec,
expires_at,
},
);
Ok(true)
}
}
fn set(
&self,
key: &[u8],
value: &[u8],
ttl: Duration,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let key_vec = key.to_vec();
let value_vec = value.to_vec();
let data = Arc::clone(&self.data);
let error_mode = Arc::clone(&self.error_mode);
let operation_count = Arc::clone(&self.operation_count);
async move {
check_error_mode!(error_mode);
operation_count.lock().unwrap().sets += 1;
let expires_at = if ttl.is_zero() {
Some(Instant::now())
} else {
Some(Instant::now() + ttl)
};
let mut data = data.lock().unwrap();
data.insert(
key_vec,
MockEntry {
value: value_vec,
expires_at,
},
);
Ok(())
}
}
fn get(&self, key: &[u8]) -> impl Future<Output = anyhow::Result<Option<Vec<u8>>>> + Send {
let key_vec = key.to_vec();
let data = Arc::clone(&self.data);
let error_mode = Arc::clone(&self.error_mode);
let operation_count = Arc::clone(&self.operation_count);
async move {
check_error_mode!(error_mode);
operation_count.lock().unwrap().gets += 1;
let data = data.lock().unwrap();
match data.get(&key_vec) {
Some(entry) if !entry.is_expired() => Ok(Some(entry.value.clone())),
_ => Ok(None),
}
}
}
fn del(&self, key: &[u8]) -> impl Future<Output = anyhow::Result<()>> + Send {
let key_vec = key.to_vec();
let data = Arc::clone(&self.data);
let error_mode = Arc::clone(&self.error_mode);
let operation_count = Arc::clone(&self.operation_count);
async move {
check_error_mode!(error_mode);
operation_count.lock().unwrap().deletes += 1;
let mut data = data.lock().unwrap();
data.remove(&key_vec);
Ok(())
}
}
}
#[derive(Clone, Debug)]
pub struct MockSet {
data: Arc<Mutex<std::collections::HashSet<String>>>,
error_mode: Arc<Mutex<Option<String>>>,
}
impl MockSet {
pub fn new() -> Self {
Self {
data: Arc::new(Mutex::new(std::collections::HashSet::new())),
error_mode: Arc::new(Mutex::new(None)),
}
}
pub fn enable_error_mode(&self, message: &str) {
let mut error_mode = self.error_mode.lock().unwrap();
*error_mode = Some(message.to_owned());
}
pub fn disable_error_mode(&self) {
let mut error_mode = self.error_mode.lock().unwrap();
*error_mode = None;
}
pub async fn add_items(&self, items: &[String]) -> anyhow::Result<()> {
check_error_mode!(self.error_mode);
if items.is_empty() {
return Ok(());
}
let mut data = self.data.lock().unwrap();
for item in items {
data.insert(item.clone());
}
Ok(())
}
pub async fn remove_items(&self, items: &[String]) -> anyhow::Result<()> {
check_error_mode!(self.error_mode);
if items.is_empty() {
return Ok(());
}
let mut data = self.data.lock().unwrap();
for item in items {
data.remove(item);
}
Ok(())
}
pub async fn add_item(&self, item: &str) -> anyhow::Result<()> {
self.add_items(&[item.to_owned()]).await
}
pub async fn remove_item(&self, item: &str) -> anyhow::Result<()> {
self.remove_items(&[item.to_owned()]).await
}
pub async fn load_items(&self) -> anyhow::Result<Vec<String>> {
check_error_mode!(self.error_mode);
let data = self.data.lock().unwrap();
Ok(data.iter().cloned().collect())
}
pub async fn trim_to(&self, max_entries: usize) -> anyhow::Result<()> {
check_error_mode!(self.error_mode);
if max_entries == 0 {
return Ok(());
}
let mut data = self.data.lock().unwrap();
while data.len() > max_entries {
if let Some(item) = data.iter().next().cloned() {
data.remove(&item);
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.data.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.data.lock().unwrap().is_empty()
}
pub fn clear(&self) {
self.data.lock().unwrap().clear();
}
pub fn contains(&self, item: &str) -> bool {
self.data.lock().unwrap().contains(item)
}
}
impl Default for MockSet {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_cache_basic_operations() {
let cache = MockCache::new();
cache
.set(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
let result = cache.get(b"nonexistent").await.unwrap();
assert_eq!(result, None);
cache.del(b"key1").await.unwrap();
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_mock_cache_set_nx_px() {
let cache = MockCache::new();
let was_set = cache
.set_nx_px(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
assert!(was_set);
let was_set = cache
.set_nx_px(b"key1", b"value2", Duration::from_secs(60))
.await
.unwrap();
assert!(!was_set);
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
}
#[tokio::test]
async fn test_mock_cache_error_mode() {
let cache = MockCache::new();
cache
.set(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
cache.enable_error_mode("Redis connection failed");
let result = cache.get(b"key1").await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("connection failed")
);
let result = cache.set(b"key2", b"value2", Duration::from_secs(60)).await;
assert!(result.is_err());
let result = cache
.set_nx_px(b"key3", b"value3", Duration::from_secs(60))
.await;
assert!(result.is_err());
let result = cache.del(b"key1").await;
assert!(result.is_err());
cache.disable_error_mode();
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
}
#[tokio::test]
async fn test_mock_cache_operation_counts() {
let cache = MockCache::new();
cache
.set(b"k1", b"v1", Duration::from_secs(60))
.await
.unwrap();
cache
.set(b"k2", b"v2", Duration::from_secs(60))
.await
.unwrap();
cache.get(b"k1").await.unwrap();
cache.get(b"k2").await.unwrap();
cache.get(b"k3").await.unwrap();
cache
.set_nx_px(b"k4", b"v4", Duration::from_secs(60))
.await
.unwrap();
cache.del(b"k1").await.unwrap();
let counts = cache.operation_counts();
assert_eq!(counts.sets, 2);
assert_eq!(counts.gets, 3);
assert_eq!(counts.set_nx_px, 1);
assert_eq!(counts.deletes, 1);
cache.reset_counts();
let counts = cache.operation_counts();
assert_eq!(counts.sets, 0);
assert_eq!(counts.gets, 0);
}
#[tokio::test]
async fn test_mock_cache_ttl_expiration() {
let cache = MockCache::new();
cache
.set(b"key1", b"value1", Duration::from_millis(50))
.await
.unwrap();
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, Some(b"value1".to_vec()));
tokio::time::sleep(Duration::from_millis(100)).await;
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_mock_cache_force_expire() {
let cache = MockCache::new();
cache
.set(b"key1", b"value1", Duration::from_secs(60))
.await
.unwrap();
cache.force_expire(b"key1");
let result = cache.get(b"key1").await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_mock_cache_with_data() {
let cache = MockCache::with_data([
(b"key1".as_slice(), b"value1".as_slice()),
(b"key2", b"value2"),
]);
let result1 = cache.get(b"key1").await.unwrap();
assert_eq!(result1, Some(b"value1".to_vec()));
let result2 = cache.get(b"key2").await.unwrap();
assert_eq!(result2, Some(b"value2".to_vec()));
}
#[tokio::test]
async fn test_mock_cache_len_and_clear() {
let cache = MockCache::new();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
cache
.set(b"k1", b"v1", Duration::from_secs(60))
.await
.unwrap();
cache
.set(b"k2", b"v2", Duration::from_secs(60))
.await
.unwrap();
assert!(!cache.is_empty());
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[tokio::test]
async fn test_mock_set_basic_operations() {
let set = MockSet::new();
assert!(set.is_empty());
set.add_item("item1").await.unwrap();
set.add_item("item2").await.unwrap();
assert_eq!(set.len(), 2);
assert!(set.contains("item1"));
assert!(set.contains("item2"));
assert!(!set.contains("item3"));
let items = set.load_items().await.unwrap();
assert_eq!(items.len(), 2);
set.remove_item("item1").await.unwrap();
assert!(!set.contains("item1"));
assert_eq!(set.len(), 1);
set.clear();
assert!(set.is_empty());
}
#[tokio::test]
async fn test_mock_set_batch_operations() {
let set = MockSet::new();
set.add_items(&["a".to_owned(), "b".to_owned(), "c".to_owned()])
.await
.unwrap();
assert_eq!(set.len(), 3);
set.remove_items(&["a".to_owned(), "c".to_owned()])
.await
.unwrap();
assert_eq!(set.len(), 1);
assert!(set.contains("b"));
}
#[tokio::test]
async fn test_mock_set_trim_to() {
let set = MockSet::new();
for i in 0..10 {
set.add_item(&format!("item{i}")).await.unwrap();
}
assert_eq!(set.len(), 10);
set.trim_to(5).await.unwrap();
assert_eq!(set.len(), 5);
set.trim_to(0).await.unwrap();
assert_eq!(set.len(), 5);
}
#[tokio::test]
async fn test_mock_set_error_mode() {
let set = MockSet::new();
set.add_item("item1").await.unwrap();
set.enable_error_mode("Redis error");
assert!(set.add_item("item2").await.is_err());
assert!(set.remove_item("item1").await.is_err());
assert!(set.load_items().await.is_err());
assert!(set.trim_to(1).await.is_err());
set.disable_error_mode();
assert!(set.load_items().await.is_ok());
}
#[tokio::test]
async fn test_mock_set_empty_operations() {
let set = MockSet::new();
set.add_items(&[]).await.unwrap();
set.remove_items(&[]).await.unwrap();
assert!(set.is_empty());
}
}