#![allow(clippy::expect_used)]
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::time::Instant;
use lru::LruCache;
pub const DEFAULT_MAX_STATEMENTS: usize = 256;
#[derive(Debug, Clone)]
pub struct PreparedStatement {
handle: i32,
sql_hash: u64,
sql: String,
created_at: Instant,
}
impl PreparedStatement {
pub fn new(handle: i32, sql: String) -> Self {
Self {
handle,
sql_hash: hash_sql(&sql),
sql,
created_at: Instant::now(),
}
}
#[must_use]
pub fn handle(&self) -> i32 {
self.handle
}
#[must_use]
pub fn sql_hash(&self) -> u64 {
self.sql_hash
}
#[must_use]
pub fn sql(&self) -> &str {
&self.sql
}
#[must_use]
pub fn created_at(&self) -> Instant {
self.created_at
}
#[must_use]
pub fn age(&self) -> std::time::Duration {
self.created_at.elapsed()
}
}
pub struct StatementCache {
cache: LruCache<u64, PreparedStatement>,
max_size: usize,
hits: u64,
misses: u64,
}
impl StatementCache {
#[must_use]
pub fn new(max_size: usize) -> Self {
assert!(max_size > 0, "max_size must be greater than 0");
Self {
cache: LruCache::new(NonZeroUsize::new(max_size).expect("max_size > 0")),
max_size,
hits: 0,
misses: 0,
}
}
#[must_use]
pub fn with_default_size() -> Self {
Self::new(DEFAULT_MAX_STATEMENTS)
}
pub fn get(&mut self, sql: &str) -> Option<i32> {
let hash = hash_sql(sql);
if let Some(stmt) = self.cache.get(&hash) {
self.hits += 1;
tracing::trace!(sql = sql, handle = stmt.handle, "statement cache hit");
Some(stmt.handle)
} else {
self.misses += 1;
tracing::trace!(sql = sql, "statement cache miss");
None
}
}
pub fn peek(&self, sql: &str) -> Option<&PreparedStatement> {
let hash = hash_sql(sql);
self.cache.peek(&hash)
}
pub fn insert(&mut self, stmt: PreparedStatement) -> Option<PreparedStatement> {
let hash = stmt.sql_hash;
tracing::debug!(
sql = stmt.sql(),
handle = stmt.handle,
"caching prepared statement"
);
let evicted = if self.cache.len() >= self.max_size {
self.cache.pop_lru().map(|(_, stmt)| stmt)
} else {
None
};
self.cache.put(hash, stmt);
evicted
}
pub fn remove(&mut self, sql: &str) -> Option<PreparedStatement> {
let hash = hash_sql(sql);
self.cache.pop(&hash)
}
pub fn clear(&mut self) -> impl Iterator<Item = PreparedStatement> + '_ {
let mut statements = Vec::with_capacity(self.cache.len());
while let Some((_, stmt)) = self.cache.pop_lru() {
statements.push(stmt);
}
tracing::debug!(count = statements.len(), "cleared statement cache");
statements.into_iter()
}
#[must_use]
pub fn len(&self) -> usize {
self.cache.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
#[must_use]
pub fn max_size(&self) -> usize {
self.max_size
}
#[must_use]
pub fn hits(&self) -> u64 {
self.hits
}
#[must_use]
pub fn misses(&self) -> u64 {
self.misses
}
#[must_use]
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn reset_stats(&mut self) {
self.hits = 0;
self.misses = 0;
}
}
impl Default for StatementCache {
fn default() -> Self {
Self::with_default_size()
}
}
impl std::fmt::Debug for StatementCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StatementCache")
.field("len", &self.cache.len())
.field("max_size", &self.max_size)
.field("hits", &self.hits)
.field("misses", &self.misses)
.finish()
}
}
#[must_use]
pub fn hash_sql(sql: &str) -> u64 {
let mut hasher = DefaultHasher::new();
sql.hash(&mut hasher);
hasher.finish()
}
#[derive(Debug, Clone)]
pub struct StatementCacheConfig {
pub enabled: bool,
pub max_size: usize,
}
impl Default for StatementCacheConfig {
fn default() -> Self {
Self {
enabled: true,
max_size: DEFAULT_MAX_STATEMENTS,
}
}
}
impl StatementCacheConfig {
#[must_use]
pub fn disabled() -> Self {
Self {
enabled: false,
max_size: 0,
}
}
#[must_use]
pub fn with_max_size(max_size: usize) -> Self {
Self {
enabled: true,
max_size,
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_statement_cache_new() {
let cache = StatementCache::new(10);
assert_eq!(cache.max_size(), 10);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[test]
fn test_statement_cache_insert_and_get() {
let mut cache = StatementCache::new(10);
let stmt = PreparedStatement::new(1, "SELECT * FROM users".to_string());
cache.insert(stmt);
assert_eq!(cache.len(), 1);
assert_eq!(cache.get("SELECT * FROM users"), Some(1));
assert_eq!(cache.hits(), 1);
assert_eq!(cache.misses(), 0);
}
#[test]
fn test_statement_cache_miss() {
let mut cache = StatementCache::new(10);
assert_eq!(cache.get("SELECT 1"), None);
assert_eq!(cache.misses(), 1);
assert_eq!(cache.hits(), 0);
}
#[test]
fn test_statement_cache_lru_eviction() {
let mut cache = StatementCache::new(2);
cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
assert_eq!(cache.len(), 2);
cache.get("SELECT 1");
let evicted = cache.insert(PreparedStatement::new(3, "SELECT 3".to_string()));
assert!(evicted.is_some());
assert_eq!(evicted.unwrap().handle(), 2);
assert_eq!(cache.len(), 2);
assert_eq!(cache.get("SELECT 1"), Some(1));
assert_eq!(cache.get("SELECT 2"), None);
assert_eq!(cache.get("SELECT 3"), Some(3));
}
#[test]
fn test_statement_cache_clear() {
let mut cache = StatementCache::new(10);
cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
let cleared: Vec<_> = cache.clear().collect();
assert_eq!(cleared.len(), 2);
assert!(cache.is_empty());
}
#[test]
fn test_statement_cache_remove() {
let mut cache = StatementCache::new(10);
cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
assert_eq!(cache.len(), 1);
let removed = cache.remove("SELECT 1");
assert!(removed.is_some());
assert_eq!(removed.unwrap().handle(), 1);
assert!(cache.is_empty());
}
#[test]
fn test_statement_cache_hit_ratio() {
let mut cache = StatementCache::new(10);
cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
cache.get("SELECT 1");
cache.get("SELECT 1");
cache.get("SELECT 2");
assert_eq!(cache.hits(), 2);
assert_eq!(cache.misses(), 1);
assert!((cache.hit_ratio() - 0.666666).abs() < 0.001);
}
#[test]
fn test_hash_sql_consistency() {
let sql = "SELECT * FROM users WHERE id = @p1";
let hash1 = hash_sql(sql);
let hash2 = hash_sql(sql);
assert_eq!(hash1, hash2);
}
#[test]
fn test_hash_sql_different() {
let hash1 = hash_sql("SELECT 1");
let hash2 = hash_sql("SELECT 2");
assert_ne!(hash1, hash2);
}
#[test]
fn test_prepared_statement_age() {
let stmt = PreparedStatement::new(1, "SELECT 1".to_string());
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(stmt.age().as_millis() >= 10);
}
#[test]
fn test_statement_cache_config_default() {
let config = StatementCacheConfig::default();
assert!(config.enabled);
assert_eq!(config.max_size, DEFAULT_MAX_STATEMENTS);
}
#[test]
fn test_statement_cache_config_disabled() {
let config = StatementCacheConfig::disabled();
assert!(!config.enabled);
}
}