use std::collections::HashMap;
use std::sync::Mutex;
use crate::parser::ParsedStatement;
pub(crate) struct ParseCache {
inner: Mutex<ParseCacheInner>,
max_size: usize,
}
impl std::fmt::Debug for ParseCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = self.stats();
f.debug_struct("ParseCache")
.field("max_size", &self.max_size)
.field("size", &s.size)
.field("hits", &s.hits)
.field("misses", &s.misses)
.field("evictions", &s.evictions)
.finish()
}
}
struct ParseCacheInner {
entries: HashMap<String, (u64, ParsedStatement)>,
tick: u64,
hits: u64,
misses: u64,
evictions: u64,
}
impl ParseCache {
pub(crate) fn new(max_size: usize) -> Self {
Self {
max_size,
inner: Mutex::new(ParseCacheInner {
entries: HashMap::with_capacity(max_size.min(1024)),
tick: 0,
hits: 0,
misses: 0,
evictions: 0,
}),
}
}
pub(crate) fn get(&self, sql: &str) -> Option<ParsedStatement> {
let mut guard = self.inner.lock().ok()?;
let next = guard.tick.wrapping_add(1);
guard.tick = next;
if let Some((counter, stmt)) = guard.entries.get_mut(sql) {
let cloned = stmt.clone();
*counter = next;
guard.hits += 1;
Some(cloned)
} else {
guard.misses += 1;
None
}
}
pub(crate) fn insert(&self, sql: String, stmt: ParsedStatement) {
if self.max_size == 0 {
return;
}
let Ok(mut guard) = self.inner.lock() else {
return;
};
if guard.entries.len() >= self.max_size && !guard.entries.contains_key(&sql) {
if let Some(oldest_key) = guard
.entries
.iter()
.min_by_key(|(_, (c, _))| *c)
.map(|(k, _)| k.clone())
{
guard.entries.remove(&oldest_key);
guard.evictions += 1;
}
}
let next = guard.tick.wrapping_add(1);
guard.tick = next;
guard.entries.insert(sql, (next, stmt));
}
pub(crate) fn stats(&self) -> ParseCacheStats {
match self.inner.lock() {
Ok(g) => ParseCacheStats {
size: g.entries.len(),
hits: g.hits,
misses: g.misses,
evictions: g.evictions,
},
Err(_) => ParseCacheStats::default(),
}
}
pub(crate) fn clear(&self) {
if let Ok(mut g) = self.inner.lock() {
g.entries.clear();
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ParseCacheStats {
pub size: usize,
pub hits: u64,
pub misses: u64,
pub evictions: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_statement;
fn parse(sql: &str) -> ParsedStatement {
parse_statement(sql).unwrap()
}
const SQL1: &str = "SELECT id FROM users WHERE id = 1";
const SQL2: &str = "SELECT id FROM users WHERE id = 2";
const SQL3: &str = "SELECT id FROM users WHERE id = 3";
#[test]
fn get_miss_on_empty_cache() {
let c = ParseCache::new(8);
assert!(c.get(SQL1).is_none());
assert_eq!(c.stats().misses, 1);
assert_eq!(c.stats().hits, 0);
}
#[test]
fn insert_then_get_hits() {
let c = ParseCache::new(8);
c.insert(SQL1.into(), parse(SQL1));
assert!(c.get(SQL1).is_some());
assert_eq!(c.stats().hits, 1);
assert_eq!(c.stats().size, 1);
}
#[test]
fn distinct_sql_does_not_collide() {
let c = ParseCache::new(8);
c.insert(SQL1.into(), parse(SQL1));
c.insert(SQL2.into(), parse(SQL2));
assert!(c.get(SQL1).is_some());
assert!(c.get(SQL2).is_some());
assert_eq!(c.stats().size, 2);
}
#[test]
fn lru_evicts_least_recently_used() {
let c = ParseCache::new(2);
c.insert(SQL1.into(), parse(SQL1));
c.insert(SQL2.into(), parse(SQL2));
c.get(SQL1);
c.insert(SQL3.into(), parse(SQL3));
assert!(c.get(SQL1).is_some());
assert!(c.get(SQL2).is_none());
assert!(c.get(SQL3).is_some());
assert_eq!(c.stats().evictions, 1);
}
#[test]
fn zero_max_size_disables_cache() {
let c = ParseCache::new(0);
c.insert(SQL1.into(), parse(SQL1));
assert!(c.get(SQL1).is_none());
assert_eq!(c.stats().size, 0);
}
#[test]
fn clear_drops_all_entries() {
let c = ParseCache::new(8);
c.insert(SQL1.into(), parse(SQL1));
c.insert(SQL2.into(), parse(SQL2));
c.clear();
assert_eq!(c.stats().size, 0);
assert!(c.get(SQL1).is_none());
}
#[test]
fn insert_updates_recency_on_existing_key() {
let c = ParseCache::new(8);
c.insert(SQL1.into(), parse(SQL1));
c.insert(SQL1.into(), parse(SQL1));
assert_eq!(c.stats().size, 1);
}
}