use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct QueryCacheConfig {
pub max_size: usize,
pub ttl: Duration,
pub cache_plans: bool,
}
impl Default for QueryCacheConfig {
fn default() -> Self {
Self {
max_size: 1000,
ttl: Duration::from_secs(5 * 60), cache_plans: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CachedQuery {
pub sql: String,
pub params_hash: u64,
pub result: Option<Vec<u8>>,
pub cached_at: Instant,
pub hit_count: usize,
}
pub struct QueryCache {
config: QueryCacheConfig,
cache: Arc<RwLock<HashMap<String, CachedQuery>>>,
}
impl QueryCache {
pub fn new(config: QueryCacheConfig) -> Self {
Self {
config,
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get(&self, sql: &str, params_hash: u64) -> Option<CachedQuery> {
let cache = self.cache.read().ok()?;
let cached = cache.get(sql)?;
if cached.cached_at.elapsed() > self.config.ttl {
return None;
}
if cached.params_hash != params_hash {
return None;
}
Some(cached.clone())
}
pub fn set(&self, sql: String, params_hash: u64, result: Option<Vec<u8>>) {
let mut cache = match self.cache.write() {
Ok(c) => c,
Err(_) => return,
};
if cache.len() >= self.config.max_size
&& let Some((oldest_key, _)) = cache
.iter()
.min_by_key(|(_, v)| v.cached_at)
.map(|(k, v)| (k.clone(), v.cached_at))
{
cache.remove(&oldest_key);
}
let cached = CachedQuery {
sql: sql.clone(),
params_hash,
result,
cached_at: Instant::now(),
hit_count: 0,
};
cache.insert(sql, cached);
}
pub fn record_hit(&self, sql: &str) {
if let Ok(mut cache) = self.cache.write()
&& let Some(cached) = cache.get_mut(sql)
{
cached.hit_count += 1;
}
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.write() {
cache.clear();
}
}
pub fn stats(&self) -> CacheStats {
let cache = match self.cache.read() {
Ok(c) => c,
Err(_) => return CacheStats::default(),
};
let total_entries = cache.len();
let total_hits: usize = cache.values().map(|v| v.hit_count).sum();
let expired_entries = cache
.values()
.filter(|v| v.cached_at.elapsed() > self.config.ttl)
.count();
CacheStats {
total_entries,
total_hits,
expired_entries,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub total_hits: usize,
pub expired_entries: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_cache_basic() {
let cache = QueryCache::new(QueryCacheConfig::default());
let sql = "SELECT * FROM users WHERE id = $1".to_string();
let params_hash = 12345u64;
assert!(cache.get(&sql, params_hash).is_none());
cache.set(sql.clone(), params_hash, Some(vec![1, 2, 3]));
let cached = cache.get(&sql, params_hash).unwrap();
assert_eq!(cached.sql, sql);
assert_eq!(cached.params_hash, params_hash);
}
#[test]
fn test_cache_expiration() {
let config = QueryCacheConfig {
max_size: 10,
ttl: Duration::from_millis(100),
cache_plans: true,
};
let cache = QueryCache::new(config);
let sql = "SELECT * FROM users".to_string();
let params_hash = 0u64;
cache.set(sql.clone(), params_hash, None);
assert!(cache.get(&sql, params_hash).is_some());
std::thread::sleep(Duration::from_millis(150));
assert!(cache.get(&sql, params_hash).is_none());
}
#[test]
fn test_cache_stats() {
let cache = QueryCache::new(QueryCacheConfig::default());
cache.set("query1".to_string(), 1, None);
cache.set("query2".to_string(), 2, None);
cache.record_hit("query1");
cache.record_hit("query1");
cache.record_hit("query2");
let stats = cache.stats();
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.total_hits, 3);
}
}