Skip to main content

exo_temporal/
anticipation.rs

1//! Predictive anticipation and pre-fetching
2
3use crate::causal::CausalGraph;
4use crate::long_term::LongTermStore;
5use crate::types::{PatternId, Query, SearchResult};
6use dashmap::DashMap;
7use parking_lot::RwLock;
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11/// Anticipation hint types
12#[derive(Debug, Clone)]
13pub enum AnticipationHint {
14    /// Sequential pattern: if A then B
15    SequentialPattern {
16        /// Recent query patterns
17        recent: Vec<PatternId>,
18    },
19    /// Temporal cycle (time-of-day patterns)
20    TemporalCycle {
21        /// Current temporal phase
22        phase: TemporalPhase,
23    },
24    /// Causal chain prediction
25    CausalChain {
26        /// Current context pattern
27        context: PatternId,
28    },
29}
30
31/// Temporal phase for cyclic patterns
32#[derive(Debug, Clone, Copy)]
33pub enum TemporalPhase {
34    /// Hour of day (0-23)
35    HourOfDay(u8),
36    /// Day of week (0-6)
37    DayOfWeek(u8),
38    /// Custom phase
39    Custom(u32),
40}
41
42/// Prefetch cache for anticipated queries
43pub struct PrefetchCache {
44    /// Cached query results
45    cache: DashMap<u64, Vec<SearchResult>>,
46    /// Cache capacity
47    capacity: usize,
48    /// LRU tracking
49    lru: Arc<RwLock<VecDeque<u64>>>,
50}
51
52impl PrefetchCache {
53    /// Create new prefetch cache
54    pub fn new(capacity: usize) -> Self {
55        Self {
56            cache: DashMap::new(),
57            capacity,
58            lru: Arc::new(RwLock::new(VecDeque::with_capacity(capacity))),
59        }
60    }
61
62    /// Insert into cache
63    pub fn insert(&self, query_hash: u64, results: Vec<SearchResult>) {
64        // Check capacity
65        if self.cache.len() >= self.capacity {
66            self.evict_lru();
67        }
68
69        // Insert
70        self.cache.insert(query_hash, results);
71
72        // Update LRU
73        let mut lru = self.lru.write();
74        lru.push_back(query_hash);
75    }
76
77    /// Get from cache
78    pub fn get(&self, query_hash: u64) -> Option<Vec<SearchResult>> {
79        self.cache.get(&query_hash).map(|v| v.clone())
80    }
81
82    /// Evict least recently used entry
83    fn evict_lru(&self) {
84        let mut lru = self.lru.write();
85        if let Some(key) = lru.pop_front() {
86            self.cache.remove(&key);
87        }
88    }
89
90    /// Clear cache
91    pub fn clear(&self) {
92        self.cache.clear();
93        self.lru.write().clear();
94    }
95
96    /// Get cache size
97    pub fn len(&self) -> usize {
98        self.cache.len()
99    }
100
101    /// Check if cache is empty
102    pub fn is_empty(&self) -> bool {
103        self.cache.is_empty()
104    }
105}
106
107impl Default for PrefetchCache {
108    fn default() -> Self {
109        Self::new(1000)
110    }
111}
112
113/// Optimized sequential pattern tracker with pre-computed frequencies
114pub struct SequentialPatternTracker {
115    /// Pre-computed frequency maps for O(1) prediction lookup
116    /// Key: source pattern, Value: sorted vector of (count, target pattern)
117    frequency_cache: DashMap<PatternId, Vec<(usize, PatternId)>>,
118    /// Raw counts for incremental updates
119    counts: DashMap<(PatternId, PatternId), usize>,
120    /// Cache validity flags
121    cache_valid: DashMap<PatternId, bool>,
122    /// Total sequences recorded (for statistics)
123    total_sequences: std::sync::atomic::AtomicUsize,
124}
125
126impl SequentialPatternTracker {
127    /// Create new tracker
128    pub fn new() -> Self {
129        Self {
130            frequency_cache: DashMap::new(),
131            counts: DashMap::new(),
132            cache_valid: DashMap::new(),
133            total_sequences: std::sync::atomic::AtomicUsize::new(0),
134        }
135    }
136
137    /// Record sequence: A followed by B (optimized with lazy cache invalidation)
138    pub fn record_sequence(&self, from: PatternId, to: PatternId) {
139        // Increment count atomically
140        *self.counts.entry((from, to)).or_insert(0) += 1;
141
142        // Invalidate cache for this source pattern
143        self.cache_valid.insert(from, false);
144
145        // Track total sequences
146        self.total_sequences
147            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
148    }
149
150    /// Predict next pattern given current (optimized O(1) cache lookup)
151    pub fn predict_next(&self, current: PatternId, top_k: usize) -> Vec<PatternId> {
152        // Check if cache is valid
153        let cache_valid = self.cache_valid.get(&current).map(|v| *v).unwrap_or(false);
154
155        if !cache_valid {
156            // Rebuild cache for this pattern
157            self.rebuild_cache(current);
158        }
159
160        // Fast O(1) lookup from pre-sorted cache
161        if let Some(sorted) = self.frequency_cache.get(&current) {
162            sorted.iter().take(top_k).map(|(_, id)| *id).collect()
163        } else {
164            Vec::new()
165        }
166    }
167
168    /// Rebuild frequency cache for a specific pattern
169    fn rebuild_cache(&self, pattern: PatternId) {
170        let mut freq_vec: Vec<(usize, PatternId)> = Vec::new();
171
172        // Collect all (pattern, target) pairs for this source
173        for entry in self.counts.iter() {
174            let (from, to) = *entry.key();
175            if from == pattern {
176                freq_vec.push((*entry.value(), to));
177            }
178        }
179
180        // Sort by count descending (higher frequency first)
181        freq_vec.sort_by(|a, b| b.0.cmp(&a.0));
182
183        // Update cache
184        self.frequency_cache.insert(pattern, freq_vec);
185        self.cache_valid.insert(pattern, true);
186    }
187
188    /// Get total number of recorded sequences
189    pub fn total_sequences(&self) -> usize {
190        self.total_sequences
191            .load(std::sync::atomic::Ordering::Relaxed)
192    }
193
194    /// Get prediction accuracy estimate (based on frequency distribution)
195    pub fn prediction_confidence(&self, pattern: PatternId) -> f32 {
196        if let Some(sorted) = self.frequency_cache.get(&pattern) {
197            if sorted.is_empty() {
198                return 0.0;
199            }
200            let total: usize = sorted.iter().map(|(c, _)| c).sum();
201            if total == 0 {
202                return 0.0;
203            }
204            // Confidence = top prediction count / total count
205            sorted[0].0 as f32 / total as f32
206        } else {
207            0.0
208        }
209    }
210
211    /// Batch record multiple sequences (optimized for bulk operations)
212    pub fn record_sequences_batch(&self, sequences: &[(PatternId, PatternId)]) {
213        let mut invalidated = std::collections::HashSet::new();
214
215        for (from, to) in sequences {
216            *self.counts.entry((*from, *to)).or_insert(0) += 1;
217            invalidated.insert(*from);
218        }
219
220        // Batch invalidate caches
221        for pattern in invalidated {
222            self.cache_valid.insert(pattern, false);
223        }
224
225        self.total_sequences
226            .fetch_add(sequences.len(), std::sync::atomic::Ordering::Relaxed);
227    }
228}
229
230impl Default for SequentialPatternTracker {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236/// Anticipate future queries and pre-fetch
237pub fn anticipate(
238    hints: &[AnticipationHint],
239    long_term: &LongTermStore,
240    causal_graph: &CausalGraph,
241    prefetch_cache: &PrefetchCache,
242    sequential_tracker: &SequentialPatternTracker,
243) -> usize {
244    let mut num_prefetched = 0;
245
246    for hint in hints {
247        match hint {
248            AnticipationHint::SequentialPattern { recent } => {
249                // Predict next based on recent patterns
250                if let Some(&last) = recent.last() {
251                    let predicted = sequential_tracker.predict_next(last, 5);
252
253                    for pattern_id in predicted {
254                        if let Some(temporal_pattern) = long_term.get(&pattern_id) {
255                            // Create query from pattern
256                            let query =
257                                Query::from_embedding(temporal_pattern.pattern.embedding.clone());
258                            let query_hash = query.hash();
259
260                            // Pre-fetch if not cached
261                            if prefetch_cache.get(query_hash).is_none() {
262                                let results = long_term.search(&query);
263                                prefetch_cache.insert(query_hash, results);
264                                num_prefetched += 1;
265                            }
266                        }
267                    }
268                }
269            }
270
271            AnticipationHint::TemporalCycle { phase } => {
272                // Encode the temporal phase as a sinusoidal query vector and
273                // pre-fetch high-salience patterns for this recurring time slot.
274                let phase_ratio = match phase {
275                    TemporalPhase::HourOfDay(h) => *h as f64 / 24.0,
276                    TemporalPhase::DayOfWeek(d) => *d as f64 / 7.0,
277                    TemporalPhase::Custom(c) => (*c as f64 % 1000.0) / 1000.0,
278                };
279
280                // Build a 32-dim sinusoidal embedding for the phase
281                let dim = 32usize;
282                let query_vec: Vec<f32> = (0..dim)
283                    .map(|i| {
284                        let angle =
285                            2.0 * std::f64::consts::PI * phase_ratio * (i + 1) as f64 / dim as f64;
286                        angle.sin() as f32
287                    })
288                    .collect();
289
290                let query = Query::from_embedding(query_vec);
291                let query_hash = query.hash();
292
293                if prefetch_cache.get(query_hash).is_none() {
294                    let results = long_term.search(&query);
295                    if !results.is_empty() {
296                        prefetch_cache.insert(query_hash, results);
297                        num_prefetched += 1;
298                    }
299                }
300            }
301
302            AnticipationHint::CausalChain { context } => {
303                // Predict downstream patterns in causal graph
304                let downstream = causal_graph.causal_future(*context);
305
306                for pattern_id in downstream.into_iter().take(5) {
307                    if let Some(temporal_pattern) = long_term.get(&pattern_id) {
308                        let query =
309                            Query::from_embedding(temporal_pattern.pattern.embedding.clone());
310                        let query_hash = query.hash();
311
312                        // Pre-fetch if not cached
313                        if prefetch_cache.get(query_hash).is_none() {
314                            let results = long_term.search(&query);
315                            prefetch_cache.insert(query_hash, results);
316                            num_prefetched += 1;
317                        }
318                    }
319                }
320            }
321        }
322    }
323
324    num_prefetched
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_prefetch_cache() {
333        let cache = PrefetchCache::new(2);
334
335        let results1 = vec![];
336        let results2 = vec![];
337
338        cache.insert(1, results1);
339        cache.insert(2, results2);
340
341        assert_eq!(cache.len(), 2);
342        assert!(cache.get(1).is_some());
343
344        // Insert third should evict first (LRU)
345        cache.insert(3, vec![]);
346        assert_eq!(cache.len(), 2);
347        assert!(cache.get(1).is_none());
348    }
349
350    #[test]
351    fn test_sequential_tracker() {
352        let tracker = SequentialPatternTracker::new();
353
354        let p1 = PatternId::new();
355        let p2 = PatternId::new();
356        let p3 = PatternId::new();
357
358        // p1 -> p2 (twice)
359        tracker.record_sequence(p1, p2);
360        tracker.record_sequence(p1, p2);
361
362        // p1 -> p3 (once)
363        tracker.record_sequence(p1, p3);
364
365        let predicted = tracker.predict_next(p1, 2);
366
367        // p2 should be first (more frequent)
368        assert_eq!(predicted.len(), 2);
369        assert_eq!(predicted[0], p2);
370
371        // Test total sequences tracking
372        assert_eq!(tracker.total_sequences(), 3);
373
374        // Test prediction confidence
375        let confidence = tracker.prediction_confidence(p1);
376        assert!(confidence > 0.6); // p2 appears 2 out of 3 times
377    }
378
379    #[test]
380    fn test_batch_recording() {
381        let tracker = SequentialPatternTracker::new();
382
383        let p1 = PatternId::new();
384        let p2 = PatternId::new();
385        let p3 = PatternId::new();
386
387        let sequences = vec![(p1, p2), (p1, p2), (p1, p3), (p2, p3)];
388
389        tracker.record_sequences_batch(&sequences);
390
391        assert_eq!(tracker.total_sequences(), 4);
392
393        let predicted = tracker.predict_next(p1, 1);
394        assert_eq!(predicted[0], p2);
395    }
396}