use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use crate::types::{sequence::OpType, value::Row};
use lru::LruCache;
#[derive(Clone, Debug)]
pub struct CacheEntry {
pub op_type: OpType,
pub row: Row,
}
pub struct RowCache {
inner: Mutex<LruCache<Vec<u8>, CacheEntry>>,
generation: AtomicU64,
hits: AtomicU64,
misses: AtomicU64,
}
impl RowCache {
pub fn new(capacity: usize) -> Self {
Self {
inner: Mutex::new(LruCache::new(
NonZeroUsize::new(capacity).expect("row cache capacity must be > 0"),
)),
generation: AtomicU64::new(0),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
pub fn get(&self, user_key: &[u8]) -> Option<CacheEntry> {
let mut guard = self.inner.lock().unwrap();
match guard.get(user_key) {
Some(entry) => {
self.hits.fetch_add(1, Ordering::Relaxed);
crate::engine::metrics::inc(crate::engine::metrics::ROW_CACHE_HITS_TOTAL);
Some(entry.clone())
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
crate::engine::metrics::inc(crate::engine::metrics::ROW_CACHE_MISSES_TOTAL);
None
}
}
}
#[inline]
pub fn snapshot_generation(&self) -> u64 {
self.generation.load(Ordering::Acquire)
}
pub fn insert_if_fresh(&self, user_key: Vec<u8>, entry: CacheEntry, captured_gen: u64) {
let mut guard = self.inner.lock().unwrap();
if self.generation.load(Ordering::Acquire) != captured_gen {
return;
}
guard.put(user_key, entry);
}
pub fn insert(&self, user_key: Vec<u8>, entry: CacheEntry) {
let mut guard = self.inner.lock().unwrap();
guard.put(user_key, entry);
}
pub fn invalidate(&self, user_key: &[u8]) {
self.generation.fetch_add(1, Ordering::AcqRel);
let mut guard = self.inner.lock().unwrap();
guard.pop(user_key);
}
pub fn clear(&self) {
self.generation.fetch_add(1, Ordering::AcqRel);
let mut guard = self.inner.lock().unwrap();
guard.clear();
}
pub fn len(&self) -> usize {
self.inner.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn cap(&self) -> usize {
self.inner.lock().unwrap().cap().into()
}
pub fn hit_count(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
pub fn miss_count(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::value::FieldValue;
fn make_entry(val: i64) -> CacheEntry {
CacheEntry {
op_type: OpType::Put,
row: Row::new(vec![Some(FieldValue::Int64(val))]),
}
}
#[test]
fn insert_and_get() {
let cache = RowCache::new(10);
cache.insert(b"key1".to_vec(), make_entry(1));
let hit = cache.get(b"key1");
assert!(hit.is_some());
assert_eq!(cache.hit_count(), 1);
assert_eq!(cache.miss_count(), 0);
}
#[test]
fn miss_returns_none() {
let cache = RowCache::new(10);
assert!(cache.get(b"missing").is_none());
assert_eq!(cache.miss_count(), 1);
}
#[test]
fn invalidate_removes_entry() {
let cache = RowCache::new(10);
cache.insert(b"key1".to_vec(), make_entry(1));
cache.invalidate(b"key1");
assert!(cache.get(b"key1").is_none());
}
#[test]
fn lru_eviction() {
let cache = RowCache::new(2);
cache.insert(b"a".to_vec(), make_entry(1));
cache.insert(b"b".to_vec(), make_entry(2));
cache.insert(b"c".to_vec(), make_entry(3)); assert!(cache.get(b"a").is_none());
assert!(cache.get(b"b").is_some());
assert!(cache.get(b"c").is_some());
}
#[test]
fn insert_if_fresh_rejects_after_invalidate() {
let cache = RowCache::new(10);
cache.insert(b"k".to_vec(), make_entry(1));
let gen = cache.snapshot_generation();
cache.invalidate(b"k");
assert!(cache.get(b"k").is_none());
cache.insert_if_fresh(b"k".to_vec(), make_entry(42), gen);
assert!(
cache.get(b"k").is_none(),
"insert_if_fresh must drop when generation advanced"
);
}
#[test]
fn insert_if_fresh_rejects_after_clear() {
let cache = RowCache::new(10);
let gen = cache.snapshot_generation();
cache.clear();
cache.insert_if_fresh(b"k".to_vec(), make_entry(42), gen);
assert!(cache.get(b"k").is_none());
}
#[test]
fn insert_if_fresh_succeeds_without_races() {
let cache = RowCache::new(10);
let gen = cache.snapshot_generation();
cache.insert_if_fresh(b"k".to_vec(), make_entry(42), gen);
assert!(cache.get(b"k").is_some());
}
#[test]
fn clear_drops_all_entries() {
let cache = RowCache::new(10);
cache.insert(b"a".to_vec(), make_entry(1));
cache.insert(b"b".to_vec(), make_entry(2));
cache.insert(b"c".to_vec(), make_entry(3));
assert_eq!(cache.len(), 3);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.get(b"a").is_none());
assert!(cache.get(b"b").is_none());
assert!(cache.get(b"c").is_none());
}
#[test]
fn len_and_cap() {
let cache = RowCache::new(5);
assert_eq!(cache.cap(), 5);
assert_eq!(cache.len(), 0);
cache.insert(b"x".to_vec(), make_entry(1));
assert_eq!(cache.len(), 1);
}
}