use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
struct CacheEntry<V> {
value: V,
expires_at: Instant,
}
#[derive(Clone)]
pub struct QueryCache<K, V>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
max_size: usize,
ttl: Duration,
inner: Arc<RwLock<CacheInner<K, V>>>,
}
#[derive(Debug)]
struct CacheInner<K, V>
where
K: Clone + Eq + std::hash::Hash,
V: Clone,
{
entries: HashMap<K, CacheEntry<V>>,
keys: Vec<K>,
}
impl<K, V> QueryCache<K, V>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(max_size: usize, ttl: Duration) -> Self {
Self {
max_size,
ttl,
inner: Arc::new(RwLock::new(CacheInner {
entries: HashMap::new(),
keys: Vec::new(),
})),
}
}
pub async fn get(&self, key: &K) -> Option<V> {
let mut inner = self.inner.write().await;
let now = Instant::now();
let key_clone = key.clone();
let value_opt = inner.entries.get(&key_clone).cloned();
if let Some(entry) = value_opt {
if now < entry.expires_at {
if let Some(pos) = inner.keys.iter().position(|k| k == &key_clone) {
inner.keys.remove(pos);
inner.keys.push(key_clone);
}
return Some(entry.value);
} else {
inner.entries.remove(&key_clone);
if let Some(pos) = inner.keys.iter().position(|k| k == &key_clone) {
inner.keys.remove(pos);
}
}
}
None
}
pub async fn insert(&self, key: K, value: V) {
let mut inner = self.inner.write().await;
if self.max_size == 0 {
return;
}
while inner.keys.len() >= self.max_size && !inner.keys.is_empty() {
if let Some(old_key) = inner.keys.first() {
let old_key = old_key.clone();
inner.keys.remove(0);
inner.entries.remove(&old_key);
}
}
let expires_at = Instant::now() + self.ttl;
if !inner.keys.contains(&key) {
inner.keys.push(key.clone());
}
inner.entries.insert(key, CacheEntry { value, expires_at });
}
pub async fn invalidate(&self, key: &K) {
let mut inner = self.inner.write().await;
inner.entries.remove(key);
if let Some(pos) = inner.keys.iter().position(|k| k == key) {
inner.keys.remove(pos);
}
}
pub async fn clear(&self) {
let mut inner = self.inner.write().await;
inner.entries.clear();
inner.keys.clear();
}
pub async fn len(&self) -> usize {
let inner = self.inner.read().await;
inner.entries.len()
}
pub async fn is_empty(&self) -> bool {
let inner = self.inner.read().await;
inner.entries.is_empty()
}
}
impl<K, V> std::fmt::Debug for QueryCache<K, V>
where
K: Clone + Eq + std::hash::Hash + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueryCache")
.field("max_size", &self.max_size)
.field("ttl", &self.ttl)
.field("inner", &"<cache>")
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_cache_insert_get() {
let cache = QueryCache::new(10, Duration::from_secs(60));
cache.insert("key1".to_string(), "value1".to_string()).await;
let value = cache.get(&"key1".to_string()).await;
assert_eq!(value, Some("value1".to_string()));
}
#[tokio::test]
async fn test_cache_miss() {
let cache = QueryCache::new(10, Duration::from_secs(60));
let value: Option<String> = cache.get(&"nonexistent".to_string()).await;
assert!(value.is_none());
}
#[tokio::test]
async fn test_cache_expiration() {
let cache = QueryCache::new(10, Duration::from_millis(50));
cache.insert("key".to_string(), "value".to_string()).await;
tokio::time::sleep(Duration::from_millis(100)).await;
let value = cache.get(&"key".to_string()).await;
assert!(value.is_none());
}
#[tokio::test]
async fn test_cache_eviction() {
let cache = QueryCache::new(2, Duration::from_secs(60));
cache.insert("key1".to_string(), "value1".to_string()).await;
cache.insert("key2".to_string(), "value2".to_string()).await;
cache.insert("key3".to_string(), "value3".to_string()).await;
assert_eq!(cache.len().await, 2);
assert!(cache.get(&"key1".to_string()).await.is_none());
assert_eq!(
cache.get(&"key2".to_string()).await,
Some("value2".to_string())
);
assert_eq!(
cache.get(&"key3".to_string()).await,
Some("value3".to_string())
);
}
#[tokio::test]
async fn test_cache_invalidate() {
let cache = QueryCache::new(10, Duration::from_secs(60));
cache.insert("key".to_string(), "value".to_string()).await;
cache.invalidate(&"key".to_string()).await;
assert!(cache.get(&"key".to_string()).await.is_none());
assert_eq!(cache.len().await, 0);
}
#[tokio::test]
async fn test_cache_clear() {
let cache = QueryCache::new(10, Duration::from_secs(60));
cache.insert("key1".to_string(), "value1".to_string()).await;
cache.insert("key2".to_string(), "value2".to_string()).await;
cache.clear().await;
assert!(cache.is_empty().await);
}
#[tokio::test]
async fn test_cache_lru_touch() {
let cache = QueryCache::new(3, Duration::from_secs(60));
cache.insert("key1".to_string(), "value1".to_string()).await;
cache.insert("key2".to_string(), "value2".to_string()).await;
cache.insert("key3".to_string(), "value3".to_string()).await;
let _ = cache.get(&"key1".to_string()).await;
cache.insert("key4".to_string(), "value4".to_string()).await;
assert_eq!(cache.len().await, 3);
assert!(cache.get(&"key2".to_string()).await.is_none());
assert_eq!(
cache.get(&"key1".to_string()).await,
Some("value1".to_string())
);
assert_eq!(
cache.get(&"key3".to_string()).await,
Some("value3".to_string())
);
assert_eq!(
cache.get(&"key4".to_string()).await,
Some("value4".to_string())
);
}
#[tokio::test]
async fn test_cache_update_existing() {
let cache = QueryCache::new(10, Duration::from_millis(100));
cache.insert("key".to_string(), "valueA".to_string()).await;
tokio::time::sleep(Duration::from_millis(50)).await;
cache.insert("key".to_string(), "valueB".to_string()).await;
assert_eq!(
cache.get(&"key".to_string()).await,
Some("valueB".to_string())
);
tokio::time::sleep(Duration::from_millis(60)).await;
assert_eq!(
cache.get(&"key".to_string()).await,
Some("valueB".to_string())
);
}
#[tokio::test]
async fn test_cache_concurrent_access() {
let cache = QueryCache::new(20, Duration::from_secs(60));
let mut handles = vec![];
for i in 0..10 {
let cache_clone = cache.clone();
handles.push(tokio::spawn(async move {
cache_clone
.insert(format!("key{}", i), format!("value{}", i))
.await;
}));
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(cache.len().await, 10);
for i in 0..10 {
assert_eq!(
cache.get(&format!("key{}", i)).await,
Some(format!("value{}", i))
);
}
}
#[tokio::test]
async fn test_cache_stress_eviction() {
let cache = QueryCache::new(5, Duration::from_secs(60));
for i in 0..100 {
cache
.insert(format!("key{}", i), format!("value{}", i))
.await;
}
assert_eq!(cache.len().await, 5);
assert!(cache.get(&"key0".to_string()).await.is_none());
assert!(cache.get(&"key94".to_string()).await.is_none());
assert_eq!(
cache.get(&"key95".to_string()).await,
Some("value95".to_string())
);
assert_eq!(
cache.get(&"key96".to_string()).await,
Some("value96".to_string())
);
assert_eq!(
cache.get(&"key97".to_string()).await,
Some("value97".to_string())
);
assert_eq!(
cache.get(&"key98".to_string()).await,
Some("value98".to_string())
);
assert_eq!(
cache.get(&"key99".to_string()).await,
Some("value99".to_string())
);
}
#[tokio::test]
async fn test_cache_zero_max_size() {
let cache = QueryCache::new(0, Duration::from_secs(60));
cache.insert("key1".to_string(), "value1".to_string()).await;
cache.insert("key2".to_string(), "value2".to_string()).await;
assert_eq!(cache.len().await, 0);
assert!(cache.is_empty().await);
assert!(cache.get(&"key1".to_string()).await.is_none());
assert!(cache.get(&"key2".to_string()).await.is_none());
}
#[tokio::test]
async fn test_cache_ttl_expiration() {
let cache = QueryCache::new(10, Duration::from_millis(100));
cache.insert("key".to_string(), "value".to_string()).await;
assert_eq!(cache.len().await, 1);
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(cache.get(&"key".to_string()).await.is_none());
assert_eq!(cache.len().await, 0);
}
}