use crate::graph::VertexId;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::RwLock;
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub enable_prefetch: bool,
pub prefetch_history_size: usize,
pub prefetch_lookahead: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
enable_prefetch: true,
prefetch_history_size: 100,
prefetch_lookahead: 4,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub size: usize,
pub prefetch_hits: u64,
pub evictions: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total > 0 {
self.hits as f64 / total as f64
} else {
0.0
}
}
}
#[derive(Debug, Clone)]
pub struct PrefetchHint {
pub source: VertexId,
pub targets: Vec<VertexId>,
pub confidence: f64,
}
#[derive(Debug, Clone)]
struct CacheEntry {
source: VertexId,
target: VertexId,
distance: f64,
last_access: u64,
prefetched: bool,
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
struct CacheKey {
source: VertexId,
target: VertexId,
}
impl CacheKey {
fn new(source: VertexId, target: VertexId) -> Self {
if source <= target {
Self { source, target }
} else {
Self {
source: target,
target: source,
}
}
}
}
pub struct PathDistanceCache {
config: CacheConfig,
cache: RwLock<HashMap<CacheKey, CacheEntry>>,
lru_order: RwLock<VecDeque<CacheKey>>,
access_counter: AtomicU64,
hits: AtomicU64,
misses: AtomicU64,
prefetch_hits: AtomicU64,
evictions: AtomicU64,
query_history: RwLock<VecDeque<CacheKey>>,
predicted_queries: RwLock<Vec<CacheKey>>,
}
impl PathDistanceCache {
pub fn new() -> Self {
Self::with_config(CacheConfig::default())
}
pub fn with_config(config: CacheConfig) -> Self {
Self {
config,
cache: RwLock::new(HashMap::new()),
lru_order: RwLock::new(VecDeque::new()),
access_counter: AtomicU64::new(0),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
prefetch_hits: AtomicU64::new(0),
evictions: AtomicU64::new(0),
query_history: RwLock::new(VecDeque::new()),
predicted_queries: RwLock::new(Vec::new()),
}
}
pub fn get(&self, source: VertexId, target: VertexId) -> Option<f64> {
let key = CacheKey::new(source, target);
let cache = self.cache.read().unwrap();
if let Some(entry) = cache.get(&key) {
self.hits.fetch_add(1, Ordering::Relaxed);
if entry.prefetched {
self.prefetch_hits.fetch_add(1, Ordering::Relaxed);
}
if self.config.enable_prefetch {
self.record_query(key);
}
return Some(entry.distance);
}
drop(cache);
self.misses.fetch_add(1, Ordering::Relaxed);
if self.config.enable_prefetch {
self.record_query(key);
}
None
}
pub fn insert(&self, source: VertexId, target: VertexId, distance: f64) {
let key = CacheKey::new(source, target);
let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
let entry = CacheEntry {
source,
target,
distance,
last_access: timestamp,
prefetched: false,
};
self.insert_entry(key, entry);
}
pub fn insert_prefetch(&self, source: VertexId, target: VertexId, distance: f64) {
let key = CacheKey::new(source, target);
let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
let entry = CacheEntry {
source,
target,
distance,
last_access: timestamp,
prefetched: true,
};
self.insert_entry(key, entry);
}
fn insert_entry(&self, key: CacheKey, entry: CacheEntry) {
let mut cache = self.cache.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
while cache.len() >= self.config.max_entries {
if let Some(evict_key) = lru.pop_front() {
cache.remove(&evict_key);
self.evictions.fetch_add(1, Ordering::Relaxed);
} else {
break;
}
}
cache.insert(key, entry);
lru.push_back(key);
}
pub fn insert_batch(&self, entries: &[(VertexId, VertexId, f64)]) {
let mut cache = self.cache.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
for &(source, target, distance) in entries {
let key = CacheKey::new(source, target);
let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
let entry = CacheEntry {
source,
target,
distance,
last_access: timestamp,
prefetched: false,
};
while cache.len() >= self.config.max_entries {
if let Some(evict_key) = lru.pop_front() {
cache.remove(&evict_key);
self.evictions.fetch_add(1, Ordering::Relaxed);
} else {
break;
}
}
cache.insert(key, entry);
lru.push_back(key);
}
}
pub fn invalidate_vertex(&self, vertex: VertexId) {
let mut cache = self.cache.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
let keys_to_remove: Vec<CacheKey> = cache
.keys()
.filter(|k| k.source == vertex || k.target == vertex)
.copied()
.collect();
for key in keys_to_remove {
cache.remove(&key);
lru.retain(|k| *k != key);
}
}
pub fn clear(&self) {
let mut cache = self.cache.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
cache.clear();
lru.clear();
}
fn record_query(&self, key: CacheKey) {
if let Ok(mut history) = self.query_history.try_write() {
history.push_back(key);
while history.len() > self.config.prefetch_history_size {
history.pop_front();
}
if history.len() % 10 == 0 {
self.update_predictions(&history);
}
}
}
fn update_predictions(&self, history: &VecDeque<CacheKey>) {
if history.len() < 10 {
return;
}
let mut vertex_frequency: HashMap<VertexId, usize> = HashMap::new();
for key in history.iter() {
*vertex_frequency.entry(key.source).or_insert(0) += 1;
*vertex_frequency.entry(key.target).or_insert(0) += 1;
}
let recent: Vec<_> = history.iter().rev().take(5).collect();
let mut predictions = Vec::new();
for key in recent {
for (vertex, &freq) in &vertex_frequency {
if freq > 2 && *vertex != key.source && *vertex != key.target {
predictions.push(CacheKey::new(key.source, *vertex));
if predictions.len() >= self.config.prefetch_lookahead {
break;
}
}
}
if predictions.len() >= self.config.prefetch_lookahead {
break;
}
}
if let Ok(mut pred) = self.predicted_queries.try_write() {
*pred = predictions;
}
}
pub fn get_prefetch_hints(&self) -> Vec<PrefetchHint> {
let history = self.query_history.read().unwrap();
if history.is_empty() {
return Vec::new();
}
let mut source_freq: HashMap<VertexId, Vec<VertexId>> = HashMap::new();
for key in history.iter() {
source_freq.entry(key.source).or_default().push(key.target);
source_freq.entry(key.target).or_default().push(key.source);
}
source_freq
.into_iter()
.filter(|(_, targets)| targets.len() > 2)
.map(|(source, targets)| {
let confidence = (targets.len() as f64 / history.len() as f64).min(1.0);
PrefetchHint {
source,
targets,
confidence,
}
})
.collect()
}
pub fn get_predicted_queries(&self) -> Vec<(VertexId, VertexId)> {
let pred = self.predicted_queries.read().unwrap();
pred.iter().map(|key| (key.source, key.target)).collect()
}
pub fn stats(&self) -> CacheStats {
let cache = self.cache.read().unwrap();
CacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
size: cache.len(),
prefetch_hits: self.prefetch_hits.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
}
}
pub fn len(&self) -> usize {
self.cache.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.cache.read().unwrap().is_empty()
}
}
impl Default for PathDistanceCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_cache_operations() {
let cache = PathDistanceCache::new();
cache.insert(1, 2, 10.0);
assert_eq!(cache.get(1, 2), Some(10.0));
assert_eq!(cache.get(2, 1), Some(10.0));
assert_eq!(cache.get(1, 3), None);
}
#[test]
fn test_lru_eviction() {
let cache = PathDistanceCache::with_config(CacheConfig {
max_entries: 3,
..Default::default()
});
cache.insert(1, 2, 1.0);
cache.insert(2, 3, 2.0);
cache.insert(3, 4, 3.0);
assert_eq!(cache.len(), 3);
cache.insert(4, 5, 4.0);
assert_eq!(cache.len(), 3);
assert_eq!(cache.get(1, 2), None); assert_eq!(cache.get(4, 5), Some(4.0)); }
#[test]
fn test_batch_insert() {
let cache = PathDistanceCache::new();
let entries = vec![(1, 2, 1.0), (2, 3, 2.0), (3, 4, 3.0)];
cache.insert_batch(&entries);
assert_eq!(cache.len(), 3);
assert_eq!(cache.get(1, 2), Some(1.0));
assert_eq!(cache.get(2, 3), Some(2.0));
assert_eq!(cache.get(3, 4), Some(3.0));
}
#[test]
fn test_invalidate_vertex() {
let cache = PathDistanceCache::new();
cache.insert(1, 2, 1.0);
cache.insert(1, 3, 2.0);
cache.insert(2, 3, 3.0);
cache.invalidate_vertex(1);
assert_eq!(cache.get(1, 2), None);
assert_eq!(cache.get(1, 3), None);
assert_eq!(cache.get(2, 3), Some(3.0));
}
#[test]
fn test_statistics() {
let cache = PathDistanceCache::new();
cache.insert(1, 2, 1.0);
cache.get(1, 2);
cache.get(1, 2);
cache.get(3, 4);
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert_eq!(stats.size, 1);
assert!(stats.hit_rate() > 0.5);
}
#[test]
fn test_prefetch_hints() {
let cache = PathDistanceCache::with_config(CacheConfig {
enable_prefetch: true,
prefetch_history_size: 50,
..Default::default()
});
for i in 0..20 {
cache.insert(1, i as u64, i as f64);
let _ = cache.get(1, i as u64);
}
let hints = cache.get_prefetch_hints();
assert!(!hints.is_empty() || cache.stats().hits > 0);
}
#[test]
fn test_clear() {
let cache = PathDistanceCache::new();
cache.insert(1, 2, 1.0);
cache.insert(2, 3, 2.0);
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
}