use std::hash::{Hash, Hasher};
use std::time::{Duration, Instant};
use arrow_array::RecordBatch;
use equivalent::Equivalent;
use quick_cache::sync::{Cache, DefaultLifecycle};
use quick_cache::{DefaultHashBuilder, Weighter};
use crate::lookup::table::LookupResult;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct LookupCacheKey {
pub table_id: u32,
pub key: Vec<u8>,
}
pub(crate) struct LookupCacheKeyRef<'a> {
pub(crate) table_id: u32,
pub(crate) key: &'a [u8],
}
impl Hash for LookupCacheKeyRef<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.table_id.hash(state);
self.key.hash(state);
}
}
impl Equivalent<LookupCacheKey> for LookupCacheKeyRef<'_> {
fn equivalent(&self, other: &LookupCacheKey) -> bool {
self.table_id == other.table_id && self.key == other.key.as_slice()
}
}
#[derive(Debug, Clone, Copy)]
pub struct LookupMemoryCacheConfig {
pub capacity_bytes: usize,
pub ttl: Option<Duration>,
}
impl Default for LookupMemoryCacheConfig {
fn default() -> Self {
Self {
capacity_bytes: 64 * 1024 * 1024, ttl: None,
}
}
}
#[derive(Clone)]
struct CachedBatch {
batch: RecordBatch,
inserted_at: Instant,
}
#[derive(Debug, Clone)]
struct BatchWeighter;
impl Weighter<LookupCacheKey, CachedBatch> for BatchWeighter {
fn weight(&self, _key: &LookupCacheKey, val: &CachedBatch) -> u64 {
val.batch.get_array_memory_size().max(1) as u64
}
}
type BatchCache = Cache<LookupCacheKey, CachedBatch, BatchWeighter>;
pub struct LookupMemoryCache {
cache: BatchCache,
table_id: u32,
ttl: Option<Duration>,
}
impl LookupMemoryCache {
#[must_use]
pub fn new(table_id: u32, config: LookupMemoryCacheConfig) -> Self {
let estimated_items = (config.capacity_bytes / 1024).max(64);
let cache = BatchCache::with(
estimated_items,
config.capacity_bytes as u64,
BatchWeighter,
DefaultHashBuilder::default(),
DefaultLifecycle::default(),
);
Self {
cache,
table_id,
ttl: config.ttl,
}
}
#[must_use]
pub fn with_defaults(table_id: u32) -> Self {
Self::new(table_id, LookupMemoryCacheConfig::default())
}
#[must_use]
pub fn table_id(&self) -> u32 {
self.table_id
}
fn make_key(&self, key: &[u8]) -> LookupCacheKey {
LookupCacheKey {
table_id: self.table_id,
key: key.to_vec(),
}
}
#[must_use]
pub fn get_cached(&self, key: &[u8]) -> LookupResult {
let ref_key = LookupCacheKeyRef {
table_id: self.table_id,
key,
};
match self.cache.get(&ref_key) {
Some(cached) if self.is_expired(&cached) => {
self.cache.remove_if(&ref_key, |v| self.is_expired(v));
LookupResult::NotFound
}
Some(cached) => LookupResult::Hit(cached.batch),
None => LookupResult::NotFound,
}
}
fn is_expired(&self, entry: &CachedBatch) -> bool {
self.ttl
.is_some_and(|ttl| entry.inserted_at.elapsed() >= ttl)
}
pub fn insert(&self, key: &[u8], value: RecordBatch) {
let cache_key = self.make_key(key);
self.cache.insert(
cache_key,
CachedBatch {
batch: value,
inserted_at: Instant::now(),
},
);
}
pub fn invalidate(&self, key: &[u8]) {
let ref_key = LookupCacheKeyRef {
table_id: self.table_id,
key,
};
self.cache.remove(&ref_key);
}
#[must_use]
pub fn len(&self) -> usize {
self.cache.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
}
impl std::fmt::Debug for LookupMemoryCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LookupMemoryCache")
.field("table_id", &self.table_id)
.field("ttl", &self.ttl)
.field("entries", &self.cache.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::StringArray;
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
fn test_batch(val: &str) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Utf8, false)]));
RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(vec![val]))]).unwrap()
}
fn small_cache(table_id: u32) -> LookupMemoryCache {
LookupMemoryCache::new(
table_id,
LookupMemoryCacheConfig {
capacity_bytes: 64 * 1024,
ttl: None,
},
)
}
#[test]
fn test_lookup_cache_hit_miss() {
let cache = small_cache(1);
assert!(cache.get_cached(b"key1").is_not_found());
cache.insert(b"key1", test_batch("value1"));
let result = cache.get_cached(b"key1");
assert!(result.is_hit());
assert_eq!(result.into_batch().unwrap().num_rows(), 1);
}
#[test]
fn test_lookup_cache_eviction() {
let cache = LookupMemoryCache::new(
1,
LookupMemoryCacheConfig {
capacity_bytes: 512,
ttl: None,
},
);
for i in 0..200u8 {
cache.insert(&[i], test_batch(&format!("v{i}")));
}
assert!(
cache.len() < 200,
"byte bound did not evict: len {}",
cache.len()
);
}
#[test]
fn test_lookup_cache_invalidation() {
let cache = small_cache(1);
cache.insert(b"key1", test_batch("value1"));
assert!(cache.get_cached(b"key1").is_hit());
cache.invalidate(b"key1");
assert!(cache.get_cached(b"key1").is_not_found());
}
#[test]
fn test_lookup_cache_table_id_isolation() {
let cache_a = small_cache(1);
let cache_b = small_cache(2);
cache_a.insert(b"key1", test_batch("from_a"));
cache_b.insert(b"key1", test_batch("from_b"));
let batch_a = cache_a.get_cached(b"key1").into_batch().unwrap();
let batch_b = cache_b.get_cached(b"key1").into_batch().unwrap();
assert_eq!(batch_a.num_rows(), 1);
assert_eq!(batch_b.num_rows(), 1);
assert_ne!(batch_a, batch_b);
}
fn ttl_cache(ttl: Duration) -> LookupMemoryCache {
LookupMemoryCache::new(
1,
LookupMemoryCacheConfig {
capacity_bytes: 64 * 1024,
ttl: Some(ttl),
},
)
}
#[test]
fn test_ttl_zero_expires_immediately() {
let cache = ttl_cache(Duration::ZERO);
cache.insert(b"k", test_batch("v"));
assert!(cache.get_cached(b"k").is_not_found());
assert!(cache.is_empty());
}
#[test]
fn test_ttl_hit_then_expire() {
let cache = ttl_cache(Duration::from_millis(20));
cache.insert(b"k", test_batch("v"));
assert!(cache.get_cached(b"k").is_hit());
std::thread::sleep(Duration::from_millis(40));
assert!(cache.get_cached(b"k").is_not_found());
assert!(cache.is_empty());
}
#[test]
fn test_no_ttl_entry_survives() {
let cache = small_cache(1);
cache.insert(b"k", test_batch("v"));
std::thread::sleep(Duration::from_millis(10));
assert!(cache.get_cached(b"k").is_hit());
}
}