use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use crate::backend::BackendDirection;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BfsCacheKey {
pub start: i64,
pub depth: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KHopCacheKey {
pub start: i64,
pub depth: u32,
pub direction: BackendDirection,
}
#[derive(Debug, Clone)]
pub struct KHopFilteredCacheKey {
pub start: i64,
pub depth: u32,
pub direction: BackendDirection,
pub allowed_edge_types: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShortestPathCacheKey {
pub start: i64,
pub end: i64,
}
#[derive(Debug, Clone)]
pub enum QueryCacheKey {
Bfs(BfsCacheKey),
KHop(KHopCacheKey),
KHopFiltered(KHopFilteredCacheKey),
ShortestPath(ShortestPathCacheKey),
}
#[derive(Debug, Clone)]
pub struct QueryCacheEntry {
pub result: QueryResult,
}
#[derive(Debug, Clone)]
pub enum QueryResult {
Bfs(Vec<i64>),
KHop(Vec<i64>),
ShortestPath(Option<Vec<i64>>),
}
impl QueryCacheKey {
pub fn hash(&self) -> u64 {
let mut hasher = ahash::AHasher::default();
match self {
QueryCacheKey::Bfs(key) => {
0u8.hash(&mut hasher);
key.start.hash(&mut hasher);
key.depth.hash(&mut hasher);
}
QueryCacheKey::KHop(key) => {
1u8.hash(&mut hasher);
key.start.hash(&mut hasher);
key.depth.hash(&mut hasher);
(match key.direction {
BackendDirection::Outgoing => 0u8,
BackendDirection::Incoming => 1u8,
})
.hash(&mut hasher);
}
QueryCacheKey::KHopFiltered(key) => {
2u8.hash(&mut hasher);
key.start.hash(&mut hasher);
key.depth.hash(&mut hasher);
(match key.direction {
BackendDirection::Outgoing => 0u8,
BackendDirection::Incoming => 1u8,
})
.hash(&mut hasher);
key.allowed_edge_types.len().hash(&mut hasher);
for edge_type in &key.allowed_edge_types {
edge_type.hash(&mut hasher);
}
}
QueryCacheKey::ShortestPath(key) => {
3u8.hash(&mut hasher);
key.start.hash(&mut hasher);
key.end.hash(&mut hasher);
}
}
hasher.finish()
}
}
impl PartialEq for QueryCacheKey {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(QueryCacheKey::Bfs(a), QueryCacheKey::Bfs(b)) => a == b,
(QueryCacheKey::KHop(a), QueryCacheKey::KHop(b)) => a == b,
(QueryCacheKey::KHopFiltered(a), QueryCacheKey::KHopFiltered(b)) => {
a.start == b.start
&& a.depth == b.depth
&& a.direction == b.direction
&& a.allowed_edge_types == b.allowed_edge_types
}
(QueryCacheKey::ShortestPath(a), QueryCacheKey::ShortestPath(b)) => a == b,
_ => false,
}
}
}
impl Eq for QueryCacheKey {}
impl Hash for QueryCacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.hash().hash(state);
}
}
impl PartialEq for KHopFilteredCacheKey {
fn eq(&self, other: &Self) -> bool {
self.start == other.start
&& self.depth == other.depth
&& self.direction == other.direction
&& self.allowed_edge_types == other.allowed_edge_types
}
}
impl Eq for KHopFilteredCacheKey {}
impl Hash for KHopFilteredCacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.start.hash(state);
self.depth.hash(state);
(match self.direction {
BackendDirection::Outgoing => 0u8,
BackendDirection::Incoming => 1u8,
})
.hash(state);
self.allowed_edge_types.len().hash(state);
for edge_type in &self.allowed_edge_types {
edge_type.hash(state);
}
}
}
#[derive(Debug)]
pub struct QueryCache {
cache: Arc<RwLock<HashMap<QueryCacheKey, QueryCacheEntry>>>,
}
impl QueryCache {
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get_bfs(&self, start: i64, depth: u32) -> Option<Vec<i64>> {
let key = QueryCacheKey::Bfs(BfsCacheKey { start, depth });
let cache = match self.cache.read() {
Ok(cache) => cache,
Err(poisoned) => {
eprintln!(
"WARNING: Query cache read lock poisoned in get_bfs operation (start={}, depth={}). Treating as cache miss.",
start, depth
);
poisoned.into_inner()
}
};
cache.get(&key).and_then(|entry| match &entry.result {
QueryResult::Bfs(result) => Some(result.clone()),
_ => None,
})
}
pub fn put_bfs(&self, start: i64, depth: u32, result: Vec<i64>) {
let key = QueryCacheKey::Bfs(BfsCacheKey { start, depth });
let entry = QueryCacheEntry {
result: QueryResult::Bfs(result),
};
match self.cache.write() {
Ok(mut cache) => {
cache.insert(key, entry);
}
Err(poisoned) => {
eprintln!(
"WARNING: Query cache write lock poisoned in put_bfs operation (start={}, depth={}). Recovering and continuing.",
start, depth
);
let mut cache = poisoned.into_inner();
cache.insert(key, entry);
}
}
}
pub fn get_k_hop(
&self,
start: i64,
depth: u32,
direction: BackendDirection,
) -> Option<Vec<i64>> {
let key = QueryCacheKey::KHop(KHopCacheKey {
start,
depth,
direction,
});
let cache = match self.cache.read() {
Ok(cache) => cache,
Err(poisoned) => {
eprintln!(
"WARNING: Query cache read lock poisoned in get_k_hop operation (start={}, depth={}, direction={:?}). Treating as cache miss.",
start, depth, direction
);
poisoned.into_inner()
}
};
cache.get(&key).and_then(|entry| match &entry.result {
QueryResult::KHop(result) => Some(result.clone()),
_ => None,
})
}
pub fn put_k_hop(&self, start: i64, depth: u32, direction: BackendDirection, result: Vec<i64>) {
let key = QueryCacheKey::KHop(KHopCacheKey {
start,
depth,
direction,
});
let entry = QueryCacheEntry {
result: QueryResult::KHop(result),
};
match self.cache.write() {
Ok(mut cache) => {
cache.insert(key, entry);
}
Err(poisoned) => {
eprintln!(
"WARNING: Query cache write lock poisoned in put_k_hop operation (start={}, depth={}, direction={:?}). Recovering and continuing.",
start, depth, direction
);
let mut cache = poisoned.into_inner();
cache.insert(key, entry);
}
}
}
pub fn get_k_hop_filtered(
&self,
start: i64,
depth: u32,
direction: BackendDirection,
allowed_edge_types: &[&str],
) -> Option<Vec<i64>> {
let edge_types = allowed_edge_types.iter().map(|s| s.to_string()).collect();
let key = QueryCacheKey::KHopFiltered(KHopFilteredCacheKey {
start,
depth,
direction,
allowed_edge_types: edge_types,
});
let cache = match self.cache.read() {
Ok(cache) => cache,
Err(poisoned) => {
eprintln!(
"WARNING: Query cache read lock poisoned in get_k_hop_filtered operation (start={}, depth={}, direction={:?}, edge_types={:?}). Treating as cache miss.",
start, depth, direction, allowed_edge_types
);
poisoned.into_inner()
}
};
cache.get(&key).and_then(|entry| match &entry.result {
QueryResult::KHop(result) => Some(result.clone()),
_ => None,
})
}
pub fn put_k_hop_filtered(
&self,
start: i64,
depth: u32,
direction: BackendDirection,
allowed_edge_types: &[&str],
result: Vec<i64>,
) {
let edge_types = allowed_edge_types.iter().map(|s| s.to_string()).collect();
let key = QueryCacheKey::KHopFiltered(KHopFilteredCacheKey {
start,
depth,
direction,
allowed_edge_types: edge_types,
});
let entry = QueryCacheEntry {
result: QueryResult::KHop(result),
};
match self.cache.write() {
Ok(mut cache) => {
cache.insert(key, entry);
}
Err(poisoned) => {
eprintln!(
"WARNING: Query cache write lock poisoned in put_k_hop_filtered operation (start={}, depth={}, direction={:?}, edge_types={:?}). Recovering and continuing.",
start, depth, direction, allowed_edge_types
);
let mut cache = poisoned.into_inner();
cache.insert(key, entry);
}
}
}
pub fn get_shortest_path(&self, start: i64, end: i64) -> Option<Option<Vec<i64>>> {
let key = QueryCacheKey::ShortestPath(ShortestPathCacheKey { start, end });
let cache = match self.cache.read() {
Ok(cache) => cache,
Err(poisoned) => {
eprintln!(
"WARNING: Query cache read lock poisoned in get_shortest_path operation (start={}, end={}). Treating as cache miss.",
start, end
);
poisoned.into_inner()
}
};
cache.get(&key).and_then(|entry| match &entry.result {
QueryResult::ShortestPath(result) => Some(result.clone()),
_ => None,
})
}
pub fn put_shortest_path(&self, start: i64, end: i64, result: Option<Vec<i64>>) {
let key = QueryCacheKey::ShortestPath(ShortestPathCacheKey { start, end });
let entry = QueryCacheEntry {
result: QueryResult::ShortestPath(result),
};
match self.cache.write() {
Ok(mut cache) => {
cache.insert(key, entry);
}
Err(poisoned) => {
eprintln!(
"WARNING: Query cache write lock poisoned in put_shortest_path operation (start={}, end={}). Recovering and continuing.",
start, end
);
let mut cache = poisoned.into_inner();
cache.insert(key, entry);
}
}
}
pub fn invalidate_all(&self) {
match self.cache.write() {
Ok(mut cache) => {
cache.clear();
}
Err(poisoned) => {
eprintln!(
"WARNING: Query cache write lock poisoned in invalidate_all operation. Recovering and continuing."
);
let mut cache = poisoned.into_inner();
cache.clear();
}
}
}
pub fn size(&self) -> usize {
match self.cache.read() {
Ok(cache) => cache.len(),
Err(poisoned) => {
eprintln!(
"WARNING: Query cache read lock poisoned in size operation. Treating as empty cache."
);
poisoned.into_inner().len()
}
}
}
pub fn is_empty(&self) -> bool {
match self.cache.read() {
Ok(cache) => cache.is_empty(),
Err(poisoned) => {
eprintln!(
"WARNING: Query cache read lock poisoned in is_empty operation. Treating as empty cache."
);
poisoned.into_inner().is_empty()
}
}
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new()
}
}
impl Clone for QueryCache {
fn clone(&self) -> Self {
Self {
cache: Arc::clone(&self.cache),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_key_hashing() {
let key1 = QueryCacheKey::Bfs(BfsCacheKey {
start: 42,
depth: 3,
});
let key2 = QueryCacheKey::Bfs(BfsCacheKey {
start: 42,
depth: 3,
});
assert_eq!(key1.hash(), key2.hash());
let key3 = QueryCacheKey::Bfs(BfsCacheKey {
start: 42,
depth: 4,
});
assert_ne!(key1.hash(), key3.hash());
}
#[test]
fn test_cache_basic_operations() {
let cache = QueryCache::new();
assert_eq!(cache.get_bfs(1, 2), None);
cache.put_bfs(1, 2, vec![3, 4, 5]);
assert_eq!(cache.get_bfs(1, 2), Some(vec![3, 4, 5]));
assert_eq!(cache.size(), 1);
assert!(!cache.is_empty());
cache.invalidate_all();
assert_eq!(cache.get_bfs(1, 2), None);
assert_eq!(cache.size(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_k_hop_filtered_cache() {
let cache = QueryCache::new();
let edge_types = vec!["friend", "colleague"];
assert_eq!(
cache.get_k_hop_filtered(1, 2, BackendDirection::Outgoing, &edge_types),
None
);
cache.put_k_hop_filtered(1, 2, BackendDirection::Outgoing, &edge_types, vec![3, 4]);
assert_eq!(
cache.get_k_hop_filtered(1, 2, BackendDirection::Outgoing, &edge_types),
Some(vec![3, 4])
);
assert_eq!(
cache.get_k_hop_filtered(1, 2, BackendDirection::Outgoing, &["enemy"]),
None
);
}
#[test]
fn test_shortest_path_cache() {
let cache = QueryCache::new();
cache.put_shortest_path(1, 5, None);
assert_eq!(cache.get_shortest_path(1, 5), Some(None));
cache.put_shortest_path(1, 3, Some(vec![1, 2, 3]));
assert_eq!(cache.get_shortest_path(1, 3), Some(Some(vec![1, 2, 3])));
assert_eq!(cache.size(), 2);
}
}