use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use super::GLOBAL_STMT_CACHE;
#[derive(Debug, Clone)]
struct PreparedStatement {
sql: String,
prepared_at: Instant,
last_used: Instant,
execution_count: u64,
avg_execution_time_us: u64,
}
impl PreparedStatement {
fn new(sql: String) -> Self {
let now = Instant::now();
Self {
sql,
prepared_at: now,
last_used: now,
execution_count: 0,
avg_execution_time_us: 0,
}
}
fn record_execution(&mut self, execution_time_us: u64) {
self.last_used = Instant::now();
let total = self.avg_execution_time_us * self.execution_count + execution_time_us;
self.execution_count += 1;
self.avg_execution_time_us = total / self.execution_count;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PreparedStatementStats {
pub hits: u64,
pub misses: u64,
pub cached_count: usize,
pub total_executions: u64,
pub evictions: u64,
}
impl PreparedStatementStats {
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
#[derive(Debug, Clone)]
pub struct PreparedStatementConfig {
pub enabled: bool,
pub max_statements: usize,
pub max_age: Duration,
}
impl Default for PreparedStatementConfig {
fn default() -> Self {
Self {
enabled: false,
max_statements: 500,
max_age: Duration::from_secs(3600),
}
}
}
#[derive(Debug)]
pub struct PreparedStatementCache {
config: RwLock<PreparedStatementConfig>,
enabled: AtomicBool,
statements: RwLock<HashMap<u64, PreparedStatement>>,
hits: AtomicU64,
misses: AtomicU64,
cached_count: AtomicUsize,
total_executions: AtomicU64,
evictions: AtomicU64,
}
impl PreparedStatementCache {
pub fn new() -> Self {
Self {
config: RwLock::new(PreparedStatementConfig::default()),
enabled: AtomicBool::new(false),
statements: RwLock::new(HashMap::new()),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
cached_count: AtomicUsize::new(0),
total_executions: AtomicU64::new(0),
evictions: AtomicU64::new(0),
}
}
pub fn with_config(config: PreparedStatementConfig) -> Self {
let enabled = config.enabled;
Self {
config: RwLock::new(config),
enabled: AtomicBool::new(enabled),
statements: RwLock::new(HashMap::new()),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
cached_count: AtomicUsize::new(0),
total_executions: AtomicU64::new(0),
evictions: AtomicU64::new(0),
}
}
fn snapshot_stats(&self) -> PreparedStatementStats {
PreparedStatementStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
cached_count: self.cached_count.load(Ordering::Relaxed),
total_executions: self.total_executions.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
}
}
pub fn global() -> &'static PreparedStatementCache {
GLOBAL_STMT_CACHE.get_or_init(PreparedStatementCache::new)
}
pub fn init_global(config: PreparedStatementConfig) -> &'static PreparedStatementCache {
let _ = GLOBAL_STMT_CACHE.set(PreparedStatementCache::with_config(config));
PreparedStatementCache::global()
}
pub fn enable(&self) -> &Self {
self.config.write().enabled = true;
self.enabled.store(true, Ordering::Release);
self
}
pub fn disable(&self) -> &Self {
self.config.write().enabled = false;
self.enabled.store(false, Ordering::Release);
self
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::Acquire)
}
pub fn set_max_statements(&self, max: usize) -> &Self {
self.config.write().max_statements = max;
self
}
pub fn set_max_age(&self, age: Duration) -> &Self {
self.config.write().max_age = age;
self
}
pub fn config(&self) -> Option<PreparedStatementConfig> {
Some(self.config.read().clone())
}
pub fn hash_sql(sql: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
sql.hash(&mut hasher);
hasher.finish()
}
pub fn get_or_prepare(&self, sql: &str) -> (String, bool) {
if !self.is_enabled() {
return (sql.to_string(), false);
}
let hash = Self::hash_sql(sql);
let max_age = self.config.read().max_age;
{
let statements = self.statements.read();
if let Some(stmt) = statements.get(&hash) {
if stmt.prepared_at.elapsed() < max_age {
let sql = stmt.sql.clone();
drop(statements);
self.hits.fetch_add(1, Ordering::Relaxed);
return (sql, true);
}
}
}
{
let mut statements = self.statements.write();
if let Some(stmt) = statements.get(&hash) {
if stmt.prepared_at.elapsed() < max_age {
let sql = stmt.sql.clone();
drop(statements);
self.hits.fetch_add(1, Ordering::Relaxed);
return (sql, true);
}
statements.remove(&hash);
}
}
self.cache_statement(sql);
self.misses.fetch_add(1, Ordering::Relaxed);
(sql.to_string(), false)
}
fn cache_statement(&self, sql: &str) {
let hash = Self::hash_sql(sql);
let max_statements = self.config.read().max_statements;
let mut statements = self.statements.write();
while statements.len() >= max_statements {
let oldest_key = statements
.iter()
.min_by_key(|(_, stmt)| stmt.last_used)
.map(|(key, _)| *key);
if let Some(key) = oldest_key {
statements.remove(&key);
self.evictions.fetch_add(1, Ordering::Relaxed);
}
}
statements.insert(hash, PreparedStatement::new(sql.to_string()));
self.cached_count.store(statements.len(), Ordering::Relaxed);
}
pub fn record_execution(&self, sql: &str, execution_time_us: u64) {
if !self.is_enabled() {
return;
}
let hash = Self::hash_sql(sql);
{
let mut statements = self.statements.write();
if let Some(stmt) = statements.get_mut(&hash) {
stmt.record_execution(execution_time_us);
}
}
self.total_executions.fetch_add(1, Ordering::Relaxed);
}
pub fn invalidate(&self, sql: &str) -> bool {
let hash = Self::hash_sql(sql);
let mut statements = self.statements.write();
let removed = statements.remove(&hash).is_some();
if removed {
self.cached_count.store(statements.len(), Ordering::Relaxed);
}
removed
}
pub fn clear(&self) {
let mut statements = self.statements.write();
statements.clear();
self.cached_count.store(0, Ordering::Relaxed);
}
pub fn stats(&self) -> PreparedStatementStats {
self.snapshot_stats()
}
pub fn reset_stats(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.total_executions.store(0, Ordering::Relaxed);
self.evictions.store(0, Ordering::Relaxed);
self.cached_count
.store(self.statements.read().len(), Ordering::Relaxed);
}
pub fn len(&self) -> usize {
self.statements.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn cached_statements_info(&self) -> Vec<CachedStatementInfo> {
let statements = self.statements.read();
statements
.iter()
.map(|(hash, stmt)| CachedStatementInfo {
hash: *hash,
sql_preview: if stmt.sql.len() > 100 {
format!("{}...", &stmt.sql[..100])
} else {
stmt.sql.clone()
},
execution_count: stmt.execution_count,
avg_execution_time_us: stmt.avg_execution_time_us,
age_secs: stmt.prepared_at.elapsed().as_secs(),
})
.collect()
}
}
impl Default for PreparedStatementCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedStatementInfo {
pub hash: u64,
pub sql_preview: String,
pub execution_count: u64,
pub avg_execution_time_us: u64,
pub age_secs: u64,
}