use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct CacheKey([u8; 32]);
impl CacheKey {
pub fn from_bytes(data: &[u8]) -> Self {
let hash = blake3::hash(data);
Self(*hash.as_bytes())
}
pub fn from_query(query_type: &str, params: &[u8]) -> Self {
let mut hasher = blake3::Hasher::new();
hasher.update(query_type.as_bytes());
hasher.update(b":");
hasher.update(params);
let hash = hasher.finalize();
Self(*hash.as_bytes())
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
struct CacheEntry {
result: Vec<u8>,
created_at: Instant,
ttl: Duration,
access_count: AtomicU64,
last_accessed: Instant,
size_bytes: usize,
collection: Option<String>,
}
impl CacheEntry {
fn is_expired(&self) -> bool {
self.created_at.elapsed() > self.ttl
}
}
pub struct CacheStats {
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
insertions: AtomicU64,
}
impl CacheStats {
fn new() -> Self {
Self {
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
evictions: AtomicU64::new(0),
insertions: AtomicU64::new(0),
}
}
pub fn snapshot(&self) -> CacheStatsSnapshot {
CacheStatsSnapshot {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
insertions: self.insertions.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStatsSnapshot {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub insertions: u64,
}
impl CacheStatsSnapshot {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn approx_size(&self) -> u64 {
self.insertions.saturating_sub(self.evictions)
}
}
struct CacheInner {
entries: HashMap<CacheKey, CacheEntry>,
lru_order: Vec<CacheKey>,
collection_index: HashMap<String, Vec<CacheKey>>,
}
impl CacheInner {
fn new() -> Self {
Self {
entries: HashMap::new(),
lru_order: Vec::new(),
collection_index: HashMap::new(),
}
}
fn touch(&mut self, key: &CacheKey) {
if let Some(pos) = self.lru_order.iter().position(|k| k == key) {
self.lru_order.remove(pos);
}
self.lru_order.push(*key);
}
fn evict_lru(&mut self) -> Option<CacheKey> {
if self.lru_order.is_empty() {
return None;
}
let key = self.lru_order.remove(0);
self.remove_entry_inner(&key);
Some(key)
}
fn remove_entry(&mut self, key: &CacheKey) {
if let Some(pos) = self.lru_order.iter().position(|k| k == key) {
self.lru_order.remove(pos);
}
self.remove_entry_inner(key);
}
fn remove_entry_inner(&mut self, key: &CacheKey) {
if let Some(entry) = self.entries.remove(key) {
if let Some(ref coll) = entry.collection {
if let Some(keys) = self.collection_index.get_mut(coll) {
keys.retain(|k| k != key);
if keys.is_empty() {
self.collection_index.remove(coll);
}
}
}
}
}
}
pub struct QueryCache {
inner: RwLock<CacheInner>,
max_entries: AtomicUsize,
default_ttl: Duration,
max_value_size: usize,
stats: CacheStats,
}
impl QueryCache {
pub fn new(max_entries: usize, default_ttl: Duration, max_value_size: usize) -> Self {
Self {
inner: RwLock::new(CacheInner::new()),
max_entries: AtomicUsize::new(max_entries),
default_ttl,
max_value_size,
stats: CacheStats::new(),
}
}
pub fn get(&self, key: &CacheKey) -> Option<Vec<u8>> {
let mut inner = self.inner.write();
if let Some(entry) = inner.entries.get(key) {
if entry.is_expired() {
inner.remove_entry(key);
self.stats.misses.fetch_add(1, Ordering::Relaxed);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
return None;
}
let result = entry.result.clone();
if let Some(entry) = inner.entries.get_mut(key) {
entry.access_count.fetch_add(1, Ordering::Relaxed);
entry.last_accessed = Instant::now();
}
inner.touch(key);
self.stats.hits.fetch_add(1, Ordering::Relaxed);
Some(result)
} else {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub fn put(&self, key: CacheKey, result: Vec<u8>) {
self.put_with_options(key, result, self.default_ttl, None);
}
pub fn put_with_ttl(&self, key: CacheKey, result: Vec<u8>, ttl: Duration) {
self.put_with_options(key, result, ttl, None);
}
pub fn put_with_options(
&self,
key: CacheKey,
result: Vec<u8>,
ttl: Duration,
collection: Option<&str>,
) {
if result.len() > self.max_value_size {
return; }
let size_bytes = result.len();
let now = Instant::now();
let entry = CacheEntry {
result,
created_at: now,
ttl,
access_count: AtomicU64::new(0),
last_accessed: now,
size_bytes,
collection: collection.map(String::from),
};
let mut inner = self.inner.write();
if inner.entries.contains_key(&key) {
inner.remove_entry(&key);
}
let max = self.max_entries.load(Ordering::Relaxed);
while inner.entries.len() >= max {
if inner.evict_lru().is_some() {
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
} else {
break;
}
}
if let Some(ref coll) = entry.collection {
inner
.collection_index
.entry(coll.clone())
.or_default()
.push(key);
}
inner.entries.insert(key, entry);
inner.lru_order.push(key);
self.stats.insertions.fetch_add(1, Ordering::Relaxed);
}
pub fn invalidate(&self, collection: &str) {
let mut inner = self.inner.write();
if let Some(keys) = inner.collection_index.remove(collection) {
let evicted = keys.len() as u64;
for key in &keys {
if let Some(pos) = inner.lru_order.iter().position(|k| k == key) {
inner.lru_order.remove(pos);
}
inner.entries.remove(key);
}
self.stats.evictions.fetch_add(evicted, Ordering::Relaxed);
}
}
pub fn invalidate_all(&self) {
let mut inner = self.inner.write();
let evicted = inner.entries.len() as u64;
inner.entries.clear();
inner.lru_order.clear();
inner.collection_index.clear();
self.stats.evictions.fetch_add(evicted, Ordering::Relaxed);
}
pub fn stats(&self) -> CacheStatsSnapshot {
self.stats.snapshot()
}
pub fn resize(&self, new_max: usize) {
let mut inner = self.inner.write();
while inner.entries.len() > new_max {
if inner.evict_lru().is_some() {
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
} else {
break;
}
}
drop(inner);
self.max_entries.store(new_max, Ordering::SeqCst);
}
pub fn len(&self) -> usize {
let inner = self.inner.read();
inner.entries.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn total_size_bytes(&self) -> usize {
let inner = self.inner.read();
inner.entries.values().map(|e| e.size_bytes).sum()
}
}
impl std::fmt::Debug for QueryCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let snap = self.stats.snapshot();
f.debug_struct("QueryCache")
.field("max_entries", &self.max_entries)
.field("default_ttl", &self.default_ttl)
.field("max_value_size", &self.max_value_size)
.field("len", &self.len())
.field("stats", &snap)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
fn test_cache(max_entries: usize) -> QueryCache {
QueryCache::new(max_entries, Duration::from_secs(60), 1024 * 1024)
}
#[test]
fn test_cache_put_get() {
let cache = test_cache(100);
let key = CacheKey::from_bytes(b"select * from users");
cache.put(key, vec![1, 2, 3, 4]);
let result = cache.get(&key);
assert!(result.is_some());
assert_eq!(result.expect("should have value"), vec![1, 2, 3, 4]);
}
#[test]
fn test_cache_miss() {
let cache = test_cache(100);
let key = CacheKey::from_bytes(b"nonexistent query");
let result = cache.get(&key);
assert!(result.is_none());
let snap = cache.stats();
assert_eq!(snap.hits, 0);
assert_eq!(snap.misses, 1);
}
#[test]
fn test_cache_ttl_expiry() {
let cache = QueryCache::new(100, Duration::from_millis(50), 1024 * 1024);
let key = CacheKey::from_bytes(b"expiring query");
cache.put(key, vec![10, 20]);
assert!(cache.get(&key).is_some());
thread::sleep(Duration::from_millis(100));
assert!(cache.get(&key).is_none());
let snap = cache.stats();
assert_eq!(snap.hits, 1);
assert_eq!(snap.misses, 1);
assert_eq!(snap.evictions, 1); }
#[test]
fn test_cache_hit_updates_stats() {
let cache = test_cache(100);
let key = CacheKey::from_bytes(b"stats query");
cache.put(key, vec![1]);
for _ in 0..5 {
let _ = cache.get(&key);
}
let snap = cache.stats();
assert_eq!(snap.hits, 5);
assert_eq!(snap.misses, 0);
}
#[test]
fn test_cache_miss_updates_stats() {
let cache = test_cache(100);
for i in 0..3u8 {
let key = CacheKey::from_bytes(&[i]);
let _ = cache.get(&key);
}
let snap = cache.stats();
assert_eq!(snap.hits, 0);
assert_eq!(snap.misses, 3);
}
#[test]
fn test_cache_lru_eviction() {
let cache = test_cache(3);
let keys: Vec<CacheKey> = (0..3u8).map(|i| CacheKey::from_bytes(&[i])).collect();
for (i, key) in keys.iter().enumerate() {
cache.put(*key, vec![i as u8]);
}
assert_eq!(cache.len(), 3);
let _ = cache.get(&keys[0]);
let key3 = CacheKey::from_bytes(&[3u8]);
cache.put(key3, vec![3]);
assert_eq!(cache.len(), 3);
assert!(
cache.get(&keys[0]).is_some(),
"key[0] was accessed and should survive"
);
assert!(
cache.get(&keys[1]).is_none(),
"key[1] should have been evicted"
);
assert!(
cache.get(&keys[2]).is_some(),
"key[2] should still be present"
);
assert!(cache.get(&key3).is_some(), "key[3] was just inserted");
let snap = cache.stats();
assert!(snap.evictions >= 1);
}
#[test]
fn test_cache_invalidate_collection() {
let cache = test_cache(100);
let k1 = CacheKey::from_query("filter", b"users:age>18");
let k2 = CacheKey::from_query("get", b"users:id=1");
let k3 = CacheKey::from_query("filter", b"orders:total>100");
cache.put_with_options(k1, vec![1], Duration::from_secs(60), Some("users"));
cache.put_with_options(k2, vec![2], Duration::from_secs(60), Some("users"));
cache.put_with_options(k3, vec![3], Duration::from_secs(60), Some("orders"));
assert_eq!(cache.len(), 3);
cache.invalidate("users");
assert_eq!(cache.len(), 1);
assert!(cache.get(&k1).is_none());
assert!(cache.get(&k2).is_none());
assert!(cache.get(&k3).is_some(), "orders entry should remain");
}
#[test]
fn test_cache_invalidate_all() {
let cache = test_cache(100);
for i in 0..10u8 {
let key = CacheKey::from_bytes(&[i]);
cache.put(key, vec![i]);
}
assert_eq!(cache.len(), 10);
cache.invalidate_all();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
let snap = cache.stats();
assert_eq!(snap.evictions, 10);
}
#[test]
fn test_cache_hit_rate() {
let cache = test_cache(100);
let key = CacheKey::from_bytes(b"rate query");
cache.put(key, vec![1]);
for _ in 0..3 {
let _ = cache.get(&key);
}
let missing = CacheKey::from_bytes(b"no such key");
let _ = cache.get(&missing);
let snap = cache.stats();
assert!((snap.hit_rate() - 0.75).abs() < 1e-9);
let empty_cache = test_cache(10);
let snap = empty_cache.stats();
assert!((snap.hit_rate() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_cache_concurrent_access() {
let cache = Arc::new(test_cache(500));
let mut handles = Vec::new();
for t in 0..4 {
let c = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..200u64 {
let key_bytes = format!("thread-{}-key-{}", t, i);
let key = CacheKey::from_bytes(key_bytes.as_bytes());
c.put(key, vec![t as u8; 64]);
}
}));
}
for t in 0..4 {
let c = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..200u64 {
let key_bytes = format!("thread-{}-key-{}", t, i);
let key = CacheKey::from_bytes(key_bytes.as_bytes());
let _ = c.get(&key);
}
}));
}
for h in handles {
h.join().expect("thread should not panic");
}
let snap = cache.stats();
assert!(snap.insertions > 0);
assert!(cache.len() <= 500);
}
#[test]
fn test_cache_max_value_size() {
let cache = QueryCache::new(100, Duration::from_secs(60), 100);
let k1 = CacheKey::from_bytes(b"small");
cache.put(k1, vec![0u8; 100]);
assert!(cache.get(&k1).is_some());
let k2 = CacheKey::from_bytes(b"big");
cache.put(k2, vec![0u8; 101]);
assert!(cache.get(&k2).is_none());
let snap = cache.stats();
assert_eq!(snap.insertions, 1); }
#[test]
fn test_cache_resize() {
let cache = test_cache(10);
for i in 0..10u8 {
let key = CacheKey::from_bytes(&[i]);
cache.put(key, vec![i]);
}
assert_eq!(cache.len(), 10);
cache.resize(5);
assert_eq!(cache.len(), 5);
let snap = cache.stats();
assert_eq!(snap.evictions, 5);
for i in 100..106u8 {
let key = CacheKey::from_bytes(&[i]);
cache.put(key, vec![i]);
}
assert!(cache.len() <= 5);
}
#[test]
fn test_cache_key_generation() {
let k1 = CacheKey::from_query("filter", b"users:age>18");
let k2 = CacheKey::from_query("filter", b"users:age>18");
assert_eq!(k1, k2, "same query should produce the same key");
let k3 = CacheKey::from_bytes(b"hello world");
let k4 = CacheKey::from_bytes(b"hello world");
assert_eq!(k3, k4);
}
#[test]
fn test_cache_different_queries() {
let k1 = CacheKey::from_query("filter", b"users:age>18");
let k2 = CacheKey::from_query("filter", b"users:age>21");
assert_ne!(k1, k2, "different params should produce different keys");
let k3 = CacheKey::from_query("filter", b"users:age>18");
let k4 = CacheKey::from_query("get", b"users:age>18");
assert_ne!(
k3, k4,
"different query types should produce different keys"
);
}
#[test]
fn test_total_size_bytes() {
let cache = test_cache(100);
let k1 = CacheKey::from_bytes(b"a");
let k2 = CacheKey::from_bytes(b"b");
cache.put(k1, vec![0u8; 100]);
cache.put(k2, vec![0u8; 200]);
assert_eq!(cache.total_size_bytes(), 300);
}
#[test]
fn test_put_with_custom_ttl() {
let cache = QueryCache::new(100, Duration::from_secs(300), 1024 * 1024);
let key = CacheKey::from_bytes(b"short lived");
cache.put_with_ttl(key, vec![1, 2], Duration::from_millis(50));
assert!(cache.get(&key).is_some());
thread::sleep(Duration::from_millis(100));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_debug_format() {
let cache = test_cache(10);
let dbg = format!("{:?}", cache);
assert!(dbg.contains("QueryCache"));
assert!(dbg.contains("max_entries"));
}
}