use parking_lot::Mutex;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::time::{Duration, Instant};
use super::planner::ExecutionPlan;
use super::query::ResultRow;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CachePolicy {
Permanent,
TimeBound {
ttl: Duration,
},
}
impl Default for CachePolicy {
fn default() -> Self {
Self::TimeBound {
ttl: Duration::from_secs(5),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub plan_hash: u64,
pub capability_version: u64,
}
impl CacheKey {
pub fn for_plan(plan: &ExecutionPlan, capability_version: u64) -> Option<Self> {
use std::collections::hash_map::DefaultHasher;
let bytes = postcard::to_allocvec(plan).ok()?;
let mut hasher = DefaultHasher::new();
bytes.hash(&mut hasher);
Some(Self {
plan_hash: hasher.finish(),
capability_version,
})
}
}
#[derive(Clone, Debug)]
pub struct CachedResult {
pub rows: Vec<ResultRow>,
pub inserted_at: Instant,
pub policy: CachePolicy,
approx_bytes: u64,
}
impl CachedResult {
#[inline]
pub fn new(rows: Vec<ResultRow>, inserted_at: Instant, policy: CachePolicy) -> Self {
let approx_bytes = compute_approx_bytes(&rows);
Self {
rows,
inserted_at,
policy,
approx_bytes,
}
}
pub fn is_expired(&self) -> bool {
match self.policy {
CachePolicy::Permanent => false,
CachePolicy::TimeBound { ttl } => self.inserted_at.elapsed() >= ttl,
}
}
#[inline]
fn approx_bytes(&self) -> u64 {
self.approx_bytes
}
}
fn compute_approx_bytes(rows: &[ResultRow]) -> u64 {
let row_overhead = std::mem::size_of::<ResultRow>() as u64;
rows.iter()
.map(|r| r.payload.len() as u64 + row_overhead)
.sum::<u64>()
}
pub trait ResultCache: Send + Sync {
fn get(&self, key: &CacheKey) -> Option<CachedResult>;
fn insert(&self, key: CacheKey, result: CachedResult);
fn invalidate_all(&self);
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub const LRU_MAX_ENTRIES: usize = 1024;
pub const LRU_MAX_BYTES: u64 = 256 * 1024 * 1024;
pub struct LruResultCache {
inner: Mutex<LruInner>,
}
struct LruInner {
by_key: HashMap<CacheKey, usize>,
nodes: Vec<LruNode>,
head: Option<usize>,
tail: Option<usize>,
free: Vec<usize>,
total_bytes: u64,
max_entries: usize,
max_bytes: u64,
}
struct LruNode {
key: CacheKey,
value: CachedResult,
prev: Option<usize>,
next: Option<usize>,
bytes: u64,
}
impl Default for LruResultCache {
fn default() -> Self {
Self::new(LRU_MAX_ENTRIES, LRU_MAX_BYTES)
}
}
impl LruResultCache {
pub fn new(max_entries: usize, max_bytes: u64) -> Self {
Self {
inner: Mutex::new(LruInner {
by_key: HashMap::new(),
nodes: Vec::new(),
head: None,
tail: None,
free: Vec::new(),
total_bytes: 0,
max_entries,
max_bytes,
}),
}
}
}
impl ResultCache for LruResultCache {
fn get(&self, key: &CacheKey) -> Option<CachedResult> {
let mut g = self.inner.lock();
let idx = *g.by_key.get(key)?;
if g.nodes[idx].value.is_expired() {
g.detach_and_drop(idx);
return None;
}
g.move_to_head(idx);
Some(g.nodes[idx].value.clone())
}
fn insert(&self, key: CacheKey, result: CachedResult) {
let mut g = self.inner.lock();
let bytes = result.approx_bytes();
if bytes > g.max_bytes {
return;
}
if let Some(&idx) = g.by_key.get(&key) {
let old_bytes = g.nodes[idx].bytes;
g.total_bytes = g.total_bytes.saturating_sub(old_bytes);
g.nodes[idx].value = result;
g.nodes[idx].bytes = bytes;
g.total_bytes = g.total_bytes.saturating_add(bytes);
g.move_to_head(idx);
g.evict_until_within_bounds();
return;
}
let prev_head = g.head;
let idx = g.alloc_node(LruNode {
key,
value: result,
prev: None,
next: prev_head,
bytes,
});
if let Some(h) = g.head {
g.nodes[h].prev = Some(idx);
}
g.head = Some(idx);
if g.tail.is_none() {
g.tail = Some(idx);
}
g.by_key.insert(key, idx);
g.total_bytes = g.total_bytes.saturating_add(bytes);
g.evict_until_within_bounds();
}
fn invalidate_all(&self) {
let mut g = self.inner.lock();
g.by_key.clear();
g.nodes.clear();
g.head = None;
g.tail = None;
g.free.clear();
g.total_bytes = 0;
}
fn len(&self) -> usize {
self.inner.lock().by_key.len()
}
}
impl LruInner {
fn alloc_node(&mut self, node: LruNode) -> usize {
if let Some(idx) = self.free.pop() {
self.nodes[idx] = node;
idx
} else {
self.nodes.push(node);
self.nodes.len() - 1
}
}
fn detach(&mut self, idx: usize) {
let (prev, next) = (self.nodes[idx].prev, self.nodes[idx].next);
match prev {
Some(p) => self.nodes[p].next = next,
None => self.head = next,
}
match next {
Some(n) => self.nodes[n].prev = prev,
None => self.tail = prev,
}
self.nodes[idx].prev = None;
self.nodes[idx].next = None;
}
fn detach_and_drop(&mut self, idx: usize) {
let key = self.nodes[idx].key;
let bytes = self.nodes[idx].bytes;
self.detach(idx);
self.by_key.remove(&key);
self.total_bytes = self.total_bytes.saturating_sub(bytes);
self.free.push(idx);
}
fn move_to_head(&mut self, idx: usize) {
if self.head == Some(idx) {
return;
}
self.detach(idx);
self.nodes[idx].prev = None;
self.nodes[idx].next = self.head;
if let Some(h) = self.head {
self.nodes[h].prev = Some(idx);
}
self.head = Some(idx);
if self.tail.is_none() {
self.tail = Some(idx);
}
}
fn evict_until_within_bounds(&mut self) {
while self.by_key.len() > self.max_entries || self.total_bytes > self.max_bytes {
let Some(tail) = self.tail else { break };
self.detach_and_drop(tail);
}
}
}
#[cfg(test)]
mod tests {
use super::super::query::SeqNum;
use super::*;
use std::thread::sleep;
fn make_rows(n: usize) -> Vec<ResultRow> {
(0..n)
.map(|i| ResultRow {
origin: i as u64,
seq: SeqNum(i as u64),
payload: vec![0u8; 8],
})
.collect()
}
fn make_result(rows: Vec<ResultRow>, policy: CachePolicy) -> CachedResult {
CachedResult::new(rows, Instant::now(), policy)
}
fn key(plan: u64, version: u64) -> CacheKey {
CacheKey {
plan_hash: plan,
capability_version: version,
}
}
#[test]
fn default_policy_is_timebound_5s() {
assert_eq!(
CachePolicy::default(),
CachePolicy::TimeBound {
ttl: Duration::from_secs(5)
}
);
}
#[test]
fn permanent_entries_never_expire_by_time() {
let r = make_result(vec![], CachePolicy::Permanent);
assert!(!r.is_expired());
}
#[test]
fn timebound_entries_expire_after_ttl() {
let r = CachedResult::new(
vec![],
Instant::now() - Duration::from_millis(50),
CachePolicy::TimeBound {
ttl: Duration::from_millis(10),
},
);
assert!(r.is_expired());
}
#[test]
fn lru_round_trips_a_simple_insert_then_get() {
let cache = LruResultCache::default();
let k = key(1, 1);
cache.insert(k, make_result(make_rows(3), CachePolicy::Permanent));
let got = cache.get(&k).expect("hit");
assert_eq!(got.rows.len(), 3);
assert_eq!(cache.len(), 1);
}
#[test]
fn lru_miss_on_unknown_key() {
let cache = LruResultCache::default();
assert!(cache.get(&key(42, 0)).is_none());
}
#[test]
fn lru_miss_on_version_mismatch_by_construction() {
let cache = LruResultCache::default();
cache.insert(key(1, 1), make_result(vec![], CachePolicy::Permanent));
assert!(cache.get(&key(1, 2)).is_none());
assert!(cache.get(&key(1, 1)).is_some());
}
#[test]
fn lru_expired_entries_miss_and_are_dropped_lazily() {
let cache = LruResultCache::default();
let k = key(1, 0);
let stale = CachedResult::new(
vec![],
Instant::now() - Duration::from_millis(50),
CachePolicy::TimeBound {
ttl: Duration::from_millis(10),
},
);
cache.insert(k, stale);
assert!(cache.get(&k).is_none());
assert_eq!(cache.len(), 0, "expired entry dropped on miss");
}
#[test]
fn lru_evicts_least_recently_used_when_entry_bound_trips() {
let cache = LruResultCache::new(2, u64::MAX);
cache.insert(key(1, 0), make_result(make_rows(1), CachePolicy::Permanent));
cache.insert(key(2, 0), make_result(make_rows(1), CachePolicy::Permanent));
let _ = cache.get(&key(1, 0));
cache.insert(key(3, 0), make_result(make_rows(1), CachePolicy::Permanent));
assert_eq!(cache.len(), 2);
assert!(cache.get(&key(1, 0)).is_some());
assert!(cache.get(&key(2, 0)).is_none(), "evicted as LRU");
assert!(cache.get(&key(3, 0)).is_some());
}
#[test]
fn lru_evicts_when_byte_bound_trips() {
let row_bytes = std::mem::size_of::<ResultRow>() as u64 + 8; let cache = LruResultCache::new(usize::MAX, row_bytes + 1);
cache.insert(key(1, 0), make_result(make_rows(1), CachePolicy::Permanent));
assert_eq!(cache.len(), 1);
cache.insert(key(2, 0), make_result(make_rows(1), CachePolicy::Permanent));
assert!(cache.len() <= 1);
}
#[test]
fn lru_rejects_oversized_entry_instead_of_self_evicting() {
let row_bytes = std::mem::size_of::<ResultRow>() as u64 + 8;
let cache = LruResultCache::new(usize::MAX, row_bytes + 1);
cache.insert(key(1, 0), make_result(make_rows(1), CachePolicy::Permanent));
assert_eq!(cache.len(), 1);
cache.insert(key(2, 0), make_result(make_rows(4), CachePolicy::Permanent));
assert!(
cache.get(&key(2, 0)).is_none(),
"oversized insert must not be observable via get"
);
assert!(
cache.get(&key(1, 0)).is_some(),
"prior entry must survive a refused oversized insert"
);
}
#[test]
fn lru_replace_at_same_key_updates_bytes_in_place() {
let cache = LruResultCache::new(8, 10_000);
let k = key(1, 0);
cache.insert(k, make_result(make_rows(1), CachePolicy::Permanent));
cache.insert(k, make_result(make_rows(5), CachePolicy::Permanent));
assert_eq!(cache.len(), 1);
let got = cache.get(&k).unwrap();
assert_eq!(got.rows.len(), 5);
}
#[test]
fn invalidate_all_drops_every_entry() {
let cache = LruResultCache::default();
for i in 0..5 {
cache.insert(key(i, 0), make_result(make_rows(1), CachePolicy::Permanent));
}
assert_eq!(cache.len(), 5);
cache.invalidate_all();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn key_for_plan_is_deterministic() {
use super::super::planner::{CostEstimate, OperatorNode, OperatorPlan};
let plan = ExecutionPlan {
root: OperatorNode {
operator: OperatorPlan::LatestRead {
origin: 0xABCD_EF01,
},
target_nodes: vec![1, 2, 3],
cost: CostEstimate::default(),
},
total_cost: CostEstimate::default(),
};
let k1 = CacheKey::for_plan(&plan, 7).unwrap();
let k2 = CacheKey::for_plan(&plan, 7).unwrap();
assert_eq!(k1, k2);
}
#[test]
fn key_for_plan_differs_on_version_change() {
use super::super::planner::{CostEstimate, OperatorNode, OperatorPlan};
let plan = ExecutionPlan {
root: OperatorNode {
operator: OperatorPlan::LatestRead { origin: 0x01 },
target_nodes: vec![],
cost: CostEstimate::default(),
},
total_cost: CostEstimate::default(),
};
let a = CacheKey::for_plan(&plan, 1).unwrap();
let b = CacheKey::for_plan(&plan, 2).unwrap();
assert_ne!(a, b);
}
#[test]
fn key_for_plan_handles_filter_plans_without_panicking() {
use super::super::planner::{CostEstimate, OperatorNode, OperatorPlan};
use crate::adapter::net::behavior::predicate::Predicate;
use crate::adapter::net::behavior::tag::TagKey;
use crate::adapter::net::behavior::TaxonomyAxis;
let pred = Predicate::Equals {
key: TagKey::new(TaxonomyAxis::Software, "any"),
value: "v".to_string(),
};
let inner = OperatorNode {
operator: OperatorPlan::LatestRead { origin: 0x01 },
target_nodes: vec![],
cost: CostEstimate::default(),
};
let plan = ExecutionPlan {
root: OperatorNode {
operator: OperatorPlan::Filter {
input: Box::new(inner),
predicate: pred.to_wire(),
},
target_nodes: vec![],
cost: CostEstimate::default(),
},
total_cost: CostEstimate::default(),
};
let k = CacheKey::for_plan(&plan, 0).expect("filter plan is encodable today");
assert_eq!(k, CacheKey::for_plan(&plan, 0).unwrap());
}
#[test]
fn ttl_expiry_is_observable_through_get() {
let cache = LruResultCache::default();
let k = key(1, 0);
let entry = make_result(
make_rows(1),
CachePolicy::TimeBound {
ttl: Duration::from_millis(15),
},
);
cache.insert(k, entry);
assert!(cache.get(&k).is_some());
sleep(Duration::from_millis(25));
assert!(cache.get(&k).is_none());
}
}