use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use crate::error::{Error, Result};
static GLOBAL_QUERY_CACHE: OnceLock<QueryCache> = OnceLock::new();
static GLOBAL_STMT_CACHE: OnceLock<PreparedStatementCache> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub enabled: bool,
pub max_entries: usize,
pub default_ttl: Duration,
pub strategy: CacheStrategy,
pub cache_empty_results: bool,
pub key_prefix: Option<String>,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: false,
max_entries: 1000,
default_ttl: Duration::from_secs(60),
strategy: CacheStrategy::LRU,
cache_empty_results: true,
key_prefix: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheStrategy {
LRU,
FIFO,
TTL,
}
impl std::fmt::Display for CacheStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CacheStrategy::LRU => write!(f, "LRU"),
CacheStrategy::FIFO => write!(f, "FIFO"),
CacheStrategy::TTL => write!(f, "TTL"),
}
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
data: serde_json::Value,
size_bytes: usize,
created_at: Instant,
last_accessed: Instant,
ttl: Duration,
model_name: String,
hit_count: u64,
}
impl CacheEntry {
fn new(data: serde_json::Value, size_bytes: usize, ttl: Duration, model_name: &str) -> Self {
let now = Instant::now();
Self {
data,
size_bytes,
created_at: now,
last_accessed: now,
ttl,
model_name: model_name.to_string(),
hit_count: 0,
}
}
fn is_expired(&self) -> bool {
self.created_at.elapsed() > self.ttl
}
fn touch(&mut self) {
self.last_accessed = Instant::now();
self.hit_count += 1;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub size_bytes: usize,
pub evictions: u64,
pub invalidations: u64,
}
impl CacheStats {
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug)]
pub struct QueryCache {
config: RwLock<CacheConfig>,
enabled: AtomicBool,
cache: RwLock<HashMap<String, CacheEntry>>,
hits: AtomicU64,
misses: AtomicU64,
entries: AtomicUsize,
size_bytes: AtomicUsize,
evictions: AtomicU64,
invalidations: AtomicU64,
}
impl QueryCache {
pub fn new() -> Self {
Self {
config: RwLock::new(CacheConfig::default()),
enabled: AtomicBool::new(false),
cache: RwLock::new(HashMap::new()),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
entries: AtomicUsize::new(0),
size_bytes: AtomicUsize::new(0),
evictions: AtomicU64::new(0),
invalidations: AtomicU64::new(0),
}
}
pub fn with_config(config: CacheConfig) -> Self {
let enabled = config.enabled;
Self {
config: RwLock::new(config),
enabled: AtomicBool::new(enabled),
cache: RwLock::new(HashMap::new()),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
entries: AtomicUsize::new(0),
size_bytes: AtomicUsize::new(0),
evictions: AtomicU64::new(0),
invalidations: AtomicU64::new(0),
}
}
fn snapshot_stats(&self) -> CacheStats {
CacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
entries: self.entries.load(Ordering::Relaxed),
size_bytes: self.size_bytes.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
invalidations: self.invalidations.load(Ordering::Relaxed),
}
}
fn record_entries_len(&self, entries: usize) {
self.entries.store(entries, Ordering::Relaxed);
}
fn add_size_bytes(&self, bytes: usize) {
self.size_bytes.fetch_add(bytes, Ordering::Relaxed);
}
fn subtract_size_bytes(&self, bytes: usize) {
self.size_bytes.fetch_sub(bytes, Ordering::Relaxed);
}
fn overwrite_size_bytes(&self, bytes: usize) {
self.size_bytes.store(bytes, Ordering::Relaxed);
}
pub fn global() -> &'static QueryCache {
GLOBAL_QUERY_CACHE.get_or_init(QueryCache::new)
}
pub fn init_global(config: CacheConfig) -> &'static QueryCache {
let _ = GLOBAL_QUERY_CACHE.set(QueryCache::with_config(config));
QueryCache::global()
}
pub fn enable(&self) -> &Self {
self.config.write().enabled = true;
self.enabled.store(true, Ordering::Release);
self
}
pub fn disable(&self) -> &Self {
self.config.write().enabled = false;
self.enabled.store(false, Ordering::Release);
self
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::Acquire)
}
pub fn set_max_entries(&self, max: usize) -> &Self {
self.config.write().max_entries = max;
self
}
pub fn set_default_ttl(&self, ttl: Duration) -> &Self {
self.config.write().default_ttl = ttl;
self
}
pub fn set_strategy(&self, strategy: CacheStrategy) -> &Self {
self.config.write().strategy = strategy;
self
}
pub fn set_key_prefix(&self, prefix: &str) -> &Self {
self.config.write().key_prefix = Some(prefix.to_string());
self
}
pub fn set_cache_empty_results(&self, cache_empty: bool) -> &Self {
self.config.write().cache_empty_results = cache_empty;
self
}
pub fn config(&self) -> Option<CacheConfig> {
Some(self.config.read().clone())
}
pub fn generate_key(&self, table: &str, query_hash: u64) -> String {
let prefix = self.config.read().key_prefix.clone().unwrap_or_default();
if prefix.is_empty() {
format!("{}:{}", table, query_hash)
} else {
format!("{}:{}:{}", prefix, table, query_hash)
}
}
pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
if !self.is_enabled() {
return None;
}
let strategy = self.config.read().strategy;
{
let cache = self.cache.read();
match cache.get(key) {
Some(entry) if !entry.is_expired() && strategy != CacheStrategy::LRU => {
self.hits.fetch_add(1, Ordering::Relaxed);
return serde_json::from_value(entry.data.clone()).ok();
}
Some(_) => {}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
}
}
let mut cache = self.cache.write();
match cache.get(key) {
Some(entry) if entry.is_expired() => {
if let Some(expired_entry) = cache.remove(key) {
self.record_entries_len(cache.len());
self.subtract_size_bytes(expired_entry.size_bytes);
}
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
Some(_) if strategy == CacheStrategy::LRU => {
let entry = cache
.get_mut(key)
.expect("entry must exist after successful immutable lookup");
entry.touch();
self.hits.fetch_add(1, Ordering::Relaxed);
serde_json::from_value(entry.data.clone()).ok()
}
Some(entry) => {
self.hits.fetch_add(1, Ordering::Relaxed);
serde_json::from_value(entry.data.clone()).ok()
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
}
pub fn set<T: Serialize>(
&self,
key: &str,
value: &T,
ttl: Option<Duration>,
model_name: &str,
) -> Result<()> {
if !self.is_enabled() {
return Ok(());
}
let config = self.config.read();
let ttl = ttl.unwrap_or(config.default_ttl);
let max_entries = config.max_entries;
drop(config);
let data = serde_json::to_value(value)
.map_err(|e| Error::internal(format!("Failed to serialize cache value: {}", e)))?;
if let serde_json::Value::Array(arr) = &data {
if arr.is_empty() {
let should_cache = self.config.read().cache_empty_results;
if !should_cache {
return Ok(());
}
}
}
let entry_size = data.to_string().len();
let entry = CacheEntry::new(data, entry_size, ttl, model_name);
let mut cache = self.cache.write();
while cache.len() >= max_entries {
self.evict_one(&mut cache);
}
let replaced_entry = cache.insert(key.to_string(), entry);
self.record_entries_len(cache.len());
match replaced_entry {
Some(previous) if previous.size_bytes >= entry_size => {
self.subtract_size_bytes(previous.size_bytes - entry_size);
}
Some(previous) => {
self.add_size_bytes(entry_size - previous.size_bytes);
}
None => {
self.add_size_bytes(entry_size);
}
}
Ok(())
}
pub fn invalidate(&self, key: &str) -> bool {
let mut cache = self.cache.write();
if let Some(removed) = cache.remove(key) {
self.invalidations.fetch_add(1, Ordering::Relaxed);
self.record_entries_len(cache.len());
self.subtract_size_bytes(removed.size_bytes);
true
} else {
false
}
}
pub fn invalidate_model(&self, model_name: &str) {
let mut cache = self.cache.write();
let keys_to_remove: Vec<String> = cache
.iter()
.filter(|(_, entry)| entry.model_name == model_name)
.map(|(key, _)| key.clone())
.collect();
let count = keys_to_remove.len();
let mut removed_size = 0;
for key in keys_to_remove {
if let Some(entry) = cache.remove(&key) {
removed_size += entry.size_bytes;
}
}
if count > 0 {
self.invalidations
.fetch_add(count as u64, Ordering::Relaxed);
self.record_entries_len(cache.len());
self.subtract_size_bytes(removed_size);
}
}
pub fn clear(&self) {
let mut cache = self.cache.write();
let count = cache.len();
let removed_size = cache.values().map(|entry| entry.size_bytes).sum::<usize>();
cache.clear();
if count > 0 {
self.invalidations
.fetch_add(count as u64, Ordering::Relaxed);
self.record_entries_len(0);
self.subtract_size_bytes(removed_size);
}
}
pub fn stats(&self) -> CacheStats {
self.snapshot_stats()
}
pub fn reset_stats(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.evictions.store(0, Ordering::Relaxed);
self.invalidations.store(0, Ordering::Relaxed);
let cache = self.cache.read();
self.record_entries_len(cache.len());
self.overwrite_size_bytes(cache.values().map(|entry| entry.size_bytes).sum());
}
pub fn evict_expired(&self) {
let mut cache = self.cache.write();
let keys_to_remove: Vec<String> = cache
.iter()
.filter(|(_, entry)| entry.is_expired())
.map(|(key, _)| key.clone())
.collect();
let count = keys_to_remove.len();
let mut removed_size = 0;
for key in keys_to_remove {
if let Some(entry) = cache.remove(&key) {
removed_size += entry.size_bytes;
}
}
if count > 0 {
self.evictions.fetch_add(count as u64, Ordering::Relaxed);
self.record_entries_len(cache.len());
self.subtract_size_bytes(removed_size);
}
}
fn evict_one(&self, cache: &mut HashMap<String, CacheEntry>) {
let strategy = self.config.read().strategy;
let key_to_remove = match strategy {
CacheStrategy::LRU => cache
.iter()
.min_by_key(|(_, entry)| entry.last_accessed)
.map(|(key, _)| key.clone()),
CacheStrategy::FIFO => cache
.iter()
.min_by_key(|(_, entry)| entry.created_at)
.map(|(key, _)| key.clone()),
CacheStrategy::TTL => {
cache
.iter()
.min_by_key(|(_, entry)| {
entry
.ttl
.checked_sub(entry.created_at.elapsed())
.unwrap_or(Duration::ZERO)
})
.map(|(key, _)| key.clone())
}
};
if let Some(key) = key_to_remove {
if let Some(entry) = cache.remove(&key) {
self.evictions.fetch_add(1, Ordering::Relaxed);
self.record_entries_len(cache.len());
self.subtract_size_bytes(entry.size_bytes);
}
}
}
pub fn contains(&self, key: &str) -> bool {
let cache = self.cache.read();
if let Some(entry) = cache.get(key) {
return !entry.is_expired();
}
false
}
pub fn len(&self) -> usize {
self.entries.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct PreparedStatement {
sql: String,
prepared_at: Instant,
last_used: Instant,
execution_count: u64,
avg_execution_time_us: u64,
}
impl PreparedStatement {
fn new(sql: String) -> Self {
let now = Instant::now();
Self {
sql,
prepared_at: now,
last_used: now,
execution_count: 0,
avg_execution_time_us: 0,
}
}
fn record_execution(&mut self, execution_time_us: u64) {
self.last_used = Instant::now();
let total = self.avg_execution_time_us * self.execution_count + execution_time_us;
self.execution_count += 1;
self.avg_execution_time_us = total / self.execution_count;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PreparedStatementStats {
pub hits: u64,
pub misses: u64,
pub cached_count: usize,
pub total_executions: u64,
pub evictions: u64,
}
impl PreparedStatementStats {
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug, Clone)]
pub struct PreparedStatementConfig {
pub enabled: bool,
pub max_statements: usize,
pub max_age: Duration,
}
impl Default for PreparedStatementConfig {
fn default() -> Self {
Self {
enabled: false,
max_statements: 500,
max_age: Duration::from_secs(3600), }
}
}
#[derive(Debug)]
pub struct PreparedStatementCache {
config: RwLock<PreparedStatementConfig>,
enabled: AtomicBool,
statements: RwLock<HashMap<u64, PreparedStatement>>,
stats: RwLock<PreparedStatementStats>,
}
impl PreparedStatementCache {
pub fn new() -> Self {
Self {
config: RwLock::new(PreparedStatementConfig::default()),
enabled: AtomicBool::new(false),
statements: RwLock::new(HashMap::new()),
stats: RwLock::new(PreparedStatementStats::default()),
}
}
pub fn with_config(config: PreparedStatementConfig) -> Self {
let enabled = config.enabled;
Self {
config: RwLock::new(config),
enabled: AtomicBool::new(enabled),
statements: RwLock::new(HashMap::new()),
stats: RwLock::new(PreparedStatementStats::default()),
}
}
pub fn global() -> &'static PreparedStatementCache {
GLOBAL_STMT_CACHE.get_or_init(PreparedStatementCache::new)
}
pub fn init_global(config: PreparedStatementConfig) -> &'static PreparedStatementCache {
let _ = GLOBAL_STMT_CACHE.set(PreparedStatementCache::with_config(config));
PreparedStatementCache::global()
}
pub fn enable(&self) -> &Self {
self.config.write().enabled = true;
self.enabled.store(true, Ordering::Release);
self
}
pub fn disable(&self) -> &Self {
self.config.write().enabled = false;
self.enabled.store(false, Ordering::Release);
self
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::Acquire)
}
pub fn set_max_statements(&self, max: usize) -> &Self {
self.config.write().max_statements = max;
self
}
pub fn set_max_age(&self, age: Duration) -> &Self {
self.config.write().max_age = age;
self
}
pub fn config(&self) -> Option<PreparedStatementConfig> {
Some(self.config.read().clone())
}
pub fn hash_sql(sql: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
sql.hash(&mut hasher);
hasher.finish()
}
pub fn get_or_prepare(&self, sql: &str) -> (String, bool) {
if !self.is_enabled() {
return (sql.to_string(), false);
}
let hash = Self::hash_sql(sql);
let max_age = self.config.read().max_age;
{
let statements = self.statements.read();
if let Some(stmt) = statements.get(&hash) {
if stmt.prepared_at.elapsed() < max_age {
let sql = stmt.sql.clone();
drop(statements);
self.stats.write().hits += 1;
return (sql, true);
}
}
}
{
let mut statements = self.statements.write();
if let Some(stmt) = statements.get(&hash) {
if stmt.prepared_at.elapsed() < max_age {
let sql = stmt.sql.clone();
drop(statements);
self.stats.write().hits += 1;
return (sql, true);
}
statements.remove(&hash);
}
}
self.cache_statement(sql);
self.stats.write().misses += 1;
(sql.to_string(), false)
}
fn cache_statement(&self, sql: &str) {
let hash = Self::hash_sql(sql);
let max_statements = self.config.read().max_statements;
let mut statements = self.statements.write();
while statements.len() >= max_statements {
let oldest_key = statements
.iter()
.min_by_key(|(_, stmt)| stmt.last_used)
.map(|(key, _)| *key);
if let Some(key) = oldest_key {
statements.remove(&key);
self.stats.write().evictions += 1;
}
}
statements.insert(hash, PreparedStatement::new(sql.to_string()));
self.stats.write().cached_count = statements.len();
}
pub fn record_execution(&self, sql: &str, execution_time_us: u64) {
if !self.is_enabled() {
return;
}
let hash = Self::hash_sql(sql);
{
let mut statements = self.statements.write();
if let Some(stmt) = statements.get_mut(&hash) {
stmt.record_execution(execution_time_us);
}
}
self.stats.write().total_executions += 1;
}
pub fn invalidate(&self, sql: &str) -> bool {
let hash = Self::hash_sql(sql);
let mut statements = self.statements.write();
let removed = statements.remove(&hash).is_some();
if removed {
self.stats.write().cached_count = statements.len();
}
removed
}
pub fn clear(&self) {
let mut statements = self.statements.write();
statements.clear();
self.stats.write().cached_count = 0;
}
pub fn stats(&self) -> PreparedStatementStats {
self.stats.read().clone()
}
pub fn reset_stats(&self) {
let mut stats = self.stats.write();
*stats = PreparedStatementStats::default();
stats.cached_count = self.statements.read().len();
}
pub fn len(&self) -> usize {
self.statements.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn cached_statements_info(&self) -> Vec<CachedStatementInfo> {
let statements = self.statements.read();
statements
.iter()
.map(|(hash, stmt)| CachedStatementInfo {
hash: *hash,
sql_preview: if stmt.sql.len() > 100 {
format!("{}...", &stmt.sql[..100])
} else {
stmt.sql.clone()
},
execution_count: stmt.execution_count,
avg_execution_time_us: stmt.avg_execution_time_us,
age_secs: stmt.prepared_at.elapsed().as_secs(),
})
.collect()
}
}
impl Default for PreparedStatementCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedStatementInfo {
pub hash: u64,
pub sql_preview: String,
pub execution_count: u64,
pub avg_execution_time_us: u64,
pub age_secs: u64,
}
#[derive(Debug, Default)]
pub struct CacheKeyBuilder {
parts: Vec<String>,
}
impl CacheKeyBuilder {
pub fn new() -> Self {
Self { parts: Vec::new() }
}
pub fn table(mut self, table: &str) -> Self {
self.parts.push(format!("t:{}", table));
self
}
pub fn condition(mut self, column: &str, value: impl std::fmt::Display) -> Self {
self.parts.push(format!("{}={}", column, value));
self
}
pub fn order(mut self, column: &str, direction: &str) -> Self {
self.parts.push(format!("o:{}:{}", column, direction));
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.parts.push(format!("l:{}", limit));
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.parts.push(format!("off:{}", offset));
self
}
pub fn raw(mut self, part: &str) -> Self {
self.parts.push(part.to_string());
self
}
pub fn build(self) -> String {
self.parts.join(":")
}
pub fn build_hash(self) -> u64 {
use std::collections::hash_map::DefaultHasher;
let key = self.build();
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
hasher.finish()
}
}
#[derive(Debug, Clone)]
pub struct CacheOptions {
pub key: Option<String>,
pub ttl: Duration,
pub tags: Vec<String>,
}
impl CacheOptions {
pub fn new(ttl: Duration) -> Self {
Self {
key: None,
ttl,
tags: Vec::new(),
}
}
pub fn with_key(mut self, key: &str) -> Self {
self.key = Some(key.to_string());
self
}
pub fn with_tag(mut self, tag: &str) -> Self {
self.tags.push(tag.to_string());
self
}
pub fn with_tags(mut self, tags: &[&str]) -> Self {
self.tags.extend(tags.iter().map(|s| s.to_string()));
self
}
}
#[derive(Debug, Clone)]
pub struct CacheWarmer {
queries: Vec<WarmQuery>,
}
#[derive(Debug, Clone)]
struct WarmQuery {
key: String,
sql: String,
ttl: Duration,
}
impl CacheWarmer {
pub fn new() -> Self {
Self {
queries: Vec::new(),
}
}
pub fn add_query(mut self, key: &str, sql: &str, ttl: Duration) -> Self {
self.queries.push(WarmQuery {
key: key.to_string(),
sql: sql.to_string(),
ttl,
});
self
}
pub fn query_count(&self) -> usize {
self.queries.len()
}
pub fn queries(&self) -> impl Iterator<Item = (&str, &str, Duration)> {
self.queries
.iter()
.map(|q| (q.key.as_str(), q.sql.as_str(), q.ttl))
}
}
impl Default for CacheWarmer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[path = "testing/cache_tests.rs"]
mod tests;