use crate::causal::CausalGraph;
use crate::long_term::LongTermStore;
use crate::types::{PatternId, Query, SearchResult};
use dashmap::DashMap;
use parking_lot::RwLock;
use std::collections::VecDeque;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum AnticipationHint {
SequentialPattern {
recent: Vec<PatternId>,
},
TemporalCycle {
phase: TemporalPhase,
},
CausalChain {
context: PatternId,
},
}
#[derive(Debug, Clone, Copy)]
pub enum TemporalPhase {
HourOfDay(u8),
DayOfWeek(u8),
Custom(u32),
}
pub struct PrefetchCache {
cache: DashMap<u64, Vec<SearchResult>>,
capacity: usize,
lru: Arc<RwLock<VecDeque<u64>>>,
}
impl PrefetchCache {
pub fn new(capacity: usize) -> Self {
Self {
cache: DashMap::new(),
capacity,
lru: Arc::new(RwLock::new(VecDeque::with_capacity(capacity))),
}
}
pub fn insert(&self, query_hash: u64, results: Vec<SearchResult>) {
if self.cache.len() >= self.capacity {
self.evict_lru();
}
self.cache.insert(query_hash, results);
let mut lru = self.lru.write();
lru.push_back(query_hash);
}
pub fn get(&self, query_hash: u64) -> Option<Vec<SearchResult>> {
self.cache.get(&query_hash).map(|v| v.clone())
}
fn evict_lru(&self) {
let mut lru = self.lru.write();
if let Some(key) = lru.pop_front() {
self.cache.remove(&key);
}
}
pub fn clear(&self) {
self.cache.clear();
self.lru.write().clear();
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
}
impl Default for PrefetchCache {
fn default() -> Self {
Self::new(1000)
}
}
pub struct SequentialPatternTracker {
frequency_cache: DashMap<PatternId, Vec<(usize, PatternId)>>,
counts: DashMap<(PatternId, PatternId), usize>,
cache_valid: DashMap<PatternId, bool>,
total_sequences: std::sync::atomic::AtomicUsize,
}
impl SequentialPatternTracker {
pub fn new() -> Self {
Self {
frequency_cache: DashMap::new(),
counts: DashMap::new(),
cache_valid: DashMap::new(),
total_sequences: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn record_sequence(&self, from: PatternId, to: PatternId) {
*self.counts.entry((from, to)).or_insert(0) += 1;
self.cache_valid.insert(from, false);
self.total_sequences
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn predict_next(&self, current: PatternId, top_k: usize) -> Vec<PatternId> {
let cache_valid = self.cache_valid.get(¤t).map(|v| *v).unwrap_or(false);
if !cache_valid {
self.rebuild_cache(current);
}
if let Some(sorted) = self.frequency_cache.get(¤t) {
sorted.iter().take(top_k).map(|(_, id)| *id).collect()
} else {
Vec::new()
}
}
fn rebuild_cache(&self, pattern: PatternId) {
let mut freq_vec: Vec<(usize, PatternId)> = Vec::new();
for entry in self.counts.iter() {
let (from, to) = *entry.key();
if from == pattern {
freq_vec.push((*entry.value(), to));
}
}
freq_vec.sort_by(|a, b| b.0.cmp(&a.0));
self.frequency_cache.insert(pattern, freq_vec);
self.cache_valid.insert(pattern, true);
}
pub fn total_sequences(&self) -> usize {
self.total_sequences
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn prediction_confidence(&self, pattern: PatternId) -> f32 {
if let Some(sorted) = self.frequency_cache.get(&pattern) {
if sorted.is_empty() {
return 0.0;
}
let total: usize = sorted.iter().map(|(c, _)| c).sum();
if total == 0 {
return 0.0;
}
sorted[0].0 as f32 / total as f32
} else {
0.0
}
}
pub fn record_sequences_batch(&self, sequences: &[(PatternId, PatternId)]) {
let mut invalidated = std::collections::HashSet::new();
for (from, to) in sequences {
*self.counts.entry((*from, *to)).or_insert(0) += 1;
invalidated.insert(*from);
}
for pattern in invalidated {
self.cache_valid.insert(pattern, false);
}
self.total_sequences
.fetch_add(sequences.len(), std::sync::atomic::Ordering::Relaxed);
}
}
impl Default for SequentialPatternTracker {
fn default() -> Self {
Self::new()
}
}
pub fn anticipate(
hints: &[AnticipationHint],
long_term: &LongTermStore,
causal_graph: &CausalGraph,
prefetch_cache: &PrefetchCache,
sequential_tracker: &SequentialPatternTracker,
) -> usize {
let mut num_prefetched = 0;
for hint in hints {
match hint {
AnticipationHint::SequentialPattern { recent } => {
if let Some(&last) = recent.last() {
let predicted = sequential_tracker.predict_next(last, 5);
for pattern_id in predicted {
if let Some(temporal_pattern) = long_term.get(&pattern_id) {
let query =
Query::from_embedding(temporal_pattern.pattern.embedding.clone());
let query_hash = query.hash();
if prefetch_cache.get(query_hash).is_none() {
let results = long_term.search(&query);
prefetch_cache.insert(query_hash, results);
num_prefetched += 1;
}
}
}
}
}
AnticipationHint::TemporalCycle { phase } => {
let phase_ratio = match phase {
TemporalPhase::HourOfDay(h) => *h as f64 / 24.0,
TemporalPhase::DayOfWeek(d) => *d as f64 / 7.0,
TemporalPhase::Custom(c) => (*c as f64 % 1000.0) / 1000.0,
};
let dim = 32usize;
let query_vec: Vec<f32> = (0..dim)
.map(|i| {
let angle =
2.0 * std::f64::consts::PI * phase_ratio * (i + 1) as f64 / dim as f64;
angle.sin() as f32
})
.collect();
let query = Query::from_embedding(query_vec);
let query_hash = query.hash();
if prefetch_cache.get(query_hash).is_none() {
let results = long_term.search(&query);
if !results.is_empty() {
prefetch_cache.insert(query_hash, results);
num_prefetched += 1;
}
}
}
AnticipationHint::CausalChain { context } => {
let downstream = causal_graph.causal_future(*context);
for pattern_id in downstream.into_iter().take(5) {
if let Some(temporal_pattern) = long_term.get(&pattern_id) {
let query =
Query::from_embedding(temporal_pattern.pattern.embedding.clone());
let query_hash = query.hash();
if prefetch_cache.get(query_hash).is_none() {
let results = long_term.search(&query);
prefetch_cache.insert(query_hash, results);
num_prefetched += 1;
}
}
}
}
}
}
num_prefetched
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefetch_cache() {
let cache = PrefetchCache::new(2);
let results1 = vec![];
let results2 = vec![];
cache.insert(1, results1);
cache.insert(2, results2);
assert_eq!(cache.len(), 2);
assert!(cache.get(1).is_some());
cache.insert(3, vec![]);
assert_eq!(cache.len(), 2);
assert!(cache.get(1).is_none());
}
#[test]
fn test_sequential_tracker() {
let tracker = SequentialPatternTracker::new();
let p1 = PatternId::new();
let p2 = PatternId::new();
let p3 = PatternId::new();
tracker.record_sequence(p1, p2);
tracker.record_sequence(p1, p2);
tracker.record_sequence(p1, p3);
let predicted = tracker.predict_next(p1, 2);
assert_eq!(predicted.len(), 2);
assert_eq!(predicted[0], p2);
assert_eq!(tracker.total_sequences(), 3);
let confidence = tracker.prediction_confidence(p1);
assert!(confidence > 0.6); }
#[test]
fn test_batch_recording() {
let tracker = SequentialPatternTracker::new();
let p1 = PatternId::new();
let p2 = PatternId::new();
let p3 = PatternId::new();
let sequences = vec![(p1, p2), (p1, p2), (p1, p3), (p2, p3)];
tracker.record_sequences_batch(&sequences);
assert_eq!(tracker.total_sequences(), 4);
let predicted = tracker.predict_next(p1, 1);
assert_eq!(predicted[0], p2);
}
}