use super::cluster::EdgeCluster;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub node_id: i64,
pub direction: super::cluster_trace::Direction,
}
impl CacheKey {
pub fn new(node_id: i64, direction: super::cluster_trace::Direction) -> Self {
Self { node_id, direction }
}
}
#[derive(Debug, Clone)]
pub struct AccessPatternTracker {
access_history: Vec<i64>,
max_history: usize,
}
impl AccessPatternTracker {
pub fn new(max_history: usize) -> Self {
Self {
access_history: Vec::with_capacity(max_history),
max_history,
}
}
pub fn record_access(&mut self, node_id: i64) -> AccessType {
self.access_history.push(node_id);
if self.access_history.len() > self.max_history {
self.access_history.remove(0);
}
if self.access_history.len() >= 2 {
let _last = self.access_history[self.access_history.len() - 2];
if self.is_traversal_pattern(node_id) {
return AccessType::Traversal;
}
}
AccessType::Lookup
}
fn is_traversal_pattern(&self, node_id: i64) -> bool {
self.access_history
.iter()
.filter(|&&id| id == node_id)
.count()
> 0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccessType {
Traversal,
Lookup,
}
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub data: Arc<EdgeCluster>,
pub access_count: u32,
pub last_access: Instant,
pub traversal_score: f64,
access_history: [Option<Instant>; 2],
}
impl CacheEntry {
pub fn new(data: Arc<EdgeCluster>) -> Self {
let now = Instant::now();
Self {
data,
access_count: 0,
last_access: now,
traversal_score: 0.0,
access_history: [None, None],
}
}
pub fn record_access(&mut self, access_type: AccessType) {
self.access_count += 1;
self.last_access = Instant::now();
self.access_history[1] = self.access_history[0];
self.access_history[0] = Some(Instant::now());
match access_type {
AccessType::Traversal => {
self.traversal_score += 1.0;
}
AccessType::Lookup => {
self.traversal_score += 0.1;
}
}
}
pub fn eviction_score(&self) -> f64 {
let recency_score = if let Some(most_recent) = self.access_history[0] {
1.0 / (most_recent.elapsed().as_secs_f64() + 1.0)
} else {
0.0
};
self.traversal_score * 10.0 + recency_score
}
pub fn is_high_degree(&self) -> bool {
self.data.edge_count() > 100
}
}
pub struct TraversalAwareCache {
entries: HashMap<CacheKey, CacheEntry>,
access_pattern: AccessPatternTracker,
max_capacity: usize,
stats: CacheStats,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub traversals: u64,
pub lookups: u64,
}
impl TraversalAwareCache {
pub fn new(max_capacity: usize) -> Self {
Self {
entries: HashMap::with_capacity(max_capacity),
access_pattern: AccessPatternTracker::new(100),
max_capacity,
stats: CacheStats::default(),
}
}
pub fn get(&mut self, key: CacheKey) -> Option<Arc<EdgeCluster>> {
let access_type = self.access_pattern.record_access(key.node_id);
match access_type {
AccessType::Traversal => self.stats.traversals += 1,
AccessType::Lookup => self.stats.lookups += 1,
}
if let Some(entry) = self.entries.get_mut(&key) {
self.stats.hits += 1;
entry.record_access(access_type);
return Some(Arc::clone(&entry.data));
}
self.stats.misses += 1;
None
}
pub fn insert(&mut self, key: CacheKey, cluster: Arc<EdgeCluster>) {
if let Some(entry) = self.entries.get_mut(&key) {
entry.data = cluster;
entry.record_access(AccessType::Lookup);
return;
}
if self.entries.len() >= self.max_capacity {
self.evict_one();
}
let entry = CacheEntry::new(cluster);
self.entries.insert(key, entry);
}
pub fn remove(&mut self, key: &CacheKey) -> Option<Arc<EdgeCluster>> {
self.entries.remove(key).map(|entry| entry.data)
}
fn evict_one(&mut self) {
if self.entries.is_empty() {
return;
}
let mut worst_key = None;
let mut worst_score = f64::MAX;
for (key, entry) in &self.entries {
let score = entry.eviction_score();
let adjusted_score = if entry.is_high_degree() {
score * 2.0 } else {
score
};
if adjusted_score < worst_score {
worst_score = adjusted_score;
worst_key = Some(*key);
}
}
if let Some(key) = worst_key {
self.entries.remove(&key);
}
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn hit_ratio(&self) -> f64 {
let total = self.stats.hits + self.stats.misses;
if total == 0 {
0.0
} else {
self.stats.hits as f64 / total as f64
}
}
}
pub struct ThreadSafeCache {
inner: Arc<RwLock<TraversalAwareCache>>,
}
impl ThreadSafeCache {
pub fn new(max_capacity: usize) -> Self {
Self {
inner: Arc::new(RwLock::new(TraversalAwareCache::new(max_capacity))),
}
}
pub fn get(&self, key: CacheKey) -> Option<Arc<EdgeCluster>> {
self.inner.write().get(key)
}
pub fn insert(&self, key: CacheKey, cluster: Arc<EdgeCluster>) {
self.inner.write().insert(key, cluster);
}
pub fn remove(&self, key: &CacheKey) -> Option<Arc<EdgeCluster>> {
self.inner.write().remove(key)
}
pub fn stats(&self) -> CacheStats {
self.inner.read().stats().clone()
}
pub fn hit_ratio(&self) -> f64 {
self.inner.read().hit_ratio()
}
pub fn inner(&self) -> &Arc<RwLock<TraversalAwareCache>> {
&self.inner
}
}
impl Clone for ThreadSafeCache {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::native::v2::edge_cluster::cluster_trace::Direction;
#[test]
fn test_cache_basics() {
let mut cache = TraversalAwareCache::new(3);
let key1 = CacheKey::new(1, Direction::Outgoing);
let key2 = CacheKey::new(2, Direction::Outgoing);
let cluster = Arc::new(
EdgeCluster::create_from_compact_edges(vec![], 1, Direction::Outgoing).unwrap(),
);
cache.insert(key1, Arc::clone(&cluster));
assert!(cache.get(key1).is_some());
assert!(cache.get(key2).is_none());
assert_eq!(cache.stats().hits, 1);
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_cache_eviction() {
let mut cache = TraversalAwareCache::new(2);
let key1 = CacheKey::new(1, Direction::Outgoing);
let key2 = CacheKey::new(2, Direction::Outgoing);
let key3 = CacheKey::new(3, Direction::Outgoing);
let cluster = Arc::new(
EdgeCluster::create_from_compact_edges(vec![], 1, Direction::Outgoing).unwrap(),
);
cache.insert(key1, Arc::clone(&cluster));
cache.insert(key2, Arc::clone(&cluster));
cache.insert(key3, Arc::clone(&cluster));
assert_eq!(cache.len(), 2);
}
#[test]
fn test_hit_ratio() {
let mut cache = TraversalAwareCache::new(10);
let key1 = CacheKey::new(1, Direction::Outgoing);
let cluster = Arc::new(
EdgeCluster::create_from_compact_edges(vec![], 1, Direction::Outgoing).unwrap(),
);
cache.insert(key1, Arc::clone(&cluster));
for _ in 0..5 {
cache.get(key1);
}
for i in 2..7 {
cache.get(CacheKey::new(i, Direction::Outgoing));
}
assert!((cache.hit_ratio() - 0.5).abs() < 0.01);
}
}