use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tracing::{debug, trace};
#[derive(Debug, thiserror::Error)]
pub enum CacheError {
#[error("Cache is full: {0} entries, {1} bytes")]
CacheFull(usize, usize),
#[error("Entry too large: {0} bytes exceeds max {1}")]
EntryTooLarge(usize, usize),
#[error("Cache lock poisoned")]
LockPoisoned,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum CacheEvictionPolicy {
#[default]
Lru,
Lfu,
Ttl,
Size,
}
#[derive(Debug, Clone)]
pub struct VersionedCacheConfig {
pub max_entries: usize,
pub max_size_bytes: usize,
pub default_ttl: Option<Duration>,
pub eviction_policy: CacheEvictionPolicy,
}
impl Default for VersionedCacheConfig {
fn default() -> Self {
Self {
max_entries: 1024,
max_size_bytes: 512 * 1024 * 1024, default_ttl: None,
eviction_policy: CacheEvictionPolicy::Lru,
}
}
}
#[derive(Debug)]
pub struct VersionedCacheEntry<V> {
pub value: V,
pub version: String,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
pub ttl: Option<Duration>,
pub size_bytes: usize,
}
impl<V> VersionedCacheEntry<V> {
pub fn is_expired(&self) -> bool {
match self.ttl {
None => false,
Some(ttl) => self.created_at.elapsed() > ttl,
}
}
pub fn is_valid_version(&self, required_version: Option<&str>) -> bool {
match required_version {
None => true,
Some(req) => self.version == req,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct VersionedCacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub ttl_expirations: u64,
pub version_mismatches: u64,
pub total_size_bytes: usize,
pub entry_count: usize,
}
impl VersionedCacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
pub struct VersionedCache<K, V>
where
K: Hash + Eq + Clone,
{
entries: Arc<RwLock<HashMap<K, VersionedCacheEntry<V>>>>,
access_order: Arc<RwLock<Vec<K>>>,
config: VersionedCacheConfig,
stats: Arc<RwLock<VersionedCacheStats>>,
}
impl<K, V> VersionedCache<K, V>
where
K: Hash + Eq + Clone + std::fmt::Debug,
V: Clone,
{
pub fn new(config: VersionedCacheConfig) -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
access_order: Arc::new(RwLock::new(Vec::new())),
config,
stats: Arc::new(RwLock::new(VersionedCacheStats::default())),
}
}
fn update_access_order(&self, key: &K) {
if let Ok(mut order) = self.access_order.write() {
order.retain(|k| k != key);
order.push(key.clone());
}
}
fn remove_from_access_order(&self, key: &K) {
if let Ok(mut order) = self.access_order.write() {
order.retain(|k| k != key);
}
}
fn evict_one(&self) -> bool {
let entries_guard = match self.entries.write() {
Ok(g) => g,
Err(_) => return false,
};
let expired_key: Option<K> = entries_guard
.iter()
.find(|(_, e)| e.is_expired())
.map(|(k, _)| k.clone());
if let Some(key) = expired_key {
drop(entries_guard);
self.remove_entry_internal(&key, true);
return true;
}
let victim_key: Option<K> = match self.config.eviction_policy {
CacheEvictionPolicy::Ttl => {
None
}
CacheEvictionPolicy::Lru => {
let order = match self.access_order.read() {
Ok(g) => g,
Err(_) => return false,
};
order.first().cloned()
}
CacheEvictionPolicy::Lfu => entries_guard
.iter()
.min_by_key(|(_, e)| e.access_count)
.map(|(k, _)| k.clone()),
CacheEvictionPolicy::Size => entries_guard
.iter()
.max_by_key(|(_, e)| e.size_bytes)
.map(|(k, _)| k.clone()),
};
drop(entries_guard);
if let Some(key) = victim_key {
self.remove_entry_internal(&key, false);
true
} else {
false
}
}
fn remove_entry_internal(&self, key: &K, is_ttl_expiry: bool) {
let removed_size = {
let mut entries = match self.entries.write() {
Ok(g) => g,
Err(_) => return,
};
entries.remove(key).map(|e| e.size_bytes).unwrap_or(0)
};
self.remove_from_access_order(key);
if let Ok(mut stats) = self.stats.write() {
if removed_size > 0 {
stats.total_size_bytes = stats.total_size_bytes.saturating_sub(removed_size);
stats.entry_count = stats.entry_count.saturating_sub(1);
if is_ttl_expiry {
stats.ttl_expirations += 1;
} else {
stats.evictions += 1;
}
}
}
}
pub fn insert(
&self,
key: K,
value: V,
version: impl Into<String>,
size_bytes: usize,
ttl_override: Option<Duration>,
) -> Result<(), CacheError> {
if self.config.max_size_bytes > 0 && size_bytes > self.config.max_size_bytes {
return Err(CacheError::EntryTooLarge(
size_bytes,
self.config.max_size_bytes,
));
}
loop {
let (entry_count, total_bytes) = {
let s = self.stats.read().map_err(|_| CacheError::LockPoisoned)?;
(s.entry_count, s.total_size_bytes)
};
let over_entries = self.config.max_entries > 0 && entry_count >= self.config.max_entries;
let over_bytes =
self.config.max_size_bytes > 0 && total_bytes + size_bytes > self.config.max_size_bytes;
if !over_entries && !over_bytes {
break;
}
if !self.evict_one() {
return Err(CacheError::CacheFull(entry_count, total_bytes));
}
}
let effective_ttl = ttl_override.or(self.config.default_ttl);
let now = Instant::now();
let entry = VersionedCacheEntry {
value,
version: version.into(),
created_at: now,
last_accessed: now,
access_count: 0,
ttl: effective_ttl,
size_bytes,
};
{
let mut entries = self.entries.write().map_err(|_| CacheError::LockPoisoned)?;
if let Some(old) = entries.remove(&key) {
if let Ok(mut stats) = self.stats.write() {
stats.total_size_bytes =
stats.total_size_bytes.saturating_sub(old.size_bytes);
stats.entry_count = stats.entry_count.saturating_sub(1);
}
self.remove_from_access_order(&key);
}
entries.insert(key.clone(), entry);
}
self.update_access_order(&key);
if let Ok(mut stats) = self.stats.write() {
stats.total_size_bytes += size_bytes;
stats.entry_count += 1;
}
trace!(key = ?key, "Inserted entry into versioned cache");
Ok(())
}
pub fn get(&self, key: &K, required_version: Option<&str>) -> Option<V> {
let (value, valid) = {
let entries = self.entries.read().ok()?;
match entries.get(key) {
None => (None, false),
Some(entry) => {
if entry.is_expired() || !entry.is_valid_version(required_version) {
(None, false)
} else {
(Some(entry.value.clone()), true)
}
}
}
};
if !valid {
let is_version_mismatch = {
let entries = self.entries.read().ok()?;
entries.get(key).is_some_and(|e| {
!e.is_expired() && !e.is_valid_version(required_version)
})
};
if let Ok(mut stats) = self.stats.write() {
stats.misses += 1;
if is_version_mismatch {
stats.version_mismatches += 1;
}
}
{
let expired = self
.entries
.read()
.ok()?
.get(key)
.is_some_and(|e| e.is_expired());
if expired {
self.remove_entry_internal(key, true);
}
}
return None;
}
if let Ok(mut entries) = self.entries.write() {
if let Some(entry) = entries.get_mut(key) {
entry.last_accessed = Instant::now();
entry.access_count += 1;
}
}
self.update_access_order(key);
if let Ok(mut stats) = self.stats.write() {
stats.hits += 1;
}
debug!(key = ?key, "Cache hit");
value
}
pub fn contains(&self, key: &K) -> bool {
self.entries
.read()
.ok()
.and_then(|e| e.get(key).map(|entry| !entry.is_expired()))
.unwrap_or(false)
}
pub fn remove(&self, key: &K) -> bool {
let existed = {
let entries = match self.entries.read() {
Ok(g) => g,
Err(_) => return false,
};
entries.contains_key(key)
};
if existed {
self.remove_entry_internal(key, false);
if let Ok(mut stats) = self.stats.write() {
stats.evictions = stats.evictions.saturating_sub(1);
}
}
existed
}
pub fn evict_expired(&self) -> usize {
let expired_keys: Vec<K> = {
let entries = match self.entries.read() {
Ok(g) => g,
Err(_) => return 0,
};
entries
.iter()
.filter(|(_, e)| e.is_expired())
.map(|(k, _)| k.clone())
.collect()
};
let count = expired_keys.len();
for key in expired_keys {
self.remove_entry_internal(&key, true);
}
count
}
pub fn clear(&self) {
if let Ok(mut entries) = self.entries.write() {
entries.clear();
}
if let Ok(mut order) = self.access_order.write() {
order.clear();
}
if let Ok(mut stats) = self.stats.write() {
*stats = VersionedCacheStats::default();
}
}
pub fn stats(&self) -> VersionedCacheStats {
self.stats
.read()
.map(|s| s.clone())
.unwrap_or_default()
}
pub fn len(&self) -> usize {
self.entries
.read()
.map(|e| e.values().filter(|v| !v.is_expired()).count())
.unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn invalidate_version(&self, version: &str) -> usize {
let keys_to_remove: Vec<K> = {
let entries = match self.entries.read() {
Ok(g) => g,
Err(_) => return 0,
};
entries
.iter()
.filter(|(_, e)| e.version == version)
.map(|(k, _)| k.clone())
.collect()
};
let count = keys_to_remove.len();
for key in keys_to_remove {
self.remove_entry_internal(&key, false);
if let Ok(mut stats) = self.stats.write() {
stats.evictions = stats.evictions.saturating_sub(1);
}
}
count
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn default_cache() -> VersionedCache<String, Vec<u8>> {
VersionedCache::new(VersionedCacheConfig::default())
}
fn small_cache(max_entries: usize) -> VersionedCache<String, Vec<u8>> {
VersionedCache::new(VersionedCacheConfig {
max_entries,
max_size_bytes: 0, default_ttl: None,
eviction_policy: CacheEvictionPolicy::Lru,
})
}
#[test]
fn test_insert_and_get() {
let cache = default_cache();
cache
.insert("key1".to_string(), vec![1, 2, 3], "1.0", 3, None)
.unwrap();
let v = cache.get(&"key1".to_string(), None).unwrap();
assert_eq!(v, vec![1, 2, 3]);
}
#[test]
fn test_miss_returns_none() {
let cache = default_cache();
assert!(cache.get(&"missing".to_string(), None).is_none());
}
#[test]
fn test_version_mismatch_returns_none() {
let cache = default_cache();
cache
.insert("k".to_string(), vec![0], "1.0", 1, None)
.unwrap();
assert!(cache.get(&"k".to_string(), Some("2.0")).is_none());
}
#[test]
fn test_version_match_returns_value() {
let cache = default_cache();
cache
.insert("k".to_string(), vec![42], "1.0", 1, None)
.unwrap();
let v = cache.get(&"k".to_string(), Some("1.0")).unwrap();
assert_eq!(v, vec![42]);
}
#[test]
fn test_ttl_expiry() {
let cache: VersionedCache<String, u32> = VersionedCache::new(VersionedCacheConfig {
default_ttl: Some(Duration::from_millis(10)),
..Default::default()
});
cache.insert("k".to_string(), 99, "1.0", 4, None).unwrap();
thread::sleep(Duration::from_millis(20));
assert!(cache.get(&"k".to_string(), None).is_none());
}
#[test]
fn test_ttl_override_per_entry() {
let cache: VersionedCache<String, u32> = VersionedCache::new(VersionedCacheConfig {
default_ttl: Some(Duration::from_secs(3600)), ..Default::default()
});
cache
.insert(
"k".to_string(),
99,
"1.0",
4,
Some(Duration::from_millis(10)),
)
.unwrap();
thread::sleep(Duration::from_millis(25));
assert!(cache.get(&"k".to_string(), None).is_none());
}
#[test]
fn test_contains() {
let cache = default_cache();
cache
.insert("k".to_string(), vec![], "v1", 0, None)
.unwrap();
assert!(cache.contains(&"k".to_string()));
assert!(!cache.contains(&"other".to_string()));
}
#[test]
fn test_remove() {
let cache = default_cache();
cache
.insert("k".to_string(), vec![1], "v1", 1, None)
.unwrap();
assert!(cache.remove(&"k".to_string()));
assert!(!cache.contains(&"k".to_string()));
assert!(!cache.remove(&"k".to_string())); }
#[test]
fn test_evict_expired() {
let cache: VersionedCache<String, u32> = VersionedCache::new(VersionedCacheConfig {
default_ttl: Some(Duration::from_millis(10)),
..Default::default()
});
cache.insert("a".to_string(), 1, "v", 4, None).unwrap();
cache.insert("b".to_string(), 2, "v", 4, None).unwrap();
cache
.insert(
"c".to_string(),
3,
"v",
4,
Some(Duration::from_secs(3600)),
)
.unwrap();
thread::sleep(Duration::from_millis(25));
let evicted = cache.evict_expired();
assert_eq!(evicted, 2);
assert!(cache.contains(&"c".to_string()));
}
#[test]
fn test_clear() {
let cache = default_cache();
cache
.insert("a".to_string(), vec![1], "v1", 1, None)
.unwrap();
cache
.insert("b".to_string(), vec![2], "v1", 1, None)
.unwrap();
cache.clear();
assert!(cache.is_empty());
assert_eq!(cache.stats().entry_count, 0);
}
#[test]
fn test_hit_rate() {
let cache = default_cache();
cache
.insert("k".to_string(), vec![1], "v1", 1, None)
.unwrap();
cache.get(&"k".to_string(), None);
cache.get(&"k".to_string(), None);
cache.get(&"miss".to_string(), None);
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 2.0 / 3.0).abs() < 1e-9);
}
#[test]
fn test_lru_eviction() {
let cache = small_cache(2);
cache
.insert("a".to_string(), vec![1], "v1", 1, None)
.unwrap();
cache
.insert("b".to_string(), vec![2], "v1", 1, None)
.unwrap();
cache.get(&"a".to_string(), None);
cache
.insert("c".to_string(), vec![3], "v1", 1, None)
.unwrap();
assert!(cache.contains(&"a".to_string()));
assert!(cache.contains(&"c".to_string()));
assert!(!cache.contains(&"b".to_string()));
}
#[test]
fn test_entry_too_large() {
let cache: VersionedCache<String, Vec<u8>> = VersionedCache::new(VersionedCacheConfig {
max_size_bytes: 10,
..Default::default()
});
let result = cache.insert("big".to_string(), vec![0; 100], "v1", 100, None);
assert!(matches!(result, Err(CacheError::EntryTooLarge(100, 10))));
}
#[test]
fn test_invalidate_version() {
let cache = default_cache();
cache
.insert("a".to_string(), vec![1], "1.0", 1, None)
.unwrap();
cache
.insert("b".to_string(), vec![2], "1.0", 1, None)
.unwrap();
cache
.insert("c".to_string(), vec![3], "2.0", 1, None)
.unwrap();
let invalidated = cache.invalidate_version("1.0");
assert_eq!(invalidated, 2);
assert!(!cache.contains(&"a".to_string()));
assert!(!cache.contains(&"b".to_string()));
assert!(cache.contains(&"c".to_string()));
}
#[test]
fn test_len_and_is_empty() {
let cache = default_cache();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
cache
.insert("x".to_string(), vec![1], "v", 1, None)
.unwrap();
assert!(!cache.is_empty());
assert_eq!(cache.len(), 1);
}
#[test]
fn test_stats_version_mismatch_counter() {
let cache = default_cache();
cache
.insert("k".to_string(), vec![1], "1.0", 1, None)
.unwrap();
cache.get(&"k".to_string(), Some("2.0")); let stats = cache.stats();
assert_eq!(stats.version_mismatches, 1);
}
#[test]
fn test_overwrite_same_key() {
let cache = default_cache();
cache
.insert("k".to_string(), vec![1], "1.0", 1, None)
.unwrap();
cache
.insert("k".to_string(), vec![2], "2.0", 1, None)
.unwrap();
let v = cache.get(&"k".to_string(), Some("2.0")).unwrap();
assert_eq!(v, vec![2]);
assert_eq!(cache.stats().entry_count, 1);
}
#[test]
fn test_lfu_eviction() {
let cache: VersionedCache<String, u32> = VersionedCache::new(VersionedCacheConfig {
max_entries: 2,
eviction_policy: CacheEvictionPolicy::Lfu,
..Default::default()
});
cache.insert("a".to_string(), 1, "v", 4, None).unwrap();
cache.insert("b".to_string(), 2, "v", 4, None).unwrap();
cache.get(&"a".to_string(), None);
cache.get(&"a".to_string(), None);
cache.insert("c".to_string(), 3, "v", 4, None).unwrap();
assert!(cache.contains(&"a".to_string()));
assert!(cache.contains(&"c".to_string()));
assert!(!cache.contains(&"b".to_string()));
}
#[test]
fn test_size_eviction_policy() {
let cache: VersionedCache<String, Vec<u8>> = VersionedCache::new(VersionedCacheConfig {
max_entries: 3,
max_size_bytes: 1000,
eviction_policy: CacheEvictionPolicy::Size,
..Default::default()
});
cache
.insert("small".to_string(), vec![0; 10], "v", 10, None)
.unwrap();
cache
.insert("large".to_string(), vec![0; 500], "v", 500, None)
.unwrap();
cache
.insert("medium".to_string(), vec![0; 100], "v", 100, None)
.unwrap();
cache
.insert("new".to_string(), vec![0; 400], "v", 400, None)
.unwrap();
assert!(!cache.contains(&"large".to_string()));
}
}