use crate::algebra::Algebra;
use crate::cache::CacheCoordinator;
use crate::optimizer::Statistics;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
pub struct QueryPlanCache {
cache: Arc<DashMap<QuerySignature, CachedPlan>>,
config: CachingConfig,
stats: Arc<CacheStatistics>,
access_counter: Arc<AtomicU64>,
invalidation_coordinator: Option<Arc<CacheCoordinator>>,
invalidated_entries: Arc<dashmap::DashSet<QuerySignature>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachingConfig {
pub enabled: bool,
pub max_cache_size: usize,
pub ttl_seconds: u64,
pub parameterized_queries: bool,
pub invalidate_on_stats_change: bool,
pub stats_change_threshold: f64,
}
impl Default for CachingConfig {
fn default() -> Self {
Self {
enabled: true,
max_cache_size: 10000,
ttl_seconds: 3600, parameterized_queries: true,
invalidate_on_stats_change: true,
stats_change_threshold: 0.2, }
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct QuerySignature {
normalized_query: String,
parameter_types: Vec<String>,
stats_hash: u64,
}
impl QuerySignature {
pub fn new(query: &str, params: Vec<String>, stats: &Statistics) -> Self {
Self {
normalized_query: Self::normalize_query(query),
parameter_types: params,
stats_hash: Self::hash_statistics(stats),
}
}
fn normalize_query(query: &str) -> String {
let mut normalized = query.to_string();
let re_string = regex::Regex::new(r#""[^"]*""#).expect("regex pattern should be valid");
normalized = re_string.replace_all(&normalized, "\"?\"").to_string();
let re_number =
regex::Regex::new(r"\b\d+(\.\d+)?\b").expect("regex pattern should be valid");
normalized = re_number.replace_all(&normalized, "?").to_string();
let re_whitespace = regex::Regex::new(r"\s+").expect("regex pattern should be valid");
re_whitespace.replace_all(&normalized, " ").to_string()
}
fn hash_statistics(stats: &Statistics) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
for (pattern, card) in &stats.cardinalities {
pattern.hash(&mut hasher);
card.hash(&mut hasher);
}
for (pred, freq) in &stats.predicate_frequency {
pred.hash(&mut hasher);
freq.hash(&mut hasher);
}
hasher.finish()
}
}
#[derive(Debug, Clone)]
pub struct CachedPlan {
pub plan: Algebra,
pub cached_at: Instant,
pub hit_count: Arc<AtomicUsize>,
pub last_accessed: Arc<AtomicU64>,
pub estimated_cost: f64,
pub stats_snapshot: StatisticsSnapshot,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatisticsSnapshot {
pub cardinalities: BTreeMap<String, usize>,
pub predicate_frequency: BTreeMap<String, usize>,
pub snapshot_time: u64,
}
impl StatisticsSnapshot {
pub fn from_statistics(stats: &Statistics) -> Self {
Self {
cardinalities: stats
.cardinalities
.iter()
.map(|(k, v)| (k.clone(), *v))
.collect(),
predicate_frequency: stats
.predicate_frequency
.iter()
.map(|(k, v)| (k.clone(), *v))
.collect(),
snapshot_time: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
}
}
pub fn has_changed_significantly(&self, current_stats: &Statistics, threshold: f64) -> bool {
for (pattern, old_card) in &self.cardinalities {
let current_card = current_stats
.cardinalities
.get(pattern)
.copied()
.unwrap_or(0);
if *old_card == 0 && current_card > 0 {
return true; }
if *old_card > 0 {
let change_ratio =
(current_card as f64 - *old_card as f64).abs() / *old_card as f64;
if change_ratio > threshold {
return true;
}
}
}
for (pred, old_freq) in &self.predicate_frequency {
let current_freq = current_stats
.predicate_frequency
.get(pred)
.copied()
.unwrap_or(0);
if *old_freq > 0 {
let change_ratio =
(current_freq as f64 - *old_freq as f64).abs() / *old_freq as f64;
if change_ratio > threshold {
return true;
}
}
}
false
}
}
#[derive(Debug, Default)]
pub struct CacheStatistics {
pub hits: AtomicU64,
pub misses: AtomicU64,
pub evictions: AtomicU64,
pub invalidations: AtomicU64,
pub size_bytes: AtomicU64,
}
impl CacheStatistics {
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
pub fn total_requests(&self) -> u64 {
self.hits.load(Ordering::Relaxed) + self.misses.load(Ordering::Relaxed)
}
}
impl QueryPlanCache {
pub fn new() -> Self {
Self::with_config(CachingConfig::default())
}
pub fn with_config(config: CachingConfig) -> Self {
Self {
cache: Arc::new(DashMap::new()),
config,
stats: Arc::new(CacheStatistics::default()),
access_counter: Arc::new(AtomicU64::new(0)),
invalidation_coordinator: None,
invalidated_entries: Arc::new(dashmap::DashSet::new()),
}
}
pub fn with_invalidation_coordinator(
config: CachingConfig,
coordinator: Arc<CacheCoordinator>,
) -> Self {
Self {
cache: Arc::new(DashMap::new()),
config,
stats: Arc::new(CacheStatistics::default()),
access_counter: Arc::new(AtomicU64::new(0)),
invalidation_coordinator: Some(coordinator),
invalidated_entries: Arc::new(dashmap::DashSet::new()),
}
}
pub fn attach_coordinator(&mut self, coordinator: Arc<CacheCoordinator>) {
self.invalidation_coordinator = Some(coordinator);
}
pub fn get(
&self,
query: &str,
params: Vec<String>,
current_stats: &Statistics,
) -> Option<Algebra> {
if !self.config.enabled {
return None;
}
let signature = QuerySignature::new(query, params, current_stats);
if self.invalidated_entries.contains(&signature) {
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
self.stats.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
if let Some(entry) = self.cache.get_mut(&signature) {
let elapsed = entry.cached_at.elapsed();
if elapsed.as_secs() > self.config.ttl_seconds {
drop(entry); self.cache.remove(&signature);
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
self.stats.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
if self.config.invalidate_on_stats_change
&& entry
.stats_snapshot
.has_changed_significantly(current_stats, self.config.stats_change_threshold)
{
drop(entry); self.cache.remove(&signature);
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
self.stats.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
entry.hit_count.fetch_add(1, Ordering::Relaxed);
let access_time = self.access_counter.fetch_add(1, Ordering::Relaxed);
entry.last_accessed.store(access_time, Ordering::Relaxed);
self.stats.hits.fetch_add(1, Ordering::Relaxed);
return Some(entry.plan.clone());
}
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub fn insert(
&self,
query: &str,
params: Vec<String>,
plan: Algebra,
estimated_cost: f64,
current_stats: &Statistics,
) {
if !self.config.enabled {
return;
}
if self.cache.len() >= self.config.max_cache_size {
self.evict_lru();
}
let signature = QuerySignature::new(query, params, current_stats);
let cached_plan = CachedPlan {
plan,
cached_at: Instant::now(),
hit_count: Arc::new(AtomicUsize::new(0)),
last_accessed: Arc::new(AtomicU64::new(self.access_counter.load(Ordering::Relaxed))),
estimated_cost,
stats_snapshot: StatisticsSnapshot::from_statistics(current_stats),
};
self.cache.insert(signature, cached_plan);
}
fn evict_lru(&self) {
let mut oldest_key = None;
let mut oldest_access = u64::MAX;
for entry in self.cache.iter() {
let access_time = entry.last_accessed.load(Ordering::Relaxed);
if access_time < oldest_access {
oldest_access = access_time;
oldest_key = Some(entry.key().clone());
}
}
if let Some(key) = oldest_key {
self.cache.remove(&key);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
}
pub fn clear(&self) {
let count = self.cache.len();
self.cache.clear();
self.stats
.invalidations
.fetch_add(count as u64, Ordering::Relaxed);
}
pub fn invalidate_pattern(&self, pattern: &str) {
let keys_to_remove: Vec<_> = self
.cache
.iter()
.filter(|entry| entry.stats_snapshot.cardinalities.contains_key(pattern))
.map(|entry| entry.key().clone())
.collect();
for key in keys_to_remove {
self.invalidated_entries.insert(key.clone());
self.cache.remove(&key);
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
}
}
pub fn mark_invalidated(&self, signature: QuerySignature) {
self.invalidated_entries.insert(signature);
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
}
pub fn invalidate_signature(&self, signature: &QuerySignature) {
self.invalidated_entries.insert(signature.clone());
self.cache.remove(signature);
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
}
pub fn statistics(&self) -> CacheStats {
CacheStats {
hits: self.stats.hits.load(Ordering::Relaxed),
misses: self.stats.misses.load(Ordering::Relaxed),
evictions: self.stats.evictions.load(Ordering::Relaxed),
invalidations: self.stats.invalidations.load(Ordering::Relaxed),
size: self.cache.len(),
capacity: self.config.max_cache_size,
hit_rate: self.stats.hit_rate(),
}
}
pub fn config(&self) -> &CachingConfig {
&self.config
}
pub fn update_config(&mut self, config: CachingConfig) {
self.config = config;
}
}
impl Default for QueryPlanCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub invalidations: u64,
pub size: usize,
pub capacity: usize,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_query_plan_cache_basic() {
let cache = QueryPlanCache::new();
let query = "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10";
let stats = Statistics::new();
assert!(cache.get(query, vec![], &stats).is_none());
let plan = Algebra::Bgp(vec![]);
cache.insert(query, vec![], plan.clone(), 100.0, &stats);
let cached = cache.get(query, vec![], &stats);
assert!(cached.is_some());
let stats = cache.statistics();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_cache_normalization() {
let stats = Statistics::new();
let query1 = "SELECT ?s WHERE { ?s <http://example.org/p> \"Alice\" }";
let query2 = "SELECT ?s WHERE { ?s <http://example.org/p> \"Bob\" }";
let sig1 = QuerySignature::new(query1, vec![], &stats);
let sig2 = QuerySignature::new(query2, vec![], &stats);
assert_eq!(sig1.normalized_query, sig2.normalized_query);
}
#[test]
#[ignore = "inherently slow: requires wall-clock TTL expiry (use nextest --ignored to run)"]
fn test_cache_ttl() {
let config = CachingConfig {
ttl_seconds: 1, ..Default::default()
};
let cache = QueryPlanCache::with_config(config);
let query = "SELECT ?s WHERE { ?s ?p ?o }";
let stats = Statistics::new();
cache.insert(query, vec![], Algebra::Bgp(vec![]), 100.0, &stats);
assert!(cache.get(query, vec![], &stats).is_some());
std::thread::sleep(Duration::from_secs(2));
assert!(cache.get(query, vec![], &stats).is_none());
}
#[test]
fn test_cache_eviction() {
let config = CachingConfig {
max_cache_size: 2,
..Default::default()
};
let cache = QueryPlanCache::with_config(config);
let stats = Statistics::new();
cache.insert("query1", vec![], Algebra::Bgp(vec![]), 100.0, &stats);
cache.insert("query2", vec![], Algebra::Bgp(vec![]), 100.0, &stats);
cache.insert("query3", vec![], Algebra::Bgp(vec![]), 100.0, &stats);
assert_eq!(cache.cache.len(), 2);
let stats = cache.statistics();
assert_eq!(stats.evictions, 1);
}
#[test]
fn test_cache_clear() {
let cache = QueryPlanCache::new();
let stats = Statistics::new();
for i in 0..10 {
let query = format!("SELECT ?s ?var{} WHERE {{ ?s ?p{} ?o{} }}", i, i, i);
cache.insert(&query, vec![], Algebra::Bgp(vec![]), 100.0, &stats);
}
let initial_len = cache.cache.len();
assert!(initial_len > 0, "Cache should have entries");
cache.clear();
assert_eq!(cache.cache.len(), 0);
let cache_stats = cache.statistics();
assert_eq!(cache_stats.invalidations, initial_len as u64);
}
#[test]
fn test_statistics_snapshot() {
let stats = Statistics::new();
let snapshot = StatisticsSnapshot::from_statistics(&stats);
assert!(!snapshot.has_changed_significantly(&stats, 0.2));
}
#[test]
fn test_cache_disabled() {
let config = CachingConfig {
enabled: false,
..Default::default()
};
let cache = QueryPlanCache::with_config(config);
let stats = Statistics::new();
cache.insert("query", vec![], Algebra::Bgp(vec![]), 100.0, &stats);
assert!(cache.get("query", vec![], &stats).is_none());
}
#[test]
fn test_hit_rate_calculation() {
let cache = QueryPlanCache::new();
let stats = Statistics::new();
assert_eq!(cache.statistics().hit_rate, 0.0);
cache.insert(
"SELECT ?s WHERE { ?s ?p ?o }",
vec![],
Algebra::Bgp(vec![]),
100.0,
&stats,
);
cache.get("SELECT ?s WHERE { ?s ?p ?o }", vec![], &stats); cache.get("SELECT ?x WHERE { ?x ?y ?z }", vec![], &stats);
let cache_stats = cache.statistics();
assert_eq!(cache_stats.hits, 1);
assert_eq!(cache_stats.misses, 1); assert!((cache_stats.hit_rate - 0.5).abs() < 0.01); }
}