use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct CachedStatement {
pub sql: String,
pub last_used: Instant,
pub hit_count: u64,
}
#[derive(Debug)]
pub struct StatementCache {
cache: HashMap<u64, CachedStatement>,
max_size: usize,
}
impl StatementCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: HashMap::with_capacity(max_size.min(256)),
max_size,
}
}
pub fn get_or_insert(&mut self, key: u64, builder: impl FnOnce() -> String) -> &str {
if !self.cache.contains_key(&key) && self.cache.len() >= self.max_size {
self.evict_lru();
}
let entry = self.cache.entry(key).or_insert_with(|| CachedStatement {
sql: builder(),
last_used: Instant::now(),
hit_count: 0,
});
entry.last_used = Instant::now();
entry.hit_count += 1;
&entry.sql
}
pub fn contains(&self, key: u64) -> bool {
self.cache.contains_key(&key)
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn clear(&mut self) {
self.cache.clear();
}
fn evict_lru(&mut self) {
if let Some((&lru_key, _)) = self.cache.iter().min_by_key(|(_, entry)| entry.last_used) {
self.cache.remove(&lru_key);
}
}
}
pub fn cache_key(value: &impl Hash) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
value.hash(&mut hasher);
hasher.finish()
}
impl Default for StatementCache {
fn default() -> Self {
Self::new(1024)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_hit() {
let mut cache = StatementCache::new(10);
let sql = cache
.get_or_insert(1, || "SELECT 1".to_string())
.to_string();
assert_eq!(sql, "SELECT 1");
let called = std::cell::Cell::new(false);
let sql2 = cache
.get_or_insert(1, || {
called.set(true);
"SELECT 1".to_string()
})
.to_string();
assert_eq!(sql2, "SELECT 1");
assert!(!called.get());
}
#[test]
fn test_cache_miss() {
let mut cache = StatementCache::new(10);
let sql1 = cache
.get_or_insert(1, || "SELECT 1".to_string())
.to_string();
let sql2 = cache
.get_or_insert(2, || "SELECT 2".to_string())
.to_string();
assert_eq!(sql1, "SELECT 1");
assert_eq!(sql2, "SELECT 2");
assert_eq!(cache.len(), 2);
}
#[test]
fn test_eviction() {
let mut cache = StatementCache::new(2);
cache.get_or_insert(1, || "SELECT 1".to_string());
cache.get_or_insert(2, || "SELECT 2".to_string());
cache.get_or_insert(3, || "SELECT 3".to_string());
assert_eq!(cache.len(), 2);
assert!(!cache.contains(1));
assert!(cache.contains(2));
assert!(cache.contains(3));
}
#[test]
fn test_lru_ordering() {
let mut cache = StatementCache::new(2);
cache.get_or_insert(1, || "SELECT 1".to_string());
cache.get_or_insert(2, || "SELECT 2".to_string());
let called = std::cell::Cell::new(false);
cache.get_or_insert(1, || {
called.set(true);
"SELECT 1".to_string()
});
assert!(!called.get());
cache.get_or_insert(3, || "SELECT 3".to_string());
assert!(cache.contains(1));
assert!(!cache.contains(2));
assert!(cache.contains(3));
}
#[test]
fn test_cache_key_function() {
let key1 = cache_key(&"SELECT * FROM users");
let key2 = cache_key(&"SELECT * FROM users");
let key3 = cache_key(&"SELECT * FROM orders");
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_clear() {
let mut cache = StatementCache::new(10);
cache.get_or_insert(1, || "SELECT 1".to_string());
cache.get_or_insert(2, || "SELECT 2".to_string());
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
}
}