use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use lru::LruCache;
use std::num::NonZeroUsize;
use crate::statement::Statement;
const DEFAULT_LRU_CAPACITY: usize = 256;
pub struct StatementCache {
registered: HashMap<String, CachedStatement>,
adhoc: LruCache<String, CachedStatement>,
name_counter: AtomicU64,
metrics: CacheMetrics,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CachedStatement {
pub name: String,
pub statement: Statement,
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct CacheMetrics {
pub tier1_hits: u64,
pub tier2_hits: u64,
pub misses: u64,
pub evictions: u64,
}
impl CacheMetrics {
pub fn total_hits(&self) -> u64 {
self.tier1_hits + self.tier2_hits
}
pub fn hit_rate(&self) -> f64 {
let total = self.total_hits() + self.misses;
if total == 0 {
0.0
} else {
self.total_hits() as f64 / total as f64
}
}
}
impl StatementCache {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_LRU_CAPACITY)
}
pub fn with_capacity(lru_capacity: usize) -> Self {
Self {
registered: HashMap::new(),
adhoc: LruCache::new(NonZeroUsize::new(lru_capacity).unwrap_or(NonZeroUsize::MIN)),
name_counter: AtomicU64::new(0),
metrics: CacheMetrics::default(),
}
}
pub fn register(&mut self, name: &str, statement: Statement) {
self.registered.insert(
name.to_string(),
CachedStatement {
name: name.to_string(),
statement,
},
);
}
pub fn get_registered(&mut self, name: &str) -> Option<&CachedStatement> {
let result = self.registered.get(name);
if result.is_some() {
self.metrics.tier1_hits += 1;
}
result
}
pub fn get_adhoc(&mut self, sql: &str) -> Option<&CachedStatement> {
let result = self.adhoc.get(sql);
if result.is_some() {
self.metrics.tier2_hits += 1;
}
result
}
pub fn insert_adhoc(&mut self, sql: String, statement: Statement) -> Option<String> {
let name = self.generate_name();
let evicted = if self.adhoc.len() == self.adhoc.cap().get() {
self.adhoc.peek_lru().map(|(_, cached)| cached.name.clone())
} else {
None
};
if evicted.is_some() {
self.metrics.evictions += 1;
}
self.adhoc.put(sql, CachedStatement { name, statement });
evicted
}
pub fn record_miss(&mut self) {
self.metrics.misses += 1;
}
pub fn lookup_or_miss(&mut self, sql: &str) -> Option<&CachedStatement> {
if self.adhoc.get(sql).is_some() {
self.metrics.tier2_hits += 1;
self.adhoc.get(sql)
} else {
self.metrics.misses += 1;
None
}
}
pub fn metrics(&self) -> &CacheMetrics {
&self.metrics
}
pub fn registered_count(&self) -> usize {
self.registered.len()
}
pub fn adhoc_count(&self) -> usize {
self.adhoc.len()
}
pub fn generate_name(&self) -> String {
let id = self.name_counter.fetch_add(1, Ordering::Relaxed);
format!("_sentinel_s{id}")
}
}
impl Default for StatementCache {
fn default() -> Self {
Self::new()
}
}