use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::{Duration, Instant};
use super::{
cache_key::CacheKey,
eviction::{EvictionPolicy, LRUEviction, SizeBasedEviction, TTLEviction},
metrics::CacheMetrics,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub max_entries: Option<usize>,
pub max_memory_bytes: Option<usize>,
pub ttl: Option<Duration>,
pub enable_metrics: bool,
pub compress_values: bool,
pub compression_threshold: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: Some(1000),
max_memory_bytes: Some(1024 * 1024 * 1024), ttl: Some(Duration::from_secs(3600)), enable_metrics: true,
compress_values: true,
compression_threshold: 1024, }
}
}
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub value: Vec<u8>,
pub uncompressed_size: usize,
pub is_compressed: bool,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
}
impl CacheEntry {
fn new(value: Vec<u8>, is_compressed: bool, uncompressed_size: usize) -> Self {
let now = Instant::now();
Self {
value,
uncompressed_size,
is_compressed,
created_at: now,
last_accessed: now,
access_count: 0,
}
}
fn access(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
fn memory_size(&self) -> usize {
self.value.len() + std::mem::size_of::<Self>()
}
}
pub struct InferenceCache {
cache: Arc<DashMap<CacheKey, CacheEntry>>,
eviction_policy: Arc<parking_lot::Mutex<Box<dyn EvictionPolicy>>>,
config: CacheConfig,
metrics: Option<Arc<CacheMetrics>>,
}
impl InferenceCache {
pub fn new(config: CacheConfig) -> Self {
let eviction_policy = Self::create_eviction_policy(&config);
let metrics =
if config.enable_metrics { Some(Arc::new(CacheMetrics::new())) } else { None };
Self {
cache: Arc::new(DashMap::new()),
eviction_policy: Arc::new(parking_lot::Mutex::new(eviction_policy)),
config,
metrics,
}
}
fn create_eviction_policy(config: &CacheConfig) -> Box<dyn EvictionPolicy> {
if let Some(max_bytes) = config.max_memory_bytes {
Box::new(SizeBasedEviction::new(max_bytes))
}
else if let Some(max_entries) = config.max_entries {
Box::new(LRUEviction::new(max_entries))
}
else if let Some(ttl) = config.ttl {
Box::new(TTLEviction::new(ttl))
}
else {
Box::new(LRUEviction::new(1000))
}
}
pub fn get(&self, key: &CacheKey) -> Option<Vec<u8>> {
let start = Instant::now();
if let Some(mut entry) = self.cache.get_mut(key) {
entry.access();
let value = entry.value.clone();
let is_compressed = entry.is_compressed;
drop(entry);
self.eviction_policy.lock().on_access(key);
let result = if is_compressed { self.decompress(&value).ok() } else { Some(value) };
if let Some(metrics) = &self.metrics {
let elapsed = start.elapsed();
if result.is_some() {
metrics.record_hit(elapsed);
} else {
metrics.record_miss(elapsed);
}
}
result
} else {
if let Some(metrics) = &self.metrics {
metrics.record_miss(start.elapsed());
}
None
}
}
pub fn insert(&self, key: CacheKey, value: Vec<u8>) {
let start = Instant::now();
let uncompressed_size = value.len();
let (stored_value, is_compressed) = if self.config.compress_values
&& uncompressed_size >= self.config.compression_threshold
{
match self.compress(&value) {
Ok(compressed) if compressed.len() < uncompressed_size => (compressed, true),
_ => (value, false),
}
} else {
(value, false)
};
let entry = CacheEntry::new(stored_value, is_compressed, uncompressed_size);
let memory_size = entry.memory_size();
self.cache.insert(key.clone(), entry);
self.eviction_policy.lock().on_insert(&key, memory_size);
self.maybe_evict();
if let Some(metrics) = &self.metrics {
metrics.record_insert(memory_size, start.elapsed());
}
}
pub fn remove(&self, key: &CacheKey) -> Option<Vec<u8>> {
if let Some((_, entry)) = self.cache.remove(key) {
let memory_size = entry.memory_size();
self.eviction_policy.lock().on_remove(key);
if let Some(metrics) = &self.metrics {
metrics.record_eviction(memory_size);
}
if entry.is_compressed {
self.decompress(&entry.value).ok()
} else {
Some(entry.value)
}
} else {
None
}
}
pub fn clear(&self) {
self.cache.clear();
if let Some(metrics) = &self.metrics {
metrics.reset();
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn metrics(&self) -> Option<Arc<CacheMetrics>> {
self.metrics.clone()
}
fn handle_eviction(&self, key: &CacheKey) {
if let Some((_, entry)) = self.cache.remove(key) {
if let Some(metrics) = &self.metrics {
metrics.record_eviction(entry.memory_size());
}
}
}
fn maybe_evict(&self) {
let mut policy = self.eviction_policy.lock();
while policy.should_evict() {
if let Some(key) = policy.next_eviction() {
self.handle_eviction(&key);
} else {
break;
}
}
}
fn compress(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
use std::io::Write;
let mut encoder = oxiarc_zstd::ZstdStreamEncoder::new(Vec::new(), 3);
encoder.write_all(data)?;
encoder.finish()
}
fn decompress(&self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
oxiarc_zstd::decode_all(data).map_err(|e| std::io::Error::other(e.to_string()))
}
}
pub struct InferenceCacheBuilder {
config: CacheConfig,
}
impl InferenceCacheBuilder {
pub fn new() -> Self {
Self {
config: CacheConfig::default(),
}
}
pub fn max_entries(mut self, max_entries: usize) -> Self {
self.config.max_entries = Some(max_entries);
self.config.max_memory_bytes = None;
self
}
pub fn max_memory_mb(mut self, max_memory_mb: usize) -> Self {
self.config.max_memory_bytes = Some(max_memory_mb * 1024 * 1024);
self.config.max_entries = None;
self
}
pub fn ttl(mut self, ttl: Duration) -> Self {
self.config.ttl = Some(ttl);
self
}
pub fn enable_metrics(mut self, enable: bool) -> Self {
self.config.enable_metrics = enable;
self
}
pub fn enable_compression(mut self, enable: bool) -> Self {
self.config.compress_values = enable;
self
}
pub fn compression_threshold(mut self, threshold: usize) -> Self {
self.config.compression_threshold = threshold;
self
}
pub fn build(self) -> InferenceCache {
InferenceCache::new(self.config)
}
}
impl Default for InferenceCacheBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::cache_key::CacheKeyBuilder;
#[test]
fn test_basic_cache_operations() {
let cache = InferenceCacheBuilder::new().max_entries(10).enable_metrics(true).build();
let key = CacheKeyBuilder::new("test-model", "classification")
.with_text("Hello world")
.build();
let value = b"prediction result".to_vec();
cache.insert(key.clone(), value.clone());
let retrieved = cache.get(&key).expect("expected value not found");
assert_eq!(retrieved, value);
let metrics = cache.metrics().expect("operation failed in test");
let snapshot = metrics.snapshot();
assert_eq!(snapshot.hits, 1);
assert_eq!(snapshot.misses, 0);
assert_eq!(snapshot.total_entries, 1);
}
#[test]
fn test_compression() {
let cache = InferenceCacheBuilder::new()
.enable_compression(true)
.compression_threshold(10)
.build();
let key = CacheKeyBuilder::new("test-model", "generation")
.with_text("Test prompt")
.build();
let value = vec![42u8; 1000];
cache.insert(key.clone(), value.clone());
let retrieved = cache.get(&key).expect("expected value not found");
assert_eq!(retrieved, value);
let entry = cache.cache.get(&key).expect("expected value not found");
assert!(entry.is_compressed);
assert!(entry.value.len() < entry.uncompressed_size);
}
#[test]
fn test_eviction() {
let cache = InferenceCacheBuilder::new().max_entries(3).enable_metrics(true).build();
let keys: Vec<_> = (0..5)
.map(|i| CacheKeyBuilder::new("model", "task").with_text(&format!("text{}", i)).build())
.collect();
for (i, key) in keys.iter().enumerate() {
cache.insert(key.clone(), vec![i as u8; 100]);
}
assert!(cache.get(&keys[0]).is_none());
assert!(cache.get(&keys[1]).is_none());
assert!(cache.get(&keys[2]).is_some());
assert!(cache.get(&keys[3]).is_some());
assert!(cache.get(&keys[4]).is_some());
let metrics = cache.metrics().expect("operation failed in test");
let snapshot = metrics.snapshot();
assert_eq!(snapshot.evictions, 2);
}
}