use scirs2_core::metrics::MetricsRegistry;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub max_memory_bytes: u64,
pub default_ttl: Duration,
pub enable_lru: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 10000,
max_memory_bytes: 1024 * 1024 * 1024, default_ttl: Duration::from_secs(300), enable_lru: true,
}
}
}
#[derive(Debug, Clone)]
struct CacheEntry<V> {
value: V,
size_bytes: u64,
#[allow(dead_code)]
created_at: Instant,
expires_at: Instant,
last_accessed: Instant,
access_count: u64,
}
impl<V> CacheEntry<V> {
fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
fn touch(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
}
pub struct QueryResultCache<V: Clone> {
config: CacheConfig,
entries: Arc<RwLock<HashMap<String, CacheEntry<V>>>>,
lru_queue: Arc<RwLock<VecDeque<String>>>,
current_memory: Arc<AtomicU64>,
stats: CacheStats,
#[allow(dead_code)]
metrics: Arc<MetricsRegistry>,
}
#[derive(Clone)]
pub struct CacheStats {
pub hits: Arc<AtomicU64>,
pub misses: Arc<AtomicU64>,
pub evictions: Arc<AtomicU64>,
pub expirations: Arc<AtomicU64>,
pub puts: Arc<AtomicU64>,
pub invalidations: Arc<AtomicU64>,
}
impl CacheStats {
fn new() -> Self {
Self {
hits: Arc::new(AtomicU64::new(0)),
misses: Arc::new(AtomicU64::new(0)),
evictions: Arc::new(AtomicU64::new(0)),
expirations: Arc::new(AtomicU64::new(0)),
puts: Arc::new(AtomicU64::new(0)),
invalidations: Arc::new(AtomicU64::new(0)),
}
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
pub fn reset(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.evictions.store(0, Ordering::Relaxed);
self.expirations.store(0, Ordering::Relaxed);
self.puts.store(0, Ordering::Relaxed);
self.invalidations.store(0, Ordering::Relaxed);
}
}
impl<V: Clone> QueryResultCache<V> {
pub fn new(config: CacheConfig) -> Self {
let metrics = MetricsRegistry::new();
Self {
config,
entries: Arc::new(RwLock::new(HashMap::new())),
lru_queue: Arc::new(RwLock::new(VecDeque::new())),
current_memory: Arc::new(AtomicU64::new(0)),
stats: CacheStats::new(),
metrics: Arc::new(metrics),
}
}
pub fn put(&self, key: String, value: V) {
self.put_with_ttl(key, value, self.config.default_ttl);
}
pub fn put_with_ttl(&self, key: String, value: V, ttl: Duration) {
let now = Instant::now();
let size_bytes = self.estimate_size(&value);
let entry = CacheEntry {
value,
size_bytes,
created_at: now,
expires_at: now + ttl,
last_accessed: now,
access_count: 0,
};
self.ensure_capacity(size_bytes);
{
let mut entries = self.entries.write().expect("entries lock poisoned");
if let Some(old_entry) = entries.remove(&key) {
self.current_memory
.fetch_sub(old_entry.size_bytes, Ordering::Relaxed);
}
entries.insert(key.clone(), entry);
self.current_memory.fetch_add(size_bytes, Ordering::Relaxed);
}
if self.config.enable_lru {
let mut lru = self.lru_queue.write().expect("lru_queue lock poisoned");
lru.retain(|k| k != &key); lru.push_back(key);
}
self.stats.puts.fetch_add(1, Ordering::Relaxed);
}
pub fn get(&self, key: &str) -> Option<V> {
self.clean_expired();
let mut entries = self.entries.write().expect("entries lock poisoned");
if let Some(entry) = entries.get_mut(key) {
if entry.is_expired() {
self.current_memory
.fetch_sub(entry.size_bytes, Ordering::Relaxed);
entries.remove(key);
self.stats.expirations.fetch_add(1, Ordering::Relaxed);
self.stats.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
entry.touch();
if self.config.enable_lru {
let mut lru = self.lru_queue.write().expect("lru_queue lock poisoned");
lru.retain(|k| k != key);
lru.push_back(key.to_string());
}
self.stats.hits.fetch_add(1, Ordering::Relaxed);
Some(entry.value.clone())
} else {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub fn invalidate(&self, key: &str) -> bool {
let mut entries = self.entries.write().expect("entries lock poisoned");
if let Some(entry) = entries.remove(key) {
self.current_memory
.fetch_sub(entry.size_bytes, Ordering::Relaxed);
if self.config.enable_lru {
let mut lru = self.lru_queue.write().expect("lru_queue lock poisoned");
lru.retain(|k| k != key);
}
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
true
} else {
false
}
}
pub fn clear(&self) {
let mut entries = self.entries.write().expect("entries lock poisoned");
entries.clear();
if self.config.enable_lru {
let mut lru = self.lru_queue.write().expect("lru_queue lock poisoned");
lru.clear();
}
self.current_memory.store(0, Ordering::Relaxed);
}
pub fn len(&self) -> usize {
self.entries.read().expect("entries lock poisoned").len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn memory_usage(&self) -> u64 {
self.current_memory.load(Ordering::Relaxed)
}
pub fn stats(&self) -> CacheStats {
self.stats.clone()
}
fn clean_expired(&self) {
let mut entries = self.entries.write().expect("entries lock poisoned");
let mut to_remove = Vec::new();
for (key, entry) in entries.iter() {
if entry.is_expired() {
to_remove.push((key.clone(), entry.size_bytes));
}
}
for (key, size) in to_remove {
entries.remove(&key);
self.current_memory.fetch_sub(size, Ordering::Relaxed);
self.stats.expirations.fetch_add(1, Ordering::Relaxed);
if self.config.enable_lru {
let mut lru = self.lru_queue.write().expect("lru_queue lock poisoned");
lru.retain(|k| k != &key);
}
}
}
fn ensure_capacity(&self, new_entry_size: u64) {
while self.len() >= self.config.max_entries {
self.evict_lru();
}
while self.memory_usage() + new_entry_size > self.config.max_memory_bytes {
self.evict_lru();
}
}
fn evict_lru(&self) {
if !self.config.enable_lru {
let key_to_evict = {
let entries = self.entries.read().expect("entries lock poisoned");
entries.keys().next().cloned()
};
if let Some(key) = key_to_evict {
let mut entries = self.entries.write().expect("entries lock poisoned");
if let Some(entry) = entries.remove(&key) {
self.current_memory
.fetch_sub(entry.size_bytes, Ordering::Relaxed);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
}
return;
}
let key_to_evict = {
let mut lru = self.lru_queue.write().expect("lru_queue lock poisoned");
lru.pop_front()
};
if let Some(key) = key_to_evict {
let mut entries = self.entries.write().expect("entries lock poisoned");
if let Some(entry) = entries.remove(&key) {
self.current_memory
.fetch_sub(entry.size_bytes, Ordering::Relaxed);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
}
}
fn estimate_size(&self, _value: &V) -> u64 {
1024
}
}
impl<V: Clone> Default for QueryResultCache<V> {
fn default() -> Self {
Self::new(CacheConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_cache_operations() {
let cache = QueryResultCache::<String>::new(CacheConfig::default());
cache.put("key1".to_string(), "value1".to_string());
assert_eq!(cache.get("key1"), Some("value1".to_string()));
assert_eq!(cache.get("key2"), None);
assert_eq!(cache.stats().hits.load(Ordering::Relaxed), 1);
assert_eq!(cache.stats().misses.load(Ordering::Relaxed), 1);
}
#[test]
fn test_ttl_expiration() {
let config = CacheConfig {
default_ttl: Duration::from_millis(100),
..Default::default()
};
let cache = QueryResultCache::<String>::new(config);
cache.put("key1".to_string(), "value1".to_string());
assert_eq!(cache.get("key1"), Some("value1".to_string()));
std::thread::sleep(Duration::from_millis(150));
assert_eq!(cache.get("key1"), None);
assert_eq!(cache.stats().expirations.load(Ordering::Relaxed), 1);
}
#[test]
fn test_lru_eviction() {
let config = CacheConfig {
max_entries: 3,
enable_lru: true,
..Default::default()
};
let cache = QueryResultCache::<String>::new(config);
cache.put("key1".to_string(), "value1".to_string());
cache.put("key2".to_string(), "value2".to_string());
cache.put("key3".to_string(), "value3".to_string());
cache.get("key1");
cache.put("key4".to_string(), "value4".to_string());
assert_eq!(cache.get("key1"), Some("value1".to_string()));
assert_eq!(cache.get("key2"), None); assert_eq!(cache.get("key3"), Some("value3".to_string()));
assert_eq!(cache.get("key4"), Some("value4".to_string()));
}
#[test]
fn test_cache_invalidation() {
let cache = QueryResultCache::<String>::new(CacheConfig::default());
cache.put("key1".to_string(), "value1".to_string());
assert!(cache.invalidate("key1"));
assert_eq!(cache.get("key1"), None);
assert!(!cache.invalidate("key1")); }
#[test]
fn test_cache_clear() {
let cache = QueryResultCache::<String>::new(CacheConfig::default());
cache.put("key1".to_string(), "value1".to_string());
cache.put("key2".to_string(), "value2".to_string());
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert_eq!(cache.memory_usage(), 0);
}
#[test]
fn test_hit_rate() {
let cache = QueryResultCache::<String>::new(CacheConfig::default());
cache.put("key1".to_string(), "value1".to_string());
cache.get("key1"); cache.get("key2"); cache.get("key1");
assert_eq!(cache.stats().hit_rate(), 2.0 / 3.0);
}
#[test]
fn test_concurrent_access() {
use std::sync::Arc;
use std::thread;
let cache = Arc::new(QueryResultCache::<String>::new(CacheConfig::default()));
let mut handles = vec![];
for i in 0..10 {
let cache_clone = Arc::clone(&cache);
let handle = thread::spawn(move || {
for j in 0..100 {
let key = format!("key_{}", i * 100 + j);
let value = format!("value_{}", i * 100 + j);
cache_clone.put(key.clone(), value.clone());
cache_clone.get(&key);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("thread should not panic");
}
assert!(cache.len() <= 1000);
}
#[test]
fn test_memory_aware_eviction() {
let config = CacheConfig {
max_entries: 1000,
max_memory_bytes: 5120, enable_lru: true,
..Default::default()
};
let max_memory = config.max_memory_bytes;
let cache = QueryResultCache::<String>::new(config);
for i in 0..10 {
cache.put(format!("key{}", i), format!("value{}", i));
}
assert!(cache.memory_usage() <= max_memory);
assert!(cache.stats().evictions.load(Ordering::Relaxed) > 0);
}
}