1use crate::ir::{Predicate, Term};
10use crate::reasoning::Substitution;
11use ipfrs_core::Cid;
12use lru::LruCache;
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::hash::{Hash, Hasher};
17use std::num::NonZeroUsize;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct QueryKey {
25 pub predicate_name: String,
27 pub ground_args: Vec<GroundArg>,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub enum GroundArg {
34 String(String),
36 Int(i64),
38 Float(u64),
40 Variable,
42}
43
44impl QueryKey {
45 pub fn from_predicate(pred: &Predicate) -> Self {
47 let ground_args = pred
48 .args
49 .iter()
50 .map(|arg| match arg {
51 Term::Const(c) => match c {
52 crate::ir::Constant::String(s) => GroundArg::String(s.clone()),
53 crate::ir::Constant::Int(i) => GroundArg::Int(*i),
54 crate::ir::Constant::Float(f) => {
56 let hash = f.parse::<f64>().map(|v| v.to_bits()).unwrap_or(0);
57 GroundArg::Float(hash)
58 }
59 crate::ir::Constant::Bool(b) => GroundArg::Int(if *b { 1 } else { 0 }),
60 },
61 Term::Var(_) | Term::Fun(_, _) | Term::Ref(_) => GroundArg::Variable,
62 })
63 .collect();
64
65 Self {
66 predicate_name: pred.name.clone(),
67 ground_args,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct CachedResult {
75 pub solutions: Vec<Substitution>,
77 pub cached_at: Instant,
79 pub ttl: Option<Duration>,
81}
82
83impl CachedResult {
84 pub fn new(solutions: Vec<Substitution>, ttl: Option<Duration>) -> Self {
86 Self {
87 solutions,
88 cached_at: Instant::now(),
89 ttl,
90 }
91 }
92
93 #[inline]
95 pub fn is_expired(&self) -> bool {
96 if let Some(ttl) = self.ttl {
97 self.cached_at.elapsed() > ttl
98 } else {
99 false
100 }
101 }
102
103 #[inline]
105 pub fn remaining_ttl(&self) -> Option<Duration> {
106 self.ttl
107 .map(|ttl| ttl.saturating_sub(self.cached_at.elapsed()))
108 }
109}
110
111#[derive(Debug, Default)]
113pub struct CacheStats {
114 pub hits: AtomicU64,
116 pub misses: AtomicU64,
118 pub evictions: AtomicU64,
120 pub expirations: AtomicU64,
122}
123
124impl CacheStats {
125 pub fn new() -> Self {
127 Self::default()
128 }
129
130 #[inline]
132 pub fn record_hit(&self) {
133 self.hits.fetch_add(1, Ordering::Relaxed);
134 }
135
136 #[inline]
138 pub fn record_miss(&self) {
139 self.misses.fetch_add(1, Ordering::Relaxed);
140 }
141
142 #[inline]
144 pub fn record_eviction(&self) {
145 self.evictions.fetch_add(1, Ordering::Relaxed);
146 }
147
148 #[inline]
150 pub fn record_expiration(&self) {
151 self.expirations.fetch_add(1, Ordering::Relaxed);
152 }
153
154 pub fn hit_rate(&self) -> f64 {
156 let hits = self.hits.load(Ordering::Relaxed);
157 let misses = self.misses.load(Ordering::Relaxed);
158 let total = hits + misses;
159 if total == 0 {
160 0.0
161 } else {
162 hits as f64 / total as f64
163 }
164 }
165
166 pub fn snapshot(&self) -> CacheStatsSnapshot {
168 CacheStatsSnapshot {
169 hits: self.hits.load(Ordering::Relaxed),
170 misses: self.misses.load(Ordering::Relaxed),
171 evictions: self.evictions.load(Ordering::Relaxed),
172 expirations: self.expirations.load(Ordering::Relaxed),
173 }
174 }
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct CacheStatsSnapshot {
180 pub hits: u64,
182 pub misses: u64,
184 pub evictions: u64,
186 pub expirations: u64,
188}
189
190impl CacheStatsSnapshot {
191 #[inline]
193 pub fn hit_rate(&self) -> f64 {
194 let total = self.hits + self.misses;
195 if total == 0 {
196 0.0
197 } else {
198 self.hits as f64 / total as f64
199 }
200 }
201}
202
203pub struct QueryCache {
205 cache: RwLock<LruCache<QueryKey, CachedResult>>,
207 default_ttl: Option<Duration>,
209 stats: Arc<CacheStats>,
211}
212
213impl QueryCache {
214 pub fn new(capacity: usize) -> Self {
216 Self {
217 cache: RwLock::new(LruCache::new(
218 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(100).unwrap()),
219 )),
220 default_ttl: None,
221 stats: Arc::new(CacheStats::new()),
222 }
223 }
224
225 pub fn with_ttl(capacity: usize, ttl: Duration) -> Self {
227 Self {
228 cache: RwLock::new(LruCache::new(
229 NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(100).unwrap()),
230 )),
231 default_ttl: Some(ttl),
232 stats: Arc::new(CacheStats::new()),
233 }
234 }
235
236 #[inline]
238 pub fn get(&self, key: &QueryKey) -> Option<Vec<Substitution>> {
239 let mut cache = self.cache.write();
240
241 if let Some(result) = cache.get(key) {
242 if result.is_expired() {
243 self.stats.record_expiration();
244 cache.pop(key);
245 self.stats.record_miss();
246 return None;
247 }
248 self.stats.record_hit();
249 Some(result.solutions.clone())
250 } else {
251 self.stats.record_miss();
252 None
253 }
254 }
255
256 pub fn insert(&self, key: QueryKey, solutions: Vec<Substitution>) {
258 let mut cache = self.cache.write();
259
260 if cache.len() >= cache.cap().get() {
262 self.stats.record_eviction();
263 }
264
265 let result = CachedResult::new(solutions, self.default_ttl);
266 cache.put(key, result);
267 }
268
269 pub fn insert_with_ttl(&self, key: QueryKey, solutions: Vec<Substitution>, ttl: Duration) {
271 let mut cache = self.cache.write();
272
273 if cache.len() >= cache.cap().get() {
274 self.stats.record_eviction();
275 }
276
277 let result = CachedResult::new(solutions, Some(ttl));
278 cache.put(key, result);
279 }
280
281 pub fn invalidate(&self, key: &QueryKey) -> bool {
283 let mut cache = self.cache.write();
284 cache.pop(key).is_some()
285 }
286
287 pub fn invalidate_predicate(&self, predicate_name: &str) {
289 let mut cache = self.cache.write();
290 let keys_to_remove: Vec<QueryKey> = cache
291 .iter()
292 .filter(|(k, _)| k.predicate_name == predicate_name)
293 .map(|(k, _)| k.clone())
294 .collect();
295
296 for key in keys_to_remove {
297 cache.pop(&key);
298 }
299 }
300
301 pub fn clear(&self) {
303 let mut cache = self.cache.write();
304 cache.clear();
305 }
306
307 #[inline]
309 pub fn stats(&self) -> Arc<CacheStats> {
310 self.stats.clone()
311 }
312
313 #[inline]
315 pub fn len(&self) -> usize {
316 self.cache.read().len()
317 }
318
319 #[inline]
321 pub fn is_empty(&self) -> bool {
322 self.cache.read().is_empty()
323 }
324
325 #[inline]
327 pub fn capacity(&self) -> usize {
328 self.cache.read().cap().get()
329 }
330
331 pub fn evict_expired(&self) -> usize {
333 let mut cache = self.cache.write();
334 let mut expired_keys = Vec::new();
335
336 for (key, result) in cache.iter() {
337 if result.is_expired() {
338 expired_keys.push(key.clone());
339 }
340 }
341
342 let count = expired_keys.len();
343 for key in expired_keys {
344 cache.pop(&key);
345 self.stats.record_expiration();
346 }
347
348 count
349 }
350}
351
352impl Default for QueryCache {
353 fn default() -> Self {
354 Self::new(1000)
355 }
356}
357
358#[derive(Debug, Clone)]
360pub struct RemoteFact {
361 pub fact: Predicate,
363 pub source: Option<Cid>,
365 pub fetched_at: Instant,
367 pub ttl: Duration,
369}
370
371impl RemoteFact {
372 pub fn new(fact: Predicate, source: Option<Cid>, ttl: Duration) -> Self {
374 Self {
375 fact,
376 source,
377 fetched_at: Instant::now(),
378 ttl,
379 }
380 }
381
382 #[inline]
384 pub fn is_expired(&self) -> bool {
385 self.fetched_at.elapsed() > self.ttl
386 }
387}
388
389#[derive(Debug, Clone, PartialEq, Eq, Hash)]
391pub struct FactKey {
392 pub predicate_name: String,
394 pub args_hash: u64,
396}
397
398impl FactKey {
399 pub fn from_predicate(pred: &Predicate) -> Self {
401 let mut hasher = std::collections::hash_map::DefaultHasher::new();
402 for arg in &pred.args {
403 arg.hash(&mut hasher);
404 }
405 Self {
406 predicate_name: pred.name.clone(),
407 args_hash: hasher.finish(),
408 }
409 }
410}
411
412pub struct RemoteFactCache {
414 facts: RwLock<HashMap<String, Vec<RemoteFact>>>,
416 max_per_predicate: usize,
418 default_ttl: Duration,
420 stats: Arc<CacheStats>,
422}
423
424impl RemoteFactCache {
425 pub fn new(max_per_predicate: usize, default_ttl: Duration) -> Self {
427 Self {
428 facts: RwLock::new(HashMap::new()),
429 max_per_predicate,
430 default_ttl,
431 stats: Arc::new(CacheStats::new()),
432 }
433 }
434
435 pub fn get_facts(&self, predicate_name: &str) -> Vec<Predicate> {
437 let facts = self.facts.read();
438
439 if let Some(remote_facts) = facts.get(predicate_name) {
440 let valid_facts: Vec<Predicate> = remote_facts
441 .iter()
442 .filter(|f| !f.is_expired())
443 .map(|f| f.fact.clone())
444 .collect();
445
446 if valid_facts.is_empty() {
447 self.stats.record_miss();
448 } else {
449 self.stats.record_hit();
450 }
451
452 valid_facts
453 } else {
454 self.stats.record_miss();
455 Vec::new()
456 }
457 }
458
459 pub fn add_fact(&self, fact: Predicate, source: Option<Cid>) {
461 self.add_fact_with_ttl(fact, source, self.default_ttl);
462 }
463
464 pub fn add_fact_with_ttl(&self, fact: Predicate, source: Option<Cid>, ttl: Duration) {
466 let mut facts = self.facts.write();
467 let name = fact.name.clone();
468
469 let remote_fact = RemoteFact::new(fact, source, ttl);
470
471 let entry = facts.entry(name).or_default();
472
473 entry.retain(|f| !f.is_expired());
475
476 if entry.len() >= self.max_per_predicate {
478 entry.sort_by_key(|f| f.fetched_at);
480 entry.remove(0);
481 self.stats.record_eviction();
482 }
483
484 entry.push(remote_fact);
485 }
486
487 pub fn add_facts(&self, facts: Vec<Predicate>, source: Option<Cid>) {
489 for fact in facts {
490 self.add_fact(fact, source);
491 }
492 }
493
494 pub fn invalidate_predicate(&self, predicate_name: &str) {
496 let mut facts = self.facts.write();
497 facts.remove(predicate_name);
498 }
499
500 pub fn clear(&self) {
502 let mut facts = self.facts.write();
503 facts.clear();
504 }
505
506 pub fn stats(&self) -> Arc<CacheStats> {
508 self.stats.clone()
509 }
510
511 pub fn evict_expired(&self) -> usize {
513 let mut facts = self.facts.write();
514 let mut count = 0;
515
516 for entry in facts.values_mut() {
517 let before = entry.len();
518 entry.retain(|f| !f.is_expired());
519 count += before - entry.len();
520 }
521
522 for _ in 0..count {
523 self.stats.record_expiration();
524 }
525
526 count
527 }
528
529 pub fn len(&self) -> usize {
531 let facts = self.facts.read();
532 facts.values().map(|v| v.len()).sum()
533 }
534
535 pub fn is_empty(&self) -> bool {
537 self.len() == 0
538 }
539}
540
541impl Default for RemoteFactCache {
542 fn default() -> Self {
543 Self::new(1000, Duration::from_secs(300))
544 }
545}
546
547pub struct CacheManager {
549 pub query_cache: QueryCache,
551 pub fact_cache: RemoteFactCache,
553}
554
555impl CacheManager {
556 pub fn new() -> Self {
558 Self {
559 query_cache: QueryCache::new(10000),
560 fact_cache: RemoteFactCache::new(1000, Duration::from_secs(300)),
561 }
562 }
563
564 pub fn with_config(
566 query_capacity: usize,
567 query_ttl: Option<Duration>,
568 fact_capacity: usize,
569 fact_ttl: Duration,
570 ) -> Self {
571 let query_cache = if let Some(ttl) = query_ttl {
572 QueryCache::with_ttl(query_capacity, ttl)
573 } else {
574 QueryCache::new(query_capacity)
575 };
576
577 Self {
578 query_cache,
579 fact_cache: RemoteFactCache::new(fact_capacity, fact_ttl),
580 }
581 }
582
583 pub fn evict_expired(&self) -> (usize, usize) {
585 let queries = self.query_cache.evict_expired();
586 let facts = self.fact_cache.evict_expired();
587 (queries, facts)
588 }
589
590 pub fn clear_all(&self) {
592 self.query_cache.clear();
593 self.fact_cache.clear();
594 }
595
596 pub fn stats(&self) -> CombinedCacheStats {
598 CombinedCacheStats {
599 query_stats: self.query_cache.stats().snapshot(),
600 fact_stats: self.fact_cache.stats().snapshot(),
601 query_cache_size: self.query_cache.len(),
602 fact_cache_size: self.fact_cache.len(),
603 }
604 }
605}
606
607impl Default for CacheManager {
608 fn default() -> Self {
609 Self::new()
610 }
611}
612
613#[derive(Debug, Clone, Serialize, Deserialize)]
615pub struct CombinedCacheStats {
616 pub query_stats: CacheStatsSnapshot,
618 pub fact_stats: CacheStatsSnapshot,
620 pub query_cache_size: usize,
622 pub fact_cache_size: usize,
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629 use crate::ir::Constant;
630 use std::thread::sleep;
631
632 #[test]
633 fn test_query_cache_basic() {
634 let cache = QueryCache::new(100);
635
636 let key = QueryKey {
637 predicate_name: "test".to_string(),
638 ground_args: vec![GroundArg::String("value".to_string())],
639 };
640
641 let solutions = vec![Substitution::new()];
642 cache.insert(key.clone(), solutions.clone());
643
644 let result = cache.get(&key);
645 assert!(result.is_some());
646 assert_eq!(result.unwrap().len(), 1);
647 }
648
649 #[test]
650 fn test_query_cache_ttl() {
651 let cache = QueryCache::with_ttl(100, Duration::from_millis(50));
652
653 let key = QueryKey {
654 predicate_name: "test".to_string(),
655 ground_args: vec![],
656 };
657
658 cache.insert(key.clone(), vec![Substitution::new()]);
659
660 assert!(cache.get(&key).is_some());
662
663 sleep(Duration::from_millis(100));
665
666 assert!(cache.get(&key).is_none());
668 }
669
670 #[test]
671 fn test_query_cache_stats() {
672 let cache = QueryCache::new(100);
673
674 let key = QueryKey {
675 predicate_name: "test".to_string(),
676 ground_args: vec![],
677 };
678
679 cache.get(&key);
681
682 cache.insert(key.clone(), vec![]);
684 cache.get(&key);
685
686 let stats = cache.stats().snapshot();
687 assert_eq!(stats.hits, 1);
688 assert_eq!(stats.misses, 1);
689 }
690
691 #[test]
692 fn test_remote_fact_cache() {
693 let cache = RemoteFactCache::new(100, Duration::from_secs(60));
694
695 let fact = Predicate::new(
696 "test".to_string(),
697 vec![Term::Const(Constant::String("value".to_string()))],
698 );
699
700 cache.add_fact(fact.clone(), None);
701
702 let facts = cache.get_facts("test");
703 assert_eq!(facts.len(), 1);
704 assert_eq!(facts[0].name, "test");
705 }
706
707 #[test]
708 fn test_remote_fact_cache_ttl() {
709 let cache = RemoteFactCache::new(100, Duration::from_millis(50));
710
711 let fact = Predicate::new("test".to_string(), vec![]);
712 cache.add_fact(fact, None);
713
714 assert_eq!(cache.get_facts("test").len(), 1);
716
717 sleep(Duration::from_millis(100));
719
720 assert!(cache.get_facts("test").is_empty());
722 }
723
724 #[test]
725 fn test_cache_manager() {
726 let manager = CacheManager::new();
727
728 let key = QueryKey {
730 predicate_name: "test".to_string(),
731 ground_args: vec![],
732 };
733 manager.query_cache.insert(key.clone(), vec![]);
734 assert!(manager.query_cache.get(&key).is_some());
735
736 let fact = Predicate::new("fact".to_string(), vec![]);
738 manager.fact_cache.add_fact(fact, None);
739 assert_eq!(manager.fact_cache.get_facts("fact").len(), 1);
740
741 let stats = manager.stats();
743 assert!(stats.query_cache_size > 0);
744 assert!(stats.fact_cache_size > 0);
745 }
746}