use super::CypherQuery;
use crate::error::KgError;
use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};
const CACHE_CAPACITY: usize = 256;
struct ParseCache {
map: HashMap<u64, CypherQuery>,
order: std::collections::VecDeque<u64>,
}
impl ParseCache {
fn new() -> Self {
Self {
map: HashMap::with_capacity(CACHE_CAPACITY),
order: std::collections::VecDeque::with_capacity(CACHE_CAPACITY),
}
}
}
static CACHE: OnceLock<RwLock<ParseCache>> = OnceLock::new();
fn cache() -> &'static RwLock<ParseCache> {
CACHE.get_or_init(|| RwLock::new(ParseCache::new()))
}
fn hash_query(query: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query.hash(&mut hasher);
hasher.finish()
}
pub fn parse_cypher_cached(query: &str) -> Result<CypherQuery, KgError> {
let key = hash_query(query);
{
let guard = cache().read().expect("parse_cache RwLock poisoned");
if let Some(ast) = guard.map.get(&key) {
return Ok(ast.clone());
}
}
let parsed = super::parser::parse_cypher(query)?;
let mut guard = cache().write().expect("parse_cache RwLock poisoned");
if guard.map.len() >= CACHE_CAPACITY && !guard.map.contains_key(&key) {
if let Some(oldest) = guard.order.pop_front() {
guard.map.remove(&oldest);
}
}
if !guard.map.contains_key(&key) {
guard.order.push_back(key);
}
guard.map.insert(key, parsed.clone());
Ok(parsed)
}
#[cfg(test)]
pub fn clear_for_tests() {
let mut guard = cache().write().expect("parse_cache RwLock poisoned");
guard.map.clear();
guard.order.clear();
}
#[cfg(test)]
pub fn entry_count_for_tests() -> usize {
cache()
.read()
.expect("parse_cache RwLock poisoned")
.map
.len()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static TEST_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn cache_hit_returns_equivalent_ast() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|p| p.into_inner());
clear_for_tests();
let q = "MATCH (n:Person) RETURN n.name";
let first = parse_cypher_cached(q).unwrap();
let second = parse_cypher_cached(q).unwrap();
assert_eq!(first.clauses.len(), second.clauses.len());
assert_eq!(entry_count_for_tests(), 1);
}
#[test]
fn cache_evicts_at_capacity() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|p| p.into_inner());
clear_for_tests();
for i in 0..(CACHE_CAPACITY + 5) {
let q = format!("MATCH (n:T{}) RETURN n", i);
parse_cypher_cached(&q).unwrap();
}
assert_eq!(entry_count_for_tests(), CACHE_CAPACITY);
}
#[test]
fn parse_errors_are_not_cached() {
let _guard = TEST_LOCK.lock().unwrap_or_else(|p| p.into_inner());
clear_for_tests();
let q = "MATCH NOT VALID CYPHER";
let r1 = parse_cypher_cached(q);
let r2 = parse_cypher_cached(q);
assert!(r1.is_err());
assert!(r2.is_err());
assert_eq!(entry_count_for_tests(), 0);
}
}