use core::time::Duration;
use std::sync::Arc;
use clock_lib::{Clock, Monotonic, SystemClock};
use iqdb_index::{IndexCore, IndexStats};
use iqdb_types::{DistanceMetric, Hit, Metadata, Result, SearchParams, VectorId};
use crate::config::{CacheConfig, EvictionPolicy};
use crate::key::ResultKey;
use crate::policy::PolicyCache;
use crate::stats::CacheStats;
use crate::sync::{AtomicU64, Mutex, MutexGuard, Ordering};
struct CacheEntry {
hits: Box<[Hit]>,
stamp: Option<Monotonic>,
}
pub struct CachedIndex<I> {
inner: I,
cache: Mutex<PolicyCache<ResultKey, CacheEntry>>,
capacity: usize,
policy: EvictionPolicy,
ttl: Option<Duration>,
clock: Arc<dyn Clock>,
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
}
impl<I: IndexCore> CachedIndex<I> {
#[must_use]
pub fn new(inner: I) -> Self {
Self::with_config(inner, CacheConfig::new())
}
#[must_use]
pub fn with_capacity(inner: I, capacity: usize) -> Self {
Self::with_config(inner, CacheConfig::new().capacity(capacity))
}
#[must_use]
pub fn with_config(inner: I, config: CacheConfig) -> Self {
Self::with_config_in(inner, config, Arc::new(SystemClock::new()))
}
pub(crate) fn with_config_in(inner: I, config: CacheConfig, clock: Arc<dyn Clock>) -> Self {
Self {
inner,
cache: Mutex::new(PolicyCache::new(config.policy, config.capacity)),
capacity: config.capacity,
policy: config.policy,
ttl: config.ttl,
clock,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
evictions: AtomicU64::new(0),
}
}
#[inline]
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline]
#[must_use]
pub fn ttl(&self) -> Option<Duration> {
self.ttl
}
#[inline]
#[must_use]
pub fn policy(&self) -> EvictionPolicy {
self.policy
}
#[inline]
#[must_use]
pub fn is_enabled(&self) -> bool {
self.capacity > 0
}
#[inline]
#[must_use]
pub fn get_ref(&self) -> &I {
&self.inner
}
#[must_use]
pub fn into_inner(self) -> I {
self.inner
}
pub fn clear_cache(&mut self) {
self.lock_cache().clear();
}
#[must_use]
pub fn cache_stats(&self) -> CacheStats {
let len = self.lock_cache().len();
CacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
len,
capacity: self.capacity,
}
}
fn lock_cache(&self) -> MutexGuard<'_, PolicyCache<ResultKey, CacheEntry>> {
self.cache
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[inline]
fn is_live(&self, entry: &CacheEntry) -> bool {
match (self.ttl, entry.stamp) {
(Some(ttl), Some(stamp)) => self.clock.now().saturating_duration_since(stamp) < ttl,
_ => true,
}
}
fn invalidate(&mut self) {
self.lock_cache().clear();
}
}
impl<I: IndexCore> IndexCore for CachedIndex<I> {
fn insert(
&mut self,
id: VectorId,
vector: std::sync::Arc<[f32]>,
metadata: Option<Metadata>,
) -> Result<()> {
let result = self.inner.insert(id, vector, metadata);
if result.is_ok() {
self.invalidate();
}
result
}
fn insert_batch(
&mut self,
items: Vec<(VectorId, std::sync::Arc<[f32]>, Option<Metadata>)>,
) -> Result<()> {
let result = self.inner.insert_batch(items);
self.invalidate();
result
}
fn delete(&mut self, id: &VectorId) -> Result<()> {
let result = self.inner.delete(id);
if result.is_ok() {
self.invalidate();
}
result
}
fn search(&self, query: &[f32], params: &SearchParams) -> Result<Vec<Hit>> {
if self.capacity == 0 {
let _ = self.misses.fetch_add(1, Ordering::Relaxed);
return self.inner.search(query, params);
}
let key = ResultKey::new(query, params);
{
let mut cache = self.lock_cache();
if let Some(entry) = cache.get(&key) {
if self.is_live(entry) {
let _ = self.hits.fetch_add(1, Ordering::Relaxed);
return Ok(entry.hits.to_vec());
}
}
}
let hits = self.inner.search(query, params)?;
let _ = self.misses.fetch_add(1, Ordering::Relaxed);
let stamp = self.ttl.map(|_| self.clock.now());
let evicted = {
let mut cache = self.lock_cache();
cache.put(
key,
CacheEntry {
hits: hits.clone().into_boxed_slice(),
stamp,
},
)
};
if evicted {
let _ = self.evictions.fetch_add(1, Ordering::Relaxed);
}
Ok(hits)
}
fn len(&self) -> usize {
self.inner.len()
}
fn is_empty(&self) -> bool {
self.inner.is_empty()
}
fn dim(&self) -> usize {
self.inner.dim()
}
fn metric(&self) -> DistanceMetric {
self.inner.metric()
}
fn flush(&mut self) -> Result<()> {
self.inner.flush()
}
fn stats(&self) -> IndexStats {
self.inner.stats()
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use clock_lib::ManualClock;
use super::*;
use crate::doc_stub::stub_index;
fn params() -> SearchParams {
SearchParams::new(1, DistanceMetric::Cosine)
}
#[test]
fn ttl_entry_is_recomputed_after_expiry() {
let clock = Arc::new(ManualClock::new());
let config = CacheConfig::new().capacity(8).ttl(Duration::from_secs(60));
let cached = CachedIndex::with_config_in(stub_index(), config, clock.clone());
let _miss = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
let _hit = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
assert_eq!(cached.cache_stats().hits, 1);
clock.advance(Duration::from_secs(59));
let _hit2 = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
assert_eq!(cached.cache_stats().hits, 2);
clock.advance(Duration::from_secs(2));
let _expired = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
assert_eq!(cached.cache_stats().hits, 2);
assert_eq!(cached.cache_stats().misses, 2);
let _hit3 = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
assert_eq!(cached.cache_stats().hits, 3);
}
#[test]
fn ttl_boundary_is_exclusive() {
let clock = Arc::new(ManualClock::new());
let config = CacheConfig::new().capacity(8).ttl(Duration::from_secs(10));
let cached = CachedIndex::with_config_in(stub_index(), config, clock.clone());
let _miss = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
clock.advance(Duration::from_secs(10));
let _again = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
assert_eq!(cached.cache_stats().hits, 0);
assert_eq!(cached.cache_stats().misses, 2);
}
#[test]
fn no_ttl_never_expires_even_as_time_passes() {
let clock = Arc::new(ManualClock::new());
let config = CacheConfig::new().capacity(8); let cached = CachedIndex::with_config_in(stub_index(), config, clock.clone());
let _miss = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
let _hit = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
clock.advance(Duration::from_secs(60 * 60 * 24 * 365));
let _hit2 = cached.search(&[1.0, 0.0, 0.0], ¶ms()).unwrap();
assert_eq!(cached.cache_stats().hits, 2);
assert_eq!(cached.cache_stats().misses, 1);
}
}