use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use crate::{Result, Tuple};
#[derive(Debug, Clone)]
pub struct CachedResult {
pub tuples: Vec<Tuple>,
pub cached_at: Instant,
pub ttl: Duration,
pub tables: Vec<String>,
}
impl CachedResult {
pub fn new(tuples: Vec<Tuple>, ttl: Duration, tables: Vec<String>) -> Self {
Self {
tuples,
cached_at: Instant::now(),
ttl,
tables,
}
}
pub fn is_expired(&self) -> bool {
self.cached_at.elapsed() > self.ttl
}
pub fn age(&self) -> Duration {
self.cached_at.elapsed()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub query_hash: u64,
pub branch: Option<String>,
}
impl CacheKey {
pub fn new(sql: &str, branch: Option<String>) -> Self {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
sql.hash(&mut hasher);
Self {
query_hash: hasher.finish(),
branch,
}
}
}
pub struct QueryCache {
cache: Arc<RwLock<HashMap<CacheKey, CachedResult>>>,
max_entries: usize,
default_ttl: Duration,
stats: Arc<RwLock<CacheStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub invalidations: u64,
pub evictions: u64,
}
impl QueryCache {
pub fn new() -> Self {
Self::with_config(1000, Duration::from_secs(60))
}
pub fn with_config(max_entries: usize, default_ttl: Duration) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
max_entries,
default_ttl,
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
pub fn get(&self, key: &CacheKey) -> Result<Option<Vec<Tuple>>> {
use crate::error::LockResultExt;
let cache = self.cache.read()
.map_lock_err("Failed to acquire read lock on query cache")?;
if let Some(cached) = cache.get(key) {
if cached.is_expired() {
drop(cache);
let mut stats = self.stats.write()
.map_lock_err("Failed to acquire write lock on cache stats")?;
stats.misses += 1;
return Ok(None);
}
drop(cache);
let mut stats = self.stats.write()
.map_lock_err("Failed to acquire write lock on cache stats")?;
stats.hits += 1;
let cache = self.cache.read()
.map_lock_err("Failed to acquire read lock on query cache")?;
return Ok(cache.get(key).map(|c| c.tuples.clone()));
}
let mut stats = self.stats.write()
.map_lock_err("Failed to acquire write lock on cache stats")?;
stats.misses += 1;
Ok(None)
}
pub fn put(&self, key: CacheKey, tuples: Vec<Tuple>, tables: Vec<String>) -> Result<()> {
self.put_with_ttl(key, tuples, tables, self.default_ttl)
}
pub fn put_with_ttl(
&self,
key: CacheKey,
tuples: Vec<Tuple>,
tables: Vec<String>,
ttl: Duration,
) -> Result<()> {
use crate::error::LockResultExt;
let mut cache = self.cache.write()
.map_lock_err("Failed to acquire write lock on query cache")?;
if cache.len() >= self.max_entries && !cache.contains_key(&key) {
let expired_keys: Vec<CacheKey> = cache
.iter()
.filter(|(_, v)| v.is_expired())
.map(|(k, _)| k.clone())
.collect();
for k in expired_keys {
cache.remove(&k);
}
if cache.len() >= self.max_entries {
let oldest_key = cache
.iter()
.min_by_key(|(_, v)| v.cached_at)
.map(|(k, _)| k.clone());
if let Some(key) = oldest_key {
cache.remove(&key);
let mut stats = self.stats.write()
.map_lock_err("Failed to acquire write lock on cache stats")?;
stats.evictions += 1;
}
}
}
cache.insert(key, CachedResult::new(tuples, ttl, tables));
Ok(())
}
pub fn invalidate_table(&self, table_name: &str) -> Result<u64> {
use crate::error::LockResultExt;
let mut cache = self.cache.write()
.map_lock_err("Failed to acquire write lock on query cache")?;
let keys_to_remove: Vec<CacheKey> = cache
.iter()
.filter(|(_, v)| v.tables.iter().any(|t| t == table_name))
.map(|(k, _)| k.clone())
.collect();
let count = keys_to_remove.len() as u64;
for key in keys_to_remove {
cache.remove(&key);
}
if count > 0 {
let mut stats = self.stats.write()
.map_lock_err("Failed to acquire write lock on cache stats")?;
stats.invalidations += count;
}
Ok(count)
}
pub fn invalidate_all(&self) -> Result<u64> {
use crate::error::LockResultExt;
let mut cache = self.cache.write()
.map_lock_err("Failed to acquire write lock on query cache")?;
let count = cache.len() as u64;
cache.clear();
if count > 0 {
let mut stats = self.stats.write()
.map_lock_err("Failed to acquire write lock on cache stats")?;
stats.invalidations += count;
}
Ok(count)
}
pub fn stats(&self) -> Result<CacheStats> {
use crate::error::LockResultExt;
let stats = self.stats.read()
.map_lock_err("Failed to acquire read lock on cache stats")?;
Ok(stats.clone())
}
pub fn len(&self) -> Result<usize> {
use crate::error::LockResultExt;
let cache = self.cache.read()
.map_lock_err("Failed to acquire read lock on query cache")?;
Ok(cache.len())
}
pub fn is_empty(&self) -> Result<bool> {
self.len().map(|len| len == 0)
}
pub fn hit_rate(&self) -> Result<f64> {
let stats = self.stats()?;
let total = stats.hits + stats.misses;
if total == 0 {
return Ok(0.0);
}
Ok(stats.hits as f64 / total as f64)
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_hit_miss() {
let cache = QueryCache::new();
let key = CacheKey::new("SELECT * FROM users", None);
let result = cache.get(&key).unwrap();
assert!(result.is_none());
let tuples = vec![Tuple::new(vec![crate::Value::Int4(1)])];
cache.put(key.clone(), tuples.clone(), vec!["users".to_string()]).unwrap();
let result = cache.get(&key).unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().len(), 1);
}
#[test]
fn test_cache_expiration() {
let cache = QueryCache::with_config(100, Duration::from_millis(10));
let key = CacheKey::new("SELECT * FROM users", None);
let tuples = vec![Tuple::new(vec![crate::Value::Int4(1)])];
cache.put(key.clone(), tuples, vec!["users".to_string()]).unwrap();
assert!(cache.get(&key).unwrap().is_some());
std::thread::sleep(Duration::from_millis(20));
assert!(cache.get(&key).unwrap().is_none());
}
#[test]
fn test_cache_invalidation() {
let cache = QueryCache::new();
let key1 = CacheKey::new("SELECT * FROM users", None);
let key2 = CacheKey::new("SELECT * FROM orders", None);
cache.put(key1.clone(), vec![], vec!["users".to_string()]).unwrap();
cache.put(key2.clone(), vec![], vec!["orders".to_string()]).unwrap();
let count = cache.invalidate_table("users").unwrap();
assert_eq!(count, 1);
assert!(cache.get(&key1).unwrap().is_none());
assert!(cache.get(&key2).unwrap().is_some());
}
}