use std::sync::Arc;
use deepsize::DeepSizeOf;
use lance_index::{scalar::ScalarIndex, vector::VectorIndex};
use lance_table::format::Index;
use moka::sync::{Cache, ConcurrentCacheExt};
use std::sync::atomic::{AtomicU64, Ordering};
use crate::dataset::DEFAULT_INDEX_CACHE_SIZE;
#[derive(Debug, Default, DeepSizeOf)]
struct CacheStats {
hits: AtomicU64,
misses: AtomicU64,
}
impl CacheStats {
fn record_hit(&self) {
self.hits.fetch_add(1, Ordering::Relaxed);
}
fn record_miss(&self) {
self.misses.fetch_add(1, Ordering::Relaxed);
}
}
#[derive(Clone)]
pub struct IndexCache {
scalar_cache: Arc<Cache<String, Arc<dyn ScalarIndex>>>,
vector_cache: Arc<Cache<String, Arc<dyn VectorIndex>>>,
metadata_cache: Arc<Cache<String, Arc<Vec<Index>>>>,
cache_stats: Arc<CacheStats>,
}
impl DeepSizeOf for IndexCache {
fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
self.scalar_cache
.iter()
.map(|(_, v)| v.deep_size_of_children(context))
.sum::<usize>()
+ self
.vector_cache
.iter()
.map(|(_, v)| v.deep_size_of_children(context))
.sum::<usize>()
+ self
.metadata_cache
.iter()
.map(|(_, v)| v.deep_size_of_children(context))
.sum::<usize>()
+ self.cache_stats.deep_size_of_children(context)
}
}
impl IndexCache {
pub(crate) fn new(capacity: usize) -> Self {
Self {
scalar_cache: Arc::new(Cache::new(capacity as u64)),
vector_cache: Arc::new(Cache::new(capacity as u64)),
metadata_cache: Arc::new(Cache::new(capacity as u64)),
cache_stats: Arc::new(CacheStats::default()),
}
}
pub(crate) fn capacity(&self) -> u64 {
self.vector_cache
.policy()
.max_capacity()
.unwrap_or(DEFAULT_INDEX_CACHE_SIZE as u64)
}
#[allow(dead_code)]
pub(crate) fn len_vector(&self) -> usize {
self.vector_cache.sync();
self.vector_cache.entry_count() as usize
}
pub(crate) fn get_size(&self) -> usize {
self.scalar_cache.sync();
self.vector_cache.sync();
self.metadata_cache.sync();
(self.scalar_cache.entry_count()
+ self.vector_cache.entry_count()
+ self.metadata_cache.entry_count()) as usize
}
pub(crate) fn get_scalar(&self, key: &str) -> Option<Arc<dyn ScalarIndex>> {
if let Some(index) = self.scalar_cache.get(key) {
self.cache_stats.record_hit();
Some(index)
} else {
self.cache_stats.record_miss();
None
}
}
pub(crate) fn get_vector(&self, key: &str) -> Option<Arc<dyn VectorIndex>> {
if let Some(index) = self.vector_cache.get(key) {
self.cache_stats.record_hit();
Some(index)
} else {
self.cache_stats.record_miss();
None
}
}
pub(crate) fn insert_scalar(&self, key: &str, index: Arc<dyn ScalarIndex>) {
self.scalar_cache.insert(key.to_string(), index);
}
pub(crate) fn insert_vector(&self, key: &str, index: Arc<dyn VectorIndex>) {
self.vector_cache.insert(key.to_string(), index);
}
fn metadata_key(dataset_uuid: &str, version: u64) -> String {
format!("{}:{}", dataset_uuid, version)
}
pub(crate) fn get_metadata(&self, key: &str, version: u64) -> Option<Arc<Vec<Index>>> {
let key = Self::metadata_key(key, version);
if let Some(indices) = self.metadata_cache.get(&key) {
self.cache_stats.record_hit();
Some(indices)
} else {
self.cache_stats.record_miss();
None
}
}
pub(crate) fn insert_metadata(&self, key: &str, version: u64, indices: Arc<Vec<Index>>) {
let key = Self::metadata_key(key, version);
self.metadata_cache.insert(key, indices);
}
#[allow(dead_code)]
pub(crate) fn hit_rate(&self) -> f32 {
let hits = self.cache_stats.hits.load(Ordering::Relaxed) as f32;
let misses = self.cache_stats.misses.load(Ordering::Relaxed) as f32;
if (hits + misses) == 0.0 {
1.0
} else {
hits / (hits + misses)
}
}
}