use crate::core::models::openai::{ChatCompletionRequest, ChatCompletionResponse};
use crate::utils::error::Result;
use crate::utils::result_ext::ResultExt;
use crate::utils::string_pool::intern_string;
use dashmap::DashMap;
use lru::LruCache;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub max_entries: usize,
pub default_ttl: Duration,
pub enable_semantic: bool,
pub similarity_threshold: f32,
pub min_prompt_length: usize,
pub enable_compression: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 10000,
default_ttl: Duration::from_secs(3600), enable_semantic: true,
similarity_threshold: 0.95,
min_prompt_length: 10,
enable_compression: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheEntry<T> {
pub value: T,
pub created_at: Instant,
pub expires_at: Instant,
pub access_count: u64,
pub last_accessed: Instant,
pub size_bytes: usize,
}
impl<T> CacheEntry<T> {
pub fn new(value: T, ttl: Duration, size_bytes: usize) -> Self {
let now = Instant::now();
Self {
value,
created_at: now,
expires_at: now + ttl,
access_count: 0,
last_accessed: now,
size_bytes,
}
}
pub fn is_expired(&self) -> bool {
Instant::now() > self.expires_at
}
pub fn mark_accessed(&mut self) {
self.access_count += 1;
self.last_accessed = Instant::now();
}
pub fn age(&self) -> Duration {
Instant::now().duration_since(self.created_at)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub model: Arc<str>,
pub request_hash: u64,
pub user_id: Option<Arc<str>>,
}
impl CacheKey {
pub fn from_request(request: &ChatCompletionRequest, user_id: Option<&str>) -> Self {
let model = intern_string(&request.model);
let request_hash = Self::hash_request(request);
let user_id = user_id.map(intern_string);
Self {
model,
request_hash,
user_id,
}
}
fn hash_request(request: &ChatCompletionRequest) -> u64 {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
for message in &request.messages {
message.role.hash(&mut hasher);
if let Some(content) = &message.content {
content.hash(&mut hasher);
}
}
request.temperature.map(|t| t.to_bits()).hash(&mut hasher);
request.max_tokens.hash(&mut hasher);
request.top_p.map(|p| p.to_bits()).hash(&mut hasher);
request
.frequency_penalty
.map(|p| p.to_bits())
.hash(&mut hasher);
request
.presence_penalty
.map(|p| p.to_bits())
.hash(&mut hasher);
request.stop.hash(&mut hasher);
hasher.finish()
}
}
pub struct CacheManager {
l1_cache: Arc<RwLock<LruCache<CacheKey, CacheEntry<ChatCompletionResponse>>>>,
l2_cache: Arc<DashMap<CacheKey, CacheEntry<ChatCompletionResponse>>>,
semantic_cache: Arc<RwLock<HashMap<String, Vec<(CacheKey, f32)>>>>,
config: CacheConfig,
stats: Arc<RwLock<CacheStats>>,
}
#[derive(Debug, Default)]
pub struct CacheStats {
pub l1_hits: u64,
pub l1_misses: u64,
pub l2_hits: u64,
pub l2_misses: u64,
pub semantic_hits: u64,
pub semantic_misses: u64,
pub evictions: u64,
pub total_size_bytes: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total_hits = self.l1_hits + self.l2_hits + self.semantic_hits;
let total_requests = total_hits + self.l1_misses + self.l2_misses + self.semantic_misses;
if total_requests == 0 {
0.0
} else {
total_hits as f64 / total_requests as f64
}
}
}
impl CacheManager {
pub fn new(config: CacheConfig) -> Result<Self> {
let l1_capacity = NonZeroUsize::new(config.max_entries / 10)
.or_else(|| NonZeroUsize::new(100))
.ok_or_else(|| {
crate::utils::error::GatewayError::Config(
"Invalid cache configuration: max_entries must be greater than 0".to_string(),
)
})?;
Ok(Self {
l1_cache: Arc::new(RwLock::new(LruCache::new(l1_capacity))),
l2_cache: Arc::new(DashMap::new()),
semantic_cache: Arc::new(RwLock::new(HashMap::new())),
config,
stats: Arc::new(RwLock::new(CacheStats::default())),
})
}
pub async fn get(&self, key: &CacheKey) -> Result<Option<ChatCompletionResponse>> {
{
let mut l1 = self.l1_cache.write();
if let Some(entry) = l1.get_mut(key) {
if !entry.is_expired() {
entry.mark_accessed();
self.stats.write().l1_hits += 1;
debug!("L1 cache hit for key: {:?}", key);
return Ok(Some(entry.value.clone()));
} else {
l1.pop(key);
}
}
}
self.stats.write().l1_misses += 1;
if let Some(mut entry) = self.l2_cache.get_mut(key) {
if !entry.is_expired() {
entry.mark_accessed();
let mut l1 = self.l1_cache.write();
l1.put(key.clone(), entry.clone());
self.stats.write().l2_hits += 1;
debug!("L2 cache hit for key: {:?}", key);
return Ok(Some(entry.value.clone()));
} else {
self.l2_cache.remove(key);
}
}
self.stats.write().l2_misses += 1;
if self.config.enable_semantic {
if let Some(response) = self.semantic_lookup(key).await? {
self.stats.write().semantic_hits += 1;
debug!("Semantic cache hit for key: {:?}", key);
return Ok(Some(response));
}
}
self.stats.write().semantic_misses += 1;
Ok(None)
}
pub async fn put(&self, key: CacheKey, response: ChatCompletionResponse) -> Result<()> {
let size_bytes = self.estimate_size(&response);
let entry = CacheEntry::new(response, self.config.default_ttl, size_bytes);
self.l2_cache.insert(key.clone(), entry.clone());
if self.config.enable_semantic {
self.update_semantic_cache(&key).await?;
}
{
let mut stats = self.stats.write();
stats.total_size_bytes += size_bytes;
}
if self.l2_cache.len() % 1000 == 0 {
self.cleanup_expired().await;
}
debug!("Cached response for key: {:?}", key);
Ok(())
}
async fn semantic_lookup(&self, _key: &CacheKey) -> Result<Option<ChatCompletionResponse>> {
Ok(None)
}
async fn update_semantic_cache(&self, _key: &CacheKey) -> Result<()> {
Ok(())
}
fn estimate_size(&self, response: &ChatCompletionResponse) -> usize {
serde_json::to_string(response)
.map(|s| s.len())
.unwrap_or(1024) }
async fn cleanup_expired(&self) {
let mut removed_count = 0;
let mut removed_size = 0;
self.l2_cache.retain(|_, entry| {
if entry.is_expired() {
removed_count += 1;
removed_size += entry.size_bytes;
false
} else {
true
}
});
{
let mut stats = self.stats.write();
stats.evictions += removed_count;
stats.total_size_bytes = stats.total_size_bytes.saturating_sub(removed_size);
}
if removed_count > 0 {
info!(
"Cleaned up {} expired cache entries, freed {} bytes",
removed_count, removed_size
);
}
}
pub fn stats(&self) -> CacheStats {
let stats_result: Result<CacheStats> = Ok({
let stats = self.stats.read();
CacheStats {
l1_hits: stats.l1_hits,
l1_misses: stats.l1_misses,
l2_hits: stats.l2_hits,
l2_misses: stats.l2_misses,
semantic_hits: stats.semantic_hits,
semantic_misses: stats.semantic_misses,
evictions: stats.evictions,
total_size_bytes: stats.total_size_bytes,
}
});
stats_result.unwrap_or_log_default("cache stats retrieval")
}
pub async fn clear(&self) {
self.l1_cache.write().clear();
self.l2_cache.clear();
self.semantic_cache.write().clear();
let mut stats = self.stats.write();
*stats = CacheStats::default();
info!("All caches cleared");
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::models::openai::*;
#[tokio::test]
async fn test_cache_manager() -> Result<()> {
let config = CacheConfig::default();
let cache = CacheManager::new(config)?;
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello".to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
}],
..Default::default()
};
let key = CacheKey::from_request(&request, None);
let initial_result = cache.get(&key).await?;
assert!(initial_result.is_none());
let response = ChatCompletionResponse {
id: "test".to_string(),
object: "chat.completion".to_string(),
created: 1234567890,
model: "gpt-4".to_string(),
choices: vec![],
usage: None,
system_fingerprint: None,
};
cache.put(key.clone(), response.clone()).await?;
let cached = cache.get(&key).await?;
assert!(cached.is_some());
if let Some(cached_response) = cached {
assert_eq!(cached_response.id, response.id);
}
Ok(())
}
#[test]
fn test_cache_key_generation() {
let request1 = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello".to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
}],
..Default::default()
};
let request2 = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello".to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
}],
..Default::default()
};
let key1 = CacheKey::from_request(&request1, None);
let key2 = CacheKey::from_request(&request2, None);
assert_eq!(key1, key2);
}
}