use super::lru::LruCache;
use crate::types::{ETypeId, Edge, NodeId, TraversalCacheConfig};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TraversalDirection {
Out,
In,
}
type TraversalKey = u64;
const ALL_ETYPES: u64 = 0x3FF;
#[derive(Debug, Clone)]
pub struct CachedNeighbors {
pub neighbors: Vec<Edge>,
pub truncated: bool,
}
#[derive(Debug, Clone, Default)]
pub struct TraversalCacheStats {
pub hits: u64,
pub misses: u64,
pub cache_size: usize,
pub max_cache_size: usize,
}
impl TraversalCacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total > 0 {
self.hits as f64 / total as f64
} else {
0.0
}
}
}
pub struct TraversalCache {
cache: LruCache<TraversalKey, CachedNeighbors>,
max_neighbors_per_entry: usize,
node_key_index: HashMap<NodeId, HashSet<TraversalKey>>,
hits: u64,
misses: u64,
}
impl TraversalCache {
pub fn new(config: TraversalCacheConfig) -> Self {
Self {
cache: LruCache::new(config.max_entries),
max_neighbors_per_entry: config.max_neighbors_per_entry,
node_key_index: HashMap::new(),
hits: 0,
misses: 0,
}
}
pub fn get(
&mut self,
node_id: NodeId,
etype: Option<ETypeId>,
direction: TraversalDirection,
) -> Option<&CachedNeighbors> {
let key = Self::make_key(node_id, etype, direction);
let result = self.cache.get(&key);
if result.is_some() {
self.hits += 1;
} else {
self.misses += 1;
}
result
}
pub fn peek(
&self,
node_id: NodeId,
etype: Option<ETypeId>,
direction: TraversalDirection,
) -> Option<&CachedNeighbors> {
let key = Self::make_key(node_id, etype, direction);
self.cache.peek(&key)
}
pub fn set(
&mut self,
node_id: NodeId,
etype: Option<ETypeId>,
direction: TraversalDirection,
neighbors: Vec<Edge>,
) {
let key = Self::make_key(node_id, etype, direction);
let (cached_neighbors, truncated) = if neighbors.len() > self.max_neighbors_per_entry {
(
neighbors
.into_iter()
.take(self.max_neighbors_per_entry)
.collect(),
true,
)
} else {
(neighbors, false)
};
self.add_to_node_index(node_id, key);
for edge in &cached_neighbors {
let dest_id = match direction {
TraversalDirection::Out => edge.dst,
TraversalDirection::In => edge.src,
};
self.add_to_node_index(dest_id, key);
}
self.cache.set(
key,
CachedNeighbors {
neighbors: cached_neighbors,
truncated,
},
);
}
pub fn invalidate_node(&mut self, node_id: NodeId) {
if let Some(keys) = self.node_key_index.remove(&node_id) {
for key in keys {
self.cache.delete(&key);
}
}
}
pub fn invalidate_edge(&mut self, src: NodeId, etype: ETypeId, dst: NodeId) {
self.invalidate_node_traversals(src, TraversalDirection::Out, etype);
self.invalidate_node_traversals(dst, TraversalDirection::In, etype);
}
pub fn clear(&mut self) {
self.cache.clear();
self.node_key_index.clear();
self.hits = 0;
self.misses = 0;
}
pub fn stats(&self) -> TraversalCacheStats {
TraversalCacheStats {
hits: self.hits,
misses: self.misses,
cache_size: self.cache.len(),
max_cache_size: self.cache.max_size(),
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn reset_stats(&mut self) {
self.hits = 0;
self.misses = 0;
}
fn make_key(
node_id: NodeId,
etype: Option<ETypeId>,
direction: TraversalDirection,
) -> TraversalKey {
let etype_val = etype.map(|e| e as u64).unwrap_or(ALL_ETYPES);
let dir_val = match direction {
TraversalDirection::Out => 0u64,
TraversalDirection::In => 1u64,
};
(node_id << 11) | (etype_val << 1) | dir_val
}
fn add_to_node_index(&mut self, node_id: NodeId, key: TraversalKey) {
self.node_key_index.entry(node_id).or_default().insert(key);
}
fn invalidate_node_traversals(
&mut self,
node_id: NodeId,
direction: TraversalDirection,
etype: ETypeId,
) {
let Some(keys) = self.node_key_index.get_mut(&node_id) else {
return;
};
let specific_key = Self::make_key(node_id, Some(etype), direction);
let all_key = Self::make_key(node_id, None, direction);
let mut keys_to_delete = Vec::new();
if keys.contains(&specific_key) {
keys_to_delete.push(specific_key);
}
if keys.contains(&all_key) {
keys_to_delete.push(all_key);
}
for key in keys_to_delete {
self.cache.delete(&key);
keys.remove(&key);
}
if keys.is_empty() {
self.node_key_index.remove(&node_id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cache() -> TraversalCache {
TraversalCache::new(TraversalCacheConfig {
max_entries: 100,
max_neighbors_per_entry: 10,
})
}
fn make_edge(src: NodeId, etype: ETypeId, dst: NodeId) -> Edge {
Edge { src, etype, dst }
}
#[test]
fn test_new_cache() {
let cache = make_cache();
assert!(cache.is_empty());
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
}
#[test]
fn test_cache_miss() {
let mut cache = make_cache();
assert!(cache.get(1, Some(1), TraversalDirection::Out).is_none());
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_cache_hit() {
let mut cache = make_cache();
let neighbors = vec![make_edge(1, 1, 2), make_edge(1, 1, 3)];
cache.set(1, Some(1), TraversalDirection::Out, neighbors.clone());
let result = cache.get(1, Some(1), TraversalDirection::Out);
assert!(result.is_some());
let cached = result.expect("expected value");
assert_eq!(cached.neighbors.len(), 2);
assert!(!cached.truncated);
assert_eq!(cache.stats().hits, 1);
}
#[test]
fn test_cache_all_etypes() {
let mut cache = make_cache();
let neighbors = vec![make_edge(1, 1, 2), make_edge(1, 2, 3)];
cache.set(1, None, TraversalDirection::Out, neighbors);
assert!(cache.get(1, None, TraversalDirection::Out).is_some());
assert!(cache.get(1, Some(1), TraversalDirection::Out).is_none());
}
#[test]
fn test_different_directions() {
let mut cache = make_cache();
let out_neighbors = vec![make_edge(1, 1, 2)];
let in_neighbors = vec![make_edge(3, 1, 1)];
cache.set(1, Some(1), TraversalDirection::Out, out_neighbors);
cache.set(1, Some(1), TraversalDirection::In, in_neighbors);
let out_result = cache.get(1, Some(1), TraversalDirection::Out);
assert!(out_result.is_some());
assert_eq!(out_result.expect("expected value").neighbors[0].dst, 2);
let in_result = cache.get(1, Some(1), TraversalDirection::In);
assert!(in_result.is_some());
assert_eq!(in_result.expect("expected value").neighbors[0].src, 3);
}
#[test]
fn test_truncation() {
let mut cache = TraversalCache::new(TraversalCacheConfig {
max_entries: 100,
max_neighbors_per_entry: 3,
});
let neighbors = vec![
make_edge(1, 1, 2),
make_edge(1, 1, 3),
make_edge(1, 1, 4),
make_edge(1, 1, 5),
make_edge(1, 1, 6),
];
cache.set(1, Some(1), TraversalDirection::Out, neighbors);
let result = cache
.get(1, Some(1), TraversalDirection::Out)
.expect("expected value");
assert_eq!(result.neighbors.len(), 3);
assert!(result.truncated);
}
#[test]
fn test_invalidate_node() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2)],
);
cache.set(
2,
Some(1),
TraversalDirection::Out,
vec![make_edge(2, 1, 3)],
);
assert_eq!(cache.len(), 2);
cache.invalidate_node(1);
assert!(cache.get(1, Some(1), TraversalDirection::Out).is_none());
let stats_before = cache.stats().hits;
assert!(cache.get(2, Some(1), TraversalDirection::Out).is_some());
assert_eq!(cache.stats().hits, stats_before + 1);
}
#[test]
fn test_invalidate_node_as_destination() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2), make_edge(1, 1, 3)],
);
cache.invalidate_node(2);
assert!(cache.get(1, Some(1), TraversalDirection::Out).is_none());
}
#[test]
fn test_invalidate_edge() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2)],
);
cache.set(2, Some(1), TraversalDirection::In, vec![make_edge(1, 1, 2)]);
cache.set(
3,
Some(1),
TraversalDirection::Out,
vec![make_edge(3, 1, 4)],
);
cache.invalidate_edge(1, 1, 2);
assert!(cache.peek(1, Some(1), TraversalDirection::Out).is_none());
assert!(cache.peek(2, Some(1), TraversalDirection::In).is_none());
assert!(cache.peek(3, Some(1), TraversalDirection::Out).is_some());
}
#[test]
fn test_invalidate_edge_all_etypes() {
let mut cache = make_cache();
cache.set(
1,
None,
TraversalDirection::Out,
vec![make_edge(1, 1, 2), make_edge(1, 2, 3)],
);
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2)],
);
cache.invalidate_edge(1, 1, 2);
assert!(cache.peek(1, Some(1), TraversalDirection::Out).is_none());
assert!(cache.peek(1, None, TraversalDirection::Out).is_none());
}
#[test]
fn test_clear() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2)],
);
cache.get(1, Some(1), TraversalDirection::Out);
cache.get(999, Some(1), TraversalDirection::Out);
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
}
#[test]
fn test_peek_does_not_update_lru() {
let mut cache = TraversalCache::new(TraversalCacheConfig {
max_entries: 2,
max_neighbors_per_entry: 10,
});
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 10)],
);
cache.set(
2,
Some(1),
TraversalDirection::Out,
vec![make_edge(2, 1, 20)],
);
cache.peek(1, Some(1), TraversalDirection::Out);
cache.set(
3,
Some(1),
TraversalDirection::Out,
vec![make_edge(3, 1, 30)],
);
assert!(cache.peek(1, Some(1), TraversalDirection::Out).is_none());
assert!(cache.peek(2, Some(1), TraversalDirection::Out).is_some());
assert!(cache.peek(3, Some(1), TraversalDirection::Out).is_some());
}
#[test]
fn test_updates_lru() {
let mut cache = TraversalCache::new(TraversalCacheConfig {
max_entries: 2,
max_neighbors_per_entry: 10,
});
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 10)],
);
cache.set(
2,
Some(1),
TraversalDirection::Out,
vec![make_edge(2, 1, 20)],
);
cache.get(1, Some(1), TraversalDirection::Out);
cache.set(
3,
Some(1),
TraversalDirection::Out,
vec![make_edge(3, 1, 30)],
);
assert!(cache.peek(1, Some(1), TraversalDirection::Out).is_some());
assert!(cache.peek(2, Some(1), TraversalDirection::Out).is_none());
assert!(cache.peek(3, Some(1), TraversalDirection::Out).is_some());
}
#[test]
fn test_stats() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2)],
);
cache.get(1, Some(1), TraversalDirection::Out);
cache.get(1, Some(1), TraversalDirection::Out);
cache.get(999, Some(1), TraversalDirection::Out);
cache.get(1, Some(2), TraversalDirection::Out);
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 2);
assert_eq!(stats.hit_rate(), 0.5);
}
#[test]
fn test_reset_stats() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2)],
);
cache.get(1, Some(1), TraversalDirection::Out);
cache.get(999, Some(1), TraversalDirection::Out);
cache.reset_stats();
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
assert_eq!(cache.len(), 1);
}
#[test]
fn test_key_uniqueness() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 100)],
);
cache.set(
2,
Some(1),
TraversalDirection::Out,
vec![make_edge(2, 1, 200)],
);
cache.set(
1,
Some(2),
TraversalDirection::Out,
vec![make_edge(1, 2, 300)],
);
cache.set(
1,
Some(1),
TraversalDirection::In,
vec![make_edge(100, 1, 1)],
);
assert_eq!(cache.len(), 4);
assert_eq!(
cache
.peek(1, Some(1), TraversalDirection::Out)
.expect("expected value")
.neighbors[0]
.dst,
100
);
assert_eq!(
cache
.peek(2, Some(1), TraversalDirection::Out)
.expect("expected value")
.neighbors[0]
.dst,
200
);
assert_eq!(
cache
.peek(1, Some(2), TraversalDirection::Out)
.expect("expected value")
.neighbors[0]
.dst,
300
);
assert_eq!(
cache
.peek(1, Some(1), TraversalDirection::In)
.expect("expected value")
.neighbors[0]
.src,
100
);
}
#[test]
fn test_empty_neighbors() {
let mut cache = make_cache();
cache.set(1, Some(1), TraversalDirection::Out, vec![]);
let result = cache.get(1, Some(1), TraversalDirection::Out);
assert!(result.is_some());
assert!(result.expect("expected value").neighbors.is_empty());
assert!(!result.expect("expected value").truncated);
}
#[test]
fn test_invalidate_nonexistent_node() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2)],
);
cache.invalidate_node(999);
assert!(cache.peek(1, Some(1), TraversalDirection::Out).is_some());
}
#[test]
fn test_invalidate_nonexistent_edge() {
let mut cache = make_cache();
cache.set(
1,
Some(1),
TraversalDirection::Out,
vec![make_edge(1, 1, 2)],
);
cache.invalidate_edge(999, 1, 888);
assert!(cache.peek(1, Some(1), TraversalDirection::Out).is_some());
}
}