ipfrs_tensorlogic/
cache.rs

1//! Caching support for query results and remote facts
2//!
3//! Provides:
4//! - LRU cache for query results
5//! - TTL-based cache for remote facts
6//! - Thread-safe caching primitives
7//! - Cache statistics
8
9use crate::ir::{Predicate, Term};
10use crate::reasoning::Substitution;
11use ipfrs_core::Cid;
12use lru::LruCache;
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::hash::{Hash, Hasher};
17use std::num::NonZeroUsize;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21
22/// Query key for cache lookups
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct QueryKey {
25    /// Predicate name
26    pub predicate_name: String,
27    /// Ground arguments (for filtering)
28    pub ground_args: Vec<GroundArg>,
29}
30
31/// Ground argument for query key
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub enum GroundArg {
34    /// String constant
35    String(String),
36    /// Integer constant
37    Int(i64),
38    /// Float constant (as bits for hashing)
39    Float(u64),
40    /// Variable (not ground)
41    Variable,
42}
43
44impl QueryKey {
45    /// Create a query key from a predicate
46    pub fn from_predicate(pred: &Predicate) -> Self {
47        let ground_args = pred
48            .args
49            .iter()
50            .map(|arg| match arg {
51                Term::Const(c) => match c {
52                    crate::ir::Constant::String(s) => GroundArg::String(s.clone()),
53                    crate::ir::Constant::Int(i) => GroundArg::Int(*i),
54                    // Float is stored as String for deterministic hashing
55                    crate::ir::Constant::Float(f) => {
56                        let hash = f.parse::<f64>().map(|v| v.to_bits()).unwrap_or(0);
57                        GroundArg::Float(hash)
58                    }
59                    crate::ir::Constant::Bool(b) => GroundArg::Int(if *b { 1 } else { 0 }),
60                },
61                Term::Var(_) | Term::Fun(_, _) | Term::Ref(_) => GroundArg::Variable,
62            })
63            .collect();
64
65        Self {
66            predicate_name: pred.name.clone(),
67            ground_args,
68        }
69    }
70}
71
72/// Cached query result
73#[derive(Debug, Clone)]
74pub struct CachedResult {
75    /// Query solutions (substitutions)
76    pub solutions: Vec<Substitution>,
77    /// When the result was cached
78    pub cached_at: Instant,
79    /// Time-to-live for this result
80    pub ttl: Option<Duration>,
81}
82
83impl CachedResult {
84    /// Create a new cached result
85    pub fn new(solutions: Vec<Substitution>, ttl: Option<Duration>) -> Self {
86        Self {
87            solutions,
88            cached_at: Instant::now(),
89            ttl,
90        }
91    }
92
93    /// Check if the cached result has expired
94    #[inline]
95    pub fn is_expired(&self) -> bool {
96        if let Some(ttl) = self.ttl {
97            self.cached_at.elapsed() > ttl
98        } else {
99            false
100        }
101    }
102
103    /// Get remaining TTL
104    #[inline]
105    pub fn remaining_ttl(&self) -> Option<Duration> {
106        self.ttl
107            .map(|ttl| ttl.saturating_sub(self.cached_at.elapsed()))
108    }
109}
110
111/// Cache statistics
112#[derive(Debug, Default)]
113pub struct CacheStats {
114    /// Number of cache hits
115    pub hits: AtomicU64,
116    /// Number of cache misses
117    pub misses: AtomicU64,
118    /// Number of evictions
119    pub evictions: AtomicU64,
120    /// Number of expirations
121    pub expirations: AtomicU64,
122}
123
124impl CacheStats {
125    /// Create new stats
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    /// Record a hit
131    #[inline]
132    pub fn record_hit(&self) {
133        self.hits.fetch_add(1, Ordering::Relaxed);
134    }
135
136    /// Record a miss
137    #[inline]
138    pub fn record_miss(&self) {
139        self.misses.fetch_add(1, Ordering::Relaxed);
140    }
141
142    /// Record an eviction
143    #[inline]
144    pub fn record_eviction(&self) {
145        self.evictions.fetch_add(1, Ordering::Relaxed);
146    }
147
148    /// Record an expiration
149    #[inline]
150    pub fn record_expiration(&self) {
151        self.expirations.fetch_add(1, Ordering::Relaxed);
152    }
153
154    /// Get hit rate
155    pub fn hit_rate(&self) -> f64 {
156        let hits = self.hits.load(Ordering::Relaxed);
157        let misses = self.misses.load(Ordering::Relaxed);
158        let total = hits + misses;
159        if total == 0 {
160            0.0
161        } else {
162            hits as f64 / total as f64
163        }
164    }
165
166    /// Get a snapshot of stats
167    pub fn snapshot(&self) -> CacheStatsSnapshot {
168        CacheStatsSnapshot {
169            hits: self.hits.load(Ordering::Relaxed),
170            misses: self.misses.load(Ordering::Relaxed),
171            evictions: self.evictions.load(Ordering::Relaxed),
172            expirations: self.expirations.load(Ordering::Relaxed),
173        }
174    }
175}
176
177/// Snapshot of cache statistics
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct CacheStatsSnapshot {
180    /// Number of cache hits
181    pub hits: u64,
182    /// Number of cache misses
183    pub misses: u64,
184    /// Number of evictions
185    pub evictions: u64,
186    /// Number of expirations
187    pub expirations: u64,
188}
189
190impl CacheStatsSnapshot {
191    /// Get hit rate
192    #[inline]
193    pub fn hit_rate(&self) -> f64 {
194        let total = self.hits + self.misses;
195        if total == 0 {
196            0.0
197        } else {
198            self.hits as f64 / total as f64
199        }
200    }
201}
202
203/// LRU cache for query results
204pub struct QueryCache {
205    /// The underlying LRU cache
206    cache: RwLock<LruCache<QueryKey, CachedResult>>,
207    /// Default TTL for cached results
208    default_ttl: Option<Duration>,
209    /// Cache statistics
210    stats: Arc<CacheStats>,
211}
212
213impl QueryCache {
214    /// Create a new query cache with the given capacity
215    pub fn new(capacity: usize) -> Self {
216        Self {
217            cache: RwLock::new(LruCache::new(
218                NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(100).unwrap()),
219            )),
220            default_ttl: None,
221            stats: Arc::new(CacheStats::new()),
222        }
223    }
224
225    /// Create a new query cache with TTL
226    pub fn with_ttl(capacity: usize, ttl: Duration) -> Self {
227        Self {
228            cache: RwLock::new(LruCache::new(
229                NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(100).unwrap()),
230            )),
231            default_ttl: Some(ttl),
232            stats: Arc::new(CacheStats::new()),
233        }
234    }
235
236    /// Get a cached result
237    #[inline]
238    pub fn get(&self, key: &QueryKey) -> Option<Vec<Substitution>> {
239        let mut cache = self.cache.write();
240
241        if let Some(result) = cache.get(key) {
242            if result.is_expired() {
243                self.stats.record_expiration();
244                cache.pop(key);
245                self.stats.record_miss();
246                return None;
247            }
248            self.stats.record_hit();
249            Some(result.solutions.clone())
250        } else {
251            self.stats.record_miss();
252            None
253        }
254    }
255
256    /// Insert a result into the cache
257    pub fn insert(&self, key: QueryKey, solutions: Vec<Substitution>) {
258        let mut cache = self.cache.write();
259
260        // Check if we need to evict
261        if cache.len() >= cache.cap().get() {
262            self.stats.record_eviction();
263        }
264
265        let result = CachedResult::new(solutions, self.default_ttl);
266        cache.put(key, result);
267    }
268
269    /// Insert a result with custom TTL
270    pub fn insert_with_ttl(&self, key: QueryKey, solutions: Vec<Substitution>, ttl: Duration) {
271        let mut cache = self.cache.write();
272
273        if cache.len() >= cache.cap().get() {
274            self.stats.record_eviction();
275        }
276
277        let result = CachedResult::new(solutions, Some(ttl));
278        cache.put(key, result);
279    }
280
281    /// Invalidate a cached result
282    pub fn invalidate(&self, key: &QueryKey) -> bool {
283        let mut cache = self.cache.write();
284        cache.pop(key).is_some()
285    }
286
287    /// Invalidate all results for a predicate
288    pub fn invalidate_predicate(&self, predicate_name: &str) {
289        let mut cache = self.cache.write();
290        let keys_to_remove: Vec<QueryKey> = cache
291            .iter()
292            .filter(|(k, _)| k.predicate_name == predicate_name)
293            .map(|(k, _)| k.clone())
294            .collect();
295
296        for key in keys_to_remove {
297            cache.pop(&key);
298        }
299    }
300
301    /// Clear the entire cache
302    pub fn clear(&self) {
303        let mut cache = self.cache.write();
304        cache.clear();
305    }
306
307    /// Get cache statistics
308    #[inline]
309    pub fn stats(&self) -> Arc<CacheStats> {
310        self.stats.clone()
311    }
312
313    /// Get current cache size
314    #[inline]
315    pub fn len(&self) -> usize {
316        self.cache.read().len()
317    }
318
319    /// Check if cache is empty
320    #[inline]
321    pub fn is_empty(&self) -> bool {
322        self.cache.read().is_empty()
323    }
324
325    /// Get cache capacity
326    #[inline]
327    pub fn capacity(&self) -> usize {
328        self.cache.read().cap().get()
329    }
330
331    /// Remove expired entries
332    pub fn evict_expired(&self) -> usize {
333        let mut cache = self.cache.write();
334        let mut expired_keys = Vec::new();
335
336        for (key, result) in cache.iter() {
337            if result.is_expired() {
338                expired_keys.push(key.clone());
339            }
340        }
341
342        let count = expired_keys.len();
343        for key in expired_keys {
344            cache.pop(&key);
345            self.stats.record_expiration();
346        }
347
348        count
349    }
350}
351
352impl Default for QueryCache {
353    fn default() -> Self {
354        Self::new(1000)
355    }
356}
357
358/// Remote fact with metadata
359#[derive(Debug, Clone)]
360pub struct RemoteFact {
361    /// The fact predicate
362    pub fact: Predicate,
363    /// Source peer CID
364    pub source: Option<Cid>,
365    /// When the fact was fetched
366    pub fetched_at: Instant,
367    /// Time-to-live
368    pub ttl: Duration,
369}
370
371impl RemoteFact {
372    /// Create a new remote fact
373    pub fn new(fact: Predicate, source: Option<Cid>, ttl: Duration) -> Self {
374        Self {
375            fact,
376            source,
377            fetched_at: Instant::now(),
378            ttl,
379        }
380    }
381
382    /// Check if the fact has expired
383    #[inline]
384    pub fn is_expired(&self) -> bool {
385        self.fetched_at.elapsed() > self.ttl
386    }
387}
388
389/// Cache key for remote facts
390#[derive(Debug, Clone, PartialEq, Eq, Hash)]
391pub struct FactKey {
392    /// Predicate name
393    pub predicate_name: String,
394    /// Serialized arguments
395    pub args_hash: u64,
396}
397
398impl FactKey {
399    /// Create a fact key from a predicate
400    pub fn from_predicate(pred: &Predicate) -> Self {
401        let mut hasher = std::collections::hash_map::DefaultHasher::new();
402        for arg in &pred.args {
403            arg.hash(&mut hasher);
404        }
405        Self {
406            predicate_name: pred.name.clone(),
407            args_hash: hasher.finish(),
408        }
409    }
410}
411
412/// Cache for remote facts
413pub struct RemoteFactCache {
414    /// Facts by predicate name
415    facts: RwLock<HashMap<String, Vec<RemoteFact>>>,
416    /// Maximum facts per predicate
417    max_per_predicate: usize,
418    /// Default TTL
419    default_ttl: Duration,
420    /// Statistics
421    stats: Arc<CacheStats>,
422}
423
424impl RemoteFactCache {
425    /// Create a new remote fact cache
426    pub fn new(max_per_predicate: usize, default_ttl: Duration) -> Self {
427        Self {
428            facts: RwLock::new(HashMap::new()),
429            max_per_predicate,
430            default_ttl,
431            stats: Arc::new(CacheStats::new()),
432        }
433    }
434
435    /// Get facts for a predicate
436    pub fn get_facts(&self, predicate_name: &str) -> Vec<Predicate> {
437        let facts = self.facts.read();
438
439        if let Some(remote_facts) = facts.get(predicate_name) {
440            let valid_facts: Vec<Predicate> = remote_facts
441                .iter()
442                .filter(|f| !f.is_expired())
443                .map(|f| f.fact.clone())
444                .collect();
445
446            if valid_facts.is_empty() {
447                self.stats.record_miss();
448            } else {
449                self.stats.record_hit();
450            }
451
452            valid_facts
453        } else {
454            self.stats.record_miss();
455            Vec::new()
456        }
457    }
458
459    /// Add a fact to the cache
460    pub fn add_fact(&self, fact: Predicate, source: Option<Cid>) {
461        self.add_fact_with_ttl(fact, source, self.default_ttl);
462    }
463
464    /// Add a fact with custom TTL
465    pub fn add_fact_with_ttl(&self, fact: Predicate, source: Option<Cid>, ttl: Duration) {
466        let mut facts = self.facts.write();
467        let name = fact.name.clone();
468
469        let remote_fact = RemoteFact::new(fact, source, ttl);
470
471        let entry = facts.entry(name).or_default();
472
473        // Remove expired facts
474        entry.retain(|f| !f.is_expired());
475
476        // Check capacity
477        if entry.len() >= self.max_per_predicate {
478            // Remove oldest
479            entry.sort_by_key(|f| f.fetched_at);
480            entry.remove(0);
481            self.stats.record_eviction();
482        }
483
484        entry.push(remote_fact);
485    }
486
487    /// Add multiple facts
488    pub fn add_facts(&self, facts: Vec<Predicate>, source: Option<Cid>) {
489        for fact in facts {
490            self.add_fact(fact, source);
491        }
492    }
493
494    /// Invalidate facts for a predicate
495    pub fn invalidate_predicate(&self, predicate_name: &str) {
496        let mut facts = self.facts.write();
497        facts.remove(predicate_name);
498    }
499
500    /// Clear all facts
501    pub fn clear(&self) {
502        let mut facts = self.facts.write();
503        facts.clear();
504    }
505
506    /// Get statistics
507    pub fn stats(&self) -> Arc<CacheStats> {
508        self.stats.clone()
509    }
510
511    /// Remove expired facts
512    pub fn evict_expired(&self) -> usize {
513        let mut facts = self.facts.write();
514        let mut count = 0;
515
516        for entry in facts.values_mut() {
517            let before = entry.len();
518            entry.retain(|f| !f.is_expired());
519            count += before - entry.len();
520        }
521
522        for _ in 0..count {
523            self.stats.record_expiration();
524        }
525
526        count
527    }
528
529    /// Get total number of cached facts
530    pub fn len(&self) -> usize {
531        let facts = self.facts.read();
532        facts.values().map(|v| v.len()).sum()
533    }
534
535    /// Check if empty
536    pub fn is_empty(&self) -> bool {
537        self.len() == 0
538    }
539}
540
541impl Default for RemoteFactCache {
542    fn default() -> Self {
543        Self::new(1000, Duration::from_secs(300))
544    }
545}
546
547/// Combined cache manager
548pub struct CacheManager {
549    /// Query result cache
550    pub query_cache: QueryCache,
551    /// Remote fact cache
552    pub fact_cache: RemoteFactCache,
553}
554
555impl CacheManager {
556    /// Create a new cache manager with default settings
557    pub fn new() -> Self {
558        Self {
559            query_cache: QueryCache::new(10000),
560            fact_cache: RemoteFactCache::new(1000, Duration::from_secs(300)),
561        }
562    }
563
564    /// Create with custom settings
565    pub fn with_config(
566        query_capacity: usize,
567        query_ttl: Option<Duration>,
568        fact_capacity: usize,
569        fact_ttl: Duration,
570    ) -> Self {
571        let query_cache = if let Some(ttl) = query_ttl {
572            QueryCache::with_ttl(query_capacity, ttl)
573        } else {
574            QueryCache::new(query_capacity)
575        };
576
577        Self {
578            query_cache,
579            fact_cache: RemoteFactCache::new(fact_capacity, fact_ttl),
580        }
581    }
582
583    /// Evict all expired entries
584    pub fn evict_expired(&self) -> (usize, usize) {
585        let queries = self.query_cache.evict_expired();
586        let facts = self.fact_cache.evict_expired();
587        (queries, facts)
588    }
589
590    /// Clear all caches
591    pub fn clear_all(&self) {
592        self.query_cache.clear();
593        self.fact_cache.clear();
594    }
595
596    /// Get combined statistics
597    pub fn stats(&self) -> CombinedCacheStats {
598        CombinedCacheStats {
599            query_stats: self.query_cache.stats().snapshot(),
600            fact_stats: self.fact_cache.stats().snapshot(),
601            query_cache_size: self.query_cache.len(),
602            fact_cache_size: self.fact_cache.len(),
603        }
604    }
605}
606
607impl Default for CacheManager {
608    fn default() -> Self {
609        Self::new()
610    }
611}
612
613/// Combined cache statistics
614#[derive(Debug, Clone, Serialize, Deserialize)]
615pub struct CombinedCacheStats {
616    /// Query cache statistics
617    pub query_stats: CacheStatsSnapshot,
618    /// Fact cache statistics
619    pub fact_stats: CacheStatsSnapshot,
620    /// Current query cache size
621    pub query_cache_size: usize,
622    /// Current fact cache size
623    pub fact_cache_size: usize,
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629    use crate::ir::Constant;
630    use std::thread::sleep;
631
632    #[test]
633    fn test_query_cache_basic() {
634        let cache = QueryCache::new(100);
635
636        let key = QueryKey {
637            predicate_name: "test".to_string(),
638            ground_args: vec![GroundArg::String("value".to_string())],
639        };
640
641        let solutions = vec![Substitution::new()];
642        cache.insert(key.clone(), solutions.clone());
643
644        let result = cache.get(&key);
645        assert!(result.is_some());
646        assert_eq!(result.unwrap().len(), 1);
647    }
648
649    #[test]
650    fn test_query_cache_ttl() {
651        let cache = QueryCache::with_ttl(100, Duration::from_millis(50));
652
653        let key = QueryKey {
654            predicate_name: "test".to_string(),
655            ground_args: vec![],
656        };
657
658        cache.insert(key.clone(), vec![Substitution::new()]);
659
660        // Should be valid immediately
661        assert!(cache.get(&key).is_some());
662
663        // Wait for TTL to expire
664        sleep(Duration::from_millis(100));
665
666        // Should be expired now
667        assert!(cache.get(&key).is_none());
668    }
669
670    #[test]
671    fn test_query_cache_stats() {
672        let cache = QueryCache::new(100);
673
674        let key = QueryKey {
675            predicate_name: "test".to_string(),
676            ground_args: vec![],
677        };
678
679        // Miss
680        cache.get(&key);
681
682        // Insert and hit
683        cache.insert(key.clone(), vec![]);
684        cache.get(&key);
685
686        let stats = cache.stats().snapshot();
687        assert_eq!(stats.hits, 1);
688        assert_eq!(stats.misses, 1);
689    }
690
691    #[test]
692    fn test_remote_fact_cache() {
693        let cache = RemoteFactCache::new(100, Duration::from_secs(60));
694
695        let fact = Predicate::new(
696            "test".to_string(),
697            vec![Term::Const(Constant::String("value".to_string()))],
698        );
699
700        cache.add_fact(fact.clone(), None);
701
702        let facts = cache.get_facts("test");
703        assert_eq!(facts.len(), 1);
704        assert_eq!(facts[0].name, "test");
705    }
706
707    #[test]
708    fn test_remote_fact_cache_ttl() {
709        let cache = RemoteFactCache::new(100, Duration::from_millis(50));
710
711        let fact = Predicate::new("test".to_string(), vec![]);
712        cache.add_fact(fact, None);
713
714        // Should be valid
715        assert_eq!(cache.get_facts("test").len(), 1);
716
717        // Wait for TTL
718        sleep(Duration::from_millis(100));
719
720        // Should be expired
721        assert!(cache.get_facts("test").is_empty());
722    }
723
724    #[test]
725    fn test_cache_manager() {
726        let manager = CacheManager::new();
727
728        // Test query cache
729        let key = QueryKey {
730            predicate_name: "test".to_string(),
731            ground_args: vec![],
732        };
733        manager.query_cache.insert(key.clone(), vec![]);
734        assert!(manager.query_cache.get(&key).is_some());
735
736        // Test fact cache
737        let fact = Predicate::new("fact".to_string(), vec![]);
738        manager.fact_cache.add_fact(fact, None);
739        assert_eq!(manager.fact_cache.get_facts("fact").len(), 1);
740
741        // Test stats
742        let stats = manager.stats();
743        assert!(stats.query_cache_size > 0);
744        assert!(stats.fact_cache_size > 0);
745    }
746}