use parking_lot::RwLock;
use rustc_hash::FxHashMap;
use std::collections::VecDeque;
use std::hash::{BuildHasher, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
use super::ast::Query;
use super::error::ParseError;
use super::Parser;
#[derive(Debug, Clone, Copy, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
}
impl CacheStats {
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
return 0.0;
}
(self.hits as f64 / total as f64) * 100.0
}
}
pub struct QueryCache {
cache: RwLock<FxHashMap<u64, Vec<CacheEntry>>>,
order: RwLock<VecDeque<CacheKey>>,
max_size: usize,
hash_fn: fn(&str) -> u64,
stats: AtomicCacheStats,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct CacheKey {
hash: u64,
original_query: String,
}
#[derive(Debug, Clone)]
struct CacheEntry {
original_query: String,
canonical_query: String,
parsed: Query,
}
#[derive(Debug, Default)]
struct AtomicCacheStats {
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
}
impl AtomicCacheStats {
fn snapshot(&self) -> CacheStats {
CacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
}
}
fn clear(&self) {
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
self.evictions.store(0, Ordering::Relaxed);
}
}
impl QueryCache {
#[must_use]
pub fn new(max_size: usize) -> Self {
Self::new_with_hasher(max_size, default_query_hash)
}
fn new_with_hasher(max_size: usize, hash_fn: fn(&str) -> u64) -> Self {
Self {
cache: RwLock::new(FxHashMap::default()),
order: RwLock::new(VecDeque::with_capacity(max_size.max(1))),
max_size: max_size.max(1),
hash_fn,
stats: AtomicCacheStats::default(),
}
}
pub fn parse(&self, query: &str) -> Result<Query, ParseError> {
self.parse_impl(query, true)
}
#[cfg(feature = "internal-bench")]
pub(crate) fn parse_without_stats(&self, query: &str) -> Result<Query, ParseError> {
self.parse_impl(query, false)
}
fn parse_impl(&self, query: &str, record_stats: bool) -> Result<Query, ParseError> {
let canonical_query = canonicalize_query(query);
let original_query = query.to_string();
let hash = (self.hash_fn)(&canonical_query);
if let Some(cached) =
self.try_cache_hit(hash, &original_query, &canonical_query, record_stats)
{
return Ok(cached);
}
let parsed = Parser::parse(query)?;
self.insert_into_cache(
hash,
original_query,
canonical_query,
query,
&parsed,
record_stats,
);
Ok(parsed)
}
fn try_cache_hit(
&self,
hash: u64,
original_query: &str,
canonical_query: &str,
record_stats: bool,
) -> Option<Query> {
let cache = self.cache.upgradable_read();
let cached = cache.get(&hash).and_then(|entries| {
entries
.iter()
.find(|entry| {
entry.original_query == original_query
&& entry.canonical_query == canonical_query
})
.cloned()
})?;
let key = CacheKey {
hash,
original_query: original_query.to_string(),
};
let mut order = self.order.write();
if let Some(pos) = order.iter().position(|existing| existing == &key) {
order.remove(pos);
}
order.push_back(key);
drop(order);
if record_stats {
self.stats.hits.fetch_add(1, Ordering::Relaxed);
}
Some(cached.parsed)
}
fn insert_into_cache(
&self,
hash: u64,
original_query: String,
canonical_query: String,
raw_query: &str,
parsed: &Query,
record_stats: bool,
) {
let mut cache = self.cache.write();
let mut order = self.order.write();
if record_stats {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
}
self.evict_oldest(&mut cache, &mut order, record_stats);
let key = CacheKey {
hash,
original_query: original_query.clone(),
};
if let Some(pos) = order.iter().position(|existing| existing == &key) {
order.remove(pos);
}
let new_entry = CacheEntry {
original_query,
canonical_query,
parsed: parsed.clone(),
};
cache
.entry(hash)
.and_modify(|bucket| {
bucket.retain(|entry| entry.original_query != raw_query);
bucket.push(new_entry.clone());
})
.or_insert_with(|| vec![new_entry]);
order.push_back(key);
debug_assert_eq!(Self::entry_count(&cache), order.len());
}
fn evict_oldest(
&self,
cache: &mut FxHashMap<u64, Vec<CacheEntry>>,
order: &mut VecDeque<CacheKey>,
record_stats: bool,
) {
while Self::entry_count(cache) >= self.max_size {
if let Some(oldest) = order.pop_front() {
if let Some(bucket) = cache.get_mut(&oldest.hash) {
bucket.retain(|entry| entry.original_query != oldest.original_query);
if bucket.is_empty() {
cache.remove(&oldest.hash);
}
}
if record_stats {
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
}
}
}
#[must_use]
pub fn stats(&self) -> CacheStats {
self.stats.snapshot()
}
#[must_use]
pub fn len(&self) -> usize {
Self::entry_count(&self.cache.read())
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
let mut cache = self.cache.write();
let mut order = self.order.write();
cache.clear();
order.clear();
self.stats.clear();
}
fn entry_count(cache: &FxHashMap<u64, Vec<CacheEntry>>) -> usize {
cache.values().map(std::vec::Vec::len).sum()
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new(1000)
}
}
fn default_query_hash(query: &str) -> u64 {
let mut hasher = rustc_hash::FxBuildHasher.build_hasher();
hasher.write(query.as_bytes());
hasher.finish()
}
fn canonicalize_query(query: &str) -> String {
query.split_whitespace().collect::<Vec<_>>().join(" ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_stats_hit_rate_empty() {
let stats = CacheStats::default();
assert!((stats.hit_rate() - 0.0).abs() < 1e-5);
}
#[test]
fn test_cache_stats_hit_rate_all_hits() {
let stats = CacheStats {
hits: 10,
misses: 0,
evictions: 0,
};
assert!((stats.hit_rate() - 100.0).abs() < 1e-5);
}
#[test]
fn test_cache_stats_hit_rate_half() {
let stats = CacheStats {
hits: 5,
misses: 5,
evictions: 0,
};
assert!((stats.hit_rate() - 50.0).abs() < 1e-5);
}
#[test]
fn test_query_cache_new() {
let cache = QueryCache::new(100);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[test]
fn test_query_cache_default() {
let cache = QueryCache::default();
assert!(cache.is_empty());
}
#[test]
fn test_query_cache_parse_and_hit() {
let cache = QueryCache::new(10);
let query = "SELECT * FROM docs LIMIT 5";
let result1 = cache.parse(query);
assert!(result1.is_ok());
assert_eq!(cache.stats().misses, 1);
assert_eq!(cache.stats().hits, 0);
let result2 = cache.parse(query);
assert!(result2.is_ok());
assert_eq!(cache.stats().hits, 1);
}
#[test]
fn test_query_cache_clear() {
let cache = QueryCache::new(10);
let _ = cache.parse("SELECT * FROM docs LIMIT 1");
assert!(!cache.is_empty());
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
}
#[test]
fn test_query_cache_eviction() {
let cache = QueryCache::new(2);
let _ = cache.parse("SELECT * FROM docs LIMIT 1");
let _ = cache.parse("SELECT * FROM docs LIMIT 2");
assert_eq!(cache.len(), 2);
let _ = cache.parse("SELECT * FROM docs LIMIT 3");
assert_eq!(cache.len(), 2);
assert!(cache.stats().evictions >= 1);
}
#[test]
fn test_query_cache_hit_refreshes_mru_without_duplicates() {
let cache = QueryCache::new(3);
let q1 = "SELECT * FROM docs LIMIT 1";
let q2 = "SELECT * FROM docs LIMIT 2";
let q3 = "SELECT * FROM docs LIMIT 3";
let _ = cache.parse(q1);
let _ = cache.parse(q2);
let _ = cache.parse(q3);
let _ = cache.parse(q1);
let order = cache.order.read();
assert_eq!(order.len(), cache.len());
assert_eq!(
order
.iter()
.filter(|v| v.original_query.as_str() == q1)
.count(),
1
);
assert_eq!(order.back().map(|v| v.original_query.as_str()), Some(q1));
}
#[test]
fn test_query_cache_concurrent_invariant_no_order_duplicates() {
use std::sync::Arc;
use std::thread;
let cache = Arc::new(QueryCache::new(32));
let queries = [
"SELECT * FROM docs LIMIT 1",
"SELECT * FROM docs LIMIT 2",
"SELECT * FROM docs LIMIT 3",
"SELECT * FROM docs LIMIT 4",
"SELECT * FROM docs LIMIT 5",
];
let mut handles = Vec::new();
for _ in 0..8 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..200 {
let q = queries[i % queries.len()];
let _ = cache.parse(q);
}
}));
}
for h in handles {
h.join().expect("thread must complete");
}
let order = cache.order.read();
let mut uniq = std::collections::HashSet::new();
for key in order.iter() {
assert!(uniq.insert(key.clone()), "duplicate query in LRU order");
}
assert_eq!(order.len(), cache.len());
}
#[test]
fn test_query_cache_collision_safe_with_forced_hash_collision() {
let cache = QueryCache::new_with_hasher(10, |_| 42);
let q1 = "SELECT * FROM docs LIMIT 1";
let q2 = "SELECT id FROM docs LIMIT 2";
let r1 = cache.parse(q1).expect("q1 should parse");
let r2 = cache.parse(q2).expect("q2 should parse");
let r1_again = cache.parse(q1).expect("q1 should be cache hit");
assert_eq!(r1, r1_again);
assert_ne!(r1, r2);
assert_eq!(cache.len(), 2);
}
#[test]
fn test_query_cache_min_size() {
let cache = QueryCache::new(0);
let _ = cache.parse("SELECT * FROM docs LIMIT 1");
assert!(!cache.is_empty());
}
#[test]
fn test_query_cache_invalid_query() {
let cache = QueryCache::new(10);
let result = cache.parse("INVALID QUERY SYNTAX!!!");
assert!(result.is_err());
}
}