exo_temporal/
anticipation.rs

1//! Predictive anticipation and pre-fetching
2
3use crate::causal::CausalGraph;
4use crate::long_term::LongTermStore;
5use crate::types::{PatternId, Query, SearchResult};
6use dashmap::DashMap;
7use parking_lot::RwLock;
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11/// Anticipation hint types
12#[derive(Debug, Clone)]
13pub enum AnticipationHint {
14    /// Sequential pattern: if A then B
15    SequentialPattern {
16        /// Recent query patterns
17        recent: Vec<PatternId>,
18    },
19    /// Temporal cycle (time-of-day patterns)
20    TemporalCycle {
21        /// Current temporal phase
22        phase: TemporalPhase,
23    },
24    /// Causal chain prediction
25    CausalChain {
26        /// Current context pattern
27        context: PatternId,
28    },
29}
30
31/// Temporal phase for cyclic patterns
32#[derive(Debug, Clone, Copy)]
33pub enum TemporalPhase {
34    /// Hour of day (0-23)
35    HourOfDay(u8),
36    /// Day of week (0-6)
37    DayOfWeek(u8),
38    /// Custom phase
39    Custom(u32),
40}
41
42/// Prefetch cache for anticipated queries
43pub struct PrefetchCache {
44    /// Cached query results
45    cache: DashMap<u64, Vec<SearchResult>>,
46    /// Cache capacity
47    capacity: usize,
48    /// LRU tracking
49    lru: Arc<RwLock<VecDeque<u64>>>,
50}
51
52impl PrefetchCache {
53    /// Create new prefetch cache
54    pub fn new(capacity: usize) -> Self {
55        Self {
56            cache: DashMap::new(),
57            capacity,
58            lru: Arc::new(RwLock::new(VecDeque::with_capacity(capacity))),
59        }
60    }
61
62    /// Insert into cache
63    pub fn insert(&self, query_hash: u64, results: Vec<SearchResult>) {
64        // Check capacity
65        if self.cache.len() >= self.capacity {
66            self.evict_lru();
67        }
68
69        // Insert
70        self.cache.insert(query_hash, results);
71
72        // Update LRU
73        let mut lru = self.lru.write();
74        lru.push_back(query_hash);
75    }
76
77    /// Get from cache
78    pub fn get(&self, query_hash: u64) -> Option<Vec<SearchResult>> {
79        self.cache.get(&query_hash).map(|v| v.clone())
80    }
81
82    /// Evict least recently used entry
83    fn evict_lru(&self) {
84        let mut lru = self.lru.write();
85        if let Some(key) = lru.pop_front() {
86            self.cache.remove(&key);
87        }
88    }
89
90    /// Clear cache
91    pub fn clear(&self) {
92        self.cache.clear();
93        self.lru.write().clear();
94    }
95
96    /// Get cache size
97    pub fn len(&self) -> usize {
98        self.cache.len()
99    }
100
101    /// Check if cache is empty
102    pub fn is_empty(&self) -> bool {
103        self.cache.is_empty()
104    }
105}
106
107impl Default for PrefetchCache {
108    fn default() -> Self {
109        Self::new(1000)
110    }
111}
112
113/// Optimized sequential pattern tracker with pre-computed frequencies
114pub struct SequentialPatternTracker {
115    /// Pre-computed frequency maps for O(1) prediction lookup
116    /// Key: source pattern, Value: sorted vector of (count, target pattern)
117    frequency_cache: DashMap<PatternId, Vec<(usize, PatternId)>>,
118    /// Raw counts for incremental updates
119    counts: DashMap<(PatternId, PatternId), usize>,
120    /// Cache validity flags
121    cache_valid: DashMap<PatternId, bool>,
122    /// Total sequences recorded (for statistics)
123    total_sequences: std::sync::atomic::AtomicUsize,
124}
125
126impl SequentialPatternTracker {
127    /// Create new tracker
128    pub fn new() -> Self {
129        Self {
130            frequency_cache: DashMap::new(),
131            counts: DashMap::new(),
132            cache_valid: DashMap::new(),
133            total_sequences: std::sync::atomic::AtomicUsize::new(0),
134        }
135    }
136
137    /// Record sequence: A followed by B (optimized with lazy cache invalidation)
138    pub fn record_sequence(&self, from: PatternId, to: PatternId) {
139        // Increment count atomically
140        *self.counts.entry((from, to)).or_insert(0) += 1;
141
142        // Invalidate cache for this source pattern
143        self.cache_valid.insert(from, false);
144
145        // Track total sequences
146        self.total_sequences.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147    }
148
149    /// Predict next pattern given current (optimized O(1) cache lookup)
150    pub fn predict_next(&self, current: PatternId, top_k: usize) -> Vec<PatternId> {
151        // Check if cache is valid
152        let cache_valid = self.cache_valid.get(&current).map(|v| *v).unwrap_or(false);
153
154        if !cache_valid {
155            // Rebuild cache for this pattern
156            self.rebuild_cache(current);
157        }
158
159        // Fast O(1) lookup from pre-sorted cache
160        if let Some(sorted) = self.frequency_cache.get(&current) {
161            sorted.iter()
162                .take(top_k)
163                .map(|(_, id)| *id)
164                .collect()
165        } else {
166            Vec::new()
167        }
168    }
169
170    /// Rebuild frequency cache for a specific pattern
171    fn rebuild_cache(&self, pattern: PatternId) {
172        let mut freq_vec: Vec<(usize, PatternId)> = Vec::new();
173
174        // Collect all (pattern, target) pairs for this source
175        for entry in self.counts.iter() {
176            let (from, to) = *entry.key();
177            if from == pattern {
178                freq_vec.push((*entry.value(), to));
179            }
180        }
181
182        // Sort by count descending (higher frequency first)
183        freq_vec.sort_by(|a, b| b.0.cmp(&a.0));
184
185        // Update cache
186        self.frequency_cache.insert(pattern, freq_vec);
187        self.cache_valid.insert(pattern, true);
188    }
189
190    /// Get total number of recorded sequences
191    pub fn total_sequences(&self) -> usize {
192        self.total_sequences.load(std::sync::atomic::Ordering::Relaxed)
193    }
194
195    /// Get prediction accuracy estimate (based on frequency distribution)
196    pub fn prediction_confidence(&self, pattern: PatternId) -> f32 {
197        if let Some(sorted) = self.frequency_cache.get(&pattern) {
198            if sorted.is_empty() {
199                return 0.0;
200            }
201            let total: usize = sorted.iter().map(|(c, _)| c).sum();
202            if total == 0 {
203                return 0.0;
204            }
205            // Confidence = top prediction count / total count
206            sorted[0].0 as f32 / total as f32
207        } else {
208            0.0
209        }
210    }
211
212    /// Batch record multiple sequences (optimized for bulk operations)
213    pub fn record_sequences_batch(&self, sequences: &[(PatternId, PatternId)]) {
214        let mut invalidated = std::collections::HashSet::new();
215
216        for (from, to) in sequences {
217            *self.counts.entry((*from, *to)).or_insert(0) += 1;
218            invalidated.insert(*from);
219        }
220
221        // Batch invalidate caches
222        for pattern in invalidated {
223            self.cache_valid.insert(pattern, false);
224        }
225
226        self.total_sequences.fetch_add(sequences.len(), std::sync::atomic::Ordering::Relaxed);
227    }
228}
229
230impl Default for SequentialPatternTracker {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236/// Anticipate future queries and pre-fetch
237pub fn anticipate(
238    hints: &[AnticipationHint],
239    long_term: &LongTermStore,
240    causal_graph: &CausalGraph,
241    prefetch_cache: &PrefetchCache,
242    sequential_tracker: &SequentialPatternTracker,
243) -> usize {
244    let mut num_prefetched = 0;
245
246    for hint in hints {
247        match hint {
248            AnticipationHint::SequentialPattern { recent } => {
249                // Predict next based on recent patterns
250                if let Some(&last) = recent.last() {
251                    let predicted = sequential_tracker.predict_next(last, 5);
252
253                    for pattern_id in predicted {
254                        if let Some(temporal_pattern) = long_term.get(&pattern_id) {
255                            // Create query from pattern
256                            let query = Query::from_embedding(temporal_pattern.pattern.embedding.clone());
257                            let query_hash = query.hash();
258
259                            // Pre-fetch if not cached
260                            if prefetch_cache.get(query_hash).is_none() {
261                                let results = long_term.search(&query);
262                                prefetch_cache.insert(query_hash, results);
263                                num_prefetched += 1;
264                            }
265                        }
266                    }
267                }
268            }
269
270            AnticipationHint::TemporalCycle { phase: _ } => {
271                // TODO: Implement temporal cycle prediction
272                // Would track queries by time-of-day/day-of-week
273                // and pre-fetch commonly accessed patterns for current phase
274            }
275
276            AnticipationHint::CausalChain { context } => {
277                // Predict downstream patterns in causal graph
278                let downstream = causal_graph.causal_future(*context);
279
280                for pattern_id in downstream.into_iter().take(5) {
281                    if let Some(temporal_pattern) = long_term.get(&pattern_id) {
282                        let query = Query::from_embedding(temporal_pattern.pattern.embedding.clone());
283                        let query_hash = query.hash();
284
285                        // Pre-fetch if not cached
286                        if prefetch_cache.get(query_hash).is_none() {
287                            let results = long_term.search(&query);
288                            prefetch_cache.insert(query_hash, results);
289                            num_prefetched += 1;
290                        }
291                    }
292                }
293            }
294        }
295    }
296
297    num_prefetched
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_prefetch_cache() {
306        let cache = PrefetchCache::new(2);
307
308        let results1 = vec![];
309        let results2 = vec![];
310
311        cache.insert(1, results1);
312        cache.insert(2, results2);
313
314        assert_eq!(cache.len(), 2);
315        assert!(cache.get(1).is_some());
316
317        // Insert third should evict first (LRU)
318        cache.insert(3, vec![]);
319        assert_eq!(cache.len(), 2);
320        assert!(cache.get(1).is_none());
321    }
322
323    #[test]
324    fn test_sequential_tracker() {
325        let tracker = SequentialPatternTracker::new();
326
327        let p1 = PatternId::new();
328        let p2 = PatternId::new();
329        let p3 = PatternId::new();
330
331        // p1 -> p2 (twice)
332        tracker.record_sequence(p1, p2);
333        tracker.record_sequence(p1, p2);
334
335        // p1 -> p3 (once)
336        tracker.record_sequence(p1, p3);
337
338        let predicted = tracker.predict_next(p1, 2);
339
340        // p2 should be first (more frequent)
341        assert_eq!(predicted.len(), 2);
342        assert_eq!(predicted[0], p2);
343
344        // Test total sequences tracking
345        assert_eq!(tracker.total_sequences(), 3);
346
347        // Test prediction confidence
348        let confidence = tracker.prediction_confidence(p1);
349        assert!(confidence > 0.6); // p2 appears 2 out of 3 times
350    }
351
352    #[test]
353    fn test_batch_recording() {
354        let tracker = SequentialPatternTracker::new();
355
356        let p1 = PatternId::new();
357        let p2 = PatternId::new();
358        let p3 = PatternId::new();
359
360        let sequences = vec![
361            (p1, p2),
362            (p1, p2),
363            (p1, p3),
364            (p2, p3),
365        ];
366
367        tracker.record_sequences_batch(&sequences);
368
369        assert_eq!(tracker.total_sequences(), 4);
370
371        let predicted = tracker.predict_next(p1, 1);
372        assert_eq!(predicted[0], p2);
373    }
374}