use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{RwLock, OnceLock};
use std::time::{Duration, Instant};
use serde::{Serialize, Deserialize};
use crate::error::{Error, Result};
static GLOBAL_QUERY_CACHE: OnceLock<QueryCache> = OnceLock::new();
static GLOBAL_STMT_CACHE: OnceLock<PreparedStatementCache> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub enabled: bool,
pub max_entries: usize,
pub default_ttl: Duration,
pub strategy: CacheStrategy,
pub cache_empty_results: bool,
pub key_prefix: Option<String>,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: false,
max_entries: 1000,
default_ttl: Duration::from_secs(60),
strategy: CacheStrategy::LRU,
cache_empty_results: true,
key_prefix: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheStrategy {
LRU,
FIFO,
TTL,
}
impl std::fmt::Display for CacheStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CacheStrategy::LRU => write!(f, "LRU"),
CacheStrategy::FIFO => write!(f, "FIFO"),
CacheStrategy::TTL => write!(f, "TTL"),
}
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
data: serde_json::Value,
created_at: Instant,
last_accessed: Instant,
ttl: Duration,
model_name: String,
hit_count: u64,
}
impl CacheEntry {
fn new(data: serde_json::Value, ttl: Duration, model_name: &str) -> Self {
let now = Instant::now();
Self {
data,
created_at: now,
last_accessed: now,
ttl,
model_name: model_name.to_string(),
hit_count: 0,
}
}
fn is_expired(&self) -> bool {
self.created_at.elapsed() > self.ttl
}
fn touch(&mut self) {
self.last_accessed = Instant::now();
self.hit_count += 1;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub size_bytes: usize,
pub evictions: u64,
pub invalidations: u64,
}
impl CacheStats {
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug)]
pub struct QueryCache {
config: RwLock<CacheConfig>,
cache: RwLock<HashMap<String, CacheEntry>>,
stats: RwLock<CacheStats>,
}
impl QueryCache {
pub fn new() -> Self {
Self {
config: RwLock::new(CacheConfig::default()),
cache: RwLock::new(HashMap::new()),
stats: RwLock::new(CacheStats::default()),
}
}
pub fn with_config(config: CacheConfig) -> Self {
Self {
config: RwLock::new(config),
cache: RwLock::new(HashMap::new()),
stats: RwLock::new(CacheStats::default()),
}
}
pub fn global() -> &'static QueryCache {
GLOBAL_QUERY_CACHE.get_or_init(QueryCache::new)
}
pub fn init_global(config: CacheConfig) -> &'static QueryCache {
let _ = GLOBAL_QUERY_CACHE.set(QueryCache::with_config(config));
QueryCache::global()
}
pub fn enable(&self) -> &Self {
if let Ok(mut config) = self.config.write() {
config.enabled = true;
}
self
}
pub fn disable(&self) -> &Self {
if let Ok(mut config) = self.config.write() {
config.enabled = false;
}
self
}
pub fn is_enabled(&self) -> bool {
self.config.read().map(|c| c.enabled).unwrap_or(false)
}
pub fn set_max_entries(&self, max: usize) -> &Self {
if let Ok(mut config) = self.config.write() {
config.max_entries = max;
}
self
}
pub fn set_default_ttl(&self, ttl: Duration) -> &Self {
if let Ok(mut config) = self.config.write() {
config.default_ttl = ttl;
}
self
}
pub fn set_strategy(&self, strategy: CacheStrategy) -> &Self {
if let Ok(mut config) = self.config.write() {
config.strategy = strategy;
}
self
}
pub fn set_key_prefix(&self, prefix: &str) -> &Self {
if let Ok(mut config) = self.config.write() {
config.key_prefix = Some(prefix.to_string());
}
self
}
pub fn set_cache_empty_results(&self, cache_empty: bool) -> &Self {
if let Ok(mut config) = self.config.write() {
config.cache_empty_results = cache_empty;
}
self
}
pub fn config(&self) -> Option<CacheConfig> {
self.config.read().ok().map(|c| c.clone())
}
pub fn generate_key(&self, table: &str, query_hash: u64) -> String {
let prefix = self.config.read()
.ok()
.and_then(|c| c.key_prefix.clone())
.unwrap_or_default();
if prefix.is_empty() {
format!("{}:{}", table, query_hash)
} else {
format!("{}:{}:{}", prefix, table, query_hash)
}
}
pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
if !self.is_enabled() {
return None;
}
let mut cache = self.cache.write().ok()?;
if let Some(entry) = cache.get_mut(key) {
if entry.is_expired() {
cache.remove(key);
if let Ok(mut stats) = self.stats.write() {
stats.misses += 1;
stats.entries = cache.len();
}
return None;
}
entry.touch();
if let Ok(mut stats) = self.stats.write() {
stats.hits += 1;
}
serde_json::from_value(entry.data.clone()).ok()
} else {
if let Ok(mut stats) = self.stats.write() {
stats.misses += 1;
}
None
}
}
pub fn set<T: Serialize>(&self, key: &str, value: &T, ttl: Option<Duration>, model_name: &str) -> Result<()> {
if !self.is_enabled() {
return Ok(());
}
let config = self.config.read()
.map_err(|_| Error::internal("Failed to read cache config"))?;
let ttl = ttl.unwrap_or(config.default_ttl);
let max_entries = config.max_entries;
drop(config);
let data = serde_json::to_value(value)
.map_err(|e| Error::internal(format!("Failed to serialize cache value: {}", e)))?;
if let serde_json::Value::Array(arr) = &data {
if arr.is_empty() {
let should_cache = self.config.read()
.map(|c| c.cache_empty_results)
.unwrap_or(true);
if !should_cache {
return Ok(());
}
}
}
let entry = CacheEntry::new(data, ttl, model_name);
let entry_size = entry.data.to_string().len();
let mut cache = self.cache.write()
.map_err(|_| Error::internal("Failed to acquire cache write lock"))?;
while cache.len() >= max_entries {
self.evict_one(&mut cache);
}
cache.insert(key.to_string(), entry);
if let Ok(mut stats) = self.stats.write() {
stats.entries = cache.len();
stats.size_bytes += entry_size;
}
Ok(())
}
pub fn invalidate(&self, key: &str) -> bool {
if let Ok(mut cache) = self.cache.write() {
let removed = cache.remove(key).is_some();
if removed {
if let Ok(mut stats) = self.stats.write() {
stats.invalidations += 1;
stats.entries = cache.len();
}
}
removed
} else {
false
}
}
pub fn invalidate_model(&self, model_name: &str) {
if let Ok(mut cache) = self.cache.write() {
let keys_to_remove: Vec<String> = cache.iter()
.filter(|(_, entry)| entry.model_name == model_name)
.map(|(key, _)| key.clone())
.collect();
let count = keys_to_remove.len();
for key in keys_to_remove {
cache.remove(&key);
}
if let Ok(mut stats) = self.stats.write() {
stats.invalidations += count as u64;
stats.entries = cache.len();
}
}
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.write() {
let count = cache.len();
cache.clear();
if let Ok(mut stats) = self.stats.write() {
stats.invalidations += count as u64;
stats.entries = 0;
stats.size_bytes = 0;
}
}
}
pub fn stats(&self) -> CacheStats {
self.stats.read().map(|s| s.clone()).unwrap_or_default()
}
pub fn reset_stats(&self) {
if let Ok(mut stats) = self.stats.write() {
*stats = CacheStats::default();
if let Ok(cache) = self.cache.read() {
stats.entries = cache.len();
}
}
}
pub fn evict_expired(&self) {
if let Ok(mut cache) = self.cache.write() {
let keys_to_remove: Vec<String> = cache.iter()
.filter(|(_, entry)| entry.is_expired())
.map(|(key, _)| key.clone())
.collect();
let count = keys_to_remove.len();
for key in keys_to_remove {
cache.remove(&key);
}
if let Ok(mut stats) = self.stats.write() {
stats.evictions += count as u64;
stats.entries = cache.len();
}
}
}
fn evict_one(&self, cache: &mut HashMap<String, CacheEntry>) {
let strategy = self.config.read()
.map(|c| c.strategy)
.unwrap_or(CacheStrategy::LRU);
let key_to_remove = match strategy {
CacheStrategy::LRU => {
cache.iter()
.min_by_key(|(_, entry)| entry.last_accessed)
.map(|(key, _)| key.clone())
}
CacheStrategy::FIFO => {
cache.iter()
.min_by_key(|(_, entry)| entry.created_at)
.map(|(key, _)| key.clone())
}
CacheStrategy::TTL => {
cache.iter()
.min_by_key(|(_, entry)| {
entry.ttl.checked_sub(entry.created_at.elapsed())
.unwrap_or(Duration::ZERO)
})
.map(|(key, _)| key.clone())
}
};
if let Some(key) = key_to_remove {
cache.remove(&key);
if let Ok(mut stats) = self.stats.write() {
stats.evictions += 1;
}
}
}
pub fn contains(&self, key: &str) -> bool {
if let Ok(cache) = self.cache.read() {
if let Some(entry) = cache.get(key) {
return !entry.is_expired();
}
}
false
}
pub fn len(&self) -> usize {
self.cache.read().map(|c| c.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct PreparedStatement {
sql: String,
prepared_at: Instant,
last_used: Instant,
execution_count: u64,
avg_execution_time_us: u64,
}
impl PreparedStatement {
fn new(sql: String) -> Self {
let now = Instant::now();
Self {
sql,
prepared_at: now,
last_used: now,
execution_count: 0,
avg_execution_time_us: 0,
}
}
fn record_execution(&mut self, execution_time_us: u64) {
self.last_used = Instant::now();
let total = self.avg_execution_time_us * self.execution_count + execution_time_us;
self.execution_count += 1;
self.avg_execution_time_us = total / self.execution_count;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PreparedStatementStats {
pub hits: u64,
pub misses: u64,
pub cached_count: usize,
pub total_executions: u64,
pub evictions: u64,
}
impl PreparedStatementStats {
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug, Clone)]
pub struct PreparedStatementConfig {
pub enabled: bool,
pub max_statements: usize,
pub max_age: Duration,
}
impl Default for PreparedStatementConfig {
fn default() -> Self {
Self {
enabled: false,
max_statements: 500,
max_age: Duration::from_secs(3600), }
}
}
#[derive(Debug)]
pub struct PreparedStatementCache {
config: RwLock<PreparedStatementConfig>,
statements: RwLock<HashMap<u64, PreparedStatement>>,
stats: RwLock<PreparedStatementStats>,
}
impl PreparedStatementCache {
pub fn new() -> Self {
Self {
config: RwLock::new(PreparedStatementConfig::default()),
statements: RwLock::new(HashMap::new()),
stats: RwLock::new(PreparedStatementStats::default()),
}
}
pub fn with_config(config: PreparedStatementConfig) -> Self {
Self {
config: RwLock::new(config),
statements: RwLock::new(HashMap::new()),
stats: RwLock::new(PreparedStatementStats::default()),
}
}
pub fn global() -> &'static PreparedStatementCache {
GLOBAL_STMT_CACHE.get_or_init(PreparedStatementCache::new)
}
pub fn init_global(config: PreparedStatementConfig) -> &'static PreparedStatementCache {
let _ = GLOBAL_STMT_CACHE.set(PreparedStatementCache::with_config(config));
PreparedStatementCache::global()
}
pub fn enable(&self) -> &Self {
if let Ok(mut config) = self.config.write() {
config.enabled = true;
}
self
}
pub fn disable(&self) -> &Self {
if let Ok(mut config) = self.config.write() {
config.enabled = false;
}
self
}
pub fn is_enabled(&self) -> bool {
self.config.read().map(|c| c.enabled).unwrap_or(false)
}
pub fn set_max_statements(&self, max: usize) -> &Self {
if let Ok(mut config) = self.config.write() {
config.max_statements = max;
}
self
}
pub fn set_max_age(&self, age: Duration) -> &Self {
if let Ok(mut config) = self.config.write() {
config.max_age = age;
}
self
}
pub fn config(&self) -> Option<PreparedStatementConfig> {
self.config.read().ok().map(|c| c.clone())
}
pub fn hash_sql(sql: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
sql.hash(&mut hasher);
hasher.finish()
}
pub fn get_or_prepare(&self, sql: &str) -> (String, bool) {
if !self.is_enabled() {
return (sql.to_string(), false);
}
let hash = Self::hash_sql(sql);
let max_age = self.config.read()
.map(|c| c.max_age)
.unwrap_or(Duration::from_secs(3600));
if let Ok(mut statements) = self.statements.write() {
if let Some(stmt) = statements.get_mut(&hash) {
if stmt.prepared_at.elapsed() < max_age {
if let Ok(mut stats) = self.stats.write() {
stats.hits += 1;
}
return (stmt.sql.clone(), true);
} else {
statements.remove(&hash);
}
}
}
self.cache_statement(sql);
if let Ok(mut stats) = self.stats.write() {
stats.misses += 1;
}
(sql.to_string(), false)
}
fn cache_statement(&self, sql: &str) {
let hash = Self::hash_sql(sql);
let max_statements = self.config.read()
.map(|c| c.max_statements)
.unwrap_or(500);
if let Ok(mut statements) = self.statements.write() {
while statements.len() >= max_statements {
let oldest_key = statements.iter()
.min_by_key(|(_, stmt)| stmt.last_used)
.map(|(key, _)| *key);
if let Some(key) = oldest_key {
statements.remove(&key);
if let Ok(mut stats) = self.stats.write() {
stats.evictions += 1;
}
}
}
statements.insert(hash, PreparedStatement::new(sql.to_string()));
if let Ok(mut stats) = self.stats.write() {
stats.cached_count = statements.len();
}
}
}
pub fn record_execution(&self, sql: &str, execution_time_us: u64) {
if !self.is_enabled() {
return;
}
let hash = Self::hash_sql(sql);
if let Ok(mut statements) = self.statements.write() {
if let Some(stmt) = statements.get_mut(&hash) {
stmt.record_execution(execution_time_us);
}
}
if let Ok(mut stats) = self.stats.write() {
stats.total_executions += 1;
}
}
pub fn invalidate(&self, sql: &str) -> bool {
let hash = Self::hash_sql(sql);
if let Ok(mut statements) = self.statements.write() {
let removed = statements.remove(&hash).is_some();
if removed {
if let Ok(mut stats) = self.stats.write() {
stats.cached_count = statements.len();
}
}
removed
} else {
false
}
}
pub fn clear(&self) {
if let Ok(mut statements) = self.statements.write() {
statements.clear();
if let Ok(mut stats) = self.stats.write() {
stats.cached_count = 0;
}
}
}
pub fn stats(&self) -> PreparedStatementStats {
self.stats.read().map(|s| s.clone()).unwrap_or_default()
}
pub fn reset_stats(&self) {
if let Ok(mut stats) = self.stats.write() {
*stats = PreparedStatementStats::default();
if let Ok(statements) = self.statements.read() {
stats.cached_count = statements.len();
}
}
}
pub fn len(&self) -> usize {
self.statements.read().map(|s| s.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn cached_statements_info(&self) -> Vec<CachedStatementInfo> {
if let Ok(statements) = self.statements.read() {
statements.iter().map(|(hash, stmt)| {
CachedStatementInfo {
hash: *hash,
sql_preview: if stmt.sql.len() > 100 {
format!("{}...", &stmt.sql[..100])
} else {
stmt.sql.clone()
},
execution_count: stmt.execution_count,
avg_execution_time_us: stmt.avg_execution_time_us,
age_secs: stmt.prepared_at.elapsed().as_secs(),
}
}).collect()
} else {
Vec::new()
}
}
}
impl Default for PreparedStatementCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedStatementInfo {
pub hash: u64,
pub sql_preview: String,
pub execution_count: u64,
pub avg_execution_time_us: u64,
pub age_secs: u64,
}
#[derive(Debug, Default)]
pub struct CacheKeyBuilder {
parts: Vec<String>,
}
impl CacheKeyBuilder {
pub fn new() -> Self {
Self { parts: Vec::new() }
}
pub fn table(mut self, table: &str) -> Self {
self.parts.push(format!("t:{}", table));
self
}
pub fn condition(mut self, column: &str, value: impl std::fmt::Display) -> Self {
self.parts.push(format!("{}={}", column, value));
self
}
pub fn order(mut self, column: &str, direction: &str) -> Self {
self.parts.push(format!("o:{}:{}", column, direction));
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.parts.push(format!("l:{}", limit));
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.parts.push(format!("off:{}", offset));
self
}
pub fn raw(mut self, part: &str) -> Self {
self.parts.push(part.to_string());
self
}
pub fn build(self) -> String {
self.parts.join(":")
}
pub fn build_hash(self) -> u64 {
use std::collections::hash_map::DefaultHasher;
let key = self.build();
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
}
#[derive(Debug, Clone)]
pub struct CacheOptions {
pub key: Option<String>,
pub ttl: Duration,
pub tags: Vec<String>,
}
impl CacheOptions {
pub fn new(ttl: Duration) -> Self {
Self {
key: None,
ttl,
tags: Vec::new(),
}
}
pub fn with_key(mut self, key: &str) -> Self {
self.key = Some(key.to_string());
self
}
pub fn with_tag(mut self, tag: &str) -> Self {
self.tags.push(tag.to_string());
self
}
pub fn with_tags(mut self, tags: &[&str]) -> Self {
self.tags.extend(tags.iter().map(|s| s.to_string()));
self
}
}
#[derive(Debug, Clone)]
pub struct CacheWarmer {
queries: Vec<WarmQuery>,
}
#[derive(Debug, Clone)]
struct WarmQuery {
key: String,
sql: String,
ttl: Duration,
}
impl CacheWarmer {
pub fn new() -> Self {
Self { queries: Vec::new() }
}
pub fn add_query(mut self, key: &str, sql: &str, ttl: Duration) -> Self {
self.queries.push(WarmQuery {
key: key.to_string(),
sql: sql.to_string(),
ttl,
});
self
}
pub fn query_count(&self) -> usize {
self.queries.len()
}
pub fn queries(&self) -> impl Iterator<Item = (&str, &str, Duration)> {
self.queries.iter().map(|q| (q.key.as_str(), q.sql.as_str(), q.ttl))
}
}
impl Default for CacheWarmer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert!(!config.enabled);
assert_eq!(config.max_entries, 1000);
assert_eq!(config.default_ttl, Duration::from_secs(60));
}
#[test]
fn test_query_cache_basic() {
let cache = QueryCache::new();
cache.enable();
assert!(cache.is_enabled());
assert!(cache.is_empty());
cache.set("test_key", &vec![1, 2, 3], None, "test_model").unwrap();
assert!(!cache.is_empty());
assert_eq!(cache.len(), 1);
let result: Option<Vec<i32>> = cache.get("test_key");
assert_eq!(result, Some(vec![1, 2, 3]));
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.entries, 1);
}
#[test]
fn test_query_cache_invalidation() {
let cache = QueryCache::new();
cache.enable();
cache.set("key1", &"value1", None, "model1").unwrap();
cache.set("key2", &"value2", None, "model1").unwrap();
cache.set("key3", &"value3", None, "model2").unwrap();
assert_eq!(cache.len(), 3);
cache.invalidate("key1");
assert_eq!(cache.len(), 2);
cache.invalidate_model("model1");
assert_eq!(cache.len(), 1);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_prepared_statement_cache() {
let cache = PreparedStatementCache::new();
cache.enable();
let sql = "SELECT * FROM users WHERE id = $1";
let (_, cached) = cache.get_or_prepare(sql);
assert!(!cached);
let (_, cached) = cache.get_or_prepare(sql);
assert!(cached);
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_cache_key_builder() {
let key = CacheKeyBuilder::new()
.table("users")
.condition("active", true)
.condition("role", "admin")
.order("created_at", "desc")
.limit(10)
.build();
assert!(key.contains("t:users"));
assert!(key.contains("active=true"));
assert!(key.contains("role=admin"));
assert!(key.contains("o:created_at:desc"));
assert!(key.contains("l:10"));
}
#[test]
fn test_cache_stats_hit_ratio() {
let mut stats = CacheStats::default();
assert_eq!(stats.hit_ratio(), 0.0);
stats.hits = 75;
stats.misses = 25;
assert!((stats.hit_ratio() - 0.75).abs() < 0.001);
}
#[test]
fn test_cache_strategy_display() {
assert_eq!(format!("{}", CacheStrategy::LRU), "LRU");
assert_eq!(format!("{}", CacheStrategy::FIFO), "FIFO");
assert_eq!(format!("{}", CacheStrategy::TTL), "TTL");
}
}