1use crate::causal::CausalGraph;
4use crate::long_term::LongTermStore;
5use crate::types::{PatternId, Query, SearchResult};
6use dashmap::DashMap;
7use parking_lot::RwLock;
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub enum AnticipationHint {
14 SequentialPattern {
16 recent: Vec<PatternId>,
18 },
19 TemporalCycle {
21 phase: TemporalPhase,
23 },
24 CausalChain {
26 context: PatternId,
28 },
29}
30
31#[derive(Debug, Clone, Copy)]
33pub enum TemporalPhase {
34 HourOfDay(u8),
36 DayOfWeek(u8),
38 Custom(u32),
40}
41
42pub struct PrefetchCache {
44 cache: DashMap<u64, Vec<SearchResult>>,
46 capacity: usize,
48 lru: Arc<RwLock<VecDeque<u64>>>,
50}
51
52impl PrefetchCache {
53 pub fn new(capacity: usize) -> Self {
55 Self {
56 cache: DashMap::new(),
57 capacity,
58 lru: Arc::new(RwLock::new(VecDeque::with_capacity(capacity))),
59 }
60 }
61
62 pub fn insert(&self, query_hash: u64, results: Vec<SearchResult>) {
64 if self.cache.len() >= self.capacity {
66 self.evict_lru();
67 }
68
69 self.cache.insert(query_hash, results);
71
72 let mut lru = self.lru.write();
74 lru.push_back(query_hash);
75 }
76
77 pub fn get(&self, query_hash: u64) -> Option<Vec<SearchResult>> {
79 self.cache.get(&query_hash).map(|v| v.clone())
80 }
81
82 fn evict_lru(&self) {
84 let mut lru = self.lru.write();
85 if let Some(key) = lru.pop_front() {
86 self.cache.remove(&key);
87 }
88 }
89
90 pub fn clear(&self) {
92 self.cache.clear();
93 self.lru.write().clear();
94 }
95
96 pub fn len(&self) -> usize {
98 self.cache.len()
99 }
100
101 pub fn is_empty(&self) -> bool {
103 self.cache.is_empty()
104 }
105}
106
107impl Default for PrefetchCache {
108 fn default() -> Self {
109 Self::new(1000)
110 }
111}
112
113pub struct SequentialPatternTracker {
115 frequency_cache: DashMap<PatternId, Vec<(usize, PatternId)>>,
118 counts: DashMap<(PatternId, PatternId), usize>,
120 cache_valid: DashMap<PatternId, bool>,
122 total_sequences: std::sync::atomic::AtomicUsize,
124}
125
126impl SequentialPatternTracker {
127 pub fn new() -> Self {
129 Self {
130 frequency_cache: DashMap::new(),
131 counts: DashMap::new(),
132 cache_valid: DashMap::new(),
133 total_sequences: std::sync::atomic::AtomicUsize::new(0),
134 }
135 }
136
137 pub fn record_sequence(&self, from: PatternId, to: PatternId) {
139 *self.counts.entry((from, to)).or_insert(0) += 1;
141
142 self.cache_valid.insert(from, false);
144
145 self.total_sequences.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147 }
148
149 pub fn predict_next(&self, current: PatternId, top_k: usize) -> Vec<PatternId> {
151 let cache_valid = self.cache_valid.get(¤t).map(|v| *v).unwrap_or(false);
153
154 if !cache_valid {
155 self.rebuild_cache(current);
157 }
158
159 if let Some(sorted) = self.frequency_cache.get(¤t) {
161 sorted.iter()
162 .take(top_k)
163 .map(|(_, id)| *id)
164 .collect()
165 } else {
166 Vec::new()
167 }
168 }
169
170 fn rebuild_cache(&self, pattern: PatternId) {
172 let mut freq_vec: Vec<(usize, PatternId)> = Vec::new();
173
174 for entry in self.counts.iter() {
176 let (from, to) = *entry.key();
177 if from == pattern {
178 freq_vec.push((*entry.value(), to));
179 }
180 }
181
182 freq_vec.sort_by(|a, b| b.0.cmp(&a.0));
184
185 self.frequency_cache.insert(pattern, freq_vec);
187 self.cache_valid.insert(pattern, true);
188 }
189
190 pub fn total_sequences(&self) -> usize {
192 self.total_sequences.load(std::sync::atomic::Ordering::Relaxed)
193 }
194
195 pub fn prediction_confidence(&self, pattern: PatternId) -> f32 {
197 if let Some(sorted) = self.frequency_cache.get(&pattern) {
198 if sorted.is_empty() {
199 return 0.0;
200 }
201 let total: usize = sorted.iter().map(|(c, _)| c).sum();
202 if total == 0 {
203 return 0.0;
204 }
205 sorted[0].0 as f32 / total as f32
207 } else {
208 0.0
209 }
210 }
211
212 pub fn record_sequences_batch(&self, sequences: &[(PatternId, PatternId)]) {
214 let mut invalidated = std::collections::HashSet::new();
215
216 for (from, to) in sequences {
217 *self.counts.entry((*from, *to)).or_insert(0) += 1;
218 invalidated.insert(*from);
219 }
220
221 for pattern in invalidated {
223 self.cache_valid.insert(pattern, false);
224 }
225
226 self.total_sequences.fetch_add(sequences.len(), std::sync::atomic::Ordering::Relaxed);
227 }
228}
229
230impl Default for SequentialPatternTracker {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236pub fn anticipate(
238 hints: &[AnticipationHint],
239 long_term: &LongTermStore,
240 causal_graph: &CausalGraph,
241 prefetch_cache: &PrefetchCache,
242 sequential_tracker: &SequentialPatternTracker,
243) -> usize {
244 let mut num_prefetched = 0;
245
246 for hint in hints {
247 match hint {
248 AnticipationHint::SequentialPattern { recent } => {
249 if let Some(&last) = recent.last() {
251 let predicted = sequential_tracker.predict_next(last, 5);
252
253 for pattern_id in predicted {
254 if let Some(temporal_pattern) = long_term.get(&pattern_id) {
255 let query = Query::from_embedding(temporal_pattern.pattern.embedding.clone());
257 let query_hash = query.hash();
258
259 if prefetch_cache.get(query_hash).is_none() {
261 let results = long_term.search(&query);
262 prefetch_cache.insert(query_hash, results);
263 num_prefetched += 1;
264 }
265 }
266 }
267 }
268 }
269
270 AnticipationHint::TemporalCycle { phase: _ } => {
271 }
275
276 AnticipationHint::CausalChain { context } => {
277 let downstream = causal_graph.causal_future(*context);
279
280 for pattern_id in downstream.into_iter().take(5) {
281 if let Some(temporal_pattern) = long_term.get(&pattern_id) {
282 let query = Query::from_embedding(temporal_pattern.pattern.embedding.clone());
283 let query_hash = query.hash();
284
285 if prefetch_cache.get(query_hash).is_none() {
287 let results = long_term.search(&query);
288 prefetch_cache.insert(query_hash, results);
289 num_prefetched += 1;
290 }
291 }
292 }
293 }
294 }
295 }
296
297 num_prefetched
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_prefetch_cache() {
306 let cache = PrefetchCache::new(2);
307
308 let results1 = vec![];
309 let results2 = vec![];
310
311 cache.insert(1, results1);
312 cache.insert(2, results2);
313
314 assert_eq!(cache.len(), 2);
315 assert!(cache.get(1).is_some());
316
317 cache.insert(3, vec![]);
319 assert_eq!(cache.len(), 2);
320 assert!(cache.get(1).is_none());
321 }
322
323 #[test]
324 fn test_sequential_tracker() {
325 let tracker = SequentialPatternTracker::new();
326
327 let p1 = PatternId::new();
328 let p2 = PatternId::new();
329 let p3 = PatternId::new();
330
331 tracker.record_sequence(p1, p2);
333 tracker.record_sequence(p1, p2);
334
335 tracker.record_sequence(p1, p3);
337
338 let predicted = tracker.predict_next(p1, 2);
339
340 assert_eq!(predicted.len(), 2);
342 assert_eq!(predicted[0], p2);
343
344 assert_eq!(tracker.total_sequences(), 3);
346
347 let confidence = tracker.prediction_confidence(p1);
349 assert!(confidence > 0.6); }
351
352 #[test]
353 fn test_batch_recording() {
354 let tracker = SequentialPatternTracker::new();
355
356 let p1 = PatternId::new();
357 let p2 = PatternId::new();
358 let p3 = PatternId::new();
359
360 let sequences = vec![
361 (p1, p2),
362 (p1, p2),
363 (p1, p3),
364 (p2, p3),
365 ];
366
367 tracker.record_sequences_batch(&sequences);
368
369 assert_eq!(tracker.total_sequences(), 4);
370
371 let predicted = tracker.predict_next(p1, 1);
372 assert_eq!(predicted[0], p2);
373 }
374}