use parking_lot::RwLock;
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct QueryCacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub max_value_size: usize,
pub invalidation_policy: InvalidationPolicy,
}
impl Default for QueryCacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
ttl: Duration::from_secs(60),
max_value_size: 1024 * 1024, invalidation_policy: InvalidationPolicy::OnWrite,
}
}
}
impl QueryCacheConfig {
#[must_use]
pub fn with_max_entries(mut self, max_entries: usize) -> Self {
self.max_entries = max_entries;
self
}
#[must_use]
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
#[must_use]
pub fn with_max_value_size(mut self, max_value_size: usize) -> Self {
self.max_value_size = max_value_size;
self
}
#[must_use]
pub fn with_invalidation_policy(mut self, policy: InvalidationPolicy) -> Self {
self.invalidation_policy = policy;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InvalidationPolicy {
OnWrite,
Manual,
None,
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub size: usize,
pub total_inserts: u64,
pub invalidations: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug, Clone)]
struct CachedResult {
data: Vec<u8>,
inserted_at: Instant,
hit_count: u64,
collection: Option<String>,
}
#[derive(Debug, Clone)]
struct LruNode {
prev: Option<CacheKey>,
next: Option<CacheKey>,
}
type CacheKey = [u8; 32];
struct CacheInner {
entries: HashMap<CacheKey, (CachedResult, LruNode)>,
head: Option<CacheKey>,
tail: Option<CacheKey>,
collection_index: HashMap<String, Vec<CacheKey>>,
stats: CacheStats,
}
impl CacheInner {
fn new() -> Self {
Self {
entries: HashMap::new(),
head: None,
tail: None,
collection_index: HashMap::new(),
stats: CacheStats {
hits: 0,
misses: 0,
evictions: 0,
size: 0,
total_inserts: 0,
invalidations: 0,
},
}
}
fn detach(&mut self, key: &CacheKey) {
let node = if let Some((_, node)) = self.entries.get(key) {
node.clone()
} else {
return;
};
if let Some(prev_key) = &node.prev {
if let Some((_, prev_node)) = self.entries.get_mut(prev_key) {
prev_node.next = node.next;
}
} else {
self.head = node.next;
}
if let Some(next_key) = &node.next {
if let Some((_, next_node)) = self.entries.get_mut(next_key) {
next_node.prev = node.prev;
}
} else {
self.tail = node.prev;
}
if let Some((_, n)) = self.entries.get_mut(key) {
n.prev = None;
n.next = None;
}
}
fn push_front(&mut self, key: CacheKey) {
if let Some(old_head) = self.head {
if old_head == key {
return; }
if let Some((_, node)) = self.entries.get_mut(&old_head) {
node.prev = Some(key);
}
}
if let Some((_, node)) = self.entries.get_mut(&key) {
node.prev = None;
node.next = self.head;
}
self.head = Some(key);
if self.tail.is_none() {
self.tail = Some(key);
}
}
fn touch(&mut self, key: &CacheKey) {
let k = *key;
self.detach(&k);
self.push_front(k);
}
fn evict_lru(&mut self) -> Option<CacheKey> {
let tail_key = self.tail?;
self.remove_entry(&tail_key);
self.stats.evictions += 1;
Some(tail_key)
}
fn remove_entry(&mut self, key: &CacheKey) {
self.detach(key);
if let Some((result, _)) = self.entries.remove(key) {
self.stats.size = self.stats.size.saturating_sub(1);
if let Some(ref coll) = result.collection {
if let Some(keys) = self.collection_index.get_mut(coll) {
keys.retain(|k| k != key);
if keys.is_empty() {
self.collection_index.remove(coll);
}
}
}
}
}
}
pub struct QueryCache {
inner: RwLock<CacheInner>,
config: QueryCacheConfig,
}
impl QueryCache {
pub fn new(config: QueryCacheConfig) -> Self {
Self {
inner: RwLock::new(CacheInner::new()),
config,
}
}
pub fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
let cache_key = Self::hash_key(key);
{
let inner = self.inner.read();
match inner.entries.get(&cache_key) {
Some((result, _)) => {
if result.inserted_at.elapsed() > self.config.ttl {
drop(inner);
let mut inner = self.inner.write();
inner.remove_entry(&cache_key);
inner.stats.misses += 1;
return None;
}
}
None => {
drop(inner);
let mut inner = self.inner.write();
inner.stats.misses += 1;
return None;
}
}
}
let mut inner = self.inner.write();
if let Some((result, _)) = inner.entries.get_mut(&cache_key) {
if result.inserted_at.elapsed() > self.config.ttl {
inner.remove_entry(&cache_key);
inner.stats.misses += 1;
return None;
}
result.hit_count += 1;
let data = result.data.clone();
inner.stats.hits += 1;
inner.touch(&cache_key);
Some(data)
} else {
inner.stats.misses += 1;
None
}
}
pub fn put(&self, key: &[u8], value: Vec<u8>) {
self.put_with_collection(key, value, None);
}
pub fn put_with_collection(&self, key: &[u8], value: Vec<u8>, collection: Option<&str>) {
if value.len() > self.config.max_value_size {
return; }
let cache_key = Self::hash_key(key);
let mut inner = self.inner.write();
if inner.entries.contains_key(&cache_key) {
inner.remove_entry(&cache_key);
}
while inner.entries.len() >= self.config.max_entries {
inner.evict_lru();
}
let coll_string = collection.map(String::from);
if let Some(ref coll) = coll_string {
inner
.collection_index
.entry(coll.clone())
.or_default()
.push(cache_key);
}
let result = CachedResult {
data: value,
inserted_at: Instant::now(),
hit_count: 0,
collection: coll_string,
};
let node = LruNode {
prev: None,
next: None,
};
inner.entries.insert(cache_key, (result, node));
inner.stats.size += 1;
inner.stats.total_inserts += 1;
inner.push_front(cache_key);
}
pub fn invalidate(&self, key: &[u8]) {
let cache_key = Self::hash_key(key);
let mut inner = self.inner.write();
if inner.entries.contains_key(&cache_key) {
inner.remove_entry(&cache_key);
inner.stats.invalidations += 1;
}
}
pub fn invalidate_collection(&self, collection: &str) {
let mut inner = self.inner.write();
if let Some(keys) = inner.collection_index.remove(collection) {
for key in &keys {
inner.detach(key);
inner.entries.remove(key);
inner.stats.size = inner.stats.size.saturating_sub(1);
inner.stats.invalidations += 1;
}
}
}
pub fn clear(&self) {
let mut inner = self.inner.write();
let prev_size = inner.entries.len();
inner.entries.clear();
inner.head = None;
inner.tail = None;
inner.collection_index.clear();
inner.stats.size = 0;
inner.stats.invalidations += prev_size as u64;
}
pub fn stats(&self) -> CacheStats {
let inner = self.inner.read();
inner.stats.clone()
}
pub fn len(&self) -> usize {
let inner = self.inner.read();
inner.entries.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn config(&self) -> &QueryCacheConfig {
&self.config
}
pub fn invalidation_policy(&self) -> InvalidationPolicy {
self.config.invalidation_policy
}
pub fn make_key(collection: &str, query_key: &[u8]) -> Vec<u8> {
let mut buf = Vec::with_capacity(collection.len() + 1 + query_key.len());
buf.extend_from_slice(collection.as_bytes());
buf.push(b':');
buf.extend_from_slice(query_key);
buf
}
fn hash_key(key: &[u8]) -> CacheKey {
let hash = blake3::hash(key);
*hash.as_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
fn default_cache() -> QueryCache {
QueryCache::new(QueryCacheConfig::default())
}
#[test]
fn test_cache_hit() {
let cache = default_cache();
cache.put(b"key1", vec![10, 20, 30]);
let result = cache.get(b"key1");
assert!(result.is_some());
assert_eq!(result.expect("should have value"), vec![10, 20, 30]);
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_cache_miss() {
let cache = default_cache();
let result = cache.get(b"nonexistent");
assert!(result.is_none());
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_ttl_expiry() {
let config = QueryCacheConfig::default().with_ttl(Duration::from_millis(50));
let cache = QueryCache::new(config);
cache.put(b"key1", vec![1, 2, 3]);
assert!(cache.get(b"key1").is_some());
thread::sleep(Duration::from_millis(80));
assert!(cache.get(b"key1").is_none());
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1); }
#[test]
fn test_lru_eviction() {
let config = QueryCacheConfig::default().with_max_entries(3);
let cache = QueryCache::new(config);
cache.put(b"a", vec![1]);
cache.put(b"b", vec![2]);
cache.put(b"c", vec![3]);
cache.put(b"d", vec![4]);
assert!(cache.get(b"a").is_none(), "a should have been evicted");
assert!(cache.get(b"b").is_some());
assert!(cache.get(b"c").is_some());
assert!(cache.get(b"d").is_some());
let stats = cache.stats();
assert_eq!(stats.evictions, 1);
}
#[test]
fn test_lru_access_order() {
let config = QueryCacheConfig::default().with_max_entries(3);
let cache = QueryCache::new(config);
cache.put(b"a", vec![1]);
cache.put(b"b", vec![2]);
cache.put(b"c", vec![3]);
let _ = cache.get(b"a");
cache.put(b"d", vec![4]);
assert!(
cache.get(b"a").is_some(),
"a was accessed and should not be evicted"
);
assert!(cache.get(b"b").is_none(), "b should have been evicted");
assert!(cache.get(b"c").is_some());
assert!(cache.get(b"d").is_some());
}
#[test]
fn test_invalidate_key() {
let cache = default_cache();
cache.put(b"key1", vec![1]);
cache.put(b"key2", vec![2]);
cache.invalidate(b"key1");
assert!(cache.get(b"key1").is_none());
assert!(cache.get(b"key2").is_some());
let stats = cache.stats();
assert_eq!(stats.invalidations, 1);
}
#[test]
fn test_invalidate_collection() {
let cache = default_cache();
let key1 = QueryCache::make_key("users", b"u1");
let key2 = QueryCache::make_key("users", b"u2");
let key3 = QueryCache::make_key("orders", b"o1");
cache.put_with_collection(&key1, vec![1], Some("users"));
cache.put_with_collection(&key2, vec![2], Some("users"));
cache.put_with_collection(&key3, vec![3], Some("orders"));
cache.invalidate_collection("users");
assert!(cache.get(&key1).is_none());
assert!(cache.get(&key2).is_none());
assert!(cache.get(&key3).is_some(), "orders entry should remain");
let stats = cache.stats();
assert_eq!(stats.invalidations, 2);
}
#[test]
fn test_stats_accuracy() {
let cache = default_cache();
cache.put(b"a", vec![1]);
cache.put(b"b", vec![2]);
cache.put(b"c", vec![3]);
let _ = cache.get(b"a");
let _ = cache.get(b"b");
let _ = cache.get(b"z");
cache.invalidate(b"c");
let stats = cache.stats();
assert_eq!(stats.total_inserts, 3);
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert_eq!(stats.invalidations, 1);
assert_eq!(stats.size, 2);
let rate = stats.hit_rate();
assert!((rate - 2.0 / 3.0).abs() < 1e-9);
}
#[test]
fn test_hit_rate_no_lookups() {
let cache = default_cache();
let stats = cache.stats();
assert!((stats.hit_rate() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_concurrent_access() {
use std::sync::Arc;
let cache = Arc::new(QueryCache::new(
QueryCacheConfig::default().with_max_entries(500),
));
let mut handles = Vec::new();
for t in 0..4 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..200 {
let key = format!("thread-{}-key-{}", t, i);
cache.put(key.as_bytes(), vec![t as u8; 64]);
}
}));
}
for t in 0..4 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..200 {
let key = format!("thread-{}-key-{}", t, i);
let _ = cache.get(key.as_bytes());
}
}));
}
for h in handles {
h.join().expect("thread should not panic");
}
let stats = cache.stats();
assert!(stats.total_inserts > 0);
assert!(stats.size <= 500);
}
#[test]
fn test_max_value_size_enforcement() {
let config = QueryCacheConfig::default().with_max_value_size(100);
let cache = QueryCache::new(config);
cache.put(b"small", vec![0u8; 100]);
assert!(cache.get(b"small").is_some());
cache.put(b"big", vec![0u8; 101]);
assert!(cache.get(b"big").is_none());
let stats = cache.stats();
assert_eq!(stats.total_inserts, 1); }
#[test]
fn test_clear() {
let cache = default_cache();
cache.put(b"a", vec![1]);
cache.put(b"b", vec![2]);
cache.put(b"c", vec![3]);
assert_eq!(cache.len(), 3);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
assert!(cache.get(b"a").is_none());
let stats = cache.stats();
assert_eq!(stats.size, 0);
assert_eq!(stats.invalidations, 3);
}
#[test]
fn test_make_key() {
let key = QueryCache::make_key("users", b"abc");
assert_eq!(key, b"users:abc");
}
#[test]
fn test_overwrite_existing_key() {
let cache = default_cache();
cache.put(b"key", vec![1, 2, 3]);
assert_eq!(cache.get(b"key").expect("should exist"), vec![1, 2, 3]);
cache.put(b"key", vec![4, 5, 6]);
assert_eq!(cache.get(b"key").expect("should exist"), vec![4, 5, 6]);
assert_eq!(cache.len(), 1);
let stats = cache.stats();
assert_eq!(stats.total_inserts, 2);
}
#[test]
fn test_invalidation_policy_config() {
let config =
QueryCacheConfig::default().with_invalidation_policy(InvalidationPolicy::Manual);
let cache = QueryCache::new(config);
assert_eq!(cache.invalidation_policy(), InvalidationPolicy::Manual);
}
#[test]
fn test_single_entry_cache() {
let config = QueryCacheConfig::default().with_max_entries(1);
let cache = QueryCache::new(config);
cache.put(b"a", vec![1]);
assert!(cache.get(b"a").is_some());
cache.put(b"b", vec![2]);
assert!(cache.get(b"a").is_none());
assert!(cache.get(b"b").is_some());
let stats = cache.stats();
assert_eq!(stats.evictions, 1);
}
}