use crate::ir::{Predicate, Term};
use crate::reasoning::Substitution;
use ipfrs_core::Cid;
use lru::LruCache;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct QueryKey {
pub predicate_name: String,
pub ground_args: Vec<GroundArg>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum GroundArg {
String(String),
Int(i64),
Float(u64),
Variable,
}
impl QueryKey {
pub fn from_predicate(pred: &Predicate) -> Self {
let ground_args = pred
.args
.iter()
.map(|arg| match arg {
Term::Const(c) => match c {
crate::ir::Constant::String(s) => GroundArg::String(s.clone()),
crate::ir::Constant::Int(i) => GroundArg::Int(*i),
crate::ir::Constant::Float(f) => {
let hash = f.parse::<f64>().map(|v| v.to_bits()).unwrap_or(0);
GroundArg::Float(hash)
}
crate::ir::Constant::Bool(b) => GroundArg::Int(if *b { 1 } else { 0 }),
},
Term::Var(_) | Term::Fun(_, _) | Term::Ref(_) => GroundArg::Variable,
})
.collect();
Self {
predicate_name: pred.name.clone(),
ground_args,
}
}
}
#[derive(Debug, Clone)]
pub struct CachedResult {
pub solutions: Vec<Substitution>,
pub cached_at: Instant,
pub ttl: Option<Duration>,
}
impl CachedResult {
pub fn new(solutions: Vec<Substitution>, ttl: Option<Duration>) -> Self {
Self {
solutions,
cached_at: Instant::now(),
ttl,
}
}
#[inline]
pub fn is_expired(&self) -> bool {
if let Some(ttl) = self.ttl {
self.cached_at.elapsed() > ttl
} else {
false
}
}
#[inline]
pub fn remaining_ttl(&self) -> Option<Duration> {
self.ttl
.map(|ttl| ttl.saturating_sub(self.cached_at.elapsed()))
}
}
#[derive(Debug, Default)]
pub struct CacheStats {
pub hits: AtomicU64,
pub misses: AtomicU64,
pub evictions: AtomicU64,
pub expirations: AtomicU64,
}
impl CacheStats {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn record_hit(&self) {
self.hits.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_miss(&self) {
self.misses.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_eviction(&self) {
self.evictions.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_expiration(&self) {
self.expirations.fetch_add(1, Ordering::Relaxed);
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
pub fn snapshot(&self) -> CacheStatsSnapshot {
CacheStatsSnapshot {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
expirations: self.expirations.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStatsSnapshot {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub expirations: u64,
}
impl CacheStatsSnapshot {
#[inline]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
pub struct QueryCache {
cache: RwLock<LruCache<QueryKey, CachedResult>>,
default_ttl: Option<Duration>,
stats: Arc<CacheStats>,
}
impl QueryCache {
pub fn new(capacity: usize) -> Self {
Self {
cache: RwLock::new(LruCache::new(
NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(100).unwrap()),
)),
default_ttl: None,
stats: Arc::new(CacheStats::new()),
}
}
pub fn with_ttl(capacity: usize, ttl: Duration) -> Self {
Self {
cache: RwLock::new(LruCache::new(
NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(100).unwrap()),
)),
default_ttl: Some(ttl),
stats: Arc::new(CacheStats::new()),
}
}
#[inline]
pub fn get(&self, key: &QueryKey) -> Option<Vec<Substitution>> {
let mut cache = self.cache.write();
if let Some(result) = cache.get(key) {
if result.is_expired() {
self.stats.record_expiration();
cache.pop(key);
self.stats.record_miss();
return None;
}
self.stats.record_hit();
Some(result.solutions.clone())
} else {
self.stats.record_miss();
None
}
}
pub fn insert(&self, key: QueryKey, solutions: Vec<Substitution>) {
let mut cache = self.cache.write();
if cache.len() >= cache.cap().get() {
self.stats.record_eviction();
}
let result = CachedResult::new(solutions, self.default_ttl);
cache.put(key, result);
}
pub fn insert_with_ttl(&self, key: QueryKey, solutions: Vec<Substitution>, ttl: Duration) {
let mut cache = self.cache.write();
if cache.len() >= cache.cap().get() {
self.stats.record_eviction();
}
let result = CachedResult::new(solutions, Some(ttl));
cache.put(key, result);
}
pub fn invalidate(&self, key: &QueryKey) -> bool {
let mut cache = self.cache.write();
cache.pop(key).is_some()
}
pub fn invalidate_predicate(&self, predicate_name: &str) {
let mut cache = self.cache.write();
let keys_to_remove: Vec<QueryKey> = cache
.iter()
.filter(|(k, _)| k.predicate_name == predicate_name)
.map(|(k, _)| k.clone())
.collect();
for key in keys_to_remove {
cache.pop(&key);
}
}
pub fn clear(&self) {
let mut cache = self.cache.write();
cache.clear();
}
#[inline]
pub fn stats(&self) -> Arc<CacheStats> {
self.stats.clone()
}
#[inline]
pub fn len(&self) -> usize {
self.cache.read().len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.cache.read().is_empty()
}
#[inline]
pub fn capacity(&self) -> usize {
self.cache.read().cap().get()
}
pub fn evict_expired(&self) -> usize {
let mut cache = self.cache.write();
let mut expired_keys = Vec::new();
for (key, result) in cache.iter() {
if result.is_expired() {
expired_keys.push(key.clone());
}
}
let count = expired_keys.len();
for key in expired_keys {
cache.pop(&key);
self.stats.record_expiration();
}
count
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new(1000)
}
}
#[derive(Debug, Clone)]
pub struct RemoteFact {
pub fact: Predicate,
pub source: Option<Cid>,
pub fetched_at: Instant,
pub ttl: Duration,
}
impl RemoteFact {
pub fn new(fact: Predicate, source: Option<Cid>, ttl: Duration) -> Self {
Self {
fact,
source,
fetched_at: Instant::now(),
ttl,
}
}
#[inline]
pub fn is_expired(&self) -> bool {
self.fetched_at.elapsed() > self.ttl
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FactKey {
pub predicate_name: String,
pub args_hash: u64,
}
impl FactKey {
pub fn from_predicate(pred: &Predicate) -> Self {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
for arg in &pred.args {
arg.hash(&mut hasher);
}
Self {
predicate_name: pred.name.clone(),
args_hash: hasher.finish(),
}
}
}
pub struct RemoteFactCache {
facts: RwLock<HashMap<String, Vec<RemoteFact>>>,
max_per_predicate: usize,
default_ttl: Duration,
stats: Arc<CacheStats>,
}
impl RemoteFactCache {
pub fn new(max_per_predicate: usize, default_ttl: Duration) -> Self {
Self {
facts: RwLock::new(HashMap::new()),
max_per_predicate,
default_ttl,
stats: Arc::new(CacheStats::new()),
}
}
pub fn get_facts(&self, predicate_name: &str) -> Vec<Predicate> {
let facts = self.facts.read();
if let Some(remote_facts) = facts.get(predicate_name) {
let valid_facts: Vec<Predicate> = remote_facts
.iter()
.filter(|f| !f.is_expired())
.map(|f| f.fact.clone())
.collect();
if valid_facts.is_empty() {
self.stats.record_miss();
} else {
self.stats.record_hit();
}
valid_facts
} else {
self.stats.record_miss();
Vec::new()
}
}
pub fn add_fact(&self, fact: Predicate, source: Option<Cid>) {
self.add_fact_with_ttl(fact, source, self.default_ttl);
}
pub fn add_fact_with_ttl(&self, fact: Predicate, source: Option<Cid>, ttl: Duration) {
let mut facts = self.facts.write();
let name = fact.name.clone();
let remote_fact = RemoteFact::new(fact, source, ttl);
let entry = facts.entry(name).or_default();
entry.retain(|f| !f.is_expired());
if entry.len() >= self.max_per_predicate {
entry.sort_by_key(|f| f.fetched_at);
entry.remove(0);
self.stats.record_eviction();
}
entry.push(remote_fact);
}
pub fn add_facts(&self, facts: Vec<Predicate>, source: Option<Cid>) {
for fact in facts {
self.add_fact(fact, source);
}
}
pub fn invalidate_predicate(&self, predicate_name: &str) {
let mut facts = self.facts.write();
facts.remove(predicate_name);
}
pub fn clear(&self) {
let mut facts = self.facts.write();
facts.clear();
}
pub fn stats(&self) -> Arc<CacheStats> {
self.stats.clone()
}
pub fn evict_expired(&self) -> usize {
let mut facts = self.facts.write();
let mut count = 0;
for entry in facts.values_mut() {
let before = entry.len();
entry.retain(|f| !f.is_expired());
count += before - entry.len();
}
for _ in 0..count {
self.stats.record_expiration();
}
count
}
pub fn len(&self) -> usize {
let facts = self.facts.read();
facts.values().map(|v| v.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for RemoteFactCache {
fn default() -> Self {
Self::new(1000, Duration::from_secs(300))
}
}
pub struct CacheManager {
pub query_cache: QueryCache,
pub fact_cache: RemoteFactCache,
}
impl CacheManager {
pub fn new() -> Self {
Self {
query_cache: QueryCache::new(10000),
fact_cache: RemoteFactCache::new(1000, Duration::from_secs(300)),
}
}
pub fn with_config(
query_capacity: usize,
query_ttl: Option<Duration>,
fact_capacity: usize,
fact_ttl: Duration,
) -> Self {
let query_cache = if let Some(ttl) = query_ttl {
QueryCache::with_ttl(query_capacity, ttl)
} else {
QueryCache::new(query_capacity)
};
Self {
query_cache,
fact_cache: RemoteFactCache::new(fact_capacity, fact_ttl),
}
}
pub fn evict_expired(&self) -> (usize, usize) {
let queries = self.query_cache.evict_expired();
let facts = self.fact_cache.evict_expired();
(queries, facts)
}
pub fn clear_all(&self) {
self.query_cache.clear();
self.fact_cache.clear();
}
pub fn stats(&self) -> CombinedCacheStats {
CombinedCacheStats {
query_stats: self.query_cache.stats().snapshot(),
fact_stats: self.fact_cache.stats().snapshot(),
query_cache_size: self.query_cache.len(),
fact_cache_size: self.fact_cache.len(),
}
}
}
impl Default for CacheManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CombinedCacheStats {
pub query_stats: CacheStatsSnapshot,
pub fact_stats: CacheStatsSnapshot,
pub query_cache_size: usize,
pub fact_cache_size: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::Constant;
use std::thread::sleep;
#[test]
fn test_query_cache_basic() {
let cache = QueryCache::new(100);
let key = QueryKey {
predicate_name: "test".to_string(),
ground_args: vec![GroundArg::String("value".to_string())],
};
let solutions = vec![Substitution::new()];
cache.insert(key.clone(), solutions.clone());
let result = cache.get(&key);
assert!(result.is_some());
assert_eq!(result.unwrap().len(), 1);
}
#[test]
fn test_query_cache_ttl() {
let cache = QueryCache::with_ttl(100, Duration::from_millis(50));
let key = QueryKey {
predicate_name: "test".to_string(),
ground_args: vec![],
};
cache.insert(key.clone(), vec![Substitution::new()]);
assert!(cache.get(&key).is_some());
sleep(Duration::from_millis(100));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_query_cache_stats() {
let cache = QueryCache::new(100);
let key = QueryKey {
predicate_name: "test".to_string(),
ground_args: vec![],
};
cache.get(&key);
cache.insert(key.clone(), vec![]);
cache.get(&key);
let stats = cache.stats().snapshot();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_remote_fact_cache() {
let cache = RemoteFactCache::new(100, Duration::from_secs(60));
let fact = Predicate::new(
"test".to_string(),
vec![Term::Const(Constant::String("value".to_string()))],
);
cache.add_fact(fact.clone(), None);
let facts = cache.get_facts("test");
assert_eq!(facts.len(), 1);
assert_eq!(facts[0].name, "test");
}
#[test]
fn test_remote_fact_cache_ttl() {
let cache = RemoteFactCache::new(100, Duration::from_millis(50));
let fact = Predicate::new("test".to_string(), vec![]);
cache.add_fact(fact, None);
assert_eq!(cache.get_facts("test").len(), 1);
sleep(Duration::from_millis(100));
assert!(cache.get_facts("test").is_empty());
}
#[test]
fn test_cache_manager() {
let manager = CacheManager::new();
let key = QueryKey {
predicate_name: "test".to_string(),
ground_args: vec![],
};
manager.query_cache.insert(key.clone(), vec![]);
assert!(manager.query_cache.get(&key).is_some());
let fact = Predicate::new("fact".to_string(), vec![]);
manager.fact_cache.add_fact(fact, None);
assert_eq!(manager.fact_cache.get_facts("fact").len(), 1);
let stats = manager.stats();
assert!(stats.query_cache_size > 0);
assert!(stats.fact_cache_size > 0);
}
}