1use crate::causal::CausalGraph;
4use crate::long_term::LongTermStore;
5use crate::types::{PatternId, Query, SearchResult};
6use dashmap::DashMap;
7use parking_lot::RwLock;
8use std::collections::VecDeque;
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub enum AnticipationHint {
14 SequentialPattern {
16 recent: Vec<PatternId>,
18 },
19 TemporalCycle {
21 phase: TemporalPhase,
23 },
24 CausalChain {
26 context: PatternId,
28 },
29}
30
31#[derive(Debug, Clone, Copy)]
33pub enum TemporalPhase {
34 HourOfDay(u8),
36 DayOfWeek(u8),
38 Custom(u32),
40}
41
42pub struct PrefetchCache {
44 cache: DashMap<u64, Vec<SearchResult>>,
46 capacity: usize,
48 lru: Arc<RwLock<VecDeque<u64>>>,
50}
51
52impl PrefetchCache {
53 pub fn new(capacity: usize) -> Self {
55 Self {
56 cache: DashMap::new(),
57 capacity,
58 lru: Arc::new(RwLock::new(VecDeque::with_capacity(capacity))),
59 }
60 }
61
62 pub fn insert(&self, query_hash: u64, results: Vec<SearchResult>) {
64 if self.cache.len() >= self.capacity {
66 self.evict_lru();
67 }
68
69 self.cache.insert(query_hash, results);
71
72 let mut lru = self.lru.write();
74 lru.push_back(query_hash);
75 }
76
77 pub fn get(&self, query_hash: u64) -> Option<Vec<SearchResult>> {
79 self.cache.get(&query_hash).map(|v| v.clone())
80 }
81
82 fn evict_lru(&self) {
84 let mut lru = self.lru.write();
85 if let Some(key) = lru.pop_front() {
86 self.cache.remove(&key);
87 }
88 }
89
90 pub fn clear(&self) {
92 self.cache.clear();
93 self.lru.write().clear();
94 }
95
96 pub fn len(&self) -> usize {
98 self.cache.len()
99 }
100
101 pub fn is_empty(&self) -> bool {
103 self.cache.is_empty()
104 }
105}
106
107impl Default for PrefetchCache {
108 fn default() -> Self {
109 Self::new(1000)
110 }
111}
112
113pub struct SequentialPatternTracker {
115 frequency_cache: DashMap<PatternId, Vec<(usize, PatternId)>>,
118 counts: DashMap<(PatternId, PatternId), usize>,
120 cache_valid: DashMap<PatternId, bool>,
122 total_sequences: std::sync::atomic::AtomicUsize,
124}
125
126impl SequentialPatternTracker {
127 pub fn new() -> Self {
129 Self {
130 frequency_cache: DashMap::new(),
131 counts: DashMap::new(),
132 cache_valid: DashMap::new(),
133 total_sequences: std::sync::atomic::AtomicUsize::new(0),
134 }
135 }
136
137 pub fn record_sequence(&self, from: PatternId, to: PatternId) {
139 *self.counts.entry((from, to)).or_insert(0) += 1;
141
142 self.cache_valid.insert(from, false);
144
145 self.total_sequences
147 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
148 }
149
150 pub fn predict_next(&self, current: PatternId, top_k: usize) -> Vec<PatternId> {
152 let cache_valid = self.cache_valid.get(¤t).map(|v| *v).unwrap_or(false);
154
155 if !cache_valid {
156 self.rebuild_cache(current);
158 }
159
160 if let Some(sorted) = self.frequency_cache.get(¤t) {
162 sorted.iter().take(top_k).map(|(_, id)| *id).collect()
163 } else {
164 Vec::new()
165 }
166 }
167
168 fn rebuild_cache(&self, pattern: PatternId) {
170 let mut freq_vec: Vec<(usize, PatternId)> = Vec::new();
171
172 for entry in self.counts.iter() {
174 let (from, to) = *entry.key();
175 if from == pattern {
176 freq_vec.push((*entry.value(), to));
177 }
178 }
179
180 freq_vec.sort_by(|a, b| b.0.cmp(&a.0));
182
183 self.frequency_cache.insert(pattern, freq_vec);
185 self.cache_valid.insert(pattern, true);
186 }
187
188 pub fn total_sequences(&self) -> usize {
190 self.total_sequences
191 .load(std::sync::atomic::Ordering::Relaxed)
192 }
193
194 pub fn prediction_confidence(&self, pattern: PatternId) -> f32 {
196 if let Some(sorted) = self.frequency_cache.get(&pattern) {
197 if sorted.is_empty() {
198 return 0.0;
199 }
200 let total: usize = sorted.iter().map(|(c, _)| c).sum();
201 if total == 0 {
202 return 0.0;
203 }
204 sorted[0].0 as f32 / total as f32
206 } else {
207 0.0
208 }
209 }
210
211 pub fn record_sequences_batch(&self, sequences: &[(PatternId, PatternId)]) {
213 let mut invalidated = std::collections::HashSet::new();
214
215 for (from, to) in sequences {
216 *self.counts.entry((*from, *to)).or_insert(0) += 1;
217 invalidated.insert(*from);
218 }
219
220 for pattern in invalidated {
222 self.cache_valid.insert(pattern, false);
223 }
224
225 self.total_sequences
226 .fetch_add(sequences.len(), std::sync::atomic::Ordering::Relaxed);
227 }
228}
229
230impl Default for SequentialPatternTracker {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236pub fn anticipate(
238 hints: &[AnticipationHint],
239 long_term: &LongTermStore,
240 causal_graph: &CausalGraph,
241 prefetch_cache: &PrefetchCache,
242 sequential_tracker: &SequentialPatternTracker,
243) -> usize {
244 let mut num_prefetched = 0;
245
246 for hint in hints {
247 match hint {
248 AnticipationHint::SequentialPattern { recent } => {
249 if let Some(&last) = recent.last() {
251 let predicted = sequential_tracker.predict_next(last, 5);
252
253 for pattern_id in predicted {
254 if let Some(temporal_pattern) = long_term.get(&pattern_id) {
255 let query =
257 Query::from_embedding(temporal_pattern.pattern.embedding.clone());
258 let query_hash = query.hash();
259
260 if prefetch_cache.get(query_hash).is_none() {
262 let results = long_term.search(&query);
263 prefetch_cache.insert(query_hash, results);
264 num_prefetched += 1;
265 }
266 }
267 }
268 }
269 }
270
271 AnticipationHint::TemporalCycle { phase } => {
272 let phase_ratio = match phase {
275 TemporalPhase::HourOfDay(h) => *h as f64 / 24.0,
276 TemporalPhase::DayOfWeek(d) => *d as f64 / 7.0,
277 TemporalPhase::Custom(c) => (*c as f64 % 1000.0) / 1000.0,
278 };
279
280 let dim = 32usize;
282 let query_vec: Vec<f32> = (0..dim)
283 .map(|i| {
284 let angle =
285 2.0 * std::f64::consts::PI * phase_ratio * (i + 1) as f64 / dim as f64;
286 angle.sin() as f32
287 })
288 .collect();
289
290 let query = Query::from_embedding(query_vec);
291 let query_hash = query.hash();
292
293 if prefetch_cache.get(query_hash).is_none() {
294 let results = long_term.search(&query);
295 if !results.is_empty() {
296 prefetch_cache.insert(query_hash, results);
297 num_prefetched += 1;
298 }
299 }
300 }
301
302 AnticipationHint::CausalChain { context } => {
303 let downstream = causal_graph.causal_future(*context);
305
306 for pattern_id in downstream.into_iter().take(5) {
307 if let Some(temporal_pattern) = long_term.get(&pattern_id) {
308 let query =
309 Query::from_embedding(temporal_pattern.pattern.embedding.clone());
310 let query_hash = query.hash();
311
312 if prefetch_cache.get(query_hash).is_none() {
314 let results = long_term.search(&query);
315 prefetch_cache.insert(query_hash, results);
316 num_prefetched += 1;
317 }
318 }
319 }
320 }
321 }
322 }
323
324 num_prefetched
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_prefetch_cache() {
333 let cache = PrefetchCache::new(2);
334
335 let results1 = vec![];
336 let results2 = vec![];
337
338 cache.insert(1, results1);
339 cache.insert(2, results2);
340
341 assert_eq!(cache.len(), 2);
342 assert!(cache.get(1).is_some());
343
344 cache.insert(3, vec![]);
346 assert_eq!(cache.len(), 2);
347 assert!(cache.get(1).is_none());
348 }
349
350 #[test]
351 fn test_sequential_tracker() {
352 let tracker = SequentialPatternTracker::new();
353
354 let p1 = PatternId::new();
355 let p2 = PatternId::new();
356 let p3 = PatternId::new();
357
358 tracker.record_sequence(p1, p2);
360 tracker.record_sequence(p1, p2);
361
362 tracker.record_sequence(p1, p3);
364
365 let predicted = tracker.predict_next(p1, 2);
366
367 assert_eq!(predicted.len(), 2);
369 assert_eq!(predicted[0], p2);
370
371 assert_eq!(tracker.total_sequences(), 3);
373
374 let confidence = tracker.prediction_confidence(p1);
376 assert!(confidence > 0.6); }
378
379 #[test]
380 fn test_batch_recording() {
381 let tracker = SequentialPatternTracker::new();
382
383 let p1 = PatternId::new();
384 let p2 = PatternId::new();
385 let p3 = PatternId::new();
386
387 let sequences = vec![(p1, p2), (p1, p2), (p1, p3), (p2, p3)];
388
389 tracker.record_sequences_batch(&sequences);
390
391 assert_eq!(tracker.total_sequences(), 4);
392
393 let predicted = tracker.predict_next(p1, 1);
394 assert_eq!(predicted[0], p2);
395 }
396}