use crate::pool::ConnectionId;
#[path = "cache_types.rs"]
mod types;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tracing::{debug, trace, warn};
use types::{CachedStatementMetadata, ConnectionCache};
pub use types::{PreparedCacheConfig, PreparedCacheStats};
const MAX_STATEMENT_AGE: Duration = Duration::from_secs(3600);
pub struct PreparedStatementCache {
cache: RwLock<HashMap<ConnectionId, ConnectionCache>>,
config: PreparedCacheConfig,
stats: RwLock<PreparedCacheStats>,
}
impl PreparedStatementCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: RwLock::new(HashMap::new()),
config: PreparedCacheConfig {
max_size,
..Default::default()
},
stats: RwLock::new(PreparedCacheStats::default()),
}
}
pub fn with_config(config: PreparedCacheConfig) -> Self {
Self {
cache: RwLock::new(HashMap::new()),
config,
stats: RwLock::new(PreparedCacheStats::default()),
}
}
pub fn get_connection_id(&self) -> ConnectionId {
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
NEXT_ID.fetch_add(1, Ordering::Relaxed)
}
pub fn record_hit(&self, conn_id: ConnectionId, sql: &str) {
let mut cache = self.cache.write();
let conn_cache = cache.entry(conn_id).or_insert_with(|| {
debug!("Creating new connection cache for {:?}", conn_id);
ConnectionCache::new()
});
if let Some(stmt) = conn_cache.get(sql) {
stmt.increment_use();
}
drop(cache);
self.stats.write().record_hit();
}
pub fn record_miss(&self, conn_id: ConnectionId, sql: &str, prepare_time_us: u64) {
let mut cache = self.cache.write();
if cache.len() >= self.config.max_connections && !cache.contains_key(&conn_id) {
self.evict_lru_connection(&mut cache);
}
let conn_cache = cache.entry(conn_id).or_insert_with(|| {
debug!("Creating new connection cache for {:?}", conn_id);
ConnectionCache::new()
});
if conn_cache.len() >= self.config.max_size && !conn_cache.statements.contains_key(sql) {
if let Some(evicted) = conn_cache.evict_lru() {
debug!("Evicted cached statement: {}", evicted);
self.stats.write().record_eviction();
}
}
let metadata = CachedStatementMetadata::new(sql.to_string());
conn_cache.insert(sql.to_string(), metadata);
let total_size = cache.values().map(|c| c.len()).sum();
let connection_count = cache.len();
drop(cache);
let mut stats = self.stats.write();
stats.record_miss();
stats.record_prepared(prepare_time_us);
stats.update_size(total_size);
stats.update_active_connections(connection_count);
trace!("Recorded cache miss for SQL on {:?}: {}", conn_id, sql);
}
pub fn is_cached(&self, conn_id: ConnectionId, sql: &str) -> bool {
let mut cache = self.cache.write();
if let Some(conn_cache) = cache.get_mut(&conn_id) {
if let Some(stmt) = conn_cache.get(sql) {
return !stmt.needs_refresh(&self.config);
}
}
false
}
pub async fn get_or_prepare(
&self,
conn: &libsql::Connection,
sql: &str,
) -> Result<libsql::Statement, libsql::Error> {
let conn_id = self.get_connection_id();
if self.is_cached(conn_id, sql) {
self.record_hit(conn_id, sql);
}
let start = Instant::now();
let stmt = conn.prepare(sql).await?;
let prepare_time_us = start.elapsed().as_micros() as u64;
self.record_miss(conn_id, sql, prepare_time_us);
Ok(stmt)
}
fn evict_lru_connection(&self, cache: &mut HashMap<ConnectionId, ConnectionCache>) {
if cache.is_empty() {
return;
}
let mut oldest = None;
let mut oldest_time = Instant::now();
for (id, conn_cache) in cache.iter() {
if conn_cache.last_accessed < oldest_time {
oldest_time = conn_cache.last_accessed;
oldest = Some(*id);
}
}
if let Some(id) = oldest {
if cache.remove(&id).is_some() {
warn!(
"Evicted connection cache for {:?} (max connections exceeded)",
id
);
self.stats.write().record_connection_eviction();
}
}
}
pub fn clear_connection(&self, conn_id: ConnectionId) -> usize {
let mut cache = self.cache.write();
let cleared = if let Some(conn_cache) = cache.remove(&conn_id) {
let count = conn_cache.len();
debug!(
"Cleared {} cached statements for connection {:?}",
count, conn_id
);
count
} else {
0
};
let total_size = cache.values().map(|c| c.len()).sum();
let active_connections = cache.len();
drop(cache);
let mut stats = self.stats.write();
stats.update_size(total_size);
stats.update_active_connections(active_connections);
cleared
}
pub fn clear(&self) {
let mut cache = self.cache.write();
let total_statements: usize = cache.values().map(|c| c.len()).sum();
cache.clear();
let mut stats = self.stats.write();
stats.update_size(0);
stats.update_active_connections(0);
debug!(
"Cleared {} cached statements from {} connections",
total_statements,
cache.len()
);
}
pub fn stats(&self) -> PreparedCacheStats {
self.stats.read().clone()
}
pub fn total_size(&self) -> usize {
self.cache.read().values().map(|c| c.len()).sum()
}
pub fn connection_size(&self, conn_id: ConnectionId) -> usize {
self.cache
.read()
.get(&conn_id)
.map(|c| c.len())
.unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.total_size() == 0
}
pub fn connection_count(&self) -> usize {
self.cache.read().len()
}
pub fn remove(&self, conn_id: ConnectionId, sql: &str) -> bool {
let mut cache = self.cache.write();
let removed = if let Some(conn_cache) = cache.get_mut(&conn_id) {
conn_cache.remove(sql)
} else {
false
};
if removed {
let total_size = cache.values().map(|c| c.len()).sum();
let active_connections = cache.len();
drop(cache);
let mut stats = self.stats.write();
stats.update_size(total_size);
stats.update_active_connections(active_connections);
}
removed
}
pub fn cleanup_idle_connections(&self, max_idle_duration: Duration) -> usize {
let mut cache = self.cache.write();
let mut to_remove = Vec::new();
for (id, conn_cache) in cache.iter() {
if conn_cache.idle_time() > max_idle_duration {
to_remove.push(*id);
}
}
let count = to_remove.len();
for id in to_remove {
cache.remove(&id);
debug!("Cleaned up idle connection cache for {:?}", id);
}
if count > 0 {
let total_size = cache.values().map(|c| c.len()).sum();
let active_connections = cache.len();
drop(cache);
let mut stats = self.stats.write();
stats.update_size(total_size);
stats.update_active_connections(active_connections);
stats.connection_evictions += count as u64;
}
count
}
}
impl Default for PreparedStatementCache {
fn default() -> Self {
Self::new(100)
}
}
#[cfg(test)]
#[path = "cache_tests.rs"]
mod tests;