use crate::{PredicateInfo, SymbolTable};
use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CacheKey {
PredicateByName(String),
PredicatesByArity(usize),
PredicatesByDomain(String),
PredicatesBySignature(Vec<String>),
PredicatesByPattern(String),
DomainUsageCount(String),
AllDomainNames,
AllPredicateNames,
Custom(String),
}
#[derive(Debug, Clone)]
pub struct CachedResult<T> {
pub value: T,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
pub ttl: Option<Duration>,
}
impl<T> CachedResult<T> {
pub fn new(value: T, ttl: Option<Duration>) -> Self {
let now = Instant::now();
Self {
value,
created_at: now,
last_accessed: now,
access_count: 1,
ttl,
}
}
pub fn is_expired(&self) -> bool {
if let Some(ttl) = self.ttl {
self.created_at.elapsed() > ttl
} else {
false
}
}
pub fn update_access(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub default_ttl: Option<Duration>,
pub enable_lru: bool,
pub enable_stats: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
default_ttl: Some(Duration::from_secs(300)), enable_lru: true,
enable_stats: true,
}
}
}
impl CacheConfig {
pub fn small() -> Self {
Self {
max_entries: 100,
default_ttl: Some(Duration::from_secs(60)),
enable_lru: true,
enable_stats: true,
}
}
pub fn large() -> Self {
Self {
max_entries: 10000,
default_ttl: Some(Duration::from_secs(600)),
enable_lru: true,
enable_stats: true,
}
}
pub fn no_ttl() -> Self {
Self {
max_entries: 1000,
default_ttl: None,
enable_lru: true,
enable_stats: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct QueryCacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub expirations: u64,
pub invalidations: u64,
}
impl QueryCacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn miss_rate(&self) -> f64 {
1.0 - self.hit_rate()
}
pub fn total_accesses(&self) -> u64 {
self.hits + self.misses
}
}
pub struct QueryCache<T> {
cache: HashMap<CacheKey, CachedResult<T>>,
lru_queue: VecDeque<CacheKey>,
config: CacheConfig,
stats: QueryCacheStats,
}
impl<T: Clone> QueryCache<T> {
pub fn new() -> Self {
Self::with_config(CacheConfig::default())
}
pub fn with_config(config: CacheConfig) -> Self {
Self {
cache: HashMap::new(),
lru_queue: VecDeque::new(),
config,
stats: QueryCacheStats::default(),
}
}
pub fn get(&mut self, key: &CacheKey) -> Option<T> {
let is_expired = self
.cache
.get(key)
.map(|entry| entry.is_expired())
.unwrap_or(false);
if is_expired {
self.cache.remove(key);
if self.config.enable_stats {
self.stats.expirations += 1;
self.stats.misses += 1;
}
return None;
}
if let Some(entry) = self.cache.get_mut(key) {
entry.update_access();
if self.config.enable_stats {
self.stats.hits += 1;
}
let value = entry.value.clone();
if self.config.enable_lru {
self.update_lru(key);
}
Some(value)
} else {
if self.config.enable_stats {
self.stats.misses += 1;
}
None
}
}
pub fn insert(&mut self, key: CacheKey, value: T) {
self.insert_with_ttl(key, value, self.config.default_ttl);
}
pub fn insert_with_ttl(&mut self, key: CacheKey, value: T, ttl: Option<Duration>) {
if self.cache.len() >= self.config.max_entries {
self.evict_one();
}
let entry = CachedResult::new(value, ttl);
self.cache.insert(key.clone(), entry);
if self.config.enable_lru {
self.lru_queue.push_back(key);
}
}
pub fn invalidate(&mut self, key: &CacheKey) -> bool {
if self.cache.remove(key).is_some() {
if self.config.enable_stats {
self.stats.invalidations += 1;
}
if self.config.enable_lru {
self.lru_queue.retain(|k| k != key);
}
true
} else {
false
}
}
pub fn clear(&mut self) {
self.cache.clear();
self.lru_queue.clear();
}
pub fn cleanup_expired(&mut self) -> usize {
let mut removed = 0;
let expired_keys: Vec<CacheKey> = self
.cache
.iter()
.filter(|(_, v)| v.is_expired())
.map(|(k, _)| k.clone())
.collect();
for key in expired_keys {
self.cache.remove(&key);
self.lru_queue.retain(|k| k != &key);
removed += 1;
}
if self.config.enable_stats {
self.stats.expirations += removed as u64;
}
removed
}
pub fn stats(&self) -> &QueryCacheStats {
&self.stats
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn config(&self) -> &CacheConfig {
&self.config
}
fn update_lru(&mut self, key: &CacheKey) {
self.lru_queue.retain(|k| k != key);
self.lru_queue.push_back(key.clone());
}
fn evict_one(&mut self) {
if let Some(key) = self.lru_queue.pop_front() {
self.cache.remove(&key);
if self.config.enable_stats {
self.stats.evictions += 1;
}
}
}
}
impl<T: Clone> Default for QueryCache<T> {
fn default() -> Self {
Self::new()
}
}
pub struct SymbolTableCache {
predicate_cache: QueryCache<Vec<PredicateInfo>>,
domain_cache: QueryCache<Vec<String>>,
scalar_cache: QueryCache<usize>,
}
impl SymbolTableCache {
pub fn new() -> Self {
Self {
predicate_cache: QueryCache::new(),
domain_cache: QueryCache::new(),
scalar_cache: QueryCache::new(),
}
}
pub fn with_config(config: CacheConfig) -> Self {
Self {
predicate_cache: QueryCache::with_config(config.clone()),
domain_cache: QueryCache::with_config(config.clone()),
scalar_cache: QueryCache::with_config(config),
}
}
pub fn get_predicates_by_arity(
&mut self,
table: &SymbolTable,
arity: usize,
) -> Vec<PredicateInfo> {
let key = CacheKey::PredicatesByArity(arity);
if let Some(result) = self.predicate_cache.get(&key) {
return result;
}
let result: Vec<PredicateInfo> = table
.predicates
.values()
.filter(|p| p.arg_domains.len() == arity)
.cloned()
.collect();
self.predicate_cache.insert(key, result.clone());
result
}
pub fn get_predicates_by_domain(
&mut self,
table: &SymbolTable,
domain: &str,
) -> Vec<PredicateInfo> {
let key = CacheKey::PredicatesByDomain(domain.to_string());
if let Some(result) = self.predicate_cache.get(&key) {
return result;
}
let result: Vec<PredicateInfo> = table
.predicates
.values()
.filter(|p| p.arg_domains.contains(&domain.to_string()))
.cloned()
.collect();
self.predicate_cache.insert(key, result.clone());
result
}
pub fn get_domain_names(&mut self, table: &SymbolTable) -> Vec<String> {
let key = CacheKey::AllDomainNames;
if let Some(result) = self.domain_cache.get(&key) {
return result;
}
let mut result: Vec<String> = table.domains.keys().cloned().collect();
result.sort();
self.domain_cache.insert(key, result.clone());
result
}
pub fn get_domain_usage_count(&mut self, table: &SymbolTable, domain: &str) -> usize {
let key = CacheKey::DomainUsageCount(domain.to_string());
if let Some(result) = self.scalar_cache.get(&key) {
return result;
}
let mut count = 0;
for predicate in table.predicates.values() {
count += predicate
.arg_domains
.iter()
.filter(|d| d.as_str() == domain)
.count();
}
for var_domain in table.variables.values() {
if var_domain == domain {
count += 1;
}
}
self.scalar_cache.insert(key, count);
count
}
pub fn invalidate_all(&mut self) {
self.predicate_cache.clear();
self.domain_cache.clear();
self.scalar_cache.clear();
}
pub fn invalidate_domain(&mut self, domain: &str) {
self.predicate_cache
.invalidate(&CacheKey::PredicatesByDomain(domain.to_string()));
self.scalar_cache
.invalidate(&CacheKey::DomainUsageCount(domain.to_string()));
self.domain_cache.invalidate(&CacheKey::AllDomainNames);
}
pub fn invalidate_predicates(&mut self) {
self.predicate_cache.clear();
}
pub fn combined_stats(&self) -> QueryCacheStats {
let pred_stats = self.predicate_cache.stats();
let domain_stats = self.domain_cache.stats();
let scalar_stats = self.scalar_cache.stats();
QueryCacheStats {
hits: pred_stats.hits + domain_stats.hits + scalar_stats.hits,
misses: pred_stats.misses + domain_stats.misses + scalar_stats.misses,
evictions: pred_stats.evictions + domain_stats.evictions + scalar_stats.evictions,
expirations: pred_stats.expirations
+ domain_stats.expirations
+ scalar_stats.expirations,
invalidations: pred_stats.invalidations
+ domain_stats.invalidations
+ scalar_stats.invalidations,
}
}
pub fn cleanup_expired(&mut self) -> usize {
self.predicate_cache.cleanup_expired()
+ self.domain_cache.cleanup_expired()
+ self.scalar_cache.cleanup_expired()
}
}
impl Default for SymbolTableCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DomainInfo;
#[test]
fn test_cache_basic_operations() {
let mut cache: QueryCache<String> = QueryCache::new();
let key = CacheKey::Custom("test".to_string());
cache.insert(key.clone(), "value".to_string());
assert_eq!(cache.get(&key), Some("value".to_string()));
assert_eq!(cache.stats().hits, 1);
assert_eq!(cache.stats().misses, 0);
}
#[test]
fn test_cache_miss() {
let mut cache: QueryCache<String> = QueryCache::new();
let key = CacheKey::Custom("nonexistent".to_string());
assert_eq!(cache.get(&key), None);
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_cache_invalidation() {
let mut cache: QueryCache<String> = QueryCache::new();
let key = CacheKey::Custom("test".to_string());
cache.insert(key.clone(), "value".to_string());
assert!(cache.invalidate(&key));
assert_eq!(cache.get(&key), None);
}
#[test]
fn test_cache_expiration() {
let config = CacheConfig {
default_ttl: Some(Duration::from_millis(10)),
..Default::default()
};
let mut cache: QueryCache<String> = QueryCache::with_config(config);
let key = CacheKey::Custom("test".to_string());
cache.insert(key.clone(), "value".to_string());
std::thread::sleep(Duration::from_millis(20));
assert_eq!(cache.get(&key), None);
assert_eq!(cache.stats().expirations, 1);
}
#[test]
fn test_cache_eviction() {
let config = CacheConfig {
max_entries: 2,
enable_lru: true,
..Default::default()
};
let mut cache: QueryCache<String> = QueryCache::with_config(config);
cache.insert(CacheKey::Custom("key1".to_string()), "value1".to_string());
cache.insert(CacheKey::Custom("key2".to_string()), "value2".to_string());
cache.insert(CacheKey::Custom("key3".to_string()), "value3".to_string());
assert_eq!(cache.len(), 2);
assert_eq!(cache.stats().evictions, 1);
}
#[test]
fn test_symbol_table_cache() {
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
table
.add_predicate(PredicateInfo::new(
"knows",
vec!["Person".to_string(), "Person".to_string()],
))
.expect("unwrap");
table
.add_predicate(PredicateInfo::new("age", vec!["Person".to_string()]))
.expect("unwrap");
let mut cache = SymbolTableCache::new();
let predicates = cache.get_predicates_by_arity(&table, 2);
assert_eq!(predicates.len(), 1);
assert_eq!(cache.predicate_cache.stats().misses, 1);
let predicates = cache.get_predicates_by_arity(&table, 2);
assert_eq!(predicates.len(), 1);
assert_eq!(cache.predicate_cache.stats().hits, 1);
}
#[test]
fn test_cache_config_presets() {
let small = CacheConfig::small();
assert_eq!(small.max_entries, 100);
let large = CacheConfig::large();
assert_eq!(large.max_entries, 10000);
let no_ttl = CacheConfig::no_ttl();
assert!(no_ttl.default_ttl.is_none());
}
#[test]
fn test_cache_stats() {
let mut cache: QueryCache<String> = QueryCache::new();
let key1 = CacheKey::Custom("key1".to_string());
let key2 = CacheKey::Custom("key2".to_string());
cache.insert(key1.clone(), "value1".to_string());
cache.get(&key1); cache.get(&key2);
let stats = cache.stats();
assert_eq!(stats.hit_rate(), 0.5);
assert_eq!(stats.miss_rate(), 0.5);
assert_eq!(stats.total_accesses(), 2);
}
#[test]
fn test_cleanup_expired() {
let config = CacheConfig {
default_ttl: Some(Duration::from_millis(10)),
..Default::default()
};
let mut cache: QueryCache<String> = QueryCache::with_config(config);
cache.insert(CacheKey::Custom("key1".to_string()), "value1".to_string());
cache.insert(CacheKey::Custom("key2".to_string()), "value2".to_string());
std::thread::sleep(Duration::from_millis(20));
let removed = cache.cleanup_expired();
assert_eq!(removed, 2);
assert!(cache.is_empty());
}
}