use std::{
collections::HashSet,
sync::{
Arc,
atomic::{AtomicU64, AtomicUsize, Ordering},
},
time::Duration,
};
use dashmap::{DashMap, DashSet};
use moka::sync::Cache as MokaCache;
use serde::{Deserialize, Serialize};
use super::config::CacheConfig;
use crate::{db::types::JsonbValue, error::Result};
#[derive(Debug, Clone)]
pub struct CachedResult {
pub result: Arc<Vec<JsonbValue>>,
pub accessed_views: Box<[String]>,
pub cached_at: u64,
pub ttl_seconds: u64,
pub entity_refs: Box<[(String, String)]>,
pub is_list_query: bool,
}
struct CacheEntryExpiry;
impl moka::Expiry<u64, Arc<CachedResult>> for CacheEntryExpiry {
fn expire_after_create(
&self,
_key: &u64,
value: &Arc<CachedResult>,
_created_at: std::time::Instant,
) -> Option<Duration> {
if value.ttl_seconds == 0 {
None
} else {
Some(Duration::from_secs(value.ttl_seconds))
}
}
}
pub struct QueryResultCache {
store: MokaCache<u64, Arc<CachedResult>>,
config: CacheConfig,
hits: AtomicU64,
misses: AtomicU64,
total_cached: AtomicU64,
invalidations: AtomicU64,
memory_bytes: Arc<AtomicUsize>,
view_index: Arc<DashMap<String, DashSet<u64>>>,
entity_index: Arc<DashMap<String, DashMap<String, DashSet<u64>>>>,
list_index: Arc<DashMap<String, DashSet<u64>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheMetrics {
pub hits: u64,
pub misses: u64,
pub total_cached: u64,
pub invalidations: u64,
pub size: usize,
pub memory_bytes: usize,
}
const fn entry_overhead() -> usize {
std::mem::size_of::<CachedResult>() + std::mem::size_of::<u64>() * 2
}
fn build_store(
config: &CacheConfig,
memory_bytes: Arc<AtomicUsize>,
view_index: Arc<DashMap<String, DashSet<u64>>>,
entity_index: Arc<DashMap<String, DashMap<String, DashSet<u64>>>>,
list_index: Arc<DashMap<String, DashSet<u64>>>,
) -> MokaCache<u64, Arc<CachedResult>> {
let max_cap = config.max_entries as u64;
let mb = memory_bytes;
let vi = view_index;
let ei = entity_index;
let li = list_index;
MokaCache::builder()
.max_capacity(max_cap)
.expire_after(CacheEntryExpiry)
.eviction_listener(move |key: Arc<u64>, value: Arc<CachedResult>, _cause| {
mb.fetch_sub(entry_overhead(), Ordering::Relaxed);
for view in &value.accessed_views {
if let Some(keys) = vi.get(view) {
keys.remove(&*key);
}
}
if value.is_list_query {
for view in &value.accessed_views {
if let Some(keys) = li.get(view) {
keys.remove(&*key);
}
}
}
for (et, id) in &*value.entity_refs {
if let Some(by_type) = ei.get(et) {
if let Some(keys) = by_type.get(id) {
keys.remove(&*key);
}
}
}
})
.build()
}
impl QueryResultCache {
#[must_use]
pub fn new(config: CacheConfig) -> Self {
assert!(config.max_entries > 0, "max_entries must be > 0");
let memory_bytes = Arc::new(AtomicUsize::new(0));
let view_index: Arc<DashMap<String, DashSet<u64>>> = Arc::new(DashMap::new());
let entity_index: Arc<DashMap<String, DashMap<String, DashSet<u64>>>> =
Arc::new(DashMap::new());
let list_index: Arc<DashMap<String, DashSet<u64>>> = Arc::new(DashMap::new());
let store = build_store(
&config,
Arc::clone(&memory_bytes),
Arc::clone(&view_index),
Arc::clone(&entity_index),
Arc::clone(&list_index),
);
Self {
store,
config,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
total_cached: AtomicU64::new(0),
invalidations: AtomicU64::new(0),
memory_bytes,
view_index,
entity_index,
list_index,
}
}
#[must_use]
pub const fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn get(&self, cache_key: u64) -> Result<Option<Arc<Vec<JsonbValue>>>> {
if !self.config.enabled {
return Ok(None);
}
if let Some(cached) = self.store.get(&cache_key) {
self.hits.fetch_add(1, Ordering::Relaxed);
Ok(Some(Arc::clone(&cached.result)))
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
Ok(None)
}
}
pub fn put_arc(
&self,
cache_key: u64,
result: Arc<Vec<JsonbValue>>,
accessed_views: Vec<String>,
ttl_override: Option<u64>,
entity_type: Option<&str>,
) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let ttl_seconds = ttl_override.unwrap_or(self.config.ttl_seconds);
if !self.config.cache_list_queries && result.len() > 1 {
return Ok(());
}
if let Some(max_entry) = self.config.max_entry_bytes {
let estimated = serde_json::to_vec(&*result).map_or(0, |v| v.len());
if estimated > max_entry {
return Ok(()); }
}
if let Some(max_total) = self.config.max_total_bytes {
if self.memory_bytes.load(Ordering::Relaxed) >= max_total {
return Ok(()); }
}
let is_list_query = result.len() > 1;
let entity_refs: Box<[(String, String)]> = if let Some(et) = entity_type {
result
.iter()
.filter_map(|row| {
row.as_value()
.as_object()?
.get("id")?
.as_str()
.map(|id| (et.to_string(), id.to_string()))
})
.collect::<Vec<_>>()
.into_boxed_slice()
} else {
Box::default()
};
for view in &accessed_views {
self.view_index.entry(view.clone()).or_default().insert(cache_key);
}
if is_list_query {
for view in &accessed_views {
self.list_index.entry(view.clone()).or_default().insert(cache_key);
}
}
for (et, id) in &*entity_refs {
self.entity_index
.entry(et.clone())
.or_default()
.entry(id.clone())
.or_default()
.insert(cache_key);
}
let cached = CachedResult {
result,
accessed_views: accessed_views.into_boxed_slice(),
cached_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs()),
ttl_seconds,
entity_refs,
is_list_query,
};
self.memory_bytes.fetch_add(entry_overhead(), Ordering::Relaxed);
self.store.insert(cache_key, Arc::new(cached));
self.total_cached.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn put(
&self,
cache_key: u64,
result: Vec<JsonbValue>,
accessed_views: Vec<String>,
ttl_override: Option<u64>,
entity_type: Option<&str>,
) -> Result<()> {
self.put_arc(cache_key, Arc::new(result), accessed_views, ttl_override, entity_type)
}
pub fn invalidate_views(&self, views: &[String]) -> Result<u64> {
if !self.config.enabled {
return Ok(0);
}
let mut keys_to_invalidate: HashSet<u64> = HashSet::new();
for view in views {
if let Some(keys) = self.view_index.get(view) {
for key in keys.iter() {
keys_to_invalidate.insert(*key);
}
}
}
#[allow(clippy::cast_possible_truncation)]
let count = keys_to_invalidate.len() as u64;
for key in keys_to_invalidate {
self.store.invalidate(&key);
}
self.invalidations.fetch_add(count, Ordering::Relaxed);
Ok(count)
}
pub fn invalidate_list_queries(&self, views: &[String]) -> Result<u64> {
if !self.config.enabled {
return Ok(0);
}
let mut keys_to_invalidate: HashSet<u64> = HashSet::new();
for view in views {
if let Some(keys) = self.list_index.get(view) {
for k in keys.iter() {
keys_to_invalidate.insert(*k);
}
}
}
#[allow(clippy::cast_possible_truncation)]
let count = keys_to_invalidate.len() as u64;
for key in keys_to_invalidate {
self.store.invalidate(&key);
}
self.invalidations.fetch_add(count, Ordering::Relaxed);
Ok(count)
}
pub fn invalidate_by_entity(&self, entity_type: &str, entity_id: &str) -> Result<u64> {
if !self.config.enabled {
return Ok(0);
}
if !self.entity_index.contains_key(entity_type) {
return Ok(0);
}
let keys_to_invalidate: Vec<u64> = self
.entity_index
.get(entity_type)
.and_then(|by_type| {
by_type.get(entity_id).map(|keys| keys.iter().map(|k| *k).collect())
})
.unwrap_or_default();
#[allow(clippy::cast_possible_truncation)]
let count = keys_to_invalidate.len() as u64;
for key in keys_to_invalidate {
self.store.invalidate(&key);
}
self.invalidations.fetch_add(count, Ordering::Relaxed);
Ok(count)
}
pub fn metrics(&self) -> Result<CacheMetrics> {
Ok(CacheMetrics {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
total_cached: self.total_cached.load(Ordering::Relaxed),
invalidations: self.invalidations.load(Ordering::Relaxed),
#[allow(clippy::cast_possible_truncation)]
size: self.store.entry_count() as usize,
memory_bytes: self.memory_bytes.load(Ordering::Relaxed),
})
}
pub fn clear(&self) -> Result<()> {
self.store.invalidate_all();
self.view_index.clear();
self.entity_index.clear();
self.list_index.clear();
self.memory_bytes.store(0, Ordering::Relaxed);
Ok(())
}
}
impl CacheMetrics {
#[must_use]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
{
self.hits as f64 / total as f64
}
}
#[must_use]
pub fn is_healthy(&self) -> bool {
self.hit_rate() > 0.6
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use serde_json::json;
use super::*;
fn test_result() -> Vec<JsonbValue> {
vec![JsonbValue::new(json!({"id": 1, "name": "test"}))]
}
#[test]
fn test_cache_miss() {
let cache = QueryResultCache::new(CacheConfig::enabled());
let result = cache.get(999_u64).unwrap();
assert!(result.is_none(), "Should be cache miss");
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.misses, 1);
assert_eq!(metrics.hits, 0);
}
#[test]
fn test_cache_put_and_get() {
let cache = QueryResultCache::new(CacheConfig::enabled());
let result = test_result();
cache.put(1_u64, result, vec!["v_user".to_string()], None, None).unwrap();
let cached = cache.get(1_u64).unwrap();
assert!(cached.is_some(), "Should be cache hit");
assert_eq!(cached.unwrap().len(), 1);
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.hits, 1);
assert_eq!(metrics.misses, 0);
assert_eq!(metrics.total_cached, 1);
}
#[test]
fn test_cache_hit_updates_hit_count() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
cache.get(1_u64).unwrap();
cache.get(1_u64).unwrap();
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.hits, 2);
}
#[test]
fn test_ttl_expiry() {
let config = CacheConfig {
ttl_seconds: 1,
enabled: true,
..Default::default()
};
let cache = QueryResultCache::new(config);
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
std::thread::sleep(std::time::Duration::from_secs(2));
cache.store.run_pending_tasks();
let result = cache.get(1_u64).unwrap();
assert!(result.is_none(), "Entry should be expired");
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.misses, 1); }
#[test]
fn test_per_entry_ttl_override_expires_early() {
let config = CacheConfig {
ttl_seconds: 3600,
enabled: true,
..Default::default()
};
let cache = QueryResultCache::new(config);
cache
.put(
1_u64,
test_result(),
vec!["v_ref".to_string()],
Some(1), None,
)
.unwrap();
std::thread::sleep(std::time::Duration::from_secs(2));
cache.store.run_pending_tasks();
let result = cache.get(1_u64).unwrap();
assert!(result.is_none(), "Entry with per-entry TTL=1s should have expired");
}
#[test]
fn test_per_entry_ttl_zero_cached_indefinitely() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache
.put(1_u64, test_result(), vec!["v_live".to_string()], Some(0), None)
.unwrap();
let result = cache.get(1_u64).unwrap();
assert!(result.is_some(), "Entry with TTL=0 should be cached indefinitely");
}
#[test]
fn test_ttl_not_expired() {
let config = CacheConfig {
ttl_seconds: 3600, enabled: true,
..Default::default()
};
let cache = QueryResultCache::new(config);
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
let result = cache.get(1_u64).unwrap();
assert!(result.is_some(), "Entry should not be expired");
}
#[test]
fn test_capacity_eviction() {
let config = CacheConfig {
max_entries: 2,
enabled: true,
..Default::default()
};
let cache = QueryResultCache::new(config);
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
cache.put(2_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
cache.put(3_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
cache.store.run_pending_tasks();
let metrics = cache.metrics().unwrap();
assert!(metrics.size <= 2, "Cache size should not exceed max capacity");
}
#[test]
fn test_cache_disabled() {
let config = CacheConfig::disabled();
let cache = QueryResultCache::new(config);
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
assert!(cache.get(1_u64).unwrap().is_none(), "Cache disabled should always miss");
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.total_cached, 0);
}
#[test]
fn test_invalidate_single_view() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
cache.put(2_u64, test_result(), vec!["v_post".to_string()], None, None).unwrap();
let invalidated = cache.invalidate_views(&["v_user".to_string()]).unwrap();
assert_eq!(invalidated, 1);
assert!(cache.get(1_u64).unwrap().is_none());
assert!(cache.get(2_u64).unwrap().is_some());
}
#[test]
fn test_invalidate_multiple_views() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
cache.put(2_u64, test_result(), vec!["v_post".to_string()], None, None).unwrap();
cache
.put(3_u64, test_result(), vec!["v_product".to_string()], None, None)
.unwrap();
let invalidated =
cache.invalidate_views(&["v_user".to_string(), "v_post".to_string()]).unwrap();
assert_eq!(invalidated, 2);
assert!(cache.get(1_u64).unwrap().is_none());
assert!(cache.get(2_u64).unwrap().is_none());
assert!(cache.get(3_u64).unwrap().is_some());
}
#[test]
fn test_invalidate_entry_with_multiple_views() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache
.put(
1_u64,
test_result(),
vec!["v_user".to_string(), "v_post".to_string()],
None,
None,
)
.unwrap();
let invalidated = cache.invalidate_views(&["v_user".to_string()]).unwrap();
assert_eq!(invalidated, 1);
assert!(cache.get(1_u64).unwrap().is_none());
}
#[test]
fn test_invalidate_nonexistent_view() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
let invalidated = cache.invalidate_views(&["v_nonexistent".to_string()]).unwrap();
assert_eq!(invalidated, 0);
assert!(cache.get(1_u64).unwrap().is_some());
}
#[test]
fn test_clear() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
cache.put(2_u64, test_result(), vec!["v_post".to_string()], None, None).unwrap();
cache.clear().unwrap();
cache.store.run_pending_tasks();
assert!(cache.get(1_u64).unwrap().is_none());
assert!(cache.get(2_u64).unwrap().is_none());
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.size, 0);
}
#[test]
fn test_metrics_tracking() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache.get(999_u64).unwrap();
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
cache.get(1_u64).unwrap();
cache.store.run_pending_tasks();
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.hits, 1);
assert_eq!(metrics.misses, 1);
assert_eq!(metrics.size, 1);
assert_eq!(metrics.total_cached, 1);
}
#[test]
fn test_metrics_hit_rate() {
let metrics = CacheMetrics {
hits: 80,
misses: 20,
total_cached: 100,
invalidations: 5,
size: 95,
memory_bytes: 1_000_000,
};
assert!((metrics.hit_rate() - 0.8).abs() < f64::EPSILON);
assert!(metrics.is_healthy());
}
#[test]
fn test_metrics_hit_rate_zero_requests() {
let metrics = CacheMetrics {
hits: 0,
misses: 0,
total_cached: 0,
invalidations: 0,
size: 0,
memory_bytes: 0,
};
assert!((metrics.hit_rate() - 0.0).abs() < f64::EPSILON);
assert!(!metrics.is_healthy());
}
#[test]
fn test_metrics_is_healthy() {
let good = CacheMetrics {
hits: 70,
misses: 30,
total_cached: 100,
invalidations: 5,
size: 95,
memory_bytes: 1_000_000,
};
assert!(good.is_healthy());
let bad = CacheMetrics {
hits: 50,
misses: 50,
total_cached: 100,
invalidations: 5,
size: 95,
memory_bytes: 1_000_000,
};
assert!(!bad.is_healthy()); }
fn entity_result(id: &str) -> Vec<JsonbValue> {
vec![JsonbValue::new(
serde_json::json!({"id": id, "name": "test"}),
)]
}
#[test]
fn test_invalidate_by_entity_only_removes_matching_entries() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache
.put(1_u64, entity_result("uuid-a"), vec!["v_user".to_string()], None, Some("User"))
.unwrap();
cache
.put(2_u64, entity_result("uuid-b"), vec!["v_user".to_string()], None, Some("User"))
.unwrap();
let evicted = cache.invalidate_by_entity("User", "uuid-a").unwrap();
assert_eq!(evicted, 1);
assert!(cache.get(1_u64).unwrap().is_none(), "User A should be evicted");
assert!(cache.get(2_u64).unwrap().is_some(), "User B should remain");
}
#[test]
fn test_invalidate_by_entity_removes_list_containing_entity() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache
.put(1_u64, entity_result("uuid-a"), vec!["v_user".to_string()], None, Some("User"))
.unwrap();
let evicted = cache.invalidate_by_entity("User", "uuid-a").unwrap();
assert_eq!(evicted, 1);
assert!(cache.get(1_u64).unwrap().is_none(), "Entry for A should be evicted");
}
#[test]
fn test_invalidate_by_entity_leaves_unrelated_types() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache
.put(
1_u64,
entity_result("uuid-user"),
vec!["v_user".to_string()],
None,
Some("User"),
)
.unwrap();
cache
.put(
2_u64,
entity_result("uuid-post"),
vec!["v_post".to_string()],
None,
Some("Post"),
)
.unwrap();
let evicted = cache.invalidate_by_entity("User", "uuid-user").unwrap();
assert_eq!(evicted, 1);
assert!(cache.get(1_u64).unwrap().is_none(), "User entry should be evicted");
assert!(cache.get(2_u64).unwrap().is_some(), "Post entry should remain");
}
#[test]
fn test_put_builds_entity_id_index() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache
.put(1_u64, entity_result("uuid-1"), vec!["v_user".to_string()], None, Some("User"))
.unwrap();
let evicted = cache.invalidate_by_entity("User", "uuid-1").unwrap();
assert_eq!(evicted, 1);
assert!(cache.get(1_u64).unwrap().is_none());
}
#[test]
fn test_put_without_entity_type_not_indexed() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache
.put(
1_u64,
entity_result("uuid-1"),
vec!["v_user".to_string()],
None,
None, )
.unwrap();
let evicted = cache.invalidate_by_entity("User", "uuid-1").unwrap();
assert_eq!(evicted, 0);
assert!(cache.get(1_u64).unwrap().is_some(), "Non-indexed entry should remain");
}
fn list_result(ids: &[&str]) -> Vec<JsonbValue> {
ids.iter()
.map(|id| JsonbValue::new(serde_json::json!({"id": id, "name": "test"})))
.collect()
}
#[test]
fn test_put_indexes_all_entities_in_list() {
let cache = QueryResultCache::new(CacheConfig::enabled());
let rows = list_result(&["uuid-A", "uuid-B", "uuid-C"]);
cache.put(0xABC, rows, vec!["v_user".to_string()], None, Some("User")).unwrap();
let evicted_a = cache.invalidate_by_entity("User", "uuid-A").unwrap();
assert_eq!(evicted_a, 1, "uuid-A must be indexed and evictable");
let rows2 = list_result(&["uuid-A", "uuid-B", "uuid-C"]);
cache.put(0xDEF, rows2, vec!["v_user".to_string()], None, Some("User")).unwrap();
let evicted_c = cache.invalidate_by_entity("User", "uuid-C").unwrap();
assert_eq!(evicted_c, 1, "uuid-C at position 2 must also be indexed");
}
#[test]
fn test_update_evicts_list_query_via_non_first_entity() {
let cache = QueryResultCache::new(CacheConfig::enabled());
let rows = list_result(&["uuid-A", "uuid-B"]);
cache.put(0x111, rows, vec!["v_user".to_string()], None, Some("User")).unwrap();
let evicted = cache.invalidate_by_entity("User", "uuid-B").unwrap();
assert_eq!(evicted, 1);
assert!(cache.get(0x111).unwrap().is_none(), "list entry containing uuid-B must be gone");
}
#[test]
fn test_invalidate_list_queries_spares_point_lookups() {
let cache = QueryResultCache::new(CacheConfig::enabled());
let single = vec![JsonbValue::new(serde_json::json!({"id": "uuid-X"}))];
cache
.put(0x001, single, vec!["v_user".to_string()], None, Some("User"))
.unwrap();
let list = list_result(&["uuid-A", "uuid-B"]);
cache.put(0x002, list, vec!["v_user".to_string()], None, Some("User")).unwrap();
let evicted = cache.invalidate_list_queries(&["v_user".to_string()]).unwrap();
assert_eq!(evicted, 1, "only the list entry should be evicted");
assert!(cache.get(0x001).unwrap().is_some(), "point lookup must survive");
assert!(cache.get(0x002).unwrap().is_none(), "list entry must be evicted");
}
#[test]
fn test_invalidate_by_entity_short_circuits_on_empty_index() {
let cache = QueryResultCache::new(CacheConfig::enabled());
let count = cache.invalidate_by_entity("User", "uuid-X").unwrap();
assert_eq!(count, 0);
}
#[test]
fn test_eviction_listener_cleans_all_entity_refs() {
let cache = QueryResultCache::new(CacheConfig::enabled());
let rows = list_result(&["uuid-A", "uuid-B"]);
cache.put(0x001, rows, vec!["v_user".to_string()], None, Some("User")).unwrap();
cache.invalidate_views(&["v_user".to_string()]).unwrap();
cache.store.run_pending_tasks();
let count_a = cache.invalidate_by_entity("User", "uuid-A").unwrap();
let count_b = cache.invalidate_by_entity("User", "uuid-B").unwrap();
assert_eq!(count_a, 0, "entity_index must be clean after eviction");
assert_eq!(count_b, 0, "entity_index must be clean after eviction");
}
#[test]
fn test_concurrent_access() {
use std::{sync::Arc, thread};
let cache = Arc::new(QueryResultCache::new(CacheConfig::enabled()));
let handles: Vec<_> = (0_u64..10)
.map(|key| {
let cache_clone = cache.clone();
thread::spawn(move || {
cache_clone
.put(key, test_result(), vec!["v_user".to_string()], None, None)
.unwrap();
cache_clone.get(key).unwrap();
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.total_cached, 10);
assert_eq!(metrics.hits, 10);
}
#[test]
fn test_cache_list_queries_false_skips_multi_row() {
let config = CacheConfig {
enabled: true,
cache_list_queries: false,
..CacheConfig::default()
};
let cache = QueryResultCache::new(config);
let two_rows = vec![
JsonbValue::new(json!({"id": 1})),
JsonbValue::new(json!({"id": 2})),
];
cache.put(1_u64, two_rows, vec!["v_user".to_string()], None, None).unwrap();
assert!(
cache.get(1_u64).unwrap().is_none(),
"multi-row result must not be cached when cache_list_queries=false"
);
}
#[test]
fn test_cache_list_queries_false_allows_single_row() {
let config = CacheConfig {
enabled: true,
cache_list_queries: false,
..CacheConfig::default()
};
let cache = QueryResultCache::new(config);
let one_row = vec![JsonbValue::new(json!({"id": 1}))];
cache.put(1_u64, one_row, vec!["v_user".to_string()], None, None).unwrap();
assert!(
cache.get(1_u64).unwrap().is_some(),
"single-row result must be cached even when cache_list_queries=false"
);
}
#[test]
fn test_max_entry_bytes_skips_oversized_entry() {
let config = CacheConfig {
enabled: true,
max_entry_bytes: Some(10), ..CacheConfig::default()
};
let cache = QueryResultCache::new(config);
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
assert!(cache.get(1_u64).unwrap().is_none(), "oversized entry must be silently skipped");
}
#[test]
fn test_max_entry_bytes_allows_small_entry() {
let config = CacheConfig {
enabled: true,
max_entry_bytes: Some(100_000), ..CacheConfig::default()
};
let cache = QueryResultCache::new(config);
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
assert!(
cache.get(1_u64).unwrap().is_some(),
"small entry must be cached when within max_entry_bytes"
);
}
#[test]
fn test_max_total_bytes_skips_when_budget_exhausted() {
let config = CacheConfig {
enabled: true,
max_total_bytes: Some(0), ..CacheConfig::default()
};
let cache = QueryResultCache::new(config);
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
assert!(
cache.get(1_u64).unwrap().is_none(),
"entry must be skipped when max_total_bytes budget is already exhausted"
);
}
#[test]
fn test_cross_key_view_invalidation() {
let config = CacheConfig {
max_entries: 10_000,
enabled: true,
..CacheConfig::default()
};
let cache = QueryResultCache::new(config);
for i in 0_u64..200 {
let view = if i % 2 == 0 { "v_user" } else { "v_post" };
cache.put(i, test_result(), vec![view.to_string()], None, None).unwrap();
}
let invalidated = cache.invalidate_views(&["v_user".to_string()]).unwrap();
assert_eq!(invalidated, 100);
for i in 0_u64..200 {
if i % 2 == 0 {
assert!(cache.get(i).unwrap().is_none(), "v_user entry should be invalidated");
} else {
assert!(cache.get(i).unwrap().is_some(), "v_post entry should remain");
}
}
}
#[test]
fn test_cross_key_entity_invalidation() {
let config = CacheConfig {
max_entries: 10_000,
enabled: true,
..CacheConfig::default()
};
let cache = QueryResultCache::new(config);
for i in 0_u64..50 {
cache
.put(
i,
entity_result("uuid-target"),
vec!["v_user".to_string()],
None,
Some("User"),
)
.unwrap();
}
cache
.put(
999_u64,
entity_result("uuid-other"),
vec!["v_user".to_string()],
None,
Some("User"),
)
.unwrap();
let evicted = cache.invalidate_by_entity("User", "uuid-target").unwrap();
assert_eq!(evicted, 50);
assert!(cache.get(999_u64).unwrap().is_some(), "unrelated entity should remain");
}
#[test]
fn test_clear_all() {
let config = CacheConfig {
max_entries: 10_000,
enabled: true,
..CacheConfig::default()
};
let cache = QueryResultCache::new(config);
for i in 0_u64..200 {
cache.put(i, test_result(), vec!["v_user".to_string()], None, None).unwrap();
}
cache.clear().unwrap();
cache.store.run_pending_tasks();
let metrics = cache.metrics().unwrap();
assert_eq!(metrics.size, 0);
for i in 0_u64..200 {
assert!(cache.get(i).unwrap().is_none());
}
}
#[test]
fn test_memory_bytes_tracked() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache.put(1_u64, test_result(), vec!["v".to_string()], None, None).unwrap();
cache.put(2_u64, test_result(), vec!["v".to_string()], None, None).unwrap();
let before = cache.metrics().unwrap().memory_bytes;
assert!(before > 0, "memory_bytes should be tracked");
}
#[test]
fn test_memory_bytes_decreases_on_clear() {
let cache = QueryResultCache::new(CacheConfig::enabled());
cache.put(1_u64, test_result(), vec!["v_user".to_string()], None, None).unwrap();
let before = cache.metrics().unwrap().memory_bytes;
assert!(before > 0);
cache.clear().unwrap();
let after = cache.metrics().unwrap().memory_bytes;
assert_eq!(after, 0, "memory_bytes should be zero after clear()");
}
#[test]
#[ignore = "wall-clock dependent — run manually to confirm lock-free read scaling"]
fn test_concurrent_reads_do_not_serialize() {
const ITERS: usize = 10_000;
let config = CacheConfig::enabled();
let cache = Arc::new(QueryResultCache::new(config));
let key = 42_u64;
cache.put(key, test_result(), vec!["v_user".to_string()], None, None).unwrap();
let start = std::time::Instant::now();
for _ in 0..ITERS {
let _ = cache.get(key).unwrap();
}
let single_elapsed = start.elapsed();
let start = std::time::Instant::now();
let handles: Vec<_> = (0..40)
.map(|_| {
let c = Arc::clone(&cache);
std::thread::spawn(move || {
for _ in 0..ITERS {
let _ = c.get(key).unwrap();
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let multi_elapsed = start.elapsed();
assert!(
multi_elapsed <= single_elapsed * 2,
"40-thread ({:?}) was more than 2× single-thread ({:?}) — suggests serialization",
multi_elapsed,
single_elapsed,
);
}
}