use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub max_memory_bytes: usize,
pub ttl: Duration,
pub cacheable_temp_max: f32,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
max_memory_bytes: 100 * 1024 * 1024, ttl: Duration::from_secs(3600), cacheable_temp_max: 0.0, }
}
}
impl CacheConfig {
#[must_use]
pub fn new(max_entries: usize) -> Self {
Self {
max_entries,
..Default::default()
}
}
#[must_use]
pub fn disabled() -> Self {
Self {
max_entries: 0,
max_memory_bytes: 0,
ttl: Duration::ZERO,
cacheable_temp_max: -1.0,
}
}
#[must_use]
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
#[must_use]
pub fn with_cacheable_temp(mut self, temp: f32) -> Self {
self.cacheable_temp_max = temp;
self
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.max_entries > 0 && self.ttl > Duration::ZERO
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
model: String,
content_hash: u64,
max_tokens: Option<u32>,
}
impl CacheKey {
#[must_use]
pub fn new(model: &str, content_hash: u64, max_tokens: Option<u32>) -> Self {
Self {
model: model.to_string(),
content_hash,
max_tokens,
}
}
#[must_use]
pub fn from_chat_request(
model: &str,
messages: &[impl AsRef<str>],
max_tokens: Option<u32>,
) -> Self {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
for msg in messages {
msg.as_ref().hash(&mut hasher);
}
Self::new(model, hasher.finish(), max_tokens)
}
#[must_use]
pub fn from_completion_request(model: &str, prompt: &str, max_tokens: Option<u32>) -> Self {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
prompt.hash(&mut hasher);
Self::new(model, hasher.finish(), max_tokens)
}
#[must_use]
pub fn model(&self) -> &str {
&self.model
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedResponse {
pub body: String,
pub size_bytes: usize,
#[serde(skip)]
created_at: Option<Instant>,
#[serde(skip)]
access_count: u64,
}
impl CachedResponse {
#[must_use]
pub fn new(body: String) -> Self {
let size_bytes = body.len();
Self {
body,
size_bytes,
created_at: Some(Instant::now()),
access_count: 0,
}
}
#[must_use]
pub fn age(&self) -> Duration {
self.created_at.map_or(Duration::ZERO, |t| t.elapsed())
}
#[must_use]
pub fn is_expired(&self, ttl: Duration) -> bool {
self.age() > ttl
}
#[must_use]
pub fn access_count(&self) -> u64 {
self.access_count
}
pub fn record_access(&mut self) -> u64 {
self.access_count += 1;
self.access_count
}
}
struct CacheEntry {
response: CachedResponse,
created_at: Instant,
last_accessed: Instant,
access_count: u64,
}
impl CacheEntry {
fn new(response: CachedResponse) -> Self {
let now = Instant::now();
Self {
response,
created_at: now,
last_accessed: now,
access_count: 1,
}
}
fn touch(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
fn size_bytes(&self) -> usize {
self.response.size_bytes
}
}
#[derive(Debug, Default)]
pub struct CacheMetrics {
hits: AtomicU64,
misses: AtomicU64,
evictions: AtomicU64,
expirations: AtomicU64,
entry_count: AtomicU64,
memory_bytes: AtomicU64,
}
impl CacheMetrics {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn hits(&self) -> u64 {
self.hits.load(Ordering::Relaxed)
}
#[must_use]
pub fn misses(&self) -> u64 {
self.misses.load(Ordering::Relaxed)
}
#[must_use]
pub fn hit_ratio(&self) -> f64 {
let hits = self.hits() as f64;
let total = hits + self.misses() as f64;
if total == 0.0 {
0.0
} else {
hits / total
}
}
#[must_use]
pub fn evictions(&self) -> u64 {
self.evictions.load(Ordering::Relaxed)
}
#[must_use]
pub fn expirations(&self) -> u64 {
self.expirations.load(Ordering::Relaxed)
}
#[must_use]
pub fn entry_count(&self) -> u64 {
self.entry_count.load(Ordering::Relaxed)
}
#[must_use]
pub fn memory_bytes(&self) -> u64 {
self.memory_bytes.load(Ordering::Relaxed)
}
#[must_use]
pub fn render_prometheus(&self) -> String {
format!(
r#"# HELP infernum_cache_hits_total Total cache hits
# TYPE infernum_cache_hits_total counter
infernum_cache_hits_total {}
# HELP infernum_cache_misses_total Total cache misses
# TYPE infernum_cache_misses_total counter
infernum_cache_misses_total {}
# HELP infernum_cache_hit_ratio Cache hit ratio
# TYPE infernum_cache_hit_ratio gauge
infernum_cache_hit_ratio {:.4}
# HELP infernum_cache_evictions_total Total evictions
# TYPE infernum_cache_evictions_total counter
infernum_cache_evictions_total {}
# HELP infernum_cache_entries Current cache entry count
# TYPE infernum_cache_entries gauge
infernum_cache_entries {}
# HELP infernum_cache_memory_bytes Current cache memory usage
# TYPE infernum_cache_memory_bytes gauge
infernum_cache_memory_bytes {}
"#,
self.hits(),
self.misses(),
self.hit_ratio(),
self.evictions(),
self.entry_count(),
self.memory_bytes(),
)
}
fn record_hit(&self) {
self.hits.fetch_add(1, Ordering::Relaxed);
}
fn record_miss(&self) {
self.misses.fetch_add(1, Ordering::Relaxed);
}
fn record_eviction(&self) {
self.evictions.fetch_add(1, Ordering::Relaxed);
}
fn record_expiration(&self) {
self.expirations.fetch_add(1, Ordering::Relaxed);
}
fn set_entry_count(&self, count: u64) {
self.entry_count.store(count, Ordering::Relaxed);
}
fn set_memory_bytes(&self, bytes: u64) {
self.memory_bytes.store(bytes, Ordering::Relaxed);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheResult {
Hit,
Miss,
}
impl CacheResult {
#[must_use]
pub fn header_value(&self) -> &'static str {
match self {
Self::Hit => "HIT",
Self::Miss => "MISS",
}
}
}
pub struct ResponseCache {
cache: RwLock<HashMap<CacheKey, CacheEntry>>,
config: CacheConfig,
metrics: CacheMetrics,
}
impl ResponseCache {
#[must_use]
pub fn new(config: CacheConfig) -> Self {
Self {
cache: RwLock::new(HashMap::with_capacity(config.max_entries)),
config,
metrics: CacheMetrics::new(),
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(CacheConfig::default())
}
#[must_use]
pub fn disabled() -> Self {
Self::new(CacheConfig::disabled())
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.config.is_enabled()
}
#[must_use]
pub fn is_cacheable_temp(&self, temperature: f32) -> bool {
temperature <= self.config.cacheable_temp_max
}
pub fn get(&self, key: &CacheKey) -> Option<CachedResponse> {
if !self.is_enabled() {
return None;
}
let mut cache = self.cache.write().ok()?;
if let Some(entry) = cache.get_mut(key) {
if entry.is_expired(self.config.ttl) {
cache.remove(key);
self.metrics.record_expiration();
self.metrics.record_miss();
self.update_metrics(&cache);
return None;
}
entry.touch();
self.metrics.record_hit();
return Some(entry.response.clone());
}
self.metrics.record_miss();
None
}
pub fn put(&self, key: CacheKey, response: CachedResponse) {
if !self.is_enabled() {
return;
}
let Ok(mut cache) = self.cache.write() else {
return;
};
if cache.contains_key(&key) {
cache.insert(key, CacheEntry::new(response));
self.update_metrics(&cache);
return;
}
while cache.len() >= self.config.max_entries {
self.evict_one_lru(&mut cache);
}
let new_size = response.size_bytes;
let current_memory: usize = cache.values().map(|e| e.size_bytes()).sum();
while current_memory + new_size > self.config.max_memory_bytes && !cache.is_empty() {
self.evict_one_lru(&mut cache);
}
cache.insert(key, CacheEntry::new(response));
self.update_metrics(&cache);
}
pub fn remove(&self, key: &CacheKey) -> Option<CachedResponse> {
let Ok(mut cache) = self.cache.write() else {
return None;
};
let entry = cache.remove(key);
self.update_metrics(&cache);
entry.map(|e| e.response)
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.write() {
cache.clear();
self.update_metrics(&cache);
}
}
#[must_use]
pub fn len(&self) -> usize {
self.cache.read().map_or(0, |c| c.len())
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn metrics(&self) -> &CacheMetrics {
&self.metrics
}
#[must_use]
pub fn config(&self) -> &CacheConfig {
&self.config
}
pub fn cleanup_expired(&self) {
let Ok(mut cache) = self.cache.write() else {
return;
};
let before = cache.len();
cache.retain(|_, entry| !entry.is_expired(self.config.ttl));
let removed = before - cache.len();
for _ in 0..removed {
self.metrics.record_expiration();
}
self.update_metrics(&cache);
}
fn evict_one_lru(&self, cache: &mut HashMap<CacheKey, CacheEntry>) {
if cache.is_empty() {
return;
}
let lru_key = cache
.iter()
.min_by_key(|(_, v)| v.last_accessed)
.map(|(k, _)| k.clone());
if let Some(key) = lru_key {
cache.remove(&key);
self.metrics.record_eviction();
}
}
fn update_metrics(&self, cache: &HashMap<CacheKey, CacheEntry>) {
self.metrics.set_entry_count(cache.len() as u64);
let total_bytes: usize = cache.values().map(|e| e.size_bytes()).sum();
self.metrics.set_memory_bytes(total_bytes as u64);
}
}
impl std::fmt::Debug for ResponseCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseCache")
.field("enabled", &self.is_enabled())
.field("entries", &self.len())
.field("config", &self.config)
.finish()
}
}
pub const CACHE_HEADER: &str = "x-cache";
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert_eq!(config.max_entries, 1000);
assert!(config.is_enabled());
}
#[test]
fn test_cache_config_disabled() {
let config = CacheConfig::disabled();
assert!(!config.is_enabled());
}
#[test]
fn test_cache_config_builder() {
let config = CacheConfig::new(500)
.with_ttl(Duration::from_secs(300))
.with_cacheable_temp(0.1);
assert_eq!(config.max_entries, 500);
assert_eq!(config.ttl, Duration::from_secs(300));
assert!((config.cacheable_temp_max - 0.1).abs() < 0.001);
}
#[test]
fn test_cache_key_from_chat() {
let key1 = CacheKey::from_chat_request("gpt-4", &["Hello", "World"], Some(100));
let key2 = CacheKey::from_chat_request("gpt-4", &["Hello", "World"], Some(100));
let key3 = CacheKey::from_chat_request("gpt-4", &["Hello", "Different"], Some(100));
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_cache_key_from_completion() {
let key1 = CacheKey::from_completion_request("llama", "Hello world", Some(50));
let key2 = CacheKey::from_completion_request("llama", "Hello world", Some(50));
let key3 = CacheKey::from_completion_request("llama", "Different", Some(50));
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_cache_key_different_max_tokens() {
let key1 = CacheKey::from_completion_request("llama", "Hello", Some(50));
let key2 = CacheKey::from_completion_request("llama", "Hello", Some(100));
assert_ne!(key1, key2);
}
#[test]
fn test_cached_response_age() {
let response = CachedResponse::new("test".to_string());
std::thread::sleep(Duration::from_millis(10));
assert!(response.age() >= Duration::from_millis(10));
}
#[test]
fn test_cached_response_expiry() {
let response = CachedResponse::new("test".to_string());
assert!(!response.is_expired(Duration::from_secs(60)));
assert!(response.is_expired(Duration::ZERO));
}
#[test]
fn test_cache_put_and_get() {
let cache = ResponseCache::with_defaults();
let key = CacheKey::from_completion_request("test", "hello", None);
let response = CachedResponse::new("response".to_string());
cache.put(key.clone(), response);
let cached = cache.get(&key);
assert!(cached.is_some());
assert_eq!(cached.unwrap().body, "response");
}
#[test]
fn test_cache_miss() {
let cache = ResponseCache::with_defaults();
let key = CacheKey::from_completion_request("test", "hello", None);
assert!(cache.get(&key).is_none());
assert_eq!(cache.metrics().misses(), 1);
}
#[test]
fn test_cache_hit_metrics() {
let cache = ResponseCache::with_defaults();
let key = CacheKey::from_completion_request("test", "hello", None);
cache.put(key.clone(), CachedResponse::new("response".to_string()));
let _ = cache.get(&key);
assert_eq!(cache.metrics().hits(), 1);
}
#[test]
fn test_cache_remove() {
let cache = ResponseCache::with_defaults();
let key = CacheKey::from_completion_request("test", "hello", None);
cache.put(key.clone(), CachedResponse::new("response".to_string()));
assert!(!cache.is_empty());
let removed = cache.remove(&key);
assert!(removed.is_some());
assert!(cache.is_empty());
}
#[test]
fn test_cache_clear() {
let cache = ResponseCache::with_defaults();
for i in 0..10 {
let key = CacheKey::from_completion_request("test", &format!("hello{}", i), None);
cache.put(key, CachedResponse::new("response".to_string()));
}
assert_eq!(cache.len(), 10);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_cache_ttl_expiry() {
let config = CacheConfig::new(100).with_ttl(Duration::from_millis(50));
let cache = ResponseCache::new(config);
let key = CacheKey::from_completion_request("test", "hello", None);
cache.put(key.clone(), CachedResponse::new("response".to_string()));
assert!(cache.get(&key).is_some());
std::thread::sleep(Duration::from_millis(100));
assert!(cache.get(&key).is_none());
assert_eq!(cache.metrics().expirations(), 1);
}
#[test]
fn test_cache_lru_eviction() {
let config = CacheConfig {
max_entries: 3,
max_memory_bytes: 1024 * 1024,
ttl: Duration::from_secs(3600),
cacheable_temp_max: 0.0,
};
let cache = ResponseCache::new(config);
for i in 0..3 {
let key = CacheKey::from_completion_request("test", &format!("msg{}", i), None);
cache.put(key, CachedResponse::new("response".to_string()));
std::thread::sleep(Duration::from_millis(10));
}
assert_eq!(cache.len(), 3);
let key1 = CacheKey::from_completion_request("test", "msg1", None);
let _ = cache.get(&key1);
let key3 = CacheKey::from_completion_request("test", "msg3", None);
cache.put(key3.clone(), CachedResponse::new("new".to_string()));
assert_eq!(cache.len(), 3);
let key0 = CacheKey::from_completion_request("test", "msg0", None);
assert!(cache.get(&key0).is_none());
assert!(cache.get(&key1).is_some());
}
#[test]
fn test_cache_disabled() {
let cache = ResponseCache::disabled();
assert!(!cache.is_enabled());
let key = CacheKey::from_completion_request("test", "hello", None);
cache.put(key.clone(), CachedResponse::new("response".to_string()));
assert!(cache.is_empty());
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cacheable_temp() {
let config = CacheConfig::new(100).with_cacheable_temp(0.1);
let cache = ResponseCache::new(config);
assert!(cache.is_cacheable_temp(0.0));
assert!(cache.is_cacheable_temp(0.1));
assert!(!cache.is_cacheable_temp(0.2));
assert!(!cache.is_cacheable_temp(1.0));
}
#[test]
fn test_cache_metrics_prometheus() {
let cache = ResponseCache::with_defaults();
let key = CacheKey::from_completion_request("test", "hello", None);
cache.put(key.clone(), CachedResponse::new("response".to_string()));
let _ = cache.get(&key);
let _ = cache.get(&CacheKey::from_completion_request("test", "miss", None));
let output = cache.metrics().render_prometheus();
assert!(output.contains("infernum_cache_hits_total 1"));
assert!(output.contains("infernum_cache_misses_total 1"));
assert!(output.contains("infernum_cache_hit_ratio 0.5"));
}
#[test]
fn test_cache_result_header() {
assert_eq!(CacheResult::Hit.header_value(), "HIT");
assert_eq!(CacheResult::Miss.header_value(), "MISS");
}
#[test]
fn test_cache_memory_limit() {
let config = CacheConfig {
max_entries: 1000,
max_memory_bytes: 100, ttl: Duration::from_secs(3600),
cacheable_temp_max: 0.0,
};
let cache = ResponseCache::new(config);
let key1 = CacheKey::from_completion_request("test", "msg1", None);
cache.put(
key1.clone(),
CachedResponse::new("x".repeat(50)), );
assert_eq!(cache.len(), 1);
let key2 = CacheKey::from_completion_request("test", "msg2", None);
cache.put(
key2.clone(),
CachedResponse::new("y".repeat(60)), );
assert!(cache.metrics().evictions() > 0);
}
#[test]
fn test_cleanup_expired() {
let config = CacheConfig::new(100).with_ttl(Duration::from_millis(50));
let cache = ResponseCache::new(config);
for i in 0..5 {
let key = CacheKey::from_completion_request("test", &format!("msg{}", i), None);
cache.put(key, CachedResponse::new("response".to_string()));
}
assert_eq!(cache.len(), 5);
std::thread::sleep(Duration::from_millis(100));
cache.cleanup_expired();
assert_eq!(cache.len(), 0);
assert_eq!(cache.metrics().expirations(), 5);
}
#[test]
fn test_cache_debug() {
let cache = ResponseCache::with_defaults();
let debug_str = format!("{:?}", cache);
assert!(debug_str.contains("ResponseCache"));
assert!(debug_str.contains("enabled"));
}
#[test]
fn test_hit_ratio_empty() {
let metrics = CacheMetrics::new();
assert_eq!(metrics.hit_ratio(), 0.0);
}
}