use crate::types::{Value, RowId};
use lru::LruCache;
use parking_lot::Mutex;
use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
pub struct CachedIndex {
cache: Mutex<LruCache<Vec<u8>, Arc<Vec<RowId>>>>,
hit_count: AtomicU64,
miss_count: AtomicU64,
}
impl CachedIndex {
pub fn new(capacity: usize) -> Self {
let capacity = NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(1000).unwrap());
Self {
cache: Mutex::new(LruCache::new(capacity)),
hit_count: AtomicU64::new(0),
miss_count: AtomicU64::new(0),
}
}
pub fn get(&self, key: &Value) -> Option<Arc<Vec<RowId>>> {
let key_bytes = bincode::serialize(key).ok()?;
let mut cache = self.cache.lock();
if let Some(ids) = cache.get(&key_bytes) {
self.hit_count.fetch_add(1, Ordering::Relaxed);
return Some(Arc::clone(ids)); }
None
}
pub fn put(&self, key: Value, ids: Vec<RowId>) {
if let Ok(key_bytes) = bincode::serialize(&key) {
let mut cache = self.cache.lock();
cache.put(key_bytes, Arc::new(ids)); }
}
pub fn record_miss(&self) {
self.miss_count.fetch_add(1, Ordering::Relaxed);
}
pub fn hit_rate(&self) -> f64 {
let hits = self.hit_count.load(Ordering::Relaxed) as f64;
let misses = self.miss_count.load(Ordering::Relaxed) as f64;
if hits + misses == 0.0 {
0.0
} else {
hits / (hits + misses)
}
}
pub fn stats(&self) -> CacheStats {
let cache = self.cache.lock();
CacheStats {
capacity: cache.cap().get(),
size: cache.len(),
hits: self.hit_count.load(Ordering::Relaxed),
misses: self.miss_count.load(Ordering::Relaxed),
hit_rate: self.hit_rate(),
}
}
pub fn clear(&self) {
let mut cache = self.cache.lock();
cache.clear();
self.hit_count.store(0, Ordering::Relaxed);
self.miss_count.store(0, Ordering::Relaxed);
}
pub fn invalidate(&self, key: &Value) {
if let Ok(key_bytes) = bincode::serialize(key) {
let mut cache = self.cache.lock();
cache.pop(&key_bytes);
}
}
pub fn invalidate_range(&self, start: &Value, end: &Value) {
let Ok(start_bytes) = bincode::serialize(start) else { return };
let Ok(end_bytes) = bincode::serialize(end) else { return };
let mut cache = self.cache.lock();
let keys_to_remove: Vec<Vec<u8>> = cache.iter()
.filter_map(|(key_bytes, _)| {
if key_bytes >= &start_bytes && key_bytes <= &end_bytes {
Some(key_bytes.clone())
} else {
None
}
})
.collect();
for key in keys_to_remove {
cache.pop(&key);
}
}
pub fn invalidate_batch(&self, keys: &[Value]) {
if keys.is_empty() {
return;
}
let mut cache = self.cache.lock();
for key in keys {
if let Ok(key_bytes) = bincode::serialize(key) {
cache.pop(&key_bytes);
}
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub capacity: usize,
pub size: usize,
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
}
impl std::fmt::Display for CacheStats {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"Cache: {}/{} entries, {:.1}% hit rate ({} hits, {} misses)",
self.size,
self.capacity,
self.hit_rate * 100.0,
self.hits,
self.misses
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_basic() {
let cache = CachedIndex::new(100);
assert_eq!(cache.stats().size, 0);
assert_eq!(cache.hit_rate(), 0.0);
cache.put(Value::Integer(1), vec![100, 200]);
cache.put(Value::Integer(2), vec![300]);
assert_eq!(*cache.get(&Value::Integer(1)).unwrap(), vec![100, 200]);
assert_eq!(*cache.get(&Value::Integer(2)).unwrap(), vec![300]);
assert_eq!(cache.get(&Value::Integer(3)), None);
cache.record_miss();
let stats = cache.stats();
assert_eq!(stats.size, 2);
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate - 0.666).abs() < 0.01);
}
#[test]
fn test_cache_lru_eviction() {
let cache = CachedIndex::new(2);
cache.put(Value::Integer(1), vec![100]);
cache.put(Value::Integer(2), vec![200]);
cache.get(&Value::Integer(1));
cache.put(Value::Integer(3), vec![300]);
assert_eq!(*cache.get(&Value::Integer(1)).unwrap(), vec![100]);
assert_eq!(*cache.get(&Value::Integer(3)).unwrap(), vec![300]);
assert_eq!(cache.get(&Value::Integer(2)), None);
}
#[test]
fn test_cache_invalidate() {
let cache = CachedIndex::new(100);
cache.put(Value::Integer(1), vec![100]);
assert_eq!(*cache.get(&Value::Integer(1)).unwrap(), vec![100]);
cache.invalidate(&Value::Integer(1));
assert_eq!(cache.get(&Value::Integer(1)), None);
}
#[test]
fn test_cache_clear() {
let cache = CachedIndex::new(100);
cache.put(Value::Integer(1), vec![100]);
cache.put(Value::Integer(2), vec![200]);
cache.get(&Value::Integer(1));
assert_eq!(cache.stats().size, 2);
assert_eq!(cache.stats().hits, 1);
cache.clear();
assert_eq!(cache.stats().size, 0);
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
}
}