use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct VectorCacheConfig {
pub max_entries: usize,
pub max_memory_bytes: usize,
pub default_ttl: Option<Duration>,
}
impl Default for VectorCacheConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
max_memory_bytes: 0,
default_ttl: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStatistics {
pub hits: u64,
pub misses: u64,
pub inserts: u64,
pub evictions: u64,
pub expirations: u64,
pub invalidations: u64,
}
impl CacheStatistics {
pub fn hit_ratio(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
return 0.0;
}
self.hits as f64 / total as f64
}
pub fn total_requests(&self) -> u64 {
self.hits + self.misses
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
vector: Vec<f64>,
inserted_at: Instant,
last_accessed: Instant,
ttl: Option<Duration>,
}
impl CacheEntry {
fn memory_bytes(&self) -> usize {
std::mem::size_of::<Self>() + self.vector.len() * std::mem::size_of::<f64>()
}
fn is_expired(&self, default_ttl: Option<Duration>) -> bool {
let ttl = self.ttl.or(default_ttl);
if let Some(duration) = ttl {
self.inserted_at.elapsed() > duration
} else {
false
}
}
}
#[derive(Debug, Clone)]
pub struct CacheSnapshot {
pub entries: Vec<(String, Vec<f64>)>,
pub entry_count: usize,
}
pub struct VectorCache {
store: HashMap<String, CacheEntry>,
lru_order: VecDeque<String>,
config: VectorCacheConfig,
stats: CacheStatistics,
}
impl VectorCache {
pub fn new() -> Self {
Self::with_config(VectorCacheConfig::default())
}
pub fn with_config(config: VectorCacheConfig) -> Self {
Self {
store: HashMap::new(),
lru_order: VecDeque::new(),
config,
stats: CacheStatistics::default(),
}
}
pub fn get(&mut self, key: &str) -> Option<Vec<f64>> {
if let Some(entry) = self.store.get(key) {
if entry.is_expired(self.config.default_ttl) {
let key_owned = key.to_string();
self.store.remove(&key_owned);
self.lru_order.retain(|k| k != &key_owned);
self.stats.expirations += 1;
self.stats.misses += 1;
return None;
}
} else {
self.stats.misses += 1;
return None;
}
self.touch(key);
self.stats.hits += 1;
if let Some(entry) = self.store.get_mut(key) {
entry.last_accessed = Instant::now();
Some(entry.vector.clone())
} else {
None
}
}
pub fn put(&mut self, key: impl Into<String>, vector: Vec<f64>) {
self.put_with_ttl(key, vector, None);
}
pub fn put_with_ttl(
&mut self,
key: impl Into<String>,
vector: Vec<f64>,
ttl: Option<Duration>,
) {
let key = key.into();
let now = Instant::now();
let entry = CacheEntry {
vector,
inserted_at: now,
last_accessed: now,
ttl,
};
if self.store.contains_key(&key) {
self.lru_order.retain(|k| k != &key);
}
self.store.insert(key.clone(), entry);
self.lru_order.push_back(key);
self.stats.inserts += 1;
self.enforce_entry_limit();
self.enforce_memory_limit();
}
pub fn batch_get(&mut self, keys: &[&str]) -> HashMap<String, Vec<f64>> {
let mut result = HashMap::new();
for &key in keys {
if let Some(vec) = self.get(key) {
result.insert(key.to_string(), vec);
}
}
result
}
pub fn batch_put(&mut self, entries: Vec<(String, Vec<f64>)>) {
for (key, vector) in entries {
self.put(key, vector);
}
}
pub fn warm(&mut self, entries: Vec<(String, Vec<f64>)>) -> usize {
let mut loaded = 0;
for (key, vector) in entries {
if !self.store.contains_key(&key) {
self.put(key, vector);
loaded += 1;
}
}
loaded
}
pub fn invalidate(&mut self, key: &str) -> bool {
if self.store.remove(key).is_some() {
self.lru_order.retain(|k| k != key);
self.stats.invalidations += 1;
true
} else {
false
}
}
pub fn invalidate_prefix(&mut self, prefix: &str) -> usize {
let keys_to_remove: Vec<String> = self
.store
.keys()
.filter(|k| k.starts_with(prefix))
.cloned()
.collect();
let count = keys_to_remove.len();
for key in &keys_to_remove {
self.store.remove(key);
}
self.lru_order.retain(|k| !k.starts_with(prefix));
self.stats.invalidations += count as u64;
count
}
pub fn clear(&mut self) {
let count = self.store.len() as u64;
self.store.clear();
self.lru_order.clear();
self.stats.invalidations += count;
}
pub fn statistics(&self) -> &CacheStatistics {
&self.stats
}
pub fn reset_statistics(&mut self) {
self.stats = CacheStatistics::default();
}
pub fn len(&self) -> usize {
self.store.len()
}
pub fn is_empty(&self) -> bool {
self.store.is_empty()
}
pub fn memory_usage(&self) -> usize {
self.store.values().map(|e| e.memory_bytes()).sum::<usize>()
+ self.lru_order.len() * std::mem::size_of::<String>()
}
pub fn contains_key(&self, key: &str) -> bool {
self.store.contains_key(key)
}
pub fn snapshot(&self) -> CacheSnapshot {
let entries: Vec<(String, Vec<f64>)> = self
.store
.iter()
.filter(|(_, e)| !e.is_expired(self.config.default_ttl))
.map(|(k, e)| (k.clone(), e.vector.clone()))
.collect();
let entry_count = entries.len();
CacheSnapshot {
entries,
entry_count,
}
}
pub fn load_snapshot(&mut self, snapshot: CacheSnapshot) -> usize {
let count = snapshot.entries.len();
for (key, vector) in snapshot.entries {
self.put(key, vector);
}
count
}
pub fn sweep_expired(&mut self) -> usize {
let expired_keys: Vec<String> = self
.store
.iter()
.filter(|(_, e)| e.is_expired(self.config.default_ttl))
.map(|(k, _)| k.clone())
.collect();
let count = expired_keys.len();
for key in &expired_keys {
self.store.remove(key);
}
self.lru_order.retain(|k| !expired_keys.contains(k));
self.stats.expirations += count as u64;
count
}
fn touch(&mut self, key: &str) {
self.lru_order.retain(|k| k != key);
self.lru_order.push_back(key.to_string());
}
fn enforce_entry_limit(&mut self) {
if self.config.max_entries == 0 {
return;
}
while self.store.len() > self.config.max_entries {
if let Some(oldest) = self.lru_order.pop_front() {
self.store.remove(&oldest);
self.stats.evictions += 1;
} else {
break;
}
}
}
fn enforce_memory_limit(&mut self) {
if self.config.max_memory_bytes == 0 {
return;
}
while self.memory_usage() > self.config.max_memory_bytes {
if let Some(oldest) = self.lru_order.pop_front() {
self.store.remove(&oldest);
self.stats.evictions += 1;
} else {
break;
}
}
}
}
impl Default for VectorCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn vec3(a: f64, b: f64, c: f64) -> Vec<f64> {
vec![a, b, c]
}
#[test]
fn test_put_and_get() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
let v = cache.get("k1");
assert!(v.is_some());
assert_eq!(v.expect("should exist"), vec3(1.0, 2.0, 3.0));
}
#[test]
fn test_get_miss() {
let mut cache = VectorCache::new();
assert!(cache.get("nonexistent").is_none());
assert_eq!(cache.statistics().misses, 1);
}
#[test]
fn test_put_overwrite() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
cache.put("k1", vec3(4.0, 5.0, 6.0));
let v = cache.get("k1");
assert_eq!(v.expect("should exist"), vec3(4.0, 5.0, 6.0));
assert_eq!(cache.len(), 1);
}
#[test]
fn test_lru_eviction() {
let config = VectorCacheConfig {
max_entries: 2,
..Default::default()
};
let mut cache = VectorCache::with_config(config);
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
cache.put("c", vec3(0.0, 0.0, 1.0));
assert!(cache.get("a").is_none());
assert!(cache.get("b").is_some());
assert!(cache.get("c").is_some());
assert_eq!(cache.statistics().evictions, 1);
}
#[test]
fn test_lru_access_refreshes() {
let config = VectorCacheConfig {
max_entries: 2,
..Default::default()
};
let mut cache = VectorCache::with_config(config);
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
let _ = cache.get("a");
cache.put("c", vec3(0.0, 0.0, 1.0));
assert!(cache.get("a").is_some());
assert!(cache.get("b").is_none());
assert!(cache.get("c").is_some());
}
#[test]
fn test_memory_limit_eviction() {
let entry_size = std::mem::size_of::<CacheEntry>() + 3 * std::mem::size_of::<f64>();
let config = VectorCacheConfig {
max_entries: 0, max_memory_bytes: entry_size * 2 + 100, default_ttl: None,
};
let mut cache = VectorCache::with_config(config);
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
cache.put("c", vec3(0.0, 0.0, 1.0));
assert!(cache.len() <= 3);
let _ = cache.statistics().evictions;
}
#[test]
fn test_ttl_expiry() {
let config = VectorCacheConfig {
default_ttl: Some(Duration::from_millis(1)),
..Default::default()
};
let mut cache = VectorCache::with_config(config);
cache.put("k1", vec3(1.0, 2.0, 3.0));
std::thread::sleep(Duration::from_millis(10));
assert!(cache.get("k1").is_none());
assert!(cache.statistics().expirations >= 1);
}
#[test]
fn test_per_entry_ttl() {
let mut cache = VectorCache::new();
cache.put_with_ttl("k1", vec3(1.0, 2.0, 3.0), Some(Duration::from_millis(1)));
cache.put("k2", vec3(4.0, 5.0, 6.0));
std::thread::sleep(Duration::from_millis(10));
assert!(cache.get("k1").is_none()); assert!(cache.get("k2").is_some()); }
#[test]
fn test_sweep_expired() {
let config = VectorCacheConfig {
default_ttl: Some(Duration::from_millis(1)),
..Default::default()
};
let mut cache = VectorCache::with_config(config);
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
std::thread::sleep(Duration::from_millis(10));
let swept = cache.sweep_expired();
assert_eq!(swept, 2);
assert!(cache.is_empty());
}
#[test]
fn test_invalidate_key() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
assert!(cache.invalidate("k1"));
assert!(cache.get("k1").is_none());
assert_eq!(cache.statistics().invalidations, 1);
}
#[test]
fn test_invalidate_nonexistent() {
let mut cache = VectorCache::new();
assert!(!cache.invalidate("nope"));
}
#[test]
fn test_invalidate_prefix() {
let mut cache = VectorCache::new();
cache.put("user:1", vec3(1.0, 0.0, 0.0));
cache.put("user:2", vec3(0.0, 1.0, 0.0));
cache.put("item:1", vec3(0.0, 0.0, 1.0));
let removed = cache.invalidate_prefix("user:");
assert_eq!(removed, 2);
assert_eq!(cache.len(), 1);
assert!(cache.contains_key("item:1"));
}
#[test]
fn test_clear() {
let mut cache = VectorCache::new();
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_batch_put_and_get() {
let mut cache = VectorCache::new();
cache.batch_put(vec![
("k1".to_string(), vec3(1.0, 0.0, 0.0)),
("k2".to_string(), vec3(0.0, 1.0, 0.0)),
("k3".to_string(), vec3(0.0, 0.0, 1.0)),
]);
assert_eq!(cache.len(), 3);
let results = cache.batch_get(&["k1", "k3", "missing"]);
assert_eq!(results.len(), 2);
assert!(results.contains_key("k1"));
assert!(results.contains_key("k3"));
}
#[test]
fn test_warm() {
let mut cache = VectorCache::new();
cache.put("existing", vec3(9.0, 9.0, 9.0));
let loaded = cache.warm(vec![
("existing".to_string(), vec3(0.0, 0.0, 0.0)), ("new1".to_string(), vec3(1.0, 0.0, 0.0)),
("new2".to_string(), vec3(0.0, 1.0, 0.0)),
]);
assert_eq!(loaded, 2);
assert_eq!(cache.len(), 3);
let v = cache.get("existing").expect("should exist");
assert_eq!(v, vec3(9.0, 9.0, 9.0));
}
#[test]
fn test_hit_ratio() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
let _ = cache.get("k1"); let _ = cache.get("k2");
let stats = cache.statistics();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!((stats.hit_ratio() - 0.5).abs() < f64::EPSILON);
assert_eq!(stats.total_requests(), 2);
}
#[test]
fn test_hit_ratio_no_requests() {
let cache = VectorCache::new();
assert!((cache.statistics().hit_ratio() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_reset_statistics() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
let _ = cache.get("k1");
cache.reset_statistics();
assert_eq!(cache.statistics().hits, 0);
assert_eq!(cache.statistics().inserts, 0);
}
#[test]
fn test_snapshot_and_load() {
let mut cache = VectorCache::new();
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
let snap = cache.snapshot();
assert_eq!(snap.entry_count, 2);
let mut cache2 = VectorCache::new();
let loaded = cache2.load_snapshot(snap);
assert_eq!(loaded, 2);
assert!(cache2.get("a").is_some());
assert!(cache2.get("b").is_some());
}
#[test]
fn test_snapshot_excludes_expired() {
let config = VectorCacheConfig {
default_ttl: Some(Duration::from_millis(1)),
..Default::default()
};
let mut cache = VectorCache::with_config(config);
cache.put("x", vec3(1.0, 2.0, 3.0));
std::thread::sleep(Duration::from_millis(10));
let snap = cache.snapshot();
assert_eq!(snap.entry_count, 0);
}
#[test]
fn test_len_and_empty() {
let mut cache = VectorCache::new();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
cache.put("k1", vec3(1.0, 2.0, 3.0));
assert!(!cache.is_empty());
assert_eq!(cache.len(), 1);
}
#[test]
fn test_contains_key() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
assert!(cache.contains_key("k1"));
assert!(!cache.contains_key("k2"));
}
#[test]
fn test_memory_usage_grows() {
let mut cache = VectorCache::new();
let m0 = cache.memory_usage();
cache.put("k1", vec![0.0; 100]);
let m1 = cache.memory_usage();
assert!(m1 > m0);
}
#[test]
fn test_default_config() {
let c = VectorCacheConfig::default();
assert_eq!(c.max_entries, 10_000);
assert_eq!(c.max_memory_bytes, 0);
assert!(c.default_ttl.is_none());
}
#[test]
fn test_default_cache() {
let cache = VectorCache::default();
assert!(cache.is_empty());
}
#[test]
fn test_empty_vector() {
let mut cache = VectorCache::new();
cache.put("empty", vec![]);
let v = cache.get("empty");
assert_eq!(v.expect("should exist"), Vec::<f64>::new());
}
#[test]
fn test_large_vector() {
let mut cache = VectorCache::new();
let big = vec![1.0; 10_000];
cache.put("big", big.clone());
let v = cache.get("big").expect("should exist");
assert_eq!(v.len(), 10_000);
}
#[test]
fn test_invalidate_prefix_no_match() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
let removed = cache.invalidate_prefix("zzz:");
assert_eq!(removed, 0);
assert_eq!(cache.len(), 1);
}
#[test]
fn test_multiple_evictions() {
let config = VectorCacheConfig {
max_entries: 3,
..Default::default()
};
let mut cache = VectorCache::with_config(config);
for i in 0..10 {
cache.put(format!("k{i}"), vec![i as f64]);
}
assert_eq!(cache.len(), 3);
assert!(cache.statistics().evictions >= 7);
}
#[test]
fn test_batch_get_all_miss() {
let mut cache = VectorCache::new();
let results = cache.batch_get(&["a", "b", "c"]);
assert!(results.is_empty());
assert_eq!(cache.statistics().misses, 3);
}
#[test]
fn test_batch_put_then_invalidate() {
let mut cache = VectorCache::new();
cache.batch_put(vec![
("a".to_string(), vec3(1.0, 0.0, 0.0)),
("b".to_string(), vec3(0.0, 1.0, 0.0)),
]);
cache.invalidate("a");
assert!(!cache.contains_key("a"));
assert!(cache.contains_key("b"));
}
#[test]
fn test_warm_empty_list() {
let mut cache = VectorCache::new();
let loaded = cache.warm(vec![]);
assert_eq!(loaded, 0);
}
#[test]
fn test_snapshot_empty_cache() {
let cache = VectorCache::new();
let snap = cache.snapshot();
assert_eq!(snap.entry_count, 0);
assert!(snap.entries.is_empty());
}
#[test]
fn test_load_snapshot_into_non_empty_cache() {
let mut cache1 = VectorCache::new();
cache1.put("a", vec3(1.0, 0.0, 0.0));
let snap = cache1.snapshot();
let mut cache2 = VectorCache::new();
cache2.put("b", vec3(0.0, 1.0, 0.0));
cache2.load_snapshot(snap);
assert!(cache2.contains_key("a"));
assert!(cache2.contains_key("b"));
assert_eq!(cache2.len(), 2);
}
#[test]
fn test_put_updates_insert_count() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 0.0, 0.0));
cache.put("k2", vec3(0.0, 1.0, 0.0));
assert_eq!(cache.statistics().inserts, 2);
}
#[test]
fn test_clear_resets_len() {
let mut cache = VectorCache::new();
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
cache.put("c", vec3(0.0, 0.0, 1.0));
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_stats_invalidation_count() {
let mut cache = VectorCache::new();
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
cache.invalidate("a");
cache.invalidate("b");
assert_eq!(cache.statistics().invalidations, 2);
}
#[test]
fn test_get_after_clear() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
cache.clear();
assert!(cache.get("k1").is_none());
}
#[test]
fn test_sweep_expired_none_expired() {
let mut cache = VectorCache::new();
cache.put("k1", vec3(1.0, 2.0, 3.0));
let swept = cache.sweep_expired();
assert_eq!(swept, 0);
assert_eq!(cache.len(), 1);
}
#[test]
fn test_contains_key_after_eviction() {
let config = VectorCacheConfig {
max_entries: 1,
..Default::default()
};
let mut cache = VectorCache::with_config(config);
cache.put("first", vec3(1.0, 0.0, 0.0));
cache.put("second", vec3(0.0, 1.0, 0.0));
assert!(!cache.contains_key("first"));
assert!(cache.contains_key("second"));
}
#[test]
fn test_invalidate_prefix_all() {
let mut cache = VectorCache::new();
cache.put("x:1", vec3(1.0, 0.0, 0.0));
cache.put("x:2", vec3(0.0, 1.0, 0.0));
cache.put("x:3", vec3(0.0, 0.0, 1.0));
let removed = cache.invalidate_prefix("x:");
assert_eq!(removed, 3);
assert!(cache.is_empty());
}
#[test]
fn test_memory_usage_after_clear() {
let mut cache = VectorCache::new();
cache.put("k1", vec![0.0; 1000]);
let before = cache.memory_usage();
assert!(before > 0);
cache.clear();
assert_eq!(cache.memory_usage(), 0);
}
#[test]
fn test_put_with_zero_ttl() {
let mut cache = VectorCache::new();
cache.put_with_ttl("k1", vec3(1.0, 2.0, 3.0), Some(Duration::from_secs(0)));
std::thread::sleep(Duration::from_millis(2));
assert!(cache.get("k1").is_none());
}
#[test]
fn test_lru_order_after_overwrite() {
let config = VectorCacheConfig {
max_entries: 2,
..Default::default()
};
let mut cache = VectorCache::with_config(config);
cache.put("a", vec3(1.0, 0.0, 0.0));
cache.put("b", vec3(0.0, 1.0, 0.0));
cache.put("a", vec3(9.0, 9.0, 9.0));
cache.put("c", vec3(0.0, 0.0, 1.0));
assert!(cache.contains_key("a"));
assert!(!cache.contains_key("b"));
assert!(cache.contains_key("c"));
}
}